Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add test and simplify play ht synthesizer config #348

Merged
merged 3 commits into from
Aug 9, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 64 additions & 3 deletions tests/synthesizer/conftest.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import pytest
from aioresponses import aioresponses, CallbackResult
from vocode.streaming.models.audio_encoding import AudioEncoding
from vocode.streaming.models.synthesizer import ElevenLabsSynthesizerConfig
from vocode.streaming.models.synthesizer import ElevenLabsSynthesizerConfig, PlayHtSynthesizerConfig
import re
from vocode.streaming.synthesizer.eleven_labs_synthesizer import (
ElevenLabsSynthesizer,
ELEVEN_LABS_BASE_URL,
)
from vocode.streaming.synthesizer.play_ht_synthesizer import (
PlayHtSynthesizer,
TTS_ENDPOINT
)

import re
from tests.streaming.data.loader import get_audio_path
import asyncio
Expand All @@ -16,9 +21,10 @@
"audio_encoding": AudioEncoding.LINEAR16}

MOCK_API_KEY = "my_api_key"
MOCK_USER_ID = "my_user_id"


def create_request_handler(optimize_streaming_latency=False):
def create_eleven_labs_request_handler(optimize_streaming_latency=False):
def request_handler(url, headers, **kwargs):
if optimize_streaming_latency and not re.search(r"optimize_streaming_latency=\d", url):
raise Exception("optimize_streaming_latency not found in url")
Expand All @@ -35,7 +41,7 @@ def mock_eleven_labs_api():
with aioresponses() as m:
pattern = re.compile(
rf"{re.escape(ELEVEN_LABS_BASE_URL)}text-to-speech/\w+")
m.post(pattern, callback=create_request_handler())
m.post(pattern, callback=create_eleven_labs_request_handler())
yield m


Expand All @@ -60,3 +66,58 @@ async def fixture_eleven_labs_synthesizer_env_api_key():

os.environ["ELEVEN_LABS_API_KEY"] = MOCK_API_KEY
return ElevenLabsSynthesizer(ElevenLabsSynthesizerConfig(**params))


# PlayHT Setup

def create_play_ht_request_handler():
def request_handler(url, headers, **kwargs):
if headers["Authorization"] != f"Bearer {MOCK_API_KEY}":
return CallbackResult(status=401)
if headers["X-User-ID"] != MOCK_USER_ID:
return CallbackResult(status=401)
with open(get_audio_path("fake_audio.mp3"), "rb") as audio_file:
return CallbackResult(content_type="audio/mpeg", body=audio_file.read())

return request_handler


@pytest.fixture
def mock_play_ht_api():
with aioresponses() as m:
m.post(TTS_ENDPOINT, callback=create_play_ht_request_handler())
yield m


@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_with_api_key():
params = DEFAULT_PARAMS.copy()
params["api_key"] = MOCK_API_KEY
params["user_id"] = MOCK_USER_ID
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))


@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_wrong_api_key():
params = DEFAULT_PARAMS.copy()
params["api_key"] = "wrong_api_key"
params["user_id"] = MOCK_USER_ID
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))

@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_wrong_user_id():
params = DEFAULT_PARAMS.copy()
params["api_key"] = MOCK_API_KEY
params["user_id"] = "wrong_api_key"
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))


@pytest.fixture(scope="module")
async def fixture_play_ht_synthesizer_env_api_key():
params = DEFAULT_PARAMS.copy()
import os

os.environ["PLAY_HT_API_KEY"] = MOCK_API_KEY
os.environ["PLAY_HT_USER_ID"] = MOCK_USER_ID
return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params))

61 changes: 61 additions & 0 deletions tests/synthesizer/test_play_ht.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import asyncio

from aioresponses import aioresponses
from pydub import AudioSegment
import pytest

from vocode.streaming.models.message import BaseMessage
from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult
from vocode.streaming.synthesizer.play_ht_synthesizer import PlayHtSynthesizer


async def assert_synthesis_result_valid(synthesizer: PlayHtSynthesizer):
response = await synthesizer.create_speech(BaseMessage(text="Hello, world!"), 1024)
assert isinstance(response, SynthesisResult)
assert response.chunk_generator is not None
audio = AudioSegment.empty()
async for chunk in response.chunk_generator:
audio += AudioSegment(
chunk.chunk,
frame_rate=synthesizer.synthesizer_config.sampling_rate,
sample_width=2,
channels=1,
)


@pytest.mark.asyncio
async def test_with_api_key(
fixture_play_ht_synthesizer_with_api_key: PlayHtSynthesizer,
mock_play_ht_api: aioresponses,
):
await assert_synthesis_result_valid(await fixture_play_ht_synthesizer_with_api_key)


@pytest.mark.asyncio
async def test_with_wrong_api_key(
fixture_play_ht_synthesizer_wrong_api_key: PlayHtSynthesizer,
mock_play_ht_api: aioresponses,
):
with pytest.raises(Exception, match="Play.ht API error status code 401"):
await (await fixture_play_ht_synthesizer_wrong_api_key).create_speech(
BaseMessage(text="Hello, world!"), 1024
)

@pytest.mark.asyncio
async def test_with_wrong_user_id(
fixture_play_ht_synthesizer_wrong_user_id: PlayHtSynthesizer,
mock_play_ht_api: aioresponses,
):
with pytest.raises(Exception, match="Play.ht API error status code 401"):
await (await fixture_play_ht_synthesizer_wrong_user_id).create_speech(
BaseMessage(text="Hello, world!"), 1024
)



@pytest.mark.asyncio
async def test_with_env_api_key(
fixture_play_ht_synthesizer_env_api_key: PlayHtSynthesizer,
mock_play_ht_api: aioresponses,
):
await assert_synthesis_result_valid(await fixture_play_ht_synthesizer_env_api_key)
4 changes: 3 additions & 1 deletion vocode/streaming/models/synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,10 +163,12 @@ def override_voice_id_with_prompt(cls, voice_id, values):


class PlayHtSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.PLAY_HT.value):
voice_id: str = PLAYHT_DEFAULT_VOICE_ID
api_key: Optional[str] = None
user_id: Optional[str] = None
speed: Optional[int] = None
seed: Optional[int] = None
temperature: Optional[int] = None
voice_id: str = PLAYHT_DEFAULT_VOICE_ID


class CoquiTTSSynthesizerConfig(
Expand Down
6 changes: 2 additions & 4 deletions vocode/streaming/synthesizer/play_ht_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,13 @@ class PlayHtSynthesizer(BaseSynthesizer[PlayHtSynthesizerConfig]):
def __init__(
self,
synthesizer_config: PlayHtSynthesizerConfig,
api_key: Optional[str] = None,
user_id: Optional[str] = None,
logger: Optional[logging.Logger] = None,
aiohttp_session: Optional[ClientSession] = None,
):
super().__init__(synthesizer_config, aiohttp_session)
self.synthesizer_config = synthesizer_config
self.api_key = api_key or getenv("PLAY_HT_API_KEY")
self.user_id = user_id or getenv("PLAY_HT_USER_ID")
self.api_key = synthesizer_config.api_key or getenv("PLAY_HT_API_KEY")
self.user_id = synthesizer_config.user_id or getenv("PLAY_HT_USER_ID")
if not self.api_key or not self.user_id:
raise ValueError(
"You must set the PLAY_HT_API_KEY and PLAY_HT_USER_ID environment variables"
Expand Down
Loading