If you want a quick start (although it is recommended to read this, as it provides a quick overview so you can understand and then use the model):
git clone --depth=1 https://github.com/Pythonic456/vae-latentspace-explorer
pip3 install torch numba PIL numpy tqdm tkinter
python3 train.py
python3 tkload.py
The Variational Autoencoder (VAE) is a generative neural network architecture used for unsupervised learning and data generation. It belongs to the family of autoencoders, which are neural networks designed to encode input data into a compact latent representation and then decode it back into the original data.
The VAE consists of three main parts: an encoder, decoder, and the latent space.
- Encoder
The encoder network takes input data, in this case, images, and maps them to a lower-dimensional latent space. The encoder typically comprises convolutional layers followed by fully connected layers. In the provided code, the encoder consists of convolutional layers that progressively reduce the spatial dimensions of the input image while increasing the number of feature channels. The final fully connected layers produce two vectors: the mean (mu) and the log variance (logvar) of the latent space representation.
- Decoder
The decoder network takes a point from the latent space (a sampled vector) and reconstructs the original data. Like the encoder, the decoder also includes fully connected layers and convolutional transpose layers. The output of the decoder is an image that should closely resemble the input data.
- Latent Space
The latent space is a lower-dimensional representation of the input data. It's sampled from a multivariate Gaussian distribution parametrized by the mean (mu) and log variance (logvar) produced by the encoder. The VAE's unique feature is that it learns not only to map data to this latent space but also to generate data from it.
Training a VAE involves two main components: a reconstruction loss and a regularization term.
- Reconstruction Loss
The reconstruction loss measures how well the decoder can reconstruct the input data from the latent space. In the provided code, you can choose between different loss functions, including binary cross-entropy, mean squared error, or a combination of both.
- Regularization (KL Divergence)
To encourage the learned latent space to follow a Gaussian distribution, a regularization term is added to the loss function. This term, often referred to as the KL Divergence, penalizes the deviation of the latent space from a unit Gaussian distribution. It helps the VAE to disentangle and organize information in the latent space.
Once the VAE is trained, you can sample points from the learned latent space and use the decoder to generate new data samples. The provided code includes a GUI application that allows you to interactively explore the latent space and see how different points in the space correspond to generated images.
- Train the VAE: Set the appropriate parameters, including the input folder, image size, latent space size, number of epochs etc. in
train.py
. Run the training code:python3 train.py
Some example training images are provided in in/
, these images are from https://thispersondoesnotexist.com/
-
Explore the Latent Space:
python3 tkload.py
You can adjust the sliders to navigate through the latent space and visualize the generated images. Make sure to use set the parameters at the start of thetkload.py
file to match the ones you used to train the model intrain.py
. -
Save Generated Images: If you find interesting images, you can save them using the "Save" button in the GUI.