Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NPU Support #777

Merged
merged 10 commits into from
Oct 10, 2024
8 changes: 5 additions & 3 deletions ChatTTS/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,11 @@ def _load(
vocos = (
Vocos(feature_extractor=feature_extractor, backbone=backbone, head=head)
.to(
# vocos on mps will crash, use cpu fallback
# Vocos on mps will crash, use cpu fallback.
# Plus, complex dtype used in the decode process of Vocos is not supported in torch_npu now,
# so we put this calculation of data on CPU instead of NPU.
"cpu"
if "mps" in str(device)
if "mps" in str(device) or "npu" in str(device)
else device
)
.eval()
Expand Down Expand Up @@ -422,7 +424,7 @@ def _infer(

@torch.inference_mode()
def _vocos_decode(self, spec: torch.Tensor) -> np.ndarray:
if "mps" in str(self.device):
if "mps" in str(self.device) or "npu" in str(self.device):
return self.vocos.decode(spec.cpu()).cpu().numpy()
else:
return self.vocos.decode(spec).cpu().numpy()
Expand Down
36 changes: 35 additions & 1 deletion ChatTTS/utils/gpu.py
fumiama marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
import torch

try:
import torch_npu
except ImportError:
pass

from .log import logger


Expand All @@ -21,6 +26,26 @@ def select_device(min_memory=2047, experimental=False):
device = torch.device("cpu")
else:
device = torch.device(f"cuda:{selected_gpu}")
elif _is_torch_npu_available():
"""
Using Ascend NPU to accelerate the process of inferencing when GPU is not found.
"""
selected_npu = 0
max_free_memory = -1
for i in range(torch.npu.device_count()):
props = torch.npu.get_device_properties(i)
free_memory = props.total_memory - torch.npu.memory_reserved(i)
if max_free_memory < free_memory:
selected_npu = i
max_free_memory = free_memory
free_memory_mb = max_free_memory / (1024 * 1024)
if free_memory_mb < min_memory:
logger.get_logger().warning(
f"NPU {selected_npu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU."
)
device = torch.device("cpu")
else:
device = torch.device(f"npu:{selected_npu}")
elif torch.backends.mps.is_available():
"""
Currently MPS is slower than CPU while needs more memory and core utility,
Expand All @@ -34,7 +59,16 @@ def select_device(min_memory=2047, experimental=False):
logger.get_logger().info("found Apple GPU, but use CPU.")
device = torch.device("cpu")
else:
logger.get_logger().warning("no GPU found, use CPU instead")
logger.get_logger().warning("no GPU or NPU found, use CPU instead")
device = torch.device("cpu")

return device


def _is_torch_npu_available():
try:
fumiama marked this conversation as resolved.
Show resolved Hide resolved
# will raise a AttributeError if torch_npu is not imported or a RuntimeError if no NPU found
_ = torch.npu.device_count()
return torch.npu.is_available()
except (AttributeError, RuntimeError):
return False
Loading