Skip to content

Commit

Permalink
Support for spark jar task type in for each task
Browse files Browse the repository at this point in the history
  • Loading branch information
riccamini committed Dec 10, 2024
1 parent b46bd32 commit a688ce1
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 0 deletions.
2 changes: 2 additions & 0 deletions brickflow/engine/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,7 @@ def for_each_task(
name: Optional[str] = None,
task_settings: Optional[TaskSettings] = None,
depends_on: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
libraries: Optional[List[TaskLibrary]] = None,
if_else_outcome: Optional[Dict[Union[str, str], str]] = None,
for_each_task_inputs: Optional[str] = None,
for_each_task_concurrency: Optional[int] = 1,
Expand All @@ -603,6 +604,7 @@ def for_each_task(
task_type=TaskType.FOR_EACH_TASK,
task_settings=task_settings,
depends_on=depends_on,
libraries=libraries,
if_else_outcome=if_else_outcome,
for_each_task_inputs=for_each_task_inputs,
for_each_task_concurrency=for_each_task_concurrency,
Expand Down
18 changes: 18 additions & 0 deletions tests/codegen/expected_bundles/local_bundle_foreach_task.yml
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,24 @@
notebook_path: test_databricks_bundle.py
source: WORKSPACE # TODO check if this needs to be GIT for BF tasks
job_cluster_key: sample_job_cluster
- task_key: for_each_spark_jar
depends_on:
- task_key: for_each_bf_task
email_notifications: {}
for_each_task:
inputs: "[1,2,3]"
concurrency: 1
task:
"depends_on": [ ]
"email_notifications": { }
task_key: for_each_spark_jar_nested
spark_jar_task:
main_class_name: com.example.MainClass
parameters:
- "{{input}}"
job_cluster_key: sample_job_cluster
libraries:
- jar: dbfs:/some/path/to/The.jar
"timeout_seconds": null
"trigger": null
"webhook_notifications": null
Expand Down
13 changes: 13 additions & 0 deletions tests/codegen/sample_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,16 @@ def for_each_notebook():
)
def for_each_bf_task():
print("This is a bf task!")


@wf3.for_each_task(
depends_on=for_each_bf_task,
for_each_task_inputs="[1,2,3]",
for_each_task_concurrency=1,
libraries=[JarTaskLibrary(jar="dbfs:/some/path/to/The.jar")],
)
def for_each_spark_jar():
return SparkJarTask(
main_class_name="com.example.MainClass",
parameters=["{{input}}"],
)

0 comments on commit a688ce1

Please sign in to comment.