Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: supported tokenize and truncate_prompt endpoints #50

Merged
merged 15 commits into from
Feb 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 112 additions & 81 deletions aidial_sdk/application.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging.config
from json import JSONDecodeError
from typing import Dict, Optional, Union
from typing import Dict, Optional, Type, TypeVar

from fastapi import FastAPI, HTTPException, Request
from fastapi.responses import JSONResponse, Response, StreamingResponse

from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import RateRequest
from aidial_sdk.chat_completion.request import Request as ChatCompletionRequest
from aidial_sdk.chat_completion.response import (
Response as ChatCompletionResponse,
)
from aidial_sdk.deployment.from_request_mixin import FromRequestMixin
from aidial_sdk.deployment.rate import RateRequest
from aidial_sdk.deployment.tokenize import TokenizeRequest
from aidial_sdk.deployment.truncate_prompt import TruncatePromptRequest
from aidial_sdk.exceptions import HTTPException as DIALException
from aidial_sdk.header_propagator import HeaderPropagator
from aidial_sdk.pydantic_v1 import ValidationError
from aidial_sdk.telemetry.types import TelemetryConfig
Expand All @@ -21,6 +24,8 @@

logging.config.dictConfig(LogConfig().dict())

RequestType = TypeVar("RequestType", bound=FromRequestMixin)


class DIALApp(FastAPI):
chat_completion_impls: Dict[str, ChatCompletion] = {}
Expand Down Expand Up @@ -61,7 +66,29 @@ def __init__(
methods=["POST"],
)

self.add_exception_handler(HTTPException, DIALApp._exception_handler)
self.add_api_route(
"/openai/deployments/{deployment_id}/tokenize",
self._endpoint_factory("tokenize", TokenizeRequest),
methods=["POST"],
)

self.add_api_route(
"/openai/deployments/{deployment_id}/truncate_prompt",
self._endpoint_factory("truncate_prompt", TruncatePromptRequest),
methods=["POST"],
)

self.add_exception_handler(
ValidationError, DIALApp._pydantic_validation_exception_handler
)

self.add_exception_handler(
HTTPException, DIALApp._fastapi_exception_handler
)

self.add_exception_handler(
DIALException, DIALApp._dial_exception_handler
)

def configure_telemetry(self, config: TelemetryConfig):
try:
Expand All @@ -79,124 +106,128 @@ def add_chat_completion(
) -> None:
self.chat_completion_impls[deployment_name] = impl

def _endpoint_factory(
self, endpoint: str, request_type: Type["RequestType"]
):
async def _handler(
deployment_id: str, original_request: Request
) -> Response:
set_log_deployment(deployment_id)
deployment = self._get_deployment(deployment_id)

request = await request_type.from_request(original_request)

endpoint_impl = getattr(deployment, endpoint, None)
if not endpoint_impl:
raise self._get_missing_endpoint_error(endpoint)

try:
response = await endpoint_impl(request)
except NotImplementedError:
raise self._get_missing_endpoint_error(endpoint)

response_json = response.dict()
log_debug(f"response [{endpoint}]: {response_json}")
return JSONResponse(content=response_json)

return _handler

async def _rate_response(
self, deployment_id: str, original_request: Request
) -> Response:
set_log_deployment(deployment_id)
impl = self._get_deployment(deployment_id)
deployment = self._get_deployment(deployment_id)

if isinstance(impl, JSONResponse):
return impl
request = await RateRequest.from_request(original_request)

body = await DIALApp._get_json_body(original_request)
if isinstance(body, JSONResponse):
return body
log_debug(f"request: {body}")

try:
request = RateRequest(**body)
except ValidationError as e:
return DIALApp._get_validation_error_response(e)

await impl.rate_response(request)
await deployment.rate_response(request)
return Response(status_code=200)

async def _chat_completion(
self, deployment_id: str, original_request: Request
) -> Response:
set_log_deployment(deployment_id)
impl = self._get_deployment(deployment_id)
deployment = self._get_deployment(deployment_id)

if isinstance(impl, JSONResponse):
return impl

