Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

allows conversation state manager to wait on utterances sent to the call #308

Merged
merged 3 commits into from
Jul 27, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion playground/streaming/agent/chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async def sender():
None, lambda: input("Human: ")
)
agent.consume_nonblocking(
agent.interruptible_event_factory.create(
agent.interruptible_event_factory.create_interruptible_event(
TranscriptionAgentInput(
transcription=Transcription(
message=message, confidence=1.0, is_final=True
Expand Down
4 changes: 3 additions & 1 deletion playground/streaming/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,9 @@ async def run_agents():
),
conversation_id=0,
)
agent.consume_nonblocking(agent.interruptible_event_factory.create(message))
agent.consume_nonblocking(
agent.interruptible_event_factory.create_interruptible_event(message)
)

while True:
try:
Expand Down
8 changes: 5 additions & 3 deletions vocode/streaming/action/base_action.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Dict, Generic, Type, TypeVar
from typing import Any, Dict, Generic, Type, TypeVar, TYPE_CHECKING
from vocode.streaming.action.utils import exclude_keys_recursive
from vocode.streaming.models.actions import (
ActionConfig,
Expand All @@ -8,7 +8,9 @@
ParametersType,
ResponseType,
)
from vocode.streaming.utils.state_manager import ConversationStateManager

if TYPE_CHECKING:
from vocode.streaming.utils.state_manager import ConversationStateManager

ActionConfigType = TypeVar("ActionConfigType", bound=ActionConfig)

Expand All @@ -29,7 +31,7 @@ def __init__(
self.is_interruptible = is_interruptible

def attach_conversation_state_manager(
self, conversation_state_manager: ConversationStateManager
self, conversation_state_manager: "ConversationStateManager"
):
self.conversation_state_manager = conversation_state_manager

Expand Down
42 changes: 30 additions & 12 deletions vocode/streaming/agent/base_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,16 @@
import json
import logging
import random
from typing import AsyncGenerator, Generator, Generic, Optional, Tuple, TypeVar, Union
from typing import (
AsyncGenerator,
Generator,
Generic,
Optional,
Tuple,
TypeVar,
Union,
TYPE_CHECKING,
)
import typing
from opentelemetry import trace
from opentelemetry.trace import Span
Expand Down Expand Up @@ -33,13 +42,16 @@
from vocode.streaming.utils import remove_non_letters_digits
from vocode.streaming.utils.goodbye_model import GoodbyeModel
from vocode.streaming.models.transcript import Transcript
from vocode.streaming.utils.state_manager import ConversationStateManager
from vocode.streaming.utils.worker import (
InterruptibleAgentResponseEvent,
InterruptibleEvent,
InterruptibleEventFactory,
InterruptibleWorker,
)

if TYPE_CHECKING:
from vocode.streaming.utils.state_manager import ConversationStateManager

tracer = trace.get_tracer(__name__)
AGENT_TRACE_NAME = "agent"

Expand Down Expand Up @@ -128,7 +140,7 @@ def __init__(
InterruptibleEvent[AgentInput]
] = asyncio.Queue()
self.output_queue: asyncio.Queue[
InterruptibleEvent[AgentResponse]
InterruptibleAgentResponseEvent[AgentResponse]
] = asyncio.Queue()
AbstractAgent.__init__(self, agent_config=agent_config)
InterruptibleWorker.__init__(
Expand Down Expand Up @@ -167,7 +179,7 @@ def attach_conversation_state_manager(
def start(self):
super().start()
if self.agent_config.initial_message is not None:
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=self.agent_config.initial_message),
is_interruptible=False,
)
Expand All @@ -180,7 +192,9 @@ def get_input_queue(
) -> asyncio.Queue[InterruptibleEvent[AgentInput]]:
return self.input_queue

def get_output_queue(self) -> asyncio.Queue[InterruptibleEvent[AgentResponse]]:
def get_output_queue(
self,
) -> asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]]:
return self.output_queue

def create_goodbye_detection_task(self, message: str) -> asyncio.Task:
Expand Down Expand Up @@ -214,7 +228,7 @@ async def handle_generate_response(
if is_first_response:
agent_span_first.end()
is_first_response = False
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=BaseMessage(text=response)),
is_interruptible=self.agent_config.allow_agent_to_be_cut_off,
)
Expand All @@ -240,7 +254,7 @@ async def handle_respond(
response = None
return True
if response:
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=BaseMessage(text=response)),
is_interruptible=self.agent_config.allow_agent_to_be_cut_off,
)
Expand Down Expand Up @@ -288,7 +302,9 @@ async def process(self, item: InterruptibleEvent[AgentInput]):
transcription.message
)
if self.agent_config.send_filler_audio:
self.produce_interruptible_event_nonblocking(AgentResponseFillerAudio())
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseFillerAudio()
)
self.logger.debug("Responding to transcription")
should_stop = False
if self.agent_config.generate_responses:
Expand All @@ -302,7 +318,9 @@ async def process(self, item: InterruptibleEvent[AgentInput]):

