Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add param inject for auto inject #237 #238

Open
wants to merge 2 commits into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/dishka/integrations/aiogram.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from aiogram.types import TelegramObject

from dishka import AsyncContainer, FromDishka, Provider, Scope, from_context
from .base import is_dishka_injected, wrap_injection
from .base import InjectDecorator, is_dishka_injected, wrap_injection

P = ParamSpec("P")
T = TypeVar("T")
Expand Down Expand Up @@ -61,6 +61,9 @@ async def __call__(


class AutoInjectMiddleware(BaseMiddleware):
def __init__(self, inject_decorator: InjectDecorator = inject) -> None:
self.inject_decorator = inject_decorator

async def __call__(
self,
handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]],
Expand All @@ -72,7 +75,7 @@ async def __call__(
return await handler(event, data)

new_handler = HandlerObject(
callback=inject(old_handler.callback),
callback=self.inject_decorator(old_handler.callback),
filters=old_handler.filters,
flags=old_handler.flags,
)
Expand Down
5 changes: 2 additions & 3 deletions src/dishka/integrations/aiogram_dialog.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
__all__ = ["inject"]

from collections.abc import Callable
from typing import Any, Final, ParamSpec, TypeVar, cast
from typing import Any, Final, TypeVar, cast

from dishka import AsyncContainer
from dishka.integrations.base import wrap_injection

T = TypeVar("T")
P = ParamSpec("P")
TWO: Final = 2
CONTAINER_NAME: Final = "dishka_container"

Expand All @@ -26,7 +25,7 @@ def _container_getter(
return cast(AsyncContainer, container)


def inject(func: Callable[P, T]) -> Callable[P, T]:
def inject(func: Callable[..., T]) -> Callable[..., T]:
return wrap_injection(
func=func,
is_async=True,
Expand Down
25 changes: 18 additions & 7 deletions src/dishka/integrations/aiohttp.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,11 @@
from aiohttp.web_response import StreamResponse

from dishka import AsyncContainer, FromDishka, Provider, Scope, from_context
from dishka.integrations.base import is_dishka_injected, wrap_injection
from dishka.integrations.base import (
InjectDecorator,
is_dishka_injected,
wrap_injection,
)

T = TypeVar("T")
DISHKA_CONTAINER_KEY: Final = web.AppKey("dishka_container", AsyncContainer)
Expand Down Expand Up @@ -65,22 +69,28 @@ async def container_middleware(
return await handler(request)


def _inject_routes(router: web.UrlDispatcher) -> None:
def _inject_routes(
router: web.UrlDispatcher,
inject_decorator: InjectDecorator,
) -> None:
for route in router.routes():
_inject_route(route)
_inject_route(route, inject_decorator)

for resource in router.resources():
for route in resource._routes: # type: ignore[attr-defined] # noqa: SLF001
_inject_route(route)
_inject_route(route, inject_decorator)


def _inject_route(route: web.AbstractRoute) -> None:
def _inject_route(
route: web.AbstractRoute,
inject_decorator: InjectDecorator,
) -> None:
if not is_dishka_injected(route.handler):
# typing.cast is used because AbstractRoute._handler
# is Handler or Type[AbstractView]
route._handler = cast( # noqa: SLF001
AiohttpHandler,
inject(route.handler),
inject_decorator(route.handler),
)


Expand All @@ -93,10 +103,11 @@ def setup_dishka(
app: Application,
*,
auto_inject: bool = False,
inject_decorator: InjectDecorator = inject,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do not like inject_decorator as a separate parameter here as it is only applicable if auto_inject is enabled.

) -> None:
app[DISHKA_CONTAINER_KEY] = container
app.middlewares.append(container_middleware)
app.on_shutdown.append(_on_shutdown)

if auto_inject:
_inject_routes(app.router)
_inject_routes(app.router, inject_decorator)
1 change: 1 addition & 0 deletions src/dishka/integrations/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
type | Sequence[type],
FromDishka | _FromComponent,
)
InjectDecorator: TypeAlias = Callable[[Callable[..., Any]], Any]


def default_parse_dependency(
Expand Down
15 changes: 10 additions & 5 deletions src/dishka/integrations/click.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
)

from dishka import Container, FromDishka
from .base import is_dishka_injected, wrap_injection
from .base import InjectDecorator, is_dishka_injected, wrap_injection

T = TypeVar("T")
CONTAINER_NAME: Final = "dishka_container"
Expand All @@ -32,16 +32,20 @@ def inject(func: Callable[..., T]) -> Callable[..., T]:
)


def _inject_commands(context: Context, command: Command | None) -> None:
def _inject_commands(
context: Context,
command: Command | None,
inject_decorator: InjectDecorator,
) -> None:
if isinstance(command, Command) and not is_dishka_injected(
command.callback, # type: ignore[arg-type]
):
command.callback = inject(command.callback) # type: ignore[arg-type]
command.callback = inject_decorator(command.callback) # type: ignore[arg-type]

if isinstance(command, Group):
for command_name in command.list_commands(context):
child_command = command.get_command(context, command_name)
_inject_commands(context, child_command)
_inject_commands(context, child_command, inject_decorator)


def setup_dishka(
Expand All @@ -50,11 +54,12 @@ def setup_dishka(
*,
finalize_container: bool = True,
auto_inject: bool = False,
inject_decorator: InjectDecorator = inject,
) -> None:
context.meta[CONTAINER_NAME] = container

if finalize_container:
context.call_on_close(container.close)

if auto_inject:
_inject_commands(context, context.command)
_inject_commands(context, context.command, inject_decorator)
21 changes: 11 additions & 10 deletions src/dishka/integrations/faststream.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from faststream.utils.context import ContextRepo

from dishka import AsyncContainer, FromDishka, Provider, Scope, from_context
from dishka.integrations.base import wrap_injection
from dishka.integrations.base import InjectDecorator, wrap_injection

T = TypeVar("T")
P = ParamSpec("P")
Expand All @@ -31,6 +31,14 @@ class FastStreamProvider(Provider):
FASTSTREAM_OLD_MIDDLEWARES = __version__ < "0.5"


def inject(func: Callable[P, T]) -> Callable[P, T]:
return wrap_injection(
func=func,
is_async=True,
container_getter=lambda *_: context.get_local("dishka"),
)


class _DishkaBaseMiddleware(BaseMiddleware):
def __init__(self, container: AsyncContainer) -> None:
self.container = container
Expand Down Expand Up @@ -85,6 +93,7 @@ def setup_dishka(
*,
finalize_container: bool = True,
auto_inject: bool = False,
inject_decorator: InjectDecorator = inject,
) -> None:
assert app.broker, "You can't patch FastStream application without broker" # noqa: S101

Expand Down Expand Up @@ -128,14 +137,6 @@ def setup_dishka(

if auto_inject:
app.broker._call_decorators = ( # noqa: SLF001
inject,
inject_decorator,
*app.broker._call_decorators, # noqa: SLF001
)


def inject(func: Callable[P, T]) -> Callable[P, T]:
return wrap_injection(
func=func,
is_async=True,
container_getter=lambda *_: context.get_local("dishka"),
)
17 changes: 12 additions & 5 deletions src/dishka/integrations/flask.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from flask.typing import RouteCallable

from dishka import Container, FromDishka, Provider, Scope, from_context
from .base import is_dishka_injected, wrap_injection
from .base import InjectDecorator, is_dishka_injected, wrap_injection

T = TypeVar("T")
P = ParamSpec("P")
Expand Down Expand Up @@ -43,24 +43,31 @@ def exit_request(self, *_args: Any, **_kwargs: Any) -> None:
g.dishka_container.close()


def _inject_routes(scaffold: Scaffold) -> None:
def _inject_routes(
scaffold: Scaffold,
inject_decorator: InjectDecorator,
) -> None:
for key, func in scaffold.view_functions.items():
if not is_dishka_injected(func):
# typing.cast is applied because there
# are RouteCallable objects in dict value
scaffold.view_functions[key] = cast(RouteCallable, inject(func))
scaffold.view_functions[key] = cast(
RouteCallable,
inject_decorator(func),
)


def setup_dishka(
container: Container,
app: Flask,
*,
auto_inject: bool = False,
inject_decorator: InjectDecorator = inject,
) -> None:
middleware = ContainerMiddleware(container)
app.before_request(middleware.enter_request)
app.teardown_appcontext(middleware.exit_request)
if auto_inject:
_inject_routes(app)
_inject_routes(app, inject_decorator)
for blueprint in app.blueprints.values():
_inject_routes(blueprint)
_inject_routes(blueprint, inject_decorator)
21 changes: 15 additions & 6 deletions src/dishka/integrations/sanic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,13 @@
from sanic.models.handler_types import RouteHandler
from sanic_routing import Route

from dishka import AsyncContainer, FromDishka, Provider, Scope, from_context
from dishka.integrations.base import is_dishka_injected, wrap_injection
from dishka import AsyncContainer, Provider, Scope, from_context
from dishka.integrations.base import (
FromDishka,
InjectDecorator,
is_dishka_injected,
wrap_injection,
)


def inject(func: RouteHandler) -> RouteHandler:
Expand Down Expand Up @@ -43,23 +48,27 @@ async def on_response(self, request: Request, _: HTTPResponse) -> None:
await request.ctx.dishka_container.close()


def _inject_routes(routes: Iterable[Route]) -> None:
def _inject_routes(
routes: Iterable[Route],
inject_decorator: InjectDecorator = inject,
) -> None:
for route in routes:
if not is_dishka_injected(route.handler):
route.handler = inject(route.handler)
route.handler = inject_decorator(route.handler)


def setup_dishka(
container: AsyncContainer,
app: Sanic[Any, Any],
*,
auto_inject: bool = False,
inject_decorator: InjectDecorator = inject,
) -> None:
middleware = ContainerMiddleware(container)
app.on_request(middleware.on_request)
app.on_response(middleware.on_response) # type: ignore[no-untyped-call]

if auto_inject:
_inject_routes(app.router.routes)
_inject_routes(app.router.routes, inject_decorator)
for blueprint in app.blueprints.values():
_inject_routes(blueprint.routes)
_inject_routes(blueprint.routes, inject_decorator)