From e25ff2ea296b5284546a8ba49823f2e62cb2ace6 Mon Sep 17 00:00:00 2001
From: Wenjun Si <wenjun.swj@alibaba-inc.com>
Date: Mon, 24 Jan 2022 19:33:08 +0800
Subject: [PATCH] TRY fix fault inject

---
 mars/deploy/oscar/tests/test_cmdline.py       |  68 ++++----
 mars/services/scheduling/api/oscar.py         |  14 --
 .../services/scheduling/supervisor/manager.py |  47 ++++--
 .../services/scheduling/tests/test_service.py |   3 -
 mars/services/scheduling/worker/__init__.py   |   2 +-
 .../worker/{exec => execution}/__init__.py    |   1 +
 .../worker/{exec => execution}/actor.py       | 146 ++++++++++++++----
 .../worker/{exec => execution}/core.py        |  15 +-
 .../worker/{exec => execution}/prepare.py     |   0
 .../{exec => execution}/tests/__init__.py     |   0
 .../{exec => execution}/tests/test_exec.py    |   0
 .../{exec => execution}/tests/test_prepare.py |   0
 mars/services/scheduling/worker/service.py    |   2 +-
 mars/services/task/supervisor/stage.py        |   4 -
 .../services/tests/fault_injection_manager.py |   3 +
 mars/services/tests/fault_injection_patch.py  |  70 ++++++++-
 16 files changed, 265 insertions(+), 110 deletions(-)
 rename mars/services/scheduling/worker/{exec => execution}/__init__.py (94%)
 rename mars/services/scheduling/worker/{exec => execution}/actor.py (82%)
 rename mars/services/scheduling/worker/{exec => execution}/core.py (81%)
 rename mars/services/scheduling/worker/{exec => execution}/prepare.py (100%)
 rename mars/services/scheduling/worker/{exec => execution}/tests/__init__.py (100%)
 rename mars/services/scheduling/worker/{exec => execution}/tests/test_exec.py (100%)
 rename mars/services/scheduling/worker/{exec => execution}/tests/test_prepare.py (100%)

diff --git a/mars/deploy/oscar/tests/test_cmdline.py b/mars/deploy/oscar/tests/test_cmdline.py
index 836a5d3b01..e71b6054af 100644
--- a/mars/deploy/oscar/tests/test_cmdline.py
+++ b/mars/deploy/oscar/tests/test_cmdline.py
@@ -111,38 +111,6 @@ def _get_labelled_port(label=None, create=True):
 
 supervisor_cmd_start = [sys.executable, "-m", "mars.deploy.oscar.supervisor"]
 worker_cmd_start = [sys.executable, "-m", "mars.deploy.oscar.worker"]
-start_params = {
-    "bare_start": [
-        supervisor_cmd_start,
-        worker_cmd_start
-        + [
-            "--config-file",
-            os.path.join(os.path.dirname(__file__), "local_test_config.yml"),
-        ],
-        False,
-    ],
-    "with_supervisors": [
-        supervisor_cmd_start
-        + [
-            "-e",
-            lambda: f'127.0.0.1:{_get_labelled_port("supervisor")}',
-            "-w",
-            lambda: str(_get_labelled_port("web")),
-            "--n-process",
-            "2",
-        ],
-        worker_cmd_start
-        + [
-            "-e",
-            lambda: f"127.0.0.1:{get_next_port(occupy=True)}",
-            "-s",
-            lambda: f'127.0.0.1:{_get_labelled_port("supervisor")}',
-            "--config-file",
-            os.path.join(os.path.dirname(__file__), "local_test_config.yml"),
-        ],
-        True,
-    ],
-}
 
 
 def _reload_args(args):
