diff --git a/examples/scripts/client.py b/examples/scripts/client.py index 1114929e1..ca1e1de7b 100644 --- a/examples/scripts/client.py +++ b/examples/scripts/client.py @@ -145,7 +145,7 @@ def register_experiment(self, name: str) -> dict[str, Any]: response = requests.post( self.experiment_endpoint, - data=experiment_registration_form, + json=experiment_registration_form, ) return response.json() diff --git a/src/dioptra/client/_client.py b/src/dioptra/client/_client.py index 39f20c3cb..37d269baf 100644 --- a/src/dioptra/client/_client.py +++ b/src/dioptra/client/_client.py @@ -580,7 +580,7 @@ def register_experiment(self, name: str) -> dict[str, Any]: response = requests.post( self.experiment_endpoint, - data=experiment_registration_form, + json=experiment_registration_form, ) return cast(dict[str, Any], response.json()) diff --git a/src/dioptra/restapi/experiment/controller.py b/src/dioptra/restapi/experiment/controller.py index b7926e40e..193823c04 100644 --- a/src/dioptra/restapi/experiment/controller.py +++ b/src/dioptra/restapi/experiment/controller.py @@ -29,20 +29,11 @@ from injector import inject from structlog.stdlib import BoundLogger -from dioptra.restapi.utils import as_api_parser - -from .errors import ExperimentDoesNotExistError, ExperimentRegistrationError -from .interface import ExperimentUpdateInterface -from .model import ( - Experiment, - ExperimentRegistrationForm, - ExperimentRegistrationFormData, -) -from .schema import ( - ExperimentRegistrationSchema, - ExperimentSchema, - ExperimentUpdateSchema, -) +from dioptra.restapi.utils import slugify + +from .errors import ExperimentDoesNotExistError +from .model import Experiment +from .schema import ExperimentSchema from .service import ExperimentService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -73,33 +64,21 @@ def get(self) -> List[Experiment]: return self._experiment_service.get_all(log=log) @login_required - @api.expect(as_api_parser(api, ExperimentRegistrationSchema)) - @accepts(ExperimentRegistrationSchema, api=api) + @accepts(schema=ExperimentSchema, api=api) @responds(schema=ExperimentSchema, api=api) def post(self) -> Experiment: """Creates a new experiment via an experiment registration form.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="experiment", request_type="POST" ) # noqa: F841 - experiment_registration_form: ExperimentRegistrationForm = ( - ExperimentRegistrationForm() - ) log.info("Request received") - if not experiment_registration_form.validate_on_submit(): - log.error("Form validation failed") - raise ExperimentRegistrationError + parsed_obj = request.parsed_obj # type: ignore - log.info("Form validation successful") - experiment_registration_form_data: ExperimentRegistrationFormData = ( - self._experiment_service.extract_data_from_form( - experiment_registration_form=experiment_registration_form, log=log - ) - ) - return self._experiment_service.create( - experiment_registration_form_data=experiment_registration_form_data, log=log - ) + name = slugify(str(parsed_obj["name"])) + + return self._experiment_service.create(experiment_name=name, log=log) @api.route("/") @@ -144,14 +123,14 @@ def delete(self, experimentId: int) -> Response: return jsonify(dict(status="Success", id=id)) @login_required - @accepts(schema=ExperimentUpdateSchema, api=api) + @accepts(schema=ExperimentSchema, api=api) @responds(schema=ExperimentSchema, api=api) def put(self, experimentId: int) -> Experiment: """Modifies an experiment by its unique identifier.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="experimentId", request_type="PUT" ) # noqa: F841 - changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore + changes: dict = request.parsed_obj # type: ignore experiment: Optional[Experiment] = self._experiment_service.get_by_id( experimentId, log=log ) @@ -219,14 +198,14 @@ def delete(self, experimentName: str) -> Response: return jsonify(dict(status="Success", id=id)) @login_required - @accepts(schema=ExperimentUpdateSchema, api=api) + @accepts(schema=ExperimentSchema, api=api) @responds(schema=ExperimentSchema, api=api) def put(self, experimentName: str) -> Experiment: """Modifies an experiment by its unique name.""" log: BoundLogger = LOGGER.new( request_id=str(uuid.uuid4()), resource="experimentName", request_type="PUT" ) # noqa: F841 - changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore + changes: dict = request.parsed_obj # type: ignore experiment: Optional[Experiment] = self._experiment_service.get_by_name( experiment_name=experimentName, log=log ) diff --git a/src/dioptra/restapi/experiment/dependencies.py b/src/dioptra/restapi/experiment/dependencies.py index af134fce6..3d7891681 100644 --- a/src/dioptra/restapi/experiment/dependencies.py +++ b/src/dioptra/restapi/experiment/dependencies.py @@ -24,16 +24,6 @@ from dioptra.restapi.shared.request_scope import request -from .schema import ExperimentRegistrationFormSchema - - -class ExperimentRegistrationFormSchemaModule(Module): - @provider - def provide_experiment_registration_form_schema_module( - self, - ) -> ExperimentRegistrationFormSchema: - return ExperimentRegistrationFormSchema() - class MLFlowClientModule(Module): @request @@ -60,5 +50,4 @@ def register_providers(modules: List[Callable[..., Any]]) -> None: modules: A list of callables used for configuring the dependency injection environment. """ - modules.append(ExperimentRegistrationFormSchemaModule) modules.append(MLFlowClientModule) diff --git a/src/dioptra/restapi/experiment/interface.py b/src/dioptra/restapi/experiment/interface.py deleted file mode 100644 index 68bd480cb..000000000 --- a/src/dioptra/restapi/experiment/interface.py +++ /dev/null @@ -1,53 +0,0 @@ -# This Software (Dioptra) is being made available as a public service by the -# National Institute of Standards and Technology (NIST), an Agency of the United -# States Department of Commerce. This software was developed in part by employees of -# NIST and in part by NIST contractors. Copyright in portions of this software that -# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant -# to Title 17 United States Code Section 105, works of NIST employees are not -# subject to copyright protection in the United States. However, NIST may hold -# international copyright in software created by its employees and domestic -# copyright (or licensing rights) in portions of software that were assigned or -# licensed to NIST. To the extent that NIST holds copyright in this software, it is -# being made available under the Creative Commons Attribution 4.0 International -# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts -# of the software developed or licensed by NIST. -# -# ACCESS THE FULL CC BY 4.0 LICENSE HERE: -# https://creativecommons.org/licenses/by/4.0/legalcode -"""The interfaces for creating and updating |Experiment| objects. - -.. |Experiment| replace:: :py:class:`~.model.Experiment` -""" -from __future__ import annotations - -import datetime - -from typing_extensions import TypedDict - - -class ExperimentInterface(TypedDict, total=False): - """The interface for constructing a new |Experiment| object. - - Attributes: - experiment_id: An integer identifying a registered experiment. - created_on: The date and time the experiment was created. - last_modified: The date and time the experiment was last modified. - name: The name of the experiment. - """ - - experiment_id: int - created_on: datetime.datetime - last_modified: datetime.datetime - name: str - - -class ExperimentUpdateInterface(TypedDict, total=False): - """The interface for updating an |Experiment| object. - - Attributes: - name: The name of the experiment. - is_deleted: A boolean that indicates if the experiment record is deleted. - """ - - name: str - is_deleted: bool diff --git a/src/dioptra/restapi/experiment/model.py b/src/dioptra/restapi/experiment/model.py index 86cc92c2a..dcfde0359 100644 --- a/src/dioptra/restapi/experiment/model.py +++ b/src/dioptra/restapi/experiment/model.py @@ -18,16 +18,9 @@ from __future__ import annotations import datetime - -from flask_wtf import FlaskForm -from typing_extensions import TypedDict -from wtforms.fields import StringField -from wtforms.validators import InputRequired, ValidationError +from typing import Any, Dict from dioptra.restapi.app import db -from dioptra.restapi.utils import slugify - -from .interface import ExperimentUpdateInterface class Experiment(db.Model): @@ -53,7 +46,7 @@ class Experiment(db.Model): jobs = db.relationship("Job", back_populates="experiment") - def update(self, changes: ExperimentUpdateInterface): + def update(self, changes: Dict[str, Any]): """Updates the record. Args: @@ -66,41 +59,3 @@ def update(self, changes: ExperimentUpdateInterface): setattr(self, key, val) return self - - -class ExperimentRegistrationForm(FlaskForm): - """The experiment registration form. - - Attributes: - name: The name to register as a new experiment. - """ - - name = StringField("Name of Experiment", validators=[InputRequired()]) - - def validate_name(self, field: StringField) -> None: - """Validates that the experiment does not exist in the registry. - - Args: - field: The form field for `name`. - """ - - standardized_name: str = slugify(field.data) - - if ( - Experiment.query.filter_by(name=standardized_name, is_deleted=False).first() - is not None - ): - raise ValidationError( - "Bad Request - An experiment is already registered under " - f"the name {standardized_name}. Please select another and resubmit." - ) - - -class ExperimentRegistrationFormData(TypedDict, total=False): - """The data extracted from the experiment registration form. - - Attributes: - name: The name of the experiment. - """ - - name: str diff --git a/src/dioptra/restapi/experiment/schema.py b/src/dioptra/restapi/experiment/schema.py index 7dcf4680e..eed245452 100644 --- a/src/dioptra/restapi/experiment/schema.py +++ b/src/dioptra/restapi/experiment/schema.py @@ -23,17 +23,7 @@ """ from __future__ import annotations -from typing import Any, Dict - -from marshmallow import Schema, fields, post_dump, pre_dump - -from dioptra.restapi.utils import ParametersSchema, slugify - -from .model import ( - Experiment, - ExperimentRegistrationForm, - ExperimentRegistrationFormData, -) +from marshmallow import Schema, fields class ExperimentSchema(Schema): @@ -46,79 +36,23 @@ class ExperimentSchema(Schema): name: The name of the experiment. """ - __model__ = Experiment - experimentId = fields.Integer( attribute="experiment_id", metadata=dict(description="An integer identifying a registered experiment."), + dump_only=True, ) createdOn = fields.DateTime( attribute="created_on", metadata=dict(description="The date and time the experiment was created."), + dump_only=True, ) lastModified = fields.DateTime( attribute="last_modified", metadata=dict( description="The date and time the experiment was last modified." ), + dump_only=True, ) name = fields.String( attribute="name", metadata=dict(description="The name of the experiment.") ) - - -class ExperimentUpdateSchema(Schema): - """The schema for the data used to update an |Experiment| object. - - Attributes: - name: The new name of the experiment. Must be unique. - """ - - __model__ = Experiment - - name = fields.String( - attribute="name", - metadata=dict(description="The new name of the experiment. Must be unique."), - ) - - -class ExperimentRegistrationFormSchema(Schema): - """The schema for the information stored in an experiment registration form. - - Attributes: - name: The name of the experiment. Must be unique. - """ - - __model__ = ExperimentRegistrationFormData - - name = fields.String( - attribute="name", - required=True, - metadata=dict(description="The name of the experiment. Must be unique."), - ) - - @pre_dump - def extract_data_from_form( - self, data: ExperimentRegistrationForm, many: bool, **kwargs - ) -> Dict[str, Any]: - """Extracts data from the |ExperimentRegistrationForm| for validation.""" - - return {"name": slugify(data.name.data)} - - @post_dump - def serialize_object( - self, data: Dict[str, Any], many: bool, **kwargs - ) -> ExperimentRegistrationFormData: - """Makes an |ExperimentRegistrationFormData| object from the validated data.""" - return self.__model__(**data) - - -ExperimentRegistrationSchema: list[ParametersSchema] = [ - dict( - name="name", - type=str, - location="form", - required=True, - help="The name of the experiment. Must be unique.", - ) -] diff --git a/src/dioptra/restapi/experiment/service.py b/src/dioptra/restapi/experiment/service.py index 7cca47c92..9bb499224 100644 --- a/src/dioptra/restapi/experiment/service.py +++ b/src/dioptra/restapi/experiment/service.py @@ -35,12 +35,7 @@ ExperimentMLFlowTrackingDoesNotExistError, ExperimentMLFlowTrackingRegistrationError, ) -from .model import ( - Experiment, - ExperimentRegistrationForm, - ExperimentRegistrationFormData, -) -from .schema import ExperimentRegistrationFormSchema +from .model import Experiment LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -49,19 +44,16 @@ class ExperimentService(object): @inject def __init__( self, - experiment_registration_form_schema: ExperimentRegistrationFormSchema, mlflow_tracking_service: MLFlowTrackingService, ) -> None: - self._experiment_registration_form_schema = experiment_registration_form_schema self._mlflow_tracking_service = mlflow_tracking_service def create( self, - experiment_registration_form_data: ExperimentRegistrationFormData, + experiment_name: str, **kwargs, ) -> Experiment: log: BoundLogger = kwargs.get("log", LOGGER.new()) - experiment_name: str = experiment_registration_form_data["name"] if self.get_by_name(experiment_name, log=log) is not None: raise ExperimentAlreadyExistsError @@ -158,14 +150,3 @@ def get_by_name(experiment_name: str, **kwargs) -> Optional[Experiment]: return Experiment.query.filter_by( # type: ignore name=experiment_name, is_deleted=False ).first() - - def extract_data_from_form( - self, experiment_registration_form: ExperimentRegistrationForm, **kwargs - ) -> ExperimentRegistrationFormData: - log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 - - data: ExperimentRegistrationFormData = ( - self._experiment_registration_form_schema.dump(experiment_registration_form) - ) - - return data diff --git a/src/dioptra/restapi/models.py b/src/dioptra/restapi/models.py index 925b9c928..5f9cb8d27 100644 --- a/src/dioptra/restapi/models.py +++ b/src/dioptra/restapi/models.py @@ -17,11 +17,7 @@ """A module of reexports of the application's data models and forms.""" from __future__ import annotations -from .experiment.model import ( - Experiment, - ExperimentRegistrationForm, - ExperimentRegistrationFormData, -) +from .experiment.model import Experiment from .job.model import Job, JobForm, JobFormData from .queue.model import Queue, QueueLock from .task_plugin.model import ( @@ -33,8 +29,6 @@ __all__ = [ "Experiment", - "ExperimentRegistrationForm", - "ExperimentRegistrationFormData", "Job", "JobForm", "JobFormData", diff --git a/tests/cookiecutter_dioptra_deployment/conftest.py b/tests/cookiecutter_dioptra_deployment/conftest.py index 72a3f9937..3e14e2449 100644 --- a/tests/cookiecutter_dioptra_deployment/conftest.py +++ b/tests/cookiecutter_dioptra_deployment/conftest.py @@ -117,7 +117,7 @@ def context(): "image": "node", "namespace": "", "tag": "latest", - "registry": "" + "registry": "", }, "redis": { "image": "redis", diff --git a/tests/unit/restapi/conftest.py b/tests/unit/restapi/conftest.py index 711d2a95a..8e10a1360 100644 --- a/tests/unit/restapi/conftest.py +++ b/tests/unit/restapi/conftest.py @@ -94,9 +94,6 @@ def task_plugin_archive(): @pytest.fixture def dependency_modules() -> List[Any]: - from dioptra.restapi.experiment.dependencies import ( - ExperimentRegistrationFormSchemaModule, - ) from dioptra.restapi.job.dependencies import ( JobFormSchemaModule, RQServiceConfiguration, @@ -137,7 +134,6 @@ def configure(binder: Binder) -> None: return [ configure, - ExperimentRegistrationFormSchemaModule(), JobFormSchemaModule(), PasswordServiceModule(), RQServiceModule(), diff --git a/tests/unit/rq/tasks/test_run_task_engine.py b/tests/unit/rq/tasks/test_run_task_engine.py index 0dde0064e..c6576f8c2 100644 --- a/tests/unit/rq/tasks/test_run_task_engine.py +++ b/tests/unit/rq/tasks/test_run_task_engine.py @@ -111,20 +111,11 @@ def dioptra_set_run_id_for_job(self, run_id, job_id): # silly experiment which calls a function which does nothing silly_experiment = { # Match these up with global_experiment_params above - "parameters": { - "param1": {"type": "integer"}, - "param2": {"type": "string"} - }, + "parameters": {"param1": {"type": "integer"}, "param2": {"type": "string"}}, "tasks": { - "silly": { - "plugin": "tests.unit.rq.tasks.test_run_task_engine.silly_plugin" - } + "silly": {"plugin": "tests.unit.rq.tasks.test_run_task_engine.silly_plugin"} }, - "graph": { - "step1": { - "silly": [] - } - } + "graph": {"step1": {"silly": []}}, } # Split the builtins plugins listing into two pages, to test paging. diff --git a/tests/unit/sdk/utilities/paths/test_clear_directory.py b/tests/unit/sdk/utilities/paths/test_clear_directory.py index 201d2d2cc..a3394c7af 100644 --- a/tests/unit/sdk/utilities/paths/test_clear_directory.py +++ b/tests/unit/sdk/utilities/paths/test_clear_directory.py @@ -2,11 +2,7 @@ def test_clear_directory(tmp_path): - files = [ - "file1", - "dir1/file2", - "dir1/dir2/file3" - ] + files = ["file1", "dir1/file2", "dir1/dir2/file3"] # Make some files for f in files: diff --git a/tests/unit/task_engine/test_task_engine.py b/tests/unit/task_engine/test_task_engine.py index 00f692ca0..72d3a3560 100644 --- a/tests/unit/task_engine/test_task_engine.py +++ b/tests/unit/task_engine/test_task_engine.py @@ -258,9 +258,7 @@ def test_globals_dict_nodefault() -> None: "graph": {"step1": {"add": [1, "$global_in"]}}, } - dioptra.task_engine.task_engine.run_experiment( - desc, {"global_in": 2} - ) + dioptra.task_engine.task_engine.run_experiment(desc, {"global_in": 2}) assert _output == 3 @@ -278,9 +276,7 @@ def test_globals_dict_default() -> None: "graph": {"step1": {"add": [1, "$global_in"]}}, } - dioptra.task_engine.task_engine.run_experiment( - desc, {"global_in": 2} - ) + dioptra.task_engine.task_engine.run_experiment(desc, {"global_in": 2}) assert _output == 3