-
Notifications
You must be signed in to change notification settings - Fork 16
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
contrib: Add audio synthesis from spectrogram
- Fast Griffin-Lim algorithm - Neural vocoder based on https://github.com/kan-bayashi/ParallelWaveGAN
- Loading branch information
1 parent
ff94a2a
commit 9a6bfe8
Showing
6 changed files
with
656 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from .vocoder import Vocoder | ||
from .parametric import fast_griffin_lim, FGLA | ||
from .legacy import Converter |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,84 @@ | ||
import typing | ||
from functools import partial | ||
|
||
import numpy as np | ||
import torch | ||
from paderbox.transform.module_resample import resample_sox | ||
import padertorch as pt | ||
|
||
|
||
class Synthesis(pt.Configurable): | ||
sampling_rate: int | ||
|
||
def __init__( | ||
self, | ||
postprocessing: typing.Optional[typing.Callable] = None, | ||
): | ||
super().__init__() | ||
self.postprocessing = postprocessing | ||
|
||
def __call__( | ||
self, | ||
time_signal: typing.Union[ | ||
np.ndarray, torch.Tensor, typing.List[np.ndarray], | ||
typing.List[torch.Tensor] | ||
], | ||
target_sampling_rate: typing.Optional[int] = None, | ||
) -> typing.Union[ | ||
np.ndarray, torch.Tensor, typing.List[np.ndarray], | ||
typing.List[torch.Tensor] | ||
]: | ||
if self.postprocessing is not None: | ||
if isinstance(time_signal, list) or time_signal.ndim == 2: | ||
time_signal = list(map(self.postprocessing, time_signal)) | ||
else: | ||
time_signal = self.postprocessing(time_signal) | ||
return self.resample(time_signal, target_sampling_rate) | ||
|
||
def _resample( | ||
self, | ||
wav: typing.Union[np.ndarray, torch.Tensor], | ||
target_sampling_rate: typing.Optional[int] = None, | ||
) -> typing.Union[np.ndarray, torch.Tensor]: | ||
to_torch = False | ||
if ( | ||
target_sampling_rate is None | ||
or target_sampling_rate == self.sampling_rate | ||
): | ||
return wav | ||
if isinstance(wav, torch.Tensor): | ||
to_torch = True | ||
wav = pt.utils.to_numpy(wav, detach=True) | ||
wav = resample_sox( | ||
wav, | ||
in_rate=self.sampling_rate, | ||
out_rate=target_sampling_rate | ||
) | ||
if to_torch: | ||
wav = torch.from_numpy(wav) | ||
return wav | ||
|
||
def resample( | ||
self, | ||
wav: typing.Union[ | ||
np.ndarray, torch.Tensor, typing.List[np.ndarray], | ||
typing.List[torch.Tensor] | ||
], | ||
target_sampling_rate: typing.Optional[int] = None, | ||
) -> typing.Union[ | ||
np.ndarray, torch.Tensor, typing.List[np.ndarray], | ||
typing.List[torch.Tensor] | ||
]: | ||
if isinstance(wav, list) or wav.ndim == 2: | ||
wav = list(map( | ||
partial( | ||
self._resample, target_sampling_rate=target_sampling_rate | ||
), wav | ||
)) | ||
try: | ||
m = np if isinstance(wav[0], np.ndarray) else torch | ||
wav = m.stack(wav) | ||
except (ValueError, RuntimeError): | ||
pass | ||
return wav | ||
return self._resample(wav, target_sampling_rate=target_sampling_rate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .griffin_lim import fast_griffin_lim, FGLA |
214 changes: 214 additions & 0 deletions
214
padertorch/contrib/mk/synthesis/parametric/griffin_lim.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,214 @@ | ||
import typing | ||
|
||
import numpy as np | ||
import torch | ||
from paderbox.transform import STFT as pbSTFT | ||
import padertorch as pt | ||
from padertorch.ops import STFT as ptSTFT | ||
|
||
from ..base import Synthesis | ||
|
||
|
||
__all__ = [ | ||
'fast_griffin_lim', | ||
'FGLA', | ||
] | ||
|
||
|
||
def reshape_complex(signal, complex_representation): | ||
if complex_representation in (None, 'complex'): | ||
return signal | ||
if complex_representation == 'stacked': | ||
signal = torch.stack( | ||
(signal.real, signal.imag), dim=-1 | ||
) | ||
else: | ||
signal = torch.cat( | ||
(signal.real, signal.imag), dim=-1 | ||
) | ||
return signal | ||
|
||
|
||
def griffin_lim_step( | ||
a: typing.Union[np.ndarray, torch.Tensor], | ||
reconstruction_stft: typing.Union[np.ndarray, torch.Tensor], | ||
stft: typing.Union[pbSTFT, ptSTFT], | ||
backend=None, | ||
): | ||
""" | ||
Args: | ||
a: | ||
reconstruction_stft: | ||
stft: | ||
backend: | ||
Returns: | ||
""" | ||
if backend is None: | ||
if isinstance(a, np.ndarray): | ||
backend = np | ||
else: | ||
backend = torch | ||
|
||
# From paderbox.transform.module_phase_reconstruction | ||
reconstruction_angle = backend.angle(reconstruction_stft) | ||
proposal_spec = a * backend.exp(1.0j * reconstruction_angle) # P_A | ||
|
||
audio = stft.inverse( | ||
reshape_complex( | ||
proposal_spec, getattr(stft, 'complex_representation', None) | ||
) | ||
) # P_C | ||
stft_signal = stft(audio) | ||
if isinstance(stft_signal, np.ndarray): | ||
return stft_signal, audio | ||
if stft.complex_representation != 'complex': | ||
if stft.complex_representation == 'stacked': | ||
stft_signal = stft_signal[..., 0] + 1j * stft_signal[..., 1] | ||
else: | ||
size = stft_signal.shape[-1] | ||
stft_signal = ( | ||
stft_signal[..., :size//2] + 1j * stft_signal[..., size//2:] | ||
) | ||
return stft_signal, audio | ||
|
||
|
||
def fast_griffin_lim( | ||
a: typing.Union[np.ndarray, torch.Tensor], | ||
stft: [pbSTFT, ptSTFT], | ||
alpha=0.95, | ||
iterations=100, | ||
atol: float = 0.1, | ||
verbose=False, | ||
x=None, | ||
): | ||
"""Griffin-Lim algorithm with momentum for phase retrieval [1]. | ||
>>> f_0 = 200 # Hz | ||
>>> f_s = 16_000 # Hz | ||
>>> t = np.linspace(0, 1, num=f_s) | ||
>>> sine = np.sin(2*np.pi*f_0*t) | ||
>>> sine.shape | ||
(16000,) | ||
>>> stft = STFT(256, 1024, window_length=None, window='hann', pad=True, fading='half') | ||
>>> mag_spec = np.abs(stft(sine)) | ||
>>> mag_spec.shape | ||
(63, 513) | ||
>>> sine_hat = fast_griffin_lim(mag_spec, stft) | ||
>>> sine_hat.shape | ||
(16128,) | ||
[1]: Peer, Tal, Simon Welker, and Timo Gerkmann. "Beyond Griffin-LIM: | ||
Improved Iterative Phase Retrieval for Speech." 2022 International | ||
Workshop on Acoustic Signal Enhancement (IWAENC). IEEE, 2022. | ||
Args: | ||
a: Magnitude spectrogram of shape (*, num_frames, stft.size//2+1) | ||
stft: paderbox.transform.module_stft.STFT instance | ||
alpha: Momentum for GLA acceleration, where 0 <= alpha <= 1 | ||
iterations: Number of optimization iterations | ||
atol: | ||
verbose: If True, print the reconstruction error after each iteration step | ||
x: Optional complex STFT output from a different phase retrieval algorithm | ||
""" | ||
if isinstance(a, np.ndarray): | ||
backend = np | ||
else: | ||
backend = torch | ||
|
||
if x is None: | ||
# Random phase initialization | ||
if backend is np: | ||
angle = np.random.uniform( | ||
low=-np.pi, high=np.pi, size=a.shape | ||
) | ||
else: | ||
angle = torch.rand(a.shape).to(a.device) * 2 * torch.pi - torch.pi | ||
else: | ||
assert x.dtype in (np.complex64, np.complex128, torch.complex64), x.dtype | ||
angle = backend.angle(x) | ||
|
||
with torch.no_grad(): | ||
x = a * backend.exp(1.0j * angle) | ||
y = x | ||
for n in range(iterations): | ||
x_, _ = griffin_lim_step(a, y, stft) | ||
y = x_ + alpha * (x_ - x) | ||
x = x_ | ||
reconstruction_magnitude = backend.abs(x) | ||
diff = (backend.sqrt( | ||
backend.mean((reconstruction_magnitude - a) ** 2) | ||
)) | ||
if verbose: | ||
print( | ||
'Reconstruction iteration: {}/{} RMSE: {} '.format( | ||
n, iterations, diff | ||
) | ||
) | ||
if diff < atol: | ||
break | ||
angle = backend.angle(x) | ||
x = a * backend.exp(1.0j * angle) | ||
signal = stft.inverse( | ||
reshape_complex(x, getattr(stft, 'complex_representation', None)) | ||
) | ||
return signal | ||
|
||
|
||
class FGLA(Synthesis): | ||
"""Phase reconstruction using the Griffin-Lim algorithm (FGLA). | ||
""" | ||
def __init__( | ||
self, | ||
sampling_rate: int, | ||
stft: typing.Union[pbSTFT, ptSTFT], | ||
alpha: float = .95, | ||
iterations: int = 30, | ||
atol: float = 0.1, | ||
): | ||
""" | ||
Args: | ||
sampling_rate: Sampling rate of the synthesized signal | ||
stft: paderbox or padertorch STFT instance that was used to obtain | ||
the magnitude spectrogram | ||
alpha: See fast_griffin_lim | ||
iterations: See fast_griffin_lim | ||
atol: See fast_griffin_lim | ||
""" | ||
self.sampling_rate = sampling_rate | ||
self.stft = stft | ||
self.alpha = alpha | ||
self.iterations = iterations | ||
self.atol = atol | ||
|
||
def __call__( | ||
self, | ||
mag_spec: typing.Union[np.ndarray, torch.Tensor], | ||
sequence_lengths: typing.Optional[typing.List[int]] = None, | ||
target_sampling_rate: typing.Optional[int] = None, | ||
) -> typing.Union[torch.Tensor, np.ndarray]: | ||
""" | ||
Args: | ||
mag_spec: Magnitude spectrogram of shape | ||
(*, num_frames, stft.size//2+1) | ||
sequence_lengths: Ignored | ||
target_sampling_rate: If not None, resample to | ||
`target_sampling_rate` | ||
Returns: np.ndarray or torch.Tensor | ||
The synthesized waveform | ||
""" | ||
del sequence_lengths | ||
if isinstance(mag_spec, np.ndarray) and isinstance(self.stft, ptSTFT): | ||
mag_spec = pt.data.example_to_device(mag_spec) | ||
elif ( | ||
isinstance(mag_spec, torch.Tensor) | ||
and isinstance(self.stft, pbSTFT) | ||
): | ||
mag_spec = pt.utils.to_numpy(mag_spec, detach=True) | ||
|
||
signal = fast_griffin_lim( | ||
mag_spec, self.stft, self.alpha, self.iterations, self.atol | ||
) | ||
return self._resample(signal, target_sampling_rate) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .pwg import Vocoder |
Oops, something went wrong.