-
Notifications
You must be signed in to change notification settings - Fork 546
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[draft] Support STT with Google realtime API #1321
base: main
Are you sure you want to change the base?
Changes from 16 commits
2ba5598
4889301
06d60ba
4800261
97f5040
f5249f4
61da56a
a88281b
3ac01ac
aee4c1c
f42cbe1
99c4fa7
3911037
7eb6766
e8617e8
a6b378b
c624fba
fc10e5e
9e3b020
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -311,19 +311,34 @@ def _input_speech_transcription_completed(ev: _InputTranscriptionProto): | |
alternatives=[stt.SpeechData(language="", text=ev.transcript)], | ||
) | ||
) | ||
user_msg = ChatMessage.create( | ||
text=ev.transcript, role="user", id=ev.item_id | ||
) | ||
if self._model.capabilities.supports_truncate: | ||
user_msg = ChatMessage.create( | ||
Comment on lines
+314
to
+315
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why is this only done when it supports truncate? it seems you are trying to update an item, instead of truncate? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Some methods are not implemented in Gemini. We maintain remoteconversations in OpenAI, but not in Gemini. We should prevent invoking those methods when using Gemini. The purpose of supports_truncate is to differentiate between that |
||
text=ev.transcript, role="user", id=ev.item_id | ||
) | ||
|
||
self._session._update_conversation_item_content( | ||
ev.item_id, user_msg.content | ||
) | ||
self._session._update_conversation_item_content( | ||
ev.item_id, user_msg.content | ||
) | ||
|
||
self._emit_speech_committed("user", ev.transcript) | ||
|
||
self.emit("user_speech_committed", user_msg) | ||
logger.debug( | ||
"committed user speech", | ||
extra={"user_transcript": ev.transcript}, | ||
@self._session.on("agent_speech_transcription_completed") | ||
def _agent_speech_transcription_completed(ev: _InputTranscriptionProto): | ||
self._agent_stt_forwarder.update( | ||
stt.SpeechEvent( | ||
type=stt.SpeechEventType.FINAL_TRANSCRIPT, | ||
alternatives=[stt.SpeechData(language="", text=ev.transcript)], | ||
) | ||
) | ||
self._emit_speech_committed("agent", ev.transcript) | ||
|
||
# Similar to _input_speech_started, this handles updating the state to "listening" when the agent's speech is complete. | ||
# However, since Gemini doesn't support VAD events, we are not emitting the `user_started_speaking` event here. | ||
@self._session.on("agent_speech_completed") | ||
def _agent_speech_completed(): | ||
self._update_state("listening") | ||
if self._playing_handle is not None and not self._playing_handle.done(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. can you include comments on why this is needed? |
||
self._playing_handle.interrupt() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why we should interrupt here? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because we call this function when speech is interrupted as well. They likely made some changes, but now Gemini returns |
||
|
||
@self._session.on("input_speech_started") | ||
def _input_speech_started(): | ||
|
@@ -365,9 +380,9 @@ async def _run_task(delay: float) -> None: | |
await asyncio.sleep(delay) | ||
|
||
if self._room.isconnected(): | ||
await self._room.local_participant.set_attributes( | ||
{ATTRIBUTE_AGENT_STATE: state} | ||
) | ||
await self._room.local_participant.set_attributes({ | ||
ATTRIBUTE_AGENT_STATE: state | ||
}) | ||
|
||
if self._update_state_task is not None: | ||
self._update_state_task.cancel() | ||
|
@@ -378,6 +393,17 @@ async def _run_task(delay: float) -> None: | |
async def _main_task(self) -> None: | ||
self._update_state("initializing") | ||
self._audio_source = rtc.AudioSource(24000, 1) | ||
track = rtc.LocalAudioTrack.create_audio_track( | ||
"assistant_voice", self._audio_source | ||
) | ||
self._agent_publication = await self._room.local_participant.publish_track( | ||
track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) | ||
) | ||
self._agent_stt_forwarder = transcription.STTSegmentsForwarder( | ||
room=self._room, | ||
participant=self._room.local_participant, | ||
track=track, | ||
) | ||
self._agent_playout = agent_playout.AgentPlayout( | ||
audio_source=self._audio_source | ||
) | ||
|
@@ -395,39 +421,21 @@ def _on_playout_stopped(interrupted: bool) -> None: | |
if interrupted: | ||
collected_text += "..." | ||
|
||
msg = ChatMessage.create( | ||
text=collected_text, | ||
role="assistant", | ||
id=self._playing_handle.item_id, | ||
) | ||
if self._model.capabilities.supports_truncate: | ||
if self._model.capabilities.supports_truncate and collected_text: | ||
msg = ChatMessage.create( | ||
text=collected_text, | ||
role="assistant", | ||
id=self._playing_handle.item_id, | ||
) | ||
self._session._update_conversation_item_content( | ||
self._playing_handle.item_id, msg.content | ||
) | ||
|
||
if interrupted: | ||
self.emit("agent_speech_interrupted", msg) | ||
else: | ||
self.emit("agent_speech_committed", msg) | ||
|
||
logger.debug( | ||
"committed agent speech", | ||
extra={ | ||
"agent_transcript": collected_text, | ||
"interrupted": interrupted, | ||
}, | ||
) | ||
self._emit_speech_committed("agent", collected_text, interrupted) | ||
|
||
self._agent_playout.on("playout_started", _on_playout_started) | ||
self._agent_playout.on("playout_stopped", _on_playout_stopped) | ||
|
||
track = rtc.LocalAudioTrack.create_audio_track( | ||
"assistant_voice", self._audio_source | ||
) | ||
self._agent_publication = await self._room.local_participant.publish_track( | ||
track, rtc.TrackPublishOptions(source=rtc.TrackSource.SOURCE_MICROPHONE) | ||
) | ||
|
||
await self._agent_publication.wait_for_subscription() | ||
|
||
bstream = utils.audio.AudioByteStream( | ||
|
@@ -497,3 +505,22 @@ def _ensure_session(self) -> aiohttp.ClientSession: | |
self._http_session = utils.http_context.http_session() | ||
|
||
return self._http_session | ||
|
||
def _emit_speech_committed( | ||
self, speaker: Literal["user", "agent"], msg: str, interrupted: bool = False | ||
): | ||
if speaker == "user": | ||
self.emit("user_speech_committed", msg) | ||
else: | ||
if interrupted: | ||
self.emit("agent_speech_interrupted", msg) | ||
else: | ||
self.emit("agent_speech_committed", msg) | ||
|
||
logger.debug( | ||
f"committed {speaker} speech", | ||
extra={ | ||
f"{speaker}_transcript": msg, | ||
"interrupted": interrupted, | ||
}, | ||
) |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,22 @@ | ||
from __future__ import annotations | ||
|
||
import inspect | ||
import json | ||
from typing import Any, Dict, List, Literal, Sequence, Union | ||
|
||
from livekit.agents import llm | ||
|
||
from google.genai import types # type: ignore | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The code here is hard to follow, really sad we don't have types (it's unclear what is the structure of the dicts) |
||
|
||
__all__ = [ | ||
"ClientEvents", | ||
"LiveAPIModels", | ||
"ResponseModality", | ||
"Voice", | ||
"_build_gemini_ctx", | ||
"_build_tools", | ||
] | ||
|
||
LiveAPIModels = Literal["gemini-2.0-flash-exp"] | ||
|
||
Voice = Literal["Puck", "Charon", "Kore", "Fenrir", "Aoede"] | ||
|
@@ -77,3 +89,35 @@ def _build_tools(fnc_ctx: Any) -> List[types.FunctionDeclarationDict]: | |
function_declarations.append(func_decl) | ||
|
||
return function_declarations | ||
|
||
|
||
def _build_gemini_ctx(chat_ctx: llm.ChatContext) -> List[types.Content]: | ||
content = None | ||
turns = [] | ||
|
||
for msg in chat_ctx.messages: | ||
role = None | ||
if msg.role == "assistant": | ||
role = "model" | ||
elif msg.role in {"system", "user"}: | ||
role = "user" | ||
elif msg.role == "tool": | ||
continue | ||
|
||
if content and content.role == role: | ||
if isinstance(msg.content, str): | ||
content.parts.append(types.Part(text=msg.content)) | ||
elif isinstance(msg.content, dict): | ||
content.parts.append(types.Part(text=json.dumps(msg.content))) | ||
elif isinstance(msg.content, list): | ||
for item in msg.content: | ||
if isinstance(item, str): | ||
content.parts.append(types.Part(text=item)) | ||
else: | ||
content = types.Content( | ||
parts=[types.Part(text=msg.content)], | ||
role=role, | ||
) | ||
turns.append(content) | ||
|
||
return turns |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
does the last message have to be user.. in order for gemini to respond first?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it can be either
assistant
oruser
.