diff --git a/brickflow/codegen/databricks_bundle.py b/brickflow/codegen/databricks_bundle.py index 484a8205..5ef0da15 100644 --- a/brickflow/codegen/databricks_bundle.py +++ b/brickflow/codegen/databricks_bundle.py @@ -392,6 +392,7 @@ def workflow_obj_to_schedule(workflow: Workflow) -> Optional[JobsSchedule]: return JobsSchedule( quartz_cron_expression=workflow.schedule_quartz_expression, timezone_id=workflow.timezone, + pause_status=workflow.schedule_pause_status, ) return None diff --git a/brickflow/engine/workflow.py b/brickflow/engine/workflow.py index 2a7a5568..2f11746e 100644 --- a/brickflow/engine/workflow.py +++ b/brickflow/engine/workflow.py @@ -29,6 +29,10 @@ from brickflow.engine.utils import wraps_keyerror +class WorkflowConfigError(Exception): + pass + + class NoWorkflowComputeError(Exception): pass @@ -109,6 +113,7 @@ class Workflow: _name: str schedule_quartz_expression: Optional[str] = None timezone: str = "UTC" + schedule_pause_status: str = "UNPAUSED" default_cluster: Optional[Cluster] = None clusters: List[Cluster] = field(default_factory=lambda: []) default_task_settings: TaskSettings = TaskSettings() @@ -155,6 +160,13 @@ def __post_init__(self) -> None: # the default cluster is set to the first cluster if it is not configured self.default_cluster = self.clusters[0] + self.schedule_pause_status = self.schedule_pause_status.upper() + allowed_scheduled_pause_statuses = ["PAUSED", "UNPAUSED"] + if self.schedule_pause_status not in allowed_scheduled_pause_statuses: + raise WorkflowConfigError( + f"schedule_pause_status must be one of {allowed_scheduled_pause_statuses}" + ) + # def __hash__(self) -> int: # import json # diff --git a/docs/workflows.md b/docs/workflows.md index 91201a60..30942d6f 100644 --- a/docs/workflows.md +++ b/docs/workflows.md @@ -15,6 +15,7 @@ wf = Workflow( # (1)! # Optional parameters below schedule_quartz_expression="0 0/20 0 ? * * *", # (4)! timezone="UTC", # (5)! + schedule_pause_status="PAUSED", # (15)! default_task_settings=TaskSettings( # (6)! email_notifications=EmailNotifications( on_start=["email@nike.com"], @@ -65,6 +66,7 @@ def task_function(*, test="var"): 12. Suffix for the name of the workflow 13. Define the common task parameters that can be used in all the tasks 14. Define a workflow task and associate it to the workflow +15. Define the schedule pause status. It is defaulted to "UNPAUSED" ### Clusters diff --git a/tests/codegen/expected_bundles/dev_bundle_monorepo.yml b/tests/codegen/expected_bundles/dev_bundle_monorepo.yml index 5ce9a93e..f56dd142 100644 --- a/tests/codegen/expected_bundles/dev_bundle_monorepo.yml +++ b/tests/codegen/expected_bundles/dev_bundle_monorepo.yml @@ -29,6 +29,7 @@ environments: schedule: quartz_cron_expression: '* * * * *' timezone_id: UTC + pause_status: "UNPAUSED" tags: brickflow_project_name: test-project brickflow_deployment_mode: Databricks Asset Bundles diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml index 2de3ff40..8171e12c 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo.yml @@ -29,6 +29,7 @@ environments: schedule: quartz_cron_expression: '* * * * *' timezone_id: UTC + pause_status: "UNPAUSED" tags: brickflow_project_name: test-project brickflow_deployment_mode: Databricks Asset Bundles diff --git a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml index c1423b3b..fac192de 100644 --- a/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml +++ b/tests/codegen/expected_bundles/dev_bundle_polyrepo_with_auto_libs.yml @@ -85,6 +85,7 @@ environments: schedule: quartz_cron_expression: '* * * * *' timezone_id: UTC + pause_status: "UNPAUSED" tags: brickflow_project_name: test-project brickflow_deployment_mode: Databricks Asset Bundles diff --git a/tests/codegen/expected_bundles/local_bundle.yml b/tests/codegen/expected_bundles/local_bundle.yml index c394bb92..0cbeaebd 100644 --- a/tests/codegen/expected_bundles/local_bundle.yml +++ b/tests/codegen/expected_bundles/local_bundle.yml @@ -26,6 +26,7 @@ environments: schedule: quartz_cron_expression: '* * * * *' timezone_id: UTC + pause_status: "UNPAUSED" tags: brickflow_project_name: test-project brickflow_deployment_mode: Databricks Asset Bundles diff --git a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml index 94593ccf..e91e2279 100644 --- a/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml +++ b/tests/codegen/expected_bundles/local_bundle_prefix_suffix.yml @@ -26,6 +26,7 @@ environments: schedule: quartz_cron_expression: '* * * * *' timezone_id: UTC + pause_status: "UNPAUSED" tags: brickflow_project_name: test-project brickflow_deployment_mode: Databricks Asset Bundles diff --git a/tests/engine/test_workflow.py b/tests/engine/test_workflow.py index fbb52438..23f6cd98 100644 --- a/tests/engine/test_workflow.py +++ b/tests/engine/test_workflow.py @@ -16,6 +16,7 @@ ServicePrincipal, Workflow, NoWorkflowComputeError, + WorkflowConfigError, ) from tests.engine.sample_workflow import wf, task_function @@ -226,3 +227,31 @@ def test_another_workflow(self): assert len(wf1.graph.nodes) == 2 assert len(wf.graph.nodes) == 10 + + def test_schedule_run_status_workflow(self): + this_wf = Workflow("test", clusters=[Cluster("name", "spark", "vm-node")]) + assert this_wf.schedule_pause_status == "UNPAUSED" + + this_wf = Workflow( + "test", + clusters=[Cluster("name", "spark", "vm-node")], + schedule_pause_status="PAUSED", + ) + assert this_wf.schedule_pause_status == "PAUSED" + + this_wf = Workflow( + "test", + clusters=[Cluster("name", "spark", "vm-node")], + schedule_pause_status="paused", + ) + assert this_wf.schedule_pause_status == "PAUSED" + + with pytest.raises(WorkflowConfigError) as excinfo: + Workflow( + "test", + clusters=[Cluster("name", "spark", "vm-node")], + schedule_pause_status="invalid", + ) + assert "schedule_pause_status must be one of ['PAUSED', 'UNPAUSED']" == str( + excinfo.value + )