This repository contains the official implementation of our Consistency Flow Matching.
Consistency Flow Matching: Defining Straight Flows with Velocity Consistency
Ling Yang, Zixiang Zhang, Zhilong Zhang, Xingchao Liu, Minkai Xu, Wentao Zhang, Chenlin Meng, Stefano Ermon, Bin Cui
Peking University, University of Texas at Austin, Stanford University, Pika Labs
Flow matching (FM) is a general framework for defining probability paths via Ordinary Differential Equations (ODEs) to transform between noise and data samples. Recent approaches attempt to straighten these flow trajectories to generate high-quality samples with fewer function evaluations, typically through iterative rectification methods or optimal transport solutions. In this paper, we introduce Consistency Flow Matching (Consistency-FM), a novel FM method that explicitly enforces self-consistency in the velocity field. Consistency-FM directly defines straight flows starting from different times to the same endpoint, imposing constraints on their velocity values. Additionally, we propose a multi-segment training approach for Consistency-FM to enhance expressiveness, achieving a better trade-off between sampling quality and speed. Experiments demonstrate that our Consistency-FM significantly improves training efficiency by converging 4.4x faster than consistency models and 1.7x faster than rectified flow models while achieving better generation quality
Run the following commands to install the dependencies:
conda create -n cfm python=3.8
conda activate cfm
pip install torch==1.11.0+cu113 torchvision==0.12.0+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
pip install tensorflow==2.9.0 tensorflow-probability==0.12.2 tensorflow-gan==2.0.0 tensorflow-datasets==4.6.0
pip install -U jax==0.3.4 jaxlib==0.3.2+cuda11.cudnn82 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
pip install -r requirements.txt
Run the following command to train a Consistency-FM Flow from scratch
python ./main.py --config ./configs/consistencyfm/cifar10_gaussian_ddpmpp.py --eval_folder eval --mode train --workdir ./logs/cifar10 --config.consistencyfm.boundary 0 --config.training.n_iters 100001
# After the above training process completes, run the command below to continue.
python ./main.py --config ./configs/consistencyfm/cifar10_gaussian_ddpmpp.py --eval_folder eval --mode train --workdir ./logs/cifar10 --config.consistencyfm.boundary 0.9 --config.training.n_iters 200001
-
--config
The configuration file for this run. -
--eval_folder
The generated images and other files for each evaluation during training will be stroed in./workdir/eval_folder
. In this command, it is./logs/cifar10/eval/
-
--mode
Mode selection formain.py
. Select fromtrain
oreval
.
We follow the evaluation pipeline as in Score SDE. You can download cifar10_stats.npz
and save it to assets/stats/
.
Then run
python ./main.py --config ./configs/consistencyfm/cifar10_gaussian_ddpmpp.py --eval_folder eval/NFE=2 --mode eval --config.eval.enable_sampling --config.eval.batch_size 1024 --config.eval.num_samples 50000 --config.sampling.sample_N 2 \
--config.eval.begin_ckpt 34 \
--config.eval.end_ckpt 40 \
--workdir ./logs/cifar10
which uses a batch size of 1024 to sample 50000 images, starting from checkpoint-34.pth, and computes the FID and IS.
For a quick start, we have provided CIFAR-10 checkpoints at this link.
To train Consistency-FM on CelebA-HQ (256x256), follow these steps:
python ./main.py --config ./configs/consistencyfm/celeba_hq_pytorch_gaussian.py --eval_folder eval --mode train --workdir ./logs/celebahq --config.consistencyfm.boundary 0 --config.training.n_iters 150001 --config.training.data_dir path_to_celebahq
# After the above training process completes, run the command below to continue.
python ./main.py --config ./configs/consistencyfm/celeba_hq_pytorch_gaussian.py --eval_folder eval --mode train --workdir ./logs/celebahq --config.consistencyfm.boundary 0.9 --config.training.n_iters 250001 --config.training.data_dir path_to_celebahq
To sample images, run:
python ./main.py --config ./configs/consistencyfm/celeba_hq_pytorch_gaussian.py --mode eval --config.eval.enable_figures_only \
--config.training.data_dir path_to_celebahq \
--config.eval.num_samples 200 \
--config.eval.batch_size 25 \
--eval_folder fig_only/NFE=6 --config.sampling.sample_N 6 \
--workdir ./logs/celebahq \
--config.eval.begin_ckpt 0 --config.eval.end_ckpt 100
To evaluate the model, follow the instructions in On Aliased Resizing and Surprising Subtleties in GAN Evaluation to generate the custom stats files, and then use the command below to calculate the FID score:
python ./main.py --config ./configs/consistencyfm/celeba_hq_pytorch_gaussian.py --mode eval --config.eval.enable_sampling --config.eval.end_ckpt 100 --config.eval.clean_fid.enabled True \
--config.training.data_dir path_to_celebahq \
--config.eval.num_samples 50000 \
--config.eval.batch_size 100 \
--workdir ./logs/celebahq \
--eval_folder eval_cleanfid/samples=10K/NFE=6 \
--config.eval.begin_ckpt 51 \
--config.sampling.sample_N 6 \
--config.eval.clean_fid.custom_stat.dataset_name1 custom_cleanfid_stats_name
For AFHQ-CAT, simply replace ./configs/consistencyfm/celeba_hq_pytorch_gaussian.py
with ./configs/consistencyfm/afhq_cat_pytorch_gaussian.py
.
If you use the code or our work is related to yours, please cite us:
@article{yang2024consistencyfm,
title={Consistency Flow Matching: Defining Straight Flows with Velocity Consistency},
author={Yang, Ling and Zhang, Zixiang and Zhang, Zhilong and Liu, Xingchao and Xu, Minkai and Zhang, Wentao and Meng, Chenlin and Ermon, Stefano and Cui, Bin},
journal={arXiv preprint arXiv:2407.02398},
year={2024}
}
Thanks to RectifiedFlow and TorchCFM for providing their implementations, which have significantly contributed to this codebase.