From 4fc282e1e62986dc68b20b65dbb73a6881275d93 Mon Sep 17 00:00:00 2001 From: Enno Hermann Date: Thu, 7 Nov 2024 00:33:54 +0100 Subject: [PATCH] refactor: use external package for monotonic alignment --- TTS/tts/utils/helpers.py | 42 +++++++++----------- TTS/tts/utils/monotonic_align/__init__.py | 0 TTS/tts/utils/monotonic_align/core.pyx | 47 ----------------------- pyproject.toml | 1 + 4 files changed, 20 insertions(+), 70 deletions(-) delete mode 100644 TTS/tts/utils/monotonic_align/__init__.py delete mode 100644 TTS/tts/utils/monotonic_align/core.pyx diff --git a/TTS/tts/utils/helpers.py b/TTS/tts/utils/helpers.py index 7429d0fcc8..ab7730ad8e 100644 --- a/TTS/tts/utils/helpers.py +++ b/TTS/tts/utils/helpers.py @@ -1,14 +1,10 @@ import numpy as np import torch +from monotonic_alignment_search import maximum_path as maximum_path_cython from scipy.stats import betabinom from torch.nn import functional as F -try: - from TTS.tts.utils.monotonic_align.core import maximum_path_c - - CYTHON = True -except ModuleNotFoundError: - CYTHON = False +CYTHON = True class StandardScaler: @@ -174,23 +170,23 @@ def maximum_path(value, mask): return maximum_path_numpy(value, mask) -def maximum_path_cython(value, mask): - """Cython optimised version. - Shapes: - - value: :math:`[B, T_en, T_de]` - - mask: :math:`[B, T_en, T_de]` - """ - value = value * mask - device = value.device - dtype = value.dtype - value = value.data.cpu().numpy().astype(np.float32) - path = np.zeros_like(value).astype(np.int32) - mask = mask.data.cpu().numpy() - - t_x_max = mask.sum(1)[:, 0].astype(np.int32) - t_y_max = mask.sum(2)[:, 0].astype(np.int32) - maximum_path_c(path, value, t_x_max, t_y_max) - return torch.from_numpy(path).to(device=device, dtype=dtype) +# def maximum_path_cython(value, mask): +# """Cython optimised version. +# Shapes: +# - value: :math:`[B, T_en, T_de]` +# - mask: :math:`[B, T_en, T_de]` +# """ +# value = value * mask +# device = value.device +# dtype = value.dtype +# value = value.data.cpu().numpy().astype(np.float32) +# path = np.zeros_like(value).astype(np.int32) +# mask = mask.data.cpu().numpy() + +# t_x_max = mask.sum(1)[:, 0].astype(np.int32) +# t_y_max = mask.sum(2)[:, 0].astype(np.int32) +# maximum_path_c(path, value, t_x_max, t_y_max) +# return torch.from_numpy(path).to(device=device, dtype=dtype) def maximum_path_numpy(value, mask, max_neg_val=None): diff --git a/TTS/tts/utils/monotonic_align/__init__.py b/TTS/tts/utils/monotonic_align/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/TTS/tts/utils/monotonic_align/core.pyx b/TTS/tts/utils/monotonic_align/core.pyx deleted file mode 100644 index 091fcc3a50..0000000000 --- a/TTS/tts/utils/monotonic_align/core.pyx +++ /dev/null @@ -1,47 +0,0 @@ -import numpy as np - -cimport cython -cimport numpy as np - -from cython.parallel import prange - - -@cython.boundscheck(False) -@cython.wraparound(False) -cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil: - cdef int x - cdef int y - cdef float v_prev - cdef float v_cur - cdef float tmp - cdef int index = t_x - 1 - - for y in range(t_y): - for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)): - if x == y: - v_cur = max_neg_val - else: - v_cur = value[x, y-1] - if x == 0: - if y == 0: - v_prev = 0. - else: - v_prev = max_neg_val - else: - v_prev = value[x-1, y-1] - value[x, y] = max(v_cur, v_prev) + value[x, y] - - for y in range(t_y - 1, -1, -1): - path[index, y] = 1 - if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]): - index = index - 1 - - -@cython.boundscheck(False) -@cython.wraparound(False) -cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil: - cdef int b = values.shape[0] - - cdef int i - for i in prange(b, nogil=True): - maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val) diff --git a/pyproject.toml b/pyproject.toml index 23387fd37d..d13e2145d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ dependencies = [ # Coqui stack "coqui-tts-trainer>=0.1.4,<0.2.0", "coqpit>=0.0.16", + "monotonic-alignment-search>=0.1.0", # Gruut + supported languages "gruut[de,es,fr]>=2.4.0", # Tortoise