From c15709a3ff538bed4bc09136ff0741d76e3e375e Mon Sep 17 00:00:00 2001 From: David Brochart Date: Thu, 18 Apr 2024 09:55:44 +0200 Subject: [PATCH] Use TaskFactory instead of Task --- plugins/yjs/fps_yjs/main.py | 11 +++- plugins/yjs/fps_yjs/routes.py | 98 +++++++++++++++-------------------- 2 files changed, 52 insertions(+), 57 deletions(-) diff --git a/plugins/yjs/fps_yjs/main.py b/plugins/yjs/fps_yjs/main.py index b4cdbcc8..18c9daa8 100644 --- a/plugins/yjs/fps_yjs/main.py +++ b/plugins/yjs/fps_yjs/main.py @@ -1,6 +1,12 @@ from __future__ import annotations -from asphalt.core import Component, add_resource, get_resource, start_service_task +from asphalt.core import ( + Component, + add_resource, + get_resource, + start_background_task_factory, + start_service_task, +) from jupyverse_api.app import App from jupyverse_api.auth import Auth @@ -18,7 +24,8 @@ async def start(self) -> None: contents = await get_resource(Contents, wait=True) # type: ignore[type-abstract] lifespan = await get_resource(Lifespan, wait=True) - yjs = _Yjs(app, auth, contents, lifespan) + task_factory = await start_background_task_factory("yjs_tasks") + yjs = _Yjs(app, auth, contents, lifespan, task_factory) add_resource(yjs, types=Yjs) await start_service_task(yjs.start, "Room manager", teardown_action=yjs.stop) diff --git a/plugins/yjs/fps_yjs/routes.py b/plugins/yjs/fps_yjs/routes.py index 1a81bec4..4835953c 100644 --- a/plugins/yjs/fps_yjs/routes.py +++ b/plugins/yjs/fps_yjs/routes.py @@ -6,8 +6,9 @@ from typing import Dict from uuid import uuid4 -from anyio import TASK_STATUS_IGNORED, Event, create_task_group, sleep -from anyio.abc import TaskGroup, TaskStatus +from anyio import TASK_STATUS_IGNORED, sleep +from anyio.abc import TaskStatus +from asphalt.core import TaskFactory, TaskHandle from fastapi import ( HTTPException, Request, @@ -50,30 +51,33 @@ def __init__( auth: Auth, contents: Contents, lifespan: Lifespan, + task_factory: TaskFactory, ) -> None: super().__init__(app=app, auth=auth) self.contents = contents self.lifespan = lifespan + self.task_factory = task_factory if Widgets is None: self.widgets = None else: self.widgets = Widgets() # type: ignore async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None: - async with create_task_group() as tg: - self._task_group = tg - self.room_manager = RoomManager(self.contents, self.lifespan, tg) - await tg.start(self.room_manager.websocket_server.start) - tg.start_soon(self.room_manager.on_shutdown) - task_status.started() + self.room_manager = RoomManager(self.contents, self.lifespan, self.task_factory) + await self.task_factory.start_task( + self.room_manager.websocket_server.start, + "WebSocket server", + ) + self.task_factory.start_task_soon(self.room_manager.on_shutdown) + task_status.started() async def stop(self) -> None: - for watcher in self.room_manager.watchers.values(): - watcher.cancel() - for saver in self.room_manager.savers.values(): - saver.cancel() - for cleaner in self.room_manager.cleaners.values(): - cleaner.cancel() + for task in ( + list(self.room_manager.watchers.values()) + + list(self.room_manager.savers.values()) + + list(self.room_manager.cleaners.values()) + ): + task.cancel() async def collaboration_room_websocket( self, @@ -162,18 +166,18 @@ class RoomManager: contents: Contents lifespan: Lifespan documents: Dict[str, YBaseDoc] - watchers: Dict[str, Task] - savers: Dict[str, Task] - cleaners: Dict[YRoom, Task] + watchers: Dict[str, TaskHandle] + savers: Dict[str, TaskHandle] + cleaners: Dict[YRoom, TaskHandle] last_modified: Dict[str, datetime] websocket_server: JupyterWebsocketServer room_lock: ResourceLock - _task_group: TaskGroup + task_factory: TaskFactory - def __init__(self, contents: Contents, lifespan: Lifespan, task_group: TaskGroup): + def __init__(self, contents: Contents, lifespan: Lifespan, task_factory: TaskFactory): self.contents = contents self.lifespan = lifespan - self._task_group = task_group + self.task_factory = task_factory self.documents = {} # a dictionary of room_name:document self.watchers = {} # a dictionary of file_id:task self.savers = {} # a dictionary of file_id:task @@ -199,6 +203,9 @@ async def serve(self, websocket: YWebsocket, permissions) -> None: # cleaning the room was scheduled because there was no client left # cancel that since there is a new client self.cleaners[room].cancel() + await self.cleaners[room].wait_finished() + if room in self.cleaners: + del self.cleaners[room] if not room.ready: file_path = await self.contents.file_id_manager.get_path(file_id) logger.info(f"Opening collaboration room: {websocket.path} ({file_path})") @@ -240,16 +247,18 @@ async def serve(self, websocket: YWebsocket, permissions) -> None: ) # update the document when file changes if file_id not in self.watchers: - self.watchers[file_id] = Task( - self.watch_file(file_format, file_id, document), self._task_group + self.watchers[file_id] = self.task_factory.start_task_soon( + lambda: self.watch_file(file_format, file_id, document), + f"Watch file {file_id}" ) await self.websocket_server.serve(websocket, self.lifespan.shutdown_request) if is_stored_document and not room.clients: # no client in this room after we disconnect - self.cleaners[room] = Task( - self.maybe_clean_room(room, websocket.path), self._task_group + self.cleaners[room] = self.task_factory.start_task_soon( + lambda: self.maybe_clean_room(room, websocket.path), + f"Clean room {websocket.path}" ) async def filter_message(self, can_write: bool, message: bytes) -> bool: @@ -299,6 +308,8 @@ async def watch_file(self, file_format: str, file_id: str, document: YBaseDoc) - file_path = new_file_path # break await self.maybe_load_file(file_format, file_path, file_id) + if file_id in self.watchers: + del self.watchers[file_id] async def maybe_load_file(self, file_format: str, file_path: str, file_id: str) -> None: model = await self.contents.read_content(file_path, False) @@ -329,8 +340,9 @@ def on_document_change( ) if file_id in self.savers: self.savers[file_id].cancel() - self.savers[file_id] = Task( - self.maybe_save_document(file_id, file_type, file_format, document), self._task_group + self.savers[file_id] = self.task_factory.start_task_soon( + lambda: self.maybe_save_document(file_id, file_type, file_format, document), + f"Save file {file_id}" ) async def maybe_save_document( @@ -380,11 +392,15 @@ async def maybe_clean_room(self, room, ws_path: str) -> None: documents = [v for k, v in self.documents.items() if k.split(":", 2)[2] == file_id] if not documents: self.watchers[file_id].cancel() - del self.watchers[file_id] + await self.watchers[file_id].wait_finished() + if file_id in self.watchers: + del self.watchers[file_id] room_name = self.websocket_server.get_room_name(room) self.websocket_server.delete_room(room=room) file_path = await self.get_file_path(file_id, document) logger.info(f"Closing collaboration room: {room_name} ({file_path})") + if room in self.cleaners: + del self.cleaners[room] class JupyterWebsocketServer(WebsocketServer): @@ -402,31 +418,3 @@ async def get_room(self, ws_path: str, ydoc: Doc | None = None) -> YRoom: room = self.rooms[ws_path] await self.start_room(room) return room - - -class Task: - def __init__(self, coro, task_group: TaskGroup, cancel_event: Event | None = None): - self._coro = coro - self._cancel_event = cancel_event - self.cancelled = Event() - self.finished = Event() - task_group.start_soon(self.run) - - def cancel(self): - self.cancelled.set() - - async def run(self): - async with create_task_group() as tg: - tg.start_soon(self._run, tg) - tg.start_soon(self._check_cancellation, self.cancelled, tg) - if self._cancel_event is not None: - tg.start_soon(self._check_cancellation, self._cancel_event, tg) - self.finished.set() - - async def _run(self, tg: TaskGroup): - await self._coro - tg.cancel_scope.cancel() - - async def _check_cancellation(self, cancel_event, tg: TaskGroup): - await cancel_event.wait() - tg.cancel_scope.cancel()