Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[FEATURE] 163 - For each task #185

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
09bf43e
Support for foreach task
riccamini Dec 2, 2024
8e7c423
tests
riccamini Dec 3, 2024
cb39e5e
refactor: DatabricksBundleCodegen to improve task handling and stream…
mikita-sakalouski Dec 3, 2024
8d99954
fix: update task builder function to use nested task type for improve…
mikita-sakalouski Dec 3, 2024
aaf1578
Merge pull request #1 from riccamini/feature/mikita_contribution
riccamini Dec 3, 2024
afbd444
Builders in _get_task_builder can be resolved also from task class
riccamini Dec 4, 2024
953f79c
Fix tests and serialization issue with ForEachTask custom class
riccamini Dec 5, 2024
0285f59
Support for brickflow task type in for each task
riccamini Dec 6, 2024
197e868
Setting push_return_value in task response to False if task type is f…
riccamini Dec 9, 2024
b46bd32
Refactoring for each task build function to allow for validation of t…
riccamini Dec 9, 2024
a688ce1
Support for spark jar task type in for each task
riccamini Dec 10, 2024
ebb12b1
Tests for run job task type in for each task
riccamini Dec 12, 2024
de85a5a
Tests for sql task type in for each task
riccamini Dec 12, 2024
4dba431
feat: add JobsTasksForEachTaskConfigs for improved task configuration…
mikita-sakalouski Dec 16, 2024
2acaa85
Merge remote-tracking branch 'origin/feature/163-for-each-task-suppor…
mikita-sakalouski Dec 16, 2024
a80d229
feat: add JobsTasksForEachTaskConfigs for for-each task configuration…
mikita-sakalouski Dec 16, 2024
5a67000
feat: implement model validation for ForEachTask configuration inputs…
mikita-sakalouski Dec 16, 2024
dbf98e0
feat: simplify ForEachTask initialization by using task configuration…
mikita-sakalouski Dec 16, 2024
aa2420b
fix format
mikita-sakalouski Dec 16, 2024
d51ce57
fix format
mikita-sakalouski Dec 16, 2024
4421ee9
Removed TODO, nested task name is not exposed
riccamini Dec 16, 2024
a1b8ff9
Documentation of for each task type
riccamini Dec 16, 2024
7deec96
For each task examples
riccamini Dec 16, 2024
abbf90b
Merge pull request #2 from riccamini/feature/for_each_model_settings
riccamini Dec 16, 2024
4ff9c13
Moved for each task config validation up
riccamini Dec 16, 2024
b9131f2
Updates to doc and examples after introduction of for each task confi…
riccamini Dec 16, 2024
b2b915e
chore: update Python version to 3.9 and refine configuration files
mikita-sakalouski Dec 17, 2024
c5b495c
chore: update documentation dependencies and versions in pyproject.toml
mikita-sakalouski Dec 17, 2024
d5c25b9
Fixed formatting of bullet points in doc
riccamini Dec 18, 2024
23c86a1
Update brickflow project conf in for each examples
riccamini Dec 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/onpush.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ jobs:
strategy:
max-parallel: 2
matrix:
python-version: [ '3.8' ]
python-version: [ '3.9' ]
os: [ ubuntu-latest ]

steps:
Expand Down
7 changes: 5 additions & 2 deletions brickflow/bundles/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from typing import Any, Dict, List, Optional, Union

from pydantic import BaseModel, Extra, Field, constr
from pydantic import BaseModel, Field, constr, InstanceOf
from typing_extensions import Literal


Expand Down Expand Up @@ -1174,7 +1174,10 @@ class Config:


class JobsTasksForEachTask(BaseModel):
pass
inputs: str
concurrency: int
task: JobsTasks



