Skip to content

Commit

Permalink
Several fixes from Ben's review
Browse files Browse the repository at this point in the history
* Type cleanups
* Updating tests
* Fixing a sorting error
* Adding a `noop_error_handler`
* Minor fixes to simple_error_handler to match previous logic
  • Loading branch information
yadudoc committed Aug 11, 2023
1 parent d20ee05 commit 1e78474
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 11 deletions.
3 changes: 1 addition & 2 deletions parsl/executors/high_throughput/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from parsl.serialize import pack_apply_message, deserialize
from parsl.serialize.errors import SerializationError, DeserializationError
from parsl.app.errors import RemoteExceptionWrapper
from parsl.executors.base import ParslExecutor
from parsl.jobs.states import JobStatus
from parsl.executors.high_throughput import zmq_pipes
from parsl.executors.high_throughput import interchange
Expand Down Expand Up @@ -214,7 +213,7 @@ def __init__(self,
poll_period: int = 10,
address_probe_timeout: Optional[int] = None,
worker_logdir_root: Optional[str] = None,
block_error_handler: Union[bool, Callable[[ParslExecutor, Dict[str, JobStatus]], None]] = False):
block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]] = True):

logger.debug("Initializing HighThroughputExecutor")

Expand Down
17 changes: 12 additions & 5 deletions parsl/executors/status_handling.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from __future__ import annotations
import logging
import threading
from itertools import compress
Expand All @@ -9,6 +10,7 @@
from parsl.executors.base import ParslExecutor
from parsl.executors.errors import BadStateException, ScalingFailed
from parsl.jobs.states import JobStatus, JobState
from parsl.jobs.simple_error_handler import simple_error_handler, noop_error_handler
from parsl.providers.base import ExecutionProvider
from parsl.utils import AtomicIDCounter

Expand Down Expand Up @@ -45,10 +47,18 @@ class BlockProviderExecutor(ParslExecutor):
"""
def __init__(self, *,
provider: Optional[ExecutionProvider],
block_error_handler: Union[bool, Callable[[ParslExecutor, Dict[str, JobStatus]], None]]):
block_error_handler: Union[bool, Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]]):
super().__init__()
self._provider = provider
self.block_error_handler = block_error_handler
self.block_error_handler: Callable[[BlockProviderExecutor, Dict[str, JobStatus]], None]
if isinstance(block_error_handler, bool):
if block_error_handler:
self.block_error_handler = simple_error_handler
else:
self.block_error_handler = noop_error_handler
else:
self.block_error_handler = block_error_handler

# errors can happen during the submit call to the provider; this is used
# to keep track of such errors so that they can be handled in one place
# together with errors reported by status()
Expand Down Expand Up @@ -159,9 +169,6 @@ def handle_errors(self, status: Dict[str, JobStatus]) -> None:
scheme will be used.
:param status: status of all jobs launched by this executor
"""
if not self.block_error_handler:
return
assert isinstance(Callable, self.block_error_handler)
self.block_error_handler(self, status)

@property
Expand Down
9 changes: 8 additions & 1 deletion parsl/jobs/simple_error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,21 @@
from parsl.jobs.states import JobStatus, JobState


def noop_error_handler(executor: status_handling.BlockProviderExecutor, status: Dict[str, JobStatus], threshold: int = 3):
pass


def simple_error_handler(executor: status_handling.BlockProviderExecutor, status: Dict[str, JobStatus], threshold: int = 3):
(total_jobs, failed_jobs) = _count_jobs(status)
if hasattr(executor.provider, "init_blocks"):
threshold = max(1, executor.provider.init_blocks)

if total_jobs >= threshold and failed_jobs == total_jobs:
executor.set_bad_state_and_fail_all(_get_error(status))


def windowed_error_handler(executor: status_handling.BlockProviderExecutor, status: Dict[str, JobStatus], threshold: int = 3):
sorted_status = [(key, status[key]) for key in sorted(status)]
sorted_status = [(key, status[key]) for key in sorted(status, key=lambda x: int(x))]
current_window = dict(sorted_status[-threshold:])
total, failed = _count_jobs(current_window)
if failed == threshold:
Expand Down
34 changes: 31 additions & 3 deletions parsl/tests/test_scaling/test_block_error_handler.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,18 @@
import pytest

from parsl.executors import HighThroughputExecutor
from parsl.providers import LocalProvider
from unittest.mock import Mock
from parsl.jobs.states import JobStatus, JobState
from parsl.jobs.simple_error_handler import simple_error_handler, windowed_error_handler
from parsl.jobs.simple_error_handler import simple_error_handler, windowed_error_handler, noop_error_handler
from functools import partial


@pytest.mark.local
def test_block_error_handler_false():
mock = Mock()
htex = HighThroughputExecutor(block_error_handler=False)
assert htex.block_error_handler is False
assert htex.block_error_handler is noop_error_handler
htex.set_bad_state_and_fail_all = mock

bad_jobs = {'1': JobStatus(JobState.FAILED),
Expand Down Expand Up @@ -44,7 +45,9 @@ def test_block_error_handler_mock():

@pytest.mark.local
def test_simple_error_handler():
htex = HighThroughputExecutor(block_error_handler=simple_error_handler)
htex = HighThroughputExecutor(block_error_handler=simple_error_handler,
provider=LocalProvider(init_blocks=3))

assert htex.block_error_handler is simple_error_handler

bad_state_mock = Mock()
Expand Down Expand Up @@ -108,6 +111,31 @@ def test_windowed_error_handler():
bad_state_mock.assert_called()


@pytest.mark.local
def test_windowed_error_handler_sorting():
htex = HighThroughputExecutor(block_error_handler=windowed_error_handler)
assert htex.block_error_handler is windowed_error_handler

bad_state_mock = Mock()
htex.set_bad_state_and_fail_all = bad_state_mock

bad_jobs = {'8': JobStatus(JobState.FAILED),
'9': JobStatus(JobState.FAILED),
'10': JobStatus(JobState.FAILED),
'11': JobStatus(JobState.COMPLETED),
'12': JobStatus(JobState.COMPLETED)}
htex.handle_errors(bad_jobs)
bad_state_mock.assert_not_called()

bad_jobs = {'8': JobStatus(JobState.COMPLETED),
'9': JobStatus(JobState.FAILED),
'21': JobStatus(JobState.FAILED),
'22': JobStatus(JobState.FAILED),
'10': JobStatus(JobState.FAILED)}
htex.handle_errors(bad_jobs)
bad_state_mock.assert_called()


@pytest.mark.local
def test_windowed_error_handler_with_threshold():
error_handler = partial(windowed_error_handler, threshold=2)
Expand Down

0 comments on commit 1e78474

Please sign in to comment.