Skip to content

Commit

Permalink
Adding custom model loading
Browse files Browse the repository at this point in the history
  • Loading branch information
WeberJulian committed Nov 29, 2023
1 parent a0cc660 commit 818a108
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
15 changes: 14 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,21 @@ docker build -t xtts-stream . -f Dockerfile.cuda121
2. Run the server container:

```bash
$ docker run --gpus=all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 xtts-stream
$ docker run --gpus all -e COQUI_TOS_AGREED=1 --rm -p 8000:80 xtts-stream
```

Setting the `COQUI_TOS_AGREED` environment variable to `1` indicates you have read and agreed to
the terms of the [CPML license](https://coqui.ai/cpml).

2. (bis) Run the server container with your own model:

```bash
docker run -v /path/to/model/folder:/app/tts_models --gpus all --rm -p 8000:80 xtts-stream
```

Make sure the model folder contains the following files:
- `config.json`
- `model.pth`
- `vocab.json`

(Fine-tuned XTTS models also are under the [CPML license](https://coqui.ai/cpml))
1 change: 1 addition & 0 deletions server/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ RUN python -m pip install --use-deprecated=legacy-resolver -r requirements.txt \
&& python -m pip cache purge

RUN python -m unidic download
RUN mkdir -p /app/tts_models

COPY main.py .
ENV NVIDIA_DISABLE_REQUIRE=1
Expand Down
1 change: 1 addition & 0 deletions server/Dockerfile.cuda121
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ RUN python -m pip install --use-deprecated=legacy-resolver -r requirements.txt \
&& python -m pip cache purge

RUN python -m unidic download
RUN mkdir -p /app/tts_models

COPY main.py .

Expand Down
17 changes: 12 additions & 5 deletions server/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@
torch.set_num_threads(int(os.environ.get("NUM_THREADS", "2")))
device = torch.device("cuda")

model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
print("Downloading XTTS Model:",model_name,flush=True)
ModelManager().download_model(model_name)
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
print("XTTS Model downloaded",flush=True)
custom_model_path = os.environ.get("CUSTOM_MODEL_PATH", "/app/tts_models")

if os.path.exists(custom_model_path) and os.path.isfile(custom_model_path + "/config.json"):
model_path = custom_model_path
print("Loading custom model from", model_path, flush=True)
else:
print("Loading default model", flush=True)
model_name = "tts_models/multilingual/multi-dataset/xtts_v2"
print("Downloading XTTS Model:",model_name, flush=True)
ModelManager().download_model(model_name)
model_path = os.path.join(get_user_data_dir("tts"), model_name.replace("/", "--"))
print("XTTS Model downloaded",flush=True)

print("Loading XTTS",flush=True)
config = XttsConfig()
Expand Down

0 comments on commit 818a108

Please sign in to comment.