Nunchaku is an inference engine designed for 4-bit diffusion models, as demonstrated in our paper SVDQuant. Please check DeepCompressor for the quantization library.
- [Nov 7, 2024] 🔥 Our latest W4A4 Diffusion model quantization work SVDQuant is publicly released! Check DeepCompressor for the quantization library.
SVDQuant is a post-training quantization technique for 4-bit weights and activations that well maintains visual fidelity. On 12B FLUX.1-dev, it achieves 3.6× memory reduction compared to the BF16 model. By eliminating CPU offloading, it offers 8.7× speedup over the 16-bit model when on a 16GB laptop 4090 GPU, 3× faster than the NF4 W4A16 baseline. On PixArt-∑, it demonstrates significantly superior visual quality over other W4A4 or even W4A8 baselines. "E2E" means the end-to-end latency including the text encoder and VAE decoder.
SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models
Muyang Li*, Yujun Lin*, Zhekai Zhang*, Tianle Cai, Xiuyu Li, Junxian Guo, Enze Xie, Chenlin Meng, Jun-Yan Zhu, and Song Han
MIT, NVIDIA, CMU, Princeton, UC Berkeley, SJTU, and Pika Labs
Overview of SVDQuant. Stage1: Originally, both the activation
(a) Naïvely running low-rank branch with rank 32 will introduce 57% latency overhead due to extra read of 16-bit inputs in Down Projection and extra write of 16-bit outputs in Up Projection. Nunchaku optimizes this overhead with kernel fusion. (b) Down Projection and Quantize kernels use the same input, while Up Projection and 4-Bit Compute kernels share the same output. To reduce data movement overhead, we fuse the first two and the latter two kernels together.
SVDQuant reduces the model size of the 12B FLUX.1 by 3.6×. Additionally, Nunchaku, further cuts memory usage of the 16-bit model by 3.5× and delivers 3.0× speedups over the NF4 W4A16 baseline on both the desktop and laptop NVIDIA RTX 4090 GPUs. Remarkably, on laptop 4090, it achieves in total 10.1× speedup by eliminating CPU offloading.
Note:
-
For Windows user, please refer to this issue for the instruction.
-
We currently support only NVIDIA GPUs with architectures sm_86 (Ampere: RTX 3090, A6000), sm_89 (Ada: RTX 4090), and sm_80 (A100). See this issue for more details.
-
Install dependencies:
conda create -n nunchaku python=3.11 conda activate nunchaku pip install torch==2.4.1 torchvision==0.19.1 torchaudio==2.4.1 --index-url https://download.pytorch.org/whl/cu121 pip install diffusers ninja wheel transformers accelerate sentencepiece protobuf pip install huggingface_hub peft opencv-python einops gradio spaces GPUtil
-
Install
nunchaku
package: Make sure you havegcc/g++>=11
. If you don't, you can install it via Conda:conda install -c conda-forge gxx=11 gcc=11
Then build the package from source:
git clone https://github.com/mit-han-lab/nunchaku.git cd nunchaku git submodule init git submodule update pip install -e .
In example.py, we provide a minimal script for running INT4 FLUX.1-schnell model with Nunchaku.
import torch
from nunchaku.pipelines import flux as nunchaku_flux
pipeline = nunchaku_flux.from_pretrained(
"black-forest-labs/FLUX.1-schnell",
torch_dtype=torch.bfloat16,
qmodel_path="mit-han-lab/svdquant-models/svdq-int4-flux.1-schnell.safetensors", # download from Huggingface
).to("cuda")
image = pipeline("A cat holding a sign that says hello world", num_inference_steps=4, guidance_scale=0).images[0]
image.save("example.png")
Specifically, nunchaku
shares the same APIs as diffusers and can be used in a similar way. The FLUX.1-dev model can be loaded in the same way by replace all schnell
with dev
.
cd app/t2i
python run_gradio.py
- The demo also defaults to the FLUX.1-schnell model. To switch to the FLUX.1-dev model, use
-m dev
. - By default, the Gemma-2B model is loaded as a safety checker. To disable this feature and save GPU memory, use
--no-safety-checker
. - To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying
--use-qencoder
. - By default, only the INT4 DiT is loaded. Use
-p int4 bf16
to add a BF16 DiT for side-by-side comparison, or-p bf16
to load only the BF16 model.
cd app/i2i
python run_gradio.py
- Similarly, the demo loads the Gemma-2B model as a safety checker by default. To disable this feature, use
--no-safety-checker
. - To further reduce GPU memory usage, you can enable the W4A16 text encoder by specifying
--use-qencoder
. - By default, we use our INT4 model. Use
-p bf16
to switch to the BF16 model.
Please refer to app/t2i/README.md for instructions on reproducing our paper's quality results and benchmarking inference latency.
- Easy installation
- Comfy UI node
- Customized LoRA conversion instructions
- Customized model quantization instructions
- Modularization
- ControlNet and IP-Adapter integration
- Mochi and CogVideoX support
- Metal backend
If you find nunchaku
useful or relevant to your research, please cite our paper:
@article{
li2024svdquant,
title={SVDQuant: Absorbing Outliers by Low-Rank Components for 4-Bit Diffusion Models},
author={Li*, Muyang and Lin*, Yujun and Zhang*, Zhekai and Cai, Tianle and Li, Xiuyu and Guo, Junxian and Xie, Enze and Meng, Chenlin and Zhu, Jun-Yan and Han, Song},
journal={arXiv preprint arXiv:2411.05007},
year={2024}
}
- Efficient Spatially Sparse Inference for Conditional GANs and Diffusion Models, NeurIPS 2022 & T-PAMI 2023
- SmoothQuant: Accurate and Efficient Post-Training Quantization for Large Language Models, ICML 2023
- Q-Diffusion: Quantizing Diffusion Models, ICCV 2023
- AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration, MLSys 2024
- DistriFusion: Distributed Parallel Inference for High-Resolution Diffusion Models, CVPR 2024
- QServe: W4A8KV4 Quantization and System Co-design for Efficient LLM Serving, ArXiv 2024
- SANA: Efficient High-Resolution Image Synthesis with Linear Diffusion Transformers, ArXiv 2024
We thank MIT-IBM Watson AI Lab, MIT and Amazon Science Hub, MIT AI Hardware Program, National Science Foundation, Packard Foundation, Dell, LG, Hyundai, and Samsung for supporting this research. We thank NVIDIA for donating the DGX server.
We use img2img-turbo to train the sketch-to-image LoRA. Our text-to-image and sketch-to-image UI is built upon playground-v.25 and img2img-turbo, respectively. Our safety checker is borrowed from hart.
Nunchaku is also inspired by many open-source libraries, including (but not limited to) TensorRT-LLM, vLLM, QServe, AWQ, FlashAttention-2, and Atom.