From 79174a70320c5bc2ad21f51033f523ae27eede52 Mon Sep 17 00:00:00 2001 From: Michael Kuhlmann Date: Thu, 4 Apr 2024 12:46:57 +0200 Subject: [PATCH 1/3] contrib: Add fast recursive scandir Recursively find all files matching the extensions. Taken from https://stackoverflow.com/a/59803793/16085876 --- padertorch/contrib/mk/io.py | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) create mode 100644 padertorch/contrib/mk/io.py diff --git a/padertorch/contrib/mk/io.py b/padertorch/contrib/mk/io.py new file mode 100644 index 00000000..76417aa9 --- /dev/null +++ b/padertorch/contrib/mk/io.py @@ -0,0 +1,22 @@ +import os +from pathlib import Path +from typing import List + + +# https://stackoverflow.com/a/59803793/16085876 +def run_fast_scandir(dir: Path, ext: List[str]): + subfolders, files = [], [] + + for f in os.scandir(dir): + if f.is_dir(): + subfolders.append(f.path) + if f.is_file(): + if os.path.splitext(f.name)[1].lower() in ext: + files.append(Path(f.path)) + + + for dir in list(subfolders): + sf, f = run_fast_scandir(dir, ext) + subfolders.extend(sf) + files.extend(f) + return subfolders, files From ff94a2a4c5f88121c01c95a7b5fc3a061fbd2467 Mon Sep 17 00:00:00 2001 From: Michael Kuhlmann Date: Thu, 4 Apr 2024 14:27:07 +0200 Subject: [PATCH 2/3] contrib: Add tensorboard utils --- padertorch/contrib/mk/tbx_utils.py | 93 ++++++++++++++++++++++++++++++ 1 file changed, 93 insertions(+) create mode 100644 padertorch/contrib/mk/tbx_utils.py diff --git a/padertorch/contrib/mk/tbx_utils.py b/padertorch/contrib/mk/tbx_utils.py new file mode 100644 index 00000000..4c593e01 --- /dev/null +++ b/padertorch/contrib/mk/tbx_utils.py @@ -0,0 +1,93 @@ +import typing + +import numpy as np +from padertorch.utils import to_numpy +from padertorch.summary.tbx_utils import spectrogram_to_image +import torch +from torch import Tensor +from torchvision.utils import make_grid + + +def tensor_to_image(signal: Tensor, input_type: str): + x = to_numpy(signal, detach=True) + if input_type == 'image': + x = (x * 255).astype(np.uint8) + elif input_type == 'spectrogram': + x = spectrogram_to_image( + x.transpose(-1, -2), batch_first=None, log=False + ) + else: + raise ValueError(f'Unknown input type {input_type}') + return x + + +def batch_image_to_grid( + batch_image: torch.Tensor, + input_shape_format: str = 'b c h w', + height_axis: str = 'h', + width_axis: str = 'w', + stack: typing.Optional[str] = None, + origin: str = 'upper', + normalize: bool = True, + scale_each: bool = False, +): + """ + >>> batch_image = torch.rand(4, 3, 32, 32) + >>> grid = batch_image_to_grid(batch_image) + >>> grid.shape + torch.Size([3, 138, 36]) + >>> grid = batch_image_to_grid(\ + torch.rand(4, 32, 32),\ + input_shape_format='b h w'\ + ) + >>> grid.shape + torch.Size([138, 36]) + + Args: + batch_image: Batched images of shape (batch, channel, heigth, width) or + (batch, height, width). + input_shape_format: Format of the input shape. Should be a string of + space-separated dimension names, e.g., 'b c h w'. + height_axis: Name of the height (frequency) axis. + width_axis: Name of the width (time) axis. + stack: How to stack the images. `height_axis` for horizontal, + `width_axis` for vertical stacking. + origin: Origin of the plot. Can be `'upper'` or `'lower'`. + normalize: See make_grid + scale_each: See make_grid + """ + if origin not in ('upper', 'lower'): + raise ValueError(f'"origin" should be "upper" or "lower" but got {origin}') + + if stack is None: + stack = height_axis + + if stack not in (height_axis, width_axis): + raise ValueError( + f'"stack" should be "{height_axis}" or ' + f'"{width_axis}" but got {stack}' + ) + + dims = input_shape_format.split() + if len(dims) != batch_image.ndim: + raise ValueError(f'Shape format {input_shape_format} does not match input shape {batch_image.shape}') + + if batch_image.ndim == 3: + # Add channel dimension + batch_image = batch_image.unsqueeze(1) + dims.insert(1, 'c') + + if origin == 'lower': + # Reverse the order of the height (frequency) dimension + batch_image = batch_image.flip(dims.index(height_axis)) + + grid = make_grid( + batch_image, + normalize=normalize, + scale_each=scale_each, + nrow=1 if stack==height_axis else batch_image.shape[0], + ) + if batch_image.shape[1] == 1: + # Remove color dimension + grid = grid[0] + return grid From 9a6bfe85693cbff26ce8dfca4a0c088e471fe00c Mon Sep 17 00:00:00 2001 From: Michael Kuhlmann Date: Thu, 4 Apr 2024 17:09:49 +0200 Subject: [PATCH 3/3] contrib: Add audio synthesis from spectrogram - Fast Griffin-Lim algorithm - Neural vocoder based on https://github.com/kan-bayashi/ParallelWaveGAN --- padertorch/contrib/mk/synthesis/__init__.py | 3 + padertorch/contrib/mk/synthesis/base.py | 84 +++++ .../mk/synthesis/parametric/__init__.py | 1 + .../mk/synthesis/parametric/griffin_lim.py | 214 +++++++++++ .../contrib/mk/synthesis/vocoder/__init__.py | 1 + .../contrib/mk/synthesis/vocoder/pwg.py | 353 ++++++++++++++++++ 6 files changed, 656 insertions(+) create mode 100644 padertorch/contrib/mk/synthesis/__init__.py create mode 100644 padertorch/contrib/mk/synthesis/base.py create mode 100644 padertorch/contrib/mk/synthesis/parametric/__init__.py create mode 100644 padertorch/contrib/mk/synthesis/parametric/griffin_lim.py create mode 100644 padertorch/contrib/mk/synthesis/vocoder/__init__.py create mode 100644 padertorch/contrib/mk/synthesis/vocoder/pwg.py diff --git a/padertorch/contrib/mk/synthesis/__init__.py b/padertorch/contrib/mk/synthesis/__init__.py new file mode 100644 index 00000000..f5eeac9c --- /dev/null +++ b/padertorch/contrib/mk/synthesis/__init__.py @@ -0,0 +1,3 @@ +from .vocoder import Vocoder +from .parametric import fast_griffin_lim, FGLA +from .legacy import Converter diff --git a/padertorch/contrib/mk/synthesis/base.py b/padertorch/contrib/mk/synthesis/base.py new file mode 100644 index 00000000..c1f34d9b --- /dev/null +++ b/padertorch/contrib/mk/synthesis/base.py @@ -0,0 +1,84 @@ +import typing +from functools import partial + +import numpy as np +import torch +from paderbox.transform.module_resample import resample_sox +import padertorch as pt + + +class Synthesis(pt.Configurable): + sampling_rate: int + + def __init__( + self, + postprocessing: typing.Optional[typing.Callable] = None, + ): + super().__init__() + self.postprocessing = postprocessing + + def __call__( + self, + time_signal: typing.Union[ + np.ndarray, torch.Tensor, typing.List[np.ndarray], + typing.List[torch.Tensor] + ], + target_sampling_rate: typing.Optional[int] = None, + ) -> typing.Union[ + np.ndarray, torch.Tensor, typing.List[np.ndarray], + typing.List[torch.Tensor] + ]: + if self.postprocessing is not None: + if isinstance(time_signal, list) or time_signal.ndim == 2: + time_signal = list(map(self.postprocessing, time_signal)) + else: + time_signal = self.postprocessing(time_signal) + return self.resample(time_signal, target_sampling_rate) + + def _resample( + self, + wav: typing.Union[np.ndarray, torch.Tensor], + target_sampling_rate: typing.Optional[int] = None, + ) -> typing.Union[np.ndarray, torch.Tensor]: + to_torch = False + if ( + target_sampling_rate is None + or target_sampling_rate == self.sampling_rate + ): + return wav + if isinstance(wav, torch.Tensor): + to_torch = True + wav = pt.utils.to_numpy(wav, detach=True) + wav = resample_sox( + wav, + in_rate=self.sampling_rate, + out_rate=target_sampling_rate + ) + if to_torch: + wav = torch.from_numpy(wav) + return wav + + def resample( + self, + wav: typing.Union[ + np.ndarray, torch.Tensor, typing.List[np.ndarray], + typing.List[torch.Tensor] + ], + target_sampling_rate: typing.Optional[int] = None, + ) -> typing.Union[ + np.ndarray, torch.Tensor, typing.List[np.ndarray], + typing.List[torch.Tensor] + ]: + if isinstance(wav, list) or wav.ndim == 2: + wav = list(map( + partial( + self._resample, target_sampling_rate=target_sampling_rate + ), wav + )) + try: + m = np if isinstance(wav[0], np.ndarray) else torch + wav = m.stack(wav) + except (ValueError, RuntimeError): + pass + return wav + return self._resample(wav, target_sampling_rate=target_sampling_rate) diff --git a/padertorch/contrib/mk/synthesis/parametric/__init__.py b/padertorch/contrib/mk/synthesis/parametric/__init__.py new file mode 100644 index 00000000..a0477f32 --- /dev/null +++ b/padertorch/contrib/mk/synthesis/parametric/__init__.py @@ -0,0 +1 @@ +from .griffin_lim import fast_griffin_lim, FGLA diff --git a/padertorch/contrib/mk/synthesis/parametric/griffin_lim.py b/padertorch/contrib/mk/synthesis/parametric/griffin_lim.py new file mode 100644 index 00000000..c0a6f378 --- /dev/null +++ b/padertorch/contrib/mk/synthesis/parametric/griffin_lim.py @@ -0,0 +1,214 @@ +import typing + +import numpy as np +import torch +from paderbox.transform import STFT as pbSTFT +import padertorch as pt +from padertorch.ops import STFT as ptSTFT + +from ..base import Synthesis + + +__all__ = [ + 'fast_griffin_lim', + 'FGLA', +] + + +def reshape_complex(signal, complex_representation): + if complex_representation in (None, 'complex'): + return signal + if complex_representation == 'stacked': + signal = torch.stack( + (signal.real, signal.imag), dim=-1 + ) + else: + signal = torch.cat( + (signal.real, signal.imag), dim=-1 + ) + return signal + + +def griffin_lim_step( + a: typing.Union[np.ndarray, torch.Tensor], + reconstruction_stft: typing.Union[np.ndarray, torch.Tensor], + stft: typing.Union[pbSTFT, ptSTFT], + backend=None, +): + """ + Args: + a: + reconstruction_stft: + stft: + backend: + + Returns: + + """ + if backend is None: + if isinstance(a, np.ndarray): + backend = np + else: + backend = torch + + # From paderbox.transform.module_phase_reconstruction + reconstruction_angle = backend.angle(reconstruction_stft) + proposal_spec = a * backend.exp(1.0j * reconstruction_angle) # P_A + + audio = stft.inverse( + reshape_complex( + proposal_spec, getattr(stft, 'complex_representation', None) + ) + ) # P_C + stft_signal = stft(audio) + if isinstance(stft_signal, np.ndarray): + return stft_signal, audio + if stft.complex_representation != 'complex': + if stft.complex_representation == 'stacked': + stft_signal = stft_signal[..., 0] + 1j * stft_signal[..., 1] + else: + size = stft_signal.shape[-1] + stft_signal = ( + stft_signal[..., :size//2] + 1j * stft_signal[..., size//2:] + ) + return stft_signal, audio + + +def fast_griffin_lim( + a: typing.Union[np.ndarray, torch.Tensor], + stft: [pbSTFT, ptSTFT], + alpha=0.95, + iterations=100, + atol: float = 0.1, + verbose=False, + x=None, +): + """Griffin-Lim algorithm with momentum for phase retrieval [1]. + + >>> f_0 = 200 # Hz + >>> f_s = 16_000 # Hz + >>> t = np.linspace(0, 1, num=f_s) + >>> sine = np.sin(2*np.pi*f_0*t) + >>> sine.shape + (16000,) + >>> stft = STFT(256, 1024, window_length=None, window='hann', pad=True, fading='half') + >>> mag_spec = np.abs(stft(sine)) + >>> mag_spec.shape + (63, 513) + >>> sine_hat = fast_griffin_lim(mag_spec, stft) + >>> sine_hat.shape + (16128,) + + [1]: Peer, Tal, Simon Welker, and Timo Gerkmann. "Beyond Griffin-LIM: + Improved Iterative Phase Retrieval for Speech." 2022 International + Workshop on Acoustic Signal Enhancement (IWAENC). IEEE, 2022. + + Args: + a: Magnitude spectrogram of shape (*, num_frames, stft.size//2+1) + stft: paderbox.transform.module_stft.STFT instance + alpha: Momentum for GLA acceleration, where 0 <= alpha <= 1 + iterations: Number of optimization iterations + atol: + verbose: If True, print the reconstruction error after each iteration step + x: Optional complex STFT output from a different phase retrieval algorithm + """ + if isinstance(a, np.ndarray): + backend = np + else: + backend = torch + + if x is None: + # Random phase initialization + if backend is np: + angle = np.random.uniform( + low=-np.pi, high=np.pi, size=a.shape + ) + else: + angle = torch.rand(a.shape).to(a.device) * 2 * torch.pi - torch.pi + else: + assert x.dtype in (np.complex64, np.complex128, torch.complex64), x.dtype + angle = backend.angle(x) + + with torch.no_grad(): + x = a * backend.exp(1.0j * angle) + y = x + for n in range(iterations): + x_, _ = griffin_lim_step(a, y, stft) + y = x_ + alpha * (x_ - x) + x = x_ + reconstruction_magnitude = backend.abs(x) + diff = (backend.sqrt( + backend.mean((reconstruction_magnitude - a) ** 2) + )) + if verbose: + print( + 'Reconstruction iteration: {}/{} RMSE: {} '.format( + n, iterations, diff + ) + ) + if diff < atol: + break + angle = backend.angle(x) + x = a * backend.exp(1.0j * angle) + signal = stft.inverse( + reshape_complex(x, getattr(stft, 'complex_representation', None)) + ) + return signal + + +class FGLA(Synthesis): + """Phase reconstruction using the Griffin-Lim algorithm (FGLA). + """ + def __init__( + self, + sampling_rate: int, + stft: typing.Union[pbSTFT, ptSTFT], + alpha: float = .95, + iterations: int = 30, + atol: float = 0.1, + ): + """ + Args: + sampling_rate: Sampling rate of the synthesized signal + stft: paderbox or padertorch STFT instance that was used to obtain + the magnitude spectrogram + alpha: See fast_griffin_lim + iterations: See fast_griffin_lim + atol: See fast_griffin_lim + """ + self.sampling_rate = sampling_rate + self.stft = stft + self.alpha = alpha + self.iterations = iterations + self.atol = atol + + def __call__( + self, + mag_spec: typing.Union[np.ndarray, torch.Tensor], + sequence_lengths: typing.Optional[typing.List[int]] = None, + target_sampling_rate: typing.Optional[int] = None, + ) -> typing.Union[torch.Tensor, np.ndarray]: + """ + Args: + mag_spec: Magnitude spectrogram of shape + (*, num_frames, stft.size//2+1) + sequence_lengths: Ignored + target_sampling_rate: If not None, resample to + `target_sampling_rate` + + Returns: np.ndarray or torch.Tensor + The synthesized waveform + """ + del sequence_lengths + if isinstance(mag_spec, np.ndarray) and isinstance(self.stft, ptSTFT): + mag_spec = pt.data.example_to_device(mag_spec) + elif ( + isinstance(mag_spec, torch.Tensor) + and isinstance(self.stft, pbSTFT) + ): + mag_spec = pt.utils.to_numpy(mag_spec, detach=True) + + signal = fast_griffin_lim( + mag_spec, self.stft, self.alpha, self.iterations, self.atol + ) + return self._resample(signal, target_sampling_rate) diff --git a/padertorch/contrib/mk/synthesis/vocoder/__init__.py b/padertorch/contrib/mk/synthesis/vocoder/__init__.py new file mode 100644 index 00000000..0fc6a034 --- /dev/null +++ b/padertorch/contrib/mk/synthesis/vocoder/__init__.py @@ -0,0 +1 @@ +from .pwg import Vocoder diff --git a/padertorch/contrib/mk/synthesis/vocoder/pwg.py b/padertorch/contrib/mk/synthesis/vocoder/pwg.py new file mode 100644 index 00000000..f3a1773d --- /dev/null +++ b/padertorch/contrib/mk/synthesis/vocoder/pwg.py @@ -0,0 +1,353 @@ +import os +from pathlib import Path +import typing +from collections import namedtuple +import natsort +from distutils.version import LooseVersion +import io +import importlib.util +import tempfile + +import numpy as np +import torch +from einops import rearrange +import yaml +try: + from parallel_wavegan.utils import load_model, download_pretrained_model +except ImportError: + raise ImportError( + '`parallel_wavegan` package was not found. ' + 'To install it, see here: ' + 'https://github.com/kan-bayashi/ParallelWaveGAN' + ) + +from paderbox.io import load_yaml +import padertorch as pt + +from ..base import Synthesis + + +def _pwg_load_model(checkpoint, config=None, stats=None, consider_mpi=False): + """ + Copy of `parallel_wavegan.utils.load_model` with MPI support + + Args: + checkpoint (str): Checkpoint path. + config (dict): Configuration dict. + stats (str): Statistics file path. + consider_mpi (bool): + + Returns: + + """ + # load config if not provided + if config is None: + dirname = os.path.dirname(checkpoint) + config = os.path.join(dirname, "config.yml") + with open(config) as f: + config = yaml.load(f, Loader=yaml.Loader) + + # lazy load for circular error + import parallel_wavegan.models + + # get model and load parameters + model_class = getattr( + parallel_wavegan.models, + config.get("generator_type", "ParallelWaveGANGenerator"), + ) + # workaround for typo #295 + generator_params = { + k.replace("upsample_kernal_sizes", "upsample_kernel_sizes"): v + for k, v in config["generator_params"].items() + } + model = model_class(**generator_params) + if consider_mpi: + checkpoint_content = None + import dlp_mpi + if dlp_mpi.IS_MASTER: + checkpoint_content = Path(checkpoint).read_bytes() + checkpoint_content = dlp_mpi.bcast(checkpoint_content) + _checkpoint = torch.load( + io.BytesIO(checkpoint_content), map_location="cpu") + else: + _checkpoint = torch.load(checkpoint, map_location="cpu") + model.load_state_dict(_checkpoint["model"]["generator"]) + + # check stats existence + if stats is None: + dirname = os.path.dirname(checkpoint) + if config["format"] == "hdf5": + ext = "h5" + else: + ext = "npy" + if os.path.exists(os.path.join(dirname, f"stats.{ext}")): + stats = os.path.join(dirname, f"stats.{ext}") + + # load stats + if stats is not None: + model.register_stats(stats) + + # add pqmf if needed + if config["generator_params"]["out_channels"] > 1: + # lazy load for circular error + from parallel_wavegan.layers import PQMF + + pqmf_params = {} + if LooseVersion(config.get("version", "0.1.0")) <= LooseVersion( + "0.4.2"): + # For compatibility, here we set default values in version <= 0.4.2 + pqmf_params.update(taps=62, cutoff_ratio=0.15, beta=9.0) + model.pqmf = PQMF( + subbands=config["generator_params"]["out_channels"], + **config.get("pqmf_params", pqmf_params), + ) + + return model + + +def load_vocoder_model( + vocoder_base_path: typing.Union[str, Path], config_name: str = 'config.yml', + vocoder_stats: str = 'stats.h5', vocoder_checkpoint: str = None, + consider_mpi=False, +): + """ + Load a pre-trained vocoder model from + https://github.com/kan-bayashi/ParallelWaveGAN#results. + + Args: + vocoder_base_path: Filepath to the vocoder folder containing the + checkpoint file, config and normalization statistics. + config_name: Filename of the config file. + vocoder_stats: Filename of the normalization statistics. + vocoder_checkpoint: Filename of the checkpoint that should be loaded. + If None, choose the latest checkpoint in `vocoder_base_path`. + consider_mpi: If True, reduce IO load on workers + + Returns: + Loaded model and sampling rate of the synthesized audios + """ + + vocoder_base_path = Path(vocoder_base_path) + if vocoder_checkpoint is None: + ckpt_files = natsort.natsorted(list(map( + lambda p: str(p), vocoder_base_path.glob('*.pkl')))) + vocoder_checkpoint = ckpt_files[-1] + print(f'Loading vocoder checkpoint {vocoder_checkpoint}') + + if consider_mpi: + import dlp_mpi + config = None + if dlp_mpi.IS_MASTER: + config = load_yaml(vocoder_base_path / config_name) + config = dlp_mpi.bcast(config) + else: + config = load_yaml(vocoder_base_path / config_name) + + vocoder = _pwg_load_model( + vocoder_checkpoint, config, + stats=str(vocoder_base_path / vocoder_stats), + consider_mpi=consider_mpi + ) + vocoder.remove_weight_norm() + vocoder = vocoder.eval() + audio_params = namedtuple( + 'AudioParams', ['sampling_rate', 'shift', 'window_length']) + window_length = config['win_length'] + if window_length is None: + window_length = config['fft_size'] + window_length = window_length / config['sampling_rate'] * 1000 + shift = config['hop_size'] / config['sampling_rate'] * 1000 + return vocoder, audio_params( + config['sampling_rate'], shift, window_length + ) + + +class Vocoder(Synthesis): + """ + Neural vocoder wrapping any models from https://github.com/kan-bayashi/ParallelWaveGAN + + Provides an easy-to-use interface to download vocoders and to perform + waveform synthesis from log-mel spectrogram. Vocoder models are identified + by a vocoder tag of the form _.. The __call__ + has to take a log-mel spectrogram (np.ndarray or torch.Tensor), a list of + sequence lengths (in case of a batched input), the desired output sampling + rate, and any optional *keyword arguments* to control the synthesis. + + Attributes: + sampling_rate (int): Sampling rate of the training data the vocoder was + trained on + """ + def __init__( + self, + database: str = 'libritts', + pwg_base_dir: typing.Optional[typing.Union[str, Path]] = \ + os.environ.get('PWG_BASE_DIR', None), + vocoder_model: str = 'hifigan', + vocoder_tag: typing.Optional[str] = None, + normalize_before: bool = False, + device: typing.Union[str, int] = 'cpu', + consider_mpi: bool = False, + batch_axis: int = 0, + sequence_axis: int = -1, + postprocessing: typing.Optional[typing.Callable] = None, + ): + """ + Args: + database: The database the vocoder was trained on (see + https://github.com/kan-bayashi/ParallelWaveGAN). Used to + infer the vocoder model. If `vocoder_tag` is not None, this + argument will be ignored + pwg_base_dir: Path to folder where vocoders will be downloaded to. + Vocoders will be dumped as directories indexed by their vocoder + tag. If a vocoder is already downloaded, it will be loaded from + disk instead of being downloaded again + vocoder_model: Type of vocoder architecture as specified in the + vocoder tag, e.g., "parallel_wavegan" or "hifigan". See + https://github.com/kan-bayashi/ParallelWaveGAN for available + architectures. If `vocoder_tag` is not None, this + argument will be ignored + vocoder_tag: Vocoder tag of the form + _.. If not None, will + first look under `pwg_base_dir` and then try to download it + from https://github.com/kan-bayashi/ParallelWaveGAN + normalize_before: If True, perform z-normalization with vocoder + train statistics. If False, `mel_spec` should be normalized + with test statistics. Defaults to False + device: Device (CPU, GPU) used for inference + consider_mpi: If True, load the weights on the master and distribute + to all workers + batch_axis: Axis along which the batches are stacked. If the input + to __call__ is 2-dimensional, `batch_axis` will be ignored + sequence_axis: Axis that contains time information + postprocessing: Optional postprocessing function that is applied to + the synthesized waveform + """ + super().__init__(postprocessing=postprocessing) + self.pwg_base_dir = pwg_base_dir + self.vocoder_tag = vocoder_tag + self.normalize_before = normalize_before + self.device = device + self.batch_axis = batch_axis + self.sequence_axis = sequence_axis + + if self.vocoder_tag is None: + self.vocoder_tag = str(Path( + '_'.join((database, vocoder_model))).with_suffix('.v1')) + if self.pwg_base_dir is None: + self.pwg_base_dir = Path(tempfile.gettempdir()) / 'pwg_models' + else: + self.pwg_base_dir = Path(self.pwg_base_dir) + if not (self.pwg_base_dir / self.vocoder_tag).exists(): + # Download vocoder and store it under `self.pwg_base_dir` + if consider_mpi: + try: + import dlp_mpi + except ImportError as e: + raise ImportError( + 'Could not import dlp_mpi.\n' + 'Please install it or set consider_mpi=False' + ) from e + if dlp_mpi.IS_MASTER: + self._download() + dlp_mpi.barrier() + else: + self._download() + self.vocoder_model, vocoder_audio_params = load_vocoder_model( + self.pwg_base_dir / self.vocoder_tag, consider_mpi=consider_mpi) + self.sampling_rate = vocoder_audio_params.sampling_rate + self.vocoder_model.to(self.device).eval() + + def _download(self): + try: + download_pretrained_model(self.vocoder_tag, str(self.pwg_base_dir)) + print(f'Downloaded {self.vocoder_tag} to {self.pwg_base_dir}') + except KeyError as e: + raise KeyError( + f'Could not find {self.vocoder_tag} in pretrained models!\n' + 'list(self.pwg_base_dir.iterdir()): ' + f'{[p for p in self.pwg_base_dir.iterdir() if p.is_dir()]}\n' + 'Check parallel_wavegan.PRETRAINED_MODEL_LIST or ' + 'https://github.com/kan-bayashi/ParallelWaveGAN#results for' + 'more pretrained models.\n' + 'You can specify a pretrained model with the vocoder_tag ' + 'argument.' + ) from e + + def __call__( + self, + mel_spec: typing.Union[torch.Tensor, np.ndarray], + sequence_lengths: typing.Optional[typing.List[int]] = None, + target_sampling_rate: typing.Optional[int] = None, + ) -> typing.Union[ + torch.Tensor, np.ndarray, typing.List[np.ndarray], + typing.List[torch.Tensor] + ]: + """ + Synthesize waveform from log-mel spectrogram with a neural vocoder + + Args: + mel_spec: (Batched) mel-spectrograms where shape must match as + specified by `self.batch_axis` and `self.sequence_axis` + sequence_lengths: + target_sampling_rate: By default, vocoder will produce waveforms + with sampling rate seen during training. If not None, resample + to `target_sampling_rate` + + Returns: torch.Tensor or np.ndarray + Synthesized waveform + """ + to_numpy = isinstance(mel_spec, np.ndarray) + if to_numpy: + mel_spec = torch.from_numpy(mel_spec).to(self.device) + mel_spec = mel_spec.squeeze() + sequence_axis = self.sequence_axis % mel_spec.ndim + batch_axis = self.batch_axis % mel_spec.ndim + if mel_spec.ndim == 2: + mel_spec = torch.moveaxis(mel_spec, sequence_axis, 0) + with torch.no_grad(): + y = self.vocoder_model.inference( + mel_spec, normalize_before=self.normalize_before + ).view(-1) + if to_numpy: + y = pt.utils.to_numpy(y, detach=True) + elif mel_spec.ndim == 3: + feature_axis = set(range(3)).difference( + {batch_axis, self.sequence_axis} + ).pop() + shape = list(map( + lambda t: t[1], + sorted( + zip( + [batch_axis, sequence_axis, feature_axis], + ['b', 't', 'f'] + ), key=lambda t: t[0] + ) + )) + mel_spec = rearrange(mel_spec, f"{' '.join(shape)} -> b t f") + if sequence_lengths is None: + sequence_lengths = [mel_spec.shape[1]] * mel_spec.shape[0] + with torch.no_grad(): + y = [] + for _mel_spec, seq_len in zip(mel_spec, sequence_lengths): + y_ = ( + self.vocoder_model.inference( + _mel_spec[:seq_len], + normalize_before=self.normalize_before + ).view(-1) + ) + if to_numpy: + y_ = pt.utils.to_numpy(y, detach=True) + y.append(y_) + try: + if to_numpy: + y = np.stack(y) + else: + y = torch.stack(y) + except RuntimeError: + pass + else: + raise TypeError( + 'Expected 2- or 3-dim. spectrogram but got ' + f'{mel_spec.ndim}-dim. input with shape {mel_spec.shape}' + ) + return super(Vocoder, self).__call__(y, target_sampling_rate)