Skip to content

Commit

Permalink
more options for server
Browse files Browse the repository at this point in the history
  • Loading branch information
gorkemgoknar committed Oct 30, 2023
1 parent d76c97c commit 7472a76
Showing 1 changed file with 8 additions and 3 deletions.
11 changes: 8 additions & 3 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,11 @@
from fastapi import (
FastAPI,
UploadFile,
Body,
)
from pydantic import BaseModel
from fastapi.responses import StreamingResponse

from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from TTS.utils.generic_utils import get_user_data_dir
Expand Down Expand Up @@ -112,9 +114,11 @@ class StreamingInputs(BaseModel):
"ja",
]
add_wav_header: bool = True
stream_chunk_size: str = "20"
decoder: str = "ne_hifigan"


async def predict_streaming_generator(parsed_input: StreamingInputs):
async def predict_streaming_generator(parsed_input: dict = Body(...)):
speaker_embedding = (
torch.tensor(parsed_input.speaker_embedding).unsqueeze(0).unsqueeze(-1)
)
Expand All @@ -123,14 +127,15 @@ async def predict_streaming_generator(parsed_input: StreamingInputs):
)
text = parsed_input.text
language = parsed_input.language
decoder = parsed_input.get("decoder", "ne_hifigan")
decoder = parsed_input.decoder

if decoder not in ["ne_hifigan","hifigan"]:
decoder = "ne_hifigan"

stream_chunk_size = int(parsed_input.get("stream_chunk_size", "20"))
stream_chunk_size = int(parsed_input.stream_chunk_size)
add_wav_header = parsed_input.add_wav_header


chunks = model.inference_stream(text, language, gpt_cond_latent, speaker_embedding, decoder=decoder,stream_chunk_size=stream_chunk_size)
for i, chunk in enumerate(chunks):
chunk = postprocess(chunk)
Expand Down

0 comments on commit 7472a76

Please sign in to comment.