Skip to content

Commit

Permalink
fix: adds requried
Browse files Browse the repository at this point in the history
  • Loading branch information
nazarfil committed Sep 29, 2024
1 parent 1c8e6a2 commit e5e356c
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 58 deletions.
6 changes: 3 additions & 3 deletions openhexa/cli/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

from openhexa.cli.settings import settings
from openhexa.sdk.pipelines import get_local_workspace_config
from openhexa.sdk.pipelines.runtime import get_pipeline_metadata
from openhexa.sdk.pipelines.runtime import get_pipeline
from openhexa.utils import create_requests_session, stringcase


Expand Down Expand Up @@ -195,7 +195,7 @@ def list_pipelines():
return data["pipelines"]["items"]


def get_pipeline(pipeline_code: str) -> dict[str, typing.Any]:
def get_pipeline_from_code(pipeline_code: str) -> dict[str, typing.Any]:
"""Get a single pipeline."""
if settings.current_workspace is None:
raise NoActiveWorkspaceError
Expand Down Expand Up @@ -543,7 +543,7 @@ def upload_pipeline(
raise NoActiveWorkspaceError

directory = pipeline_directory_path.absolute()
pipeline = get_pipeline_metadata(directory)
pipeline = get_pipeline(directory)
zip_file = generate_zip_file(directory)

if settings.debug:
Expand Down
10 changes: 5 additions & 5 deletions openhexa/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,15 @@
download_pipeline_sourcecode,
ensure_is_pipeline_dir,
get_library_versions,
get_pipeline,
get_pipeline_from_code,
get_workspace,
list_pipelines,
run_pipeline,
upload_pipeline,
)
from openhexa.cli.settings import settings, setup_logging
from openhexa.sdk.pipelines.exceptions import PipelineNotFound
from openhexa.sdk.pipelines.runtime import get_pipeline_metadata
from openhexa.sdk.pipelines.runtime import get_pipeline


