From 5a24b92ef765c537c3d7c895e9e5a34f88eb34da Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Thu, 31 Oct 2024 12:02:41 -0400 Subject: [PATCH] chore(restapi): Moved building task engine dict from worfkflow service to lib --- .../workflows/lib/export_task_engine_yaml.py | 37 +++++++++++++++++ src/dioptra/restapi/v1/workflows/service.py | 41 +++---------------- 2 files changed, 43 insertions(+), 35 deletions(-) diff --git a/src/dioptra/restapi/v1/workflows/lib/export_task_engine_yaml.py b/src/dioptra/restapi/v1/workflows/lib/export_task_engine_yaml.py index d89923b84..02d6c1210 100644 --- a/src/dioptra/restapi/v1/workflows/lib/export_task_engine_yaml.py +++ b/src/dioptra/restapi/v1/workflows/lib/export_task_engine_yaml.py @@ -83,6 +83,43 @@ def export_task_engine_yaml( return task_yaml_path +def build_task_engine_dict_for_validation( + plugins: list[Any], + parameters: dict[str, Any], + task_graph: str, +) -> dict[str, Any]: + tasks: dict[str, Any] = {} + parameter_types: dict[str, Any] = {} + for plugin in plugins: + for plugin_file in plugin['plugin_files']: + for task in plugin_file.tasks: + input_parameters = task.input_parameters + output_parameters = task.output_parameters + tasks[task.plugin_task_name] = { + "plugin": _build_plugin_field(plugin['plugin'], plugin_file, task), + } + if input_parameters: + tasks[task.plugin_task_name]["inputs"] = _build_task_inputs( + input_parameters + ) + if output_parameters: + tasks[task.plugin_task_name]["outputs"] = _build_task_outputs( + output_parameters + ) + for param in input_parameters + output_parameters: + name = param.parameter_type.name + if name not in BUILTIN_TYPES: + parameter_types[name] = param.parameter_type.structure + + task_engine_dict = { + "types": parameter_types, + "parameters": parameters, + "tasks": tasks, + "graph": cast(dict[str, Any], yaml.safe_load(task_graph)), + } + return task_engine_dict + + def build_task_engine_dict( entrypoint: models.EntryPoint, entry_point_plugin_files: list[models.EntryPointPluginFile], diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 06638d359..5b6bd5c26 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -24,6 +24,7 @@ from .lib import views from .lib.package_job_files import package_job_files +from .lib.export_task_engine_yaml import build_task_engine_dict_for_validation from .schema import FileTypes LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -40,11 +41,6 @@ validate, is_valid, ) -from dioptra.restapi.v1.workflows.lib.export_task_engine_yaml import ( - _build_plugin_field, - _build_task_inputs, - _build_task_outputs, -) class JobFilesDownloadService(object): @@ -128,37 +124,12 @@ def validate( log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters) parameters = {param['name']: param['default_value'] for param in entrypoint_parameters} - - tasks: dict[str, Any] = {} - parameter_types: dict[str, Any] = {} plugins = self._plugin_id_service.get(plugin_ids) - for plugin in plugins: - for plugin_file in plugin['plugin_files']: - for task in plugin_file.tasks: - input_parameters = task.input_parameters - output_parameters = task.output_parameters - tasks[task.plugin_task_name] = { - "plugin": _build_plugin_field(plugin['plugin'], plugin_file, task), - } - if input_parameters: - tasks[task.plugin_task_name]["inputs"] = _build_task_inputs( - input_parameters - ) - if output_parameters: - tasks[task.plugin_task_name]["outputs"] = _build_task_outputs( - output_parameters - ) - for param in input_parameters + output_parameters: - name = param.parameter_type.name - if name not in BUILTIN_TYPES: - parameter_types[name] = param.parameter_type.structure - - task_engine_dict = { - "types": parameter_types, - "parameters": parameters, - "tasks": tasks, - "graph": cast(dict[str, Any], yaml.safe_load(task_graph)), - } + task_engine_dict = build_task_engine_dict_for_validation( + plugins=plugins, + parameters=parameters, + task_graph=task_graph + ) valid = is_valid(task_engine_dict) if valid: