diff --git a/airflow/providers/databricks/operators/databricks_workflow.py b/airflow/providers/databricks/operators/databricks_workflow.py index 15333dc69118..6df8e2d025ce 100644 --- a/airflow/providers/databricks/operators/databricks_workflow.py +++ b/airflow/providers/databricks/operators/databricks_workflow.py @@ -52,7 +52,7 @@ class WorkflowRunMetadata: """ conn_id: str - job_id: str + job_id: int run_id: int @@ -116,6 +116,7 @@ def __init__( self.notebook_params = notebook_params or {} self.tasks_to_convert = tasks_to_convert or [] self.relevant_upstreams = [task_id] + self.workflow_run_metadata: WorkflowRunMetadata | None = None super().__init__(task_id=task_id, **kwargs) def _get_hook(self, caller: str) -> DatabricksHook: @@ -212,12 +213,36 @@ def execute(self, context: Context) -> Any: self._wait_for_job_to_start(run_id) + self.workflow_run_metadata = WorkflowRunMetadata( + self.databricks_conn_id, + job_id, + run_id, + ) + return { "conn_id": self.databricks_conn_id, "job_id": job_id, "run_id": run_id, } + def on_kill(self) -> None: + if self.workflow_run_metadata: + run_id = self.workflow_run_metadata.run_id + job_id = self.workflow_run_metadata.job_id + + self._hook.cancel_run(run_id) + self.log.info( + "Run: %(run_id)s of job_id: %(job_id)s was requested to be cancelled.", + {"run_id": run_id, "job_id": job_id}, + ) + else: + self.log.error( + """ + Error: Workflow Run metadata is not populated, so the run was not canceled. This could be due + to the workflow not being started or an error in the workflow creation process. + """ + ) + class DatabricksWorkflowTaskGroup(TaskGroup): """ diff --git a/tests/providers/databricks/operators/test_databricks_workflow.py b/tests/providers/databricks/operators/test_databricks_workflow.py index 4c3f54b800ae..fbc429ed1d9a 100644 --- a/tests/providers/databricks/operators/test_databricks_workflow.py +++ b/tests/providers/databricks/operators/test_databricks_workflow.py @@ -28,6 +28,7 @@ from airflow.providers.databricks.hooks.databricks import RunLifeCycleState from airflow.providers.databricks.operators.databricks_workflow import ( DatabricksWorkflowTaskGroup, + WorkflowRunMetadata, _CreateDatabricksWorkflowOperator, _flatten_node, ) @@ -59,6 +60,11 @@ def mock_task_group(): return mock_group +@pytest.fixture +def mock_workflow_run_metadata(): + return MagicMock(spec=WorkflowRunMetadata) + + def test_flatten_node(): """Test that _flatten_node returns a flat list of operators.""" task_group = MagicMock(spec=DatabricksWorkflowTaskGroup) @@ -231,3 +237,19 @@ def test_task_group_root_tasks_set_upstream_to_operator(mock_databricks_workflow create_operator_instance = mock_databricks_workflow_operator.return_value task1.set_upstream.assert_called_once_with(create_operator_instance) + + +def test_on_kill(mock_databricks_hook, context, mock_workflow_run_metadata): + """Test that _CreateDatabricksWorkflowOperator.execute runs the task group.""" + operator = _CreateDatabricksWorkflowOperator(task_id="test_task", databricks_conn_id="databricks_default") + operator.workflow_run_metadata = mock_workflow_run_metadata + + RUN_ID = 789 + + mock_workflow_run_metadata.conn_id = operator.databricks_conn_id + mock_workflow_run_metadata.job_id = "123" + mock_workflow_run_metadata.run_id = RUN_ID + + operator.on_kill() + + operator._hook.cancel_run.assert_called_once_with(RUN_ID)