This repository contains the implementation of the ngram-decoding (aka prompt lookup decoding) method for faster LLM inference.
This exploration aims to understand the using n-grams for loseless accelaration of LLM inference, as proposed in:
Combining the core ideas from both methods, I explored the following algorithm built upon the aforementioned works:
- Match the n-grams in the prompt with the tokens in the input sequence, and obtain
K
candidate tokens. - If multiple candidates are found, select the set with the most candidate tokens. In case of a tie, a random selection is made.
- If no candidate tokens are identified, default to single-step greedy decoding.
Note
The number of tokens generated per step in n-gram decoding ranges from 1
to K+1
.
- Repeat the above steps until either the maximum
n
number of tokens is reached or theEOS
(e.g.,<|eot_id|>
) token is generated.
This project uses uv for dependency management. To install UV, run the following command:
# On macOS and Linux.
curl -LsSf https://astral.sh/uv/install.sh | sh
# On Windows.
powershell -c "irm https://astral.sh/uv/install.ps1 | iex"
# With pip.
pip install uv
# With pipx.
pipx install uv
# With Homebrew.
brew install uv
# With Pacman.
pacman -S uv
Thereafter, install the rest of the dependencies using uv:
# create a virtual env
uv venv
# install dependencies
uv pip install -r requirements.txt # Install from a requirements.txt file.
Note
Currently, the script only supports Meta-Llama-3.1-8B-Instruct
model.
# check cli options
python main.py --help
usage: main.py [-h] [--model MODEL] --decoding-method {greedy,ngram}
optional arguments:
-h, --help show this help message and exit
--model MODEL
--decoding-method {greedy,ngram}
Running LLM inference comparison script:
# ngram decoding
python main.py --model meta-llama/Meta-Llama-3.1-8B-Instruct \
--decoding-method ngram
# greedy decoding
python main.py --model meta-llama/Meta-Llama-3.1-8B-Instruct \
--decoding-method greedy
The following results are obtained on A100
GPU with 40GB
RAM, with the following settings:
ngrams_size
= 3K
= 10n
= 400
ngram-vs-greedy.mp4
Using the following example prompt:
<|start_header_id|>user<|end_header_id|>
Code:
```python
def generate_candidate_tokens(
input_ids: torch.Tensor, n_grams: torch.Tensor, ngrams_size: int, K: int
):
# unfold the tensor into windows of `pattern_len + following_elements_count`
window = input_ids.unfold(dimension=1, size=ngrams_size, step=1)
# compare each window with the pattern (only the parts corresponding to the pattern)
matching_window_indices = (window == n_grams).all(dim=2)
# extract the indices where there are matches
matching_indices = matching_window_indices.nonzero(as_tuple=True)[1]
# find candidates with the longest length
# based on: https://arxiv.org/pdf/2304.04487
# we choose the candidate with the longest length at random if there are multiple candidates
candidates = []
max_length = K
for idx in matching_indices:
start_idx = idx + ngrams_size
end_idx = start_idx + K
candidate = input_ids[0, start_idx : min(end_idx, input_ids.size(1))]
length = len(candidate)
if length == max_length:
candidates.append(candidate)
else:
# we do not consider prefix with no candidates
if length > max_length:
max_length = length
candidates = [candidate]
if candidates:
chosen_candidate = candidates[np.random.randint(len(candidates))]
else:
chosen_candidate = torch.tensor([], dtype=torch.long, device=input_ids.device)
return chosen_candidate.unsqueeze(dim=0)
```
Question: Can you the variable name 'candidates' to 'candidates_tokens'?
Modified code:
<|start_header_id|>assistant<|end_header_id|>
The following timings are observed:
Decoding Method | Time Taken (s) | Token/secs | Speedup |
---|---|---|---|
Greedy Decoding | 26.4 | 14.0 | 1x |
Ngrams Decoding | 12.8 | 28.9 | ~2x |
In the simple demonstration experiment, we achieved results comparable to those of the original Prompt Lookup Decoding implementation and the figures reported in LLMA Decoding. Both decoding methods demonstrated approximately a 2-3x improvement in speed over greedy decoding.
@misc{saxena2023prompt,
title = {Prompt Lookup Decoding},
author = {Apoorv Saxena},
year = {2023},
month = {November},
url = {https://github.com/apoorvumang/prompt-lookup-decoding/}
}
@misc{yang2023inferencereferencelosslessacceleration,
title={Inference with Reference: Lossless Acceleration of Large Language Models},
author={Nan Yang and Tao Ge and Liang Wang and Binxing Jiao and Daxin Jiang and Linjun Yang and Rangan Majumder and Furu Wei},
year={2023},
eprint={2304.04487},
archivePrefix={arXiv},
primaryClass={cs.CL},
url={https://arxiv.org/abs/2304.04487},
}
The implementation for ngram-decoding is build upon the following repository: