Skip to content

Commit

Permalink
adds is_interruptible granularity to individual messages
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 committed Sep 6, 2023
1 parent c31cddf commit 19f0e0e
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 17 deletions.
4 changes: 2 additions & 2 deletions vocode/streaming/agent/anthropic_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ async def generate_response(
human_input,
conversation_id: str,
is_interrupt: bool = False,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[Tuple[str, bool], None]:
self.memory.chat_memory.messages.append(HumanMessage(content=human_input))

bot_memory_message = AIMessage(content="")
Expand All @@ -115,7 +115,7 @@ async def generate_response(
if sentence:
bot_memory_message.content = bot_memory_message.content + sentence
buffer = remainder
yield sentence
yield sentence, True
continue

def update_last_bot_message_on_cut_off(self, message: str):
Expand Down
14 changes: 11 additions & 3 deletions vocode/streaming/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,10 @@ class AgentInput(TypedModel, type=AgentInputType.BASE.value):
conversation_id: str
vonage_uuid: Optional[str]
twilio_sid: Optional[str]
agent_response_tracker: Optional[asyncio.Event] = None

class Config:
arbitrary_types_allowed = True


class TranscriptionAgentInput(AgentInput, type=AgentInputType.TRANSCRIPTION.value):
Expand Down Expand Up @@ -213,7 +217,7 @@ async def handle_generate_response(
)
is_first_response = True
function_call = None
async for response in responses:
async for response, is_interruptible in responses:
if isinstance(response, FunctionCall):
function_call = response
continue
Expand All @@ -222,7 +226,9 @@ async def handle_generate_response(
is_first_response = False
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=BaseMessage(text=response)),
is_interruptible=self.agent_config.allow_agent_to_be_cut_off,
is_interruptible=self.agent_config.allow_agent_to_be_cut_off
and is_interruptible,
agent_response_tracker=agent_input.agent_response_tracker,
)
# TODO: implement should_stop for generate_responses
agent_span.end()
Expand Down Expand Up @@ -429,5 +435,7 @@ def generate_response(
human_input,
conversation_id: str,
is_interrupt: bool = False,
) -> AsyncGenerator[Union[str, FunctionCall], None]:
) -> AsyncGenerator[
Tuple[Union[str, FunctionCall], bool], None
]: # tuple of the content and whether it is interruptible
raise NotImplementedError
12 changes: 7 additions & 5 deletions vocode/streaming/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def get_functions(self):
for action_config in self.agent_config.actions
]

def get_chat_parameters(self, messages: Optional[List] = None):
def get_chat_parameters(
self, messages: Optional[List] = None, use_functions: bool = True
):
assert self.transcript is not None
messages = messages or format_openai_chat_messages_from_transcript(
self.transcript, self.agent_config.prompt_preamble
Expand All @@ -86,7 +88,7 @@ def get_chat_parameters(self, messages: Optional[List] = None):
else:
parameters["model"] = self.agent_config.model_name

if self.functions:
if use_functions and self.functions:
parameters["functions"] = self.functions

return parameters
Expand Down Expand Up @@ -134,10 +136,10 @@ async def generate_response(
human_input: str,
conversation_id: str,
is_interrupt: bool = False,
) -> AsyncGenerator[Union[str, FunctionCall], None]:
) -> AsyncGenerator[Tuple[Union[str, FunctionCall], bool], None]:
if is_interrupt and self.agent_config.cut_off_response:
cut_off_response = self.get_cut_off_response()
yield cut_off_response
yield cut_off_response, False
return
assert self.transcript is not None

Expand Down Expand Up @@ -174,4 +176,4 @@ async def generate_response(
async for message in collate_response_async(
openai_get_tokens(stream), get_functions=True
):
yield message
yield message, True
4 changes: 2 additions & 2 deletions vocode/streaming/agent/echo_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ async def generate_response(
human_input,
conversation_id: str,
is_interrupt: bool = False,
) -> AsyncGenerator[str, None]:
yield human_input
) -> AsyncGenerator[Tuple[str, bool], None]:
yield human_input, True

def update_last_bot_message_on_cut_off(self, message: str):
pass
4 changes: 2 additions & 2 deletions vocode/streaming/agent/llamacpp_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ async def generate_response(
human_input: str,
conversation_id: str,
is_interrupt: bool = False,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[Tuple[str, bool], None]:
asyncio.get_event_loop().run_in_executor(
self.thread_pool_executor,
lambda input: self.conversation.predict(input=input),
Expand All @@ -145,4 +145,4 @@ async def generate_response(
async for message in collate_response_async(
self.llamacpp_get_tokens(),
):
yield str(message)
yield str(message), True
6 changes: 3 additions & 3 deletions vocode/streaming/agent/llm_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,12 @@ async def generate_response(
human_input,
conversation_id: str,
is_interrupt: bool = False,
) -> AsyncGenerator[str, None]:
) -> AsyncGenerator[Tuple[str, bool], None]:
self.logger.debug("LLM generating response to human input")
if is_interrupt and self.agent_config.cut_off_response:
cut_off_response = self.get_cut_off_response()
self.memory.append(self.get_memory_entry(human_input, cut_off_response))
yield cut_off_response
yield cut_off_response, False
return
self.memory.append(self.get_memory_entry(human_input, ""))
if self.is_first_response and self.first_response:
Expand All @@ -146,7 +146,7 @@ async def generate_response(
sentence = re.sub(r"^\s+(.*)", r" \1", sentence)
response_buffer += sentence
self.memory[-1] = self.get_memory_entry(human_input, response_buffer)
yield sentence
yield sentence, True

def update_last_bot_message_on_cut_off(self, message: str):
last_message = self.memory[-1]
Expand Down

0 comments on commit 19f0e0e

Please sign in to comment.