From e5e356c8dd0d07548fbfc32d6cc2cd0ac39d322b Mon Sep 17 00:00:00 2001 From: nazarfil Date: Sun, 29 Sep 2024 20:05:54 +0200 Subject: [PATCH] fix: adds requried --- openhexa/cli/api.py | 6 +-- openhexa/cli/cli.py | 10 ++--- openhexa/sdk/pipelines/parameter.py | 15 ++++++++ openhexa/sdk/pipelines/pipeline.py | 8 ++++ openhexa/sdk/pipelines/runtime.py | 45 +++++++++++------------ tests/test_ast.py | 57 ++++++++++++++++------------- 6 files changed, 83 insertions(+), 58 deletions(-) diff --git a/openhexa/cli/api.py b/openhexa/cli/api.py index 24685a6..24994eb 100644 --- a/openhexa/cli/api.py +++ b/openhexa/cli/api.py @@ -21,7 +21,7 @@ from openhexa.cli.settings import settings from openhexa.sdk.pipelines import get_local_workspace_config -from openhexa.sdk.pipelines.runtime import get_pipeline_metadata +from openhexa.sdk.pipelines.runtime import get_pipeline from openhexa.utils import create_requests_session, stringcase @@ -195,7 +195,7 @@ def list_pipelines(): return data["pipelines"]["items"] -def get_pipeline(pipeline_code: str) -> dict[str, typing.Any]: +def get_pipeline_from_code(pipeline_code: str) -> dict[str, typing.Any]: """Get a single pipeline.""" if settings.current_workspace is None: raise NoActiveWorkspaceError @@ -543,7 +543,7 @@ def upload_pipeline( raise NoActiveWorkspaceError directory = pipeline_directory_path.absolute() - pipeline = get_pipeline_metadata(directory) + pipeline = get_pipeline(directory) zip_file = generate_zip_file(directory) if settings.debug: diff --git a/openhexa/cli/cli.py b/openhexa/cli/cli.py index 6ff05bc..30cf788 100644 --- a/openhexa/cli/cli.py +++ b/openhexa/cli/cli.py @@ -21,7 +21,7 @@ download_pipeline_sourcecode, ensure_is_pipeline_dir, get_library_versions, - get_pipeline, + get_pipeline_from_code, get_workspace, list_pipelines, run_pipeline, @@ -29,7 +29,7 @@ ) from openhexa.cli.settings import settings, setup_logging from openhexa.sdk.pipelines.exceptions import PipelineNotFound -from openhexa.sdk.pipelines.runtime import get_pipeline_metadata +from openhexa.sdk.pipelines.runtime import get_pipeline def validate_url(ctx, param, value): @@ -283,7 +283,7 @@ def pipelines_push( ensure_is_pipeline_dir(path) try: - pipeline = get_pipeline_metadata(path) + pipeline = get_pipeline(path) except PipelineNotFound: _terminate( f"❌ No function with openhexa.sdk pipeline decorator found in {click.style(path, bold=True)}.", @@ -296,7 +296,7 @@ def pipelines_push( if settings.debug: click.echo(workspace_pipelines) - if get_pipeline(pipeline.code) is None: + if get_pipeline_from_code(pipeline.code) is None: click.echo( f"Pipeline {click.style(pipeline.code, bold=True)} does not exist in workspace {click.style(workspace, bold=True)}" ) @@ -374,7 +374,7 @@ def pipelines_delete(code: str): err=True, ) else: - pipeline = get_pipeline(code) + pipeline = get_pipeline_from_code(code) if pipeline is None: _terminate( f"❌ Pipeline {click.style(code, bold=True)} does not exist in workspace {click.style(settings.current_workspace, bold=True)}" diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 2d5e65e..1a902df 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -429,6 +429,21 @@ def validate(self, value: typing.Any) -> typing.Any: else: return self._validate_single(value) + def to_dict(self) -> dict[str, typing.Any]: + """Return a dictionary representation of the Parameter instance.""" + print(f"Multiple : {self.required}, required: {self.required}") + return { + "code": self.code, + "type": self.type.spec_type, + "name": self.name, + "choices": self.choices, + "help": self.help, + "default": self.default, + "required": self.required, + "multiple": self.multiple, + } + + def _validate_single(self, value: typing.Any): # Normalize empty values to None and handles default normalized_value = self.type.normalize(value) diff --git a/openhexa/sdk/pipelines/pipeline.py b/openhexa/sdk/pipelines/pipeline.py index b9b3a6c..be03834 100644 --- a/openhexa/sdk/pipelines/pipeline.py +++ b/openhexa/sdk/pipelines/pipeline.py @@ -167,6 +167,14 @@ def parameters_spec(self) -> list[dict[str, typing.Any]]: """Return the individual specifications of all the parameters of this pipeline.""" return [arg.parameter_spec() for arg in self.parameters] + def to_dict(self): + return { + "code": self.code, + "name": self.name, + "parameters": [p.to_dict() for p in self.parameters], + "timeout": self.timeout, + } + def _get_available_tasks(self) -> list[Task]: return [task for task in self.tasks if task.is_ready()] diff --git a/openhexa/sdk/pipelines/runtime.py b/openhexa/sdk/pipelines/runtime.py index ee70a3c..780eafd 100644 --- a/openhexa/sdk/pipelines/runtime.py +++ b/openhexa/sdk/pipelines/runtime.py @@ -15,7 +15,7 @@ import requests from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound -from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE +from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE, Parameter from openhexa.sdk.pipelines.utils import validate_pipeline_parameter_code from .pipeline import Pipeline @@ -57,16 +57,6 @@ class Argument: types: list[typing.Any] = field(default_factory=list) -@dataclass -class PipelineSpecs: - """Specification of a pipeline.""" - - code: string - name: string - timeout: int = None - parameters: list[PipelineParameterSpecs] = field(default_factory=list) - - def import_pipeline(pipeline_dir_path: str): """Import pipeline code within provided path using importlib.""" pipeline_dir = os.path.abspath(pipeline_dir_path) @@ -131,7 +121,7 @@ def _get_decorator_arg_value(decorator, arg: Argument, index: int): return None -def _get_decorator_spec(decorator, args: tuple[Argument], key=None): +def _get_decorator_spec(decorator, args: tuple[Argument]): d = {"name": decorator.func.id, "args": {}} for i, arg in enumerate(args): @@ -140,8 +130,8 @@ def _get_decorator_spec(decorator, args: tuple[Argument], key=None): return d -def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: - """Return the pipeline metadata from the pipeline code. +def get_pipeline(pipeline_path: Path) -> Pipeline: + """Return the pipeline with metadata and parameters from the pipeline code. Args: pipeline_path (Path): Path to the pipeline directory @@ -154,7 +144,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: Returns ------- - typing.Tuple[PipelineSpecs, typing.List[PipelineParameterSpecs]]: A tuple containing the pipeline specs and the list of parameters specs. + Pipeline: The pipeline object with parameters and metadata. """ tree = ast.parse(open(Path(pipeline_path) / "pipeline.py").read()) pipeline = None @@ -174,7 +164,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: Argument("timeout", [ast.Constant]), ), ) - pipeline = PipelineSpecs(**pipeline_decorator_spec["args"]) + pipelines_parameters = [] for parameter_decorator in _get_decorators_by_name(node, "parameter"): param_decorator_spec = _get_decorator_spec( parameter_decorator, @@ -189,14 +179,21 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: Argument("multiple", [ast.Constant]), ), ) - try: - args = param_decorator_spec["args"] - inst = TYPES_BY_PYTHON_TYPE[args["type"]]() - args["type"] = inst.spec_type - - pipeline.parameters.append(PipelineParameterSpecs(**args)) - except KeyError: - raise InvalidParameterError(f"Invalid parameter type {args['type']}") + args = param_decorator_spec["args"] + inst = TYPES_BY_PYTHON_TYPE[args["type"]]() + args["type"] = inst.expected_type + parameter = Parameter( + code=args.get("code"), + name=args.get("name"), + type=args.get("type"), + choices=args.get("choices"), + help=args.get("help"), + default=args.get("default"), + required=args.get("required") if args.get("required") is not None else True, + multiple=args.get("multiple") if args.get("multiple") is not None else False,) + pipelines_parameters.append(parameter) + + pipeline = Pipeline(parameters=pipelines_parameters, function=None, **pipeline_decorator_spec["args"]) if pipeline is None: raise PipelineNotFound("No function with openhexa.sdk pipeline decorator found.") diff --git a/tests/test_ast.py b/tests/test_ast.py index 6216f88..4c54745 100644 --- a/tests/test_ast.py +++ b/tests/test_ast.py @@ -1,11 +1,12 @@ """Tests related to the parsing of the pipeline code.""" +import json import tempfile from dataclasses import asdict from unittest import TestCase from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound -from openhexa.sdk.pipelines.runtime import get_pipeline_metadata +from openhexa.sdk.pipelines.runtime import get_pipeline class AstTest(TestCase): @@ -18,7 +19,7 @@ def test_pipeline_not_found(self): f.write("print('hello')") with self.assertRaises(PipelineNotFound): - get_pipeline_metadata(tmpdirname) + get_pipeline(tmpdirname) def test_pipeline_no_parameters(self): """The file contains a @pipeline decorator but no parameters.""" @@ -35,9 +36,9 @@ def test_pipeline_no_parameters(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None} + pipeline.to_dict(), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None} ) def test_pipeline_with_args(self): @@ -55,9 +56,9 @@ def test_pipeline_with_args(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None} + pipeline.to_dict(), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None} ) def test_pipeline_with_invalid_parameter_args(self): @@ -77,7 +78,7 @@ def test_pipeline_with_invalid_parameter_args(self): ) ) with self.assertRaises(ValueError): - get_pipeline_metadata(tmpdirname) + get_pipeline(tmpdirname) def test_pipeline_with_invalid_pipeline_args(self): """The file contains a @pipeline decorator with invalid value.""" @@ -97,7 +98,7 @@ def test_pipeline_with_invalid_pipeline_args(self): ) ) with self.assertRaises(ValueError): - get_pipeline_metadata(tmpdirname) + get_pipeline(tmpdirname) def test_pipeline_with_int_param(self): """The file contains a @pipeline decorator and a @parameter decorator with an int.""" @@ -116,9 +117,9 @@ def test_pipeline_with_int_param(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", @@ -147,7 +148,7 @@ def test_pipeline_with_multiple_param(self): [ "from openhexa.sdk.pipelines import pipeline, parameter", "", - "@parameter('test_param', name='Test Param', type=int, default=42, help='Param help', multiple=True)", + "@parameter('test_param', name='Test Param', type=int, default=[42], help='Param help', multiple=True)", "@pipeline('test', 'Test pipeline')", "def test_pipeline():", " pass", @@ -155,9 +156,9 @@ def test_pipeline_with_multiple_param(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", @@ -168,7 +169,7 @@ def test_pipeline_with_multiple_param(self): "code": "test_param", "type": "int", "name": "Test Param", - "default": 42, + "default": [42], "help": "Param help", "required": True, } @@ -195,12 +196,14 @@ def test_pipeline_with_dataset(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", + "function": None, "name": "Test pipeline", + "tasks": [], "parameters": [ { "choices": None, @@ -234,9 +237,9 @@ def test_pipeline_with_choices(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", @@ -271,9 +274,11 @@ def test_pipeline_with_timeout(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": 42} + pipeline.to_dict(), + {"code": "test", "function": None, "name": "Test pipeline", "parameters": [], "timeout": 42, + "tasks": []} ) def test_pipeline_with_bool(self): @@ -293,9 +298,9 @@ def test_pipeline_with_bool(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", @@ -333,9 +338,9 @@ def test_pipeline_with_multiple_parameters(self): ] ) ) - pipeline = get_pipeline_metadata(tmpdirname) + pipeline = get_pipeline(tmpdirname) self.assertEqual( - asdict(pipeline), + pipeline.to_dict(), { "code": "test", "name": "Test pipeline", @@ -382,5 +387,5 @@ def test_pipeline_with_unsupported_parameter(self): ] ) ) - with self.assertRaises(InvalidParameterError): - get_pipeline_metadata(tmpdirname) + with self.assertRaises(KeyError): + get_pipeline(tmpdirname)