diff --git a/robyn/authentication.py b/robyn/authentication.py index dc01edf74..b2f6e8345 100644 --- a/robyn/authentication.py +++ b/robyn/authentication.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Optional +from typing import Any, Coroutine, Optional, Union from robyn.robyn import Headers, Identity, Request, Response from robyn.status_codes import HTTP_401_UNAUTHORIZED @@ -64,7 +64,7 @@ def unauthorized_response(self) -> Response: ) @abstractmethod - def authenticate(self, request: Request) -> Optional[Identity]: + def authenticate(self, request: Request) -> Union[Optional[Identity], Coroutine[Any, Any, Optional[Identity]]]: """ Authenticates the user. :param request: The request object. diff --git a/robyn/router.py b/robyn/router.py index 83af186f9..dc4b40d15 100644 --- a/robyn/router.py +++ b/robyn/router.py @@ -289,16 +289,32 @@ def add_auth_middleware(self, endpoint: str): injected_dependencies: dict = {} def decorator(handler): - @wraps(handler) - def inner_handler(request: Request, *args): - if not self.authentication_handler: - raise AuthenticationNotConfiguredError() - identity = self.authentication_handler.authenticate(request) - if identity is None: - return self.authentication_handler.unauthorized_response - request.identity = identity - - return request + if inspect.iscoroutinefunction(self.authentication_handler.authenticate): + + @wraps(handler) + async def inner_handler(request: Request, *args): + if not self.authentication_handler: + raise AuthenticationNotConfiguredError() + + identity = await self.authentication_handler.authenticate(request) + if identity is None: + return self.authentication_handler.unauthorized_response + request.identity = identity + + return request + else: + + @wraps(handler) + def inner_handler(request: Request, *args): + if not self.authentication_handler: + raise AuthenticationNotConfiguredError() + + identity = self.authentication_handler.authenticate(request) + if identity is None: + return self.authentication_handler.unauthorized_response + request.identity = identity + + return request self.add_route( MiddlewareType.BEFORE_REQUEST,