Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: fixes the default value to be of type list #211

Merged
merged 8 commits into from
Oct 2, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions openhexa/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import os
import tempfile
import typing
from dataclasses import asdict
from importlib.metadata import version
from pathlib import Path
from zipfile import ZipFile
Expand All @@ -21,7 +20,7 @@

from openhexa.cli.settings import settings
from openhexa.sdk.pipelines import get_local_workspace_config
from openhexa.sdk.pipelines.runtime import get_pipeline_metadata
from openhexa.sdk.pipelines.runtime import get_pipeline
from openhexa.utils import create_requests_session, stringcase


Expand Down Expand Up @@ -195,7 +194,7 @@ def list_pipelines():
return data["pipelines"]["items"]


def get_pipeline(pipeline_code: str) -> dict[str, typing.Any]:
def get_pipeline_from_code(pipeline_code: str) -> dict[str, typing.Any]:
"""Get a single pipeline."""
if settings.current_workspace is None:
raise NoActiveWorkspaceError
Expand Down Expand Up @@ -543,7 +542,7 @@ def upload_pipeline(
raise NoActiveWorkspaceError

directory = pipeline_directory_path.absolute()
pipeline = get_pipeline_metadata(directory)
pipeline = get_pipeline(directory)
zip_file = generate_zip_file(directory)

if settings.debug:
Expand Down Expand Up @@ -574,7 +573,7 @@ def upload_pipeline(
"description": description,
"externalLink": link,
"zipfile": base64_content,
"parameters": [asdict(p) for p in pipeline.parameters],
"parameters": [p.to_dict() for p in pipeline.parameters],
"timeout": pipeline.timeout,
}
},
Expand Down
10 changes: 5 additions & 5 deletions openhexa/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
download_pipeline_sourcecode,
ensure_is_pipeline_dir,
get_library_versions,
get_pipeline,
get_pipeline_from_code,
get_workspace,
list_pipelines,
run_pipeline,
upload_pipeline,
)
from openhexa.cli.settings import settings, setup_logging
from openhexa.sdk.pipelines.exceptions import PipelineNotFound
from openhexa.sdk.pipelines.runtime import get_pipeline_metadata
from openhexa.sdk.pipelines.runtime import get_pipeline


