This a an unofficial PyTorch (Lightning) implementation of EDM Elucidating the Design Space of Diffusion-Based Generative Models and Analyzing and Improving the Training Dynamics of Diffusion Models.
- Config G.
- Post-hoc EMA.
git clone https://github.com/YichengDWu/tinyedm.git
cd tinyedm && pip install .
python experiments/train.py --config-name=mnist
python experiments/train.py --config-name=cifar10
To download the ImageNet dataset, follow these steps:
- Visit the ImageNet website: http://www.image-net.org/
- Register for an account and request access for the dataset.
- Once approved, follow the instructions provided by ImageNet to download the dataset.
After downloading the ImageNet dataset, extract the files to a directory. When running the feature extraction script, use the --data-dir
option to specify the path to this directory.
For example:
python src/tinyedm/datamodules/extract_latents.py --data-dir ./datasets/imagenet/train --out-dir ./datasets/imagenet/latents/train
python src/tinyedm/generate.py \
--ckpt_path /path/to/checkpoint.ckpt \
--load_ema \
--output_dir /path/to/output \
--num_samples 50000 \
--image_size 32 \
--num_classes 10 \
--batch_size 128 \
--num_workers 16 \
--num_steps 32
Dataset | Params | type | epochs | FID |
---|---|---|---|---|
CIFAR-10 | 35.6 M | unconditional | 1700 | 4.0 |
- Using FP16 mixed precision training on the CIFAR-10 dataset sometimes leads to overflow, so we have adopted bf16 mixed precision, which may result in a loss of accuracy for the model.
- For the scale factors of skip connections, this implementation uses a small network to learn them, inspired by ScaleLong: Towards More Stable Training of Diffusion Model via Scaling Network Long Skip Connection . The experiment shows that this improves the results.
- The use of multi-task learning in the paper did not observe any improvement, or it may be more effective in long-term training. However, I do not have the compute power to verify this.