Skip to content

Commit

Permalink
more cleanups and fixes
Browse files Browse the repository at this point in the history
Signed-off-by: karthik2804 <[email protected]>
  • Loading branch information
karthik2804 committed Sep 17, 2024
1 parent 3352e7e commit 6afbcfb
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 14 deletions.
3 changes: 1 addition & 2 deletions crates/llm-local/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,6 @@ impl CachedInferencingModel for LlamaModels {
}
// Decode the token and add it to the output.
if let Some(t) = tokenizer.next_token(next_token)? {
print!("{}", t);
output_text.push_str(&t);
}
}
Expand All @@ -183,7 +182,7 @@ impl CachedInferencingModel for LlamaModels {
}

/// Loads a list of SafeTensors file paths from a given model directory and
/// path to the model index JSON file.
/// path to the model index JSON file relative to the model folder.
fn load_safetensors(model_dir: &Path, json_file: &str) -> Result<Vec<std::path::PathBuf>> {
let json_file = model_dir.join(json_file);
let json_file = std::fs::File::open(json_file)?;
Expand Down
30 changes: 18 additions & 12 deletions crates/llm-local/src/token_output_stream.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
/// Implementation for TokenOutputStream Code is borrowed from
/// https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs
/// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711)
///
/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a
/// streaming way rather than having to wait for the full decoding.
/// Implementation for TokenOutputStream Code is borrowed from
///
/// Borrowed from https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs
/// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711)
pub struct TokenOutputStream {
tokenizer: tokenizers::Tokenizer,
tokens: Vec<u32>,
Expand All @@ -21,16 +21,10 @@ impl TokenOutputStream {
}
}

fn decode(&self, tokens: &[u32]) -> anyhow::Result<String> {
match self.tokenizer.decode(tokens, true) {
Ok(str) => Ok(str),
Err(err) => anyhow::bail!("cannot decode: {err}"),
}
}

/// Processes the next token in the sequence, decodes the current token stream,
/// and returns any newly decoded text.
/// https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68
///
/// Based on the following code: <https://github.com/huggingface/text-generation-inference/blob/5ba53d44a18983a4de32d122f4cb46f4a17d9ef6/server/text_generation_server/models/model.py#L68>
pub fn next_token(&mut self, token: u32) -> anyhow::Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
Expand All @@ -50,6 +44,12 @@ impl TokenOutputStream {
}
}

/// Decodes the remaining tokens and returns any new text found.
///
/// This function decodes tokens from `self.prev_index` to the end and
/// compares it with the previously decoded portion (from `self.prev_index`
/// to `self.current_index`). If new text is found, it returns the
/// additional part as `Some(String)`. Otherwise, returns `None`.
pub fn decode_rest(&self) -> anyhow::Result<Option<String>> {
let prev_text = if self.tokens.is_empty() {
String::new()
Expand All @@ -65,4 +65,10 @@ impl TokenOutputStream {
Ok(None)
}
}

fn decode(&self, tokens: &[u32]) -> anyhow::Result<String> {
self.tokenizer
.decode(tokens, true)
.context("failed to decode token stream")
}
}

0 comments on commit 6afbcfb

Please sign in to comment.