Skip to content

Commit

Permalink
likely hood
Browse files Browse the repository at this point in the history
  • Loading branch information
MaxMax2016 authored Nov 13, 2023
1 parent ef2cea4 commit d15f3cd
Showing 1 changed file with 92 additions and 86 deletions.
178 changes: 92 additions & 86 deletions pitch/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,6 @@ def forward(self, x):
return x * torch.tanh(torch.nn.functional.softplus(x))


class Upsample(BaseModule):
def __init__(self, dim):
super(Upsample, self).__init__()
self.conv = torch.nn.ConvTranspose2d(dim, dim, (1,4), (1,2), 1)

def forward(self, x):
return self.conv(x)


class Downsample(BaseModule):
def __init__(self, dim):
super(Downsample, self).__init__()
self.conv = torch.nn.Conv2d(dim, dim, (1,3), (1,2), 1)

def forward(self, x):
return self.conv(x)


class Rezero(BaseModule):
def __init__(self, fn):
super(Rezero, self).__init__()
Expand Down Expand Up @@ -117,15 +99,15 @@ def forward(self, x, scale=1000):


class GradLogPEstimator2d(BaseModule):
def __init__(self, dim, c_dim, n_mels, dim_mults=(1, 2, 4), groups=8, pe_scale=1000):
def __init__(self, n_feat, n_cond, dim, dim_mults=(1, 2, 4), groups=8, pe_scale=1000):
super(GradLogPEstimator2d, self).__init__()
self.dim = dim
self.dim_mults = dim_mults
self.groups = groups
self.pe_scale = pe_scale

self.cond = torch.nn.Sequential(torch.nn.Conv1d(c_dim, dim * 4, 1), Mish(),
torch.nn.Conv1d(dim * 4, n_mels, 1))
self.cond = torch.nn.Sequential(torch.nn.Conv1d(n_cond, dim * 4, 1), Mish(),
torch.nn.Conv1d(dim * 4, n_feat, 1))
self.time_pos_emb = SinusoidalPosEmb(dim)
self.mlp = torch.nn.Sequential(torch.nn.Linear(dim, dim * 4), Mish(),
torch.nn.Linear(dim * 4, dim))
Expand All @@ -142,7 +124,7 @@ def __init__(self, dim, c_dim, n_mels, dim_mults=(1, 2, 4), groups=8, pe_scale=1
ResnetBlock(dim_in, dim_out, time_emb_dim=dim),
ResnetBlock(dim_out, dim_out, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_out))),
Downsample(dim_out) if not is_last else torch.nn.Identity()]))
torch.nn.Identity()]))

mid_dim = dims[-1]
self.mid_block1 = ResnetBlock(mid_dim, mid_dim, time_emb_dim=dim)
Expand All @@ -151,14 +133,14 @@ def __init__(self, dim, c_dim, n_mels, dim_mults=(1, 2, 4), groups=8, pe_scale=1

for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])): # 2 ups
self.ups.append(torch.nn.ModuleList([
ResnetBlock(dim_out * 2, dim_in, time_emb_dim=dim),
ResnetBlock(dim_out, dim_in, time_emb_dim=dim),
ResnetBlock(dim_in, dim_in, time_emb_dim=dim),
Residual(Rezero(LinearAttention(dim_in))),
Upsample(dim_in)]))
torch.nn.Identity()]))
self.final_block = Block(dim, dim)
self.final_conv = torch.nn.Conv2d(dim, 1, 1)

def forward(self, c, x, mask, mu, t):
def forward(self, x, mask, mu, c, t):

t = self.time_pos_emb(t, scale=self.pe_scale)
t = self.mlp(t)
Expand All @@ -167,100 +149,124 @@ def forward(self, c, x, mask, mu, t):
x = torch.stack([mu, x, c], 1)
mask = mask.unsqueeze(1)

