Skip to content

Commit

Permalink
chore(restapi): Moved building task engine dict from worfkflow servic…
Browse files Browse the repository at this point in the history
…e to lib
  • Loading branch information
andrewhand committed Oct 31, 2024
1 parent d8c11c2 commit 5a24b92
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 35 deletions.
37 changes: 37 additions & 0 deletions src/dioptra/restapi/v1/workflows/lib/export_task_engine_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,43 @@ def export_task_engine_yaml(
return task_yaml_path


def build_task_engine_dict_for_validation(
plugins: list[Any],
parameters: dict[str, Any],
task_graph: str,
) -> dict[str, Any]:
tasks: dict[str, Any] = {}
parameter_types: dict[str, Any] = {}
for plugin in plugins:
for plugin_file in plugin['plugin_files']:
for task in plugin_file.tasks:
input_parameters = task.input_parameters
output_parameters = task.output_parameters
tasks[task.plugin_task_name] = {
"plugin": _build_plugin_field(plugin['plugin'], plugin_file, task),
}
if input_parameters:
tasks[task.plugin_task_name]["inputs"] = _build_task_inputs(
input_parameters
)
if output_parameters:
tasks[task.plugin_task_name]["outputs"] = _build_task_outputs(
output_parameters
)
for param in input_parameters + output_parameters:
name = param.parameter_type.name
if name not in BUILTIN_TYPES:
parameter_types[name] = param.parameter_type.structure

task_engine_dict = {
"types": parameter_types,
"parameters": parameters,
"tasks": tasks,
"graph": cast(dict[str, Any], yaml.safe_load(task_graph)),
}
return task_engine_dict


def build_task_engine_dict(
entrypoint: models.EntryPoint,
entry_point_plugin_files: list[models.EntryPointPluginFile],
Expand Down
41 changes: 6 additions & 35 deletions src/dioptra/restapi/v1/workflows/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

from .lib import views
from .lib.package_job_files import package_job_files
from .lib.export_task_engine_yaml import build_task_engine_dict_for_validation
from .schema import FileTypes

LOGGER: BoundLogger = structlog.stdlib.get_logger()
Expand All @@ -40,11 +41,6 @@
validate,
is_valid,
)
from dioptra.restapi.v1.workflows.lib.export_task_engine_yaml import (
_build_plugin_field,
_build_task_inputs,
_build_task_outputs,
)


class JobFilesDownloadService(object):
Expand Down Expand Up @@ -128,37 +124,12 @@ def validate(
log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters)

parameters = {param['name']: param['default_value'] for param in entrypoint_parameters}

tasks: dict[str, Any] = {}
parameter_types: dict[str, Any] = {}
plugins = self._plugin_id_service.get(plugin_ids)
for plugin in plugins:
for plugin_file in plugin['plugin_files']:
for task in plugin_file.tasks:
input_parameters = task.input_parameters
output_parameters = task.output_parameters
tasks[task.plugin_task_name] = {
"plugin": _build_plugin_field(plugin['plugin'], plugin_file, task),
}
if input_parameters:
tasks[task.plugin_task_name]["inputs"] = _build_task_inputs(
input_parameters
)
if output_parameters:
tasks[task.plugin_task_name]["outputs"] = _build_task_outputs(
output_parameters
)
for param in input_parameters + output_parameters:
name = param.parameter_type.name
if name not in BUILTIN_TYPES:
parameter_types[name] = param.parameter_type.structure

task_engine_dict = {
"types": parameter_types,
"parameters": parameters,
"tasks": tasks,
"graph": cast(dict[str, Any], yaml.safe_load(task_graph)),
}
task_engine_dict = build_task_engine_dict_for_validation(
plugins=plugins,
parameters=parameters,
task_graph=task_graph
)
valid = is_valid(task_engine_dict)

if valid:
Expand Down

0 comments on commit 5a24b92

Please sign in to comment.