Skip to content

Commit

Permalink
fix(create-cart-diagram): use functions instead of tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
yannforget committed Nov 3, 2024
1 parent 6b68b31 commit dc5bec4
Showing 1 changed file with 18 additions and 12 deletions.
30 changes: 18 additions & 12 deletions create-cart-diagram/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from datetime import datetime
from pathlib import Path
from typing import Optional
from typing import Optional, Tuple

from openhexa.sdk import Dataset, current_run, parameter, pipeline, workspace
from pathways.typing.mermaid import cart_diagram
Expand Down Expand Up @@ -35,7 +35,9 @@ def create_cart_diagram(
):
"""Create a CART diagram from CART outputs."""

data = load_dataset(dataset=cart_outputs, version_name=version_name)
urban, rural, version = load_dataset(
dataset=cart_outputs, version_name=version_name
)

if output_dir:
output_dir = Path(workspace.files_path, output_dir)
Expand All @@ -46,20 +48,21 @@ def create_cart_diagram(
"data",
"output",
"cart_diagram",
data["version"],
version,
datetime.now().strftime("%Y-%m-%d_%H:%M:%S"),
)

generate_diagram(
urban_cart=data["urban"],
rural_cart=data["rural"],
urban_cart=urban,
rural_cart=rural,
output_dir=output_dir,
version_name=data["version"],
version_name=version,
)


@create_cart_diagram.task
def load_dataset(dataset: Dataset, version_name: str | None = None) -> dict:
def load_dataset(
dataset: Dataset, version_name: str | None = None
) -> Tuple[list[dict], list[dict], str]:
"""Load urban and rural JSON files from dataset.
Parameters
Expand All @@ -71,8 +74,12 @@ def load_dataset(dataset: Dataset, version_name: str | None = None) -> dict:
Return
------
dict
A dictionary containing the urban and rural JSON files (with `urban` and `rural` keys).
list[dict]
The urban JSON-like CART data
list[dict]
The rural JSON-like CART data
str
The name of the dataset version
"""
ds: Dataset = None

Expand Down Expand Up @@ -110,10 +117,9 @@ def load_dataset(dataset: Dataset, version_name: str | None = None) -> dict:
current_run.log_error(msg)
raise FileNotFoundError(msg)

return {"urban": urban, "rural": rural, "version": ds.name}
return urban, rural, ds.name


@create_cart_diagram.task
def generate_diagram(
urban_cart: list[dict], rural_cart: list[dict], output_dir: Path, version_name: str
):
Expand Down

0 comments on commit dc5bec4

Please sign in to comment.