From 06277279f7d680dd5a8291d0e18fa0ebf3371c9b Mon Sep 17 00:00:00 2001 From: 0X0StradSong Date: Mon, 1 Jul 2024 16:57:37 +0000 Subject: [PATCH] clean codebase & stablize vtrust --- tts_rater/rater.py | 68 +--------------------------------------------- 1 file changed, 1 insertion(+), 67 deletions(-) diff --git a/tts_rater/rater.py b/tts_rater/rater.py index 08c5d50..886631f 100644 --- a/tts_rater/rater.py +++ b/tts_rater/rater.py @@ -215,51 +215,6 @@ def compute_dns_mos_loss(audio_paths, batch_size): return [-x for x in dns_mos_results] -# =================== PANN MMD loss =================== -# speaker_pann_embeds = torch.load(os.path.join(script_dir, "pann/pann_embeds.pth"), map_location="cuda") -# pann_model = PANNModel() - -# def compute_mmd(a_x: torch.Tensor, b_y: torch.Tensor): -# _SIGMA = 10 -# _SCALE = 1000 - -# a_x = a_x.double() -# b_y = b_y.double() - -# a_x_sqnorms = torch.sum(a_x**2, dim=1) -# b_y_sqnorms = torch.sum(b_y**2, dim=1) - -# gamma = 1 / (2 * _SIGMA**2) - -# k_xx = torch.mean(torch.exp(-gamma * (-2 * (a_x @ a_x.T) + a_x_sqnorms[:, None] + a_x_sqnorms[None, :]))) -# k_xy = torch.mean(torch.exp(-gamma * (-2 * (a_x @ b_y.T) + a_x_sqnorms[:, None] + b_y_sqnorms[None, :]))) -# k_yy = torch.mean(torch.exp(-gamma * (-2 * (b_y @ b_y.T) + b_y_sqnorms[:, None] + b_y_sqnorms[None, :]))) - -# return _SCALE * (k_xx + k_yy - 2 * k_xy) - - -# def compute_pann_mmd_loss(audio_paths: list[str], speaker: str = "p374"): - -# n_samples = len(audio_paths) -# waveforms = [load_wav_file(fname, 32000) for fname in audio_paths] - -# embeddings = [] -# for audio in tqdm(waveforms): -# audio = torch.Tensor(audio).cuda() -# embedding = pann_model.get_embedding(audio[None])[0] -# embeddings.append(embedding) -# embeddings = torch.stack(embeddings, dim=0) -# embeddings_reverse = embeddings.flip(0) -# embeddings_cat = torch.cat([embeddings, embeddings_reverse], dim=1).reshape(-1, embeddings.shape[-1]) - -# mmd_losses = [] -# for idx in range(n_samples): -# sampled_embeddings = embeddings_cat[idx: idx + 16] -# mmd = compute_mmd(speaker_pann_embeds[speaker], sampled_embeddings) -# mmd_losses.append(mmd.item()) - -# return mmd_losses - # =================== Anti Spoofing loss =================== speaker_antispoofing_embeds = torch.load(os.path.join(script_dir, "rawnet/antispoofing_embeds.pth"), map_location="cuda") antispoofing_model = AntiSpoofingInference() @@ -358,25 +313,6 @@ def get_normalized_scores(raw_errs: dict[str, float]): normalized_scores[key] = np.clip(1 - normalized_err, 1e-6, 1.0) return normalized_scores - -# def compute_sharpe_ratios(scores: list[float]) -> list[float]: -# # Jackknife estimate of the Sharpe ratio. -# n = len(scores) -# sharpe_ratios = [] -# for ii in range(n): -# scores_jack = scores[:ii] + scores[ii + 1 :] -# mean_jack = np.mean(scores_jack) -# std_jack = np.std(scores_jack, ddof=1) -# sharpe_jack = mean_jack / std_jack - -# if mean_jack < 1e-6 and std_jack == 0.0: -# sharpe_jack = 0.0 - -# sharpe_ratios.append(sharpe_jack) - -# return sharpe_ratios - - def rate( ckpt_path, speaker="p374", @@ -451,7 +387,7 @@ def rate_( keys = list(norm_dict.keys()) norm_scores = [] for ii in range(samples): - norm_score = np.prod([norm_dict[k][ii] for k in keys]) + norm_score = norm_dict["word_error_rate"][ii] * (norm_dict["tone_color"][ii] + norm_dict["antispoofing"][ii] + norm_dict["judge_scores"][ii]) norm_scores.append(norm_score) return norm_scores, norm_dict @@ -466,7 +402,6 @@ def rate_( parser.add_argument("--speaker", type=str, default="p374") parser.add_argument("--seed", type=int, default=0) parser.add_argument("--samples", type=int, default=64) - parser.add_argument("--n_bootstrap", type=int, default=256) parser.add_argument("--batch_size", type=int, default=16) args = parser.parse_args() @@ -477,7 +412,6 @@ def rate_( args.seed, args.samples, args.batch_size, - args.n_bootstrap, args.use_tmpdir, ) )