def validate_url(ctx, param, value):
Expand Down Expand Up @@ -283,7 +283,7 @@ def pipelines_push(
ensure_is_pipeline_dir(path)

try:
pipeline = get_pipeline_metadata(path)
pipeline = get_pipeline(path)
except PipelineNotFound:
_terminate(
f"❌ No function with openhexa.sdk pipeline decorator found in {click.style(path, bold=True)}.",
Expand All @@ -296,7 +296,7 @@ def pipelines_push(
if settings.debug:
click.echo(workspace_pipelines)

if get_pipeline(pipeline.code) is None:
if get_pipeline_from_code(pipeline.code) is None:
click.echo(
f"Pipeline {click.style(pipeline.code, bold=True)} does not exist in workspace {click.style(workspace, bold=True)}"
)
Expand Down Expand Up @@ -374,7 +374,7 @@ def pipelines_delete(code: str):
err=True,
)
else:
pipeline = get_pipeline(code)
pipeline = get_pipeline_from_code(code)
if pipeline is None:
_terminate(
f"❌ Pipeline {click.style(code, bold=True)} does not exist in workspace {click.style(settings.current_workspace, bold=True)}"
Expand Down
25 changes: 23 additions & 2 deletions openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,19 @@ def validate(self, value: typing.Any) -> typing.Any:
else:
return self._validate_single(value)

def to_dict(self) -> dict[str, typing.Any]:
"""Return a dictionary representation of the Parameter instance."""
return {
"code": self.code,
"type": self.type.spec_type,
"name": self.name,
"choices": self.choices,
"help": self.help,
"default": self.default,
"required": self.required,
"multiple": self.multiple,
}

def _validate_single(self, value: typing.Any):
# Normalize empty values to None and handles default
normalized_value = self.type.normalize(value)
Expand Down Expand Up @@ -487,8 +500,16 @@ def _validate_default(self, default: typing.Any, multiple: bool):
except ParameterValueError:
raise InvalidParameterError(f"The default value for {self.code} is not valid.")

if self.choices is not None and default not in self.choices:
raise InvalidParameterError(f"The default value for {self.code} is not included in the provided choices.")
if self.choices is not None:
if isinstance(default, list):
if not all(d in self.choices for d in default):
raise InvalidParameterError(
f"The default list of values for {self.code} is not included in the provided choices."
)
elif default not in self.choices:
raise InvalidParameterError(
f"The default value for {self.code} is not included in the provided choices."
)

def parameter_spec(self) -> dict[str, typing.Any]:
"""Build specification for this parameter, to be provided to the OpenHEXA backend."""
Expand Down
11 changes: 11 additions & 0 deletions openhexa/sdk/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,17 @@ def parameters_spec(self) -> list[dict[str, typing.Any]]:
"""Return the individual specifications of all the parameters of this pipeline."""
return [arg.parameter_spec() for arg in self.parameters]

def to_dict(self):
"""Return a dictionary representation of the pipeline."""
return {
"code": self.code,
"name": self.name,
"parameters": [p.to_dict() for p in self.parameters],
"timeout": self.timeout,
"function": self.function.__dict__ if self.function else None,
"tasks": [t.__dict__ for t in self.tasks],
}

def _get_available_tasks(self) -> list[Task]:
return [task for task in self.tasks if task.is_ready()]

Expand Down
75 changes: 24 additions & 51 deletions openhexa/sdk/pipelines/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,37 +14,12 @@

import requests

from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound
from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE
from openhexa.sdk.pipelines.utils import validate_pipeline_parameter_code
from openhexa.sdk.pipelines.exceptions import PipelineNotFound
from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE, Parameter

from .pipeline import Pipeline


@dataclass
class PipelineParameterSpecs:
"""Specification of a pipeline parameter."""

code: string
type: string
name: string
choices: list[typing.Union[str, int, float]]
help: string
default: typing.Any
required: bool = True
multiple: bool = False

def __post_init__(self):
"""Validate the parameter and set default values."""
if self.default and self.choices and self.default not in self.choices:
raise ValueError(f"Default value '{self.default}' not in choices {self.choices}")
validate_pipeline_parameter_code(self.code)
if self.required is None:
self.required = True
if self.multiple is None:
self.multiple = False


@dataclass
class Argument:
"""Argument of a decorator."""
Expand All @@ -53,16 +28,6 @@ class Argument:
types: list[typing.Any] = field(default_factory=list)


@dataclass
class PipelineSpecs:
"""Specification of a pipeline."""

code: string
name: string
timeout: int = None
parameters: list[PipelineParameterSpecs] = field(default_factory=list)


def import_pipeline(pipeline_dir_path: str):
"""Import pipeline code within provided path using importlib."""
pipeline_dir = os.path.abspath(pipeline_dir_path)
Expand Down Expand Up @@ -127,7 +92,7 @@ def _get_decorator_arg_value(decorator, arg: Argument, index: int):
return None


def _get_decorator_spec(decorator, args: tuple[Argument], key=None):
def _get_decorator_spec(decorator, args: tuple[Argument]):
d = {"name": decorator.func.id, "args": {}}

for i, arg in enumerate(args):
Expand All @@ -136,8 +101,8 @@ def _get_decorator_spec(decorator, args: tuple[Argument], key=None):
return d


def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
"""Return the pipeline metadata from the pipeline code.
def get_pipeline(pipeline_path: Path) -> Pipeline:
"""Return the pipeline with metadata and parameters from the pipeline code.

Args:
pipeline_path (Path): Path to the pipeline directory
Expand All @@ -150,7 +115,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:

Returns
-------
typing.Tuple[PipelineSpecs, typing.List[PipelineParameterSpecs]]: A tuple containing the pipeline specs and the list of parameters specs.
Pipeline: The pipeline object with parameters and metadata.
"""
tree = ast.parse(open(Path(pipeline_path) / "pipeline.py").read())
pipeline = None
Expand All @@ -170,7 +135,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Argument("timeout", [ast.Constant]),
),
)
pipeline = PipelineSpecs(**pipeline_decorator_spec["args"])
pipelines_parameters = []
for parameter_decorator in _get_decorators_by_name(node, "parameter"):
param_decorator_spec = _get_decorator_spec(
parameter_decorator,
Expand All @@ -180,19 +145,27 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Argument("name", [ast.Constant]),
Argument("choices", [ast.List]),
Argument("help", [ast.Constant]),
Argument("default", [ast.Constant]),
Argument("default", [ast.Constant, ast.List]),
nazarfil marked this conversation as resolved.
Show resolved Hide resolved
Argument("required", [ast.Constant]),
Argument("multiple", [ast.Constant]),
),
)
try:
args = param_decorator_spec["args"]
inst = TYPES_BY_PYTHON_TYPE[args["type"]]()
args["type"] = inst.spec_type

pipeline.parameters.append(PipelineParameterSpecs(**args))
except KeyError:
raise InvalidParameterError(f"Invalid parameter type {args['type']}")
args = param_decorator_spec["args"]
inst = TYPES_BY_PYTHON_TYPE[args["type"]]()
args["type"] = inst.expected_type
parameter = Parameter(
code=args["code"],
name=args.get("name"),
nazarfil marked this conversation as resolved.
Show resolved Hide resolved
type=args.get("type"),
choices=args.get("choices"),
help=args.get("help"),
default=args.get("default"),
required=args.get("required") if args.get("required") is not None else True,
nazarfil marked this conversation as resolved.
Show resolved Hide resolved
multiple=args.get("multiple") if args.get("multiple") is not None else False,
nazarfil marked this conversation as resolved.
Show resolved Hide resolved
)
pipelines_parameters.append(parameter)

pipeline = Pipeline(parameters=pipelines_parameters, function=None, **pipeline_decorator_spec["args"])
nazarfil marked this conversation as resolved.
Show resolved Hide resolved

if pipeline is None:
raise PipelineNotFound("No function with openhexa.sdk pipeline decorator found.")
Expand Down
1 change: 1 addition & 0 deletions openhexa/utils/stringcase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

Coming from https://github.com/okunishinishi/python-stringcase
"""

import re


Expand Down
Loading