Skip to content

Commit

Permalink
feat: add integration support for vox implant (#97)
Browse files Browse the repository at this point in the history
  • Loading branch information
quitrk authored Sep 26, 2024
1 parent 1edf916 commit 9b3344f
Show file tree
Hide file tree
Showing 9 changed files with 1,448 additions and 1,097 deletions.
2,288 changes: 1,223 additions & 1,065 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ aiofiles = "^23.2.1"
pydantic = "^2.9.1"
langchain = "^0.3.0"
langchain-openai = "^0.2.0"
av = "^12.3.0"
pybase64 = "^1.4.0"

[build-system]
build-backend = "poetry.core.masonry.api"
Expand Down
63 changes: 32 additions & 31 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,89 +2,90 @@ aiofiles==23.2.1 ; python_version >= "3.11" and python_version < "3.12"
aiohttp==3.9.5 ; python_version >= "3.11" and python_version < "3.12"
aiosignal==1.3.1 ; python_version >= "3.11" and python_version < "3.12"
annotated-types==0.7.0 ; python_version >= "3.11" and python_version < "3.12"
anyio==4.4.0 ; python_version >= "3.11" and python_version < "3.12"
anyio==4.5.0 ; python_version >= "3.11" and python_version < "3.12"
async-lru==2.0.4 ; python_version >= "3.11" and python_version < "3.12"
async-timeout==4.0.3 ; python_version >= "3.11" and python_full_version <= "3.11.2"
attrs==24.2.0 ; python_version >= "3.11" and python_version < "3.12"
av==12.3.0 ; python_version >= "3.11" and python_version < "3.12"
boto3==1.34.156 ; python_version >= "3.11" and python_version < "3.12"
botocore==1.34.156 ; python_version >= "3.11" and python_version < "3.12"
certifi==2024.7.4 ; python_version >= "3.11" and python_version < "3.12"
cffi==1.17.0 ; python_version >= "3.11" and python_version < "3.12" and platform_python_implementation != "PyPy"
boto3==1.35.23 ; python_version >= "3.11" and python_version < "3.12"
botocore==1.35.23 ; python_version >= "3.11" and python_version < "3.12"
certifi==2024.8.30 ; python_version >= "3.11" and python_version < "3.12"
cffi==1.17.1 ; python_version >= "3.11" and python_version < "3.12" and platform_python_implementation != "PyPy"
charset-normalizer==3.3.2 ; python_version >= "3.11" and python_version < "3.12"
click==8.1.7 ; python_version >= "3.11" and python_version < "3.12"
colorama==0.4.6 ; python_version >= "3.11" and python_version < "3.12" and (sys_platform == "win32" or platform_system == "Windows")
coloredlogs==15.0.1 ; python_version >= "3.11" and python_version < "3.12"
cryptography==43.0.0 ; python_version >= "3.11" and python_version < "3.12"
ctranslate2==4.3.1 ; python_version >= "3.11" and python_version < "3.12"
cryptography==43.0.1 ; python_version >= "3.11" and python_version < "3.12"
ctranslate2==4.4.0 ; python_version >= "3.11" and python_version < "3.12"
distro==1.9.0 ; python_version >= "3.11" and python_version < "3.12"
fastapi-versionizer==3.0.4 ; python_version >= "3.11" and python_version < "3.12"
fastapi==0.109.0 ; python_version >= "3.11" and python_version < "3.12"
faster-whisper==1.0.3 ; python_version >= "3.11" and python_version < "3.12"
filelock==3.15.4 ; python_version >= "3.11" and python_version < "3.12"
filelock==3.16.1 ; python_version >= "3.11" and python_version < "3.12"
flatbuffers==24.3.25 ; python_version >= "3.11" and python_version < "3.12"
frozenlist==1.4.1 ; python_version >= "3.11" and python_version < "3.12"
fsspec==2024.6.1 ; python_version >= "3.11" and python_version < "3.12"
fsspec==2024.9.0 ; python_version >= "3.11" and python_version < "3.12"
greenlet==3.1.0 ; python_version < "3.12" and (platform_machine == "aarch64" or platform_machine == "ppc64le" or platform_machine == "x86_64" or platform_machine == "amd64" or platform_machine == "AMD64" or platform_machine == "win32" or platform_machine == "WIN32") and python_version >= "3.11"
h11==0.14.0 ; python_version >= "3.11" and python_version < "3.12"
httpcore==1.0.5 ; python_version >= "3.11" and python_version < "3.12"
httptools==0.6.1 ; python_version >= "3.11" and python_version < "3.12"
httpx==0.27.0 ; python_version >= "3.11" and python_version < "3.12"
huggingface-hub==0.24.5 ; python_version >= "3.11" and python_version < "3.12"
httpx==0.27.2 ; python_version >= "3.11" and python_version < "3.12"
huggingface-hub==0.25.0 ; python_version >= "3.11" and python_version < "3.12"
humanfriendly==10.0 ; python_version >= "3.11" and python_version < "3.12"
idna==3.7 ; python_version >= "3.11" and python_version < "3.12"
idna==3.10 ; python_version >= "3.11" and python_version < "3.12"
jinja2==3.1.4 ; python_version >= "3.11" and python_version < "3.12"
jiter==0.5.0 ; python_version >= "3.11" and python_version < "3.12"
jmespath==1.0.1 ; python_version >= "3.11" and python_version < "3.12"
jsonpatch==1.33 ; python_version >= "3.11" and python_version < "3.12"
jsonpointer==3.0.0 ; python_version >= "3.11" and python_version < "3.12"
langchain-core==0.3.0 ; python_version >= "3.11" and python_version < "3.12"
langchain-core==0.3.2 ; python_version >= "3.11" and python_version < "3.12"
langchain-openai==0.2.0 ; python_version >= "3.11" and python_version < "3.12"
langchain-text-splitters==0.3.0 ; python_version >= "3.11" and python_version < "3.12"
langchain==0.3.0 ; python_version >= "3.11" and python_version < "3.12"
langsmith==0.1.120 ; python_version >= "3.11" and python_version < "3.12"
langsmith==0.1.125 ; python_version >= "3.11" and python_version < "3.12"
markupsafe==2.1.5 ; python_version >= "3.11" and python_version < "3.12"
mpmath==1.3.0 ; python_version >= "3.11" and python_version < "3.12"
multidict==6.0.5 ; python_version >= "3.11" and python_version < "3.12"
multidict==6.1.0 ; python_version >= "3.11" and python_version < "3.12"
natsort==8.4.0 ; python_version >= "3.11" and python_version < "3.12"
networkx==3.3 ; python_version >= "3.11" and python_version < "3.12"
numpy==1.26.4 ; python_version >= "3.11" and python_version < "3.12"
onnxruntime==1.18.1 ; python_version >= "3.11" and python_version < "3.12"
openai==1.40.1 ; python_version >= "3.11" and python_version < "3.12"
onnxruntime==1.19.2 ; python_version >= "3.11" and python_version < "3.12"
openai==1.46.1 ; python_version >= "3.11" and python_version < "3.12"
orjson==3.10.7 ; python_version >= "3.11" and python_version < "3.12"
packaging==23.2 ; python_version >= "3.11" and python_version < "3.12"
packaging==24.1 ; python_version >= "3.11" and python_version < "3.12"
prometheus-client==0.19.0 ; python_version >= "3.11" and python_version < "3.12"
prometheus-fastapi-instrumentator==6.1.0 ; python_version >= "3.11" and python_version < "3.12"
protobuf==5.27.3 ; python_version >= "3.11" and python_version < "3.12"
protobuf==5.28.2 ; python_version >= "3.11" and python_version < "3.12"
pybase64==1.4.0 ; python_version >= "3.11" and python_version < "3.12"
pycparser==2.22 ; python_version >= "3.11" and python_version < "3.12" and platform_python_implementation != "PyPy"
pydantic-core==2.23.3 ; python_version >= "3.11" and python_version < "3.12"
pydantic==2.9.1 ; python_version >= "3.11" and python_version < "3.12"
pydantic-core==2.23.4 ; python_version >= "3.11" and python_version < "3.12"
pydantic==2.9.2 ; python_version >= "3.11" and python_version < "3.12"
pyjwt[crypto]==2.9.0 ; python_version >= "3.11" and python_version < "3.12"
pyreadline3==3.4.1 ; sys_platform == "win32" and python_version >= "3.11" and python_version < "3.12"
pyreadline3==3.5.4 ; sys_platform == "win32" and python_version >= "3.11" and python_version < "3.12"
python-dateutil==2.9.0.post0 ; python_version >= "3.11" and python_version < "3.12"
python-dotenv==1.0.1 ; python_version >= "3.11" and python_version < "3.12"
pyyaml==6.0.2 ; python_version >= "3.11" and python_version < "3.12"
redis==5.0.1 ; python_version >= "3.11" and python_version < "3.12"
regex==2024.9.11 ; python_version >= "3.11" and python_version < "3.12"
requests==2.32.3 ; python_version >= "3.11" and python_version < "3.12"
s3transfer==0.10.2 ; python_version >= "3.11" and python_version < "3.12"
setuptools==72.1.0 ; python_version >= "3.11" and python_version < "3.12"
setuptools==75.1.0 ; python_version >= "3.11" and python_version < "3.12"
six==1.16.0 ; python_version >= "3.11" and python_version < "3.12"
sniffio==1.3.1 ; python_version >= "3.11" and python_version < "3.12"
sqlalchemy==2.0.34 ; python_version >= "3.11" and python_version < "3.12"
sqlalchemy==2.0.35 ; python_version >= "3.11" and python_version < "3.12"
starlette==0.35.1 ; python_version >= "3.11" and python_version < "3.12"
sympy==1.13.1 ; python_version >= "3.11" and python_version < "3.12"
sympy==1.13.3 ; python_version >= "3.11" and python_version < "3.12"
tenacity==8.5.0 ; python_version >= "3.11" and python_version < "3.12"
tiktoken==0.7.0 ; python_version >= "3.11" and python_version < "3.12"
tokenizers==0.15.2 ; python_version >= "3.11" and python_version < "3.12"
tokenizers==0.20.0 ; python_version >= "3.11" and python_version < "3.12"
torch==2.0.1 ; python_version >= "3.11" and python_version < "3.12"
torchaudio==2.0.2 ; python_version >= "3.11" and python_version < "3.12"
tqdm==4.66.5 ; python_version >= "3.11" and python_version < "3.12"
typing-extensions==4.12.2 ; python_version >= "3.11" and python_version < "3.12"
urllib3==2.2.2 ; python_version >= "3.11" and python_version < "3.12"
urllib3==2.2.3 ; python_version >= "3.11" and python_version < "3.12"
uuid6==2024.7.10 ; python_version >= "3.11" and python_version < "3.12"
uvicorn[standard]==0.29.0 ; python_version >= "3.11" and python_version < "3.12"
uvloop==0.19.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.11" and python_version < "3.12"
watchfiles==0.23.0 ; python_version >= "3.11" and python_version < "3.12"
websockets==12.0 ; python_version >= "3.11" and python_version < "3.12"
yarl==1.9.4 ; python_version >= "3.11" and python_version < "3.12"
uvloop==0.20.0 ; (sys_platform != "win32" and sys_platform != "cygwin") and platform_python_implementation != "PyPy" and python_version >= "3.11" and python_version < "3.12"
watchfiles==0.24.0 ; python_version >= "3.11" and python_version < "3.12"
websockets==13.0.1 ; python_version >= "3.11" and python_version < "3.12"
yarl==1.11.1 ; python_version >= "3.11" and python_version < "3.12"
2 changes: 2 additions & 0 deletions skynet/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,10 @@ async def lifespan(main_app: FastAPI):

if 'streaming_whisper' in modules:
from skynet.modules.stt.streaming_whisper.app import app as streaming_whisper_app
from skynet.modules.stt.vox.app import app as vox_app

main_app.mount('/streaming-whisper', streaming_whisper_app)
main_app.mount('/vox', vox_app)

if 'summaries:dispatcher' in modules:
from skynet.modules.ttt.summaries.app import app as summaries_app, app_startup as summaries_startup
Expand Down
2 changes: 1 addition & 1 deletion skynet/modules/stt/streaming_whisper/cfg.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,4 @@ def get_device() -> str:
download_root=whisper_model_path,
)

