Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Don't split bot messages in transcript (Fixes #319) #337

141 changes: 131 additions & 10 deletions tests/streaming/agent/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from typing import Any, Dict, List, Optional, Union
from openai.openai_object import OpenAIObject
from pydantic import BaseModel
from vocode.streaming.models.actions import FunctionCall
from vocode.streaming.models.actions import (
ActionConfig,
ActionInput,
ActionOutput,
FunctionCall,
)
import pytest
from vocode.streaming.agent.utils import collate_response_async, openai_get_tokens
from vocode.streaming.agent.utils import (
collate_response_async,
format_openai_chat_messages_from_transcript,
openai_get_tokens,
)
from vocode.streaming.models.events import Sender
from vocode.streaming.models.transcript import (
ActionFinish,
ActionStart,
Message,
Transcript,
)


async def _agen_from_list(l):
Expand Down Expand Up @@ -202,7 +218,7 @@ class StreamOpenAIResponseTestCase(BaseModel):
"finish_reason": None,
},
{"delta": {}, "finish_reason": "function_call"},
]
],
]

EXPECTED_SENTENCES = [
Expand Down Expand Up @@ -230,7 +246,7 @@ class StreamOpenAIResponseTestCase(BaseModel):
],
[
FunctionCall(name="wave_hello", arguments='{\n "name": "user"\n}'),
]
],
]

FUNCTIONS_INPUT = [
Expand Down Expand Up @@ -267,7 +283,7 @@ class StreamOpenAIResponseTestCase(BaseModel):
"finish_reason": None,
},
{"delta": {}, "finish_reason": "function_call"},
]
],
]

FUNCTIONS_OUTPUT = [
Expand All @@ -278,10 +294,10 @@ class StreamOpenAIResponseTestCase(BaseModel):
],
[
FunctionCall(name="wave_hello", arguments='{\n "name": "user"\n}'),
]

],
]


@pytest.mark.asyncio
async def test_stream_openai_response_async():
test_cases = [
Expand All @@ -290,18 +306,123 @@ async def test_stream_openai_response_async():
create_chatgpt_openai_object(**obj) for obj in openai_objects
],
expected_sentences=expected_sentences,
get_functions=any(isinstance(item, FunctionCall) for item in expected_sentences)
get_functions=any(
isinstance(item, FunctionCall) for item in expected_sentences
),
)
for openai_objects, expected_sentences in zip(
OPENAI_OBJECTS, EXPECTED_SENTENCES
)
]

for test_case in test_cases:
actual_sentences = []
async for sentence in collate_response_async(
openai_get_tokens(_agen_from_list(test_case.openai_objects)),
get_functions=test_case.get_functions
get_functions=test_case.get_functions,
):
actual_sentences.append(sentence)
assert actual_sentences == test_case.expected_sentences


def test_format_openai_chat_messages_from_transcript():
test_cases = [
(
(
Transcript(
event_logs=[
Message(sender=Sender.BOT, text="Hello!"),
Message(sender=Sender.BOT, text="How are you doing today?"),
Message(sender=Sender.HUMAN, text="I'm doing well, thanks!"),
]
),
"prompt preamble",
),
[
{"role": "system", "content": "prompt preamble"},
{"role": "assistant", "content": "Hello! How are you doing today?"},
{"role": "user", "content": "I'm doing well, thanks!"},
],
),
(
(
Transcript(
event_logs=[
Message(sender=Sender.BOT, text="Hello!"),
Message(sender=Sender.BOT, text="How are you doing today?"),
Message(sender=Sender.HUMAN, text="I'm doing well, thanks!"),
]
),
None,
),
[
{"role": "assistant", "content": "Hello! How are you doing today?"},
{"role": "user", "content": "I'm doing well, thanks!"},
],
),
(
(
Transcript(
event_logs=[
Message(sender=Sender.BOT, text="Hello!"),
Message(sender=Sender.BOT, text="How are you doing today?"),
]
),
"prompt preamble",
),
[
{"role": "system", "content": "prompt preamble"},
{"role": "assistant", "content": "Hello! How are you doing today?"},
],
),
(
(
Transcript(
event_logs=[
Message(sender=Sender.BOT, text="Hello!"),
Message(
sender=Sender.HUMAN, text="Hello, what's the weather like?"
),
ActionStart(
action_type="weather",
action_input=ActionInput(
action_config=ActionConfig(),
conversation_id="asdf",
params={},
),
),
ActionFinish(
action_type="weather",
action_output=ActionOutput(
action_type="weather", response={}
),
),
]
),
None,
),
[
{"role": "assistant", "content": "Hello!"},
{
"role": "user",
"content": "Hello, what's the weather like?",
},
{
"role": "assistant",
"content": None,
"function_call": {
"name": "weather",
"arguments": "{}",
},
},
{
"role": "function",
"name": "weather",
"content": "{}",
},
],
),
]

for params, expected_output in test_cases:
assert format_openai_chat_messages_from_transcript(*params) == expected_output
28 changes: 27 additions & 1 deletion vocode/streaming/agent/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import re
from typing import (
Dict,
Expand All @@ -18,6 +19,7 @@
from vocode.streaming.models.transcript import (
ActionFinish,
ActionStart,
EventLog,
Message,
Transcript,
)
Expand Down Expand Up @@ -116,7 +118,31 @@ def format_openai_chat_messages_from_transcript(
chat_messages: List[Dict[str, Optional[Any]]] = (
[{"role": "system", "content": prompt_preamble}] if prompt_preamble else []
)
for event_log in transcript.event_logs:

# merge consecutive bot messages
new_event_logs: List[EventLog] = []
ajar98 marked this conversation as resolved.
Show resolved Hide resolved
idx = 0
while idx < len(transcript.event_logs):
bot_messages_buffer: List[Message] = []
current_log = transcript.event_logs[idx]
while isinstance(current_log, Message) and current_log.sender == Sender.BOT:
bot_messages_buffer.append(current_log)
idx += 1
try:
current_log = transcript.event_logs[idx]
except IndexError:
break
if bot_messages_buffer:
merged_bot_message = deepcopy(bot_messages_buffer[-1])
merged_bot_message.text = " ".join(
event_log.text for event_log in bot_messages_buffer
)
new_event_logs.append(merged_bot_message)
else:
new_event_logs.append(current_log)
idx += 1

for event_log in new_event_logs:
if isinstance(event_log, Message):
chat_messages.append(
{
Expand Down
Loading