Skip to content

Commit

Permalink
Merge pull request #35 from claroty/feature/abuddi/upgrade-sqlalchemy…
Browse files Browse the repository at this point in the history
…-to-1.4.0-remove-asyncalchemy

Upgrade to sqlalchemy 1.4.0
  • Loading branch information
omerabuddi authored Aug 21, 2023
2 parents 8f9407d + a2e28fc commit 9d1f038
Show file tree
Hide file tree
Showing 8 changed files with 252 additions and 83 deletions.
2 changes: 2 additions & 0 deletions jwthenticator/consts.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,11 @@ def days_to_seconds(days: int) -> int:

# DB consts
DB_CONNECTOR = env.str("DB_CONNECTOR", "postgresql+pg8000")
ASYNC_DB_CONNECTOR = env.str("ASYNC_DB_CONNECTOR", "postgresql+asyncpg")
DB_USER = env.str("DB_USER", "postgres")
DB_PASS = env.str("DB_PASS", "")
DB_HOST = env.str("DB_HOST", "localhost")
DB_NAME = env.str("DB_NAME", "jwthenticator")

DB_URI = env.str("DB_URI", f"{DB_CONNECTOR}://{DB_USER}:{DB_PASS}@{DB_HOST}/{DB_NAME}")
ASYNC_DB_URI = env.str("ASYNC_URI", f"{ASYNC_DB_CONNECTOR}://{DB_USER}:{DB_PASS}@{DB_HOST}/{DB_NAME}")
29 changes: 16 additions & 13 deletions jwthenticator/keys.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,13 @@
from hashlib import sha512
from uuid import UUID

from asyncalchemy import create_session_factory
from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.schemas import KeyData
from jwthenticator.models import Base, KeyInfo
from jwthenticator.exceptions import InvalidKeyError
from jwthenticator.consts import KEY_EXPIRY, DB_URI
from jwthenticator.consts import KEY_EXPIRY, ASYNC_DB_URI


class KeyManager:
Expand All @@ -19,7 +20,7 @@ class KeyManager:
"""

def __init__(self) -> None:
self.session_factory = create_session_factory(DB_URI, Base)
self.async_session_factory = create_async_session_factory(ASYNC_DB_URI, Base)
self.key_schema = KeyData.Schema()


Expand All @@ -43,19 +44,19 @@ async def create_key(self, key: str, identifier: UUID, expires_at: Optional[date
key_hash=key_hash,
identifier=identifier
)
async with self.session_factory() as session:
await session.add(key_obj)
async with self.async_session_factory() as session:
async with session.begin():
session.add(key_obj)
return True


async def check_key_exists(self, key_hash: str) -> bool:
"""
Check if a key exists in DB.
"""
async with self.session_factory() as session:
if await session.query(KeyInfo).filter_by(key_hash=key_hash).count() == 1:
return True
return False
async with self.async_session_factory() as session:
query = select(func.count(KeyInfo.id)).where(KeyInfo.key_hash == key_hash)
return (await session.scalar(query)) == 1


async def update_key_expiry(self, key_hash: str, expires_at: datetime) -> bool:
Expand All @@ -64,8 +65,9 @@ async def update_key_expiry(self, key_hash: str, expires_at: datetime) -> bool:
"""
if not await self.check_key_exists(key_hash):
raise InvalidKeyError("Invalid key")
async with self.session_factory() as session:
key_info_obj = await session.query(KeyInfo).filter_by(key_hash=key_hash).first()
async with self.async_session_factory() as session:
query = select(KeyInfo).where(KeyInfo.key_hash == key_hash)
key_info_obj = await session.scalar(query)
key_info_obj.expires_at = expires_at
return True

