-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: limit size of model dialogue with addons (#58)
* Limit size of model dialogue with addons.
- Loading branch information
1 parent
c725a33
commit b4bbbd0
Showing
6 changed files
with
242 additions
and
17 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
68 changes: 68 additions & 0 deletions
68
tests/unit_tests/tools_chain/test_tools_chain_best_effort.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
import json | ||
|
||
import pytest | ||
from openai.types.chat import ( | ||
ChatCompletionMessageParam, | ||
ChatCompletionMessageToolCallParam, | ||
) | ||
from openai.types.chat.chat_completion_message_tool_call_param import Function | ||
|
||
from aidial_assistant.tools_chain.tools_chain import ToolsChain | ||
from aidial_assistant.utils.open_ai import ( | ||
construct_tool, | ||
tool_calls_message, | ||
tool_message, | ||
user_message, | ||
) | ||
from tests.utils.mocks import ( | ||
TestChainCallback, | ||
TestCommand, | ||
TestModelClient, | ||
TestModelRequestLimiter, | ||
) | ||
|
||
TEST_COMMAND_NAME = "<test command>" | ||
TOOL_ID = "<tool id>" | ||
TOOL_RESPONSE = "<tool response>" | ||
BEST_EFFORT_RESPONSE = "<best effort response>" | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_model_request_limit_exceeded(): | ||
messages: list[ChatCompletionMessageParam] = [user_message("<query>")] | ||
command_args = {"<test argument>": "<test value>"} | ||
tool_calls = [ | ||
ChatCompletionMessageToolCallParam( | ||
id=TOOL_ID, | ||
function=Function( | ||
name=TEST_COMMAND_NAME, | ||
arguments=json.dumps(command_args), | ||
), | ||
type="function", | ||
) | ||
] | ||
tool = construct_tool(TEST_COMMAND_NAME, "", {}, []) | ||
model = TestModelClient( | ||
tool_calls={ | ||
TestModelClient.agenerate_key(messages, tools=[tool]): tool_calls | ||
}, | ||
results={TestModelClient.agenerate_key(messages): BEST_EFFORT_RESPONSE}, | ||
) | ||
|
||
messages_with_dialogue = messages + [ | ||
tool_calls_message(tool_calls=tool_calls), | ||
tool_message(TOOL_RESPONSE, TOOL_ID), | ||
] | ||
model_request_limiter = TestModelRequestLimiter(messages_with_dialogue) | ||
callback = TestChainCallback() | ||
command = TestCommand( | ||
{TestCommand.execute_key(command_args): TOOL_RESPONSE} | ||
) | ||
tools_chain = ToolsChain( | ||
model, | ||
commands={TEST_COMMAND_NAME: (lambda: command, tool)}, | ||
) | ||
|
||
await tools_chain.run_chat(messages, callback, model_request_limiter) | ||
|
||
assert callback.mock_result_callback.result == BEST_EFFORT_RESPONSE |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
import json | ||
from typing import Any, AsyncIterator | ||
from unittest.mock import MagicMock, Mock | ||
|
||
from openai.types.chat import ( | ||
ChatCompletionMessageParam, | ||
ChatCompletionMessageToolCallParam, | ||
) | ||
from typing_extensions import override | ||
|
||
from aidial_assistant.chain.callbacks.chain_callback import ChainCallback | ||
from aidial_assistant.chain.callbacks.command_callback import CommandCallback | ||
from aidial_assistant.chain.callbacks.result_callback import ResultCallback | ||
from aidial_assistant.chain.command_chain import ( | ||
LimitExceededException, | ||
ModelRequestLimiter, | ||
) | ||
from aidial_assistant.commands.base import ( | ||
Command, | ||
ExecutionCallback, | ||
ResultObject, | ||
ResultType, | ||
) | ||
from aidial_assistant.model.model_client import ( | ||
ExtraResultsCallback, | ||
ModelClient, | ||
) | ||
|
||
|
||
class TestModelRequestLimiter(ModelRequestLimiter): | ||
def __init__(self, exception_trigger: list[ChatCompletionMessageParam]): | ||
self.exception_trigger = exception_trigger | ||
|
||
async def verify_limit(self, messages: list[ChatCompletionMessageParam]): | ||
if messages == self.exception_trigger: | ||
raise LimitExceededException() | ||
|
||
|
||
class TestModelClient(ModelClient): | ||
def __init__( | ||
self, | ||
tool_calls: dict[str, list[ChatCompletionMessageToolCallParam]], | ||
results: dict[str, str], | ||
): | ||
super().__init__(Mock(), {}) | ||
self.tool_calls = tool_calls | ||
self.results = results | ||
|
||
@override | ||
async def agenerate( | ||
self, | ||
messages: list[ChatCompletionMessageParam], | ||
extra_results_callback: ExtraResultsCallback | None = None, | ||
**kwargs, | ||
) -> AsyncIterator[str]: | ||
args = TestModelClient.agenerate_key(messages, **kwargs) | ||
if extra_results_callback and args in self.tool_calls: | ||
extra_results_callback.on_tool_calls(self.tool_calls[args]) | ||
return | ||
|
||
if args in self.results: | ||
yield self.results[args] | ||
return | ||
|
||
assert False, f"Unexpected arguments: {args}" | ||
|
||
@staticmethod | ||
def agenerate_key( | ||
messages: list[ChatCompletionMessageParam], **kwargs | ||
) -> str: | ||
return json.dumps({"messages": messages, **kwargs}) | ||
|
||
|
||
class TestCommand(Command): | ||
def __init__(self, results: dict[str, str]): | ||
self.results = results | ||
|
||
@staticmethod | ||
def token() -> str: | ||
return "test-command" | ||
|
||
@override | ||
async def execute( | ||
self, args: dict[str, Any], execution_callback: ExecutionCallback | ||
) -> ResultObject: | ||
args_string = TestCommand.execute_key(args) | ||
assert args_string in self.results, f"Unexpected argument: {args}" | ||
|
||
return ResultObject(ResultType.TEXT, self.results[args_string]) | ||
|
||
@staticmethod | ||
def execute_key(args: dict[str, Any]) -> str: | ||
return json.dumps({"args": args}) | ||
|
||
|
||
class TestResultCallback(ResultCallback): | ||
def __init__(self): | ||
self.result: str = "" | ||
|
||
def on_result(self, chunk: str): | ||
self.result += chunk | ||
|
||
|
||
class TestChainCallback(ChainCallback): | ||
def __init__(self): | ||
self.mock_result_callback = TestResultCallback() | ||
|
||
def command_callback(self) -> CommandCallback: | ||
return MagicMock(spec=CommandCallback) | ||
|
||
def on_state(self, request: str, response: str): | ||
pass | ||
|
||
def result_callback(self) -> ResultCallback: | ||
return self.mock_result_callback | ||
|
||
def on_error(self, title: str, error: str): | ||
pass |