From e25ff2ea296b5284546a8ba49823f2e62cb2ace6 Mon Sep 17 00:00:00 2001 From: Wenjun Si <wenjun.swj@alibaba-inc.com> Date: Mon, 24 Jan 2022 19:33:08 +0800 Subject: [PATCH] TRY fix fault inject --- mars/deploy/oscar/tests/test_cmdline.py | 68 ++++---- mars/services/scheduling/api/oscar.py | 14 -- .../services/scheduling/supervisor/manager.py | 47 ++++-- .../services/scheduling/tests/test_service.py | 3 - mars/services/scheduling/worker/__init__.py | 2 +- .../worker/{exec => execution}/__init__.py | 1 + .../worker/{exec => execution}/actor.py | 146 ++++++++++++++---- .../worker/{exec => execution}/core.py | 15 +- .../worker/{exec => execution}/prepare.py | 0 .../{exec => execution}/tests/__init__.py | 0 .../{exec => execution}/tests/test_exec.py | 0 .../{exec => execution}/tests/test_prepare.py | 0 mars/services/scheduling/worker/service.py | 2 +- mars/services/task/supervisor/stage.py | 4 - .../services/tests/fault_injection_manager.py | 3 + mars/services/tests/fault_injection_patch.py | 70 ++++++++- 16 files changed, 265 insertions(+), 110 deletions(-) rename mars/services/scheduling/worker/{exec => execution}/__init__.py (94%) rename mars/services/scheduling/worker/{exec => execution}/actor.py (82%) rename mars/services/scheduling/worker/{exec => execution}/core.py (81%) rename mars/services/scheduling/worker/{exec => execution}/prepare.py (100%) rename mars/services/scheduling/worker/{exec => execution}/tests/__init__.py (100%) rename mars/services/scheduling/worker/{exec => execution}/tests/test_exec.py (100%) rename mars/services/scheduling/worker/{exec => execution}/tests/test_prepare.py (100%) diff --git a/mars/deploy/oscar/tests/test_cmdline.py b/mars/deploy/oscar/tests/test_cmdline.py index 836a5d3b01..e71b6054af 100644 --- a/mars/deploy/oscar/tests/test_cmdline.py +++ b/mars/deploy/oscar/tests/test_cmdline.py @@ -111,38 +111,6 @@ def _get_labelled_port(label=None, create=True): supervisor_cmd_start = [sys.executable, "-m", "mars.deploy.oscar.supervisor"] worker_cmd_start = [sys.executable, "-m", "mars.deploy.oscar.worker"] -start_params = { - "bare_start": [ - supervisor_cmd_start, - worker_cmd_start - + [ - "--config-file", - os.path.join(os.path.dirname(__file__), "local_test_config.yml"), - ], - False, - ], - "with_supervisors": [ - supervisor_cmd_start - + [ - "-e", - lambda: f'127.0.0.1:{_get_labelled_port("supervisor")}', - "-w", - lambda: str(_get_labelled_port("web")), - "--n-process", - "2", - ], - worker_cmd_start - + [ - "-e", - lambda: f"127.0.0.1:{get_next_port(occupy=True)}", - "-s", - lambda: f'127.0.0.1:{_get_labelled_port("supervisor")}', - "--config-file", - os.path.join(os.path.dirname(__file__), "local_test_config.yml"), - ], - True, - ], -} def _reload_args(args): @@ -159,8 +127,40 @@ def _reload_args(args): @pytest.mark.parametrize( "supervisor_args,worker_args,use_web_addr", - list(start_params.values()), - ids=list(start_params.keys()), + [ + pytest.param( + supervisor_cmd_start, + worker_cmd_start + + [ + "--config-file", + os.path.join(os.path.dirname(__file__), "local_test_config.yml"), + ], + False, + id="bare_start", + ), + pytest.param( + supervisor_cmd_start + + [ + "-e", + lambda: f'127.0.0.1:{_get_labelled_port("supervisor")}', + "-w", + lambda: str(_get_labelled_port("web")), + "--n-process", + "2", + ], + worker_cmd_start + + [ + "-e", + lambda: f"127.0.0.1:{get_next_port(occupy=True)}", + "-s", + lambda: f'127.0.0.1:{_get_labelled_port("supervisor")}', + "--config-file", + os.path.join(os.path.dirname(__file__), "local_test_config.yml"), + ], + True, + id="with_supervisors", + ), + ], ) @flaky(max_runs=10, rerun_filter=lambda err, *_: issubclass(err[0], _rerun_errors)) def test_cmdline_run(supervisor_args, worker_args, use_web_addr): diff --git a/mars/services/scheduling/api/oscar.py b/mars/services/scheduling/api/oscar.py index e2d5862d64..fb8c4f8b5c 100644 --- a/mars/services/scheduling/api/oscar.py +++ b/mars/services/scheduling/api/oscar.py @@ -117,20 +117,6 @@ async def cancel_subtasks( """ await self._manager_ref.cancel_subtasks(subtask_ids, kill_timeout=kill_timeout) - async def finish_subtasks(self, subtask_ids: List[str], schedule_next: bool = True): - """ - Mark subtasks as finished, letting scheduling service to schedule - next tasks in the ready queue - - Parameters - ---------- - subtask_ids - ids of subtasks to mark as finished - schedule_next - whether to schedule succeeding subtasks - """ - await self._manager_ref.finish_subtasks(subtask_ids, schedule_next) - class MockSchedulingAPI(SchedulingAPI): @classmethod diff --git a/mars/services/scheduling/supervisor/manager.py b/mars/services/scheduling/supervisor/manager.py index 149b5fd035..758e91debd 100644 --- a/mars/services/scheduling/supervisor/manager.py +++ b/mars/services/scheduling/supervisor/manager.py @@ -95,8 +95,7 @@ async def __post_create__(self): AssignerActor.gen_uid(self._session_id), address=self.address ) - @alru_cache - async def _get_task_api(self): + async def _get_task_api(self) -> TaskAPI: return await TaskAPI.create(self._session_id, self.address) def _put_subtask_with_priority(self, subtask: Subtask, priority: Tuple = None): @@ -272,21 +271,47 @@ async def update_subtask_priorities( @alru_cache(maxsize=10000) async def _get_execution_ref(self, address: str): - from ..worker.exec import SubtaskExecutionActor + from ..worker.execution import SubtaskExecutionActor return await mo.actor_ref(SubtaskExecutionActor.default_uid(), address=address) - async def finish_subtasks(self, subtask_ids: List[str], schedule_next: bool = True): - band_tasks = defaultdict(lambda: 0) - for subtask_id in subtask_ids: - subtask_info = self._subtask_infos.pop(subtask_id, None) + async def set_subtask_results( + self, subtask_results: List[SubtaskResult], source_bands: List[BandType] + ): + delays = [] + task_api = await self._get_task_api() + for result, band in zip(subtask_results, source_bands): + if result.status == SubtaskStatus.errored: + subtask_info = self._subtask_infos.get(result.subtask_id) + if ( + subtask_info is not None + and subtask_info.subtask.retryable + and subtask_info.num_reschedules < subtask_info.max_reschedules + and isinstance(result.error, (MarsError, OSError)) + ): + subtask_info.num_reschedules += 1 + logger.warning( + "Resubmit subtask %s at attempt %d", + subtask_info.subtask.subtask_id, + subtask_info.num_reschedules, + ) + execution_ref = await self._get_execution_ref(band[0]) + await execution_ref.submit_subtasks.tell( + [subtask_info.subtask], + [subtask_info.priority], + self.address, + band[1], + ) + continue + + subtask_info = self._subtask_infos.pop(result.subtask_id, None) if subtask_info is not None: - self._subtask_summaries[subtask_id] = subtask_info.to_summary( + self._subtask_summaries[result.subtask_id] = subtask_info.to_summary( is_finished=True ) - if schedule_next: - for band in subtask_info.submitted_bands: - band_tasks[band] += 1 + delays.append(task_api.set_subtask_result.delay(result)) + + await task_api.set_subtask_result.batch(*delays) def _get_subtasks_by_ids(self, subtask_ids: List[str]) -> List[Optional[Subtask]]: subtasks = [] diff --git a/mars/services/scheduling/tests/test_service.py b/mars/services/scheduling/tests/test_service.py index 06e0e150c0..0362d4c065 100644 --- a/mars/services/scheduling/tests/test_service.py +++ b/mars/services/scheduling/tests/test_service.py @@ -171,7 +171,6 @@ async def test_schedule_success(actor_pools): subtask.expect_bands = [(worker_pool.external_address, "numa-0")] await scheduling_api.add_subtasks([subtask], [(0,)]) await task_manager_ref.wait_subtask_result(subtask.subtask_id) - await scheduling_api.finish_subtasks([subtask.subtask_id]) result_key = next(subtask.chunk_graph.iter_indep(reverse=True)).key result = await storage_api.get(result_key) @@ -197,7 +196,6 @@ def _remote_fun(secs): async def _waiter_fun(subtask_id): await task_manager_ref.wait_subtask_result(subtask_id) - await scheduling_api.finish_subtasks([subtask_id]) finish_ids.append(subtask_id) finish_time.append(time.time()) @@ -245,7 +243,6 @@ def _remote_fun(secs): async def _waiter_fun(subtask_id): await task_manager_ref.wait_subtask_result(subtask_id) - await scheduling_api.finish_subtasks([subtask_id]) subtasks = [] wait_tasks = [] diff --git a/mars/services/scheduling/worker/__init__.py b/mars/services/scheduling/worker/__init__.py index 369b0adbfb..d4f84b6125 100644 --- a/mars/services/scheduling/worker/__init__.py +++ b/mars/services/scheduling/worker/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .exec import SubtaskExecutionActor +from .execution import SubtaskExecutionActor from .queues import SubtaskExecutionQueueActor, SubtaskPrepareQueueActor from .quota import QuotaActor, MemQuotaActor, WorkerQuotaManagerActor from .service import SchedulingWorkerService diff --git a/mars/services/scheduling/worker/exec/__init__.py b/mars/services/scheduling/worker/execution/__init__.py similarity index 94% rename from mars/services/scheduling/worker/exec/__init__.py rename to mars/services/scheduling/worker/execution/__init__.py index d6091cde2e..83f5970252 100644 --- a/mars/services/scheduling/worker/exec/__init__.py +++ b/mars/services/scheduling/worker/execution/__init__.py @@ -13,4 +13,5 @@ # limitations under the License. from .actor import SubtaskExecutionActor +from .core import SubtaskExecutionInfo from .prepare import SubtaskPreparer diff --git a/mars/services/scheduling/worker/exec/actor.py b/mars/services/scheduling/worker/execution/actor.py similarity index 82% rename from mars/services/scheduling/worker/exec/actor.py rename to mars/services/scheduling/worker/execution/actor.py index d8776a0fac..bcdd97295d 100644 --- a/mars/services/scheduling/worker/exec/actor.py +++ b/mars/services/scheduling/worker/execution/actor.py @@ -26,7 +26,6 @@ from ....cluster import ClusterAPI from ....core import ActorCallback from ....subtask import Subtask, SubtaskAPI, SubtaskResult, SubtaskStatus -from ....task import TaskAPI from ..queues import SubtaskPrepareQueueActor, SubtaskExecutionQueueActor from ..quota import QuotaActor from ..slotmanager import SlotManagerActor @@ -102,6 +101,16 @@ async def _get_band_quota_ref( ) -> Union[mo.ActorRef, QuotaActor]: return await mo.actor_ref(QuotaActor.gen_uid(band_name), address=self.address) + @staticmethod + @alru_cache(cache_exceptions=False) + async def _get_manager_ref(session_id: str, supervisor_address: str): + from ...supervisor.manager import SubtaskManagerActor + + return await mo.actor_ref( + uid=SubtaskManagerActor.gen_uid(session_id), + address=supervisor_address, + ) + def _build_subtask_info( self, subtask: Subtask, @@ -109,12 +118,20 @@ def _build_subtask_info( supervisor_address: str, band_name: str, ) -> SubtaskExecutionInfo: + subtask_max_retries = ( + subtask.extra_config.get("subtask_max_retries") + if subtask.extra_config + else None + ) + if subtask_max_retries is None: + subtask_max_retries = self._subtask_max_retries + subtask_info = SubtaskExecutionInfo( subtask, priority, supervisor_address=supervisor_address, band_name=band_name, - max_retries=self._subtask_max_retries, + max_retries=subtask_max_retries, ) subtask_info.result = SubtaskResult( subtask_id=subtask.subtask_id, @@ -216,8 +233,15 @@ async def submit_subtasks( subtask = self._subtask_caches[subtask].subtask except KeyError: subtask = self._subtask_executions[subtask].subtask - if subtask.subtask_id in self._subtask_executions: - continue + try: + info = self._subtask_executions[subtask.subtask_id] + if info.result.status not in ( + SubtaskStatus.cancelled, + SubtaskStatus.errored, + ): + continue + except KeyError: + pass subtask_info = self._build_subtask_info( subtask, @@ -252,18 +276,19 @@ async def _dequeue_subtask_ids(self, queue_ref, subtask_ids: List[str]): infos_to_report.append(subtask_info) await self._report_subtask_results(infos_to_report) - @staticmethod - async def _report_subtask_results(subtask_infos: List[SubtaskExecutionInfo]): + async def _report_subtask_results(self, subtask_infos: List[SubtaskExecutionInfo]): if not subtask_infos: return - task_api = await TaskAPI.create( - subtask_infos[0].result.session_id, subtask_infos[0].supervisor_address + try: + manager_ref = await self._get_manager_ref( + subtask_infos[0].result.session_id, subtask_infos[0].supervisor_address + ) + except mo.ActorNotExist: + return + await manager_ref.set_subtask_results.tell( + [info.result for info in subtask_infos], + [(self.address, info.band_name) for info in subtask_infos], ) - batch = [ - task_api.set_subtask_result.delay(subtask_info.result) - for subtask_info in subtask_infos - ] - await task_api.set_subtask_result.batch(*batch) async def cancel_subtasks( self, subtask_ids: List[str], kill_timeout: Optional[int] = 5 @@ -289,6 +314,7 @@ async def cancel_subtasks( self.uncache_subtasks(subtask_ids) + infos_to_report = [] for subtask_id in subtask_ids: try: subtask_info = self._subtask_executions[subtask_id] @@ -296,6 +322,8 @@ async def cancel_subtasks( continue if not subtask_info.result.status.is_done: self._fill_result_with_exc(subtask_info, exc_cls=asyncio.CancelledError) + infos_to_report.append(subtask_info) + await self._report_subtask_results(infos_to_report) async def wait_subtasks(self, subtask_ids: List[str]): infos = [ @@ -307,6 +335,28 @@ async def wait_subtasks(self, subtask_ids: List[str]): yield asyncio.wait([info.finish_future for info in infos]) raise mo.Return([info.result for info in infos]) + def _create_subtask_with_exception(self, subtask_id, coro): + info = self._subtask_executions[subtask_id] + + async def _run_with_exception_handling(): + try: + return await coro + except: # noqa: E722 # nosec # pylint: disable=bare-except + self._fill_result_with_exc(info) + await self._report_subtask_results([info]) + await self._prepare_queue_ref.release_slot( + info.subtask.subtask_id, errors="ignore" + ) + await self._execution_queue_ref.release_slot( + info.subtask.subtask_id, errors="ignore" + ) + for aio_task in info.aio_tasks: + if aio_task is not asyncio.current_task(): + aio_task.cancel() + + task = asyncio.create_task(_run_with_exception_handling()) + info.aio_tasks.append(task) + async def handle_prepare_queue(self, band_name: str): while True: try: @@ -322,8 +372,8 @@ async def handle_prepare_queue(self, band_name: str): continue logger.debug(f"Obtained subtask {subtask_id} from prepare queue") - subtask_info.aio_tasks.append( - asyncio.create_task(self._prepare_subtask_with_retry(subtask_info)) + self._create_subtask_with_exception( + subtask_id, self._prepare_subtask_with_retry(subtask_info) ) async def handle_execute_queue(self, band_name: str): @@ -355,8 +405,8 @@ async def handle_execute_queue(self, band_name: str): c.key in self._pred_key_mapping_dag for c in subtask_info.subtask.chunk_graph.result_chunks ) - subtask_info.aio_tasks.append( - asyncio.create_task(self._execute_subtask_with_retry(subtask_info)) + self._create_subtask_with_exception( + subtask_id, self._execute_subtask_with_retry(subtask_info) ) async def _prepare_subtask_once(self, subtask_info: SubtaskExecutionInfo): @@ -438,7 +488,20 @@ async def _execute_subtask_with_retry(self, subtask_info: SubtaskExecutionInfo): subtask_info, max_retries=subtask_info.max_retries if subtask.retryable else 0, ) - except: # noqa: E722 # nosec # pylint: disable=bare-except + except Exception as ex: # noqa: E722 # nosec # pylint: disable=bare-except + if not subtask.retryable: + unretryable_op = [ + chunk.op + for chunk in subtask.chunk_graph + if not getattr(chunk.op, "retryable", True) + ] + logger.exception( + "Run subtask failed due to %r, the subtask %s is " + "not retryable, it contains unretryable op: %r", + ex, + subtask.subtask_id, + unretryable_op, + ) self._fill_result_with_exc(subtask_info) finally: self._subtask_executions.pop(subtask.subtask_id, None) @@ -454,18 +517,19 @@ async def _execute_subtask_with_retry(self, subtask_info: SubtaskExecutionInfo): return subtask_info.result @classmethod - async def _call_with_retry( + def _log_subtask_retry( cls, - target_func: Callable, subtask_info: SubtaskExecutionInfo, - max_retries: Optional[int] = None, + target_func: Callable, + trial: int, + exc_info: Tuple, + retry: bool = True, ): subtask = subtask_info.subtask - max_retries = max_retries or subtask_info.max_retries - - def log_func(trial: int, exc_info: Tuple, retry: bool = True): - subtask_info.num_retries = trial - if retry: + max_retries = subtask_info.max_retries + subtask_info.num_retries = trial + if retry: + if trial < max_retries - 1: logger.error( "Rerun %s of subtask %s at attempt %d due to %s", target_func, @@ -475,22 +539,42 @@ def log_func(trial: int, exc_info: Tuple, retry: bool = True): ) else: logger.exception( - "Failed to rerun the %s of subtask %s, " - "num_retries: %s, max_retries: %s", - target_func, - subtask.subtask_id, + "Exceed max rerun (%s / %s): %s of subtask %s due to %s", trial + 1, max_retries, + target_func, + subtask.subtask_id, + exc_info[1], exc_info=exc_info, ) + else: + logger.exception( + "Failed to rerun %s of subtask %s due to unhandled exception: %s", + target_func, + subtask.subtask_id, + exc_info[1], + exc_info=exc_info, + ) + + @classmethod + async def _call_with_retry( + cls, + target_func: Callable, + subtask_info: SubtaskExecutionInfo, + max_retries: Optional[int] = None, + ): + subtask_info.max_retries = max_retries or subtask_info.max_retries if subtask_info.max_retries <= 1: return await target_func(subtask_info) else: + err_callback = functools.partial( + cls._log_subtask_retry, subtask_info, target_func + ) return await call_with_retry( functools.partial(target_func, subtask_info), max_retries=max_retries, - error_callback=log_func, + error_callback=err_callback, ) @classmethod diff --git a/mars/services/scheduling/worker/exec/core.py b/mars/services/scheduling/worker/execution/core.py similarity index 81% rename from mars/services/scheduling/worker/exec/core.py rename to mars/services/scheduling/worker/execution/core.py index dcdcb891de..01cab8e170 100644 --- a/mars/services/scheduling/worker/exec/core.py +++ b/mars/services/scheduling/worker/execution/core.py @@ -58,12 +58,17 @@ async def call_with_retry( try: return await async_fun() except (OSError, MarsError): + exc_info_raw = sys.exc_info() + exc_info = error_callback(trial=trial, exc_info=exc_info_raw, retry=True) + exc_info = exc_info or exc_info_raw + if trial >= max_retries - 1: - error_callback(trial=trial, exc_info=sys.exc_info(), retry=False) - raise - error_callback(trial=trial, exc_info=sys.exc_info(), retry=True) + raise exc_info[1].with_traceback(exc_info[-1]) except asyncio.CancelledError: raise except: # noqa: E722 # nosec # pylint: disable=bare-except - error_callback(trial=trial, exc_info=sys.exc_info(), retry=False) - raise + exc_info_raw = sys.exc_info() + exc_info = error_callback(trial=trial, exc_info=exc_info_raw, retry=False) + exc_info = exc_info or exc_info_raw + + raise exc_info[1].with_traceback(exc_info[-1]) diff --git a/mars/services/scheduling/worker/exec/prepare.py b/mars/services/scheduling/worker/execution/prepare.py similarity index 100% rename from mars/services/scheduling/worker/exec/prepare.py rename to mars/services/scheduling/worker/execution/prepare.py diff --git a/mars/services/scheduling/worker/exec/tests/__init__.py b/mars/services/scheduling/worker/execution/tests/__init__.py similarity index 100% rename from mars/services/scheduling/worker/exec/tests/__init__.py rename to mars/services/scheduling/worker/execution/tests/__init__.py diff --git a/mars/services/scheduling/worker/exec/tests/test_exec.py b/mars/services/scheduling/worker/execution/tests/test_exec.py similarity index 100% rename from mars/services/scheduling/worker/exec/tests/test_exec.py rename to mars/services/scheduling/worker/execution/tests/test_exec.py diff --git a/mars/services/scheduling/worker/exec/tests/test_prepare.py b/mars/services/scheduling/worker/execution/tests/test_prepare.py similarity index 100% rename from mars/services/scheduling/worker/exec/tests/test_prepare.py rename to mars/services/scheduling/worker/execution/tests/test_prepare.py diff --git a/mars/services/scheduling/worker/service.py b/mars/services/scheduling/worker/service.py index 1ed6e4eaae..a71694d7c7 100644 --- a/mars/services/scheduling/worker/service.py +++ b/mars/services/scheduling/worker/service.py @@ -15,10 +15,10 @@ from .... import oscar as mo from ....utils import calc_size_by_str from ...core import AbstractService +from .execution import SubtaskExecutionActor from .slotmanager import SlotManagerActor from .queues import SubtaskPrepareQueueActor, SubtaskExecutionQueueActor from .quota import WorkerQuotaManagerActor -from .exec import SubtaskExecutionActor class SchedulingWorkerService(AbstractService): diff --git a/mars/services/task/supervisor/stage.py b/mars/services/task/supervisor/stage.py index 4afd314f16..9d6a5aa64a 100644 --- a/mars/services/task/supervisor/stage.py +++ b/mars/services/task/supervisor/stage.py @@ -152,9 +152,6 @@ async def set_subtask_result(self, result: SubtaskResult): await self._update_chunks_meta(self.chunk_graph) # tell scheduling to finish subtasks - await self._scheduling_api.finish_subtasks( - [result.subtask_id], schedule_next=not error_or_cancelled - ) if self.result.status != TaskStatus.terminated: self.result = TaskResult( self.task.task_id, @@ -196,7 +193,6 @@ async def set_subtask_result(self, result: SubtaskResult): # all predecessors finished to_schedule_subtasks.append(succ_subtask) await self._schedule_subtasks(to_schedule_subtasks) - await self._scheduling_api.finish_subtasks([result.subtask_id]) async def run(self): if len(self.subtask_graph) == 0: diff --git a/mars/services/tests/fault_injection_manager.py b/mars/services/tests/fault_injection_manager.py index 06fa7fb310..0b29d10881 100644 --- a/mars/services/tests/fault_injection_manager.py +++ b/mars/services/tests/fault_injection_manager.py @@ -13,6 +13,7 @@ # limitations under the License. import enum +import logging import os import uuid from abc import ABC, abstractmethod @@ -20,6 +21,8 @@ from ...core.base import MarsError from ..session import SessionAPI +logger = logging.getLogger(__name__) + class ExtraConfigKey: FAULT_INJECTION_MANAGER_NAME = "fault_injection_manager_name" diff --git a/mars/services/tests/fault_injection_patch.py b/mars/services/tests/fault_injection_patch.py index 6c410a4278..7e2e07fe68 100644 --- a/mars/services/tests/fault_injection_patch.py +++ b/mars/services/tests/fault_injection_patch.py @@ -12,15 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Union, Any, Dict +import sys +from typing import Any, Callable, Dict, Tuple, Union from ... import oscar as mo from ...core import OperandType from ...lib.aio import alru_cache from ...tests.core import patch_cls, patch_super as super from ..session import SessionAPI -from ..scheduling.worker.exec import SubtaskExecutionActor -from ..subtask import Subtask +from ..scheduling.worker.execution import SubtaskExecutionActor, SubtaskExecutionInfo from ..subtask.worker.processor import SubtaskProcessor from ..tests.fault_injection_manager import ( AbstractFaultInjectionManager, @@ -44,14 +44,38 @@ async def _get_fault_injection_manager_ref( async def _get_session_api(supervisor_address: str): return await SessionAPI.create(supervisor_address) - async def internal_run_subtask(self, subtask: Subtask, band_name: str): + async def _execute_subtask_once(self, subtask_info: SubtaskExecutionInfo): + try: + return await super()._execute_subtask_once(subtask_info) + except: # noqa: E722 # nosec # pylint: disable=bare-except + exc_info = sys.exc_info() + subtask = subtask_info.subtask + if not subtask.retryable: + unretryable_op = [ + chunk.op + for chunk in subtask.chunk_graph + if not getattr(chunk.op, "retryable", True) + ] + message = ( + f"Run subtask failed due to {exc_info[1]}, " + f"the subtask {subtask.subtask_id} is not retryable, " + f"it contains unretryable op: {unretryable_op!r}" + ) + _UnretryableException = type( + "_UnretryableException", (exc_info[0],), {} + ) + raise _UnretryableException(message).with_traceback(exc_info[-1]) + else: + raise + + async def _execute_subtask_with_retry(self, subtask_info: SubtaskExecutionInfo): + subtask = subtask_info.subtask # fault injection if subtask.extra_config: fault_injection_manager_name = subtask.extra_config.get( ExtraConfigKey.FAULT_INJECTION_MANAGER_NAME ) if fault_injection_manager_name is not None: - subtask_info = self._subtask_info[subtask.subtask_id] fault_injection_manager = await self._get_fault_injection_manager_ref( subtask_info.supervisor_address, subtask.session_id, @@ -61,7 +85,41 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str): FaultPosition.ON_RUN_SUBTASK, {"subtask": subtask} ) handle_fault(fault) - return super().internal_run_subtask(subtask, band_name) + return await super()._execute_subtask_with_retry(subtask_info) + + @classmethod + def _log_subtask_retry( + cls, + subtask_info: SubtaskExecutionInfo, + target_func: Callable, + trial: int, + exc_info: Tuple, + retry: bool = True, + ): + exc_info = ( + super()._log_subtask_retry( + subtask_info, target_func, trial, exc_info, retry=retry + ) + or exc_info + ) + + if retry: + if trial < subtask_info.max_retries - 1: + return exc_info + else: + _ExceedMaxRerun = type("_ExceedMaxRerun", (exc_info[0],), {}) + return ( + _ExceedMaxRerun, + _ExceedMaxRerun(str(exc_info[1])).with_traceback(exc_info[-1]), + exc_info[-1], + ) + else: + _UnhandledException = type("_UnhandledException", (exc_info[0],), {}) + return ( + _UnhandledException, + _UnhandledException(str(exc_info[1])).with_traceback(exc_info[-1]), + exc_info[-1], + ) @patch_cls(SubtaskProcessor)