Skip to content

Commit

Permalink
allows action to wait on user message (#312)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 authored Jul 27, 2023
1 parent e007d77 commit cb69ca8
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 8 deletions.
7 changes: 6 additions & 1 deletion vocode/streaming/action/base_action.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Any, Dict, Generic, Type, TypeVar, TYPE_CHECKING
from vocode.streaming.action.utils import exclude_keys_recursive
from vocode.streaming.models.actions import (
Expand Down Expand Up @@ -55,7 +56,9 @@ def get_openai_function(self):
parameters_schema["properties"][
"user_message"
] = self._user_message_param_info()
parameters_schema["required"].append("user_message")
required = parameters_schema.get("required", [])
required.append("user_message")
parameters_schema["required"] = required

return {
"name": self.action_config.type,
Expand All @@ -67,13 +70,15 @@ def create_action_input(
self,
conversation_id: str,
params: Dict[str, Any],
user_message_tracker: asyncio.Event,
) -> ActionInput[ParametersType]:
if "user_message" in params:
del params["user_message"]
return ActionInput(
action_config=self.action_config,
conversation_id=conversation_id,
params=self.parameters_type(**params),
user_message_tracker=user_message_tracker,
)

def _user_message_param_info(self):
Expand Down
15 changes: 13 additions & 2 deletions vocode/streaming/action/phone_call_action.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from typing import Dict, Any
from vocode.streaming.action.base_action import ActionConfigType, BaseAction
from vocode.streaming.models.actions import (
Expand All @@ -12,7 +13,11 @@

class VonagePhoneCallAction(BaseAction[ActionConfigType, ParametersType, ResponseType]):
def create_phone_call_action_input(
self, conversation_id: str, params: Dict[str, Any], vonage_uuid: str
self,
conversation_id: str,
params: Dict[str, Any],
vonage_uuid: str,
user_message_tracker: asyncio.Event,
) -> VonagePhoneCallActionInput[ParametersType]:
if "user_message" in params:
del params["user_message"]
Expand All @@ -21,6 +26,7 @@ def create_phone_call_action_input(
conversation_id=conversation_id,
params=self.parameters_type(**params),
vonage_uuid=vonage_uuid,
user_message_tracker=user_message_tracker,
)

def get_vonage_uuid(self, action_input: ActionInput[ParametersType]) -> str:
Expand All @@ -30,7 +36,11 @@ def get_vonage_uuid(self, action_input: ActionInput[ParametersType]) -> str:

class TwilioPhoneCallAction(BaseAction[ActionConfigType, ParametersType, ResponseType]):
def create_phone_call_action_input(
self, conversation_id: str, params: Dict[str, Any], twilio_sid: str
self,
conversation_id: str,
params: Dict[str, Any],
twilio_sid: str,
user_message_tracker: asyncio.Event,
) -> TwilioPhoneCallActionInput[ParametersType]:
if "user_message" in params:
del params["user_message"]
Expand All @@ -39,6 +49,7 @@ def create_phone_call_action_input(
conversation_id=conversation_id,
params=self.parameters_type(**params),
twilio_sid=twilio_sid,
user_message_tracker=user_message_tracker,
)

def get_twilio_sid(self, action_input: ActionInput[ParametersType]) -> str:
Expand Down
19 changes: 14 additions & 5 deletions vocode/streaming/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,7 @@ async def handle_generate_response(
# TODO: implement should_stop for generate_responses
agent_span.end()
if function_call and self.agent_config.actions is not None:
self.call_function(function_call, agent_input)
await self.call_function(function_call, agent_input)
return False

async def handle_respond(
Expand Down Expand Up @@ -346,7 +346,7 @@ def _get_action_config(self, function_name: str) -> Optional[ActionConfig]:
return action_config
return None

def call_function(self, function_call: FunctionCall, agent_input: AgentInput):
async def call_function(self, function_call: FunctionCall, agent_input: AgentInput):
action_config = self._get_action_config(function_call.name)
if action_config is None:
self.logger.error(
Expand All @@ -357,28 +357,37 @@ def call_function(self, function_call: FunctionCall, agent_input: AgentInput):
params = json.loads(function_call.arguments)
if "user_message" in params:
user_message = params["user_message"]
user_message_tracker = asyncio.Event()
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=BaseMessage(text=user_message))
AgentResponseMessage(message=BaseMessage(text=user_message)),
agent_response_tracker=user_message_tracker,
)
action_input: ActionInput
if isinstance(action, VonagePhoneCallAction):
assert (
agent_input.vonage_uuid is not None
), "Cannot use VonagePhoneCallActionFactory unless the attached conversation is a VonageCall"
action_input = action.create_phone_call_action_input(
agent_input.conversation_id, params, agent_input.vonage_uuid
agent_input.conversation_id,
params,
agent_input.vonage_uuid,
user_message_tracker,
)
elif isinstance(action, TwilioPhoneCallAction):
assert (
agent_input.twilio_sid is not None
), "Cannot use TwilioPhoneCallActionFactory unless the attached conversation is a TwilioCall"
action_input = action.create_phone_call_action_input(
agent_input.conversation_id, params, agent_input.twilio_sid
agent_input.conversation_id,
params,
agent_input.twilio_sid,
user_message_tracker,
)
else:
action_input = action.create_action_input(
agent_input.conversation_id,
params,
user_message_tracker,
)
event = self.interruptible_event_factory.create_interruptible_event(
action_input, is_interruptible=action.is_interruptible
Expand Down
5 changes: 5 additions & 0 deletions vocode/streaming/models/actions.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from enum import Enum
from typing import Generic, TypeVar
from pydantic import BaseModel
Expand All @@ -20,6 +21,10 @@ class ActionInput(BaseModel, Generic[ParametersType]):
action_config: ActionConfig
conversation_id: str
params: ParametersType
user_message_tracker: asyncio.Event

class Config:
arbitrary_types_allowed = True


class FunctionFragment(BaseModel):
Expand Down

0 comments on commit cb69ca8

Please sign in to comment.