From 41bbb4215dbb77e31c9747e30b64216bb40982ca Mon Sep 17 00:00:00 2001 From: IvanKirpichnikov Date: Sun, 1 Sep 2024 12:47:49 +0300 Subject: [PATCH] add param inject for auto inject --- src/dishka/integrations/aiogram.py | 7 +++++-- src/dishka/integrations/aiogram_dialog.py | 5 ++--- src/dishka/integrations/aiohttp.py | 25 ++++++++++++++++------- src/dishka/integrations/base.py | 1 + src/dishka/integrations/click.py | 15 +++++++++----- src/dishka/integrations/faststream.py | 21 ++++++++++--------- src/dishka/integrations/flask.py | 17 ++++++++++----- src/dishka/integrations/sanic.py | 18 +++++++++++----- 8 files changed, 72 insertions(+), 37 deletions(-) diff --git a/src/dishka/integrations/aiogram.py b/src/dishka/integrations/aiogram.py index aadf53a7..adbf9d7b 100644 --- a/src/dishka/integrations/aiogram.py +++ b/src/dishka/integrations/aiogram.py @@ -15,7 +15,7 @@ from aiogram.types import TelegramObject from dishka import AsyncContainer, FromDishka -from .base import is_dishka_injected, wrap_injection +from .base import InjectDecorator, is_dishka_injected, wrap_injection P = ParamSpec("P") T = TypeVar("T") @@ -56,6 +56,9 @@ async def __call__( class AutoInjectMiddleware(BaseMiddleware): + def __init__(self, inject_decorator: InjectDecorator = inject) -> None: + self.inject_decorator = inject_decorator + async def __call__( self, handler: Callable[[TelegramObject, dict[str, Any]], Awaitable[Any]], @@ -67,7 +70,7 @@ async def __call__( return await handler(event, data) new_handler = HandlerObject( - callback=inject(old_handler.callback), + callback=self.inject_decorator(old_handler.callback), filters=old_handler.filters, flags=old_handler.flags, ) diff --git a/src/dishka/integrations/aiogram_dialog.py b/src/dishka/integrations/aiogram_dialog.py index 4d9a49d8..05539d55 100644 --- a/src/dishka/integrations/aiogram_dialog.py +++ b/src/dishka/integrations/aiogram_dialog.py @@ -1,13 +1,12 @@ __all__ = ["inject"] from collections.abc import Callable -from typing import Any, Final, ParamSpec, TypeVar, cast +from typing import Any, Final, TypeVar, cast from dishka import AsyncContainer from dishka.integrations.base import wrap_injection T = TypeVar("T") -P = ParamSpec("P") TWO: Final = 2 CONTAINER_NAME: Final = "dishka_container" @@ -26,7 +25,7 @@ def _container_getter( return cast(AsyncContainer, container) -def inject(func: Callable[P, T]) -> Callable[P, T]: +def inject(func: Callable[..., T]) -> Callable[..., T]: return wrap_injection( func=func, is_async=True, diff --git a/src/dishka/integrations/aiohttp.py b/src/dishka/integrations/aiohttp.py index add68bba..42e91131 100644 --- a/src/dishka/integrations/aiohttp.py +++ b/src/dishka/integrations/aiohttp.py @@ -16,7 +16,11 @@ from aiohttp.web_response import StreamResponse from dishka import AsyncContainer, FromDishka, Scope -from dishka.integrations.base import is_dishka_injected, wrap_injection +from dishka.integrations.base import ( + InjectDecorator, + is_dishka_injected, + wrap_injection, +) T = TypeVar("T") DISHKA_CONTAINER_KEY: Final = web.AppKey("dishka_container", AsyncContainer) @@ -60,22 +64,28 @@ async def container_middleware( return await handler(request) -def _inject_routes(router: web.UrlDispatcher) -> None: +def _inject_routes( + router: web.UrlDispatcher, + inject_decorator: InjectDecorator, +) -> None: for route in router.routes(): - _inject_route(route) + _inject_route(route, inject_decorator) for resource in router.resources(): for route in resource._routes: # type: ignore[attr-defined] # noqa: SLF001 - _inject_route(route) + _inject_route(route, inject_decorator) -def _inject_route(route: web.AbstractRoute) -> None: +def _inject_route( + route: web.AbstractRoute, + inject_decorator: InjectDecorator, +) -> None: if not is_dishka_injected(route.handler): # typing.cast is used because AbstractRoute._handler # is Handler or Type[AbstractView] route._handler = cast( # noqa: SLF001 AiohttpHandler, - inject(route.handler), + inject_decorator(route.handler), ) @@ -88,10 +98,11 @@ def setup_dishka( app: Application, *, auto_inject: bool = False, + inject_decorator: InjectDecorator = inject, ) -> None: app[DISHKA_CONTAINER_KEY] = container app.middlewares.append(container_middleware) app.on_shutdown.append(_on_shutdown) if auto_inject: - _inject_routes(app.router) + _inject_routes(app.router, inject_decorator) diff --git a/src/dishka/integrations/base.py b/src/dishka/integrations/base.py index 7971e083..83f456cb 100644 --- a/src/dishka/integrations/base.py +++ b/src/dishka/integrations/base.py @@ -34,6 +34,7 @@ type | Sequence[type], FromDishka | FromComponent, ) +InjectDecorator: TypeAlias = Callable[[Callable[..., Any]], Any] def default_parse_dependency( diff --git a/src/dishka/integrations/click.py b/src/dishka/integrations/click.py index eb6cd954..16f997aa 100644 --- a/src/dishka/integrations/click.py +++ b/src/dishka/integrations/click.py @@ -15,7 +15,7 @@ ) from dishka import Container, FromDishka -from .base import is_dishka_injected, wrap_injection +from .base import InjectDecorator, is_dishka_injected, wrap_injection T = TypeVar("T") CONTAINER_NAME: Final = "dishka_container" @@ -32,16 +32,20 @@ def inject(func: Callable[..., T]) -> Callable[..., T]: ) -def _inject_commands(context: Context, command: Command | None) -> None: +def _inject_commands( + context: Context, + command: Command | None, + inject_decorator: InjectDecorator, +) -> None: if isinstance(command, Command) and not is_dishka_injected( command.callback, # type: ignore[arg-type] ): - command.callback = inject(command.callback) # type: ignore[arg-type] + command.callback = inject_decorator(command.callback) # type: ignore[arg-type] if isinstance(command, Group): for command_name in command.list_commands(context): child_command = command.get_command(context, command_name) - _inject_commands(context, child_command) + _inject_commands(context, child_command, inject_decorator) def setup_dishka( @@ -50,6 +54,7 @@ def setup_dishka( *, finalize_container: bool = True, auto_inject: bool = False, + inject_decorator: InjectDecorator = inject, ) -> None: context.meta[CONTAINER_NAME] = container @@ -57,4 +62,4 @@ def setup_dishka( context.call_on_close(container.close) if auto_inject: - _inject_commands(context, context.command) + _inject_commands(context, context.command, inject_decorator) diff --git a/src/dishka/integrations/faststream.py b/src/dishka/integrations/faststream.py index 3b08e42c..bf2f41c8 100644 --- a/src/dishka/integrations/faststream.py +++ b/src/dishka/integrations/faststream.py @@ -17,7 +17,7 @@ from faststream.utils.context import ContextRepo from dishka import AsyncContainer, FromDishka, Provider, Scope, from_context -from dishka.integrations.base import wrap_injection +from dishka.integrations.base import InjectDecorator, wrap_injection T = TypeVar("T") P = ParamSpec("P") @@ -30,6 +30,14 @@ class FastStreamProvider(Provider): FASTSTREAM_OLD_MIDDLEWARES = __version__ < "0.5" +def inject(func: Callable[P, T]) -> Callable[P, T]: + return wrap_injection( + func=func, + is_async=True, + container_getter=lambda *_: context.get_local("dishka"), + ) + + class _DishkaBaseMiddleware(BaseMiddleware): def __init__(self, container: AsyncContainer) -> None: self.container = container @@ -84,6 +92,7 @@ def setup_dishka( *, finalize_container: bool = True, auto_inject: bool = False, + inject_decorator: InjectDecorator = inject, ) -> None: assert app.broker, "You can't patch FastStream application without broker" # noqa: S101 @@ -127,14 +136,6 @@ def setup_dishka( if auto_inject: app.broker._call_decorators = ( # noqa: SLF001 - inject, + inject_decorator, *app.broker._call_decorators, # noqa: SLF001 ) - - -def inject(func: Callable[P, T]) -> Callable[P, T]: - return wrap_injection( - func=func, - is_async=True, - container_getter=lambda *_: context.get_local("dishka"), - ) diff --git a/src/dishka/integrations/flask.py b/src/dishka/integrations/flask.py index ee79574d..1ad5d0b0 100644 --- a/src/dishka/integrations/flask.py +++ b/src/dishka/integrations/flask.py @@ -12,7 +12,7 @@ from flask.typing import RouteCallable from dishka import Container, FromDishka -from .base import is_dishka_injected, wrap_injection +from .base import InjectDecorator, is_dishka_injected, wrap_injection T = TypeVar("T") P = ParamSpec("P") @@ -38,12 +38,18 @@ def exit_request(self, *_args: Any, **_kwargs: Any) -> None: g.dishka_container.close() -def _inject_routes(scaffold: Scaffold) -> None: +def _inject_routes( + scaffold: Scaffold, + inject_decorator: InjectDecorator, +) -> None: for key, func in scaffold.view_functions.items(): if not is_dishka_injected(func): # typing.cast is applied because there # are RouteCallable objects in dict value - scaffold.view_functions[key] = cast(RouteCallable, inject(func)) + scaffold.view_functions[key] = cast( + RouteCallable, + inject_decorator(func), + ) def setup_dishka( @@ -51,11 +57,12 @@ def setup_dishka( app: Flask, *, auto_inject: bool = False, + inject_decorator: InjectDecorator = inject, ) -> None: middleware = ContainerMiddleware(container) app.before_request(middleware.enter_request) app.teardown_appcontext(middleware.exit_request) if auto_inject: - _inject_routes(app) + _inject_routes(app, inject_decorator) for blueprint in app.blueprints.values(): - _inject_routes(blueprint) + _inject_routes(blueprint, inject_decorator) diff --git a/src/dishka/integrations/sanic.py b/src/dishka/integrations/sanic.py index 2fcd4ff5..4a185144 100644 --- a/src/dishka/integrations/sanic.py +++ b/src/dishka/integrations/sanic.py @@ -12,7 +12,11 @@ from sanic_routing import Route from dishka import AsyncContainer, FromDishka -from dishka.integrations.base import is_dishka_injected, wrap_injection +from dishka.integrations.base import ( + InjectDecorator, + is_dishka_injected, + wrap_injection, +) def inject(func: RouteHandler) -> RouteHandler: @@ -38,10 +42,13 @@ async def on_response(self, request: Request, _: HTTPResponse) -> None: await request.ctx.dishka_container.close() -def _inject_routes(routes: Iterable[Route]) -> None: +def _inject_routes( + routes: Iterable[Route], + inject_decorator: InjectDecorator = inject, +) -> None: for route in routes: if not is_dishka_injected(route.handler): - route.handler = inject(route.handler) + route.handler = inject_decorator(route.handler) def setup_dishka( @@ -49,12 +56,13 @@ def setup_dishka( app: Sanic[Any, Any], *, auto_inject: bool = False, + inject_decorator: InjectDecorator = inject, ) -> None: middleware = ContainerMiddleware(container) app.on_request(middleware.on_request) app.on_response(middleware.on_response) # type: ignore[no-untyped-call] if auto_inject: - _inject_routes(app.router.routes) + _inject_routes(app.router.routes, inject_decorator) for blueprint in app.blueprints.values(): - _inject_routes(blueprint.routes) + _inject_routes(blueprint.routes, inject_decorator)