Skip to content
/ MaskLLM Public

[NeurIPS 24 Spotlight] MaskLLM: Learnable Semi-structured Sparsity for Large Language Models

License

Notifications You must be signed in to change notification settings

NVlabs/MaskLLM

Repository files navigation

MaskLLM

MaskLLM: Learnable Semi-structured Sparsity for Large Language Models

- NeurIPS 2024 Spotlight -

Gongfan Fang, Hongxu Yin, Saurav Muralidharan, Greg Heinrich
Jeff Pool, Jan Kautz, Pavlo Molchanov, Xinchao Wang

NVIDIA Research, National University of Singapore

πŸ“„ [ArXiv] | 🎯 [Project Page] | πŸ“Ž [License]

MaskLLM

0. What is MaskLLM

This work introduces MaskLLM, a learnable pruning method that establishes Semi-structured (or ``N:M'') Sparsity in LLMs, aimed at reducing computational overhead during inference. The proposed method is scalable and stands to benefit from larger training datasets.

Scalability Scalability

1. Run MaskLLM with Megatron-LM

The following section provides an example for MaskLLM-LLaMA-2/3 on a single node with 8 GPUs. The LLaMA model will be shard across 8 GPUs with tensor parallelism, taking ~40GB per GPU for end-to-end training.

1.1 Docker Image

Docker is required for Megatron-LM. We use the official PyTorch docker image pytorch:24.01-py3 from NVIDIA NGC as the base image. If you can not use docker, please refer to the official setup instructions in Megatron-LM. Run the following command to download & start the docker container and mount your home directory.

docker run --gpus all --ipc=host --ulimit memlock=-1 --ulimit stack=67108864 -v $HOME:$HOME -it --rm nvcr.io/nvidia/pytorch:24.01-py3

1.2 Prepare LLaMA Checkpoints

In the container, we need to download the LLaMA checkpoints and convert them to Megatron format.

Download Huggingface Checkpoints

Install basic dependencies.

pip install transformers accelerate datasets SentencePiece wandb tqdm ninja tensorboardx==2.6 pulp timm einops

The following scripts download and save all HF checkpoints at ./assets/checkpoints.

python scripts/tools/download_llama2_7b_hf.py 
python scripts/tools/download_llama2_13b_hf.py
python scripts/tools/download_llama3_8b_hf.py
assets
β”œβ”€β”€ checkpoints
β”‚   β”œβ”€β”€ llama2_13b_hf
β”‚   β”œβ”€β”€ llama2_7b_hf
β”‚   └── llama3_8b_hf

Convert HF to Megatron

Convert the downloaded HF checkpoint to Megatron format, with tp=8 for tensor parallelism.

bash scripts/tools/convert_llama2_7b_hf_to_megatron.sh 
bash scripts/tools/convert_llama2_13b_hf_to_megatron.sh 
bash scripts/tools/convert_llama3_8b_hf_to_megatron.sh
assets/
β”œβ”€β”€ checkpoints
β”‚   β”œβ”€β”€ llama2_13b_hf
β”‚   β”œβ”€β”€ llama2_13b_megatron_tp8 # <= Megatron format
β”‚   β”œβ”€β”€ llama2_7b_hf
β”‚   β”œβ”€β”€ llama2_7b_megatron_tp8
β”‚   β”œβ”€β”€ llama3_8b_hf
β”‚   └── llama3_8b_megatron_tp8

Evaluate the dense model with the arguments size (7b/8b/13b), tensor parallelism (8), and sparsity (dense or sparse).

bash scripts/ppl/evaluate_llama2_wikitext2.sh assets/checkpoints/llama2_7b_megatron_tp8 7b 8 dense
bash scripts/ppl/evaluate_llama2_wikitext2.sh assets/checkpoints/llama2_13b_megatron_tp8 13b 8 dense
bash scripts/ppl/evaluate_llama3_wikitext2.sh assets/checkpoints/llama3_8b_megatron_tp8 8b 8 dense
# Outputs for LLaMA-2 7B:
validation results on WIKITEXT2 | avg loss: 1.6323E+00 | ppl: 5.1155E+00 | adjusted ppl: 5.1155E+00 | token ratio: 1.0 |

# Outputs for LLaMA-2 13B:
validation results on WIKITEXT2 | avg loss: 1.5202E+00 | ppl: 4.5730E+00 | adjusted ppl: 4.5730E+00 | token ratio: 1.0 |

# Outputs for LLaMA-3 8B:
validation results on WIKITEXT2 | avg loss: 1.7512E+00 | ppl: 5.7615E+00 | adjusted ppl: 5.7615E+00 | token ratio: 1.0 |

1.3 Pre-tokenize C4 Data for Megatron

