diff --git a/src/ai/backend/web/proxy.py b/src/ai/backend/web/proxy.py index af70554ada..fa7549ee0f 100644 --- a/src/ai/backend/web/proxy.py +++ b/src/ai/backend/web/proxy.py @@ -5,7 +5,7 @@ import json import logging import random -from typing import Optional, Tuple, Union, cast +from typing import Iterable, Optional, Tuple, Union, cast import aiohttp from aiohttp import web @@ -150,23 +150,21 @@ async def decrypt_payload(request: web.Request, handler) -> web.StreamResponse: return await handler(request) -async def web_handler(request: web.Request, *, is_anonymous=False) -> web.StreamResponse: +async def web_handler( + request: web.Request, + *, + is_anonymous: bool = False, + override_api_endpoint: Optional[str] = None, + extra_forwarding_headers: Iterable[str] | None = None, +) -> web.StreamResponse: stats: WebStats = request.app["stats"] stats.active_proxy_api_handlers.add(asyncio.current_task()) # type: ignore - config = request.app["config"] path = request.match_info.get("path", "") - proxy_path, _, real_path = request.path.lstrip("/").partition("/") - if proxy_path == "pipeline": - pipeline_config = config["pipeline"] - if not pipeline_config: - raise RuntimeError("'pipeline' config must be set to handle pipeline requests.") - endpoint = pipeline_config["endpoint"] - log.info(f"WEB_HANDLER: {request.path} -> {endpoint}/{real_path}") - is_anonymous = not real_path.lstrip("/").startswith("login") if is_anonymous: - api_session = await asyncio.shield(get_anonymous_session(request)) + api_session = await asyncio.shield(get_anonymous_session(request, override_api_endpoint)) else: - api_session = await asyncio.shield(get_api_session(request)) + api_session = await asyncio.shield(get_api_session(request, override_api_endpoint)) + extra_forwarding_headers = extra_forwarding_headers or [] try: async with api_session: # We perform request signing by ourselves using the HTTP session data, @@ -202,13 +200,9 @@ async def web_handler(request: web.Request, *, is_anonymous=False) -> web.Stream api_rqst.headers["Content-Length"] = request.headers["Content-Length"] if "Content-Length" in request.headers and secure_context: api_rqst.headers["Content-Length"] = str(decrypted_payload_length) - for hdr in HTTP_HEADERS_TO_FORWARD: + for hdr in {*HTTP_HEADERS_TO_FORWARD, *extra_forwarding_headers}: if request.headers.get(hdr) is not None: api_rqst.headers[hdr] = request.headers[hdr] - if proxy_path == "pipeline" and real_path.rstrip("/") == "login": - api_rqst.headers["X-BackendAI-SessionID"] = request.headers.get( - "X-BackendAI-SessionID", "" - ) # Uploading request body happens at the entering of the block, # and downloading response body happens in the read loop inside. async with api_rqst.fetch() as up_resp: diff --git a/src/ai/backend/web/server.py b/src/ai/backend/web/server.py index 05ec559c05..59dd1dfa85 100644 --- a/src/ai/backend/web/server.py +++ b/src/ai/backend/web/server.py @@ -654,6 +654,17 @@ async def server_main( anon_web_handler = partial(web_handler, is_anonymous=True) anon_web_plugin_handler = partial(web_plugin_handler, is_anonymous=True) + pipeline_api_endpoint = config["pipeline"]["endpoint"] + pipeline_handler = partial( + web_handler, is_anonymous=True, override_api_endpoint=pipeline_api_endpoint + ) + pipeline_login_handler = partial( + web_handler, + is_anonymous=False, + override_api_endpoint=pipeline_api_endpoint, + extra_forwarding_headers={"X-BackendAI-SessionID"}, + ) + app.router.add_route("HEAD", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler) app.router.add_route("PATCH", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler) app.router.add_route( @@ -689,11 +700,12 @@ async def server_main( cors.add(app.router.add_route("PATCH", "/func/{path:.*$}", web_handler)) cors.add(app.router.add_route("DELETE", "/func/{path:.*$}", web_handler)) cors.add(app.router.add_route("GET", "/pipeline/{path:stream/.*$}", websocket_handler)) - cors.add(app.router.add_route("GET", "/pipeline/{path:.*$}", web_handler)) - cors.add(app.router.add_route("PUT", "/pipeline/{path:.*$}", web_handler)) - cors.add(app.router.add_route("POST", "/pipeline/{path:.*$}", web_handler)) - cors.add(app.router.add_route("PATCH", "/pipeline/{path:.*$}", web_handler)) - cors.add(app.router.add_route("DELETE", "/pipeline/{path:.*$}", web_handler)) + cors.add(app.router.add_route("POST", "/pipeline/login/", pipeline_login_handler)) + cors.add(app.router.add_route("GET", "/pipeline/{path:.*$}", pipeline_handler)) + cors.add(app.router.add_route("PUT", "/pipeline/{path:.*$}", pipeline_handler)) + cors.add(app.router.add_route("POST", "/pipeline/{path:.*$}", pipeline_handler)) + cors.add(app.router.add_route("PATCH", "/pipeline/{path:.*$}", pipeline_handler)) + cors.add(app.router.add_route("DELETE", "/pipeline/{path:.*$}", pipeline_handler)) if config["service"]["mode"] == "webui": cors.add(app.router.add_route("GET", "/config.ini", config_ini_handler)) cors.add(app.router.add_route("GET", "/config.toml", config_toml_handler))