diff --git a/server/main.py b/server/main.py index 2a2ca91..ccbae2c 100644 --- a/server/main.py +++ b/server/main.py @@ -10,9 +10,11 @@ from fastapi import ( FastAPI, UploadFile, + Body, ) from pydantic import BaseModel from fastapi.responses import StreamingResponse + from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from TTS.utils.generic_utils import get_user_data_dir @@ -112,9 +114,11 @@ class StreamingInputs(BaseModel): "ja", ] add_wav_header: bool = True + stream_chunk_size: str = "20" + decoder: str = "ne_hifigan" -async def predict_streaming_generator(parsed_input: StreamingInputs): +async def predict_streaming_generator(parsed_input: dict = Body(...)): speaker_embedding = ( torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1) ) @@ -123,14 +127,15 @@ async def predict_streaming_generator(parsed_input: StreamingInputs): ) text = parsed_input.text language = parsed_input.language - decoder = parsed_input.get("decoder", "ne_hifigan") + decoder = parsed_input.decoder if decoder not in ["ne_hifigan","hifigan"]: decoder = "ne_hifigan" - stream_chunk_size = int(parsed_input.get("stream_chunk_size", "20")) + stream_chunk_size = int(parsed_input.stream_chunk_size) add_wav_header = parsed_input.add_wav_header + chunks = model.inference_stream(text, language, gpt_cond_latent, speaker_embedding, decoder=decoder,stream_chunk_size=stream_chunk_size) for i, chunk in enumerate(chunks): chunk = postprocess(chunk)