one_byte_s = 0.00003125 # the equivalent of one byte in seconds
one_byte_s = 0.00003125 # the equivalent of one byte in seconds for 16kHz audio, 2 bytes per sample, mono
91 changes: 91 additions & 0 deletions skynet/modules/stt/vox/app.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
import asyncio

import pybase64

from fastapi import FastAPI, WebSocket, WebSocketDisconnect

from skynet.logs import get_logger
from skynet.modules.stt.streaming_whisper.utils import utils
from skynet.modules.stt.vox.connection_manager import ConnectionManager
from skynet.modules.stt.vox.decoder import PcmaDecoder
from skynet.modules.stt.vox.resampler import PcmResampler

log = get_logger(__name__)

ws_connection_manager = ConnectionManager()
app = FastAPI()
running_tasks = set()
whisper_sampling_rate = 16000


@app.websocket('/ws')
async def websocket_endpoint(websocket: WebSocket, auth_token: str | None = None):
decoder = PcmaDecoder()
resampler = PcmResampler()
session_id = utils.Uuid7().get()
await ws_connection_manager.connect(websocket, session_id, auth_token)

data_map = dict()
resampler = None
sampling_rate = 8000

while True:
try:
ws_data = await websocket.receive_json()

event = ws_data.get('event')

if event == 'start':
try:
sampling_rate = ws_data['start']['mediaFormat']['sampleRate']
except KeyError:
pass