class JobsTasksHealthRules(BaseModel):
Expand Down
191 changes: 148 additions & 43 deletions brickflow/codegen/databricks_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,13 @@
)
from brickflow.engine.task import (
DLTPipeline,
ForEachTask,
IfElseConditionTask,
NotebookTask,
RunJobTask,
SparkJarTask,
SparkPythonTask,
SqlTask,
TaskLibrary,
TaskSettings,
filter_bf_related_libraries,
Expand Down Expand Up @@ -498,12 +505,11 @@ def adjust_file_path(self, file_path: str) -> str:
return file_path

def task_to_task_obj(self, task: Task) -> JobsTasksNotebookTask:
if task.task_type in [TaskType.BRICKFLOW_TASK, TaskType.CUSTOM_PYTHON_TASK]:
generated_path = handle_mono_repo_path(self.project, self.env)
return JobsTasksNotebookTask(
**task.get_obj_dict(generated_path),
source=self.adjust_source(),
)
generated_path = handle_mono_repo_path(self.project, self.env)
return JobsTasksNotebookTask(
**task.get_obj_dict(generated_path),
source=self.adjust_source(),
)

def workflow_obj_to_pipelines(self, workflow: Workflow) -> Dict[str, Pipelines]:
pipelines_dict = {}
Expand Down Expand Up @@ -760,6 +766,70 @@ def _build_dlt_task(
task_key=task_name,
)

def _build_native_for_each_task(
self,
task_name: str,
task: Task,
task_libraries: List[JobsTasksLibraries],
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
**kwargs: Any,
) -> JobsTasks:
supported_task_types = (
TaskType.NOTEBOOK_TASK,
TaskType.SPARK_JAR_TASK,
TaskType.SPARK_PYTHON_TASK,
TaskType.RUN_JOB_TASK,
TaskType.SQL,
TaskType.BRICKFLOW_TASK, # Accounts for brickflow entrypoint tasks
)

if task.for_each_task_conf is None:
raise ValueError(
f"Error while building for each task {task_name}. "
f"Make sure {task_name} has a for_each_task_conf attribute."
)

nested_task = task.task_func()
task_type = self._get_task_type(nested_task)

try:
assert task_type in supported_task_types
except AssertionError as e:
raise ValueError(
f"Error while building python task {task_name}. Make sure {task_name} is one of "
f"{', '.join(task_type.__name__ for task_type in supported_task_types)}."
) from e

builder_func = self._get_task_builder(task_type=task_type)

workflow: Optional[Workflow] = kwargs.get("workflow")
# Currently the inner task name is not exposed, will have to add a parammeter to the for_each_task decorator to
# allow user to configure it
nested_task_jt = builder_func(
task_name=f"{task_name}_nested",
task=task,
workflow=workflow,
task_libraries=task_libraries,
task_settings=task_settings,
depends_on=[],
)

for_each_task = ForEachTask(
configs=task.for_each_task_conf,
task=nested_task_jt,
)

# We are not specifying any cluster or libraries as for_each_task cannot have them!
jt = JobsTasks(
**task_settings.to_tf_dict(),
for_each_task=for_each_task,
depends_on=depends_on,
task_key=task_name,
)

return jt

def _build_brickflow_entrypoint_task(
self,
task_name: str,
Expand All @@ -771,7 +841,7 @@ def _build_brickflow_entrypoint_task(
) -> JobsTasks:
task_obj = JobsTasks(
**{
task.databricks_task_type_str: self.task_to_task_obj(task),
TaskType.NOTEBOOK_TASK.value: self.task_to_task_obj(task),
**task_settings.to_tf_dict(),
}, # type: ignore
depends_on=depends_on,
Expand All @@ -791,56 +861,91 @@ def _build_brickflow_entrypoint_task(
)
return task_obj

def workflow_obj_to_tasks(
self, workflow: Workflow
) -> List[Union[JobsTasks, Pipelines]]:
tasks = []
def _get_task_type(self, task: Any) -> TaskType:
"""Resolves the task type given the task object"""

map_task_class_to_task_type: Dict[typing.Type, TaskType] = {
DLTPipeline: TaskType.DLT,
NotebookTask: TaskType.NOTEBOOK_TASK,
SparkJarTask: TaskType.SPARK_JAR_TASK,
SparkPythonTask: TaskType.SPARK_PYTHON_TASK,
RunJobTask: TaskType.RUN_JOB_TASK,
SqlTask: TaskType.SQL,
IfElseConditionTask: TaskType.IF_ELSE_CONDITION_TASK,
ForEachTask: TaskType.FOR_EACH_TASK,
}

# Brickflow tasks does not have a dedicated task class, so we are matching everything else with it
return map_task_class_to_task_type.get(type(task), TaskType.BRICKFLOW_TASK)

def _get_task_builder(self, task_type: TaskType = None) -> Callable[..., Any]:
map_task_type_to_builder: Dict[TaskType, Callable[..., Any]] = {
TaskType.BRICKFLOW_TASK: self._build_brickflow_entrypoint_task,
TaskType.DLT: self._build_dlt_task,
TaskType.NOTEBOOK_TASK: self._build_native_notebook_task,
TaskType.SPARK_JAR_TASK: self._build_native_spark_jar_task,
TaskType.SPARK_PYTHON_TASK: self._build_native_spark_python_task,
TaskType.RUN_JOB_TASK: self._build_native_run_job_task,
TaskType.SQL: self._build_native_sql_file_task,
TaskType.IF_ELSE_CONDITION_TASK: self._build_native_condition_task,
TaskType.FOR_EACH_TASK: self._build_native_for_each_task,
TaskType.CUSTOM_PYTHON_TASK: self._build_brickflow_entrypoint_task,
}

for task_name, task in workflow.tasks.items():
builder_func = map_task_type_to_builder.get(
task.task_type, self._build_brickflow_entrypoint_task
)
builder = map_task_type_to_builder.get(task_type, None)
if builder is None:
raise ValueError("No builder found for the given task or task class")
return builder

def _build_task(
self, build_func: Callable, workflow: Workflow, task_name: str, task: Task
) -> Union[JobsTasks, Pipelines]:
# TODO: DLT
# pipeline_task: Pipeline = self._create_dlt_notebooks(stack, task)
if task.depends_on_names:
depends_on = [
JobsTasksDependsOn(task_key=depends_key, outcome=expected_outcome)
for i in task.depends_on_names
for depends_key, expected_outcome in i.items()
] # type: ignore
else:
depends_on = []

# TODO: DLT
# pipeline_task: Pipeline = self._create_dlt_notebooks(stack, task)
if task.depends_on_names:
depends_on = [
JobsTasksDependsOn(task_key=depends_key, outcome=expected_outcome)
for i in task.depends_on_names
for depends_key, expected_outcome in i.items()
] # type: ignore
else:
depends_on = []
libraries = TaskLibrary.unique_libraries(
task.libraries + (self.project.libraries or [])
)
if workflow.enable_plugins is True:
libraries = filter_bf_related_libraries(libraries)
libraries += get_brickflow_libraries(workflow.enable_plugins)
libraries = TaskLibrary.unique_libraries(
task.libraries + (self.project.libraries or [])
)
if workflow.enable_plugins is True:
libraries = filter_bf_related_libraries(libraries)
libraries += get_brickflow_libraries(workflow.enable_plugins)

task_libraries = [JobsTasksLibraries(**library.dict) for library in libraries] # type: ignore
task_settings = workflow.default_task_settings.merge(task.task_settings) # type: ignore
task = build_func(
task_name=task_name,
task=task,
workflow=workflow,
task_libraries=task_libraries,
task_settings=task_settings,
depends_on=depends_on,
)

task_libraries = [
JobsTasksLibraries(**library.dict) for library in libraries
] # type: ignore
task_settings = workflow.default_task_settings.merge(task.task_settings) # type: ignore
task = builder_func(
task_name=task_name,
task=task,
workflow=workflow,
task_libraries=task_libraries,
task_settings=task_settings,
depends_on=depends_on,
return task

def workflow_obj_to_tasks(
self, workflow: Workflow
) -> List[Union[JobsTasks, Pipelines]]:
tasks = []

for task_name, task in workflow.tasks.items():
build_func = self._get_task_builder(task_type=task.task_type)
tasks.append(
self._build_task(
build_func=build_func,
workflow=workflow,
task_name=task_name,
task=task,
)
)
tasks.append(task)

tasks.sort(key=lambda t: (t.task_key is None, t.task_key))

Expand Down
45 changes: 44 additions & 1 deletion brickflow/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

import pluggy
from decouple import config
from pydantic import BaseModel, Field, field_validator, model_validator

from brickflow import (
BrickflowDefaultEnvs,
Expand All @@ -38,6 +39,7 @@
)
from brickflow.bundles.model import (
JobsTasksConditionTask,
JobsTasksForEachTask,
JobsTasksHealthRules,
JobsTasksNotebookTask,
JobsTasksNotificationSettings,
Expand Down Expand Up @@ -123,6 +125,7 @@ class TaskType(Enum):
SPARK_PYTHON_TASK = "spark_python_task"
RUN_JOB_TASK = "run_job_task"
IF_ELSE_CONDITION_TASK = "condition_task"
FOR_EACH_TASK = "for_each_task"


class TaskRunCondition(Enum):
Expand Down Expand Up @@ -493,6 +496,44 @@ def __init__(self, **kwargs: Any) -> None:
self.python_file = kwargs.get("python_file", None)


class JobsTasksForEachTaskConfigs(BaseModel):
inputs: str = Field(..., description="The input data for the task.")
concurrency: int = Field(
default=1, description="Number of iterations that can run in parallel,"
)

@field_validator("inputs", mode="before")
@classmethod
def validate_inputs(cls, inputs: Any) -> str:
if not isinstance(inputs, str):
inputs = json.dumps(inputs)
return inputs


class ForEachTask(JobsTasksForEachTask):
"""
The ForEachTask class provides iteration of a task over a list of inputs. The looped task can be executed
concurrently based on the concurrency value provided.

Attributes:
inputs (str): Array for task to iterate on. This can be a JSON string or a reference to an array parameter.
concurrency (int): An optional maximum allowed number of concurrent runs of the task. Set this value if you want
to be able to execute multiple runs of the task concurrently
task (Any): The task that will be run for each element in the array

"""

configs: JobsTasksForEachTaskConfigs
task: Any

@model_validator(mode="before")
def validate_configs(self) -> "ForEachTask":
self["inputs"] = self["configs"].inputs # type: ignore
self["concurrency"] = self["configs"].concurrency # type: ignore

return self


class RunJobTask(JobsTasksRunJobTask):
"""
The RunJobTask class is designed to handle the execution of a specific job in a Databricks workspace.
Expand Down Expand Up @@ -704,10 +745,11 @@ def task_execute(task: "Task", workflow: "Workflow") -> TaskResponse:
else:
kwargs = task.get_runtime_parameter_values()
try:
# Task return value cannot be pushed if we are in a for each task (now allowed by Databricks)
return TaskResponse(
task.task_func(**kwargs),
user_code_error=None,
push_return_value=True,
push_return_value=not task.task_type == TaskType.FOR_EACH_TASK,
input_kwargs=kwargs,
)
except Exception as e:
Expand Down Expand Up @@ -807,6 +849,7 @@ class Task:
ensure_brickflow_plugins: bool = False
health: Optional[List[JobsTasksHealthRules]] = None
if_else_outcome: Optional[Dict[Union[str, str], str]] = None
for_each_task_conf: Optional[JobsTasksForEachTaskConfigs] = None

def __post_init__(self) -> None:
self.is_valid_task_signature()
Expand Down
Loading
Loading