Skip to content

Zanette-Labs/SpeculativeRejection

Repository files navigation

Fast Best-of-N Decoding via Speculative Rejection

fast inference-time alignment

Hanshi Sun1*, Momin Haider2*, Ruiqi Zhang3*, Huitao Yang5, Jiahao Qiu4,
Ming Yin4, Mengdi Wang4, Peter Bartlett3, Andrea Zanette1*
1Carnegie Mellon University 2University of Virginia 3UC Berkeley
4Princeton University 5Fudan University
[Paper] | [Blog]

Environment Set Up

# create env
conda create -n SpecRej python=3.10 -y
conda activate SpecRej

# install packages
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/

Efficiency Evaluation

First, we need to run the Best-of-N baselines and Speculative Rejection. The following commands are examples of running the Best-of-120, Best-of-960, and Speculative Rejection (alpha=0.5) on the Meta-Llama-3-8B and ArmoRM-Llama3-8B-v0.1. For larger N (e.g., Best-of-3840), we can adjust the seed and merge the results from multiple runs using 8 H100 GPUs using postprocess/concat_json.py.

# Best-of-120
accelerate launch --num_processes 1 --num_machines 1 --gpu_ids 1 --machine_rank 0 --mixed_precision no --dynamo_backend no \
main.py --output_folder ./archive/Bo120_Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_0 \
--llm_name Meta-Llama-3-8B --reward_model_name ArmoRM-Llama3-8B-v0.1 \
--max_tokens 8000 --batch_size 120 --seed 0 

# ... (Best-of-240, Best-of-480)

# Best-of-960
accelerate launch --multi_gpu --num_processes 8 --num_machines 1 --gpu_ids 0,1,2,3,4,5,6,7 --machine_rank 0 --mixed_precision no \
--dynamo_backend no main.py --output_folder ./archive/Bo960_Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_0 \
--llm_name Meta-Llama-3-8B --reward_model_name ArmoRM-Llama3-8B-v0.1 \
--max_tokens 8000 --batch_size 120 --seed 0 

# Speculative Rejection (alpha=0.5)
accelerate launch --num_processes 1 --num_machines 1 --gpu_ids 0 --machine_rank 0 --mixed_precision no --dynamo_backend no \
main.py --output_folder ./archive/SpR_alpha_0.5_Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_0 \
--llm_name Meta-Llama-3-8B --reward_model_name ArmoRM-Llama3-8B-v0.1 \
--max_tokens 8000 --seed 0 \
--speculative_rejection --alpha 0.5

After gathering the results under archive folder, we can evaluate the efficiency of the Best-of-N baselines and Speculative Rejection using the following command.

# make sure the args correct in the script first
python postprocess/plot_compare.py

Win-rate Evaluation

When we get the all the outputs from the Best-of-N baselines and Speculative Rejection, we can evaluate the win-rate using alpaca_eval.

First, we need to gather the best utterances from the outputs of the Best-of-N baselines and Speculative Rejection and merge the outputs for win-rate evaluation.

# gather best answers
python postprocess/gather_best_ans.py

# merge json files for win-rate evaluation
python postprocess/merge_json.py

Then, we can evaluate the win-rate using the following command.

export OPENAI_API_KEY=YOUR_API_KEY

alpaca_eval make_leaderboard --leaderboard_path leader_board.csv  --all_model_outputs win_rate/Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_compare.json   --reference_outputs win_rate/Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_ref.json --output_path leader_board --fn_metric 'get_length_controlled_winrate' --sort_by 'length_controlled_winrate'  --is_overwrite_leaderboard

Citation

If you find Speculative Rejection useful or relevant to your project and research, please kindly cite our paper:

@article{sun2024fast,
  title={Fast Best-of-N Decoding via Speculative Rejection},
  author={Sun, Hanshi and Haider, Momin and Zhang, Ruiqi and Yang, Huitao and Qiu, Jiahao and Yin, Ming and Wang, Mengdi and Bartlett, Peter and Zanette, Andrea},
  journal={arXiv preprint arXiv:2410.20290},
  year={2024}
}