Skip to content

Commit

Permalink
examples(phi-3-vision): Simplify input processing with KV cache (#296)
Browse files Browse the repository at this point in the history
With KV cache, full input sequence reconstruction is unnecessary.
Only process the newly generated token for each iteration.
  • Loading branch information
web3nomad authored Oct 15, 2024
1 parent cdd6be7 commit 87dc4f2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 20 deletions.
4 changes: 2 additions & 2 deletions examples/phi-3-vision/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@ This example currently only supports single image input.
The performance of ONNX-based LLM inference can be relatively slow, especially on CPU:

- On an Apple M1 Pro:
- For image+text input (about 300 tokens): ~5 seconds per output token
- For text-only input (about 10 tokens): ~200ms per output token
- For image+text input (about 300 tokens): ~7 tokens/s
- For text-only input (about 10 tokens): ~5 tokens/s

## Run this Example

Expand Down
31 changes: 13 additions & 18 deletions examples/phi-3-vision/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ const VISION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-vision.onnx";
const TEXT_EMBEDDING_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text-embedding.onnx";
const GENERATION_MODEL_NAME: &'static str = "phi-3-v-128k-instruct-text.onnx";

const MAX_LENGTH: usize = 100; // max length of the generated text
const MAX_LENGTH: usize = 1000; // max length of the generated text
const EOS_TOKEN_ID: i64 = 32007; // <|end|>
const USER_TOKEN_ID: i64 = 32010; // <|user|>
const VOCAB_SIZE: usize = 32064;
Expand Down Expand Up @@ -108,7 +108,7 @@ pub async fn generate_text(
image: &Option<DynamicImage>,
text: &str
) -> Result<()> {
let (mut inputs_embeds, mut attention_mask) = {
let (inputs_embeds, mut attention_mask) = {
let visual_features = get_image_embedding(&vision_model, &image)?;
let prompt = format_chat_template(&image, text);
let encoding = tokenizer.encode(prompt, true).map_err(|e| anyhow::anyhow!("Error encoding: {:?}", e))?;
Expand Down Expand Up @@ -139,12 +139,13 @@ pub async fn generate_text(
// 4. Head size (96)
let mut past_key_values: Vec<Array4<f32>> = vec![Array4::zeros((1, 32, 0, 96)); 64];
let mut generated_tokens: Vec<i64> = Vec::new();
let mut next_inputs_embeds = inputs_embeds.clone();
// Loop until <|end|> token is generated or max length is reached
for _ in 0..MAX_LENGTH {
// Prepare model inputs
let model_inputs = {
let mut model_inputs = ort::inputs![
"inputs_embeds" => inputs_embeds.clone(),
"inputs_embeds" => next_inputs_embeds.clone(),
"attention_mask" => attention_mask.clone(),
]?;
for i in 0..32 {
Expand Down Expand Up @@ -176,27 +177,21 @@ pub async fn generate_text(
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap())
.unwrap()
.0 as i64;

if next_token_id == EOS_TOKEN_ID {
break;
}

generated_tokens.push(next_token_id);
// Log the generated text
let output_ids: Vec<u32> = generated_tokens.iter().map(|&id| id as u32).collect();
let generated_text = tokenizer.decode(&output_ids, false).unwrap();
tracing::info!("Generated text: {}", generated_text);

if next_token_id == EOS_TOKEN_ID {
break;
}

// Update inputs_embeds, attention_mask, and past_key_values for the next iteration
(inputs_embeds, attention_mask) = {
let new_token_id = Array2::from_elem((1, 1), next_token_id);
let new_token_embed = get_text_embedding(&text_embedding_model, &new_token_id)?;
// Merge the new token embedding with the previous embeddings
let mut combined_embeds = Array3::zeros((inputs_embeds.shape()[0], inputs_embeds.shape()[1] + 1, inputs_embeds.shape()[2]));
combined_embeds.slice_mut(s![.., ..inputs_embeds.shape()[1], ..]).assign(&inputs_embeds);
combined_embeds.slice_mut(s![.., inputs_embeds.shape()[1].., ..]).assign(&new_token_embed);
let new_attention_mask = Array2::ones((1, attention_mask.shape()[1] + 1));
(combined_embeds, new_attention_mask)
};
// Update current_embeds, attention_mask, and past_key_values for the next iteration
let new_token_id = Array2::from_elem((1, 1), next_token_id);
next_inputs_embeds = get_text_embedding(&text_embedding_model, &new_token_id)?;
attention_mask = Array2::ones((1, attention_mask.shape()[1] + 1));
for i in 0..32 {
past_key_values[i * 2] = model_outputs[format!("present.{}.key", i)]
.try_extract_tensor::<f32>()?
Expand Down

0 comments on commit 87dc4f2

Please sign in to comment.