Skip to content

Commit

Permalink
refactor: improve type-strictness
Browse files Browse the repository at this point in the history
Done by enabling mypy-strict mode
  • Loading branch information
jorgenengelsen committed Sep 15, 2023
1 parent 4f208d8 commit 903e36e
Show file tree
Hide file tree
Showing 17 changed files with 85 additions and 65 deletions.
5 changes: 5 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,11 @@ repos:
- types-requests
- types-ujson
- types-toml
- types-click
- types-python-jose
- pymongo
- pydantic
- fastapi

# The path to the venv python interpreter differ between linux and windows. An if/else is used to find it on either.
- repo: local
Expand Down
8 changes: 6 additions & 2 deletions api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -65,11 +65,15 @@ requires = ["poetry-core>=1.0.0"]
build-backend = "poetry.core.masonry.api"

[tool.mypy]

plugins = ["pydantic.mypy"]

ignore_missing_imports = true
warn_return_any = true
warn_unused_configs = true
namespace_packages = true
explicit_package_bases = true
allow_subclassing_any = true

strict = true


[tool.ruff]
Expand Down
4 changes: 2 additions & 2 deletions api/src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def create_app() -> FastAPI:


@click.group()
def cli():
def cli() -> None:
pass


@cli.command()
def run():
def run() -> None:
import uvicorn

uvicorn.run(
Expand Down
2 changes: 1 addition & 1 deletion api/src/authentication/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


@cached(cache=TTLCache(maxsize=32, ttl=86400))
def fetch_openid_configuration() -> dict:
def fetch_openid_configuration() -> dict[str, str]:
try:
oid_conf_response = httpx.get(config.OAUTH_WELL_KNOWN)
oid_conf_response.raise_for_status()
Expand Down
5 changes: 2 additions & 3 deletions api/src/authentication/mock_token_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@
"""


def generate_mock_token(user: User = default_user):
def generate_mock_token(user: User = default_user) -> str:
"""
This function is for testing purposes only
Used for behave testing
Expand All @@ -64,5 +64,4 @@ def generate_mock_token(user: User = default_user):
"roles": user.roles,
"iss": "mock-auth-server",
}
token = jwt.encode(payload, mock_rsa_private_key, algorithm="RS256")
return token
return jwt.encode(payload, mock_rsa_private_key, algorithm="RS256")
13 changes: 8 additions & 5 deletions api/src/authentication/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from enum import IntEnum
from typing import Any

from pydantic import BaseModel, GetJsonSchemaHandler
from pydantic_core import core_schema
Expand All @@ -15,11 +16,11 @@ def check_privilege(self, required_level: "AccessLevel") -> bool:
return False

@classmethod
def __get_validators__(cls):
def __get_validators__(cls): # type:ignore
yield cls.validate

@classmethod
def validate(cls, v):
def validate(cls, v: str) -> "AccessLevel":
if isinstance(v, cls):
return v
try:
Expand All @@ -28,7 +29,9 @@ def validate(cls, v):
raise ValueError("invalid AccessLevel enum value ")

