Skip to content

Commit

Permalink
makes synthesis result worker thread async
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 committed Aug 1, 2023
1 parent 8fb11e5 commit 5ba1782
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 14 deletions.
2 changes: 1 addition & 1 deletion vocode/streaming/output_device/blocking_speaker_output.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def __init__(
def start(self):
ThreadAsyncWorker.start(self)

def _run_loop(self):
async def _run_loop(self):
while not self._ended:
try:
chunk = self.input_janus_queue.sync_q.get(timeout=1)
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/output_device/file_output_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def __init__(self, input_queue: Queue, wave) -> None:
super().__init__(input_queue)
self.wav = wave

def _run_loop(self):
async def _run_loop(self):
while True:
try:
block = self.input_janus_queue.sync_q.get()
Expand Down
3 changes: 2 additions & 1 deletion vocode/streaming/streaming_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@
InterruptibleEventFactory,
InterruptibleAgentResponseEvent,
InterruptibleWorker,
ThreadInterruptibleAgentResponseWorker,
)

OutputDeviceType = TypeVar("OutputDeviceType", bound=BaseOutputDevice)
Expand Down Expand Up @@ -286,7 +287,7 @@ async def process(self, item: InterruptibleAgentResponseEvent[AgentResponse]):
except asyncio.CancelledError:
pass

class SynthesisResultsWorker(InterruptibleAgentResponseWorker):
class SynthesisResultsWorker(ThreadInterruptibleAgentResponseWorker):
"""Plays SynthesisResults from the output queue on the output device"""

def __init__(
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/transcriber/azure_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def recognized_sentence_stream(self, evt):
Transcription(message=evt.result.text, confidence=1.0, is_final=False)
)

def _run_loop(self):
async def _run_loop(self):
stream = self.generator()

def stop_cb(evt):
Expand Down
3 changes: 2 additions & 1 deletion vocode/streaming/transcriber/base_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
tracer = trace.get_tracer(__name__)
meter = metrics.get_meter(__name__)


class Transcription(BaseModel):
message: str
confidence: float
Expand Down Expand Up @@ -87,7 +88,7 @@ def __init__(
ThreadAsyncWorker.__init__(self, self.input_queue, self.output_queue)
AbstractTranscriber.__init__(self, transcriber_config)

def _run_loop(self):
async def _run_loop(self):
raise NotImplementedError

def send_audio(self, chunk):
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/transcriber/google_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def create_google_streaming_config(self):
interim_results=True,
)

def _run_loop(self):
async def _run_loop(self):
stream = self.generator()
requests = (
self.speech.StreamingRecognizeRequest(audio_content=content)
Expand Down
2 changes: 1 addition & 1 deletion vocode/streaming/transcriber/whisper_cpp_transcriber.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def create_new_buffer(self):
wav.setframerate(self.transcriber_config.sampling_rate)
return wav, buffer

def _run_loop(self):
async def _run_loop(self):
in_memory_wav, audio_buffer = self.create_new_buffer()
message_buffer = ""
while not self._ended:
Expand Down
61 changes: 54 additions & 7 deletions vocode/streaming/utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,26 +35,33 @@ def produce_nonblocking(self, item):
async def _run_loop(self):
raise NotImplementedError

async def get_next_event(self) -> WorkerInputType:
return await self.input_queue.get()

def terminate(self):
if self.worker_task:
return self.worker_task.cancel()

return False


class ThreadAsyncWorker(AsyncWorker):
class ThreadAsyncWorker(AsyncWorker[WorkerInputType]):
def __init__(
self,
input_queue: asyncio.Queue,
input_queue: asyncio.Queue[WorkerInputType],
output_queue: asyncio.Queue = asyncio.Queue(),
) -> None:
super().__init__(input_queue, output_queue)
AsyncWorker.__init__(self, input_queue, output_queue)
self.worker_thread: Optional[threading.Thread] = None
self.input_janus_queue: janus.Queue = janus.Queue()
self.output_janus_queue: janus.Queue = janus.Queue()
self.loop_task: Optional[asyncio.Task] = None
self.loop = asyncio.new_event_loop()

def start(self) -> asyncio.Task:
self.worker_thread = threading.Thread(target=self._run_loop)
self.worker_thread = threading.Thread(
target=self._run_loop_sync, args=(self.loop,)
)
self.worker_thread.start()
self.worker_task = asyncio.create_task(self.run_thread_forwarding())
return self.worker_task
Expand All @@ -78,11 +85,21 @@ async def _forward_from_thead(self):
item = await self.output_janus_queue.async_q.get()
self.output_queue.put_nowait(item)

def _run_loop(self):
def _run_loop_sync(self, loop: asyncio.AbstractEventLoop):
asyncio.set_event_loop(loop)
self.loop_task = loop.create_task(self._run_loop())
loop.run_until_complete(self.loop_task)

async def _run_loop(self):
raise NotImplementedError

async def get_next_event(self) -> WorkerInputType:
return self.input_janus_queue.sync_q.get()

def terminate(self):
return super().terminate()
super().terminate()
self.loop.call_soon_threadsafe(self.loop_task.cancel)
self.loop.call_soon_threadsafe(self.loop.stop)


class AsyncQueueWorker(AsyncWorker):
Expand Down Expand Up @@ -208,7 +225,7 @@ def produce_interruptible_agent_response_event_nonblocking(
async def _run_loop(self):
# TODO Implement concurrency with max_nb_of_thread
while True:
item = await self.input_queue.get()
item = await self.get_next_event()
if item.is_interrupted():
continue
self.interruptible_event = item
Expand Down Expand Up @@ -245,7 +262,37 @@ def cancel_current_task(self):
return False


class ThreadInterruptibleWorker(
ThreadAsyncWorker[InterruptibleEventType],
InterruptibleWorker[InterruptibleEventType],
):
def __init__(
self,
input_queue: asyncio.Queue[InterruptibleEventType],
output_queue: asyncio.Queue = asyncio.Queue(),
interruptible_event_factory: InterruptibleEventFactory = InterruptibleEventFactory(),
max_concurrency=2,
) -> None:
InterruptibleWorker.__init__(
self,
input_queue,
output_queue,
interruptible_event_factory,
max_concurrency,
)
ThreadAsyncWorker.__init__(self, input_queue, output_queue)

async def _run_loop(self):
return await InterruptibleWorker._run_loop(self)


class InterruptibleAgentResponseWorker(
InterruptibleWorker[InterruptibleAgentResponseEvent]
):
pass


class ThreadInterruptibleAgentResponseWorker(
ThreadInterruptibleWorker[InterruptibleAgentResponseEvent]
):
pass

0 comments on commit 5ba1782

Please sign in to comment.