diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000..af28862 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 88 +select = C,E,F,W,B,B9 +ignore = E203, E501, W503 +exclude = __init__.py \ No newline at end of file diff --git a/README.md b/README.md index a251dab..790f252 100644 --- a/README.md +++ b/README.md @@ -9,8 +9,8 @@ fastapi-cloudauth standardizes and simplifies the integration between FastAPI an ## Features * [X] Verify access/id token -* [X] Authenticate permission based on scope (or groups) within access token -* [X] Get login user info (name, email, etc.) within ID token +* [X] Authenticate permission based on scope (or groups) within access token and Extract user info +* [X] Get the detail of login user info (name, email, etc.) within ID token * [X] Dependency injection for verification/getting user, powered by [FastAPI](https://github.com/tiangolo/fastapi) * [X] Support for: * [X] [AWS Cognito](https://aws.amazon.com/jp/cognito/) @@ -35,7 +35,7 @@ $ pip install fastapi-cloudauth * Create a user assigned `read:users` permission in AWS Cognito * Get Access/ID token for the created user -NOTE: access token is valid for verification and scope-based authentication. ID token is valid for verification and getting user info from claims. +NOTE: access token is valid for verification, scope-based authentication and getting user info (optional). ID token is valid for verification and getting full user info from claims. ### Create it @@ -43,6 +43,7 @@ Create a file main.py with: ```python3 import os +from pydantic import BaseModel from fastapi import FastAPI, Depends from fastapi_cloudauth.cognito import Cognito, CognitoCurrentUser, CognitoClaims @@ -56,6 +57,16 @@ def secure(): return "Hello" +class AccessUser(BaseModel): + sub: str + + +@app.get("/access/") +def secure_access(current_user: AccessUser = Depends(auth.claim(AccessUser))): + # access token is valid and getting user info from access token + return f"Hello", {current_user.sub} + + get_current_user = CognitoCurrentUser( region=os.environ["REGION"], userPoolId=os.environ["USERPOOLID"] ) @@ -63,7 +74,7 @@ get_current_user = CognitoCurrentUser( @app.get("/user/") def secure_user(current_user: CognitoClaims = Depends(get_current_user)): - # ID token is valid + # ID token is valid and getting user info from ID token return f"Hello, {current_user.username}" ``` @@ -105,6 +116,7 @@ Create a file main.py with: ```python3 import os +from pydantic import BaseModel from fastapi import FastAPI, Depends from fastapi_cloudauth.auth0 import Auth0, Auth0CurrentUser, Auth0Claims @@ -119,12 +131,22 @@ def secure(): return "Hello" +class AccessUser(BaseModel): + sub: str + + +@app.get("/access/") +def secure_access(current_user: AccessUser = Depends(auth.claim(AccessUser))): + # access token is valid and getting user info from access token + return f"Hello", {current_user.sub} + + get_current_user = Auth0CurrentUser(domain=os.environ["DOMAIN"]) @app.get("/user/") def secure_user(current_user: Auth0Claims = Depends(get_current_user)): - # ID token is valid + # ID token is valid and getting user info from ID token return f"Hello, {current_user.username}" ``` @@ -153,15 +175,18 @@ get_current_user = FirebaseCurrentUser() @app.get("/user/") def secure_user(current_user: FirebaseClaims = Depends(get_current_user)): - # ID token is valid + # ID token is valid and getting user info from ID token return f"Hello, {current_user.user_id}" ``` Try to run the server and see interactive UI in the same way. -## Custom claims +## Additional User Information + +We can get values for the current user from access/ID token by writing a few lines. + +### Custom Claims -We can get values for the current user by writing a few lines. For Auth0, the ID token contains extra values as follows (Ref at [Auth0 official doc](https://auth0.com/docs/tokens)): ```json @@ -183,7 +208,7 @@ For Auth0, the ID token contains extra values as follows (Ref at [Auth0 official By default, `Auth0CurrentUser` gives `pydantic.BaseModel` object, which has `username` (name) and `email` fields. -Here is sample code for extracting extra user information (adding `user_id`): +Here is sample code for extracting extra user information (adding `user_id`) from ID token: ```python3 from pydantic import Field @@ -197,6 +222,31 @@ get_current_user = Auth0CurrentUser(domain=DOMAIN) get_current_user.user_info = CustomAuth0Claims # override user info model with a custom one. ``` +Or, we can also set new custom claims as follows: + +```python3 +get_user_detail = get_current_user.claim(CustomAuth0Claims) + +@app.get("/new/") +async def detail(user: CustomAuth0Claims = Depends(get_user_detail)): + return f"Hello, {user.user_id}" +``` + +### Raw payload + +If you doesn't require `pydantic` data serialization (validation), `FastAPI-CloudAuth` has a option to extract raw payload. + +All you need is: + +```python3 +get_raw_info = get_current_user.claim(None) + +@app.get("/new/") +async def raw_detail(user = Depends(get_raw_info)): + # user has all items (ex. iss, sub, aud, exp, ... it depends on passed token) + return f"Hello, {user.get('sub')}" +``` + ## Development - Contributing Please read the [CONTRIBUTING](../CONTRIBUTING.md) how to setup development environment and testing. diff --git a/docs/server/auth0.py b/docs/server/auth0.py new file mode 100644 index 0000000..29ae74a --- /dev/null +++ b/docs/server/auth0.py @@ -0,0 +1,44 @@ +import os +from pydantic import BaseModel +from fastapi import FastAPI, Depends +from fastapi_cloudauth.auth0 import Auth0, Auth0CurrentUser, Auth0Claims + +tags_metadata = [ + { + "name": "Auth0", + "description": "Operations with access/ID token, provided by Auth0.", + } +] + +app = FastAPI( + title="FastAPI CloudAuth Project", + description="Simple integration between FastAPI and cloud authentication services (AWS Cognito, Auth0, Firebase Authentication).", + openapi_tags=tags_metadata, +) + +auth = Auth0(domain=os.environ["AUTH0_DOMAIN"]) + + +@app.get("/", dependencies=[Depends(auth.scope("read:users"))], tags=["Auth0"]) +def secure(): + # access token is valid + return "Hello" + + +class AccessUser(BaseModel): + sub: str + + +@app.get("/access/", tags=["Auth0"]) +def secure_access(current_user: AccessUser = Depends(auth.claim(AccessUser))): + # access token is valid and getting user info from access token + return f"Hello", {current_user.sub} + + +get_current_user = Auth0CurrentUser(domain=os.environ["AUTH0_DOMAIN"]) + + +@app.get("/user/", tags=["Auth0"]) +def secure_user(current_user: Auth0Claims = Depends(get_current_user)): + # ID token is valid and getting user info from ID token + return f"Hello, {current_user.username}" diff --git a/docs/server/cognito.py b/docs/server/cognito.py new file mode 100644 index 0000000..36958fe --- /dev/null +++ b/docs/server/cognito.py @@ -0,0 +1,48 @@ +import os +from pydantic import BaseModel +from fastapi import FastAPI, Depends +from fastapi_cloudauth.cognito import Cognito, CognitoCurrentUser, CognitoClaims + +tags_metadata = [ + { + "name": "Cognito", + "description": "Operations with access/ID token, provided by AWS Cognito.", + } +] + +app = FastAPI( + title="FastAPI CloudAuth Project", + description="Simple integration between FastAPI and cloud authentication services (AWS Cognito, Auth0, Firebase Authentication).", + openapi_tags=tags_metadata, +) + +auth = Cognito( + region=os.environ["COGNITO_REGION"], userPoolId=os.environ["COGNITO_USERPOOLID"] +) + + +@app.get("/", dependencies=[Depends(auth.scope("read:users"))], tags=["Cognito"]) +def secure(): + # access token is valid + return "Hello" + + +class AccessUser(BaseModel): + sub: str + + +@app.get("/access/", tags=["Cognito"]) +def secure_access(current_user: AccessUser = Depends(auth.claim(AccessUser))): + # access token is valid and getting user info from access token + return f"Hello", {current_user.sub} + + +get_current_user = CognitoCurrentUser( + region=os.environ["COGNITO_REGION"], userPoolId=os.environ["COGNITO_USERPOOLID"] +) + + +@app.get("/user/", tags=["Cognito"]) +def secure_user(current_user: CognitoClaims = Depends(get_current_user)): + # ID token is valid and getting user info from ID token + return f"Hello, {current_user.username}" diff --git a/docs/server/firebase.py b/docs/server/firebase.py new file mode 100644 index 0000000..439344c --- /dev/null +++ b/docs/server/firebase.py @@ -0,0 +1,23 @@ +from fastapi import FastAPI, Depends +from fastapi_cloudauth.firebase import FirebaseCurrentUser, FirebaseClaims + +tags_metadata = [ + { + "name": "Firebase", + "description": "Operations with access/ID token, provided by Firebase Authentication.", + } +] + +app = FastAPI( + title="FastAPI CloudAuth Project", + description="Simple integration between FastAPI and cloud authentication services (AWS Cognito, Auth0, Firebase Authentication).", + openapi_tags=tags_metadata, +) + +get_current_user = FirebaseCurrentUser() + + +@app.get("/user/", tags=["Firebase"]) +def secure_user(current_user: FirebaseClaims = Depends(get_current_user)): + # ID token is valid and getting user info from ID token + return f"Hello, {current_user.user_id}" diff --git a/fastapi_cloudauth/auth0.py b/fastapi_cloudauth/auth0.py index 0e56fc3..2139e14 100644 --- a/fastapi_cloudauth/auth0.py +++ b/fastapi_cloudauth/auth0.py @@ -1,18 +1,29 @@ +from typing import Any, Optional + from pydantic import BaseModel, Field -from .base import TokenVerifier, TokenUserInfoGetter, JWKS + +from .base import ScopedAuth, UserInfoAuth +from .verification import JWKS -class Auth0(TokenVerifier): +class Auth0(ScopedAuth): """ Verify access token of auth0 """ - scope_key = "permissions" + user_info = None - def __init__(self, domain: str, *args, **kwargs): + def __init__( + self, + domain: str, + scope_key: Optional[str] = "permissions", + auto_error: bool = True, + ): url = f"https://{domain}/.well-known/jwks.json" jwks = JWKS.fromurl(url) - super().__init__(jwks, *args, **kwargs) + super().__init__( + jwks, scope_key=scope_key, auto_error=auto_error, + ) class Auth0Claims(BaseModel): @@ -20,14 +31,16 @@ class Auth0Claims(BaseModel): email: str = Field(None, alias="email") -class Auth0CurrentUser(TokenUserInfoGetter): +class Auth0CurrentUser(UserInfoAuth): """ Verify ID token and get user info of Auth0 """ user_info = Auth0Claims - def __init__(self, domain: str, *args, **kwargs): + def __init__( + self, domain: str, *args: Any, **kwargs: Any, + ): url = f"https://{domain}/.well-known/jwks.json" jwks = JWKS.fromurl(url) - super().__init__(jwks, *args, **kwargs) + super().__init__(jwks, *args, user_info=self.user_info, **kwargs) diff --git a/fastapi_cloudauth/base.py b/fastapi_cloudauth/base.py index 52fcedc..70f2015 100644 --- a/fastapi_cloudauth/base.py +++ b/fastapi_cloudauth/base.py @@ -1,161 +1,62 @@ -from typing import List, Dict, Optional, Any, Type -import requests +from abc import ABC, abstractmethod from copy import deepcopy -from jose import jwk, jwt -from jose.utils import base64url_decode -from jose.backends.base import Key +from typing import Any, Dict, Optional, Type, Union + from fastapi import Depends, HTTPException -from fastapi.security import HTTPBearer, HTTPAuthorizationCredentials +from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer +from jose import jwt # type: ignore from pydantic import BaseModel from pydantic.error_wrappers import ValidationError from starlette import status -NOT_AUTHENTICATED = "Not authenticated" -NO_PUBLICKEY = "JWK public Attribute for authorization token not found" -NOT_VERIFIED = "Not verified" -SCOPE_NOT_MATCHED = "Scope not matched" -NOT_VALIDATED_CLAIMS = "Validation Error for Claims" - - -class JWKS: - # keys: List[Dict[str, Any]] - keys: Dict[str, Key] +from fastapi_cloudauth.messages import NOT_AUTHENTICATED, NOT_VALIDATED_CLAIMS +from fastapi_cloudauth.verification import ( + JWKS, + JWKsVerifier, + ScopedJWKsVerifier, + Verifier, +) - def __init__(self, keys: Dict[str, Key]): - self.keys = keys - - @classmethod - def fromurl(cls, url: str): - """ - get and parse json into jwks from endpoint as follows, - https://xxx/.well-known/jwks.json - """ - # return cls.parse_obj(requests.get(url).json()) - jwks = requests.get(url).json() - - jwks = {_jwk["kid"]: jwk.construct(_jwk) for _jwk in jwks.get("keys", [])} - return cls(keys=jwks) - - @classmethod - def firebase(cls, url: str): - """ - get and parse json into jwks from endpoint for Firebase, - """ - certs = requests.get(url).json() - keys = { - kid: jwk.construct(publickey, algorithm="RS256") - for kid, publickey in certs.items() - } - return cls(keys=keys) +class CloudAuth(ABC): + @property + @abstractmethod + def verifier(self) -> Verifier: + """Composite Verifier class to verify jwt in HTTPAuthorizationCredentials""" + ... # pragma: no cover -class BaseTokenVerifier: - def __init__(self, jwks: JWKS, auto_error: bool = True, *args, **kwargs): - """ - auto-error: if False, return payload as b'null' for invalid token. - """ - self.jwks_to_key = jwks.keys + @verifier.setter + def verifier(self, instance: Verifier) -> None: + ... # pragma: no cover - self.scope_name: Optional[str] = None - self.auto_error = auto_error + @abstractmethod + async def call(self, http_auth: HTTPAuthorizationCredentials) -> Any: + """Define postprocess for verified token""" + ... # pragma: no cover - def clone(self): + def clone(self, instance: "CloudAuth") -> "CloudAuth": """create clone instanse""" - # In some case, self.jwks_to_key can't pickle (deepcopy). + # In some case, Verifier can't pickle (deepcopy). # Tempolary put it aside to deepcopy. Then, undo it at the last line. - jwks_to_key = self.jwks_to_key - self.jwks_to_key = {} - clone = deepcopy(self) - clone.jwks_to_key = jwks_to_key - - # undo original instanse - self.jwks_to_key = jwks_to_key - return clone - - def get_publickey(self, http_auth: HTTPAuthorizationCredentials): - token = http_auth.credentials - header = jwt.get_unverified_header(token) - kid = header.get("kid") - if not kid: - if self.auto_error: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=NOT_AUTHENTICATED - ) - else: - return None - publickey = self.jwks_to_key.get(kid) - if not publickey: - if self.auto_error: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=NO_PUBLICKEY, - ) - else: - return None - return publickey - - def verify_token(self, http_auth: HTTPAuthorizationCredentials) -> bool: - public_key = self.get_publickey(http_auth) - if not public_key: - # error handling is included in self.get_publickey - return False - - message, encoded_sig = http_auth.credentials.rsplit(".", 1) - decoded_sig = base64url_decode(encoded_sig.encode()) - is_verified = public_key.verify(message.encode(), decoded_sig) - - if not is_verified: - if self.auto_error: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=NOT_VERIFIED - ) - - return is_verified - - -class TokenVerifier(BaseTokenVerifier): - """ - Verify `Access token` and authorize it based on scope (or groups) - """ - - scope_key: Optional[str] = None - - def scope(self, scope_name: str): - """User-SCOPE verification Shortcut to pass it into dependencies. - Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): - ``` - from fastapi import Depends - - @app.get("/", dependencies=[Depends(auth.scope("allowed scope"))]) - def api(): - return "hello" - ``` - """ - clone = self.clone() - clone.scope_name = scope_name - if not clone.scope_key: - raise AttributeError("declaire scope_key to set scope") + if not isinstance(instance, CloudAuth): + raise TypeError( + "Only subclass of CloudAuth can be cloned" + ) # pragma: no cover + + _verifier = instance.verifier + instance.verifier = None # type: ignore + clone = deepcopy(instance) + clone.verifier = _verifier.clone(_verifier) + instance.verifier = _verifier return clone - def verify_scope(self, http_auth: HTTPAuthorizationCredentials) -> bool: - claims = jwt.get_unverified_claims(http_auth.credentials) - scopes = claims.get(self.scope_key) - if isinstance(scopes, str): - scopes = {scope.strip() for scope in scopes.split()} - if scopes is None or self.scope_name not in scopes: - if self.auto_error: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=SCOPE_NOT_MATCHED, - ) - return False - return True - async def __call__( self, http_auth: Optional[HTTPAuthorizationCredentials] = Depends( HTTPBearer(auto_error=False) ), - ) -> Optional[bool]: - """User access-token verification Shortcut to pass it into dependencies. + ) -> Any: + """User access/ID-token verification Shortcut to pass it into dependencies. Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): ``` from fastapi import Depends @@ -166,45 +67,76 @@ def api(): ``` """ if http_auth is None: - if self.auto_error: + if self.verifier.auto_error: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail=NOT_AUTHENTICATED ) else: return None - is_verified = self.verify_token(http_auth) + is_verified = self.verifier.verify_token(http_auth) if not is_verified: return None - if self.scope_name: - is_verified_scope = self.verify_scope(http_auth) - if not is_verified_scope: - return None - - return True + return await self.call(http_auth) -class TokenUserInfoGetter(BaseTokenVerifier): +class UserInfoAuth(CloudAuth): """ Verify `ID token` and extract user information """ user_info: Optional[Type[BaseModel]] = None - def __init__(self, *args, **kwargs): - if not self.user_info: - raise AttributeError( - "must assign custom pydantic.BaseModel into class attributes `user_info`" - ) - super().__init__(*args, **kwargs) - - async def __call__( + def __init__( self, - http_auth: Optional[HTTPAuthorizationCredentials] = Depends( - HTTPBearer(auto_error=False) - ), - ) -> Optional[BaseModel]: + jwks: JWKS, + *, + user_info: Optional[Type[BaseModel]] = None, + auto_error: bool = True, + **kwargs: Any + ) -> None: + + self.user_info = user_info + self.auto_error = auto_error + self._verifier = JWKsVerifier(jwks, auto_error=self.auto_error) + + @property + def verifier(self) -> JWKsVerifier: + return self._verifier + + @verifier.setter + def verifier(self, verifier: JWKsVerifier) -> None: + self._verifier = verifier + + def _clone(self) -> "UserInfoAuth": + cloned = super().clone(self) + if isinstance(cloned, UserInfoAuth): + return cloned + raise NotImplementedError # pragma: no cover + + def claim(self, schema: Optional[Type[BaseModel]] = None) -> "UserInfoAuth": + """User verification and validation shortcut to pass it into app arguments. + Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): + ``` + from fastapi import Depends + from pydantic import BaseModel + + class CustomClaim(BaseModel): + sub: str + + @app.get("/") + def api(user: CustomClaim = Depends(auth.claim(CustomClaim))): + return CustomClaim + ``` + """ + clone = self._clone() + clone.user_info = schema + return clone + + async def call( + self, http_auth: HTTPAuthorizationCredentials + ) -> Optional[Union[BaseModel, Dict[str, Any]]]: """Get current user and verification with ID-token Shortcut. Use as (`Auth` is this subclass, `auth` is `Auth` instanse and `app` is fastapi.FastAPI instanse): ``` @@ -215,19 +147,139 @@ def api(current_user: Auth = Depends(auth)): return current_user ``` """ - if http_auth is None: + claims: Dict[str, Any] = jwt.get_unverified_claims(http_auth.credentials) + + if not self.user_info: + return claims + + try: + current_user = self.user_info.parse_obj(claims) + return current_user + except ValidationError: if self.auto_error: raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=NOT_AUTHENTICATED + status_code=status.HTTP_403_FORBIDDEN, detail=NOT_VALIDATED_CLAIMS, ) else: return None - is_verified = self.verify_token(http_auth) - if not is_verified: - return None - claims = jwt.get_unverified_claims(http_auth.credentials) +class ScopedAuth(CloudAuth): + """ + Verify `Access token` and authorize it based on scope (or groups) + """ + + _scope_key: Optional[str] = None + user_info: Optional[Type[BaseModel]] = None + + def __init__( + self, + jwks: JWKS, + user_info: Optional[Type[BaseModel]] = None, + scope_name: Optional[str] = None, + scope_key: Optional[str] = None, + auto_error: bool = True, + ): + self.user_info = user_info + self.auto_error = auto_error + self._scope_name = scope_name + if scope_key: + self._scope_key = scope_key + + self._verifier = ScopedJWKsVerifier( + jwks, + scope_name=self._scope_name, + scope_key=self._scope_key, + auto_error=self.auto_error, + ) + + @property + def verifier(self) -> ScopedJWKsVerifier: + return self._verifier + + @verifier.setter + def verifier(self, verifier: ScopedJWKsVerifier) -> None: + self._verifier = verifier + + @property + def scope_key(self) -> Optional[str]: + return self._scope_key + + @scope_key.setter + def scope_key(self, key: Optional[str]) -> None: + self._scope_key = key + self._verifier.scope_key = key + + @property + def scope_name(self) -> Optional[str]: + return self._scope_name + + @scope_name.setter + def scope_name(self, name: Optional[str]) -> None: + self._scope_name = name + self._verifier.scope_name = name + + def _clone(self) -> "ScopedAuth": + cloned = super().clone(self) + if isinstance(cloned, ScopedAuth): + return cloned + raise NotImplementedError # pragma: no cover + + def scope(self, scope_name: str) -> "ScopedAuth": + """User-SCOPE verification Shortcut to pass it into dependencies. + Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): + ``` + from fastapi import Depends + + @app.get("/", dependencies=[Depends(auth.scope("allowed scope"))]) + def api(): + return "hello" + ``` + """ + clone = self._clone() + clone.scope_name = scope_name + if not clone.scope_key: + raise AttributeError("declaire scope_key to set scope") + return clone + + def claim(self, schema: Optional[Type[BaseModel]] = None) -> "ScopedAuth": + """User verification and validation shortcut to pass it into app arguments. + Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): + ``` + from fastapi import Depends + from pydantic import BaseModel + + class CustomClaim(BaseModel): + sub: str + + @app.get("/") + def api(user: CustomClaim = Depends(auth.claim(CustomClaim))): + return CustomClaim + ``` + """ + clone = self._clone() + clone.user_info = schema + return clone + + async def call( + self, http_auth: HTTPAuthorizationCredentials + ) -> Optional[Union[Dict[str, Any], BaseModel, bool]]: + """User access-token verification Shortcut to pass it into dependencies. + Use as (`auth` is this instanse and `app` is fastapi.FastAPI instanse): + ``` + from fastapi import Depends + + @app.get("/", dependencies=[Depends(auth)]) + def api(): + return "hello" + ``` + """ + + claims: Dict[str, Any] = jwt.get_unverified_claims(http_auth.credentials) + + if not self.user_info: + return claims + try: current_user = self.user_info.parse_obj(claims) return current_user diff --git a/fastapi_cloudauth/cognito.py b/fastapi_cloudauth/cognito.py index 5b58af7..7f85080 100644 --- a/fastapi_cloudauth/cognito.py +++ b/fastapi_cloudauth/cognito.py @@ -1,18 +1,30 @@ +from typing import Any, Optional + from pydantic import BaseModel, Field -from .base import TokenVerifier, TokenUserInfoGetter, JWKS + +from .base import ScopedAuth, UserInfoAuth +from .verification import JWKS -class Cognito(TokenVerifier): +class Cognito(ScopedAuth): """ Verify access token of AWS Cognito """ - scope_key = "cognito:groups" + user_info = None - def __init__(self, region: str, userPoolId: str, *args, **kwargs): + def __init__( + self, + region: str, + userPoolId: str, + scope_key: Optional[str] = "cognito:groups", + auto_error: bool = True, + ): url = f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json" jwks = JWKS.fromurl(url) - super().__init__(jwks, *args, **kwargs) + super().__init__( + jwks, scope_key=scope_key, auto_error=auto_error, + ) class CognitoClaims(BaseModel): @@ -20,14 +32,16 @@ class CognitoClaims(BaseModel): email: str = Field(None, alias="email") -class CognitoCurrentUser(TokenUserInfoGetter): +class CognitoCurrentUser(UserInfoAuth): """ Verify ID token and get user info of AWS Cognito """ user_info = CognitoClaims - def __init__(self, region: str, userPoolId: str, *args, **kwargs): + def __init__( + self, region: str, userPoolId: str, *args: Any, **kwargs: Any, + ): url = f"https://cognito-idp.{region}.amazonaws.com/{userPoolId}/.well-known/jwks.json" jwks = JWKS.fromurl(url) - super().__init__(jwks, *args, **kwargs) + super().__init__(jwks, user_info=self.user_info, *args, **kwargs) diff --git a/fastapi_cloudauth/firebase.py b/fastapi_cloudauth/firebase.py index d1081ac..9dea230 100644 --- a/fastapi_cloudauth/firebase.py +++ b/fastapi_cloudauth/firebase.py @@ -1,5 +1,9 @@ +from typing import Any + from pydantic import BaseModel, Field -from .base import TokenUserInfoGetter, JWKS + +from .base import UserInfoAuth +from .verification import JWKS class FirebaseClaims(BaseModel): @@ -7,14 +11,14 @@ class FirebaseClaims(BaseModel): email: str = Field(None, alias="email") -class FirebaseCurrentUser(TokenUserInfoGetter): +class FirebaseCurrentUser(UserInfoAuth): """ Verify ID token and get user info of Firebase """ user_info = FirebaseClaims - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any): url = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com" jwks = JWKS.firebase(url) - super().__init__(jwks, *args, **kwargs) + super().__init__(jwks, *args, user_info=self.user_info, **kwargs) diff --git a/fastapi_cloudauth/messages.py b/fastapi_cloudauth/messages.py new file mode 100644 index 0000000..673c554 --- /dev/null +++ b/fastapi_cloudauth/messages.py @@ -0,0 +1,5 @@ +NOT_AUTHENTICATED = "Not authenticated" +NO_PUBLICKEY = "JWK public Attribute for authorization token not found" +NOT_VERIFIED = "Not verified" +SCOPE_NOT_MATCHED = "Scope not matched" +NOT_VALIDATED_CLAIMS = "Validation Error for Claims" diff --git a/fastapi_cloudauth/verification.py b/fastapi_cloudauth/verification.py new file mode 100644 index 0000000..ea2bc2c --- /dev/null +++ b/fastapi_cloudauth/verification.py @@ -0,0 +1,179 @@ +from abc import ABC, abstractmethod +from copy import deepcopy +from typing import Any, Dict, Optional + +import requests +from fastapi import HTTPException +from fastapi.security import HTTPAuthorizationCredentials +from jose import jwk # type: ignore +from jose import jwt +from jose.backends.base import Key # type: ignore +from jose.utils import base64url_decode # type: ignore +from starlette import status + +from fastapi_cloudauth.messages import ( + NO_PUBLICKEY, + NOT_AUTHENTICATED, + NOT_VERIFIED, + SCOPE_NOT_MATCHED, +) + + +class Verifier(ABC): + @property + @abstractmethod + def auto_error(self) -> bool: + ... # pragma: no cover + + @abstractmethod + def verify_token(self, http_auth: HTTPAuthorizationCredentials) -> bool: + ... # pragma: no cover + + @abstractmethod + def clone(self, instance: "Verifier") -> "Verifier": + """create clone instanse""" + ... # pragma: no cover + + +class JWKS: + keys: Dict[str, Key] + + def __init__(self, keys: Dict[str, Key]): + self.keys = keys + + @classmethod + def fromurl(cls, url: str) -> "JWKS": + """ + get and parse json into jwks from endpoint as follows, + https://xxx/.well-known/jwks.json + """ + jwks = requests.get(url).json() + + jwks = {_jwk["kid"]: jwk.construct(_jwk) for _jwk in jwks.get("keys", [])} + return cls(keys=jwks) + + @classmethod + def firebase(cls, url: str) -> "JWKS": + """ + get and parse json into jwks from endpoint for Firebase, + """ + certs = requests.get(url).json() + keys = { + kid: jwk.construct(publickey, algorithm="RS256") + for kid, publickey in certs.items() + } + return cls(keys=keys) + + +class JWKsVerifier(Verifier): + def __init__( + self, jwks: JWKS, auto_error: bool = True, *args: Any, **kwargs: Any + ) -> None: + """ + auto-error: if False, return payload as b'null' for invalid token. + """ + self._jwks_to_key = jwks.keys + self._auto_error = auto_error + + @property + def auto_error(self) -> bool: + return self._auto_error + + @auto_error.setter + def auto_error(self, auto_error: bool) -> None: + self._auto_error = auto_error + + def _get_publickey(self, http_auth: HTTPAuthorizationCredentials) -> Optional[Key]: + token = http_auth.credentials + header = jwt.get_unverified_header(token) + kid = header.get("kid") + if not kid: + if self.auto_error: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=NOT_AUTHENTICATED + ) + else: + return None + publickey: Optional[Key] = self._jwks_to_key.get(kid) + if not publickey: + if self.auto_error: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=NO_PUBLICKEY, + ) + else: + return None + return publickey + + def verify_token(self, http_auth: HTTPAuthorizationCredentials) -> bool: + public_key = self._get_publickey(http_auth) + if not public_key: + # error handling is included in self.get_publickey + return False + + message, encoded_sig = http_auth.credentials.rsplit(".", 1) + decoded_sig = base64url_decode(encoded_sig.encode()) + is_verified: bool = public_key.verify(message.encode(), decoded_sig) + + if not is_verified: + if self.auto_error: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=NOT_VERIFIED + ) + + return is_verified + + def clone(self, instance: "JWKsVerifier") -> "JWKsVerifier": # type: ignore[override] + _jwks_to_key = instance._jwks_to_key + instance._jwks_to_key = {} + clone = deepcopy(instance) + clone._jwks_to_key = _jwks_to_key + instance._jwks_to_key = _jwks_to_key + return clone + + +class ScopedJWKsVerifier(JWKsVerifier): + def __init__( + self, + jwks: JWKS, + scope_name: Optional[str] = None, + scope_key: Optional[str] = None, + auto_error: bool = True, + *args: Any, + **kwargs: Any + ) -> None: + """ + auto-error: if False, return payload as b'null' for invalid token. + """ + super().__init__(jwks, auto_error=auto_error) + self.scope_name = scope_name + self.scope_key = scope_key + + def clone(self, instance: "ScopedJWKsVerifier") -> "ScopedJWKsVerifier": # type: ignore[override] + cloned = super().clone(instance) + if isinstance(cloned, ScopedJWKsVerifier): + return cloned + raise NotImplementedError # pragma: no cover + + def _verify_scope(self, http_auth: HTTPAuthorizationCredentials) -> bool: + claims = jwt.get_unverified_claims(http_auth.credentials) + scopes = claims.get(self.scope_key) + if isinstance(scopes, str): + scopes = {scope.strip() for scope in scopes.split()} + if scopes is None or self.scope_name not in scopes: + if self.auto_error: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=SCOPE_NOT_MATCHED, + ) + return False + return True + + def verify_token(self, http_auth: HTTPAuthorizationCredentials) -> bool: + is_verified = super().verify_token(http_auth) + if not is_verified: + return False + + if self.scope_name: + is_verified_scope = self._verify_scope(http_auth) + return is_verified_scope + + return True diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000..9d6d3b0 --- /dev/null +++ b/mypy.ini @@ -0,0 +1,17 @@ +[mypy] + +# --strict +disallow_any_generics = True +disallow_subclassing_any = True +disallow_untyped_calls = True +disallow_untyped_defs = True +disallow_incomplete_defs = True +check_untyped_defs = True +disallow_untyped_decorators = True +no_implicit_optional = True +warn_redundant_casts = True +warn_unused_ignores = True +warn_return_any = True +implicit_reexport = False +strict_equality = True +# --strict end \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 5c2ccc3..4fb8b40 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -35,6 +35,7 @@ pytest-cov = "^2.10.0" flake8 = "^3.8.3" mypy = "^0.782" black = "^19.10b0" +isort = "^5.7.0" uvicorn = ">=0.12.0,<0.14.0" botocore = "^1.17.32" boto3 = "^1.14.32" @@ -42,6 +43,9 @@ authlib = "^0.15.2" firebase-admin = "^4.4.0" auth0-python = "^3.14.0" pytest-mock = "^3.5.1" +pytest-asyncio = "^0.14.0" +autoflake = "^1.4" + [build-system] requires = ["poetry>=0.12"] diff --git a/scripts/develop.sh b/scripts/develop.sh new file mode 100644 index 0000000..15b3756 --- /dev/null +++ b/scripts/develop.sh @@ -0,0 +1,6 @@ +#!/usr/bin/env bash + +source ./scripts/load_env.sh +uvicorn docs.server.auth0:app +uvicorn docs.server.cognito:app +uvicorn docs.server.firebase:app \ No newline at end of file diff --git a/scripts/format-imports.sh b/scripts/format-imports.sh new file mode 100644 index 0000000..899d832 --- /dev/null +++ b/scripts/format-imports.sh @@ -0,0 +1,6 @@ +#!/bin/sh -e +set -x + +# Sort imports one per line, so autoflake can remove unused imports +isort fastapi_cloudauth tests scripts --force-single-line-imports +sh ./scripts/format.sh \ No newline at end of file diff --git a/scripts/format.sh b/scripts/format.sh new file mode 100644 index 0000000..f2cf2ca --- /dev/null +++ b/scripts/format.sh @@ -0,0 +1,6 @@ +#!/bin/sh -e +set -x + +autoflake --remove-all-unused-imports --recursive --remove-unused-variables --in-place fastapi_cloudauth tests scripts --exclude=__init__.py +black fastapi_cloudauth tests scripts +isort fastapi_cloudauth tests scripts \ No newline at end of file diff --git a/scripts/lint.sh b/scripts/lint.sh new file mode 100644 index 0000000..94e6552 --- /dev/null +++ b/scripts/lint.sh @@ -0,0 +1,9 @@ +#!/usr/bin/env bash + +set -e +set -x + +mypy fastapi_cloudauth +flake8 fastapi_cloudauth tests +black fastapi_cloudauth tests --check +isort fastapi_cloudauth tests scripts --check-only \ No newline at end of file diff --git a/tests/conftest.py b/tests/conftest.py index 20ae728..9caa079 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,4 +1,3 @@ import pytest pytest.register_assert_rewrite("tests.helpers") - diff --git a/tests/helpers.py b/tests/helpers.py index 5547a46..46d5a19 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -1,5 +1,5 @@ -import json import base64 +import json class BaseTestCloudAuth: diff --git a/tests/test_auth0.py b/tests/test_auth0.py index 88af903..0cd300d 100644 --- a/tests/test_auth0.py +++ b/tests/test_auth0.py @@ -1,15 +1,16 @@ import os from sys import version_info as info -import requests from typing import Optional -from jose import jwt -from fastapi import Depends, FastAPI -from fastapi.testclient import TestClient + +import requests from auth0.v3.authentication import GetToken from auth0.v3.management import Auth0 as Auth0sdk +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient +from jose import jwt +from pydantic import BaseModel from fastapi_cloudauth.auth0 import Auth0, Auth0Claims, Auth0CurrentUser - from tests.helpers import BaseTestCloudAuth, decode_token DOMAIN = os.getenv("AUTH0_DOMAIN") @@ -81,7 +82,7 @@ def add_test_user( if scope: auth0.users.add_permissions( user_id, - [{"permission_name": scope, "resource_server_identifier": AUDIENCE,}], + [{"permission_name": scope, "resource_server_identifier": AUDIENCE}], ) @@ -107,12 +108,12 @@ def get_access_token( CLIENTID: Set client id of `Default App` in environment variable. See Applications in Auth0 dashboard CLIENT_SECRET: Set client secret of `Default App` in environment variable AUDIENCE: In Auth0 dashboard, create custom applications and API, - and add permission `read:test` into that API, + and add permission `read:test` into that API, and then copy the audience (identifier) in environment variable. NOTE: the followings setting in Auth0 dashboard is required - sidebar > Applications > settings > Advanced settings > grant: click `password` on - - top right icon > Set General > API Authorization Settings > Default Directory to Username-Password-Authentication + - top right icon > Set General > API Authorization Settings > Default Directory to Username-Password-Authentication """ resp = requests.post( f"https://{DOMAIN}/oauth/token", @@ -139,12 +140,12 @@ def get_id_token( CLIENTID: Set client id of `Default App` in environment variable. See Applications in Auth0 dashboard CLIENT_SECRET: Set client secret of `Default App` in environment variable AUDIENCE: In Auth0 dashboard, create custom applications and API, - and add permission `read:test` into that API, + and add permission `read:test` into that API, and then copy the audience (identifier) in environment variable. NOTE: the followings setting in Auth0 dashboard is required - sidebar > Applications > settings > Advanced settings > grant: click `password` on - - top right icon > Set General > API Authorization Settings > Default Directory to Username-Password-Authentication + - top right icon > Set General > API Authorization Settings > Default Directory to Username-Password-Authentication """ resp = requests.post( f"https://{DOMAIN}/oauth/token", @@ -220,13 +221,42 @@ class Auth0FakeCurrentUser(Auth0CurrentUser): ) @app.get("/") - async def secure(payload=Depends(auth)) -> bool: + async def secure(payload: bool = Depends(auth)) -> bool: return payload @app.get("/no-error/", dependencies=[Depends(auth_no_error)]) async def secure_no_error(payload=Depends(auth_no_error)) -> bool: return payload + class AccessClaim(BaseModel): + sub: str = None + + @app.get("/access/user") + async def secure_access_user( + payload: AccessClaim = Depends(auth.claim(AccessClaim)), + ): + assert isinstance(payload, AccessClaim) + return payload + + @app.get("/access/user/no-error/") + async def secure_access_user_no_error( + payload: AccessClaim = Depends(auth_no_error.claim(AccessClaim)), + ) -> Optional[AccessClaim]: + return payload + + class InvalidAccessClaim(BaseModel): + fake_field: str + + @app.get("/access/user/invalid") + async def invalid_access_user(payload=Depends(auth.claim(InvalidAccessClaim)),): + return payload # pragma: no cover + + @app.get("/access/user/invalid/no-error/") + async def invalid_access_user_no_error( + payload=Depends(auth_no_error.claim(InvalidAccessClaim)), + ) -> Optional[InvalidAccessClaim]: + assert payload is None + @app.get("/scope/") async def secure_scope(payload=Depends(auth.scope(self.scope))) -> bool: pass diff --git a/tests/test_base.py b/tests/test_base.py index ff120b7..07befd5 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -1,14 +1,16 @@ import pytest from fastapi import HTTPException from fastapi.security import HTTPAuthorizationCredentials +from pydantic import BaseModel -from fastapi_cloudauth.base import JWKS, TokenUserInfoGetter, TokenVerifier +from fastapi_cloudauth.base import ScopedAuth, UserInfoAuth +from fastapi_cloudauth.verification import JWKS @pytest.mark.unittest def test_raise_error_invalid_set_scope(): # scope_key is not declaired - token_verifier = TokenVerifier(jwks=JWKS(keys=[])) + token_verifier = ScopedAuth(jwks=JWKS(keys=[])) with pytest.raises(AttributeError): # raise AttributeError for invalid instanse attributes wrt scope token_verifier.scope("read:test") @@ -17,18 +19,22 @@ def test_raise_error_invalid_set_scope(): @pytest.mark.unittest def test_return_instance_with_scope(): # scope method return new instance to give it for Depends. - verifier = TokenVerifier(jwks=JWKS(keys=[])) - # must set scope_key (Inherit TokenVerifier and override scope_key attribute) + verifier = ScopedAuth(jwks=JWKS(keys=[])) + # must set scope_key (Inherit ScopedAuth and override scope_key attribute) scope_key = "dummy key" verifier.scope_key = scope_key scope_name = "required-scope" obj = verifier.scope(scope_name) - assert isinstance(obj, TokenVerifier) + assert isinstance(obj, ScopedAuth) assert obj.scope_key == scope_key, "scope_key mustn't be cleared." assert obj.scope_name == scope_name, "Must set scope_name in returned instanse." - assert obj.jwks_to_key == verifier.jwks_to_key, "return cloned objects" - assert obj.auto_error == verifier.auto_error, "return cloned objects" + assert ( + obj.verifier._jwks_to_key == verifier.verifier._jwks_to_key + ), "return cloned objects" + assert ( + obj.verifier.auto_error == verifier.verifier.auto_error + ), "return cloned objects" @pytest.mark.unittest @@ -42,27 +48,106 @@ def test_return_instance_with_scope(): ) def test_validation_scope(mocker, scopes): mocker.patch( - "fastapi_cloudauth.base.jwt.get_unverified_claims", + "fastapi_cloudauth.verification.jwt.get_unverified_claims", return_value={"dummy key": scopes}, ) - verifier = TokenVerifier(jwks=JWKS(keys=[])) + verifier = ScopedAuth(jwks=JWKS(keys=[])) scope_key = "dummy key" verifier.scope_key = scope_key scope_name = "user-assigned-scope" obj = verifier.scope(scope_name) - assert obj.verify_scope(HTTPAuthorizationCredentials(scheme="", credentials="")) + assert obj.verifier._verify_scope( + HTTPAuthorizationCredentials(scheme="", credentials="") + ) scope_name = "user-assigned-scope-invalid" obj = verifier.scope(scope_name) with pytest.raises(HTTPException): - obj.verify_scope(HTTPAuthorizationCredentials(scheme="", credentials="")) + obj.verifier._verify_scope( + HTTPAuthorizationCredentials(scheme="", credentials="") + ) - obj.auto_error = False - assert not obj.verify_scope(HTTPAuthorizationCredentials(scheme="", credentials="")) + obj.verifier.auto_error = False + assert not obj.verifier._verify_scope( + HTTPAuthorizationCredentials(scheme="", credentials="") + ) @pytest.mark.unittest -def test_forget_def_user_info(): - with pytest.raises(AttributeError): - TokenUserInfoGetter() +@pytest.mark.asyncio +@pytest.mark.parametrize( + "auth", [UserInfoAuth, ScopedAuth], +) +async def test_forget_def_user_info(auth): + dummy_http_auth = HTTPAuthorizationCredentials( + scheme="a", + credentials="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Im5hbWUiLCJpYXQiOjE1MTYyMzkwMjJ9.3ZEDmhWNZWDbJDPDlZX_I3oaalNYXdoT-bKLxIxQK4U", + ) + """If `.user_info` is None, return raw payload""" + get_current_user = auth(jwks=JWKS(keys=[])) + assert get_current_user.user_info is None + res = await get_current_user.call(dummy_http_auth) + assert res == {"sub": "1234567890", "name": "name", "iat": 1516239022} + + +@pytest.mark.unittest +@pytest.mark.asyncio +@pytest.mark.parametrize( + "auth", [UserInfoAuth, ScopedAuth], +) +async def test_assign_user_info(auth): + """three way to set user info schema + 1. pass it to arguments when create instance + 2. call `.claim` method and pass it to that arguments + 3. assign with `=` statements + """ + + class SubSchema(BaseModel): + sub: str + + class NameSchema(BaseModel): + name: str + + class IatSchema(BaseModel): + iat: int + + # authorized token + dummy_http_auth = HTTPAuthorizationCredentials( + scheme="a", + credentials="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Im5hbWUiLCJpYXQiOjE1MTYyMzkwMjJ9.3ZEDmhWNZWDbJDPDlZX_I3oaalNYXdoT-bKLxIxQK4U", + ) + + user = auth(jwks=JWKS(keys=[]), user_info=IatSchema) + assert await user.call(dummy_http_auth) == IatSchema(iat=1516239022) + + assert await user.claim(SubSchema).call(dummy_http_auth) == SubSchema( + sub="1234567890" + ) + + user.user_info = NameSchema + assert await user.call(dummy_http_auth) == NameSchema(name="name") + + +@pytest.mark.unittest +@pytest.mark.asyncio +@pytest.mark.parametrize( + "auth", [UserInfoAuth, ScopedAuth], +) +async def test_extract_raw_user_info(auth): + dummy_http_auth = HTTPAuthorizationCredentials( + scheme="a", + credentials="eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6Im5hbWUiLCJpYXQiOjE1MTYyMzkwMjJ9.3ZEDmhWNZWDbJDPDlZX_I3oaalNYXdoT-bKLxIxQK4U", + ) + + class NameSchema(BaseModel): + name: str + + get_current_user = auth(jwks=JWKS(keys=[]), user_info=NameSchema) + get_current_user.user_info = None + res = await get_current_user.call(dummy_http_auth) + assert res == {"sub": "1234567890", "name": "name", "iat": 1516239022} + + get_current_user = auth(jwks=JWKS(keys=[]), user_info=NameSchema) + res = await get_current_user.claim(None).call(dummy_http_auth) + assert res == {"sub": "1234567890", "name": "name", "iat": 1516239022} diff --git a/tests/test_cloudauth.py b/tests/test_cloudauth.py index 8fb647a..6a7e50a 100644 --- a/tests/test_cloudauth.py +++ b/tests/test_cloudauth.py @@ -1,14 +1,8 @@ -from typing import List - import pytest -from fastapi_cloudauth.base import ( - NOT_AUTHENTICATED, - NO_PUBLICKEY, - NOT_VERIFIED, - SCOPE_NOT_MATCHED, - NOT_VALIDATED_CLAIMS, -) +from fastapi_cloudauth.messages import (NO_PUBLICKEY, NOT_AUTHENTICATED, + NOT_VALIDATED_CLAIMS, NOT_VERIFIED, + SCOPE_NOT_MATCHED) from tests.helpers import assert_get_response from tests.test_auth0 import Auth0Client from tests.test_cognito import CognitoClient @@ -47,6 +41,12 @@ def success_case(self, path: str, token: str = None): client=self.client, endpoint=path, token=token, status_code=200 ) + def userinfo_success_case(self, path: str, token: str = None): + response = self.success_case(path, token) + for value in response.json().values(): + assert value, f"{response.content} failed to parse" + return response + def failure_case(self, path: str, token: str = None, detail=""): return assert_get_response( client=self.client, @@ -60,7 +60,6 @@ def test_valid_token(self): self.success_case("/", self.ACCESS_TOKEN) def test_no_token(self): - # handle in fastapi.security.HTTPBearer self.failure_case("/") # not auto_error self.success_case("no-error") @@ -98,6 +97,22 @@ def test_invalid_scope(self): self.failure_case("/scope/", self.ACCESS_TOKEN, detail=SCOPE_NOT_MATCHED) self.success_case("/scope/no-error/", self.ACCESS_TOKEN) + def test_valid_token_extraction(self): + self.userinfo_success_case("/access/user", self.ACCESS_TOKEN) + + def test_no_token_extraction(self): + self.failure_case("/access/user") + # not auto_error + self.success_case("/access/user/no-error") + + def test_insufficient_user_info_from_access_token(self): + # verified but token does not contains user info + self.failure_case( + "/access/user/invalid/", self.ACCESS_TOKEN, detail=NOT_VALIDATED_CLAIMS + ) + # not auto_error + self.success_case("/access/user/invalid/no-error", self.ACCESS_TOKEN) + class IdTokenTestCase(BaseTestCloudAuth): def success_case(self, path: str, token: str = None): diff --git a/tests/test_cognito.py b/tests/test_cognito.py index 21d99c9..5e57709 100644 --- a/tests/test_cognito.py +++ b/tests/test_cognito.py @@ -1,14 +1,15 @@ import os from sys import version_info as info +from typing import Optional + import boto3 from botocore.exceptions import ClientError -from typing import Optional from fastapi import Depends, FastAPI from fastapi.testclient import TestClient +from pydantic.main import BaseModel from fastapi_cloudauth import Cognito, CognitoCurrentUser from fastapi_cloudauth.cognito import CognitoClaims - from tests.helpers import BaseTestCloudAuth, decode_token REGION = os.getenv("COGNITO_REGION") @@ -40,19 +41,19 @@ def add_test_user( password="testPass1-", scope: Optional[str] = None, ): - resp = client.sign_up( + client.sign_up( ClientId=CLIENTID, Username=username, Password=password, - UserAttributes=[{"Name": "email", "Value": username},], + UserAttributes=[{"Name": "email", "Value": username}], ) - resp = client.admin_confirm_sign_up(UserPoolId=USERPOOLID, Username=username) + client.admin_confirm_sign_up(UserPoolId=USERPOOLID, Username=username) if scope: try: - resp = client.create_group(GroupName=scope, UserPoolId=USERPOOLID) - except ClientError as e: # pragma: no cover + client.create_group(GroupName=scope, UserPoolId=USERPOOLID) + except ClientError: # pragma: no cover pass # pragma: no cover - resp = client.admin_add_user_to_group( + client.admin_add_user_to_group( UserPoolId=USERPOOLID, Username=username, GroupName=scope, ) @@ -77,8 +78,8 @@ def delete_cognito_user( client, username=f"test_user{info.major}{info.minor}@example.com", ): try: - response = client.admin_delete_user(UserPoolId=USERPOOLID, Username=username) - except: # pragma: no cover + client.admin_delete_user(UserPoolId=USERPOOLID, Username=username) + except Exception: # pragma: no cover pass # pragma: no cover @@ -138,6 +139,35 @@ async def secure(payload=Depends(auth)) -> bool: async def secure_no_error(payload=Depends(auth_no_error)): assert payload is None + class AccessClaim(BaseModel): + sub: str = None + + @app.get("/access/user") + async def secure_access_user( + payload: AccessClaim = Depends(auth.claim(AccessClaim)), + ): + assert isinstance(payload, AccessClaim) + return payload + + @app.get("/access/user/no-error/") + async def secure_access_user_no_error( + payload: AccessClaim = Depends(auth_no_error.claim(AccessClaim)), + ) -> Optional[AccessClaim]: + return payload + + class InvalidAccessClaim(BaseModel): + fake_field: str + + @app.get("/access/user/invalid") + async def invalid_access_user(payload=Depends(auth.claim(InvalidAccessClaim)),): + return payload # pragma: no cover + + @app.get("/access/user/invalid/no-error/") + async def invalid_access_user_no_error( + payload=Depends(auth_no_error.claim(InvalidAccessClaim)), + ) -> Optional[InvalidAccessClaim]: + assert payload is None + @app.get("/scope/", dependencies=[Depends(auth.scope(self.scope))]) async def secure_scope() -> bool: pass @@ -190,4 +220,3 @@ def decode(self): # id token id_header, id_payload, *_ = decode_token(self.ID_TOKEN) assert id_payload.get("email") == self.user - diff --git a/tests/test_firebase.py b/tests/test_firebase.py index 1a50a18..100ac6a 100644 --- a/tests/test_firebase.py +++ b/tests/test_firebase.py @@ -1,22 +1,19 @@ -import os import base64 import json +import os import tempfile -from typing import Optional from sys import version_info as info +from typing import Optional +import firebase_admin +import requests from fastapi import Depends, FastAPI from fastapi.testclient import TestClient -from firebase_admin.auth import delete_user -import requests -import firebase_admin -from firebase_admin import auth -from firebase_admin import credentials +from firebase_admin import auth, credentials from fastapi_cloudauth import FirebaseCurrentUser from fastapi_cloudauth.firebase import FirebaseClaims - -from tests.helpers import assert_get_response, decode_token, BaseTestCloudAuth +from tests.helpers import BaseTestCloudAuth, decode_token API_KEY = os.getenv("FIREBASE_APIKEY") BASE64_CREDENTIAL = os.getenv("FIREBASE_BASE64_CREDENCIALS")