Skip to content

Commit

Permalink
migration to pydantic>=2.* (#105)
Browse files Browse the repository at this point in the history
* Unify gaussian likelihoods

* remove print

* `pydantic` migration

* fix VAEGAN and planar flows

* black and isort

* cleaning

* prepare release
  • Loading branch information
clementchadebec authored Sep 6, 2023
1 parent 337c2e1 commit 24939dc
Show file tree
Hide file tree
Showing 39 changed files with 187 additions and 125 deletions.
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cloudpickle>=2.1.0
imageio
numpy>=1.19
pydantic==1.8.2
pydantic>=2.0
scikit-learn
scipy>=1.7.1
torch>=1.10.1
Expand Down
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="pythae",
version="0.1.1",
version="0.1.2",
author="Clement Chadebec (HekA team INRIA)",
author_email="[email protected]",
description="Unifying Generative Autoencoders in Python",
Expand All @@ -29,7 +29,7 @@
"cloudpickle>=2.1.0",
"imageio",
"numpy>=1.19",
"pydantic==1.8.2",
"pydantic>=2.0",
"scikit-learn",
"scipy>=1.7.1",
"torch>=1.10.1",
Expand Down
13 changes: 8 additions & 5 deletions src/pythae/models/adversarial_ae/adversarial_ae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,11 +154,14 @@ def loss_function(self, recon_x, x, z, z_prior):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
4 changes: 2 additions & 2 deletions src/pythae/models/base/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def reconstruct(self, inputs: torch.Tensor):
torch.Tensor: A tensor of shape [B x input_dim] containing the reconstructed samples.
"""
return self(DatasetOutput(data=inputs)).recon_x

def embed(self, inputs: torch.Tensor) -> torch.Tensor:
"""Return the embeddings of the input data.
Expand All @@ -127,7 +127,7 @@ def embed(self, inputs: torch.Tensor) -> torch.Tensor:
torch.Tensor: A tensor of shape [B x latent_dim] containing the embeddings.
"""
return self(DatasetOutput(data=inputs)).z

def predict(self, inputs: torch.Tensor) -> ModelOutput:
"""The input data is encoded and decoded without computing loss
Expand Down
13 changes: 8 additions & 5 deletions src/pythae/models/beta_tc_vae/beta_tc_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, dataset_size):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
13 changes: 8 additions & 5 deletions src/pythae/models/beta_vae/beta_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,11 +84,14 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
3 changes: 2 additions & 1 deletion src/pythae/models/ciwae/ciwae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,5 +19,6 @@ class CIWAEConfig(VAEConfig):
number_samples: int = 10
beta: float = 0.5

def __post_init_post_parse__(self):
def __post_init__(self):
super().__post_init__()
assert 0 <= self.beta <= 1, f"Beta parameter must be in [0-1]. Got {self.beta}."
17 changes: 10 additions & 7 deletions src/pythae/models/ciwae/ciwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,16 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.repeat(1, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.repeat(1, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, epoch):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
1 change: 0 additions & 1 deletion src/pythae/models/factor_vae/factor_vae_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,3 @@ class FactorVAEConfig(VAEConfig):
"""
gamma: float = 2.0
uses_default_discriminator: bool = True
discriminator_input_dim: int = None
15 changes: 9 additions & 6 deletions src/pythae/models/factor_vae/factor_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput:
recon_x=recon_x,
recon_x_indices=idx_1,
z=z,
z_bis_permuted=z_bis_permuted
z_bis_permuted=z_bis_permuted,
)

return output
Expand All @@ -137,11 +137,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
13 changes: 8 additions & 5 deletions src/pythae/models/info_vae/info_vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,11 +92,14 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
17 changes: 10 additions & 7 deletions src/pythae/models/iwae/iwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,13 +94,16 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.repeat(1, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.repeat(1, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
19 changes: 11 additions & 8 deletions src/pythae/models/miwae/miwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,14 +103,17 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, self.gradient_n_estimates, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, self.gradient_n_estimates, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ class PixelCNNConfig(BaseNFConfig):
n_layers: int = 10
kernel_size: int = 5

def __post_init_post_parse__(self):
def __post_init__(self):
super().__post_init__()
assert (
self.kernel_size % 2 == 1
), f"Wrong kernel size provided. The kernel size must be odd. Got {self.kernel_size}."
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ class PlanarFlowConfig(BaseNFConfig):

activation: str = "tanh"

def __post_init_post_parse__(self):
def __post_init__(self):
super().__post_init__()
assert self.activation in ["linear", "tanh", "elu"], (
f"'{self.activation}' doesn't correspond to an activation handled by the model. "
"Available activations ['linear', 'tanh', 'elu']"
Expand Down
19 changes: 11 additions & 8 deletions src/pythae/models/piwae/piwae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,14 +111,17 @@ def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, self.gradient_n_estimates, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x,
x.reshape(recon_x.shape[0], -1)
.unsqueeze(1)
.unsqueeze(1)
.repeat(1, self.gradient_n_estimates, self.n_samples, 1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
13 changes: 8 additions & 5 deletions src/pythae/models/pvae/pvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,11 +128,14 @@ def loss_function(self, recon_x, x, z, qz_x):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
4 changes: 2 additions & 2 deletions src/pythae/models/rhvae/rhvae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ def predict(self, inputs: torch.Tensor) -> ModelOutput:
z = self._leap_step_2(recon_x, inputs, z, rho_, G_inv, G_log_det)

recon_x = self.decoder(z)["reconstruction"]

# compute metric value on new z using final metric
G = self.G(z)
G_inv = self.G_inv(z)
Expand All @@ -328,7 +328,7 @@ def predict(self, inputs: torch.Tensor) -> ModelOutput:
)

return output

def _leap_step_1(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3):
"""
Resolves first equation of generalized leapfrog integrator
Expand Down
13 changes: 8 additions & 5 deletions src/pythae/models/svae/svae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,11 +101,14 @@ def loss_function(self, recon_x, x, loc, concentration, z):

if self.model_config.reconstruction_loss == "mse":

recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down
15 changes: 9 additions & 6 deletions src/pythae/models/vae/vae_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,11 +100,14 @@ def forward(self, inputs: BaseDataset, **kwargs):
def loss_function(self, recon_x, x, mu, log_var, z):

if self.model_config.reconstruction_loss == "mse":
recon_loss = 0.5 * F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
recon_loss = (
0.5
* F.mse_loss(
recon_x.reshape(x.shape[0], -1),
x.reshape(x.shape[0], -1),
reduction="none",
).sum(dim=-1)
)

elif self.model_config.reconstruction_loss == "bce":

Expand Down Expand Up @@ -161,7 +164,7 @@ def get_nll(self, data, n_samples=1, batch_size=100):
log_q_z_given_x = -0.5 * (
log_var + (z - mu) ** 2 / torch.exp(log_var)
).sum(dim=-1)
log_p_z = -0.5 * (z**2).sum(dim=-1)
log_p_z = -0.5 * (z ** 2).sum(dim=-1)

recon_x = self.decoder(z)["reconstruction"]

Expand Down
Loading

0 comments on commit 24939dc

Please sign in to comment.