Skip to content

Commit

Permalink
refactor: Update pipeline websocket handler
Browse files Browse the repository at this point in the history
  • Loading branch information
rapsealk committed Jan 15, 2025
1 parent d9b5e97 commit fe42788
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 18 deletions.
19 changes: 6 additions & 13 deletions src/ai/backend/web/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,16 +154,16 @@ async def web_handler(
request: web.Request,
*,
is_anonymous: bool = False,
override_api_endpoint: Optional[str] = None,
api_endpoint: Optional[str] = None,
extra_forwarding_headers: Iterable[str] | None = None,
) -> web.StreamResponse:
stats: WebStats = request.app["stats"]
stats.active_proxy_api_handlers.add(asyncio.current_task()) # type: ignore
path = request.match_info.get("path", "")
if is_anonymous:
api_session = await asyncio.shield(get_anonymous_session(request, override_api_endpoint))
api_session = await asyncio.shield(get_anonymous_session(request, api_endpoint))
else:
api_session = await asyncio.shield(get_api_session(request, override_api_endpoint))
api_session = await asyncio.shield(get_api_session(request, api_endpoint))
extra_forwarding_headers = extra_forwarding_headers or []
try:
async with api_session:
Expand Down Expand Up @@ -326,15 +326,16 @@ async def web_plugin_handler(request, *, is_anonymous=False) -> web.StreamRespon
)


async def websocket_handler(request, *, is_anonymous=False) -> web.StreamResponse:
async def websocket_handler(
request, *, is_anonymous=False, api_endpoint: Optional[str] = None
) -> web.StreamResponse:
stats: WebStats = request.app["stats"]
stats.active_proxy_websocket_handlers.add(asyncio.current_task()) # type: ignore
path = request.match_info["path"]
session = await get_session(request)
app = request.query.get("app")

# Choose a specific Manager endpoint for persistent web app connection.
api_endpoint = None
should_save_session = False
configured_endpoints = request.app["config"]["api"]["endpoint"]
if session.get("api_endpoints", {}).get(app):
Expand All @@ -348,14 +349,6 @@ async def websocket_handler(request, *, is_anonymous=False) -> web.StreamRespons
session["api_endpoints"][app] = str(api_endpoint)
should_save_session = True

proxy_path, _, real_path = request.path.lstrip("/").partition("/")
if proxy_path == "pipeline":
pipeline_config = request.app["config"]["pipeline"]
if not pipeline_config:
raise RuntimeError("'pipeline' config must be set to handle pipeline requests.")
endpoint = pipeline_config["endpoint"].with_scheme("ws")
log.info(f"WEBSOCKET_HANDLER {request.path} -> {endpoint}/{real_path}")
is_anonymous = True
if is_anonymous:
api_session = await asyncio.shield(get_anonymous_session(request, api_endpoint))
else:
Expand Down
11 changes: 6 additions & 5 deletions src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -655,15 +655,16 @@ async def server_main(
anon_web_plugin_handler = partial(web_plugin_handler, is_anonymous=True)

pipeline_api_endpoint = config["pipeline"]["endpoint"]
pipeline_handler = partial(
web_handler, is_anonymous=True, override_api_endpoint=pipeline_api_endpoint
)
pipeline_handler = partial(web_handler, is_anonymous=True, api_endpoint=pipeline_api_endpoint)
pipeline_login_handler = partial(
web_handler,
is_anonymous=False,
override_api_endpoint=pipeline_api_endpoint,
api_endpoint=pipeline_api_endpoint,
extra_forwarding_headers={"X-BackendAI-SessionID"},
)
pipeline_websocket_handler = partial(
websocket_handler, is_anonymous=True, api_endpoint=pipeline_api_endpoint.with_scheme("ws")
)

app.router.add_route("HEAD", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler)
app.router.add_route("PATCH", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler)
Expand Down Expand Up @@ -699,7 +700,7 @@ async def server_main(
cors.add(app.router.add_route("POST", "/func/{path:.*$}", web_handler))
cors.add(app.router.add_route("PATCH", "/func/{path:.*$}", web_handler))
cors.add(app.router.add_route("DELETE", "/func/{path:.*$}", web_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:stream/.*$}", websocket_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:stream/.*$}", pipeline_websocket_handler))
cors.add(app.router.add_route("POST", "/pipeline/login/", pipeline_login_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("PUT", "/pipeline/{path:.*$}", pipeline_handler))
Expand Down

0 comments on commit fe42788

Please sign in to comment.