Skip to content

Commit

Permalink
Add test and simplify play ht synthesizer config
Browse files Browse the repository at this point in the history
  • Loading branch information
Divi kumar authored and Divi kumar committed Aug 7, 2023
1 parent 746241b commit 048d149
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 16 deletions.
67 changes: 64 additions & 3 deletions tests/synthesizer/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import pytest
from aioresponses import aioresponses, CallbackResult
from vocode.streaming.models.audio_encoding import AudioEncoding
from vocode.streaming.models.synthesizer import ElevenLabsSynthesizerConfig
from vocode.streaming.models.synthesizer import ElevenLabsSynthesizerConfig, PlayHtSynthesizerConfig
import re
from vocode.streaming.synthesizer.eleven_labs_synthesizer import (
ElevenLabsSynthesizer,
ELEVEN_LABS_BASE_URL,
)
from vocode.streaming.synthesizer.play_ht_synthesizer import (
PlayHtSynthesizer,
TTS_ENDPOINT
)

import re
from tests.streaming.data.loader import get_audio_path
import asyncio
Expand All @@ -16,9 +21,10 @@
"audio_encoding": AudioEncoding.LINEAR16}

MOCK_API_KEY = "my_api_key"
MOCK_USER_ID = "my_user_id"


def create_request_handler(optimize_streaming_latency=False):
def create_eleven_labs_request_handler(optimize_streaming_latency=False):
def request_handler(url, headers, **kwargs):
if optimize_streaming_latency and not re.search(r"optimize_streaming_latency=\d", url):
raise Exception("optimize_streaming_latency not found in url")
Expand All @@ -35,7 +41,7 @@ def mock_eleven_labs_api():
with aioresponses() as m:
pattern = re.compile(
rf"{re.escape(ELEVEN_LABS_BASE_URL)}text-to-speech/\w+")
m.post(pattern, callback=create_request_handler())
m.post(pattern, callback=create_eleven_labs_request_handler())
yield m


Expand All @@ -60,3 +66,58 @@ async def fixture_eleven_labs_synthesizer_env_api_key():

os.environ["ELEVEN_LABS_API_KEY"] = MOCK_API_KEY
return ElevenLabsSynthesizer(ElevenLabsSynthesizerConfig(**params))


# PlayHT Setup

def create_play_ht_request_handler():
def request_handler(url, headers, **kwargs):
if headers["Authorization"] != f"Bearer {MOCK_API_KEY}":
return CallbackResult(status=401)
if headers["X-User-ID"] != MOCK_USER_ID:
return CallbackResult(status=401)
with open(get_audio_path("fake_audio.mp3"), "rb") as audio_file:
return CallbackResult(content_type="audio/mpeg", body=audio_file.read())

return request_handler


@pytest.fixture
def mock_play_ht_api():
with aioresponses() as m:
m.post(TTS_ENDPOINT, callback=create_play_ht_request_handler())
yield m


@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_with_api_key():
params = DEFAULT_PARAMS.copy()
params["api_key"] = MOCK_API_KEY
params["user_id"] = MOCK_USER_ID
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))


@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_wrong_api_key():
params = DEFAULT_PARAMS.copy()
params["api_key"] = "wrong_api_key"
params["user_id"] = MOCK_USER_ID
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))

@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_wrong_user_id():
params = DEFAULT_PARAMS.copy()
params["api_key"] = MOCK_API_KEY
params["user_id"] = "wrong_api_key"
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))


@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_env_api_key():
params = DEFAULT_PARAMS.copy()
import os

os.environ["PLAY_HT_API_KEY"] = MOCK_API_KEY
os.environ["PLAY_HT_USER_ID"] = MOCK_USER_ID
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))

61 changes: 61 additions & 0 deletions tests/synthesizer/test_play_ht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import asyncio

from aioresponses import aioresponses
from pydub import AudioSegment
import pytest

from vocode.streaming.models.message import BaseMessage
from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult
from vocode.streaming.synthesizer.play_ht_synthesizer import PlayHtSynthesizer


