Skip to content

Commit

Permalink
fix: fixes the default value to be of type list
Browse files Browse the repository at this point in the history
Closes HEXA-1037 : Default values for parameters set as "multiple" fails on run of the pipeline
  • Loading branch information
nazarfil committed Sep 26, 2024
1 parent 80a73be commit b76cdab
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 6 deletions.
11 changes: 9 additions & 2 deletions openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,8 +487,15 @@ def _validate_default(self, default: typing.Any, multiple: bool):
except ParameterValueError:
raise InvalidParameterError(f"The default value for {self.code} is not valid.")

if self.choices is not None and default not in self.choices:
raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.")
if self.choices is not None:
if isinstance(default, list):
if not all(d in self.choices for d in default):
raise InvalidParameterError(
f"The default list of values for {self.code} is not included in the provided choices."
)
elif default not in self.choices:
raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.")


def parameter_spec(self) -> dict[str, typing.Any]:
"""Build specification for this parameter, to be provided to the OpenHEXA backend."""
Expand Down
10 changes: 7 additions & 3 deletions openhexa/sdk/pipelines/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,12 @@ class PipelineParameterSpecs:

def __post_init__(self):
"""Validate the parameter and set default values."""
if self.default and self.choices and self.default not in self.choices:
raise ValueError(f"Default value '{self.default}' not in choices {self.choices}")
if self.default and self.choices:
if isinstance(self.default, list):
if not all(d in self.choices for d in self.default):
raise ValueError(f"Default list of values {self.default} not in choices {self.choices}")
elif self.default not in self.choices:
raise ValueError(f"Default value '{self.default}' not in choices {self.choices}")
validate_pipeline_parameter_code(self.code)
if self.required is None:
self.required = True
Expand Down Expand Up @@ -180,7 +184,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Argument("name", [ast.Constant]),
Argument("choices", [ast.List]),
Argument("help", [ast.Constant]),
Argument("default", [ast.Constant]),
Argument("default", [ast.Constant, ast.List]),
Argument("required", [ast.Constant]),
Argument("multiple", [ast.Constant]),
),
Expand Down
1 change: 1 addition & 0 deletions openhexa/utils/stringcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
Coming from https://github.com/okunishinishi/python-stringcase
"""

import re


Expand Down
2 changes: 1 addition & 1 deletion tests/test_parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def test_parameter_validate_multiple():
assert parameter_3.validate([]) == []

# choices
parameter_4 = Parameter("arg4", type=str, choices=["ab", "cd"], multiple=True)
parameter_4 = Parameter("arg4", type=str, default=["ab"], choices=["ab", "cd"], multiple=True)
assert parameter_4.validate(["ab"]) == ["ab"]
assert parameter_4.validate(["ab", "cd"]) == ["ab", "cd"]
with pytest.raises(ParameterValueError):
Expand Down

0 comments on commit b76cdab

Please sign in to comment.