diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index 9045b461..e34f4f36 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -13,6 +13,7 @@ Thanks to the contributors who helped on this project apart from the authors * [Dmitrii Grigorev](https://www.linkedin.com/in/dmitrii-grigorev-074739135/) * [Chanukya Konuganti](https://www.linkedin.com/in/chanukyakonuganti/) * [Maxim Mityutko](https://www.linkedin.com/in/mityutko/) +* [Raju Gujjalapati](https://in.linkedin.com/in/raju-gujjalapati-470a88171) # Honorary Mentions Thanks to the team below for invaluable insights and support throughout the initial release of this project diff --git a/brickflow/__init__.py b/brickflow/__init__.py index cd08b962..c9c7e024 100644 --- a/brickflow/__init__.py +++ b/brickflow/__init__.py @@ -260,6 +260,7 @@ def get_bundles_project_env() -> str: TaskResponse, BrickflowTriggerRule, TaskRunCondition, + Operator, BrickflowTaskEnvVars, StorageBasedTaskLibrary, JarTaskLibrary, @@ -277,6 +278,7 @@ def get_bundles_project_env() -> str: SparkJarTask, RunJobTask, SqlTask, + IfElseConditionTask, ) from brickflow.engine.compute import Cluster, Runtimes from brickflow.engine.project import Project @@ -306,6 +308,7 @@ def get_bundles_project_env() -> str: "TaskResponse", "BrickflowTriggerRule", "TaskRunCondition", + "Operator", "BrickflowTaskEnvVars", "StorageBasedTaskLibrary", "JarTaskLibrary", @@ -323,6 +326,7 @@ def get_bundles_project_env() -> str: "DLTChannels", "Cluster", "Runtimes", + "IfElseConditionTask", "Project", "SqlTask", "_ilog", diff --git a/brickflow/codegen/databricks_bundle.py b/brickflow/codegen/databricks_bundle.py index 5608ce06..85cdc879 100644 --- a/brickflow/codegen/databricks_bundle.py +++ b/brickflow/codegen/databricks_bundle.py @@ -40,6 +40,7 @@ JobsTasksRunJobTask, JobsTasksSparkJarTask, JobsTasksSqlTask, + JobsTasksConditionTask, Resources, Workspace, Bundle, @@ -531,6 +532,28 @@ def _build_native_sql_file_task( task_key=task_name, ) + def _build_native_condition_task( + self, + task_name: str, + task: Task, + task_settings: TaskSettings, + depends_on: List[JobsTasksDependsOn], + ) -> JobsTasks: + try: + condition_task: JobsTasksConditionTask = task.task_func() + except Exception as e: + print(e) + raise ValueError( + f"Error while building If/else task {task_name}. " + f"Make sure {task_name} returns a JobsTasksConditionTask object." + ) from e + return JobsTasks( + **task_settings.to_tf_dict(), # type: ignore + condition_task=condition_task, + depends_on=depends_on, + task_key=task_name, + ) + def _build_dlt_task( self, task_name: str, @@ -559,7 +582,14 @@ def workflow_obj_to_tasks( for task_name, task in workflow.tasks.items(): # TODO: DLT # pipeline_task: Pipeline = self._create_dlt_notebooks(stack, task) - depends_on = [JobsTasksDependsOn(task_key=f) for f in task.depends_on_names] # type: ignore + if task.depends_on_names: + depends_on = [ + JobsTasksDependsOn(task_key=depends_key, outcome=expected_outcome) + for i in task.depends_on_names + for depends_key, expected_outcome in i.items() + ] # type: ignore + else: + depends_on = [] libraries = TaskLibrary.unique_libraries( task.libraries + (self.project.libraries or []) ) @@ -600,12 +630,19 @@ def workflow_obj_to_tasks( ) ) elif task.task_type == TaskType.SQL: - # native run job task + # native SQL task tasks.append( self._build_native_sql_file_task( task_name, task, task_settings, depends_on ) ) + elif task.task_type == TaskType.IF_ELSE_CONDITION_TASK: + # native If/Else task + tasks.append( + self._build_native_condition_task( + task_name, task, task_settings, depends_on + ) + ) else: # brickflow entrypoint task task_obj = JobsTasks( diff --git a/brickflow/engine/task.py b/brickflow/engine/task.py index aafe163e..0c7cb2ed 100644 --- a/brickflow/engine/task.py +++ b/brickflow/engine/task.py @@ -47,6 +47,7 @@ JobsTasksSqlTaskDashboard, JobsTasksSqlTaskFile, JobsTasksSqlTaskQuery, + JobsTasksConditionTask, ) from brickflow.cli.projects import DEFAULT_BRICKFLOW_VERSION_MODE from brickflow.context import ( @@ -117,6 +118,7 @@ class TaskType(Enum): NOTEBOOK_TASK = "notebook_task" SPARK_JAR_TASK = "spark_jar_task" RUN_JOB_TASK = "run_job_task" + IF_ELSE_CONDITION_TASK = "condition_task" class TaskRunCondition(Enum): @@ -128,6 +130,15 @@ class TaskRunCondition(Enum): ALL_FAILED = "ALL_FAILED" +class Operator(Enum): + EQUAL_TO = "==" + NOT_EQUAL = "!=" + GREATER_THAN = ">" + LESS_THAN = "<" + GREATER_THAN_OR_EQUAL = ">=" + LESS_THAN_OR_EQUAL = "<=" + + @dataclass(frozen=True) class TaskLibrary: @staticmethod @@ -585,6 +596,45 @@ def __init__(self, *args: Any, **kwds: Any): ) +class IfElseConditionTask(JobsTasksConditionTask): + """ + The IfElseConditionTask class is designed to handle conditional tasks in a workflow. + An instance of IfElseConditionTask represents a conditional task that compares two values (left and right) + using a specified operator. The operator can be one of the following: "==", "!=", ">", "<", ">=", "<=". + + The IfElseConditionTask class provides additional functionality for + mapping the provided operator to a specific operation. For example, + the operator "==" is mapped to "EQUAL_TO", and the operator "!=" is mapped to "NOT_EQUAL". + + Attributes: + left (str): The left operand in the condition. + right (str): The right operand in the condition. + It can be one of the following: "==", "!=", ">", "<", ">=", "<=". + op (Operator): The operation corresponding to the operator. + It is determined based on the operator. + + Examples: + Below are the different ways in which the IfElseConditionTask class + can be used inside a workflow (if_else_condition_task).: + 1. IfElseConditionTask(left="value1", right="value2", op="==") + 2. IfElseConditionTask(left="value1", right="value2", op="!=") + 3. IfElseConditionTask(left="value1", right="value2", op=">") + 4. IfElseConditionTask(left="value1", right="value2", op="<") + 5. IfElseConditionTask(left="value1", right="value2", op=">=") + 6. IfElseConditionTask(left="value1", right="value2", op="<=") + """ + + left: str + right: str + op: str + + def __init__(self, *args: Any, **kwds: Any): + super().__init__(*args, **kwds) + self.left = kwds["left"] + self.right = kwds["right"] + self.op = str(Operator(self.op).name) + + class DefaultBrickflowTaskPluginImpl(BrickflowTaskPluginSpec): @staticmethod @brickflow_task_plugin_impl @@ -712,6 +762,7 @@ class Task: custom_execute_callback: Optional[Callable] = None ensure_brickflow_plugins: bool = False health: Optional[List[JobsTasksHealthRules]] = None + if_else_outcome: Optional[Dict[Union[str, str], str]] = None def __post_init__(self) -> None: self.is_valid_task_signature() @@ -725,12 +776,17 @@ def parents(self) -> List[str]: return list(self.workflow.parents(self.task_id)) @property - def depends_on_names(self) -> Iterator[str]: + def depends_on_names(self) -> Iterator[Dict[str, Optional[str]]]: for i in self.depends_on: + if self.if_else_outcome: + outcome = list(self.if_else_outcome.values())[0] + else: + outcome = None + if callable(i) and hasattr(i, "__name__"): - yield i.__name__ + yield {i.__name__: outcome} else: - yield str(i) + yield {str(i): outcome} @property def databricks_task_type_str(self) -> str: diff --git a/brickflow/engine/workflow.py b/brickflow/engine/workflow.py index 69e29472..cda9b51f 100644 --- a/brickflow/engine/workflow.py +++ b/brickflow/engine/workflow.py @@ -298,6 +298,7 @@ def _add_task( custom_execute_callback: Optional[Callable] = None, task_settings: Optional[TaskSettings] = None, ensure_brickflow_plugins: bool = False, + if_else_outcome: Optional[Dict[Union[str, str], str]] = None, ) -> None: if self.task_exists(task_id): raise TaskAlreadyExistsError( @@ -339,6 +340,7 @@ def _add_task( task_settings=task_settings, custom_execute_callback=custom_execute_callback, ensure_brickflow_plugins=ensure_plugins, + if_else_outcome=if_else_outcome, ) # attempt to create task object before adding to graph @@ -353,6 +355,7 @@ def dlt_task( name: Optional[str] = None, task_settings: Optional[TaskSettings] = None, depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, + if_else_outcome: Optional[Dict[Union[str, str], str]] = None, ) -> Callable: return self.task( task_func, @@ -360,6 +363,7 @@ def dlt_task( task_type=TaskType.DLT, task_settings=task_settings, depends_on=depends_on, + if_else_outcome=if_else_outcome, ) def notebook_task( @@ -370,6 +374,7 @@ def notebook_task( libraries: Optional[List[TaskLibrary]] = None, task_settings: Optional[TaskSettings] = None, depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, + if_else_outcome: Optional[Dict[Union[str, str], str]] = None, ) -> Callable: return self.task( task_func, @@ -379,6 +384,7 @@ def notebook_task( task_type=TaskType.NOTEBOOK_TASK, task_settings=task_settings, depends_on=depends_on, + if_else_outcome=if_else_outcome, ) def spark_jar_task( @@ -389,6 +395,7 @@ def spark_jar_task( libraries: Optional[List[TaskLibrary]] = None, task_settings: Optional[TaskSettings] = None, depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, + if_else_outcome: Optional[Dict[Union[str, str], str]] = None, ) -> Callable: return self.task( task_func, @@ -398,6 +405,7 @@ def spark_jar_task( task_type=TaskType.SPARK_JAR_TASK, task_settings=task_settings, depends_on=depends_on, + if_else_outcome=if_else_outcome, ) def run_job_task( @@ -406,6 +414,7 @@ def run_job_task( name: Optional[str] = None, task_settings: Optional[TaskSettings] = None, depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, + if_else_outcome: Optional[Dict[Union[str, str], str]] = None, ) -> Callable: return self.task( task_func, @@ -413,6 +422,7 @@ def run_job_task( task_type=TaskType.RUN_JOB_TASK, task_settings=task_settings, depends_on=depends_on, + if_else_outcome=if_else_outcome, ) def sql_task( @@ -421,6 +431,7 @@ def sql_task( name: Optional[str] = None, task_settings: Optional[TaskSettings] = None, depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, + if_else_outcome: Optional[Dict[Union[str, str], str]] = None, ) -> Callable: return self.task( task_func, @@ -428,6 +439,24 @@ def sql_task( task_type=TaskType.SQL, task_settings=task_settings, depends_on=depends_on, + if_else_outcome=if_else_outcome, + ) + + def if_else_condition_task( + self, + task_func: Optional[Callable] = None, + name: Optional[str] = None, + task_settings: Optional[TaskSettings] = None, + depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, + if_else_outcome: Optional[Dict[Union[str, str], str]] = None, + ) -> Callable: + return self.task( + task_func, + name, + task_type=TaskType.IF_ELSE_CONDITION_TASK, + task_settings=task_settings, + depends_on=depends_on, + if_else_outcome=if_else_outcome, ) def task( @@ -442,6 +471,7 @@ def task( custom_execute_callback: Optional[Callable] = None, task_settings: Optional[TaskSettings] = None, ensure_brickflow_plugins: bool = False, + if_else_outcome: Optional[Dict[Union[str, str], str]] = None, ) -> Callable: if len(self.tasks) >= self.max_tasks_in_workflow: raise ValueError( @@ -464,6 +494,7 @@ def task_wrapper(f: Callable) -> Callable: custom_execute_callback=custom_execute_callback, task_settings=task_settings, ensure_brickflow_plugins=ensure_brickflow_plugins, + if_else_outcome=if_else_outcome, ) @functools.wraps(f) diff --git a/docs/tasks.md b/docs/tasks.md index 57f1120e..3958ea8b 100644 --- a/docs/tasks.md +++ b/docs/tasks.md @@ -398,6 +398,60 @@ def sample_sql_dashboard_task() -> any: ``` +#### If/Else Task +The `IfElseConditionTask` class is used to create conditional tasks in the workflow. It can be used to create tasks with a left operand, a right operand, and an operator. + +The `IfElseConditionTask` is used as a decorator in conjunction with the `if_else_condition_task` method of a `Workflow` instance. This method registers the task within the workflow. + +`IfElseConditionTask` class can accept the following as inputs: +- **left[Optional]**: A string representing the left operand in the condition. +- **right[Optional]**: A string representing the right operand in the condition. +- **operator[Optional]**: A string representing the operator used in the condition. It can be one of the following: "==", "!=", ">", "<", ">=", "<=". + +Here's an example of how to use the `IfElseConditionTask` type: + +```python +# Example 1: creating a if/else task with some params +@wf.if_else_condition_task +def sample_if_else_condition_task(): + return IfElseConditionTask( + left="value1", right="value2", op="==" + ) +# Let me walk you through how we can make use of if/else condition task. we created a task with name `sample_if_else_condition_task` and it will return either true ot false. Now based on the returned bool, now we're going to decide which task to run. check the below examples. + +# Example 2: creating a if/else task that depends on example 1 and this task only triggers if example 1 returns true. +@wf.if_else_condition_task(depends_on="sample_if_else_condition_task", name="new_conditon_task", if_else_outcome={"sample_if_else_condition_task":"true"}) +def sample_condition_true(): + return IfElseConditionTask( + left='{{job.id}}', + op="==", + right='{{job.id}}') + +''' + Now i created on more condition task (you can create any task type), since my new task named `new_conditon_task` is dependent on `sample_if_else_condition_task` (if/else task). Now, If my parent tasks runs sucessfully (returns true) then only this task will trigger, cz i mentioned +if_else_outcome={"sample_if_else_condition_task":"true"}, to seee the false case see the example below. +''' +#Example 3: this task will trigger only when example1 task fails (returns false). +@wf.if_else_condition_task(depends_on="sample_if_else_condition_task", name="example_task_3", if_else_outcome={"sample_if_else_condition_task":"false"}) +def sample_condition_false(): + return IfElseConditionTask( + left='{{job.trigger.type}}', + op="==", + right='{{job.trigger.type}}') + +# Note: As we can have multiple deps same way we can keep multiple deps for if_else_outcome: +# Ex: if_else_outcome={"task1":"false", "task2":"true"} + +# Example 4: creating a SQL Alert Task that depends on example_task_3, now the below sql task will trigger only if above tasks returns true. +@wf.sql_task(depends_on="example_task_3", if_else_outcome={"example_task_3":"true"}) +def sample_sql_alert() ->any: + # it automatically validates user emails + return SqlTask(alert_id="ALERT_ID", pause_subscriptions=False, subscriptions={"usernames":["YOUR_EMAIL", 'YOUR_EMAIL']} ,warehouse_id="WAREHOUSE_ID") +# Note: Since SQL task doesn't return any bool, we can't make use of if_else_outcome params for the tasks that depends on sql Task +``` + + + ### Trigger rules There are two types of trigger rules that can be applied on a task. It can be either ALL_SUCCESS or NONE_FAILED diff --git a/examples/brickflow_examples/workflows/demo_wf.py b/examples/brickflow_examples/workflows/demo_wf.py index 5c7e9fc4..9269cecb 100644 --- a/examples/brickflow_examples/workflows/demo_wf.py +++ b/examples/brickflow_examples/workflows/demo_wf.py @@ -15,6 +15,7 @@ RunJobTask, SparkJarTask, JarTaskLibrary, + IfElseConditionTask, ) from brickflow_plugins import ( TaskDependencySensor, @@ -464,7 +465,42 @@ def sample_sql_dashboard() -> any: ) -@wf.task(depends_on=airflow_autosys_sensor) +@wf.if_else_condition_task(depends_on=sample_sql_dashboard) +def sample_condition_task1(): + return IfElseConditionTask(left="1", op="==", right="2") + + +@wf.if_else_condition_task( + depends_on=sample_condition_task1, + name="new_conditon_tasl", + if_else_outcome={"sample_condition_task1": "true"}, +) +def sample_condition_task2(): + return IfElseConditionTask(left="{{job.id}}", op="==", right="{{job.id}}") + + +@wf.if_else_condition_task( + depends_on=["sample_condition_task1", "new_conditon_tasl"], + name="new_conditon_taslq", + if_else_outcome={"sample_condition_task1": "false", "new_conditon_tasl": "true"}, +) +def sample_condition_task4(): + return IfElseConditionTask(left="2", op="==", right="4") + + +@wf.if_else_condition_task( + depends_on=["sample_condition_task1", "new_conditon_tasl"], + name="new_conditon_tasly", + if_else_outcome={"sample_condition_task1": "true", "new_conditon_tasl": "true"}, +) +def sample_condition_task5(): + return IfElseConditionTask(left="2", op="==", right="4") + + +@wf.task( + depends_on=sample_condition_task5, + if_else_outcome={"sample_condition_task5": "true"}, +) def end(): pass diff --git a/tests/codegen/expected_bundles/dev_bundle_monorepo.yml b/tests/codegen/expected_bundles/dev_bundle_monorepo.yml index c73bb076..d0dd28d0 100644 --- a/tests/codegen/expected_bundles/dev_bundle_monorepo.yml +++ b/tests/codegen/expected_bundles/dev_bundle_monorepo.yml @@ -381,6 +381,18 @@ targets: warehouse_id: "your_warehouse_id" task_key: "sample_sql_task_query" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "2" + task_key: "condtion_task_test" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml index c8ca5d48..384c4f1e 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml @@ -381,6 +381,18 @@ targets: warehouse_id: "your_warehouse_id" task_key: "sample_sql_task_query" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "2" + task_key: "condtion_task_test" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml index 12be15ee..24c1b0ab 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml @@ -476,6 +476,18 @@ targets: warehouse_id: "your_warehouse_id" task_key: "sample_sql_task_query" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "2" + task_key: "condtion_task_test" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/expected_bundles/local_bundle.yml b/tests/codegen/expected_bundles/local_bundle.yml index 617a36de..14a60e56 100644 --- a/tests/codegen/expected_bundles/local_bundle.yml +++ b/tests/codegen/expected_bundles/local_bundle.yml @@ -377,6 +377,18 @@ targets: warehouse_id: "your_warehouse_id" task_key: "sample_sql_task_query" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "2" + task_key: "condtion_task_test" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml index 647c74b2..12cb7037 100644 --- a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml +++ b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml @@ -378,6 +378,18 @@ targets: warehouse_id: "your_warehouse_id" task_key: "sample_sql_task_query" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "2" + task_key: "condtion_task_test" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/sample_workflow.py b/tests/codegen/sample_workflow.py index 2776873b..16fda518 100644 --- a/tests/codegen/sample_workflow.py +++ b/tests/codegen/sample_workflow.py @@ -11,6 +11,7 @@ SparkJarTask, TaskSettings, TaskRunCondition, + IfElseConditionTask, ) from brickflow.engine.workflow import Workflow, WorkflowPermissions, User @@ -123,6 +124,15 @@ def sample_sql_dashboard() -> any: ) +@wf.if_else_condition_task(depends_on=[sample_sql_task_query]) +def condtion_task_test() -> any: + return IfElseConditionTask( + left="1", + op="==", + right="2", + ) + + @wf.dlt_task def dlt_pipeline(): # pass diff --git a/tests/engine/test_task.py b/tests/engine/test_task.py index 2070b891..daac8435 100644 --- a/tests/engine/test_task.py +++ b/tests/engine/test_task.py @@ -6,7 +6,11 @@ import pytest from deepdiff import DeepDiff from brickflow.engine.utils import get_job_id -from brickflow import BrickflowProjectDeploymentSettings, SparkJarTask +from brickflow import ( + BrickflowProjectDeploymentSettings, + SparkJarTask, + IfElseConditionTask, +) from brickflow.context import ( ctx, BRANCH_SKIP_EXCEPT, @@ -531,6 +535,13 @@ def test_without_params_spark_jar(self): assert task.jar_uri == "test_uri" assert task.parameters is None + def if_else_condition_task(self): + # Test the __init__ method + instance = IfElseConditionTask(left="left_value", right="right_value", op="==") + assert instance.left == "left_value" + assert instance.right == "right_value" + assert instance.op == "==" + @patch("brickflow.bundles.model.JobsTasksSqlTaskAlert") @patch("brickflow.engine.task.SqlTask") def test_alert_creation(self, mock_sql_task, mock_alert):