Skip to content

Commit

Permalink
creates ActionConfigs for preset configuration on Actions (#300)
Browse files Browse the repository at this point in the history
* adds notion of actionconfigs

* fix import

* adds playground action
  • Loading branch information
ajar98 authored Jul 25, 2023
1 parent 7b543f5 commit 84f852f
Show file tree
Hide file tree
Showing 11 changed files with 134 additions and 37 deletions.
59 changes: 55 additions & 4 deletions playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,16 @@
import typing
from dotenv import load_dotenv
from playground.streaming.tracing_utils import make_parser_and_maybe_trace
from pydantic import BaseModel
from vocode.streaming.action.base_action import BaseAction
from vocode.streaming.action.factory import ActionFactory
from vocode.streaming.action.worker import ActionsWorker
from vocode.streaming.models.actions import ActionType
from vocode.streaming.models.actions import (
ActionConfig,
ActionInput,
ActionOutput,
ActionType,
)
from vocode.streaming.models.agent import ChatGPTAgentConfig
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.utils.state_manager import ConversationStateManager
Expand All @@ -22,6 +30,46 @@
from vocode.streaming.utils import create_conversation_id


class ShoutActionConfig(ActionConfig):
type = "shout"
num_exclamation_marks: int


class ShoutActionParameters(BaseModel):
name: str


class ShoutActionResponse(BaseModel):
success: bool


class ShoutAction(
BaseAction[ShoutActionConfig, ShoutActionParameters, ShoutActionResponse]
):
description: str = "Shouts someone's name"
parameters_type: typing.Type[ShoutActionParameters] = ShoutActionParameters
response_type: typing.Type[ShoutActionResponse] = ShoutActionResponse

async def run(
self, action_input: ActionInput[ShoutActionParameters]
) -> ActionOutput[ShoutActionResponse]:
print(
f"HI THERE {action_input.params.name}{self.action_config.num_exclamation_marks * '!'}"
)
return ActionOutput(
action_type=self.action_config.type,
response=ShoutActionResponse(success=True),
)


class ShoutActionFactory(ActionFactory):
def create_action(self, action_config: ActionConfig) -> BaseAction:
if isinstance(action_config, ShoutActionConfig):
return ShoutAction(action_config, should_respond=True)
else:
raise Exception("Invalid action type")


class DummyConversationManager(ConversationStateManager):
pass

Expand Down Expand Up @@ -96,9 +144,12 @@ async def agent_main():
# Replace with your agent!
agent = ChatGPTAgent(
ChatGPTAgentConfig(
prompt_preamble="the assistant is ready to help you send emails",
actions=[ActionType.NYLAS_SEND_EMAIL.value],
)
prompt_preamble="have a conversation",
actions=[
ShoutActionConfig(num_exclamation_marks=3),
],
),
action_factory=ShoutActionFactory(),
)
agent.attach_conversation_state_manager(DummyConversationManager(conversation=None))
agent.attach_transcript(transcript)
Expand Down
21 changes: 15 additions & 6 deletions vocode/streaming/action/base_action.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, Generic, Optional, Type
from typing import Any, Dict, Generic, Type, TypeVar
from vocode.streaming.action.utils import exclude_keys_recursive
from vocode.streaming.models.actions import (
ActionConfig,
ActionInput,
ActionOutput,
ActionType,
Expand All @@ -9,12 +10,20 @@
)
from vocode.streaming.utils.state_manager import ConversationStateManager

ActionConfigType = TypeVar("ActionConfigType", bound=ActionConfig)

class BaseAction(Generic[ParametersType, ResponseType]):

class BaseAction(Generic[ActionConfigType, ParametersType, ResponseType]):
description: str = ""
action_type: str = ActionType.BASE.value

def __init__(self, should_respond: bool = False, quiet: bool = False, is_interruptible: bool = True):
def __init__(
self,
action_config: ActionConfigType,
should_respond: bool = False,
quiet: bool = False,
is_interruptible: bool = True,
):
self.action_config = action_config
self.should_respond = should_respond
self.quiet = quiet
self.is_interruptible = is_interruptible
Expand Down Expand Up @@ -47,7 +56,7 @@ def get_openai_function(self):
parameters_schema["required"].append("user_message")

