diff --git a/parsl/executors/high_throughput/executor.py b/parsl/executors/high_throughput/executor.py index 7f4408eb1c..69af77a86b 100644 --- a/parsl/executors/high_throughput/executor.py +++ b/parsl/executors/high_throughput/executor.py @@ -15,7 +15,6 @@ from parsl.serialize import pack_apply_message, deserialize from parsl.serialize.errors import SerializationError, DeserializationError from parsl.app.errors import RemoteExceptionWrapper -from parsl.executors.base import ParslExecutor from parsl.jobs.states import JobStatus from parsl.executors.high_throughput import zmq_pipes from parsl.executors.high_throughput import interchange @@ -214,7 +213,7 @@ def __init__(self, poll_period: int = 10, address_probe_timeout: Optional[int] = None, worker_logdir_root: Optional[str] = None, - block_error_handler: Union[bool, Callable[[ParslExecutor, Dict[str, JobStatus]], None]] = False): + block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True): logger.debug("Initializing HighThroughputExecutor") diff --git a/parsl/executors/status_handling.py b/parsl/executors/status_handling.py index e5988bec91..16450d472d 100644 --- a/parsl/executors/status_handling.py +++ b/parsl/executors/status_handling.py @@ -1,3 +1,4 @@ +from __future__ import annotations import logging import threading from itertools import compress @@ -9,6 +10,7 @@ from parsl.executors.base import ParslExecutor from parsl.executors.errors import BadStateException, ScalingFailed from parsl.jobs.states import JobStatus, JobState +from parsl.jobs.simple_error_handler import simple_error_handler, noop_error_handler from parsl.providers.base import ExecutionProvider from parsl.utils import AtomicIDCounter @@ -45,10 +47,18 @@ class BlockProviderExecutor(ParslExecutor): """ def __init__(self, *, provider: Optional[ExecutionProvider], - block_error_handler: Union[bool, Callable[[ParslExecutor, Dict[str, JobStatus]], None]]): + block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]]): super().__init__() self._provider = provider - self.block_error_handler = block_error_handler + self.block_error_handler: Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None] + if isinstance(block_error_handler, bool): + if block_error_handler: + self.block_error_handler = simple_error_handler + else: + self.block_error_handler = noop_error_handler + else: + self.block_error_handler = block_error_handler + # errors can happen during the submit call to the provider; this is used # to keep track of such errors so that they can be handled in one place # together with errors reported by status() @@ -159,9 +169,6 @@ def handle_errors(self, status: Dict[str, JobStatus]) -> None: scheme will be used. :param status: status of all jobs launched by this executor """ - if not self.block_error_handler: - return - assert isinstance(Callable, self.block_error_handler) self.block_error_handler(self, status) @property diff --git a/parsl/jobs/simple_error_handler.py b/parsl/jobs/simple_error_handler.py index 550c883787..2bd91a8c4b 100644 --- a/parsl/jobs/simple_error_handler.py +++ b/parsl/jobs/simple_error_handler.py @@ -6,14 +6,21 @@ from parsl.jobs.states import JobStatus, JobState +def noop_error_handler(executor: status_handling.BlockProviderExecutor, status: Dict[str, JobStatus], threshold: int = 3): + pass + + def simple_error_handler(executor: status_handling.BlockProviderExecutor, status: Dict[str, JobStatus], threshold: int = 3): (total_jobs, failed_jobs) = _count_jobs(status) + if hasattr(executor.provider, "init_blocks"): + threshold = max(1, executor.provider.init_blocks) + if total_jobs >= threshold and failed_jobs == total_jobs: executor.set_bad_state_and_fail_all(_get_error(status)) def windowed_error_handler(executor: status_handling.BlockProviderExecutor, status: Dict[str, JobStatus], threshold: int = 3): - sorted_status = [(key, status[key]) for key in sorted(status)] + sorted_status = [(key, status[key]) for key in sorted(status, key=lambda x: int(x))] current_window = dict(sorted_status[-threshold:]) total, failed = _count_jobs(current_window) if failed == threshold: diff --git a/parsl/tests/test_scaling/test_block_error_handler.py b/parsl/tests/test_scaling/test_block_error_handler.py index ed9d476190..37241eb756 100644 --- a/parsl/tests/test_scaling/test_block_error_handler.py +++ b/parsl/tests/test_scaling/test_block_error_handler.py @@ -1,9 +1,10 @@ import pytest from parsl.executors import HighThroughputExecutor +from parsl.providers import LocalProvider from unittest.mock import Mock from parsl.jobs.states import JobStatus, JobState -from parsl.jobs.simple_error_handler import simple_error_handler, windowed_error_handler +from parsl.jobs.simple_error_handler import simple_error_handler, windowed_error_handler, noop_error_handler from functools import partial @@ -11,7 +12,7 @@ def test_block_error_handler_false(): mock = Mock() htex = HighThroughputExecutor(block_error_handler=False) - assert htex.block_error_handler is False + assert htex.block_error_handler is noop_error_handler htex.set_bad_state_and_fail_all = mock bad_jobs = {'1': JobStatus(JobState.FAILED), @@ -44,7 +45,9 @@ def test_block_error_handler_mock(): @pytest.mark.local def test_simple_error_handler(): - htex = HighThroughputExecutor(block_error_handler=simple_error_handler) + htex = HighThroughputExecutor(block_error_handler=simple_error_handler, + provider=LocalProvider(init_blocks=3)) + assert htex.block_error_handler is simple_error_handler bad_state_mock = Mock() @@ -108,6 +111,31 @@ def test_windowed_error_handler(): bad_state_mock.assert_called() +@pytest.mark.local +def test_windowed_error_handler_sorting(): + htex = HighThroughputExecutor(block_error_handler=windowed_error_handler) + assert htex.block_error_handler is windowed_error_handler + + bad_state_mock = Mock() + htex.set_bad_state_and_fail_all = bad_state_mock + + bad_jobs = {'8': JobStatus(JobState.FAILED), + '9': JobStatus(JobState.FAILED), + '10': JobStatus(JobState.FAILED), + '11': JobStatus(JobState.COMPLETED), + '12': JobStatus(JobState.COMPLETED)} + htex.handle_errors(bad_jobs) + bad_state_mock.assert_not_called() + + bad_jobs = {'8': JobStatus(JobState.COMPLETED), + '9': JobStatus(JobState.FAILED), + '21': JobStatus(JobState.FAILED), + '22': JobStatus(JobState.FAILED), + '10': JobStatus(JobState.FAILED)} + htex.handle_errors(bad_jobs) + bad_state_mock.assert_called() + + @pytest.mark.local def test_windowed_error_handler_with_threshold(): error_handler = partial(windowed_error_handler, threshold=2)