Skip to content

Commit

Permalink
remove Box::clone on cache
Browse files Browse the repository at this point in the history
Signed-off-by: karthik2804 <[email protected]>
  • Loading branch information
karthik2804 committed Aug 27, 2024
1 parent d48367a commit 8a822de
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions crates/llm-local/src/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ pub fn auto_device() -> Result<Device> {
pub(crate) struct LlamaModels {
model: Arc<Llama>,
config: Config,
cache: Box<Cache>,
cache: Cache,
tokenizer: Tokenizer,
device: Device,
}
Expand All @@ -47,8 +47,8 @@ impl LlamaModels {
let tokenizer = Tokenizer::from_file(tokenizer_path).map_err(|e| anyhow!(e.to_string()))?;
let config: LlamaConfig = serde_json::from_slice(&fs::read(config_path)?)?;

// flash attention is supposed to minimize memory read and writes
let config = config.into_config(true);
// TODO: flash attention is supposed to minimize memory read and writes - Do we want to turn it on
let config = config.into_config(false);
let cache = llama::Cache::new(true, dtype, &config, &device)?;

let safetensor_files = load_safetensors(&model_dir, MODEL_SAFETENSORS_INDEX)?;
Expand All @@ -60,7 +60,7 @@ impl LlamaModels {
Ok(Self {
model: Arc::new(model),
config,
cache: Box::new(cache),
cache,
tokenizer,
device,
})
Expand All @@ -77,7 +77,7 @@ impl CachedInferencingModel for LlamaModels {
let model = Arc::clone(&self.model);
let config = &self.config;
let tokenizer = self.tokenizer.clone();
let mut cache = Box::clone(&self.cache);
let mut cache = self.cache.clone();
let eos_token_id = config.clone().eos_token_id.or_else(|| {
tokenizer
.token_to_id(EOS_TOKEN)
Expand Down

0 comments on commit 8a822de

Please sign in to comment.