-
Notifications
You must be signed in to change notification settings - Fork 2
/
index.html
352 lines (326 loc) · 21.1 KB
/
index.html
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
<!DOCTYPE html>
<html>
<head>
<meta charset="utf-8">
<!-- Meta tags for social media banners, these should be filled in appropriatly as they are your "business card" -->
<!-- Replace the content tag with appropriate information -->
<meta name="description" content="Fast Best-of-N Decoding via Speculative Rejection">
<meta property="og:title" content="Speculative Rejection"/>
<meta property="og:description" content="Fast Best-of-N Decoding via Speculative Rejection"/>
<!-- Path to banner image, should be in the path listed below. Optimal dimenssions are 1200X630-->
<meta property="og:image" content="static/images/proj_fig.png" />
<meta property="og:image:width" content="1200"/>
<meta property="og:image:height" content="630"/>
<meta name="twitter:title" content="Speculative Rejection">
<meta name="twitter:description" content="Fast Best-of-N Decoding via Speculative Rejection">
<!-- Path to banner image, should be in the path listed below. Optimal dimenssions are 1200X600-->
<meta name="twitter:image" content="static/images/proj_fig.png">
<meta name="twitter:card" content="summary_large_image">
<!-- Keywords for your paper to be indexed by-->
<meta name="keywords" content="KV Cache">
<meta name="viewport" content="width=device-width, initial-scale=1">
<title>Fast Best-of-N Decoding via Speculative Rejection</title>
<link rel="icon" type="image/x-icon" href="static/images/Speculative Rejection.png">
<link href="https://fonts.googleapis.com/css?family=Google+Sans|Noto+Sans|Castoro"
rel="stylesheet">
<link rel="stylesheet" href="static/css/bulma.min.css">
<link rel="stylesheet" href="static/css/bulma-carousel.min.css">
<link rel="stylesheet" href="static/css/bulma-slider.min.css">
<link rel="stylesheet" href="static/css/fontawesome.all.min.css">
<link rel="stylesheet"
href="https://cdn.jsdelivr.net/gh/jpswalsh/academicons@1/css/academicons.min.css">
<link rel="stylesheet" href="static/css/index.css">
<script src="https://ajax.googleapis.com/ajax/libs/jquery/3.5.1/jquery.min.js"></script>
<script src="https://documentcloud.adobe.com/view-sdk/main.js"></script>
<script defer src="static/js/fontawesome.all.min.js"></script>
<script src="static/js/bulma-carousel.min.js"></script>
<script src="static/js/bulma-slider.min.js"></script>
<script src="static/js/index.js"></script>
<script type="text/x-mathjax-config">
MathJax.Hub.Config({tex2jax: {inlineMath: [['$','$'], ['\\(','\\)']]}});
</script>
<script type="text/javascript"
src="http://cdnjs.cloudflare.com/ajax/libs/mathjax/2.7.1/MathJax.js?config=TeX-AMS-MML_HTMLorMML">
</script>
</head>
<body>
<section class="hero">
<div class="hero-body">
<div class="container is-max-desktop">
<div class="columns is-centered">
<div class="column has-text-centered">
<h1 class="title is-2 publication-title" style="display: inline;"> Fast Best-of-N Decoding via Speculative Rejection</h1>
<br><br>
<div class="is-size-5 publication-authors">
<!-- Paper authors -->
<span class="author-block">
<a href="https://github.com/preminstrel">Hanshi Sun</a></b><sup>1*</sup>,</span>
<span class="author-block">
Momin Haider<sup>2*</sup>,
</span>
<span class="author-block">
<a href="https://rqzhangberkeley.github.io/">Ruiqi Zhang</a></b><sup>3*</sup>,</span>
<span class="author-block">
Huitao Yang<sup>5</sup>,
</span>
<span class="author-block">
<a href="https://ece.princeton.edu/people/jiahao-qiu">Jiahao Qiu</a></b><sup>4</sup>,
</span> <br>
<span class="author-block">
<a href="https://mingyin0312.github.io/">Ming Yin</a></b><sup>4</sup>,
</span>
<span class="author-block">
<a href="https://mwang.princeton.edu/">Mengdi Wang</a></b><sup>4</sup>,
</span>
<span class="author-block">
<a href="https://people.eecs.berkeley.edu/~bartlett/">Peter Bartlett</a></b><sup>3</sup>,
</span>
<span class="author-block">
<a href="https://azanette.com/">Andrea Zanette</a></b><sup>1*</sup>
</span>
</div>
<div class="is-size-5 publication-authors">
<span class="affliation"><small><sup>1</sup>Carnegie Mellon University
<sup>2</sup>University of Virginia
<sup>3</sup>UC Berkeley<br>
<sup>4</sup>Princeton University
<sup>5</sup>Fudan University</span>
<span class="eql-cntrb"><small><br><sup>*</sup>Core Contributors</small></span>
</div>
<div class="column has-text-centered">
<!-- ArXiv abstract Link -->
<span class="link-block">
<a href="https://arxiv.org/abs/2410.20290" target="_blank"
class="external-link button is-normal is-rounded is-dark">
<span class="icon">
<i class="ai ai-arxiv"></i>
</span>
<span>arXiv</span>
</a>
</span>
<!-- Github link -->
<span class="link-block">
<a href="https://github.com/Zanette-Labs/SpeculativeRejection" target="_blank"
class="external-link button is-normal is-rounded is-dark">
<span class="icon">
<i class="fab fa-github"></i>
</span>
<span>Code</span>
</a>
</span>
</div>
</div>
</div>
</div>
</div>
</div>
</section>
<!-- Paper abstract -->
<section class="section hero is-light">
<div class="container is-max-desktop">
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;"><img src="static/images/Align.png" style="height: 43px; display: inline; vertical-align:text-top;"/> Introduction</h2>
<div class="content has-text-justified">
<p>
The safe and effective deployment of LLMs involves a critical step called alignment, which ensures that the model's responses are in accordance with human preferences. Techniques like <a href="https://proceedings.neurips.cc/paper_files/paper/2023/hash/a85b405ed65c6477a4fe8302b5e06ce7-Abstract-Conference.html" style="color: #209CEE;">DPO</a>, <a href="https://arxiv.org/abs/1707.06347" style="color: #209CEE;">PPO</a> and their variants, align LLMs by changing the pre-trained model weights during a phase called post-training. While predominant, these post-training methods <b>add substantial complexity</b> before LLMs can be deployed. Inference-time alignment methods avoid the complex post-training step and instead bias the generation towards responses that are aligned with human preferences. The best-known inference-time alignment method, called <a href="https://proceedings.neurips.cc/paper/2020/hash/1f89885d556929e98d3ef9b86448f951-Abstract.html" style="color: #209CEE;">Best-of-N</a>, is as effective as the state-of-the-art post-training procedures. Unfortunately, Best-of-N <b>requires vastly more resources at inference time</b> than standard decoding strategies, which makes it computationally not viable. We introduce <b>Speculative Rejection, a computationally-viable inference-time alignment algorithm</b>. It generates high-scoring responses according to a given reward model, like Best-of-N does, while being <b>between 16 to 32 times more computationally efficient</b>.
</p>
<div class="figure">
<img src="static/images/perf_rm.png" alt="Retrieval-based Drafting" height="400" />
</div>
<br>
<p>We evaluate the effectiveness of Speculative Rejection on the <a href="https://github.com/tatsu-lab/alpaca_eval" style="color: #209CEE;">AlpacaFarm-Eval</a> dataset using various generative models and reward models. The numbers indicate N for Best-of-N and rejection rate α for Speculative Rejection. <b>Our method consistently achieves higher reward scores with fewer computational resources compared to Best-of-N</b>.</p>
</div>
</div>
</div>
</div>
</section>
<!-- Speculative Rejection -->
<section class="section hero is-light">
<div class="container is-max-desktop">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;"><img src="static/images/rej.png" style="height: 50px; display: inline; vertical-align: middle;"/> Speculative Rejection</h2>
<div class="content has-text-justified">
<p>
Speculative Rejection is based on the observation that the <b>reward function used for scoring the utterances can distinguish high-quality responses from low-quality ones at an early stage of the generation</b>. In other words, we observe that <b>the scores of partial utterances are positively correlated to the scores of full utterances</b>. As illustrated in the figure, this insight enables us to identify, during generation, utterances that are unlikely to achieve high scores upon completion, allowing us to <b>halt their generation early</b>.
</p>
<div class="figure">
<img src="static/images/spr.png" alt="Speculative Rejection System" height="400" />
</div>
<br>
<p>
Speculative Rejection begins with <b>a very large batch size, effectively simulating the initial phases of Best-of-N with a large N (e.g., 5000) on a single accelerator</b>. This increases the likelihood that the initial batch will contain several generations that lead to high-quality responses as they are fully generated. However, such a large batch size would eventually exhaust the GPU memory during the later stages of auto-regressive generation. To address this, Speculative Rejection queries the reward model <b>multiple times throughout the generation process</b>, attempting to infer which responses are <b>unlikely to score high upon completion</b>. Our method dynamically <b>reducing the batch size</b> and <b>preventing memory exhaustion</b> while <b>ensuring that only the most promising responses are fully generated</b>.
</p>
</div>
</div>
</div>
</div>
</section>
<section class="section hero is-light">
<div class="container is-max-desktop">
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;"><img src="static/images/gpt.png" style="height: 50px; display: inline; vertical-align: middle;"/> Win-rate Evaluation by GPT-4-Turbo</h2>
<div class="content has-text-justified">
<p>
To further validate the generation quality, we evaluate both the <b>win-rate</b> and the <b>length-controlled (LC) win-rate</b> using GPT-4-Turbo with <a href="https://github.com/tatsu-lab/alpaca_eval" style="color: #209CEE;">alpaca eval</a>. For each measurement, the win-rate baseline is Bo120. As shown in the table, <b>Speculative Rejection maintains generation quality while achieving a notable speedup across various settings</b> for the <a href="https://huggingface.co/mistralai/Mistral-7B-v0.3/tree/main" style="color: #209CEE;">Mistral-7B</a>, <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B" style="color: #209CEE;">Llama-3-8B</a>, and <a href="https://huggingface.co/meta-llama/Meta-Llama-3-8B-Instruct" style="color: #209CEE;">Llama-3-8B-Instruct</a> models, scored by the reward model <a href="https://huggingface.co/RLHFlow/ArmoRM-Llama3-8B-v0.1" style="color: #209CEE;">ArmoRM-Llama-3-8B</a> and evaluated using GPT-4-Turbo. "WR" refers to win-rate, and "LC-WR" refers to length-controlled win-rate.
</p>
<br>
<table style="width:100%; border-collapse:collapse; text-align:center;">
<thead>
<tr>
<th rowspan="2" >Methods</th>
<th colspan="2">Mistral-7B</th>
<th colspan="2">Llama-3-8B</th>
<th colspan="2">Llama-3-8B-Instruct</th>
<th colspan="2">Average</th>
</tr>
<tr>
<th>WR</th>
<th>LC-WR</th>
<th>WR</th>
<th>LC-WR</th>
<th>WR</th>
<th>LC-WR</th>
<th>WR</th>
<th>LC-WR</th>
</tr>
</thead>
<tbody>
<tr>
<td>Bo120</td>
<td>50.00</td><td>50.00</td>
<td>50.00</td><td>50.00</td>
<td>50.00</td><td>50.00</td>
<td>50.00</td><td>50.00</td>
</tr>
<tr>
<td>Bo240</td>
<td>60.69</td><td>60.07</td>
<td>50.45</td><td>50.27</td>
<td>49.92</td><td>52.89</td>
<td>53.69</td><td>54.41</td>
</tr>
<tr>
<td>Bo480</td>
<td>61.28</td><td>61.84</td>
<td>58.90</td><td>59.93</td>
<td>50.49</td><td>53.11</td>
<td>56.89</td><td>58.29</td>
</tr>
<tr>
<td>Bo960</td>
<td>67.50</td><td>68.07</td>
<td>59.20</td><td>60.26</td>
<td>50.39</td><td>51.64</td>
<td>59.03</td><td>59.99</td>
</tr>
<tr>
<td>Bo1920</td>
<td>75.20</td><td>76.27</td>
<td>60.57</td><td>61.05</td>
<td>51.86</td><td>53.13</td>
<td>62.54</td><td>63.48</td>
</tr>
<tr>
<td>Bo3840</td>
<td><strong>76.13</strong></td><td><strong>77.21</strong></td>
<td>59.19</td><td>57.91</td>
<td>53.36</td><td>54.01</td>
<td>62.89</td><td>63.04</td>
</tr>
<tr style="background-color:#e0f7fa;">
<td>Ours (α=0.5)</td>
<td>69.42</td><td>73.31</td>
<td><strong>73.60</strong></td><td><strong>77.91</strong></td>
<td><strong>55.50</strong></td><td><strong>58.80</strong></td>
<td><strong>66.17</strong></td><td><strong>70.01</strong></td>
</tr>
</tbody>
</table>
</div>
</div>
</div>
</div>
</section>
<!-- <section class="section hero is-light">
<div class="container is-max-desktop">
<div class="columns is-centered">
<div class="column is-four-fifths">
<h2 class="title is-3" style="text-align: center;"><img src="static/images/Idea.png" style="height: 50px; display: inline; vertical-align: middle;"/> Motivation of Speculative Rejection</h2>
<div style="display: flex; align-items: top; gap: 10px;">
<div style="flex: 1;">
<p>
Our exploration reveals that the information from long context tokens needed by adjacent tokens tends to be similar. With the context length established at 120K, we instruct the model to generate 256 tokens. By choosing the top-4K indices according to the attention scores of the last prefilled token, we use these indices to gather attention scores for the subsequently generated tokens and assess the score's recovery rate for the initially prefilled 120K tokens. It leads to high recovery across most layers and a slowly decreasing trend as the number of tokens increases.
</p>
</div>
<div style="flex: 0 0 40%; max-width: 50%;">
<img src="static/images/motivation.png" alt="Locality" width=300 />
</div>
</div>
<p>
This observation allows for <b>a single construction of the cache to suffice for multiple decoding steps, thereby amortizing the latency of constructing draft cache and boosting efficiency</b>. As new KV cache are introduced, guided by the understanding that recent words are more strongly correlated with the tokens currently being decoded, these entries will replace the less significant ones. Cache re-building operations can be scheduled at regular intervals or adaptively in response to a drop in the acceptance rate, which ensures that the cache remains dynamically aligned with the evolving context.
</p>
</div>
</div>
</div>
</div>
</section> -->
<section class="section hero is-light">
<div class="container is-max-desktop">
<div class="columns is-centered has-text-centered">
<div class="column is-four-fifths">
<h2 class="title is-3"><img src="static/images/Telescope.png" style="height: 50px; display: inline; vertical-align: middle;"/> Conclusion and Future Work</h2>
<div class="content has-text-justified">
<p>
Speculative Rejection is a general purpose techique to accelerate reward-oriented decoding from LLMs. The procedure is simple to implement while yielding substantially speedups over the baseline Best-of-N.
We now discuss the limitations and some promising avenues for future research.
<br>
<br>
<b>Prompt-dependent Stopping.</b>
Our implementation of speculative rejection leverages statistical correlations to early stop trajectories that are deemed unpromising. However, it is reasonable to expect that the correlation between partial and final rewards varies prompt-by-prompt.
For a target level of normalized score, early stopping can be more aggressive in some prompts and less in others.
This consideration suggests that setting the rejection rate <b>adaptively</b> can potentially achieve higher speedup and normalized score on different prompts.
We leave this opportunity for future research.
<br>
<br>
<b>Reward Models as Value Functions.</b>
Our method leverages the statistical correlation between the reward values at the decision tokens and upon termination. Concurrently, recent literature also suggest training reward models as value functions.
Doing so would enable reward models to predict the <b>expected</b> score upon completion at any point during the generation and thus be much more accurate models for our purposes. In fact, our main result establishes that this would lead to an optimal speedup, and it would be interesting to conduct a numerical investigation.
</p>
</div>
</div>
</div>
</div>
</section>
<!--BibTex citation -->
<section class="section" id="BibTeX">
<div class="container is-max-desktop content">
<h2 class="title">BibTeX</h2>
<pre><code>@article{sun2024fast,
title={Fast Best-of-N Decoding via Speculative Rejection},
author={Sun, Hanshi and Haider, Momin and Zhang, Ruiqi and Yang, Huitao and Qiu, Jiahao and Yin, Ming and Wang, Mengdi and Bartlett, Peter and Zanette, Andrea},
journal={arXiv preprint arXiv:2410.20290},
year={2024}
}</code></pre>
</div>
</section>
<!--End BibTex citation -->
<footer class="footer">
<div class="container">
<div class="columns is-centered">
<div class="column is-8">
<div class="content">
<p>
This page was built using the <a href="https://github.com/eliahuhorwitz/Academic-project-page-template" target="_blank">Academic Project Page Template</a> which was adopted from the <a href="https://nerfies.github.io" target="_blank">Nerfies</a> project page.
You are free to borrow the of this website, we just ask that you link back to this page in the footer. <br> This website is licensed under a <a rel="license" href="http://creativecommons.org/licenses/by-sa/4.0/" target="_blank">Creative
Commons Attribution-ShareAlike 4.0 International License</a>. The icons are created by GPT4.
</p>
</div>
</div>
</div>
</div>
</footer>
</body>
</html>