diff --git a/.github/workflows/checksum.yml b/.github/workflows/checksum.yml index 162c43e12..4a92578a7 100644 --- a/.github/workflows/checksum.yml +++ b/.github/workflows/checksum.yml @@ -4,7 +4,7 @@ on: jobs: checksum: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 steps: - uses: actions/checkout@v4 diff --git a/.github/workflows/close-issue.yml b/.github/workflows/close-issue.yml index 32c54d6ae..37e526eb3 100644 --- a/.github/workflows/close-issue.yml +++ b/.github/workflows/close-issue.yml @@ -5,7 +5,7 @@ on: jobs: close-issues: - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 permissions: issues: write pull-requests: write diff --git a/.github/workflows/pull-format.yml b/.github/workflows/pull-format.yml index 57a783bd9..b5ebf5826 100644 --- a/.github/workflows/pull-format.yml +++ b/.github/workflows/pull-format.yml @@ -8,7 +8,7 @@ jobs: # This workflow closes invalid PR change-or-close-pr: # The type of runner that the job will run on - runs-on: ubuntu-latest + runs-on: ubuntu-24.04 permissions: write-all # Steps represent a sequence of tasks that will be executed as part of the job @@ -63,6 +63,14 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 + - name: Create venv + run: python3 -m venv .venv + + - name: Activate venv + run: | + . .venv/bin/activate + echo PATH=$PATH >> $GITHUB_ENV + - name: Install Black run: pip install "black[jupyter]" diff --git a/.github/workflows/push-format.yml b/.github/workflows/push-format.yml index 15fe6caca..0d3a9b6f2 100644 --- a/.github/workflows/push-format.yml +++ b/.github/workflows/push-format.yml @@ -24,6 +24,14 @@ jobs: - name: Set up Python uses: actions/setup-python@v5 + - name: Create venv + run: python3 -m venv .venv + + - name: Activate venv + run: | + . .venv/bin/activate + echo PATH=$PATH >> $GITHUB_ENV + - name: Install Black run: pip install "black[jupyter]" diff --git a/.github/workflows/unitest.yml b/.github/workflows/unitest.yml index 77f44c33d..e0395e813 100644 --- a/.github/workflows/unitest.yml +++ b/.github/workflows/unitest.yml @@ -25,6 +25,14 @@ jobs: run: | sudo apt-get install -y portaudio19-dev python3-pyaudio + - name: Create venv + run: python3 -m venv .venv + + - name: Activate venv + run: | + . .venv/bin/activate + echo PATH=$PATH >> $GITHUB_ENV + - name: Test Install run: pip install . diff --git a/ChatTTS/core.py b/ChatTTS/core.py index c38ad8957..4447dca54 100644 --- a/ChatTTS/core.py +++ b/ChatTTS/core.py @@ -253,9 +253,11 @@ def _load( vocos = ( Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head) .to( - # vocos on mps will crash, use cpu fallback + # Vocos on mps will crash, use cpu fallback. + # Plus, complex dtype used in the decode process of Vocos is not supported in torch_npu now, + # so we put this calculation of data on CPU instead of NPU. "cpu" - if "mps" in str(device) + if "mps" in str(device) or "npu" in str(device) else device ) .eval() @@ -422,7 +424,7 @@ def _infer( @torch.inference_mode() def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray: - if "mps" in str(self.device): + if "mps" in str(self.device) or "npu" in str(self.device): return self.vocos.decode(spec.cpu()).cpu().numpy() else: return self.vocos.decode(spec).cpu().numpy() diff --git a/ChatTTS/norm.py b/ChatTTS/norm.py index 854c42a25..355a59cb7 100644 --- a/ChatTTS/norm.py +++ b/ChatTTS/norm.py @@ -10,7 +10,7 @@ from .utils import del_all -@jit +@jit(nopython=True) def _find_index(table: np.ndarray, val: np.uint16): for i in range(table.size): if table[i] == val: @@ -18,7 +18,7 @@ def _find_index(table: np.ndarray, val: np.uint16): return -1 -@jit +@jit(nopython=True) def _fast_replace( table: np.ndarray, text: bytes ) -> Tuple[np.ndarray, List[Tuple[str, str]]]: @@ -34,7 +34,7 @@ def _fast_replace( return result, replaced_words -@jit +@jit(nopython=True) def _split_tags(text: str) -> Tuple[List[str], List[str]]: texts: List[str] = [] tags: List[str] = [] @@ -57,7 +57,7 @@ def _split_tags(text: str) -> Tuple[List[str], List[str]]: return texts, tags -@jit +@jit(nopython=True) def _combine_tags(texts: List[str], tags: List[str]) -> str: text = "" for t in texts: diff --git a/ChatTTS/utils/gpu.py b/ChatTTS/utils/gpu.py index 58aeb3eea..40698be94 100644 --- a/ChatTTS/utils/gpu.py +++ b/ChatTTS/utils/gpu.py @@ -1,26 +1,36 @@ import torch +try: + import torch_npu +except ImportError: + pass + from .log import logger def select_device(min_memory=2047, experimental=False): - if torch.cuda.is_available(): - selected_gpu = 0 + has_cuda = torch.cuda.is_available() + if has_cuda or _is_torch_npu_available(): + provider = torch.cuda if has_cuda else torch.npu + """ + Using Ascend NPU to accelerate the process of inferencing when GPU is not found. + """ + dev_idx = 0 max_free_memory = -1 - for i in range(torch.cuda.device_count()): - props = torch.cuda.get_device_properties(i) - free_memory = props.total_memory - torch.cuda.memory_reserved(i) + for i in range(provider.device_count()): + props = provider.get_device_properties(i) + free_memory = props.total_memory - provider.memory_reserved(i) if max_free_memory < free_memory: - selected_gpu = i + dev_idx = i max_free_memory = free_memory free_memory_mb = max_free_memory / (1024 * 1024) if free_memory_mb < min_memory: logger.get_logger().warning( - f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." + f"{provider.device(dev_idx)} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." ) device = torch.device("cpu") else: - device = torch.device(f"cuda:{selected_gpu}") + device = provider._get_device(dev_idx) elif torch.backends.mps.is_available(): """ Currently MPS is slower than CPU while needs more memory and core utility, @@ -34,7 +44,16 @@ def select_device(min_memory=2047, experimental=False): logger.get_logger().info("found Apple GPU, but use CPU.") device = torch.device("cpu") else: - logger.get_logger().warning("no GPU found, use CPU instead") + logger.get_logger().warning("no GPU or NPU found, use CPU instead") device = torch.device("cpu") return device + + +def _is_torch_npu_available(): + try: + # will raise a AttributeError if torch_npu is not imported or a RuntimeError if no NPU found + _ = torch.npu.device_count() + return torch.npu.is_available() + except (AttributeError, RuntimeError): + return False diff --git a/tools/audio/np.py b/tools/audio/np.py index a1aee2047..b0e082fd8 100644 --- a/tools/audio/np.py +++ b/tools/audio/np.py @@ -4,7 +4,7 @@ from numba import jit -@jit +@jit(nopython=True) def float_to_int16(audio: np.ndarray) -> np.ndarray: am = int(math.ceil(float(np.abs(audio).max())) * 32768) am = 32767 * 32768 // am