-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcompute_labels.py
137 lines (122 loc) · 5.38 KB
/
compute_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
import re
import os
import json
import torch
import numpy as np
import pandas as pd
import librosa as lr
from glob import glob
from multiprocessing.pool import Pool
from functools import partial
from tqdm import tqdm
from torchmetrics.functional.audio.pesq import perceptual_evaluation_speech_quality as pesq
from torchmetrics.functional.audio.stoi import short_time_objective_intelligibility as stoi
from torchmetrics.functional.audio import scale_invariant_signal_distortion_ratio as si_sdr
from torchmetrics.functional.audio import scale_invariant_signal_noise_ratio as si_snr
# code from https://github.com/facebookresearch/denoiser
def match_dns(noisy, clean):
"""match_dns.
Match noisy and clean DNS dataset filenames.
:param noisy: list of the noisy filenames
:param clean: list of the clean filenames
"""
noisydict = {}
extra_noisy = []
for path, size in noisy:
match = re.search(r'fileid_(\d+)\.wav$', path)
if match is None:
# maybe we are mixing some other dataset in
extra_noisy.append((path, size))
else:
noisydict[match.group(1)] = (path, size)
noisy[:] = []
extra_clean = []
copied = list(clean)
clean[:] = []
for path, size in copied:
match = re.search(r'fileid_(\d+)\.wav$', path)
if match is None:
extra_clean.append((path, size))
else:
noisy.append(noisydict[match.group(1)])
clean.append((path, size))
extra_noisy.sort()
extra_clean.sort()
clean += extra_clean
noisy += extra_noisy
# code from https://github.com/facebookresearch/denoiser
def match_files(noisy, clean, matching="sort"):
"""match_files.
Sort files to match noisy and clean filenames.
:param noisy: list of the noisy filenames
:param clean: list of the clean filenames
:param matching: the matching function, at this point only sort is supported
"""
if matching == "dns":
# dns dataset filenames don't match when sorted, we have to manually match them
match_dns(noisy, clean)
elif matching == "sort":
noisy.sort()
clean.sort()
else:
raise ValueError(f"Invalid value for matching {matching}")
def load_filelists(egs_path, matching="sort"):
noisy_json = os.path.join(egs_path, 'noisy.json')
clean_json = os.path.join(egs_path, 'clean.json')
with open(noisy_json, 'r') as f:
noisy = json.load(f)
with open(clean_json, 'r') as f:
clean = json.load(f)
match_files(noisy, clean, matching)
return list(zip(noisy, clean))
def compute_metrics(paths, win_len, win_hop):
noisy_path, clean_path = paths
x_noisy, sr_noisy = lr.load(noisy_path, sr=None)
x_clean, sr_clean = lr.load(clean_path, sr=None)
assert sr_noisy == sr_clean
win_len = lr.time_to_samples(win_len, sr=sr_noisy)
win_hop = lr.time_to_samples(win_hop, sr=sr_noisy)
assert x_noisy.shape[0] >= win_len, f"File {noisy_path} is too short: {x_noisy.shape[0]} vs {win_len}"
xx_noisy = lr.util.frame(x_noisy, frame_length=win_len, hop_length=win_hop, axis=0)
xx_clean = lr.util.frame(x_clean, frame_length=win_len, hop_length=win_hop, axis=0)
xx_noisy = torch.tensor(xx_noisy)
xx_clean = torch.tensor(xx_clean)
try:
labels = pd.DataFrame({
"pesq": pesq(xx_noisy, xx_clean, fs=sr_noisy, mode="wb").numpy(),
"stoi": stoi(xx_noisy, xx_clean, fs=sr_noisy).numpy(),
"si_sdr": si_sdr(xx_noisy, xx_clean).numpy(),
"si_snr": si_snr(xx_noisy, xx_clean).numpy()
})
except:
labels = None
return labels, tuple([os.path.basename(p) for p in [noisy_path, clean_path]])
def main(egs_path, output_path, win_len=9, win_hop=2, n_jobs=12, matching="dns"):
# load file list
files_list = load_filelists(egs_path, matching=matching)
files_list = [(fl[0][0], fl[1][0]) for fl in files_list]
print("File pairs found: ", len(files_list))
print("Showing first 5 pairs: ", files_list[:5])
# compute metrics (multi-processing)
compute_metrics_partial = partial(compute_metrics, win_len=win_len, win_hop=win_hop)
data = {}
with Pool(n_jobs) as p:
for metrics, k in tqdm(p.imap(compute_metrics_partial, files_list), total=len(files_list)):
if metrics is not None:
data[k] = metrics
# save metrics
data_df = pd.concat(data, names=['noisy_filename', 'clean_filename', 'segment'])
df_path = os.path.join(output_path, f"metrics_{win_len}_{win_hop}.pkl")
data_df.to_pickle(df_path)
print(f"Metrics saved at {df_path}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description='Compute speech quality labels for a given DNS-type dataset.')
parser.add_argument('egs_path', type=str, help='Path containing EGS files')
parser.add_argument('output_path', type=str, help='Path to save the metrics dataframe at')
parser.add_argument('--win_len', type=float, default=9.0, help='Window length in seconds')
parser.add_argument('--win_hop', type=float, default=2.0, help='Window hop in seconds')
parser.add_argument('--n_jobs', type=int, default=12, help='Number of parallel jobs')
parser.add_argument('--matching', type=str, default='dns', help='Matching function for noisy and clean files')
args = parser.parse_args()
main(args.egs_path, args.output_path, args.win_len, args.win_hop, args.n_jobs, args.matching)