Skip to content

Commit

Permalink
Use TaskFactory instead of Task
Browse files Browse the repository at this point in the history
  • Loading branch information
davidbrochart committed Apr 18, 2024
1 parent 4aace77 commit c15709a
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 57 deletions.
11 changes: 9 additions & 2 deletions plugins/yjs/fps_yjs/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
from __future__ import annotations

from asphalt.core import Component, add_resource, get_resource, start_service_task
from asphalt.core import (
Component,
add_resource,
get_resource,
start_background_task_factory,
start_service_task,
)

from jupyverse_api.app import App
from jupyverse_api.auth import Auth
Expand All @@ -18,7 +24,8 @@ async def start(self) -> None:
contents = await get_resource(Contents, wait=True) # type: ignore[type-abstract]
lifespan = await get_resource(Lifespan, wait=True)

yjs = _Yjs(app, auth, contents, lifespan)
task_factory = await start_background_task_factory("yjs_tasks")
yjs = _Yjs(app, auth, contents, lifespan, task_factory)
add_resource(yjs, types=Yjs)

await start_service_task(yjs.start, "Room manager", teardown_action=yjs.stop)
Expand Down
98 changes: 43 additions & 55 deletions plugins/yjs/fps_yjs/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,9 @@
from typing import Dict
from uuid import uuid4