hiddens = []
masks = [mask]
for resnet1, resnet2, attn, downsample in self.downs:
mask_down = masks[-1]
x = resnet1(x, mask_down, t)
x = resnet2(x, mask_down, t)
x = resnet1(x, mask, t)
x = resnet2(x, mask, t)
x = attn(x)
hiddens.append(x)
x = downsample(x * mask_down)
masks.append(mask_down[:, :, :, ::2])
x = downsample(x * mask)

masks = masks[:-1]
mask_mid = masks[-1]
x = self.mid_block1(x, mask_mid, t)
x = self.mid_block1(x, mask, t)
x = self.mid_attn(x)
x = self.mid_block2(x, mask_mid, t)
x = self.mid_block2(x, mask, t)

for resnet1, resnet2, attn, upsample in self.ups:
mask_up = masks.pop()
x = torch.cat((x, hiddens.pop()), dim=1)
x = resnet1(x, mask_up, t)
x = resnet2(x, mask_up, t)
x = resnet1(x, mask, t)
x = resnet2(x, mask, t)
x = attn(x)
x = upsample(x * mask_up)
x = upsample(x * mask)

x = self.final_block(x, mask)
output = self.final_conv(x * mask)

return (output * mask).squeeze(1)


def get_noise(t, beta_init, beta_term, cumulative=False):
if cumulative:
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
else:
noise = beta_init + (beta_term - beta_init)*t
return noise


class Diffusion(BaseModule):
def __init__(self, n_mels, dim, c_dim, beta_min=0.05, beta_max=20, pe_scale=1000):
def __init__(self, n_feat, n_cond, dim, beta_min=0.05, beta_max=20, pe_scale=1000):
super(Diffusion, self).__init__()
self.n_mels = n_mels
self.estimator = GradLogPEstimator2d(n_feat, n_cond, dim, pe_scale=pe_scale)
self.n_feat = n_feat
self.beta_min = beta_min
self.beta_max = beta_max
self.estimator = GradLogPEstimator2d(dim, c_dim, n_mels, pe_scale=pe_scale)

def forward_diffusion(self, mel, mask, mu, t):
time = t.unsqueeze(-1).unsqueeze(-1)
cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
variance = 1.0 - torch.exp(-cum_noise)
z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device,
requires_grad=False)
xt = mean + z * torch.sqrt(variance)
return xt * mask, z * mask
def get_beta(self, t):
beta = self.beta_min + (self.beta_max - self.beta_min) * t
return beta

def get_gamma(self, s, t, p=1.0, use_torch=False):
beta_integral = self.beta_min + 0.5 * (self.beta_max - self.beta_min) * (t + s)
beta_integral *= (t - s)
if use_torch:
gamma = torch.exp(-0.5 * p * beta_integral).unsqueeze(-1).unsqueeze(-1)
else:
gamma = math.exp(-0.5 * p * beta_integral)
return gamma

def get_mu(self, s, t):
a = self.get_gamma(s, t)
b = 1.0 - self.get_gamma(0, s, p=2.0)
c = 1.0 - self.get_gamma(0, t, p=2.0)
return a * b / c

def get_nu(self, s, t):
a = self.get_gamma(0, s)
b = 1.0 - self.get_gamma(s, t, p=2.0)
c = 1.0 - self.get_gamma(0, t, p=2.0)
return a * b / c

def get_sigma(self, s, t):
a = 1.0 - self.get_gamma(0, s, p=2.0)
b = 1.0 - self.get_gamma(s, t, p=2.0)
c = 1.0 - self.get_gamma(0, t, p=2.0)
return math.sqrt(a * b / c)

@torch.no_grad()
def reverse_diffusion(self, c, z, mask, mu, n_timesteps, stoc=False):
def reverse_diffusion(self, z, mask, mu, mu_c, n_timesteps):
h = 1.0 / n_timesteps
xt = z * mask

