-
Notifications
You must be signed in to change notification settings - Fork 90
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #113 from gudgud96/dev-vqt
feat: Add Variable-Q Transform
- Loading branch information
Showing
4 changed files
with
290 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -11,3 +11,4 @@ | |
from .griffin_lim import * | ||
from .mel import * | ||
from .stft import * | ||
from .vqt import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,202 @@ | ||
import torch | ||
import torch.nn as nn | ||
import numpy as np | ||
from time import time | ||
from ..librosa_functions import * | ||
from ..utils import * | ||
|
||
|
||
class VQT(torch.nn.Module): | ||
def __init__( | ||
self, | ||
sr=22050, | ||
hop_length=512, | ||
fmin=32.70, | ||
fmax=None, | ||
n_bins=84, | ||
filter_scale=1, | ||
bins_per_octave=12, | ||
norm=True, | ||
basis_norm=1, | ||
gamma=0, | ||
window='hann', | ||
pad_mode='reflect', | ||
earlydownsample=True, | ||
trainable=False, | ||
output_format='Magnitude', | ||
verbose=True | ||
): | ||
|
||
super().__init__() | ||
|
||
self.norm = norm | ||
self.hop_length = hop_length | ||
self.pad_mode = pad_mode | ||
self.n_bins = n_bins | ||
self.earlydownsample = earlydownsample | ||
self.trainable = trainable | ||
self.output_format = output_format | ||
self.filter_scale = filter_scale | ||
self.bins_per_octave = bins_per_octave | ||
self.sr = sr | ||
self.gamma = gamma | ||
self.basis_norm = basis_norm | ||
|
||
# It will be used to calculate filter_cutoff and creating CQT kernels | ||
Q = float(filter_scale)/(2**(1/bins_per_octave)-1) | ||
|
||
# Creating lowpass filter and make it a torch tensor | ||
if verbose==True: | ||
print("Creating low pass filter ...", end='\r') | ||
start = time() | ||
lowpass_filter = torch.tensor(create_lowpass_filter( | ||
band_center = 0.50, | ||
kernelLength=256, | ||
transitionBandwidth=0.001) | ||
) | ||
|
||
self.register_buffer('lowpass_filter', lowpass_filter[None,None,:]) | ||
if verbose == True: | ||
print("Low pass filter created, time used = {:.4f} seconds".format(time()-start)) | ||
|
||
n_filters = min(bins_per_octave, n_bins) | ||
self.n_filters = n_filters | ||
self.n_octaves = int(np.ceil(float(n_bins) / bins_per_octave)) | ||
if verbose == True: | ||
print("num_octave = ", self.n_octaves) | ||
|
||
self.fmin_t = fmin * 2 ** (self.n_octaves - 1) | ||
remainder = n_bins % bins_per_octave | ||
|
||
if remainder==0: | ||
# Calculate the top bin frequency | ||
fmax_t = self.fmin_t*2**((bins_per_octave-1)/bins_per_octave) | ||
else: | ||
# Calculate the top bin frequency | ||
fmax_t = self.fmin_t*2**((remainder-1)/bins_per_octave) | ||
|
||
# Adjusting the top minimum bins | ||
self.fmin_t = fmax_t / 2 ** (1 - 1 / bins_per_octave) | ||
if fmax_t > sr/2: | ||
raise ValueError('The top bin {}Hz has exceeded the Nyquist frequency, \ | ||
please reduce the n_bins'.format(fmax_t)) | ||
|
||
if self.earlydownsample == True: # Do early downsampling if this argument is True | ||
if verbose == True: | ||
print("Creating early downsampling filter ...", end='\r') | ||
start = time() | ||
sr, self.hop_length, self.downsample_factor, early_downsample_filter, \ | ||
self.earlydownsample = get_early_downsample_params(sr, | ||
hop_length, | ||
fmax_t, | ||
Q, | ||
self.n_octaves, | ||
verbose) | ||
self.register_buffer('early_downsample_filter', early_downsample_filter) | ||
|
||
if verbose==True: | ||
print("Early downsampling filter created, \ | ||
time used = {:.4f} seconds".format(time()-start)) | ||
else: | ||
self.downsample_factor = 1. | ||
|
||
# For normalization in the end | ||
# The freqs returned by create_cqt_kernels cannot be used | ||
# Since that returns only the top octave bins | ||
# We need the information for all freq bin | ||
alpha = 2.0 ** (1.0 / bins_per_octave) - 1.0 | ||
freqs = fmin * 2.0 ** (np.r_[0:n_bins] / np.float(bins_per_octave)) | ||
self.frequencies = freqs | ||
lenghts = np.ceil(Q * sr / (freqs + gamma / alpha)) | ||
|
||
# get max window length depending on gamma value | ||
max_len = int(max(lenghts)) | ||
self.n_fft = int(2 ** (np.ceil(np.log2(max_len)))) | ||
|
||
lenghts = torch.tensor(lenghts).float() | ||
self.register_buffer('lenghts', lenghts) | ||
|
||
|
||
def forward(self, x, output_format=None, normalization_type='librosa'): | ||
""" | ||
Convert a batch of waveforms to VQT spectrograms. | ||
Parameters | ||
---------- | ||
x : torch tensor | ||
Input signal should be in either of the following shapes.\n | ||
1. ``(len_audio)``\n | ||
2. ``(num_audio, len_audio)``\n | ||
3. ``(num_audio, 1, len_audio)`` | ||
It will be automatically broadcast to the right shape | ||
""" | ||
output_format = output_format or self.output_format | ||
|
||
x = broadcast_dim(x) | ||
if self.earlydownsample==True: | ||
x = downsampling_by_n(x, self.early_downsample_filter, self.downsample_factor) | ||
hop = self.hop_length | ||
vqt = [] | ||
|
||
x_down = x # Preparing a new variable for downsampling | ||
my_sr = self.sr | ||
|
||
for i in range(self.n_octaves): | ||
if i > 0: | ||
x_down = downsampling_by_2(x_down, self.lowpass_filter) | ||
my_sr /= 2 | ||
hop //= 2 | ||
|
||
else: | ||
x_down = x | ||
|
||
Q = float(self.filter_scale)/(2**(1/self.bins_per_octave)-1) | ||
|
||
basis, self.n_fft, lengths, _ = create_cqt_kernels(Q, | ||
my_sr, | ||
self.fmin_t * 2 ** -i, | ||
self.n_filters, | ||
self.bins_per_octave, | ||
norm=self.basis_norm, | ||
topbin_check=False, | ||
gamma=self.gamma) | ||
|
||
cqt_kernels_real = torch.tensor(basis.real.astype(np.float32)).unsqueeze(1) | ||
cqt_kernels_imag = torch.tensor(basis.imag.astype(np.float32)).unsqueeze(1) | ||
|
||
if self.pad_mode == 'constant': | ||
my_padding = nn.ConstantPad1d(cqt_kernels_real.shape[-1] // 2, 0) | ||
elif self.pad_mode == 'reflect': | ||
my_padding= nn.ReflectionPad1d(cqt_kernels_real.shape[-1] // 2) | ||
|
||
cur_vqt = get_cqt_complex(x_down, cqt_kernels_real, cqt_kernels_imag, hop, my_padding) | ||
vqt.insert(0, cur_vqt) | ||
|
||
vqt = torch.cat(vqt, dim=1) | ||
vqt = vqt[:,-self.n_bins:,:] # Removing unwanted bottom bins | ||
vqt = vqt * self.downsample_factor | ||
|
||
# Normalize again to get same result as librosa | ||
if normalization_type == 'librosa': | ||
vqt = vqt * torch.sqrt(self.lenghts.view(-1,1,1)) | ||
elif normalization_type == 'convolutional': | ||
pass | ||
elif normalization_type == 'wrap': | ||
vqt *= 2 | ||
else: | ||
raise ValueError("The normalization_type %r is not part of our current options." % normalization_type) | ||
|
||
if output_format=='Magnitude': | ||
if self.trainable==False: | ||
# Getting CQT Amplitude | ||
return torch.sqrt(vqt.pow(2).sum(-1)) | ||
else: | ||
return torch.sqrt(vqt.pow(2).sum(-1) + 1e-8) | ||
|
||
elif output_format=='Complex': | ||
return vqt | ||
|
||
elif output_format=='Phase': | ||
phase_real = torch.cos(torch.atan2(vqt[:,:,:,1], vqt[:,:,:,0])) | ||
phase_imag = torch.sin(torch.atan2(vqt[:,:,:,1], vqt[:,:,:,0])) | ||
return torch.stack((phase_real,phase_imag), -1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,75 @@ | ||
import pytest | ||
import librosa | ||
import torch | ||
import sys | ||
|
||
sys.path.insert(0, "./") | ||
|
||
import os | ||
|
||
dir_path = os.path.dirname(os.path.realpath(__file__)) | ||
|
||
from nnAudio.features import CQT2010v2, VQT | ||
import numpy as np | ||
from parameters import * | ||
import warnings | ||
|
||
gpu_idx = 0 # Choose which GPU to use | ||
|
||
# If GPU is avaliable, also test on GPU | ||
if torch.cuda.is_available(): | ||
device_args = ["cpu", f"cuda:{gpu_idx}"] | ||
else: | ||
warnings.warn("GPU is not avaliable, testing only on CPU") | ||
device_args = ["cpu"] | ||
|
||
# librosa example audio for testing | ||
y, sr = librosa.load(librosa.ex('choice'), duration=5) | ||
|
||
@pytest.mark.parametrize("device", [*device_args]) | ||
def test_vqt_gamma_zero(device): | ||
# nnAudio cqt | ||
spec = CQT2010v2(sr=sr, verbose=False) | ||
C2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa') | ||
C2 = C2.numpy().squeeze() | ||
|
||
# nnAudio vqt | ||
spec = VQT(sr=sr, gamma=0, verbose=False) | ||
V2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa') | ||
V2 = V2.numpy().squeeze() | ||
|
||
assert (C2 == V2).all() == True | ||
|
||
|
||
@pytest.mark.parametrize("device", [*device_args]) | ||
def test_vqt(device): | ||
for gamma in [0, 1, 2, 5, 10]: | ||
|
||
# librosa vqt | ||
V1 = np.abs(librosa.vqt(y, sr=sr, gamma=gamma)) | ||
|
||
# nnAudio vqt | ||
spec = VQT(sr=sr, gamma=gamma, verbose=False) | ||
V2 = spec(torch.tensor(y).unsqueeze(0), output_format="Magnitude", normalization_type='librosa') | ||
V2 = V2.numpy().squeeze() | ||
|
||
# NOTE: there will still be some diff between librosa and nnAudio vqt values (same as cqt) | ||
# mainly due to the lengths of both - librosa uses float but nnAudio uses int | ||
# this test aims to keep the diff range within a baseline threshold | ||
vqt_diff = np.abs(V1 - V2) | ||
|
||
if gamma == 0: | ||
assert np.amin(vqt_diff) < 1e-8 | ||
assert np.amax(vqt_diff) < 0.6785 | ||
elif gamma == 1: | ||
assert np.amin(vqt_diff) < 1e-8 | ||
assert np.amax(vqt_diff) < 0.6510 | ||
elif gamma == 2: | ||
assert np.amin(vqt_diff) < 1e-8 | ||
assert np.amax(vqt_diff) < 0.5962 | ||
elif gamma == 5: | ||
assert np.amin(vqt_diff) < 1e-8 | ||
assert np.amax(vqt_diff) < 0.3695 | ||
else: | ||
assert np.amin(vqt_diff) < 1e-8 | ||
assert np.amax(vqt_diff) < 0.1 |