Skip to content

Commit

Permalink
adds compute total chars interface (#683)
Browse files Browse the repository at this point in the history
  • Loading branch information
ajar98 authored Aug 16, 2024
1 parent 6e6f37a commit 9cda52c
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 13 deletions.
48 changes: 36 additions & 12 deletions vocode/streaming/synthesizer/azure_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ async def get_phrase_filler_audios(self) -> List[FillerAudio]:
audio_data = open(filler_audio_path, "rb").read()
else:
logger.debug(f"Generating filler audio for {filler_phrase.text}")
ssml = self.create_ssml(filler_phrase.text)
ssml = self.create_ssml(
message=filler_phrase.text, synthesizer_config=self.synthesizer_config
)
self.total_chars += self.get_total_chars_from_ssml(ssml)
result = await asyncio.get_event_loop().run_in_executor(
self.thread_pool_executor, self.synthesizer.speak_ssml, ssml
)
Expand Down Expand Up @@ -176,16 +179,35 @@ def add_marks(self, message: str, index=0) -> str:
def word_boundary_cb(self, evt, pool):
pool.add(evt)

def create_ssml(self, message: str) -> str:
voice_language_code = self.synthesizer_config.voice_name[:5]
@classmethod
def compute_total_chars(
cls, message: BaseMessage, synthesizer_config: AzureSynthesizerConfig
) -> int:
ssml = (
message.ssml
if isinstance(message, SSMLMessage)
else cls.create_ssml(message=message.text, synthesizer_config=synthesizer_config)
)
return cls.get_total_chars_from_ssml(ssml)

@staticmethod
def get_total_chars_from_ssml(ssml: str) -> int:
regmatch = re.search(_AZURE_INSIDE_VOICE_REGEX, ssml, re.DOTALL)
if regmatch:
return len(regmatch.group(1))
return 0

@staticmethod
def create_ssml(message: str, synthesizer_config: AzureSynthesizerConfig) -> str:
voice_language_code = synthesizer_config.voice_name[:5]
ssml_root = ElementTree.fromstring(
f'<speak version="1.0" xmlns="https://www.w3.org/2001/10/synthesis" xml:lang="{voice_language_code}"></speak>'
)
voice = ElementTree.SubElement(ssml_root, "voice")
voice.set("name", self.voice_name)
if self.synthesizer_config.language_code != "en-US":
voice.set("name", synthesizer_config.voice_name)
if synthesizer_config.language_code != "en-US":
lang = ElementTree.SubElement(voice, "{%s}lang" % NAMESPACES.get(""))
lang.set("xml:lang", self.synthesizer_config.language_code)
lang.set("xml:lang", synthesizer_config.language_code)
voice_root = lang
else:
voice_root = voice
Expand All @@ -198,13 +220,10 @@ def create_ssml(self, message: str) -> str:
silence.set("value", "500ms")
silence.set("type", "Tailing-exact")
prosody = ElementTree.SubElement(voice_root, "prosody")
prosody.set("pitch", f"{self.pitch}%")
prosody.set("rate", f"{self.rate}%")
prosody.set("pitch", f"{synthesizer_config.pitch}%")
prosody.set("rate", f"{synthesizer_config.rate}%")
prosody.text = message.strip()
ssml = ElementTree.tostring(ssml_root, encoding="unicode")
regmatch = re.search(_AZURE_INSIDE_VOICE_REGEX, ssml, re.DOTALL)
if regmatch:
self.total_chars += len(regmatch.group(1))
return ssml

def synthesize_ssml(self, ssml: str) -> speechsdk.AudioDataStream:
Expand Down Expand Up @@ -293,7 +312,12 @@ async def chunk_generator(
self.synthesizer.synthesis_word_boundary.connect(
lambda event: self.word_boundary_cb(event, word_boundary_event_pool)
)
ssml = message.ssml if isinstance(message, SSMLMessage) else self.create_ssml(message.text)
ssml = (
message.ssml
if isinstance(message, SSMLMessage)
else self.create_ssml(message=message.text, synthesizer_config=self.synthesizer_config)
)
self.total_chars += self.get_total_chars_from_ssml(ssml)
audio_data_stream = await asyncio.get_event_loop().run_in_executor(
self.thread_pool_executor, self.synthesize_ssml, ssml
)
Expand Down
8 changes: 7 additions & 1 deletion vocode/streaming/synthesizer/base_synthesizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ def create_synthesis_result(self, chunk_size) -> SynthesisResult:

class BaseSynthesizer(Generic[SynthesizerConfigType]):
streaming_conversation: "StreamingConversation"
total_chars: int

def __init__(
self,
Expand Down Expand Up @@ -277,9 +278,14 @@ def get_typing_noise_filler_audio(self) -> FillerAudio:
seconds_per_chunk=2,
)

def get_cost(self) -> float:
@classmethod
def get_cost(cls, total_chars) -> float:
raise NotImplementedError

@classmethod
def compute_total_chars(cls, message: BaseMessage, synthesizer_config: SynthesizerConfigType):
return len(message.text)

async def set_filler_audios(self, filler_audio_config: FillerAudioConfig):
if filler_audio_config.use_phrases:
self.filler_audios = await self.get_phrase_filler_audios()
Expand Down

0 comments on commit 9cda52c

Please sign in to comment.