Skip to content

Commit

Permalink
TaskDependencySensor: do not fail if upstream failed / derive executi…
Browse files Browse the repository at this point in the history
…on time from context (#132)

* sensor improvements

* timeout is handled natively

* tests

---------

Co-authored-by: Maxim Mityutko <[email protected]>
  • Loading branch information
maxim-mityutko and Maxim Mityutko authored Jun 6, 2024
1 parent 8986064 commit 06d7a62
Show file tree
Hide file tree
Showing 2 changed files with 162 additions and 16 deletions.
48 changes: 32 additions & 16 deletions brickflow_plugins/airflow/operators/external_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import requests
from airflow.models import Connection
from airflow.sensors.base import BaseSensorOperator
from airflow.exceptions import AirflowSensorTimeout
from requests.adapters import HTTPAdapter
from requests.packages.urllib3.util.retry import Retry
from requests import HTTPError
Expand Down Expand Up @@ -217,8 +218,9 @@ def __init__(
self.latest = latest
self.poke_interval = poke_interval
self._poke_count = 0
self._start_time = time.time()

def get_execution_stats(self):
def get_execution_stats(self, execution_date: datetime):
"""Function to get the execution stats for task_id within a execution delta window
Returns:
Expand All @@ -231,7 +233,7 @@ def get_execution_stats(self):
external_dag_id = self.external_dag_id
external_task_id = self.external_task_id
execution_delta = self.execution_delta
execution_window_tz = (datetime.now() + execution_delta).strftime(
execution_window_tz = (execution_date + execution_delta).strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
headers = {
Expand All @@ -257,17 +259,20 @@ def get_execution_stats(self):
+ f"/dagRuns?execution_date_gte={execution_window_tz}"
)
log.info(f"URL to poke for dag runs {url}")
response = requests.request("GET", url, headers=headers)
if response.status_code == 401:
raise Exception(
f"No Runs found for {external_dag_id} dag after {execution_window_tz}, Please check upstream dag"
)
response = requests.request("GET", url, headers=headers, verify=False)
response.raise_for_status()
list_of_dictionaries = response.json()

list_of_dictionaries = response.json()["dag_runs"]
list_of_dictionaries = sorted(
list_of_dictionaries, key=lambda k: k["execution_date"], reverse=True
)

if len(list_of_dictionaries) == 0:
log.info(
f"No Runs found for {external_dag_id} dag after {execution_window_tz}, please check upstream dag"
)
return "none"

if af_version.startswith("1."):
# For airflow 1.X Execution date is needed to check the status of the task
dag_run_id = list_of_dictionaries[0]["execution_date"]
Expand All @@ -294,17 +299,24 @@ def get_execution_stats(self):
+ f"/dagRuns/{dag_run_id}/taskInstances/{external_task_id}"
)
log.info(f"Pinging airflow API {task_url} for task status ")
task_response = requests.request("GET", task_url, headers=headers)
task_response = requests.request("GET", task_url, headers=headers, verify=False)
task_response.raise_for_status()
task_state = task_response.json()["state"]
return task_state

def poke(self, context):
log.info(f"executing poke.. {self._poke_count}")
self._poke_count = self._poke_count + 1
logging.info("Poking.. {0} round".format(str(self._poke_count)))
task_status = self.get_execution_stats()
log.info(f"task_status= {task_status}")
log.info("Poking.. {0} round".format(str(self._poke_count)))

# Execution date is extracted from context and will be based on the task schedule, e.g.
# 0 0 1 ? * MON-SAT * -> 2024-01-01T01:00:00.000000+00:00
# This means that the relative delta between workflow execution and target Airflow DAG always stays the same.
execution_date = parse(context["execution_date"])
log.info(f"Execution date derived from context: {execution_date}")

task_status = self.get_execution_stats(execution_date=execution_date)
log.info(f"task_status={task_status}")
return task_status

def execute(self, context):
Expand All @@ -321,22 +333,26 @@ def execute(self, context):
external_dag_id = self.external_dag_id
external_task_id = self.external_task_id
execution_delta = self.execution_delta
execution_window_tz = (datetime.now() + execution_delta).strftime(
"%Y-%m-%dT%H:%M:%SZ"
)
execution_window_tz = (
parse(context["execution_date"]) + execution_delta
).strftime("%Y-%m-%dT%H:%M:%SZ")
log.info(
f"Executing TaskDependency Sensor Operator to check successful run for {external_dag_id} dag, task {external_task_id} after {execution_window_tz} "
)
status = ""
while status not in allowed_states:
status = self.poke(context)
if status == "failed":
# Log the fact that upstream failed, however do not fail the task and continue poking until timeout
log.error(
f"Upstream dag {external_dag_id} failed at {external_task_id} task "
)
raise Exception("Upstream Dag Failed")
time.sleep(self.poke_interval)
elif status != "success":
time.sleep(self.poke_interval)

if (time.time() - self._start_time) > self.timeout:
raise AirflowSensorTimeout("The job has timed out")
log.info(f"Upstream Dag {external_dag_id} is successful")


Expand Down
130 changes: 130 additions & 0 deletions tests/airflow_plugins/test_task_dependency.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,130 @@
from datetime import timedelta

import pytest
from requests.exceptions import HTTPError
from requests_mock.mocker import Mocker as RequestsMocker

from brickflow_plugins.airflow.operators.external_tasks import (
TaskDependencySensor,
AirflowProxyOktaClusterAuth,
log,
AirflowSensorTimeout,
)

BASE_URL = "https://42.airflow.my-org.com/foo"


class TestTaskDependencySensor:
log.propagate = True

@pytest.fixture(autouse=True, name="api", scope="class")
def mock_api(self):
"""
End-to-end test scenario for Airflow v2 API
"""
rm = RequestsMocker()
# DAG Run Endpoint
rm.register_uri(
method="GET",
url=f"{BASE_URL}/api/v1/dags/test-dag/dagRuns?execution_date_gte=2024-01-01T00:00:00Z",
response_list=[
# Test 1: No Run
{"json": {"dag_runs": [], "total_entries": 0}, "status_code": int(200)},
# Test 2: Run Exists
{
"json": {
"dag_runs": [
{
"conf": {},
"dag_id": "test-dag",
"dag_run_id": "manual__2024-01-01T01:00:00.000000+00:00",
"end_date": "2024-01-01T01:10:00.000000+00:00",
"execution_date": "2024-01-01T01:00:00.000000+00:00",
"external_trigger": True,
"logical_date": "2024-01-01T01:00:00.000000+00:00",
"start_date": "2024-01-01T01:00:00.000000+00:00",
"state": "success",
},
],
"total_entries": 1,
},
"status_code": int(200),
},
# Test 2: No Run
],
)
# Task Instance Endpoint
rm.register_uri(
method="GET",
url=(
f"{BASE_URL}"
f"/api/v1/dags/test-dag/dagRuns/manual__2024-01-01T01:00:00.000000+00:00/taskInstances/test-task"
),
response_list=[
{"json": {"state": "running"}, "status_code": int(200)},
{"json": {"state": "failed"}, "status_code": int(200)},
{"json": {"state": "success"}, "status_code": int(200)},
],
)
yield rm

@pytest.fixture()
def sensor(self, mocker):
auth_mock = mocker.MagicMock(spec=AirflowProxyOktaClusterAuth)
auth_mock.get_access_token.return_value = "foo"
auth_mock.get_airflow_api_url.return_value = BASE_URL
auth_mock.get_version.return_value = "2.0.2"

yield TaskDependencySensor(
external_dag_id="test-dag",
external_task_id="test-task",
allowed_states=["success"],
execution_delta=timedelta(**{"hours": -3}),
airflow_auth=auth_mock,
task_id="foo",
poke_interval=1,
)

def test_api_airflow_v2(self, api, caplog, sensor):
# Scenario
# Airflow API poked 4 times
# 1. No Run
# 2. Run Exists - Task is in Running State
# 3. Run Exists - Task is in Failed State
# 4. Run Exists - Task is in Success State
with api:
sensor.execute(context={"execution_date": "2024-01-01T03:00:00Z"})

assert (
"No Runs found for test-dag dag after 2024-01-01T00:00:00Z, please check upstream dag"
in caplog.text
)
assert "task_status=running" in caplog.text
assert "task_status=failed" in caplog.text
assert "task_status=success" in caplog.text
assert "Poking.. 4 round" in caplog.text

def test_non_200(self, sensor):
rm = RequestsMocker()
rm.get(
f"{BASE_URL}/api/v1/dags/test-dag/dagRuns?execution_date_gte=2024-01-01T00:00:00Z",
status_code=404,
)

with pytest.raises(HTTPError):
with rm:
sensor.execute(context={"execution_date": "2024-01-01T03:00:00Z"})

def test_timeout(self, sensor):
sensor.timeout = 1

rm = RequestsMocker()
rm.get(
url=f"{BASE_URL}/api/v1/dags/test-dag/dagRuns?execution_date_gte=2024-01-01T00:00:00Z",
status_code=200,
json={"dag_runs": [], "total_entries": 0},
)

with pytest.raises(AirflowSensorTimeout):
with rm:
sensor.execute(context={"execution_date": "2024-01-01T03:00:00Z"})

0 comments on commit 06d7a62

Please sign in to comment.