diff --git a/brickflow_plugins/airflow/operators/external_tasks.py b/brickflow_plugins/airflow/operators/external_tasks.py index 5f72f60c..30385f6e 100644 --- a/brickflow_plugins/airflow/operators/external_tasks.py +++ b/brickflow_plugins/airflow/operators/external_tasks.py @@ -7,6 +7,7 @@ import requests from airflow.models import Connection from airflow.sensors.base import BaseSensorOperator +from airflow.exceptions import AirflowSensorTimeout from requests.adapters import HTTPAdapter from requests.packages.urllib3.util.retry import Retry from requests import HTTPError @@ -217,8 +218,9 @@ def __init__( self.latest = latest self.poke_interval = poke_interval self._poke_count = 0 + self._start_time = time.time() - def get_execution_stats(self): + def get_execution_stats(self, execution_date: datetime): """Function to get the execution stats for task_id within a execution delta window Returns: @@ -231,7 +233,7 @@ def get_execution_stats(self): external_dag_id = self.external_dag_id external_task_id = self.external_task_id execution_delta = self.execution_delta - execution_window_tz = (datetime.now() + execution_delta).strftime( + execution_window_tz = (execution_date + execution_delta).strftime( "%Y-%m-%dT%H:%M:%SZ" ) headers = { @@ -257,17 +259,20 @@ def get_execution_stats(self): + f"/dagRuns?execution_date_gte={execution_window_tz}" ) log.info(f"URL to poke for dag runs {url}") - response = requests.request("GET", url, headers=headers) - if response.status_code == 401: - raise Exception( - f"No Runs found for {external_dag_id} dag after {execution_window_tz}, Please check upstream dag" - ) + response = requests.request("GET", url, headers=headers, verify=False) response.raise_for_status() - list_of_dictionaries = response.json() + list_of_dictionaries = response.json()["dag_runs"] list_of_dictionaries = sorted( list_of_dictionaries, key=lambda k: k["execution_date"], reverse=True ) + + if len(list_of_dictionaries) == 0: + log.info( + f"No Runs found for {external_dag_id} dag after {execution_window_tz}, please check upstream dag" + ) + return "none" + if af_version.startswith("1."): # For airflow 1.X Execution date is needed to check the status of the task dag_run_id = list_of_dictionaries[0]["execution_date"] @@ -294,7 +299,7 @@ def get_execution_stats(self): + f"/dagRuns/{dag_run_id}/taskInstances/{external_task_id}" ) log.info(f"Pinging airflow API {task_url} for task status ") - task_response = requests.request("GET", task_url, headers=headers) + task_response = requests.request("GET", task_url, headers=headers, verify=False) task_response.raise_for_status() task_state = task_response.json()["state"] return task_state @@ -302,9 +307,16 @@ def get_execution_stats(self): def poke(self, context): log.info(f"executing poke.. {self._poke_count}") self._poke_count = self._poke_count + 1 - logging.info("Poking.. {0} round".format(str(self._poke_count))) - task_status = self.get_execution_stats() - log.info(f"task_status= {task_status}") + log.info("Poking.. {0} round".format(str(self._poke_count))) + + # Execution date is extracted from context and will be based on the task schedule, e.g. + # 0 0 1 ? * MON-SAT * -> 2024-01-01T01:00:00.000000+00:00 + # This means that the relative delta between workflow execution and target Airflow DAG always stays the same. + execution_date = parse(context["execution_date"]) + log.info(f"Execution date derived from context: {execution_date}") + + task_status = self.get_execution_stats(execution_date=execution_date) + log.info(f"task_status={task_status}") return task_status def execute(self, context): @@ -321,9 +333,9 @@ def execute(self, context): external_dag_id = self.external_dag_id external_task_id = self.external_task_id execution_delta = self.execution_delta - execution_window_tz = (datetime.now() + execution_delta).strftime( - "%Y-%m-%dT%H:%M:%SZ" - ) + execution_window_tz = ( + parse(context["execution_date"]) + execution_delta + ).strftime("%Y-%m-%dT%H:%M:%SZ") log.info( f"Executing TaskDependency Sensor Operator to check successful run for {external_dag_id} dag, task {external_task_id} after {execution_window_tz} " ) @@ -331,12 +343,16 @@ def execute(self, context): while status not in allowed_states: status = self.poke(context) if status == "failed": + # Log the fact that upstream failed, however do not fail the task and continue poking until timeout log.error( f"Upstream dag {external_dag_id} failed at {external_task_id} task " ) - raise Exception("Upstream Dag Failed") + time.sleep(self.poke_interval) elif status != "success": time.sleep(self.poke_interval) + + if (time.time() - self._start_time) > self.timeout: + raise AirflowSensorTimeout("The job has timed out") log.info(f"Upstream Dag {external_dag_id} is successful") diff --git a/tests/airflow_plugins/test_task_dependency.py b/tests/airflow_plugins/test_task_dependency.py new file mode 100644 index 00000000..9c176581 --- /dev/null +++ b/tests/airflow_plugins/test_task_dependency.py @@ -0,0 +1,130 @@ +from datetime import timedelta + +import pytest +from requests.exceptions import HTTPError +from requests_mock.mocker import Mocker as RequestsMocker + +from brickflow_plugins.airflow.operators.external_tasks import ( + TaskDependencySensor, + AirflowProxyOktaClusterAuth, + log, + AirflowSensorTimeout, +) + +BASE_URL = "https://42.airflow.my-org.com/foo" + + +class TestTaskDependencySensor: + log.propagate = True + + @pytest.fixture(autouse=True, name="api", scope="class") + def mock_api(self): + """ + End-to-end test scenario for Airflow v2 API + """ + rm = RequestsMocker() + # DAG Run Endpoint + rm.register_uri( + method="GET", + url=f"{BASE_URL}/api/v1/dags/test-dag/dagRuns?execution_date_gte=2024-01-01T00:00:00Z", + response_list=[ + # Test 1: No Run + {"json": {"dag_runs": [], "total_entries": 0}, "status_code": int(200)}, + # Test 2: Run Exists + { + "json": { + "dag_runs": [ + { + "conf": {}, + "dag_id": "test-dag", + "dag_run_id": "manual__2024-01-01T01:00:00.000000+00:00", + "end_date": "2024-01-01T01:10:00.000000+00:00", + "execution_date": "2024-01-01T01:00:00.000000+00:00", + "external_trigger": True, + "logical_date": "2024-01-01T01:00:00.000000+00:00", + "start_date": "2024-01-01T01:00:00.000000+00:00", + "state": "success", + }, + ], + "total_entries": 1, + }, + "status_code": int(200), + }, + # Test 2: No Run + ], + ) + # Task Instance Endpoint + rm.register_uri( + method="GET", + url=( + f"{BASE_URL}" + f"/api/v1/dags/test-dag/dagRuns/manual__2024-01-01T01:00:00.000000+00:00/taskInstances/test-task" + ), + response_list=[ + {"json": {"state": "running"}, "status_code": int(200)}, + {"json": {"state": "failed"}, "status_code": int(200)}, + {"json": {"state": "success"}, "status_code": int(200)}, + ], + ) + yield rm + + @pytest.fixture() + def sensor(self, mocker): + auth_mock = mocker.MagicMock(spec=AirflowProxyOktaClusterAuth) + auth_mock.get_access_token.return_value = "foo" + auth_mock.get_airflow_api_url.return_value = BASE_URL + auth_mock.get_version.return_value = "2.0.2" + + yield TaskDependencySensor( + external_dag_id="test-dag", + external_task_id="test-task", + allowed_states=["success"], + execution_delta=timedelta(**{"hours": -3}), + airflow_auth=auth_mock, + task_id="foo", + poke_interval=1, + ) + + def test_api_airflow_v2(self, api, caplog, sensor): + # Scenario + # Airflow API poked 4 times + # 1. No Run + # 2. Run Exists - Task is in Running State + # 3. Run Exists - Task is in Failed State + # 4. Run Exists - Task is in Success State + with api: + sensor.execute(context={"execution_date": "2024-01-01T03:00:00Z"}) + + assert ( + "No Runs found for test-dag dag after 2024-01-01T00:00:00Z, please check upstream dag" + in caplog.text + ) + assert "task_status=running" in caplog.text + assert "task_status=failed" in caplog.text + assert "task_status=success" in caplog.text + assert "Poking.. 4 round" in caplog.text + + def test_non_200(self, sensor): + rm = RequestsMocker() + rm.get( + f"{BASE_URL}/api/v1/dags/test-dag/dagRuns?execution_date_gte=2024-01-01T00:00:00Z", + status_code=404, + ) + + with pytest.raises(HTTPError): + with rm: + sensor.execute(context={"execution_date": "2024-01-01T03:00:00Z"}) + + def test_timeout(self, sensor): + sensor.timeout = 1 + + rm = RequestsMocker() + rm.get( + url=f"{BASE_URL}/api/v1/dags/test-dag/dagRuns?execution_date_gte=2024-01-01T00:00:00Z", + status_code=200, + json={"dag_runs": [], "total_entries": 0}, + ) + + with pytest.raises(AirflowSensorTimeout): + with rm: + sensor.execute(context={"execution_date": "2024-01-01T03:00:00Z"})