Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Transcriber session with Google realtime API #1321

Merged
merged 31 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions .changeset/slimy-candles-agree.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
"livekit-plugins-google": patch
---

support transcriber session for user/agent audio
16 changes: 14 additions & 2 deletions examples/multimodal-agent/gemini_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,25 @@ async def get_weather(
await ctx.connect(auto_subscribe=AutoSubscribe.AUDIO_ONLY)
participant = await ctx.wait_for_participant()

# chat_ctx is used to serve as initial context, Agent will start the conversation first if chat_ctx is provided
chat_ctx = llm.ChatContext()
chat_ctx.append(text="What is LiveKit?", role="user")
chat_ctx.append(
text="LiveKit is the platform for building realtime AI. The main use cases are to build AI voice agents. LiveKit also powers livestreaming apps, robotics, and video conferencing.",
role="assistant",
)
chat_ctx.append(text="What is the LiveKit Agents framework?", role="user")
jayeshp19 marked this conversation as resolved.
Show resolved Hide resolved

agent = multimodal.MultimodalAgent(
model=google.beta.realtime.RealtimeModel(
voice="Charon",
voice="Puck",
temperature=0.8,
instructions="You are a helpful assistant",
instructions="""
You are a helpful assistant
Here are some helpful information about LiveKit and its products and services:
- LiveKit is the platform for building realtime AI. The main use cases are to build AI voice agents. LiveKit also powers livestreaming apps, robotics, and video conferencing.
- LiveKit provides an Agents framework for building server-side AI agents, client SDKs for building frontends, and LiveKit Cloud is a global network that transports voice, video, and data traffic in realtime.
""",
),
fnc_ctx=fnc_ctx,
chat_ctx=chat_ctx,
Expand Down
1 change: 0 additions & 1 deletion livekit-agents/livekit/agents/cli/log.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
"watchfiles",
"anthropic",
"websockets.client",
"root",
]


Expand Down
99 changes: 62 additions & 37 deletions livekit-agents/livekit/agents/multimodal/multimodal_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,19 +324,32 @@ def _input_speech_transcription_completed(ev: _InputTranscriptionProto):
alternatives=[stt.SpeechData(language="", text=ev.transcript)],
)
)
user_msg = ChatMessage.create(
text=ev.transcript, role="user", id=ev.item_id
)
if self._model.capabilities.supports_truncate:
user_msg = ChatMessage.create(
jayeshp19 marked this conversation as resolved.
Show resolved Hide resolved
text=ev.transcript, role="user", id=ev.item_id
)

self._session._update_conversation_item_content(
ev.item_id, user_msg.content
)
self._session._update_conversation_item_content(
ev.item_id, user_msg.content
)

