-
Notifications
You must be signed in to change notification settings - Fork 170
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add ResNet models and PixelCNN (#19)
* added tests + raise custom errors on wronf flow names in vae_nf_models * Add "VAE_NF" to models __init__ * Add utils import in vae_nf_models * fixed additionnal utils import * Simplified imports in vae_nf_utils * fixed bad import (Encoder_VAE_Conv) * minor changes * structure idea * Add MADE, MAF and correct typos * add NFWrapper * fix flip * add working MAF * add working IAF * black and isort cleanup * doc clean up and addition of NFs * add VAE_IAF + tests * add working VAE_IAF * change gitignor * work on Linear VAE flow * add VAE with linear normalizing flows * Add VAE_LinNF showcase + black + isort * update docs * add demo + VQVAE sampling showcase * update overview notebook * cleaning * first work * add MNIST resnet and clean up tests * add CIFAR and CELEBA resnets * fix test * add sigmoid in resnet * add ema * add resnets * minor fix * minor typo * clean up + pixelCNN * clean up docs * increase nn test coverage * minor changes * clean up imports * apply black * cehck BetaTCVAE * add nll computation in IWAE * check WAE, INFOVAE MSSSI VAE * check VAMP and SVAE * clean up trainers * Check AAE & VAEGAN * Start checking HVAE * check HVAE * update gitignore * udpte docs * update .gitignore * fix typo in trainer * equalize vq vae decoder and encoder with peers * remove no_grad in predict for HVAE * add stopping flag in while * revert trainer zero_grads * fix coupled adversariel trainer test and typos * add predict tests * add AutoModel and tests * fix some imports * fix issue in wae * fix issue in info_vae * add GenerationPipeline * apply black * fix progress bar in eval and speed up eval steps * minor change * fix device issue * fix device issue * update sampler API * add detach in sampler except * add detach in sampler except * change device allocation aae loss * change device allocation vaegan loss * fix IWAE loss issue * fix tqdm on colab * fix tqdm notebook * update .gitignore * update gitignore * allow non square latent dim in vqvae * remove comments * update notebooks * update gitignore * update notebooks * update AutoModel * Uodate notebooks * Update setup and req * Update README and doc * udpate notebooks and gitignore * update notebooks * remove unused import * fix nl hvae * update notebook * Update notebooks * update notebooks * uopdate notebooks * update notebooks * update setup * isort and black Co-authored-by: louis-j-vincent <[email protected]>
- Loading branch information
1 parent
340bb5b
commit e5fe37c
Showing
239 changed files
with
13,552 additions
and
2,393 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
******************************** | ||
ConvNets | ||
******************************** | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_AE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_VAE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_SVAE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_Conv_AE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Discriminator_Conv_CELEBA | ||
:members: |
35 changes: 35 additions & 0 deletions
35
docs/source/models/nn/celeba/pythae_benchmarks_nn_celeba.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
********************************** | ||
CELEBA | ||
********************************** | ||
|
||
.. automodule:: | ||
pythae.models.nn.benchmarks.celeba | ||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
|
||
convnets | ||
resnets | ||
|
||
ConvNets | ||
~~~~~~~~~~~~~~~ | ||
|
||
.. autosummary:: | ||
~pythae.models.nn.benchmarks.celeba.Encoder_Conv_AE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Encoder_Conv_VAE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Encoder_Conv_SVAE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Decoder_Conv_AE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Discriminator_Conv_CELEBA | ||
:nosignatures: | ||
|
||
ResNets | ||
~~~~~~~~~~~~~~~ | ||
|
||
.. autosummary:: | ||
~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_AE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VAE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_SVAE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VQVAE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Decoder_ResNet_AE_CELEBA | ||
~pythae.models.nn.benchmarks.celeba.Decoder_ResNet_VQVAE_CELEBA | ||
:nosignatures: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
******************************** | ||
ResNets | ||
******************************** | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_AE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VAE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_SVAE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VQVAE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_ResNet_AE_CELEBA | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_ResNet_VQVAE_CELEBA | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
******************************** | ||
ConvNets | ||
******************************** | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_AE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_VAE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_SVAE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_Conv_AE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Discriminator_Conv_CIFAR | ||
:members: |
35 changes: 35 additions & 0 deletions
35
docs/source/models/nn/cifar/pythae_benchmarks_nn_cifar.rst
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
********************************** | ||
CIFAR | ||
********************************** | ||
|
||
.. automodule:: | ||
pythae.models.nn.benchmarks.cifar | ||
|
||
.. toctree:: | ||
:maxdepth: 1 | ||
|
||
convnets | ||
resnets | ||
|
||
ConvNets | ||
~~~~~~~~~~~~~~~ | ||
|
||
.. autosummary:: | ||
~pythae.models.nn.benchmarks.cifar.Encoder_Conv_AE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Encoder_Conv_VAE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Encoder_Conv_SVAE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Decoder_Conv_AE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Discriminator_Conv_CIFAR | ||
:nosignatures: | ||
|
||
ResNets | ||
~~~~~~~~~~~~~~~ | ||
|
||
.. autosummary:: | ||
~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_AE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VAE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_SVAE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VQVAE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Decoder_ResNet_AE_CIFAR | ||
~pythae.models.nn.benchmarks.cifar.Decoder_ResNet_VQVAE_CIFAR | ||
:nosignatures: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
******************************** | ||
ResNets | ||
******************************** | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_AE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VAE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_SVAE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VQVAE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_ResNet_AE_CIFAR | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_ResNet_VQVAE_CIFAR | ||
:members: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
******************************** | ||
ConvNets | ||
******************************** | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_AE_MNIST | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_VAE_MNIST | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_SVAE_MNIST | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.mnist.Decoder_Conv_AE_MNIST | ||
:members: | ||
|
||
.. autoclass:: pythae.models.nn.benchmarks.mnist.Discriminator_Conv_MNIST | ||
:members: |
Oops, something went wrong.