Skip to content

Sirius, an efficient correction mechanism, which significantly boosts Contextual Sparsity models on reasoning tasks while maintaining its efficiency gain.

Notifications You must be signed in to change notification settings

Infini-AI-Lab/Sirius

Repository files navigation

Sirius: Contextual Sparsity
with Correction for Efficient LLMs

Sirius, an efficient correction mechanism, which significantly boosts Contextual Sparsity models on reasoning tasks while maintaining its efficiency gain.
1Carnegie Mellon University 2Stevens Institute of Technology 3Meta AI (FAIR)
[Paper] | [Blog]

Problem of Contextual Sparsity

Contextual Sparsity Weakness in Complex Reasoning Tasks
In this paper, we evaluate Contextual Sparsity (CS) models comprehensively on various complex generation tasks. CS models are evaluated at their default sparsity (50% neuron sparsity). Across the evaluation, we present the following takeaways:
  1. CS models work well on prompt understanding tasks, e.g. text summarization (CNN/DailyMail) and conversation question answering (CoQA).
  2. CS models significantly ill-perform on generation tasks that require complex reasoning (GSM8K) or knowledge-based tasks (MMLU-FLAN-COT).
  3. The problem in complex reasoning generation tasks escalates for more well-trained model, given the similar parameter count.

Effectiveness of Sirius

Sirius is proposed to effectively boost the weakness of CS on complex generation tasks on reasoning, while maintaining the efficiency of CS models. Sirius is evaluated on 6 models with 8 different complex tasks ranging from arithmetic reasoning, commonsense reasoning, and code generation. (More detailed results, please refer to the paper). Below we show briefly the results for Llama-3-8B-Instruct on GSM8K, CSQA, and HumanEval. Llama-3-8B-Instruct with Sirius Effectiveness on Different Complex Tasks
GSM8K
Model Full Perf. CSparse Perf. CSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.7536 0.3844 0.65 0.7051 (8) 15.22/16 0.706
Model Full Perf. FSparse Perf. FSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.7536 0.5868 0.76 0.7278 (4) 15.37/16 0.807
CSQA
Model Full Perf. CSparse Perf. CSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.7073 0.6470 0.58 0.7076 (8) 14.76/16 0.657
Model Full Perf. FSparse Perf. FSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.7073 0.6158 0.72 0.7043 (8) 15.66/16 0.753
HumanEval
Model Full Perf. CSparse Perf. CSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.561 0.207 0.65 0.524 (8) 14.67/16 0.733
Model Full Perf. CSparse Perf. CSparse Density Sirius Perf. AAL Effective Density
Llama-3-8B-Instruct 0.561 0.457 0.76 0.616 (6) 15.42/16 0.804

On-Chip Generation Speedup

Sirius is able to deliver the promised latency reduction. We focus on generation-only setting. Below are some partial results for the on-chip setting with GSM-8K-COT. For the specific experiment setup, please refer to the paper. Llama-3-8B-Instruct On-Chip Wallclock Latency Speedup
Settings Performance A40 Speedup Ratio L40 Speedup Ratio Performance A100 Speedup Ratio
Coarse-grained Sparsity 0.3601 20.7 0.85 15.6 0.67 0.3601 9.6 0.72
Sirius 0.7127 24.1 0.77 18.2 0.78 0.7089 11.1 0.83
Full 0.7612 30.9 1.0 23.2 1.0 0.7612 13.3 1.0

Offloading Setting Speedup

We also show the speedup for the Llama-3-70B-Instruct on GSM-8K-COT with PCIE 25 GB/s.
Llama-3-70B-Instruct with Offloading
Settings Sparse Sirius Full
Performance 0.7407 0.8719 0.9014
Latency (s) 3.57 s 3.68 s 5.72 s
Ratio to Full 0.6241 0.6434 1.0

Overview of the Code

Environment Setup

pip install -r requirements.txt 
pip install flash-attn --no-build-isolation 

On special package to notice is that since Sirius uses torch.compile to optimize the inference latency, we strictly require PyTorch version to be 2.3.0.

Test Sirius Effectiveness and Efficiency Metrics AAL

This section presents code that is for testing only the effectiveness and specific efficiency metrics AAL. The implementation isn't for best speedup.
  • GSM-8K, GSM-8K-COT, CNN/DailyMail, MMLU-FLAN-COT
    We use base our implementation on LM Evaluation Harness since they support these tasks. The essential blocks are packed in the folder "Miscellaneous". To run the Sirius on various Huggingface models, follow the line.
