From 0f5c0cc5aa361e5bea7294c02502028904a0305f Mon Sep 17 00:00:00 2001 From: clementchadebec <47564971+clementchadebec@users.noreply.github.com> Date: Fri, 22 Jul 2022 10:46:15 +0200 Subject: [PATCH] Interpolations (#38) * add interpolate and reconstruct methods * update doc * update tests * update demo * black & isort * prepare release * update README --- README.md | 2 +- docs/source/index.rst | 6 ++ .../source/models/autoencoders/auto_model.rst | 10 ++ docs/source/models/autoencoders/models.rst | 2 + docs/source/models/pythae.models.rst | 1 + .../adversarial_ae_training.ipynb | 80 ++++++++++++++++ .../models_training/ae_training.ipynb | 80 ++++++++++++++++ .../beta_tc_vae_training.ipynb | 80 ++++++++++++++++ .../models_training/beta_vae_training.ipynb | 80 ++++++++++++++++ .../disentangled_beta_vae_training.ipynb | 80 ++++++++++++++++ .../models_training/factor_vae_training.ipynb | 82 +++++++++++++++- .../models_training/hvae_training.ipynb | 80 ++++++++++++++++ .../models_training/info_vae_training.ipynb | 80 ++++++++++++++++ .../models_training/iwae_training.ipynb | 82 +++++++++++++++- .../ms_ssim_vae_training.ipynb | 80 ++++++++++++++++ .../models_training/rae_gp_training.ipynb | 80 ++++++++++++++++ .../models_training/rae_l2_training.ipynb | 80 ++++++++++++++++ .../models_training/rhvae_training.ipynb | 80 ++++++++++++++++ .../models_training/svae_training.ipynb | 80 ++++++++++++++++ .../models_training/vae_iaf_training.ipynb | 80 ++++++++++++++++ .../models_training/vae_lin_nf_training.ipynb | 80 ++++++++++++++++ .../models_training/vae_training.ipynb | 80 ++++++++++++++++ .../models_training/vaegan_training.ipynb | 82 +++++++++++++++- .../models_training/vamp_training.ipynb | 80 ++++++++++++++++ .../models_training/vqvae_training.ipynb | 95 +++++++++++++------ .../models_training/wae_training.ipynb | 80 ++++++++++++++++ setup.py | 2 +- src/pythae/models/auto_model/__init__.py | 2 +- src/pythae/models/auto_model/auto_model.py | 4 +- src/pythae/models/base/base_model.py | 65 ++++++++++++- tests/test_AE.py | 54 +++++++++++ tests/test_Adversarial_AE.py | 55 +++++++++++ tests/test_BetaTCVAE.py | 55 +++++++++++ tests/test_BetaVAE.py | 55 +++++++++++ tests/test_DisentangledBetaVAE.py | 55 +++++++++++ tests/test_FactorVAE.py | 55 +++++++++++ tests/test_HVAE.py | 55 +++++++++++ tests/test_IWAE.py | 55 +++++++++++ tests/test_MSSSIMVAE.py | 50 ++++++++++ tests/test_RHVAE.py | 55 +++++++++++ tests/test_SVAE.py | 55 +++++++++++ tests/test_VAE.py | 55 +++++++++++ tests/test_VAEGAN.py | 55 +++++++++++ tests/test_VAE_IAF.py | 55 +++++++++++ tests/test_VAE_LinFlow.py | 56 +++++++++++ tests/test_VAMP.py | 56 +++++++++++ tests/test_VQVAE.py | 55 +++++++++++ tests/test_WAE_MMD.py | 55 +++++++++++ tests/test_info_vae_mmd.py | 54 +++++++++++ tests/test_rae_gp.py | 56 +++++++++++ tests/test_rae_l2.py | 55 +++++++++++ 51 files changed, 2909 insertions(+), 37 deletions(-) create mode 100644 docs/source/models/autoencoders/auto_model.rst diff --git a/README.md b/README.md index 551494cf..6812c6a1 100644 --- a/README.md +++ b/README.md @@ -307,7 +307,7 @@ And now build the model You can also find predefined neural network architectures for the most common data sets (*i.e.* MNIST, CIFAR, CELEBA ...) that can be loaded as follows ```python ->>> for pythae.models.nn.benchmark.mnist import ( +>>> from pythae.models.nn.benchmark.mnist import ( ... Encoder_Conv_AE_MNIST, # For AE based model (only return embeddings) ... Encoder_Conv_VAE_MNIST, # For VAE based model (return embeddings and log_covariances) ... Decoder_Conv_AE_MNIST diff --git a/docs/source/index.rst b/docs/source/index.rst index 2242c3d2..fd430758 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -23,6 +23,12 @@ we can conduct benchmark analysis and reproducible research! Setup ~~~~~~~~~~~~~ +To install the latest stable release of this library run the following using ``pip`` + +.. code-block:: bash + + $ pip install pythae + To install the latest version of this library run the following using ``pip`` .. code-block:: bash diff --git a/docs/source/models/autoencoders/auto_model.rst b/docs/source/models/autoencoders/auto_model.rst new file mode 100644 index 00000000..1d94343a --- /dev/null +++ b/docs/source/models/autoencoders/auto_model.rst @@ -0,0 +1,10 @@ +********************************** +AutoModel +********************************** + + +.. automodule:: + pythae.models.auto_model + +.. autoclass:: pythae.models.AutoModel + :members: diff --git a/docs/source/models/autoencoders/models.rst b/docs/source/models/autoencoders/models.rst index 0e322cb1..d2480d38 100644 --- a/docs/source/models/autoencoders/models.rst +++ b/docs/source/models/autoencoders/models.rst @@ -7,6 +7,7 @@ Autoencoders :maxdepth: 1 baseAE + auto_model ae vae betavae @@ -37,6 +38,7 @@ Available Models .. autosummary:: ~pythae.models.BaseAE + ~pythae.models.AutoModel ~pythae.models.AE ~pythae.models.VAE ~pythae.models.BetaVAE diff --git a/docs/source/models/pythae.models.rst b/docs/source/models/pythae.models.rst index 65a53412..2da47100 100644 --- a/docs/source/models/pythae.models.rst +++ b/docs/source/models/pythae.models.rst @@ -21,6 +21,7 @@ Available Autoencoders .. autosummary:: ~pythae.models.BaseAE + ~pythae.models.AutoModel ~pythae.models.AE ~pythae.models.VAE ~pythae.models.BetaVAE diff --git a/examples/notebooks/models_training/adversarial_ae_training.ipynb b/examples/notebooks/models_training/adversarial_ae_training.ipynb index c78b6784..a28d80d4 100644 --- a/examples/notebooks/models_training/adversarial_ae_training.ipynb +++ b/examples/notebooks/models_training/adversarial_ae_training.ipynb @@ -239,6 +239,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/ae_training.ipynb b/examples/notebooks/models_training/ae_training.ipynb index a7211cf7..68cc252c 100644 --- a/examples/notebooks/models_training/ae_training.ipynb +++ b/examples/notebooks/models_training/ae_training.ipynb @@ -238,6 +238,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/beta_tc_vae_training.ipynb b/examples/notebooks/models_training/beta_tc_vae_training.ipynb index bf56d084..cf82e9a2 100644 --- a/examples/notebooks/models_training/beta_tc_vae_training.ipynb +++ b/examples/notebooks/models_training/beta_tc_vae_training.ipynb @@ -242,6 +242,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/beta_vae_training.ipynb b/examples/notebooks/models_training/beta_vae_training.ipynb index e0a26b89..f1fff78c 100644 --- a/examples/notebooks/models_training/beta_vae_training.ipynb +++ b/examples/notebooks/models_training/beta_vae_training.ipynb @@ -240,6 +240,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb b/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb index d38423e8..58bb2eaf 100644 --- a/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb +++ b/examples/notebooks/models_training/disentangled_beta_vae_training.ipynb @@ -242,6 +242,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/factor_vae_training.ipynb b/examples/notebooks/models_training/factor_vae_training.ipynb index 7080f2d8..134c5a76 100644 --- a/examples/notebooks/models_training/factor_vae_training.ipynb +++ b/examples/notebooks/models_training/factor_vae_training.ipynb @@ -240,6 +240,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { @@ -260,7 +340,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/hvae_training.ipynb b/examples/notebooks/models_training/hvae_training.ipynb index 06397d2a..6323b1bd 100644 --- a/examples/notebooks/models_training/hvae_training.ipynb +++ b/examples/notebooks/models_training/hvae_training.ipynb @@ -241,6 +241,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/info_vae_training.ipynb b/examples/notebooks/models_training/info_vae_training.ipynb index e823e729..72e1745c 100644 --- a/examples/notebooks/models_training/info_vae_training.ipynb +++ b/examples/notebooks/models_training/info_vae_training.ipynb @@ -249,6 +249,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/iwae_training.ipynb b/examples/notebooks/models_training/iwae_training.ipynb index 43893dd0..cc2f4d81 100644 --- a/examples/notebooks/models_training/iwae_training.ipynb +++ b/examples/notebooks/models_training/iwae_training.ipynb @@ -239,6 +239,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { @@ -259,7 +339,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/ms_ssim_vae_training.ipynb b/examples/notebooks/models_training/ms_ssim_vae_training.ipynb index 4753245e..b6238259 100644 --- a/examples/notebooks/models_training/ms_ssim_vae_training.ipynb +++ b/examples/notebooks/models_training/ms_ssim_vae_training.ipynb @@ -240,6 +240,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/rae_gp_training.ipynb b/examples/notebooks/models_training/rae_gp_training.ipynb index aa32febb..4e8433a0 100644 --- a/examples/notebooks/models_training/rae_gp_training.ipynb +++ b/examples/notebooks/models_training/rae_gp_training.ipynb @@ -239,6 +239,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/rae_l2_training.ipynb b/examples/notebooks/models_training/rae_l2_training.ipynb index 9b06efb6..e90a62ca 100644 --- a/examples/notebooks/models_training/rae_l2_training.ipynb +++ b/examples/notebooks/models_training/rae_l2_training.ipynb @@ -241,6 +241,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/rhvae_training.ipynb b/examples/notebooks/models_training/rhvae_training.ipynb index 0e45f546..085c92c2 100644 --- a/examples/notebooks/models_training/rhvae_training.ipynb +++ b/examples/notebooks/models_training/rhvae_training.ipynb @@ -243,6 +243,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/svae_training.ipynb b/examples/notebooks/models_training/svae_training.ipynb index 684b4e41..67a3770e 100644 --- a/examples/notebooks/models_training/svae_training.ipynb +++ b/examples/notebooks/models_training/svae_training.ipynb @@ -180,6 +180,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/vae_iaf_training.ipynb b/examples/notebooks/models_training/vae_iaf_training.ipynb index 4842a1bc..e017fcb5 100644 --- a/examples/notebooks/models_training/vae_iaf_training.ipynb +++ b/examples/notebooks/models_training/vae_iaf_training.ipynb @@ -240,6 +240,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/vae_lin_nf_training.ipynb b/examples/notebooks/models_training/vae_lin_nf_training.ipynb index d2533ebb..9632579c 100644 --- a/examples/notebooks/models_training/vae_lin_nf_training.ipynb +++ b/examples/notebooks/models_training/vae_lin_nf_training.ipynb @@ -239,6 +239,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/vae_training.ipynb b/examples/notebooks/models_training/vae_training.ipynb index 6df02af7..91e16caf 100644 --- a/examples/notebooks/models_training/vae_training.ipynb +++ b/examples/notebooks/models_training/vae_training.ipynb @@ -238,6 +238,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/vaegan_training.ipynb b/examples/notebooks/models_training/vaegan_training.ipynb index 6858176a..1daa665d 100644 --- a/examples/notebooks/models_training/vaegan_training.ipynb +++ b/examples/notebooks/models_training/vaegan_training.ipynb @@ -244,6 +244,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { @@ -264,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.8.12" + "version": "3.8.13" }, "orig_nbformat": 4 }, diff --git a/examples/notebooks/models_training/vamp_training.ipynb b/examples/notebooks/models_training/vamp_training.ipynb index e83db94f..ebfbbb11 100644 --- a/examples/notebooks/models_training/vamp_training.ipynb +++ b/examples/notebooks/models_training/vamp_training.ipynb @@ -239,6 +239,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/examples/notebooks/models_training/vqvae_training.ipynb b/examples/notebooks/models_training/vqvae_training.ipynb index e74f4c2c..910bd449 100644 --- a/examples/notebooks/models_training/vqvae_training.ipynb +++ b/examples/notebooks/models_training/vqvae_training.ipynb @@ -127,7 +127,11 @@ "metadata": {}, "outputs": [], "source": [ - "recon = trained_model({'data': eval_dataset[:50]}).recon_x.detach().cpu()" + "import torch\n", + "from pythae.samplers import PixelCNNSampler, PixelCNNSamplerConfig\n", + "from pythae.trainers import BaseTrainerConfig\n", + "sampler_config = PixelCNNSamplerConfig(n_layers=3, kernel_size=5) \n", + "pixelcnn_sampler = PixelCNNSampler(model=trained_model, sampler_config=sampler_config)" ] }, { @@ -136,15 +140,19 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "\n", - "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", - "\n", - "for i in range(5):\n", - " for j in range(5):\n", - " axes[i][j].imshow(recon[i*5 +j].cpu().squeeze(0), cmap='gray')\n", - " axes[i][j].axis('off')\n", - "plt.tight_layout(pad=0.)" + "pixelcnn_sampler.fit(train_data=torch.tensor(train_dataset), eval_data=torch.tensor(eval_dataset), training_config=BaseTrainerConfig(num_epochs=30, learning_rate=1e-4))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "gen_data = pixelcnn_sampler.sample(\n", + " num_samples=100,\n", + " #output_dir='generated/mnist/vae_2_stage_mnist'\n", + ")" ] }, { @@ -157,24 +165,26 @@ "\n", "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", "\n", + "\n", "for i in range(5):\n", " for j in range(5):\n", - " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n", " axes[i][j].axis('off')\n", "plt.tight_layout(pad=0.)" ] }, { - "cell_type": "code", - "execution_count": null, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "import torch\n", - "from pythae.samplers import PixelCNNSampler, PixelCNNSamplerConfig\n", - "from pythae.trainers import BaseTrainerConfig\n", - "sampler_config = PixelCNNSamplerConfig(n_layers=3, kernel_size=5) \n", - "pixelcnn_sampler = PixelCNNSampler(model=trained_model, sampler_config=sampler_config)" + "## ... the other samplers work the same" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" ] }, { @@ -183,7 +193,7 @@ "metadata": {}, "outputs": [], "source": [ - "pixelcnn_sampler.fit(train_data=torch.tensor(train_dataset), eval_data=torch.tensor(eval_dataset), training_config=BaseTrainerConfig(num_epochs=30, learning_rate=1e-4))" + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" ] }, { @@ -192,10 +202,14 @@ "metadata": {}, "outputs": [], "source": [ - "gen_data = pixelcnn_sampler.sample(\n", - " num_samples=100,\n", - " #output_dir='generated/mnist/vae_2_stage_mnist'\n", - ")" + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" ] }, { @@ -204,14 +218,12 @@ "metadata": {}, "outputs": [], "source": [ - "import matplotlib.pyplot as plt\n", - "\n", + "# show the true data\n", "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", "\n", - "\n", "for i in range(5):\n", " for j in range(5):\n", - " axes[i][j].imshow(gen_data[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", " axes[i][j].axis('off')\n", "plt.tight_layout(pad=0.)" ] @@ -220,7 +232,32 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "## ... the other samplers work the same" + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" ] } ], diff --git a/examples/notebooks/models_training/wae_training.ipynb b/examples/notebooks/models_training/wae_training.ipynb index 1a3c111d..830ea999 100644 --- a/examples/notebooks/models_training/wae_training.ipynb +++ b/examples/notebooks/models_training/wae_training.ipynb @@ -241,6 +241,86 @@ "source": [ "## ... the other samplers work the same" ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing reconstructions" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "reconstructions = model.reconstruct(eval_dataset[:25]).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show reconstructions\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(reconstructions[i*5 + j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show the true data\n", + "fig, axes = plt.subplots(nrows=5, ncols=5, figsize=(10, 10))\n", + "\n", + "for i in range(5):\n", + " for j in range(5):\n", + " axes[i][j].imshow(eval_dataset[i*5 +j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Visualizing interpolations" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "interpolations = model.interpolate(eval_dataset[:5], eval_dataset[5:10], granularity=10).detach().cpu()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# show interpolations\n", + "fig, axes = plt.subplots(nrows=5, ncols=10, figsize=(10, 5))\n", + "\n", + "for i in range(5):\n", + " for j in range(10):\n", + " axes[i][j].imshow(interpolations[i, j].cpu().squeeze(0), cmap='gray')\n", + " axes[i][j].axis('off')\n", + "plt.tight_layout(pad=0.)" + ] } ], "metadata": { diff --git a/setup.py b/setup.py index 3cfee814..734962be 100644 --- a/setup.py +++ b/setup.py @@ -5,7 +5,7 @@ setup( name="pythae", - version="0.0.5", + version="0.0.6", author="Clement Chadebec (HekA team INRIA)", author_email="clement.chadebec@inria.fr", description="Unifying Generative Autoencoders in Python", diff --git a/src/pythae/models/auto_model/__init__.py b/src/pythae/models/auto_model/__init__.py index 1db22166..96f7c6a4 100644 --- a/src/pythae/models/auto_model/__init__.py +++ b/src/pythae/models/auto_model/__init__.py @@ -1,4 +1,4 @@ -"""Utils class allowing to reload any :class:`pythae.models` automatically witht the following +"""Utils class allowing to reload any :class:`pythae.models` automatically with the following lines of code. .. code-block:: diff --git a/src/pythae/models/auto_model/auto_model.py b/src/pythae/models/auto_model/auto_model.py index 93024af1..8ac6788e 100644 --- a/src/pythae/models/auto_model/auto_model.py +++ b/src/pythae/models/auto_model/auto_model.py @@ -13,6 +13,8 @@ class AutoModel(nn.Module): + "Utils class allowing to reload any :class:`pythae.models` automatically" + def __init__(self) -> None: super().__init__() @@ -185,7 +187,7 @@ def load_from_hf_hub( cls, hf_hub_path: str, allow_pickle: bool = False ): # pragma: no cover """Class method to be used to load a automaticaly a pretrained model from the Hugging Face - hub + hub Args: hf_hub_path (str): The path where the model should have been be saved on the diff --git a/src/pythae/models/base/base_model.py b/src/pythae/models/base/base_model.py index 99c5a3f1..8ba8f013 100644 --- a/src/pythae/models/base/base_model.py +++ b/src/pythae/models/base/base_model.py @@ -19,7 +19,12 @@ from ..nn import BaseDecoder, BaseEncoder from ..nn.default_architectures import Decoder_AE_MLP from .base_config import BaseAEConfig, EnvironmentConfig -from .base_utils import CPU_Unpickler, ModelOutput, hf_hub_is_available, model_card_template +from .base_utils import ( + CPU_Unpickler, + ModelOutput, + hf_hub_is_available, + model_card_template, +) logger = logging.getLogger(__name__) console = logging.StreamHandler() @@ -100,6 +105,64 @@ def forward(self, inputs: BaseDataset, **kwargs) -> ModelOutput: ``loss = model_output.loss``""" raise NotImplementedError() + def reconstruct(self, inputs: torch.Tensor): + """This function returns the reconstructions of given input data. + + Args: + inputs (torch.Tensor): The inputs data to be reconstructed of shape [B x input_dim] + ending_inputs (torch.Tensor): The starting inputs in the interpolation of shape + + Returns: + torch.Tensor: A tensor of shape [B x input_dim] containing the reconstructed samples. + """ + return self({"data": inputs, "data_bis": inputs}).recon_x + + def interpolate( + self, + starting_inputs: torch.Tensor, + ending_inputs: torch.Tensor, + granularity: int = 10, + ): + """This function performs a linear interpolation in the latent space of the autoencoder + from starting inputs to ending inputs. It returns the interpolation trajectories. + + Args: + starting_inputs (torch.Tensor): The starting inputs in the interpolation of shape + [B x input_dim] + ending_inputs (torch.Tensor): The starting inputs in the interpolation of shape + [B x input_dim] + granularity (int): The granularity of the interpolation. + + Returns: + torch.Tensor: A tensor of shape [B x granularity x input_dim] containing the + interpolation trajectories. + """ + assert starting_inputs.shape[0] == ending_inputs.shape[0], ( + "The number of starting_inputs should equal the number of ending_inputs. Got " + f"{starting_inputs.shape[0]} sampler for starting_inputs and {ending_inputs.shape[0]} " + "for endinging_inputs." + ) + + starting_z = self({"data": starting_inputs, "data_bis": starting_inputs}).z + ending_z = self({"data": ending_inputs, "data_bis": ending_inputs}).z + t = torch.linspace(0, 1, granularity).to(starting_inputs.device) + intep_line = ( + torch.kron( + starting_z.reshape(starting_z.shape[0], -1), (1 - t).unsqueeze(-1) + ) + + torch.kron(ending_z.reshape(ending_z.shape[0], -1), t.unsqueeze(-1)) + ).reshape((starting_z.shape[0] * t.shape[0],) + (starting_z.shape[1:])) + + decoded_line = self.decoder(intep_line).reconstruction.reshape( + ( + starting_inputs.shape[0], + t.shape[0], + ) + + (starting_inputs.shape[1:]) + ) + + return decoded_line + def update(self): """Method that allows model update during the training (at the end of a training epoch) diff --git a/tests/test_AE.py b/tests/test_AE.py index 5d79c5c3..4f5917f6 100644 --- a/tests/test_AE.py +++ b/tests/test_AE.py @@ -289,6 +289,60 @@ def test_model_train_output(self, ae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture() + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return AE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return AE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape @pytest.mark.slow class Test_AE_Training: diff --git a/tests/test_Adversarial_AE.py b/tests/test_Adversarial_AE.py index 0acbc8fa..fbfabc00 100644 --- a/tests/test_Adversarial_AE.py +++ b/tests/test_Adversarial_AE.py @@ -419,6 +419,61 @@ def test_model_train_output(self, adversarial_ae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def adversarial_ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return Adversarial_AE(model_configs) + + + def test_interpolate(self, adversarial_ae, demo_data, granularity): + with pytest.raises(AssertionError): + adversarial_ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = adversarial_ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def adversarial_ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return Adversarial_AE(model_configs) + + + def test_reconstruct(self, adversarial_ae, demo_data): + + recon = adversarial_ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + @pytest.mark.slow class Test_Adversarial_AE_Training: diff --git a/tests/test_BetaTCVAE.py b/tests/test_BetaTCVAE.py index 2e3806f9..177b4903 100644 --- a/tests/test_BetaTCVAE.py +++ b/tests/test_BetaTCVAE.py @@ -300,6 +300,61 @@ def test_model_train_output(self, betavae, demo_data): assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return BetaTCVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return BetaTCVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture def demo_data(self): diff --git a/tests/test_BetaVAE.py b/tests/test_BetaVAE.py index dca0be15..ccafa4c1 100644 --- a/tests/test_BetaVAE.py +++ b/tests/test_BetaVAE.py @@ -291,6 +291,61 @@ def test_model_train_output(self, betavae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return BetaVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return BetaVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_DisentangledBetaVAE.py b/tests/test_DisentangledBetaVAE.py index 606a9ac3..59a958ee 100644 --- a/tests/test_DisentangledBetaVAE.py +++ b/tests/test_DisentangledBetaVAE.py @@ -310,6 +310,61 @@ def test_model_train_output(self, betavae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return DisentangledBetaVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return DisentangledBetaVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_FactorVAE.py b/tests/test_FactorVAE.py index 94e02980..a12b6fc6 100644 --- a/tests/test_FactorVAE.py +++ b/tests/test_FactorVAE.py @@ -334,6 +334,61 @@ def test_model_train_output(self, factor_ae, demo_data): assert not torch.equal(out.z, out.z_bis_permuted) +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return FactorVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return FactorVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_HVAE.py b/tests/test_HVAE.py index bb7d7282..aad458f7 100644 --- a/tests/test_HVAE.py +++ b/tests/test_HVAE.py @@ -323,6 +323,61 @@ def test_nll_compute(self, hvae, demo_data, nll_params): assert isinstance(nll, float) assert nll < 0 +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return HVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return HVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + @pytest.mark.slow class Test_HVAE_Training: diff --git a/tests/test_IWAE.py b/tests/test_IWAE.py index 2cf9654f..9b076f29 100644 --- a/tests/test_IWAE.py +++ b/tests/test_IWAE.py @@ -299,6 +299,61 @@ def test_model_train_output(self, iwae, demo_data): ) +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return IWAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return IWAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture def demo_data(self): diff --git a/tests/test_MSSSIMVAE.py b/tests/test_MSSSIMVAE.py index 8e8534ac..0eb9e6d7 100644 --- a/tests/test_MSSSIMVAE.py +++ b/tests/test_MSSSIMVAE.py @@ -299,6 +299,56 @@ def test_model_train_output(self, msssim_vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return MSSSIM_VAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return MSSSIM_VAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_RHVAE.py b/tests/test_RHVAE.py index 3acb0521..80d95cb1 100644 --- a/tests/test_RHVAE.py +++ b/tests/test_RHVAE.py @@ -442,6 +442,61 @@ def test_model_output(self, rhvae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return RHVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return RHVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_SVAE.py b/tests/test_SVAE.py index 2aeee030..e3090f92 100644 --- a/tests/test_SVAE.py +++ b/tests/test_SVAE.py @@ -289,6 +289,61 @@ def test_model_train_output(self, svae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return SVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return SVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_VAE.py b/tests/test_VAE.py index 904dc8d2..e0521f20 100644 --- a/tests/test_VAE.py +++ b/tests/test_VAE.py @@ -289,6 +289,61 @@ def test_model_train_output(self, vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_VAEGAN.py b/tests/test_VAEGAN.py index 8f6c6634..410127ef 100644 --- a/tests/test_VAEGAN.py +++ b/tests/test_VAEGAN.py @@ -416,6 +416,61 @@ def test_model_train_output(self, vaegan, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAEGAN(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAEGAN(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_VAE_IAF.py b/tests/test_VAE_IAF.py index e7a673a7..09eb68ca 100644 --- a/tests/test_VAE_IAF.py +++ b/tests/test_VAE_IAF.py @@ -295,6 +295,61 @@ def test_model_train_output(self, vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAE_IAF(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAE_IAF(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_VAE_LinFlow.py b/tests/test_VAE_LinFlow.py index 48abb79d..f392c46a 100644 --- a/tests/test_VAE_LinFlow.py +++ b/tests/test_VAE_LinFlow.py @@ -309,6 +309,62 @@ def test_model_train_output(self, vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAE_LinNF(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAE_LinNF(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_VAMP.py b/tests/test_VAMP.py index 224906d4..3a747bfd 100644 --- a/tests/test_VAMP.py +++ b/tests/test_VAMP.py @@ -293,6 +293,62 @@ def test_model_train_output(self, vamp, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAMP(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VAMP(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + + class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_VQVAE.py b/tests/test_VQVAE.py index c5f8ebd7..b053c414 100644 --- a/tests/test_VQVAE.py +++ b/tests/test_VQVAE.py @@ -306,6 +306,61 @@ def test_model_train_output(self, vae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VQVAE(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return VQVAE(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + @pytest.mark.slow class Test_VQVAETraining: diff --git a/tests/test_WAE_MMD.py b/tests/test_WAE_MMD.py index e561a196..ad7e8125 100644 --- a/tests/test_WAE_MMD.py +++ b/tests/test_WAE_MMD.py @@ -291,6 +291,61 @@ def test_model_train_output(self, wae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return WAE_MMD(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return WAE_MMD(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + @pytest.mark.slow class Test_WAE_MMD_Training: diff --git a/tests/test_info_vae_mmd.py b/tests/test_info_vae_mmd.py index 2ebb5ee9..fbbe33a7 100644 --- a/tests/test_info_vae_mmd.py +++ b/tests/test_info_vae_mmd.py @@ -302,6 +302,60 @@ def test_model_train_output(self, info_vae_mmd, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return INFOVAE_MMD(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return INFOVAE_MMD(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape class Test_NLL_Compute: @pytest.fixture diff --git a/tests/test_rae_gp.py b/tests/test_rae_gp.py index 57b831b9..a2b1a61c 100644 --- a/tests/test_rae_gp.py +++ b/tests/test_rae_gp.py @@ -291,6 +291,62 @@ def test_model_train_output(self, rae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return RAE_GP(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return RAE_GP(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + + @pytest.mark.slow class Test_RAE_GP_Training: diff --git a/tests/test_rae_l2.py b/tests/test_rae_l2.py index 4b148237..fa09bbaa 100644 --- a/tests/test_rae_l2.py +++ b/tests/test_rae_l2.py @@ -295,6 +295,61 @@ def test_model_train_output(self, rae, demo_data): assert out.z.shape[0] == demo_data["data"].shape[0] assert out.recon_x.shape == demo_data["data"].shape +class Test_Model_interpolate: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture() + def granularity(self): + return int(torch.randint(1, 10, (1,))) + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return RAE_L2(model_configs) + + + def test_interpolate(self, ae, demo_data, granularity): + with pytest.raises(AssertionError): + ae.interpolate(demo_data, demo_data[1:], granularity) + + interp = ae.interpolate(demo_data, demo_data, granularity) + + assert tuple(interp.shape) == (demo_data.shape[0], granularity,) + (demo_data.shape[1:]) + +class Test_Model_reconstruct: + @pytest.fixture( + params=[ + torch.randn(3, 2, 3, 1), + torch.randn(3, 2, 2), + torch.load(os.path.join(PATH, "data/mnist_clean_train_dataset_sample"))[ + : + ]['data'] + ] + ) + def demo_data(self, request): + return request.param + + @pytest.fixture + def ae(self, model_configs, demo_data): + model_configs.input_dim = tuple(demo_data[0].shape) + return RAE_L2(model_configs) + + + def test_reconstruct(self, ae, demo_data): + + recon = ae.reconstruct(demo_data) + assert tuple(recon.shape) == demo_data.shape + @pytest.mark.slow class Test_RAE_L2_Training: