Skip to content

Commit

Permalink
Merge branch 'v1-entrypoint-workflow-yaml-validation' of https://gith…
Browse files Browse the repository at this point in the history
…ub.com/usnistgov/dioptra into v1-entrypoint-workflow-yaml-validation
  • Loading branch information
andrewhand committed Oct 21, 2024
2 parents fbbe3b8 + 4c9a5a8 commit 6ae8a80
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 0 deletions.
34 changes: 34 additions & 0 deletions src/dioptra/restapi/v1/workflows/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,40 @@ def get(self):
)


@api.route("/entrypointValidate")
class EntrypointValidateEndpoint(Resource):
@inject
def __init__(
self, entrypoint_validate_service: EntrypointValidateService, *args, **kwargs
) -> None:
"""Initialize the workflow resource.
All arguments are provided via dependency injection.
Args:
entrypoint_validate_service: A EntrypointValidateService object.
"""
self._entrypoint_validate_service = entrypoint_validate_service
super().__init__(*args, **kwargs)

@login_required
@accepts(schema=EntrypointWorkflowSchema, api=api)
def post(self):
"""Validates the workflow for a entrypoint.""" # noqa: B950
log = LOGGER.new(
request_id=str(uuid.uuid4()), resource="Workflows", request_type="POST"
)
parsed_obj = request.parsed_obj # type: ignore
task_graph = parsed_obj["task_graph"]
plugin_ids = parsed_obj["plugin_ids"]
parameters = parsed_obj["parameters"]
return self._entrypoint_validate_service.validate(
task_graph=task_graph,
plugin_ids=plugin_ids,
entrypoint_parameters=parameters,
log=log,
)

@api.route("/entrypointValidate")
class EntrypointValidateEndpoint(Resource):
@inject
Expand Down
25 changes: 25 additions & 0 deletions src/dioptra/restapi/v1/workflows/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@

from dioptra.restapi.v1.entrypoints.schema import EntrypointParameterSchema

from dioptra.restapi.v1.entrypoints.schema import EntrypointParameterSchema


class FileTypes(Enum):
TAR_GZ = "tar_gz"
Expand Down Expand Up @@ -89,3 +91,26 @@ class EntrypointWorkflowSchema(Schema):
many=True,
metadata=dict(description="List of parameters for the entrypoint."),
)


class EntrypointWorkflowSchema(Schema):
"""The YAML that represents the Entrypoint Workflow."""

taskGraph = fields.String(
attribute="task_graph",
metadata=dict(description="Task graph of the Entrypoint resource."),
required=True,
)
pluginIds = fields.List(
fields.Integer(),
attribute="plugin_ids",
data_key="plugins",
metadata=dict(description="List of plugin files for the entrypoint."),
load_only=True,
)
parameters = fields.Nested(
EntrypointParameterSchema,
attribute="parameters",
many=True,
metadata=dict(description="List of parameters for the entrypoint."),
)

0 comments on commit 6ae8a80

Please sign in to comment.