From 4844790e10b1c4a82a2ea81106b821de82fd7c96 Mon Sep 17 00:00:00 2001 From: "Soumick Chatterjee, PhD" Date: Wed, 19 Jul 2023 19:56:06 +0200 Subject: [PATCH] predict overridden in rhvae (#98) * predict overridden in rhvae * rhvae predict corrected --------- Co-authored-by: soumick.chatterjee --- src/pythae/models/rhvae/rhvae_model.py | 66 ++++++++++++++++++++++++++ 1 file changed, 66 insertions(+) diff --git a/src/pythae/models/rhvae/rhvae_model.py b/src/pythae/models/rhvae/rhvae_model.py index 9b3cbfad..b37eabfc 100644 --- a/src/pythae/models/rhvae/rhvae_model.py +++ b/src/pythae/models/rhvae/rhvae_model.py @@ -263,6 +263,72 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: return output + def predict(self, inputs: torch.Tensor) -> ModelOutput: + """The input data is encoded and decoded without computing loss + + Args: + inputs (torch.Tensor): The input data to be reconstructed, as well as to generate the embedding. + + Returns: + ModelOutput: An instance of ModelOutput containing reconstruction, raw embedding (output of encoder), and the final embedding (output of metric) + """ + encoder_output = self.encoder(inputs) + mu, log_var = encoder_output.embedding, encoder_output.log_covariance + + std = torch.exp(0.5 * log_var) + z0, _ = self._sample_gauss(mu, std) + + z = z0 + + G = self.G(z) + G_inv = self.G_inv(z) + L = torch.linalg.cholesky(G) + + G_log_det = -torch.logdet(G_inv) + + gamma = torch.randn_like(z0, device=inputs.device) + rho = gamma / self.beta_zero_sqrt + beta_sqrt_old = self.beta_zero_sqrt + + # sample \rho from N(0, G) + rho = (L @ rho.unsqueeze(-1)).squeeze(-1) + + recon_x = self.decoder(z)["reconstruction"] + + for k in range(self.n_lf): + + # perform leapfrog steps + + # step 1 + rho_ = self._leap_step_1(recon_x, inputs, z, rho, G_inv, G_log_det) + + # step 2 + z = self._leap_step_2(recon_x, inputs, z, rho_, G_inv, G_log_det) + + recon_x = self.decoder(z)["reconstruction"] + + # compute metric value on new z using final metric + G = self.G(z) + G_inv = self.G_inv(z) + + G_log_det = -torch.logdet(G_inv) + + # step 3 + rho__ = self._leap_step_3(recon_x, inputs, z, rho_, G_inv, G_log_det) + + # tempering + beta_sqrt = self._tempering(k + 1, self.n_lf) + rho = (beta_sqrt_old / beta_sqrt) * rho__ + beta_sqrt_old = beta_sqrt + + output = ModelOutput( + recon_x=recon_x, + raw_embedding=encoder_output.embedding, + embedding=z if self.n_lf > 0 else encoder_output.embedding, + ) + + return output + def _leap_step_1(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3): """ Resolves first equation of generalized leapfrog integrator