-
Notifications
You must be signed in to change notification settings - Fork 947
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add olmo support * add olmo readme * Fix fmt. * Fix clippy. * Get olmo to work on cuda. --------- Co-authored-by: laurent <[email protected]>
- Loading branch information
1 parent
cfab6e7
commit 6cf82fd
Showing
4 changed files
with
658 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
# candle-olmo: Open Language Models designed to enable the science of language models | ||
|
||
OLMo is a series of Open Language Models designed to enable the science of language models. | ||
|
||
- **Project Page:** https://allenai.org/olmo | ||
- **Paper:** [Link](https://arxiv.org/abs/2402.00838) | ||
- **Technical blog post:** https://blog.allenai.org/olmo-open-language-model-87ccfc95f580 | ||
- **W&B Logs:** https://wandb.ai/ai2-llm/OLMo-1B/reports/OLMo-1B--Vmlldzo2NzY1Njk1 | ||
<!-- - **Press release:** TODO --> | ||
|
||
## Running the example | ||
|
||
```bash | ||
$ cargo run --example olmo --release -- --prompt "It is only with the heart that one can see rightly" | ||
|
||
avx: true, neon: false, simd128: false, f16c: true | ||
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64 | ||
retrieved the files in 354.977µs | ||
loaded the model in 19.87779666s | ||
It is only with the heart that one can see rightly; what is essential is invisible to the eye. | ||
``` | ||
|
||
Various model sizes are available via the `--model` argument. | ||
|
||
```bash | ||
$ cargo run --example olmo --release -- --model 1.7-7b --prompt 'It is only with the heart that one can see rightly' | ||
|
||
avx: true, neon: false, simd128: false, f16c: true | ||
temp: 0.20 repeat-penalty: 1.10 repeat-last-n: 64 | ||
retrieved the files in 1.226087ms | ||
loaded the model in 171.274578609s | ||
It is only with the heart that one can see rightly; what is essential is invisible to the eye.” | ||
~ Antoine de Saint-Exupery, The Little Prince | ||
I am a big fan of this quote. It reminds me that I need to be open and aware of my surroundings in order to truly appreciate them. | ||
``` | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,284 @@ | ||
#[cfg(feature = "mkl")] | ||
extern crate intel_mkl_src; | ||
|
||
#[cfg(feature = "accelerate")] | ||
extern crate accelerate_src; | ||
|
||
use anyhow::{Error as E, Result}; | ||
use clap::{Parser, ValueEnum}; | ||
|
||
use candle_transformers::models::olmo::{Config, Model as OLMo}; | ||
|
||
use candle::{DType, Device, Tensor}; | ||
use candle_examples::token_output_stream::TokenOutputStream; | ||
use candle_nn::VarBuilder; | ||
use candle_transformers::generation::LogitsProcessor; | ||
use hf_hub::{api::sync::Api, Repo, RepoType}; | ||
use tokenizers::Tokenizer; | ||
|
||
enum Model { | ||
OLMo(OLMo), | ||
} | ||
|
||
struct TextGeneration { | ||
model: Model, | ||
device: Device, | ||
tokenizer: TokenOutputStream, | ||
logits_processor: LogitsProcessor, | ||
repeat_penalty: f32, | ||
repeat_last_n: usize, | ||
} | ||
|
||
impl TextGeneration { | ||
#[allow(clippy::too_many_arguments)] | ||
fn new( | ||
model: Model, | ||
tokenizer: Tokenizer, | ||
seed: u64, | ||
temp: Option<f64>, | ||
top_p: Option<f64>, | ||
repeat_penalty: f32, | ||
repeat_last_n: usize, | ||
device: &Device, | ||
) -> Self { | ||
let logits_processor = LogitsProcessor::new(seed, temp, top_p); | ||
Self { | ||
model, | ||
tokenizer: TokenOutputStream::new(tokenizer), | ||
logits_processor, | ||
repeat_penalty, | ||
repeat_last_n, | ||
device: device.clone(), | ||
} | ||
} | ||
|
||
fn run(&mut self, prompt: &str, sample_len: usize) -> Result<()> { | ||
use std::io::Write; | ||
self.tokenizer.clear(); | ||
let mut tokens = self | ||
.tokenizer | ||
.tokenizer() | ||
.encode(prompt, false) | ||
.map_err(E::msg)? | ||
.get_ids() | ||
.to_vec(); | ||
for &t in tokens.iter() { | ||
if let Some(t) = self.tokenizer.next_token(t)? { | ||
print!("{t}") | ||
} | ||
} | ||
std::io::stdout().flush()?; | ||
|
||
let mut generated_tokens = 0usize; | ||
let eos_token = match self.tokenizer.get_token("<|endoftext|>") { | ||
Some(token) => token, | ||
None => anyhow::bail!("cannot find the <|endoftext|> token"), | ||
}; | ||
let start_gen = std::time::Instant::now(); | ||
for index in 0..sample_len { | ||
let context_size = if index > 0 { 1 } else { tokens.len() }; | ||
let start_pos = tokens.len().saturating_sub(context_size); | ||
let ctxt = &tokens[start_pos..]; | ||
let input = Tensor::new(ctxt, &self.device)?.unsqueeze(0)?; | ||
let logits = match &mut self.model { | ||
Model::OLMo(m) => m.forward(&input, start_pos)?, | ||
}; | ||
let logits = logits.squeeze(0)?.squeeze(0)?.to_dtype(DType::F32)?; | ||
let logits = if self.repeat_penalty == 1. { | ||
logits | ||
} else { | ||
let start_at = tokens.len().saturating_sub(self.repeat_last_n); | ||
candle_transformers::utils::apply_repeat_penalty( | ||
&logits, | ||
self.repeat_penalty, | ||
&tokens[start_at..], | ||
)? | ||
}; | ||
|
||
let next_token = self.logits_processor.sample(&logits)?; | ||
tokens.push(next_token); | ||
generated_tokens += 1; | ||
if next_token == eos_token { | ||
break; | ||
} | ||
if let Some(t) = self.tokenizer.next_token(next_token)? { | ||
print!("{t}"); | ||
std::io::stdout().flush()?; | ||
} | ||
} | ||
let dt = start_gen.elapsed(); | ||
if let Some(rest) = self.tokenizer.decode_rest().map_err(E::msg)? { | ||
print!("{rest}"); | ||
} | ||
std::io::stdout().flush()?; | ||
println!( | ||
"\n{generated_tokens} tokens generated ({:.2} token/s)", | ||
generated_tokens as f64 / dt.as_secs_f64(), | ||
); | ||
Ok(()) | ||
} | ||
} | ||
|
||
#[derive(Clone, Copy, Debug, ValueEnum, PartialEq, Eq)] | ||
enum Which { | ||
#[value(name = "1b")] | ||
W1b, | ||
#[value(name = "7b")] | ||
W7b, | ||
#[value(name = "7b-twin-2t")] | ||
W7bTwin2T, | ||
#[value(name = "1.7-7b")] | ||
V1_7W7b, | ||
} | ||
|
||
#[derive(Parser, Debug)] | ||
#[command(author, version, about, long_about = None)] | ||
struct Args { | ||
/// Run on CPU rather than on GPU. | ||
#[arg(long)] | ||
cpu: bool, | ||
|
||
/// Enable tracing (generates a trace-timestamp.json file). | ||
#[arg(long)] | ||
tracing: bool, | ||
|
||
#[arg(long)] | ||
prompt: String, | ||
|
||
/// The temperature used to generate samples. | ||
#[arg(long)] | ||
temperature: Option<f64>, | ||
|
||
/// Nucleus sampling probability cutoff. | ||
#[arg(long)] | ||
top_p: Option<f64>, | ||
|
||
/// The seed to use when generating random samples. | ||
#[arg(long, default_value_t = 299792458)] | ||
seed: u64, | ||
|
||
/// The length of the sample to generate (in tokens). | ||
#[arg(long, short = 'n', default_value_t = 1000)] | ||
sample_len: usize, | ||
|
||
#[arg(long)] | ||
model_id: Option<String>, | ||
|
||
#[arg(long, default_value = "main")] | ||
revision: String, | ||
|
||
#[arg(long, default_value = "1b")] | ||
model: Which, | ||
|
||
#[arg(long)] | ||
tokenizer_file: Option<String>, | ||
|
||
#[arg(long)] | ||
weight_files: Option<String>, | ||
|
||
/// Penalty to be applied for repeating tokens, 1. means no penalty. | ||
#[arg(long, default_value_t = 1.1)] | ||
repeat_penalty: f32, | ||
|
||
/// The context size to consider for the repeat penalty. | ||
#[arg(long, default_value_t = 64)] | ||
repeat_last_n: usize, | ||
} | ||
|
||
fn main() -> Result<()> { | ||
use tracing_chrome::ChromeLayerBuilder; | ||
use tracing_subscriber::prelude::*; | ||
|
||
let args = Args::parse(); | ||
let _guard = if args.tracing { | ||
let (chrome_layer, guard) = ChromeLayerBuilder::new().build(); | ||
tracing_subscriber::registry().with(chrome_layer).init(); | ||
Some(guard) | ||
} else { | ||
None | ||
}; | ||
println!( | ||
"avx: {}, neon: {}, simd128: {}, f16c: {}", | ||
candle::utils::with_avx(), | ||
candle::utils::with_neon(), | ||
candle::utils::with_simd128(), | ||
candle::utils::with_f16c() | ||
); | ||
println!( | ||
"temp: {:.2} repeat-penalty: {:.2} repeat-last-n: {}", | ||
args.temperature.unwrap_or(0.), | ||
args.repeat_penalty, | ||
args.repeat_last_n | ||
); | ||
|
||
let start = std::time::Instant::now(); | ||
let api = Api::new()?; | ||
let model_id = match args.model_id { | ||
Some(model_id) => model_id, | ||
None => match args.model { | ||
Which::W1b => "allenai/OLMo-1B-hf".to_string(), | ||
Which::W7b => "allenai/OLMo-7B-hf".to_string(), | ||
Which::W7bTwin2T => "allenai/OLMo-7B-Twin-2T-hf".to_string(), | ||
Which::V1_7W7b => "allenai/OLMo-1.7-7B-hf".to_string(), | ||
}, | ||
}; | ||
|
||
let repo = api.repo(Repo::with_revision( | ||
model_id, | ||
RepoType::Model, | ||
args.revision, | ||
)); | ||
let tokenizer_filename = match args.tokenizer_file { | ||
Some(file) => std::path::PathBuf::from(file), | ||
None => repo.get("tokenizer.json")?, | ||
}; | ||
let filenames = match args.weight_files { | ||
Some(files) => files | ||
.split(',') | ||
.map(std::path::PathBuf::from) | ||
.collect::<Vec<_>>(), | ||
None => match args.model { | ||
Which::W1b => { | ||
vec![repo.get("model.safetensors")?] | ||
} | ||
_ => candle_examples::hub_load_safetensors(&repo, "model.safetensors.index.json")?, | ||
}, | ||
}; | ||
|
||
println!("retrieved the files in {:?}", start.elapsed()); | ||
let tokenizer = Tokenizer::from_file(tokenizer_filename).map_err(E::msg)?; | ||
|
||
let start = std::time::Instant::now(); | ||
let config = { | ||
let config_filename = repo.get("config.json")?; | ||
let config: Config = serde_json::from_slice(&std::fs::read(config_filename)?)?; | ||
config | ||
}; | ||
|
||
let device = candle_examples::device(args.cpu)?; | ||
let model = { | ||
let dtype = if device.is_cuda() { | ||
DType::BF16 | ||
} else { | ||
DType::F32 | ||
}; | ||
let vb = unsafe { VarBuilder::from_mmaped_safetensors(&filenames, dtype, &device)? }; | ||
let model = OLMo::new(&config, vb)?; | ||
Model::OLMo(model) | ||
}; | ||
|
||
println!("loaded the model in {:?}", start.elapsed()); | ||
|
||
let mut pipeline = TextGeneration::new( | ||
model, | ||
tokenizer, | ||
args.seed, | ||
args.temperature, | ||
args.top_p, | ||
args.repeat_penalty, | ||
args.repeat_last_n, | ||
&device, | ||
); | ||
pipeline.run(&args.prompt, args.sample_len)?; | ||
Ok(()) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.