Skip to content

Commit

Permalink
feat: limit size of model dialogue with addons (#58)
Browse files Browse the repository at this point in the history
* Limit size of model dialogue with addons.
  • Loading branch information
Oleksii-Klimov authored Jan 18, 2024
1 parent c725a33 commit b4bbbd0
Show file tree
Hide file tree
Showing 6 changed files with 242 additions and 17 deletions.
12 changes: 10 additions & 2 deletions aidial_assistant/application/assistant_application.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,10 +254,15 @@ async def _run_native_tools_chat(
request: Request,
response: Response,
):
# TODO: Add max_addons_dialogue_tokens as a request parameter
max_addons_dialogue_tokens = 1000

def create_command_tool(
plugin: PluginInfo,
) -> Tuple[CommandConstructor, ChatCompletionToolParam]:
return lambda: RunTool(model, plugin), _construct_tool(
return lambda: RunTool(
model, plugin, max_addons_dialogue_tokens
), _construct_tool(
plugin.info.ai_plugin.name_for_model,
plugin.info.ai_plugin.description_for_human,
)
Expand All @@ -275,7 +280,10 @@ def create_command_tool(
finish_reason = FinishReason.STOP
messages = convert_commands_to_tools(parse_history(request.messages))
try:
await chain.run_chat(messages, callback)
model_request_limiter = AddonsDialogueLimiter(
max_addons_dialogue_tokens, model
)
await chain.run_chat(messages, callback, model_request_limiter)
except ReasonLengthException:
finish_reason = FinishReason.LENGTH

Expand Down
7 changes: 5 additions & 2 deletions aidial_assistant/commands/run_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,12 @@ def _construct_tool(op: APIOperation) -> ChatCompletionToolParam:


class RunTool(Command):
def __init__(self, model: ModelClient, plugin: PluginInfo):
def __init__(
self, model: ModelClient, plugin: PluginInfo, max_completion_tokens: int
):
self.model = model
self.plugin = plugin
self.max_completion_tokens = max_completion_tokens

@staticmethod
def token():
Expand All @@ -91,7 +94,7 @@ def create_command_tool(op: APIOperation) -> CommandTool:
name: create_command_tool(op) for name, op in ops.items()
}

chain = ToolsChain(self.model, commands)
chain = ToolsChain(self.model, commands, self.max_completion_tokens)

messages = [
system_message(self.plugin.info.ai_plugin.description_for_model),
Expand Down
52 changes: 40 additions & 12 deletions aidial_assistant/tools_chain/tools_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,11 @@

from aidial_assistant.chain.callbacks.chain_callback import ChainCallback
from aidial_assistant.chain.callbacks.command_callback import CommandCallback
from aidial_assistant.chain.command_chain import CommandConstructor
from aidial_assistant.chain.command_chain import (
CommandConstructor,
LimitExceededException,
ModelRequestLimiter,
)
from aidial_assistant.chain.command_result import (
CommandInvocation,
CommandResult,
Expand Down Expand Up @@ -119,52 +123,76 @@ def on_tool_calls(


class ToolsChain:
def __init__(self, model: ModelClient, commands: CommandToolDict):
def __init__(
self,
model: ModelClient,
commands: CommandToolDict,
max_completion_tokens: int | None = None,
):
self.model = model
self.commands = commands
self.model_extra_args = (
{}
if max_completion_tokens is None
else {"max_tokens": max_completion_tokens}
)

async def run_chat(
self,
messages: list[ChatCompletionMessageParam],
callback: ChainCallback,
model_request_limiter: ModelRequestLimiter | None = None,
):
result_callback = callback.result_callback()
dialogue: list[ChatCompletionMessageParam] = []
last_message_block_length = 0
tools = [tool for _, tool in self.commands.values()]
all_messages = messages.copy()
while True:
tool_calls_callback = ToolCallsCallback()
try:
if model_request_limiter:
await model_request_limiter.verify_limit(all_messages)

async for chunk in self.model.agenerate(
messages + dialogue, tool_calls_callback, tools=tools
all_messages,
tool_calls_callback,
tools=tools,
**self.model_extra_args,
):
result_callback.on_result(chunk)
except BadRequestError as e:
if len(dialogue) == 0 or e.code == "429":
except (BadRequestError, LimitExceededException) as e:
if (
last_message_block_length == 0
or isinstance(e, BadRequestError)
and e.code == "429"
):
raise

# If the dialog size exceeds model context size then remove last message block
# and try again without tools.
dialogue = dialogue[:-last_message_block_length]
all_messages = all_messages[:-last_message_block_length]
async for chunk in self.model.agenerate(
messages + dialogue, tool_calls_callback
all_messages, tool_calls_callback
):
result_callback.on_result(chunk)
break

if not tool_calls_callback.tool_calls:
break

dialogue.append(
previous_message_count = len(all_messages)
all_messages.append(
tool_calls_message(
tool_calls_callback.tool_calls,
)
)
result_messages = await self._run_tools(
all_messages += await self._run_tools(
tool_calls_callback.tool_calls, callback
)
dialogue.extend(result_messages)
last_message_block_length = len(result_messages) + 1

last_message_block_length = (
len(all_messages) - previous_message_count
)

def _create_command(self, name: str) -> Command:
if name not in self.commands:
Expand Down
2 changes: 1 addition & 1 deletion aidial_assistant/utils/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def _to_http_exception(e: Exception) -> HTTPException:
if isinstance(e, APIError):
raise HTTPException(
message=e.message,
status_code=getattr(e, "status_code") or 500,
status_code=getattr(e, "status_code", None) or 500,
type=e.type or "runtime_error",
code=e.code,
param=e.param,
Expand Down
68 changes: 68 additions & 0 deletions tests/unit_tests/tools_chain/test_tools_chain_best_effort.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import json

import pytest
from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
)
from openai.types.chat.chat_completion_message_tool_call_param import Function

from aidial_assistant.tools_chain.tools_chain import ToolsChain
from aidial_assistant.utils.open_ai import (
construct_tool,
tool_calls_message,
tool_message,
user_message,
)
from tests.utils.mocks import (
TestChainCallback,
TestCommand,
TestModelClient,
TestModelRequestLimiter,
)

TEST_COMMAND_NAME = "<test command>"
TOOL_ID = "<tool id>"
TOOL_RESPONSE = "<tool response>"
BEST_EFFORT_RESPONSE = "<best effort response>"


@pytest.mark.asyncio
async def test_model_request_limit_exceeded():
messages: list[ChatCompletionMessageParam] = [user_message("<query>")]
command_args = {"<test argument>": "<test value>"}
tool_calls = [
ChatCompletionMessageToolCallParam(
id=TOOL_ID,
function=Function(
name=TEST_COMMAND_NAME,
arguments=json.dumps(command_args),
),
type="function",
)
]
tool = construct_tool(TEST_COMMAND_NAME, "", {}, [])
model = TestModelClient(
tool_calls={
TestModelClient.agenerate_key(messages, tools=[tool]): tool_calls
},
results={TestModelClient.agenerate_key(messages): BEST_EFFORT_RESPONSE},
)

messages_with_dialogue = messages + [
tool_calls_message(tool_calls=tool_calls),
tool_message(TOOL_RESPONSE, TOOL_ID),
]
model_request_limiter = TestModelRequestLimiter(messages_with_dialogue)
callback = TestChainCallback()
command = TestCommand(
{TestCommand.execute_key(command_args): TOOL_RESPONSE}
)
tools_chain = ToolsChain(
model,
commands={TEST_COMMAND_NAME: (lambda: command, tool)},
)

await tools_chain.run_chat(messages, callback, model_request_limiter)

assert callback.mock_result_callback.result == BEST_EFFORT_RESPONSE
118 changes: 118 additions & 0 deletions tests/utils/mocks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
import json
from typing import Any, AsyncIterator
from unittest.mock import MagicMock, Mock

from openai.types.chat import (
ChatCompletionMessageParam,
ChatCompletionMessageToolCallParam,
)
from typing_extensions import override

from aidial_assistant.chain.callbacks.chain_callback import ChainCallback
from aidial_assistant.chain.callbacks.command_callback import CommandCallback
from aidial_assistant.chain.callbacks.result_callback import ResultCallback
from aidial_assistant.chain.command_chain import (
LimitExceededException,
ModelRequestLimiter,
)
from aidial_assistant.commands.base import (
Command,
ExecutionCallback,
ResultObject,
ResultType,
)
from aidial_assistant.model.model_client import (
ExtraResultsCallback,
ModelClient,
)


class TestModelRequestLimiter(ModelRequestLimiter):
def __init__(self, exception_trigger: list[ChatCompletionMessageParam]):
self.exception_trigger = exception_trigger

async def verify_limit(self, messages: list[ChatCompletionMessageParam]):
if messages == self.exception_trigger:
raise LimitExceededException()


class TestModelClient(ModelClient):
def __init__(
self,
tool_calls: dict[str, list[ChatCompletionMessageToolCallParam]],
results: dict[str, str],
):
super().__init__(Mock(), {})
self.tool_calls = tool_calls
self.results = results

@override
async def agenerate(
self,
messages: list[ChatCompletionMessageParam],
extra_results_callback: ExtraResultsCallback | None = None,
**kwargs,
) -> AsyncIterator[str]:
args = TestModelClient.agenerate_key(messages, **kwargs)
if extra_results_callback and args in self.tool_calls:
extra_results_callback.on_tool_calls(self.tool_calls[args])
return

if args in self.results:
yield self.results[args]
return

assert False, f"Unexpected arguments: {args}"

@staticmethod
def agenerate_key(
messages: list[ChatCompletionMessageParam], **kwargs
) -> str:
return json.dumps({"messages": messages, **kwargs})


class TestCommand(Command):
def __init__(self, results: dict[str, str]):
self.results = results

@staticmethod
def token() -> str:
return "test-command"

@override
async def execute(
self, args: dict[str, Any], execution_callback: ExecutionCallback
) -> ResultObject:
args_string = TestCommand.execute_key(args)
assert args_string in self.results, f"Unexpected argument: {args}"

return ResultObject(ResultType.TEXT, self.results[args_string])

@staticmethod
def execute_key(args: dict[str, Any]) -> str:
return json.dumps({"args": args})


class TestResultCallback(ResultCallback):
def __init__(self):
self.result: str = ""

def on_result(self, chunk: str):
self.result += chunk


class TestChainCallback(ChainCallback):
def __init__(self):
self.mock_result_callback = TestResultCallback()

def command_callback(self) -> CommandCallback:
return MagicMock(spec=CommandCallback)

def on_state(self, request: str, response: str):
pass

def result_callback(self) -> ResultCallback:
return self.mock_result_callback

def on_error(self, title: str, error: str):
pass

0 comments on commit b4bbbd0

Please sign in to comment.