Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

create_response and cancel_response APIs for MultimodalAgent #1359

Merged
merged 12 commits into from
Jan 18, 2025
27 changes: 27 additions & 0 deletions examples/multimodal-agent/openai_manual_vad/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Push to Talk Example

This example demonstrates how to manually control the VAD of the OpenAI realtime agent using LiveKit's [RPC functionality](https://docs.livekit.io/home/client/data/rpc/).

## How It Works

1. The agent sets a `supports-ptt` attribute to indicate it supports push-to-talk functionality
2. The agent registers an RPC method `ptt` that handles push/release actions
3. When the button is pressed, the frontend sends an RPC call with `push` payload to interrupt the agent
4. When the button is released, the frontend sends an RPC call with `release` payload to commit the audio buffer

## Frontend Integration

A complete frontend implementation can be found in the [voice-assistant-frontend](https://github.com/livekit-examples/voice-assistant-frontend) repository. The frontend will:

1. Check for the `supports-ptt` attribute on the agent
2. If PTT is supported, enable the push-to-talk button
3. Send RPC calls to the agent when the button is pressed/released

## Running the Example

1. Start the agent:
```bash
python push_to_talk.py dev
```

2. Run the frontend application from [voice-assistant-frontend](https://github.com/livekit-examples/voice-assistant-frontend)
75 changes: 75 additions & 0 deletions examples/multimodal-agent/openai_manual_vad/push_to_talk.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from __future__ import annotations

import asyncio
import logging

from dotenv import load_dotenv
from livekit.agents import (
JobContext,
JobRequest,
WorkerOptions,
cli,
llm,
multimodal,
)
from livekit.plugins import openai

load_dotenv()

logger = logging.getLogger("my-worker")
logger.setLevel(logging.INFO)


async def entrypoint(ctx: JobContext):
logger.info("starting entrypoint")

await ctx.connect()
participant = await ctx.wait_for_participant()
agent = multimodal.MultimodalAgent(
model=openai.realtime.RealtimeModel(
voice="alloy",
temperature=0.8,
instructions="You are a helpful assistant",
turn_detection=None,
),
)
agent.start(ctx.room, participant)

@ctx.room.local_participant.register_rpc_method("ptt")
async def handle_ptt(data):
logger.info(f"Received PTT action: {data.payload}")
if data.payload == "push":
agent.interrupt()
elif data.payload == "release":
agent.generate_reply(on_duplicate="cancel_existing")
return "ok"
davidzhao marked this conversation as resolved.
Show resolved Hide resolved

@agent.on("agent_speech_committed")
@agent.on("agent_speech_interrupted")
def _on_agent_speech_created(msg: llm.ChatMessage):
# example of truncating the chat context
max_ctx_len = 10
chat_ctx = agent.chat_ctx_copy()
if len(chat_ctx.messages) > max_ctx_len:
chat_ctx.messages = chat_ctx.messages[-max_ctx_len:]
# NOTE: The `set_chat_ctx` function will attempt to synchronize changes made
# to the local chat context with the server instead of completely replacing it,
# provided that the message IDs are consistent.
asyncio.create_task(agent.set_chat_ctx(chat_ctx))


async def handle_request(request: JobRequest) -> None:
await request.accept(
identity="ptt-agent",
# This attribute communicates to frontend that we support PTT
attributes={"supports-ptt": "1"},
)


if __name__ == "__main__":
cli.run_app(
WorkerOptions(
entrypoint_fnc=entrypoint,
request_fnc=handle_request,
)
)
28 changes: 26 additions & 2 deletions livekit-agents/livekit/agents/multimodal/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
"user_stopped_speaking",
"agent_started_speaking",
"agent_stopped_speaking",
"input_speech_committed",
"user_speech_committed",
"agent_speech_committed",
"agent_speech_interrupted",
Expand Down Expand Up @@ -97,7 +96,20 @@ def _push_audio(self, frame: rtc.AudioFrame) -> None: ...
def fnc_ctx(self) -> llm.FunctionContext | None: ...
@fnc_ctx.setter
def fnc_ctx(self, value: llm.FunctionContext | None) -> None: ...

def chat_ctx_copy(self) -> llm.ChatContext: ...

def cancel_response(self) -> None: ...
def create_response(
self,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "keep_both",
) -> None: ...
def commit_audio_buffer(self) -> None: ...
@property
def server_vad_enabled(self) -> bool: ...

def _recover_from_text_response(self, item_id: str) -> None: ...
def _update_conversation_item_content(
self,
Expand Down Expand Up @@ -303,7 +315,6 @@ def _input_speech_committed():
alternatives=[stt.SpeechData(language="", text="")],
)
)
self.emit("input_speech_committed")

@self._session.on("input_speech_transcription_completed")
def _input_speech_transcription_completed(ev: _InputTranscriptionProto):
Expand Down Expand Up @@ -349,6 +360,8 @@ def _metrics_collected(metrics: MultimodalLLMMetrics):
self.emit("metrics_collected", metrics)

def interrupt(self) -> None:
self._session.cancel_response()

if self._playing_handle is not None and not self._playing_handle.done():
self._playing_handle.interrupt()

Expand All @@ -360,6 +373,17 @@ def interrupt(self) -> None:
)
self._update_state("listening")

def generate_reply(
self,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "cancel_existing",
) -> None:
"""Generate a reply from the agent"""
if not self._session.server_vad_enabled:
self._session.commit_audio_buffer()
self._session.create_response(on_duplicate=on_duplicate)

def _update_state(self, state: AgentState, delay: float = 0.0):
"""Set the current state of the agent"""

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,23 @@ def chat_ctx_copy(self) -> llm.ChatContext:
async def set_chat_ctx(self, ctx: llm.ChatContext) -> None:
self._chat_ctx = ctx.copy()

def cancel_response(self) -> None:
raise NotImplementedError("cancel_response is not supported yet")

def create_response(
self,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "keep_both",
) -> None:
raise NotImplementedError("create_response is not supported yet")

def commit_audio_buffer(self) -> None:
raise NotImplementedError("commit_audio_buffer is not supported yet")

def server_vad_enabled(self) -> bool:
return True

@utils.log_exceptions(logger=logger)
async def _main_task(self):
@utils.log_exceptions(logger=logger)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1017,6 +1017,25 @@ def _validate_message(msg: llm.ChatMessage) -> bool:
# wait for all the futures to complete
await asyncio.gather(*_futs)

def cancel_response(self) -> None:
if self._active_response_id:
self.response.cancel()

def create_response(
self,
on_duplicate: Literal[
"cancel_existing", "cancel_new", "keep_both"
] = "keep_both",
) -> None:
self.response.create(on_duplicate=on_duplicate)

def commit_audio_buffer(self) -> None:
self.input_audio_buffer.commit()

@property
def server_vad_enabled(self) -> bool:
return self._opts.turn_detection is not None
davidzhao marked this conversation as resolved.
Show resolved Hide resolved

def _create_empty_user_audio_message(self, duration: float) -> llm.ChatMessage:
"""Create an empty audio message with the given duration."""
samples = int(duration * api_proto.SAMPLE_RATE)
Expand Down
Loading