From cb69ca8fd890ac9bdb1e1e318e593e95fb456cd7 Mon Sep 17 00:00:00 2001 From: Ajay Raj Date: Thu, 27 Jul 2023 12:24:13 -0700 Subject: [PATCH] allows action to wait on user message (#312) --- vocode/streaming/action/base_action.py | 7 ++++++- vocode/streaming/action/phone_call_action.py | 15 +++++++++++++-- vocode/streaming/agent/base_agent.py | 19 ++++++++++++++----- vocode/streaming/models/actions.py | 5 +++++ 4 files changed, 38 insertions(+), 8 deletions(-) diff --git a/vocode/streaming/action/base_action.py b/vocode/streaming/action/base_action.py index 0c2cfbd71..63a8fbb11 100644 --- a/vocode/streaming/action/base_action.py +++ b/vocode/streaming/action/base_action.py @@ -1,3 +1,4 @@ +import asyncio from typing import Any, Dict, Generic, Type, TypeVar, TYPE_CHECKING from vocode.streaming.action.utils import exclude_keys_recursive from vocode.streaming.models.actions import ( @@ -55,7 +56,9 @@ def get_openai_function(self): parameters_schema["properties"][ "user_message" ] = self._user_message_param_info() - parameters_schema["required"].append("user_message") + required = parameters_schema.get("required", []) + required.append("user_message") + parameters_schema["required"] = required return { "name": self.action_config.type, @@ -67,6 +70,7 @@ def create_action_input( self, conversation_id: str, params: Dict[str, Any], + user_message_tracker: asyncio.Event, ) -> ActionInput[ParametersType]: if "user_message" in params: del params["user_message"] @@ -74,6 +78,7 @@ def create_action_input( action_config=self.action_config, conversation_id=conversation_id, params=self.parameters_type(**params), + user_message_tracker=user_message_tracker, ) def _user_message_param_info(self): diff --git a/vocode/streaming/action/phone_call_action.py b/vocode/streaming/action/phone_call_action.py index 1dc2efc2f..7310fadbb 100644 --- a/vocode/streaming/action/phone_call_action.py +++ b/vocode/streaming/action/phone_call_action.py @@ -1,3 +1,4 @@ +import asyncio from typing import Dict, Any from vocode.streaming.action.base_action import ActionConfigType, BaseAction from vocode.streaming.models.actions import ( @@ -12,7 +13,11 @@ class VonagePhoneCallAction(BaseAction[ActionConfigType, ParametersType, ResponseType]): def create_phone_call_action_input( - self, conversation_id: str, params: Dict[str, Any], vonage_uuid: str + self, + conversation_id: str, + params: Dict[str, Any], + vonage_uuid: str, + user_message_tracker: asyncio.Event, ) -> VonagePhoneCallActionInput[ParametersType]: if "user_message" in params: del params["user_message"] @@ -21,6 +26,7 @@ def create_phone_call_action_input( conversation_id=conversation_id, params=self.parameters_type(**params), vonage_uuid=vonage_uuid, + user_message_tracker=user_message_tracker, ) def get_vonage_uuid(self, action_input: ActionInput[ParametersType]) -> str: @@ -30,7 +36,11 @@ def get_vonage_uuid(self, action_input: ActionInput[ParametersType]) -> str: class TwilioPhoneCallAction(BaseAction[ActionConfigType, ParametersType, ResponseType]): def create_phone_call_action_input( - self, conversation_id: str, params: Dict[str, Any], twilio_sid: str + self, + conversation_id: str, + params: Dict[str, Any], + twilio_sid: str, + user_message_tracker: asyncio.Event, ) -> TwilioPhoneCallActionInput[ParametersType]: if "user_message" in params: del params["user_message"] @@ -39,6 +49,7 @@ def create_phone_call_action_input( conversation_id=conversation_id, params=self.parameters_type(**params), twilio_sid=twilio_sid, + user_message_tracker=user_message_tracker, ) def get_twilio_sid(self, action_input: ActionInput[ParametersType]) -> str: diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index 2a347d6ed..d471bda84 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -235,7 +235,7 @@ async def handle_generate_response( # TODO: implement should_stop for generate_responses agent_span.end() if function_call and self.agent_config.actions is not None: - self.call_function(function_call, agent_input) + await self.call_function(function_call, agent_input) return False async def handle_respond( @@ -346,7 +346,7 @@ def _get_action_config(self, function_name: str) -> Optional[ActionConfig]: return action_config return None - def call_function(self, function_call: FunctionCall, agent_input: AgentInput): + async def call_function(self, function_call: FunctionCall, agent_input: AgentInput): action_config = self._get_action_config(function_call.name) if action_config is None: self.logger.error( @@ -357,8 +357,10 @@ def call_function(self, function_call: FunctionCall, agent_input: AgentInput): params = json.loads(function_call.arguments) if "user_message" in params: user_message = params["user_message"] + user_message_tracker = asyncio.Event() self.produce_interruptible_agent_response_event_nonblocking( - AgentResponseMessage(message=BaseMessage(text=user_message)) + AgentResponseMessage(message=BaseMessage(text=user_message)), + agent_response_tracker=user_message_tracker, ) action_input: ActionInput if isinstance(action, VonagePhoneCallAction): @@ -366,19 +368,26 @@ def call_function(self, function_call: FunctionCall, agent_input: AgentInput): agent_input.vonage_uuid is not None ), "Cannot use VonagePhoneCallActionFactory unless the attached conversation is a VonageCall" action_input = action.create_phone_call_action_input( - agent_input.conversation_id, params, agent_input.vonage_uuid + agent_input.conversation_id, + params, + agent_input.vonage_uuid, + user_message_tracker, ) elif isinstance(action, TwilioPhoneCallAction): assert ( agent_input.twilio_sid is not None ), "Cannot use TwilioPhoneCallActionFactory unless the attached conversation is a TwilioCall" action_input = action.create_phone_call_action_input( - agent_input.conversation_id, params, agent_input.twilio_sid + agent_input.conversation_id, + params, + agent_input.twilio_sid, + user_message_tracker, ) else: action_input = action.create_action_input( agent_input.conversation_id, params, + user_message_tracker, ) event = self.interruptible_event_factory.create_interruptible_event( action_input, is_interruptible=action.is_interruptible diff --git a/vocode/streaming/models/actions.py b/vocode/streaming/models/actions.py index cc764a048..5f9a046fd 100644 --- a/vocode/streaming/models/actions.py +++ b/vocode/streaming/models/actions.py @@ -1,3 +1,4 @@ +import asyncio from enum import Enum from typing import Generic, TypeVar from pydantic import BaseModel @@ -20,6 +21,10 @@ class ActionInput(BaseModel, Generic[ParametersType]): action_config: ActionConfig conversation_id: str params: ParametersType + user_message_tracker: asyncio.Event + + class Config: + arbitrary_types_allowed = True class FunctionFragment(BaseModel):