From c35a82922821a6f3cd045a8476618a2af6f1f4a9 Mon Sep 17 00:00:00 2001 From: Sanghun Lee Date: Tue, 2 Jan 2024 18:34:38 +0900 Subject: [PATCH] Revert "remove middleware" This reverts commit 1dd2c331de6fbff1931cbd14abab193d2873d76b. --- src/ai/backend/manager/api/service.py | 9 ++++++++- src/ai/backend/manager/api/utils.py | 4 ++-- 2 files changed, 10 insertions(+), 3 deletions(-) diff --git a/src/ai/backend/manager/api/service.py b/src/ai/backend/manager/api/service.py index 84ab17d03d..c375491837 100644 --- a/src/ai/backend/manager/api/service.py +++ b/src/ai/backend/manager/api/service.py @@ -61,6 +61,7 @@ from .types import CORSOptions, WebMiddleware from .utils import ( check_api_params_v2, + convert_response, get_access_key_scopes, get_user_uuid_scopes, undefined, @@ -912,6 +913,12 @@ async def clear_error(request: web.Request) -> web.Response: return web.Response(status=204) +@web.middleware +async def response_middleware(request: web.Request, handler) -> web.StreamResponse: + result = await handler(request) + return convert_response(result) + + @attrs.define(slots=True, auto_attribs=True, init=False) class PrivateContext: database_ptask_group: aiotools.PersistentTaskGroup @@ -950,4 +957,4 @@ def create_app( cors.add(add_route("PUT", "/{service_id}/routings/{route_id}", update_route)) cors.add(add_route("DELETE", "/{service_id}/routings/{route_id}", delete_route)) cors.add(add_route("POST", "/{service_id}/token", generate_token)) - return app, [] + return app, [response_middleware] diff --git a/src/ai/backend/manager/api/utils.py b/src/ai/backend/manager/api/utils.py index c5111932a8..66ee50752f 100644 --- a/src/ai/backend/manager/api/utils.py +++ b/src/ai/backend/manager/api/utils.py @@ -217,12 +217,12 @@ async def wrapped(request: web.Request, *args: P.args, **kwargs: P.kwargs) -> TA def convert_response(response: TResponseModel | list | TAnyResponse) -> web.StreamResponse: match response: - case web_response.StreamResponse(): - return response case BaseModel(): return web.json_response(response.model_dump()) case list(): return web.json_response(TypeAdapter(list[TResponseModel]).dump_python(response)) + case web_response.StreamResponse(): + return response case _: raise RuntimeError(f"Unsupported response type ({type(response)})")