Skip to content

Commit

Permalink
refactor: remove dependency on wtforms from the experiment endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
alexb1200 authored Dec 13, 2023
1 parent bab6efb commit 491d324
Show file tree
Hide file tree
Showing 14 changed files with 32 additions and 274 deletions.
2 changes: 1 addition & 1 deletion examples/scripts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def register_experiment(self, name: str) -> dict[str, Any]:

response = requests.post(
self.experiment_endpoint,
data=experiment_registration_form,
json=experiment_registration_form,
)

return response.json()
Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def register_experiment(self, name: str) -> dict[str, Any]:

response = requests.post(
self.experiment_endpoint,
data=experiment_registration_form,
json=experiment_registration_form,
)

return cast(dict[str, Any], response.json())
Expand Down
49 changes: 14 additions & 35 deletions src/dioptra/restapi/experiment/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,11 @@
from injector import inject
from structlog.stdlib import BoundLogger

from dioptra.restapi.utils import as_api_parser

from .errors import ExperimentDoesNotExistError, ExperimentRegistrationError
from .interface import ExperimentUpdateInterface
from .model import (
Experiment,
ExperimentRegistrationForm,
ExperimentRegistrationFormData,
)
from .schema import (
ExperimentRegistrationSchema,
ExperimentSchema,
ExperimentUpdateSchema,
)
from dioptra.restapi.utils import slugify

from .errors import ExperimentDoesNotExistError
from .model import Experiment
from .schema import ExperimentSchema
from .service import ExperimentService

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand Down Expand Up @@ -73,33 +64,21 @@ def get(self) -> List[Experiment]:
return self._experiment_service.get_all(log=log)

@login_required
@api.expect(as_api_parser(api, ExperimentRegistrationSchema))
@accepts(ExperimentRegistrationSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def post(self) -> Experiment:
"""Creates a new experiment via an experiment registration form."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experiment", request_type="POST"
) # noqa: F841
experiment_registration_form: ExperimentRegistrationForm = (
ExperimentRegistrationForm()
)

log.info("Request received")

if not experiment_registration_form.validate_on_submit():
log.error("Form validation failed")
raise ExperimentRegistrationError
parsed_obj = request.parsed_obj # type: ignore

log.info("Form validation successful")
experiment_registration_form_data: ExperimentRegistrationFormData = (
self._experiment_service.extract_data_from_form(
experiment_registration_form=experiment_registration_form, log=log
)
)
return self._experiment_service.create(
experiment_registration_form_data=experiment_registration_form_data, log=log
)
name = slugify(str(parsed_obj["name"]))

return self._experiment_service.create(experiment_name=name, log=log)


@api.route("/<int:experimentId>")
Expand Down Expand Up @@ -144,14 +123,14 @@ def delete(self, experimentId: int) -> Response:
return jsonify(dict(status="Success", id=id))

@login_required
@accepts(schema=ExperimentUpdateSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def put(self, experimentId: int) -> Experiment:
"""Modifies an experiment by its unique identifier."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experimentId", request_type="PUT"
) # noqa: F841
changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore
changes: dict = request.parsed_obj # type: ignore
experiment: Optional[Experiment] = self._experiment_service.get_by_id(
experimentId, log=log
)
Expand Down Expand Up @@ -219,14 +198,14 @@ def delete(self, experimentName: str) -> Response:
return jsonify(dict(status="Success", id=id))

