Skip to content

Commit

Permalink
Merge pull request #257 from pipecat-ai/aleix/cancel-all-tasks-when-i…
Browse files Browse the repository at this point in the history
…nterrutpted

cancel all tasks when interrutpted
  • Loading branch information
aconchillo authored Jun 25, 2024
2 parents 83d1931 + 38aee7d commit 84074e9
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 12 deletions.
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,10 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Fixed

- Fixed an issue with asynchronous STT services (Deepgram and Azure) that could
cause static audio issues and interruptions to not work properly when dealing
with multiple LLMs sentences.

- Fixed an issue that could mix new LLM responses with previous ones when
handling interruptions.

Expand Down
43 changes: 40 additions & 3 deletions src/pipecat/services/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,18 @@
from PIL import Image
from typing import AsyncGenerator

from pipecat.frames.frames import AudioRawFrame, CancelFrame, EndFrame, ErrorFrame, Frame, StartFrame, SystemFrame, TranscriptionFrame, URLImageRawFrame
from pipecat.frames.frames import (
AudioRawFrame,
CancelFrame,
EndFrame,
ErrorFrame,
Frame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
SystemFrame,
TranscriptionFrame,
URLImageRawFrame)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import AIService, TTSService, ImageGenService
from pipecat.services.openai import BaseOpenAILLMService
Expand All @@ -34,7 +45,7 @@
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Azure TTS, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.")
"In order to use Azure, you need to `pip install pipecat-ai[azure]`. Also, set `AZURE_SPEECH_API_KEY` and `AZURE_SPEECH_REGION` environment variables.")
raise Exception(f"Missing module: {e}")


Expand Down Expand Up @@ -123,12 +134,18 @@ def __init__(
speech_config=speech_config, audio_config=audio_config)
self._speech_recognizer.recognized.connect(self._on_handle_recognized)

# This event will be used to ignore out-of-band transcriptions while we
# are itnerrupted.
self._is_interrupted_event = asyncio.Event()

self._create_push_task()

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, SystemFrame):
if isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
await self._handle_interruptions(frame)
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
self._audio_stream.write(frame.audio)
Expand All @@ -148,6 +165,23 @@ async def cancel(self, frame: CancelFrame):
self._push_frame_task.cancel()
await self._push_frame_task

async def _handle_interruptions(self, frame: Frame):
if isinstance(frame, StartInterruptionFrame):
# Indicate we are interrupted, we should ignore any out-of-band
# transcriptions.
self._is_interrupted_event.set()
# Cancel the task. This will stop pushing frames downstream.
self._push_frame_task.cancel()
await self._push_frame_task
# Push an out-of-band frame (i.e. not using the ordered push
# frame task).
await self.push_frame(frame)
# Create a new queue and task.
self._create_push_task()
elif isinstance(frame, StopInterruptionFrame):
# We should now be able to receive transcriptions again.
self._is_interrupted_event.clear()

def _create_push_task(self):
self._push_queue = asyncio.Queue()
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
Expand All @@ -163,6 +197,9 @@ async def _push_frame_task_handler(self):
break

def _on_handle_recognized(self, event):
if self._is_interrupted_event.is_set():
return

if event.result.reason == ResultReason.RecognizedSpeech and len(event.result.text) > 0:
direction = FrameDirection.DOWNSTREAM
frame = TranscriptionFrame(event.result.text, "", int(time.time_ns() / 1000000))
Expand Down
51 changes: 43 additions & 8 deletions src/pipecat/services/deepgram.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,20 +18,29 @@
Frame,
InterimTranscriptionFrame,
StartFrame,
StartInterruptionFrame,
StopInterruptionFrame,
SystemFrame,
TranscriptionFrame)
from pipecat.processors.frame_processor import FrameDirection
from pipecat.services.ai_services import AIService, TTSService

from deepgram import (
DeepgramClient,
DeepgramClientOptions,
LiveTranscriptionEvents,
LiveOptions,
)

from loguru import logger

# See .env.example for Deepgram configuration needed
try:
from deepgram import (
DeepgramClient,
DeepgramClientOptions,
LiveTranscriptionEvents,
LiveOptions,
)
except ModuleNotFoundError as e:
logger.error(f"Exception: {e}")
logger.error(
"In order to use Deepgram, you need to `pip install pipecat-ai[deepgram]`. Also, set `DEEPGRAM_API_KEY` environment variable.")
raise Exception(f"Missing module: {e}")


class DeepgramTTSService(TTSService):

Expand Down Expand Up @@ -109,12 +118,18 @@ def __init__(self,
self._connection = self._client.listen.asynclive.v("1")
self._connection.on(LiveTranscriptionEvents.Transcript, self._on_message)

# This event will be used to ignore out-of-band transcriptions while we
# are itnerrupted.
self._is_interrupted_event = asyncio.Event()

self._create_push_task()

async def process_frame(self, frame: Frame, direction: FrameDirection):
await super().process_frame(frame, direction)

if isinstance(frame, SystemFrame):
if isinstance(frame, StartInterruptionFrame) or isinstance(frame, StopInterruptionFrame):
await self._handle_interruptions(frame)
elif isinstance(frame, SystemFrame):
await self.push_frame(frame, direction)
elif isinstance(frame, AudioRawFrame):
await self._connection.send(frame.audio)
Expand All @@ -137,6 +152,23 @@ async def cancel(self, frame: CancelFrame):
self._push_frame_task.cancel()
await self._push_frame_task

async def _handle_interruptions(self, frame: Frame):
if isinstance(frame, StartInterruptionFrame):
# Indicate we are interrupted, we should ignore any out-of-band
# transcriptions.
self._is_interrupted_event.set()
# Cancel the task. This will stop pushing frames downstream.
self._push_frame_task.cancel()
await self._push_frame_task
# Push an out-of-band frame (i.e. not using the ordered push
# frame task).
await self.push_frame(frame)
# Create a new queue and task.
self._create_push_task()
elif isinstance(frame, StopInterruptionFrame):
# We should now be able to receive transcriptions again.
self._is_interrupted_event.clear()

def _create_push_task(self):
self._push_queue = asyncio.Queue()
self._push_frame_task = self.get_event_loop().create_task(self._push_frame_task_handler())
Expand All @@ -152,6 +184,9 @@ async def _push_frame_task_handler(self):
break

async def _on_message(self, *args, **kwargs):
if self._is_interrupted_event.is_set():
return

result = kwargs["result"]
is_final = result.is_final
transcript = result.channel.alternatives[0].transcript
Expand Down
7 changes: 6 additions & 1 deletion src/pipecat/transports/base_input.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,15 @@ async def _handle_interruptions(self, frame: Frame):
# Make sure we notify about interruptions quickly out-of-band
if isinstance(frame, UserStartedSpeakingFrame):
logger.debug("User started speaking")
# Cancel the task. This will stop pushing frames downstream.
self._push_frame_task.cancel()
await self._push_frame_task
self._create_push_task()
# Push an out-of-band frame (i.e. not using the ordered push
# frame task) to stop everything, specially at the output
# transport.
await self.push_frame(StartInterruptionFrame())
# Create a new queue and task.
self._create_push_task()
elif isinstance(frame, UserStoppedSpeakingFrame):
logger.debug("User stopped speaking")
await self.push_frame(StopInterruptionFrame())
Expand Down

0 comments on commit 84074e9

Please sign in to comment.