Skip to content

Commit

Permalink
feat: integrate the task engine with the backend rq worker
Browse files Browse the repository at this point in the history
This update adds a new function which can be called through redis queue for running an experiment
via the task engine. It also adds a Python implementation for syncing the task plugins at the start
of the job without needing to use the AWS CLI tool. Unit tests have been added for these new
features.

A new, temporary endpoint `/job/newTaskEngine` has been added for creating jobs that use the new
"native" task engine integration.
  • Loading branch information
chisholm authored and jkglasbrenner committed Nov 27, 2023
1 parent 965a8a4 commit a76b919
Show file tree
Hide file tree
Showing 24 changed files with 1,133 additions and 112 deletions.
2 changes: 1 addition & 1 deletion docker/shellscripts/entrypoint-worker.m4
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ start_rq() {
${job_queues}"

cd ${dioptra_workdir}
python -m ${rq_worker_module} worker\
PYTHONPATH="${DIOPTRA_PLUGIN_DIR}" python -m ${rq_worker_module} worker\
--url ${rq_redis_uri}\
--results-ttl ${rq_results_ttl}\
${job_queues}
Expand Down
12 changes: 8 additions & 4 deletions src/dioptra/mlflow_plugins/dioptra_clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,11 @@ def get_active_job(self) -> Optional[Dict[str, Any]]:
if self.job_id is None:
return None

return self.get_job(self.job_id)

def get_job(self, job_id) -> Dict[str, Any]:
with self.app.app_context():
job: Job = Job.query.get(self.job_id)
job: Job = Job.query.get(job_id)
return {
"job_id": job.job_id,
"queue": job.queue.name,
Expand All @@ -68,8 +71,7 @@ def update_active_job_status(self, status: str) -> None:

def update_job_status(self, job_id: str, status: str) -> None:
LOGGER.info(
f"=== Updating job status for job with ID '{self.job_id}' to "
f"{status} ==="
f"=== Updating job status for job with ID '{job_id}' to {status} ==="
)

with self.app.app_context():
Expand All @@ -90,9 +92,11 @@ def set_mlflow_run_id_in_db(self, run_id: str) -> None:
return None

LOGGER.info("=== Setting MLFlow run ID in the Dioptra database ===")
self.set_mlflow_run_id_for_job(run_id, self.job_id)

def set_mlflow_run_id_for_job(self, run_id: str, job_id: str) -> None:
with self.app.app_context():
job = Job.query.get(self.job_id)
job = Job.query.get(job_id)
job.update(changes={"mlflow_run_id": run_id})

try:
Expand Down
27 changes: 26 additions & 1 deletion src/dioptra/restapi/job/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import uuid
from typing import List, Optional

import flask
import structlog
from flask_accepts import accepts, responds
from flask_login import login_required
Expand All @@ -31,7 +32,7 @@

from .errors import JobDoesNotExistError, JobSubmissionError
from .model import Job, JobForm, JobFormData
from .schema import JobSchema, job_submit_form_schema
from .schema import JobSchema, TaskEngineSubmission, job_submit_form_schema
from .service import JobService

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand Down Expand Up @@ -116,3 +117,27 @@ def get(self, jobId: str) -> Job:
raise JobDoesNotExistError

return job


@api.route("/newTaskEngine")
class TaskEngineResource(Resource):
@inject
def __init__(self, job_service: JobService, *args, **kwargs):
super().__init__(*args, **kwargs)
self._job_service = job_service

@accepts(schema=TaskEngineSubmission, api=api)
@responds(schema=JobSchema, api=api)
def post(self) -> Job:
post_obj = flask.request.parsed_obj # type: ignore

new_job = self._job_service.submit_task_engine(
queue_name=post_obj["queue"],
experiment_name=post_obj["experimentName"],
experiment_description=post_obj["experimentDescription"],
global_parameters=post_obj.get("globalParameters"),
timeout=post_obj.get("timeout"),
depends_on=post_obj.get("dependsOn"),
)

return new_job
11 changes: 8 additions & 3 deletions src/dioptra/restapi/job/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ def provide_job_form_schema_module(self) -> JobFormSchema:
class RQServiceConfiguration(object):
redis: Redis
run_mlflow: str
run_task_engine: str


class RQServiceModule(Module):
Expand All @@ -50,16 +51,20 @@ class RQServiceModule(Module):
def provide_rq_service_module(
self, configuration: RQServiceConfiguration
) -> RQService:
return RQService(redis=configuration.redis, run_mlflow=configuration.run_mlflow)
return RQService(
redis=configuration.redis,
run_mlflow=configuration.run_mlflow,
run_task_engine=configuration.run_task_engine,
)


def _bind_rq_service_configuration(binder: Binder):
redis_conn: Redis = Redis.from_url(os.getenv("RQ_REDIS_URI", "redis://"))
run_mlflow: str = "dioptra.rq.tasks.run_mlflow_task"
run_task_engine: str = "dioptra.rq.tasks.run_task_engine_task"

configuration: RQServiceConfiguration = RQServiceConfiguration(
redis=redis_conn,
run_mlflow=run_mlflow,
redis=redis_conn, run_mlflow=run_mlflow, run_task_engine=run_task_engine
)

binder.bind(RQServiceConfiguration, to=configuration, scope=request)
Expand Down
8 changes: 8 additions & 0 deletions src/dioptra/restapi/job/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class JobWorkflowUploadError(Exception):
"""The service for storing the uploaded workfile file is unavailable."""


class InvalidExperimentDescriptionError(Exception):
"""The experiment description failed validation."""


def register_error_handlers(api: Api) -> None:
@api.errorhandler(JobDoesNotExistError)
def handle_job_does_not_exist_error(error):
Expand All @@ -57,3 +61,7 @@ def handle_job_workflow_upload_error(error):
},
503,
)