self.emit("user_speech_committed", user_msg)
logger.debug(
"committed user speech",
extra={"user_transcript": ev.transcript},
self._emit_speech_committed("user", ev.transcript)

@self._session.on("agent_speech_transcription_completed")
def _agent_speech_transcription_completed(ev: _InputTranscriptionProto):
self._agent_stt_forwarder.update(
stt.SpeechEvent(
type=stt.SpeechEventType.FINAL_TRANSCRIPT,
alternatives=[stt.SpeechData(language="", text=ev.transcript)],
)
)
self._emit_speech_committed("agent", ev.transcript)

# Similar to _input_speech_started, this handles updating the state to "listening" when the agent's speech is complete.
# However, since Gemini doesn't support VAD events, we are not emitting the `user_started_speaking` event here.
@self._session.on("agent_speech_stopped")
def _agent_speech_stopped():
self.interrupt()

@self._session.on("input_speech_started")
def _input_speech_started():
Expand All @@ -360,12 +373,12 @@ def _metrics_collected(metrics: MultimodalLLMMetrics):
self.emit("metrics_collected", metrics)

def interrupt(self) -> None:
self._session.cancel_response()

if self._playing_handle is not None and not self._playing_handle.done():
self._playing_handle.interrupt()

if self._model.capabilities.supports_truncate:
self._session.cancel_response() # Only supported by OpenAI

self._session._truncate_conversation_item(
item_id=self._playing_handle.item_id,
content_index=self._playing_handle.content_index,
Expand Down Expand Up @@ -405,6 +418,17 @@ async def _run_task(delay: float) -> None:
async def _main_task(self) -> None:
self._update_state("initializing")
self._audio_source = rtc.AudioSource(24000, 1)
track = rtc.LocalAudioTrack.create_audio_track(
"assistant_voice", self._audio_source
)
self._agent_publication = await self._room.local_participant.publish_track(
track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
)
self._agent_stt_forwarder = transcription.STTSegmentsForwarder(
room=self._room,
participant=self._room.local_participant,
track=track,
)
self._agent_playout = agent_playout.AgentPlayout(
audio_source=self._audio_source
)
Expand All @@ -422,39 +446,21 @@ def _on_playout_stopped(interrupted: bool) -> None:
if interrupted:
collected_text += "..."

msg = ChatMessage.create(
text=collected_text,
role="assistant",
id=self._playing_handle.item_id,
)
if self._model.capabilities.supports_truncate:
if self._model.capabilities.supports_truncate and collected_text:
msg = ChatMessage.create(
text=collected_text,
role="assistant",
id=self._playing_handle.item_id,
)
self._session._update_conversation_item_content(
self._playing_handle.item_id, msg.content
)

if interrupted:
self.emit("agent_speech_interrupted", msg)
else:
self.emit("agent_speech_committed", msg)

logger.debug(
"committed agent speech",
extra={
"agent_transcript": collected_text,
"interrupted": interrupted,
},
)
self._emit_speech_committed("agent", collected_text, interrupted)

self._agent_playout.on("playout_started", _on_playout_started)
self._agent_playout.on("playout_stopped", _on_playout_stopped)

track = rtc.LocalAudioTrack.create_audio_track(
"assistant_voice", self._audio_source
)
self._agent_publication = await self._room.local_participant.publish_track(
track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE)
)

await self._agent_publication.wait_for_subscription()

bstream = utils.audio.AudioByteStream(
Expand Down Expand Up @@ -524,3 +530,22 @@ def _ensure_session(self) -> aiohttp.ClientSession:
self._http_session = utils.http_context.http_session()

return self._http_session

def _emit_speech_committed(
self, speaker: Literal["user", "agent"], msg: str, interrupted: bool = False
):
if speaker == "user":
self.emit("user_speech_committed", msg)
else:
if interrupted:
self.emit("agent_speech_interrupted", msg)
else:
self.emit("agent_speech_committed", msg)

logger.debug(
f"committed {speaker} speech",
extra={
f"{speaker}_transcript": msg,
"interrupted": interrupted,
},
)
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import base64
import inspect
import json
from typing import Any, Dict, List, Optional, Union, cast, get_args, get_origin
from typing import Any, Dict, List, Optional, get_args, get_origin

from livekit import rtc
from livekit.agents import llm, utils
Expand Down Expand Up @@ -88,18 +88,16 @@ def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclaration]:

def _build_gemini_ctx(
chat_ctx: llm.ChatContext, cache_key: Any
) -> tuple[
Union[types.ContentListUnion, types.ContentListUnionDict], Optional[types.Part]
]:
) -> tuple[list[types.Content], Optional[types.Content]]:
turns: list[types.Content] = []
current_content: Optional[types.Content] = None
system_instruction: Optional[types.Part] = None
system_instruction: Optional[types.Content] = None
current_role: Optional[str] = None

for msg in chat_ctx.messages:
if msg.role == "system":
if isinstance(msg.content, str):
system_instruction = types.Part(text=msg.content)
system_instruction = types.Content(parts=[types.Part(text=msg.content)])
continue

if msg.role == "assistant":
Expand Down Expand Up @@ -169,8 +167,7 @@ def _build_gemini_ctx(
_build_gemini_image_part(item, cache_key)
)

ctx = cast(Union[types.ContentListUnion, types.ContentListUnionDict], turns)
return ctx, system_instruction
return turns, system_instruction


def _build_gemini_image_part(image: llm.ChatImage, cache_key: Any) -> types.Part:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from google.genai import types

from ..._utils import _build_tools
from ..._utils import _build_gemini_ctx, _build_tools

LiveAPIModels = Literal["gemini-2.0-flash-exp"]

Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"]

__all__ = ["_build_tools", "ClientEvents"]
__all__ = ["_build_tools", "ClientEvents", "_build_gemini_ctx"]

ClientEvents = Union[
types.ContentListUnion,
Expand Down
Loading
Loading