cd Miscellaneous 
# Full model 
accelerate launch --main_process_port <main_port> --num_processes <num_procs> --num_machines <num_node> main.py --model xhf --model_args pretrained=<huggingface-token-model>,griffin=False,check=False --tasks <tasks_name> --batch_size 1 
# Coarse-grained Sparsity 
accelerate launch --main_process_port <main_port> --num_processes <num_procs> --num_machines <num_node> main.py --model xhf --model_args pretrained=<huggingface-token-model>,griffin=True,check=False --tasks <task_name> --batch_size 1 
# Fine-grained Sparsity 
accelerate launch --main_process_port <main_port> --num_processes <num_procs> --num_machines <num_node> main.py --model xhf --model_args pretrained=<huggingface-token-model>,cats=True,check=False --tasks <task_name> --batch_size 1
# Sirius with Sparse 
accelerate launch --main_process_port <main_port> --num_processes <num_procs> --num_machines <num_node> main.py --model xhf --model_args pretrained=<huggingface-token-model>,griffin=True,check=True,kernel_size=<kernel_size>,widthtree=<width_tree>,patternstrict=True,thr=0.1 --tasks <task_name> --batch_size 1 

For Sirius to be turned on, set check=True. cats=True for fine-grained sparsity, while griffin=True for coarse-grained sparsity. Importantly, fine-grained sparsity here is based on topk not the threshold as in https://arxiv.org/abs/2404.08763. Unfortunately, their implementation isn't open-sourced, and using the threshold isn't safe for testing multiple different generation datasets and maintaining the same neuron sparsity level.

For cats=True and Sirius to have widthtree > 1, patternstrict must set to True.

  • For Commonsense Reasoning tasks, we follow the Chain-of-Thought (https://arxiv.org/abs/2201.11903) work to convert previously multiple-choice question dataset CSQA, StrategyQA, Date and Sports into generation question. The essential block is packed in "CommonSenseReasoning" folder.
cd CommonSenseReasoning
# Sirius with Sparse 
accelerate launch --main_process_port <main_port> --num_processes <num_proc> main.py --tasks <task_name> --model <huggingface_token> --shotfive --cats --check --kernel_size <kernel_size> --spr <sparity> --thr <threshold> --widthtree <widthtree> --patternstrict 

Adding --cats for fine-grained sparsity or --griffin for coarse-grained sparsity, neither for full model. Adding --check for using full model for correction, or correction is not used Again, for --cats and <widthtree>>1, --patternstrict must be added. --shotfive is used for 5 fewshot examples, which is the setting where the measurement is performed.

cd CodeGeneration 
accelerate launch --num_processes <num_proc> main.py \
  --model <huggingface-token-model> \
  --tasks <task_name> \
  --do_sample False \
  --n_samples 1 \
  --batch_size 1 \
  --max_length_generation 512 \
  --enable_epatches \
  --cats \
  --allow_code_execution \
  --spr <sparsity> \
  --widthtree $treesize \
  --check \
  --kernelsize <kernel_size> \
  --thr <threshold> \
  --patternstrict \

In our code, we only use greedy decoding, --do_sample is False. Similarly, adding --cats for fine-grained sparsity or --griffin for coarse-grained sparsity, neither for full model. Adding --check for using full model for correction, or correction is not used Again, for --cats and <widthtree>>1, --patternstrict must be added. For <task_name>, we only support humaneval and mbppplus.

Speedup On-Chip

This section presents code for measuring wallclock speedup of Sirius.
cd Miscellaneous

No Tree

python main.py --model xhf --model_args pretrained=<huggingface-token-model>,griffin=True,check=True,kernel_size=<kernel_size>,widthtree=<width_tree>,patternstrict=True,thr=0.05,mode=wallclock_notree --tasks <task_name> --batch_size 1 

With Tree

Right now, we only suuport treewidth of four.
python main.py --model xhf --model_args pretrained=<huggingface-token-model>,griffin=True,check=True,kernel_size=<kernel_size>,widthtree=<width_tree>,patternstrict=True,thr=0.05,mode=wallclock_tree --tasks <task_name> --batch_size 1 

Offloading

python main.py --model xhf --model_args pretrained=<huggingface-token-model>,griffin=True,check=True,kernel_size=<kernel_size>,widthtree=<width_tree>,patternstrict=True,thr=0.05,mode=wallclock_70b --tasks <task_name> --batch_size 1 

About

Sirius, an efficient correction mechanism, which significantly boosts Contextual Sparsity models on reasoning tasks while maintaining its efficiency gain.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published