Skip to content

Commit

Permalink
use dependencies to get test and worker, raise exception instances
Browse files Browse the repository at this point in the history
  • Loading branch information
elfkuzco committed Jun 19, 2024
1 parent 68602d6 commit 98b39c3
Show file tree
Hide file tree
Showing 16 changed files with 172 additions and 154 deletions.
2 changes: 1 addition & 1 deletion backend/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ dependencies = [
"beautifulsoup4==4.12.3",
"requests==2.32.3",
"pycountry==24.6.1",
"httpx==0.27.0",
"cryptography==42.0.8",
"PyJWT==2.8.0",
]
Expand Down Expand Up @@ -55,6 +54,7 @@ test = [
"coverage==7.4.1",
"Faker==25.8.0",
"paramiko==3.4.0",
"httpx==0.27.0",
]
dev = [
"pre-commit==3.6.0",
Expand Down
2 changes: 1 addition & 1 deletion backend/src/mirrors_qa_backend/cryptography.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def verify_signed_message(public_key: bytes, signature: bytes, message: bytes) -
try:
pem_public_key = serialization.load_pem_public_key(public_key)
except Exception as exc:
raise PEMPublicKeyLoadError from exc
raise PEMPublicKeyLoadError("Unable to load public key") from exc

try:
pem_public_key.verify( # pyright: ignore
Expand Down
14 changes: 1 addition & 13 deletions backend/src/mirrors_qa_backend/db/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,7 @@
from sqlalchemy.orm import sessionmaker

from mirrors_qa_backend import logger
from mirrors_qa_backend.db import (
mirrors,
models,
tests,
worker,
)
from mirrors_qa_backend.db import mirrors, models
from mirrors_qa_backend.extract import get_current_mirrors
from mirrors_qa_backend.settings import Settings

Expand Down Expand Up @@ -63,10 +58,3 @@ def initialize_mirrors() -> None:
f"Added {result.nb_mirrors_added} mirrors. "
f"Disabled {result.nb_mirrors_disabled} mirrors."
)


__all__ = [
"tests",
"worker",
"mirrors",
]
4 changes: 2 additions & 2 deletions backend/src/mirrors_qa_backend/db/exceptions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
class ModelDoesNotExistError(Exception):
"""A database model does not exist."""
class RecordDoesNotExistError(Exception):
"""A database record does not exist."""

def __init__(self, message: str, *args: object) -> None:
super().__init__(message, *args)
Expand Down
20 changes: 10 additions & 10 deletions backend/src/mirrors_qa_backend/db/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from sqlalchemy.orm import Session as OrmSession

from mirrors_qa_backend.db import models
from mirrors_qa_backend.db.exceptions import ModelDoesNotExistError
from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError
from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum
from mirrors_qa_backend.settings import Settings

Expand All @@ -25,7 +25,7 @@ def filter_test(
*,
worker_id: str | None = None,
country: str | None = None,
status: list[StatusEnum] | None = None,
statuses: list[StatusEnum] | None = None,
) -> bool:
"""Checks if a test has the same attribute as the provided attribute.
Expand All @@ -36,7 +36,7 @@ def filter_test(
return False
if country is not None and test.country != country:
return False
if status is not None and test.status not in status:
if statuses is not None and test.status not in statuses:
return False
return True

Expand All @@ -52,16 +52,16 @@ def list_tests(
*,
worker_id: str | None = None,
country: str | None = None,
status: list[StatusEnum] | None = None,
statuses: list[StatusEnum] | None = None,
page_num: int = 1,
page_size: int = Settings.MAX_PAGE_SIZE,
sort_column: TestSortColumnEnum = TestSortColumnEnum.requested_on,
sort_direction: SortDirectionEnum = SortDirectionEnum.asc,
) -> TestListResult:

# If no status is provided, populate status with all the allowed values
if status is None:
status = list(StatusEnum)
if statuses is None:
statuses = list(StatusEnum)

if sort_direction == SortDirectionEnum.asc:
direction = asc
Expand All @@ -86,9 +86,9 @@ def list_tests(
query = (
select(func.count().over().label("total_records"), models.Test)
.where(
(models.Test.worker_id == worker_id) | (worker_id == None), # noqa
(models.Test.country == country) | (country == None), # noqa
(models.Test.status.in_(status)),
(models.Test.worker_id == worker_id) | (worker_id is None),
(models.Test.country == country) | (country is None),
(models.Test.status.in_(statuses)),
)
.order_by(*order_by)
.offset((page_num - 1) * page_size)
Expand Down Expand Up @@ -127,7 +127,7 @@ def create_or_update_test(
else:
test = get_test(session, test_id)
if test is None:
raise ModelDoesNotExistError(f"Test with id: {test_id} does not exist.")
raise RecordDoesNotExistError(f"Test with id: {test_id} does not exist.")

# If a value is provided, it takes precedence over the default value of the model
test.worker_id = worker_id if worker_id else test.worker_id
Expand Down
10 changes: 2 additions & 8 deletions backend/src/mirrors_qa_backend/main.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
from contextlib import asynccontextmanager

from fastapi import Depends, FastAPI
from fastapi import FastAPI

from mirrors_qa_backend import db
from mirrors_qa_backend.routes import auth, tests
from mirrors_qa_backend.routes.dependencies import verify_authorization_header


@asynccontextmanager
Expand All @@ -15,12 +14,7 @@ async def lifespan(_: FastAPI):


def create_app(*, debug: bool = True):
app = FastAPI(
debug=debug,
docs_url="/",
lifespan=lifespan,
dependencies=[Depends(verify_authorization_header)],
)
app = FastAPI(debug=debug, docs_url="/", lifespan=lifespan)

app.include_router(router=tests.router)
app.include_router(router=auth.router)
Expand Down
3 changes: 0 additions & 3 deletions backend/src/mirrors_qa_backend/routes/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +0,0 @@
from mirrors_qa_backend.routes.dependencies import CurrentWorker, DbSession

__all__ = ["DbSession", "CurrentWorker"]
34 changes: 19 additions & 15 deletions backend/src/mirrors_qa_backend/routes/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,11 @@

from fastapi import APIRouter, Header

from mirrors_qa_backend import cryptography, db, logger, schemas
from mirrors_qa_backend.routes import DbSession, http_errors
from mirrors_qa_backend import cryptography, logger, schemas
from mirrors_qa_backend.db import worker
from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError
from mirrors_qa_backend.routes import http_errors
from mirrors_qa_backend.routes.dependencies import DbSession
from mirrors_qa_backend.settings import Settings

router = APIRouter(prefix="/auth", tags=["auth"])
Expand All @@ -20,23 +23,23 @@ def authenticate_worker(
Header(description="message (format): worker_id:timestamp (UTC ISO)"),
],
x_sshauth_signature: Annotated[
str, Header(description="bas64 string of signature")
str, Header(description="signature, base64-encoded")
],
) -> schemas.Token:
"""Authenticate using signed message and generate tokens."""
try:
signature = base64.standard_b64decode(x_sshauth_signature)
except binascii.Error:
except binascii.Error as exc:
raise http_errors.BadRequestError(
"Invalid signature format (not base64)"
) from None
) from exc

try:
# decode message: worker_id:timestamp(UTC ISO)
worker_id, timestamp_str = x_sshauth_message.split(":", 1)
timestamp = datetime.datetime.fromisoformat(timestamp_str)
except ValueError:
raise http_errors.BadRequestError("Invalid message format.") from None
except ValueError as exc:
raise http_errors.BadRequestError("Invalid message format.") from exc

# verify timestamp is less than MESSAGE_VALIDITY
if (
Expand All @@ -48,20 +51,21 @@ def authenticate_worker(
)

# verify worker with worker_id exists in database
worker = db.worker.get_worker(session, worker_id)
if worker is None:
raise http_errors.UnauthorizedError
db_worker = worker.get_worker(session, worker_id)
if db_worker is None:
raise http_errors.UnauthorizedError()

# verify signature of message with worker's public keys
try:
cryptography.verify_signed_message(
bytes(worker.pubkey_pkcs8, encoding="ascii"),
if not cryptography.verify_signed_message(
bytes(db_worker.pubkey_pkcs8, encoding="ascii"),
signature,
bytes(x_sshauth_message, encoding="ascii"),
)
except Exception:
):
raise http_errors.UnauthorizedError()
except PEMPublicKeyLoadError as exc:
logger.exception("error while verifying message using public key")
raise http_errors.ServerError from None
raise http_errors.ForbiddenError("Unable to load public_key") from exc

# generate tokens
access_token = cryptography.generate_access_token(worker_id)
Expand Down
70 changes: 38 additions & 32 deletions backend/src/mirrors_qa_backend/routes/dependencies.py
Original file line number Diff line number Diff line change
@@ -1,57 +1,63 @@
from typing import Annotated

import jwt
from fastapi import Depends, Header
from fastapi import Depends, Header, Path
from jwt import exceptions as jwt_exceptions
from pydantic import UUID4
from pydantic import ValidationError as PydanticValidationError
from sqlalchemy.orm import Session

from mirrors_qa_backend import db, schemas
from mirrors_qa_backend.db import gen_dbsession, models
from mirrors_qa_backend import schemas
from mirrors_qa_backend.db import gen_dbsession, models, tests, worker
from mirrors_qa_backend.routes import http_errors
from mirrors_qa_backend.settings import Settings

DbSession = Annotated[Session, Depends(gen_dbsession)]


def verify_authorization_header(
authorization: Annotated[str | None, Header()] = None
) -> schemas.JWTClaims | None:
if authorization is None:
return None

def get_current_worker(
session: DbSession,
authorization: Annotated[str, Header()] = "",
) -> models.Worker:
header_parts = authorization.split(" ")
if len(header_parts) != 2 or header_parts[0] != "Bearer": # noqa
raise http_errors.UnauthorizedError
if len(header_parts) != 2 or header_parts[0] != "Bearer": # noqa: PLR2004
raise http_errors.UnauthorizedError()

token = header_parts[1]
try:
claims = jwt.decode(token, Settings.JWT_SECRET, algorithms=["HS256"])
except jwt_exceptions.ExpiredSignatureError:
raise http_errors.UnauthorizedError("Token has expired.") from None
except (jwt_exceptions.InvalidTokenError, jwt_exceptions.PyJWTError):
raise http_errors.UnauthorizedError from None
jwt_claims = jwt.decode(token, Settings.JWT_SECRET, algorithms=["HS256"])
except jwt_exceptions.ExpiredSignatureError as exc:
raise http_errors.UnauthorizedError("Token has expired.") from exc
except (jwt_exceptions.InvalidTokenError, jwt_exceptions.PyJWTError) as exc:
raise http_errors.UnauthorizedError from exc

try:
claims = schemas.JWTClaims(**claims)
except PydanticValidationError:
raise http_errors.UnauthorizedError from None
return claims


def get_current_worker(
session: DbSession,
claims: Annotated[schemas.JWTClaims | None, Depends(verify_authorization_header)],
) -> models.Worker:
if claims is None:
raise http_errors.UnauthorizedError
claims = schemas.JWTClaims(**jwt_claims)
except PydanticValidationError as exc:
raise http_errors.UnauthorizedError from exc

# At this point, we know that the JWT is all OK and we can
# trust the data in it. We extract the worker_id from the claims
worker = db.worker.get_worker(session, claims.subject)
if worker is None:
raise http_errors.UnauthorizedError
return worker
db_worker = worker.get_worker(session, claims.subject)
if db_worker is None:
raise http_errors.UnauthorizedError()
return db_worker


CurrentWorker = Annotated[models.Worker, Depends(get_current_worker)]


def get_test(session: DbSession, test_id: Annotated[UUID4, Path()]) -> models.Test:
"""Fetches the test specified in the request."""
test = tests.get_test(session, test_id)
if test is None:
raise http_errors.NotFoundError(f"Test with id {test_id} does not exist.")
return test


GetTest = Annotated[models.Test, Depends(get_test)]


def verify_worker_owns_test(worker: CurrentWorker, test: GetTest):
if test.worker_id != worker.id:
raise http_errors.UnauthorizedError("Insufficient privileges to update test.")
11 changes: 11 additions & 0 deletions backend/src/mirrors_qa_backend/routes/http_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@ def __init__(self, message: Any = None) -> None:
)


class ForbiddenError(HTTPException):
def __init__(self, message: Any = None) -> None:
if message is None:
message = "Identity unknown to server."
super().__init__(
status_code=status.HTTP_403_FORBIDDEN,
detail=message,
headers={"WWW-Authenticate": "Bearer"},
)


class NotFoundError(HTTPException):
def __init__(self, message: Any) -> None:
super().__init__(status_code=status.HTTP_404_NOT_FOUND, detail=message)
Expand Down
Loading

0 comments on commit 98b39c3

Please sign in to comment.