Skip to content

Commit

Permalink
refactor: remove dependency on wtforms from the experiment endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
abyrne authored and jkglasbrenner committed Nov 30, 2023
1 parent e67864e commit 1e1b5ab
Show file tree
Hide file tree
Showing 9 changed files with 22 additions and 244 deletions.
49 changes: 14 additions & 35 deletions src/dioptra/restapi/experiment/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,11 @@
from injector import inject
from structlog.stdlib import BoundLogger

from dioptra.restapi.utils import as_api_parser

from .errors import ExperimentDoesNotExistError, ExperimentRegistrationError
from .interface import ExperimentUpdateInterface
from .model import (
Experiment,
ExperimentRegistrationForm,
ExperimentRegistrationFormData,
)
from .schema import (
ExperimentRegistrationSchema,
ExperimentSchema,
ExperimentUpdateSchema,
)
from dioptra.restapi.utils import slugify

from .errors import ExperimentDoesNotExistError
from .model import Experiment
from .schema import ExperimentSchema
from .service import ExperimentService

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand Down Expand Up @@ -73,33 +64,21 @@ def get(self) -> List[Experiment]:
return self._experiment_service.get_all(log=log)

@login_required
@api.expect(as_api_parser(api, ExperimentRegistrationSchema))
@accepts(ExperimentRegistrationSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def post(self) -> Experiment:
"""Creates a new experiment via an experiment registration form."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experiment", request_type="POST"
) # noqa: F841
experiment_registration_form: ExperimentRegistrationForm = (
ExperimentRegistrationForm()
)

log.info("Request received")

if not experiment_registration_form.validate_on_submit():
log.error("Form validation failed")
raise ExperimentRegistrationError
parsed_obj = request.parsed_obj # type: ignore

log.info("Form validation successful")
experiment_registration_form_data: ExperimentRegistrationFormData = (
self._experiment_service.extract_data_from_form(
experiment_registration_form=experiment_registration_form, log=log
)
)
return self._experiment_service.create(
experiment_registration_form_data=experiment_registration_form_data, log=log
)
name = slugify(str(parsed_obj["name"]))

return self._experiment_service.create(experiment_name=name, log=log)


@api.route("/<int:experimentId>")
Expand Down Expand Up @@ -144,14 +123,14 @@ def delete(self, experimentId: int) -> Response:
return jsonify(dict(status="Success", id=id))

@login_required
@accepts(schema=ExperimentUpdateSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def put(self, experimentId: int) -> Experiment:
"""Modifies an experiment by its unique identifier."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experimentId", request_type="PUT"
) # noqa: F841
changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore
changes: dict = request.parsed_obj # type: ignore
experiment: Optional[Experiment] = self._experiment_service.get_by_id(
experimentId, log=log
)
Expand Down Expand Up @@ -219,14 +198,14 @@ def delete(self, experimentName: str) -> Response:
return jsonify(dict(status="Success", id=id))

