-
Notifications
You must be signed in to change notification settings - Fork 7
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add integration support for vox implant (#97)
- Loading branch information
Showing
9 changed files
with
1,448 additions
and
1,097 deletions.
There are no files selected for viewing
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,91 @@ | ||
import asyncio | ||
|
||
import pybase64 | ||
|
||
from fastapi import FastAPI, WebSocket, WebSocketDisconnect | ||
|
||
from skynet.logs import get_logger | ||
from skynet.modules.stt.streaming_whisper.utils import utils | ||
from skynet.modules.stt.vox.connection_manager import ConnectionManager | ||
from skynet.modules.stt.vox.decoder import PcmaDecoder | ||
from skynet.modules.stt.vox.resampler import PcmResampler | ||
|
||
log = get_logger(__name__) | ||
|
||
ws_connection_manager = ConnectionManager() | ||
app = FastAPI() | ||
running_tasks = set() | ||
whisper_sampling_rate = 16000 | ||
|
||
|
||
@app.websocket('/ws') | ||
async def websocket_endpoint(websocket: WebSocket, auth_token: str | None = None): | ||
decoder = PcmaDecoder() | ||
resampler = PcmResampler() | ||
session_id = utils.Uuid7().get() | ||
await ws_connection_manager.connect(websocket, session_id, auth_token) | ||
|
||
data_map = dict() | ||
resampler = None | ||
sampling_rate = 8000 | ||
|
||
while True: | ||
try: | ||
ws_data = await websocket.receive_json() | ||
|
||
event = ws_data.get('event') | ||
|
||
if event == 'start': | ||
try: | ||
sampling_rate = ws_data['start']['mediaFormat']['sampleRate'] | ||
except KeyError: | ||
pass | ||
|
||
if event == 'media': | ||
media = ws_data['media'] | ||
participant_id: str = media['tag'] | ||
|
||
if participant_id not in data_map: | ||
header = (participant_id.encode() + '|en'.encode()).ljust(60, b'\0') | ||
data_map[participant_id] = dict(raw=b'', chunks=0, header=header) | ||
|
||
payload = pybase64.b64decode(media['payload']) | ||
participant = data_map[participant_id] | ||
|
||
participant['raw'] += payload | ||
participant['chunks'] += 1 | ||
|
||
if participant['chunks'] == 50: # 50 chunks = 1 second | ||
frames = decoder.decode(participant['raw'], media['timestamp']) | ||
|
||
decoded_raw = b'' | ||
|
||
resampler = resampler or PcmResampler( | ||
format='s16', | ||
layout='mono', | ||
rate=whisper_sampling_rate, | ||
) | ||
|
||
for frame in frames: | ||
decoded_raw += resampler.resample(frame) | ||
|
||
task = asyncio.create_task( | ||
ws_connection_manager.process( | ||
session_id, participant['header'] + decoded_raw, media['timestamp'], participant_id | ||
) | ||
) | ||
|
||
running_tasks.add(task) | ||
task.add_done_callback(running_tasks.remove) | ||
|
||
participant['chunks'] = 0 | ||
participant['raw'] = b'' | ||
|
||
except WebSocketDisconnect: | ||
ws_connection_manager.disconnect(session_id) | ||
data_map.clear() | ||
log.info(f'Session {session_id} has ended') | ||
break | ||
|
||
|
||
__all__ = ['app'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
from typing import List | ||
|
||
from fastapi import WebSocketDisconnect | ||
|
||
from skynet.logs import get_logger | ||
from skynet.modules.stt.streaming_whisper.connection_manager import ConnectionManager as BaseConnectionManager | ||
from skynet.modules.stt.streaming_whisper.utils.utils import TranscriptionResponse | ||
|
||
log = get_logger(__name__) | ||
|
||
|
||
class ConnectionManager(BaseConnectionManager): | ||
async def process(self, session_id: str, buffer: bytes, timestamp: int, tag: str): | ||
log.debug(f'Processing chunk for session {session_id}') | ||
|
||
if session_id not in self.connections: | ||
log.warning(f'No such session id {session_id}, the connection was probably closed.') | ||
return | ||
|
||
results: List[TranscriptionResponse] = await self.connections[session_id].process(buffer, timestamp) | ||
|
||
if results is not None: | ||
for result in results: | ||
if result.type == 'final': | ||
log.info(f'Participant {tag} result: {result.text}') | ||
await self.send(session_id, result, timestamp, tag) | ||
|
||
async def send(self, session_id: str, result: TranscriptionResponse | None, timestamp: int, tag: str): | ||
if result is not None: | ||
try: | ||
await self.connections[session_id].ws.send_json( | ||
{ | ||
'timestamp': timestamp, | ||
'tag': tag, | ||
'final': result.text, | ||
'language': 'en', | ||
} | ||
) | ||
except WebSocketDisconnect as e: | ||
log.warning(f'Session {session_id}: the connection was closed before sending all results: {e}') | ||
self.disconnect(session_id) | ||
except Exception as ex: | ||
log.error(f'Session {session_id}: exception while sending transcription results {ex}') |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
import fractions | ||
import time | ||
from typing import List, Optional | ||
|
||
from av import CodecContext | ||
from av.audio.codeccontext import AudioCodecContext | ||
from av.audio.frame import AudioFrame | ||
from av.packet import Packet | ||
|
||
SAMPLE_RATE = 8000 | ||
TIME_BASE = fractions.Fraction(1, 8000) | ||
|
||
|
||
class PcmDecoder: | ||
def __init__(self, codec_name: str) -> None: | ||
self.codec: AudioCodecContext = CodecContext.create(codec_name, "r") | ||
self.codec.format = "s16" | ||
self.codec.layout = "mono" | ||
self.codec.sample_rate = SAMPLE_RATE | ||
|
||
def decode(self, data: bytes, timestamp: Optional[int] = None) -> List[AudioFrame]: | ||
packet = Packet(data) | ||
packet.pts = timestamp or int(time.time()) | ||
packet.time_base = TIME_BASE | ||
return self.codec.decode(packet) | ||
|
||
|
||
class PcmaDecoder(PcmDecoder): | ||
def __init__(self) -> None: | ||
super().__init__("pcm_alaw") | ||
|
||
|
||
__all__ = ["PcmaDecoder"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
from av.audio.frame import AudioFrame | ||
from av.audio.resampler import AudioResampler | ||
|
||
|
||
class PcmResampler: | ||
def __init__(self, **kwargs) -> None: | ||
self.resampler = AudioResampler(**kwargs) | ||
self.layout = kwargs.get("layout") | ||
self.format = kwargs.get("format") | ||
|
||
def resample(self, frame: AudioFrame) -> bytes: | ||
resampled_raw = b'' | ||
resampled_frames = self.resampler.resample(frame) | ||
|
||
for resampled_frame in resampled_frames: | ||
resampled_raw += bytes(resampled_frame.planes[0]) | ||
|
||
return resampled_raw | ||
|
||
|
||
__all__ = ["PcmResampler"] |