diff --git a/playground/streaming/agent/chat.py b/playground/streaming/agent/chat.py index 2685c0fdd..d539ac474 100644 --- a/playground/streaming/agent/chat.py +++ b/playground/streaming/agent/chat.py @@ -2,8 +2,16 @@ import typing from dotenv import load_dotenv from playground.streaming.tracing_utils import make_parser_and_maybe_trace +from pydantic import BaseModel +from vocode.streaming.action.base_action import BaseAction +from vocode.streaming.action.factory import ActionFactory from vocode.streaming.action.worker import ActionsWorker -from vocode.streaming.models.actions import ActionType +from vocode.streaming.models.actions import ( + ActionConfig, + ActionInput, + ActionOutput, + ActionType, +) from vocode.streaming.models.agent import ChatGPTAgentConfig from vocode.streaming.models.transcript import Transcript from vocode.streaming.utils.state_manager import ConversationStateManager @@ -22,6 +30,46 @@ from vocode.streaming.utils import create_conversation_id +class ShoutActionConfig(ActionConfig): + type = "shout" + num_exclamation_marks: int + + +class ShoutActionParameters(BaseModel): + name: str + + +class ShoutActionResponse(BaseModel): + success: bool + + +class ShoutAction( + BaseAction[ShoutActionConfig, ShoutActionParameters, ShoutActionResponse] +): + description: str = "Shouts someone's name" + parameters_type: typing.Type[ShoutActionParameters] = ShoutActionParameters + response_type: typing.Type[ShoutActionResponse] = ShoutActionResponse + + async def run( + self, action_input: ActionInput[ShoutActionParameters] + ) -> ActionOutput[ShoutActionResponse]: + print( + f"HI THERE {action_input.params.name}{self.action_config.num_exclamation_marks * '!'}" + ) + return ActionOutput( + action_type=self.action_config.type, + response=ShoutActionResponse(success=True), + ) + + +class ShoutActionFactory(ActionFactory): + def create_action(self, action_config: ActionConfig) -> BaseAction: + if isinstance(action_config, ShoutActionConfig): + return ShoutAction(action_config, should_respond=True) + else: + raise Exception("Invalid action type") + + class DummyConversationManager(ConversationStateManager): pass @@ -96,9 +144,12 @@ async def agent_main(): # Replace with your agent! agent = ChatGPTAgent( ChatGPTAgentConfig( - prompt_preamble="the assistant is ready to help you send emails", - actions=[ActionType.NYLAS_SEND_EMAIL.value], - ) + prompt_preamble="have a conversation", + actions=[ + ShoutActionConfig(num_exclamation_marks=3), + ], + ), + action_factory=ShoutActionFactory(), ) agent.attach_conversation_state_manager(DummyConversationManager(conversation=None)) agent.attach_transcript(transcript) diff --git a/vocode/streaming/action/base_action.py b/vocode/streaming/action/base_action.py index 81324b8d6..201b6ac63 100644 --- a/vocode/streaming/action/base_action.py +++ b/vocode/streaming/action/base_action.py @@ -1,6 +1,7 @@ -from typing import Any, Dict, Generic, Optional, Type +from typing import Any, Dict, Generic, Type, TypeVar from vocode.streaming.action.utils import exclude_keys_recursive from vocode.streaming.models.actions import ( + ActionConfig, ActionInput, ActionOutput, ActionType, @@ -9,12 +10,20 @@ ) from vocode.streaming.utils.state_manager import ConversationStateManager +ActionConfigType = TypeVar("ActionConfigType", bound=ActionConfig) -class BaseAction(Generic[ParametersType, ResponseType]): + +class BaseAction(Generic[ActionConfigType, ParametersType, ResponseType]): description: str = "" - action_type: str = ActionType.BASE.value - def __init__(self, should_respond: bool = False, quiet: bool = False, is_interruptible: bool = True): + def __init__( + self, + action_config: ActionConfigType, + should_respond: bool = False, + quiet: bool = False, + is_interruptible: bool = True, + ): + self.action_config = action_config self.should_respond = should_respond self.quiet = quiet self.is_interruptible = is_interruptible @@ -47,7 +56,7 @@ def get_openai_function(self): parameters_schema["required"].append("user_message") return { - "name": self.action_type, + "name": self.action_config.type, "description": self.description, "parameters": parameters_schema, } @@ -60,7 +69,7 @@ def create_action_input( if "user_message" in params: del params["user_message"] return ActionInput( - action_type=self.action_type, + action_config=self.action_config, conversation_id=conversation_id, params=self.parameters_type(**params), ) diff --git a/vocode/streaming/action/factory.py b/vocode/streaming/action/factory.py index dd7eb81fc..053acf52d 100644 --- a/vocode/streaming/action/factory.py +++ b/vocode/streaming/action/factory.py @@ -1,11 +1,14 @@ from vocode.streaming.action.base_action import BaseAction -from vocode.streaming.action.nylas_send_email import NylasSendEmail -from vocode.streaming.models.actions import ActionType +from vocode.streaming.action.nylas_send_email import ( + NylasSendEmail, + NylasSendEmailActionConfig, +) +from vocode.streaming.models.actions import ActionConfig class ActionFactory: - def create_action(self, action_type: str) -> BaseAction: - if action_type == ActionType.NYLAS_SEND_EMAIL: - return NylasSendEmail(should_respond=True) + def create_action(self, action_config: ActionConfig) -> BaseAction: + if isinstance(action_config, NylasSendEmailActionConfig): + return NylasSendEmail(action_config, should_respond=True) else: raise Exception("Invalid action type") diff --git a/vocode/streaming/action/nylas_send_email.py b/vocode/streaming/action/nylas_send_email.py index 654a98a18..004b48d8b 100644 --- a/vocode/streaming/action/nylas_send_email.py +++ b/vocode/streaming/action/nylas_send_email.py @@ -2,7 +2,16 @@ from pydantic import BaseModel, Field import os from vocode.streaming.action.base_action import BaseAction -from vocode.streaming.models.actions import ActionInput, ActionOutput, ActionType +from vocode.streaming.models.actions import ( + ActionConfig, + ActionInput, + ActionOutput, + ActionType, +) + + +class NylasSendEmailActionConfig(ActionConfig): + type = ActionType.NYLAS_SEND_EMAIL class NylasSendEmailParameters(BaseModel): @@ -15,9 +24,12 @@ class NylasSendEmailResponse(BaseModel): success: bool -class NylasSendEmail(BaseAction[NylasSendEmailParameters, NylasSendEmailResponse]): +class NylasSendEmail( + BaseAction[ + NylasSendEmailActionConfig, NylasSendEmailParameters, NylasSendEmailResponse + ] +): description: str = "Sends an email using Nylas API." - action_type: str = ActionType.NYLAS_SEND_EMAIL.value parameters_type: Type[NylasSendEmailParameters] = NylasSendEmailParameters response_type: Type[NylasSendEmailResponse] = NylasSendEmailResponse @@ -45,6 +57,6 @@ async def run( draft.send() return ActionOutput( - action_type=action_input.action_type, + action_type=self.action_config.type, response=NylasSendEmailResponse(success=True), ) diff --git a/vocode/streaming/action/phone_call_action.py b/vocode/streaming/action/phone_call_action.py index 18b846f73..1dc2efc2f 100644 --- a/vocode/streaming/action/phone_call_action.py +++ b/vocode/streaming/action/phone_call_action.py @@ -1,5 +1,5 @@ from typing import Dict, Any -from vocode.streaming.action.base_action import BaseAction +from vocode.streaming.action.base_action import ActionConfigType, BaseAction from vocode.streaming.models.actions import ( ActionInput, ActionOutput, @@ -10,14 +10,14 @@ ) -class VonagePhoneCallAction(BaseAction[ParametersType, ResponseType]): +class VonagePhoneCallAction(BaseAction[ActionConfigType, ParametersType, ResponseType]): def create_phone_call_action_input( self, conversation_id: str, params: Dict[str, Any], vonage_uuid: str ) -> VonagePhoneCallActionInput[ParametersType]: if "user_message" in params: del params["user_message"] return VonagePhoneCallActionInput( - action_type=self.action_type, + action_config=self.action_config, conversation_id=conversation_id, params=self.parameters_type(**params), vonage_uuid=vonage_uuid, @@ -28,14 +28,14 @@ def get_vonage_uuid(self, action_input: ActionInput[ParametersType]) -> str: return action_input.vonage_uuid -class TwilioPhoneCallAction(BaseAction[ParametersType, ResponseType]): +class TwilioPhoneCallAction(BaseAction[ActionConfigType, ParametersType, ResponseType]): def create_phone_call_action_input( self, conversation_id: str, params: Dict[str, Any], twilio_sid: str ) -> TwilioPhoneCallActionInput[ParametersType]: if "user_message" in params: del params["user_message"] return TwilioPhoneCallActionInput( - action_type=self.action_type, + action_config=self.action_config, conversation_id=conversation_id, params=self.parameters_type(**params), twilio_sid=twilio_sid, diff --git a/vocode/streaming/action/worker.py b/vocode/streaming/action/worker.py index b47d3b5f8..4bc5f309d 100644 --- a/vocode/streaming/action/worker.py +++ b/vocode/streaming/action/worker.py @@ -38,7 +38,7 @@ def attach_conversation_state_manager( async def process(self, item: InterruptibleEvent[ActionInput]): action_input = item.payload - action = self.action_factory.create_action(action_input.action_type) + action = self.action_factory.create_action(action_input.action_config) action.attach_conversation_state_manager(self.conversation_state_manager) action_output = await action.run(action_input) self.produce_interruptible_event_nonblocking( diff --git a/vocode/streaming/agent/base_agent.py b/vocode/streaming/agent/base_agent.py index 97e7851c6..42a56fe1f 100644 --- a/vocode/streaming/agent/base_agent.py +++ b/vocode/streaming/agent/base_agent.py @@ -15,6 +15,7 @@ VonagePhoneCallAction, ) from vocode.streaming.models.actions import ( + ActionConfig, ActionInput, ActionOutput, FunctionCall, @@ -315,8 +316,22 @@ async def process(self, item: InterruptibleEvent[AgentInput]): except asyncio.CancelledError: pass + def _get_action_config(self, function_name: str) -> Optional[ActionConfig]: + if self.agent_config.actions is None: + return None + for action_config in self.agent_config.actions: + if action_config.type == function_name: + return action_config + return None + def call_function(self, function_call: FunctionCall, agent_input: AgentInput): - action = self.action_factory.create_action(function_call.name) + action_config = self._get_action_config(function_call.name) + if action_config is None: + self.logger.error( + f"Function {function_call.name} not found in agent config, skipping" + ) + return + action = self.action_factory.create_action(action_config) params = json.loads(function_call.arguments) if "user_message" in params: user_message = params["user_message"] diff --git a/vocode/streaming/agent/chat_gpt_agent.py b/vocode/streaming/agent/chat_gpt_agent.py index 8892cb73f..3f0cc68c3 100644 --- a/vocode/streaming/agent/chat_gpt_agent.py +++ b/vocode/streaming/agent/chat_gpt_agent.py @@ -64,8 +64,8 @@ def get_functions(self): if not self.action_factory: return None return [ - self.action_factory.create_action(action_type).get_openai_function() - for action_type in self.agent_config.actions + self.action_factory.create_action(action_config).get_openai_function() + for action_config in self.agent_config.actions ] def get_chat_parameters(self, messages: Optional[List] = None): diff --git a/vocode/streaming/models/actions.py b/vocode/streaming/models/actions.py index 7a5bf1e82..ba57c4674 100644 --- a/vocode/streaming/models/actions.py +++ b/vocode/streaming/models/actions.py @@ -3,6 +3,10 @@ from pydantic import BaseModel +class ActionConfig(BaseModel): + type: str + + class ActionType(str, Enum): BASE = "action_base" NYLAS_SEND_EMAIL = "action_nylas_send_email" @@ -12,17 +16,20 @@ class ActionType(str, Enum): class ActionInput(BaseModel, Generic[ParametersType]): - action_type: str + action_config: ActionConfig conversation_id: str params: ParametersType + class FunctionFragment(BaseModel): - name: str - arguments: str + name: str + arguments: str + class FunctionCall(BaseModel): - name: str - arguments: str + name: str + arguments: str + class VonagePhoneCallActionInput(ActionInput[ParametersType]): vonage_uuid: str diff --git a/vocode/streaming/models/agent.py b/vocode/streaming/models/agent.py index 0c7c51e45..a4d6b8730 100644 --- a/vocode/streaming/models/agent.py +++ b/vocode/streaming/models/agent.py @@ -2,6 +2,7 @@ from enum import Enum from pydantic import validator +from vocode.streaming.models.actions import ActionConfig from vocode.streaming.models.message import BaseMessage from .model import TypedModel, BaseModel @@ -68,7 +69,8 @@ class AgentConfig(TypedModel, type=AgentType.BASE.value): send_filler_audio: Union[bool, FillerAudioConfig] = False webhook_config: Optional[WebhookConfig] = None track_bot_sentiment: bool = False - actions: Optional[List[str]] = None + actions: Optional[List[ActionConfig]] = None + class CutOffResponse(BaseModel): messages: List[BaseMessage] = [BaseMessage(text="Sorry?")] @@ -94,7 +96,6 @@ class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value): vector_db_config: Optional[VectorDBConfig] = None - class ChatAnthropicAgentConfig(AgentConfig, type=AgentType.CHAT_ANTHROPIC.value): prompt_preamble: str model_name: str = CHAT_ANTHROPIC_DEFAULT_MODEL_NAME @@ -107,7 +108,6 @@ class ChatVertexAIAgentConfig(AgentConfig, type=AgentType.CHAT_VERTEX_AI.value): generate_responses: bool = False # Google Vertex AI doesn't support streaming - class InformationRetrievalAgentConfig( AgentConfig, type=AgentType.INFORMATION_RETRIEVAL.value ): diff --git a/vocode/streaming/models/transcript.py b/vocode/streaming/models/transcript.py index 556be78c4..5d233eba2 100644 --- a/vocode/streaming/models/transcript.py +++ b/vocode/streaming/models/transcript.py @@ -128,7 +128,7 @@ def add_action_start_log(self, action_input: ActionInput, conversation_id: str): self.event_logs.append( ActionStart( action_input=action_input, - action_type=action_input.action_type, + action_type=action_input.action_config.type, timestamp=timestamp, ) )