From fcbf6fe926435c8bed5d7f0f0becbd4540613c10 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Wed, 2 Oct 2024 20:44:29 +0300 Subject: [PATCH] [BACKPORT] Add retry logic in the scheduler for updating trigger timeouts in case of deadlocks. (#41429) (#42651) * Add retry logic in the scheduler for updating trigger timeouts in case of deadlocks. (#41429) * Add retry in update trigger timeout * add ut for these cases * use OperationalError in ut to describe deadlock scenarios * [MINOR] add newsfragment for this PR * [MINOR] refactor UT for mypy check (cherry picked from commit 00589cf8fe8faffa6f994e9b85717cb2babc1631) * Fix type-ignore comment for typing changes (#42656) --------- Co-authored-by: TakawaAkirayo <153728772+TakawaAkirayo@users.noreply.github.com> Co-authored-by: Tzu-ping Chung --- airflow/jobs/scheduler_job_runner.py | 36 +++++++------ newsfragments/41429.improvement.rst | 1 + tests/jobs/test_scheduler_job.py | 80 +++++++++++++++++++++++++++- 3 files changed, 99 insertions(+), 18 deletions(-) create mode 100644 newsfragments/41429.improvement.rst diff --git a/airflow/jobs/scheduler_job_runner.py b/airflow/jobs/scheduler_job_runner.py index cb82ec8c50d2..3d2872081082 100644 --- a/airflow/jobs/scheduler_job_runner.py +++ b/airflow/jobs/scheduler_job_runner.py @@ -1923,23 +1923,27 @@ def adopt_or_reset_orphaned_tasks(self, session: Session = NEW_SESSION) -> int: return len(to_reset) @provide_session - def check_trigger_timeouts(self, session: Session = NEW_SESSION) -> None: + def check_trigger_timeouts( + self, max_retries: int = MAX_DB_RETRIES, session: Session = NEW_SESSION + ) -> None: """Mark any "deferred" task as failed if the trigger or execution timeout has passed.""" - num_timed_out_tasks = session.execute( - update(TI) - .where( - TI.state == TaskInstanceState.DEFERRED, - TI.trigger_timeout < timezone.utcnow(), - ) - .values( - state=TaskInstanceState.SCHEDULED, - next_method="__fail__", - next_kwargs={"error": "Trigger/execution timeout"}, - trigger_id=None, - ) - ).rowcount - if num_timed_out_tasks: - self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) + for attempt in run_with_db_retries(max_retries, logger=self.log): + with attempt: + num_timed_out_tasks = session.execute( + update(TI) + .where( + TI.state == TaskInstanceState.DEFERRED, + TI.trigger_timeout < timezone.utcnow(), + ) + .values( + state=TaskInstanceState.SCHEDULED, + next_method="__fail__", + next_kwargs={"error": "Trigger/execution timeout"}, + trigger_id=None, + ) + ).rowcount + if num_timed_out_tasks: + self.log.info("Timed out %i deferred tasks without fired triggers", num_timed_out_tasks) # [START find_zombies] def _find_zombies(self) -> None: diff --git a/newsfragments/41429.improvement.rst b/newsfragments/41429.improvement.rst new file mode 100644 index 000000000000..6d04d5dfe61a --- /dev/null +++ b/newsfragments/41429.improvement.rst @@ -0,0 +1 @@ +Add ``run_with_db_retries`` when the scheduler updates the deferred Task as failed to tolerate database deadlock issues. diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py index 2e96728d5eca..7e06ecd0d01f 100644 --- a/tests/jobs/test_scheduler_job.py +++ b/tests/jobs/test_scheduler_job.py @@ -144,7 +144,7 @@ def clean_db(): @pytest.fixture(autouse=True) def per_test(self) -> Generator: self.clean_db() - self.job_runner = None + self.job_runner: SchedulerJobRunner | None = None yield @@ -5227,6 +5227,82 @@ def test_timeout_triggers(self, dag_maker): assert ti1.next_method == "__fail__" assert ti2.state == State.DEFERRED + def test_retry_on_db_error_when_update_timeout_triggers(self, dag_maker): + """ + Tests that it will retry on DB error like deadlock when updating timeout triggers. + """ + from sqlalchemy.exc import OperationalError + + retry_times = 3 + + session = settings.Session() + # Create the test DAG and task + with dag_maker( + dag_id="test_retry_on_db_error_when_update_timeout_triggers", + start_date=DEFAULT_DATE, + schedule="@once", + max_active_runs=1, + session=session, + ): + EmptyOperator(task_id="dummy1") + + # Mock the db failure within retry times + might_fail_session = MagicMock(wraps=session) + + def check_if_trigger_timeout(max_retries: int): + def make_side_effect(): + call_count = 0 + + def side_effect(*args, **kwargs): + nonlocal call_count + if call_count < retry_times - 1: + call_count += 1 + raise OperationalError("any_statement", "any_params", "any_orig") + else: + return session.execute(*args, **kwargs) + + return side_effect + + might_fail_session.execute.side_effect = make_side_effect() + + try: + # Create a Task Instance for the task that is allegedly deferred + # but past its timeout, and one that is still good. + # We don't actually need a linked trigger here; the code doesn't check. + dr1 = dag_maker.create_dagrun() + dr2 = dag_maker.create_dagrun( + run_id="test2", execution_date=DEFAULT_DATE + datetime.timedelta(seconds=1) + ) + ti1 = dr1.get_task_instance("dummy1", session) + ti2 = dr2.get_task_instance("dummy1", session) + ti1.state = State.DEFERRED + ti1.trigger_timeout = timezone.utcnow() - datetime.timedelta(seconds=60) + ti2.state = State.DEFERRED + ti2.trigger_timeout = timezone.utcnow() + datetime.timedelta(seconds=60) + session.flush() + + # Boot up the scheduler and make it check timeouts + scheduler_job = Job() + self.job_runner = SchedulerJobRunner(job=scheduler_job, subdir=os.devnull) + + self.job_runner.check_trigger_timeouts(max_retries=max_retries, session=might_fail_session) + + # Make sure that TI1 is now scheduled to fail, and 2 wasn't touched + session.refresh(ti1) + session.refresh(ti2) + assert ti1.state == State.SCHEDULED + assert ti1.next_method == "__fail__" + assert ti2.state == State.DEFERRED + finally: + self.clean_db() + + # Positive case, will retry until success before reach max retry times + check_if_trigger_timeout(retry_times) + + # Negative case: no retries, execute only once. + with pytest.raises(OperationalError): + check_if_trigger_timeout(1) + def test_find_zombies_nothing(self): executor = MockExecutor(do_update=False) scheduler_job = Job(executor=executor) @@ -5504,7 +5580,7 @@ def spy(*args, **kwargs): def watch_set_state(dr: DagRun, state, **kwargs): if state in (DagRunState.SUCCESS, DagRunState.FAILED): # Stop the scheduler - self.job_runner.num_runs = 1 # type: ignore[attr-defined] + self.job_runner.num_runs = 1 # type: ignore[union-attr] orig_set_state(dr, state, **kwargs) # type: ignore[call-arg] def watch_heartbeat(*args, **kwargs):