if should_stop:
self.logger.debug("Agent requested to stop")
self.produce_interruptible_event_nonblocking(AgentResponseStop())
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseStop()
)
return
if goodbye_detected_task:
try:
Expand All @@ -311,7 +329,7 @@ async def process(self, item: InterruptibleEvent[AgentInput]):
)
if goodbye_detected:
self.logger.debug("Goodbye detected, ending conversation")
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseStop()
)
return
Expand Down Expand Up @@ -339,7 +357,7 @@ def call_function(self, function_call: FunctionCall, agent_input: AgentInput):
params = json.loads(function_call.arguments)
if "user_message" in params:
user_message = params["user_message"]
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
AgentResponseMessage(message=BaseMessage(text=user_message))
)
action_input: ActionInput
Expand All @@ -362,7 +380,7 @@ def call_function(self, function_call: FunctionCall, agent_input: AgentInput):
agent_input.conversation_id,
params,
)
event = self.interruptible_event_factory.create(
event = self.interruptible_event_factory.create_interruptible_event(
action_input, is_interruptible=action.is_interruptible
)
assert self.transcript is not None
Expand Down
11 changes: 7 additions & 4 deletions vocode/streaming/agent/websocket_user_implemented_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,10 @@
import logging
from typing import Dict
from vocode.streaming.transcriber.base_transcriber import Transcription
from vocode.streaming.utils.worker import InterruptibleEvent
from vocode.streaming.utils.worker import (
InterruptibleAgentResponseEvent,
InterruptibleEvent,
)
import websockets
from websockets.client import (
connect,
Expand Down Expand Up @@ -33,7 +36,7 @@

class WebSocketUserImplementedAgent(BaseAgent[WebSocketUserImplementedAgentConfig]):
input_queue: asyncio.Queue[InterruptibleEvent[AgentInput]]
output_queue: asyncio.Queue[InterruptibleEvent[AgentResponse]]
output_queue: asyncio.Queue[InterruptibleAgentResponseEvent[AgentResponse]]

def __init__(
self,
Expand Down Expand Up @@ -75,7 +78,7 @@ def _handle_incoming_socket_message(self, message: WebSocketAgentMessage) -> Non
raise Exception("Unknown Socket message type")

self.logger.info("Putting interruptible agent response event in output queue")
self.produce_interruptible_event_nonblocking(
self.produce_interruptible_agent_response_event_nonblocking(
agent_response, self.get_agent_config().allow_agent_to_be_cut_off
)

Expand Down Expand Up @@ -161,5 +164,5 @@ async def receiver(ws: WebSocketClientProtocol) -> None:
await asyncio.gather(sender(ws), receiver(ws))

def terminate(self):
self.output_queue.put_nowait(AgentResponseStop())
self.produce_interruptible_agent_response_event_nonblocking(AgentResponseStop())
super().terminate()
Loading
Loading