def validate_url(ctx, param, value):
Expand Down Expand Up @@ -283,7 +283,7 @@ def pipelines_push(
ensure_is_pipeline_dir(path)

try:
pipeline = get_pipeline_metadata(path)
pipeline = get_pipeline(path)
except PipelineNotFound:
_terminate(
f"❌ No function with openhexa.sdk pipeline decorator found in {click.style(path, bold=True)}.",
Expand All @@ -296,7 +296,7 @@ def pipelines_push(
if settings.debug:
click.echo(workspace_pipelines)

if get_pipeline(pipeline.code) is None:
if get_pipeline_from_code(pipeline.code) is None:
click.echo(
f"Pipeline {click.style(pipeline.code, bold=True)} does not exist in workspace {click.style(workspace, bold=True)}"
)
Expand Down Expand Up @@ -374,7 +374,7 @@ def pipelines_delete(code: str):
err=True,
)
else:
pipeline = get_pipeline(code)
pipeline = get_pipeline_from_code(code)
if pipeline is None:
_terminate(
f"❌ Pipeline {click.style(code, bold=True)} does not exist in workspace {click.style(settings.current_workspace, bold=True)}"
Expand Down
15 changes: 15 additions & 0 deletions openhexa/sdk/pipelines/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,6 +429,21 @@ def validate(self, value: typing.Any) -> typing.Any:
else:
return self._validate_single(value)

def to_dict(self) -> dict[str, typing.Any]:
"""Return a dictionary representation of the Parameter instance."""
print(f"Multiple : {self.required}, required: {self.required}")
return {
"code": self.code,
"type": self.type.spec_type,
"name": self.name,
"choices": self.choices,
"help": self.help,
"default": self.default,
"required": self.required,
"multiple": self.multiple,
}


def _validate_single(self, value: typing.Any):
# Normalize empty values to None and handles default
normalized_value = self.type.normalize(value)
Expand Down
8 changes: 8 additions & 0 deletions openhexa/sdk/pipelines/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,14 @@ def parameters_spec(self) -> list[dict[str, typing.Any]]:
"""Return the individual specifications of all the parameters of this pipeline."""
return [arg.parameter_spec() for arg in self.parameters]

def to_dict(self):
return {
"code": self.code,
"name": self.name,
"parameters": [p.to_dict() for p in self.parameters],
"timeout": self.timeout,
}

def _get_available_tasks(self) -> list[Task]:
return [task for task in self.tasks if task.is_ready()]

Expand Down
45 changes: 21 additions & 24 deletions openhexa/sdk/pipelines/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import requests

from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound
from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE
from openhexa.sdk.pipelines.parameter import TYPES_BY_PYTHON_TYPE, Parameter
from openhexa.sdk.pipelines.utils import validate_pipeline_parameter_code

from .pipeline import Pipeline
Expand Down Expand Up @@ -57,16 +57,6 @@ class Argument:
types: list[typing.Any] = field(default_factory=list)


@dataclass
class PipelineSpecs:
"""Specification of a pipeline."""

code: string
name: string
timeout: int = None
parameters: list[PipelineParameterSpecs] = field(default_factory=list)


def import_pipeline(pipeline_dir_path: str):
"""Import pipeline code within provided path using importlib."""
pipeline_dir = os.path.abspath(pipeline_dir_path)
Expand Down Expand Up @@ -131,7 +121,7 @@ def _get_decorator_arg_value(decorator, arg: Argument, index: int):
return None


def _get_decorator_spec(decorator, args: tuple[Argument], key=None):
def _get_decorator_spec(decorator, args: tuple[Argument]):
d = {"name": decorator.func.id, "args": {}}

for i, arg in enumerate(args):
Expand All @@ -140,8 +130,8 @@ def _get_decorator_spec(decorator, args: tuple[Argument], key=None):
return d


def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
"""Return the pipeline metadata from the pipeline code.
def get_pipeline(pipeline_path: Path) -> Pipeline:
"""Return the pipeline with metadata and parameters from the pipeline code.
Args:
pipeline_path (Path): Path to the pipeline directory
Expand All @@ -154,7 +144,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Returns
-------
typing.Tuple[PipelineSpecs, typing.List[PipelineParameterSpecs]]: A tuple containing the pipeline specs and the list of parameters specs.
Pipeline: The pipeline object with parameters and metadata.
"""
tree = ast.parse(open(Path(pipeline_path) / "pipeline.py").read())
pipeline = None
Expand All @@ -174,7 +164,7 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Argument("timeout", [ast.Constant]),
),
)
pipeline = PipelineSpecs(**pipeline_decorator_spec["args"])
pipelines_parameters = []
for parameter_decorator in _get_decorators_by_name(node, "parameter"):
param_decorator_spec = _get_decorator_spec(
parameter_decorator,
Expand All @@ -189,14 +179,21 @@ def get_pipeline_metadata(pipeline_path: Path) -> PipelineSpecs:
Argument("multiple", [ast.Constant]),
),
)
try:
args = param_decorator_spec["args"]
inst = TYPES_BY_PYTHON_TYPE[args["type"]]()
args["type"] = inst.spec_type

pipeline.parameters.append(PipelineParameterSpecs(**args))
except KeyError:
raise InvalidParameterError(f"Invalid parameter type {args['type']}")
args = param_decorator_spec["args"]
inst = TYPES_BY_PYTHON_TYPE[args["type"]]()
args["type"] = inst.expected_type
parameter = Parameter(
code=args.get("code"),
name=args.get("name"),
type=args.get("type"),
choices=args.get("choices"),
help=args.get("help"),
default=args.get("default"),
required=args.get("required") if args.get("required") is not None else True,
multiple=args.get("multiple") if args.get("multiple") is not None else False,)
pipelines_parameters.append(parameter)

pipeline = Pipeline(parameters=pipelines_parameters, function=None, **pipeline_decorator_spec["args"])

if pipeline is None:
raise PipelineNotFound("No function with openhexa.sdk pipeline decorator found.")
Expand Down
57 changes: 31 additions & 26 deletions tests/test_ast.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
"""Tests related to the parsing of the pipeline code."""

import json
import tempfile
from dataclasses import asdict
from unittest import TestCase

from openhexa.sdk.pipelines.exceptions import InvalidParameterError, PipelineNotFound
from openhexa.sdk.pipelines.runtime import get_pipeline_metadata
from openhexa.sdk.pipelines.runtime import get_pipeline


class AstTest(TestCase):
Expand All @@ -18,7 +19,7 @@ def test_pipeline_not_found(self):
f.write("print('hello')")

with self.assertRaises(PipelineNotFound):
get_pipeline_metadata(tmpdirname)
get_pipeline(tmpdirname)

def test_pipeline_no_parameters(self):
"""The file contains a @pipeline decorator but no parameters."""
Expand All @@ -35,9 +36,9 @@ def test_pipeline_no_parameters(self):
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None}
pipeline.to_dict(), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None}
)

def test_pipeline_with_args(self):
Expand All @@ -55,9 +56,9 @@ def test_pipeline_with_args(self):
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None}
pipeline.to_dict(), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": None}
)

def test_pipeline_with_invalid_parameter_args(self):
Expand All @@ -77,7 +78,7 @@ def test_pipeline_with_invalid_parameter_args(self):
)
)
with self.assertRaises(ValueError):
get_pipeline_metadata(tmpdirname)
get_pipeline(tmpdirname)

def test_pipeline_with_invalid_pipeline_args(self):
"""The file contains a @pipeline decorator with invalid value."""
Expand All @@ -97,7 +98,7 @@ def test_pipeline_with_invalid_pipeline_args(self):
)
)
with self.assertRaises(ValueError):
get_pipeline_metadata(tmpdirname)
get_pipeline(tmpdirname)

def test_pipeline_with_int_param(self):
"""The file contains a @pipeline decorator and a @parameter decorator with an int."""
Expand All @@ -116,9 +117,9 @@ def test_pipeline_with_int_param(self):
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline),
pipeline.to_dict(),
{
"code": "test",
"name": "Test pipeline",
Expand Down Expand Up @@ -147,17 +148,17 @@ def test_pipeline_with_multiple_param(self):
[
"from openhexa.sdk.pipelines import pipeline, parameter",
"",
"@parameter('test_param', name='Test Param', type=int, default=42, help='Param help', multiple=True)",
"@parameter('test_param', name='Test Param', type=int, default=[42], help='Param help', multiple=True)",
"@pipeline('test', 'Test pipeline')",
"def test_pipeline():",
" pass",
"",
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline),
pipeline.to_dict(),
{
"code": "test",
"name": "Test pipeline",
Expand All @@ -168,7 +169,7 @@ def test_pipeline_with_multiple_param(self):
"code": "test_param",
"type": "int",
"name": "Test Param",
"default": 42,
"default": [42],
"help": "Param help",
"required": True,
}
Expand All @@ -195,12 +196,14 @@ def test_pipeline_with_dataset(self):
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline),
pipeline.to_dict(),
{
"code": "test",
"function": None,
"name": "Test pipeline",
"tasks": [],
"parameters": [
{
"choices": None,
Expand Down Expand Up @@ -234,9 +237,9 @@ def test_pipeline_with_choices(self):
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline),
pipeline.to_dict(),
{
"code": "test",
"name": "Test pipeline",
Expand Down Expand Up @@ -271,9 +274,11 @@ def test_pipeline_with_timeout(self):
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline), {"code": "test", "name": "Test pipeline", "parameters": [], "timeout": 42}
pipeline.to_dict(),
{"code": "test", "function": None, "name": "Test pipeline", "parameters": [], "timeout": 42,
"tasks": []}
)

def test_pipeline_with_bool(self):
Expand All @@ -293,9 +298,9 @@ def test_pipeline_with_bool(self):
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline),
pipeline.to_dict(),
{
"code": "test",
"name": "Test pipeline",
Expand Down Expand Up @@ -333,9 +338,9 @@ def test_pipeline_with_multiple_parameters(self):
]
)
)
pipeline = get_pipeline_metadata(tmpdirname)
pipeline = get_pipeline(tmpdirname)
self.assertEqual(
asdict(pipeline),
pipeline.to_dict(),
{
"code": "test",
"name": "Test pipeline",
Expand Down Expand Up @@ -382,5 +387,5 @@ def test_pipeline_with_unsupported_parameter(self):
]
)
)
with self.assertRaises(InvalidParameterError):
get_pipeline_metadata(tmpdirname)
with self.assertRaises(KeyError):
get_pipeline(tmpdirname)

0 comments on commit e5e356c

Please sign in to comment.