Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Experiment endpoint and remove wtforms dependency #328

Merged
merged 3 commits into from
Dec 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/scripts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def register_experiment(self, name: str) -> dict[str, Any]:

response = requests.post(
self.experiment_endpoint,
data=experiment_registration_form,
json=experiment_registration_form,
)

return response.json()
Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,7 +580,7 @@ def register_experiment(self, name: str) -> dict[str, Any]:

response = requests.post(
self.experiment_endpoint,
data=experiment_registration_form,
json=experiment_registration_form,
)

return cast(dict[str, Any], response.json())
Expand Down
49 changes: 14 additions & 35 deletions src/dioptra/restapi/experiment/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,20 +29,11 @@
from injector import inject
from structlog.stdlib import BoundLogger

from dioptra.restapi.utils import as_api_parser

from .errors import ExperimentDoesNotExistError, ExperimentRegistrationError
from .interface import ExperimentUpdateInterface
from .model import (
Experiment,
ExperimentRegistrationForm,
ExperimentRegistrationFormData,
)
from .schema import (
ExperimentRegistrationSchema,
ExperimentSchema,
ExperimentUpdateSchema,
)
from dioptra.restapi.utils import slugify

from .errors import ExperimentDoesNotExistError
from .model import Experiment
from .schema import ExperimentSchema
from .service import ExperimentService

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand Down Expand Up @@ -73,33 +64,21 @@ def get(self) -> List[Experiment]:
return self._experiment_service.get_all(log=log)

@login_required
@api.expect(as_api_parser(api, ExperimentRegistrationSchema))
@accepts(ExperimentRegistrationSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def post(self) -> Experiment:
"""Creates a new experiment via an experiment registration form."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experiment", request_type="POST"
) # noqa: F841
experiment_registration_form: ExperimentRegistrationForm = (
ExperimentRegistrationForm()
)

log.info("Request received")

if not experiment_registration_form.validate_on_submit():
log.error("Form validation failed")
raise ExperimentRegistrationError
parsed_obj = request.parsed_obj # type: ignore

log.info("Form validation successful")
experiment_registration_form_data: ExperimentRegistrationFormData = (
self._experiment_service.extract_data_from_form(
experiment_registration_form=experiment_registration_form, log=log
)
)
return self._experiment_service.create(
experiment_registration_form_data=experiment_registration_form_data, log=log
)
name = slugify(str(parsed_obj["name"]))

return self._experiment_service.create(experiment_name=name, log=log)


@api.route("/<int:experimentId>")
Expand Down Expand Up @@ -144,14 +123,14 @@ def delete(self, experimentId: int) -> Response:
return jsonify(dict(status="Success", id=id))

@login_required
@accepts(schema=ExperimentUpdateSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def put(self, experimentId: int) -> Experiment:
"""Modifies an experiment by its unique identifier."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experimentId", request_type="PUT"
) # noqa: F841
changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore
changes: dict = request.parsed_obj # type: ignore
experiment: Optional[Experiment] = self._experiment_service.get_by_id(
experimentId, log=log
)
Expand Down Expand Up @@ -219,14 +198,14 @@ def delete(self, experimentName: str) -> Response:
return jsonify(dict(status="Success", id=id))

@login_required
@accepts(schema=ExperimentUpdateSchema, api=api)
@accepts(schema=ExperimentSchema, api=api)
@responds(schema=ExperimentSchema, api=api)
def put(self, experimentName: str) -> Experiment:
"""Modifies an experiment by its unique name."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="experimentName", request_type="PUT"
) # noqa: F841
changes: ExperimentUpdateInterface = request.parsed_obj # type: ignore
changes: dict = request.parsed_obj # type: ignore
experiment: Optional[Experiment] = self._experiment_service.get_by_name(
experiment_name=experimentName, log=log
)
Expand Down
11 changes: 0 additions & 11 deletions src/dioptra/restapi/experiment/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,16 +24,6 @@

from dioptra.restapi.shared.request_scope import request

from .schema import ExperimentRegistrationFormSchema


class ExperimentRegistrationFormSchemaModule(Module):
@provider
def provide_experiment_registration_form_schema_module(
self,
) -> ExperimentRegistrationFormSchema:
return ExperimentRegistrationFormSchema()


class MLFlowClientModule(Module):
@request
Expand All @@ -60,5 +50,4 @@ def register_providers(modules: List[Callable[..., Any]]) -> None:
modules: A list of callables used for configuring the dependency injection
environment.
"""
modules.append(ExperimentRegistrationFormSchemaModule)
modules.append(MLFlowClientModule)
53 changes: 0 additions & 53 deletions src/dioptra/restapi/experiment/interface.py