if event == 'media':
media = ws_data['media']
participant_id: str = media['tag']

if participant_id not in data_map:
header = (participant_id.encode() + '|en'.encode()).ljust(60, b'\0')
data_map[participant_id] = dict(raw=b'', chunks=0, header=header)

payload = pybase64.b64decode(media['payload'])
participant = data_map[participant_id]

participant['raw'] += payload
participant['chunks'] += 1

if participant['chunks'] == 50: # 50 chunks = 1 second
frames = decoder.decode(participant['raw'], media['timestamp'])

decoded_raw = b''

resampler = resampler or PcmResampler(
format='s16',
layout='mono',
rate=whisper_sampling_rate,
)

for frame in frames:
decoded_raw += resampler.resample(frame)

task = asyncio.create_task(
ws_connection_manager.process(
session_id, participant['header'] + decoded_raw, media['timestamp'], participant_id
)
)

running_tasks.add(task)
task.add_done_callback(running_tasks.remove)

participant['chunks'] = 0
participant['raw'] = b''

except WebSocketDisconnect:
ws_connection_manager.disconnect(session_id)
data_map.clear()
log.info(f'Session {session_id} has ended')
break


__all__ = ['app']
43 changes: 43 additions & 0 deletions skynet/modules/stt/vox/connection_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from typing import List