async def assert_synthesis_result_valid(synthesizer: PlayHtSynthesizer):
response = await synthesizer.create_speech(BaseMessage(text="Hello, world!"), 1024)
assert isinstance(response, SynthesisResult)
assert response.chunk_generator is not None
audio = AudioSegment.empty()
async for chunk in response.chunk_generator:
audio += AudioSegment(
chunk.chunk,
frame_rate=synthesizer.synthesizer_config.sampling_rate,
sample_width=2,
channels=1,
)


@pytest.mark.asyncio
async def test_with_api_key(
fixture_play_ht_synthesizer_with_api_key: PlayHtSynthesizer,
mock_play_ht_api: aioresponses,
):
await assert_synthesis_result_valid(await fixture_play_ht_synthesizer_with_api_key)


@pytest.mark.asyncio
async def test_with_wrong_api_key(
fixture_play_ht_synthesizer_wrong_api_key: PlayHtSynthesizer,
mock_play_ht_api: aioresponses,
):
with pytest.raises(Exception, match="Play.ht API error status code 401"):
await (await fixture_play_ht_synthesizer_wrong_api_key).create_speech(
BaseMessage(text="Hello, world!"), 1024
)

@pytest.mark.asyncio
async def test_with_wrong_user_id(
fixture_play_ht_synthesizer_wrong_user_id: PlayHtSynthesizer,
mock_play_ht_api: aioresponses,
):
with pytest.raises(Exception, match="Play.ht API error status code 401"):
await (await fixture_play_ht_synthesizer_wrong_user_id).create_speech(
BaseMessage(text="Hello, world!"), 1024
)



@pytest.mark.asyncio
async def test_with_env_api_key(
fixture_play_ht_synthesizer_env_api_key: PlayHtSynthesizer,
mock_play_ht_api: aioresponses,
):
await assert_synthesis_result_valid(await fixture_play_ht_synthesizer_env_api_key)
5 changes: 2 additions & 3 deletions vocode/streaming/models/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,9 @@ def override_voice_id_with_prompt(cls, voice_id, values):


class PlayHtSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.PLAY_HT.value):
api_key: Optional[str] = None
user_id: Optional[str] = None
voice_id: str = PLAYHT_DEFAULT_VOICE_ID
speed: Optional[int] = None
seed: Optional[int] = None
temperature: Optional[int] = None


class CoquiTTSSynthesizerConfig(
Expand Down
12 changes: 2 additions & 10 deletions vocode/streaming/synthesizer/play_ht_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ class PlayHtSynthesizer(BaseSynthesizer[PlayHtSynthesizerConfig]):
def __init__(
self,
synthesizer_config: PlayHtSynthesizerConfig,
api_key: Optional[str] = None,
user_id: Optional[str] = None,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[ClientSession] = None,
):
super().__init__(synthesizer_config, aiohttp_session)
self.synthesizer_config = synthesizer_config
self.api_key = api_key or getenv("PLAY_HT_API_KEY")
self.user_id = user_id or getenv("PLAY_HT_USER_ID")
self.api_key = synthesizer_config.api_key or getenv("PLAY_HT_API_KEY")
self.user_id = synthesizer_config.user_id or getenv("PLAY_HT_USER_ID")
if not self.api_key or not self.user_id:
raise ValueError(
"You must set the PLAY_HT_API_KEY and PLAY_HT_USER_ID environment variables"
Expand All @@ -58,12 +56,6 @@ async def create_speech(
"text": message.text,
"sample_rate": self.synthesizer_config.sampling_rate,
}
if self.synthesizer_config.speed:
body["speed"] = self.synthesizer_config.speed
if self.synthesizer_config.seed:
body["seed"] = self.synthesizer_config.seed
if self.synthesizer_config.temperature:
body["temperature"] = self.synthesizer_config.temperature

create_speech_span = tracer.start_span(
f"synthesizer.{SynthesizerType.PLAY_HT.value.split('_', 1)[-1]}.create_total",
Expand Down

0 comments on commit 048d149

Please sign in to comment.