From b76cdabc56284855f305d97134540f6107dc5c0e Mon Sep 17 00:00:00 2001 From: nazarfil Date: Mon, 23 Sep 2024 14:17:28 +0200 Subject: [PATCH] fix: fixes the default value to be of type list Closes HEXA-1037 : Default values for parameters set as "multiple" fails on run of the pipeline --- openhexa/sdk/pipelines/parameter.py | 11 +++++++++-- openhexa/sdk/pipelines/runtime.py | 10 +++++++--- openhexa/utils/stringcase.py | 1 + tests/test_parameter.py | 2 +- 4 files changed, 18 insertions(+), 6 deletions(-) diff --git a/openhexa/sdk/pipelines/parameter.py b/openhexa/sdk/pipelines/parameter.py index 203e217..862ee23 100644 --- a/openhexa/sdk/pipelines/parameter.py +++ b/openhexa/sdk/pipelines/parameter.py @@ -487,8 +487,15 @@ def _validate_default(self, default: typing.Any, multiple: bool): except ParameterValueError: raise InvalidParameterError(f"The default value for {self.code} is not valid.") - if self.choices is not None and default not in self.choices: - raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.") + if self.choices is not None: + if isinstance(default, list): + if not all(d in self.choices for d in default): + raise InvalidParameterError( + f"The default list of values for {self.code} is not included in the provided choices." + ) + elif default not in self.choices: + raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.") + def parameter_spec(self) -> dict[str, typing.Any]: """Build specification for this parameter, to be provided to the OpenHEXA backend.""" diff --git a/openhexa/sdk/pipelines/runtime.py b/openhexa/sdk/pipelines/runtime.py index 9b7d54e..ee70a3c 100644 --- a/openhexa/sdk/pipelines/runtime.py +++ b/openhexa/sdk/pipelines/runtime.py @@ -36,8 +36,12 @@ class PipelineParameterSpecs: def __post_init__(self): """Validate the parameter and set default values.""" - if self.default and self.choices and self.default not in self.choices: - raise ValueError(f"Default value '{self.default}' not in choices {self.choices}") + if self.default and self.choices: + if isinstance(self.default, list): + if not all(d in self.choices for d in self.default): + raise ValueError(f"Default list of values {self.default} not in choices {self.choices}") + elif self.default not in self.choices: + raise ValueError(f"Default value '{self.default}' not in choices {self.choices}") validate_pipeline_parameter_code(self.code) if self.required is None: self.required = True @@ -180,7 +184,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs: Argument("name", [ast.Constant]), Argument("choices", [ast.List]), Argument("help", [ast.Constant]), - Argument("default", [ast.Constant]), + Argument("default", [ast.Constant, ast.List]), Argument("required", [ast.Constant]), Argument("multiple", [ast.Constant]), ), diff --git a/openhexa/utils/stringcase.py b/openhexa/utils/stringcase.py index e79a274..f2acdbe 100644 --- a/openhexa/utils/stringcase.py +++ b/openhexa/utils/stringcase.py @@ -2,6 +2,7 @@ Coming from https://github.com/okunishinishi/python-stringcase """ + import re diff --git a/tests/test_parameter.py b/tests/test_parameter.py index afc3d09..6cb8efc 100644 --- a/tests/test_parameter.py +++ b/tests/test_parameter.py @@ -307,7 +307,7 @@ def test_parameter_validate_multiple(): assert parameter_3.validate([]) == [] # choices - parameter_4 = Parameter("arg4", type=str, choices=["ab", "cd"], multiple=True) + parameter_4 = Parameter("arg4", type=str, default=["ab"], choices=["ab", "cd"], multiple=True) assert parameter_4.validate(["ab"]) == ["ab"] assert parameter_4.validate(["ab", "cd"]) == ["ab", "cd"] with pytest.raises(ParameterValueError):