body = await DIALApp._get_json_body(original_request)
if isinstance(body, JSONResponse):
return body

headers = original_request.headers
try:
request = ChatCompletionRequest(
**body,
api_key=headers["Api-Key"],
jwt=headers.get("Authorization"),
deployment_id=deployment_id,
api_version=original_request.query_params.get("api-version"),
headers=headers,
)
except ValidationError as e:
return DIALApp._get_validation_error_response(e)

log_debug(f"request: {body}")
request = await ChatCompletionRequest.from_request(original_request)

response = ChatCompletionResponse(request)
first_chunk = await response._generator(impl.chat_completion, request)
first_chunk = await response._generator(
deployment.chat_completion, request
)

if request.stream:
return StreamingResponse(
response._generate_stream(first_chunk),
media_type="text/event-stream",
)
else:
response_body = await merge_chunks(
response_json = await merge_chunks(
response._generate_stream(first_chunk)
)

log_debug(f"response: {response_body}")
return JSONResponse(content=response_body)
log_debug(f"response: {response_json}")
return JSONResponse(content=response_json)

@staticmethod
async def _healthcheck() -> JSONResponse:
return JSONResponse(content={"status": "ok"})

def _get_deployment(
self, deployment_id: str
) -> Union[ChatCompletion, JSONResponse]:
def _get_deployment(self, deployment_id: str) -> ChatCompletion:
impl = self.chat_completion_impls.get(deployment_id, None)

if not impl:
return JSONResponse(
raise DIALException(
status_code=404,
content=json_error(
message="The API deployment for this resource does not exist.",
code="deployment_not_found",
),
code="deployment_not_found",
message="The API deployment for this resource does not exist.",
)

return impl

@staticmethod
async def _get_json_body(request: Request) -> Union[dict, JSONResponse]:
try:
return await request.json()
except JSONDecodeError as e:
return JSONResponse(
status_code=400,
content=json_error(
message=f"Your request contained invalid JSON: {str(e.msg)}",
type="invalid_request_error",
),
)
def _get_missing_endpoint_error(endpoint: str) -> DIALException:
return DIALException(
status_code=404,
code="endpoint_not_found",
message=f"The deployment doesn't implement '{endpoint}' endpoint.",
)

@staticmethod
def _get_validation_error_response(
e: ValidationError,
def _pydantic_validation_exception_handler(
request: Request, exc: Exception
) -> JSONResponse:
error = e.errors()[0]
path = ".".join(map(str, e.errors()[0]["loc"]))
assert isinstance(exc, ValidationError)

error = exc.errors()[0]
path = ".".join(map(str, error["loc"]))
message = f"Your request contained invalid structure on path {path}. {error['msg']}"
return JSONResponse(
status_code=400,
content=json_error(
message=f"Your request contained invalid structure on path {path}. {error['msg']}",
type="invalid_request_error",
),
content=json_error(message=message, type="invalid_request_error"),
)

@staticmethod
async def _healthcheck() -> JSONResponse:
return JSONResponse(content={"status": "ok"})
def _fastapi_exception_handler(
request: Request, exc: Exception
) -> JSONResponse:
assert isinstance(exc, HTTPException)
return JSONResponse(
status_code=exc.status_code,
content=exc.detail,
)

@staticmethod
def _exception_handler(request: Request, exc: Exception):
if isinstance(exc, HTTPException):
return JSONResponse(
status_code=exc.status_code,
content=exc.detail,
)
else:
raise exc
def _dial_exception_handler(
request: Request, exc: Exception
) -> JSONResponse:
assert isinstance(exc, DIALException)
return JSONResponse(
status_code=exc.status_code,
content=json_error(
message=exc.message,
type=exc.type,
param=exc.param,
code=exc.code,
),
)
12 changes: 12 additions & 0 deletions aidial_sdk/chat_completion/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,15 @@
)
from aidial_sdk.chat_completion.response import Response
from aidial_sdk.chat_completion.stage import Stage
from aidial_sdk.deployment.tokenize import (
TokenizeError,
TokenizeRequest,
TokenizeResponse,
TokenizeSuccess,
)
from aidial_sdk.deployment.truncate_prompt import (
TruncatePromptError,
TruncatePromptRequest,
TruncatePromptResponse,
TruncatePromptSuccess,
)
18 changes: 17 additions & 1 deletion aidial_sdk/chat_completion/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
from abc import ABC, abstractmethod

from aidial_sdk.chat_completion.request import RateRequest, Request
from aidial_sdk.chat_completion.request import Request
from aidial_sdk.chat_completion.response import Response
from aidial_sdk.deployment.rate import RateRequest
from aidial_sdk.deployment.tokenize import TokenizeRequest, TokenizeResponse
from aidial_sdk.deployment.truncate_prompt import (
TruncatePromptRequest,
TruncatePromptResponse,
)


class ChatCompletion(ABC):
Expand All @@ -13,3 +19,13 @@ async def chat_completion(

async def rate_response(self, request: RateRequest) -> None:
"""Implement rate response logic"""

async def tokenize(self, request: TokenizeRequest) -> TokenizeResponse:
"""Implement tokenize logic"""
raise NotImplementedError()
vladisavvv marked this conversation as resolved.
Show resolved Hide resolved

async def truncate_prompt(
self, request: TruncatePromptRequest
) -> TruncatePromptResponse:
"""Implement truncate prompt logic"""
raise NotImplementedError()
vladisavvv marked this conversation as resolved.
Show resolved Hide resolved
6 changes: 3 additions & 3 deletions aidial_sdk/chat_completion/chunks.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, Optional
from typing import Any, Dict, List, Optional

from aidial_sdk.chat_completion.enums import FinishReason, Status
from aidial_sdk.pydantic_v1 import BaseModel, root_validator
Expand Down Expand Up @@ -457,9 +457,9 @@ def to_dict(self):


class DiscardedMessagesChunk(BaseChunk):
discarded_messages: int
discarded_messages: List[int]

def __init__(self, discarded_messages: int):
def __init__(self, discarded_messages: List[int]):
self.discarded_messages = discarded_messages

def to_dict(self):
Expand Down
27 changes: 10 additions & 17 deletions aidial_sdk/chat_completion/request.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,15 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Mapping, Optional, Union

from aidial_sdk.deployment.from_request_mixin import FromRequestDeploymentMixin
from aidial_sdk.pydantic_v1 import (
BaseModel,
ConstrainedFloat,
ConstrainedInt,
ConstrainedList,
Field,
PositiveInt,
StrictStr,
)


class ExtraForbidModel(BaseModel):
class Config:
extra = "forbid"
from aidial_sdk.utils.pydantic import ExtraForbidModel


class Attachment(ExtraForbidModel):
Expand Down Expand Up @@ -112,7 +107,7 @@ class ToolChoice(ExtraForbidModel):
function: FunctionChoice


class Request(ExtraForbidModel):
class AzureChatCompletionRequest(ExtraForbidModel):
model: Optional[StrictStr] = None
messages: List[Message]
functions: Optional[List[Function]] = None
Expand All @@ -128,19 +123,17 @@ class Request(ExtraForbidModel):
n: Optional[N] = None
stop: Optional[Union[StrictStr, Stop]] = None
max_tokens: Optional[PositiveInt] = None
max_prompt_tokens: Optional[PositiveInt] = None
presence_penalty: Optional[Penalty] = None
frequency_penalty: Optional[Penalty] = None
logit_bias: Optional[Mapping[int, float]] = None
user: Optional[StrictStr] = None

api_key: StrictStr
jwt: Optional[StrictStr] = None
deployment_id: StrictStr
api_version: Optional[StrictStr] = None
headers: Mapping[StrictStr, StrictStr]

class ChatCompletionRequest(AzureChatCompletionRequest):
model: Optional[StrictStr] = None
addons: Optional[List[Addon]] = None
max_prompt_tokens: Optional[PositiveInt] = None


class RateRequest(ExtraForbidModel):
response_id: StrictStr = Field(None, alias="responseId")
rate: bool = False
class Request(ChatCompletionRequest, FromRequestDeploymentMixin):
pass
Loading
Loading