Skip to content

Commit

Permalink
Don't split bot messages in transcript (Fixes #319) (#337)
Browse files Browse the repository at this point in the history
* Initial fix

* Fix after merge

* Add message_ids everywhere

* Revert past commits

* Add format_openai_chat_messages_from_transcript solution

* Deepcopy before changing attribute

* Add type

* add comments, clean up code

* checkpoint

* adds test for format_openai_chat_messages

---------

Co-authored-by: Ajay Raj <[email protected]>
  • Loading branch information
HHousen and ajar98 authored Aug 7, 2023
1 parent 76fc257 commit 806788b
Show file tree
Hide file tree
Showing 2 changed files with 158 additions and 11 deletions.
141 changes: 131 additions & 10 deletions tests/streaming/agent/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,25 @@
from typing import Any, Dict, List, Optional, Union
from openai.openai_object import OpenAIObject
from pydantic import BaseModel
from vocode.streaming.models.actions import FunctionCall
from vocode.streaming.models.actions import (
ActionConfig,
ActionInput,
ActionOutput,
FunctionCall,
)
import pytest
from vocode.streaming.agent.utils import collate_response_async, openai_get_tokens
from vocode.streaming.agent.utils import (
collate_response_async,
format_openai_chat_messages_from_transcript,
openai_get_tokens,
)
from vocode.streaming.models.events import Sender
from vocode.streaming.models.transcript import (
ActionFinish,
ActionStart,
Message,
Transcript,
)


async def _agen_from_list(l):
Expand Down Expand Up @@ -202,7 +218,7 @@ class StreamOpenAIResponseTestCase(BaseModel):
"finish_reason": None,
},
{"delta": {}, "finish_reason": "function_call"},
]
],
]

EXPECTED_SENTENCES = [
Expand Down Expand Up @@ -230,7 +246,7 @@ class StreamOpenAIResponseTestCase(BaseModel):
],
[
FunctionCall(name="wave_hello", arguments='{\n "name": "user"\n}'),
]
],
]

FUNCTIONS_INPUT = [
Expand Down Expand Up @@ -267,7 +283,7 @@ class StreamOpenAIResponseTestCase(BaseModel):
"finish_reason": None,
},
{"delta": {}, "finish_reason": "function_call"},
]
],
]

FUNCTIONS_OUTPUT = [
Expand All @@ -278,10 +294,10 @@ class StreamOpenAIResponseTestCase(BaseModel):
],
[
FunctionCall(name="wave_hello", arguments='{\n "name": "user"\n}'),
]

],
]


@pytest.mark.asyncio
async def test_stream_openai_response_async():
test_cases = [
Expand All @@ -290,18 +306,123 @@ async def test_stream_openai_response_async():
create_chatgpt_openai_object(**obj) for obj in openai_objects
],
expected_sentences=expected_sentences,
get_functions=any(isinstance(item, FunctionCall) for item in expected_sentences)
get_functions=any(
isinstance(item, FunctionCall) for item in expected_sentences
),
)
for openai_objects, expected_sentences in zip(
OPENAI_OBJECTS, EXPECTED_SENTENCES
)
]

for test_case in test_cases:
actual_sentences = []
async for sentence in collate_response_async(
openai_get_tokens(_agen_from_list(test_case.openai_objects)),
get_functions=test_case.get_functions
get_functions=test_case.get_functions,
):
actual_sentences.append(sentence)
assert actual_sentences == test_case.expected_sentences


def test_format_openai_chat_messages_from_transcript():
test_cases = [
(
(
Transcript(
event_logs=[
Message(sender=Sender.BOT, text="Hello!"),
Message(sender=Sender.BOT, text="How are you doing today?"),
Message(sender=Sender.HUMAN, text="I'm doing well, thanks!"),
]
),
"prompt preamble",
),
[
{"role": "system", "content": "prompt preamble"},
{"role": "assistant", "content": "Hello! How are you doing today?"},
{"role": "user", "content": "I'm doing well, thanks!"},
],
),
(
(
Transcript(
event_logs=[
Message(sender=Sender.BOT, text="Hello!"),
Message(sender=Sender.BOT, text="How are you doing today?"),
Message(sender=Sender.HUMAN, text="I'm doing well, thanks!"),
]
),
None,
),
[
{"role": "assistant", "content": "Hello! How are you doing today?"},
{"role": "user", "content": "I'm doing well, thanks!"},
],
),
(
(
Transcript(
event_logs=[
Message(sender=Sender.BOT, text="Hello!"),
Message(sender=Sender.BOT, text="How are you doing today?"),
]
),
"prompt preamble",
),
[
{"role": "system", "content": "prompt preamble"},
{"role": "assistant", "content": "Hello! How are you doing today?"},
],
),
(
(
Transcript(
event_logs=[
Message(sender=Sender.BOT, text="Hello!"),
Message(
sender=Sender.HUMAN, text="Hello, what's the weather like?"
),
ActionStart(
action_type="weather",
action_input=ActionInput(
action_config=ActionConfig(),
conversation_id="asdf",
params={},
),
),
ActionFinish(
action_type="weather",
action_output=ActionOutput(
action_type="weather", response={}
),
),
]
),
None,
),
[
{"role": "assistant", "content": "Hello!"},
{
"role": "user",
"content": "Hello, what's the weather like?",
},
{
"role": "assistant",
"content": None,
"function_call": {
"name": "weather",
"arguments": "{}",
},
},
{
"role": "function",
"name": "weather",
"content": "{}",
},
],
),
]

for params, expected_output in test_cases:
assert format_openai_chat_messages_from_transcript(*params) == expected_output
28 changes: 27 additions & 1 deletion vocode/streaming/agent/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from copy import deepcopy
import re
from typing import (
Dict,
Expand All @@ -18,6 +19,7 @@
from vocode.streaming.models.transcript import (
ActionFinish,
ActionStart,
EventLog,
Message,
Transcript,
)
Expand Down Expand Up @@ -116,7 +118,31 @@ def format_openai_chat_messages_from_transcript(
chat_messages: List[Dict[str, Optional[Any]]] = (
[{"role": "system", "content": prompt_preamble}] if prompt_preamble else []
)
for event_log in transcript.event_logs:

# merge consecutive bot messages
new_event_logs: List[EventLog] = []
idx = 0
while idx < len(transcript.event_logs):
bot_messages_buffer: List[Message] = []
current_log = transcript.event_logs[idx]
while isinstance(current_log, Message) and current_log.sender == Sender.BOT:
bot_messages_buffer.append(current_log)
idx += 1
try:
current_log = transcript.event_logs[idx]
except IndexError:
break
if bot_messages_buffer:
merged_bot_message = deepcopy(bot_messages_buffer[-1])
merged_bot_message.text = " ".join(
event_log.text for event_log in bot_messages_buffer
)
new_event_logs.append(merged_bot_message)
else:
new_event_logs.append(current_log)
idx += 1

for event_log in new_event_logs:
if isinstance(event_log, Message):
chat_messages.append(
{
Expand Down

0 comments on commit 806788b

Please sign in to comment.