diff --git a/src/dioptra/restapi/resource/__init__.py b/src/dioptra/restapi/resource/__init__.py new file mode 100644 index 000000000..7015b8853 --- /dev/null +++ b/src/dioptra/restapi/resource/__init__.py @@ -0,0 +1,28 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The job endpoint subpackage.""" + +from .dependencies import bind_dependencies, register_providers +from .errors import register_error_handlers +from .routes import register_routes + +__all__ = [ + "bind_dependencies", + "register_error_handlers", + "register_providers", + "register_routes", +] diff --git a/src/dioptra/restapi/resource/controller.py b/src/dioptra/restapi/resource/controller.py new file mode 100644 index 000000000..7068e8133 --- /dev/null +++ b/src/dioptra/restapi/resource/controller.py @@ -0,0 +1,114 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The module defining the job endpoints.""" +from __future__ import annotations + +import uuid +from typing import List, Optional + +import structlog +from flask_accepts import accepts, responds +from flask_restx import Namespace, Resource +from injector import inject +from structlog.stdlib import BoundLogger + +from dioptra.restapi.utils import as_api_parser + +from .errors import JobDoesNotExistError, JobSubmissionError +from .model import Job, JobForm, JobFormData +from .schema import JobSchema, job_submit_form_schema +from .service import JobService + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + +api: Namespace = Namespace( + "Job", + description="Job submission and management operations", +) + + +@api.route("/") +class JobResource(Resource): + """Shows a list of all jobs, and lets you POST to create new jobs.""" + + @inject + def __init__( + self, + *args, + job_service: JobService, + **kwargs, + ) -> None: + self._job_service = job_service + super().__init__(*args, **kwargs) + + @responds(schema=JobSchema(many=True), api=api) + def get(self) -> List[Job]: + """Gets a list of all submitted jobs.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="job", request_type="GET" + ) # noqa: F841 + log.info("Request received") + return self._job_service.get_all(log=log) + + @api.expect(as_api_parser(api, job_submit_form_schema)) + @accepts(job_submit_form_schema, api=api) + @responds(schema=JobSchema, api=api) + def post(self) -> Job: + """Creates a new job via a job submission form with an attached file.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="job", request_type="POST" + ) # noqa: F841 + job_form: JobForm = JobForm() + + log.info("Request received") + + if not job_form.validate_on_submit(): + log.error("Form validation failed") + raise JobSubmissionError + + log.info("Form validation successful") + job_form_data: JobFormData = self._job_service.extract_data_from_form( + job_form=job_form, + log=log, + ) + return self._job_service.submit(job_form_data=job_form_data, log=log) + + +@api.route("/") +@api.param("jobId", "A string specifying a job's UUID.") +class JobIdResource(Resource): + """Shows a single job.""" + + @inject + def __init__(self, *args, job_service: JobService, **kwargs) -> None: + self._job_service = job_service + super().__init__(*args, **kwargs) + + @responds(schema=JobSchema, api=api) + def get(self, jobId: str) -> Job: + """Gets a job by its unique identifier.""" + log: BoundLogger = LOGGER.new( + request_id=str(uuid.uuid4()), resource="jobId", request_type="GET" + ) # noqa: F841 + log.info("Request received", job_id=jobId) + job: Optional[Job] = self._job_service.get_by_id(jobId, log=log) + + if job is None: + log.error("Job not found", job_id=jobId) + raise JobDoesNotExistError + + return job diff --git a/src/dioptra/restapi/resource/dependencies.py b/src/dioptra/restapi/resource/dependencies.py new file mode 100644 index 000000000..6bcb3009e --- /dev/null +++ b/src/dioptra/restapi/resource/dependencies.py @@ -0,0 +1,96 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Binding configurations to shared services using dependency injection.""" +from __future__ import annotations + +import os +from dataclasses import dataclass +from typing import Any, Callable, List, Optional + +from boto3.session import Session +from botocore.client import BaseClient +from injector import Binder, Module, provider +from redis import Redis + +from dioptra.restapi.shared.request_scope import request +from dioptra.restapi.shared.rq.service import RQService + +from .schema import JobFormSchema + + +class JobFormSchemaModule(Module): + @provider + def provide_job_form_schema_module(self) -> JobFormSchema: + return JobFormSchema() + + +@dataclass +class RQServiceConfiguration(object): + redis: Redis + run_mlflow: str + + +class RQServiceModule(Module): + @request + @provider + def provide_rq_service_module( + self, configuration: RQServiceConfiguration + ) -> RQService: + return RQService(redis=configuration.redis, run_mlflow=configuration.run_mlflow) + + +def _bind_rq_service_configuration(binder: Binder): + redis_conn: Redis = Redis.from_url(os.getenv("RQ_REDIS_URI", "redis://")) + run_mlflow: str = "dioptra.rq.tasks.run_mlflow_task" + + configuration: RQServiceConfiguration = RQServiceConfiguration( + redis=redis_conn, + run_mlflow=run_mlflow, + ) + + binder.bind(RQServiceConfiguration, to=configuration, scope=request) + + +def _bind_s3_service_configuration(binder: Binder) -> None: + s3_endpoint_url: Optional[str] = os.getenv("MLFLOW_S3_ENDPOINT_URL") + + s3_session: Session = Session() + s3_client: BaseClient = s3_session.client("s3", endpoint_url=s3_endpoint_url) + + binder.bind(Session, to=s3_session, scope=request) + binder.bind(BaseClient, to=s3_client, scope=request) + + +def bind_dependencies(binder: Binder) -> None: + """Binds interfaces to implementations within the main application. + + Args: + binder: A :py:class:`~injector.Binder` object. + """ + _bind_rq_service_configuration(binder) + _bind_s3_service_configuration(binder) + + +def register_providers(modules: List[Callable[..., Any]]) -> None: + """Registers type providers within the main application. + + Args: + modules: A list of callables used for configuring the dependency injection + environment. + """ + modules.append(JobFormSchemaModule) + modules.append(RQServiceModule) diff --git a/src/dioptra/restapi/resource/errors.py b/src/dioptra/restapi/resource/errors.py new file mode 100644 index 000000000..e053c17e8 --- /dev/null +++ b/src/dioptra/restapi/resource/errors.py @@ -0,0 +1,59 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Error handlers for the job endpoints.""" +from __future__ import annotations + +from flask_restx import Api + + +class JobDoesNotExistError(Exception): + """The requested job does not exist.""" + + +class JobSubmissionError(Exception): + """The job submission form contains invalid parameters.""" + + +class JobWorkflowUploadError(Exception): + """The service for storing the uploaded workfile file is unavailable.""" + + +def register_error_handlers(api: Api) -> None: + @api.errorhandler(JobDoesNotExistError) + def handle_job_does_not_exist_error(error): + return {"message": "Not Found - The requested job does not exist"}, 404 + + @api.errorhandler(JobSubmissionError) + def handle_job_submission_error(error): + return ( + { + "message": "Bad Request - The job submission form contains " + "invalid parameters. Please verify and resubmit." + }, + 400, + ) + + @api.errorhandler(JobWorkflowUploadError) + def handle_job_workflow_upload_error(error): + return ( + { + "message": "Service Unavailable - Unable to store the " + "workflow file after upload. Please try again " + "later." + }, + 503, + ) diff --git a/src/dioptra/restapi/resource/interface.py b/src/dioptra/restapi/resource/interface.py new file mode 100644 index 000000000..f4ea82804 --- /dev/null +++ b/src/dioptra/restapi/resource/interface.py @@ -0,0 +1,74 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The interfaces for creating and updating |Job| objects. + +.. |Job| replace:: :py:class:`~.model.Job` +""" +from __future__ import annotations + +import datetime +from typing import Optional + +from typing_extensions import TypedDict + + +class JobInterface(TypedDict, total=False): + """The interface for constructing a new |Job| object. + + Attributes: + job_id: A UUID that identifies the job. + mlflow_run_id: A UUID that identifies the MLFlow run associated with the job. + experiment_id: An integer identifying a registered experiment. + queue_id: An integer identifying a registered queue. + created_on: The date and time the job was created. + last_modified: The date and time the job was last modified. + timeout: The maximum alloted time for a job before it times out and is stopped. + workflow_uri: The URI pointing to the tarball archive or zip file uploaded with + the job. + entry_point: The name of the entry point in the MLproject file to run. + entry_point_kwargs: A string listing parameter values to pass to the entry point + for the job. The list of parameters is specified using the following format: + `-P param1=value1 -P param2=value2`. + status: The current status of the job. The allowed values are: `queued`, + `started`, `deferred`, `finished`, `failed`. + depends_on: A UUID for a previously submitted job to set as a dependency for the + current job. + """ + + job_id: str + mlflow_run_id: Optional[str] + experiment_id: int + queue_id: int + created_on: datetime.datetime + last_modified: datetime.datetime + timeout: Optional[str] + workflow_uri: str + entry_point: str + entry_point_kwargs: Optional[str] + status: str + depends_on: Optional[str] + + +class JobUpdateInterface(TypedDict, total=False): + """The interface for updating a |Job| object. + + Attributes: + status: The current status of the job. The allowed values are: `queued`, + `started`, `deferred`, `finished`, `failed`. + """ + + status: str diff --git a/src/dioptra/restapi/resource/model.py b/src/dioptra/restapi/resource/model.py new file mode 100644 index 000000000..ce0ef43f1 --- /dev/null +++ b/src/dioptra/restapi/resource/model.py @@ -0,0 +1,248 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The data models for the job endpoint objects.""" +from __future__ import annotations + +import datetime +from typing import Optional + +from flask_wtf import FlaskForm +from flask_wtf.file import FileAllowed, FileField, FileRequired +from typing_extensions import TypedDict +from werkzeug.datastructures import FileStorage +from wtforms.fields import StringField +from wtforms.validators import UUID, InputRequired +from wtforms.validators import Optional as OptionalField +from wtforms.validators import Regexp, ValidationError + +from dioptra.restapi.app import db +from dioptra.restapi.utils import slugify + +from .interface import JobUpdateInterface + +from ..SharedResource import SharedResource + +job_statuses = db.Table( + "job_statuses", db.Column("status", db.String(255), primary_key=True) +) + + +class Resource(db.Model): + """The Resource table. + + Attributes: + resource_id: A UUID that identifies the resource. + creator_id: A UUID that identifies the user that created the resource. + owner_id: A UUID that identifies the group that owns the resource. + created_on: The date and time the resource was created. + last_modified: The date and time the resource was last modified. + """ + + __tablename__ = "Resources" + + resource_id = db.Column(db.String(36), primary_key=True) + """A UUID that identifies the Resource.""" + + creator_id= db.Column(db.BigInteger(), db.ForeignKey("users.user_id"), index= True) + owner_id= db.Column(db.BigInteger(), db.ForeignKey("groups.group_id"), index= True) + created_on = db.Column(db.DateTime()) + last_modified = db.Column(db.DateTime()) + deleted = db.Column(db.Boolean) + + + creator = db.relationship('User', foreign_keys=[creator_id]) + owner = db.relationship('Group', foreign_keys=[owner_id]) + + + @property + def shares(self): + """List of groups that the resource is shared with.""" + # Define a relationship with SharedPrototypeResource, if needed + return SharedResource.query.filter_by(resource_id=self.resource_id).all() + + def check_permission(self, user: User, action: str) -> bool: + """Check if the user has permission to perform the specified action. + + Args: + user: The user to check. + action: The action to check. + + Returns: + True if the user has permission to perform the action, False otherwise. + """ + + membership = GroupMemberships.query.filter_by(user_id = user.user_id) + #next((x for x in self.owner.users if x.user_id == user.id), None) + + if membership is None: + return False + + return cast(bool, getattr(membership, action)) + + + + def update(self, changes: JobUpdateInterface): + """Updates the record. + + Args: + changes: A :py:class:`~.interface.JobUpdateInterface` dictionary containing + record updates. + """ + self.last_modified = datetime.datetime.now() + + for key, val in changes.items(): + setattr(self, key, val) + + return self + + +class JobForm(FlaskForm): + """The job submission form. + + Attributes: + experiment_name: The name of a registered experiment. + queue: The name of an active queue. + timeout: The maximum alloted time for a job before it times out and is stopped. + If omitted, the job timeout will default to 24 hours. + entry_point: The name of the entry point in the MLproject file to run. + entry_point_kwargs: A list of entry point parameter values to use for the job. + The list is a string with the following format: `-P param1=value1 + -P param2=value2`. If omitted, the default values in the MLproject file will + be used. + depends_on: A job UUID to set as a dependency for this new job. The new job will + not run until this job completes successfully. If omitted, then the new job + will start as soon as computing resources are available. + workflow: A tarball archive or zip file containing, at a minimum, a MLproject + file and its associated entry point scripts. + """ + + experiment_name = StringField( + "Name of Experiment", + validators=[InputRequired()], + description="The name of a registered experiment.", + ) + queue = StringField( + "Queue", validators=[InputRequired()], description="The name of an active queue" + ) + timeout = StringField( + "Job Timeout", + validators=[OptionalField(), Regexp(r"\d+?[dhms]")], + description="The maximum alloted time for a job before it times out and " + "is stopped. If omitted, the job timeout will default to 24 hours.", + ) + entry_point = StringField( + "MLproject Entry Point", + validators=[InputRequired()], + description="The name of the entry point in the MLproject file to run.", + ) + entry_point_kwargs = StringField( + "MLproject Parameter Overrides", + validators=[OptionalField()], + description="A list of entry point parameter values to use for the job. The " + 'list is a string with the following format: "-P param1=value1 ' + '-P param2=value2". If omitted, the default values in the MLproject file will ' + "be used.", + ) + depends_on = StringField( + "Job Dependency", + validators=[OptionalField(), UUID()], + description="A job UUID to set as a dependency for this new job. The new job " + "will not run until this job completes successfully. If omitted, then the new " + "job will start as soon as computing resources are available.", + ) + workflow = FileField( + validators=[ + FileRequired(), + FileAllowed(["tar", "tgz", "bz2", "gz", "xz", "zip"]), + ], + description="A tarball archive or zip file containing, at a minimum, a " + "MLproject file and its associated entry point scripts.", + ) + + def validate_experiment_name(self, field: StringField) -> None: + """Validates that the experiment is registered and not deleted. + + Args: + field: The form field for `experiment_name`. + """ + from dioptra.restapi.models import Experiment + + standardized_name: str = slugify(field.data) + + if ( + Experiment.query.filter_by(name=standardized_name, is_deleted=False).first() + is None + ): + raise ValidationError( + f"Bad Request - The experiment {standardized_name} does not exist. " + "Please check spelling and resubmit." + ) + + def validate_queue(self, field: StringField) -> None: + """Validates that the queue is registered, active and not deleted. + + Args: + field: The form field for `queue`. + """ + from dioptra.restapi.models import Queue, QueueLock + + standardized_name: str = slugify(field.data) + queue: Optional[Queue] = ( + Queue.query.outerjoin(QueueLock, Queue.queue_id == QueueLock.queue_id) + .filter( + Queue.name == standardized_name, + QueueLock.queue_id == None, # noqa: E711 + Queue.is_deleted == False, # noqa: E712 + ) + .first() + ) + + if queue is None: + raise ValidationError( + f"Bad Request - The queue {standardized_name} is not valid. " + "Please check spelling and resubmit." + ) + + +class JobFormData(TypedDict, total=False): + """The data extracted from the job submission form. + + Attributes: + experiment_id: An integer identifying the registered experiment. + experiment_name: The name of the registered experiment. + queue_id: An integer identifying a registered queue. + queue: The name of an active queue. + timeout: The maximum alloted time for a job before it times out and is stopped. + entry_point: The name of the entry point in the MLproject file to run. + entry_point_kwargs: A list of entry point parameter values to use for the job. + The list is a string with the following format: `-P param1=value1 + -P param2=value2`. + depends_on: A job UUID to set as a dependency for this new job. The new job will + not run until this job completes successfully. + workflow: A tarball archive or zip file containing, at a minimum, a MLproject + file and its associated entry point scripts. + """ + + experiment_id: int + experiment_name: str + queue_id: int + queue: str + timeout: Optional[str] + entry_point: str + entry_point_kwargs: Optional[str] + depends_on: Optional[str] + workflow: FileStorage diff --git a/src/dioptra/restapi/resource/routes.py b/src/dioptra/restapi/resource/routes.py new file mode 100644 index 000000000..247e2c850 --- /dev/null +++ b/src/dioptra/restapi/resource/routes.py @@ -0,0 +1,41 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Methods for registering the job endpoint routes with the main application. + +.. |Api| replace:: :py:class:`flask_restx.Api` +.. |Flask| replace:: :py:class:`flask.Flask` +""" +from __future__ import annotations + +from flask import Flask +from flask_restx import Api + +BASE_ROUTE: str = "job" + + +def register_routes(api: Api, app: Flask, root: str = "api") -> None: + """Registers the job endpoint routes with the main application. + + Args: + api: The main REST |Api| object. + app: The main |Flask| application. + root: The root path for the registration prefix of the namespace. The default + is `"api"`. + """ + from .controller import api as endpoint_api + + api.add_namespace(endpoint_api, path=f"/{root}/{BASE_ROUTE}") diff --git a/src/dioptra/restapi/resource/schema.py b/src/dioptra/restapi/resource/schema.py new file mode 100644 index 000000000..84e4480d3 --- /dev/null +++ b/src/dioptra/restapi/resource/schema.py @@ -0,0 +1,286 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The schemas for serializing/deserializing the job endpoint objects. + +.. |Job| replace:: :py:class:`~.model.Job` +.. |JobForm| replace:: :py:class:`~.model.JobForm` +.. |JobFormData| replace:: :py:class:`~.model.JobFormData` +""" +from __future__ import annotations + +from typing import Any, Dict + +from marshmallow import Schema, fields, post_dump, post_load, pre_dump, validate +from werkzeug.datastructures import FileStorage + +from dioptra.restapi.utils import ParametersSchema, slugify + +from .model import Job, JobForm, JobFormData + + +class JobSchema(Schema): + """The schema for the data stored in a |Job| object. + + Attributes: + jobId: A UUID that identifies the job. + mlflowRunId: A UUID that identifies the MLFlow run associated with the job. + experimentId: An integer identifying a registered experiment. + queueId: An integer identifying a registered queue. + createdOn: The date and time the job was created. + lastModified: The date and time the job was last modified. + timeout: The maximum alloted time for a job before it times out and is stopped. + workflowUri: The URI pointing to the tarball archive or zip file uploaded with + the job. + entryPoint: The name of the entry point in the MLproject file to run. + entryPointKwargs: A string listing parameter values to pass to the entry point + for the job. The list of parameters is specified using the following format: + `-P param1=value1 -P param2=value2`. + dependsOn: A UUID for a previously submitted job to set as a dependency for the + current job. + status: The current status of the job. The allowed values are: `queued`, + `started`, `deferred`, `finished`, `failed`. + """ + + __model__ = Job + + jobId = fields.String( + attribute="job_id", metadata=dict(description="A UUID that identifies the job.") + ) + mlflowRunId = fields.String( + attribute="mlflow_run_id", + allow_none=True, + metadata=dict( + description="A UUID that identifies the MLFLow run associated with the " + "job.", + ), + ) + experimentId = fields.Integer( + attribute="experiment_id", + metadata=dict(description="An integer identifying a registered experiment."), + ) + queueId = fields.Integer( + attribute="queue_id", + metadata=dict(description="An integer identifying a registered queue."), + ) + createdOn = fields.DateTime( + attribute="created_on", + metadata=dict(description="The date and time the job was created."), + ) + lastModified = fields.DateTime( + attribute="last_modified", + metadata=dict(description="The date and time the job was last modified."), + ) + timeout = fields.String( + attribute="timeout", + allow_none=True, + metadata=dict( + description="The maximum alloted time for a job before it times out and " + "is stopped.", + ), + ) + workflowUri = fields.String( + attribute="workflow_uri", + metadata=dict( + description="The URI pointing to the tarball archive or zip file uploaded " + "with the job.", + ), + ) + entryPoint = fields.String( + attribute="entry_point", + metadata=dict( + description="The name of the entry point in the MLproject file to run.", + ), + ) + entryPointKwargs = fields.String( + attribute="entry_point_kwargs", + allow_none=True, + metadata=dict( + description="A string listing parameter values to pass to the entry point " + "for the job. The list of parameters is specified using the following " + 'format: "-P param1=value1 -P param2=value2".', + ), + ) + dependsOn = fields.String( + attribute="depends_on", + allow_none=True, + metadata=dict( + description="A UUID for a previously submitted job to set as a dependency " + "for the current job.", + ), + ) + status = fields.String( + validate=validate.OneOf( + ["queued", "started", "deferred", "finished", "failed"], + ), + metadata=dict( + description="The current status of the job. The allowed values are: " + "queued, started, deferred, finished, failed.", + ), + ) + + @post_load + def deserialize_object(self, data: Dict[str, Any], many: bool, **kwargs) -> Job: + """Creates a |Job| object from the validated data.""" + return self.__model__(**data) + + +class JobFormSchema(Schema): + """The schema for the information stored in a submitted job form. + + Attributes: + experiment_name: The name of a registered experiment. + queue: The name of an active queue. + timeout: The maximum alloted time for a job before it times out and is stopped. + If omitted, the job timeout will default to 24 hours. + entry_point: The name of the entry point in the MLproject file to run. + entry_point_kwargs: A list of entry point parameter values to use for the job. + The list is a string with the following format: `-P param1=value1 + -P param2=value2`. If omitted, the default values in the MLproject file will + be used. + depends_on: A job UUID to set as a dependency for this new job. The new job will + not run until this job completes successfully. If omitted, then the new job + will start as soon as computing resources are available. + workflow: A tarball archive or zip file containing, at a minimum, a MLproject + file and its associated entry point scripts. + """ + + __model__ = JobFormData + + experiment_name = fields.String( + required=True, metadata=dict(description="The name of a registered experiment.") + ) + queue = fields.String( + required=True, metadata=dict(description="The name of an active queue") + ) + timeout = fields.String( + allow_none=True, + metadata=dict( + description="The maximum alloted time for a job before it times out and " + "is stopped. If omitted, the job timeout will default to 24 hours.", + ), + ) + entry_point = fields.String( + required=True, + metadata=dict( + description="The name of the entry point in the MLproject file to run.", + ), + ) + entry_point_kwargs = fields.String( + allow_none=True, + metadata=dict( + description="A list of entry point parameter values to use for the job. " + 'The list is a string with the following format: "-P param1=value1 ' + '-P param2=value2". If omitted, the default values in the MLproject ' + "file will be used.", + ), + ) + depends_on = fields.String( + allow_none=True, + metadata=dict( + description="A job UUID to set as a dependency for this new job. The new " + "job will not run until this job completes successfully. If omitted, then " + "the new job will start as soon as computing resources are available.", + ), + ) + workflow = fields.Raw( + metadata=dict( + description="A tarball archive or zip file containing, at a minimum, a " + "MLproject file and its associated entry point scripts.", + ), + ) + + @pre_dump + def extract_data_from_form( + self, data: JobForm, many: bool, **kwargs + ) -> Dict[str, Any]: + """Extracts data from the |JobForm| for validation.""" + + return { + "experiment_name": slugify(data.experiment_name.data), + "queue": slugify(data.queue.data), + "timeout": data.timeout.data or None, + "entry_point": data.entry_point.data, + "entry_point_kwargs": data.entry_point_kwargs.data or None, + "depends_on": data.depends_on.data or None, + "workflow": data.workflow.data, + } + + @post_dump + def serialize_object( + self, data: Dict[str, Any], many: bool, **kwargs + ) -> JobFormData: + """Creates a |JobFormData| object from the validated data.""" + return self.__model__(**data) + + +job_submit_form_schema: list[ParametersSchema] = [ + dict( + name="experiment_name", + type=str, + location="form", + required=True, + help="The name of a registered experiment.", + ), + dict( + name="queue", + type=str, + location="form", + required=True, + help="The name of an active queue.", + ), + dict( + name="timeout", + type=str, + location="form", + required=False, + help="The maximum alloted time for a job before it times out and is stopped. " + "If omitted, the job timeout will default to 24 hours.", + ), + dict( + name="entry_point", + type=str, + location="form", + required=True, + help="The name of the entry point in the MLproject file to run.", + ), + dict( + name="entry_point_kwargs", + type=str, + location="form", + required=False, + help="A list of entry point parameter values to use for the job. The list is " + 'a string with the following format: "-P param1=value1 -P param2=value2". ' + "If omitted, the default values in the MLproject file will be used.", + ), + dict( + name="depends_on", + type=str, + location="form", + required=False, + help="A job UUID to set as a dependency for this new job. The new job will " + "not run until this job completes successfully. If omitted, then the new job" + "will start as soon as computing resources are available.", + ), + dict( + name="workflow", + type=FileStorage, + location="files", + required=True, + help="A tarball archive or zip file containing, at a minimum, a MLproject file " + "and its associated entry point scripts.", + ), +] diff --git a/src/dioptra/restapi/resource/service.py b/src/dioptra/restapi/resource/service.py new file mode 100644 index 000000000..67b60a6f2 --- /dev/null +++ b/src/dioptra/restapi/resource/service.py @@ -0,0 +1,168 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""The server-side functions that perform job endpoint operations.""" +from __future__ import annotations + +import datetime +import uuid +from pathlib import Path +from typing import List, Optional + +import structlog +from injector import inject +from rq.job import Job as RQJob +from structlog.stdlib import BoundLogger +from werkzeug.utils import secure_filename + +from dioptra.restapi.app import db +from dioptra.restapi.experiment.errors import ExperimentDoesNotExistError +from dioptra.restapi.experiment.service import ExperimentService +from dioptra.restapi.queue.errors import QueueDoesNotExistError +from dioptra.restapi.queue.service import QueueService +from dioptra.restapi.shared.rq.service import RQService +from dioptra.restapi.shared.s3.service import S3Service + +from .errors import JobWorkflowUploadError +from .model import Job, JobForm, JobFormData +from .schema import JobFormSchema + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + + +class JobService(object): + @inject + def __init__( + self, + job_form_schema: JobFormSchema, + rq_service: RQService, + s3_service: S3Service, + experiment_service: ExperimentService, + queue_service: QueueService, + ) -> None: + self._job_form_schema = job_form_schema + self._rq_service = rq_service + self._s3_service = s3_service + self._experiment_service = experiment_service + self._queue_service = queue_service + + @staticmethod + def create(job_form_data: JobFormData, **kwargs) -> Job: + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + timestamp = datetime.datetime.now() + + return Job( + experiment_id=job_form_data["experiment_id"], + queue_id=job_form_data["queue_id"], + created_on=timestamp, + last_modified=timestamp, + timeout=job_form_data.get("timeout"), + entry_point=job_form_data["entry_point"], + entry_point_kwargs=job_form_data.get("entry_point_kwargs"), + depends_on=job_form_data.get("depends_on"), + ) + + @staticmethod + def get_all(**kwargs) -> List[Job]: + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + return Job.query.all() # type: ignore + + @staticmethod + def get_by_id(job_id: str, **kwargs) -> Job: + log: BoundLogger = kwargs.get("log", LOGGER.new()) # noqa: F841 + + return Job.query.get(job_id) # type: ignore + + def extract_data_from_form(self, job_form: JobForm, **kwargs) -> JobFormData: + from dioptra.restapi.models import Experiment, Queue + + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + job_form_data: JobFormData = self._job_form_schema.dump(job_form) + + experiment: Optional[Experiment] = self._experiment_service.get_by_name( + job_form_data["experiment_name"], log=log + ) + + if experiment is None: + raise ExperimentDoesNotExistError + + queue: Optional[Queue] = self._queue_service.get_unlocked_by_name( + job_form_data["queue"], log=log + ) + + if queue is None: + raise QueueDoesNotExistError + + job_form_data["experiment_id"] = experiment.experiment_id + job_form_data["queue_id"] = queue.queue_id + + return job_form_data + + def submit(self, job_form_data: JobFormData, **kwargs) -> Job: + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + workflow_uri: Optional[str] = self._upload_workflow(job_form_data, log=log) + + if workflow_uri is None: + log.error( + "Failed to upload workflow to backend storage", + workflow_filename=secure_filename( + job_form_data["workflow"].filename or "" + ), + ) + raise JobWorkflowUploadError + + new_job: Job = self.create(job_form_data, log=log) + new_job.workflow_uri = workflow_uri + + rq_job: RQJob = self._rq_service.submit_mlflow_job( + queue=job_form_data["queue"], + workflow_uri=new_job.workflow_uri, + experiment_id=new_job.experiment_id, + entry_point=new_job.entry_point, + entry_point_kwargs=new_job.entry_point_kwargs, + depends_on=new_job.depends_on, + timeout=new_job.timeout, + log=log, + ) + + new_job.job_id = rq_job.get_id() + + db.session.add(new_job) + db.session.commit() + + log.info("Job submission successful", job_id=new_job.job_id) + + return new_job + + def _upload_workflow(self, job_form_data: JobFormData, **kwargs) -> Optional[str]: + log: BoundLogger = kwargs.get("log", LOGGER.new()) + + upload_dir = Path(uuid.uuid4().hex) + workflow_filename = upload_dir / secure_filename( + job_form_data["workflow"].filename or "" + ) + + workflow_uri: Optional[str] = self._s3_service.upload( + fileobj=job_form_data["workflow"], + bucket="workflow", + key=str(workflow_filename), + log=log, + ) + + return workflow_uri