Skip to content

Commit

Permalink
Feature if else task (#136)
Browse files Browse the repository at this point in the history
* fix: added if/else task with unit test cases

* fix: added examples and docs strings and renamed wf name

* fix: added docs for if/else

* fix: updated docs

* fix: added contributors names

* fix: updated code changeds for if task, updated examples and docs

* fix: updated comments

* fix: updated init file with enum and updated task file
  • Loading branch information
RajuGujjalapati authored Jun 17, 2024
1 parent 06d7a62 commit d579492
Show file tree
Hide file tree
Showing 14 changed files with 307 additions and 7 deletions.
1 change: 1 addition & 0 deletions CONTRIBUTORS.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Thanks to the contributors who helped on this project apart from the authors
* [Dmitrii Grigorev](https://www.linkedin.com/in/dmitrii-grigorev-074739135/)
* [Chanukya Konuganti](https://www.linkedin.com/in/chanukyakonuganti/)
* [Maxim Mityutko](https://www.linkedin.com/in/mityutko/)
* [Raju Gujjalapati](https://in.linkedin.com/in/raju-gujjalapati-470a88171)

# Honorary Mentions
Thanks to the team below for invaluable insights and support throughout the initial release of this project
Expand Down
4 changes: 4 additions & 0 deletions brickflow/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,7 @@ def get_bundles_project_env() -> str:
TaskResponse,
BrickflowTriggerRule,
TaskRunCondition,
Operator,
BrickflowTaskEnvVars,
StorageBasedTaskLibrary,
JarTaskLibrary,
Expand All @@ -277,6 +278,7 @@ def get_bundles_project_env() -> str:
SparkJarTask,
RunJobTask,
SqlTask,
IfElseConditionTask,
)
from brickflow.engine.compute import Cluster, Runtimes
from brickflow.engine.project import Project
Expand Down Expand Up @@ -306,6 +308,7 @@ def get_bundles_project_env() -> str:
"TaskResponse",
"BrickflowTriggerRule",
"TaskRunCondition",
"Operator",
"BrickflowTaskEnvVars",
"StorageBasedTaskLibrary",
"JarTaskLibrary",
Expand All @@ -323,6 +326,7 @@ def get_bundles_project_env() -> str:
"DLTChannels",
"Cluster",
"Runtimes",
"IfElseConditionTask",
"Project",
"SqlTask",
"_ilog",
Expand Down
41 changes: 39 additions & 2 deletions brickflow/codegen/databricks_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
JobsTasksRunJobTask,
JobsTasksSparkJarTask,
JobsTasksSqlTask,
JobsTasksConditionTask,
Resources,
Workspace,
Bundle,
Expand Down Expand Up @@ -531,6 +532,28 @@ def _build_native_sql_file_task(
task_key=task_name,
)

def _build_native_condition_task(
self,
task_name: str,
task: Task,
task_settings: TaskSettings,
depends_on: List[JobsTasksDependsOn],
) -> JobsTasks:
try:
condition_task: JobsTasksConditionTask = task.task_func()
except Exception as e:
print(e)
raise ValueError(
f"Error while building If/else task {task_name}. "
f"Make sure {task_name} returns a JobsTasksConditionTask object."
) from e
return JobsTasks(
**task_settings.to_tf_dict(), # type: ignore
condition_task=condition_task,
depends_on=depends_on,
task_key=task_name,
)

def _build_dlt_task(
self,
task_name: str,
Expand Down Expand Up @@ -559,7 +582,14 @@ def workflow_obj_to_tasks(
for task_name, task in workflow.tasks.items():
# TODO: DLT
# pipeline_task: Pipeline = self._create_dlt_notebooks(stack, task)
depends_on = [JobsTasksDependsOn(task_key=f) for f in task.depends_on_names] # type: ignore
if task.depends_on_names:
depends_on = [
JobsTasksDependsOn(task_key=depends_key, outcome=expected_outcome)
for i in task.depends_on_names
for depends_key, expected_outcome in i.items()
] # type: ignore
else:
depends_on = []
libraries = TaskLibrary.unique_libraries(
task.libraries + (self.project.libraries or [])
)
Expand Down Expand Up @@ -600,12 +630,19 @@ def workflow_obj_to_tasks(
)
)
elif task.task_type == TaskType.SQL:
# native run job task
# native SQL task
tasks.append(
self._build_native_sql_file_task(
task_name, task, task_settings, depends_on
)
)
elif task.task_type == TaskType.IF_ELSE_CONDITION_TASK:
# native If/Else task
tasks.append(
self._build_native_condition_task(
task_name, task, task_settings, depends_on
)
)
else:
# brickflow entrypoint task
task_obj = JobsTasks(
Expand Down
62 changes: 59 additions & 3 deletions brickflow/engine/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
JobsTasksSqlTaskDashboard,
JobsTasksSqlTaskFile,
JobsTasksSqlTaskQuery,
JobsTasksConditionTask,
)
from brickflow.cli.projects import DEFAULT_BRICKFLOW_VERSION_MODE
from brickflow.context import (
Expand Down Expand Up @@ -117,6 +118,7 @@ class TaskType(Enum):
NOTEBOOK_TASK = "notebook_task"
SPARK_JAR_TASK = "spark_jar_task"
RUN_JOB_TASK = "run_job_task"
IF_ELSE_CONDITION_TASK = "condition_task"


class TaskRunCondition(Enum):
Expand All @@ -128,6 +130,15 @@ class TaskRunCondition(Enum):
ALL_FAILED = "ALL_FAILED"


class Operator(Enum):
EQUAL_TO = "=="
NOT_EQUAL = "!="
GREATER_THAN = ">"
LESS_THAN = "<"
GREATER_THAN_OR_EQUAL = ">="
LESS_THAN_OR_EQUAL = "<="


@dataclass(frozen=True)
class TaskLibrary:
@staticmethod
Expand Down Expand Up @@ -585,6 +596,45 @@ def __init__(self, *args: Any, **kwds: Any):
)


class IfElseConditionTask(JobsTasksConditionTask):
"""
The IfElseConditionTask class is designed to handle conditional tasks in a workflow.
An instance of IfElseConditionTask represents a conditional task that compares two values (left and right)
using a specified operator. The operator can be one of the following: "==", "!=", ">", "<", ">=", "<=".
The IfElseConditionTask class provides additional functionality for
mapping the provided operator to a specific operation. For example,
the operator "==" is mapped to "EQUAL_TO", and the operator "!=" is mapped to "NOT_EQUAL".
Attributes:
left (str): The left operand in the condition.
right (str): The right operand in the condition.
It can be one of the following: "==", "!=", ">", "<", ">=", "<=".
op (Operator): The operation corresponding to the operator.
It is determined based on the operator.
Examples:
Below are the different ways in which the IfElseConditionTask class
can be used inside a workflow (if_else_condition_task).:
1. IfElseConditionTask(left="value1", right="value2", op="==")
2. IfElseConditionTask(left="value1", right="value2", op="!=")
3. IfElseConditionTask(left="value1", right="value2", op=">")
4. IfElseConditionTask(left="value1", right="value2", op="<")
5. IfElseConditionTask(left="value1", right="value2", op=">=")
6. IfElseConditionTask(left="value1", right="value2", op="<=")
"""

left: str
right: str
op: str

def __init__(self, *args: Any, **kwds: Any):
super().__init__(*args, **kwds)
self.left = kwds["left"]
self.right = kwds["right"]
self.op = str(Operator(self.op).name)


class DefaultBrickflowTaskPluginImpl(BrickflowTaskPluginSpec):
@staticmethod
@brickflow_task_plugin_impl
Expand Down Expand Up @@ -712,6 +762,7 @@ class Task:
custom_execute_callback: Optional[Callable] = None
ensure_brickflow_plugins: bool = False
health: Optional[List[JobsTasksHealthRules]] = None
if_else_outcome: Optional[Dict[Union[str, str], str]] = None

def __post_init__(self) -> None:
self.is_valid_task_signature()
Expand All @@ -725,12 +776,17 @@ def parents(self) -> List[str]:
return list(self.workflow.parents(self.task_id))

@property
def depends_on_names(self) -> Iterator[str]:
def depends_on_names(self) -> Iterator[Dict[str, Optional[str]]]:
for i in self.depends_on:
if self.if_else_outcome:
outcome = list(self.if_else_outcome.values())[0]
else:
outcome = None

if callable(i) and hasattr(i, "__name__"):
yield i.__name__
yield {i.__name__: outcome}
else:
yield str(i)
yield {str(i): outcome}

@property
def databricks_task_type_str(self) -> str:
Expand Down
31 changes: 31 additions & 0 deletions brickflow/engine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def _add_task(
custom_execute_callback: Optional[Callable] = None,
task_settings: Optional[TaskSettings] = None,
ensure_brickflow_plugins: bool = False,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
) -> None:
if self.task_exists(task_id):
raise TaskAlreadyExistsError(
Expand Down Expand Up @@ -339,6 +340,7 @@ def _add_task(
task_settings=task_settings,
custom_execute_callback=custom_execute_callback,
ensure_brickflow_plugins=ensure_plugins,
if_else_outcome=if_else_outcome,
)

# attempt to create task object before adding to graph
Expand All @@ -353,13 +355,15 @@ def dlt_task(
name: Optional[str] = None,
task_settings: Optional[TaskSettings] = None,
depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
) -> Callable:
return self.task(
task_func,
name,
task_type=TaskType.DLT,
task_settings=task_settings,
depends_on=depends_on,
if_else_outcome=if_else_outcome,
)

def notebook_task(
Expand All @@ -370,6 +374,7 @@ def notebook_task(
libraries: Optional[List[TaskLibrary]] = None,
task_settings: Optional[TaskSettings] = None,
depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
) -> Callable:
return self.task(
task_func,
Expand All @@ -379,6 +384,7 @@ def notebook_task(
task_type=TaskType.NOTEBOOK_TASK,
task_settings=task_settings,
depends_on=depends_on,
if_else_outcome=if_else_outcome,
)

def spark_jar_task(
Expand All @@ -389,6 +395,7 @@ def spark_jar_task(
libraries: Optional[List[TaskLibrary]] = None,
task_settings: Optional[TaskSettings] = None,
depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
) -> Callable:
return self.task(
task_func,
Expand All @@ -398,6 +405,7 @@ def spark_jar_task(
task_type=TaskType.SPARK_JAR_TASK,
task_settings=task_settings,
depends_on=depends_on,
if_else_outcome=if_else_outcome,
)

def run_job_task(
Expand All @@ -406,13 +414,15 @@ def run_job_task(
name: Optional[str] = None,
task_settings: Optional[TaskSettings] = None,
depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
) -> Callable:
return self.task(
task_func,
name,
task_type=TaskType.RUN_JOB_TASK,
task_settings=task_settings,
depends_on=depends_on,
if_else_outcome=if_else_outcome,
)

def sql_task(
Expand All @@ -421,13 +431,32 @@ def sql_task(
name: Optional[str] = None,
task_settings: Optional[TaskSettings] = None,
depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
) -> Callable:
return self.task(
task_func,
name,
task_type=TaskType.SQL,
task_settings=task_settings,
depends_on=depends_on,
if_else_outcome=if_else_outcome,
)

def if_else_condition_task(
self,
task_func: Optional[Callable] = None,
name: Optional[str] = None,
task_settings: Optional[TaskSettings] = None,
depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
) -> Callable:
return self.task(
task_func,
name,
task_type=TaskType.IF_ELSE_CONDITION_TASK,
task_settings=task_settings,
depends_on=depends_on,
if_else_outcome=if_else_outcome,
)

def task(
Expand All @@ -442,6 +471,7 @@ def task(
custom_execute_callback: Optional[Callable] = None,
task_settings: Optional[TaskSettings] = None,
ensure_brickflow_plugins: bool = False,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
) -> Callable:
if len(self.tasks) >= self.max_tasks_in_workflow:
raise ValueError(
Expand All @@ -464,6 +494,7 @@ def task_wrapper(f: Callable) -> Callable:
custom_execute_callback=custom_execute_callback,
task_settings=task_settings,
ensure_brickflow_plugins=ensure_brickflow_plugins,
if_else_outcome=if_else_outcome,
)

@functools.wraps(f)
Expand Down
Loading

0 comments on commit d579492

Please sign in to comment.