Skip to content

Commit

Permalink
Merge pull request #31 from 0X0StradSong/main
Browse files Browse the repository at this point in the history
clean codebase & stablize vtrust
  • Loading branch information
ctlllll authored Jul 1, 2024
2 parents ff7a75d + 0627727 commit c0c6b99
Showing 1 changed file with 1 addition and 67 deletions.
68 changes: 1 addition & 67 deletions tts_rater/rater.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,51 +215,6 @@ def compute_dns_mos_loss(audio_paths, batch_size):
return [-x for x in dns_mos_results]


# =================== PANN MMD loss ===================
# speaker_pann_embeds = torch.load(os.path.join(script_dir, "pann/pann_embeds.pth"), map_location="cuda")
# pann_model = PANNModel()

# def compute_mmd(a_x: torch.Tensor, b_y: torch.Tensor):
# _SIGMA = 10
# _SCALE = 1000

# a_x = a_x.double()
# b_y = b_y.double()

# a_x_sqnorms = torch.sum(a_x**2, dim=1)
# b_y_sqnorms = torch.sum(b_y**2, dim=1)

# gamma = 1 / (2 * _SIGMA**2)

# k_xx = torch.mean(torch.exp(-gamma * (-2 * (a_x @ a_x.T) + a_x_sqnorms[:, None] + a_x_sqnorms[None, :])))
# k_xy = torch.mean(torch.exp(-gamma * (-2 * (a_x @ b_y.T) + a_x_sqnorms[:, None] + b_y_sqnorms[None, :])))
# k_yy = torch.mean(torch.exp(-gamma * (-2 * (b_y @ b_y.T) + b_y_sqnorms[:, None] + b_y_sqnorms[None, :])))

# return _SCALE * (k_xx + k_yy - 2 * k_xy)


# def compute_pann_mmd_loss(audio_paths: list[str], speaker: str = "p374"):

# n_samples = len(audio_paths)
# waveforms = [load_wav_file(fname, 32000) for fname in audio_paths]

# embeddings = []
# for audio in tqdm(waveforms):
# audio = torch.Tensor(audio).cuda()
# embedding = pann_model.get_embedding(audio[None])[0]
# embeddings.append(embedding)
# embeddings = torch.stack(embeddings, dim=0)
# embeddings_reverse = embeddings.flip(0)
# embeddings_cat = torch.cat([embeddings, embeddings_reverse], dim=1).reshape(-1, embeddings.shape[-1])

# mmd_losses = []
# for idx in range(n_samples):
# sampled_embeddings = embeddings_cat[idx: idx + 16]
# mmd = compute_mmd(speaker_pann_embeds[speaker], sampled_embeddings)
# mmd_losses.append(mmd.item())

# return mmd_losses

# =================== Anti Spoofing loss ===================
speaker_antispoofing_embeds = torch.load(os.path.join(script_dir, "rawnet/antispoofing_embeds.pth"), map_location="cuda")
antispoofing_model = AntiSpoofingInference()
Expand Down Expand Up @@ -358,25 +313,6 @@ def get_normalized_scores(raw_errs: dict[str, float]):
normalized_scores[key] = np.clip(1 - normalized_err, 1e-6, 1.0)
return normalized_scores


# def compute_sharpe_ratios(scores: list[float]) -> list[float]:
# # Jackknife estimate of the Sharpe ratio.
# n = len(scores)
# sharpe_ratios = []
# for ii in range(n):
# scores_jack = scores[:ii] + scores[ii + 1 :]
# mean_jack = np.mean(scores_jack)
# std_jack = np.std(scores_jack, ddof=1)
# sharpe_jack = mean_jack / std_jack

# if mean_jack < 1e-6 and std_jack == 0.0:
# sharpe_jack = 0.0

# sharpe_ratios.append(sharpe_jack)

# return sharpe_ratios


def rate(
ckpt_path,
speaker="p374",
Expand Down Expand Up @@ -451,7 +387,7 @@ def rate_(
keys = list(norm_dict.keys())
norm_scores = []
for ii in range(samples):
norm_score = np.prod([norm_dict[k][ii] for k in keys])
norm_score = norm_dict["word_error_rate"][ii] * (norm_dict["tone_color"][ii] + norm_dict["antispoofing"][ii] + norm_dict["judge_scores"][ii])
norm_scores.append(norm_score)

return norm_scores, norm_dict
Expand All @@ -466,7 +402,6 @@ def rate_(
parser.add_argument("--speaker", type=str, default="p374")
parser.add_argument("--seed", type=int, default=0)
parser.add_argument("--samples", type=int, default=64)
parser.add_argument("--n_bootstrap", type=int, default=256)
parser.add_argument("--batch_size", type=int, default=16)
args = parser.parse_args()

Expand All @@ -477,7 +412,6 @@ def rate_(
args.seed,
args.samples,
args.batch_size,
args.n_bootstrap,
args.use_tmpdir,
)
)

0 comments on commit c0c6b99

Please sign in to comment.