diff --git a/src/ert/scheduler/driver.py b/src/ert/scheduler/driver.py index 90baea0a5c0..7c15c51fe9b 100644 --- a/src/ert/scheduler/driver.py +++ b/src/ert/scheduler/driver.py @@ -3,10 +3,7 @@ import asyncio from abc import ABC, abstractmethod from enum import Enum -from typing import ( - Optional, - Tuple, -) +from typing import Optional, Tuple class JobEvent(Enum): @@ -48,3 +45,7 @@ async def kill(self, iens: int) -> None: @abstractmethod async def poll(self) -> None: """Poll for new job events""" + + @abstractmethod + async def finish(self) -> None: + """make sure that all the jobs / realizations are complete.""" diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 022b878f73e..b1a2c2f33d5 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import logging from enum import Enum from typing import TYPE_CHECKING @@ -16,6 +17,8 @@ from ert.ensemble_evaluator._builder._realization import Realization from ert.scheduler.scheduler import Scheduler +logger = logging.getLogger(__name__) + class State(str, Enum): WAITING = "WAITING" @@ -62,12 +65,8 @@ def iens(self) -> int: def driver(self) -> Driver: return self._scheduler.driver - async def __call__( - self, start: asyncio.Event, sem: asyncio.BoundedSemaphore - ) -> None: - await start.wait() + async def _submit_and_run_once(self, sem: asyncio.BoundedSemaphore) -> None: await sem.acquire() - try: await self._send(State.SUBMITTING) await self.driver.submit( @@ -81,6 +80,7 @@ async def __call__( while not self.returncode.done(): await asyncio.sleep(0.01) returncode = await self.returncode + if ( returncode == 0 and forward_model_ok(self.real.run_arg).status @@ -89,6 +89,8 @@ async def __call__( await self._send(State.COMPLETED) else: await self._send(State.FAILED) + self.returncode = asyncio.Future() + self.started = asyncio.Event() except asyncio.CancelledError: await self._send(State.ABORTING) @@ -99,6 +101,26 @@ async def __call__( finally: sem.release() + async def __call__( + self, start: asyncio.Event, sem: asyncio.BoundedSemaphore, max_submit: int = 2 + ) -> None: + await start.wait() + + for _ in range(max_submit): + await self._submit_and_run_once(sem) + + if self.returncode.done() or self.aborted.is_set(): + break + else: + message = f"Realization: {self.iens} failed, resubmitting" + logger.warning(message) + else: + message = ( + f"Realization: {self.iens} " + f"failed after reaching max submit {max_submit}" + ) + logger.error(message) + async def _send(self, state: State) -> None: status = STATE_TO_LEGACY[state] event = CloudEvent( diff --git a/src/ert/scheduler/local_driver.py b/src/ert/scheduler/local_driver.py index 010da709ae6..e88ae5453be 100644 --- a/src/ert/scheduler/local_driver.py +++ b/src/ert/scheduler/local_driver.py @@ -13,6 +13,7 @@ def __init__(self) -> None: self._tasks: MutableMapping[int, asyncio.Task[None]] = {} async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> None: + await self.kill(iens) self._tasks[iens] = asyncio.create_task( self._wait_until_finish(iens, executable, *args, cwd=cwd) ) @@ -20,9 +21,14 @@ async def submit(self, iens: int, executable: str, /, *args: str, cwd: str) -> N async def kill(self, iens: int) -> None: try: self._tasks[iens].cancel() + await self._tasks[iens] + del self._tasks[iens] except KeyError: return + async def finish(self) -> None: + await asyncio.gather(*self._tasks.values()) + async def _wait_until_finish( self, iens: int, executable: str, /, *args: str, cwd: str ) -> None: diff --git a/src/ert/scheduler/scheduler.py b/src/ert/scheduler/scheduler.py index 5b9a66e6874..d0a2bbc5388 100644 --- a/src/ert/scheduler/scheduler.py +++ b/src/ert/scheduler/scheduler.py @@ -7,14 +7,7 @@ import ssl import threading from dataclasses import asdict -from typing import ( - TYPE_CHECKING, - Any, - Callable, - Iterable, - MutableMapping, - Optional, -) +from typing import TYPE_CHECKING, Any, Callable, Iterable, MutableMapping, Optional from pydantic.dataclasses import dataclass from websockets import Headers @@ -53,6 +46,8 @@ def __init__(self, driver: Optional[Driver] = None) -> None: self._events: Optional[asyncio.Queue[Any]] = None self._cancelled = False + # will be read from QueueConfig + self._max_submit: int = 2 self._ee_uri = "" self._ens_id = "" @@ -131,14 +126,20 @@ async def execute( cancel_when_execute_is_done(self.driver.poll()) start = asyncio.Event() - sem = asyncio.BoundedSemaphore(semaphore._initial_value if semaphore else 10) # type: ignore + sem = asyncio.BoundedSemaphore( + semaphore._initial_value if semaphore else 10 # type: ignore + ) for iens, job in self._jobs.items(): - self._tasks[iens] = asyncio.create_task(job(start, sem)) + self._tasks[iens] = asyncio.create_task( + job(start, sem, self._max_submit) + ) start.set() for task in self._tasks.values(): await task + await self.driver.finish() + if self._cancelled: return EVTYPE_ENSEMBLE_CANCELLED diff --git a/tests/unit_tests/scheduler/test_scheduler.py b/tests/unit_tests/scheduler/test_scheduler.py index 3b071e109a4..45ec2cc2aaf 100644 --- a/tests/unit_tests/scheduler/test_scheduler.py +++ b/tests/unit_tests/scheduler/test_scheduler.py @@ -1,7 +1,6 @@ import asyncio import json import shutil -from dataclasses import asdict from pathlib import Path from typing import Sequence @@ -108,3 +107,24 @@ async def test_cancel(tmp_path: Path, realization): assert (tmp_path / "a").exists() assert not (tmp_path / "b").exists() + + +@pytest.mark.parametrize( + "max_submit", + [ + (1), + (2), + (3), + ], +) +async def test_that_max_submit_was_reached(tmp_path: Path, realization, max_submit): + script = "[ -f cnt ] && echo $(( $(cat cnt) + 1 )) > cnt || echo 1 > cnt; exit 1" + step = create_bash_step(script) + realization.forward_models = [step] + sch = scheduler.Scheduler() + sch._max_submit = max_submit + sch.add_realization(realization, callback_timeout=lambda _: None) + create_jobs_json(tmp_path, [step]) + sch.add_dispatch_information_to_jobs_file() + assert await sch.execute() == EVTYPE_ENSEMBLE_STOPPED + assert (tmp_path / "cnt").read_text() == f"{max_submit}\n"