Expand All @@ -76,7 +78,8 @@ async def get_key(self, key_hash: str) -> KeyData:
"""
if not await self.check_key_exists(key_hash):
raise InvalidKeyError("Invalid key")
async with self.session_factory() as session:
key_info_obj = await session.query(KeyInfo).filter_by(key_hash=key_hash).first()
async with self.async_session_factory() as session:
query = select(KeyInfo).where(KeyInfo.key_hash == key_hash)
key_info_obj = await session.scalar(query)
key_data_obj = self.key_schema.load((self.key_schema.dump(key_info_obj)))
return key_data_obj
2 changes: 2 additions & 0 deletions jwthenticator/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Union
from unittest.mock import patch

import nest_asyncio
from aiohttp import web, ClientSession
from aiohttp.client import ClientSession as ClientSessionType
from aiohttp.test_utils import AioHTTPTestCase, TestClient
Expand All @@ -21,6 +22,7 @@
SERVER_HOST = "127.0.0.1"
SERVER_URL = f"http://{SERVER_HOST}:{SERVER_PORT}"
CLIENT_PATCH_FILES = ["client.py"]
nest_asyncio.apply()


@authenticate(SERVER_URL)
Expand Down
2 changes: 2 additions & 0 deletions jwthenticator/tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from http import HTTPStatus
from unittest.mock import MagicMock

import nest_asyncio
from aiohttp.test_utils import AioHTTPTestCase
from aiohttp.web import Application
from jwt import PyJWKClient
Expand Down Expand Up @@ -36,6 +37,7 @@ async def get_application(self) -> Application:


def setup_class(self) -> None:
nest_asyncio.apply()
self.auth_request_schema = AuthRequest.Schema()
self.token_response_schema = TokenResponse.Schema()
self.refresh_request_schema = RefreshRequest.Schema()
Expand Down
28 changes: 15 additions & 13 deletions jwthenticator/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from uuid import UUID, uuid4

import jwt
from asyncalchemy import create_session_factory
from sqlalchemy import select, func

from jwthenticator.utils import create_async_session_factory
from jwthenticator.models import Base, RefreshTokenInfo
from jwthenticator.schemas import JWTPayloadData, RefreshTokenData
from jwthenticator.exceptions import InvalidTokenError, MissingJWTError
from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, DB_URI
from jwthenticator.consts import JWT_ALGORITHM, REFRESH_TOKEN_EXPIRY, JWT_LEASE_TIME, JWT_AUDIENCE, ASYNC_DB_URI

class TokenManager:
"""
Expand Down Expand Up @@ -39,8 +40,7 @@ def __init__(self, public_key: str, private_key: Optional[str] = None, algorithm
self.refresh_token_schema = RefreshTokenData.Schema()
self.jwt_payload_data_schema = JWTPayloadData.Schema()

self.session_factory = create_session_factory(DB_URI, Base)

self.async_session_factory = create_async_session_factory(ASYNC_DB_URI, Base)

async def create_access_token(self, identifier: UUID) -> str:
"""
Expand Down Expand Up @@ -87,13 +87,15 @@ async def create_refresh_token(self, key_id: int, expires_at: Optional[datetime]
raise Exception("Refresh token can't be created in the past")

refresh_token_str = sha512(uuid4().bytes).hexdigest()
async with self.session_factory() as session:
async with self.async_session_factory() as session:
refresh_token_info_obj = RefreshTokenInfo(
expires_at=expires_at,
token=refresh_token_str,
key_id=key_id
)
await session.add(refresh_token_info_obj)
session.add(refresh_token_info_obj)
await session.commit()
await session.refresh(refresh_token_info_obj)
await session.flush()
return refresh_token_str

Expand All @@ -102,10 +104,9 @@ async def check_refresh_token_exists(self, refresh_token: str) -> bool:
"""
Check if a refresh token exists in DB.
"""
async with self.session_factory() as session:
if await session.query(RefreshTokenInfo).filter_by(token=refresh_token).count() == 1:
return True
return False
async with self.async_session_factory() as session:
query = select(func.count(RefreshTokenInfo.id)).where(RefreshTokenInfo.token == refresh_token)
return (await session.scalar(query)) == 1


async def load_refresh_token(self, refresh_token: str) -> RefreshTokenData:
Expand All @@ -114,7 +115,8 @@ async def load_refresh_token(self, refresh_token: str) -> RefreshTokenData:
"""
if not await self.check_refresh_token_exists(refresh_token):
raise InvalidTokenError("Invalid refresh token")
async with self.session_factory() as session:
refresh_token_info_obj = await session.query(RefreshTokenInfo).filter_by(token=refresh_token).first()
refresh_token_data_obj = self.refresh_token_schema.load(self.refresh_token_schema.dump(refresh_token_info_obj))
async with self.async_session_factory() as session:
query = select(RefreshTokenInfo).where(RefreshTokenInfo.token == refresh_token)
refresh_token_info_obj = (await session.execute(query)).first()
refresh_token_data_obj = self.refresh_token_schema.load(self.refresh_token_schema.dump(refresh_token_info_obj[0]))
return refresh_token_data_obj
25 changes: 24 additions & 1 deletion jwthenticator/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
from __future__ import absolute_import

import asyncio
from os.path import isfile
from typing import Tuple, Optional
from typing import Any, Dict, Tuple, Optional
from urllib.parse import urlparse

from jwt.utils import base64url_encode
from Cryptodome.PublicKey import RSA
from Cryptodome.Hash import SHA1
from sqlalchemy.orm import sessionmaker
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.ext.asyncio import create_async_engine, AsyncSession, AsyncEngine
from sqlalchemy.pool import NullPool

from jwthenticator.consts import RSA_KEY_STRENGTH, RSA_PUBLIC_KEY, RSA_PRIVATE_KEY, RSA_PUBLIC_KEY_PATH, RSA_PRIVATE_KEY_PATH

Expand Down Expand Up @@ -96,3 +101,21 @@ def fix_url_path(url: str) -> str:
the path will be removed by urljoin.
"""
return url if url.endswith("/") else url + "/"


async def create_base(engine: AsyncEngine, base: DeclarativeMeta) -> None:
async with engine.begin() as conn:
await conn.run_sync(base.metadata.create_all)


def create_async_session_factory(uri: str, base: Optional[DeclarativeMeta] = None, **engine_kwargs: Dict[Any, Any]) -> sessionmaker:
"""
:param uri: Database uniform resource identifier
:param base: Declarative SQLAlchemy class to base off table initialization
:param engine_kwargs: Arguments to pass to SQLAlchemy's engine initialization
:returns: :class:`.AsyncSession` factory
"""
engine = create_async_engine(uri, **engine_kwargs, poolclass=NullPool)
if base is not None:
asyncio.get_event_loop().run_until_complete(create_base(engine, base))
return sessionmaker(engine, expire_on_commit=False, class_=AsyncSession)
Loading

0 comments on commit 9d1f038

Please sign in to comment.