@classmethod
def __get_pydantic_json_schema__(cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler):
def __get_pydantic_json_schema__(
cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler
) -> dict[str, Any]:
"""
Add a custom field type to the class representing the Enum's field names
Ref: https://pydantic-docs.helpmanual.io/usage/schema/#modifying-schema-in-custom-fields
Expand All @@ -50,7 +53,7 @@ class User(BaseModel):
roles: list[str] = []
scope: AccessLevel = AccessLevel.WRITE

def __hash__(self):
def __hash__(self) -> int:
return hash(type(self.user_id))


Expand All @@ -70,7 +73,7 @@ class ACL(BaseModel):
users: dict[str, AccessLevel] = {}
others: AccessLevel = AccessLevel.READ

def dict(self, **kwargs):
def dict(self, **kwargs: Any) -> dict[str, str | dict[str, AccessLevel | str]]:
return {
"owner": self.owner,
"roles": {k: v.name for k, v in self.roles.items()},
Expand Down
2 changes: 1 addition & 1 deletion api/src/common/exception_handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def validation_exception_handler(request: Request, exc: RequestValidationError)
)


def http_exception_handler(request: Request, exc: HTTPStatusError):
def http_exception_handler(request: Request, exc: HTTPStatusError) -> JSONResponse:
logger.error(exc)
return JSONResponse(
ErrorResponse(
Expand Down
16 changes: 8 additions & 8 deletions api/src/common/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class ErrorResponse(BaseModel):
type: str = "ApplicationException"
message: str = "The requested operation failed"
debug: str = "An unknown and unhandled exception occurred in the API"
extra: dict | None = None
extra: dict[str, str] | None = None


class ApplicationException(Exception):
Expand All @@ -26,13 +26,13 @@ class ApplicationException(Exception):
type: str = "ApplicationException"
message: str = "The requested operation failed"
debug: str = "An unknown and unhandled exception occurred in the API"
extra: dict | None = None
extra: dict[str, str] | None = None

def __init__(
self,
message: str = "The requested operation failed",
debug: str = "An unknown and unhandled exception occurred in the API",
extra: dict | None = None,
extra: dict[str, str] | None = None,
status: int = 500,
severity: ExceptionSeverity = ExceptionSeverity.ERROR,
):
Expand All @@ -43,7 +43,7 @@ def __init__(
self.extra = extra
self.severity = severity

def dict(self):
def dict(self) -> dict[str, int | str | dict[str, str] | None]:
return {
"status": self.status,
"type": self.type,
Expand All @@ -58,7 +58,7 @@ def __init__(
self,
message: str = "You do not have the required permissions",
debug: str = "Action denied because of insufficient permissions",
extra: dict | None = None,
extra: dict[str, str] | None = None,
):
super().__init__(message, debug, extra, request_status.HTTP_403_FORBIDDEN, severity=ExceptionSeverity.WARNING)
self.type = self.__class__.__name__
Expand All @@ -69,7 +69,7 @@ def __init__(
self,
message: str = "The requested resource could not be found",
debug: str = "The requested resource could not be found",
extra: dict | None = None,
extra: dict[str, str] | None = None,
):
super().__init__(message, debug, extra, request_status.HTTP_404_NOT_FOUND)
self.type = self.__class__.__name__
Expand All @@ -80,7 +80,7 @@ def __init__(
self,
message: str = "Invalid data for the operation",
debug: str = "Unable to complete the requested operation with the given input values.",
extra: dict | None = None,
extra: dict[str, str] | None = None,
):
super().__init__(message, debug, extra, request_status.HTTP_400_BAD_REQUEST)
self.type = self.__class__.__name__
Expand All @@ -91,7 +91,7 @@ def __init__(
self,
message: str = "The received data is invalid",
debug: str = "Values are invalid for requested operation.",
extra: dict | None = None,
extra: dict[str, str] | None = None,
):
super().__init__(message, debug, extra, request_status.HTTP_422_UNPROCESSABLE_ENTITY)
self.type = self.__class__.__name__
Expand Down
17 changes: 9 additions & 8 deletions api/src/common/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from opencensus.trace.samplers import ProbabilitySampler
from opencensus.trace.tracer import Tracer
from starlette.datastructures import MutableHeaders
from starlette.types import ASGIApp, Message, Receive, Scope, Send

from common.logger import logger
from config import config
Expand All @@ -15,20 +16,20 @@
# Middleware inheriting from the "BaseHTTPMiddleware" class does not work with Starlettes BackgroundTasks.
# see: https://github.com/encode/starlette/issues/919
class LocalLoggerMiddleware:
def __init__(self, app):
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope, receive, send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)

start_time = time.time()
process_time = ""
path = scope["path"]
method = scope["method"]
response = {}
response: Message = {}

async def inner_send(message):
async def inner_send(message: Message) -> None:
nonlocal process_time
nonlocal response
if message["type"] == "http.response.start":
Expand All @@ -49,19 +50,19 @@ class OpenCensusRequestLoggingMiddleware:
exporter = AzureExporter(connection_string=config.APPINSIGHTS_CONSTRING) if config.APPINSIGHTS_CONSTRING else None
sampler = ProbabilitySampler(1.0)

def __init__(self, app):
def __init__(self, app: ASGIApp):
self.app = app

async def __call__(self, scope, receive, send):
async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
if scope["type"] != "http":
return await self.app(scope, receive, send)

tracer = Tracer(exporter=self.exporter, sampler=self.sampler)

path = scope["path"]
response = {}
response: Message = {}

async def inner_send(message):
async def inner_send(message: Message) -> None:
nonlocal response
if message["type"] == "http.response.start":
response = message
Expand Down
2 changes: 1 addition & 1 deletion api/src/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class Config(BaseSettings):
MICROSOFT_AUTH_PROVIDER: str = "login.microsoftonline.com"


config = Config() # type: ignore[call-arg]
config = Config()

if config.AUTH_ENABLED and not all((config.OAUTH_AUTH_ENDPOINT, config.OAUTH_TOKEN_ENDPOINT, config.OAUTH_WELL_KNOWN)):
raise ValueError("Authentication was enabled, but some auth configuration parameters are missing")
Expand Down
15 changes: 9 additions & 6 deletions api/src/data_providers/clients/client_interface.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from abc import abstractmethod
from typing import Generic, TypeVar
from typing import Any, Generic, TypeVar

# Type definition for Model
M = TypeVar("M")

# Type definition for Unique Id
K = TypeVar("K")

# Type definition for filter
FilterDict = dict[str, Any]


class ClientInterface(Generic[M, K]):
@abstractmethod
Expand All @@ -30,21 +33,21 @@ def update(self, id: K, instance: M) -> M:
pass

@abstractmethod
def insert_many(self, instances: list[M]):
def insert_many(self, instances: list[M]) -> None:
pass

@abstractmethod
def delete_many(self, filter: dict):
def delete_many(self, filter: FilterDict) -> None:
pass

@abstractmethod
def find(self, filter: dict) -> M:
def find(self, filter: FilterDict) -> M:
pass

@abstractmethod
def find_one(self, filter: dict) -> M | None:
def find_one(self, filter: FilterDict) -> M | None:
pass

@abstractmethod
def delete_collection(self):
def delete_collection(self) -> None:
pass
33 changes: 18 additions & 15 deletions api/src/data_providers/clients/mongodb/mongo_database_client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Any

from pymongo.cursor import Cursor
from pymongo.database import Database
from pymongo.errors import DuplicateKeyError
from pymongo.mongo_client import MongoClient
from pymongo.results import DeleteResult, InsertManyResult

from common.exceptions import NotFoundException, ValidationException
from config import config
from data_providers.clients.client_interface import ClientInterface

MONGO_CLIENT: MongoClient = MongoClient(
MONGO_CLIENT: MongoClient[dict[str, Any]] = MongoClient(
host=config.MONGODB_HOSTNAME,
port=config.MONGODB_PORT,
username=config.MONGODB_USERNAME,
Expand All @@ -21,58 +24,58 @@


class MongoDatabaseClient(ClientInterface[dict, str]):
def __init__(self, collection_name: str, database_name: str, client: MongoClient = MONGO_CLIENT):
database: Database = client[database_name]
def __init__(self, collection_name: str, database_name: str, client: MongoClient[dict[str, Any]] = MONGO_CLIENT):
database: Database[dict[str, Any]] = client[database_name]
self.database = database
self.collection_name = collection_name
self.collection = database[collection_name]

def wipe_db(self):
def wipe_db(self) -> None:
databases = self.database.client.list_database_names()
databases_to_delete = [
database_name for database_name in databases if database_name not in ("admin", "config", "local")
] # Don't delete the mongo admin or local database
for database_name in databases_to_delete:
self.database.client.drop_database(database_name)

def delete_collection(self):
def delete_collection(self) -> None:
self.collection.drop()

def create(self, document: dict) -> dict:
def create(self, document: dict[str, Any]) -> dict[str, Any]:
try:
result = self.collection.insert_one(document)
return self.get(str(result.inserted_id))
except DuplicateKeyError:
raise ValidationException(message=f"The document with id '{document['_id']}' already exists")

def list_collection(self) -> list[dict]:
def list_collection(self) -> list[dict[str, Any]]:
return list(self.collection.find())

def get(self, uid: str) -> dict:
def get(self, uid: str) -> dict[str, Any]:
document = self.collection.find_one(filter={"_id": uid})
if document is None:
raise NotFoundException
else:
return dict(document)

def update(self, uid: str, document: dict) -> dict:
def update(self, uid: str, document: dict[str, Any]) -> dict[str, Any]:
if self.collection.find_one(filter={"_id": uid}) is None:
raise NotFoundException(extra={"uid": uid})
self.collection.replace_one({"_id": uid}, document)
return self.get(uid)

def delete(self, uid: str) -> bool:
result = self.collection.delete_one(filter={"_id": uid})
return result.deleted_count > 0 # type: ignore
return result.deleted_count > 0

def find(self, filter: dict) -> Cursor:
def find(self, filter: dict[str, Any]) -> Cursor[dict[str, Any]]:
return self.collection.find(filter=filter)

def find_one(self, filter: dict) -> dict | None:
return self.collection.find_one(filter=filter) # type: ignore
def find_one(self, filter: dict[str, Any]) -> dict[str, Any] | None:
return self.collection.find_one(filter=filter)

def insert_many(self, items: list[dict]):
def insert_many(self, items: list[dict[str, Any]]) -> InsertManyResult:
return self.collection.insert_many(items)

def delete_many(self, filter: dict):
def delete_many(self, filter: dict[str, Any]) -> DeleteResult:
return self.collection.delete_many(filter)
Loading

0 comments on commit 903e36e

Please sign in to comment.