From 736999a92afae6d0da0495ece72a5e0360a1e17f Mon Sep 17 00:00:00 2001 From: "James K. Glasbrenner" Date: Tue, 17 Oct 2023 14:51:06 -0400 Subject: [PATCH] refactor: reorganize and cleanup the queue restapi endpoint Migrate towards a behavior-driven style of testing for the REST endpoints and away from heavy mocking. Add docstrings to the refactored queue service methods. --- src/dioptra/restapi/job/service.py | 22 +- src/dioptra/restapi/models.py | 9 +- src/dioptra/restapi/queue/controller.py | 197 +++------- src/dioptra/restapi/queue/dependencies.py | 16 +- src/dioptra/restapi/queue/errors.py | 18 +- src/dioptra/restapi/queue/interface.py | 65 --- src/dioptra/restapi/queue/model.py | 59 +-- src/dioptra/restapi/queue/schema.py | 112 ++---- src/dioptra/restapi/queue/service.py | 338 +++++++++++----- tests/unit/restapi/conftest.py | 22 +- tests/unit/restapi/queue/__init__.py | 16 - tests/unit/restapi/queue/conftest.py | 66 ---- tests/unit/restapi/queue/test_controller.py | 413 -------------------- tests/unit/restapi/queue/test_interface.py | 85 ---- tests/unit/restapi/queue/test_model.py | 64 --- tests/unit/restapi/queue/test_schema.py | 172 -------- tests/unit/restapi/queue/test_service.py | 269 ------------- tests/unit/restapi/test_queue.py | 381 ++++++++++++++++++ 18 files changed, 741 insertions(+), 1583 deletions(-) delete mode 100644 src/dioptra/restapi/queue/interface.py delete mode 100644 tests/unit/restapi/queue/__init__.py delete mode 100644 tests/unit/restapi/queue/conftest.py delete mode 100644 tests/unit/restapi/queue/test_controller.py delete mode 100644 tests/unit/restapi/queue/test_interface.py delete mode 100644 tests/unit/restapi/queue/test_model.py delete mode 100644 tests/unit/restapi/queue/test_schema.py delete mode 100644 tests/unit/restapi/queue/test_service.py create mode 100644 tests/unit/restapi/test_queue.py diff --git a/src/dioptra/restapi/job/service.py b/src/dioptra/restapi/job/service.py index 67b60a6f2..4c63215c7 100644 --- a/src/dioptra/restapi/job/service.py +++ b/src/dioptra/restapi/job/service.py @@ -20,7 +20,7 @@ import datetime import uuid from pathlib import Path -from typing import List, Optional +from typing import List, Optional, cast import structlog from injector import inject @@ -31,8 +31,7 @@ from dioptra.restapi.app import db from dioptra.restapi.experiment.errors import ExperimentDoesNotExistError from dioptra.restapi.experiment.service import ExperimentService -from dioptra.restapi.queue.errors import QueueDoesNotExistError -from dioptra.restapi.queue.service import QueueService +from dioptra.restapi.queue.service import QueueNameService from dioptra.restapi.shared.rq.service import RQService from dioptra.restapi.shared.s3.service import S3Service @@ -51,13 +50,13 @@ def __init__( rq_service: RQService, s3_service: S3Service, experiment_service: ExperimentService, - queue_service: QueueService, + queue_name_service: QueueNameService, ) -> None: self._job_form_schema = job_form_schema self._rq_service = rq_service self._s3_service = s3_service self._experiment_service = experiment_service - self._queue_service = queue_service + self._queue_name_service = queue_name_service @staticmethod def create(job_form_data: JobFormData, **kwargs) -> Job: @@ -101,13 +100,16 @@ def extract_data_from_form(self, job_form: JobForm, **kwargs) -> JobFormData: if experiment is None: raise ExperimentDoesNotExistError - queue: Optional[Queue] = self._queue_service.get_unlocked_by_name( - job_form_data["queue"], log=log + queue = cast( + Queue, + self._queue_name_service.get( + job_form_data["queue"], + unlocked_only=True, + error_if_not_found=True, + log=log, + ), ) - if queue is None: - raise QueueDoesNotExistError - job_form_data["experiment_id"] = experiment.experiment_id job_form_data["queue_id"] = queue.queue_id diff --git a/src/dioptra/restapi/models.py b/src/dioptra/restapi/models.py index 4c2ee788e..e5f9c8a2a 100644 --- a/src/dioptra/restapi/models.py +++ b/src/dioptra/restapi/models.py @@ -23,12 +23,7 @@ ExperimentRegistrationFormData, ) from .job.model import Job, JobForm, JobFormData -from .queue.model import ( - Queue, - QueueLock, - QueueRegistrationForm, - QueueRegistrationFormData, -) +from .queue.model import Queue, QueueLock from .task_plugin.model import ( TaskPlugin, TaskPluginUploadForm, @@ -44,8 +39,6 @@ "JobForm", "JobFormData", "Queue", - "QueueRegistrationForm", - "QueueRegistrationFormData", "QueueLock", "TaskPlugin", "TaskPluginUploadForm", diff --git a/src/dioptra/restapi/queue/controller.py b/src/dioptra/restapi/queue/controller.py index 13ac3c73f..be37d4767 100644 --- a/src/dioptra/restapi/queue/controller.py +++ b/src/dioptra/restapi/queue/controller.py @@ -18,23 +18,25 @@ from __future__ import annotations import uuid -from typing import List, Optional +from typing import Any, cast import structlog -from flask import jsonify, request -from flask.wrappers import Response +from flask import request from flask_accepts import accepts, responds from flask_restx import Namespace, Resource from injector import inject from structlog.stdlib import BoundLogger -from dioptra.restapi.utils import as_api_parser +from dioptra.restapi.utils import slugify -from .errors import QueueDoesNotExistError, QueueRegistrationError -from .interface import QueueUpdateInterface -from .model import Queue, QueueRegistrationForm, QueueRegistrationFormData -from .schema import QueueNameUpdateSchema, QueueRegistrationSchema, QueueSchema -from .service import QueueService +from .model import Queue +from .schema import ( + IdStatusResponseSchema, + NameStatusResponseSchema, + QueueEditableSchema, + QueueSchema, +) +from .service import QueueNameService, QueueService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -54,39 +56,25 @@ def __init__(self, *args, queue_service: QueueService, **kwargs) -> None: super().__init__(*args, **kwargs) @responds(schema=QueueSchema(many=True), api=api) - def get(self) -> List[Queue]: - """Gets a list of all registered queues.""" + def get(self) -> list[Queue]: + """Gets a list of all active queues.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="queue", request_type="GET" ) # noqa: F841 log.info("Request received") return self._queue_service.get_all_unlocked(log=log) - @api.expect(as_api_parser(api, QueueRegistrationSchema)) - @accepts(QueueRegistrationSchema, api=api) + @accepts(schema=QueueEditableSchema, api=api) @responds(schema=QueueSchema, api=api) def post(self) -> Queue: - """Creates a new queue via a queue registration form.""" + """Registers a new queue.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="queue", request_type="POST" ) # noqa: F841 - queue_registration_form: QueueRegistrationForm = QueueRegistrationForm() - log.info("Request received") - - if not queue_registration_form.validate_on_submit(): - log.error("Form validation failed") - raise QueueRegistrationError - - log.info("Form validation successful") - queue_registration_form_data: QueueRegistrationFormData = ( - self._queue_service.extract_data_from_form( - queue_registration_form=queue_registration_form, log=log - ) - ) - return self._queue_service.create( - queue_registration_form_data=queue_registration_form_data, log=log - ) + parsed_obj = request.parsed_obj # type: ignore + name = slugify(str(parsed_obj["name"])) + return self._queue_service.create(name=name, log=log) @api.route("/") @@ -106,43 +94,29 @@ def get(self, queueId: int) -> Queue: request_id=str(uuid.uuid4()), resource="queueId", request_type="GET" ) # noqa: F841 log.info("Request received", queue_id=queueId) - queue: Optional[Queue] = self._queue_service.get_by_id(queueId, log=log) - - if queue is None: - log.error("Queue not found", queue_id=queueId) - raise QueueDoesNotExistError - - return queue + return cast( + Queue, self._queue_service.get(queueId, error_if_not_found=True, log=log) + ) - def delete(self, queueId: int) -> Response: + @responds(schema=IdStatusResponseSchema, api=api) + def delete(self, queueId: int) -> dict[str, Any]: """Deletes a queue by its unique identifier.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="queueId", request_type="DELETE" ) # noqa: F841 log.info("Request received", queue_id=queueId) - id: List[int] = self._queue_service.delete_queue(queueId, log=log) + return self._queue_service.delete(queueId, log=log) - return jsonify(dict(status="Success", id=id)) - - @accepts(schema=QueueNameUpdateSchema, api=api) + @accepts(schema=QueueEditableSchema, api=api) @responds(schema=QueueSchema, api=api) def put(self, queueId: int) -> Queue: """Modifies a queue by its unique identifier.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="queueId", request_type="PUT" ) # noqa: F841 - changes: QueueUpdateInterface = request.parsed_obj # type: ignore - queue: Optional[Queue] = self._queue_service.get_by_id(queueId, log=log) - - if queue is None: - log.error("Queue not found", queue_id=queueId) - raise QueueDoesNotExistError - - queue = self._queue_service.rename_queue( - queue=queue, new_name=changes["name"], log=log - ) - - return queue + parsed_obj = request.parsed_obj # type: ignore + new_name = slugify(str(parsed_obj["name"])) + return self._queue_service.rename(queueId, new_name=new_name, log=log) @api.route("//lock") @@ -155,37 +129,23 @@ def __init__(self, *args, queue_service: QueueService, **kwargs) -> None: self._queue_service = queue_service super().__init__(*args, **kwargs) - def delete(self, queueId: int) -> Response: + @responds(schema=IdStatusResponseSchema, api=api) + def delete(self, queueId: int) -> dict[str, Any]: """Removes the lock from the queue (id reference) if it exists.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="QueueIdLock", request_type="DELETE" ) # noqa: F841 log.info("Request received", queue_id=queueId) - queue: Optional[Queue] = self._queue_service.get_by_id(queueId, log=log) - - if queue is None: - log.error("Queue not found", queue_id=queueId) - raise QueueDoesNotExistError + return self._queue_service.unlock(queueId, log=log) - id: List[int] = self._queue_service.unlock_queue(queue, log=log) - - return jsonify(dict(status="Success", id=id)) - - def put(self, queueId: int) -> Queue: + @responds(schema=IdStatusResponseSchema, api=api) + def put(self, queueId: int) -> dict[str, Any]: """Locks the queue (id reference) if it is unlocked.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="QueueIdLock", request_type="PUT" ) # noqa: F841 log.info("Request received", queue_id=queueId) - queue: Optional[Queue] = self._queue_service.get_by_id(queueId, log=log) - - if queue is None: - log.error("Queue not found", queue_id=queueId) - raise QueueDoesNotExistError - - id: List[int] = self._queue_service.lock_queue(queue, log=log) - - return jsonify(dict(status="Success", id=id)) # type: ignore + return self._queue_service.lock(queueId, log=log) @api.route("/name/") @@ -194,8 +154,8 @@ class QueueNameResource(Resource): """Shows a single queue (name reference) and lets you modify and delete it.""" @inject - def __init__(self, *args, queue_service: QueueService, **kwargs) -> None: - self._queue_service = queue_service + def __init__(self, *args, queue_name_service: QueueNameService, **kwargs) -> None: + self._queue_name_service = queue_name_service super().__init__(*args, **kwargs) @responds(schema=QueueSchema, api=api) @@ -204,18 +164,16 @@ def get(self, queueName: str) -> Queue: log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="queueName", request_type="GET" ) # noqa: F841 - log.info("Request received", queue_name=queueName) - queue: Optional[Queue] = self._queue_service.get_by_name( - queue_name=queueName, log=log + log.info("Request received", queue_name=slugify(queueName)) + return cast( + Queue, + self._queue_name_service.get( + slugify(queueName), error_if_not_found=True, log=log + ), ) - if queue is None: - log.error("Queue not found", queue_name=queueName) - raise QueueDoesNotExistError - - return queue - - def delete(self, queueName: str) -> Response: + @responds(schema=NameStatusResponseSchema, api=api) + def delete(self, queueName: str) -> dict[str, Any]: """Deletes a queue by its unique name.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), @@ -223,39 +181,8 @@ def delete(self, queueName: str) -> Response: queue_name=queueName, request_type="DELETE", ) # noqa: F841 - log.info("Request received") - queue: Optional[Queue] = self._queue_service.get_by_name( - queue_name=queueName, log=log - ) - - if queue is None: - return jsonify(dict(status="Success", id=[])) - - id: List[int] = self._queue_service.delete_queue(queue_id=queue.queue_id) - - return jsonify(dict(status="Success", id=id)) - - @accepts(schema=QueueNameUpdateSchema, api=api) - @responds(schema=QueueSchema, api=api) - def put(self, queueName: str) -> Queue: - """Modifies a queue by its unique name.""" - log: BoundLogger = LOGGER.new( - request_id=str(uuid.uuid4()), resource="queueName", request_type="PUT" - ) # noqa: F841 - changes: QueueUpdateInterface = request.parsed_obj # type: ignore - queue: Optional[Queue] = self._queue_service.get_by_name( - queue_name=queueName, log=log - ) - - if queue is None: - log.error("Queue not found", queue_name=queueName) - raise QueueDoesNotExistError - - queue = self._queue_service.rename_queue( - queue=queue, new_name=changes["name"], log=log - ) - - return queue + log.info("Request received", queue_name=slugify(queueName)) + return self._queue_name_service.delete(slugify(queueName), log=log) @api.route("/name//lock") @@ -264,11 +191,12 @@ class QueueNameLockResource(Resource): """Lets you put a lock on a queue (name reference) and lets you delete it.""" @inject - def __init__(self, *args, queue_service: QueueService, **kwargs) -> None: - self._queue_service = queue_service + def __init__(self, *args, queue_name_service: QueueNameService, **kwargs) -> None: + self._queue_name_service = queue_name_service super().__init__(*args, **kwargs) - def delete(self, queueName: str) -> Response: + @responds(schema=NameStatusResponseSchema, api=api) + def delete(self, queueName: str) -> dict[str, Any]: """Removes the lock from the queue (name reference) if it exists.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), @@ -276,30 +204,13 @@ def delete(self, queueName: str) -> Response: request_type="DELETE", ) # noqa: F841 log.info("Request received", queue_name=queueName) - queue: Optional[Queue] = self._queue_service.get_by_name(queueName, log=log) + return self._queue_name_service.unlock(queueName, log=log) - if queue is None: - log.error("Queue not found", queue_name=queueName) - raise QueueDoesNotExistError - - id: List[int] = self._queue_service.unlock_queue(queue, log=log) - name: List[str] = [queueName] if id else [] - - return jsonify(dict(status="Success", name=name)) - - def put(self, queueName: str) -> Queue: + @responds(schema=NameStatusResponseSchema, api=api) + def put(self, queueName: str) -> dict[str, Any]: """Locks the queue (name reference) if it is unlocked.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="QueueNameLock", request_type="PUT" ) # noqa: F841 log.info("Request received", queue_name=queueName) - queue: Optional[Queue] = self._queue_service.get_by_name(queueName, log=log) - - if queue is None: - log.error("Queue not found", queue_name=queueName) - raise QueueDoesNotExistError - - id: List[int] = self._queue_service.lock_queue(queue, log=log) - name: List[str] = [queueName] if id else [] - - return jsonify(dict(status="Success", name=name)) # type: ignore + return self._queue_name_service.lock(queueName, log=log) diff --git a/src/dioptra/restapi/queue/dependencies.py b/src/dioptra/restapi/queue/dependencies.py index 08bf07216..8de4637ed 100644 --- a/src/dioptra/restapi/queue/dependencies.py +++ b/src/dioptra/restapi/queue/dependencies.py @@ -17,19 +17,19 @@ """Binding configurations to shared services using dependency injection.""" from __future__ import annotations -from typing import Any, Callable, List +from typing import Any, Callable from injector import Binder, Module, provider -from .schema import QueueRegistrationFormSchema +from .service import QueueNameService -class QueueRegistrationFormSchemaModule(Module): +class QueueNameServiceModule(Module): @provider - def provide_queue_registration_form_schema_module( + def provide_queue_name_service_module( self, - ) -> QueueRegistrationFormSchema: - return QueueRegistrationFormSchema() + ) -> QueueNameService: + return QueueNameService() def bind_dependencies(binder: Binder) -> None: @@ -41,11 +41,11 @@ def bind_dependencies(binder: Binder) -> None: pass -def register_providers(modules: List[Callable[..., Any]]) -> None: +def register_providers(modules: list[Callable[..., Any]]) -> None: """Registers type providers within the main application. Args: modules: A list of callables used for configuring the dependency injection environment. """ - modules.append(QueueRegistrationFormSchemaModule) + modules.append(QueueNameServiceModule) diff --git a/src/dioptra/restapi/queue/errors.py b/src/dioptra/restapi/queue/errors.py index 7f1588033..b7fe2f04f 100644 --- a/src/dioptra/restapi/queue/errors.py +++ b/src/dioptra/restapi/queue/errors.py @@ -28,8 +28,8 @@ class QueueDoesNotExistError(Exception): """The requested queue does not exist.""" -class QueueRegistrationError(Exception): - """The queue registration form contains invalid parameters.""" +class QueueLockedError(Exception): + """The requested queue is locked.""" def register_error_handlers(api: Api) -> None: @@ -37,6 +37,10 @@ def register_error_handlers(api: Api) -> None: def handle_queue_does_not_exist_error(error): return {"message": "Not Found - The requested queue does not exist"}, 404 + @api.errorhandler(QueueLockedError) + def handle_queue_locked_error(error): + return {"message": "Forbidden - The requested queue is locked."}, 403 + @api.errorhandler(QueueAlreadyExistsError) def handle_queue_already_exists_error(error): return ( @@ -46,13 +50,3 @@ def handle_queue_already_exists_error(error): }, 400, ) - - @api.errorhandler(QueueRegistrationError) - def handle_queue_registration_error(error): - return ( - { - "message": "Bad Request - The queue registration form contains " - "invalid parameters. Please verify and resubmit." - }, - 400, - ) diff --git a/src/dioptra/restapi/queue/interface.py b/src/dioptra/restapi/queue/interface.py deleted file mode 100644 index c22ff778c..000000000 --- a/src/dioptra/restapi/queue/interface.py +++ /dev/null @@ -1,65 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -"""The interfaces for creating and updating |Queue| objects. - -.. |Queue| replace:: :py:class:`~.model.Queue` -""" -from __future__ import annotations - -import datetime - -from typing_extensions import TypedDict - - -class QueueInterface(TypedDict, total=False): - """The interface for constructing a new |Queue| object. - - Attributes: - queue_id: An integer identifying a registered queue. - created_on: The date and time the queue was created. - last_modified: The date and time the queue was last modified. - name: The name of the queue. - """ - - queue_id: int - created_on: datetime.datetime - last_modified: datetime.datetime - name: str - - -class QueueLockInterface(TypedDict, total=False): - """The interface for constructing a new |QueueLock| object. - - Attributes: - queue_id: An integer identifying a registered queue. - created_on: The date and time the queue lock was created. - """ - - queue_id: int - created_on: datetime.datetime - - -class QueueUpdateInterface(TypedDict, total=False): - """The interface for updating a |Queue| object. - - Attributes: - name: The name of the queue. - is_deleted: A boolean that indicates if the queue record is deleted. - """ - - name: str - is_deleted: bool diff --git a/src/dioptra/restapi/queue/model.py b/src/dioptra/restapi/queue/model.py index bd11a69a5..c55b520a1 100644 --- a/src/dioptra/restapi/queue/model.py +++ b/src/dioptra/restapi/queue/model.py @@ -18,17 +18,9 @@ from __future__ import annotations import datetime -from typing import Optional - -from flask_wtf import FlaskForm -from typing_extensions import TypedDict -from wtforms.fields import StringField -from wtforms.validators import InputRequired, ValidationError +from typing import Any from dioptra.restapi.app import db -from dioptra.restapi.utils import slugify - -from .interface import QueueUpdateInterface class QueueLock(db.Model): @@ -78,19 +70,18 @@ class Queue(db.Model): @classmethod def next_id(cls) -> int: """Generates the next id in the sequence.""" - queue: Optional[Queue] = cls.query.order_by(cls.queue_id.desc()).first() + queue: Queue | None = cls.query.order_by(cls.queue_id.desc()).first() if queue is None: return 1 return int(queue.queue_id) + 1 - def update(self, changes: QueueUpdateInterface): + def update(self, changes: dict[str, Any]) -> Queue: """Updates the record. Args: - changes: A :py:class:`~.interface.QueueUpdateInterface` dictionary - containing record updates. + changes: A dictionary containing record updates. """ self.last_modified = datetime.datetime.now() @@ -98,45 +89,3 @@ def update(self, changes: QueueUpdateInterface): setattr(self, key, val) return self - - -class QueueRegistrationForm(FlaskForm): - """The queue registration form. - - Attributes: - name: The name to register as a new queue. - """ - - name = StringField( - "Name of Queue", - validators=[InputRequired()], - description="The name to register as a new queue.", - ) - - def validate_name(self, field: StringField) -> None: - """Validates that the queue does not exist in the registry. - - Args: - field: The form field for `name`. - """ - - standardized_name: str = slugify(field.data) - - if ( - Queue.query.filter_by(name=standardized_name, is_deleted=False).first() - is not None - ): - raise ValidationError( - "Bad Request - A queue is already registered under the name " - f"{standardized_name}. Please select another and resubmit." - ) - - -class QueueRegistrationFormData(TypedDict, total=False): - """The data extracted from the queue registration form. - - Attributes: - name: The name of the queue. - """ - - name: str diff --git a/src/dioptra/restapi/queue/schema.py b/src/dioptra/restapi/queue/schema.py index b2a229668..7add9b6e1 100644 --- a/src/dioptra/restapi/queue/schema.py +++ b/src/dioptra/restapi/queue/schema.py @@ -17,52 +17,14 @@ """The schemas for serializing/deserializing the queue endpoint objects. .. |Queue| replace:: :py:class:`~.model.Queue` -.. |QueueLock| replace:: :py:class:`~.model.QueueLock` -.. |QueueRegistrationForm| replace:: :py:class:`~.model.QueueRegistrationForm` -.. |QueueRegistrationFormData| replace:: :py:class:`~.model.QueueRegistrationFormData` """ from __future__ import annotations -from typing import Any, Dict - -from marshmallow import Schema, fields, post_dump, pre_dump - -from dioptra.restapi.utils import ParametersSchema, slugify - -from .model import Queue, QueueLock, QueueRegistrationForm, QueueRegistrationFormData - - -class QueueLockSchema(Schema): - """The schema for the data stored in a |QueueLock| object. - - Attributes: - queueId: An integer identifying a registered queue. - createdOn: The date and time the queue lock was created. - """ - - __model__ = QueueLock - - queueId = fields.Integer( - attribute="queue_id", - metadata=dict(description="An integer identifying a registered queue."), - ) - createdOn = fields.DateTime( - attribute="created_on", - metadata=dict(description="The date and time the queue lock was created."), - ) +from marshmallow import Schema, fields class QueueSchema(Schema): - """The schema for the data stored in a |Queue| object. - - Attributes: - queueId: An integer identifying a registered queue. - createdOn: The date and time the queue was created. - lastModified: The date and time the queue was last modified. - name: The name of the queue. - """ - - __model__ = Queue + """The schema for the data stored in a |Queue| object.""" queueId = fields.Integer( attribute="queue_id", @@ -81,58 +43,42 @@ class QueueSchema(Schema): ) -class QueueNameUpdateSchema(Schema): - """The schema for the data used to update a |Queue| object. - - Attributes: - name: The new name for the queue. Must be unique. - """ - - __model__ = Queue +class QueueEditableSchema(Schema): + """The settable fields on a |Queue| object.""" name = fields.String( attribute="name", - metadata=dict(description="The new name for the queue. Must be unique."), + metadata=dict(description="The name of the queue. Must be unique."), ) -class QueueRegistrationFormSchema(Schema): - """The schema for the information stored in a queue registration form. - - Attributes: - name: The name of the queue. Must be unique. - """ +class IdStatusResponseSchema(Schema): + """A simple response for reporting a status for one or more objects.""" - __model__ = QueueRegistrationFormData - - name = fields.String( - attribute="name", - required=True, - metadata=dict(description="The name of the queue. Must be unique."), + status = fields.String( + attribute="status", + metadata=dict(description="The status of the request."), + ) + id = fields.List( + fields.Integer(), + attribute="id", + metadata=dict( + description="A list of integers identifying the affected object(s)." + ), ) - @pre_dump - def extract_data_from_form( - self, data: QueueRegistrationForm, many: bool, **kwargs - ) -> Dict[str, Any]: - """Extracts data from the |QueueRegistrationForm| for validation.""" - - return {"name": slugify(data.name.data)} - - @post_dump - def serialize_object( - self, data: Dict[str, Any], many: bool, **kwargs - ) -> QueueRegistrationFormData: - """Creates a |QueueRegistrationFormData| object from the validated data.""" - return self.__model__(**data) +class NameStatusResponseSchema(Schema): + """A simple response for reporting a status for one or more objects.""" -QueueRegistrationSchema: list[ParametersSchema] = [ - dict( - name="name", - type=str, - location="form", - required=True, - help="The name of the queue. Must be unique.", + status = fields.String( + attribute="status", + metadata=dict(description="The status of the request."), + ) + name = fields.List( + fields.String(), + attribute="name", + metadata=dict( + description="A list of names identifying the affected object(s)." + ), ) -] diff --git a/src/dioptra/restapi/queue/service.py b/src/dioptra/restapi/queue/service.py index 228c0e781..483c3e9b2 100644 --- a/src/dioptra/restapi/queue/service.py +++ b/src/dioptra/restapi/queue/service.py @@ -18,7 +18,7 @@ from __future__ import annotations import datetime -from typing import List, Optional +from typing import Any, cast import structlog from injector import inject @@ -26,197 +26,321 @@ from dioptra.restapi.app import db -from .errors import QueueAlreadyExistsError -from .model import Queue, QueueLock, QueueRegistrationForm, QueueRegistrationFormData -from .schema import QueueRegistrationFormSchema +from .errors import QueueAlreadyExistsError, QueueDoesNotExistError, QueueLockedError +from .model import Queue, QueueLock LOGGER: BoundLogger = structlog.stdlib.get_logger() class QueueService(object): + """The service methods for registering and managing queues by their unique id.""" + @inject def __init__( self, - queue_registration_form_schema: QueueRegistrationFormSchema, + name_service: QueueNameService, ) -> None: - self._queue_registration_form_schema = queue_registration_form_schema + """Initialize the queue service. + + All arguments are provided via dependency injection. + + Args: + name_service: The queue name service. + """ + self._name_service = name_service def create( self, - queue_registration_form_data: QueueRegistrationFormData, + name: str, **kwargs, ) -> Queue: + """Create a new queue. + + Args: + name: The name of the queue. + + Returns: + The newly created queue object. + + Raises: + QueueAlreadyExistsError: If a queue with the given name already exists. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - queue_name: str = queue_registration_form_data["name"] - if self.get_by_name(queue_name, log=log) is not None: + if self._name_service.get(name, log=log) is not None: raise QueueAlreadyExistsError timestamp = datetime.datetime.now() new_queue: Queue = Queue( queue_id=Queue.next_id(), - name=queue_name, + name=name, created_on=timestamp, last_modified=timestamp, ) db.session.add(new_queue) db.session.commit() - log.info( "Queue registration successful", queue_id=new_queue.queue_id, name=new_queue.name, ) - return new_queue - @staticmethod - def lock_queue(queue: Queue, **kwargs) -> List[int]: + def get( + self, + queue_id: int, + unlocked_only: bool = False, + error_if_not_found: bool = False, + **kwargs, + ) -> Queue | None: + """Fetch a queue by its unique id. + + Args: + queue_id: The unique id of the queue. + unlocked_only: If True, raise an error if the queue is locked. Defaults to + False. + error_if_not_found: If True, raise an error if the queue is not found. + Defaults to False. + + Returns: + The queue object if found, otherwise None. + + Raises: + QueueDoesNotExistError: If the queue is not found and `error_if_not_found` + is True. + QueueLockedError: If the queue is locked and `unlocked_only` is True. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.info("Get queue by id", queue_id=queue_id) + queue = Queue.query.filter_by(queue_id=queue_id, is_deleted=False).first() - if queue.lock: - return [] + if queue is None: + if error_if_not_found: + log.error("Queue not found", queue_id=queue_id) + raise QueueDoesNotExistError - queue.lock.append(QueueLock()) - db.session.commit() + return None - log.info("Queue locked", queue_id=queue.queue_id) + if queue.lock and unlocked_only: + log.error("Queue is locked", queue_id=queue_id) + raise QueueLockedError + + return cast(Queue, queue) - return [queue.queue_id] + def get_all(self, **kwargs) -> list[Queue]: + """Fetch the list of all queues. - @staticmethod - def unlock_queue(queue: Queue, **kwargs) -> List[int]: + Returns: + A list of queues. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.info("Get full list of queues") + return Queue.query.filter_by(is_deleted=False).all() # type: ignore - if not queue.lock: - return [] + def get_all_unlocked(self, **kwargs) -> list[Queue]: + """Fetch the list of all unlocked queues. - db.session.delete(queue.lock[0]) + Returns: + A list of queues. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.info("Get full list of unlocked queues") + return ( # type: ignore + Queue.query.outerjoin(QueueLock, Queue.queue_id == QueueLock.queue_id) + .filter( + QueueLock.queue_id == None, # noqa: E711 + Queue.is_deleted == False, # noqa: E712 + ) + .all() + ) + + def rename(self, queue_id: int, new_name: str, **kwargs) -> Queue: + """Rename a queue. + + Args: + queue_id: The unique id of the queue. + new_name: The new name of the queue. + + Returns: + The updated queue object. + + Raises: + QueueDoesNotExistError: If the queue is not found. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + queue = cast(Queue, self.get(queue_id, error_if_not_found=True, log=log)) + queue.update(changes={"name": new_name}) db.session.commit() + log.info("Queue renamed", queue_id=queue.queue_id, new_name=new_name) + return queue - log.info("Queue unlocked", queue_id=queue.queue_id) + def delete(self, queue_id: int, **kwargs) -> dict[str, Any]: + """Delete a queue. - return [queue.queue_id] + Args: + queue_id: The unique id of the queue. - def delete_queue(self, queue_id: int, **kwargs) -> List[int]: + Returns: + A dictionary reporting the status of the request. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - queue: Optional[Queue] = self.get_by_id(queue_id=queue_id) - if queue is None: - return [] + if (queue := self.get(queue_id, log=log)) is None: + return {"status": "Success", "id": []} queue.update(changes={"is_deleted": True}) db.session.commit() - log.info("Queue deleted", queue_id=queue_id) + return {"status": "Success", "id": [queue_id]} + + def lock(self, queue_id: int, **kwargs) -> dict[str, Any]: + """Lock a queue. - return [queue_id] + Args: + queue_id: The unique id of the queue. - def rename_queue(self, queue: Queue, new_name: str, **kwargs) -> Queue: + Returns: + A dictionary reporting the status of the request. + + Raises: + QueueDoesNotExistError: If the queue is not found. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - queue.update(changes={"name": new_name}) + if (queue := self.get(queue_id, error_if_not_found=True, log=log)) is None: + return {"status": "Success", "id": []} + + queue.lock.append(QueueLock()) db.session.commit() + log.info("Queue locked", queue_id=queue.queue_id) + return {"status": "Success", "id": [queue.queue_id]} - log.info("Queue renamed", queue_id=queue.queue_id, new_name=new_name) + def unlock(self, queue_id: int, **kwargs) -> dict[str, Any]: + """Unlock a queue. - return queue + Args: + queue_id: The unique id of the queue. + + Returns: + A dictionary reporting the status of the request. - @staticmethod - def get_all(**kwargs) -> List[Queue]: + Raises: + QueueDoesNotExistError: If the queue is not found. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.info("Get full list of queues") + if (queue := self.get(queue_id, error_if_not_found=True, log=log)) is None: + return {"status": "Success", "id": []} - return Queue.query.filter_by(is_deleted=False).all() # type: ignore + db.session.delete(queue.lock[0]) + db.session.commit() + log.info("Queue unlocked", queue_id=queue.queue_id) + return {"status": "Success", "id": [queue.queue_id]} - @staticmethod - def get_all_unlocked(**kwargs) -> List[Queue]: + +class QueueNameService(object): + """The service methods for managing queues by their name.""" + + def get( + self, + queue_name: str, + unlocked_only: bool = False, + error_if_not_found: bool = False, + **kwargs, + ) -> Queue | None: + """Fetch a queue by its name. + + Args: + queue_name: The name of the queue. + unlocked_only: If True, raise an error if the queue is locked. Defaults to + False. + error_if_not_found: If True, raise an error if the queue is not found. + Defaults to False. + + Returns: + The queue object if found, otherwise None. + + Raises: + QueueDoesNotExistError: If the queue is not found and `error_if_not_found` + is True. + QueueLockedError: If the queue is locked and `unlocked_only` is True. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.info("Get queue by name", queue_name=queue_name) + queue = Queue.query.filter_by(name=queue_name, is_deleted=False).first() - log.info("Get full list of unlocked queues") + if queue is None: + if error_if_not_found: + log.error("Queue not found", name=queue_name) + raise QueueDoesNotExistError - return ( # type: ignore - Queue.query.outerjoin(QueueLock, Queue.queue_id == QueueLock.queue_id) - .filter( - QueueLock.queue_id == None, # noqa: E711 - Queue.is_deleted == False, # noqa: E712 - ) - .all() - ) + return None - @staticmethod - def get_all_locked(**kwargs) -> List[Queue]: - log: BoundLogger = kwargs.get("log", LOGGER.new()) + if queue.lock and unlocked_only: + log.error("Queue is locked", name=queue_name) + raise QueueLockedError - log.info("Get full list of locked queues") + return cast(Queue, queue) - return ( # type: ignore - Queue.query.join(QueueLock) - .filter(Queue.is_deleted == False) # noqa: E712 - .all() - ) + def delete(self, name: str, **kwargs) -> dict[str, Any]: + """Delete a queue. + + Args: + name: The name of the queue. - @staticmethod - def get_by_id(queue_id: int, **kwargs) -> Optional[Queue]: + Returns: + A dictionary reporting the status of the request. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.info("Get queue by id", queue_id=queue_id) + if (queue := self.get(name, log=log)) is None: + return {"status": "Success", "name": []} - return Queue.query.filter_by( # type: ignore - queue_id=queue_id, is_deleted=False - ).first() + queue.update(changes={"is_deleted": True}) + db.session.commit() + log.info("Queue deleted", name=name) + return {"status": "Success", "name": [name]} - @staticmethod - def get_by_name(queue_name: str, **kwargs) -> Optional[Queue]: - log: BoundLogger = kwargs.get("log", LOGGER.new()) + def lock(self, name: str, **kwargs) -> dict[str, Any]: + """Lock a queue. - log.info("Get queue by name", queue_name=queue_name) + Args: + name: The name of the queue. - return Queue.query.filter_by( # type: ignore - name=queue_name, is_deleted=False - ).first() + Returns: + A dictionary reporting the status of the request. - @staticmethod - def get_unlocked_by_id(queue_id: int, **kwargs) -> Optional[Queue]: + Raises: + QueueDoesNotExistError: If the queue is not found. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.info("Get unlocked queue by id", queue_id=queue_id) + if (queue := self.get(name, error_if_not_found=True, log=log)) is None: + return {"status": "Success", "name": []} - return ( # type: ignore - Queue.query.outerjoin(QueueLock, Queue.queue_id == QueueLock.queue_id) - .filter( - Queue.queue_id == queue_id, - QueueLock.queue_id == None, # noqa: E711 - Queue.is_deleted == False, # noqa: E712 - ) - .first() - ) + queue.lock.append(QueueLock()) + db.session.commit() + log.info("Queue locked", name=queue.name) + return {"status": "Success", "name": [queue.name]} - @staticmethod - def get_unlocked_by_name(queue_name: str, **kwargs) -> Optional[Queue]: - log: BoundLogger = kwargs.get("log", LOGGER.new()) + def unlock(self, name: str, **kwargs) -> dict[str, Any]: + """Unlock a queue. - log.info("Get unlocked queue by name", queue_name=queue_name) + Args: + name: The name of the queue. - return ( # type: ignore - Queue.query.outerjoin(QueueLock, Queue.queue_id == QueueLock.queue_id) - .filter( - Queue.name == queue_name, - QueueLock.queue_id == None, # noqa: E711 - Queue.is_deleted == False, # noqa: E712 - ) - .first() - ) + Returns: + A dictionary reporting the status of the request. - def extract_data_from_form( - self, queue_registration_form: QueueRegistrationForm, **kwargs - ) -> QueueRegistrationFormData: + Raises: + QueueDoesNotExistError: If the queue is not found. + """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.info("Extract data from queue registration form") - data: QueueRegistrationFormData = self._queue_registration_form_schema.dump( - queue_registration_form - ) + if (queue := self.get(name, error_if_not_found=True, log=log)) is None: + return {"status": "Success", "name": []} - return data + db.session.delete(queue.lock[0]) + db.session.commit() + log.info("Queue unlocked", name=name) + return {"status": "Success", "name": [queue.name]} diff --git a/tests/unit/restapi/conftest.py b/tests/unit/restapi/conftest.py index c7ddd45c4..26661d9e3 100644 --- a/tests/unit/restapi/conftest.py +++ b/tests/unit/restapi/conftest.py @@ -18,18 +18,19 @@ import io import tarfile -from typing import Any, BinaryIO, List +from typing import Any, BinaryIO, Iterable, List import pytest from _pytest.monkeypatch import MonkeyPatch from boto3.session import Session from botocore.client import BaseClient from flask import Flask +from flask.testing import FlaskClient from flask_restx import Api from flask_sqlalchemy import SQLAlchemy from injector import Binder, Injector from redis import Redis -from dioptra.restapi.utils import setup_injection + from dioptra.restapi.shared.request_scope import request @@ -96,7 +97,6 @@ def dependency_modules() -> List[Any]: RQServiceConfiguration, RQServiceModule, ) - from dioptra.restapi.queue.dependencies import QueueRegistrationFormSchemaModule from dioptra.restapi.task_plugin.dependencies import ( TaskPluginUploadFormSchemaModule, ) @@ -135,7 +135,6 @@ def configure(binder: Binder) -> None: ExperimentRegistrationFormSchemaModule(), JobFormSchemaModule(), PasswordServiceModule(), - QueueRegistrationFormSchemaModule(), RQServiceModule(), TaskPluginUploadFormSchemaModule(), UserRegistrationFormSchemaModule(), @@ -148,7 +147,9 @@ def dependency_injector(dependency_modules: List[Any]) -> Injector: @pytest.fixture -def app(dependency_modules: List[Any], monkeypatch: MonkeyPatch) -> Flask: +def app( + dependency_modules: List[Any], monkeypatch: MonkeyPatch +) -> Iterable[Flask]: import dioptra.restapi.routes from dioptra.restapi import create_app @@ -168,12 +169,14 @@ def register_test_routes(api: Api, app: Flask) -> None: attach_task_plugin(api, app) attach_user(api, app) - monkeypatch.setattr(dioptra.restapi.routes, "register_routes", register_test_routes) + monkeypatch.setattr( + dioptra.restapi.routes, "register_routes", register_test_routes + ) injector = Injector(dependency_modules) app = create_app(env="test", injector=injector) - return app + yield app @pytest.fixture @@ -203,3 +206,8 @@ def seed_database(db): ], ) db.session.commit() + + +@pytest.fixture +def client(app: Flask) -> FlaskClient: + return app.test_client() diff --git a/tests/unit/restapi/queue/__init__.py b/tests/unit/restapi/queue/__init__.py deleted file mode 100644 index ab0a41a34..000000000 --- a/tests/unit/restapi/queue/__init__.py +++ /dev/null @@ -1,16 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode diff --git a/tests/unit/restapi/queue/conftest.py b/tests/unit/restapi/queue/conftest.py deleted file mode 100644 index c4465d16e..000000000 --- a/tests/unit/restapi/queue/conftest.py +++ /dev/null @@ -1,66 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -from __future__ import annotations - -import datetime - -import pytest - -from dioptra.restapi.models import Queue, QueueLock - - -@pytest.fixture -def default_queues(db): - tf_cpu_queue: Queue = Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - tf_gpu_queue: Queue = Queue( - queue_id=2, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_gpu", - ) - tf_cpu_dev_queue: Queue = Queue( - queue_id=3, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu_dev", - ) - tf_gpu_dev_queue: Queue = Queue( - queue_id=4, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_gpu_dev", - ) - db.session.add(tf_cpu_queue) - db.session.add(tf_gpu_queue) - db.session.add(tf_cpu_dev_queue) - db.session.add(tf_gpu_dev_queue) - db.session.commit() - - -@pytest.fixture -def default_queues_with_locks(db, default_queues): - queue_lock: QueueLock = QueueLock( - queue_id=4, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - ) - db.session.add(queue_lock) - db.session.commit() diff --git a/tests/unit/restapi/queue/test_controller.py b/tests/unit/restapi/queue/test_controller.py deleted file mode 100644 index bb70400e2..000000000 --- a/tests/unit/restapi/queue/test_controller.py +++ /dev/null @@ -1,413 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -from __future__ import annotations - -import datetime -from typing import Any, Dict, List - -import pytest -import structlog -from _pytest.monkeypatch import MonkeyPatch -from flask import Flask -from flask_sqlalchemy import SQLAlchemy -from freezegun import freeze_time -from structlog.stdlib import BoundLogger - -from dioptra.restapi.models import Queue -from dioptra.restapi.queue.routes import BASE_ROUTE as QUEUE_BASE_ROUTE -from dioptra.restapi.queue.service import QueueService - -LOGGER: BoundLogger = structlog.stdlib.get_logger() - - -@pytest.fixture -def queue_registration_request() -> Dict[str, Any]: - return {"name": "tensorflow_cpu"} - - -def test_queue_resource_get(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetallunlocked(self, *args, **kwargs) -> List[Queue]: - LOGGER.info("Mocking QueueService.get_all_unlocked()") - queue: Queue = Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - return [queue] - - monkeypatch.setattr(QueueService, "get_all_unlocked", mockgetallunlocked) - - with app.test_client() as client: - response: List[Dict[str, Any]] = client.get( - f"/api/{QUEUE_BASE_ROUTE}/" - ).get_json() - - expected: List[Dict[str, Any]] = [ - { - "queueId": 1, - "createdOn": "2020-08-17T18:46:28.717559", - "lastModified": "2020-08-17T18:46:28.717559", - "name": "tensorflow_cpu", - } - ] - - assert response == expected - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_queue_resource_post( - app: Flask, - db: SQLAlchemy, - queue_registration_request: Dict[str, Any], - monkeypatch: MonkeyPatch, -) -> None: - def mockcreate(*args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.create()") - timestamp = datetime.datetime.now() - return Queue( - queue_id=1, - created_on=timestamp, - last_modified=timestamp, - name="tensorflow_cpu", - ) - - monkeypatch.setattr(QueueService, "create", mockcreate) - - with app.test_client() as client: - response: Dict[str, Any] = client.post( - f"/api/{QUEUE_BASE_ROUTE}/", - content_type="multipart/form-data", - data=queue_registration_request, - follow_redirects=True, - ).get_json() - LOGGER.info("Response received", response=response) - - expected: Dict[str, Any] = { - "queueId": 1, - "createdOn": "2020-08-17T18:46:28.717559", - "lastModified": "2020-08-17T18:46:28.717559", - "name": "tensorflow_cpu", - } - - assert response == expected - - -def test_queue_id_resource_get(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyid(self, queue_id: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_id()") - return Queue( - queue_id=queue_id, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - - monkeypatch.setattr(QueueService, "get_by_id", mockgetbyid) - queue_id: int = 1 - - with app.test_client() as client: - response: Dict[str, Any] = client.get( - f"/api/{QUEUE_BASE_ROUTE}/{queue_id}" - ).get_json() - - expected: Dict[str, Any] = { - "queueId": 1, - "createdOn": "2020-08-17T18:46:28.717559", - "lastModified": "2020-08-17T18:46:28.717559", - "name": "tensorflow_cpu", - } - - assert response == expected - - -def test_queue_id_resource_put(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyid(self, queue_id: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_id()") - return Queue( - queue_id=queue_id, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - - def mockrenamequeue(self, queue: Queue, new_name: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.rename_queue()", new_name=new_name) - queue.name = new_name - queue.last_modified = datetime.datetime(2020, 8, 17, 20, 0, 0, 0) - return queue - - monkeypatch.setattr(QueueService, "get_by_id", mockgetbyid) - monkeypatch.setattr(QueueService, "rename_queue", mockrenamequeue) - queue_id: int = 1 - payload: Dict[str, Any] = {"name": "tf_cpu"} - - with app.test_client() as client: - response: Dict[str, Any] = client.put( - f"/api/{QUEUE_BASE_ROUTE}/{queue_id}", - json=payload, - ).get_json() - - expected: Dict[str, Any] = { - "queueId": 1, - "createdOn": "2020-08-17T18:46:28.717559", - "lastModified": "2020-08-17T20:00:00", - "name": "tf_cpu", - } - - assert response == expected - - -def test_queue_id_resource_delete(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockdeletequeue(self, queue_id: int, *args, **kwargs) -> List[int]: - LOGGER.info("Mocking QueueService.delete_queue()") - return [queue_id] - - monkeypatch.setattr(QueueService, "delete_queue", mockdeletequeue) - queue_id: int = 1 - - with app.test_client() as client: - response: Dict[str, Any] = client.delete( - f"/api/{QUEUE_BASE_ROUTE}/{queue_id}" - ).get_json() - - expected: Dict[str, Any] = { - "status": "Success", - "id": [1], - } - - assert response == expected - - -def test_queue_id_lock_resource_delete(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyid(self, queue_id: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_id()") - return Queue( - queue_id=queue_id, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - - def mockunlockqueue(self, queue: Queue, *args, **kwargs) -> List[int]: - LOGGER.info("Mocking QueueService.unlock_queue()") - return [queue.queue_id] - - monkeypatch.setattr(QueueService, "get_by_id", mockgetbyid) - monkeypatch.setattr(QueueService, "unlock_queue", mockunlockqueue) - queue_id: int = 1 - - with app.test_client() as client: - response: Dict[str, Any] = client.delete( - f"/api/{QUEUE_BASE_ROUTE}/{queue_id}/lock" - ).get_json() - - expected: Dict[str, Any] = { - "status": "Success", - "id": [1], - } - - assert response == expected - - -def test_queue_id_lock_resource_put(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyid(self, queue_id: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_id()") - return Queue( - queue_id=queue_id, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - - def mocklockqueue(self, queue: Queue, *args, **kwargs) -> List[int]: - LOGGER.info("Mocking QueueService.lock_queue()") - return [queue.queue_id] - - monkeypatch.setattr(QueueService, "get_by_id", mockgetbyid) - monkeypatch.setattr(QueueService, "lock_queue", mocklockqueue) - queue_id: int = 1 - - with app.test_client() as client: - response: Dict[str, Any] = client.put( - f"/api/{QUEUE_BASE_ROUTE}/{queue_id}/lock" - ).get_json() - - expected: Dict[str, Any] = { - "status": "Success", - "id": [1], - } - - assert response == expected - - -def test_queue_name_resource_get(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyname(self, queue_name: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_name()") - return Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name=queue_name, - ) - - monkeypatch.setattr(QueueService, "get_by_name", mockgetbyname) - queue_name: str = "tensorflow_cpu" - - with app.test_client() as client: - response: Dict[str, Any] = client.get( - f"/api/{QUEUE_BASE_ROUTE}/name/{queue_name}" - ).get_json() - - expected: Dict[str, Any] = { - "queueId": 1, - "createdOn": "2020-08-17T18:46:28.717559", - "lastModified": "2020-08-17T18:46:28.717559", - "name": "tensorflow_cpu", - } - - assert response == expected - - -def test_queue_name_resource_put(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyname(self, queue_name: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_name()") - return Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name=queue_name, - ) - - def mockrenamequeue(self, queue: Queue, new_name: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.rename_queue()", new_name=new_name) - queue.name = new_name - queue.last_modified = datetime.datetime(2020, 8, 17, 20, 0, 0, 0) - return queue - - monkeypatch.setattr(QueueService, "get_by_name", mockgetbyname) - monkeypatch.setattr(QueueService, "rename_queue", mockrenamequeue) - queue_name: str = "tensorflow_cpu" - payload: Dict[str, Any] = {"name": "tf_cpu"} - - with app.test_client() as client: - response: Dict[str, Any] = client.put( - f"/api/{QUEUE_BASE_ROUTE}/name/{queue_name}", - json=payload, - ).get_json() - - expected: Dict[str, Any] = { - "queueId": 1, - "createdOn": "2020-08-17T18:46:28.717559", - "lastModified": "2020-08-17T20:00:00", - "name": "tf_cpu", - } - - assert response == expected - - -def test_queue_name_resource_delete(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyname(self, queue_name: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_name()") - return Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name=queue_name, - ) - - def mockdeletequeue(self, queue_id: int, *args, **kwargs) -> List[int]: - LOGGER.info("Mocking QueueService.delete_queue()") - return [queue_id] - - monkeypatch.setattr(QueueService, "get_by_name", mockgetbyname) - monkeypatch.setattr(QueueService, "delete_queue", mockdeletequeue) - queue_name: str = "tensorflow_cpu" - - with app.test_client() as client: - response: Dict[str, Any] = client.delete( - f"/api/{QUEUE_BASE_ROUTE}/name/{queue_name}" - ).get_json() - - expected: Dict[str, Any] = { - "status": "Success", - "id": [1], - } - - assert response == expected - - -def test_queue_name_lock_resource_delete(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyname(self, queue_name: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_name()") - return Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name=queue_name, - ) - - def mockunlockqueue(self, queue: Queue, *args, **kwargs) -> List[int]: - LOGGER.info("Mocking QueueService.unlock_queue()") - return [queue.queue_id] - - monkeypatch.setattr(QueueService, "get_by_name", mockgetbyname) - monkeypatch.setattr(QueueService, "unlock_queue", mockunlockqueue) - queue_name: str = "tensorflow_cpu" - - with app.test_client() as client: - response: Dict[str, Any] = client.delete( - f"/api/{QUEUE_BASE_ROUTE}/name/{queue_name}/lock" - ).get_json() - - expected: Dict[str, Any] = { - "status": "Success", - "name": [queue_name], - } - - assert response == expected - - -def test_queue_name_lock_resource_put(app: Flask, monkeypatch: MonkeyPatch) -> None: - def mockgetbyname(self, queue_name: str, *args, **kwargs) -> Queue: - LOGGER.info("Mocking QueueService.get_by_name()") - return Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name=queue_name, - ) - - def mocklockqueue(self, queue: Queue, *args, **kwargs) -> List[int]: - LOGGER.info("Mocking QueueService.lock_queue()") - return [queue.queue_id] - - monkeypatch.setattr(QueueService, "get_by_name", mockgetbyname) - monkeypatch.setattr(QueueService, "lock_queue", mocklockqueue) - queue_name: str = "tensorflow_cpu" - - with app.test_client() as client: - response: Dict[str, Any] = client.put( - f"/api/{QUEUE_BASE_ROUTE}/name/{queue_name}/lock" - ).get_json() - - expected: Dict[str, Any] = { - "status": "Success", - "name": [queue_name], - } - - assert response == expected diff --git a/tests/unit/restapi/queue/test_interface.py b/tests/unit/restapi/queue/test_interface.py deleted file mode 100644 index ffe9b8253..000000000 --- a/tests/unit/restapi/queue/test_interface.py +++ /dev/null @@ -1,85 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -from __future__ import annotations - -import datetime - -import pytest -import structlog -from structlog.stdlib import BoundLogger - -from dioptra.restapi.models import Queue, QueueLock -from dioptra.restapi.queue.interface import ( - QueueInterface, - QueueLockInterface, - QueueUpdateInterface, -) - -LOGGER: BoundLogger = structlog.stdlib.get_logger() - - -@pytest.fixture -def queue_interface() -> QueueInterface: - return QueueInterface( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - - -@pytest.fixture -def queue_lock_interface() -> QueueLockInterface: - return QueueLockInterface( - queue_id=1, created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559) - ) - - -@pytest.fixture -def queue_update_interface() -> QueueUpdateInterface: - return QueueUpdateInterface(name="tensorflow_cpu_dev") - - -def test_QueueInterface_create(queue_interface: QueueInterface) -> None: - assert isinstance(queue_interface, dict) - - -def test_QueueLockInterface_create(queue_lock_interface: QueueLockInterface) -> None: - assert isinstance(queue_lock_interface, dict) - - -def test_QueueUpdateInterface_create( - queue_update_interface: QueueUpdateInterface, -) -> None: - assert isinstance(queue_update_interface, dict) - - -def test_QueueInterface_works(queue_interface: QueueInterface) -> None: - queue: Queue = Queue(**queue_interface) - assert isinstance(queue, Queue) - - -def test_QueueLockInterface_works(queue_lock_interface: QueueLockInterface) -> None: - queue_lock: QueueLock = QueueLock(**queue_lock_interface) - assert isinstance(queue_lock, QueueLock) - - -def test_QueueUpdateInterface_works( - queue_update_interface: QueueUpdateInterface, -) -> None: - queue: Queue = Queue(**queue_update_interface) - assert isinstance(queue, Queue) diff --git a/tests/unit/restapi/queue/test_model.py b/tests/unit/restapi/queue/test_model.py deleted file mode 100644 index f1c33172e..000000000 --- a/tests/unit/restapi/queue/test_model.py +++ /dev/null @@ -1,64 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -from __future__ import annotations - -import datetime - -import pytest -import structlog -from structlog.stdlib import BoundLogger - -from dioptra.restapi.models import Queue, QueueLock, QueueRegistrationFormData - -LOGGER: BoundLogger = structlog.stdlib.get_logger() - - -@pytest.fixture -def queue_lock() -> QueueLock: - return QueueLock( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - ) - - -@pytest.fixture -def queue() -> Queue: - return Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - - -@pytest.fixture -def queue_registration_form_data() -> QueueRegistrationFormData: - return QueueRegistrationFormData(name="tensorflow_cpu") - - -def test_QueueLock_create(queue_lock: QueueLock) -> None: - assert isinstance(queue_lock, QueueLock) - - -def test_Queue_create(queue: Queue) -> None: - assert isinstance(queue, Queue) - - -def test_QueueRegistrationFormData_create( - queue_registration_form_data: QueueRegistrationFormData, -) -> None: - assert isinstance(queue_registration_form_data, dict) diff --git a/tests/unit/restapi/queue/test_schema.py b/tests/unit/restapi/queue/test_schema.py deleted file mode 100644 index 457a9fd01..000000000 --- a/tests/unit/restapi/queue/test_schema.py +++ /dev/null @@ -1,172 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -from __future__ import annotations - -import datetime -from typing import Any, Dict - -import pytest -import structlog -from flask import Flask -from structlog.stdlib import BoundLogger - -from dioptra.restapi.models import Queue, QueueRegistrationForm -from dioptra.restapi.queue.interface import ( - QueueInterface, - QueueLockInterface, - QueueUpdateInterface, -) -from dioptra.restapi.queue.model import QueueLock -from dioptra.restapi.queue.schema import ( - QueueLockSchema, - QueueNameUpdateSchema, - QueueRegistrationFormSchema, - QueueSchema, -) - -LOGGER: BoundLogger = structlog.stdlib.get_logger() - - -@pytest.fixture -def queue_registration_form(app: Flask) -> QueueRegistrationForm: - with app.test_request_context(): - form = QueueRegistrationForm(data={"name": "tensorflow_cpu"}) - - return form - - -@pytest.fixture -def queue_lock_schema() -> QueueLockSchema: - return QueueLockSchema() - - -@pytest.fixture -def queue_schema() -> QueueSchema: - return QueueSchema() - - -@pytest.fixture -def queue_name_update_schema() -> QueueNameUpdateSchema: - return QueueNameUpdateSchema() - - -@pytest.fixture -def queue_registration_form_schema() -> QueueRegistrationFormSchema: - return QueueRegistrationFormSchema() - - -def test_QueueLockSchema_create(queue_lock_schema: QueueLockSchema) -> None: - assert isinstance(queue_lock_schema, QueueLockSchema) - - -def test_QueueNameUpdateSchema_create( - queue_name_update_schema: QueueNameUpdateSchema, -) -> None: - assert isinstance(queue_name_update_schema, QueueNameUpdateSchema) - - -def test_QueueSchema_create(queue_schema: QueueSchema) -> None: - assert isinstance(queue_schema, QueueSchema) - - -def test_QueueRegistrationFormSchema_create( - queue_registration_form_schema: QueueRegistrationFormSchema, -) -> None: - assert isinstance(queue_registration_form_schema, QueueRegistrationFormSchema) - - -def test_QueueSchema_load_works(queue_schema: QueueSchema) -> None: - queue: QueueInterface = queue_schema.load( - { - "queueId": 1, - "createdOn": "2020-08-17T18:46:28.717559", - "lastModified": "2020-08-17T18:46:28.717559", - "name": "tensorflow_cpu", - } - ) - - assert queue["queue_id"] == 1 - assert queue["created_on"] == datetime.datetime(2020, 8, 17, 18, 46, 28, 717559) - assert queue["last_modified"] == datetime.datetime(2020, 8, 17, 18, 46, 28, 717559) - assert queue["name"] == "tensorflow_cpu" - - -def test_QueueLockSchema_load_works(queue_lock_schema: QueueLockSchema) -> None: - queue_lock: QueueLockInterface = queue_lock_schema.load( - {"queueId": 1, "createdOn": "2020-08-17T18:46:28.717559"} - ) - - assert queue_lock["queue_id"] == 1 - assert queue_lock["created_on"] == datetime.datetime( - 2020, 8, 17, 18, 46, 28, 717559 - ) - - -def test_QueueNameUpdateSchema_load_works( - queue_name_update_schema: QueueNameUpdateSchema, -) -> None: - queue: QueueUpdateInterface = queue_name_update_schema.load( - {"name": "tensorflow_cpu"} - ) - - assert queue["name"] == "tensorflow_cpu" - - -def test_QueueSchema_dump_works(queue_schema: QueueSchema) -> None: - queue: Queue = Queue( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - last_modified=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - name="tensorflow_cpu", - ) - queue_serialized: Dict[str, Any] = queue_schema.dump(queue) - - assert queue_serialized["queueId"] == 1 - assert queue_serialized["createdOn"] == "2020-08-17T18:46:28.717559" - assert queue_serialized["lastModified"] == "2020-08-17T18:46:28.717559" - assert queue_serialized["name"] == "tensorflow_cpu" - - -def test_QueueLockSchema_dump_works(queue_lock_schema: QueueLockSchema) -> None: - queue_lock: QueueLock = QueueLock( - queue_id=1, - created_on=datetime.datetime(2020, 8, 17, 18, 46, 28, 717559), - ) - queue_lock_serialized: Dict[str, Any] = queue_lock_schema.dump(queue_lock) - - assert queue_lock_serialized["queueId"] == 1 - assert queue_lock_serialized["createdOn"] == "2020-08-17T18:46:28.717559" - - -def test_QueueNameUpdateSchema_dump_works( - queue_name_update_schema: QueueNameUpdateSchema, -) -> None: - queue: Queue = Queue(name="tensorflow_cpu") - queue_name_updated_serialized: Dict[str, Any] = queue_name_update_schema.dump(queue) - - assert queue_name_updated_serialized["name"] == "tensorflow_cpu" - - -def test_QueueRegistrationFormSchema_dump_works( - queue_registration_form: QueueRegistrationForm, - queue_registration_form_schema: QueueRegistrationFormSchema, -) -> None: - queue_serialized: Dict[str, Any] = queue_registration_form_schema.dump( - queue_registration_form - ) - - assert queue_serialized["name"] == "tensorflow_cpu" diff --git a/tests/unit/restapi/queue/test_service.py b/tests/unit/restapi/queue/test_service.py deleted file mode 100644 index 26297dece..000000000 --- a/tests/unit/restapi/queue/test_service.py +++ /dev/null @@ -1,269 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -from __future__ import annotations - -import datetime -from typing import List, Optional, Set - -import pytest -import structlog -from flask import Flask -from flask_sqlalchemy import SQLAlchemy -from freezegun import freeze_time -from structlog.stdlib import BoundLogger - -from dioptra.restapi.models import ( - Queue, - QueueRegistrationForm, - QueueRegistrationFormData, -) -from dioptra.restapi.queue.errors import QueueAlreadyExistsError -from dioptra.restapi.queue.service import QueueService - -LOGGER: BoundLogger = structlog.stdlib.get_logger() - - -@pytest.fixture -def queue_registration_form(app: Flask) -> QueueRegistrationForm: - with app.test_request_context(): - form = QueueRegistrationForm(data={"name": "tensorflow_cpu"}) - - return form - - -@pytest.fixture -def queue_registration_form_data() -> QueueRegistrationFormData: - return QueueRegistrationFormData(name="tensorflow_cpu") - - -@pytest.fixture -def queue_registration_form_data2() -> QueueRegistrationFormData: - return QueueRegistrationFormData(name="tensorflow_gpu") - - -@pytest.fixture -def queue_service(dependency_injector) -> QueueService: - return dependency_injector.get(QueueService) - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_create( - db: SQLAlchemy, - queue_service: QueueService, - queue_registration_form_data: QueueRegistrationFormData, - queue_registration_form_data2: QueueRegistrationFormData, -): - queue: Queue = queue_service.create( - queue_registration_form_data=queue_registration_form_data - ) - queue2: Queue = queue_service.create( - queue_registration_form_data=queue_registration_form_data2 - ) - - assert queue.queue_id == 1 - assert queue.name == "tensorflow_cpu" - assert queue.created_on == datetime.datetime(2020, 8, 17, 18, 46, 28, 717559) - assert queue.last_modified == datetime.datetime(2020, 8, 17, 18, 46, 28, 717559) - - assert queue2.queue_id == 2 - assert queue2.name == "tensorflow_gpu" - assert queue2.created_on == datetime.datetime(2020, 8, 17, 18, 46, 28, 717559) - assert queue2.last_modified == datetime.datetime(2020, 8, 17, 18, 46, 28, 717559) - - with pytest.raises(QueueAlreadyExistsError): - queue_service.create(queue_registration_form_data=queue_registration_form_data) - - -def test_delete_queue( - db: SQLAlchemy, - queue_service: QueueService, - default_queues, -): - tf_cpu_queue_id: List[int] = queue_service.delete_queue(1) - assert tf_cpu_queue_id[0] == 1 - - tf_cpu_queue: Queue = Queue.query.filter_by(queue_id=1).first() - assert tf_cpu_queue.is_deleted - - -def test_rename_queue( - db: SQLAlchemy, - queue_service: QueueService, - default_queues, -): - tf_cpu_queue: Queue = Queue.query.filter_by(queue_id=1, is_deleted=False).first() - assert tf_cpu_queue.name == "tensorflow_cpu" - - tf_cpu_queue = queue_service.rename_queue(tf_cpu_queue, new_name="tf_cpu") - tf_cpu_queue_updated_query: Queue = Queue.query.filter_by( - queue_id=1, is_deleted=False - ).first() - - assert tf_cpu_queue.name == tf_cpu_queue_updated_query.name - assert tf_cpu_queue_updated_query.name == "tf_cpu" - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_lock_queue(db: SQLAlchemy, queue_service: QueueService, default_queues): - tf_cpu_dev_queue: Queue = Queue.query.filter_by( - queue_id=3, is_deleted=False - ).first() - assert not tf_cpu_dev_queue.lock - - tf_cpu_dev_queue_id: int = tf_cpu_dev_queue.queue_id - response: List[int] = queue_service.lock_queue(queue=tf_cpu_dev_queue) - - assert Queue.query.filter_by(queue_id=3, is_deleted=False).first().lock - assert response[0] == tf_cpu_dev_queue_id - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_unlock_queue( - db: SQLAlchemy, queue_service: QueueService, default_queues_with_locks -): - tf_gpu_dev_queue: Queue = Queue.query.filter_by( - queue_id=4, is_deleted=False - ).first() - assert tf_gpu_dev_queue.lock - - tf_gpu_dev_queue_id: int = tf_gpu_dev_queue.queue_id - response: List[int] = queue_service.unlock_queue(queue=tf_gpu_dev_queue) - - assert not Queue.query.filter_by(queue_id=4, is_deleted=False).first().lock - assert response[0] == tf_gpu_dev_queue_id - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_get_by_id(db: SQLAlchemy, queue_service: QueueService): - timestamp: datetime.datetime = datetime.datetime.now() - - new_queue: Queue = Queue( - queue_id=1, name="tensorflow_cpu", created_on=timestamp, last_modified=timestamp - ) - - db.session.add(new_queue) - db.session.commit() - - queue: Queue = queue_service.get_by_id(1) - - assert queue == new_queue - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_get_by_name(db: SQLAlchemy, queue_service: QueueService): - timestamp: datetime.datetime = datetime.datetime.now() - - new_queue: Queue = Queue( - name="tensorflow_cpu", created_on=timestamp, last_modified=timestamp - ) - - db.session.add(new_queue) - db.session.commit() - - queue: Queue = queue_service.get_by_name("tensorflow_cpu") - - assert queue == new_queue - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_get_unlocked_by_id( - db: SQLAlchemy, queue_service: QueueService, default_queues_with_locks -): - tf_cpu_dev_queue: Optional[Queue] = queue_service.get_unlocked_by_id(3) - tf_gpu_dev_queue: Optional[Queue] = queue_service.get_unlocked_by_id(4) - - assert tf_cpu_dev_queue - assert tf_cpu_dev_queue.queue_id == 3 - assert tf_cpu_dev_queue.name == "tensorflow_cpu_dev" - assert tf_gpu_dev_queue is None - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_get_unlocked_by_name( - db: SQLAlchemy, queue_service: QueueService, default_queues_with_locks -): - tf_cpu_dev_queue: Optional[Queue] = queue_service.get_unlocked_by_name( - "tensorflow_cpu_dev" - ) - tf_gpu_dev_queue: Optional[Queue] = queue_service.get_unlocked_by_name( - "tensorflow_gpu_dev" - ) - - assert tf_cpu_dev_queue - assert tf_cpu_dev_queue.queue_id == 3 - assert tf_cpu_dev_queue.name == "tensorflow_cpu_dev" - assert tf_gpu_dev_queue is None - - -@freeze_time("2020-08-17T18:46:28.717559") -def test_get_all(db: SQLAlchemy, queue_service: QueueService): - timestamp: datetime.datetime = datetime.datetime.now() - - new_queue1: Queue = Queue( - queue_id=1, name="tensorflow_cpu", created_on=timestamp, last_modified=timestamp - ) - new_queue2: Queue = Queue( - queue_id=2, name="tensorflow_gpu", created_on=timestamp, last_modified=timestamp - ) - - db.session.add(new_queue1) - db.session.add(new_queue2) - db.session.commit() - - results: List[Queue] = queue_service.get_all() - - assert len(results) == 2 - assert new_queue1 in results and new_queue2 in results - assert new_queue1.queue_id == 1 - assert new_queue2.queue_id == 2 - - -def test_get_all_unlocked( - db: SQLAlchemy, queue_service: QueueService, default_queues_with_locks -): - results: List[Queue] = queue_service.get_all_unlocked() - queue_names: Set[str] = {queue.name for queue in results} - queue_name_diff: Set[str] = queue_names.difference( - {"tensorflow_cpu", "tensorflow_gpu", "tensorflow_cpu_dev"} - ) - - assert len(results) == 3 - assert len(queue_name_diff) == 0 - - -def test_get_all_locked( - db: SQLAlchemy, queue_service: QueueService, default_queues_with_locks -): - results: List[Queue] = queue_service.get_all_locked() - queue_names: Set[str] = {queue.name for queue in results} - queue_name_diff: Set[str] = queue_names.difference({"tensorflow_gpu_dev"}) - - assert len(results) == 1 - assert len(queue_name_diff) == 0 - - -def test_extract_data_from_form( - queue_service: QueueService, - queue_registration_form: QueueRegistrationForm, -): - queue_registration_form_data: QueueRegistrationFormData = ( - queue_service.extract_data_from_form( - queue_registration_form=queue_registration_form - ) - ) - - assert queue_registration_form_data["name"] == "tensorflow_cpu" diff --git a/tests/unit/restapi/test_queue.py b/tests/unit/restapi/test_queue.py new file mode 100644 index 000000000..93269a286 --- /dev/null +++ b/tests/unit/restapi/test_queue.py @@ -0,0 +1,381 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Test suite for queue operations. + +This module contains a set of tests that validate the CRUD operations and additional +functionalities for the queue entity. The tests ensure that the queues can be +registered, renamed, deleted, and locked/unlocked as expected through the REST API. +""" +from __future__ import annotations + +from typing import Any + +from flask.testing import FlaskClient +from flask_sqlalchemy import SQLAlchemy +from werkzeug.test import TestResponse + +from dioptra.restapi.queue.routes import BASE_ROUTE as QUEUE_BASE_ROUTE + + +def register_queue(client: FlaskClient, name: str) -> TestResponse: + """Register a queue using the API. + + Args: + client: The Flask test client. + name: The name to assign to the new queue. + + Returns: + The response from the API. + """ + return client.post( + f"/api/{QUEUE_BASE_ROUTE}/", + json={"name": name}, + follow_redirects=True, + ) + + +def rename_queue( + client: FlaskClient, + queue_id: int, + new_name: str, +) -> TestResponse: + """Rename a queue using the API. + + Args: + client: The Flask test client. + queue_id: The id of the queue to rename. + new_name: The new name to assign to the queue. + + Returns: + The response from the API. + """ + return client.put( + f"/api/{QUEUE_BASE_ROUTE}/{queue_id}", + json={"name": new_name}, + follow_redirects=True, + ) + + +def delete_queue( + client: FlaskClient, + queue_id: int, +) -> TestResponse: + """Delete a queue using the API. + + Args: + client: The Flask test client. + queue_id: The id of the queue to delete. + + Returns: + The response from the API. + """ + return client.delete( + f"/api/{QUEUE_BASE_ROUTE}/{queue_id}", + follow_redirects=True, + ) + + +def lock_queue( + client: FlaskClient, + queue_id: int, +) -> TestResponse: + """Lock a queue using the API. + + Args: + client: The Flask test client. + queue_id: The id of the queue to lock. + + Returns: + The response from the API. + """ + return client.put( + f"/api/{QUEUE_BASE_ROUTE}/{queue_id}/lock", + follow_redirects=True, + ) + + +def unlock_queue( + client: FlaskClient, + queue_id: int, +) -> TestResponse: + """Unlock a queue using the API. + + Args: + client: The Flask test client. + queue_id: The id of the queue to unlock. + + Returns: + The response from the API. + """ + return client.delete( + f"/api/{QUEUE_BASE_ROUTE}/{queue_id}/lock", + follow_redirects=True, + ) + + +def assert_retrieving_queue_by_id_works( + client: FlaskClient, + queue_id: int, + expected: dict[str, Any], +) -> None: + """Assert that retrieving a queue by id works. + + Args: + client: The Flask test client. + queue_id: The id of the queue to retrieve. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = client.get(f"/api/{QUEUE_BASE_ROUTE}/{queue_id}", follow_redirects=True) + assert response.status_code == 200 and response.get_json() == expected + + +def assert_retrieving_queue_by_name_works( + client: FlaskClient, + queue_name: str, + expected: dict[str, Any], +) -> None: + """Assert that retrieving a queue by name works. + + Args: + client: The Flask test client. + queue_name: The name of the queue to retrieve. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = client.get( + f"/api/{QUEUE_BASE_ROUTE}/name/{queue_name}", follow_redirects=True + ) + assert response.status_code == 200 and response.get_json() == expected + + +def assert_retrieving_all_queues_works( + client: FlaskClient, + expected: list[dict[str, Any]], +) -> None: + """Assert that retrieving all queues works. + + Args: + client: The Flask test client. + expected: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + response = client.get(f"/api/{QUEUE_BASE_ROUTE}", follow_redirects=True) + assert response.status_code == 200 and response.get_json() == expected + + +def assert_registering_existing_queue_name_fails( + client: FlaskClient, name: str +) -> None: + """Assert that registering a queue with an existing name fails. + + Args: + client: The Flask test client. + name: The name to assign to the new queue. + + Raises: + AssertionError: If the response status code is not 400. + """ + response = register_queue(client, name=name) + assert response.status_code == 400 + + +def assert_queue_name_matches_expected_name( + client: FlaskClient, queue_id: int, expected_name: str +) -> None: + """Assert that the name of a queue matches the expected name. + + Args: + client: The Flask test client. + queue_id: The id of the queue to retrieve. + expected_name: The expected name of the queue. + + Raises: + AssertionError: If the response status code is not 200 or if the name of the + queue does not match the expected name. + """ + response = client.get( + f"/api/{QUEUE_BASE_ROUTE}/{queue_id}", + follow_redirects=True, + ) + assert response.status_code == 200 and response.get_json()["name"] == expected_name + + +def assert_queue_is_not_found( + client: FlaskClient, + queue_id: int, +) -> None: + """Assert that a queue is not found. + + Args: + client: The Flask test client. + queue_id: The id of the queue to retrieve. + + Raises: + AssertionError: If the response status code is not 404. + """ + response = client.get( + f"/api/{QUEUE_BASE_ROUTE}/{queue_id}", + follow_redirects=True, + ) + assert response.status_code == 404 + + +def assert_queue_count_matches_expected_count( + client: FlaskClient, + expected: int, +) -> None: + """Assert that the number of queues matches the expected number. + + Args: + client: The Flask test client. + expected: The expected number of queues. + + Raises: + AssertionError: If the response status code is not 200 or if the number of + queues does not match the expected number. + """ + response = client.get( + f"/api/{QUEUE_BASE_ROUTE}", + follow_redirects=True, + ) + assert len(response.get_json()) == expected + + +def test_queue_registration(client: FlaskClient, db: SQLAlchemy) -> None: + """Test that queues can be registered and retrieved using the API. + + This test validates the following sequence of actions: + + - A user registers two queues, "tensorflow_cpu" and "tensorflow_gpu" + - The user is able to retrieve information about each queue using either the + queue id or the unique queue name. + - The user is able to retrieve a list of all registered queues. + - In all cases, the returned information matches the information that was provided + during registration. + """ + queue1_response = register_queue(client, name="tensorflow_cpu") + queue2_response = register_queue(client, name="pytorch_cpu") + queue1_expected = queue1_response.get_json() + queue2_expected = queue2_response.get_json() + queue_expected_list = [queue1_expected, queue2_expected] + assert_retrieving_queue_by_id_works( + client, queue_id=queue1_expected["queueId"], expected=queue1_expected + ) + assert_retrieving_queue_by_name_works( + client, queue_name=queue1_expected["name"], expected=queue1_expected + ) + assert_retrieving_queue_by_id_works( + client, queue_id=queue2_expected["queueId"], expected=queue2_expected + ) + assert_retrieving_queue_by_name_works( + client, queue_name=queue2_expected["name"], expected=queue2_expected + ) + assert_retrieving_all_queues_works(client, expected=queue_expected_list) + + +def test_cannot_register_existing_queue_name( + client: FlaskClient, db: SQLAlchemy +) -> None: + """Test that registering a queue with an existing name fails. + + This test validates the following sequence of actions: + + - A user registers a queue named "tensorflow_cpu" + - The user attempts to register a second queue with the same name, which fails. + """ + queue_name = "tensorflow_cpu" + register_queue(client, name="tensorflow_cpu") + assert_registering_existing_queue_name_fails(client, name=queue_name) + + +def test_queue_renaming(client: FlaskClient, db: SQLAlchemy) -> None: + """Test that a queue can be renamed. + + This test validates the following sequence of actions: + + - A user registers a queue named "tensorflow_cpu" + - The user is able to retrieve information about the "tensorflow_cpu" queue that + matches the information that was provided during registration. + - The user renames this same queue to "tensorflow_gpu" + - The user retrieves information about the same queue and it reflects the name + change. + """ + queue_name = "tensorflow_cpu" + updated_queue_name = "tensorflow_gpu" + registration_response = register_queue(client, name=queue_name) + queue_json = registration_response.get_json() + assert_queue_name_matches_expected_name( + client, queue_id=queue_json["queueId"], expected_name=queue_name + ) + rename_queue(client, queue_id=queue_json["queueId"], new_name=updated_queue_name) + assert_queue_name_matches_expected_name( + client, queue_id=queue_json["queueId"], expected_name=updated_queue_name + ) + + +def test_queue_deleting(client: FlaskClient, db: SQLAlchemy) -> None: + """Test that a queue can be deleted. + + This test validates the following sequence of actions: + + - A user registers a queue named "tensorflow_cpu" + - The user is able to retrieve information about the "tensorflow_cpu" queue that + matches the information that was provided during registration. + - The user deletes the "tensorflow_cpu" queue + - The user attempts to retrieve information about the "tensorflow_cpu" queue, which + is no longer found. + """ + queue_name = "tensorflow_cpu" + registration_response = register_queue(client, name=queue_name) + queue_json = registration_response.get_json() + assert_retrieving_queue_by_id_works( + client, queue_id=queue_json["queueId"], expected=queue_json + ) + delete_queue(client, queue_id=queue_json["queueId"]) + assert_queue_is_not_found(client, queue_id=queue_json["queueId"]) + + +def test_queue_locking(client: FlaskClient, db: SQLAlchemy) -> None: + """Test that a queue can be locked. + + This test validates the following sequence of actions: + + - A user registers two queues, "tensorflow_cpu" and "tensorflow_gpu". + - The user requests a list of all queues, which returns a list of length 2. + - The user locks the "tensorflow_gpu" queue. + - The user requests a list of all queues, which returns a list of length 1. + - The user unlocks the "tensorflow_gpu" queue. + - The user requests a list of all queues, which returns a list of length 2. + """ + register_queue(client, name="tensorflow_cpu") + response = register_queue(client, name="tensorflow_gpu") + tensorflow_gpu_queue_id = response.get_json()["queueId"] + assert_queue_count_matches_expected_count(client, expected=2) + lock_queue(client, queue_id=tensorflow_gpu_queue_id) + assert_queue_count_matches_expected_count(client, expected=1) + unlock_queue(client, queue_id=tensorflow_gpu_queue_id) + assert_queue_count_matches_expected_count(client, expected=2)