Skip to content

Commit

Permalink
refactor: remove dependency on wtforms from the job endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
abyrne authored and jkglasbrenner committed Dec 14, 2023
1 parent ed91431 commit 6376fa5
Show file tree
Hide file tree
Showing 15 changed files with 194 additions and 692 deletions.
10 changes: 5 additions & 5 deletions examples/scripts/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,18 +170,18 @@ def submit_job(
queue: str = "tensorflow_cpu",
timeout: str = "24h",
) -> dict[str, Any]:
job_form = {
"experiment_name": experiment_name,
job_form: dict[str, Any] = {
"experimentName": experiment_name,
"queue": queue,
"timeout": timeout,
"entry_point": entry_point,
"entryPoint": entry_point,
}

if entry_point_kwargs is not None:
job_form["entry_point_kwargs"] = entry_point_kwargs
job_form["entryPointKwargs"] = entry_point_kwargs

if depends_on is not None:
job_form["depends_on"] = depends_on
job_form["dependsOn"] = depends_on

workflows_file = Path(workflows_file)

Expand Down
10 changes: 5 additions & 5 deletions src/dioptra/client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -667,18 +667,18 @@ def submit_job(
See https://pages.nist.gov/dioptra/user-guide/api-reference-restapi.html
for more information on Dioptra's REST api.
"""
job_form = {
"experiment_name": experiment_name,
job_form: dict[str, Any] = {
"experimentName": experiment_name,
"queue": queue,
"timeout": timeout,
"entry_point": entry_point,
"entryPoint": entry_point,
}

if entry_point_kwargs is not None:
job_form["entry_point_kwargs"] = entry_point_kwargs
job_form["entryPointKwargs"] = entry_point_kwargs

if depends_on is not None:
job_form["depends_on"] = depends_on
job_form["dependsOn"] = depends_on

workflows_file = Path(workflows_file)

Expand Down
38 changes: 21 additions & 17 deletions src/dioptra/restapi/job/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,17 +22,18 @@

import flask
import structlog
from flask import request
from flask_accepts import accepts, responds
from flask_login import login_required
from flask_restx import Namespace, Resource
from injector import inject
from structlog.stdlib import BoundLogger

from dioptra.restapi.utils import as_api_parser
from dioptra.restapi.utils import as_api_parser, as_parameters_schema_list

from .errors import JobDoesNotExistError, JobSubmissionError
from .model import Job, JobForm, JobFormData
from .schema import JobSchema, TaskEngineSubmission, job_submit_form_schema
from .errors import JobDoesNotExistError
from .model import Job
from .schema import JobSchema, TaskEngineSubmission
from .service import JobService

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand Down Expand Up @@ -68,28 +69,31 @@ def get(self) -> List[Job]:
return self._job_service.get_all(log=log)

@login_required
@api.expect(as_api_parser(api, job_submit_form_schema))
@accepts(job_submit_form_schema, api=api)
@api.expect(
as_api_parser(
api,
as_parameters_schema_list(JobSchema, operation="load", location="form"),
)
)
@accepts(form_schema=JobSchema, api=api)
@responds(schema=JobSchema, api=api)
def post(self) -> Job:
"""Creates a new job via a job submission form with an attached file."""
log: BoundLogger = LOGGER.new(
request_id=str(uuid.uuid4()), resource="job", request_type="POST"
) # noqa: F841
job_form: JobForm = JobForm()

parsed_obj = request.parsed_form # type: ignore
log.info("Request received")

if not job_form.validate_on_submit():
log.error("Form validation failed")
raise JobSubmissionError

log.info("Form validation successful")
job_form_data: JobFormData = self._job_service.extract_data_from_form(
job_form=job_form,
return self._job_service.submit(
queue_name=parsed_obj["queue"],
experiment_name=parsed_obj["experiment_name"],
timeout=parsed_obj["timeout"],
entry_point=parsed_obj["entry_point"],
entry_point_kwargs=parsed_obj["entry_point_kwargs"],
depends_on=parsed_obj["depends_on"],
workflow=request.files["workflow"],
log=log,
)
return self._job_service.submit(job_form_data=job_form_data, log=log)


@api.route("/<string:jobId>")
Expand Down
9 changes: 0 additions & 9 deletions src/dioptra/restapi/job/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@
from dioptra.restapi.shared.request_scope import request
from dioptra.restapi.shared.rq.service import RQService

from .schema import JobFormSchema


class JobFormSchemaModule(Module):
@provider
def provide_job_form_schema_module(self) -> JobFormSchema:
return JobFormSchema()


@dataclass
class RQServiceConfiguration(object):
Expand Down Expand Up @@ -97,5 +89,4 @@ def register_providers(modules: List[Callable[..., Any]]) -> None:
modules: A list of callables used for configuring the dependency injection
environment.
"""
modules.append(JobFormSchemaModule)
modules.append(RQServiceModule)
74 changes: 0 additions & 74 deletions src/dioptra/restapi/job/interface.py

This file was deleted.

154 changes: 2 additions & 152 deletions src/dioptra/restapi/job/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,21 +18,9 @@
from __future__ import annotations

import datetime
from typing import Optional

from flask_wtf import FlaskForm
from flask_wtf.file import FileAllowed, FileField, FileRequired
from typing_extensions import TypedDict
from werkzeug.datastructures import FileStorage
from wtforms.fields import StringField
from wtforms.validators import UUID, InputRequired
from wtforms.validators import Optional as OptionalField
from wtforms.validators import Regexp, ValidationError
from typing import Any

from dioptra.restapi.app import db
from dioptra.restapi.utils import slugify

from .interface import JobUpdateInterface

job_statuses = db.Table(
"job_statuses", db.Column("status", db.String(255), primary_key=True)
Expand Down Expand Up @@ -89,7 +77,7 @@ class Job(db.Model):
experiment = db.relationship("Experiment", back_populates="jobs")
queue = db.relationship("Queue", back_populates="jobs")

def update(self, changes: JobUpdateInterface):
def update(self, changes: dict[str, Any]):
"""Updates the record.
Args:
Expand All @@ -102,141 +90,3 @@ def update(self, changes: JobUpdateInterface):
setattr(self, key, val)

return self


class JobForm(FlaskForm):
"""The job submission form.
Attributes:
experiment_name: The name of a registered experiment.
queue: The name of an active queue.
timeout: The maximum alloted time for a job before it times out and is stopped.
If omitted, the job timeout will default to 24 hours.
entry_point: The name of the entry point in the MLproject file to run.
entry_point_kwargs: A list of entry point parameter values to use for the job.
The list is a string with the following format: `-P param1=value1
-P param2=value2`. If omitted, the default values in the MLproject file will
be used.
depends_on: A job UUID to set as a dependency for this new job. The new job will
not run until this job completes successfully. If omitted, then the new job
will start as soon as computing resources are available.
workflow: A tarball archive or zip file containing, at a minimum, a MLproject
file and its associated entry point scripts.
"""

experiment_name = StringField(
"Name of Experiment",
validators=[InputRequired()],
description="The name of a registered experiment.",
)
queue = StringField(
"Queue", validators=[InputRequired()], description="The name of an active queue"
)
timeout = StringField(
"Job Timeout",
validators=[OptionalField(), Regexp(r"\d+?[dhms]")],
description="The maximum alloted time for a job before it times out and "
"is stopped. If omitted, the job timeout will default to 24 hours.",
)
entry_point = StringField(
"MLproject Entry Point",
validators=[InputRequired()],
description="The name of the entry point in the MLproject file to run.",
)
entry_point_kwargs = StringField(
"MLproject Parameter Overrides",
validators=[OptionalField()],
description="A list of entry point parameter values to use for the job. The "
'list is a string with the following format: "-P param1=value1 '
'-P param2=value2". If omitted, the default values in the MLproject file will '
"be used.",
)
depends_on = StringField(
"Job Dependency",
validators=[OptionalField(), UUID()],
description="A job UUID to set as a dependency for this new job. The new job "
"will not run until this job completes successfully. If omitted, then the new "
"job will start as soon as computing resources are available.",
)
workflow = FileField(
validators=[
FileRequired(),
FileAllowed(["tar", "tgz", "bz2", "gz", "xz", "zip"]),
],
description="A tarball archive or zip file containing, at a minimum, a "
"MLproject file and its associated entry point scripts.",
)

def validate_experiment_name(self, field: StringField) -> None:
"""Validates that the experiment is registered and not deleted.
Args:
field: The form field for `experiment_name`.
"""
from dioptra.restapi.models import Experiment

standardized_name: str = slugify(field.data)

if (
Experiment.query.filter_by(name=standardized_name, is_deleted=False).first()
is None
):
raise ValidationError(
f"Bad Request - The experiment {standardized_name} does not exist. "
"Please check spelling and resubmit."
)

def validate_queue(self, field: StringField) -> None:
"""Validates that the queue is registered, active and not deleted.
Args:
field: The form field for `queue`.
"""
from dioptra.restapi.models import Queue, QueueLock

standardized_name: str = slugify(field.data)
queue: Optional[Queue] = (
Queue.query.outerjoin(QueueLock, Queue.queue_id == QueueLock.queue_id)
.filter(
Queue.name == standardized_name,
QueueLock.queue_id == None, # noqa: E711
Queue.is_deleted == False, # noqa: E712
)
.first()
)

if queue is None:
raise ValidationError(
f"Bad Request - The queue {standardized_name} is not valid. "
"Please check spelling and resubmit."
)


class JobFormData(TypedDict, total=False):
"""The data extracted from the job submission form.
Attributes:
experiment_id: An integer identifying the registered experiment.
experiment_name: The name of the registered experiment.
queue_id: An integer identifying a registered queue.
queue: The name of an active queue.
timeout: The maximum alloted time for a job before it times out and is stopped.
entry_point: The name of the entry point in the MLproject file to run.
entry_point_kwargs: A list of entry point parameter values to use for the job.
The list is a string with the following format: `-P param1=value1
-P param2=value2`.
depends_on: A job UUID to set as a dependency for this new job. The new job will
not run until this job completes successfully.
workflow: A tarball archive or zip file containing, at a minimum, a MLproject
file and its associated entry point scripts.
"""

experiment_id: int
experiment_name: str
queue_id: int
queue: str
timeout: Optional[str]
entry_point: str
entry_point_kwargs: Optional[str]
depends_on: Optional[str]
workflow: FileStorage
Loading

0 comments on commit 6376fa5

Please sign in to comment.