-
Notifications
You must be signed in to change notification settings - Fork 33
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
abyrne
committed
Oct 10, 2023
1 parent
9cac4e4
commit 4d3265b
Showing
9 changed files
with
1,114 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
# This Software (Dioptra) is being made available as a public service by the | ||
# National Institute of Standards and Technology (NIST), an Agency of the United | ||
# States Department of Commerce. This software was developed in part by employees of | ||
# NIST and in part by NIST contractors. Copyright in portions of this software that | ||
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant | ||
# to Title 17 United States Code Section 105, works of NIST employees are not | ||
# subject to copyright protection in the United States. However, NIST may hold | ||
# international copyright in software created by its employees and domestic | ||
# copyright (or licensing rights) in portions of software that were assigned or | ||
# licensed to NIST. To the extent that NIST holds copyright in this software, it is | ||
# being made available under the Creative Commons Attribution 4.0 International | ||
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts | ||
# of the software developed or licensed by NIST. | ||
# | ||
# ACCESS THE FULL CC BY 4.0 LICENSE HERE: | ||
# https://creativecommons.org/licenses/by/4.0/legalcode | ||
"""The job endpoint subpackage.""" | ||
|
||
from .dependencies import bind_dependencies, register_providers | ||
from .errors import register_error_handlers | ||
from .routes import register_routes | ||
|
||
__all__ = [ | ||
"bind_dependencies", | ||
"register_error_handlers", | ||
"register_providers", | ||
"register_routes", | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,114 @@ | ||
# This Software (Dioptra) is being made available as a public service by the | ||
# National Institute of Standards and Technology (NIST), an Agency of the United | ||
# States Department of Commerce. This software was developed in part by employees of | ||
# NIST and in part by NIST contractors. Copyright in portions of this software that | ||
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant | ||
# to Title 17 United States Code Section 105, works of NIST employees are not | ||
# subject to copyright protection in the United States. However, NIST may hold | ||
# international copyright in software created by its employees and domestic | ||
# copyright (or licensing rights) in portions of software that were assigned or | ||
# licensed to NIST. To the extent that NIST holds copyright in this software, it is | ||
# being made available under the Creative Commons Attribution 4.0 International | ||
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts | ||
# of the software developed or licensed by NIST. | ||
# | ||
# ACCESS THE FULL CC BY 4.0 LICENSE HERE: | ||
# https://creativecommons.org/licenses/by/4.0/legalcode | ||
"""The module defining the job endpoints.""" | ||
from __future__ import annotations | ||
|
||
import uuid | ||
from typing import List, Optional | ||
|
||
import structlog | ||
from flask_accepts import accepts, responds | ||
from flask_restx import Namespace, Resource | ||
from injector import inject | ||
from structlog.stdlib import BoundLogger | ||
|
||
from dioptra.restapi.utils import as_api_parser | ||
|
||
from .errors import JobDoesNotExistError, JobSubmissionError | ||
from .model import Job, JobForm, JobFormData | ||
from .schema import JobSchema, job_submit_form_schema | ||
from .service import JobService | ||
|
||
LOGGER: BoundLogger = structlog.stdlib.get_logger() | ||
|
||
api: Namespace = Namespace( | ||
"Job", | ||
description="Job submission and management operations", | ||
) | ||
|
||
|
||
@api.route("/") | ||
class JobResource(Resource): | ||
"""Shows a list of all jobs, and lets you POST to create new jobs.""" | ||
|
||
@inject | ||
def __init__( | ||
self, | ||
*args, | ||
job_service: JobService, | ||
**kwargs, | ||
) -> None: | ||
self._job_service = job_service | ||
super().__init__(*args, **kwargs) | ||
|
||
@responds(schema=JobSchema(many=True), api=api) | ||
def get(self) -> List[Job]: | ||
"""Gets a list of all submitted jobs.""" | ||
log: BoundLogger = LOGGER.new( | ||
request_id=str(uuid.uuid4()), resource="job", request_type="GET" | ||
) # noqa: F841 | ||
log.info("Request received") | ||
return self._job_service.get_all(log=log) | ||
|
||
@api.expect(as_api_parser(api, job_submit_form_schema)) | ||
@accepts(job_submit_form_schema, api=api) | ||
@responds(schema=JobSchema, api=api) | ||
def post(self) -> Job: | ||
"""Creates a new job via a job submission form with an attached file.""" | ||
log: BoundLogger = LOGGER.new( | ||
request_id=str(uuid.uuid4()), resource="job", request_type="POST" | ||
) # noqa: F841 | ||
job_form: JobForm = JobForm() | ||
|
||
log.info("Request received") | ||
|
||
if not job_form.validate_on_submit(): | ||
log.error("Form validation failed") | ||
raise JobSubmissionError | ||
|
||
log.info("Form validation successful") | ||
job_form_data: JobFormData = self._job_service.extract_data_from_form( | ||
job_form=job_form, | ||
log=log, | ||
) | ||
return self._job_service.submit(job_form_data=job_form_data, log=log) | ||
|
||
|
||
@api.route("/<string:jobId>") | ||
@api.param("jobId", "A string specifying a job's UUID.") | ||
class JobIdResource(Resource): | ||
"""Shows a single job.""" | ||
|
||
@inject | ||
def __init__(self, *args, job_service: JobService, **kwargs) -> None: | ||
self._job_service = job_service | ||
super().__init__(*args, **kwargs) | ||
|
||
@responds(schema=JobSchema, api=api) | ||
def get(self, jobId: str) -> Job: | ||
"""Gets a job by its unique identifier.""" | ||
log: BoundLogger = LOGGER.new( | ||
request_id=str(uuid.uuid4()), resource="jobId", request_type="GET" | ||
) # noqa: F841 | ||
log.info("Request received", job_id=jobId) | ||
job: Optional[Job] = self._job_service.get_by_id(jobId, log=log) | ||
|
||
if job is None: | ||
log.error("Job not found", job_id=jobId) | ||
raise JobDoesNotExistError | ||
|
||
return job |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# This Software (Dioptra) is being made available as a public service by the | ||
# National Institute of Standards and Technology (NIST), an Agency of the United | ||
# States Department of Commerce. This software was developed in part by employees of | ||
# NIST and in part by NIST contractors. Copyright in portions of this software that | ||
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant | ||
# to Title 17 United States Code Section 105, works of NIST employees are not | ||
# subject to copyright protection in the United States. However, NIST may hold | ||
# international copyright in software created by its employees and domestic | ||
# copyright (or licensing rights) in portions of software that were assigned or | ||
# licensed to NIST. To the extent that NIST holds copyright in this software, it is | ||
# being made available under the Creative Commons Attribution 4.0 International | ||
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts | ||
# of the software developed or licensed by NIST. | ||
# | ||
# ACCESS THE FULL CC BY 4.0 LICENSE HERE: | ||
# https://creativecommons.org/licenses/by/4.0/legalcode | ||
"""Binding configurations to shared services using dependency injection.""" | ||
from __future__ import annotations | ||
|
||
import os | ||
from dataclasses import dataclass | ||
from typing import Any, Callable, List, Optional | ||
|
||
from boto3.session import Session | ||
from botocore.client import BaseClient | ||
from injector import Binder, Module, provider | ||
from redis import Redis | ||
|
||
from dioptra.restapi.shared.request_scope import request | ||
from dioptra.restapi.shared.rq.service import RQService | ||
|
||
from .schema import JobFormSchema | ||
|
||
|
||
class JobFormSchemaModule(Module): | ||
@provider | ||
def provide_job_form_schema_module(self) -> JobFormSchema: | ||
return JobFormSchema() | ||
|
||
|
||
@dataclass | ||
class RQServiceConfiguration(object): | ||
redis: Redis | ||
run_mlflow: str | ||
|
||
|
||
class RQServiceModule(Module): | ||
@request | ||
@provider | ||
def provide_rq_service_module( | ||
self, configuration: RQServiceConfiguration | ||
) -> RQService: | ||
return RQService(redis=configuration.redis, run_mlflow=configuration.run_mlflow) | ||
|
||
|
||
def _bind_rq_service_configuration(binder: Binder): | ||
redis_conn: Redis = Redis.from_url(os.getenv("RQ_REDIS_URI", "redis://")) | ||
run_mlflow: str = "dioptra.rq.tasks.run_mlflow_task" | ||
|
||
configuration: RQServiceConfiguration = RQServiceConfiguration( | ||
redis=redis_conn, | ||
run_mlflow=run_mlflow, | ||
) | ||
|
||
binder.bind(RQServiceConfiguration, to=configuration, scope=request) | ||
|
||
|
||
def _bind_s3_service_configuration(binder: Binder) -> None: | ||
s3_endpoint_url: Optional[str] = os.getenv("MLFLOW_S3_ENDPOINT_URL") | ||
|
||
s3_session: Session = Session() | ||
s3_client: BaseClient = s3_session.client("s3", endpoint_url=s3_endpoint_url) | ||
|
||
binder.bind(Session, to=s3_session, scope=request) | ||
binder.bind(BaseClient, to=s3_client, scope=request) | ||
|
||
|
||
def bind_dependencies(binder: Binder) -> None: | ||
"""Binds interfaces to implementations within the main application. | ||
Args: | ||
binder: A :py:class:`~injector.Binder` object. | ||
""" | ||
_bind_rq_service_configuration(binder) | ||
_bind_s3_service_configuration(binder) | ||
|
||
|
||
def register_providers(modules: List[Callable[..., Any]]) -> None: | ||
"""Registers type providers within the main application. | ||
Args: | ||
modules: A list of callables used for configuring the dependency injection | ||
environment. | ||
""" | ||
modules.append(JobFormSchemaModule) | ||
modules.append(RQServiceModule) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
# This Software (Dioptra) is being made available as a public service by the | ||
# National Institute of Standards and Technology (NIST), an Agency of the United | ||
# States Department of Commerce. This software was developed in part by employees of | ||
# NIST and in part by NIST contractors. Copyright in portions of this software that | ||
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant | ||
# to Title 17 United States Code Section 105, works of NIST employees are not | ||
# subject to copyright protection in the United States. However, NIST may hold | ||
# international copyright in software created by its employees and domestic | ||
# copyright (or licensing rights) in portions of software that were assigned or | ||
# licensed to NIST. To the extent that NIST holds copyright in this software, it is | ||
# being made available under the Creative Commons Attribution 4.0 International | ||
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts | ||
# of the software developed or licensed by NIST. | ||
# | ||
# ACCESS THE FULL CC BY 4.0 LICENSE HERE: | ||
# https://creativecommons.org/licenses/by/4.0/legalcode | ||
"""Error handlers for the job endpoints.""" | ||
from __future__ import annotations | ||
|
||
from flask_restx import Api | ||
|
||
|
||
class JobDoesNotExistError(Exception): | ||
"""The requested job does not exist.""" | ||
|
||
|
||
class JobSubmissionError(Exception): | ||
"""The job submission form contains invalid parameters.""" | ||
|
||
|
||
class JobWorkflowUploadError(Exception): | ||
"""The service for storing the uploaded workfile file is unavailable.""" | ||
|
||
|
||
def register_error_handlers(api: Api) -> None: | ||
@api.errorhandler(JobDoesNotExistError) | ||
def handle_job_does_not_exist_error(error): | ||
return {"message": "Not Found - The requested job does not exist"}, 404 | ||
|
||
@api.errorhandler(JobSubmissionError) | ||
def handle_job_submission_error(error): | ||
return ( | ||
{ | ||
"message": "Bad Request - The job submission form contains " | ||
"invalid parameters. Please verify and resubmit." | ||
}, | ||
400, | ||
) | ||
|
||
@api.errorhandler(JobWorkflowUploadError) | ||
def handle_job_workflow_upload_error(error): | ||
return ( | ||
{ | ||
"message": "Service Unavailable - Unable to store the " | ||
"workflow file after upload. Please try again " | ||
"later." | ||
}, | ||
503, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
# This Software (Dioptra) is being made available as a public service by the | ||
# National Institute of Standards and Technology (NIST), an Agency of the United | ||
# States Department of Commerce. This software was developed in part by employees of | ||
# NIST and in part by NIST contractors. Copyright in portions of this software that | ||
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant | ||
# to Title 17 United States Code Section 105, works of NIST employees are not | ||
# subject to copyright protection in the United States. However, NIST may hold | ||
# international copyright in software created by its employees and domestic | ||
# copyright (or licensing rights) in portions of software that were assigned or | ||
# licensed to NIST. To the extent that NIST holds copyright in this software, it is | ||
# being made available under the Creative Commons Attribution 4.0 International | ||
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts | ||
# of the software developed or licensed by NIST. | ||
# | ||
# ACCESS THE FULL CC BY 4.0 LICENSE HERE: | ||
# https://creativecommons.org/licenses/by/4.0/legalcode | ||
"""The interfaces for creating and updating |Job| objects. | ||
.. |Job| replace:: :py:class:`~.model.Job` | ||
""" | ||
from __future__ import annotations | ||
|
||
import datetime | ||
from typing import Optional | ||
|
||
from typing_extensions import TypedDict | ||
|
||
|
||
class JobInterface(TypedDict, total=False): | ||
"""The interface for constructing a new |Job| object. | ||
Attributes: | ||
job_id: A UUID that identifies the job. | ||
mlflow_run_id: A UUID that identifies the MLFlow run associated with the job. | ||
experiment_id: An integer identifying a registered experiment. | ||
queue_id: An integer identifying a registered queue. | ||
created_on: The date and time the job was created. | ||
last_modified: The date and time the job was last modified. | ||
timeout: The maximum alloted time for a job before it times out and is stopped. | ||
workflow_uri: The URI pointing to the tarball archive or zip file uploaded with | ||
the job. | ||
entry_point: The name of the entry point in the MLproject file to run. | ||
entry_point_kwargs: A string listing parameter values to pass to the entry point | ||
for the job. The list of parameters is specified using the following format: | ||
`-P param1=value1 -P param2=value2`. | ||
status: The current status of the job. The allowed values are: `queued`, | ||
`started`, `deferred`, `finished`, `failed`. | ||
depends_on: A UUID for a previously submitted job to set as a dependency for the | ||
current job. | ||
""" | ||
|
||
job_id: str | ||
mlflow_run_id: Optional[str] | ||
experiment_id: int | ||
queue_id: int | ||
created_on: datetime.datetime | ||
last_modified: datetime.datetime | ||
timeout: Optional[str] | ||
workflow_uri: str | ||
entry_point: str | ||
entry_point_kwargs: Optional[str] | ||
status: str | ||
depends_on: Optional[str] | ||
|
||
|
||
class JobUpdateInterface(TypedDict, total=False): | ||
"""The interface for updating a |Job| object. | ||
Attributes: | ||
status: The current status of the job. The allowed values are: `queued`, | ||
`started`, `deferred`, `finished`, `failed`. | ||
""" | ||
|
||
status: str |
Oops, something went wrong.