From 87136c7cbf2f966dced604e36a2dde82a129285d Mon Sep 17 00:00:00 2001 From: Eivind Jahren Date: Fri, 1 Nov 2024 10:10:11 +0100 Subject: [PATCH] Add timeout for server --- src/_ert/forward_model_runner/client.py | 20 ++++++++++++----- src/ert/ensemble_evaluator/evaluator.py | 21 ++++++++++++++++++ src/ert/scheduler/job.py | 6 ++++- tests/ert/ui_tests/cli/test_cli.py | 29 ++++++++++++++++++++++++- 4 files changed, 68 insertions(+), 8 deletions(-) diff --git a/src/_ert/forward_model_runner/client.py b/src/_ert/forward_model_runner/client.py index 3db67b00b76..2566ca005f8 100644 --- a/src/_ert/forward_model_runner/client.py +++ b/src/_ert/forward_model_runner/client.py @@ -27,6 +27,10 @@ class ClientConnectionClosedOK(Exception): class Client: + DEFAULT_MAX_RETRIES = 10 + DEFAULT_TIMEOUT_MULTIPLIER = 5 + CONNECTION_TIMEOUT = 60 + def __enter__(self) -> Self: return self @@ -49,9 +53,13 @@ def __init__( url: str, token: Optional[str] = None, cert: Optional[Union[str, bytes]] = None, - max_retries: int = 10, - timeout_multiplier: int = 5, + max_retries: Optional[int] = None, + timeout_multiplier: Optional[int] = None, ) -> None: + if max_retries is None: + max_retries = self.DEFAULT_MAX_RETRIES + if timeout_multiplier is None: + timeout_multiplier = self.DEFAULT_TIMEOUT_MULTIPLIER if url is None: raise ValueError("url was None") self.url = url @@ -82,10 +90,10 @@ async def get_websocket(self) -> WebSocketClientProtocol: self.url, ssl=self._ssl_context, extra_headers=self._extra_headers, - open_timeout=60, - ping_timeout=60, - ping_interval=60, - close_timeout=60, + open_timeout=self.CONNECTION_TIMEOUT, + ping_timeout=self.CONNECTION_TIMEOUT, + ping_interval=self.CONNECTION_TIMEOUT, + close_timeout=self.CONNECTION_TIMEOUT, ) async def _send(self, msg: AnyStr) -> None: diff --git a/src/ert/ensemble_evaluator/evaluator.py b/src/ert/ensemble_evaluator/evaluator.py index 0651d86699f..3855ec85cac 100644 --- a/src/ert/ensemble_evaluator/evaluator.py +++ b/src/ert/ensemble_evaluator/evaluator.py @@ -385,8 +385,24 @@ async def _start_running(self) -> None: ) ) + CLOSE_SERVER_TIMEOUT = 60 + + async def _wait_for_stopped_server(self) -> None: + """ + When the ensemble is done, we wait for the server to stop + with a timeout. + """ + try: + await asyncio.wait_for( + self._server_done.wait(), timeout=self.CLOSE_SERVER_TIMEOUT + ) + except asyncio.TimeoutError: + print("Timeout server done") + self._server_done.set() + async def _monitor_and_handle_tasks(self) -> None: pending: Iterable[asyncio.Task[None]] = self._ee_tasks + stop_timeout_task: Optional[asyncio.Task[None]] = None while True: done, pending = await asyncio.wait( @@ -407,8 +423,13 @@ async def _monitor_and_handle_tasks(self) -> None: ) raise task_exception elif task.get_name() == "server_task": + if stop_timeout_task: + stop_timeout_task.cancel() return elif task.get_name() == "ensemble_task": + stop_timeout_task = asyncio.create_task( + self._wait_for_stopped_server() + ) continue else: msg = ( diff --git a/src/ert/scheduler/job.py b/src/ert/scheduler/job.py index 2d817de6efa..a3055f83d1a 100644 --- a/src/ert/scheduler/job.py +++ b/src/ert/scheduler/job.py @@ -59,6 +59,8 @@ class Job: (LSF, PBS, SLURM, etc.) """ + DEFAULT_CHECKSUM_TIMEOUT = 120 + def __init__(self, scheduler: Scheduler, real: Realization) -> None: self.real = real self.state = JobState.WAITING @@ -188,8 +190,10 @@ async def _max_runtime_task(self) -> None: self.returncode.cancel() async def _verify_checksum( - self, checksum_lock: asyncio.Lock, timeout: int = 120 + self, checksum_lock: asyncio.Lock, timeout: Optional[int] = None ) -> None: + if timeout is None: + timeout = self.DEFAULT_CHECKSUM_TIMEOUT # Wait for job runpath to be in the checksum dictionary runpath = self.real.run_arg.runpath while runpath not in self._scheduler.checksum: diff --git a/tests/ert/ui_tests/cli/test_cli.py b/tests/ert/ui_tests/cli/test_cli.py index ed45c984885..5ebb9e8d63b 100644 --- a/tests/ert/ui_tests/cli/test_cli.py +++ b/tests/ert/ui_tests/cli/test_cli.py @@ -6,16 +6,18 @@ from datetime import datetime from pathlib import Path from textwrap import dedent -from unittest.mock import Mock, call +from unittest.mock import Mock, call, patch import numpy as np import pandas as pd import pytest +import websockets.exceptions import xtgeo from resdata.summary import Summary import _ert.threading import ert.shared +from _ert.forward_model_runner.client import Client from ert import LibresFacade, ensemble_evaluator from ert.cli.main import ErtCliError from ert.config import ( @@ -24,6 +26,7 @@ ErtConfig, ) from ert.enkf_main import sample_prior +from ert.ensemble_evaluator import EnsembleEvaluator from ert.mode_definitions import ( ENSEMBLE_EXPERIMENT_MODE, ENSEMBLE_SMOOTHER_MODE, @@ -31,6 +34,7 @@ ITERATIVE_ENSEMBLE_SMOOTHER_MODE, TEST_RUN_MODE, ) +from ert.scheduler.job import Job from ert.storage import open_storage from .run_cli import run_cli @@ -928,3 +932,26 @@ def test_tracking_missing_ecl(monkeypatch, tmp_path, caplog): f"Expected file {case}.UNSMRY not created by forward model!\nExpected " f"file {case}.SMSPEC not created by forward model!" ) in caplog.messages + + +@pytest.mark.usefixtures("copy_poly_case") +def test_that_connection_errors_do_not_effect_final_result( + monkeypatch: pytest.MonkeyPatch, +): + monkeypatch.setattr(Client, "DEFAULT_MAX_RETRIES", 0) + monkeypatch.setattr(Client, "DEFAULT_TIMEOUT_MULTIPLIER", 0) + monkeypatch.setattr(Client, "CONNECTION_TIMEOUT", 1) + monkeypatch.setattr(EnsembleEvaluator, "CLOSE_SERVER_TIMEOUT", 0) + monkeypatch.setattr(Job, "DEFAULT_CHECKSUM_TIMEOUT", 0) + + def raise_connection_error(*args, **kwargs): + raise websockets.exceptions.ConnectionClosedError(None, None) + + with patch( + "ert.ensemble_evaluator.evaluator.dispatch_event_from_json", + raise_connection_error, + ): + run_cli( + ENSEMBLE_EXPERIMENT_MODE, + "poly.ert", + )