Skip to content

Commit

Permalink
09/06/2024 Update
Browse files Browse the repository at this point in the history
* Corrected classifier-free guidance method.
  • Loading branch information
Augus1999 authored Jun 9, 2024
1 parent 798ff07 commit 79068bc
Showing 1 changed file with 13 additions and 19 deletions.
32 changes: 13 additions & 19 deletions bayesianflow_for_chem/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,16 +300,24 @@ def calc_cts_alpha(self, t: Tensor) -> Tensor:
return 2 * a / b / self.K

def discrete_output_distribution(
self, theta: Tensor, t: Tensor, y: Optional[Tensor]
self, theta: Tensor, t: Tensor, y: Optional[Tensor], w: Optional[float]
) -> Tensor:
"""
:param theta: input distribution; shape: (n_b, n_t, n_vocab)
:param t: continuous time in [0, 1]; shape: (n_b, 1)
:param y: conditioning vector; shape: (n_b, 1, n_f)
:param w: guidance strength controlling the conditional generation
:return: output distribution; shape: (n_b, n_t, n_vocab)
"""
theta = 2 * theta - 1 # rescale to [-1, 1]
return softmax(self.forward(theta, t, None, y), -1)
if w is None:
return softmax(self.forward(theta, t, None, y), -1)
elif y is None:
return softmax(self.forward(theta, t, None, None), -1)
else:
p_cond = self.forward(theta, t, None, y)
p_uncond = self.forward(theta, t, None, None)
return softmax((1 + w) * p_cond - w * p_uncond, -1)

def cts_loss(self, x: Tensor, t: Tensor, y: Optional[Tensor]) -> Tensor:
"""
Expand All @@ -325,7 +333,7 @@ def cts_loss(self, x: Tensor, t: Tensor, y: Optional[Tensor]) -> Tensor:
mu = beta * (self.K * e_x - 1)
sigma = (beta * self.K).sqrt()
theta = softmax(mu + sigma * torch.randn_like(mu), -1)
e_hat = self.discrete_output_distribution(theta, t, y)
e_hat = self.discrete_output_distribution(theta, t, y, None)
cts_loss = self.K * (e_x - e_hat).pow(2) * self.calc_cts_alpha(t)[..., None]
return cts_loss.mean()

Expand Down Expand Up @@ -374,29 +382,15 @@ def sample(
)
for i in torch.linspace(1, sample_step, sample_step, device=self.beta.device):
t = (i - 1).view(1, 1).repeat(batch_size, 1) / sample_step
if y is None:
p = self.discrete_output_distribution(theta, t, None)
else:
p = (1 + guidance_strength) * self.discrete_output_distribution(
theta, t, y
) - guidance_strength * self.discrete_output_distribution(
theta, t, None
)
p = self.discrete_output_distribution(theta, t, y, guidance_strength)
alpha = self.calc_discrete_alpha(t, t + 1 / sample_step)[..., None]
e_k = nn.functional.one_hot(torch.argmax(p, -1), self.K).float()
mu = alpha * (self.K * e_k - 1)
sigma = (alpha * self.K).sqrt()
theta = (mu + sigma * torch.randn_like(mu)).exp() * theta
theta = theta / theta.sum(-1, True)
t_final = torch.ones((batch_size, 1), device=self.beta.device)
if y is None:
p = self.discrete_output_distribution(theta, t_final, None)
else:
p = (1 + guidance_strength) * self.discrete_output_distribution(
theta, t_final, y
) - guidance_strength * self.discrete_output_distribution(
theta, t_final, None
)
p = self.discrete_output_distribution(theta, t_final, y, guidance_strength)
return torch.argmax(p, -1)

def inference(self, x: Tensor, mlp: nn.Module) -> Tensor:
Expand Down

0 comments on commit 79068bc

Please sign in to comment.