From 19f0e0e999599fac63b05c4b575ba0be4204e98f Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Tue, 5 Sep 2023 17:19:27 -0700 Subject: [PATCH] adds is_interruptible granularity to individual messages --- vocode/streaming/agent/anthropic_agent.py | 4 ++-- vocode/streaming/agent/base_agent.py | 14 +++++++++++--- vocode/streaming/agent/chat_gpt_agent.py | 12 +++++++----- vocode/streaming/agent/echo_agent.py | 4 ++-- vocode/streaming/agent/llamacpp_agent.py | 4 ++-- vocode/streaming/agent/llm_agent.py | 6 +++--- 6 files changed, 27 insertions(+), 17 deletions(-) diff --git a/vocode/streaming/agent/anthropic_agent.py b/vocode/streaming/agent/anthropic_agent.py index e1821c3ed..617751238 100644 --- a/vocode/streaming/agent/anthropic_agent.py +++ b/vocode/streaming/agent/anthropic_agent.py @@ -91,7 +91,7 @@ async def generate_response( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[Tuple[str, bool], None]: self.memory.chat_memory.messages.append(HumanMessage(content=human_input)) bot_memory_message = AIMessage(content="") @@ -115,7 +115,7 @@ async def generate_response( if sentence: bot_memory_message.content = bot_memory_message.content + sentence buffer = remainder - yield sentence + yield sentence, True continue def update_last_bot_message_on_cut_off(self, message: str): diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index bcede5d77..8c2b5c2f3 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -66,6 +66,10 @@ class AgentInput(TypedModel, type=AgentInputType.BASE.value): conversation_id: str vonage_uuid: Optional[str] twilio_sid: Optional[str] + agent_response_tracker: Optional[asyncio.Event] = None + + class Config: + arbitrary_types_allowed = True class TranscriptionAgentInput(AgentInput, type=AgentInputType.TRANSCRIPTION.value): @@ -213,7 +217,7 @@ async def handle_generate_response( ) is_first_response = True function_call = None - async for response in responses: + async for response, is_interruptible in responses: if isinstance(response, FunctionCall): function_call = response continue @@ -222,7 +226,9 @@ async def handle_generate_response( is_first_response = False self.produce_interruptible_agent_response_event_nonblocking( AgentResponseMessage(message=BaseMessage(text=response)), - is_interruptible=self.agent_config.allow_agent_to_be_cut_off, + is_interruptible=self.agent_config.allow_agent_to_be_cut_off + and is_interruptible, + agent_response_tracker=agent_input.agent_response_tracker, ) # TODO: implement should_stop for generate_responses agent_span.end() @@ -429,5 +435,7 @@ def generate_response( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> AsyncGenerator[Union[str, FunctionCall], None]: + ) -> AsyncGenerator[ + Tuple[Union[str, FunctionCall], bool], None + ]: # tuple of the content and whether it is interruptible raise NotImplementedError diff --git a/vocode/streaming/agent/chat_gpt_agent.py b/vocode/streaming/agent/chat_gpt_agent.py index 7a4ee7240..1d1620d4f 100644 --- a/vocode/streaming/agent/chat_gpt_agent.py +++ b/vocode/streaming/agent/chat_gpt_agent.py @@ -69,7 +69,9 @@ def get_functions(self): for action_config in self.agent_config.actions ] - def get_chat_parameters(self, messages: Optional[List] = None): + def get_chat_parameters( + self, messages: Optional[List] = None, use_functions: bool = True + ): assert self.transcript is not None messages = messages or format_openai_chat_messages_from_transcript( self.transcript, self.agent_config.prompt_preamble @@ -86,7 +88,7 @@ def get_chat_parameters(self, messages: Optional[List] = None): else: parameters["model"] = self.agent_config.model_name - if self.functions: + if use_functions and self.functions: parameters["functions"] = self.functions return parameters @@ -134,10 +136,10 @@ async def generate_response( human_input: str, conversation_id: str, is_interrupt: bool = False, - ) -> AsyncGenerator[Union[str, FunctionCall], None]: + ) -> AsyncGenerator[Tuple[Union[str, FunctionCall], bool], None]: if is_interrupt and self.agent_config.cut_off_response: cut_off_response = self.get_cut_off_response() - yield cut_off_response + yield cut_off_response, False return assert self.transcript is not None @@ -174,4 +176,4 @@ async def generate_response( async for message in collate_response_async( openai_get_tokens(stream), get_functions=True ): - yield message + yield message, True diff --git a/vocode/streaming/agent/echo_agent.py b/vocode/streaming/agent/echo_agent.py index 103356652..cd4930dff 100644 --- a/vocode/streaming/agent/echo_agent.py +++ b/vocode/streaming/agent/echo_agent.py @@ -17,8 +17,8 @@ async def generate_response( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> AsyncGenerator[str, None]: - yield human_input + ) -> AsyncGenerator[Tuple[str, bool], None]: + yield human_input, True def update_last_bot_message_on_cut_off(self, message: str): pass diff --git a/vocode/streaming/agent/llamacpp_agent.py b/vocode/streaming/agent/llamacpp_agent.py index 10c88566f..ae938172e 100644 --- a/vocode/streaming/agent/llamacpp_agent.py +++ b/vocode/streaming/agent/llamacpp_agent.py @@ -135,7 +135,7 @@ async def generate_response( human_input: str, conversation_id: str, is_interrupt: bool = False, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[Tuple[str, bool], None]: asyncio.get_event_loop().run_in_executor( self.thread_pool_executor, lambda input: self.conversation.predict(input=input), @@ -145,4 +145,4 @@ async def generate_response( async for message in collate_response_async( self.llamacpp_get_tokens(), ): - yield str(message) + yield str(message), True diff --git a/vocode/streaming/agent/llm_agent.py b/vocode/streaming/agent/llm_agent.py index 5f2ccf9c6..a3d510da2 100644 --- a/vocode/streaming/agent/llm_agent.py +++ b/vocode/streaming/agent/llm_agent.py @@ -123,12 +123,12 @@ async def generate_response( human_input, conversation_id: str, is_interrupt: bool = False, - ) -> AsyncGenerator[str, None]: + ) -> AsyncGenerator[Tuple[str, bool], None]: self.logger.debug("LLM generating response to human input") if is_interrupt and self.agent_config.cut_off_response: cut_off_response = self.get_cut_off_response() self.memory.append(self.get_memory_entry(human_input, cut_off_response)) - yield cut_off_response + yield cut_off_response, False return self.memory.append(self.get_memory_entry(human_input, "")) if self.is_first_response and self.first_response: @@ -146,7 +146,7 @@ async def generate_response( sentence = re.sub(r"^\s+(.*)", r" \1", sentence) response_buffer += sentence self.memory[-1] = self.get_memory_entry(human_input, response_buffer) - yield sentence + yield sentence, True def update_last_bot_message_on_cut_off(self, message: str): last_message = self.memory[-1]