From f6981ff0b9a0fc68e5b85b744b10893ee677a1ab Mon Sep 17 00:00:00 2001 From: Anton Dubovik Date: Wed, 21 Feb 2024 12:17:35 +0000 Subject: [PATCH] feat: supported tokenize and truncate_prompt endpoints (#50) --- aidial_sdk/application.py | 193 ++++++++++++-------- aidial_sdk/chat_completion/__init__.py | 12 ++ aidial_sdk/chat_completion/base.py | 18 +- aidial_sdk/chat_completion/chunks.py | 6 +- aidial_sdk/chat_completion/request.py | 27 +-- aidial_sdk/chat_completion/response.py | 14 +- aidial_sdk/deployment/__init__.py | 0 aidial_sdk/deployment/from_request_mixin.py | 74 ++++++++ aidial_sdk/deployment/rate.py | 7 + aidial_sdk/deployment/tokenize.py | 39 ++++ aidial_sdk/deployment/truncate_prompt.py | 26 +++ aidial_sdk/telemetry/types.py | 2 +- aidial_sdk/utils/pydantic.py | 6 + pyproject.toml | 2 +- tests/applications/echo_application.py | 47 +++++ tests/applications/noop_application.py | 9 + tests/test_discarded_messages.py | 19 +- tests/test_errors.py | 27 +++ tests/test_rate_response.py | 34 ++-- tests/test_tokenize.py | 65 +++++++ tests/test_truncate_prompt.py | 108 +++++++++++ tests/utils/endpoint_test.py | 57 ++++++ tests/utils/errors.py | 52 ++++++ tests/utils/tokenization.py | 136 ++++++++++++++ 24 files changed, 851 insertions(+), 129 deletions(-) create mode 100644 aidial_sdk/deployment/__init__.py create mode 100644 aidial_sdk/deployment/from_request_mixin.py create mode 100644 aidial_sdk/deployment/rate.py create mode 100644 aidial_sdk/deployment/tokenize.py create mode 100644 aidial_sdk/deployment/truncate_prompt.py create mode 100644 aidial_sdk/utils/pydantic.py create mode 100644 tests/applications/echo_application.py create mode 100644 tests/applications/noop_application.py create mode 100644 tests/test_tokenize.py create mode 100644 tests/test_truncate_prompt.py create mode 100644 tests/utils/endpoint_test.py create mode 100644 tests/utils/errors.py create mode 100644 tests/utils/tokenization.py diff --git a/aidial_sdk/application.py b/aidial_sdk/application.py index ff392a0..59adfd1 100644 --- a/aidial_sdk/application.py +++ b/aidial_sdk/application.py @@ -1,16 +1,19 @@ import logging.config -from json import JSONDecodeError -from typing import Dict, Optional, Union +from typing import Dict, Optional, Type, TypeVar from fastapi import FastAPI, HTTPException, Request from fastapi.responses import JSONResponse, Response, StreamingResponse from aidial_sdk.chat_completion.base import ChatCompletion -from aidial_sdk.chat_completion.request import RateRequest from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest from aidial_sdk.chat_completion.response import ( Response as ChatCompletionResponse, ) +from aidial_sdk.deployment.from_request_mixin import FromRequestMixin +from aidial_sdk.deployment.rate import RateRequest +from aidial_sdk.deployment.tokenize import TokenizeRequest +from aidial_sdk.deployment.truncate_prompt import TruncatePromptRequest +from aidial_sdk.exceptions import HTTPException as DIALException from aidial_sdk.header_propagator import HeaderPropagator from aidial_sdk.pydantic_v1 import ValidationError from aidial_sdk.telemetry.types import TelemetryConfig @@ -21,6 +24,8 @@ logging.config.dictConfig(LogConfig().dict()) +RequestType = TypeVar("RequestType", bound=FromRequestMixin) + class DIALApp(FastAPI): chat_completion_impls: Dict[str, ChatCompletion] = {} @@ -61,7 +66,29 @@ def __init__( methods=["POST"], ) - self.add_exception_handler(HTTPException, DIALApp._exception_handler) + self.add_api_route( + "/openai/deployments/{deployment_id}/tokenize", + self._endpoint_factory("tokenize", TokenizeRequest), + methods=["POST"], + ) + + self.add_api_route( + "/openai/deployments/{deployment_id}/truncate_prompt", + self._endpoint_factory("truncate_prompt", TruncatePromptRequest), + methods=["POST"], + ) + + self.add_exception_handler( + ValidationError, DIALApp._pydantic_validation_exception_handler + ) + + self.add_exception_handler( + HTTPException, DIALApp._fastapi_exception_handler + ) + + self.add_exception_handler( + DIALException, DIALApp._dial_exception_handler + ) def configure_telemetry(self, config: TelemetryConfig): try: @@ -79,58 +106,55 @@ def add_chat_completion( ) -> None: self.chat_completion_impls[deployment_name] = impl + def _endpoint_factory( + self, endpoint: str, request_type: Type["RequestType"] + ): + async def _handler( + deployment_id: str, original_request: Request + ) -> Response: + set_log_deployment(deployment_id) + deployment = self._get_deployment(deployment_id) + + request = await request_type.from_request(original_request) + + endpoint_impl = getattr(deployment, endpoint, None) + if not endpoint_impl: + raise self._get_missing_endpoint_error(endpoint) + + try: + response = await endpoint_impl(request) + except NotImplementedError: + raise self._get_missing_endpoint_error(endpoint) + + response_json = response.dict() + log_debug(f"response [{endpoint}]: {response_json}") + return JSONResponse(content=response_json) + + return _handler + async def _rate_response( self, deployment_id: str, original_request: Request ) -> Response: set_log_deployment(deployment_id) - impl = self._get_deployment(deployment_id) + deployment = self._get_deployment(deployment_id) - if isinstance(impl, JSONResponse): - return impl + request = await RateRequest.from_request(original_request) - body = await DIALApp._get_json_body(original_request) - if isinstance(body, JSONResponse): - return body - log_debug(f"request: {body}") - - try: - request = RateRequest(**body) - except ValidationError as e: - return DIALApp._get_validation_error_response(e) - - await impl.rate_response(request) + await deployment.rate_response(request) return Response(status_code=200) async def _chat_completion( self, deployment_id: str, original_request: Request ) -> Response: set_log_deployment(deployment_id) - impl = self._get_deployment(deployment_id) + deployment = self._get_deployment(deployment_id) - if isinstance(impl, JSONResponse): - return impl - - body = await DIALApp._get_json_body(original_request) - if isinstance(body, JSONResponse): - return body - - headers = original_request.headers - try: - request = ChatCompletionRequest( - **body, - api_key=headers["Api-Key"], - jwt=headers.get("Authorization"), - deployment_id=deployment_id, - api_version=original_request.query_params.get("api-version"), - headers=headers, - ) - except ValidationError as e: - return DIALApp._get_validation_error_response(e) - - log_debug(f"request: {body}") + request = await ChatCompletionRequest.from_request(original_request) response = ChatCompletionResponse(request) - first_chunk = await response._generator(impl.chat_completion, request) + first_chunk = await response._generator( + deployment.chat_completion, request + ) if request.stream: return StreamingResponse( @@ -138,65 +162,72 @@ async def _chat_completion( media_type="text/event-stream", ) else: - response_body = await merge_chunks( + response_json = await merge_chunks( response._generate_stream(first_chunk) ) - log_debug(f"response: {response_body}") - return JSONResponse(content=response_body) + log_debug(f"response: {response_json}") + return JSONResponse(content=response_json) + + @staticmethod + async def _healthcheck() -> JSONResponse: + return JSONResponse(content={"status": "ok"}) - def _get_deployment( - self, deployment_id: str - ) -> Union[ChatCompletion, JSONResponse]: + def _get_deployment(self, deployment_id: str) -> ChatCompletion: impl = self.chat_completion_impls.get(deployment_id, None) if not impl: - return JSONResponse( + raise DIALException( status_code=404, - content=json_error( - message="The API deployment for this resource does not exist.", - code="deployment_not_found", - ), + code="deployment_not_found", + message="The API deployment for this resource does not exist.", ) + return impl @staticmethod - async def _get_json_body(request: Request) -> Union[dict, JSONResponse]: - try: - return await request.json() - except JSONDecodeError as e: - return JSONResponse( - status_code=400, - content=json_error( - message=f"Your request contained invalid JSON: {str(e.msg)}", - type="invalid_request_error", - ), - ) + def _get_missing_endpoint_error(endpoint: str) -> DIALException: + return DIALException( + status_code=404, + code="endpoint_not_found", + message=f"The deployment doesn't implement '{endpoint}' endpoint.", + ) @staticmethod - def _get_validation_error_response( - e: ValidationError, + def _pydantic_validation_exception_handler( + request: Request, exc: Exception ) -> JSONResponse: - error = e.errors()[0] - path = ".".join(map(str, e.errors()[0]["loc"])) + assert isinstance(exc, ValidationError) + + error = exc.errors()[0] + path = ".".join(map(str, error["loc"])) + message = f"Your request contained invalid structure on path {path}. {error['msg']}" return JSONResponse( status_code=400, - content=json_error( - message=f"Your request contained invalid structure on path {path}. {error['msg']}", - type="invalid_request_error", - ), + content=json_error(message=message, type="invalid_request_error"), ) @staticmethod - async def _healthcheck() -> JSONResponse: - return JSONResponse(content={"status": "ok"}) + def _fastapi_exception_handler( + request: Request, exc: Exception + ) -> JSONResponse: + assert isinstance(exc, HTTPException) + return JSONResponse( + status_code=exc.status_code, + content=exc.detail, + ) @staticmethod - def _exception_handler(request: Request, exc: Exception): - if isinstance(exc, HTTPException): - return JSONResponse( - status_code=exc.status_code, - content=exc.detail, - ) - else: - raise exc + def _dial_exception_handler( + request: Request, exc: Exception + ) -> JSONResponse: + assert isinstance(exc, DIALException) + return JSONResponse( + status_code=exc.status_code, + content=json_error( + message=exc.message, + type=exc.type, + param=exc.param, + code=exc.code, + ), + ) diff --git a/aidial_sdk/chat_completion/__init__.py b/aidial_sdk/chat_completion/__init__.py index 07e1ccd..03d5c30 100644 --- a/aidial_sdk/chat_completion/__init__.py +++ b/aidial_sdk/chat_completion/__init__.py @@ -17,3 +17,15 @@ ) from aidial_sdk.chat_completion.response import Response from aidial_sdk.chat_completion.stage import Stage +from aidial_sdk.deployment.tokenize import ( + TokenizeError, + TokenizeRequest, + TokenizeResponse, + TokenizeSuccess, +) +from aidial_sdk.deployment.truncate_prompt import ( + TruncatePromptError, + TruncatePromptRequest, + TruncatePromptResponse, + TruncatePromptSuccess, +) diff --git a/aidial_sdk/chat_completion/base.py b/aidial_sdk/chat_completion/base.py index 4f4d476..69ad845 100644 --- a/aidial_sdk/chat_completion/base.py +++ b/aidial_sdk/chat_completion/base.py @@ -1,7 +1,13 @@ from abc import ABC, abstractmethod -from aidial_sdk.chat_completion.request import RateRequest, Request +from aidial_sdk.chat_completion.request import Request from aidial_sdk.chat_completion.response import Response +from aidial_sdk.deployment.rate import RateRequest +from aidial_sdk.deployment.tokenize import TokenizeRequest, TokenizeResponse +from aidial_sdk.deployment.truncate_prompt import ( + TruncatePromptRequest, + TruncatePromptResponse, +) class ChatCompletion(ABC): @@ -13,3 +19,13 @@ async def chat_completion( async def rate_response(self, request: RateRequest) -> None: """Implement rate response logic""" + + async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse: + """Implement tokenize logic""" + raise NotImplementedError() + + async def truncate_prompt( + self, request: TruncatePromptRequest + ) -> TruncatePromptResponse: + """Implement truncate prompt logic""" + raise NotImplementedError() diff --git a/aidial_sdk/chat_completion/chunks.py b/aidial_sdk/chat_completion/chunks.py index 8dd6abc..5183bb2 100644 --- a/aidial_sdk/chat_completion/chunks.py +++ b/aidial_sdk/chat_completion/chunks.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, Optional +from typing import Any, Dict, List, Optional from aidial_sdk.chat_completion.enums import FinishReason, Status from aidial_sdk.pydantic_v1 import BaseModel, root_validator @@ -457,9 +457,9 @@ def to_dict(self): class DiscardedMessagesChunk(BaseChunk): - discarded_messages: int + discarded_messages: List[int] - def __init__(self, discarded_messages: int): + def __init__(self, discarded_messages: List[int]): self.discarded_messages = discarded_messages def to_dict(self): diff --git a/aidial_sdk/chat_completion/request.py b/aidial_sdk/chat_completion/request.py index 71b126a..91a38ba 100644 --- a/aidial_sdk/chat_completion/request.py +++ b/aidial_sdk/chat_completion/request.py @@ -1,20 +1,15 @@ from enum import Enum from typing import Any, Dict, List, Literal, Mapping, Optional, Union +from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin from aidial_sdk.pydantic_v1 import ( - BaseModel, ConstrainedFloat, ConstrainedInt, ConstrainedList, - Field, PositiveInt, StrictStr, ) - - -class ExtraForbidModel(BaseModel): - class Config: - extra = "forbid" +from aidial_sdk.utils.pydantic import ExtraForbidModel class Attachment(ExtraForbidModel): @@ -112,7 +107,7 @@ class ToolChoice(ExtraForbidModel): function: FunctionChoice -class Request(ExtraForbidModel): +class AzureChatCompletionRequest(ExtraForbidModel): model: Optional[StrictStr] = None messages: List[Message] functions: Optional[List[Function]] = None @@ -128,19 +123,17 @@ class Request(ExtraForbidModel): n: Optional[N] = None stop: Optional[Union[StrictStr, Stop]] = None max_tokens: Optional[PositiveInt] = None - max_prompt_tokens: Optional[PositiveInt] = None presence_penalty: Optional[Penalty] = None frequency_penalty: Optional[Penalty] = None logit_bias: Optional[Mapping[int, float]] = None user: Optional[StrictStr] = None - api_key: StrictStr - jwt: Optional[StrictStr] = None - deployment_id: StrictStr - api_version: Optional[StrictStr] = None - headers: Mapping[StrictStr, StrictStr] + +class ChatCompletionRequest(AzureChatCompletionRequest): + model: Optional[StrictStr] = None + addons: Optional[List[Addon]] = None + max_prompt_tokens: Optional[PositiveInt] = None -class RateRequest(ExtraForbidModel): - response_id: StrictStr = Field(None, alias="responseId") - rate: bool = False +class Request(ChatCompletionRequest, FromRequestDeploymentMixin): + pass diff --git a/aidial_sdk/chat_completion/response.py b/aidial_sdk/chat_completion/response.py index fdf99f4..e309886 100644 --- a/aidial_sdk/chat_completion/response.py +++ b/aidial_sdk/chat_completion/response.py @@ -1,6 +1,14 @@ import asyncio from time import time -from typing import Any, AsyncGenerator, Callable, Coroutine, Dict, Optional +from typing import ( + Any, + AsyncGenerator, + Callable, + Coroutine, + Dict, + List, + Optional, +) from uuid import uuid4 from fastapi import HTTPException @@ -204,7 +212,7 @@ async def _generate_stream( async def _generator( self, - producer: Callable[[Any, Any], Coroutine[Any, Any, Any]], + producer: Callable[[Request, "Response"], Coroutine[Any, Any, Any]], request: Request, ) -> BaseChunk: self.user_task = asyncio.create_task(producer(request, self)) @@ -285,7 +293,7 @@ def add_usage_per_model( ) self._last_usage_per_model_index += 1 - def set_discarded_messages(self, discarded_messages: int): + def set_discarded_messages(self, discarded_messages: List[int]): self._generation_started = True if self._discarded_messages_generated: diff --git a/aidial_sdk/deployment/__init__.py b/aidial_sdk/deployment/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/aidial_sdk/deployment/from_request_mixin.py b/aidial_sdk/deployment/from_request_mixin.py new file mode 100644 index 0000000..38cdbed --- /dev/null +++ b/aidial_sdk/deployment/from_request_mixin.py @@ -0,0 +1,74 @@ +from abc import ABC, abstractmethod +from json import JSONDecodeError +from typing import Any, Mapping, Optional, Type, TypeVar + +from fastapi import Request + +from aidial_sdk.exceptions import HTTPException as DIALException +from aidial_sdk.pydantic_v1 import StrictStr +from aidial_sdk.utils.logging import log_debug +from aidial_sdk.utils.pydantic import ExtraForbidModel + +T = TypeVar("T", bound="FromRequestMixin") + + +class FromRequestMixin(ABC, ExtraForbidModel): + @classmethod + @abstractmethod + async def from_request(cls: Type[T], request: Request) -> T: + pass + + +class FromRequestBasicMixin(FromRequestMixin): + @classmethod + async def from_request(cls, request: Request): + return cls(**(await _get_request_body(request))) + + +class FromRequestDeploymentMixin(FromRequestMixin): + api_key: StrictStr + jwt: Optional[StrictStr] = None + deployment_id: StrictStr + api_version: Optional[StrictStr] = None + headers: Mapping[StrictStr, StrictStr] + + @classmethod + async def from_request(cls, request: Request): + deployment_id = request.path_params.get("deployment_id") + if deployment_id is None or not isinstance(deployment_id, str): + raise DIALException( + status_code=404, + type="invalid_path", + message="Invalid path", + ) + + headers = request.headers + api_key = headers.get("Api-Key") + if api_key is None: + raise DIALException( + status_code=400, + type="invalid_request_error", + message="Api-Key header is required", + ) + + return cls( + **(await _get_request_body(request)), + api_key=api_key, + jwt=headers.get("Authorization"), + deployment_id=deployment_id, + api_version=request.query_params.get("api-version"), + headers=headers, + ) + + +async def _get_request_body(request: Request) -> Any: + try: + body = await request.json() + log_debug(f"request: {body}") + return body + except JSONDecodeError as e: + raise DIALException( + status_code=400, + type="invalid_request_error", + message=f"Your request contained invalid JSON: {e.msg}", + ) diff --git a/aidial_sdk/deployment/rate.py b/aidial_sdk/deployment/rate.py new file mode 100644 index 0000000..c81b97c --- /dev/null +++ b/aidial_sdk/deployment/rate.py @@ -0,0 +1,7 @@ +from aidial_sdk.deployment.from_request_mixin import FromRequestBasicMixin +from aidial_sdk.pydantic_v1 import Field, StrictStr + + +class RateRequest(FromRequestBasicMixin): + response_id: StrictStr = Field(None, alias="responseId") + rate: bool = False diff --git a/aidial_sdk/deployment/tokenize.py b/aidial_sdk/deployment/tokenize.py new file mode 100644 index 0000000..d52688b --- /dev/null +++ b/aidial_sdk/deployment/tokenize.py @@ -0,0 +1,39 @@ +from typing import List, Literal, Union + +from aidial_sdk.chat_completion.request import ChatCompletionRequest +from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin +from aidial_sdk.pydantic_v1 import BaseModel + + +class TokenizeInputRequest(BaseModel): + type: Literal["request"] = "request" + value: ChatCompletionRequest + + +class TokenizeInputString(BaseModel): + type: Literal["string"] = "string" + value: str + + +TokenizeInput = Union[TokenizeInputRequest, TokenizeInputString] + + +class TokenizeRequest(FromRequestDeploymentMixin): + inputs: List[TokenizeInput] + + +class TokenizeSuccess(BaseModel): + status: Literal["success"] = "success" + token_count: int + + +class TokenizeError(BaseModel): + status: Literal["error"] = "error" + error: str + + +TokenizeOutput = Union[TokenizeSuccess, TokenizeError] + + +class TokenizeResponse(BaseModel): + outputs: List[TokenizeOutput] diff --git a/aidial_sdk/deployment/truncate_prompt.py b/aidial_sdk/deployment/truncate_prompt.py new file mode 100644 index 0000000..042c4ad --- /dev/null +++ b/aidial_sdk/deployment/truncate_prompt.py @@ -0,0 +1,26 @@ +from typing import List, Literal, Union + +from aidial_sdk.chat_completion.request import ChatCompletionRequest +from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin +from aidial_sdk.pydantic_v1 import BaseModel + + +class TruncatePromptRequest(FromRequestDeploymentMixin): + inputs: List[ChatCompletionRequest] + + +class TruncatePromptSuccess(BaseModel): + status: Literal["success"] = "success" + discarded_messages: List[int] + + +class TruncatePromptError(BaseModel): + status: Literal["error"] = "error" + error: str + + +TruncatePromptResult = Union[TruncatePromptSuccess, TruncatePromptError] + + +class TruncatePromptResponse(BaseModel): + outputs: List[TruncatePromptResult] diff --git a/aidial_sdk/telemetry/types.py b/aidial_sdk/telemetry/types.py index 8d66170..613e4ac 100644 --- a/aidial_sdk/telemetry/types.py +++ b/aidial_sdk/telemetry/types.py @@ -1,6 +1,6 @@ from typing import Optional -from pydantic import BaseModel +from aidial_sdk.pydantic_v1 import BaseModel class TracingConfig(BaseModel): diff --git a/aidial_sdk/utils/pydantic.py b/aidial_sdk/utils/pydantic.py new file mode 100644 index 0000000..7b641fb --- /dev/null +++ b/aidial_sdk/utils/pydantic.py @@ -0,0 +1,6 @@ +from aidial_sdk.pydantic_v1 import BaseModel + + +class ExtraForbidModel(BaseModel): + class Config: + extra = "forbid" diff --git a/pyproject.toml b/pyproject.toml index cae8199..e3bce81 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -74,7 +74,7 @@ reportUnusedVariable = "error" reportIncompatibleMethodOverride = "error" exclude = [ ".git", - ".venv", + "**/.venv", ".nox", ".pytest_cache", "**/__pycache__", diff --git a/tests/applications/echo_application.py b/tests/applications/echo_application.py new file mode 100644 index 0000000..44bed79 --- /dev/null +++ b/tests/applications/echo_application.py @@ -0,0 +1,47 @@ +from typing_extensions import override + +from aidial_sdk.chat_completion import ChatCompletion, Request, Response +from aidial_sdk.deployment.tokenize import TokenizeRequest, TokenizeResponse +from aidial_sdk.deployment.truncate_prompt import ( + TruncatePromptRequest, + TruncatePromptResponse, +) +from tests.utils.tokenization import ( + default_truncate_prompt, + make_batched_tokenize, + make_batched_truncate_prompt, + word_count_request, + word_count_tokenize, +) + + +class EchoApplication(ChatCompletion): + model_max_prompt_tokens: int + + def __init__(self, model_max_prompt_tokens: int): + self.model_max_prompt_tokens = model_max_prompt_tokens + + async def chat_completion( + self, request: Request, response: Response + ) -> None: + response.set_response_id("test_id") + response.set_created(0) + + content = request.messages[-1].content or "" + + with response.create_single_choice() as choice: + choice.append_content(content) + + @override + async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse: + return make_batched_tokenize(word_count_tokenize)(request) + + @override + async def truncate_prompt( + self, request: TruncatePromptRequest + ) -> TruncatePromptResponse: + return make_batched_truncate_prompt( + lambda req: default_truncate_prompt( + req, word_count_request, self.model_max_prompt_tokens + ) + )(request) diff --git a/tests/applications/noop_application.py b/tests/applications/noop_application.py new file mode 100644 index 0000000..e842369 --- /dev/null +++ b/tests/applications/noop_application.py @@ -0,0 +1,9 @@ +from aidial_sdk.chat_completion import ChatCompletion, Request, Response + + +class NoopApplication(ChatCompletion): + async def chat_completion( + self, request: Request, response: Response + ) -> None: + with response.create_single_choice(): + pass diff --git a/tests/test_discarded_messages.py b/tests/test_discarded_messages.py index e5a442c..34256a7 100644 --- a/tests/test_discarded_messages.py +++ b/tests/test_discarded_messages.py @@ -7,6 +7,8 @@ from aidial_sdk import DIALApp, HTTPException from aidial_sdk.chat_completion import ChatCompletion, Request, Response +DISCARDED_MESSAGES = list(range(0, 12)) + def test_discarded_messages_returned(): dial_app = DIALApp() @@ -15,7 +17,7 @@ def test_discarded_messages_returned(): async def chat_completion_side_effect(_, res: Response) -> None: with res.create_single_choice(): pass - res.set_discarded_messages(12) + res.set_discarded_messages(DISCARDED_MESSAGES) chat_completion.chat_completion.side_effect = chat_completion_side_effect dial_app.add_chat_completion("test_app", chat_completion) @@ -30,7 +32,10 @@ async def chat_completion_side_effect(_, res: Response) -> None: headers={"Api-Key": "TEST_API_KEY"}, ) - assert response.json()["statistics"]["discarded_messages"] == 12 + assert ( + response.json()["statistics"]["discarded_messages"] + == DISCARDED_MESSAGES + ) def test_discarded_messages_returned_as_last_chunk_in_stream(): @@ -44,7 +49,7 @@ async def chat_completion_side_effect(_, res: Response) -> None: with res.create_single_choice(): pass - res.set_discarded_messages(12) + res.set_discarded_messages(DISCARDED_MESSAGES) chat_completion.chat_completion.side_effect = chat_completion_side_effect dial_app.add_chat_completion("test_app", chat_completion) @@ -95,7 +100,7 @@ def identity(data: str): { "choices": [{"index": 0, "finish_reason": "stop", "delta": {}}], "usage": None, - "statistics": {"discarded_messages": 12}, + "statistics": {"discarded_messages": DISCARDED_MESSAGES}, "id": "test_id", "created": 123, "object": "chat.completion.chunk", @@ -113,10 +118,10 @@ def test_discarded_messages_is_set_twice(): with response.create_single_choice(): pass - response.set_discarded_messages(1) + response.set_discarded_messages(DISCARDED_MESSAGES) with pytest.raises(HTTPException): - response.set_discarded_messages(1) + response.set_discarded_messages(DISCARDED_MESSAGES) def test_discarded_messages_is_set_before_choice(): @@ -124,4 +129,4 @@ def test_discarded_messages_is_set_before_choice(): response = Response(request) with pytest.raises(HTTPException): - response.set_discarded_messages(1) + response.set_discarded_messages(DISCARDED_MESSAGES) diff --git a/tests/test_errors.py b/tests/test_errors.py index 07e4d48..b971de5 100644 --- a/tests/test_errors.py +++ b/tests/test_errors.py @@ -5,6 +5,7 @@ from aidial_sdk import DIALApp from tests.applications.broken_application import BrokenApplication +from tests.applications.noop_application import NoopApplication from tests.applications.runtime_broken_application import ( RuntimeBrokenApplication, ) @@ -18,6 +19,15 @@ } } +API_KEY_IS_MISSING = { + "error": { + "code": None, + "message": "Api-Key header is required", + "param": None, + "type": "invalid_request_error", + } +} + error_testdata = [ ( "sdk_exception", @@ -151,3 +161,20 @@ def test_runtime_streaming_error(type, response_status_code, response_content): assert json.loads(data) == response_content elif index == 8: assert data == "[DONE]" + + +def test_no_api_key(): + dial_app = DIALApp() + dial_app.add_chat_completion("test_app", NoopApplication()) + + test_app = TestClient(dial_app) + + response = test_app.post( + "/openai/deployments/test_app/chat/completions", + json={ + "messages": [{"role": "user", "content": "test"}], + "stream": False, + }, + ) + + assert response.status_code == 400 and response.json() == API_KEY_IS_MISSING diff --git a/tests/test_rate_response.py b/tests/test_rate_response.py index 6d110f5..ced3412 100644 --- a/tests/test_rate_response.py +++ b/tests/test_rate_response.py @@ -1,21 +1,25 @@ -from starlette.testclient import TestClient +from typing import List -from aidial_sdk import DIALApp -from tests.applications.single_choice_application import SingleChoiceApplication +import pytest +from tests.applications.noop_application import NoopApplication +from tests.utils.endpoint_test import TestCase, run_endpoint_test +from tests.utils.errors import extra_fields_error -def test_rate_response(): - dial_app = DIALApp() - dial_app.add_chat_completion("test_app", SingleChoiceApplication()) +RATE_REQUEST_OK1 = {} +RATE_REQUEST_OK2 = {"responseId": "123", "rate": True} +RATE_REQUEST_FAIL = {"foo": "bar"} - test_app = TestClient(dial_app) - response = test_app.post( - "/openai/deployments/test_app/rate", - json={ - "responseId": "123", - "rate": True, - }, - ) +noop = NoopApplication() - assert response.status_code == 200 +testcases: List[TestCase] = [ + TestCase(noop, "rate", RATE_REQUEST_OK2, None), + TestCase(noop, "rate", RATE_REQUEST_OK1, None), + TestCase(noop, "rate", RATE_REQUEST_FAIL, extra_fields_error("foo")), +] + + +@pytest.mark.parametrize("testcase", testcases) +def test_rate_endpoint(testcase: TestCase): + run_endpoint_test(testcase) diff --git a/tests/test_tokenize.py b/tests/test_tokenize.py new file mode 100644 index 0000000..f898fe6 --- /dev/null +++ b/tests/test_tokenize.py @@ -0,0 +1,65 @@ +from typing import List + +import pytest + +from tests.applications.echo_application import EchoApplication +from tests.applications.noop_application import NoopApplication +from tests.utils.endpoint_test import TestCase, run_endpoint_test +from tests.utils.errors import ( + bad_request_error, + not_implemented_error, + route_not_found_error, +) + +CHAT_COMPLETION_REQUEST = { + "messages": [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "ping"}, + {"role": "assistant", "content": "pong"}, + {"role": "user", "content": "hello"}, + ], +} + +TOKENIZE_REQUEST_OK1 = { + "inputs": [ + {"type": "request", "value": CHAT_COMPLETION_REQUEST}, + {"type": "string", "value": "test string"}, + ] +} +TOKENIZE_RESPONSE_OK1 = { + "outputs": [ + {"status": "success", "token_count": 4}, + {"status": "success", "token_count": 2}, + ] +} + +TOKENIZE_REQUEST_OK2 = {"inputs": []} +TOKENIZE_RESPONSE_OK2 = {"outputs": []} + +TOKENIZE_REQUEST_FAIL = {"inputs": [{}]} + +noop = NoopApplication() +echo = EchoApplication + +testcases: List[TestCase] = [ + TestCase( + noop, + "tokenize", + TOKENIZE_REQUEST_OK1, + not_implemented_error("tokenize"), + ), + TestCase(noop, "tokenizer", TOKENIZE_REQUEST_OK1, route_not_found_error), + TestCase(echo(0), "tokenize", TOKENIZE_REQUEST_OK1, TOKENIZE_RESPONSE_OK1), + TestCase(echo(0), "tokenize", TOKENIZE_REQUEST_OK2, TOKENIZE_RESPONSE_OK2), + TestCase( + echo(0), + "tokenize", + TOKENIZE_REQUEST_FAIL, + bad_request_error("inputs.0.value"), + ), +] + + +@pytest.mark.parametrize("testcase", testcases) +def test_tokenize(testcase: TestCase): + run_endpoint_test(testcase) diff --git a/tests/test_truncate_prompt.py b/tests/test_truncate_prompt.py new file mode 100644 index 0000000..707db99 --- /dev/null +++ b/tests/test_truncate_prompt.py @@ -0,0 +1,108 @@ +from typing import List, Optional + +import pytest + +from tests.applications.echo_application import EchoApplication +from tests.applications.noop_application import NoopApplication +from tests.utils.endpoint_test import TestCase, run_endpoint_test +from tests.utils.errors import not_implemented_error, route_not_found_error + +CHAT_COMPLETION_REQUEST = { + "messages": [ + {"role": "system", "content": "system"}, + {"role": "user", "content": "ping"}, + {"role": "assistant", "content": "pong"}, + {"role": "user", "content": "hello"}, + ], +} + + +def create_request(max_prompt_tokens: Optional[int]): + return { + "inputs": [ + { + **CHAT_COMPLETION_REQUEST, + "max_prompt_tokens": max_prompt_tokens, + } + ] + } + + +def create_response( + model_max_prompt_tokens: int, max_prompt_tokens: Optional[int] +): + if max_prompt_tokens is None: + if model_max_prompt_tokens >= 4: + return { + "outputs": [{"status": "success", "discarded_messages": []}] + } + else: + return { + "outputs": [ + { + "status": "error", + "error": "Token count of all messages (4) exceeds " + f"the model maximum prompt tokens ({model_max_prompt_tokens}).", + } + ] + } + + if max_prompt_tokens == 1: + return { + "outputs": [ + { + "status": "error", + "error": "Token count of the last user message and all " + "system messages (2) exceeds the maximum prompt tokens (1).", + } + ] + } + if max_prompt_tokens == 2: + return { + "outputs": [{"status": "success", "discarded_messages": [1, 2]}] + } + if max_prompt_tokens == 3: + return {"outputs": [{"status": "success", "discarded_messages": [1]}]} + return {"outputs": [{"status": "success", "discarded_messages": []}]} + + +noop = NoopApplication() +echo = EchoApplication + +testcases: List[TestCase] = [ + TestCase( + noop, + "truncate_prompt", + create_request(None), + not_implemented_error("truncate_prompt"), + ), + TestCase( + noop, + "truncate_prompts", + create_request(None), + route_not_found_error, + ), + *[ + TestCase( + echo(4), + "truncate_prompt", + create_request(max_prompt_tokens), + create_response(4, max_prompt_tokens), + ) + for max_prompt_tokens in range(1, 6) + ], + *[ + TestCase( + echo(model_limit), + "truncate_prompt", + create_request(None), + create_response(model_limit, None), + ) + for model_limit in [3, 4] + ], +] + + +@pytest.mark.parametrize("testcase", testcases) +def test_truncate_prompt(testcase: TestCase): + run_endpoint_test(testcase) diff --git a/tests/utils/endpoint_test.py b/tests/utils/endpoint_test.py new file mode 100644 index 0000000..48cc91f --- /dev/null +++ b/tests/utils/endpoint_test.py @@ -0,0 +1,57 @@ +from typing import Union + +from starlette.testclient import TestClient + +from aidial_sdk import DIALApp +from aidial_sdk.chat_completion.base import ChatCompletion +from tests.utils.errors import Error + + +class TestCase: + __test__ = False + + app: ChatCompletion + endpoint: str + request: dict + response: Union[Error, dict, None] + + def __init__( + self, + app: ChatCompletion, + endpoint: str, + request: dict, + response: Union[Error, dict, None], + ): + self.app = app + self.endpoint = endpoint + self.request = request + self.response = response + + +def run_endpoint_test(testcase: TestCase): + dial_app = DIALApp() + dial_app.add_chat_completion("test_app", testcase.app) + + test_app = TestClient(dial_app) + + actual_response = test_app.post( + f"/openai/deployments/test_app/{testcase.endpoint}", + json=testcase.request, + headers={"Api-Key": "TEST_API_KEY"}, + ) + + if actual_response.text == "": + actual_response_body = None + else: + actual_response_body = actual_response.json() + + expected_response = testcase.response + if isinstance(expected_response, Error): + expected_response_code = expected_response.code + expected_response_body = expected_response.error + else: + expected_response_code = 200 + expected_response_body = expected_response + + assert actual_response.status_code == expected_response_code + assert actual_response_body == expected_response_body diff --git a/tests/utils/errors.py b/tests/utils/errors.py new file mode 100644 index 0000000..16a317b --- /dev/null +++ b/tests/utils/errors.py @@ -0,0 +1,52 @@ +from aidial_sdk.pydantic_v1 import BaseModel + + +class Error(BaseModel): + code: int + error: dict + + +def bad_request_error(path: str) -> Error: + return Error( + code=400, + error={ + "error": { + "code": None, + "message": f"Your request contained invalid structure on path {path}. field required", + "param": None, + "type": "invalid_request_error", + } + }, + ) + + +def not_implemented_error(endpoint: str) -> Error: + return Error( + code=404, + error={ + "error": { + "message": f"The deployment doesn't implement '{endpoint}' endpoint.", + "type": "runtime_error", + "code": "endpoint_not_found", + "param": None, + } + }, + ) + + +def extra_fields_error(path: str) -> Error: + return Error( + code=400, + error={ + "error": { + "code": None, + "message": f"Your request contained invalid structure on path {path}. " + "extra fields not permitted", + "param": None, + "type": "invalid_request_error", + } + }, + ) + + +route_not_found_error: Error = Error(code=404, error={"detail": "Not Found"}) diff --git a/tests/utils/tokenization.py b/tests/utils/tokenization.py new file mode 100644 index 0000000..1428356 --- /dev/null +++ b/tests/utils/tokenization.py @@ -0,0 +1,136 @@ +from typing import Callable, Optional, Set + +from aidial_sdk.chat_completion.request import ( + ChatCompletionRequest, + Message, + Role, +) +from aidial_sdk.deployment.tokenize import ( + TokenizeInput, + TokenizeOutput, + TokenizeRequest, + TokenizeResponse, + TokenizeSuccess, +) +from aidial_sdk.deployment.truncate_prompt import ( + TruncatePromptError, + TruncatePromptRequest, + TruncatePromptResponse, + TruncatePromptResult, + TruncatePromptSuccess, +) + + +def word_count_string(string: str) -> int: + return len(string.split()) + + +def word_count_message(message: Message) -> int: + return word_count_string(message.content or "") + + +def word_count_request(request: ChatCompletionRequest) -> int: + return sum(map(word_count_message, request.messages)) + + +def word_count_tokenize(request: TokenizeInput) -> TokenizeOutput: + if request.type == "request": + token_count = word_count_request(request.value) + elif request.type == "string": + token_count = word_count_string(request.value) + else: + raise ValueError(f"Unknown tokenize input type: {request.type}") + + return TokenizeSuccess(token_count=token_count) + + +def make_batched_tokenize( + tokenize: Callable[[TokenizeInput], TokenizeOutput] +) -> Callable[[TokenizeRequest], TokenizeResponse]: + def ret(request: TokenizeRequest) -> TokenizeResponse: + return TokenizeResponse( + outputs=[tokenize(inp) for inp in request.inputs] + ) + + return ret + + +def default_truncate_prompt( + request: ChatCompletionRequest, + count_request_tokens: Callable[[ChatCompletionRequest], int], + model_max_prompt_tokens: int, +) -> TruncatePromptResult: + def _count_tokens_selected(indices: Set[int]) -> int: + messages = [ + message + for idx, message in enumerate(request.messages) + if idx in indices + ] + sub_request = request.copy(update={"messages": messages}) + return count_request_tokens(sub_request) + + all_indices = set(range(0, len(request.messages))) + + max_prompt_tokens: Optional[int] = request.max_prompt_tokens + if max_prompt_tokens is None: + token_count = _count_tokens_selected(all_indices) + if token_count > model_max_prompt_tokens: + return TruncatePromptError( + error=f"Token count of all messages ({token_count}) exceeds" + f" the model maximum prompt tokens ({model_max_prompt_tokens}).", + ) + return TruncatePromptSuccess(discarded_messages=[]) + + token_count: int = 0 + found_user_message = False + selected_indices: Set[int] = set() + + for idx in reversed(range(0, len(request.messages))): + message = request.messages[idx] + + is_user_message = message.role == Role.USER + is_last_user_message = not found_user_message and is_user_message + found_user_message = found_user_message or is_user_message + + is_message_required = ( + message.role == Role.SYSTEM or is_last_user_message + ) + + if not is_message_required: + continue + + selected_indices.add(idx) + token_count = _count_tokens_selected(selected_indices) + + if token_count > max_prompt_tokens: + return TruncatePromptError( + error="Token count of the last user message and all system messages " + f"({token_count}) exceeds the maximum prompt tokens ({max_prompt_tokens}).", + ) + + for idx in reversed(range(0, len(request.messages))): + if idx in selected_indices: + continue + + new_token_count = _count_tokens_selected({*selected_indices, idx}) + if new_token_count > max_prompt_tokens: + break + + selected_indices.add(idx) + token_count = new_token_count + + discarded_indices = all_indices - selected_indices + return TruncatePromptSuccess( + discarded_messages=list(sorted(discarded_indices)) + ) + + +def make_batched_truncate_prompt( + truncate_prompt: Callable[[ChatCompletionRequest], TruncatePromptResult], +) -> Callable[[TruncatePromptRequest], TruncatePromptResponse]: + def ret(request: TruncatePromptRequest) -> TruncatePromptResponse: + return TruncatePromptResponse( + outputs=[truncate_prompt(req) for req in request.inputs] + ) + + return ret