for i in range(n_timesteps):
t = (1.0 - (i + 0.5)*h) * torch.ones(z.shape[0], dtype=z.dtype,
device=z.device)
time = t.unsqueeze(-1).unsqueeze(-1)
noise_t = get_noise(time, self.beta_min, self.beta_max,
cumulative=False)
if stoc: # adds stochastic term
dxt_det = 0.5 * (mu - xt) - self.estimator(c, xt, mask, mu, t)
dxt_det = dxt_det * noise_t * h
dxt_stoc = torch.randn(z.shape, dtype=z.dtype, device=z.device,
requires_grad=False)
dxt_stoc = dxt_stoc * torch.sqrt(noise_t * h)
dxt = dxt_det + dxt_stoc
else:
dxt = 0.5 * (mu - xt - self.estimator(c, xt, mask, mu, t))
dxt = dxt * noise_t * h
t = 1.0 - i * h
time = t * torch.ones(z.shape[0], dtype=z.dtype, device=z.device)
beta_t = self.get_beta(t)

kappa = self.get_gamma(0, t - h) * (1.0 - self.get_gamma(t - h, t, p=2.0))
kappa /= (self.get_gamma(0, t) * beta_t * h)
kappa -= 1.0
omega = self.get_nu(t - h, t) / self.get_gamma(0, t)
omega += self.get_mu(t - h, t)
omega -= (0.5 * beta_t * h + 1.0)
sigma = self.get_sigma(t - h, t)

dxt = (mu - xt) * (0.5 * beta_t * h + omega)
dxt -= (self.estimator(xt, mask, mu, mu_c, time)) * (1.0 + kappa) * (beta_t * h)
dxt += torch.randn_like(z, device=z.device) * sigma
xt = (xt - dxt) * mask

return xt

@torch.no_grad()
def forward(self, c, z, mask, mu, n_timesteps, stoc=False):
return self.reverse_diffusion(c, z, mask, mu, n_timesteps, stoc)
def forward(self, z, mask, mu, mu_c, n_timesteps):
return self.reverse_diffusion(z, mask, mu, mu_c, n_timesteps)

# train: mel means f0_groun_truth
def get_noise(self, t, beta_init, beta_term, cumulative=False):
if cumulative:
noise = beta_init*t + 0.5*(beta_term - beta_init)*(t**2)
else:
noise = beta_init + (beta_term - beta_init)*t
return noise

def forward_diffusion(self, mel, mask, mu, t):
time = t.unsqueeze(-1).unsqueeze(-1)
cum_noise = self.get_noise(time, self.beta_min, self.beta_max, cumulative=True)
mean = mel*torch.exp(-0.5*cum_noise) + mu*(1.0 - torch.exp(-0.5*cum_noise))
variance = 1.0 - torch.exp(-cum_noise)
z = torch.randn(mel.shape, dtype=mel.dtype, device=mel.device,
requires_grad=False)
xt = mean + z * torch.sqrt(variance)
return xt * mask, z * mask

def loss_t(self, c, mel, mask, mu, t):
def loss_t(self, mel, mask, mu, mu_c, t):
xt, z = self.forward_diffusion(mel, mask, mu, t)
time = t.unsqueeze(-1).unsqueeze(-1)
cum_noise = get_noise(time, self.beta_min, self.beta_max, cumulative=True)
noise_estimation = self.estimator(c, xt, mask, mu, t)
cum_noise = self.get_noise(time, self.beta_min, self.beta_max, cumulative=True)
noise_estimation = self.estimator(xt, mask, mu, mu_c, t)
noise_estimation *= torch.sqrt(1.0 - torch.exp(-cum_noise))
loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_mels)
loss = torch.sum((noise_estimation + z)**2) / (torch.sum(mask)*self.n_feat)
return loss, xt

def compute_loss(self, c, mel, mask, mu, offset=1e-5):
def compute_loss(self, mel, mask, mu, mu_c, offset=1e-5):
t = torch.rand(mel.shape[0], dtype=mel.dtype, device=mel.device, requires_grad=False)
t = torch.clamp(t, offset, 1.0 - offset)
return self.loss_t(c, mel, mask, mu, t)
return self.loss_t(mel, mask, mu, mu_c, t)

0 comments on commit d15f3cd

Please sign in to comment.