@login_required
@accepts(schema=ExperimentUpdateSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def put(self, experimentName: str) -> Experiment:
"""Modifies an experiment by its unique name."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experimentName", request_type="PUT"
) # noqa: F841
changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore
changes: dict = request.parsed_obj # type: ignore
experiment: Optional[Experiment] = self._experiment_service.get_by_name(
experiment_name=experimentName, log=log
)
Expand Down
11 changes: 0 additions & 11 deletions src/dioptra/restapi/experiment/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,6 @@

from dioptra.restapi.shared.request_scope import request

from .schema import ExperimentRegistrationFormSchema


class ExperimentRegistrationFormSchemaModule(Module):
@provider
def provide_experiment_registration_form_schema_module(
self,
) -> ExperimentRegistrationFormSchema:
return ExperimentRegistrationFormSchema()


class MLFlowClientModule(Module):
@request
Expand All @@ -60,5 +50,4 @@ def register_providers(modules: List[Callable[..., Any]]) -> None:
modules: A list of callables used for configuring the dependency injection
environment.
"""
modules.append(ExperimentRegistrationFormSchemaModule)
modules.append(MLFlowClientModule)
53 changes: 0 additions & 53 deletions src/dioptra/restapi/experiment/interface.py

This file was deleted.

49 changes: 2 additions & 47 deletions src/dioptra/restapi/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,9 @@
from __future__ import annotations

import datetime

from flask_wtf import FlaskForm
from typing_extensions import TypedDict
from wtforms.fields import StringField
from wtforms.validators import InputRequired, ValidationError
from typing import Any, Dict

from dioptra.restapi.app import db
from dioptra.restapi.utils import slugify

from .interface import ExperimentUpdateInterface


class Experiment(db.Model):
Expand All @@ -53,7 +46,7 @@ class Experiment(db.Model):

jobs = db.relationship("Job", back_populates="experiment")

def update(self, changes: ExperimentUpdateInterface):
def update(self, changes: Dict[str, Any]):
"""Updates the record.
Args:
Expand All @@ -66,41 +59,3 @@ def update(self, changes: ExperimentUpdateInterface):
setattr(self, key, val)

return self


class ExperimentRegistrationForm(FlaskForm):
"""The experiment registration form.
Attributes:
name: The name to register as a new experiment.
"""

name = StringField("Name of Experiment", validators=[InputRequired()])

def validate_name(self, field: StringField) -> None:
"""Validates that the experiment does not exist in the registry.
Args:
field: The form field for `name`.
"""

standardized_name: str = slugify(field.data)

if (
Experiment.query.filter_by(name=standardized_name, is_deleted=False).first()
is not None
):
raise ValidationError(
"Bad Request - An experiment is already registered under "
f"the name {standardized_name}. Please select another and resubmit."
)


class ExperimentRegistrationFormData(TypedDict, total=False):
"""The data extracted from the experiment registration form.
Attributes:
name: The name of the experiment.
"""

name: str
74 changes: 4 additions & 70 deletions src/dioptra/restapi/experiment/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,7 @@
"""
from __future__ import annotations

from typing import Any, Dict

from marshmallow import Schema, fields, post_dump, pre_dump

from dioptra.restapi.utils import ParametersSchema, slugify

from .model import (
Experiment,
ExperimentRegistrationForm,
ExperimentRegistrationFormData,
)
from marshmallow import Schema, fields


class ExperimentSchema(Schema):
Expand All @@ -46,79 +36,23 @@ class ExperimentSchema(Schema):
name: The name of the experiment.
"""

__model__ = Experiment

experimentId = fields.Integer(
attribute="experiment_id",
metadata=dict(description="An integer identifying a registered experiment."),
dump_only=True,
)
createdOn = fields.DateTime(
attribute="created_on",
metadata=dict(description="The date and time the experiment was created."),
dump_only=True,
)
lastModified = fields.DateTime(
attribute="last_modified",
metadata=dict(
description="The date and time the experiment was last modified."
),
dump_only=True,
)
name = fields.String(
attribute="name", metadata=dict(description="The name of the experiment.")
)


class ExperimentUpdateSchema(Schema):
"""The schema for the data used to update an |Experiment| object.
Attributes:
name: The new name of the experiment. Must be unique.
"""

__model__ = Experiment

name = fields.String(
attribute="name",
metadata=dict(description="The new name of the experiment. Must be unique."),
)


class ExperimentRegistrationFormSchema(Schema):
"""The schema for the information stored in an experiment registration form.
Attributes:
name: The name of the experiment. Must be unique.
"""

__model__ = ExperimentRegistrationFormData

name = fields.String(
attribute="name",
required=True,
metadata=dict(description="The name of the experiment. Must be unique."),
)

@pre_dump
def extract_data_from_form(
self, data: ExperimentRegistrationForm, many: bool, **kwargs
) -> Dict[str, Any]:
"""Extracts data from the |ExperimentRegistrationForm| for validation."""

return {"name": slugify(data.name.data)}

@post_dump
def serialize_object(
self, data: Dict[str, Any], many: bool, **kwargs
) -> ExperimentRegistrationFormData:
"""Makes an |ExperimentRegistrationFormData| object from the validated data."""
return self.__model__(**data)


ExperimentRegistrationSchema: list[ParametersSchema] = [
dict(
name="name",
type=str,
location="form",
required=True,
help="The name of the experiment. Must be unique.",
)
]
Loading

0 comments on commit 491d324

Please sign in to comment.