@@ -159,8 +127,40 @@ def _reload_args(args):
 
 @pytest.mark.parametrize(
     "supervisor_args,worker_args,use_web_addr",
-    list(start_params.values()),
-    ids=list(start_params.keys()),
+    [
+        pytest.param(
+            supervisor_cmd_start,
+            worker_cmd_start
+            + [
+                "--config-file",
+                os.path.join(os.path.dirname(__file__), "local_test_config.yml"),
+            ],
+            False,
+            id="bare_start",
+        ),
+        pytest.param(
+            supervisor_cmd_start
+            + [
+                "-e",
+                lambda: f'127.0.0.1:{_get_labelled_port("supervisor")}',
+                "-w",
+                lambda: str(_get_labelled_port("web")),
+                "--n-process",
+                "2",
+            ],
+            worker_cmd_start
+            + [
+                "-e",
+                lambda: f"127.0.0.1:{get_next_port(occupy=True)}",
+                "-s",
+                lambda: f'127.0.0.1:{_get_labelled_port("supervisor")}',
+                "--config-file",
+                os.path.join(os.path.dirname(__file__), "local_test_config.yml"),
+            ],
+            True,
+            id="with_supervisors",
+        ),
+    ],
 )
 @flaky(max_runs=10, rerun_filter=lambda err, *_: issubclass(err[0], _rerun_errors))
 def test_cmdline_run(supervisor_args, worker_args, use_web_addr):
