Skip to content

Commit

Permalink
feat: limit addons dialogue size (#31)
Browse files Browse the repository at this point in the history
* Limit addons dialogue size.
  • Loading branch information
Oleksii-Klimov authored Dec 5, 2023
1 parent 5f0bee0 commit fedb23f
Show file tree
Hide file tree
Showing 16 changed files with 460 additions and 116 deletions.
35 changes: 35 additions & 0 deletions aidial_assistant/application/addons_dialogue_limiter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from typing_extensions import override

from aidial_assistant.chain.command_chain import (
LimitExceededException,
ModelRequestLimiter,
)
from aidial_assistant.model.model_client import Message, ModelClient


class AddonsDialogueLimiter(ModelRequestLimiter):
def __init__(self, max_dialogue_tokens: int, model_client: ModelClient):
self.max_dialogue_tokens = max_dialogue_tokens
self.model_client = model_client

self._dialogue_tokens = 0
self._initial_tokens: int | None = None

@override
async def verify_limit(self, messages: list[Message]):
if self._initial_tokens is None:
self._initial_tokens = await self.model_client.count_tokens(
messages
)
return

self._dialogue_tokens = (
await self.model_client.count_tokens(messages)
- self._initial_tokens
)

if self._dialogue_tokens > self.max_dialogue_tokens:
raise LimitExceededException(
f"Addons dialogue limit exceeded. Max tokens: {self.max_dialogue_tokens},"
f" actual tokens: {self._dialogue_tokens}."
)
24 changes: 18 additions & 6 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@
from aidial_sdk.chat_completion.response import Response
from aiohttp import hdrs

from aidial_assistant.application.addons_dialogue_limiter import (
AddonsDialogueLimiter,
)
from aidial_assistant.application.args import parse_args
from aidial_assistant.application.assistant_callback import (
AssistantChainCallback,
Expand All @@ -17,12 +20,12 @@
)
from aidial_assistant.chain.command_chain import CommandChain, CommandDict
from aidial_assistant.chain.history import History
from aidial_assistant.chain.model_client import (
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.model.model_client import (
ModelClient,
ReasonLengthException,
)
from aidial_assistant.commands.reply import Reply
from aidial_assistant.commands.run_plugin import PluginInfo, RunPlugin
from aidial_assistant.utils.exceptions import (
RequestParameterValidationError,
unhandled_exception_handler,
Expand Down Expand Up @@ -124,8 +127,12 @@ async def chat_completion(
or info.ai_plugin.description_for_human
)

# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000
command_dict: CommandDict = {
RunPlugin.token(): lambda: RunPlugin(model, tools),
RunPlugin.token(): lambda: RunPlugin(
model, tools, max_addons_dialogue_tokens
),
Reply.token(): Reply,
}
chain = CommandChain(
Expand Down Expand Up @@ -154,7 +161,10 @@ async def chat_completion(
callback = AssistantChainCallback(choice)
finish_reason = FinishReason.STOP
try:
await chain.run_chat(history, callback)
model_request_limiter = AddonsDialogueLimiter(
max_addons_dialogue_tokens, model
)
await chain.run_chat(history, callback, model_request_limiter)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH

Expand All @@ -163,7 +173,9 @@ async def chat_completion(

choice.close(finish_reason)

response.set_usage(model.prompt_tokens, model.completion_tokens)
response.set_usage(
model.total_prompt_tokens, model.total_completion_tokens
)

if discarded_messages is not None:
response.set_discarded_messages(discarded_messages)
78 changes: 60 additions & 18 deletions aidial_assistant/chain/command_chain.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import json
import logging
from abc import ABC, abstractmethod
from typing import Any, AsyncIterator, Callable, Tuple, cast

from aidial_sdk.chat_completion.request import Role
Expand All @@ -16,9 +17,8 @@
commands_to_text,
responses_to_text,
)
from aidial_assistant.chain.dialogue import Dialogue
from aidial_assistant.chain.dialogue import Dialogue, DialogueTurn
from aidial_assistant.chain.history import History
from aidial_assistant.chain.model_client import Message, ModelClient
from aidial_assistant.chain.model_response_reader import (
AssistantProtocolException,
CommandsReader,
Expand All @@ -30,6 +30,7 @@
from aidial_assistant.json_stream.json_node import JsonNode
from aidial_assistant.json_stream.json_parser import JsonParser
from aidial_assistant.json_stream.json_string import JsonString
from aidial_assistant.model.model_client import Message, ModelClient
from aidial_assistant.utils.stream import CumulativeStream

logger = logging.getLogger(__name__)
Expand All @@ -44,17 +45,33 @@
CommandDict = dict[str, CommandConstructor]


class LimitExceededException(Exception):
pass


class ModelRequestLimiter(ABC):
@abstractmethod
async def verify_limit(self, messages: list[Message]):
pass


class CommandChain:
def __init__(
self,
name: str,
model_client: ModelClient,
command_dict: CommandDict,
max_completion_tokens: int | None = None,
max_retry_count: int = DEFAULT_MAX_RETRY_COUNT,
):
self.name = name
self.model_client = model_client
self.command_dict = command_dict
self.model_extra_args = (
{}
if max_completion_tokens is None
else {"max_tokens": max_completion_tokens}
)
self.max_retry_count = max_retry_count

def _log_message(self, role: Role, content: str):
Expand All @@ -65,19 +82,26 @@ def _log_messages(self, messages: list[Message]):
for message in messages:
self._log_message(message.role, message.content)

async def run_chat(self, history: History, callback: ChainCallback):
async def run_chat(
self,
history: History,
callback: ChainCallback,
model_request_limiter: ModelRequestLimiter | None = None,
):
dialogue = Dialogue()
try:
messages = history.to_protocol_messages()
while True:
pair = await self._run_with_protocol_failure_retries(
callback, messages + dialogue.messages
dialogue_turn = await self._run_with_protocol_failure_retries(
callback,
messages + dialogue.messages,
model_request_limiter,
)

if pair is None:
if dialogue_turn is None:
break

dialogue.append(pair[0], pair[1])
dialogue.append(dialogue_turn)
except (JsonParsingException, AssistantProtocolException):
messages = (
history.to_best_effort_messages(
Expand All @@ -88,27 +112,39 @@ async def run_chat(self, history: History, callback: ChainCallback):
else history.to_user_messages()
)
await self._generate_result(messages, callback)
except InvalidRequestError as e:
if dialogue.is_empty() or e.code == "429":
except (InvalidRequestError, LimitExceededException) as e:
if dialogue.is_empty() or (
isinstance(e, InvalidRequestError) and e.code == "429"
):
raise

# Assuming the context length is exceeded
dialogue.pop()
# TODO: Limit the error message size. The error message should not exceed reserved assistant overheads.
await self._generate_result(
history.to_best_effort_messages(str(e), dialogue), callback
)

async def _run_with_protocol_failure_retries(
self, callback: ChainCallback, messages: list[Message]
) -> Tuple[str, str] | None:
self,
callback: ChainCallback,
messages: list[Message],
model_request_limiter: ModelRequestLimiter | None = None,
) -> DialogueTurn | None:
last_error: Exception | None = None
try:
self._log_messages(messages)
retries = Dialogue()
while True:
all_messages = self._reinforce_json_format(
messages + retries.messages
)
if model_request_limiter:
await model_request_limiter.verify_limit(all_messages)

chunk_stream = CumulativeStream(
self.model_client.agenerate(
self._reinforce_json_format(messages + retries.messages)
all_messages, **self.model_extra_args # type: ignore
)
)
try:
Expand All @@ -121,13 +157,16 @@ async def _run_with_protocol_failure_retries(
response_text = responses_to_text(responses)

callback.on_state(request_text, response_text)
return request_text, response_text
return DialogueTurn(
assistant_message=request_text,
user_message=response_text,
)

return None
break
except (JsonParsingException, AssistantProtocolException) as e:
logger.exception("Failed to process model response")

retry_count = len(retries.messages) // 2
retry_count = retries.dialogue_turn_count()
callback.on_error(
"Error"
if retry_count == 0
Expand All @@ -140,12 +179,15 @@ async def _run_with_protocol_failure_retries(

last_error = e
retries.append(
chunk_stream.buffer,
"Failed to parse JSON commands: " + str(e),
DialogueTurn(
assistant_message=chunk_stream.buffer,
user_message="Failed to parse JSON commands: "
+ str(e),
)
)
finally:
self._log_message(Role.ASSISTANT, chunk_stream.buffer)
except InvalidRequestError as e:
except (InvalidRequestError, LimitExceededException) as e:
if last_error:
# Retries can increase the prompt size, which may lead to token overflow.
# Thus, if the original error was a protocol error, it should be thrown instead.
Expand Down
18 changes: 14 additions & 4 deletions aidial_assistant/chain/dialogue.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,27 @@
from aidial_assistant.chain.model_client import Message
from pydantic import BaseModel

from aidial_assistant.model.model_client import Message


class DialogueTurn(BaseModel):
assistant_message: str
user_message: str


class Dialogue:
def __init__(self):
self.messages: list[Message] = []

def append(self, assistant_message: str, user_message: str):
self.messages.append(Message.assistant(assistant_message))
self.messages.append(Message.user(user_message))
def append(self, dialogue_turn: DialogueTurn):
self.messages.append(Message.assistant(dialogue_turn.assistant_message))
self.messages.append(Message.user(dialogue_turn.user_message))

def pop(self):
self.messages.pop()
self.messages.pop()

def is_empty(self):
return not self.messages

def dialogue_turn_count(self):
return len(self.messages) // 2
52 changes: 17 additions & 35 deletions aidial_assistant/chain/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,15 @@
from aidial_sdk.chat_completion import Role
from jinja2 import Template
from pydantic import BaseModel
from typing_extensions import override

from aidial_assistant.application.prompts import ENFORCE_JSON_FORMAT_TEMPLATE
from aidial_assistant.chain.command_result import (
CommandInvocation,
commands_to_text,
)
from aidial_assistant.chain.dialogue import Dialogue
from aidial_assistant.chain.model_client import (
ExtraResultsCallback,
Message,
ModelClient,
ReasonLengthException,
)
from aidial_assistant.commands.reply import Reply
from aidial_assistant.model.model_client import Message, ModelClient


class ContextLengthExceeded(Exception):
Expand All @@ -33,17 +28,16 @@ class ScopedMessage(BaseModel):
message: Message


class ModelExtraResultsCallback(ExtraResultsCallback):
def __init__(self):
self._discarded_messages: int | None = None

@override
def on_discarded_messages(self, discarded_messages: int):
self._discarded_messages = discarded_messages

@property
def discarded_messages(self) -> int | None:
return self._discarded_messages
def enforce_json_format(messages: list[Message]) -> list[Message]:
last_message = messages[-1]
return messages[:-1] + [
Message(
role=last_message.role,
content=ENFORCE_JSON_FORMAT_TEMPLATE.render(
response=last_message.content
),
),
]


class History:
Expand Down Expand Up @@ -128,28 +122,16 @@ def to_best_effort_messages(
async def truncate(
self, max_prompt_tokens: int, model_client: ModelClient
) -> "History":
extra_results_callback = ModelExtraResultsCallback()
# TODO: This will be replaced with a dedicated truncation call on model client once implemented.
stream = model_client.agenerate(
discarded_messages = await model_client.get_discarded_messages(
self.to_protocol_messages(),
extra_results_callback,
max_prompt_tokens=max_prompt_tokens,
max_tokens=1,
max_prompt_tokens,
)
try:
async for _ in stream:
pass
except ReasonLengthException:
# Expected for max_tokens=1
pass

if extra_results_callback.discarded_messages:

if discarded_messages > 0:
return History(
assistant_system_message_template=self.assistant_system_message_template,
best_effort_template=self.best_effort_template,
scoped_messages=self._skip_messages(
extra_results_callback.discarded_messages
),
scoped_messages=self._skip_messages(discarded_messages),
)

return self
Expand Down
Loading

0 comments on commit fedb23f

Please sign in to comment.