Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Try using new websockets asyncio #9080

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/_ert/forward_model_runner/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Any, AnyStr, Optional, Union

from typing_extensions import Self
from websockets.client import WebSocketClientProtocol, connect
from websockets.asyncio.client import ClientConnection, connect
from websockets.datastructures import Headers
from websockets.exceptions import (
ConnectionClosedError,
Expand Down Expand Up @@ -74,14 +74,14 @@ def __init__(

self._max_retries = max_retries
self._timeout_multiplier = timeout_multiplier
self.websocket: Optional[WebSocketClientProtocol] = None
self.websocket: Optional[ClientConnection] = None
self.loop = new_event_loop()

async def get_websocket(self) -> WebSocketClientProtocol:
async def get_websocket(self) -> ClientConnection:
return await connect(
self.url,
ssl=self._ssl_context,
extra_headers=self._extra_headers,
additional_headers=self._extra_headers,
open_timeout=60,
ping_timeout=60,
ping_interval=60,
Expand Down
49 changes: 24 additions & 25 deletions src/ert/ensemble_evaluator/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,10 @@
get_args,
)

import websockets
from pydantic_core._pydantic_core import ValidationError
from websockets.datastructures import Headers, HeadersLike
from websockets.asyncio.server import ServerConnection, serve
from websockets.exceptions import ConnectionClosedError
from websockets.server import WebSocketServerProtocol
from websockets.http11 import Request, Response

from _ert.events import (
EESnapshot,
Expand Down Expand Up @@ -70,7 +69,7 @@ def __init__(self, ensemble: Ensemble, config: EvaluatorServerConfig):

self._loop: Optional[asyncio.AbstractEventLoop] = None

self._clients: Set[WebSocketServerProtocol] = set()
self._clients: Set[ServerConnection] = set()
self._dispatchers_connected: asyncio.Queue[None] = asyncio.Queue()

self._events: asyncio.Queue[Event] = asyncio.Queue()
Expand Down Expand Up @@ -207,14 +206,12 @@ def ensemble(self) -> Ensemble:
return self._ensemble

@contextmanager
def store_client(
self, websocket: WebSocketServerProtocol
) -> Generator[None, None, None]:
def store_client(self, websocket: ServerConnection) -> Generator[None, None, None]:
self._clients.add(websocket)
yield
self._clients.remove(websocket)

async def handle_client(self, websocket: WebSocketServerProtocol) -> None:
async def handle_client(self, websocket: ServerConnection) -> None:
with self.store_client(websocket):
current_snapshot_dict = self._ensemble.snapshot.to_dict()
event: Event = EESnapshot(
Expand All @@ -240,7 +237,7 @@ async def count_dispatcher(self) -> AsyncIterator[None]:
await self._dispatchers_connected.get()
self._dispatchers_connected.task_done()

async def handle_dispatch(self, websocket: WebSocketServerProtocol) -> None:
async def handle_dispatch(self, websocket: ServerConnection) -> None:
async with self.count_dispatcher():
try:
async for raw_msg in websocket:
Expand Down Expand Up @@ -283,32 +280,34 @@ async def forward_checksum(self, event: Event) -> None:
await self._events_to_send.put(event)
await self._manifest_queue.put(event)

async def connection_handler(self, websocket: WebSocketServerProtocol) -> None:
path = websocket.path
elements = path.split("/")
if elements[1] == "client":
await self.handle_client(websocket)
elif elements[1] == "dispatch":
await self.handle_dispatch(websocket)
async def connection_handler(self, websocket: ServerConnection) -> None:
if websocket.request is not None:
path = websocket.request.path
elements = path.split("/")
if elements[1] == "client":
await self.handle_client(websocket)
elif elements[1] == "dispatch":
await self.handle_dispatch(websocket)
else:
logger.info(f"Connection attempt to unknown path: {path}.")
else:
logger.info(f"Connection attempt to unknown path: {path}.")
logger.info("No request to handle.")

async def process_request(
self, path: str, request_headers: Headers
) -> Optional[Tuple[HTTPStatus, HeadersLike, bytes]]:
if request_headers.get("token") != self._config.token:
return HTTPStatus.UNAUTHORIZED, {}, b""
if path == "/healthcheck":
return HTTPStatus.OK, {}, b""
self, connection: ServerConnection, request: Request
) -> Optional[Response]:
if request.headers.get("token") != self._config.token:
return connection.respond(HTTPStatus.UNAUTHORIZED, "")
if request.path == "/healthcheck":
return connection.respond(HTTPStatus.OK, "OK\n")
return None

async def _server(self) -> None:
async with websockets.serve(
async with serve(
self.connection_handler,
sock=self._config.get_socket(),
ssl=self._config.get_server_ssl_context(),
process_request=self.process_request,
max_queue=None,
max_size=2**26,
ping_timeout=60,
ping_interval=60,
Expand Down
8 changes: 4 additions & 4 deletions src/ert/ensemble_evaluator/monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import TYPE_CHECKING, Any, AsyncGenerator, Final, Optional, Union

from aiohttp import ClientError
from websockets import ConnectionClosed, Headers, WebSocketClientProtocol
from websockets.client import connect
from websockets import ConnectionClosed, Headers
from websockets.asyncio.client import ClientConnection, connect

from _ert.events import (
EETerminated,
Expand Down Expand Up @@ -36,7 +36,7 @@ def __init__(self, ee_con_info: "EvaluatorConnectionInfo") -> None:
self._ee_con_info = ee_con_info
self._id = str(uuid.uuid1()).split("-", maxsplit=1)[0]
self._event_queue: asyncio.Queue[Union[Event, EventSentinel]] = asyncio.Queue()
self._connection: Optional[WebSocketClientProtocol] = None
self._connection: Optional[ClientConnection] = None
self._receiver_task: Optional[asyncio.Task[None]] = None
self._connected: asyncio.Event = asyncio.Event()
self._connection_timeout: float = 120.0
Expand Down Expand Up @@ -137,7 +137,7 @@ async def _receiver(self) -> None:
async for conn in connect(
self._ee_con_info.client_uri,
ssl=tls,
extra_headers=headers,
additional_headers=headers,
max_size=2**26,
max_queue=500,
open_timeout=5,
Expand Down
8 changes: 4 additions & 4 deletions tests/ert/unit_tests/ensemble_evaluator/test_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from urllib.parse import urlparse

import pytest
from websockets import server
from websockets.asyncio import server
from websockets.exceptions import ConnectionClosedOK

from _ert.events import EEUserCancel, EEUserDone, event_from_json
Expand All @@ -15,9 +15,9 @@
async def _mock_ws(
set_when_done: asyncio.Event, handler, ee_config: EvaluatorConnectionInfo
):
async def process_request(path, request_headers):
if path == "/healthcheck":
return HTTPStatus.OK, {}, b""
async def process_request(path, connection, request):
if request.path == "/healthcheck":
return connection.respond(HTTPStatus.OK, "")

url = urlparse(ee_config.url)
async with server.serve(
Expand Down
Loading