Skip to content

Commit

Permalink
Use the new websockets asyncio api
Browse files Browse the repository at this point in the history
  • Loading branch information
larsevj committed Oct 30, 2024
1 parent 6153918 commit 59751a7
Show file tree
Hide file tree
Showing 4 changed files with 36 additions and 37 deletions.
8 changes: 4 additions & 4 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, AnyStr, Optional, Union

from typing_extensions import Self
from websockets.client import WebSocketClientProtocol, connect
from websockets.asyncio.client import ClientConnection, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
ConnectionClosedError,
Expand Down Expand Up @@ -74,14 +74,14 @@ def __init__(

self._max_retries = max_retries
self._timeout_multiplier = timeout_multiplier
self.websocket: Optional[WebSocketClientProtocol] = None
self.websocket: Optional[ClientConnection] = None
self.loop = new_event_loop()

async def get_websocket(self) -> WebSocketClientProtocol:
async def get_websocket(self) -> ClientConnection:
return await connect(
self.url,
ssl=self._ssl_context,
extra_headers=self._extra_headers,
additional_headers=self._extra_headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
Expand Down
49 changes: 24 additions & 25 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
get_args,
)

import websockets
from pydantic_core._pydantic_core import ValidationError
from websockets.datastructures import Headers, HeadersLike
from websockets.asyncio.server import ServerConnection, serve
from websockets.exceptions import ConnectionClosedError
from websockets.server import WebSocketServerProtocol
from websockets.http11 import Request, Response

from _ert.events import (
EESnapshot,
Expand Down Expand Up @@ -70,7 +69,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):

self._loop: Optional[asyncio.AbstractEventLoop] = None

self._clients: Set[WebSocketServerProtocol] = set()
self._clients: Set[ServerConnection] = set()
self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue()

self._events: asyncio.Queue[Event] = asyncio.Queue()
Expand Down Expand Up @@ -207,14 +206,12 @@ def ensemble(self) -> Ensemble:
return self._ensemble

@contextmanager
def store_client(
self, websocket: WebSocketServerProtocol
) -> Generator[None, None, None]:
def store_client(self, websocket: ServerConnection) -> Generator[None, None, None]:
self._clients.add(websocket)
yield
self._clients.remove(websocket)

async def handle_client(self, websocket: WebSocketServerProtocol) -> None:
async def handle_client(self, websocket: ServerConnection) -> None:
with self.store_client(websocket):
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event: Event = EESnapshot(
Expand All @@ -240,7 +237,7 @@ async def count_dispatcher(self) -> AsyncIterator[None]:
await self._dispatchers_connected.get()
self._dispatchers_connected.task_done()

async def handle_dispatch(self, websocket: WebSocketServerProtocol) -> None:
async def handle_dispatch(self, websocket: ServerConnection) -> None:
async with self.count_dispatcher():
try:
async for raw_msg in websocket:
Expand Down Expand Up @@ -283,32 +280,34 @@ async def forward_checksum(self, event: Event) -> None:
await self._events_to_send.put(event)
await self._manifest_queue.put(event)

async def connection_handler(self, websocket: WebSocketServerProtocol) -> None:
path = websocket.path
elements = path.split("/")
if elements[1] == "client":
await self.handle_client(websocket)
elif elements[1] == "dispatch":
await self.handle_dispatch(websocket)
async def connection_handler(self, websocket: ServerConnection) -> None:
if websocket.request is not None:
path = websocket.request.path
elements = path.split("/")
if elements[1] == "client":
await self.handle_client(websocket)
elif elements[1] == "dispatch":
await self.handle_dispatch(websocket)
else:
logger.info(f"Connection attempt to unknown path: {path}.")
else:
logger.info(f"Connection attempt to unknown path: {path}.")
logger.info("No request to handle.")

async def process_request(
self, path: str, request_headers: Headers
) -> Optional[Tuple[HTTPStatus, HeadersLike, bytes]]:
if request_headers.get("token") != self._config.token:
return HTTPStatus.UNAUTHORIZED, {}, b""
if path == "/healthcheck":
return HTTPStatus.OK, {}, b""
self, connection: ServerConnection, request: Request
) -> Optional[Response]:
if request.headers.get("token") != self._config.token:
return connection.respond(HTTPStatus.UNAUTHORIZED, "")
if request.path == "/healthcheck":
return connection.respond(HTTPStatus.OK, "OK\n")
return None

async def _server(self) -> None:
async with websockets.serve(
async with serve(
self.connection_handler,
sock=self._config.get_socket(),
ssl=self._config.get_server_ssl_context(),
process_request=self.process_request,
max_queue=None,
max_size=2**26,
ping_timeout=60,
ping_interval=60,
Expand Down
8 changes: 4 additions & 4 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Optional, Union

from aiohttp import ClientError
from websockets import ConnectionClosed, Headers, WebSocketClientProtocol
from websockets.client import connect
from websockets import ConnectionClosed, Headers
from websockets.asyncio.client import ClientConnection, connect

from _ert.events import (
EETerminated,
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None:
self._ee_con_info = ee_con_info
self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0]
self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue()
self._connection: Optional[WebSocketClientProtocol] = None
self._connection: Optional[ClientConnection] = None
self._receiver_task: Optional[asyncio.Task[None]] = None
self._connected: asyncio.Event = asyncio.Event()
self._connection_timeout: float = 120.0
Expand Down Expand Up @@ -137,7 +137,7 @@ async def _receiver(self) -> None:
async for conn in connect(
self._ee_con_info.client_uri,
ssl=tls,
extra_headers=headers,
additional_headers=headers,
max_size=2**26,
max_queue=500,
open_timeout=5,
Expand Down
8 changes: 4 additions & 4 deletions tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from urllib.parse import urlparse

import pytest
from websockets import server
from websockets.asyncio import server
from websockets.exceptions import ConnectionClosedOK

from _ert.events import EEUserCancel, EEUserDone, event_from_json
Expand All @@ -15,9 +15,9 @@
async def _mock_ws(
set_when_done: asyncio.Event, handler, ee_config: EvaluatorConnectionInfo
):
async def process_request(path, request_headers):
if path == "/healthcheck":
return HTTPStatus.OK, {}, b""
async def process_request(path, connection, request):
if request.path == "/healthcheck":
return connection.respond(HTTPStatus.OK, "")

url = urlparse(ee_config.url)
async with server.serve(
Expand Down

0 comments on commit 59751a7

Please sign in to comment.