diff --git a/mars/services/scheduling/api/oscar.py b/mars/services/scheduling/api/oscar.py
index e2d5862d64..fb8c4f8b5c 100644
--- a/mars/services/scheduling/api/oscar.py
+++ b/mars/services/scheduling/api/oscar.py
@@ -117,20 +117,6 @@ async def cancel_subtasks(
         """
         await self._manager_ref.cancel_subtasks(subtask_ids, kill_timeout=kill_timeout)
 
-    async def finish_subtasks(self, subtask_ids: List[str], schedule_next: bool = True):
-        """
-        Mark subtasks as finished, letting scheduling service to schedule
-        next tasks in the ready queue
-
-        Parameters
-        ----------
-        subtask_ids
-            ids of subtasks to mark as finished
-        schedule_next
-            whether to schedule succeeding subtasks
-        """
-        await self._manager_ref.finish_subtasks(subtask_ids, schedule_next)
-
 
 class MockSchedulingAPI(SchedulingAPI):
     @classmethod
diff --git a/mars/services/scheduling/supervisor/manager.py b/mars/services/scheduling/supervisor/manager.py
index 149b5fd035..758e91debd 100644
--- a/mars/services/scheduling/supervisor/manager.py
+++ b/mars/services/scheduling/supervisor/manager.py
@@ -95,8 +95,7 @@ async def __post_create__(self):
             AssignerActor.gen_uid(self._session_id), address=self.address
         )
 
-    @alru_cache
-    async def _get_task_api(self):
+    async def _get_task_api(self) -> TaskAPI:
         return await TaskAPI.create(self._session_id, self.address)
 
     def _put_subtask_with_priority(self, subtask: Subtask, priority: Tuple = None):
@@ -272,21 +271,47 @@ async def update_subtask_priorities(
 
     @alru_cache(maxsize=10000)
     async def _get_execution_ref(self, address: str):
-        from ..worker.exec import SubtaskExecutionActor
+        from ..worker.execution import SubtaskExecutionActor
 
         return await mo.actor_ref(SubtaskExecutionActor.default_uid(), address=address)
 
-    async def finish_subtasks(self, subtask_ids: List[str], schedule_next: bool = True):
-        band_tasks = defaultdict(lambda: 0)
-        for subtask_id in subtask_ids:
-            subtask_info = self._subtask_infos.pop(subtask_id, None)
+    async def set_subtask_results(
+        self, subtask_results: List[SubtaskResult], source_bands: List[BandType]
+    ):
+        delays = []
+        task_api = await self._get_task_api()
+        for result, band in zip(subtask_results, source_bands):
+            if result.status == SubtaskStatus.errored:
+                subtask_info = self._subtask_infos.get(result.subtask_id)
+                if (
+                    subtask_info is not None
+                    and subtask_info.subtask.retryable
+                    and subtask_info.num_reschedules < subtask_info.max_reschedules
+                    and isinstance(result.error, (MarsError, OSError))
+                ):
+                    subtask_info.num_reschedules += 1
+                    logger.warning(
+                        "Resubmit subtask %s at attempt %d",
+                        subtask_info.subtask.subtask_id,
+                        subtask_info.num_reschedules,
+                    )
+                    execution_ref = await self._get_execution_ref(band[0])
+                    await execution_ref.submit_subtasks.tell(
+                        [subtask_info.subtask],
+                        [subtask_info.priority],
+                        self.address,
+                        band[1],
+                    )
+                    continue
+
+            subtask_info = self._subtask_infos.pop(result.subtask_id, None)
             if subtask_info is not None:
-                self._subtask_summaries[subtask_id] = subtask_info.to_summary(
+                self._subtask_summaries[result.subtask_id] = subtask_info.to_summary(
                     is_finished=True
                 )
-                if schedule_next:
-                    for band in subtask_info.submitted_bands:
-                        band_tasks[band] += 1
+            delays.append(task_api.set_subtask_result.delay(result))
+
+        await task_api.set_subtask_result.batch(*delays)
 
     def _get_subtasks_by_ids(self, subtask_ids: List[str]) -> List[Optional[Subtask]]:
         subtasks = []
diff --git a/mars/services/scheduling/tests/test_service.py b/mars/services/scheduling/tests/test_service.py
index 06e0e150c0..0362d4c065 100644
--- a/mars/services/scheduling/tests/test_service.py
+++ b/mars/services/scheduling/tests/test_service.py
@@ -171,7 +171,6 @@ async def test_schedule_success(actor_pools):
     subtask.expect_bands = [(worker_pool.external_address, "numa-0")]
     await scheduling_api.add_subtasks([subtask], [(0,)])
     await task_manager_ref.wait_subtask_result(subtask.subtask_id)
-    await scheduling_api.finish_subtasks([subtask.subtask_id])
 
     result_key = next(subtask.chunk_graph.iter_indep(reverse=True)).key
     result = await storage_api.get(result_key)
@@ -197,7 +196,6 @@ def _remote_fun(secs):
 
     async def _waiter_fun(subtask_id):
         await task_manager_ref.wait_subtask_result(subtask_id)
-        await scheduling_api.finish_subtasks([subtask_id])
         finish_ids.append(subtask_id)
         finish_time.append(time.time())
 
@@ -245,7 +243,6 @@ def _remote_fun(secs):
 
     async def _waiter_fun(subtask_id):
         await task_manager_ref.wait_subtask_result(subtask_id)
-        await scheduling_api.finish_subtasks([subtask_id])
 
     subtasks = []
     wait_tasks = []
diff --git a/mars/services/scheduling/worker/__init__.py b/mars/services/scheduling/worker/__init__.py
index 369b0adbfb..d4f84b6125 100644
--- a/mars/services/scheduling/worker/__init__.py
+++ b/mars/services/scheduling/worker/__init__.py
@@ -12,7 +12,7 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from .exec import SubtaskExecutionActor
+from .execution import SubtaskExecutionActor
 from .queues import SubtaskExecutionQueueActor, SubtaskPrepareQueueActor
 from .quota import QuotaActor, MemQuotaActor, WorkerQuotaManagerActor
 from .service import SchedulingWorkerService
diff --git a/mars/services/scheduling/worker/exec/__init__.py b/mars/services/scheduling/worker/execution/__init__.py
similarity index 94%
rename from mars/services/scheduling/worker/exec/__init__.py
rename to mars/services/scheduling/worker/execution/__init__.py
index d6091cde2e..83f5970252 100644
--- a/mars/services/scheduling/worker/exec/__init__.py
+++ b/mars/services/scheduling/worker/execution/__init__.py
@@ -13,4 +13,5 @@
 # limitations under the License.
 
 from .actor import SubtaskExecutionActor
+from .core import SubtaskExecutionInfo
 from .prepare import SubtaskPreparer
diff --git a/mars/services/scheduling/worker/exec/actor.py b/mars/services/scheduling/worker/execution/actor.py
similarity index 82%
rename from mars/services/scheduling/worker/exec/actor.py
rename to mars/services/scheduling/worker/execution/actor.py
index d8776a0fac..bcdd97295d 100644
--- a/mars/services/scheduling/worker/exec/actor.py
+++ b/mars/services/scheduling/worker/execution/actor.py
@@ -26,7 +26,6 @@
 from ....cluster import ClusterAPI
 from ....core import ActorCallback
 from ....subtask import Subtask, SubtaskAPI, SubtaskResult, SubtaskStatus
-from ....task import TaskAPI
 from ..queues import SubtaskPrepareQueueActor, SubtaskExecutionQueueActor
 from ..quota import QuotaActor
 from ..slotmanager import SlotManagerActor
@@ -102,6 +101,16 @@ async def _get_band_quota_ref(
     ) -> Union[mo.ActorRef, QuotaActor]:
         return await mo.actor_ref(QuotaActor.gen_uid(band_name), address=self.address)
 
+    @staticmethod
+    @alru_cache(cache_exceptions=False)
+    async def _get_manager_ref(session_id: str, supervisor_address: str):
+        from ...supervisor.manager import SubtaskManagerActor
+
+        return await mo.actor_ref(
+            uid=SubtaskManagerActor.gen_uid(session_id),
+            address=supervisor_address,
+        )
+
     def _build_subtask_info(
         self,
         subtask: Subtask,
@@ -109,12 +118,20 @@ def _build_subtask_info(
         supervisor_address: str,
         band_name: str,
     ) -> SubtaskExecutionInfo:
+        subtask_max_retries = (
+            subtask.extra_config.get("subtask_max_retries")
+            if subtask.extra_config
+            else None
+        )
+        if subtask_max_retries is None:
+            subtask_max_retries = self._subtask_max_retries
+
         subtask_info = SubtaskExecutionInfo(
             subtask,
             priority,
             supervisor_address=supervisor_address,
             band_name=band_name,
-            max_retries=self._subtask_max_retries,
+            max_retries=subtask_max_retries,
         )
         subtask_info.result = SubtaskResult(
             subtask_id=subtask.subtask_id,
@@ -216,8 +233,15 @@ async def submit_subtasks(
                     subtask = self._subtask_caches[subtask].subtask
                 except KeyError:
                     subtask = self._subtask_executions[subtask].subtask
-            if subtask.subtask_id in self._subtask_executions:
-                continue
+            try:
+                info = self._subtask_executions[subtask.subtask_id]
+                if info.result.status not in (
+                    SubtaskStatus.cancelled,
+                    SubtaskStatus.errored,
+                ):
+                    continue
+            except KeyError:
+                pass
 
             subtask_info = self._build_subtask_info(
                 subtask,
@@ -252,18 +276,19 @@ async def _dequeue_subtask_ids(self, queue_ref, subtask_ids: List[str]):
                 infos_to_report.append(subtask_info)
         await self._report_subtask_results(infos_to_report)
 
-    @staticmethod
-    async def _report_subtask_results(subtask_infos: List[SubtaskExecutionInfo]):
+    async def _report_subtask_results(self, subtask_infos: List[SubtaskExecutionInfo]):
         if not subtask_infos:
             return
-        task_api = await TaskAPI.create(
-            subtask_infos[0].result.session_id, subtask_infos[0].supervisor_address
+        try:
+            manager_ref = await self._get_manager_ref(
+                subtask_infos[0].result.session_id, subtask_infos[0].supervisor_address
+            )
+        except mo.ActorNotExist:
+            return
+        await manager_ref.set_subtask_results.tell(
+            [info.result for info in subtask_infos],
+            [(self.address, info.band_name) for info in subtask_infos],
         )
-        batch = [
-            task_api.set_subtask_result.delay(subtask_info.result)
-            for subtask_info in subtask_infos
-        ]
-        await task_api.set_subtask_result.batch(*batch)
 
     async def cancel_subtasks(
         self, subtask_ids: List[str], kill_timeout: Optional[int] = 5
@@ -289,6 +314,7 @@ async def cancel_subtasks(
 
         self.uncache_subtasks(subtask_ids)
 
+        infos_to_report = []
         for subtask_id in subtask_ids:
             try:
                 subtask_info = self._subtask_executions[subtask_id]
@@ -296,6 +322,8 @@ async def cancel_subtasks(
                 continue
             if not subtask_info.result.status.is_done:
                 self._fill_result_with_exc(subtask_info, exc_cls=asyncio.CancelledError)
+                infos_to_report.append(subtask_info)
+        await self._report_subtask_results(infos_to_report)
 
     async def wait_subtasks(self, subtask_ids: List[str]):
         infos = [
@@ -307,6 +335,28 @@ async def wait_subtasks(self, subtask_ids: List[str]):
             yield asyncio.wait([info.finish_future for info in infos])
         raise mo.Return([info.result for info in infos])
 
+    def _create_subtask_with_exception(self, subtask_id, coro):
+        info = self._subtask_executions[subtask_id]
+
+        async def _run_with_exception_handling():
+            try:
+                return await coro
+            except:  # noqa: E722  # nosec  # pylint: disable=bare-except
+                self._fill_result_with_exc(info)
+                await self._report_subtask_results([info])
+                await self._prepare_queue_ref.release_slot(
+                    info.subtask.subtask_id, errors="ignore"
+                )
+                await self._execution_queue_ref.release_slot(
+                    info.subtask.subtask_id, errors="ignore"
+                )
+                for aio_task in info.aio_tasks:
+                    if aio_task is not asyncio.current_task():
+                        aio_task.cancel()
+
+        task = asyncio.create_task(_run_with_exception_handling())
+        info.aio_tasks.append(task)
+
     async def handle_prepare_queue(self, band_name: str):
         while True:
             try:
@@ -322,8 +372,8 @@ async def handle_prepare_queue(self, band_name: str):
                 continue
 
             logger.debug(f"Obtained subtask {subtask_id} from prepare queue")
-            subtask_info.aio_tasks.append(
-                asyncio.create_task(self._prepare_subtask_with_retry(subtask_info))
+            self._create_subtask_with_exception(
+                subtask_id, self._prepare_subtask_with_retry(subtask_info)
             )
 
     async def handle_execute_queue(self, band_name: str):
@@ -355,8 +405,8 @@ async def handle_execute_queue(self, band_name: str):
                 c.key in self._pred_key_mapping_dag
                 for c in subtask_info.subtask.chunk_graph.result_chunks
             )
-            subtask_info.aio_tasks.append(
-                asyncio.create_task(self._execute_subtask_with_retry(subtask_info))
+            self._create_subtask_with_exception(
+                subtask_id, self._execute_subtask_with_retry(subtask_info)
             )
 
     async def _prepare_subtask_once(self, subtask_info: SubtaskExecutionInfo):
@@ -438,7 +488,20 @@ async def _execute_subtask_with_retry(self, subtask_info: SubtaskExecutionInfo):
                 subtask_info,
                 max_retries=subtask_info.max_retries if subtask.retryable else 0,
             )
-        except:  # noqa: E722  # nosec  # pylint: disable=bare-except
+        except Exception as ex:  # noqa: E722  # nosec  # pylint: disable=bare-except
+            if not subtask.retryable:
+                unretryable_op = [
+                    chunk.op
+                    for chunk in subtask.chunk_graph
+                    if not getattr(chunk.op, "retryable", True)
+                ]
+                logger.exception(
+                    "Run subtask failed due to %r, the subtask %s is "
+                    "not retryable, it contains unretryable op: %r",
+                    ex,
+                    subtask.subtask_id,
+                    unretryable_op,
+                )
             self._fill_result_with_exc(subtask_info)
         finally:
             self._subtask_executions.pop(subtask.subtask_id, None)
@@ -454,18 +517,19 @@ async def _execute_subtask_with_retry(self, subtask_info: SubtaskExecutionInfo):
         return subtask_info.result
 
     @classmethod
-    async def _call_with_retry(
+    def _log_subtask_retry(
         cls,
-        target_func: Callable,
         subtask_info: SubtaskExecutionInfo,
-        max_retries: Optional[int] = None,
+        target_func: Callable,
+        trial: int,
+        exc_info: Tuple,
+        retry: bool = True,
     ):
         subtask = subtask_info.subtask
-        max_retries = max_retries or subtask_info.max_retries
-
-        def log_func(trial: int, exc_info: Tuple, retry: bool = True):
-            subtask_info.num_retries = trial
-            if retry:
+        max_retries = subtask_info.max_retries
+        subtask_info.num_retries = trial
+        if retry:
+            if trial < max_retries - 1:
                 logger.error(
                     "Rerun %s of subtask %s at attempt %d due to %s",
                     target_func,
@@ -475,22 +539,42 @@ def log_func(trial: int, exc_info: Tuple, retry: bool = True):
                 )
             else:
                 logger.exception(
-                    "Failed to rerun the %s of subtask %s, "
-                    "num_retries: %s, max_retries: %s",
-                    target_func,
-                    subtask.subtask_id,
+                    "Exceed max rerun (%s / %s): %s of subtask %s due to %s",
                     trial + 1,
                     max_retries,
+                    target_func,
+                    subtask.subtask_id,
+                    exc_info[1],
                     exc_info=exc_info,
                 )
+        else:
+            logger.exception(
+                "Failed to rerun %s of subtask %s due to unhandled exception: %s",
+                target_func,
+                subtask.subtask_id,
+                exc_info[1],
+                exc_info=exc_info,
+            )
+
+    @classmethod
+    async def _call_with_retry(
+        cls,
+        target_func: Callable,
+        subtask_info: SubtaskExecutionInfo,
+        max_retries: Optional[int] = None,
+    ):
+        subtask_info.max_retries = max_retries or subtask_info.max_retries
 
         if subtask_info.max_retries <= 1:
             return await target_func(subtask_info)
         else:
+            err_callback = functools.partial(
+                cls._log_subtask_retry, subtask_info, target_func
+            )
             return await call_with_retry(
                 functools.partial(target_func, subtask_info),
                 max_retries=max_retries,
-                error_callback=log_func,
+                error_callback=err_callback,
             )
 
     @classmethod
diff --git a/mars/services/scheduling/worker/exec/core.py b/mars/services/scheduling/worker/execution/core.py
similarity index 81%
rename from mars/services/scheduling/worker/exec/core.py
rename to mars/services/scheduling/worker/execution/core.py
index dcdcb891de..01cab8e170 100644
--- a/mars/services/scheduling/worker/exec/core.py
+++ b/mars/services/scheduling/worker/execution/core.py
@@ -58,12 +58,17 @@ async def call_with_retry(
         try:
             return await async_fun()
         except (OSError, MarsError):
+            exc_info_raw = sys.exc_info()
+            exc_info = error_callback(trial=trial, exc_info=exc_info_raw, retry=True)
+            exc_info = exc_info or exc_info_raw
+
             if trial >= max_retries - 1:
-                error_callback(trial=trial, exc_info=sys.exc_info(), retry=False)
-                raise
-            error_callback(trial=trial, exc_info=sys.exc_info(), retry=True)
+                raise exc_info[1].with_traceback(exc_info[-1])
         except asyncio.CancelledError:
             raise
         except:  # noqa: E722  # nosec  # pylint: disable=bare-except
-            error_callback(trial=trial, exc_info=sys.exc_info(), retry=False)
-            raise
+            exc_info_raw = sys.exc_info()
+            exc_info = error_callback(trial=trial, exc_info=exc_info_raw, retry=False)
+            exc_info = exc_info or exc_info_raw
+
+            raise exc_info[1].with_traceback(exc_info[-1])
diff --git a/mars/services/scheduling/worker/exec/prepare.py b/mars/services/scheduling/worker/execution/prepare.py
similarity index 100%
rename from mars/services/scheduling/worker/exec/prepare.py
rename to mars/services/scheduling/worker/execution/prepare.py
diff --git a/mars/services/scheduling/worker/exec/tests/__init__.py b/mars/services/scheduling/worker/execution/tests/__init__.py
similarity index 100%
rename from mars/services/scheduling/worker/exec/tests/__init__.py
rename to mars/services/scheduling/worker/execution/tests/__init__.py
diff --git a/mars/services/scheduling/worker/exec/tests/test_exec.py b/mars/services/scheduling/worker/execution/tests/test_exec.py
similarity index 100%
rename from mars/services/scheduling/worker/exec/tests/test_exec.py
rename to mars/services/scheduling/worker/execution/tests/test_exec.py
diff --git a/mars/services/scheduling/worker/exec/tests/test_prepare.py b/mars/services/scheduling/worker/execution/tests/test_prepare.py
similarity index 100%
rename from mars/services/scheduling/worker/exec/tests/test_prepare.py
rename to mars/services/scheduling/worker/execution/tests/test_prepare.py
diff --git a/mars/services/scheduling/worker/service.py b/mars/services/scheduling/worker/service.py
index 1ed6e4eaae..a71694d7c7 100644
--- a/mars/services/scheduling/worker/service.py
+++ b/mars/services/scheduling/worker/service.py
@@ -15,10 +15,10 @@
 from .... import oscar as mo
 from ....utils import calc_size_by_str
 from ...core import AbstractService
+from .execution import SubtaskExecutionActor
 from .slotmanager import SlotManagerActor
 from .queues import SubtaskPrepareQueueActor, SubtaskExecutionQueueActor
 from .quota import WorkerQuotaManagerActor
-from .exec import SubtaskExecutionActor
 
 
 class SchedulingWorkerService(AbstractService):
diff --git a/mars/services/task/supervisor/stage.py b/mars/services/task/supervisor/stage.py
index 4afd314f16..9d6a5aa64a 100644
--- a/mars/services/task/supervisor/stage.py
+++ b/mars/services/task/supervisor/stage.py
@@ -152,9 +152,6 @@ async def set_subtask_result(self, result: SubtaskResult):
                 await self._update_chunks_meta(self.chunk_graph)
 
             # tell scheduling to finish subtasks
-            await self._scheduling_api.finish_subtasks(
-                [result.subtask_id], schedule_next=not error_or_cancelled
-            )
             if self.result.status != TaskStatus.terminated:
                 self.result = TaskResult(
                     self.task.task_id,
@@ -196,7 +193,6 @@ async def set_subtask_result(self, result: SubtaskResult):
                     # all predecessors finished
                     to_schedule_subtasks.append(succ_subtask)
             await self._schedule_subtasks(to_schedule_subtasks)
-            await self._scheduling_api.finish_subtasks([result.subtask_id])
 
     async def run(self):
         if len(self.subtask_graph) == 0:
diff --git a/mars/services/tests/fault_injection_manager.py b/mars/services/tests/fault_injection_manager.py
index 06fa7fb310..0b29d10881 100644
--- a/mars/services/tests/fault_injection_manager.py
+++ b/mars/services/tests/fault_injection_manager.py
@@ -13,6 +13,7 @@
 # limitations under the License.
 
 import enum
+import logging
 import os
 import uuid
 from abc import ABC, abstractmethod
@@ -20,6 +21,8 @@
 from ...core.base import MarsError
 from ..session import SessionAPI
 
+logger = logging.getLogger(__name__)
+
 
 class ExtraConfigKey:
     FAULT_INJECTION_MANAGER_NAME = "fault_injection_manager_name"
diff --git a/mars/services/tests/fault_injection_patch.py b/mars/services/tests/fault_injection_patch.py
index 6c410a4278..7e2e07fe68 100644
--- a/mars/services/tests/fault_injection_patch.py
+++ b/mars/services/tests/fault_injection_patch.py
@@ -12,15 +12,15 @@
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from typing import Union, Any, Dict
+import sys
+from typing import Any, Callable, Dict, Tuple, Union
 
 from ... import oscar as mo
 from ...core import OperandType
 from ...lib.aio import alru_cache
 from ...tests.core import patch_cls, patch_super as super
 from ..session import SessionAPI
-from ..scheduling.worker.exec import SubtaskExecutionActor
-from ..subtask import Subtask
+from ..scheduling.worker.execution import SubtaskExecutionActor, SubtaskExecutionInfo
 from ..subtask.worker.processor import SubtaskProcessor
 from ..tests.fault_injection_manager import (
     AbstractFaultInjectionManager,
@@ -44,14 +44,38 @@ async def _get_fault_injection_manager_ref(
     async def _get_session_api(supervisor_address: str):
         return await SessionAPI.create(supervisor_address)
 
-    async def internal_run_subtask(self, subtask: Subtask, band_name: str):
+    async def _execute_subtask_once(self, subtask_info: SubtaskExecutionInfo):
+        try:
+            return await super()._execute_subtask_once(subtask_info)
+        except:  # noqa: E722  # nosec  # pylint: disable=bare-except
+            exc_info = sys.exc_info()
+            subtask = subtask_info.subtask
+            if not subtask.retryable:
+                unretryable_op = [
+                    chunk.op
+                    for chunk in subtask.chunk_graph
+                    if not getattr(chunk.op, "retryable", True)
+                ]
+                message = (
+                    f"Run subtask failed due to {exc_info[1]}, "
+                    f"the subtask {subtask.subtask_id} is not retryable, "
+                    f"it contains unretryable op: {unretryable_op!r}"
+                )
+                _UnretryableException = type(
+                    "_UnretryableException", (exc_info[0],), {}
+                )
+                raise _UnretryableException(message).with_traceback(exc_info[-1])
+            else:
+                raise
+
+    async def _execute_subtask_with_retry(self, subtask_info: SubtaskExecutionInfo):
+        subtask = subtask_info.subtask
         # fault injection
         if subtask.extra_config:
             fault_injection_manager_name = subtask.extra_config.get(
                 ExtraConfigKey.FAULT_INJECTION_MANAGER_NAME
             )
             if fault_injection_manager_name is not None:
-                subtask_info = self._subtask_info[subtask.subtask_id]
                 fault_injection_manager = await self._get_fault_injection_manager_ref(
                     subtask_info.supervisor_address,
                     subtask.session_id,
@@ -61,7 +85,41 @@ async def internal_run_subtask(self, subtask: Subtask, band_name: str):
                     FaultPosition.ON_RUN_SUBTASK, {"subtask": subtask}
                 )
                 handle_fault(fault)
-        return super().internal_run_subtask(subtask, band_name)
+        return await super()._execute_subtask_with_retry(subtask_info)
+
+    @classmethod
+    def _log_subtask_retry(
+        cls,
+        subtask_info: SubtaskExecutionInfo,
+        target_func: Callable,
+        trial: int,
+        exc_info: Tuple,
+        retry: bool = True,
+    ):
+        exc_info = (
+            super()._log_subtask_retry(
+                subtask_info, target_func, trial, exc_info, retry=retry
+            )
+            or exc_info
+        )
+
+        if retry:
+            if trial < subtask_info.max_retries - 1:
+                return exc_info
+            else:
+                _ExceedMaxRerun = type("_ExceedMaxRerun", (exc_info[0],), {})
+                return (
+                    _ExceedMaxRerun,
+                    _ExceedMaxRerun(str(exc_info[1])).with_traceback(exc_info[-1]),
+                    exc_info[-1],
+                )
+        else:
+            _UnhandledException = type("_UnhandledException", (exc_info[0],), {})
+            return (
+                _UnhandledException,
+                _UnhandledException(str(exc_info[1])).with_traceback(exc_info[-1]),
+                exc_info[-1],
+            )
 
 
 @patch_cls(SubtaskProcessor)