Skip to content

Commit

Permalink
[BACKPORT] Add retry logic in the scheduler for updating trigger time…
Browse files Browse the repository at this point in the history
…outs in case of deadlocks. (apache#41429) (apache#42651)

* Add retry logic in the scheduler for updating trigger timeouts in case of deadlocks. (apache#41429)

* Add retry in update trigger timeout

* add ut for these cases

* use OperationalError in ut to describe deadlock scenarios

* [MINOR] add newsfragment for this PR

* [MINOR] refactor UT for mypy check

(cherry picked from commit 00589cf)

* Fix type-ignore comment for typing changes (apache#42656)

---------

Co-authored-by: TakawaAkirayo <[email protected]>
Co-authored-by: Tzu-ping Chung <[email protected]>
  • Loading branch information
3 people authored Oct 2, 2024
1 parent f7556c4 commit fcbf6fe
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 18 deletions.
36 changes: 20 additions & 16 deletions airflow/jobs/scheduler_job_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1923,23 +1923,27 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int:
return len(to_reset)

@provide_session
def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None:
def check_trigger_timeouts(
self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION
) -> None:
"""Mark any "deferred" task as failed if the trigger or execution timeout has passed."""
num_timed_out_tasks = session.execute(
update(TI)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < timezone.utcnow(),
)
.values(
state=TaskInstanceState.SCHEDULED,
next_method="__fail__",
next_kwargs={"error": "Trigger/execution timeout"},
trigger_id=None,
)
).rowcount
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)
for attempt in run_with_db_retries(max_retries, logger=self.log):
with attempt:
num_timed_out_tasks = session.execute(
update(TI)
.where(
TI.state == TaskInstanceState.DEFERRED,
TI.trigger_timeout < timezone.utcnow(),
)
.values(
state=TaskInstanceState.SCHEDULED,
next_method="__fail__",
next_kwargs={"error": "Trigger/execution timeout"},
trigger_id=None,
)
).rowcount
if num_timed_out_tasks:
self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks)

# [START find_zombies]
def _find_zombies(self) -> None:
Expand Down
1 change: 1 addition & 0 deletions newsfragments/41429.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Add ``run_with_db_retries`` when the scheduler updates the deferred Task as failed to tolerate database deadlock issues.
80 changes: 78 additions & 2 deletions tests/jobs/test_scheduler_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def clean_db():
@pytest.fixture(autouse=True)
def per_test(self) -> Generator:
self.clean_db()
self.job_runner = None
self.job_runner: SchedulerJobRunner | None = None

yield

Expand Down Expand Up @@ -5227,6 +5227,82 @@ def test_timeout_triggers(self, dag_maker):
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED

def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker):
"""
Tests that it will retry on DB error like deadlock when updating timeout triggers.
"""
from sqlalchemy.exc import OperationalError

retry_times = 3

session = settings.Session()
# Create the test DAG and task
with dag_maker(
dag_id="test_retry_on_db_error_when_update_timeout_triggers",
start_date=DEFAULT_DATE,
schedule="@once",
max_active_runs=1,
session=session,
):
EmptyOperator(task_id="dummy1")

# Mock the db failure within retry times
might_fail_session = MagicMock(wraps=session)

def check_if_trigger_timeout(max_retries: int):
def make_side_effect():
call_count = 0

def side_effect(*args, **kwargs):
nonlocal call_count
if call_count < retry_times - 1:
call_count += 1
raise OperationalError("any_statement", "any_params", "any_orig")
else:
return session.execute(*args, **kwargs)

return side_effect

might_fail_session.execute.side_effect = make_side_effect()

try:
# Create a Task Instance for the task that is allegedly deferred
# but past its timeout, and one that is still good.
# We don't actually need a linked trigger here; the code doesn't check.
dr1 = dag_maker.create_dagrun()
dr2 = dag_maker.create_dagrun(
run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(seconds=1)
)
ti1 = dr1.get_task_instance("dummy1", session)
ti2 = dr2.get_task_instance("dummy1", session)
ti1.state = State.DEFERRED
ti1.trigger_timeout = timezone.utcnow() - datetime.timedelta(seconds=60)
ti2.state = State.DEFERRED
ti2.trigger_timeout = timezone.utcnow() + datetime.timedelta(seconds=60)
session.flush()

# Boot up the scheduler and make it check timeouts
scheduler_job = Job()
self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull)

self.job_runner.check_trigger_timeouts(max_retries=max_retries, session=might_fail_session)

# Make sure that TI1 is now scheduled to fail, and 2 wasn't touched
session.refresh(ti1)
session.refresh(ti2)
assert ti1.state == State.SCHEDULED
assert ti1.next_method == "__fail__"
assert ti2.state == State.DEFERRED
finally:
self.clean_db()

# Positive case, will retry until success before reach max retry times
check_if_trigger_timeout(retry_times)

# Negative case: no retries, execute only once.
with pytest.raises(OperationalError):
check_if_trigger_timeout(1)

def test_find_zombies_nothing(self):
executor = MockExecutor(do_update=False)
scheduler_job = Job(executor=executor)
Expand Down Expand Up @@ -5504,7 +5580,7 @@ def spy(*args, **kwargs):
def watch_set_state(dr: DagRun, state, **kwargs):
if state in (DagRunState.SUCCESS, DagRunState.FAILED):
# Stop the scheduler
self.job_runner.num_runs = 1 # type: ignore[attr-defined]
self.job_runner.num_runs = 1 # type: ignore[union-attr]
orig_set_state(dr, state, **kwargs) # type: ignore[call-arg]

def watch_heartbeat(*args, **kwargs):
Expand Down

0 comments on commit fcbf6fe

Please sign in to comment.