From 1e1b5ab70381ddc294cbf55af64bb7751302d942 Mon Sep 17 00:00:00 2001 From: abyrne Date: Thu, 30 Nov 2023 15:12:58 -0500 Subject: [PATCH] refactor: remove dependency on wtforms from the experiment endpoint --- src/dioptra/restapi/experiment/controller.py | 49 ++++---------- .../restapi/experiment/dependencies.py | 11 --- src/dioptra/restapi/experiment/interface.py | 53 --------------- src/dioptra/restapi/experiment/model.py | 49 +------------- src/dioptra/restapi/experiment/schema.py | 67 +------------------ src/dioptra/restapi/experiment/service.py | 23 +------ src/dioptra/restapi/models.py | 8 +-- .../conftest.py | 2 +- tests/unit/restapi/conftest.py | 4 -- 9 files changed, 22 insertions(+), 244 deletions(-) delete mode 100644 src/dioptra/restapi/experiment/interface.py diff --git a/src/dioptra/restapi/experiment/controller.py b/src/dioptra/restapi/experiment/controller.py index b7926e40e..193823c04 100644 --- a/src/dioptra/restapi/experiment/controller.py +++ b/src/dioptra/restapi/experiment/controller.py @@ -29,20 +29,11 @@ from injector import inject from structlog.stdlib import BoundLogger -from dioptra.restapi.utils import as_api_parser - -from .errors import ExperimentDoesNotExistError, ExperimentRegistrationError -from .interface import ExperimentUpdateInterface -from .model import ( - Experiment, - ExperimentRegistrationForm, - ExperimentRegistrationFormData, -) -from .schema import ( - ExperimentRegistrationSchema, - ExperimentSchema, - ExperimentUpdateSchema, -) +from dioptra.restapi.utils import slugify + +from .errors import ExperimentDoesNotExistError +from .model import Experiment +from .schema import ExperimentSchema from .service import ExperimentService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -73,33 +64,21 @@ def get(self) -> List[Experiment]: return self._experiment_service.get_all(log=log) @login_required - @api.expect(as_api_parser(api, ExperimentRegistrationSchema)) - @accepts(ExperimentRegistrationSchema, api=api) + @accepts(schema=ExperimentSchema, api=api) @responds(schema=ExperimentSchema, api=api) def post(self) -> Experiment: """Creates a new experiment via an experiment registration form.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="experiment", request_type="POST" ) # noqa: F841 - experiment_registration_form: ExperimentRegistrationForm = ( - ExperimentRegistrationForm() - ) log.info("Request received") - if not experiment_registration_form.validate_on_submit(): - log.error("Form validation failed") - raise ExperimentRegistrationError + parsed_obj = request.parsed_obj # type: ignore - log.info("Form validation successful") - experiment_registration_form_data: ExperimentRegistrationFormData = ( - self._experiment_service.extract_data_from_form( - experiment_registration_form=experiment_registration_form, log=log - ) - ) - return self._experiment_service.create( - experiment_registration_form_data=experiment_registration_form_data, log=log - ) + name = slugify(str(parsed_obj["name"])) + + return self._experiment_service.create(experiment_name=name, log=log) @api.route("/") @@ -144,14 +123,14 @@ def delete(self, experimentId: int) -> Response: return jsonify(dict(status="Success", id=id)) @login_required - @accepts(schema=ExperimentUpdateSchema, api=api) + @accepts(schema=ExperimentSchema, api=api) @responds(schema=ExperimentSchema, api=api) def put(self, experimentId: int) -> Experiment: """Modifies an experiment by its unique identifier.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="experimentId", request_type="PUT" ) # noqa: F841 - changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore + changes: dict = request.parsed_obj # type: ignore experiment: Optional[Experiment] = self._experiment_service.get_by_id( experimentId, log=log ) @@ -219,14 +198,14 @@ def delete(self, experimentName: str) -> Response: return jsonify(dict(status="Success", id=id)) @login_required - @accepts(schema=ExperimentUpdateSchema, api=api) + @accepts(schema=ExperimentSchema, api=api) @responds(schema=ExperimentSchema, api=api) def put(self, experimentName: str) -> Experiment: """Modifies an experiment by its unique name.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="experimentName", request_type="PUT" ) # noqa: F841 - changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore + changes: dict = request.parsed_obj # type: ignore experiment: Optional[Experiment] = self._experiment_service.get_by_name( experiment_name=experimentName, log=log ) diff --git a/src/dioptra/restapi/experiment/dependencies.py b/src/dioptra/restapi/experiment/dependencies.py index af134fce6..3d7891681 100644 --- a/src/dioptra/restapi/experiment/dependencies.py +++ b/src/dioptra/restapi/experiment/dependencies.py @@ -24,16 +24,6 @@ from dioptra.restapi.shared.request_scope import request -from .schema import ExperimentRegistrationFormSchema - - -class ExperimentRegistrationFormSchemaModule(Module): - @provider - def provide_experiment_registration_form_schema_module( - self, - ) -> ExperimentRegistrationFormSchema: - return ExperimentRegistrationFormSchema() - class MLFlowClientModule(Module): @request @@ -60,5 +50,4 @@ def register_providers(modules: List[Callable[..., Any]]) -> None: modules: A list of callables used for configuring the dependency injection environment. """ - modules.append(ExperimentRegistrationFormSchemaModule) modules.append(MLFlowClientModule) diff --git a/src/dioptra/restapi/experiment/interface.py b/src/dioptra/restapi/experiment/interface.py deleted file mode 100644 index 68bd480cb..000000000 --- a/src/dioptra/restapi/experiment/interface.py +++ /dev/null @@ -1,53 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -"""The interfaces for creating and updating |Experiment| objects. - -.. |Experiment| replace:: :py:class:`~.model.Experiment` -""" -from __future__ import annotations - -import datetime - -from typing_extensions import TypedDict - - -class ExperimentInterface(TypedDict, total=False): - """The interface for constructing a new |Experiment| object. - - Attributes: - experiment_id: An integer identifying a registered experiment. - created_on: The date and time the experiment was created. - last_modified: The date and time the experiment was last modified. - name: The name of the experiment. - """ - - experiment_id: int - created_on: datetime.datetime - last_modified: datetime.datetime - name: str - - -class ExperimentUpdateInterface(TypedDict, total=False): - """The interface for updating an |Experiment| object. - - Attributes: - name: The name of the experiment. - is_deleted: A boolean that indicates if the experiment record is deleted. - """ - - name: str - is_deleted: bool diff --git a/src/dioptra/restapi/experiment/model.py b/src/dioptra/restapi/experiment/model.py index 86cc92c2a..dcfde0359 100644 --- a/src/dioptra/restapi/experiment/model.py +++ b/src/dioptra/restapi/experiment/model.py @@ -18,16 +18,9 @@ from __future__ import annotations import datetime - -from flask_wtf import FlaskForm -from typing_extensions import TypedDict -from wtforms.fields import StringField -from wtforms.validators import InputRequired, ValidationError +from typing import Any, Dict from dioptra.restapi.app import db -from dioptra.restapi.utils import slugify - -from .interface import ExperimentUpdateInterface class Experiment(db.Model): @@ -53,7 +46,7 @@ class Experiment(db.Model): jobs = db.relationship("Job", back_populates="experiment") - def update(self, changes: ExperimentUpdateInterface): + def update(self, changes: Dict[str, Any]): """Updates the record. Args: @@ -66,41 +59,3 @@ def update(self, changes: ExperimentUpdateInterface): setattr(self, key, val) return self - - -class ExperimentRegistrationForm(FlaskForm): - """The experiment registration form. - - Attributes: - name: The name to register as a new experiment. - """ - - name = StringField("Name of Experiment", validators=[InputRequired()]) - - def validate_name(self, field: StringField) -> None: - """Validates that the experiment does not exist in the registry. - - Args: - field: The form field for `name`. - """ - - standardized_name: str = slugify(field.data) - - if ( - Experiment.query.filter_by(name=standardized_name, is_deleted=False).first() - is not None - ): - raise ValidationError( - "Bad Request - An experiment is already registered under " - f"the name {standardized_name}. Please select another and resubmit." - ) - - -class ExperimentRegistrationFormData(TypedDict, total=False): - """The data extracted from the experiment registration form. - - Attributes: - name: The name of the experiment. - """ - - name: str diff --git a/src/dioptra/restapi/experiment/schema.py b/src/dioptra/restapi/experiment/schema.py index 7dcf4680e..d548cdc98 100644 --- a/src/dioptra/restapi/experiment/schema.py +++ b/src/dioptra/restapi/experiment/schema.py @@ -23,16 +23,10 @@ """ from __future__ import annotations -from typing import Any, Dict +from marshmallow import Schema, fields -from marshmallow import Schema, fields, post_dump, pre_dump - -from dioptra.restapi.utils import ParametersSchema, slugify - -from .model import ( +from .model import ( # ExperimentRegistrationForm,; ExperimentRegistrationFormData, Experiment, - ExperimentRegistrationForm, - ExperimentRegistrationFormData, ) @@ -65,60 +59,3 @@ class ExperimentSchema(Schema): name = fields.String( attribute="name", metadata=dict(description="The name of the experiment.") ) - - -class ExperimentUpdateSchema(Schema): - """The schema for the data used to update an |Experiment| object. - - Attributes: - name: The new name of the experiment. Must be unique. - """ - - __model__ = Experiment - - name = fields.String( - attribute="name", - metadata=dict(description="The new name of the experiment. Must be unique."), - ) - - -class ExperimentRegistrationFormSchema(Schema): - """The schema for the information stored in an experiment registration form. - - Attributes: - name: The name of the experiment. Must be unique. - """ - - __model__ = ExperimentRegistrationFormData - - name = fields.String( - attribute="name", - required=True, - metadata=dict(description="The name of the experiment. Must be unique."), - ) - - @pre_dump - def extract_data_from_form( - self, data: ExperimentRegistrationForm, many: bool, **kwargs - ) -> Dict[str, Any]: - """Extracts data from the |ExperimentRegistrationForm| for validation.""" - - return {"name": slugify(data.name.data)} - - @post_dump - def serialize_object( - self, data: Dict[str, Any], many: bool, **kwargs - ) -> ExperimentRegistrationFormData: - """Makes an |ExperimentRegistrationFormData| object from the validated data.""" - return self.__model__(**data) - - -ExperimentRegistrationSchema: list[ParametersSchema] = [ - dict( - name="name", - type=str, - location="form", - required=True, - help="The name of the experiment. Must be unique.", - ) -] diff --git a/src/dioptra/restapi/experiment/service.py b/src/dioptra/restapi/experiment/service.py index 7cca47c92..9bb499224 100644 --- a/src/dioptra/restapi/experiment/service.py +++ b/src/dioptra/restapi/experiment/service.py @@ -35,12 +35,7 @@ ExperimentMLFlowTrackingDoesNotExistError, ExperimentMLFlowTrackingRegistrationError, ) -from .model import ( - Experiment, - ExperimentRegistrationForm, - ExperimentRegistrationFormData, -) -from .schema import ExperimentRegistrationFormSchema +from .model import Experiment LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -49,19 +44,16 @@ class ExperimentService(object): @inject def __init__( self, - experiment_registration_form_schema: ExperimentRegistrationFormSchema, mlflow_tracking_service: MLFlowTrackingService, ) -> None: - self._experiment_registration_form_schema = experiment_registration_form_schema self._mlflow_tracking_service = mlflow_tracking_service def create( self, - experiment_registration_form_data: ExperimentRegistrationFormData, + experiment_name: str, **kwargs, ) -> Experiment: log: BoundLogger = kwargs.get("log", LOGGER.new()) - experiment_name: str = experiment_registration_form_data["name"] if self.get_by_name(experiment_name, log=log) is not None: raise ExperimentAlreadyExistsError @@ -158,14 +150,3 @@ def get_by_name(experiment_name: str, **kwargs) -> Optional[Experiment]: return Experiment.query.filter_by( # type: ignore name=experiment_name, is_deleted=False ).first() - - def extract_data_from_form( - self, experiment_registration_form: ExperimentRegistrationForm, **kwargs - ) -> ExperimentRegistrationFormData: - log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 - - data: ExperimentRegistrationFormData = ( - self._experiment_registration_form_schema.dump(experiment_registration_form) - ) - - return data diff --git a/src/dioptra/restapi/models.py b/src/dioptra/restapi/models.py index 925b9c928..5f9cb8d27 100644 --- a/src/dioptra/restapi/models.py +++ b/src/dioptra/restapi/models.py @@ -17,11 +17,7 @@ """A module of reexports of the application's data models and forms.""" from __future__ import annotations -from .experiment.model import ( - Experiment, - ExperimentRegistrationForm, - ExperimentRegistrationFormData, -) +from .experiment.model import Experiment from .job.model import Job, JobForm, JobFormData from .queue.model import Queue, QueueLock from .task_plugin.model import ( @@ -33,8 +29,6 @@ __all__ = [ "Experiment", - "ExperimentRegistrationForm", - "ExperimentRegistrationFormData", "Job", "JobForm", "JobFormData", diff --git a/tests/cookiecutter_dioptra_deployment/conftest.py b/tests/cookiecutter_dioptra_deployment/conftest.py index 72a3f9937..3e14e2449 100644 --- a/tests/cookiecutter_dioptra_deployment/conftest.py +++ b/tests/cookiecutter_dioptra_deployment/conftest.py @@ -117,7 +117,7 @@ def context(): "image": "node", "namespace": "", "tag": "latest", - "registry": "" + "registry": "", }, "redis": { "image": "redis", diff --git a/tests/unit/restapi/conftest.py b/tests/unit/restapi/conftest.py index 711d2a95a..8e10a1360 100644 --- a/tests/unit/restapi/conftest.py +++ b/tests/unit/restapi/conftest.py @@ -94,9 +94,6 @@ def task_plugin_archive(): @pytest.fixture def dependency_modules() -> List[Any]: - from dioptra.restapi.experiment.dependencies import ( - ExperimentRegistrationFormSchemaModule, - ) from dioptra.restapi.job.dependencies import ( JobFormSchemaModule, RQServiceConfiguration, @@ -137,7 +134,6 @@ def configure(binder: Binder) -> None: return [ configure, - ExperimentRegistrationFormSchemaModule(), JobFormSchemaModule(), PasswordServiceModule(), RQServiceModule(),