@login_required
@accepts(schema=ExperimentUpdateSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def put(self, experimentName: str) -> Experiment:
"""Modifies an experiment by its unique name."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experimentName", request_type="PUT"
) # noqa: F841
changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore
changes: dict = request.parsed_obj # type: ignore
experiment: Optional[Experiment] = self._experiment_service.get_by_name(
experiment_name=experimentName, log=log
)
Expand Down
11 changes: 0 additions & 11 deletions src/dioptra/restapi/experiment/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,6 @@

from dioptra.restapi.shared.request_scope import request

from .schema import ExperimentRegistrationFormSchema


class ExperimentRegistrationFormSchemaModule(Module):
@provider
def provide_experiment_registration_form_schema_module(
self,
) -> ExperimentRegistrationFormSchema:
return ExperimentRegistrationFormSchema()


class MLFlowClientModule(Module):
@request
Expand All @@ -60,5 +50,4 @@ def register_providers(modules: List[Callable[..., Any]]) -> None:
modules: A list of callables used for configuring the dependency injection
environment.
"""
modules.append(ExperimentRegistrationFormSchemaModule)
modules.append(MLFlowClientModule)
53 changes: 0 additions & 53 deletions src/dioptra/restapi/experiment/interface.py

This file was deleted.

49 changes: 2 additions & 47 deletions src/dioptra/restapi/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,9 @@
from __future__ import annotations

import datetime

from flask_wtf import FlaskForm
from typing_extensions import TypedDict
from wtforms.fields import StringField
from wtforms.validators import InputRequired, ValidationError
from typing import Any, Dict

from dioptra.restapi.app import db
from dioptra.restapi.utils import slugify

from .interface import ExperimentUpdateInterface


class Experiment(db.Model):
Expand All @@ -53,7 +46,7 @@ class Experiment(db.Model):

jobs = db.relationship("Job", back_populates="experiment")

def update(self, changes: ExperimentUpdateInterface):
def update(self, changes: Dict[str, Any]):
"""Updates the record.
Args:
Expand All @@ -66,41 +59,3 @@ def update(self, changes: ExperimentUpdateInterface):
setattr(self, key, val)

return self


class ExperimentRegistrationForm(FlaskForm):
"""The experiment registration form.
Attributes:
name: The name to register as a new experiment.
"""

name = StringField("Name of Experiment", validators=[InputRequired()])

def validate_name(self, field: StringField) -> None:
"""Validates that the experiment does not exist in the registry.
Args:
field: The form field for `name`.
"""

standardized_name: str = slugify(field.data)

if (
Experiment.query.filter_by(name=standardized_name, is_deleted=False).first()
is not None
):
raise ValidationError(
"Bad Request - An experiment is already registered under "
f"the name {standardized_name}. Please select another and resubmit."
)


class ExperimentRegistrationFormData(TypedDict, total=False):
"""The data extracted from the experiment registration form.
Attributes:
name: The name of the experiment.
"""

name: str
67 changes: 2 additions & 65 deletions src/dioptra/restapi/experiment/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,10 @@
"""
from __future__ import annotations

from typing import Any, Dict
from marshmallow import Schema, fields

from marshmallow import Schema, fields, post_dump, pre_dump

from dioptra.restapi.utils import ParametersSchema, slugify

from .model import (
from .model import ( # ExperimentRegistrationForm,; ExperimentRegistrationFormData,
Experiment,
ExperimentRegistrationForm,
ExperimentRegistrationFormData,
)


Expand Down Expand Up @@ -65,60 +59,3 @@ class ExperimentSchema(Schema):
name = fields.String(
attribute="name", metadata=dict(description="The name of the experiment.")
)


class ExperimentUpdateSchema(Schema):
"""The schema for the data used to update an |Experiment| object.
Attributes:
name: The new name of the experiment. Must be unique.
"""

__model__ = Experiment

name = fields.String(
attribute="name",
metadata=dict(description="The new name of the experiment. Must be unique."),
)


class ExperimentRegistrationFormSchema(Schema):
"""The schema for the information stored in an experiment registration form.
Attributes:
name: The name of the experiment. Must be unique.
"""

__model__ = ExperimentRegistrationFormData

name = fields.String(
attribute="name",
required=True,
metadata=dict(description="The name of the experiment. Must be unique."),
)

@pre_dump
def extract_data_from_form(
self, data: ExperimentRegistrationForm, many: bool, **kwargs
) -> Dict[str, Any]:
"""Extracts data from the |ExperimentRegistrationForm| for validation."""

return {"name": slugify(data.name.data)}

@post_dump
def serialize_object(
self, data: Dict[str, Any], many: bool, **kwargs
) -> ExperimentRegistrationFormData:
"""Makes an |ExperimentRegistrationFormData| object from the validated data."""
return self.__model__(**data)


ExperimentRegistrationSchema: list[ParametersSchema] = [
dict(
name="name",
type=str,
location="form",
required=True,
help="The name of the experiment. Must be unique.",
)
]
23 changes: 2 additions & 21 deletions src/dioptra/restapi/experiment/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,7 @@
ExperimentMLFlowTrackingDoesNotExistError,
ExperimentMLFlowTrackingRegistrationError,
)
from .model import (
Experiment,
ExperimentRegistrationForm,
ExperimentRegistrationFormData,
)
from .schema import ExperimentRegistrationFormSchema
from .model import Experiment

LOGGER: BoundLogger = structlog.stdlib.get_logger()

Expand All @@ -49,19 +44,16 @@ class ExperimentService(object):
@inject
def __init__(
self,
experiment_registration_form_schema: ExperimentRegistrationFormSchema,
mlflow_tracking_service: MLFlowTrackingService,
) -> None:
self._experiment_registration_form_schema = experiment_registration_form_schema
self._mlflow_tracking_service = mlflow_tracking_service

def create(
self,
experiment_registration_form_data: ExperimentRegistrationFormData,
experiment_name: str,
**kwargs,
) -> Experiment:
log: BoundLogger = kwargs.get("log", LOGGER.new())
experiment_name: str = experiment_registration_form_data["name"]

if self.get_by_name(experiment_name, log=log) is not None:
raise ExperimentAlreadyExistsError
Expand Down Expand Up @@ -158,14 +150,3 @@ def get_by_name(experiment_name: str, **kwargs) -> Optional[Experiment]:
return Experiment.query.filter_by( # type: ignore
name=experiment_name, is_deleted=False
).first()

def extract_data_from_form(
self, experiment_registration_form: ExperimentRegistrationForm, **kwargs
) -> ExperimentRegistrationFormData:
log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841

data: ExperimentRegistrationFormData = (
self._experiment_registration_form_schema.dump(experiment_registration_form)
)

return data
8 changes: 1 addition & 7 deletions src/dioptra/restapi/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,7 @@
"""A module of reexports of the application's data models and forms."""
from __future__ import annotations

from .experiment.model import (
Experiment,
ExperimentRegistrationForm,
ExperimentRegistrationFormData,
)
from .experiment.model import Experiment
from .job.model import Job, JobForm, JobFormData
from .queue.model import Queue, QueueLock
from .task_plugin.model import (
Expand All @@ -33,8 +29,6 @@

__all__ = [
"Experiment",
"ExperimentRegistrationForm",
"ExperimentRegistrationFormData",
"Job",
"JobForm",
"JobFormData",
Expand Down
Loading

0 comments on commit 1e1b5ab

Please sign in to comment.