From 8d2cedb040cbd754b37b6d8cc3fb21b7c18a7c28 Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Tue, 10 Sep 2024 11:40:47 -0400 Subject: [PATCH 01/12] build(restapi): Add workflow controller for entrypoint workflow validation --- .../restapi/v1/workflows/controller.py | 39 ++++++++++++++++++- 1 file changed, 37 insertions(+), 2 deletions(-) diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 428619cdc..b8eecd4f6 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -25,8 +25,8 @@ from injector import inject from structlog.stdlib import BoundLogger -from .schema import FileTypes, JobFilesDownloadQueryParametersSchema -from .service import JobFilesDownloadService +from .schema import FileTypes, JobFilesDownloadQueryParametersSchema, EntrypointWorkflowSchema +from .service import JobFilesDownloadService, EntrypointValidateService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -78,3 +78,38 @@ def get(self): mimetype=mimetype[parsed_query_params["file_type"]], download_name=download_name[parsed_query_params["file_type"]], ) + + +@api.route("/entrypointValidate") +class EntrypointValidateEndpoint(Resource): + @inject + def __init__( + self, entrypoint_validate_service: EntrypointValidateService, *args, **kwargs + ) -> None: + """Initialize the workflow resource. + + All arguments are provided via dependency injection. + + Args: + entrypoint_validate_service: A EntrypointValidateService object. + """ + self._entrypoint_validate_service = entrypoint_validate_service + super().__init__(*args, **kwargs) + + @login_required + @accepts(entrypoint_workflow_schema=EntrypointWorkflowSchema, api=api) + def post(self): + """Validates the workflow for a entrypoint.""" # noqa: B950 + log = LOGGER.new( + request_id=str(uuid.uuid4()), resource="Workflows", request_type="POST" + ) + parsed_obj = request.parsed_obj # type: ignore + task_graph = parsed_obj["task_graph"] + plugin_ids = parsed_obj["plugin_ids"] + parameters = parsed_obj["parameters"] + return self._entrypoint_validate_service.validate( + task_graph=task_graph, + plugin_ids=plugin_ids, + parameters=parameters, + log=log, + ) \ No newline at end of file From b8326ef7767c7a76bef43de20056a51f90471513 Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Tue, 10 Sep 2024 11:41:00 -0400 Subject: [PATCH 02/12] build(restapi): Add workflow schema for entrypoint workflow validation --- src/dioptra/restapi/v1/workflows/schema.py | 25 ++++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 92ea28ec7..df9364e75 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -19,6 +19,8 @@ from marshmallow import Schema, fields +from dioptra.restapi.entrypoints import EntrypointParameterSchema + class FileTypes(Enum): TAR_GZ = "tar_gz" @@ -41,3 +43,26 @@ class JobFilesDownloadQueryParametersSchema(Schema): by_value=True, default=FileTypes.TAR_GZ.value, ) + + +class EntrypointWorkflowSchema(Schema): + """The YAML that represents the Entrypoint Workflow.""" + + taskGraph = fields.String( + attribute="task_graph", + metadata=dict(description="Task graph of the Entrypoint resource."), + required=True, + ) + pluginIds = fields.List( + fields.Integer(), + attribute="plugin_ids", + data_key="plugins", + metadata=dict(description="List of plugin files for the entrypoint."), + load_only=True, + ) + parameters = fields.Nested( + EntrypointParameterSchema, + attribute="parameters", + many=True, + metadata=dict(description="List of parameters for the entrypoint."), + ) From b2b928656a6b08264efe5ed7e5879090b0a072ca Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Tue, 10 Sep 2024 11:41:12 -0400 Subject: [PATCH 03/12] build(restapi): Add workflow service for entrypoint workflow validation --- src/dioptra/restapi/v1/workflows/service.py | 65 +++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 3150c831d..ec887a06f 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -19,10 +19,16 @@ import structlog from structlog.stdlib import BoundLogger +from injector import inject from .lib import package_job_files, views from .schema import FileTypes +from dioptra.restapi.db import db, models +from dioptra.restapi.v1.plugins.service import PluginIdsService +from dioptra.restapi.v1.workflows.lib.export_task_engine_yaml import extract_tasks +from dioptra.task_engine.validation import validate, is_valid + LOGGER: BoundLogger = structlog.stdlib.get_logger() RESOURCE_TYPE: Final[str] = "workflow" @@ -60,3 +66,62 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]: file_type=file_type, logger=log, ) + + +class EntrypointValidateService(object): + """""" + + @inject + def __init__( + self, + plugin_ids_service: PluginIdsService, + ) -> None: + """Initialize the entrypoint service. + + All arguments are provided via dependency injection. + + Args: + plugin_ids_service: A PluginIdsService object. + """ + self._plugin_ids_service = plugin_ids_service + + def validate( + self, + task_graph: str, + plugin_ids: list[int], + parameters: dict[str: str]\ + ) -> dict[str, str]: + """Validate a entrypoint workflow before the entrypoint is created. + + Args: + task_graph: + plugin_ids: + parameters: + + Returns: + + + Raises: + + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=parameters) + + entry_point_plugin_files = self._plugin_ids_service.get(plugin_ids, error_if_not_found=True) + tasks, parameter_types = extract_tasks(entry_point_plugin_files) + task_engine_dict = { + "types": parameter_types, + "parameters": parameters, + "tasks": tasks, + "graph": task_graph, + } + valid = is_valid(task_engine_dict) + if valid: + return {"status": "Success", "valid": valid} + else: + issues = validate(task_engine_dict) + return { + "status": "Success", + "valid": valid, + "issues": issues, + } \ No newline at end of file From d4990206daa4a9b5b6142c8fdb3ee322fe453431 Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Tue, 10 Sep 2024 11:42:52 -0400 Subject: [PATCH 04/12] build(restapi): Added kwargs to entrypoint workflow validation service Added kwargs for use of log --- src/dioptra/restapi/v1/workflows/service.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index ec887a06f..d2695664f 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -89,7 +89,8 @@ def validate( self, task_graph: str, plugin_ids: list[int], - parameters: dict[str: str]\ + parameters: dict[str: str], + **kwargs, ) -> dict[str, str]: """Validate a entrypoint workflow before the entrypoint is created. From 74232b9bd10a143d51b8b5e21e3ed0935e4da22c Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Tue, 10 Sep 2024 12:03:20 -0400 Subject: [PATCH 05/12] tests(restapi): Started tests for entrypoint workflow validation --- tests/unit/restapi/v1/test_workflows.py | 92 +++++++++++++++++++++++++ 1 file changed, 92 insertions(+) create mode 100644 tests/unit/restapi/v1/test_workflows.py diff --git a/tests/unit/restapi/v1/test_workflows.py b/tests/unit/restapi/v1/test_workflows.py new file mode 100644 index 000000000..037dcb207 --- /dev/null +++ b/tests/unit/restapi/v1/test_workflows.py @@ -0,0 +1,92 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +"""Test suite for entrypoint operations. + +This module contains a set of tests that validate the CRUD operations and additional +functionalities for the entrypoint entity. The tests ensure that the entrypoints can be +registered, renamed, deleted, and locked/unlocked as expected through the REST API. +""" +from typing import Any + +from flask.testing import FlaskClient +from flask_sqlalchemy import SQLAlchemy +from werkzeug.test import TestResponse + +from dioptra.restapi.routes import V1_WORKFLOWS_ROUTE, V1_ROOT + +from ..lib import actions, asserts, helpers + + +# -- Actions --------------------------------------------------------------------------- + + +def validate_entrypoint_workflow( + client: FlaskClient, + task_graph: str, + plugin_ids: list[int], + entrypoint_parameters: dict[str, str], +) -> TestResponse: + """""" + payload: dict[str, Any] = { + "taskGraph" : task_graph, + "pluginIds": plugin_ids, + "parameters": entrypoint_parameters, + } + + return client.post( + f"/{V1_ROOT}/{V1_WORKFLOWS_ROUTE}/entrypointValidate", + json=payload, + follow_redirects=True, + ) + + +# -- Assertions ------------------------------------------------------------------------ + + +def assert_entrypoint_workflow_is_valid( + client: FlaskClient, + task_graph: str, + plugin_ids: list[int], + entrypoint_parameters: dict[str, str], + ) -> None: + response = validate_entrypoint_workflow( + client, + task_graph=task_graph, + plugin_ids=plugin_ids, + entrypoint_parameters=entrypoint_parameters, + ) + assert response.status_code == 200 and response.valid == True + + +# -- Tests ----------------------------------------------------------------------------- + + +def test_entrypoint_workflow_validation( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + """""" + task_graph = "" + plugin_ids = [] + entrypoint_parameters = [] + assert_entrypoint_workflow_is_valid( + client, + task_graph=task_graph, + plugin_ids=plugin_ids, + entrypoint_parameters=entrypoint_parameters, + ) From 610ad5bd92e4cd42932439faac2e5fd72138914a Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Tue, 10 Sep 2024 13:14:56 -0400 Subject: [PATCH 06/12] build(restapi): Fixed entrypoint workflow service Changed a dependacy from plugin id service to plugin file id service. --- src/dioptra/restapi/v1/workflows/service.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index d2695664f..7f14a015c 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -25,7 +25,7 @@ from .schema import FileTypes from dioptra.restapi.db import db, models -from dioptra.restapi.v1.plugins.service import PluginIdsService +from dioptra.restapi.v1.plugins.service import PluginIdFileIdService from dioptra.restapi.v1.workflows.lib.export_task_engine_yaml import extract_tasks from dioptra.task_engine.validation import validate, is_valid @@ -74,7 +74,7 @@ class EntrypointValidateService(object): @inject def __init__( self, - plugin_ids_service: PluginIdsService, + plugin_file_id_service: PluginIdFileIdService, ) -> None: """Initialize the entrypoint service. @@ -83,7 +83,7 @@ def __init__( Args: plugin_ids_service: A PluginIdsService object. """ - self._plugin_ids_service = plugin_ids_service + self._plugin_file_id_service = plugin_file_id_service def validate( self, @@ -108,8 +108,13 @@ def validate( log: BoundLogger = kwargs.get("log", LOGGER.new()) log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=parameters) - entry_point_plugin_files = self._plugin_ids_service.get(plugin_ids, error_if_not_found=True) - tasks, parameter_types = extract_tasks(entry_point_plugin_files) + entrypoint_plugin_files = [] + for plugin_id in plugin_ids: + plugin_files, _ = self._plugin_id_file_service.get(plugin_id) + for plugin_file in plugin_files: + entrypoint_plugin_files.append(plugin_file) + + tasks, parameter_types = extract_tasks(entrypoint_plugin_files) task_engine_dict = { "types": parameter_types, "parameters": parameters, From c74542c3d6dd2f1f5687a172c06e969a21b75789 Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Tue, 10 Sep 2024 13:15:31 -0400 Subject: [PATCH 07/12] tests(restapi): Updates to test workflow --- tests/unit/restapi/v1/test_workflows.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/tests/unit/restapi/v1/test_workflows.py b/tests/unit/restapi/v1/test_workflows.py index 037dcb207..944497079 100644 --- a/tests/unit/restapi/v1/test_workflows.py +++ b/tests/unit/restapi/v1/test_workflows.py @@ -20,6 +20,7 @@ functionalities for the entrypoint entity. The tests ensure that the entrypoints can be registered, renamed, deleted, and locked/unlocked as expected through the REST API. """ +import textwrap from typing import Any from flask.testing import FlaskClient @@ -81,7 +82,13 @@ def test_entrypoint_workflow_validation( auth_account: dict[str, Any], ) -> None: """""" - task_graph = "" + task_graph = textwrap.dedent( + """# my entrypoint graph + graph: + message: + my_entrypoint: $name + """ + ) plugin_ids = [] entrypoint_parameters = [] assert_entrypoint_workflow_is_valid( From 46c920b46655588da18738fde13c9371300e6fb3 Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Fri, 13 Sep 2024 11:50:15 -0400 Subject: [PATCH 08/12] build(restapi): Finalized working entrypoint validation --- .../restapi/v1/workflows/controller.py | 4 +- src/dioptra/restapi/v1/workflows/schema.py | 2 +- src/dioptra/restapi/v1/workflows/service.py | 73 ++++++++++++++----- 3 files changed, 58 insertions(+), 21 deletions(-) diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index b8eecd4f6..b6d21a3d5 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -97,7 +97,7 @@ def __init__( super().__init__(*args, **kwargs) @login_required - @accepts(entrypoint_workflow_schema=EntrypointWorkflowSchema, api=api) + @accepts(schema=EntrypointWorkflowSchema, api=api) def post(self): """Validates the workflow for a entrypoint.""" # noqa: B950 log = LOGGER.new( @@ -110,6 +110,6 @@ def post(self): return self._entrypoint_validate_service.validate( task_graph=task_graph, plugin_ids=plugin_ids, - parameters=parameters, + entrypoint_parameters=parameters, log=log, ) \ No newline at end of file diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index df9364e75..7c22aecb3 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -19,7 +19,7 @@ from marshmallow import Schema, fields -from dioptra.restapi.entrypoints import EntrypointParameterSchema +from dioptra.restapi.v1.entrypoints.schema import EntrypointParameterSchema class FileTypes(Enum): diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 7f14a015c..0bffd73b3 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -15,19 +15,31 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" -from typing import IO, Final +from pathlib import Path +from typing import IO, Final, Any, cast import structlog from structlog.stdlib import BoundLogger from injector import inject +import yaml from .lib import package_job_files, views from .schema import FileTypes -from dioptra.restapi.db import db, models -from dioptra.restapi.v1.plugins.service import PluginIdFileIdService +from dioptra.restapi.db import db +from dioptra.restapi.v1 import utils +from dioptra.restapi.v1.plugins.service import PluginIdFileService, PluginIdsService +from dioptra.task_engine.type_registry import BUILTIN_TYPES from dioptra.restapi.v1.workflows.lib.export_task_engine_yaml import extract_tasks -from dioptra.task_engine.validation import validate, is_valid +from dioptra.task_engine.validation import ( + validate, + is_valid, +) +from dioptra.restapi.v1.workflows.lib.export_task_engine_yaml import ( + _build_plugin_field, + _build_task_inputs, + _build_task_outputs, +) LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -74,7 +86,8 @@ class EntrypointValidateService(object): @inject def __init__( self, - plugin_file_id_service: PluginIdFileIdService, + plugin_id_service: PluginIdsService, + plugin_id_file_service: PluginIdFileService, ) -> None: """Initialize the entrypoint service. @@ -83,15 +96,16 @@ def __init__( Args: plugin_ids_service: A PluginIdsService object. """ - self._plugin_file_id_service = plugin_file_id_service + self._plugin_id_service = plugin_id_service + self._plugin_id_file_service = plugin_id_file_service def validate( self, task_graph: str, plugin_ids: list[int], - parameters: dict[str: str], + entrypoint_parameters: list[dict[str, Any]], **kwargs, - ) -> dict[str, str]: + ) -> dict[str, Any]: """Validate a entrypoint workflow before the entrypoint is created. Args: @@ -106,28 +120,51 @@ def validate( """ log: BoundLogger = kwargs.get("log", LOGGER.new()) - log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=parameters) + log.debug("Validate a entrypoint workflow", task_graph=task_graph, plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters) + + parameters = {param['name']: param['default_value'] for param in entrypoint_parameters} + + tasks: dict[str, Any] = {} + parameter_types: dict[str, Any] = {} + plugins = self._plugin_id_service.get(plugin_ids) + for plugin in plugins: + for plugin_file in plugin['plugin_files']: + # print(plugin_file.tasks) + for task in plugin_file.tasks: + input_parameters = task.input_parameters + output_parameters = task.output_parameters + tasks[task.plugin_task_name] = { + "plugin": _build_plugin_field(plugin['plugin'], plugin_file, task), + } + if input_parameters: + tasks[task.plugin_task_name]["inputs"] = _build_task_inputs( + input_parameters + ) + if output_parameters: + tasks[task.plugin_task_name]["outputs"] = _build_task_outputs( + output_parameters + ) + for param in input_parameters + output_parameters: + name = param.parameter_type.name + if name not in BUILTIN_TYPES: + parameter_types[name] = param.parameter_type.structure - entrypoint_plugin_files = [] - for plugin_id in plugin_ids: - plugin_files, _ = self._plugin_id_file_service.get(plugin_id) - for plugin_file in plugin_files: - entrypoint_plugin_files.append(plugin_file) - - tasks, parameter_types = extract_tasks(entrypoint_plugin_files) task_engine_dict = { "types": parameter_types, "parameters": parameters, "tasks": tasks, - "graph": task_graph, + "graph": cast(dict[str, Any], yaml.safe_load(task_graph)), } + print (task_engine_dict) valid = is_valid(task_engine_dict) + if valid: return {"status": "Success", "valid": valid} else: issues = validate(task_engine_dict) + print(issues) return { "status": "Success", "valid": valid, - "issues": issues, + # "issues": issues, } \ No newline at end of file From 8034c52d5dbc1aa8627c32a8864a6780e621d17d Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Fri, 13 Sep 2024 11:50:43 -0400 Subject: [PATCH 09/12] tests(restapi): Added more to test entrypoint workflow validation --- tests/unit/restapi/v1/test_workflows.py | 58 +++++++++++++++++++++---- 1 file changed, 50 insertions(+), 8 deletions(-) diff --git a/tests/unit/restapi/v1/test_workflows.py b/tests/unit/restapi/v1/test_workflows.py index 944497079..0dff7a59b 100644 --- a/tests/unit/restapi/v1/test_workflows.py +++ b/tests/unit/restapi/v1/test_workflows.py @@ -39,12 +39,12 @@ def validate_entrypoint_workflow( client: FlaskClient, task_graph: str, plugin_ids: list[int], - entrypoint_parameters: dict[str, str], + entrypoint_parameters: list[dict[str, Any]], ) -> TestResponse: """""" payload: dict[str, Any] = { "taskGraph" : task_graph, - "pluginIds": plugin_ids, + "plugins": plugin_ids, "parameters": entrypoint_parameters, } @@ -62,7 +62,7 @@ def assert_entrypoint_workflow_is_valid( client: FlaskClient, task_graph: str, plugin_ids: list[int], - entrypoint_parameters: dict[str, str], + entrypoint_parameters: list[dict[str, Any]], ) -> None: response = validate_entrypoint_workflow( client, @@ -70,7 +70,8 @@ def assert_entrypoint_workflow_is_valid( plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters, ) - assert response.status_code == 200 and response.valid == True + # print(response.get_json()) + assert response.status_code == 200 and response.get_json()['valid'] == True # -- Tests ----------------------------------------------------------------------------- @@ -82,14 +83,55 @@ def test_entrypoint_workflow_validation( auth_account: dict[str, Any], ) -> None: """""" + plugin_response = actions.register_plugin( + client, + name="hello_world", + description="The hello world plugin.", + group_id=auth_account["default_group_id"], + ).get_json() + plugin_file_contents = textwrap.dedent( + """"from dioptra import pyplugs + + @pyplugs.register + def hello_world(name: str) -> str: + return f'Hello, {name}!'" + """ + ) + plugin_file_tasks = [ + { + "name": "hello_world", + "inputParams": [ + { + "name": "name", + "parameterType": 2, + "required": True, + }, + ], + "outputParams": [ + { + "name": "greeting", + "parameterType": 2, + }, + ], + }, + ] + plugin_file_response = actions.register_plugin_file( + client, + plugin_id=plugin_response["id"], + filename="tasks.py", + description="The task plugin file for hello world.", + contents=plugin_file_contents, + tasks = plugin_file_tasks, + ).get_json() task_graph = textwrap.dedent( """# my entrypoint graph - graph: - message: - my_entrypoint: $name + hello_step: + hello_world: + name: $name """ ) - plugin_ids = [] + + plugin_ids = [plugin_response["id"]] entrypoint_parameters = [] assert_entrypoint_workflow_is_valid( client, From 07208b181be3d6cfdbaaf8f173f03dd196150681 Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Fri, 13 Sep 2024 12:00:32 -0400 Subject: [PATCH 10/12] tests(restapi): Working test case for validate entrypoint workflow --- tests/unit/restapi/v1/test_workflows.py | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/tests/unit/restapi/v1/test_workflows.py b/tests/unit/restapi/v1/test_workflows.py index 0dff7a59b..5c37164a0 100644 --- a/tests/unit/restapi/v1/test_workflows.py +++ b/tests/unit/restapi/v1/test_workflows.py @@ -132,7 +132,13 @@ def hello_world(name: str) -> str: ) plugin_ids = [plugin_response["id"]] - entrypoint_parameters = [] + entrypoint_parameters = [ + { + "name" : "name", + "defaultValue": "User", + "parameterType": "string", + }, + ] assert_entrypoint_workflow_is_valid( client, task_graph=task_graph, From 220689ae9f226146880ca305f7e0b1f133a0ad95 Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Fri, 4 Oct 2024 10:14:39 -0400 Subject: [PATCH 11/12] tests(restapi): more test for workflow validation errors --- tests/unit/restapi/v1/test_workflows.py | 257 +++++++++++++++++++++++- 1 file changed, 256 insertions(+), 1 deletion(-) diff --git a/tests/unit/restapi/v1/test_workflows.py b/tests/unit/restapi/v1/test_workflows.py index 5c37164a0..d422e6352 100644 --- a/tests/unit/restapi/v1/test_workflows.py +++ b/tests/unit/restapi/v1/test_workflows.py @@ -70,10 +70,26 @@ def assert_entrypoint_workflow_is_valid( plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters, ) - # print(response.get_json()) assert response.status_code == 200 and response.get_json()['valid'] == True +def assert_entrypoint_workflow_has_errors( + client: FlaskClient, + task_graph: str, + plugin_ids: list[int], + entrypoint_parameters: list[dict[str, Any]], + expected_message: str, +) -> None: + response = validate_entrypoint_workflow( + client, + task_graph=task_graph, + plugin_ids=plugin_ids, + entrypoint_parameters=entrypoint_parameters, + ) + # print(response.get_json()['message']) + assert response.status_code == 422 and response.get_json()['message'] == expected_message + + # -- Tests ----------------------------------------------------------------------------- @@ -145,3 +161,242 @@ def hello_world(name: str) -> str: plugin_ids=plugin_ids, entrypoint_parameters=entrypoint_parameters, ) + + +def test_entrypoint_workflow_validation_has_semantic_error( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + """""" + plugin_response = actions.register_plugin( + client, + name="hello_world", + description="The hello world plugin.", + group_id=auth_account["default_group_id"], + ).get_json() + plugin_file_contents = textwrap.dedent( + """"from dioptra import pyplugs + + @pyplugs.register + def hello_world(name: str) -> str: + return f'Hello, {name}!'" + """ + ) + plugin_file_tasks = [ + { + "name": "hello_world", + "inputParams": [ + { + "name": "name", + "parameterType": 2, + "required": True, + }, + ], + "outputParams": [ + { + "name": "greeting", + "parameterType": 2, + }, + ], + }, + ] + plugin_file_response = actions.register_plugin_file( + client, + plugin_id=plugin_response["id"], + filename="tasks.py", + description="The task plugin file for hello world.", + contents=plugin_file_contents, + tasks = plugin_file_tasks, + ).get_json() + task_graph = textwrap.dedent( + """# my entrypoint graph + hello_step: + hello_wrld: + name: $name + """ + ) # task graph is wrong, hello_wrld is not the task plugin + + plugin_ids = [plugin_response["id"]] + entrypoint_parameters = [ + { + "name" : "name", + "defaultValue": "User", + "parameterType": "string", + }, + ] + expected_message = "[ValidationIssue(IssueType.SEMANTIC, IssueSeverity.ERROR, 'In step \"hello_step\": unrecognized task plugin: hello_wrld')]" + assert_entrypoint_workflow_has_errors( + client, + task_graph=task_graph, + plugin_ids=plugin_ids, + entrypoint_parameters=entrypoint_parameters, + expected_message=expected_message, + ) + + +def test_entrypoint_workflow_validation_has_schema_error( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + """""" + plugin_response = actions.register_plugin( + client, + name="hello_world", + description="The hello world plugin.", + group_id=auth_account["default_group_id"], + ).get_json() + plugin_file_contents = textwrap.dedent( + """"from dioptra import pyplugs + + @pyplugs.register + def hello_world(name: str) -> str: + return f'Hello, {name}!'" + """ + ) + plugin_file_tasks = [ + { + "name": "hello_world", + "inputtParams": [ + { + "name": "name", + "parameterType": 2, + "required": True, + }, + ], + "outputParams": [ + { + "name": "greeting", + "parameterType": 2, + }, + ], + }, + ] # plugin file tasks is wrong, inputtParams is not an accepted feild + plugin_file_response = actions.register_plugin_file( + client, + plugin_id=plugin_response["id"], + filename="tasks.py", + description="The task plugin file for hello world.", + contents=plugin_file_contents, + tasks = plugin_file_tasks, + ).get_json() + task_graph = textwrap.dedent( + """# my entrypoint graph + hello_step: + hello_world: + name: $name + """ + ) + + plugin_ids = [plugin_response["id"]] + entrypoint_parameters = [ + { + "name" : "name", + "defaultValue": "User", + "parameterType": "string", + }, + ] + expected_message = "[ValidationIssue(IssueType.SCHEMA, IssueSeverity.ERROR, 'In tasks section: {} should be non-empty')]" + assert_entrypoint_workflow_has_errors( + client, + task_graph=task_graph, + plugin_ids=plugin_ids, + entrypoint_parameters=entrypoint_parameters, + expected_message=expected_message, + ) + + +def test_entrypoint_workflow_validation_has_type_error( + client: FlaskClient, + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + """""" + plugin_response = actions.register_plugin( + client, + name="hello_world", + description="The hello world plugin.", + group_id=auth_account["default_group_id"], + ).get_json() + plugin_file_contents = textwrap.dedent( + """"from dioptra import pyplugs + + @pyplugs.register + def hello_world(name: str) -> str: + print(f'Hello, {name}!') + return name + + @pyplugs.register + def goodbye(name_: str) -> str: + return f'Goodbye, {name_}!' + """ + ) + plugin_file_tasks = [ + { + "name": "hello_world", + "inputParams": [ + { + "name": "name", + "parameterType": 2, + "required": True, + }, + ], + "outputParams": [ + { + "name": "name_", + "parameterType": 2, + }, + ], + }, + { + "name": "goodbye", + "inputParams": [ + { + "name": "name_", + "parameterType": 2, + "required": True, + }, + ], + "outputParams": [ + { + "name": "goodbye", + "parameterType": 2, + }, + ], + }, + ] + plugin_file_response = actions.register_plugin_file( + client, + plugin_id=plugin_response["id"], + filename="tasks.py", + description="The task plugin file for hello world.", + contents=plugin_file_contents, + tasks = plugin_file_tasks, + ).get_json() + task_graph = textwrap.dedent( + """# my entrypoint graph + hello_step: + hello_world: + name: $name + goodbye: + greeting: $greeting + """ + ) + + plugin_ids = [plugin_response["id"]] + entrypoint_parameters = [ + { + "name" : "name", + "defaultValue": 2, + "parameterType": "string", + }, + ] + expected_message = "[ValidationIssue(IssueType.SCHEMA, IssueSeverity.ERROR, 'In tasks section: {} should be non-empty')]" + assert_entrypoint_workflow_has_errors( + client, + task_graph=task_graph, + plugin_ids=plugin_ids, + entrypoint_parameters=entrypoint_parameters, + expected_message=expected_message, + ) \ No newline at end of file From 58831247bb3b74f9ed02bc280e71d5ca627ab1bf Mon Sep 17 00:00:00 2001 From: Andrew Hand Date: Fri, 4 Oct 2024 10:15:00 -0400 Subject: [PATCH 12/12] build(restapi): add errors to workflow validation --- src/dioptra/restapi/errors.py | 1 + src/dioptra/restapi/v1/workflows/errors.py | 11 +++++++++++ src/dioptra/restapi/v1/workflows/service.py | 11 +++-------- 3 files changed, 15 insertions(+), 8 deletions(-) diff --git a/src/dioptra/restapi/errors.py b/src/dioptra/restapi/errors.py index ba72c03a5..1ba834aaf 100644 --- a/src/dioptra/restapi/errors.py +++ b/src/dioptra/restapi/errors.py @@ -104,3 +104,4 @@ def register_error_handlers(api: Api) -> None: v1.queues.errors.register_error_handlers(api) v1.tags.errors.register_error_handlers(api) v1.users.errors.register_error_handlers(api) + v1.workflows.errors.register_error_handlers(api) diff --git a/src/dioptra/restapi/v1/workflows/errors.py b/src/dioptra/restapi/v1/workflows/errors.py index c5723ec4a..e482bb05b 100644 --- a/src/dioptra/restapi/v1/workflows/errors.py +++ b/src/dioptra/restapi/v1/workflows/errors.py @@ -28,6 +28,10 @@ class JobExperimentDoesNotExistError(Exception): """The experiment associated with the job does not exist.""" +class EntrypointWorkflowValidationIssue(Exception): + """The entrypoint workflow yaml has issues.""" + + def register_error_handlers(api: Api) -> None: @api.errorhandler(JobEntryPointDoesNotExistError) def handle_experiment_job_does_not_exist_error(error): @@ -39,3 +43,10 @@ def handle_experiment_does_not_exist_error(error): "message": "Not Found - The experiment associated with the job does not " "exist" }, 404 + + @api.errorhandler(EntrypointWorkflowValidationIssue) + def handle_entrypoint_workflow_validation_error(error): + issues = error.args + print (issues) + message = "\n".join(str(issue) for issue in issues) + return {"message": message}, 422 \ No newline at end of file diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index 0bffd73b3..deb6ea009 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -25,6 +25,7 @@ from .lib import package_job_files, views from .schema import FileTypes +from .errors import EntrypointWorkflowValidationIssue from dioptra.restapi.db import db from dioptra.restapi.v1 import utils @@ -129,7 +130,6 @@ def validate( plugins = self._plugin_id_service.get(plugin_ids) for plugin in plugins: for plugin_file in plugin['plugin_files']: - # print(plugin_file.tasks) for task in plugin_file.tasks: input_parameters = task.input_parameters output_parameters = task.output_parameters @@ -155,16 +155,11 @@ def validate( "tasks": tasks, "graph": cast(dict[str, Any], yaml.safe_load(task_graph)), } - print (task_engine_dict) valid = is_valid(task_engine_dict) if valid: return {"status": "Success", "valid": valid} else: issues = validate(task_engine_dict) - print(issues) - return { - "status": "Success", - "valid": valid, - # "issues": issues, - } \ No newline at end of file + log.debug("Entrypoint workflow validation failed", issues=issues) + raise EntrypointWorkflowValidationIssue(issues) \ No newline at end of file