@api.errorhandler(InvalidExperimentDescriptionError)
def handle_invalid_experiment_description_error(error):
return {"message": "The experiment description is invalid!"}, 400
41 changes: 41 additions & 0 deletions src/dioptra/restapi/job/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,47 @@ def serialize_object(
return self.__model__(**data)


class TaskEngineSubmission(Schema):
queue = fields.String(
required=True,
metadata={"description": "The name of an active queue"},
)

experimentName = fields.String(
required=True,
metadata={"description": "The name of a registered experiment."},
)

experimentDescription = fields.Dict(
keys=fields.String(),
required=True,
metadata={"description": "A declarative experiment description."},
)

globalParameters = fields.Dict(
keys=fields.String(),
metadata={"description": "Global parameters for this task engine job."},
)

timeout = fields.String(
metadata={
"description": "The maximum alloted time for a job before it times"
" out and is stopped. If omitted, the job timeout"
" will default to 24 hours.",
},
)

dependsOn = fields.String(
metadata={
"description": "A job UUID to set as a dependency for this new job."
" The new job will not run until this job completes"
" successfully. If omitted, then the new job will"
" start as soon as computing resources are"
" available.",
},
)


job_submit_form_schema: list[ParametersSchema] = [
dict(
name="experiment_name",
Expand Down
69 changes: 67 additions & 2 deletions src/dioptra/restapi/job/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@
from __future__ import annotations

import datetime
import json
import uuid
from pathlib import Path
from typing import List, Optional, cast
from typing import Any, List, Mapping, Optional, cast

import structlog
from injector import inject
Expand All @@ -34,8 +35,9 @@
from dioptra.restapi.queue.service import QueueNameService
from dioptra.restapi.shared.rq.service import RQService
from dioptra.restapi.shared.s3.service import S3Service
from dioptra.task_engine.validation import is_valid

from .errors import JobWorkflowUploadError
from .errors import InvalidExperimentDescriptionError, JobWorkflowUploadError
from .model import Job, JobForm, JobFormData
from .schema import JobFormSchema

Expand Down Expand Up @@ -152,6 +154,69 @@ def submit(self, job_form_data: JobFormData, **kwargs) -> Job:

return new_job

def submit_task_engine(
self,
queue_name: str,
experiment_name: str,
experiment_description: Mapping[str, Any],
global_parameters: Optional[Mapping[str, Any]] = None,
timeout: Optional[str] = None,
depends_on: Optional[str] = None,
) -> Job:
from dioptra.restapi.models import Experiment, Queue

log: BoundLogger = LOGGER.new()

experiment: Optional[Experiment] = self._experiment_service.get_by_name(
experiment_name=experiment_name, log=log
)

if experiment is None:
raise ExperimentDoesNotExistError

queue = cast(
Queue,
self._queue_name_service.get(
queue_name, unlocked_only=True, error_if_not_found=True, log=log
),
)

if not is_valid(experiment_description):
raise InvalidExperimentDescriptionError

job_id = str(uuid.uuid4())
timestamp = datetime.datetime.now()

new_job = Job(
job_id=job_id,
experiment_id=experiment.experiment_id,
queue_id=queue.queue_id,
created_on=timestamp,
last_modified=timestamp,
timeout=timeout,
depends_on=depends_on,
)

if global_parameters is not None:
new_job.entry_point_kwargs = json.dumps(global_parameters)

db.session.add(new_job)
db.session.commit()

self._rq_service.submit_task_engine_job(
job_id=job_id,
queue=queue_name,
experiment_id=experiment.experiment_id,
experiment_description=experiment_description,
global_parameters=global_parameters,
depends_on=depends_on,
timeout=timeout,
)

log.info("Job submission successful", job_id=job_id)

return new_job

def _upload_workflow(self, job_form_data: JobFormData, **kwargs) -> Optional[str]:
log: BoundLogger = kwargs.get("log", LOGGER.new())

Expand Down
48 changes: 46 additions & 2 deletions src/dioptra/restapi/shared/rq/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
# https://creativecommons.org/licenses/by/4.0/legalcode
from __future__ import annotations

from typing import Optional, Union
from typing import Any, Mapping, Optional, Union

import structlog
from redis import Redis
Expand All @@ -32,9 +32,10 @@


class RQService(object):
def __init__(self, redis: Redis, run_mlflow: str) -> None:
def __init__(self, redis: Redis, run_mlflow: str, run_task_engine: str) -> None:
self._redis = redis
self._run_mlflow = run_mlflow
self._run_task_engine = run_task_engine

def get_job_status(self, job: Job, **kwargs) -> str:
log: BoundLogger = kwargs.get("log", LOGGER.new())
Expand Down Expand Up @@ -105,3 +106,46 @@ def submit_mlflow_job(
)

return result

def submit_task_engine_job(
self,
job_id: str,
queue: str,
experiment_id: int,
experiment_description: Mapping[str, Any],
global_parameters: Optional[Mapping[str, Any]] = None,
depends_on: Optional[str] = None,
timeout: Optional[str] = None,
):
log: BoundLogger = LOGGER.new()

job_dependency: Optional[RQJob] = None
if depends_on is not None:
job_dependency = self.get_rq_job(depends_on)

if global_parameters is None:
global_parameters = {}

cmd_kwargs = {
"experiment_id": experiment_id,
"experiment_desc": experiment_description,
"global_parameters": global_parameters,
}

log.info(
"Enqueuing job",
function=self._run_task_engine,
job_id=job_id,
cmd_kwargs=cmd_kwargs,
timeout=timeout,
depends_on=job_dependency,
)

q: RQQueue = RQQueue(queue, default_timeout=24 * 3600, connection=self._redis)
q.enqueue(
self._run_task_engine,
job_id=job_id,
kwargs=cmd_kwargs,
timeout=timeout,
depends_on=job_dependency,
)
3 changes: 2 additions & 1 deletion src/dioptra/rq/tasks/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,6 @@
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
from .run_mlflow import run_mlflow_task
from .run_task_engine import run_task_engine_task

__all__ = ["run_mlflow_task"]
__all__ = ["run_mlflow_task", "run_task_engine_task"]
Loading

0 comments on commit a76b919

Please sign in to comment.