Skip to content

Commit

Permalink
Unify gaussian likelihoods (#104)
Browse files Browse the repository at this point in the history
* Unify gaussian likelihoods

* remove print
  • Loading branch information
clementchadebec authored Sep 6, 2023
1 parent ca55647 commit 337c2e1
Show file tree
Hide file tree
Showing 17 changed files with 17 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/pythae/models/adversarial_ae/adversarial_ae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def loss_function(self, recon_x, x, z, z_prior):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/beta_tc_vae/beta_tc_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def loss_function(self, recon_x, x, mu, log_var, z, dataset_size):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/beta_vae/beta_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/ciwae/ciwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def loss_function(self, recon_x, x, mu, log_var, z, epoch):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/factor_vae/factor_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/info_vae/info_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/iwae/iwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/miwae/miwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/piwae/piwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/pvae/pvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ def loss_function(self, recon_x, x, z, qz_x):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/svae/svae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def loss_function(self, recon_x, x, loc, concentration, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
3 changes: 1 addition & 2 deletions src/pythae/models/vae/vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,7 @@ def forward(self, inputs: BaseDataset, **kwargs):
def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/vae_gan/vae_gan_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,7 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var):
)[f"embedding_layer_{self.reconstruction_layer}"]

# MSE in feature space
recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
true_discr_layer.reshape(N, -1),
recon_discr_layer.reshape(N, -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/vae_iaf/vae_iaf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def loss_function(self, recon_x, x, mu, log_var, z0, zk, log_abs_det_jac):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/vae_lin_nf/vae_lin_nf_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def loss_function(self, recon_x, x, mu, log_var, z0, zk, log_abs_det_jac):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down
2 changes: 1 addition & 1 deletion src/pythae/models/vamp/vamp_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ def loss_function(self, recon_x, x, mu, log_var, z, epoch):

if self.model_config.reconstruction_loss == "mse":

recon_loss = F.mse_loss(
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
Expand Down

0 comments on commit 337c2e1

Please sign in to comment.