From 039b74c069a9df243551c839065093a909b21486 Mon Sep 17 00:00:00 2001 From: Mikita Sakalouski <38785549+mikita-sakalouski@users.noreply.github.com> Date: Mon, 21 Oct 2024 10:58:04 +0200 Subject: [PATCH] [fix] Adjust logic for applying outcome for IF_ELSE_CONDITION tasks (#173) * feat: enhance if-else condition handling in Task class --- brickflow/engine/task.py | 18 +++-- .../expected_bundles/dev_bundle_monorepo.yml | 51 ++++++++++++++ .../expected_bundles/dev_bundle_polyrepo.yml | 51 ++++++++++++++ .../dev_bundle_polyrepo_with_auto_libs.yml | 57 +++++++++++++++ .../codegen/expected_bundles/local_bundle.yml | 69 ++++++++++++++++++- .../local_bundle_prefix_suffix.yml | 51 ++++++++++++++ tests/codegen/sample_workflows.py | 35 ++++++++++ 7 files changed, 324 insertions(+), 8 deletions(-) diff --git a/brickflow/engine/task.py b/brickflow/engine/task.py index a1494821..d8d9ccdb 100644 --- a/brickflow/engine/task.py +++ b/brickflow/engine/task.py @@ -822,15 +822,21 @@ def parents(self) -> List[str]: @property def depends_on_names(self) -> Iterator[Dict[str, Optional[str]]]: for i in self.depends_on: - if self.if_else_outcome: - outcome = list(self.if_else_outcome.values())[0] + task_name = i.__name__ if callable(i) and hasattr(i, "__name__") else str(i) + if ( + self.workflow.get_task(task_name).task_type + == TaskType.IF_ELSE_CONDITION_TASK + and self.if_else_outcome + ): + outcome = self.if_else_outcome.get(task_name) + if not outcome: + raise ValueError( + f"Task {task_name} is an if else condition task and does not have an outcome" + ) else: outcome = None - if callable(i) and hasattr(i, "__name__"): - yield {i.__name__: outcome} - else: - yield {str(i): outcome} + yield {task_name: outcome} @property def databricks_task_type_str(self) -> str: diff --git a/tests/codegen/expected_bundles/dev_bundle_monorepo.yml b/tests/codegen/expected_bundles/dev_bundle_monorepo.yml index 7a87bc49..808ca8bf 100644 --- a/tests/codegen/expected_bundles/dev_bundle_monorepo.yml +++ b/tests/codegen/expected_bundles/dev_bundle_monorepo.yml @@ -132,6 +132,45 @@ targets: retry_on_timeout: null task_key: spark_python_task_a timeout_seconds: null + - depends_on: + - task_key: spark_python_task_a + - outcome: "false" + task_key: condition_task_test2 + email_notifications: {} + existing_cluster_id: existing_cluster_id + libraries: + - pypi: + package: koheesio + repo: null + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + spark_python_task: + parameters: ["--param1", "World!", "all_tasks1", "test", "all_tasks3", "123"] + python_file: ./products/test-project/spark/python/src/run_task.py + source: GIT + task_key: spark_python_task_depended + timeout_seconds: null + - depends_on: + - outcome: "true" + task_key: condtion_task_test + - outcome: "false" + task_key: condition_task_test2 + email_notifications: {} + existing_cluster_id: existing_cluster_id + libraries: + - pypi: + package: koheesio + repo: null + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + spark_python_task: + parameters: ["--param1", "World!", "all_tasks1", "test", "all_tasks3", "123"] + python_file: ./products/test-project/spark/python/src/run_task.py + source: GIT + task_key: spark_python_task_depended2 + timeout_seconds: null - depends_on: - task_key: notebook_task_a email_notifications: {} @@ -443,6 +482,18 @@ targets: right: "2" task_key: "condtion_task_test" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "1" + task_key: "condition_task_test2" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml index ce833ea4..845d8d40 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml @@ -132,6 +132,45 @@ targets: retry_on_timeout: null task_key: spark_python_task_a timeout_seconds: null + - depends_on: + - task_key: spark_python_task_a + - outcome: "false" + task_key: condition_task_test2 + email_notifications: {} + existing_cluster_id: existing_cluster_id + libraries: + - pypi: + package: koheesio + repo: null + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + spark_python_task: + parameters: ["--param1", "World!", "all_tasks1", "test", "all_tasks3", "123"] + python_file: ./products/test-project/spark/python/src/run_task.py + source: GIT + task_key: spark_python_task_depended + timeout_seconds: null + - depends_on: + - outcome: "true" + task_key: condtion_task_test + - outcome: "false" + task_key: condition_task_test2 + email_notifications: {} + existing_cluster_id: existing_cluster_id + libraries: + - pypi: + package: koheesio + repo: null + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + spark_python_task: + parameters: ["--param1", "World!", "all_tasks1", "test", "all_tasks3", "123"] + python_file: ./products/test-project/spark/python/src/run_task.py + source: GIT + task_key: spark_python_task_depended2 + timeout_seconds: null - depends_on: - task_key: notebook_task_a email_notifications: {} @@ -443,6 +482,18 @@ targets: right: "2" task_key: "condtion_task_test" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "1" + task_key: "condition_task_test2" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml index d5f37d6f..a190f3f2 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml @@ -217,6 +217,51 @@ targets: retry_on_timeout: null task_key: spark_python_task_a timeout_seconds: null + - depends_on: + - task_key: spark_python_task_a + - outcome: "false" + task_key: condition_task_test2 + email_notifications: {} + existing_cluster_id: existing_cluster_id + libraries: + - pypi: + package: koheesio + repo: null + - pypi: + package: "brickflows==0.1.0" + repo: null + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + spark_python_task: + parameters: ["--param1", "World!", "all_tasks1", "test", "all_tasks3", "123"] + python_file: ./products/test-project/spark/python/src/run_task.py + source: GIT + task_key: spark_python_task_depended + timeout_seconds: null + - depends_on: + - outcome: "true" + task_key: condtion_task_test + - outcome: "false" + task_key: condition_task_test2 + email_notifications: {} + existing_cluster_id: existing_cluster_id + libraries: + - pypi: + package: koheesio + repo: null + - pypi: + package: "brickflows==0.1.0" + repo: null + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + spark_python_task: + parameters: ["--param1", "World!", "all_tasks1", "test", "all_tasks3", "123"] + python_file: ./products/test-project/spark/python/src/run_task.py + source: GIT + task_key: spark_python_task_depended2 + timeout_seconds: null - depends_on: - task_key: notebook_task_a email_notifications: {} @@ -552,6 +597,18 @@ targets: right: "2" task_key: "condtion_task_test" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "1" + task_key: "condition_task_test2" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/expected_bundles/local_bundle.yml b/tests/codegen/expected_bundles/local_bundle.yml index 9fb26203..2c3a750b 100644 --- a/tests/codegen/expected_bundles/local_bundle.yml +++ b/tests/codegen/expected_bundles/local_bundle.yml @@ -57,6 +57,19 @@ "retry_on_timeout": null "task_key": "condtion_task_test" "timeout_seconds": null + - "condition_task": + "left": "1" + "op": "EQUAL_TO" + "right": "1" + "depends_on": + - "outcome": null + "task_key": "sample_sql_task_query" + "email_notifications": {} + "max_retries": null + "min_retry_interval_millis": null + "retry_on_timeout": null + "task_key": "condition_task_test2" + "timeout_seconds": null - "depends_on": [] "email_notifications": {} "existing_cluster_id": "existing_cluster_id" @@ -236,15 +249,65 @@ "email_notifications": {} "existing_cluster_id": "existing_cluster_id" "libraries": - - "pypi": + - "pypi": + "package": "koheesio" + "repo": null + "max_retries": null + "min_retry_interval_millis": null + "retry_on_timeout": null + "spark_python_task": + "parameters": + - "--param1" + - "World!" + - "all_tasks1" + - "test" + - "all_tasks3" + - "123" + "python_file": "/Workspace/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/files/spark/python/src/run_task.py" + "source": "WORKSPACE" + "task_key": "spark_python_task_a" + "timeout_seconds": null + - "depends_on": + - "outcome": null + "task_key": "spark_python_task_a" + - "outcome": "false" + "task_key": "condition_task_test2" + "email_notifications": {} + "existing_cluster_id": "existing_cluster_id" + "libraries": + - "pypi": "package": "koheesio" "repo": null "max_retries": null "min_retry_interval_millis": null "retry_on_timeout": null "spark_python_task": + "parameters": + - "--param1" + - "World!" + - "all_tasks1" + - "test" + - "all_tasks3" + - "123" "python_file": "/Workspace/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/files/spark/python/src/run_task.py" "source": "WORKSPACE" + "task_key": "spark_python_task_depended" + "timeout_seconds": null + - "depends_on": + - "outcome": "true" + "task_key": "condtion_task_test" + - "outcome": "false" + "task_key": "condition_task_test2" + "email_notifications": {} + "existing_cluster_id": "existing_cluster_id" + "libraries": + - "pypi": + "package": "koheesio" + "repo": null + "max_retries": null + "min_retry_interval_millis": null + "retry_on_timeout": null + "spark_python_task": "parameters": - "--param1" - "World!" @@ -252,7 +315,9 @@ - "test" - "all_tasks3" - "123" - "task_key": "spark_python_task_a" + "python_file": "/Workspace/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/files/spark/python/src/run_task.py" + "source": "WORKSPACE" + "task_key": "spark_python_task_depended2" "timeout_seconds": null - "depends_on": [] "email_notifications": {} diff --git a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml index 66cecb1c..179de816 100644 --- a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml +++ b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml @@ -129,6 +129,45 @@ targets: retry_on_timeout: null task_key: spark_python_task_a timeout_seconds: null + - depends_on: + - task_key: spark_python_task_a + - outcome: "false" + task_key: condition_task_test2 + email_notifications: {} + existing_cluster_id: existing_cluster_id + libraries: + - pypi: + package: koheesio + repo: null + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + spark_python_task: + python_file: /Workspace/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/files/spark/python/src/run_task.py + parameters: ["--param1", "World!", "all_tasks1", "test", "all_tasks3", "123"] + source: WORKSPACE + task_key: spark_python_task_depended + timeout_seconds: null + - depends_on: + - outcome: "true" + task_key: condtion_task_test + - outcome: "false" + task_key: condition_task_test2 + email_notifications: {} + existing_cluster_id: existing_cluster_id + libraries: + - pypi: + package: koheesio + repo: null + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + spark_python_task: + python_file: /Workspace/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/files/spark/python/src/run_task.py + parameters: ["--param1", "World!", "all_tasks1", "test", "all_tasks3", "123"] + source: WORKSPACE + task_key: spark_python_task_depended2 + timeout_seconds: null - depends_on: - task_key: notebook_task_a email_notifications: {} @@ -440,6 +479,18 @@ targets: right: "2" task_key: "condtion_task_test" timeout_seconds: null + - depends_on: + - task_key: "sample_sql_task_query" + email_notifications: {} + max_retries: null + min_retry_interval_millis: null + retry_on_timeout: null + condition_task: + left: "1" + op: "EQUAL_TO" + right: "1" + task_key: "condition_task_test2" + timeout_seconds: null pipelines: test_hello_world: catalog: null diff --git a/tests/codegen/sample_workflows.py b/tests/codegen/sample_workflows.py index 78137cdf..724423f4 100644 --- a/tests/codegen/sample_workflows.py +++ b/tests/codegen/sample_workflows.py @@ -174,6 +174,41 @@ def condtion_task_test() -> any: ) +@wf.if_else_condition_task(depends_on=[sample_sql_task_query]) +def condition_task_test2() -> any: + return IfElseConditionTask( + left="1", + op="==", + right="1", + ) + + +@wf.spark_python_task( + libraries=[PypiTaskLibrary(package="koheesio")], + depends_on=[spark_python_task_a, condition_task_test2], + if_else_outcome={"condition_task_test2": "false"}, +) +def spark_python_task_depended(): + return SparkPythonTask( + python_file="./products/test-project/spark/python/src/run_task.py", + source="GIT", + parameters=["--param1", "World!"], + ) # type: ignore + + +@wf.spark_python_task( + libraries=[PypiTaskLibrary(package="koheesio")], + depends_on=[condtion_task_test, condition_task_test2], + if_else_outcome={"condtion_task_test": "true", "condition_task_test2": "false"}, +) +def spark_python_task_depended2(): + return SparkPythonTask( + python_file="./products/test-project/spark/python/src/run_task.py", + source="GIT", + parameters=["--param1", "World!"], + ) # type: ignore + + @wf.dlt_task def dlt_pipeline(): # pass