Skip to content

Commit

Permalink
Replace rustformers/llm with candle
Browse files Browse the repository at this point in the history
Signed-off-by: karthik2804 <[email protected]>
  • Loading branch information
karthik2804 committed Sep 19, 2024
1 parent dada4d8 commit 4e40481
Show file tree
Hide file tree
Showing 9 changed files with 901 additions and 764 deletions.
1,113 changes: 594 additions & 519 deletions Cargo.lock

Large diffs are not rendered by default.

18 changes: 6 additions & 12 deletions crates/factor-llm/src/spin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ mod local {
/// The default engine creator for the LLM factor when used in the Spin CLI.
pub fn default_engine_creator(
state_dir: Option<PathBuf>,
use_gpu: bool,
) -> anyhow::Result<impl LlmEngineCreator + 'static> {
#[cfg(feature = "llm")]
let engine = {
Expand All @@ -53,11 +52,11 @@ pub fn default_engine_creator(
Some(ref dir) => dir.clone(),
None => std::env::current_dir().context("failed to get current working directory")?,
};
spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"), use_gpu)
spin_llm_local::LocalLlmEngine::new(models_dir_parent.join("ai-models"))
};
#[cfg(not(feature = "llm"))]
let engine = {
let _ = (state_dir, use_gpu);
let _ = (state_dir);
noop::NoopLlmEngine
};
let engine = Arc::new(Mutex::new(engine)) as Arc<Mutex<dyn LlmEngine>>;
Expand Down Expand Up @@ -91,15 +90,14 @@ impl LlmEngine for RemoteHttpLlmEngine {
pub fn runtime_config_from_toml(
table: &impl GetTomlValue,
state_dir: Option<PathBuf>,
use_gpu: bool,
) -> anyhow::Result<Option<RuntimeConfig>> {
let Some(value) = table.get("llm_compute") else {
return Ok(None);
};
let config: LlmCompute = value.clone().try_into()?;

Ok(Some(RuntimeConfig {
engine: config.into_engine(state_dir, use_gpu)?,
engine: config.into_engine(state_dir)?,
}))
}

Expand All @@ -111,19 +109,15 @@ pub enum LlmCompute {
}

impl LlmCompute {
fn into_engine(
self,
state_dir: Option<PathBuf>,
use_gpu: bool,
) -> anyhow::Result<Arc<Mutex<dyn LlmEngine>>> {
fn into_engine(self, state_dir: Option<PathBuf>) -> anyhow::Result<Arc<Mutex<dyn LlmEngine>>> {
let engine: Arc<Mutex<dyn LlmEngine>> = match self {
#[cfg(not(feature = "llm"))]
LlmCompute::Spin => {
let _ = (state_dir, use_gpu);
let _ = (state_dir);
Arc::new(Mutex::new(noop::NoopLlmEngine))
}
#[cfg(feature = "llm")]
LlmCompute::Spin => default_engine_creator(state_dir, use_gpu)?.create(),
LlmCompute::Spin => default_engine_creator(state_dir)?.create(),
LlmCompute::RemoteHttp(config) => Arc::new(Mutex::new(RemoteHttpLlmEngine::new(
config.url,
config.auth_token,
Expand Down
24 changes: 11 additions & 13 deletions crates/llm-local/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,29 @@ authors = { workspace = true }
edition = { workspace = true }

[dependencies]
anyhow = { workspace = true }
candle = { git = "https://github.com/huggingface/candle", rev = "b80348d22f8f0dadb6cc4101bde031d5de69a9a5", package = "candle-core" }
candle-nn = { git = "https://github.com/huggingface/candle", rev = "b80348d22f8f0dadb6cc4101bde031d5de69a9a5" }
chrono = "0.4"
llm = { git = "https://github.com/rustformers/llm", rev = "2f6ffd4435799ceaa1d1bcb5a8790e5b3e0c5663", features = [
"tokenizers-remote",
"llama",
], default-features = false }
lru = "0.12"
anyhow = "1.0"
candle = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483", package = "candle-core" }
candle-nn = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483" }
candle-transformers = { git = "https://github.com/huggingface/candle", rev = "e3261216b157a7305c18ccdd766b6e2a41afe483" }
chrono = "0.4.26"
lru = "0.9.0"
num_cpus = "1"
rand = { workspace = true }
safetensors = "0.3.3"
serde = { workspace = true }
serde_json = "1.0.125"
spin-common = { path = "../common" }
spin-core = { path = "../core" }
spin-world = { path = "../world" }
terminal = { path = "../terminal" }
tokenizers = "0.13.4"
tokio = { version = "1", features = ["macros", "sync"] }
tokenizers = "0.19.1"
tokio = { version = "1.32.0", features = ["macros", "sync"] }
tracing = { workspace = true }

[features]
default = []
metal = ["llm/metal"]
cublas = ["llm/cublas"]
metal = ["candle/metal", "candle-nn/metal", "candle-transformers/metal"]
cublas = ["candle/cuda", "candle-nn/cuda", "candle-transformers/cuda"]

[lints]
workspace = true
2 changes: 1 addition & 1 deletion crates/llm-local/src/bert.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
///
/// TODO: Remove this file when a new release of Candle makes it obsolete.
use anyhow::{bail, Result};
use candle::{DType, Tensor};
use candle::{DType, Module, Tensor};
use candle_nn::{Embedding, VarBuilder};
use serde::Deserialize;

Expand Down
Loading

0 comments on commit 4e40481

Please sign in to comment.