from fastapi import WebSocketDisconnect

from skynet.logs import get_logger
from skynet.modules.stt.streaming_whisper.connection_manager import ConnectionManager as BaseConnectionManager
from skynet.modules.stt.streaming_whisper.utils.utils import TranscriptionResponse

log = get_logger(__name__)


class ConnectionManager(BaseConnectionManager):
async def process(self, session_id: str, buffer: bytes, timestamp: int, tag: str):
log.debug(f'Processing chunk for session {session_id}')

if session_id not in self.connections:
log.warning(f'No such session id {session_id}, the connection was probably closed.')
return

results: List[TranscriptionResponse] = await self.connections[session_id].process(buffer, timestamp)

if results is not None:
for result in results:
if result.type == 'final':
log.info(f'Participant {tag} result: {result.text}')
await self.send(session_id, result, timestamp, tag)

async def send(self, session_id: str, result: TranscriptionResponse | None, timestamp: int, tag: str):
if result is not None:
try:
await self.connections[session_id].ws.send_json(
{
'timestamp': timestamp,
'tag': tag,
'final': result.text,
'language': 'en',
}
)
except WebSocketDisconnect as e:
log.warning(f'Session {session_id}: the connection was closed before sending all results: {e}')
self.disconnect(session_id)
except Exception as ex:
log.error(f'Session {session_id}: exception while sending transcription results {ex}')
33 changes: 33 additions & 0 deletions skynet/modules/stt/vox/decoder.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import fractions
import time
from typing import List, Optional

from av import CodecContext
from av.audio.codeccontext import AudioCodecContext
from av.audio.frame import AudioFrame
from av.packet import Packet

SAMPLE_RATE = 8000
TIME_BASE = fractions.Fraction(1, 8000)


class PcmDecoder:
def __init__(self, codec_name: str) -> None:
self.codec: AudioCodecContext = CodecContext.create(codec_name, "r")
self.codec.format = "s16"
self.codec.layout = "mono"
self.codec.sample_rate = SAMPLE_RATE

def decode(self, data: bytes, timestamp: Optional[int] = None) -> List[AudioFrame]:
packet = Packet(data)
packet.pts = timestamp or int(time.time())
packet.time_base = TIME_BASE
return self.codec.decode(packet)


class PcmaDecoder(PcmDecoder):
def __init__(self) -> None:
super().__init__("pcm_alaw")


__all__ = ["PcmaDecoder"]
21 changes: 21 additions & 0 deletions skynet/modules/stt/vox/resampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from av.audio.frame import AudioFrame
from av.audio.resampler import AudioResampler


class PcmResampler:
def __init__(self, **kwargs) -> None:
self.resampler = AudioResampler(**kwargs)
self.layout = kwargs.get("layout")
self.format = kwargs.get("format")

def resample(self, frame: AudioFrame) -> bytes:
resampled_raw = b''
resampled_frames = self.resampler.resample(frame)

for resampled_frame in resampled_frames:
resampled_raw += bytes(resampled_frame.planes[0])

return resampled_raw


__all__ = ["PcmResampler"]

0 comments on commit 9b3344f

Please sign in to comment.