From 08f1611490221271f30a07966d8826297340f4aa Mon Sep 17 00:00:00 2001 From: karthik2804 Date: Mon, 16 Sep 2024 12:58:21 +0200 Subject: [PATCH] address some PR comments Signed-off-by: karthik2804 --- Cargo.lock | 144 ++++++++++---------- crates/factor-llm/src/spin.rs | 18 +-- crates/llm-local/src/lib.rs | 8 +- crates/llm-local/src/llama.rs | 23 +++- crates/llm-local/src/token_output_stream.rs | 35 ++--- crates/llm-local/src/utils.rs | 25 ---- crates/runtime-config/src/lib.rs | 10 +- crates/runtime-factors/src/build.rs | 4 - crates/runtime-factors/src/lib.rs | 3 +- 9 files changed, 110 insertions(+), 160 deletions(-) delete mode 100644 crates/llm-local/src/utils.rs diff --git a/Cargo.lock b/Cargo.lock index a89c0ec9e1..152de41793 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -377,7 +377,7 @@ checksum = "30c5ef0ede93efbf733c1a727f3b6b5a1060bbedd5600183e66f6e4be4af0ec5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -444,7 +444,7 @@ checksum = "16e62a023e7c117e27523144c5d2459f4397fcc3cab0085af8e2224f643a0193" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -474,7 +474,7 @@ checksum = "6e0c28dcc82d7c8ead5cb13beb15405b57b8546e93215673ff8ca0349a028107" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -698,7 +698,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -847,7 +847,7 @@ checksum = "0cc8b54b395f2fcfbb3d90c47b01c7f444d94d05bdeb775811dec868ac3bbc26" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -885,14 +885,14 @@ dependencies = [ "cudarc", "gemm", "half", - "memmap2 0.9.4", + "memmap2 0.9.5", "metal", "num-traits 0.2.18", "num_cpus", "rand 0.8.5", "rand_distr", "rayon", - "safetensors 0.4.4", + "safetensors 0.4.5", "thiserror", "yoke", "zip", @@ -928,7 +928,7 @@ dependencies = [ "metal", "num-traits 0.2.18", "rayon", - "safetensors 0.4.4", + "safetensors 0.4.5", "serde 1.0.197", "thiserror", ] @@ -1225,7 +1225,7 @@ dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -1728,7 +1728,7 @@ dependencies = [ "proc-macro2", "quote", "strsim 0.11.1", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -1750,7 +1750,7 @@ checksum = "733cabb43482b1a1b53eee8583c2b9e8684d592215ea83efd305dd31bc2f0178" dependencies = [ "darling_core 0.20.9", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -1802,7 +1802,7 @@ checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -1816,11 +1816,11 @@ dependencies = [ [[package]] name = "derive_builder" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0350b5cb0331628a5916d6c5c0b72e97393b8b6b03b47a9284f4e7f5a405ffd7" +checksum = "cd33f37ee6a119146a1781d3356a7c26028f83d779b2e04ecd45fdc75c76877b" dependencies = [ - "derive_builder_macro 0.20.0", + "derive_builder_macro 0.20.1", ] [[package]] @@ -1837,14 +1837,14 @@ dependencies = [ [[package]] name = "derive_builder_core" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d48cda787f839151732d396ac69e3473923d54312c070ee21e9effcaa8ca0b1d" +checksum = "7431fa049613920234f22c47fdc33e6cf3ee83067091ea4277a3f8c4587aae38" dependencies = [ "darling 0.20.9", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -1859,12 +1859,12 @@ dependencies = [ [[package]] name = "derive_builder_macro" -version = "0.20.0" +version = "0.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "206868b8242f27cecce124c19fd88157fbd0dd334df2587f36417bafbc85097b" +checksum = "4abae7035bf79b9877b779505d8cf3749285b80c43941eda66604841889451dc" dependencies = [ - "derive_builder_core 0.20.0", - "syn 2.0.75", + "derive_builder_core 0.20.1", + "syn 2.0.77", ] [[package]] @@ -1992,7 +1992,7 @@ checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -2124,14 +2124,14 @@ dependencies = [ [[package]] name = "enum-as-inner" -version = "0.6.0" +version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" +checksum = "a1e6a265c649f3f5979b601d26f1d05ada116434c87741c9493cb56218f76cbc" dependencies = [ - "heck 0.4.1", + "heck 0.5.0", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -2152,7 +2152,7 @@ checksum = "5c785274071b1b420972453b306eeca06acf4633829db4223b58a2a8c5953bc4" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -2270,7 +2270,7 @@ dependencies = [ "prettyplease", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -2437,7 +2437,7 @@ checksum = "1a5c6c585bc94aaf2c7b51dd4c2ba22680844aba4c687be581871a6f518c5742" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -2590,7 +2590,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -3929,7 +3929,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", - "windows-targets 0.48.5", + "windows-targets 0.52.6", ] [[package]] @@ -4174,7 +4174,7 @@ dependencies = [ "proc-macro2", "quote", "regex-syntax 0.6.29", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -4189,7 +4189,7 @@ dependencies = [ "proc-macro2", "quote", "regex-syntax 0.8.3", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -4319,9 +4319,9 @@ dependencies = [ [[package]] name = "memmap2" -version = "0.9.4" +version = "0.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322" +checksum = "fd3f7eed9d3848f8b98834af67102b720745c4ec028fcd0aa0239277e7de374f" dependencies = [ "libc", "stable_deref_trait", @@ -4401,7 +4401,7 @@ checksum = "49e7bc1560b95a3c4a25d03de42fe76ca718ab92d1a22a55b9b4cf67b3ae635c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -4412,7 +4412,7 @@ checksum = "dcf09caffaac8068c346b6df2a7fc27a177fd20b39421a39ce0a211bde679a6c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -4488,7 +4488,7 @@ checksum = "a7ce64b975ed4f123575d11afd9491f2e37bbd5813fbfbc0f09ae1fbddea74e0" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -4807,7 +4807,7 @@ dependencies = [ "proc-macro-crate", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -5018,7 +5018,7 @@ checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -5339,7 +5339,7 @@ dependencies = [ "pest_meta", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -5419,7 +5419,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -5570,7 +5570,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8d3928fb5db768cb86f891ff014f0144589297e3c6a1aba6ed7cecfdace270c7" dependencies = [ "proc-macro2", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -5692,7 +5692,7 @@ dependencies = [ "prost", "prost-types", "regex", - "syn 2.0.75", + "syn 2.0.77", "tempfile", ] @@ -5706,7 +5706,7 @@ dependencies = [ "itertools 0.12.1", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -6501,9 +6501,9 @@ dependencies = [ [[package]] name = "safetensors" -version = "0.4.4" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7725d4d98fa515472f43a6e2bbf956c48e06b89bb50593a040e5945160214450" +checksum = "44560c11236a6130a46ce36c836a62936dc81ebf8c36a37947423571be0e55b6" dependencies = [ "serde 1.0.197", "serde_json", @@ -6685,7 +6685,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -6699,9 +6699,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.125" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83c8e735a073ccf5be70aa8066aa984eaf2fa000db6c8d0100ae605b366d31ed" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", "memchr", @@ -6747,7 +6747,7 @@ checksum = "6c64451ba24fc7a6a2d60fc75dd9c83c90903b19028d4eff35e88fc1e86564e9" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -6798,7 +6798,7 @@ dependencies = [ "darling 0.20.9", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -7524,7 +7524,7 @@ dependencies = [ "expander", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -8158,9 +8158,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.75" +version = "2.0.77" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f6af063034fc1935ede7be0122941bafa9bacb949334d090b77ca98b5817c7d9" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" dependencies = [ "proc-macro2", "quote", @@ -8202,7 +8202,7 @@ checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -8354,7 +8354,7 @@ version = "0.0.0" dependencies = [ "heck 0.4.1", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -8427,7 +8427,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -8527,7 +8527,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e500fad1dd3af3d626327e6a3fe5050e664a6eaa4708b8ca92f1794aaf73e6fd" dependencies = [ "aho-corasick", - "derive_builder 0.20.0", + "derive_builder 0.20.1", "esaxx-rs", "getrandom 0.2.12", "indicatif", @@ -8588,7 +8588,7 @@ checksum = "693d596312e88961bc67d7f1f97af8a70227d9f90c31bba5806eec004978d752" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -8870,7 +8870,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -8956,7 +8956,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97fee6b57c6a41524a810daee9286c02d7752c4253064d0b05472833a438f675" dependencies = [ "cfg-if", - "rand 0.7.3", + "rand 0.8.5", "static_assertions", ] @@ -9419,7 +9419,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", "wasm-bindgen-shared", ] @@ -9453,7 +9453,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -9829,7 +9829,7 @@ dependencies = [ "anyhow", "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", "wasmtime-component-util", "wasmtime-wit-bindgen", "wit-parser 0.209.1", @@ -9956,7 +9956,7 @@ checksum = "de5a9bc4f44ceeb168e9e8e3be4e0b4beb9095b468479663a9e24c667e36826f" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -10226,7 +10226,7 @@ dependencies = [ "proc-macro2", "quote", "shellexpand 2.1.2", - "syn 2.0.75", + "syn 2.0.77", "witx", ] @@ -10238,7 +10238,7 @@ checksum = "cc26129a8aea20b62c961d1b9ab4a3c3b56b10042ed85d004f8678af0f21ba6e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", "wiggle-generate", ] @@ -10777,7 +10777,7 @@ checksum = "28cc31741b18cb6f1d5ff12f5b7523e3d6eb0852bbbad19d73905511d9849b95" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", "synstructure 0.13.1", ] @@ -10864,7 +10864,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", ] [[package]] @@ -10884,7 +10884,7 @@ checksum = "0ea7b4a3637ea8669cedf0f1fd5c286a17f3de97b8dd5a70a6c167a1730e63a5" dependencies = [ "proc-macro2", "quote", - "syn 2.0.75", + "syn 2.0.77", "synstructure 0.13.1", ] diff --git a/crates/factor-llm/src/spin.rs b/crates/factor-llm/src/spin.rs index 00d1c2126c..ec0f8d6345 100644 --- a/crates/factor-llm/src/spin.rs +++ b/crates/factor-llm/src/spin.rs @@ -44,7 +44,6 @@ mod local { /// The default engine creator for the LLM factor when used in the Spin CLI. pub fn default_engine_creator( state_dir: Option, - use_gpu: bool, ) -> anyhow::Result { #[cfg(feature = "llm")] let engine = { @@ -53,11 +52,11 @@ pub fn default_engine_creator( Some(ref dir) => dir.clone(), None => std::env::current_dir().context("failed to get current working directory")?, }; - spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"), use_gpu) + spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models")) }; #[cfg(not(feature = "llm"))] let engine = { - let _ = (state_dir, use_gpu); + let _ = (state_dir); noop::NoopLlmEngine }; let engine = Arc::new(Mutex::new(engine)) as Arc>; @@ -91,7 +90,6 @@ impl LlmEngine for RemoteHttpLlmEngine { pub fn runtime_config_from_toml( table: &impl GetTomlValue, state_dir: Option, - use_gpu: bool, ) -> anyhow::Result> { let Some(value) = table.get("llm_compute") else { return Ok(None); @@ -99,7 +97,7 @@ pub fn runtime_config_from_toml( let config: LlmCompute = value.clone().try_into()?; Ok(Some(RuntimeConfig { - engine: config.into_engine(state_dir, use_gpu)?, + engine: config.into_engine(state_dir)?, })) } @@ -111,19 +109,15 @@ pub enum LlmCompute { } impl LlmCompute { - fn into_engine( - self, - state_dir: Option, - use_gpu: bool, - ) -> anyhow::Result>> { + fn into_engine(self, state_dir: Option) -> anyhow::Result>> { let engine: Arc> = match self { #[cfg(not(feature = "llm"))] LlmCompute::Spin => { - let _ = (state_dir, use_gpu); + let _ = (state_dir); Arc::new(Mutex::new(noop::NoopLlmEngine)) } #[cfg(feature = "llm")] - LlmCompute::Spin => default_engine_creator(state_dir, use_gpu)?.create(), + LlmCompute::Spin => default_engine_creator(state_dir)?.create(), LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new( config.url, config.auth_token, diff --git a/crates/llm-local/src/lib.rs b/crates/llm-local/src/lib.rs index fbe4bde2c2..2edbf9af5c 100644 --- a/crates/llm-local/src/lib.rs +++ b/crates/llm-local/src/lib.rs @@ -1,7 +1,6 @@ mod bert; mod llama; mod token_output_stream; -mod utils; use anyhow::Context; use bert::{BertModel, Config}; @@ -24,7 +23,6 @@ type ModelName = String; #[derive(Clone)] pub struct LocalLlmEngine { registry: PathBuf, - _use_gpu: bool, inferencing_models: HashMap>, embeddings_models: HashMap>, } @@ -61,7 +59,6 @@ impl LocalLlmEngine { prompt: String, params: wasi_llm::InferencingParams, ) -> Result { - // return self.inference(model).await; let model = self.inferencing_model(model).await?; model @@ -83,10 +80,9 @@ impl LocalLlmEngine { } impl LocalLlmEngine { - pub fn new(registry: PathBuf, _use_gpu: bool) -> Self { + pub fn new(registry: PathBuf) -> Self { Self { registry, - _use_gpu, inferencing_models: Default::default(), embeddings_models: Default::default(), } @@ -145,8 +141,6 @@ impl LocalLlmEngine { &mut self, model: wasi_llm::InferencingModel, ) -> Result, wasi_llm::Error> { - // let use_gpu = self.use_gpu; - let model = match self.inferencing_models.entry(model.clone()) { Entry::Occupied(o) => o.get().clone(), Entry::Vacant(v) => { diff --git a/crates/llm-local/src/llama.rs b/crates/llm-local/src/llama.rs index be59c1c413..612493b7c9 100644 --- a/crates/llm-local/src/llama.rs +++ b/crates/llm-local/src/llama.rs @@ -1,5 +1,5 @@ -use crate::{token_output_stream, utils::load_safetensors, CachedInferencingModel}; -use anyhow::{anyhow, Result}; +use crate::{token_output_stream, CachedInferencingModel}; +use anyhow::{anyhow, Context, Result}; use candle::{utils, Device, Tensor}; use candle_nn::VarBuilder; use candle_transformers::{ @@ -7,9 +7,10 @@ use candle_transformers::{ models::llama::{self, Cache, Config, Llama, LlamaConfig}, }; use rand::{RngCore, SeedableRng}; +use serde::Deserialize; use spin_core::async_trait; use spin_world::v2::llm::{self as wasi_llm, InferencingUsage}; -use std::{fs, path::Path, sync::Arc}; +use std::{collections::HashMap, fs, path::Path, sync::Arc}; use tokenizers::Tokenizer; const TOKENIZER_FILENAME: &str = "tokenizer.json"; @@ -166,3 +167,19 @@ impl CachedInferencingModel for LlamaModels { }) } } + +#[derive(Deserialize)] +struct SafeTensorsJson { + weight_map: HashMap, +} + +fn load_safetensors(model_dir: &Path, json_file: &str) -> Result> { + let json_file = model_dir.join(json_file); + let json_file = std::fs::File::open(&json_file) + .with_context(|| format!("Error while opening {json_file:?}"))?; + + let json: SafeTensorsJson = serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; + let mut safetensors_files = Vec::new(); + safetensors_files.extend(json.weight_map.values().map(|v| model_dir.join(v))); + Ok(safetensors_files) +} diff --git a/crates/llm-local/src/token_output_stream.rs b/crates/llm-local/src/token_output_stream.rs index 34af97ca55..48e60ad0f2 100644 --- a/crates/llm-local/src/token_output_stream.rs +++ b/crates/llm-local/src/token_output_stream.rs @@ -1,6 +1,9 @@ -// Implementation for TokenOutputStream Code is borrow from -// https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs -// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711) +/// Implementation for TokenOutputStream Code is borrowed from +/// https://github.com/huggingface/candle/blob/main/candle-examples/src/token_output_stream.rs +/// (Commit SHA 4fd00b890036ef67391a9cc03f896247d0a75711) +/// +/// This is a wrapper around a tokenizer to ensure that tokens can be returned to the user in a +/// streaming way rather than having to wait for the full decoding. pub struct TokenOutputStream { tokenizer: tokenizers::Tokenizer, tokens: Vec, @@ -18,10 +21,6 @@ impl TokenOutputStream { } } - pub fn _into_inner(self) -> tokenizers::Tokenizer { - self.tokenizer - } - fn decode(&self, tokens: &[u32]) -> anyhow::Result { match self.tokenizer.decode(tokens, true) { Ok(str) => Ok(str), @@ -40,10 +39,10 @@ impl TokenOutputStream { self.tokens.push(token); let text = self.decode(&self.tokens[self.prev_index..])?; if text.len() > prev_text.len() && text.chars().last().unwrap().is_alphanumeric() { - let text = text.split_at(prev_text.len()); + let (_, text) = text.split_at(prev_text.len()); self.prev_index = self.current_index; self.current_index = self.tokens.len(); - Ok(Some(text.1.to_string())) + Ok(Some(text.to_string())) } else { Ok(None) } @@ -64,22 +63,4 @@ impl TokenOutputStream { Ok(None) } } - - pub fn _decode_all(&self) -> anyhow::Result { - self.decode(&self.tokens) - } - - pub fn _get_token(&self, token_s: &str) -> Option { - self.tokenizer.get_vocab(true).get(token_s).copied() - } - - pub fn _tokenizer(&self) -> &tokenizers::Tokenizer { - &self.tokenizer - } - - pub fn _clear(&mut self) { - self.tokens.clear(); - self.prev_index = 0; - self.current_index = 0; - } } diff --git a/crates/llm-local/src/utils.rs b/crates/llm-local/src/utils.rs deleted file mode 100644 index d50d10ec81..0000000000 --- a/crates/llm-local/src/utils.rs +++ /dev/null @@ -1,25 +0,0 @@ -use candle::Result; -use std::path::Path; - -pub fn load_safetensors(model_dir: &Path, json_file: &str) -> Result> { - let json_file = model_dir.join(json_file); - let json_file = std::fs::File::open(json_file)?; - let json: serde_json::Value = - serde_json::from_reader(&json_file).map_err(candle::Error::wrap)?; - let weight_map = match json.get("weight_map") { - None => candle::bail!("no weight map in {json_file:?}"), - Some(serde_json::Value::Object(map)) => map, - Some(_) => candle::bail!("weight map in {json_file:?} is not a map"), - }; - let mut safetensors_files = std::collections::HashSet::new(); - for value in weight_map.values() { - if let Some(file) = value.as_str() { - safetensors_files.insert(file.to_string()); - } - } - let safetensors_files = safetensors_files - .iter() - .map(|v| model_dir.join(v)) - .collect::>(); - Ok(safetensors_files) -} diff --git a/crates/runtime-config/src/lib.rs b/crates/runtime-config/src/lib.rs index 3b7b8039b6..a64e256ad0 100644 --- a/crates/runtime-config/src/lib.rs +++ b/crates/runtime-config/src/lib.rs @@ -97,7 +97,6 @@ where local_app_dir: Option, provided_state_dir: UserProvidedPath, provided_log_dir: UserProvidedPath, - use_gpu: bool, ) -> anyhow::Result { let toml = match runtime_config_path { Some(runtime_config_path) => { @@ -119,14 +118,13 @@ where let toml_resolver = TomlResolver::new(&toml, local_app_dir, provided_state_dir, provided_log_dir); - Self::new(toml_resolver, runtime_config_path, use_gpu) + Self::new(toml_resolver, runtime_config_path) } /// Creates a new resolved runtime configuration from a TOML table. pub fn new( toml_resolver: TomlResolver<'_>, runtime_config_path: Option<&Path>, - use_gpu: bool, ) -> anyhow::Result { let runtime_config_dir = runtime_config_path .and_then(Path::parent) @@ -142,7 +140,6 @@ where &key_value_config_resolver, tls_resolver.as_ref(), &sqlite_config_resolver, - use_gpu, ); let runtime_config: T = source.try_into().map_err(Into::into)?; @@ -275,7 +272,6 @@ pub struct TomlRuntimeConfigSource<'a, 'b> { key_value: &'a key_value::RuntimeConfigResolver, tls: Option<&'a SpinTlsRuntimeConfig>, sqlite: &'a sqlite::RuntimeConfigResolver, - use_gpu: bool, } impl<'a, 'b> TomlRuntimeConfigSource<'a, 'b> { @@ -284,14 +280,12 @@ impl<'a, 'b> TomlRuntimeConfigSource<'a, 'b> { key_value: &'a key_value::RuntimeConfigResolver, tls: Option<&'a SpinTlsRuntimeConfig>, sqlite: &'a sqlite::RuntimeConfigResolver, - use_gpu: bool, ) -> Self { Self { toml: toml_resolver, key_value, tls, sqlite, - use_gpu, } } } @@ -338,7 +332,7 @@ impl FactorRuntimeConfigSource for TomlRuntimeConfigSource< impl FactorRuntimeConfigSource for TomlRuntimeConfigSource<'_, '_> { fn get_runtime_config(&mut self) -> anyhow::Result> { - llm::runtime_config_from_toml(&self.toml.table, self.toml.state_dir()?, self.use_gpu) + llm::runtime_config_from_toml(&self.toml.table, self.toml.state_dir()?) } } diff --git a/crates/runtime-factors/src/build.rs b/crates/runtime-factors/src/build.rs index 01772e858e..a1a297ed03 100644 --- a/crates/runtime-factors/src/build.rs +++ b/crates/runtime-factors/src/build.rs @@ -22,14 +22,11 @@ impl RuntimeFactorsBuilder for FactorsBuilder { config: &FactorsConfig, args: &Self::CliArgs, ) -> anyhow::Result<(Self::Factors, Self::RuntimeConfig)> { - // Hardcode `use_gpu` to true for now - let use_gpu = true; let runtime_config = ResolvedRuntimeConfig::::from_file( config.runtime_config_file.clone().as_deref(), config.local_app_dir.clone().map(PathBuf::from), config.state_dir.clone(), config.log_dir.clone(), - use_gpu, )?; runtime_config.summarize(config.runtime_config_file.as_deref()); @@ -40,7 +37,6 @@ impl RuntimeFactorsBuilder for FactorsBuilder { args.allow_transient_write, runtime_config.key_value_resolver.clone(), runtime_config.sqlite_resolver.clone(), - use_gpu, ) .context("failed to create factors")?; Ok((factors, runtime_config)) diff --git a/crates/runtime-factors/src/lib.rs b/crates/runtime-factors/src/lib.rs index 276c3c8360..7186278ae8 100644 --- a/crates/runtime-factors/src/lib.rs +++ b/crates/runtime-factors/src/lib.rs @@ -42,7 +42,6 @@ impl TriggerFactors { allow_transient_writes: bool, default_key_value_label_resolver: impl spin_factor_key_value::DefaultLabelResolver + 'static, default_sqlite_label_resolver: impl spin_factor_sqlite::DefaultLabelResolver + 'static, - use_gpu: bool, ) -> anyhow::Result { Ok(Self { wasi: wasi_factor(working_dir, allow_transient_writes), @@ -56,7 +55,7 @@ impl TriggerFactors { pg: OutboundPgFactor::new(), mysql: OutboundMysqlFactor::new(), llm: LlmFactor::new( - spin_factor_llm::spin::default_engine_creator(state_dir, use_gpu) + spin_factor_llm::spin::default_engine_creator(state_dir) .context("failed to configure LLM factor")?, ), })