Skip to content

Commit

Permalink
Add basic retry loop to account for max_submit functionality
Browse files Browse the repository at this point in the history
Use while retry to iterate from running to waiting states. It includes a
simple test to check if job has started 3 times.
  • Loading branch information
xjules committed Dec 12, 2023
1 parent 374c4b7 commit 6de8e2d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 39 deletions.
83 changes: 53 additions & 30 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import logging
from enum import Enum
from typing import TYPE_CHECKING

Expand All @@ -16,6 +17,8 @@
from ert.ensemble_evaluator._builder._realization import Realization
from ert.scheduler.scheduler import Scheduler

logger = logging.getLogger(__name__)


class State(str, Enum):
WAITING = "WAITING"
Expand Down Expand Up @@ -66,36 +69,56 @@ async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore
) -> None:
await start.wait()
await sem.acquire()

try:
await self._send(State.SUBMITTING)
await self.driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

await self._send(State.STARTING)
await self.started.wait()

await self._send(State.RUNNING)
returncode = await self.returncode
if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
== LoadStatus.LOAD_SUCCESSFUL
):
await self._send(State.COMPLETED)
else:
await self._send(State.FAILED)

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
finally:
sem.release()
retries = 0
retry: bool = True
while retry:
retry = False
await sem.acquire()
try:
await self._send(State.SUBMITTING)
await self.driver.submit(
self.real.iens, self.real.job_script, cwd=self.real.run_arg.runpath
)

await self._send(State.STARTING)
await self.started.wait()

await self._send(State.RUNNING)
returncode = await self.returncode
# we need to make sure that the task has finished too
await self.driver.wait_to_finish(self.real.iens)

if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
== LoadStatus.LOAD_SUCCESSFUL
):
await self._send(State.COMPLETED)
else:
await self._send(State.FAILED)
retries += 1
retry = retries < self._scheduler._max_submit
if retry:
message = f"Realization: {self.iens} failed, resubmitting"
logger.warning(message)
print(message)
else:
message = (
f"Realization: {self.iens} "
"failed after reaching max submit "
f"{self._scheduler._max_submit}!"
)
print(message)
logger.error(message)

except asyncio.CancelledError:
await self._send(State.ABORTING)
await self.driver.kill(self.iens)

await self.aborted.wait()
await self._send(State.ABORTED)
finally:
sem.release()

async def _send(self, state: State) -> None:
status = STATE_TO_LEGACY[state]
Expand Down
4 changes: 4 additions & 0 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,10 @@ async def kill(self, iens: int) -> None:
except KeyError:
return

async def wait_to_finish(self, iens: int):
# we might need to do some timeout here I guess
await self._tasks[iens]

async def _wait_until_finish(
self, iens: int, executable: str, /, *args: str, cwd: str
) -> None:
Expand Down
10 changes: 2 additions & 8 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
import ssl
import threading
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
MutableMapping,
Optional,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, MutableMapping, Optional

from pydantic.dataclasses import dataclass
from websockets import Headers
Expand Down Expand Up @@ -51,6 +44,7 @@ def __init__(self, driver: Optional[Driver] = None) -> None:
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}

self._events: Optional[asyncio.Queue[Any]] = None
self._max_submit: int = 2

self._ee_uri = ""
self._ens_id = ""
Expand Down
17 changes: 16 additions & 1 deletion tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
import asyncio
import json
import os
import shutil
from dataclasses import asdict
from pathlib import Path
from textwrap import dedent
from typing import Sequence
from unittest.mock import patch

import pytest

Expand Down Expand Up @@ -108,3 +110,16 @@ async def test_cancel(tmp_path: Path, realization):

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()


async def test_that_max_submit_was_reached(tmp_path: Path, realization):
script = "[ -f cnt ] && echo $(( $(cat cnt) + 1 )) > cnt || echo 1 > cnt; exit 1"
step = create_bash_step(script)
realization.forward_models = [step]
sch = scheduler.Scheduler()
sch._max_submit = 3
sch.add_realization(realization, callback_timeout=lambda _: None)
create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED
assert (tmp_path / "cnt").read_text() == "3\n"

0 comments on commit 6de8e2d

Please sign in to comment.