Ming Yin4, Mengdi Wang4, Peter Bartlett3, Andrea Zanette1*
4Princeton University 5Fudan University
# create env
conda create -n SpecRej python=3.10 -y
conda activate SpecRej
# install packages
pip install -r requirements.txt
pip install flash-attn --no-build-isolation
pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.3/
First, we need to run the Best-of-N baselines and Speculative Rejection. The following commands are examples of running the Best-of-120, Best-of-960, and Speculative Rejection (alpha=0.5
) on the Meta-Llama-3-8B
and ArmoRM-Llama3-8B-v0.1
. For larger N (e.g., Best-of-3840), we can adjust the seed and merge the results from multiple runs using 8 H100 GPUs using postprocess/concat_json.py
.
# Best-of-120
accelerate launch --num_processes 1 --num_machines 1 --gpu_ids 1 --machine_rank 0 --mixed_precision no --dynamo_backend no \
main.py --output_folder ./archive/Bo120_Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_0 \
--llm_name Meta-Llama-3-8B --reward_model_name ArmoRM-Llama3-8B-v0.1 \
--max_tokens 8000 --batch_size 120 --seed 0
# ... (Best-of-240, Best-of-480)
# Best-of-960
accelerate launch --multi_gpu --num_processes 8 --num_machines 1 --gpu_ids 0,1,2,3,4,5,6,7 --machine_rank 0 --mixed_precision no \
--dynamo_backend no main.py --output_folder ./archive/Bo960_Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_0 \
--llm_name Meta-Llama-3-8B --reward_model_name ArmoRM-Llama3-8B-v0.1 \
--max_tokens 8000 --batch_size 120 --seed 0
# Speculative Rejection (alpha=0.5)
accelerate launch --num_processes 1 --num_machines 1 --gpu_ids 0 --machine_rank 0 --mixed_precision no --dynamo_backend no \
main.py --output_folder ./archive/SpR_alpha_0.5_Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_0 \
--llm_name Meta-Llama-3-8B --reward_model_name ArmoRM-Llama3-8B-v0.1 \
--max_tokens 8000 --seed 0 \
--speculative_rejection --alpha 0.5
After gathering the results under archive
folder, we can evaluate the efficiency of the Best-of-N baselines and Speculative Rejection using the following command.
# make sure the args correct in the script first
python postprocess/plot_compare.py
When we get the all the outputs from the Best-of-N baselines and Speculative Rejection, we can evaluate the win-rate using alpaca_eval
.
First, we need to gather the best utterances from the outputs of the Best-of-N baselines and Speculative Rejection and merge the outputs for win-rate evaluation.
# gather best answers
python postprocess/gather_best_ans.py
# merge json files for win-rate evaluation
python postprocess/merge_json.py
Then, we can evaluate the win-rate using the following command.
export OPENAI_API_KEY=YOUR_API_KEY
alpaca_eval make_leaderboard --leaderboard_path leader_board.csv --all_model_outputs win_rate/Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_compare.json --reference_outputs win_rate/Meta-Llama-3-8B_ArmoRM-Llama3-8B-v0.1_ref.json --output_path leader_board --fn_metric 'get_length_controlled_winrate' --sort_by 'length_controlled_winrate' --is_overwrite_leaderboard
If you find Speculative Rejection useful or relevant to your project and research, please kindly cite our paper:
@article{sun2024fast,
title={Fast Best-of-N Decoding via Speculative Rejection},
author={Sun, Hanshi and Haider, Momin and Zhang, Ruiqi and Yang, Huitao and Qiu, Jiahao and Yin, Ming and Wang, Mengdi and Bartlett, Peter and Zanette, Andrea},
journal={arXiv preprint arXiv:2410.20290},
year={2024}
}