Skip to content

Commit

Permalink
Add timeout for server
Browse files Browse the repository at this point in the history
  • Loading branch information
eivindjahren committed Nov 1, 2024
1 parent c34060c commit 87136c7
Show file tree
Hide file tree
Showing 4 changed files with 68 additions and 8 deletions.
20 changes: 14 additions & 6 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@ class ClientConnectionClosedOK(Exception):


class Client:
DEFAULT_MAX_RETRIES = 10
DEFAULT_TIMEOUT_MULTIPLIER = 5
CONNECTION_TIMEOUT = 60

def __enter__(self) -> Self:
return self

Expand All @@ -49,9 +53,13 @@ def __init__(
url: str,
token: Optional[str] = None,
cert: Optional[Union[str, bytes]] = None,
max_retries: int = 10,
timeout_multiplier: int = 5,
max_retries: Optional[int] = None,
timeout_multiplier: Optional[int] = None,
) -> None:
if max_retries is None:
max_retries = self.DEFAULT_MAX_RETRIES
if timeout_multiplier is None:
timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER
if url is None:
raise ValueError("url was None")
self.url = url
Expand Down Expand Up @@ -82,10 +90,10 @@ async def get_websocket(self) -> WebSocketClientProtocol:
self.url,
ssl=self._ssl_context,
extra_headers=self._extra_headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
close_timeout=60,
open_timeout=self.CONNECTION_TIMEOUT,
ping_timeout=self.CONNECTION_TIMEOUT,
ping_interval=self.CONNECTION_TIMEOUT,
close_timeout=self.CONNECTION_TIMEOUT,
)

async def _send(self, msg: AnyStr) -> None:
Expand Down
21 changes: 21 additions & 0 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,8 +385,24 @@ async def _start_running(self) -> None:
)
)

CLOSE_SERVER_TIMEOUT = 60

async def _wait_for_stopped_server(self) -> None:
"""
When the ensemble is done, we wait for the server to stop
with a timeout.
"""
try:
await asyncio.wait_for(
self._server_done.wait(), timeout=self.CLOSE_SERVER_TIMEOUT
)
except asyncio.TimeoutError:
print("Timeout server done")
self._server_done.set()

async def _monitor_and_handle_tasks(self) -> None:
pending: Iterable[asyncio.Task[None]] = self._ee_tasks
stop_timeout_task: Optional[asyncio.Task[None]] = None

while True:
done, pending = await asyncio.wait(
Expand All @@ -407,8 +423,13 @@ async def _monitor_and_handle_tasks(self) -> None:
)
raise task_exception
elif task.get_name() == "server_task":
if stop_timeout_task:
stop_timeout_task.cancel()
return
elif task.get_name() == "ensemble_task":
stop_timeout_task = asyncio.create_task(
self._wait_for_stopped_server()
)
continue
else:
msg = (
Expand Down
6 changes: 5 additions & 1 deletion src/ert/scheduler/job.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class Job:
(LSF, PBS, SLURM, etc.)
"""

DEFAULT_CHECKSUM_TIMEOUT = 120

def __init__(self, scheduler: Scheduler, real: Realization) -> None:
self.real = real
self.state = JobState.WAITING
Expand Down Expand Up @@ -188,8 +190,10 @@ async def _max_runtime_task(self) -> None:
self.returncode.cancel()

async def _verify_checksum(
self, checksum_lock: asyncio.Lock, timeout: int = 120
self, checksum_lock: asyncio.Lock, timeout: Optional[int] = None
) -> None:
if timeout is None:
timeout = self.DEFAULT_CHECKSUM_TIMEOUT
# Wait for job runpath to be in the checksum dictionary
runpath = self.real.run_arg.runpath
while runpath not in self._scheduler.checksum:
Expand Down
29 changes: 28 additions & 1 deletion tests/ert/ui_tests/cli/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,18 @@
from datetime import datetime
from pathlib import Path
from textwrap import dedent
from unittest.mock import Mock, call
from unittest.mock import Mock, call, patch

import numpy as np
import pandas as pd
import pytest
import websockets.exceptions
import xtgeo
from resdata.summary import Summary

import _ert.threading
import ert.shared
from _ert.forward_model_runner.client import Client
from ert import LibresFacade, ensemble_evaluator
from ert.cli.main import ErtCliError
from ert.config import (
Expand All @@ -24,13 +26,15 @@
ErtConfig,
)
from ert.enkf_main import sample_prior
from ert.ensemble_evaluator import EnsembleEvaluator
from ert.mode_definitions import (
ENSEMBLE_EXPERIMENT_MODE,
ENSEMBLE_SMOOTHER_MODE,
ES_MDA_MODE,
ITERATIVE_ENSEMBLE_SMOOTHER_MODE,
TEST_RUN_MODE,
)
from ert.scheduler.job import Job
from ert.storage import open_storage

from .run_cli import run_cli
Expand Down Expand Up @@ -928,3 +932,26 @@ def test_tracking_missing_ecl(monkeypatch, tmp_path, caplog):
f"Expected file {case}.UNSMRY not created by forward model!\nExpected "
f"file {case}.SMSPEC not created by forward model!"
) in caplog.messages


@pytest.mark.usefixtures("copy_poly_case")
def test_that_connection_errors_do_not_effect_final_result(
monkeypatch: pytest.MonkeyPatch,
):
monkeypatch.setattr(Client, "DEFAULT_MAX_RETRIES", 0)
monkeypatch.setattr(Client, "DEFAULT_TIMEOUT_MULTIPLIER", 0)
monkeypatch.setattr(Client, "CONNECTION_TIMEOUT", 1)
monkeypatch.setattr(EnsembleEvaluator, "CLOSE_SERVER_TIMEOUT", 0)
monkeypatch.setattr(Job, "DEFAULT_CHECKSUM_TIMEOUT", 0)

def raise_connection_error(*args, **kwargs):
raise websockets.exceptions.ConnectionClosedError(None, None)

with patch(
"ert.ensemble_evaluator.evaluator.dispatch_event_from_json",
raise_connection_error,
):
run_cli(
ENSEMBLE_EXPERIMENT_MODE,
"poly.ert",
)

0 comments on commit 87136c7

Please sign in to comment.