Skip to content

Commit

Permalink
refactor: Update pipeline API handler to be configurable from outside
Browse files Browse the repository at this point in the history
  • Loading branch information
rapsealk committed Jan 15, 2025
1 parent 98649d4 commit d9b5e97
Show file tree
Hide file tree
Showing 2 changed files with 29 additions and 23 deletions.
30 changes: 12 additions & 18 deletions src/ai/backend/web/proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import json
import logging
import random
from typing import Optional, Tuple, Union, cast
from typing import Iterable, Optional, Tuple, Union, cast

import aiohttp
from aiohttp import web
Expand Down Expand Up @@ -150,23 +150,21 @@ async def decrypt_payload(request: web.Request, handler) -> web.StreamResponse:
return await handler(request)


async def web_handler(request: web.Request, *, is_anonymous=False) -> web.StreamResponse:
async def web_handler(
request: web.Request,
*,
is_anonymous: bool = False,
override_api_endpoint: Optional[str] = None,
extra_forwarding_headers: Iterable[str] | None = None,
) -> web.StreamResponse:
stats: WebStats = request.app["stats"]
stats.active_proxy_api_handlers.add(asyncio.current_task()) # type: ignore
config = request.app["config"]
path = request.match_info.get("path", "")
proxy_path, _, real_path = request.path.lstrip("/").partition("/")
if proxy_path == "pipeline":
pipeline_config = config["pipeline"]
if not pipeline_config:
raise RuntimeError("'pipeline' config must be set to handle pipeline requests.")
endpoint = pipeline_config["endpoint"]
log.info(f"WEB_HANDLER: {request.path} -> {endpoint}/{real_path}")
is_anonymous = not real_path.lstrip("/").startswith("login")
if is_anonymous:
api_session = await asyncio.shield(get_anonymous_session(request))
api_session = await asyncio.shield(get_anonymous_session(request, override_api_endpoint))
else:
api_session = await asyncio.shield(get_api_session(request))
api_session = await asyncio.shield(get_api_session(request, override_api_endpoint))
extra_forwarding_headers = extra_forwarding_headers or []
try:
async with api_session:
# We perform request signing by ourselves using the HTTP session data,
Expand Down Expand Up @@ -202,13 +200,9 @@ async def web_handler(request: web.Request, *, is_anonymous=False) -> web.Stream
api_rqst.headers["Content-Length"] = request.headers["Content-Length"]
if "Content-Length" in request.headers and secure_context:
api_rqst.headers["Content-Length"] = str(decrypted_payload_length)
for hdr in HTTP_HEADERS_TO_FORWARD:
for hdr in {*HTTP_HEADERS_TO_FORWARD, *extra_forwarding_headers}:
if request.headers.get(hdr) is not None:
api_rqst.headers[hdr] = request.headers[hdr]
if proxy_path == "pipeline" and real_path.rstrip("/") == "login":
api_rqst.headers["X-BackendAI-SessionID"] = request.headers.get(
"X-BackendAI-SessionID", ""
)
# Uploading request body happens at the entering of the block,
# and downloading response body happens in the read loop inside.
async with api_rqst.fetch() as up_resp:
Expand Down
22 changes: 17 additions & 5 deletions src/ai/backend/web/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -654,6 +654,17 @@ async def server_main(
anon_web_handler = partial(web_handler, is_anonymous=True)
anon_web_plugin_handler = partial(web_plugin_handler, is_anonymous=True)

pipeline_api_endpoint = config["pipeline"]["endpoint"]
pipeline_handler = partial(
web_handler, is_anonymous=True, override_api_endpoint=pipeline_api_endpoint
)
pipeline_login_handler = partial(
web_handler,
is_anonymous=False,
override_api_endpoint=pipeline_api_endpoint,
extra_forwarding_headers={"X-BackendAI-SessionID"},
)

app.router.add_route("HEAD", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler)
app.router.add_route("PATCH", "/func/{path:folders/_/tus/upload/.*$}", anon_web_plugin_handler)
app.router.add_route(
Expand Down Expand Up @@ -689,11 +700,12 @@ async def server_main(
cors.add(app.router.add_route("PATCH", "/func/{path:.*$}", web_handler))
cors.add(app.router.add_route("DELETE", "/func/{path:.*$}", web_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:stream/.*$}", websocket_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("PUT", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("POST", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("PATCH", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("DELETE", "/pipeline/{path:.*$}", web_handler))
cors.add(app.router.add_route("POST", "/pipeline/login/", pipeline_login_handler))
cors.add(app.router.add_route("GET", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("PUT", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("POST", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("PATCH", "/pipeline/{path:.*$}", pipeline_handler))
cors.add(app.router.add_route("DELETE", "/pipeline/{path:.*$}", pipeline_handler))
if config["service"]["mode"] == "webui":
cors.add(app.router.add_route("GET", "/config.ini", config_ini_handler))
cors.add(app.router.add_route("GET", "/config.toml", config_toml_handler))
Expand Down

0 comments on commit d9b5e97

Please sign in to comment.