This file was deleted.

49 changes: 2 additions & 47 deletions src/dioptra/restapi/experiment/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,9 @@
from __future__ import annotations

import datetime

from flask_wtf import FlaskForm
from typing_extensions import TypedDict
from wtforms.fields import StringField
from wtforms.validators import InputRequired, ValidationError
from typing import Any, Dict

from dioptra.restapi.app import db
from dioptra.restapi.utils import slugify

from .interface import ExperimentUpdateInterface


class Experiment(db.Model):
Expand All @@ -53,7 +46,7 @@ class Experiment(db.Model):

jobs = db.relationship("Job", back_populates="experiment")

def update(self, changes: ExperimentUpdateInterface):
def update(self, changes: Dict[str, Any]):
"""Updates the record.

Args:
Expand All @@ -66,41 +59,3 @@ def update(self, changes: ExperimentUpdateInterface):
setattr(self, key, val)

return self


class ExperimentRegistrationForm(FlaskForm):
"""The experiment registration form.

Attributes:
name: The name to register as a new experiment.
"""

name = StringField("Name of Experiment", validators=[InputRequired()])

def validate_name(self, field: StringField) -> None:
"""Validates that the experiment does not exist in the registry.

Args:
field: The form field for `name`.
"""

standardized_name: str = slugify(field.data)

if (
Experiment.query.filter_by(name=standardized_name, is_deleted=False).first()
is not None
):
raise ValidationError(
"Bad Request - An experiment is already registered under "
f"the name {standardized_name}. Please select another and resubmit."
)


class ExperimentRegistrationFormData(TypedDict, total=False):
"""The data extracted from the experiment registration form.

Attributes:
name: The name of the experiment.
"""

name: str
74 changes: 4 additions & 70 deletions src/dioptra/restapi/experiment/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,7 @@
"""
from __future__ import annotations

from typing import Any, Dict

from marshmallow import Schema, fields, post_dump, pre_dump

from dioptra.restapi.utils import ParametersSchema, slugify

from .model import (
Experiment,
ExperimentRegistrationForm,
ExperimentRegistrationFormData,
)
from marshmallow import Schema, fields


class ExperimentSchema(Schema):
Expand All @@ -46,79 +36,23 @@ class ExperimentSchema(Schema):
name: The name of the experiment.
"""

__model__ = Experiment

experimentId = fields.Integer(
attribute="experiment_id",
metadata=dict(description="An integer identifying a registered experiment."),
dump_only=True,
)
createdOn = fields.DateTime(
attribute="created_on",
metadata=dict(description="The date and time the experiment was created."),
dump_only=True,
)
lastModified = fields.DateTime(
attribute="last_modified",
metadata=dict(
description="The date and time the experiment was last modified."
),
dump_only=True,
)
name = fields.String(
jkglasbrenner marked this conversation as resolved.
Show resolved Hide resolved
attribute="name", metadata=dict(description="The name of the experiment.")
)


class ExperimentUpdateSchema(Schema):
"""The schema for the data used to update an |Experiment| object.

Attributes:
name: The new name of the experiment. Must be unique.
"""

__model__ = Experiment

name = fields.String(
attribute="name",
metadata=dict(description="The new name of the experiment. Must be unique."),
)


class ExperimentRegistrationFormSchema(Schema):
"""The schema for the information stored in an experiment registration form.

Attributes:
name: The name of the experiment. Must be unique.
"""

__model__ = ExperimentRegistrationFormData

name = fields.String(
attribute="name",
required=True,
metadata=dict(description="The name of the experiment. Must be unique."),
)

@pre_dump
def extract_data_from_form(
self, data: ExperimentRegistrationForm, many: bool, **kwargs
) -> Dict[str, Any]:
"""Extracts data from the |ExperimentRegistrationForm| for validation."""

return {"name": slugify(data.name.data)}

@post_dump
def serialize_object(
self, data: Dict[str, Any], many: bool, **kwargs
) -> ExperimentRegistrationFormData:
"""Makes an |ExperimentRegistrationFormData| object from the validated data."""
return self.__model__(**data)


ExperimentRegistrationSchema: list[ParametersSchema] = [
dict(
name="name",
type=str,
location="form",
required=True,
help="The name of the experiment. Must be unique.",
)
]
Loading