diff --git a/requirements.txt b/requirements.txt index 56f56dfc..1119db57 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,7 @@ cloudpickle>=2.1.0 imageio numpy>=1.19 -pydantic==1.8.2 +pydantic>=2.0 scikit-learn scipy>=1.7.1 torch>=1.10.1 diff --git a/setup.py b/setup.py index 4f37cb22..1e5b496e 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="pythae", - version="0.1.1", + version="0.1.2", author="Clement Chadebec (HekA team INRIA)", author_email="clement.chadebec@inria.fr", description="Unifying Generative Autoencoders in Python", @@ -29,7 +29,7 @@ "cloudpickle>=2.1.0", "imageio", "numpy>=1.19", - "pydantic==1.8.2", + "pydantic>=2.0", "scikit-learn", "scipy>=1.7.1", "torch>=1.10.1", diff --git a/src/pythae/models/adversarial_ae/adversarial_ae_model.py b/src/pythae/models/adversarial_ae/adversarial_ae_model.py index 8689c04a..61125393 100644 --- a/src/pythae/models/adversarial_ae/adversarial_ae_model.py +++ b/src/pythae/models/adversarial_ae/adversarial_ae_model.py @@ -154,11 +154,14 @@ def loss_function(self, recon_x, x, z, z_prior): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/base/base_model.py b/src/pythae/models/base/base_model.py index 1ee028d5..23dd9600 100644 --- a/src/pythae/models/base/base_model.py +++ b/src/pythae/models/base/base_model.py @@ -116,7 +116,7 @@ def reconstruct(self, inputs: torch.Tensor): torch.Tensor: A tensor of shape [B x input_dim] containing the reconstructed samples. """ return self(DatasetOutput(data=inputs)).recon_x - + def embed(self, inputs: torch.Tensor) -> torch.Tensor: """Return the embeddings of the input data. @@ -127,7 +127,7 @@ def embed(self, inputs: torch.Tensor) -> torch.Tensor: torch.Tensor: A tensor of shape [B x latent_dim] containing the embeddings. """ return self(DatasetOutput(data=inputs)).z - + def predict(self, inputs: torch.Tensor) -> ModelOutput: """The input data is encoded and decoded without computing loss diff --git a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py index 51decd17..f88ce29b 100644 --- a/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py +++ b/src/pythae/models/beta_tc_vae/beta_tc_vae_model.py @@ -92,11 +92,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, dataset_size): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/beta_vae/beta_vae_model.py b/src/pythae/models/beta_vae/beta_vae_model.py index 82bae115..37f4b051 100644 --- a/src/pythae/models/beta_vae/beta_vae_model.py +++ b/src/pythae/models/beta_vae/beta_vae_model.py @@ -84,11 +84,14 @@ def loss_function(self, recon_x, x, mu, log_var, z): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/ciwae/ciwae_config.py b/src/pythae/models/ciwae/ciwae_config.py index 99f9eb9d..381055d8 100644 --- a/src/pythae/models/ciwae/ciwae_config.py +++ b/src/pythae/models/ciwae/ciwae_config.py @@ -19,5 +19,6 @@ class CIWAEConfig(VAEConfig): number_samples: int = 10 beta: float = 0.5 - def __post_init_post_parse__(self): + def __post_init__(self): + super().__post_init__() assert 0 <= self.beta <= 1, f"Beta parameter must be in [0-1]. Got {self.beta}." diff --git a/src/pythae/models/ciwae/ciwae_model.py b/src/pythae/models/ciwae/ciwae_model.py index 5df8dad9..7323bb53 100644 --- a/src/pythae/models/ciwae/ciwae_model.py +++ b/src/pythae/models/ciwae/ciwae_model.py @@ -95,13 +95,16 @@ def loss_function(self, recon_x, x, mu, log_var, z): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x, - x.reshape(recon_x.shape[0], -1) - .unsqueeze(1) - .repeat(1, self.n_samples, 1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .repeat(1, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py index 01e3c358..2ab9976a 100644 --- a/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py +++ b/src/pythae/models/disentangled_beta_vae/disentangled_beta_vae_model.py @@ -92,11 +92,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, epoch): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/factor_vae/factor_vae_config.py b/src/pythae/models/factor_vae/factor_vae_config.py index df445a70..0ea822fa 100644 --- a/src/pythae/models/factor_vae/factor_vae_config.py +++ b/src/pythae/models/factor_vae/factor_vae_config.py @@ -16,4 +16,3 @@ class FactorVAEConfig(VAEConfig): """ gamma: float = 2.0 uses_default_discriminator: bool = True - discriminator_input_dim: int = None diff --git a/src/pythae/models/factor_vae/factor_vae_model.py b/src/pythae/models/factor_vae/factor_vae_model.py index 3c0199f3..c2af1700 100644 --- a/src/pythae/models/factor_vae/factor_vae_model.py +++ b/src/pythae/models/factor_vae/factor_vae_model.py @@ -126,7 +126,7 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: recon_x=recon_x, recon_x_indices=idx_1, z=z, - z_bis_permuted=z_bis_permuted + z_bis_permuted=z_bis_permuted, ) return output @@ -137,11 +137,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, z_bis_permuted): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/info_vae/info_vae_model.py b/src/pythae/models/info_vae/info_vae_model.py index 33243e68..207bc3f5 100644 --- a/src/pythae/models/info_vae/info_vae_model.py +++ b/src/pythae/models/info_vae/info_vae_model.py @@ -92,11 +92,14 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/iwae/iwae_model.py b/src/pythae/models/iwae/iwae_model.py index 2be61225..06b72335 100644 --- a/src/pythae/models/iwae/iwae_model.py +++ b/src/pythae/models/iwae/iwae_model.py @@ -94,13 +94,16 @@ def loss_function(self, recon_x, x, mu, log_var, z): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x, - x.reshape(recon_x.shape[0], -1) - .unsqueeze(1) - .repeat(1, self.n_samples, 1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .repeat(1, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/miwae/miwae_model.py b/src/pythae/models/miwae/miwae_model.py index 2c2cc861..768fae0e 100644 --- a/src/pythae/models/miwae/miwae_model.py +++ b/src/pythae/models/miwae/miwae_model.py @@ -103,14 +103,17 @@ def loss_function(self, recon_x, x, mu, log_var, z): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x, - x.reshape(recon_x.shape[0], -1) - .unsqueeze(1) - .unsqueeze(1) - .repeat(1, self.gradient_n_estimates, self.n_samples, 1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, self.gradient_n_estimates, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_config.py b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_config.py index 8bac4c23..17b14df8 100644 --- a/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_config.py +++ b/src/pythae/models/normalizing_flows/pixelcnn/pixelcnn_config.py @@ -18,7 +18,8 @@ class PixelCNNConfig(BaseNFConfig): n_layers: int = 10 kernel_size: int = 5 - def __post_init_post_parse__(self): + def __post_init__(self): + super().__post_init__() assert ( self.kernel_size % 2 == 1 ), f"Wrong kernel size provided. The kernel size must be odd. Got {self.kernel_size}." diff --git a/src/pythae/models/normalizing_flows/planar_flow/planar_flow_config.py b/src/pythae/models/normalizing_flows/planar_flow/planar_flow_config.py index 747ddf2d..95ebca12 100644 --- a/src/pythae/models/normalizing_flows/planar_flow/planar_flow_config.py +++ b/src/pythae/models/normalizing_flows/planar_flow/planar_flow_config.py @@ -15,7 +15,8 @@ class PlanarFlowConfig(BaseNFConfig): activation: str = "tanh" - def __post_init_post_parse__(self): + def __post_init__(self): + super().__post_init__() assert self.activation in ["linear", "tanh", "elu"], ( f"'{self.activation}' doesn't correspond to an activation handled by the model. " "Available activations ['linear', 'tanh', 'elu']" diff --git a/src/pythae/models/piwae/piwae_model.py b/src/pythae/models/piwae/piwae_model.py index 9787131f..82dbc174 100644 --- a/src/pythae/models/piwae/piwae_model.py +++ b/src/pythae/models/piwae/piwae_model.py @@ -111,14 +111,17 @@ def loss_function(self, recon_x, x, mu, log_var, z): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x, - x.reshape(recon_x.shape[0], -1) - .unsqueeze(1) - .unsqueeze(1) - .repeat(1, self.gradient_n_estimates, self.n_samples, 1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x, + x.reshape(recon_x.shape[0], -1) + .unsqueeze(1) + .unsqueeze(1) + .repeat(1, self.gradient_n_estimates, self.n_samples, 1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/pvae/pvae_model.py b/src/pythae/models/pvae/pvae_model.py index 9282b341..1fbe6498 100644 --- a/src/pythae/models/pvae/pvae_model.py +++ b/src/pythae/models/pvae/pvae_model.py @@ -128,11 +128,14 @@ def loss_function(self, recon_x, x, z, qz_x): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/rhvae/rhvae_model.py b/src/pythae/models/rhvae/rhvae_model.py index b37eabfc..f692b1a5 100644 --- a/src/pythae/models/rhvae/rhvae_model.py +++ b/src/pythae/models/rhvae/rhvae_model.py @@ -306,7 +306,7 @@ def predict(self, inputs: torch.Tensor) -> ModelOutput: z = self._leap_step_2(recon_x, inputs, z, rho_, G_inv, G_log_det) recon_x = self.decoder(z)["reconstruction"] - + # compute metric value on new z using final metric G = self.G(z) G_inv = self.G_inv(z) @@ -328,7 +328,7 @@ def predict(self, inputs: torch.Tensor) -> ModelOutput: ) return output - + def _leap_step_1(self, recon_x, x, z, rho, G_inv, G_log_det, steps=3): """ Resolves first equation of generalized leapfrog integrator diff --git a/src/pythae/models/svae/svae_model.py b/src/pythae/models/svae/svae_model.py index 167571b3..c2ac697f 100644 --- a/src/pythae/models/svae/svae_model.py +++ b/src/pythae/models/svae/svae_model.py @@ -101,11 +101,14 @@ def loss_function(self, recon_x, x, loc, concentration, z): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/vae/vae_model.py b/src/pythae/models/vae/vae_model.py index 62bffb31..2fea4eea 100644 --- a/src/pythae/models/vae/vae_model.py +++ b/src/pythae/models/vae/vae_model.py @@ -100,11 +100,14 @@ def forward(self, inputs: BaseDataset, **kwargs): def loss_function(self, recon_x, x, mu, log_var, z): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": @@ -161,7 +164,7 @@ def get_nll(self, data, n_samples=1, batch_size=100): log_q_z_given_x = -0.5 * ( log_var + (z - mu) ** 2 / torch.exp(log_var) ).sum(dim=-1) - log_p_z = -0.5 * (z**2).sum(dim=-1) + log_p_z = -0.5 * (z ** 2).sum(dim=-1) recon_x = self.decoder(z)["reconstruction"] diff --git a/src/pythae/models/vae_gan/vae_gan_model.py b/src/pythae/models/vae_gan/vae_gan_model.py index 65732763..1ccea46a 100644 --- a/src/pythae/models/vae_gan/vae_gan_model.py +++ b/src/pythae/models/vae_gan/vae_gan_model.py @@ -192,11 +192,14 @@ def loss_function(self, recon_x, x, z, z_prior, mu, log_var): )[f"embedding_layer_{self.reconstruction_layer}"] # MSE in feature space - recon_loss = 0.5 * F.mse_loss( - true_discr_layer.reshape(N, -1), - recon_discr_layer.reshape(N, -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + true_discr_layer.reshape(N, -1), + recon_discr_layer.reshape(N, -1), + reduction="none", + ).sum(dim=-1) + ) encoder_loss = KLD + recon_loss diff --git a/src/pythae/models/vae_iaf/vae_iaf_model.py b/src/pythae/models/vae_iaf/vae_iaf_model.py index 1359471d..1979dfa4 100644 --- a/src/pythae/models/vae_iaf/vae_iaf_model.py +++ b/src/pythae/models/vae_iaf/vae_iaf_model.py @@ -106,11 +106,14 @@ def loss_function(self, recon_x, x, mu, log_var, z0, zk, log_abs_det_jac): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py index f8129132..959c9561 100644 --- a/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py +++ b/src/pythae/models/vae_lin_nf/vae_lin_nf_config.py @@ -20,7 +20,8 @@ class VAE_LinNF_Config(VAEConfig): flows: List[str] = field(default_factory=lambda: ["Planar", "Planar"]) - def __post_init_post_parse__(self): + def __post_init__(self): + super().__post_init__() for i, f in enumerate(self.flows): assert f in ["Planar", "Radial"], ( f"Flow name number {i+1}: '{f}' doesn't correspond " diff --git a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py index 03c426ce..a034fda6 100644 --- a/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py +++ b/src/pythae/models/vae_lin_nf/vae_lin_nf_model.py @@ -119,11 +119,14 @@ def loss_function(self, recon_x, x, mu, log_var, z0, zk, log_abs_det_jac): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/vamp/vamp_model.py b/src/pythae/models/vamp/vamp_model.py index 64997033..a85f7062 100644 --- a/src/pythae/models/vamp/vamp_model.py +++ b/src/pythae/models/vamp/vamp_model.py @@ -104,11 +104,14 @@ def loss_function(self, recon_x, x, mu, log_var, z, epoch): if self.model_config.reconstruction_loss == "mse": - recon_loss = 0.5 * F.mse_loss( - recon_x.reshape(x.shape[0], -1), - x.reshape(x.shape[0], -1), - reduction="none", - ).sum(dim=-1) + recon_loss = ( + 0.5 + * F.mse_loss( + recon_x.reshape(x.shape[0], -1), + x.reshape(x.shape[0], -1), + reduction="none", + ).sum(dim=-1) + ) elif self.model_config.reconstruction_loss == "bce": diff --git a/src/pythae/models/vq_vae/vq_vae_config.py b/src/pythae/models/vq_vae/vq_vae_config.py index 86cc25b6..ff9c41ee 100644 --- a/src/pythae/models/vq_vae/vq_vae_config.py +++ b/src/pythae/models/vq_vae/vq_vae_config.py @@ -23,7 +23,8 @@ class VQVAEConfig(AEConfig): use_ema: bool = False decay: float = 0.99 - def __post_init_post_parse__(self): + def __post_init__(self): + super().__post_init__() if self.use_ema: assert 0 <= self.decay <= 1, ( "The decay in the EMA update must be in [0, 1]. " f"Got {self.decay}." diff --git a/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler_config.py b/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler_config.py index 11908976..46f7b6ff 100644 --- a/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler_config.py +++ b/src/pythae/samplers/pixelcnn_sampler/pixelcnn_sampler_config.py @@ -16,7 +16,8 @@ class PixelCNNSamplerConfig(BaseSamplerConfig): n_layers: int = 10 kernel_size: int = 5 - def __post_init_post_parse__(self): + def __post_init__(self): + super().__post_init__() assert ( self.kernel_size % 2 == 1 ), f"Wrong kernel size provided. The kernel size must be odd. Got {self.kernel_size}." diff --git a/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py b/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py index 460f3700..37233581 100644 --- a/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py +++ b/src/pythae/trainers/adversarial_trainer/adversarial_trainer_config.py @@ -74,8 +74,9 @@ class AdversarialTrainerConfig(BaseTrainerConfig): autoencoder_learning_rate: float = 1e-4 discriminator_learning_rate: float = 1e-4 - def __post_init_post_parse__(self): + def __post_init__(self): """Check compatibilty""" + super().__post_init__() # Autoencoder optimizer and scheduler try: diff --git a/src/pythae/trainers/base_trainer/base_training_config.py b/src/pythae/trainers/base_trainer/base_training_config.py index c36ec170..e2300db8 100644 --- a/src/pythae/trainers/base_trainer/base_training_config.py +++ b/src/pythae/trainers/base_trainer/base_training_config.py @@ -51,7 +51,7 @@ class BaseTrainerConfig(BaseConfig): amp (bool): Whether to use auto mixed precision in training. Default: False """ - output_dir: str = None + output_dir: Union[str, None] = None per_device_train_batch_size: int = 64 per_device_eval_batch_size: int = 64 num_epochs: int = 100 @@ -76,6 +76,7 @@ class BaseTrainerConfig(BaseConfig): amp: bool = False def __post_init__(self): + """Check compatibility and sets up distributed training""" super().__post_init__() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if self.local_rank == -1 and env_local_rank != -1: @@ -99,8 +100,6 @@ def __post_init__(self): self.master_port = env_master_port os.environ["MASTER_PORT"] = self.master_port - def __post_init_post_parse__(self): - """Check compatibilty""" try: import torch.optim as optim diff --git a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py index 73f9a8d3..7eb747c7 100644 --- a/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py +++ b/src/pythae/trainers/coupled_optimizer_adversarial_trainer/coupled_optimizer_adversarial_trainer_config.py @@ -91,8 +91,9 @@ class CoupledOptimizerAdversarialTrainerConfig(BaseTrainerConfig): decoder_learning_rate: float = 1e-4 discriminator_learning_rate: float = 1e-4 - def __post_init_post_parse__(self): + def __post_init__(self): """Check compatibilty""" + super().__post_init__() # Encoder optimizer and scheduler try: diff --git a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py index e458442d..95508a78 100644 --- a/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py +++ b/src/pythae/trainers/coupled_optimizer_trainer/coupled_optimizer_trainer_config.py @@ -75,8 +75,9 @@ class CoupledOptimizerTrainerConfig(BaseTrainerConfig): encoder_learning_rate: float = 1e-4 decoder_learning_rate: float = 1e-4 - def __post_init_post_parse__(self): + def __post_init__(self): """Check compatibilty""" + super().__post_init__() # encoder optimizer and scheduler try: diff --git a/tests/test_CIWAE.py b/tests/test_CIWAE.py index d1b0a827..e537c893 100644 --- a/tests/test_CIWAE.py +++ b/tests/test_CIWAE.py @@ -4,6 +4,7 @@ import pytest import torch +from pydantic import ValidationError from pythae.customexception import BadInheritanceError from pythae.models import CIWAE, AutoModel, CIWAEConfig from pythae.models.base.base_utils import ModelOutput @@ -69,7 +70,7 @@ def test_build_model(self, model_configs): ] ) - with pytest.raises(AssertionError): + with pytest.raises(ValidationError): CIWAEConfig(beta=1.2) def test_raises_bad_inheritance(self, model_configs, bad_net): diff --git a/tests/test_PixelCNN.py b/tests/test_PixelCNN.py index 4bb3cd49..3bfaaf32 100644 --- a/tests/test_PixelCNN.py +++ b/tests/test_PixelCNN.py @@ -5,6 +5,7 @@ import pytest import torch +from pydantic import ValidationError from pythae.models import AutoModel from pythae.models.base.base_utils import ModelOutput from pythae.models.normalizing_flows import PixelCNN, PixelCNNConfig @@ -38,7 +39,7 @@ def test_build_model(self, model_configs): ] ) - with pytest.raises(AssertionError): + with pytest.raises(ValidationError): conf = PixelCNNConfig(kernel_size=2) def test_raises_no_input_output_dim(self, model_configs_no_input_output_dim): diff --git a/tests/test_VAEGAN.py b/tests/test_VAEGAN.py index 8cd2b2b4..cfba43e3 100644 --- a/tests/test_VAEGAN.py +++ b/tests/test_VAEGAN.py @@ -4,7 +4,6 @@ import numpy as np import pytest import torch -from torch.optim import Adam from pythae.customexception import BadInheritanceError from pythae.models import VAEGAN, AutoModel, VAEGANConfig @@ -99,7 +98,7 @@ def test_build_model(self, model_configs): conf = VAEGANConfig( input_dim=(1, 2, 18), latent_dim=5, - reconstruction_layer=5 + np.random.rand(), + reconstruction_layer=5 + np.random.randint(1, 3), ) with pytest.raises(AssertionError): a = VAEGAN(conf) diff --git a/tests/test_VAE_LinFlow.py b/tests/test_VAE_LinFlow.py index 4dc3421b..2f264246 100644 --- a/tests/test_VAE_LinFlow.py +++ b/tests/test_VAE_LinFlow.py @@ -4,6 +4,7 @@ import pytest import torch +from pydantic import ValidationError from pythae.customexception import BadInheritanceError from pythae.models import AutoModel, VAE_LinNF, VAE_LinNF_Config from pythae.models.base.base_utils import ModelOutput @@ -64,7 +65,7 @@ def bad_net(self): def test_raises_wrong_flows(self): - with pytest.raises(AssertionError): + with pytest.raises(ValidationError): conf = VAE_LinNF_Config( input_dim=(1, 28), latent_dim=5, flows=["Planar", "WrongFlow"] ) diff --git a/tests/test_VQVAE.py b/tests/test_VQVAE.py index 9ed177ab..799afdd6 100644 --- a/tests/test_VQVAE.py +++ b/tests/test_VQVAE.py @@ -4,6 +4,7 @@ import pytest import torch +from pydantic import ValidationError from pythae.customexception import BadInheritanceError from pythae.models import VQVAE, AutoModel, VQVAEConfig from pythae.models.base.base_utils import ModelOutput @@ -68,7 +69,7 @@ def test_build_model(self, model_configs): ] ) - with pytest.raises(AssertionError): + with pytest.raises(ValidationError): VQVAEConfig(decay=10, use_ema=True) def build_quantizer(self, model_configs): diff --git a/tests/test_config.py b/tests/test_config.py index f7268735..16e58262 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -138,7 +138,7 @@ def training_configs(self, request): def test_save_json(self, tmpdir, training_configs): tmpdir.mkdir("dummy_folder") - dir_path = dir_path = os.path.join(tmpdir, "dummy_folder") + dir_path = os.path.join(tmpdir, "dummy_folder") training_configs.save_json(dir_path, "dummy_json") diff --git a/tests/test_planar_flow.py b/tests/test_planar_flow.py index 5f7338a3..5759ca04 100644 --- a/tests/test_planar_flow.py +++ b/tests/test_planar_flow.py @@ -5,6 +5,7 @@ import pytest import torch +from pydantic import ValidationError from pythae.data.datasets import BaseDataset from pythae.models import AutoModel from pythae.models.base.base_utils import ModelOutput @@ -46,7 +47,7 @@ def test_raises_no_input_output_dim(self, model_configs_no_input_output_dim): model = PlanarFlow(model_configs_no_input_output_dim) def test_raises_wrong_activation(self): - with pytest.raises(AssertionError): + with pytest.raises(ValidationError): conf = PlanarFlowConfig(input_dim=(1,), activation="relu")