Skip to content

Commit

Permalink
Add ResNet models and PixelCNN (#19)
Browse files Browse the repository at this point in the history
* added tests + raise custom errors on wronf flow names in vae_nf_models

* Add "VAE_NF" to models __init__

* Add utils import in vae_nf_models

* fixed additionnal utils import

* Simplified imports in vae_nf_utils

* fixed bad import (Encoder_VAE_Conv)

* minor changes

* structure idea

* Add MADE, MAF and correct typos

* add NFWrapper

* fix flip

* add working MAF

* add working IAF

* black and isort cleanup

* doc clean up and addition of NFs

* add VAE_IAF + tests

* add working VAE_IAF

* change gitignor

* work on Linear VAE flow

* add VAE with linear normalizing flows

* Add VAE_LinNF showcase + black + isort

* update docs

* add demo + VQVAE sampling showcase

* update overview notebook

* cleaning

* first work

* add MNIST resnet and clean up tests

* add CIFAR and CELEBA resnets

* fix test

* add sigmoid in resnet

* add ema

* add resnets

* minor fix

* minor typo

* clean up + pixelCNN

* clean up docs

* increase nn test coverage

* minor changes

* clean up imports

* apply black

* cehck BetaTCVAE

* add nll computation in IWAE

* check WAE, INFOVAE MSSSI VAE

* check VAMP and SVAE

* clean up trainers

* Check AAE & VAEGAN

* Start checking HVAE

* check HVAE

* update gitignore

* udpte docs

* update .gitignore

* fix typo in trainer

* equalize vq vae decoder and encoder with peers

* remove no_grad in predict for HVAE

* add stopping flag in while

* revert trainer zero_grads

* fix coupled adversariel trainer test and typos

* add predict tests

* add AutoModel and tests

* fix some imports

* fix issue in wae

* fix issue in info_vae

* add GenerationPipeline

* apply black

* fix progress bar in eval and speed up eval steps

* minor change

* fix device issue

* fix device issue

* update sampler API

* add detach in sampler except

* add detach in sampler except

* change device allocation aae loss

* change device allocation vaegan loss

* fix IWAE loss issue

* fix tqdm on colab

* fix tqdm notebook

* update .gitignore

* update gitignore

* allow non square latent dim in vqvae

* remove comments

* update notebooks

* update gitignore

* update notebooks

* update AutoModel

* Uodate notebooks

* Update setup and req

* Update README and doc

* udpate notebooks and gitignore

* update notebooks

* remove unused import

* fix nl hvae

* update notebook

* Update notebooks

* update notebooks

* uopdate notebooks

* update notebooks

* update setup

* isort and black

Co-authored-by: louis-j-vincent <[email protected]>
  • Loading branch information
clementchadebec and louis-j-vincent authored Jun 14, 2022
1 parent 340bb5b commit e5fe37c
Show file tree
Hide file tree
Showing 239 changed files with 13,552 additions and 2,393 deletions.
17 changes: 17 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,23 @@ examples/scripts/reproducibility/*
wandb
examples/notebooks/dummy_output_dir/
examples/notebooks/models_training/dummy_output_dir/
logs/
downloads_data/
*wandb*
*.slurm
configs
examples/scripts/generation_jz.py
examples/scripts/classification_jz.py
examples/scripts/latent_dim_sensi.py
examples/scripts/reconstruction_jz.py
examples/scripts/jz_array_training.py
examples/scripts/interpolation_jz.py
examples/scripts/kmeans_jz.py
*.csv
results.ipynb
examples/scripts/artifacts/*
examples/scripts/plots/*
examples/notebooks/my_model_with_custom_archi/



Expand Down
55 changes: 45 additions & 10 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,8 +96,8 @@ Below is the list of the models currently implemented in the library.
| VAMP prior sampler (VAMPSampler) | VAMP | [link](https://arxiv.org/abs/1705.07120) | [link](https://github.com/jmtomczak/vae_vampprior) |
| Manifold sampler (RHVAESampler) | RHVAE | [link](https://arxiv.org/abs/2105.00026) | [link](https://github.com/clementchadebec/pyraug)|
| Masked Autoregressive Flow Sampler (MAFSampler) | all models | [link](https://arxiv.org/abs/1705.07057v4) | [link](https://github.com/gpapamak/maf) |
| Inverse Autoregressive Flow Sampler (IAFSampler) | all models | [link](https://arxiv.org/abs/1606.04934) | [link](https://github.com/openai/iaf) |

| Inverse Autoregressive Flow Sampler (IAFSampler) | all models | [link](https://arxiv.org/abs/1606.04934) | [link](https://github.com/openai/iaf) |
| PixelCNN (PixelCNNSampler) | VQVAE | [link](https://arxiv.org/abs/1606.05328) | |

## Launching a model training

Expand Down Expand Up @@ -150,13 +150,47 @@ See [README.md](https://github.com/clementchadebec/benchmark_VAE/tree/main/examp

## Launching data generation

To launch the data generation process from a trained model, you only need to build your sampler. For instance, to generate new data with your sampler, run the following.
### Using the `GeneationPipeline`

The easiest way to launch a data generation from a trained model consists in using the built-in `GenerationPipeline` provided in Pythae. Say you want to generate 100 samples using a `MAFSampler` all you have to do is 1) relaod the trained model, 2) define the sampler's configuration and 3) create and launch the `GenerationPipeline` as follows

```python
>>> from pythae.models import AutoModel
>>> from pythae.samplers import MAFSamplerConfig
>>> from pythae.pipelines import GenerationPipeline
>>> # Retrieve the trained model
>>> my_trained_vae = AutoModel.load_from_folder(
... 'path/to/your/trained/model'
... )
>>> my_sampler_config = MAFSamplerConfig(
... n_made_blocks=2,
... n_hidden_in_made=3,
... hidden_size=128
... )
>>> # Build the pipeline
>>> pipe = GenerationPipeline(
... model=my_trained_vae,
... sampler_config=my_sampler_config
... )
>>> # Launch data generation
>>> generated_samples = pipe(
... num_samples=args.num_samples,
... return_gen=True, # If false returns nothing
... train_data=train_data, # Needed to fit the sampler
... eval_data=eval_data, # Needed to fit the sampler
... training_config=BaseTrainerConfig(num_epochs=200) # TrainingConfig to use to fit the sampler
... )
```

### Using the Samplers

Alternatively, you can launch the data generation process from a trained model directly with the sampler. For instance, to generate new data with your sampler, run the following.

```python
>>> from pythae.models import VAE
>>> from pythae.models import AutoModel
>>> from pythae.samplers import NormalSampler
>>> # Retrieve the trained model
>>> my_trained_vae = VAE.load_from_folder(
>>> my_trained_vae = AutoModel.load_from_folder(
... 'path/to/your/trained/model'
... )
>>> # Define your sampler
Expand All @@ -175,10 +209,10 @@ If you set `output_dir` to a specific path, the generated images will be saved a
The samplers can be used with any model as long as it is suited. For instance, a `GaussianMixtureSampler` instance can be used to generate from any model but a `VAMPSampler` will only be usable with a `VAMP` model. Check [here](#available-samplers) to see which ones apply to your model. Be carefull that some samplers such as the `GaussianMixtureSampler` for instance may need to be fitted by calling the `fit` method before using. Below is an example for the `GaussianMixtureSampler`.

```python
>>> from pythae.models import VAE
>>> from pythae.models import AutoModel
>>> from pythae.samplers import GaussianMixtureSampler, GaussianMixtureSamplerConfig
>>> # Retrieve the trained model
>>> my_trained_vae = VAE.load_from_folder(
>>> my_trained_vae = AutoModel.load_from_folder(
... 'path/to/your/trained/model'
... )
>>> # Define your sampler
Expand All @@ -200,6 +234,7 @@ The samplers can be used with any model as long as it is suited. For instance, a
... )
```


## Define you own Autoencoder architecture

Pythae provides you the possibility to define your own neural networks within the VAE models. For instance, say you want to train a Wassertstein AE with a specific encoder and decoder, you can do the following:
Expand Down Expand Up @@ -264,9 +299,9 @@ You can also find predefined neural network architectures for the most common da

```python
>>> for pythae.models.nn.benchmark.mnist import (
... Encoder_AE_MNIST, # For AE based model (only return embeddings)
... Encoder_VAE_MNIST, # For VAE based model (return embeddings and log_covariances)
... Decoder_AE_MNIST
... Encoder_Conv_AE_MNIST, # For AE based model (only return embeddings)
... Encoder_Conv_VAE_MNIST, # For VAE based model (return embeddings and log_covariances)
... Decoder_Conv_AE_MNIST
... )
```
Replace *mnist* by cifar or celeba to access to other neural nets.
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
File renamed without changes.
9 changes: 6 additions & 3 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

# -- Path setup --------------------------------------------------------------

needs_sphinx = "1.6"
needs_sphinx = "1.3"

sys.path.insert(0, os.path.abspath("../../src/"))

Expand All @@ -31,7 +31,7 @@
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = extensions = [
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.autosummary",
"sphinx.ext.doctest",
Expand All @@ -41,9 +41,12 @@
"sphinx.ext.napoleon",
"sphinx.ext.viewcode",
"sphinx.ext.autosectionlabel",
"sphinxcontrib.bibtex",
"sphinxcontrib.bibtex"
]


suppress_warnings = ['autosectionlabel.*']

bibtex_bibfiles = ["references.bib"]


Expand Down
31 changes: 30 additions & 1 deletion docs/source/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,33 @@ we can conduct benchmark analysis and reproducible research!
models/pythae.models
samplers/pythae.samplers
trainers/pythae.trainer
pipelines/pythae.pipelines
pipelines/pythae.pipelines

Setup
~~~~~~~~~~~~~

To install the latest version of this library run the following using ``pip``

.. code-block:: bash
$ pip install git+https://github.com/clementchadebec/benchmark_VAE.git
or alternatively you can clone the github repo to access to tests, tutorials and scripts.

.. code-block:: bash
$ git clone https://github.com/clementchadebec/benchmark_VAE.git
and install the library

.. code-block:: bash
$ cd benchmark_VAE
$ pip install -e .
If you clone the pythae's repository you will access to the following:

- ``docs``: The folder in which the documentation can be retrieved.
- ``tests``: pythae's unit-testing using pytest.
- ``examples``: A list of ``ipynb`` tutorials and script describing the main functionalities of pythae.
- ``src/pythae``: The main library which can be installed with ``pip``.
4 changes: 0 additions & 4 deletions docs/source/models/autoencoders/models.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,3 @@
.. _pythae_models:


**********************************
Autoencoders
**********************************
Expand Down Expand Up @@ -31,7 +28,6 @@ Autoencoders
rae_gp
rae_l2
rhvae
:nosignatures:

.. automodule::
pythae.models
Expand Down
13 changes: 0 additions & 13 deletions docs/source/models/autoencoders/wae_mmd.rst

This file was deleted.

18 changes: 18 additions & 0 deletions docs/source/models/nn/celeba/convnets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
********************************
ConvNets
********************************

.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_AE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_VAE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_Conv_SVAE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_Conv_AE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Discriminator_Conv_CELEBA
:members:
35 changes: 35 additions & 0 deletions docs/source/models/nn/celeba/pythae_benchmarks_nn_celeba.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
**********************************
CELEBA
**********************************

.. automodule::
pythae.models.nn.benchmarks.celeba

.. toctree::
:maxdepth: 1

convnets
resnets

ConvNets
~~~~~~~~~~~~~~~

.. autosummary::
~pythae.models.nn.benchmarks.celeba.Encoder_Conv_AE_CELEBA
~pythae.models.nn.benchmarks.celeba.Encoder_Conv_VAE_CELEBA
~pythae.models.nn.benchmarks.celeba.Encoder_Conv_SVAE_CELEBA
~pythae.models.nn.benchmarks.celeba.Decoder_Conv_AE_CELEBA
~pythae.models.nn.benchmarks.celeba.Discriminator_Conv_CELEBA
:nosignatures:

ResNets
~~~~~~~~~~~~~~~

.. autosummary::
~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_AE_CELEBA
~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VAE_CELEBA
~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_SVAE_CELEBA
~pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VQVAE_CELEBA
~pythae.models.nn.benchmarks.celeba.Decoder_ResNet_AE_CELEBA
~pythae.models.nn.benchmarks.celeba.Decoder_ResNet_VQVAE_CELEBA
:nosignatures:
21 changes: 21 additions & 0 deletions docs/source/models/nn/celeba/resnets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
********************************
ResNets
********************************

.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_AE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VAE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_SVAE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Encoder_ResNet_VQVAE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_ResNet_AE_CELEBA
:members:

.. autoclass:: pythae.models.nn.benchmarks.celeba.Decoder_ResNet_VQVAE_CELEBA
:members:
18 changes: 18 additions & 0 deletions docs/source/models/nn/cifar/convnets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
********************************
ConvNets
********************************

.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_AE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_VAE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_Conv_SVAE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_Conv_AE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Discriminator_Conv_CIFAR
:members:
35 changes: 35 additions & 0 deletions docs/source/models/nn/cifar/pythae_benchmarks_nn_cifar.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
**********************************
CIFAR
**********************************

.. automodule::
pythae.models.nn.benchmarks.cifar

.. toctree::
:maxdepth: 1

convnets
resnets

ConvNets
~~~~~~~~~~~~~~~

.. autosummary::
~pythae.models.nn.benchmarks.cifar.Encoder_Conv_AE_CIFAR
~pythae.models.nn.benchmarks.cifar.Encoder_Conv_VAE_CIFAR
~pythae.models.nn.benchmarks.cifar.Encoder_Conv_SVAE_CIFAR
~pythae.models.nn.benchmarks.cifar.Decoder_Conv_AE_CIFAR
~pythae.models.nn.benchmarks.cifar.Discriminator_Conv_CIFAR
:nosignatures:

ResNets
~~~~~~~~~~~~~~~

.. autosummary::
~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_AE_CIFAR
~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VAE_CIFAR
~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_SVAE_CIFAR
~pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VQVAE_CIFAR
~pythae.models.nn.benchmarks.cifar.Decoder_ResNet_AE_CIFAR
~pythae.models.nn.benchmarks.cifar.Decoder_ResNet_VQVAE_CIFAR
:nosignatures:
21 changes: 21 additions & 0 deletions docs/source/models/nn/cifar/resnets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
********************************
ResNets
********************************

.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_AE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VAE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_SVAE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Encoder_ResNet_VQVAE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_ResNet_AE_CIFAR
:members:

.. autoclass:: pythae.models.nn.benchmarks.cifar.Decoder_ResNet_VQVAE_CIFAR
:members:
18 changes: 18 additions & 0 deletions docs/source/models/nn/mnist/convnets.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
********************************
ConvNets
********************************

.. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_AE_MNIST
:members:

.. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_VAE_MNIST
:members:

.. autoclass:: pythae.models.nn.benchmarks.mnist.Encoder_Conv_SVAE_MNIST
:members:

.. autoclass:: pythae.models.nn.benchmarks.mnist.Decoder_Conv_AE_MNIST
:members:

.. autoclass:: pythae.models.nn.benchmarks.mnist.Discriminator_Conv_MNIST
:members:
Loading

0 comments on commit e5fe37c

Please sign in to comment.