From fe42788d6f9ed627cdef7e2f37bd9aa46349eb32 Mon Sep 17 00:00:00 2001 From: Jeongseok Kang Date: Wed, 15 Jan 2025 14:19:31 +0900 Subject: [PATCH] refactor: Update pipeline websocket handler --- src/ai/backend/web/proxy.py | 19 ++++++------------- src/ai/backend/web/server.py | 11 ++++++----- 2 files changed, 12 insertions(+), 18 deletions(-) diff --git a/src/ai/backend/web/proxy.py b/src/ai/backend/web/proxy.py index fa7549ee0f..c689e3d9f7 100644 --- a/src/ai/backend/web/proxy.py +++ b/src/ai/backend/web/proxy.py @@ -154,16 +154,16 @@ async def web_handler( request: web.Request, *, is_anonymous: bool = False, - override_api_endpoint: Optional[str] = None, + api_endpoint: Optional[str] = None, extra_forwarding_headers: Iterable[str] | None = None, ) -> web.StreamResponse: stats: WebStats = request.app["stats"] stats.active_proxy_api_handlers.add(asyncio.current_task()) # type: ignore path = request.match_info.get("path", "") if is_anonymous: - api_session = await asyncio.shield(get_anonymous_session(request, override_api_endpoint)) + api_session = await asyncio.shield(get_anonymous_session(request, api_endpoint)) else: - api_session = await asyncio.shield(get_api_session(request, override_api_endpoint)) + api_session = await asyncio.shield(get_api_session(request, api_endpoint)) extra_forwarding_headers = extra_forwarding_headers or [] try: async with api_session: @@ -326,7 +326,9 @@ async def web_plugin_handler(request, *, is_anonymous=False) -> web.StreamRespon ) -async def websocket_handler(request, *, is_anonymous=False) -> web.StreamResponse: +async def websocket_handler( + request, *, is_anonymous=False, api_endpoint: Optional[str] = None +) -> web.StreamResponse: stats: WebStats = request.app["stats"] stats.active_proxy_websocket_handlers.add(asyncio.current_task()) # type: ignore path = request.match_info["path"] @@ -334,7 +336,6 @@ async def websocket_handler(request, *, is_anonymous=False) -> web.StreamRespons app = request.query.get("app") # Choose a specific Manager endpoint for persistent web app connection. - api_endpoint = None should_save_session = False configured_endpoints = request.app["config"]["api"]["endpoint"] if session.get("api_endpoints", {}).get(app): @@ -348,14 +349,6 @@ async def websocket_handler(request, *, is_anonymous=False) -> web.StreamRespons session["api_endpoints"][app] = str(api_endpoint) should_save_session = True - proxy_path, _, real_path = request.path.lstrip("/").partition("/") - if proxy_path == "pipeline": - pipeline_config = request.app["config"]["pipeline"] - if not pipeline_config: - raise RuntimeError("'pipeline' config must be set to handle pipeline requests.") - endpoint = pipeline_config["endpoint"].with_scheme("ws") - log.info(f"WEBSOCKET_HANDLER {request.path} -> {endpoint}/{real_path}") - is_anonymous = True if is_anonymous: api_session = await asyncio.shield(get_anonymous_session(request, api_endpoint)) else: diff --git a/src/ai/backend/web/server.py b/src/ai/backend/web/server.py index 59dd1dfa85..df002e853c 100644 --- a/src/ai/backend/web/server.py +++ b/src/ai/backend/web/server.py @@ -655,15 +655,16 @@ async def server_main( anon_web_plugin_handler = partial(web_plugin_handler, is_anonymous=True) pipeline_api_endpoint = config["pipeline"]["endpoint"] - pipeline_handler = partial( - web_handler, is_anonymous=True, override_api_endpoint=pipeline_api_endpoint - ) + pipeline_handler = partial(web_handler, is_anonymous=True, api_endpoint=pipeline_api_endpoint) pipeline_login_handler = partial( web_handler, is_anonymous=False, - override_api_endpoint=pipeline_api_endpoint, + api_endpoint=pipeline_api_endpoint, extra_forwarding_headers={"X-BackendAI-SessionID"}, ) + pipeline_websocket_handler = partial( + websocket_handler, is_anonymous=True, api_endpoint=pipeline_api_endpoint.with_scheme("ws") + ) app.router.add_route("HEAD", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler) app.router.add_route("PATCH", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler) @@ -699,7 +700,7 @@ async def server_main( cors.add(app.router.add_route("POST", "/func/{path:.*$}", web_handler)) cors.add(app.router.add_route("PATCH", "/func/{path:.*$}", web_handler)) cors.add(app.router.add_route("DELETE", "/func/{path:.*$}", web_handler)) - cors.add(app.router.add_route("GET", "/pipeline/{path:stream/.*$}", websocket_handler)) + cors.add(app.router.add_route("GET", "/pipeline/{path:stream/.*$}", pipeline_websocket_handler)) cors.add(app.router.add_route("POST", "/pipeline/login/", pipeline_login_handler)) cors.add(app.router.add_route("GET", "/pipeline/{path:.*$}", pipeline_handler)) cors.add(app.router.add_route("PUT", "/pipeline/{path:.*$}", pipeline_handler))