Skip to content

Commit

Permalink
Add basic retry loop to account for max_submit functionality
Browse files Browse the repository at this point in the history
Use while retry to iterate from running to waiting states. It includes a
simple test to check if job has started several times. Max_submit is a function
parameter of job.__call__ that is passed on from scheduler.

Additionally, function driver.finish will implement the basic clean up
functionally. For the local driver it makes sure that all tasks have
been awaited correctly.
  • Loading branch information
xjules committed Dec 15, 2023
1 parent 25bd439 commit 37bf484
Show file tree
Hide file tree
Showing 5 changed files with 70 additions and 20 deletions.
9 changes: 5 additions & 4 deletions src/ert/scheduler/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,7 @@
import asyncio
from abc import ABC, abstractmethod
from enum import Enum
from typing import (
Optional,
Tuple,
)
from typing import Optional, Tuple


class JobEvent(Enum):
Expand Down Expand Up @@ -48,3 +45,7 @@ async def kill(self, iens: int) -> None:
@abstractmethod
async def poll(self) -> None:
"""Poll for new job events"""

@abstractmethod
async def finish(self) -> None:
"""make sure that all the jobs / realizations are complete."""
32 changes: 27 additions & 5 deletions src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import asyncio
import logging
from enum import Enum
from typing import TYPE_CHECKING

Expand All @@ -16,6 +17,8 @@
from ert.ensemble_evaluator._builder._realization import Realization
from ert.scheduler.scheduler import Scheduler

logger = logging.getLogger(__name__)


class State(str, Enum):
WAITING = "WAITING"
Expand Down Expand Up @@ -62,12 +65,8 @@ def iens(self) -> int:
def driver(self) -> Driver:
return self._scheduler.driver

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore
) -> None:
await start.wait()
async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None:
await sem.acquire()

try:
await self._send(State.SUBMITTING)
await self.driver.submit(
Expand All @@ -81,6 +80,7 @@ async def __call__(
while not self.returncode.done():
await asyncio.sleep(0.01)
returncode = await self.returncode

if (
returncode == 0
and forward_model_ok(self.real.run_arg).status
Expand All @@ -89,6 +89,8 @@ async def __call__(
await self._send(State.COMPLETED)
else:
await self._send(State.FAILED)
self.returncode = asyncio.Future()
self.started = asyncio.Event()

except asyncio.CancelledError:
await self._send(State.ABORTING)
Expand All @@ -99,6 +101,26 @@ async def __call__(
finally:
sem.release()

async def __call__(
self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2
) -> None:
await start.wait()

for _ in range(max_submit):
await self._submit_and_run_once(sem)

if self.returncode.done() or self.aborted.is_set():
break
else:
message = f"Realization: {self.iens} failed, resubmitting"
logger.warning(message)
else:
message = (
f"Realization: {self.iens} "
f"failed after reaching max submit {max_submit}"
)
logger.error(message)

async def _send(self, state: State) -> None:
status = STATE_TO_LEGACY[state]
event = CloudEvent(
Expand Down
6 changes: 6 additions & 0 deletions src/ert/scheduler/local_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,22 @@ def __init__(self) -> None:
self._tasks: MutableMapping[int, asyncio.Task[None]] = {}

async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None:
await self.kill(iens)
self._tasks[iens] = asyncio.create_task(
self._wait_until_finish(iens, executable, *args, cwd=cwd)
)

async def kill(self, iens: int) -> None:
try:
self._tasks[iens].cancel()
await self._tasks[iens]
del self._tasks[iens]
except KeyError:
return

async def finish(self) -> None:
await asyncio.gather(*self._tasks.values())

async def _wait_until_finish(
self, iens: int, executable: str, /, *args: str, cwd: str
) -> None:
Expand Down
21 changes: 11 additions & 10 deletions src/ert/scheduler/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,7 @@
import ssl
import threading
from dataclasses import asdict
from typing import (
TYPE_CHECKING,
Any,
Callable,
Iterable,
MutableMapping,
Optional,
)
from typing import TYPE_CHECKING, Any, Callable, Iterable, MutableMapping, Optional

from pydantic.dataclasses import dataclass
from websockets import Headers
Expand Down Expand Up @@ -53,6 +46,8 @@ def __init__(self, driver: Optional[Driver] = None) -> None:

self._events: Optional[asyncio.Queue[Any]] = None
self._cancelled = False
# will be read from QueueConfig
self._max_submit: int = 2

self._ee_uri = ""
self._ens_id = ""
Expand Down Expand Up @@ -131,14 +126,20 @@ async def execute(
cancel_when_execute_is_done(self.driver.poll())

start = asyncio.Event()
sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore
sem = asyncio.BoundedSemaphore(
semaphore._initial_value if semaphore else 10 # type: ignore
)
for iens, job in self._jobs.items():
self._tasks[iens] = asyncio.create_task(job(start, sem))
self._tasks[iens] = asyncio.create_task(
job(start, sem, self._max_submit)
)

start.set()
for task in self._tasks.values():
await task

await self.driver.finish()

if self._cancelled:
return EVTYPE_ENSEMBLE_CANCELLED

Expand Down
22 changes: 21 additions & 1 deletion tests/unit_tests/scheduler/test_scheduler.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
import json
import shutil
from dataclasses import asdict
from pathlib import Path
from typing import Sequence

Expand Down Expand Up @@ -108,3 +107,24 @@ async def test_cancel(tmp_path: Path, realization):

assert (tmp_path / "a").exists()
assert not (tmp_path / "b").exists()


@pytest.mark.parametrize(
"max_submit",
[
(1),
(2),
(3),
],
)
async def test_that_max_submit_was_reached(tmp_path: Path, realization, max_submit):
script = "[ -f cnt ] && echo $(( $(cat cnt) + 1 )) > cnt || echo 1 > cnt; exit 1"
step = create_bash_step(script)
realization.forward_models = [step]
sch = scheduler.Scheduler()
sch._max_submit = max_submit
sch.add_realization(realization, callback_timeout=lambda _: None)
create_jobs_json(tmp_path, [step])
sch.add_dispatch_information_to_jobs_file()
assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED
assert (tmp_path / "cnt").read_text() == f"{max_submit}\n"

0 comments on commit 37bf484

Please sign in to comment.