From a688ce1c2118c05c24926849ac50cd316ba269fd Mon Sep 17 00:00:00 2001 From: riccamini Date: Tue, 10 Dec 2024 12:36:09 +0100 Subject: [PATCH] Support for spark jar task type in for each task --- brickflow/engine/workflow.py | 2 ++ .../local_bundle_foreach_task.yml | 18 ++++++++++++++++++ tests/codegen/sample_workflows.py | 13 +++++++++++++ 3 files changed, 33 insertions(+) diff --git a/brickflow/engine/workflow.py b/brickflow/engine/workflow.py index d74b8e2..97925da 100644 --- a/brickflow/engine/workflow.py +++ b/brickflow/engine/workflow.py @@ -593,6 +593,7 @@ def for_each_task( name: Optional[str] = None, task_settings: Optional[TaskSettings] = None, depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None, + libraries: Optional[List[TaskLibrary]] = None, if_else_outcome: Optional[Dict[Union[str, str], str]] = None, for_each_task_inputs: Optional[str] = None, for_each_task_concurrency: Optional[int] = 1, @@ -603,6 +604,7 @@ def for_each_task( task_type=TaskType.FOR_EACH_TASK, task_settings=task_settings, depends_on=depends_on, + libraries=libraries, if_else_outcome=if_else_outcome, for_each_task_inputs=for_each_task_inputs, for_each_task_concurrency=for_each_task_concurrency, diff --git a/tests/codegen/expected_bundles/local_bundle_foreach_task.yml b/tests/codegen/expected_bundles/local_bundle_foreach_task.yml index 5d563f3..0b83da5 100644 --- a/tests/codegen/expected_bundles/local_bundle_foreach_task.yml +++ b/tests/codegen/expected_bundles/local_bundle_foreach_task.yml @@ -97,6 +97,24 @@ notebook_path: test_databricks_bundle.py source: WORKSPACE # TODO check if this needs to be GIT for BF tasks job_cluster_key: sample_job_cluster + - task_key: for_each_spark_jar + depends_on: + - task_key: for_each_bf_task + email_notifications: {} + for_each_task: + inputs: "[1,2,3]" + concurrency: 1 + task: + "depends_on": [ ] + "email_notifications": { } + task_key: for_each_spark_jar_nested + spark_jar_task: + main_class_name: com.example.MainClass + parameters: + - "{{input}}" + job_cluster_key: sample_job_cluster + libraries: + - jar: dbfs:/some/path/to/The.jar "timeout_seconds": null "trigger": null "webhook_notifications": null diff --git a/tests/codegen/sample_workflows.py b/tests/codegen/sample_workflows.py index c654c09..4f0d44e 100644 --- a/tests/codegen/sample_workflows.py +++ b/tests/codegen/sample_workflows.py @@ -440,3 +440,16 @@ def for_each_notebook(): ) def for_each_bf_task(): print("This is a bf task!") + + +@wf3.for_each_task( + depends_on=for_each_bf_task, + for_each_task_inputs="[1,2,3]", + for_each_task_concurrency=1, + libraries=[JarTaskLibrary(jar="dbfs:/some/path/to/The.jar")], +) +def for_each_spark_jar(): + return SparkJarTask( + main_class_name="com.example.MainClass", + parameters=["{{input}}"], + )