Our paper uses a blended internal data for training. For reproducibility, we provide an example of learning masks on a subset of the public allenai/c4 dataset. Corresponding results can be found in Appendix D of our paper. Please see docs/preprocess_c4.md for the instructions.

1.4 Generate prior masks

It is encouraged to start training with a prior mask, either generated by SparseGPT, Wanda or Magnitude Pruning. The following scripts prune an LLaMA-2 7B model with 2:4 patterns. For SparseGPT, weight update is disabled. Add an argument --update-weight if necessary. More scripts for LLaMA-2 13B and LLaMA-3 8B are available at scripts/oneshot.

# <= SparseGPT mask
bash scripts/oneshot/run_llama2_7b_prune_tp8.sh hessian # --update-weight 
# <= Magnitude mask
bash scripts/oneshot/run_llama2_7b_prune_tp8.sh magnitude # --update-weight 
# <= Wanda mask
bash scripts/oneshot/run_llama2_7b_prune_tp8.sh wanda # --update-weight 

The pruned Llama model will contain additional .mask parameters in sparse linears, such as module.language_model.encoder.layers.31.mlp.dense_h_to_4h.mask.

output/
β”œβ”€β”€ oneshot_pruning
β”‚   β”œβ”€β”€ checkpoint
β”‚   β”‚   β”œβ”€β”€ llama2-7b-tp8.sparse.nmprune.sp0.5hessian.ex0 
β”‚   β”‚   └── llama2-7b-tp8.sparse.nmprune.sp0.5magnitude.ex0
β”‚   β”‚   β”œβ”€β”€ llama2-7b-tp8.sparse.nmprune.sp0.5wanda.ex0
β”‚   β”œβ”€β”€ llama2-7b-tp8.sparse.nmprune.sp0.5hessian.ex0.log
β”‚   β”œβ”€β”€ llama2-7b-tp8.sparse.nmprune.sp0.5magnitude.ex0.log
β”‚   └── llama2-7b-tp8.sparse.nmprune.sp0.5wanda.ex0.log

To evaluate the pruned model:

bash scripts/ppl/evaluate_llama2_wikitext2.sh output/oneshot_pruning/checkpoint/llama2-7b-tp8.sparse.nmprune.sp0.5hessian.ex0 7b 8 sparse

1.5 MaskLLM Training

Mask Sampling Visualization

By default, the script will load SparseGPT prior. Please modify the path in the script to load other masks. Here 0 means the initial training, and 1 means continue training from the latest checkpoint.

# Initial training with a prior mask. 
# By default, the script will load output/oneshot_pruning/checkpoint/llama2-7b-tp8.sparse.nmprune.sp0.5hessian.ex0 as the mask prior
bash scripts/learnable_sparsity/llama2_7b_mask_only_tp8_c4.sh 0 

# Pass the argument 1 to continue the training from the latest checkpoint
bash scripts/learnable_sparsity/llama2_7b_mask_only_tp8_c4.sh 1

1.6 Trim the checkpoint

For inference, we only need those winner masks with the highest probability. The following command will trim the checkpoint and remove unnecessary components.

python tool_trim_learnable_sparsity.py --ckpt_dir output/checkpoints/llama2-7b-tp8-mask-only-c4-singlenode/train_iters_2000/ckpt/iter_0002000 

Please modify the content in latest_checkpointed_iteration.txt as release for loading. This will set up a clean checkpoint with additional .mask parameters for each sparse layer.

1.7 To evaluate the MaskLLM model:

# For llama2 7b & 13b
bash scripts/ppl/evaluate_llama2_wikitext2.sh output/checkpoints/llama2-7b-tp8-mask-only-c4-singlenode/train_iters_2000/ckpt/ 7b 8 sparse

bash scripts/ppl/evaluate_llama2_wikitext2.sh output/checkpoints/llama2-13b-tp8-mask-only-c4-singlenode/train_iters_2000/ckpt/ 13b 8 sparse

# For llama3 8b
bash scripts/ppl/evaluate_llama3_wikitext2.sh output/checkpoints/llama3-8b-tp8-mask-only-c4-singlenode/train_iters_2000/ckpt/ 8b 8 sparse

1.8 Export to HF (Optional)

Please see docs/export_hf.md for instructions on exporting sparse models to Huggingface.

2 Key Results

exp

3 BibTeX

@article{fang2024maskllm,
  title={MaskLLM: Learnable Semi-Structured Sparsity for Large Language Models},
  author={Fang, Gongfan and Yin, Hongxu and Muralidharan, Saurav and Heinrich, Greg and Pool, Jeff and Kautz, Jan and Molchanov, Pavlo and Wang, Xinchao},
  journal={arXiv preprint arXiv:2409.17481},
  year={2024}
}