Skip to content

Commit

Permalink
Tests for run job task type in for each task
Browse files Browse the repository at this point in the history
  • Loading branch information
riccamini committed Dec 12, 2024
1 parent a688ce1 commit ebb12b1
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 2 deletions.
32 changes: 32 additions & 0 deletions tests/codegen/expected_bundles/local_bundle_foreach_task.yml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@
brickflow_start_time: "{{start_time}}"
brickflow_task_key: "{{task_key}}"
brickflow_task_retry_count: "{{task_retry_count}}"
looped_parameter: "{{input}}"
notebook_path: test_databricks_bundle.py
source: WORKSPACE # TODO check if this needs to be GIT for BF tasks
job_cluster_key: sample_job_cluster
Expand All @@ -115,6 +116,37 @@
job_cluster_key: sample_job_cluster
libraries:
- jar: dbfs:/some/path/to/The.jar
- task_key: for_each_spark_python
depends_on:
- task_key: first_notebook
email_notifications: {}
for_each_task:
inputs: "[1,2,3]"
concurrency: 1
task:
"email_notifications": { }
"depends_on": [ ]
task_key: for_each_spark_python_nested
"libraries": []
spark_python_task:
python_file: "/Workspace/Users/${workspace.current_user.userName}/.brickflow_bundles/test-project/local/files/test-project/path/to/python_script.py"
parameters:
- "{{input}}"
source: WORKSPACE
job_cluster_key: sample_job_cluster
- task_key: for_each_run_job
depends_on:
- task_key: first_notebook
email_notifications: {}
for_each_task:
inputs: "[\"job_param_1\",\"job_param_2\"]"
concurrency: 1
task:
"depends_on": [ ]
"email_notifications": { }
task_key: for_each_run_job_nested
run_job_task:
job_id: 12345678901234.0
"timeout_seconds": null
"trigger": null
"webhook_notifications": null
Expand Down
26 changes: 24 additions & 2 deletions tests/codegen/sample_workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,8 +438,8 @@ def for_each_notebook():
for_each_task_inputs=["1", "2", "3"],
for_each_task_concurrency=1,
)
def for_each_bf_task():
print("This is a bf task!")
def for_each_bf_task(*, looped_parameter="{{input}}"):
print(f"This is a nested bf task running with input: {looped_parameter}")


@wf3.for_each_task(
Expand All @@ -453,3 +453,25 @@ def for_each_spark_jar():
main_class_name="com.example.MainClass",
parameters=["{{input}}"],
)


@wf3.for_each_task(
depends_on=first_notebook,
for_each_task_inputs="[1,2,3]",
for_each_task_concurrency=1,
)
def for_each_spark_python():
return SparkPythonTask(
python_file="/test-project/path/to/python_script.py",
source="WORKSPACE",
parameters=["{{input}}"],
)


@wf3.for_each_task(
depends_on=first_notebook,
for_each_task_inputs='["job_param_1","job_param_2"]',
for_each_task_concurrency=1,
)
def for_each_run_job():
return RunJobTask(job_name="some_job_name")
6 changes: 6 additions & 0 deletions tests/codegen/test_databricks_bundle.py
Original file line number Diff line number Diff line change
Expand Up @@ -739,8 +739,10 @@ def test_generate_serverless_bundle_local(
"brickflow.context.ctx.get_current_timestamp",
MagicMock(return_value=1704067200000),
)
@patch("brickflow.codegen.databricks_bundle.MultiProjectManager")
def test_foreach_task(
self,
multi_project_manager_mock: Mock,
bf_version_mock: Mock,
dbutils: Mock,
sub_proc_mock: Mock,
Expand All @@ -751,6 +753,10 @@ def test_foreach_task(
bf_version_mock.return_value = "1.0.0"
workspace_client = get_workspace_client_mock()
get_job_id_mock.return_value = 12345678901234.0

multi_project_manager_mock.return_value.get_project.return_value = MagicMock(
path_from_repo_root_to_project_root="test-project"
)
# get caller part breaks here
with Project(
"test-project",
Expand Down

0 comments on commit ebb12b1

Please sign in to comment.