from anyio import TASK_STATUS_IGNORED, Event, create_task_group, sleep
from anyio.abc import TaskGroup, TaskStatus
from anyio import TASK_STATUS_IGNORED, sleep
from anyio.abc import TaskStatus
from asphalt.core import TaskFactory, TaskHandle
from fastapi import (
HTTPException,
Request,
Expand Down Expand Up @@ -50,30 +51,33 @@ def __init__(
auth: Auth,
contents: Contents,
lifespan: Lifespan,
task_factory: TaskFactory,
) -> None:
super().__init__(app=app, auth=auth)
self.contents = contents
self.lifespan = lifespan
self.task_factory = task_factory
if Widgets is None:
self.widgets = None
else:
self.widgets = Widgets() # type: ignore

async def start(self, *, task_status: TaskStatus[None] = TASK_STATUS_IGNORED) -> None:
async with create_task_group() as tg:
self._task_group = tg
self.room_manager = RoomManager(self.contents, self.lifespan, tg)
await tg.start(self.room_manager.websocket_server.start)
tg.start_soon(self.room_manager.on_shutdown)
task_status.started()
self.room_manager = RoomManager(self.contents, self.lifespan, self.task_factory)
await self.task_factory.start_task(
self.room_manager.websocket_server.start,
"WebSocket server",
)
self.task_factory.start_task_soon(self.room_manager.on_shutdown)
task_status.started()

async def stop(self) -> None:
for watcher in self.room_manager.watchers.values():
watcher.cancel()
for saver in self.room_manager.savers.values():
saver.cancel()
for cleaner in self.room_manager.cleaners.values():
cleaner.cancel()
for task in (
list(self.room_manager.watchers.values()) +
list(self.room_manager.savers.values()) +
list(self.room_manager.cleaners.values())
):
task.cancel()

async def collaboration_room_websocket(
self,
Expand Down Expand Up @@ -162,18 +166,18 @@ class RoomManager:
contents: Contents
lifespan: Lifespan
documents: Dict[str, YBaseDoc]
watchers: Dict[str, Task]
savers: Dict[str, Task]
cleaners: Dict[YRoom, Task]
watchers: Dict[str, TaskHandle]
savers: Dict[str, TaskHandle]
cleaners: Dict[YRoom, TaskHandle]
last_modified: Dict[str, datetime]
websocket_server: JupyterWebsocketServer
room_lock: ResourceLock
_task_group: TaskGroup
task_factory: TaskFactory

def __init__(self, contents: Contents, lifespan: Lifespan, task_group: TaskGroup):
def __init__(self, contents: Contents, lifespan: Lifespan, task_factory: TaskFactory):
self.contents = contents
self.lifespan = lifespan
self._task_group = task_group
self.task_factory = task_factory
self.documents = {} # a dictionary of room_name:document
self.watchers = {} # a dictionary of file_id:task
self.savers = {} # a dictionary of file_id:task
Expand All @@ -199,6 +203,9 @@ async def serve(self, websocket: YWebsocket, permissions) -> None:
# cleaning the room was scheduled because there was no client left
# cancel that since there is a new client
self.cleaners[room].cancel()
await self.cleaners[room].wait_finished()
if room in self.cleaners:
del self.cleaners[room]
if not room.ready:
file_path = await self.contents.file_id_manager.get_path(file_id)
logger.info(f"Opening collaboration room: {websocket.path} ({file_path})")
Expand Down Expand Up @@ -240,16 +247,18 @@ async def serve(self, websocket: YWebsocket, permissions) -> None:
)
# update the document when file changes
if file_id not in self.watchers:
self.watchers[file_id] = Task(
self.watch_file(file_format, file_id, document), self._task_group
self.watchers[file_id] = self.task_factory.start_task_soon(
lambda: self.watch_file(file_format, file_id, document),
f"Watch file {file_id}"
)

await self.websocket_server.serve(websocket, self.lifespan.shutdown_request)

if is_stored_document and not room.clients:
# no client in this room after we disconnect
self.cleaners[room] = Task(
self.maybe_clean_room(room, websocket.path), self._task_group
self.cleaners[room] = self.task_factory.start_task_soon(
lambda: self.maybe_clean_room(room, websocket.path),
f"Clean room {websocket.path}"
)

async def filter_message(self, can_write: bool, message: bytes) -> bool:
Expand Down Expand Up @@ -299,6 +308,8 @@ async def watch_file(self, file_format: str, file_id: str, document: YBaseDoc) -
file_path = new_file_path
# break
await self.maybe_load_file(file_format, file_path, file_id)
if file_id in self.watchers:
del self.watchers[file_id]

async def maybe_load_file(self, file_format: str, file_path: str, file_id: str) -> None:
model = await self.contents.read_content(file_path, False)
Expand Down Expand Up @@ -329,8 +340,9 @@ def on_document_change(
)
if file_id in self.savers:
self.savers[file_id].cancel()
self.savers[file_id] = Task(
self.maybe_save_document(file_id, file_type, file_format, document), self._task_group
self.savers[file_id] = self.task_factory.start_task_soon(
lambda: self.maybe_save_document(file_id, file_type, file_format, document),
f"Save file {file_id}"
)

async def maybe_save_document(
Expand Down Expand Up @@ -380,11 +392,15 @@ async def maybe_clean_room(self, room, ws_path: str) -> None:
documents = [v for k, v in self.documents.items() if k.split(":", 2)[2] == file_id]
if not documents:
self.watchers[file_id].cancel()
del self.watchers[file_id]
await self.watchers[file_id].wait_finished()
if file_id in self.watchers:
del self.watchers[file_id]
room_name = self.websocket_server.get_room_name(room)
self.websocket_server.delete_room(room=room)
file_path = await self.get_file_path(file_id, document)
logger.info(f"Closing collaboration room: {room_name} ({file_path})")
if room in self.cleaners:
del self.cleaners[room]


class JupyterWebsocketServer(WebsocketServer):
Expand All @@ -402,31 +418,3 @@ async def get_room(self, ws_path: str, ydoc: Doc | None = None) -> YRoom:
room = self.rooms[ws_path]
await self.start_room(room)
return room


class Task:
def __init__(self, coro, task_group: TaskGroup, cancel_event: Event | None = None):
self._coro = coro
self._cancel_event = cancel_event
self.cancelled = Event()
self.finished = Event()
task_group.start_soon(self.run)

def cancel(self):
self.cancelled.set()

async def run(self):
async with create_task_group() as tg:
tg.start_soon(self._run, tg)
tg.start_soon(self._check_cancellation, self.cancelled, tg)
if self._cancel_event is not None:
tg.start_soon(self._check_cancellation, self._cancel_event, tg)
self.finished.set()

async def _run(self, tg: TaskGroup):
await self._coro
tg.cancel_scope.cancel()

async def _check_cancellation(self, cancel_event, tg: TaskGroup):
await cancel_event.wait()
tg.cancel_scope.cancel()

0 comments on commit c15709a

Please sign in to comment.