From 98b39c3d9493aa114ca009a5f4b0bd2ebb9e7a75 Mon Sep 17 00:00:00 2001 From: Uchechukwu Orji Date: Wed, 19 Jun 2024 11:44:52 +0100 Subject: [PATCH] use dependencies to get test and worker, raise exception instances --- backend/pyproject.toml | 2 +- .../src/mirrors_qa_backend/cryptography.py | 2 +- backend/src/mirrors_qa_backend/db/__init__.py | 14 +--- .../src/mirrors_qa_backend/db/exceptions.py | 4 +- backend/src/mirrors_qa_backend/db/tests.py | 20 +++--- backend/src/mirrors_qa_backend/main.py | 10 +-- .../src/mirrors_qa_backend/routes/__init__.py | 3 - backend/src/mirrors_qa_backend/routes/auth.py | 34 +++++---- .../mirrors_qa_backend/routes/dependencies.py | 70 ++++++++++--------- .../mirrors_qa_backend/routes/http_errors.py | 11 +++ .../src/mirrors_qa_backend/routes/tests.py | 37 +++++----- backend/src/mirrors_qa_backend/schemas.py | 2 +- backend/tests/conftest.py | 31 ++++++-- backend/tests/db/test_tests.py | 42 +++++------ backend/tests/routes/conftest.py | 15 ++++ backend/tests/routes/test_tests_endpoints.py | 29 ++------ 16 files changed, 172 insertions(+), 154 deletions(-) diff --git a/backend/pyproject.toml b/backend/pyproject.toml index ff89c62..7de43c6 100644 --- a/backend/pyproject.toml +++ b/backend/pyproject.toml @@ -20,7 +20,6 @@ dependencies = [ "beautifulsoup4==4.12.3", "requests==2.32.3", "pycountry==24.6.1", - "httpx==0.27.0", "cryptography==42.0.8", "PyJWT==2.8.0", ] @@ -55,6 +54,7 @@ test = [ "coverage==7.4.1", "Faker==25.8.0", "paramiko==3.4.0", + "httpx==0.27.0", ] dev = [ "pre-commit==3.6.0", diff --git a/backend/src/mirrors_qa_backend/cryptography.py b/backend/src/mirrors_qa_backend/cryptography.py index a4b6f42..d4286c4 100644 --- a/backend/src/mirrors_qa_backend/cryptography.py +++ b/backend/src/mirrors_qa_backend/cryptography.py @@ -15,7 +15,7 @@ def verify_signed_message(public_key: bytes, signature: bytes, message: bytes) - try: pem_public_key = serialization.load_pem_public_key(public_key) except Exception as exc: - raise PEMPublicKeyLoadError from exc + raise PEMPublicKeyLoadError("Unable to load public key") from exc try: pem_public_key.verify( # pyright: ignore diff --git a/backend/src/mirrors_qa_backend/db/__init__.py b/backend/src/mirrors_qa_backend/db/__init__.py index 718e2b3..4336f88 100644 --- a/backend/src/mirrors_qa_backend/db/__init__.py +++ b/backend/src/mirrors_qa_backend/db/__init__.py @@ -7,12 +7,7 @@ from sqlalchemy.orm import sessionmaker from mirrors_qa_backend import logger -from mirrors_qa_backend.db import ( - mirrors, - models, - tests, - worker, -) +from mirrors_qa_backend.db import mirrors, models from mirrors_qa_backend.extract import get_current_mirrors from mirrors_qa_backend.settings import Settings @@ -63,10 +58,3 @@ def initialize_mirrors() -> None: f"Added {result.nb_mirrors_added} mirrors. " f"Disabled {result.nb_mirrors_disabled} mirrors." ) - - -__all__ = [ - "tests", - "worker", - "mirrors", -] diff --git a/backend/src/mirrors_qa_backend/db/exceptions.py b/backend/src/mirrors_qa_backend/db/exceptions.py index 71ec2c2..ec251c9 100644 --- a/backend/src/mirrors_qa_backend/db/exceptions.py +++ b/backend/src/mirrors_qa_backend/db/exceptions.py @@ -1,5 +1,5 @@ -class ModelDoesNotExistError(Exception): - """A database model does not exist.""" +class RecordDoesNotExistError(Exception): + """A database record does not exist.""" def __init__(self, message: str, *args: object) -> None: super().__init__(message, *args) diff --git a/backend/src/mirrors_qa_backend/db/tests.py b/backend/src/mirrors_qa_backend/db/tests.py index 309e8af..4aefd1e 100644 --- a/backend/src/mirrors_qa_backend/db/tests.py +++ b/backend/src/mirrors_qa_backend/db/tests.py @@ -7,7 +7,7 @@ from sqlalchemy.orm import Session as OrmSession from mirrors_qa_backend.db import models -from mirrors_qa_backend.db.exceptions import ModelDoesNotExistError +from mirrors_qa_backend.db.exceptions import RecordDoesNotExistError from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum from mirrors_qa_backend.settings import Settings @@ -25,7 +25,7 @@ def filter_test( *, worker_id: str | None = None, country: str | None = None, - status: list[StatusEnum] | None = None, + statuses: list[StatusEnum] | None = None, ) -> bool: """Checks if a test has the same attribute as the provided attribute. @@ -36,7 +36,7 @@ def filter_test( return False if country is not None and test.country != country: return False - if status is not None and test.status not in status: + if statuses is not None and test.status not in statuses: return False return True @@ -52,7 +52,7 @@ def list_tests( *, worker_id: str | None = None, country: str | None = None, - status: list[StatusEnum] | None = None, + statuses: list[StatusEnum] | None = None, page_num: int = 1, page_size: int = Settings.MAX_PAGE_SIZE, sort_column: TestSortColumnEnum = TestSortColumnEnum.requested_on, @@ -60,8 +60,8 @@ def list_tests( ) -> TestListResult: # If no status is provided, populate status with all the allowed values - if status is None: - status = list(StatusEnum) + if statuses is None: + statuses = list(StatusEnum) if sort_direction == SortDirectionEnum.asc: direction = asc @@ -86,9 +86,9 @@ def list_tests( query = ( select(func.count().over().label("total_records"), models.Test) .where( - (models.Test.worker_id == worker_id) | (worker_id == None), # noqa - (models.Test.country == country) | (country == None), # noqa - (models.Test.status.in_(status)), + (models.Test.worker_id == worker_id) | (worker_id is None), + (models.Test.country == country) | (country is None), + (models.Test.status.in_(statuses)), ) .order_by(*order_by) .offset((page_num - 1) * page_size) @@ -127,7 +127,7 @@ def create_or_update_test( else: test = get_test(session, test_id) if test is None: - raise ModelDoesNotExistError(f"Test with id: {test_id} does not exist.") + raise RecordDoesNotExistError(f"Test with id: {test_id} does not exist.") # If a value is provided, it takes precedence over the default value of the model test.worker_id = worker_id if worker_id else test.worker_id diff --git a/backend/src/mirrors_qa_backend/main.py b/backend/src/mirrors_qa_backend/main.py index 1b93e7e..df17118 100644 --- a/backend/src/mirrors_qa_backend/main.py +++ b/backend/src/mirrors_qa_backend/main.py @@ -1,10 +1,9 @@ from contextlib import asynccontextmanager -from fastapi import Depends, FastAPI +from fastapi import FastAPI from mirrors_qa_backend import db from mirrors_qa_backend.routes import auth, tests -from mirrors_qa_backend.routes.dependencies import verify_authorization_header @asynccontextmanager @@ -15,12 +14,7 @@ async def lifespan(_: FastAPI): def create_app(*, debug: bool = True): - app = FastAPI( - debug=debug, - docs_url="/", - lifespan=lifespan, - dependencies=[Depends(verify_authorization_header)], - ) + app = FastAPI(debug=debug, docs_url="/", lifespan=lifespan) app.include_router(router=tests.router) app.include_router(router=auth.router) diff --git a/backend/src/mirrors_qa_backend/routes/__init__.py b/backend/src/mirrors_qa_backend/routes/__init__.py index 89a5dc8..e69de29 100644 --- a/backend/src/mirrors_qa_backend/routes/__init__.py +++ b/backend/src/mirrors_qa_backend/routes/__init__.py @@ -1,3 +0,0 @@ -from mirrors_qa_backend.routes.dependencies import CurrentWorker, DbSession - -__all__ = ["DbSession", "CurrentWorker"] diff --git a/backend/src/mirrors_qa_backend/routes/auth.py b/backend/src/mirrors_qa_backend/routes/auth.py index 771e9e4..59ace0e 100644 --- a/backend/src/mirrors_qa_backend/routes/auth.py +++ b/backend/src/mirrors_qa_backend/routes/auth.py @@ -5,8 +5,11 @@ from fastapi import APIRouter, Header -from mirrors_qa_backend import cryptography, db, logger, schemas -from mirrors_qa_backend.routes import DbSession, http_errors +from mirrors_qa_backend import cryptography, logger, schemas +from mirrors_qa_backend.db import worker +from mirrors_qa_backend.exceptions import PEMPublicKeyLoadError +from mirrors_qa_backend.routes import http_errors +from mirrors_qa_backend.routes.dependencies import DbSession from mirrors_qa_backend.settings import Settings router = APIRouter(prefix="/auth", tags=["auth"]) @@ -20,23 +23,23 @@ def authenticate_worker( Header(description="message (format): worker_id:timestamp (UTC ISO)"), ], x_sshauth_signature: Annotated[ - str, Header(description="bas64 string of signature") + str, Header(description="signature, base64-encoded") ], ) -> schemas.Token: """Authenticate using signed message and generate tokens.""" try: signature = base64.standard_b64decode(x_sshauth_signature) - except binascii.Error: + except binascii.Error as exc: raise http_errors.BadRequestError( "Invalid signature format (not base64)" - ) from None + ) from exc try: # decode message: worker_id:timestamp(UTC ISO) worker_id, timestamp_str = x_sshauth_message.split(":", 1) timestamp = datetime.datetime.fromisoformat(timestamp_str) - except ValueError: - raise http_errors.BadRequestError("Invalid message format.") from None + except ValueError as exc: + raise http_errors.BadRequestError("Invalid message format.") from exc # verify timestamp is less than MESSAGE_VALIDITY if ( @@ -48,20 +51,21 @@ def authenticate_worker( ) # verify worker with worker_id exists in database - worker = db.worker.get_worker(session, worker_id) - if worker is None: - raise http_errors.UnauthorizedError + db_worker = worker.get_worker(session, worker_id) + if db_worker is None: + raise http_errors.UnauthorizedError() # verify signature of message with worker's public keys try: - cryptography.verify_signed_message( - bytes(worker.pubkey_pkcs8, encoding="ascii"), + if not cryptography.verify_signed_message( + bytes(db_worker.pubkey_pkcs8, encoding="ascii"), signature, bytes(x_sshauth_message, encoding="ascii"), - ) - except Exception: + ): + raise http_errors.UnauthorizedError() + except PEMPublicKeyLoadError as exc: logger.exception("error while verifying message using public key") - raise http_errors.ServerError from None + raise http_errors.ForbiddenError("Unable to load public_key") from exc # generate tokens access_token = cryptography.generate_access_token(worker_id) diff --git a/backend/src/mirrors_qa_backend/routes/dependencies.py b/backend/src/mirrors_qa_backend/routes/dependencies.py index e02a6e9..f17c01b 100644 --- a/backend/src/mirrors_qa_backend/routes/dependencies.py +++ b/backend/src/mirrors_qa_backend/routes/dependencies.py @@ -1,57 +1,63 @@ from typing import Annotated import jwt -from fastapi import Depends, Header +from fastapi import Depends, Header, Path from jwt import exceptions as jwt_exceptions +from pydantic import UUID4 from pydantic import ValidationError as PydanticValidationError from sqlalchemy.orm import Session -from mirrors_qa_backend import db, schemas -from mirrors_qa_backend.db import gen_dbsession, models +from mirrors_qa_backend import schemas +from mirrors_qa_backend.db import gen_dbsession, models, tests, worker from mirrors_qa_backend.routes import http_errors from mirrors_qa_backend.settings import Settings DbSession = Annotated[Session, Depends(gen_dbsession)] -def verify_authorization_header( - authorization: Annotated[str | None, Header()] = None -) -> schemas.JWTClaims | None: - if authorization is None: - return None - +def get_current_worker( + session: DbSession, + authorization: Annotated[str, Header()] = "", +) -> models.Worker: header_parts = authorization.split(" ") - if len(header_parts) != 2 or header_parts[0] != "Bearer": # noqa - raise http_errors.UnauthorizedError + if len(header_parts) != 2 or header_parts[0] != "Bearer": # noqa: PLR2004 + raise http_errors.UnauthorizedError() token = header_parts[1] try: - claims = jwt.decode(token, Settings.JWT_SECRET, algorithms=["HS256"]) - except jwt_exceptions.ExpiredSignatureError: - raise http_errors.UnauthorizedError("Token has expired.") from None - except (jwt_exceptions.InvalidTokenError, jwt_exceptions.PyJWTError): - raise http_errors.UnauthorizedError from None + jwt_claims = jwt.decode(token, Settings.JWT_SECRET, algorithms=["HS256"]) + except jwt_exceptions.ExpiredSignatureError as exc: + raise http_errors.UnauthorizedError("Token has expired.") from exc + except (jwt_exceptions.InvalidTokenError, jwt_exceptions.PyJWTError) as exc: + raise http_errors.UnauthorizedError from exc try: - claims = schemas.JWTClaims(**claims) - except PydanticValidationError: - raise http_errors.UnauthorizedError from None - return claims - - -def get_current_worker( - session: DbSession, - claims: Annotated[schemas.JWTClaims | None, Depends(verify_authorization_header)], -) -> models.Worker: - if claims is None: - raise http_errors.UnauthorizedError + claims = schemas.JWTClaims(**jwt_claims) + except PydanticValidationError as exc: + raise http_errors.UnauthorizedError from exc # At this point, we know that the JWT is all OK and we can # trust the data in it. We extract the worker_id from the claims - worker = db.worker.get_worker(session, claims.subject) - if worker is None: - raise http_errors.UnauthorizedError - return worker + db_worker = worker.get_worker(session, claims.subject) + if db_worker is None: + raise http_errors.UnauthorizedError() + return db_worker CurrentWorker = Annotated[models.Worker, Depends(get_current_worker)] + + +def get_test(session: DbSession, test_id: Annotated[UUID4, Path()]) -> models.Test: + """Fetches the test specified in the request.""" + test = tests.get_test(session, test_id) + if test is None: + raise http_errors.NotFoundError(f"Test with id {test_id} does not exist.") + return test + + +GetTest = Annotated[models.Test, Depends(get_test)] + + +def verify_worker_owns_test(worker: CurrentWorker, test: GetTest): + if test.worker_id != worker.id: + raise http_errors.UnauthorizedError("Insufficient privileges to update test.") diff --git a/backend/src/mirrors_qa_backend/routes/http_errors.py b/backend/src/mirrors_qa_backend/routes/http_errors.py index 800c3f1..e1967c8 100644 --- a/backend/src/mirrors_qa_backend/routes/http_errors.py +++ b/backend/src/mirrors_qa_backend/routes/http_errors.py @@ -21,6 +21,17 @@ def __init__(self, message: Any = None) -> None: ) +class ForbiddenError(HTTPException): + def __init__(self, message: Any = None) -> None: + if message is None: + message = "Identity unknown to server." + super().__init__( + status_code=status.HTTP_403_FORBIDDEN, + detail=message, + headers={"WWW-Authenticate": "Bearer"}, + ) + + class NotFoundError(HTTPException): def __init__(self, message: Any) -> None: super().__init__(status_code=status.HTTP_404_NOT_FOUND, detail=message) diff --git a/backend/src/mirrors_qa_backend/routes/tests.py b/backend/src/mirrors_qa_backend/routes/tests.py index 5ab1448..2dfb953 100644 --- a/backend/src/mirrors_qa_backend/routes/tests.py +++ b/backend/src/mirrors_qa_backend/routes/tests.py @@ -1,12 +1,17 @@ from typing import Annotated -from fastapi import APIRouter, Query +from fastapi import APIRouter, Depends, Query from fastapi import status as status_codes -from pydantic import UUID4 -from mirrors_qa_backend import db, schemas, serializer +from mirrors_qa_backend import schemas, serializer +from mirrors_qa_backend.db import tests from mirrors_qa_backend.enums import SortDirectionEnum, StatusEnum, TestSortColumnEnum -from mirrors_qa_backend.routes import CurrentWorker, DbSession, http_errors +from mirrors_qa_backend.routes.dependencies import ( + CurrentWorker, + DbSession, + GetTest, + verify_worker_owns_test, +) from mirrors_qa_backend.settings import Settings router = APIRouter(prefix="/tests", tags=["tests"]) @@ -31,11 +36,11 @@ def list_tests( sort_by: Annotated[TestSortColumnEnum, Query()] = TestSortColumnEnum.requested_on, order: Annotated[SortDirectionEnum, Query()] = SortDirectionEnum.asc, ) -> schemas.TestsList: - result = db.tests.list_tests( + result = tests.list_tests( session, worker_id=worker_id, country=country, - status=status, + statuses=status, page_size=page_size, page_num=page_num, sort_column=sort_by, @@ -59,10 +64,7 @@ def list_tests( }, }, ) -def get_test(test_id: UUID4, session: DbSession) -> schemas.Test: - test = db.tests.get_test(session, test_id) - if test is None: - raise http_errors.NotFoundError(f"Test with id '{test_id}' does not exist.") +def get_test(test: GetTest) -> schemas.Test: return serializer.serialize_test(test) @@ -72,26 +74,19 @@ def get_test(test_id: UUID4, session: DbSession) -> schemas.Test: responses={ status_codes.HTTP_200_OK: {"description": "Update the details of a test."}, }, + dependencies=[Depends(verify_worker_owns_test)], ) def update_test( session: DbSession, worker: CurrentWorker, - test_id: UUID4, + test: GetTest, update: schemas.UpdateTestModel, ) -> schemas.Test: data = update.model_dump(exclude_unset=True) body = schemas.UpdateTestModel().model_copy(update=data) - # Ensure that the worker is the one who the test belongs to - test = db.tests.get_test(session, test_id) - if test is None: - raise http_errors.NotFoundError(f"Test with id {test_id} does not exist.") - - if test.worker_id != worker.id: - raise http_errors.UnauthorizedError("Insufficient privileges to update test.") - - updated_test = db.tests.create_or_update_test( + updated_test = tests.create_or_update_test( session, - test_id=test_id, + test_id=test.id, worker_id=worker.id, status=body.status, error=body.error, diff --git a/backend/src/mirrors_qa_backend/schemas.py b/backend/src/mirrors_qa_backend/schemas.py index 5795b4a..16d61cc 100644 --- a/backend/src/mirrors_qa_backend/schemas.py +++ b/backend/src/mirrors_qa_backend/schemas.py @@ -74,7 +74,7 @@ def calculate_pagination_metadata( return Paginator( total_records=total_records, first_page=1, - page_size=page_size, + page_size=min(page_size, total_records), current_page=current_page, last_page=math.ceil(total_records / page_size), ) diff --git a/backend/tests/conftest.py b/backend/tests/conftest.py index 4ffb092..1c20b96 100644 --- a/backend/tests/conftest.py +++ b/backend/tests/conftest.py @@ -1,3 +1,5 @@ +import base64 +import datetime from collections.abc import Generator from typing import Any @@ -10,6 +12,7 @@ from faker.providers import DynamicProvider from sqlalchemy.orm import Session as OrmSession +from mirrors_qa_backend.cryptography import sign_message from mirrors_qa_backend.db import Session, models from mirrors_qa_backend.enums import StatusEnum @@ -26,8 +29,14 @@ def dbsession() -> Generator[OrmSession, None, None]: @pytest.fixture -def fake_data(faker: Faker) -> Faker: - """Adds additional providers to Faker.""" +def data_gen(faker: Faker) -> Faker: + """Adds additional providers to faker. + + Registers test_country and test_status as providers. + data_gen.test_status() returns a status. + data_gen.test_country() returns a country. + All other providers from Faker can be used accordingly. + """ test_status_provider = DynamicProvider( provider_name="test_status", elements=list(StatusEnum), @@ -49,7 +58,7 @@ def fake_data(faker: Faker) -> Faker: @pytest.fixture def tests( - dbsession: OrmSession, fake_data: Faker, worker: models.Worker, request: Any + dbsession: OrmSession, data_gen: Faker, worker: models.Worker, request: Any ) -> list[models.Test]: """Adds tests to the database using the num_test mark.""" mark = request.node.get_closest_marker("num_tests") @@ -60,8 +69,8 @@ def tests( tests = [ models.Test( - status=fake_data.test_status(), - country=fake_data.test_country(), + status=data_gen.test_status(), + country=data_gen.test_country(), ) for _ in range(num_tests) ] @@ -96,3 +105,15 @@ def worker(public_key: RSAPublicKey, dbsession: OrmSession) -> models.Worker: ) dbsession.add(worker) return worker + + +@pytest.fixture +def auth_message(worker: models.Worker) -> str: + return f"{worker.id}:{datetime.datetime.now(datetime.UTC).isoformat()}" + + +@pytest.fixture +def x_sshauth_signature(private_key: RSAPrivateKey, auth_message: str) -> str: + """Sign a message using RSA private key and encode it in base64""" + signature = sign_message(private_key, bytes(auth_message, encoding="ascii")) + return base64.b64encode(signature).decode() diff --git a/backend/tests/db/test_tests.py b/backend/tests/db/test_tests.py index 8200928..11b1dc5 100644 --- a/backend/tests/db/test_tests.py +++ b/backend/tests/db/test_tests.py @@ -5,21 +5,21 @@ from faker import Faker from sqlalchemy.orm import Session as OrmSession -from mirrors_qa_backend import db from mirrors_qa_backend.db import models +from mirrors_qa_backend.db import tests as db_tests from mirrors_qa_backend.enums import StatusEnum @pytest.mark.num_tests(1) def test_get_test(dbsession: OrmSession, tests: list[models.Test]): test = tests[0] - result = db.tests.get_test(dbsession, test.id) + result = db_tests.get_test(dbsession, test.id) assert result is not None assert result.id == test.id @pytest.mark.parametrize( - ["worker_id", "country", "status", "expect"], + ["worker_id", "country", "statuses", "expected"], [ (None, None, None, True), ("worker_id", None, None, False), @@ -32,19 +32,21 @@ def test_basic_filter( dbsession: OrmSession, worker_id: str | None, country: str | None, - status: list[StatusEnum] | None, - expect: bool, # noqa + statuses: list[StatusEnum] | None, + expected: bool, # noqa: FBT001 ): - test = db.tests.create_or_update_test(dbsession, status=StatusEnum.PENDING) + test = db_tests.create_or_update_test(dbsession, status=StatusEnum.PENDING) assert ( - db.tests.filter_test(test, worker_id=worker_id, country=country, status=status) - == expect + db_tests.filter_test( + test, worker_id=worker_id, country=country, statuses=statuses + ) + == expected ) @pytest.mark.num_tests @pytest.mark.parametrize( - ["worker_id", "country", "status"], + ["worker_id", "country", "statuses"], [ (None, None, None), (None, "Nigeria", None), @@ -57,39 +59,39 @@ def test_list_tests( tests: list[models.Test], worker_id: str | None, country: str | None, - status: list[StatusEnum] | None, + statuses: list[StatusEnum] | None, ): filtered_tests = [ test for test in tests - if db.tests.filter_test( - test, worker_id=worker_id, country=country, status=status + if db_tests.filter_test( + test, worker_id=worker_id, country=country, statuses=statuses ) ] - result = db.tests.list_tests( - dbsession, worker_id=worker_id, country=country, status=status + result = db_tests.list_tests( + dbsession, worker_id=worker_id, country=country, statuses=statuses ) assert len(filtered_tests) == result.nb_tests @pytest.mark.num_tests(1) -def test_update_test(dbsession: OrmSession, tests: list[models.Test], fake_data: Faker): +def test_update_test(dbsession: OrmSession, tests: list[models.Test], data_gen: Faker): test_id = tests[0].id download_size = 1_000_000 duration = 1_000 latency = 100 speed = download_size / duration update_values = { - "status": fake_data.test_status(), - "country": fake_data.test_country(), + "status": data_gen.test_status(), + "country": data_gen.test_country(), "download_size": download_size, "duration": duration, "speed": speed, - "ip_address": IPv4Address(fake_data.ipv4()), - "started_on": fake_data.date_time(datetime.UTC), + "ip_address": IPv4Address(data_gen.ipv4()), + "started_on": data_gen.date_time(datetime.UTC), "latency": latency, } - updated_test = db.tests.create_or_update_test(dbsession, test_id, **update_values) # type: ignore + updated_test = db_tests.create_or_update_test(dbsession, test_id, **update_values) # type: ignore for key, value in update_values.items(): if hasattr(updated_test, key): assert getattr(updated_test, key) == value diff --git a/backend/tests/routes/conftest.py b/backend/tests/routes/conftest.py index dacdc97..914665d 100644 --- a/backend/tests/routes/conftest.py +++ b/backend/tests/routes/conftest.py @@ -17,3 +17,18 @@ def test_dbsession() -> Generator[OrmSession, None, None]: app.dependency_overrides[gen_dbsession] = test_dbsession return TestClient(app=app) + + +@pytest.fixture +def access_token( + auth_message: str, x_sshauth_signature: str, client: TestClient +) -> str: + response = client.post( + "auth/authenticate", + headers={ + "Content-type": "application/json", + "X-SSHAuth-Message": auth_message, + "X-SSHAuth-Signature": x_sshauth_signature, + }, + ) + return response.json()["access_token"] diff --git a/backend/tests/routes/test_tests_endpoints.py b/backend/tests/routes/test_tests_endpoints.py index c119c04..2f5b0fc 100644 --- a/backend/tests/routes/test_tests_endpoints.py +++ b/backend/tests/routes/test_tests_endpoints.py @@ -1,13 +1,9 @@ -import base64 -import datetime import uuid import pytest -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from fastapi import status from fastapi.testclient import TestClient -from mirrors_qa_backend.cryptography import sign_message from mirrors_qa_backend.db import models @@ -48,30 +44,19 @@ def test_tests_list(client: TestClient, tests: list[models.Test]): (False, status.HTTP_401_UNAUTHORIZED), ], ) -def test_test_patch_success( - worker: models.Worker, - private_key: RSAPrivateKey, - client: TestClient, +def test_test_patch( tests: list[models.Test], - with_auth: bool, # noqa + client: TestClient, + access_token: str, + with_auth: bool, # noqa: FBT001 expected_status: int, ): test = tests[0] headers = {"Content-type": "application/json"} if with_auth: - message = f"{worker.id}:{datetime.datetime.now(datetime.UTC).isoformat()}" - signature = sign_message(private_key, bytes(message, encoding="ascii")) - x_sshauth_signature = base64.b64encode(signature).decode() - response = client.post( - "/auth/authenticate", - headers={ - "Content-type": "application/json", - "X-SSHAuth-Message": message, - "X-SSHAuth-Signature": x_sshauth_signature, - }, - ) - access_token = response.json()["access_token"] headers["Authorization"] = f"Bearer {access_token}" - response = client.patch(f"/tests/{test.id}", headers=headers, json={}) + response = client.patch( + f"/tests/{test.id}", headers=headers, json={"status": test.status.name} + ) assert response.status_code == expected_status