return {
"name": self.action_type,
"name": self.action_config.type,
"description": self.description,
"parameters": parameters_schema,
}
Expand All @@ -60,7 +69,7 @@ def create_action_input(
if "user_message" in params:
del params["user_message"]
return ActionInput(
action_type=self.action_type,
action_config=self.action_config,
conversation_id=conversation_id,
params=self.parameters_type(**params),
)
Expand Down
13 changes: 8 additions & 5 deletions vocode/streaming/action/factory.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from vocode.streaming.action.base_action import BaseAction
from vocode.streaming.action.nylas_send_email import NylasSendEmail
from vocode.streaming.models.actions import ActionType
from vocode.streaming.action.nylas_send_email import (
NylasSendEmail,
NylasSendEmailActionConfig,
)
from vocode.streaming.models.actions import ActionConfig


class ActionFactory:
def create_action(self, action_type: str) -> BaseAction:
if action_type == ActionType.NYLAS_SEND_EMAIL:
return NylasSendEmail(should_respond=True)
def create_action(self, action_config: ActionConfig) -> BaseAction:
if isinstance(action_config, NylasSendEmailActionConfig):
return NylasSendEmail(action_config, should_respond=True)
else:
raise Exception("Invalid action type")
20 changes: 16 additions & 4 deletions vocode/streaming/action/nylas_send_email.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,16 @@
from pydantic import BaseModel, Field
import os
from vocode.streaming.action.base_action import BaseAction
from vocode.streaming.models.actions import ActionInput, ActionOutput, ActionType
from vocode.streaming.models.actions import (
ActionConfig,
ActionInput,
ActionOutput,
ActionType,
)


class NylasSendEmailActionConfig(ActionConfig):
type = ActionType.NYLAS_SEND_EMAIL


class NylasSendEmailParameters(BaseModel):
Expand All @@ -15,9 +24,12 @@ class NylasSendEmailResponse(BaseModel):
success: bool


class NylasSendEmail(BaseAction[NylasSendEmailParameters, NylasSendEmailResponse]):
class NylasSendEmail(
BaseAction[
NylasSendEmailActionConfig, NylasSendEmailParameters, NylasSendEmailResponse
]
):
description: str = "Sends an email using Nylas API."
action_type: str = ActionType.NYLAS_SEND_EMAIL.value
parameters_type: Type[NylasSendEmailParameters] = NylasSendEmailParameters
response_type: Type[NylasSendEmailResponse] = NylasSendEmailResponse

Expand Down Expand Up @@ -45,6 +57,6 @@ async def run(
draft.send()

return ActionOutput(
action_type=action_input.action_type,
action_type=self.action_config.type,
response=NylasSendEmailResponse(success=True),
)
10 changes: 5 additions & 5 deletions vocode/streaming/action/phone_call_action.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Dict, Any
from vocode.streaming.action.base_action import BaseAction
from vocode.streaming.action.base_action import ActionConfigType, BaseAction
from vocode.streaming.models.actions import (
ActionInput,
ActionOutput,
Expand All @@ -10,14 +10,14 @@
)


class VonagePhoneCallAction(BaseAction[ParametersType, ResponseType]):
class VonagePhoneCallAction(BaseAction[ActionConfigType, ParametersType, ResponseType]):
def create_phone_call_action_input(
self, conversation_id: str, params: Dict[str, Any], vonage_uuid: str
) -> VonagePhoneCallActionInput[ParametersType]:
if "user_message" in params:
del params["user_message"]
return VonagePhoneCallActionInput(
action_type=self.action_type,
action_config=self.action_config,
conversation_id=conversation_id,
params=self.parameters_type(**params),
vonage_uuid=vonage_uuid,
Expand All @@ -28,14 +28,14 @@ def get_vonage_uuid(self, action_input: ActionInput[ParametersType]) -> str:
return action_input.vonage_uuid


class TwilioPhoneCallAction(BaseAction[ParametersType, ResponseType]):
class TwilioPhoneCallAction(BaseAction[ActionConfigType, ParametersType, ResponseType]):
def create_phone_call_action_input(
self, conversation_id: str, params: Dict[str, Any], twilio_sid: str
) -> TwilioPhoneCallActionInput[ParametersType]:
if "user_message" in params:
del params["user_message"]
return TwilioPhoneCallActionInput(
action_type=self.action_type,
action_config=self.action_config,
conversation_id=conversation_id,
params=self.parameters_type(**params),
twilio_sid=twilio_sid,
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/action/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def attach_conversation_state_manager(

async def process(self, item: InterruptibleEvent[ActionInput]):
action_input = item.payload
action = self.action_factory.create_action(action_input.action_type)
action = self.action_factory.create_action(action_input.action_config)
action.attach_conversation_state_manager(self.conversation_state_manager)
action_output = await action.run(action_input)
self.produce_interruptible_event_nonblocking(
Expand Down
17 changes: 16 additions & 1 deletion vocode/streaming/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
VonagePhoneCallAction,
)
from vocode.streaming.models.actions import (
ActionConfig,
ActionInput,
ActionOutput,
FunctionCall,
Expand Down Expand Up @@ -315,8 +316,22 @@ async def process(self, item: InterruptibleEvent[AgentInput]):
except asyncio.CancelledError:
pass

def _get_action_config(self, function_name: str) -> Optional[ActionConfig]:
if self.agent_config.actions is None:
return None
for action_config in self.agent_config.actions:
if action_config.type == function_name:
return action_config
return None

def call_function(self, function_call: FunctionCall, agent_input: AgentInput):
action = self.action_factory.create_action(function_call.name)
action_config = self._get_action_config(function_call.name)
if action_config is None:
self.logger.error(
f"Function {function_call.name} not found in agent config, skipping"
)
return
action = self.action_factory.create_action(action_config)
params = json.loads(function_call.arguments)
if "user_message" in params:
user_message = params["user_message"]
Expand Down
4 changes: 2 additions & 2 deletions vocode/streaming/agent/chat_gpt_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ def get_functions(self):
if not self.action_factory:
return None
return [
self.action_factory.create_action(action_type).get_openai_function()
for action_type in self.agent_config.actions
self.action_factory.create_action(action_config).get_openai_function()
for action_config in self.agent_config.actions
]

def get_chat_parameters(self, messages: Optional[List] = None):
Expand Down
17 changes: 12 additions & 5 deletions vocode/streaming/models/actions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@
from pydantic import BaseModel


class ActionConfig(BaseModel):
type: str


class ActionType(str, Enum):
BASE = "action_base"
NYLAS_SEND_EMAIL = "action_nylas_send_email"
Expand All @@ -12,17 +16,20 @@ class ActionType(str, Enum):


class ActionInput(BaseModel, Generic[ParametersType]):
action_type: str
action_config: ActionConfig
conversation_id: str
params: ParametersType


class FunctionFragment(BaseModel):
name: str
arguments: str
name: str
arguments: str


class FunctionCall(BaseModel):
name: str
arguments: str
name: str
arguments: str


class VonagePhoneCallActionInput(ActionInput[ParametersType]):
vonage_uuid: str
Expand Down
6 changes: 3 additions & 3 deletions vocode/streaming/models/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from enum import Enum

from pydantic import validator
from vocode.streaming.models.actions import ActionConfig

from vocode.streaming.models.message import BaseMessage
from .model import TypedModel, BaseModel
Expand Down Expand Up @@ -68,7 +69,8 @@ class AgentConfig(TypedModel, type=AgentType.BASE.value):
send_filler_audio: Union[bool, FillerAudioConfig] = False
webhook_config: Optional[WebhookConfig] = None
track_bot_sentiment: bool = False
actions: Optional[List[str]] = None
actions: Optional[List[ActionConfig]] = None


class CutOffResponse(BaseModel):
messages: List[BaseMessage] = [BaseMessage(text="Sorry?")]
Expand All @@ -94,7 +96,6 @@ class ChatGPTAgentConfig(AgentConfig, type=AgentType.CHAT_GPT.value):
vector_db_config: Optional[VectorDBConfig] = None



class ChatAnthropicAgentConfig(AgentConfig, type=AgentType.CHAT_ANTHROPIC.value):
prompt_preamble: str
model_name: str = CHAT_ANTHROPIC_DEFAULT_MODEL_NAME
Expand All @@ -107,7 +108,6 @@ class ChatVertexAIAgentConfig(AgentConfig, type=AgentType.CHAT_VERTEX_AI.value):
generate_responses: bool = False # Google Vertex AI doesn't support streaming



class InformationRetrievalAgentConfig(
AgentConfig, type=AgentType.INFORMATION_RETRIEVAL.value
):
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/models/transcript.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def add_action_start_log(self, action_input: ActionInput, conversation_id: str):
self.event_logs.append(
ActionStart(
action_input=action_input,
action_type=action_input.action_type,
action_type=action_input.action_config.type,
timestamp=timestamp,
)
)
Expand Down

0 comments on commit 84f852f

Please sign in to comment.