Skip to content

Commit

Permalink
feat: implement history truncation (#27)
Browse files Browse the repository at this point in the history
* Implement history truncation.
  • Loading branch information
Oleksii-Klimov authored Dec 5, 2023
1 parent 248bee4 commit 5f0bee0
Show file tree
Hide file tree
Showing 13 changed files with 701 additions and 234 deletions.
85 changes: 52 additions & 33 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
import logging
from pathlib import Path

from aidial_sdk import HTTPException
from aidial_sdk.chat_completion import FinishReason
from aidial_sdk.chat_completion.base import ChatCompletion
from aidial_sdk.chat_completion.request import Addon, Request
from aidial_sdk.chat_completion.request import Addon, Message, Request, Role
from aidial_sdk.chat_completion.response import Response
from aiohttp import hdrs
from openai import InvalidRequestError, OpenAIError

from aidial_assistant.application.args import parse_args
from aidial_assistant.application.assistant_callback import (
Expand All @@ -22,21 +20,24 @@
from aidial_assistant.chain.model_client import (
ModelClient,
ReasonLengthException,
UsagePublisher,
)
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.utils.exceptions import (
RequestParameterValidationError,
unhandled_exception_handler,
)
from aidial_assistant.utils.open_ai_plugin import (
AddonTokenSource,
get_open_ai_plugin_info,
get_plugin_auth,
)
from aidial_assistant.utils.state import parse_history
from aidial_assistant.utils.state import State, parse_history

logger = logging.getLogger(__name__)


def get_request_args(request: Request) -> dict[str, str]:
def _get_request_args(request: Request) -> dict[str, str]:
args = {
"model": request.model,
"temperature": request.temperature,
Expand All @@ -51,21 +52,43 @@ def get_request_args(request: Request) -> dict[str, str]:
return {k: v for k, v in args.items() if v is not None}


def _extract_addon_url(addon: Addon) -> str:
if addon.url is None:
raise InvalidRequestError("Missing required addon url.", param="")
def _validate_addons(addons: list[Addon] | None):
if addons and any(addon.url is None for addon in addons):
for index, addon in enumerate(addons):
if addon.url is None:
raise RequestParameterValidationError(
f"Missing required addon url at index {index}.",
param="addons",
)


def _validate_messages(messages: list[Message]) -> None:
if not messages:
raise RequestParameterValidationError(
"Message list cannot be empty.", param="messages"
)

if messages[-1].role != Role.USER:
raise RequestParameterValidationError(
"Last message must be from the user.", param="messages"
)


return addon.url
def _validate_request(request: Request) -> None:
_validate_messages(request.messages)
_validate_addons(request.addons)


class AssistantApplication(ChatCompletion):
def __init__(self, config_dir: Path):
self.args = parse_args(config_dir)

@unhandled_exception_handler
async def chat_completion(
self, request: Request, response: Response
) -> None:
chat_args = self.args.openai_conf.dict() | get_request_args(request)
_validate_request(request)
chat_args = self.args.openai_conf.dict() | _get_request_args(request)

model = ModelClient(
model_args=chat_args
Expand All @@ -77,10 +100,8 @@ async def chat_completion(
buffer_size=self.args.chat_conf.buffer_size,
)

addons = (
[_extract_addon_url(addon) for addon in request.addons]
if request.addons
else []
addons: list[str] = (
[addon.url for addon in request.addons] if request.addons else [] # type: ignore
)
token_source = AddonTokenSource(request.headers, addons)

Expand All @@ -103,16 +124,12 @@ async def chat_completion(
or info.ai_plugin.description_for_human
)

usage_publisher = UsagePublisher()
command_dict: CommandDict = {
RunPlugin.token(): lambda: RunPlugin(model, tools, usage_publisher),
RunPlugin.token(): lambda: RunPlugin(model, tools),
Reply.token(): Reply,
}
chain = CommandChain(
model_client=model,
name="ASSISTANT",
command_dict=command_dict,
usage_publisher=usage_publisher,
model_client=model, name="ASSISTANT", command_dict=command_dict
)
history = History(
assistant_system_message_template=MAIN_SYSTEM_DIALOG_MESSAGE.build(
Expand All @@ -123,6 +140,14 @@ async def chat_completion(
),
scoped_messages=parse_history(request.messages),
)
discarded_messages: int | None = None
if request.max_prompt_tokens is not None:
original_size = history.user_message_count
history = await history.truncate(request.max_prompt_tokens, model)
truncated_size = history.user_message_count
discarded_messages = original_size - truncated_size
# TODO: else compare the history size to the max prompt tokens of the underlying model

choice = response.create_single_choice()
choice.open()

Expand All @@ -132,19 +157,13 @@ async def chat_completion(
await chain.run_chat(history, callback)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH
except OpenAIError as e:
if e.error:
raise HTTPException(
e.error.message,
status_code=e.http_status or 500,
code=e.error.code,
)

raise
if callback.invocations:
choice.set_state(State(invocations=callback.invocations))

choice.set_state(callback.get_state())
choice.close(finish_reason)

response.set_usage(
usage_publisher.prompt_tokens, usage_publisher.completion_tokens
)
response.set_usage(model.prompt_tokens, model.completion_tokens)

if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)
8 changes: 5 additions & 3 deletions aidial_assistant/application/assistant_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from aidial_assistant.chain.callbacks.result_callback import ResultCallback
from aidial_assistant.commands.base import ExecutionCallback, ResultObject
from aidial_assistant.commands.run_plugin import RunPlugin
from aidial_assistant.utils.state import Invocation, State
from aidial_assistant.utils.state import Invocation


class PluginNameArgCallback(ArgCallback):
Expand Down Expand Up @@ -113,6 +113,7 @@ def __init__(self, choice: Choice):
self.choice = choice
self._invocations: list[Invocation] = []
self._invocation_index: int = -1
self._discarded_messages: int = 0

@override
def command_callback(self) -> CommandCallback:
Expand All @@ -138,5 +139,6 @@ def on_error(self, title: str, error: str):
stage.append_content(f"Error: {error}\n")
stage.close(Status.FAILED)

def get_state(self):
return State(invocations=self._invocations)
@property
def invocations(self) -> list[Invocation]:
return self._invocations
34 changes: 13 additions & 21 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,7 @@
)
from aidial_assistant.chain.dialogue import Dialogue
from aidial_assistant.chain.history import History
from aidial_assistant.chain.model_client import (
Message,
ModelClient,
UsagePublisher,
)
from aidial_assistant.chain.model_client import Message, ModelClient
from aidial_assistant.chain.model_response_reader import (
AssistantProtocolException,
CommandsReader,
Expand All @@ -48,23 +44,17 @@
CommandDict = dict[str, CommandConstructor]


class MaxRetryCountExceededException(Exception):
pass


class CommandChain:
def __init__(
self,
name: str,
model_client: ModelClient,
command_dict: CommandDict,
usage_publisher: UsagePublisher,
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
):
self.name = name
self.model_client = model_client
self.command_dict = command_dict
self.usage_publisher = usage_publisher
self.max_retry_count = max_retry_count

def _log_message(self, role: Role, content: str):
Expand All @@ -81,8 +71,7 @@ async def run_chat(self, history: History, callback: ChainCallback):
messages = history.to_protocol_messages()
while True:
pair = await self._run_with_protocol_failure_retries(
callback,
self._reinforce_json_format(messages + dialogue.messages),
callback, messages + dialogue.messages
)

if pair is None:
Expand All @@ -96,7 +85,7 @@ async def run_chat(self, history: History, callback: ChainCallback):
dialogue,
)
if not dialogue.is_empty()
else history.to_client_messages()
else history.to_user_messages()
)
await self._generate_result(messages, callback)
except InvalidRequestError as e:
Expand All @@ -112,15 +101,14 @@ async def run_chat(self, history: History, callback: ChainCallback):
async def _run_with_protocol_failure_retries(
self, callback: ChainCallback, messages: list[Message]
) -> Tuple[str, str] | None:
self._log_messages(messages)
last_error: Exception | None = None
try:
retry: int = 0
self._log_messages(messages)
retries = Dialogue()
while True:
chunk_stream = CumulativeStream(
self.model_client.agenerate(
messages + retries.messages, self.usage_publisher
self._reinforce_json_format(messages + retries.messages)
)
)
try:
Expand All @@ -139,15 +127,17 @@ async def _run_with_protocol_failure_retries(
except (JsonParsingException, AssistantProtocolException) as e:
logger.exception("Failed to process model response")

retry_count = len(retries.messages) // 2
callback.on_error(
"Error" if retry == 0 else f"Error (retry {retry})",
"Error"
if retry_count == 0
else f"Error (retry {retry_count})",
"The model failed to construct addon request.",
)

if retry >= self.max_retry_count:
if retry_count >= self.max_retry_count:
raise

retry += 1
last_error = e
retries.append(
chunk_stream.buffer,
Expand All @@ -157,6 +147,8 @@ async def _run_with_protocol_failure_retries(
self._log_message(Role.ASSISTANT, chunk_stream.buffer)
except InvalidRequestError as e:
if last_error:
# Retries can increase the prompt size, which may lead to token overflow.
# Thus, if the original error was a protocol error, it should be thrown instead.
raise last_error

callback.on_error("Error", str(e))
Expand Down Expand Up @@ -211,7 +203,7 @@ def _create_command(self, name: str) -> Command:
async def _generate_result(
self, messages: list[Message], callback: ChainCallback
):
stream = self.model_client.agenerate(messages, self.usage_publisher)
stream = self.model_client.agenerate(messages)

await CommandChain._to_result(stream, callback.result_callback())

Expand Down
Loading

0 comments on commit 5f0bee0

Please sign in to comment.