From faea324c6f2c70569cf1dc7705a8ecba7fc7a97e Mon Sep 17 00:00:00 2001 From: Oleksii-Klimov <133792808+Oleksii-Klimov@users.noreply.github.com> Date: Tue, 28 Nov 2023 15:18:49 +0000 Subject: [PATCH] Fix history formatting. (#34) --- aidial_assistant/chain/command_chain.py | 14 +++++---- aidial_assistant/chain/command_result.py | 9 ++++++ aidial_assistant/chain/history.py | 18 +++++++----- tests/unit_tests/chain/test_history.py | 37 ++++++++++++++++++++++++ 4 files changed, 65 insertions(+), 13 deletions(-) create mode 100644 tests/unit_tests/chain/test_history.py diff --git a/aidial_assistant/chain/command_chain.py b/aidial_assistant/chain/command_chain.py index 03b8bcf..244d4d5 100644 --- a/aidial_assistant/chain/command_chain.py +++ b/aidial_assistant/chain/command_chain.py @@ -1,6 +1,6 @@ import json import logging -from typing import Any, AsyncIterator, Callable, Tuple +from typing import Any, AsyncIterator, Callable, Tuple, cast from aidial_sdk.chat_completion.request import Role from openai import InvalidRequestError @@ -10,8 +10,10 @@ from aidial_assistant.chain.callbacks.command_callback import CommandCallback from aidial_assistant.chain.callbacks.result_callback import ResultCallback from aidial_assistant.chain.command_result import ( + CommandInvocation, CommandResult, Status, + commands_to_text, responses_to_text, ) from aidial_assistant.chain.dialogue import Dialogue @@ -127,7 +129,7 @@ async def _run_with_protocol_failure_retries( ) if responses: - request_text = json.dumps({"commands": commands}) + request_text = commands_to_text(commands) response_text = responses_to_text(responses) callback.on_state(request_text, response_text) @@ -162,12 +164,12 @@ async def _run_with_protocol_failure_retries( async def _run_commands( self, chunk_stream: AsyncIterator[str], callback: ChainCallback - ) -> Tuple[list[dict[str, Any]], list[CommandResult]]: + ) -> Tuple[list[CommandInvocation], list[CommandResult]]: char_stream = CharacterStream(chunk_stream) await skip_to_json_start(char_stream) async with JsonParser.parse(char_stream) as root_node: - commands: list[dict[str, Any]] = [] + commands: list[CommandInvocation] = [] responses: list[CommandResult] = [] request_reader = CommandsReader(root_node) async for invocation in request_reader.parse_invocations(): @@ -190,7 +192,9 @@ async def _run_commands( command_name, command, args, callback ) - commands.append(invocation.node.value()) + commands.append( + cast(CommandInvocation, invocation.node.value()) + ) responses.append(response) return commands, responses diff --git a/aidial_assistant/chain/command_result.py b/aidial_assistant/chain/command_result.py index e79605e..133685d 100644 --- a/aidial_assistant/chain/command_result.py +++ b/aidial_assistant/chain/command_result.py @@ -16,5 +16,14 @@ class CommandResult(TypedDict): error messages for the failed one.""" +class CommandInvocation(TypedDict): + command: str + args: list[str] + + def responses_to_text(responses: List[CommandResult]) -> str: return json.dumps({"responses": responses}) + + +def commands_to_text(commands: List[CommandInvocation]) -> str: + return json.dumps({"commands": commands}) diff --git a/aidial_assistant/chain/history.py b/aidial_assistant/chain/history.py index 6ff0645..e7040db 100644 --- a/aidial_assistant/chain/history.py +++ b/aidial_assistant/chain/history.py @@ -1,10 +1,13 @@ -import json from enum import Enum from aidial_sdk.chat_completion import Role from jinja2 import Template from pydantic import BaseModel +from aidial_assistant.chain.command_result import ( + CommandInvocation, + commands_to_text, +) from aidial_assistant.chain.dialogue import Dialogue from aidial_assistant.chain.model_client import Message from aidial_assistant.commands.reply import Reply @@ -54,13 +57,12 @@ def to_protocol_messages(self) -> list[Message]: elif scope == MessageScope.USER and message.role == Role.ASSISTANT: # Clients see replies in plain text, but the model should understand how to reply appropriately. - content = json.dumps( - { - "commands": { - "command": Reply.token(), - "args": [message.content], - } - } + content = commands_to_text( + [ + CommandInvocation( + command=Reply.token(), args=[message.content] + ) + ] ) messages.append(Message.assistant(content=content)) else: diff --git a/tests/unit_tests/chain/test_history.py b/tests/unit_tests/chain/test_history.py new file mode 100644 index 0000000..ae80332 --- /dev/null +++ b/tests/unit_tests/chain/test_history.py @@ -0,0 +1,37 @@ +from jinja2 import Template + +from aidial_assistant.chain.history import History, MessageScope, ScopedMessage +from aidial_assistant.chain.model_client import Message + +SYSTEM_MESSAGE = "" +USER_MESSAGE = "" +ASSISTANT_MESSAGE = "" + + +def test_protocol_messages(): + history = History( + assistant_system_message_template=Template( + "system message={{system_prefix}}" + ), + best_effort_template=Template(""), + scoped_messages=[ + ScopedMessage( + scope=MessageScope.USER, message=Message.system(SYSTEM_MESSAGE) + ), + ScopedMessage( + scope=MessageScope.USER, message=Message.user(USER_MESSAGE) + ), + ScopedMessage( + scope=MessageScope.USER, + message=Message.assistant(ASSISTANT_MESSAGE), + ), + ], + ) + + assert history.to_protocol_messages() == [ + Message.system(f"system message={SYSTEM_MESSAGE}"), + Message.user(USER_MESSAGE), + Message.assistant( + f'{{"commands": [{{"command": "reply", "args": ["{ASSISTANT_MESSAGE}"]}}]}}' + ), + ]