Skip to content

Commit

Permalink
Merge branch 'dev' into fix_norm_bug
Browse files Browse the repository at this point in the history
  • Loading branch information
fumiama authored Oct 15, 2024
2 parents 1b9a66b + 7859309 commit 1ef796c
Show file tree
Hide file tree
Showing 9 changed files with 65 additions and 20 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/checksum.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ on:

jobs:
checksum:
runs-on: ubuntu-latest
runs-on: ubuntu-24.04
steps:
- uses: actions/checkout@v4

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/close-issue.yml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ on:

jobs:
close-issues:
runs-on: ubuntu-latest
runs-on: ubuntu-24.04
permissions:
issues: write
pull-requests: write
Expand Down
10 changes: 9 additions & 1 deletion .github/workflows/pull-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ jobs:
# This workflow closes invalid PR
change-or-close-pr:
# The type of runner that the job will run on
runs-on: ubuntu-latest
runs-on: ubuntu-24.04
permissions: write-all

# Steps represent a sequence of tasks that will be executed as part of the job
Expand Down Expand Up @@ -63,6 +63,14 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5

- name: Create venv
run: python3 -m venv .venv

- name: Activate venv
run: |
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Install Black
run: pip install "black[jupyter]"

Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/push-format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@ jobs:
- name: Set up Python
uses: actions/setup-python@v5

- name: Create venv
run: python3 -m venv .venv

- name: Activate venv
run: |
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Install Black
run: pip install "black[jupyter]"

Expand Down
8 changes: 8 additions & 0 deletions .github/workflows/unitest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ jobs:
run: |
sudo apt-get install -y portaudio19-dev python3-pyaudio
- name: Create venv
run: python3 -m venv .venv

- name: Activate venv
run: |
. .venv/bin/activate
echo PATH=$PATH >> $GITHUB_ENV
- name: Test Install
run: pip install .

Expand Down
8 changes: 5 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,11 @@ def _load(
vocos = (
Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head)
.to(
# vocos on mps will crash, use cpu fallback
# Vocos on mps will crash, use cpu fallback.
# Plus, complex dtype used in the decode process of Vocos is not supported in torch_npu now,
# so we put this calculation of data on CPU instead of NPU.
"cpu"
if "mps" in str(device)
if "mps" in str(device) or "npu" in str(device)
else device
)
.eval()
Expand Down Expand Up @@ -422,7 +424,7 @@ def _infer(

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
if "mps" in str(self.device):
if "mps" in str(self.device) or "npu" in str(self.device):
return self.vocos.decode(spec.cpu()).cpu().numpy()
else:
return self.vocos.decode(spec).cpu().numpy()
Expand Down
8 changes: 4 additions & 4 deletions ChatTTS/norm.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,15 @@
from .utils import del_all


@jit
@jit(nopython=True)
def _find_index(table: np.ndarray, val: np.uint16):
for i in range(table.size):
if table[i] == val:
return i
return -1


@jit
@jit(nopython=True)
def _fast_replace(
table: np.ndarray, text: bytes
) -> Tuple[np.ndarray, List[Tuple[str, str]]]:
Expand All @@ -34,7 +34,7 @@ def _fast_replace(
return result, replaced_words


@jit
@jit(nopython=True)
def _split_tags(text: str) -> Tuple[List[str], List[str]]:
texts: List[str] = []
tags: List[str] = []
Expand All @@ -57,7 +57,7 @@ def _split_tags(text: str) -> Tuple[List[str], List[str]]:
return texts, tags


@jit
@jit(nopython=True)
def _combine_tags(texts: List[str], tags: List[str]) -> str:
text = ""
for t in texts:
Expand Down
37 changes: 28 additions & 9 deletions ChatTTS/utils/gpu.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,36 @@
import torch

try:
import torch_npu
except ImportError:
pass

from .log import logger


def select_device(min_memory=2047, experimental=False):
if torch.cuda.is_available():
selected_gpu = 0
has_cuda = torch.cuda.is_available()
if has_cuda or _is_torch_npu_available():
provider = torch.cuda if has_cuda else torch.npu
"""
Using Ascend NPU to accelerate the process of inferencing when GPU is not found.
"""
dev_idx = 0
max_free_memory = -1
for i in range(torch.cuda.device_count()):
props = torch.cuda.get_device_properties(i)
free_memory = props.total_memory - torch.cuda.memory_reserved(i)
for i in range(provider.device_count()):
props = provider.get_device_properties(i)
free_memory = props.total_memory - provider.memory_reserved(i)
if max_free_memory < free_memory:
selected_gpu = i
dev_idx = i
max_free_memory = free_memory
free_memory_mb = max_free_memory / (1024 * 1024)
if free_memory_mb < min_memory:
logger.get_logger().warning(
f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
f"{provider.device(dev_idx)} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
)
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{selected_gpu}")
device = provider._get_device(dev_idx)
elif torch.backends.mps.is_available():
"""
Currently MPS is slower than CPU while needs more memory and core utility,
Expand All @@ -34,7 +44,16 @@ def select_device(min_memory=2047, experimental=False):
logger.get_logger().info("found Apple GPU, but use CPU.")
device = torch.device("cpu")
else:
logger.get_logger().warning("no GPU found, use CPU instead")
logger.get_logger().warning("no GPU or NPU found, use CPU instead")
device = torch.device("cpu")

return device


def _is_torch_npu_available():
try:
# will raise a AttributeError if torch_npu is not imported or a RuntimeError if no NPU found
_ = torch.npu.device_count()
return torch.npu.is_available()
except (AttributeError, RuntimeError):
return False
2 changes: 1 addition & 1 deletion tools/audio/np.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from numba import jit


@jit
@jit(nopython=True)
def float_to_int16(audio: np.ndarray) -> np.ndarray:
am = int(math.ceil(float(np.abs(audio).max())) * 32768)
am = 32767 * 32768 // am
Expand Down

0 comments on commit 1ef796c

Please sign in to comment.