diff --git a/crates/llm-local/src/lib.rs b/crates/llm-local/src/lib.rs index 1a10552f4..e7d658a02 100644 --- a/crates/llm-local/src/lib.rs +++ b/crates/llm-local/src/lib.rs @@ -23,7 +23,7 @@ type ModelName = String; #[derive(Clone)] pub struct LocalLlmEngine { registry: PathBuf, - inferencing_models: HashMap>, + inferencing_models: HashMap>, embeddings_models: HashMap>, } @@ -43,11 +43,11 @@ impl FromStr for InferencingModelArch { } } -/// `CachedInferencingModel` implies that the model is prepared and cached after -/// loading, allowing faster future requests by avoiding repeated file reads -/// and decoding. This trait does not specify anything about if the results are cached. +/// A model that is prepared and cached after loading. +/// +/// This trait does not specify anything about if the results are cached. #[async_trait] -trait CachedInferencingModel: Send + Sync { +trait InferencingModel: Send + Sync { async fn infer( &self, prompt: String, @@ -143,7 +143,7 @@ impl LocalLlmEngine { async fn inferencing_model( &mut self, model: wasi_llm::InferencingModel, - ) -> Result, wasi_llm::Error> { + ) -> Result, wasi_llm::Error> { let model = match self.inferencing_models.entry(model.clone()) { Entry::Occupied(o) => o.get().clone(), Entry::Vacant(v) => { @@ -328,7 +328,6 @@ fn load_tokenizer(tokenizer_file: &Path) -> anyhow::Result anyhow::Result { let device = &candle::Device::Cpu; - // TODO: Check if there is a safe way to load the model from the file let data = std::fs::read(model_file)?; let tensors = load_buffer(&data, device)?; let vb = VarBuilder::from_tensors(tensors, DType::F32, device); diff --git a/crates/llm-local/src/llama.rs b/crates/llm-local/src/llama.rs index 06166e3e2..96c50440c 100644 --- a/crates/llm-local/src/llama.rs +++ b/crates/llm-local/src/llama.rs @@ -1,5 +1,5 @@ -use crate::{token_output_stream, CachedInferencingModel}; -use anyhow::{anyhow, bail, Result}; +use crate::{token_output_stream, InferencingModel}; +use anyhow::{anyhow, bail, Context, Result}; use candle::{safetensors::load_buffer, utils, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ @@ -76,7 +76,7 @@ impl LlamaModels { } #[async_trait] -impl CachedInferencingModel for LlamaModels { +impl InferencingModel for LlamaModels { async fn infer( &self, prompt: String, @@ -185,7 +185,8 @@ impl CachedInferencingModel for LlamaModels { /// path to the model index JSON file relative to the model folder. fn load_safetensors(model_dir: &Path, json_file: &str) -> Result> { let json_file = model_dir.join(json_file); - let json_file = std::fs::File::open(json_file)?; + let json_file = std::fs::File::open(json_file) + .with_context(format!("Could not read model index file: {json_file:?}"))?; let json: serde_json::Value = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; let weight_map = match json.get("weight_map") {