Skip to content

Commit

Permalink
Add retry logic in the scheduler for updating trigger timeouts in cas…
Browse files Browse the repository at this point in the history
…e of deadlocks. (apache#41429)

* Add retry in update trigger timeout

* add ut for these cases

* use OperationalError in ut to describe deadlock scenarios

* [MINOR] add newsfragment for this PR

* [MINOR] refactor UT for mypy check
  • Loading branch information
TakawaAkirayo authored Oct 2, 2024
1 parent 39d207b commit 00589cf
Show file tree
Hide file tree
Showing 3 changed files with 98 additions and 17 deletions.
36 changes: 20 additions & 16 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1884,23 +1884,27 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int:
return len(to_reset)

@provide_session
def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None:
def check_trigger_timeouts(
self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION
) -> None:
"""Mark any "deferred" task as failed if the trigger or execution timeout has passed."""
num_timed_out_tasks = session.execute(
update(TI)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < timezone.utcnow(),
)
.values(
state=TaskInstanceState.SCHEDULED,
next_method="__fail__",
next_kwargs={"error": "Trigger/execution timeout"},
trigger_id=None,
)
).rowcount
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
num_timed_out_tasks = session.execute(
update(TI)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < timezone.utcnow(),
)
.values(
state=TaskInstanceState.SCHEDULED,
next_method="__fail__",
next_kwargs={"error": "Trigger/execution timeout"},
trigger_id=None,
)
).rowcount
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)

# [START find_zombies]
def _find_zombies(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions newsfragments/41429.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``run_with_db_retries`` when the scheduler updates the deferred Task as failed to tolerate database deadlock issues.
78 changes: 77 additions & 1 deletion tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def clean_db():
@pytest.fixture(autouse=True)
def per_test(self) -> Generator:
self.clean_db()
self.job_runner = None
self.job_runner: SchedulerJobRunner | None = None

yield

Expand Down Expand Up @@ -5192,6 +5192,82 @@ def test_timeout_triggers(self, dag_maker):
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED

def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker):
"""
Tests that it will retry on DB error like deadlock when updating timeout triggers.
"""
from sqlalchemy.exc import OperationalError

retry_times = 3

session = settings.Session()
# Create the test DAG and task
with dag_maker(
dag_id="test_retry_on_db_error_when_update_timeout_triggers",
start_date=DEFAULT_DATE,
schedule="@once",
max_active_runs=1,
session=session,
):
EmptyOperator(task_id="dummy1")

# Mock the db failure within retry times
might_fail_session = MagicMock(wraps=session)

def check_if_trigger_timeout(max_retries: int):
def make_side_effect():
call_count = 0

def side_effect(*args, **kwargs):
nonlocal call_count
if call_count < retry_times - 1:
call_count += 1
raise OperationalError("any_statement", "any_params", "any_orig")
else:
return session.execute(*args, **kwargs)

return side_effect

might_fail_session.execute.side_effect = make_side_effect()

try:
# Create a Task Instance for the task that is allegedly deferred
# but past its timeout, and one that is still good.
# We don't actually need a linked trigger here; the code doesn't check.
dr1 = dag_maker.create_dagrun()
dr2 = dag_maker.create_dagrun(
run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(seconds=1)
)
ti1 = dr1.get_task_instance("dummy1", session)
ti2 = dr2.get_task_instance("dummy1", session)
ti1.state = State.DEFERRED
ti1.trigger_timeout = timezone.utcnow() - datetime.timedelta(seconds=60)
ti2.state = State.DEFERRED
ti2.trigger_timeout = timezone.utcnow() + datetime.timedelta(seconds=60)
session.flush()

# Boot up the scheduler and make it check timeouts
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull)

self.job_runner.check_trigger_timeouts(max_retries=max_retries, session=might_fail_session)

# Make sure that TI1 is now scheduled to fail, and 2 wasn't touched
session.refresh(ti1)
session.refresh(ti2)
assert ti1.state == State.SCHEDULED
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED
finally:
self.clean_db()

# Positive case, will retry until success before reach max retry times
check_if_trigger_timeout(retry_times)

# Negative case: no retries, execute only once.
with pytest.raises(OperationalError):
check_if_trigger_timeout(1)

def test_find_zombies_nothing(self):
executor = MockExecutor(do_update=False)
scheduler_job = Job(executor=executor)
Expand Down

0 comments on commit 00589cf

Please sign in to comment.