From 048d149f9aee7c530b3f622c9b1ae9d8e284e87b Mon Sep 17 00:00:00 2001 From: Divi kumar Date: Mon, 7 Aug 2023 12:36:38 -0700 Subject: [PATCH] Add test and simplify play ht synthesizer config --- tests/synthesizer/conftest.py | 67 ++++++++++++++++++- tests/synthesizer/test_play_ht.py | 61 +++++++++++++++++ vocode/streaming/models/synthesizer.py | 5 +- .../synthesizer/play_ht_synthesizer.py | 12 +--- 4 files changed, 129 insertions(+), 16 deletions(-) create mode 100644 tests/synthesizer/test_play_ht.py diff --git a/tests/synthesizer/conftest.py b/tests/synthesizer/conftest.py index cc304a00a..040ebeb18 100644 --- a/tests/synthesizer/conftest.py +++ b/tests/synthesizer/conftest.py @@ -1,12 +1,17 @@ import pytest from aioresponses import aioresponses, CallbackResult from vocode.streaming.models.audio_encoding import AudioEncoding -from vocode.streaming.models.synthesizer import ElevenLabsSynthesizerConfig +from vocode.streaming.models.synthesizer import ElevenLabsSynthesizerConfig, PlayHtSynthesizerConfig import re from vocode.streaming.synthesizer.eleven_labs_synthesizer import ( ElevenLabsSynthesizer, ELEVEN_LABS_BASE_URL, ) +from vocode.streaming.synthesizer.play_ht_synthesizer import ( + PlayHtSynthesizer, + TTS_ENDPOINT +) + import re from tests.streaming.data.loader import get_audio_path import asyncio @@ -16,9 +21,10 @@ "audio_encoding": AudioEncoding.LINEAR16} MOCK_API_KEY = "my_api_key" +MOCK_USER_ID = "my_user_id" -def create_request_handler(optimize_streaming_latency=False): +def create_eleven_labs_request_handler(optimize_streaming_latency=False): def request_handler(url, headers, **kwargs): if optimize_streaming_latency and not re.search(r"optimize_streaming_latency=\d", url): raise Exception("optimize_streaming_latency not found in url") @@ -35,7 +41,7 @@ def mock_eleven_labs_api(): with aioresponses() as m: pattern = re.compile( rf"{re.escape(ELEVEN_LABS_BASE_URL)}text-to-speech/\w+") - m.post(pattern, callback=create_request_handler()) + m.post(pattern, callback=create_eleven_labs_request_handler()) yield m @@ -60,3 +66,58 @@ async def fixture_eleven_labs_synthesizer_env_api_key(): os.environ["ELEVEN_LABS_API_KEY"] = MOCK_API_KEY return ElevenLabsSynthesizer(ElevenLabsSynthesizerConfig(**params)) + + +# PlayHT Setup + +def create_play_ht_request_handler(): + def request_handler(url, headers, **kwargs): + if headers["Authorization"] != f"Bearer {MOCK_API_KEY}": + return CallbackResult(status=401) + if headers["X-User-ID"] != MOCK_USER_ID: + return CallbackResult(status=401) + with open(get_audio_path("fake_audio.mp3"), "rb") as audio_file: + return CallbackResult(content_type="audio/mpeg", body=audio_file.read()) + + return request_handler + + +@pytest.fixture +def mock_play_ht_api(): + with aioresponses() as m: + m.post(TTS_ENDPOINT, callback=create_play_ht_request_handler()) + yield m + + +@pytest.fixture(scope="module") +async def fixture_play_ht_synthesizer_with_api_key(): + params = DEFAULT_PARAMS.copy() + params["api_key"] = MOCK_API_KEY + params["user_id"] = MOCK_USER_ID + return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params)) + + +@pytest.fixture(scope="module") +async def fixture_play_ht_synthesizer_wrong_api_key(): + params = DEFAULT_PARAMS.copy() + params["api_key"] = "wrong_api_key" + params["user_id"] = MOCK_USER_ID + return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params)) + +@pytest.fixture(scope="module") +async def fixture_play_ht_synthesizer_wrong_user_id(): + params = DEFAULT_PARAMS.copy() + params["api_key"] = MOCK_API_KEY + params["user_id"] = "wrong_api_key" + return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params)) + + +@pytest.fixture(scope="module") +async def fixture_play_ht_synthesizer_env_api_key(): + params = DEFAULT_PARAMS.copy() + import os + + os.environ["PLAY_HT_API_KEY"] = MOCK_API_KEY + os.environ["PLAY_HT_USER_ID"] = MOCK_USER_ID + return PlayHtSynthesizer(PlayHtSynthesizerConfig(**params)) + diff --git a/tests/synthesizer/test_play_ht.py b/tests/synthesizer/test_play_ht.py new file mode 100644 index 000000000..59b78cfeb --- /dev/null +++ b/tests/synthesizer/test_play_ht.py @@ -0,0 +1,61 @@ +import asyncio + +from aioresponses import aioresponses +from pydub import AudioSegment +import pytest + +from vocode.streaming.models.message import BaseMessage +from vocode.streaming.synthesizer.base_synthesizer import SynthesisResult +from vocode.streaming.synthesizer.play_ht_synthesizer import PlayHtSynthesizer + + +async def assert_synthesis_result_valid(synthesizer: PlayHtSynthesizer): + response = await synthesizer.create_speech(BaseMessage(text="Hello, world!"), 1024) + assert isinstance(response, SynthesisResult) + assert response.chunk_generator is not None + audio = AudioSegment.empty() + async for chunk in response.chunk_generator: + audio += AudioSegment( + chunk.chunk, + frame_rate=synthesizer.synthesizer_config.sampling_rate, + sample_width=2, + channels=1, + ) + + +@pytest.mark.asyncio +async def test_with_api_key( + fixture_play_ht_synthesizer_with_api_key: PlayHtSynthesizer, + mock_play_ht_api: aioresponses, +): + await assert_synthesis_result_valid(await fixture_play_ht_synthesizer_with_api_key) + + +@pytest.mark.asyncio +async def test_with_wrong_api_key( + fixture_play_ht_synthesizer_wrong_api_key: PlayHtSynthesizer, + mock_play_ht_api: aioresponses, +): + with pytest.raises(Exception, match="Play.ht API error status code 401"): + await (await fixture_play_ht_synthesizer_wrong_api_key).create_speech( + BaseMessage(text="Hello, world!"), 1024 + ) + +@pytest.mark.asyncio +async def test_with_wrong_user_id( + fixture_play_ht_synthesizer_wrong_user_id: PlayHtSynthesizer, + mock_play_ht_api: aioresponses, +): + with pytest.raises(Exception, match="Play.ht API error status code 401"): + await (await fixture_play_ht_synthesizer_wrong_user_id).create_speech( + BaseMessage(text="Hello, world!"), 1024 + ) + + + +@pytest.mark.asyncio +async def test_with_env_api_key( + fixture_play_ht_synthesizer_env_api_key: PlayHtSynthesizer, + mock_play_ht_api: aioresponses, +): + await assert_synthesis_result_valid(await fixture_play_ht_synthesizer_env_api_key) diff --git a/vocode/streaming/models/synthesizer.py b/vocode/streaming/models/synthesizer.py index effb23600..305517ebe 100644 --- a/vocode/streaming/models/synthesizer.py +++ b/vocode/streaming/models/synthesizer.py @@ -163,10 +163,9 @@ def override_voice_id_with_prompt(cls, voice_id, values): class PlayHtSynthesizerConfig(SynthesizerConfig, type=SynthesizerType.PLAY_HT.value): + api_key: Optional[str] = None + user_id: Optional[str] = None voice_id: str = PLAYHT_DEFAULT_VOICE_ID - speed: Optional[int] = None - seed: Optional[int] = None - temperature: Optional[int] = None class CoquiTTSSynthesizerConfig( diff --git a/vocode/streaming/synthesizer/play_ht_synthesizer.py b/vocode/streaming/synthesizer/play_ht_synthesizer.py index 410919167..0fb476f34 100644 --- a/vocode/streaming/synthesizer/play_ht_synthesizer.py +++ b/vocode/streaming/synthesizer/play_ht_synthesizer.py @@ -27,15 +27,13 @@ class PlayHtSynthesizer(BaseSynthesizer[PlayHtSynthesizerConfig]): def __init__( self, synthesizer_config: PlayHtSynthesizerConfig, - api_key: Optional[str] = None, - user_id: Optional[str] = None, logger: Optional[logging.Logger] = None, aiohttp_session: Optional[ClientSession] = None, ): super().__init__(synthesizer_config, aiohttp_session) self.synthesizer_config = synthesizer_config - self.api_key = api_key or getenv("PLAY_HT_API_KEY") - self.user_id = user_id or getenv("PLAY_HT_USER_ID") + self.api_key = synthesizer_config.api_key or getenv("PLAY_HT_API_KEY") + self.user_id = synthesizer_config.user_id or getenv("PLAY_HT_USER_ID") if not self.api_key or not self.user_id: raise ValueError( "You must set the PLAY_HT_API_KEY and PLAY_HT_USER_ID environment variables" @@ -58,12 +56,6 @@ async def create_speech( "text": message.text, "sample_rate": self.synthesizer_config.sampling_rate, } - if self.synthesizer_config.speed: - body["speed"] = self.synthesizer_config.speed - if self.synthesizer_config.seed: - body["seed"] = self.synthesizer_config.seed - if self.synthesizer_config.temperature: - body["temperature"] = self.synthesizer_config.temperature create_speech_span = tracer.start_span( f"synthesizer.{SynthesizerType.PLAY_HT.value.split('_', 1)[-1]}.create_total",