diff --git a/candle-examples/examples/simple-training/main.rs b/candle-examples/examples/simple-training/main.rs index edec2e9270..35b938e8da 100644 --- a/candle-examples/examples/simple-training/main.rs +++ b/candle-examples/examples/simple-training/main.rs @@ -1,16 +1,130 @@ -// This should rearch 91.5% accuracy. +// This should reach 91.5% accuracy. #[cfg(feature = "mkl")] extern crate intel_mkl_src; -use anyhow::Result; -use candle::{DType, Var, D}; -use candle_nn::{loss, ops}; +use candle::{DType, Device, Result, Shape, Tensor, Var, D}; +use candle_nn::{loss, ops, Linear}; +use std::sync::{Arc, Mutex}; const IMAGE_DIM: usize = 784; const LABELS: usize = 10; -pub fn main() -> Result<()> { +struct TensorData { + tensors: std::collections::HashMap, + pub dtype: DType, + pub device: Device, +} + +// A variant of candle_nn::VarBuilder for initializing variables before training. +#[derive(Clone)] +struct VarStore { + data: Arc>, + path: Vec, +} + +impl VarStore { + fn new(dtype: DType, device: Device) -> Self { + let data = TensorData { + tensors: std::collections::HashMap::new(), + dtype, + device, + }; + Self { + data: Arc::new(Mutex::new(data)), + path: vec![], + } + } + + fn pp(&self, s: &str) -> Self { + let mut path = self.path.clone(); + path.push(s.to_string()); + Self { + data: self.data.clone(), + path, + } + } + + fn get>(&self, shape: S, tensor_name: &str) -> Result { + let shape = shape.into(); + let path = if self.path.is_empty() { + tensor_name.to_string() + } else { + [&self.path.join("."), tensor_name].join(".") + }; + let mut tensor_data = self.data.lock().unwrap(); + if let Some(tensor) = tensor_data.tensors.get(&path) { + let tensor_shape = tensor.shape(); + if &shape != tensor_shape { + candle::bail!("shape mismatch on {path}: {shape:?} <> {tensor_shape:?}") + } + return Ok(tensor.as_tensor().clone()); + } + // TODO: Proper initialization using the `Init` enum. + let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?; + let tensor = var.as_tensor().clone(); + tensor_data.tensors.insert(path, var); + Ok(tensor) + } + + fn all_vars(&self) -> Vec { + let tensor_data = self.data.lock().unwrap(); + #[allow(clippy::map_clone)] + tensor_data + .tensors + .values() + .map(|c| c.clone()) + .collect::>() + } +} + +fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result { + let ws = vs.get((dim2, dim1), "weight")?; + let bs = vs.get(dim2, "bias")?; + Ok(Linear::new(ws, Some(bs))) +} + +#[allow(unused)] +struct LinearModel { + linear: Linear, +} + +#[allow(unused)] +impl LinearModel { + fn new(vs: VarStore) -> Result { + let linear = linear(IMAGE_DIM, LABELS, vs)?; + Ok(Self { linear }) + } + + fn forward(&self, xs: &Tensor) -> Result { + self.linear.forward(xs) + } +} + +#[allow(unused)] +struct Mlp { + ln1: Linear, + ln2: Linear, +} + +#[allow(unused)] +impl Mlp { + fn new(vs: VarStore) -> Result { + let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?; + let ln2 = linear(100, LABELS, vs.pp("ln2"))?; + Ok(Self { ln1, ln2 }) + } + + fn forward(&self, xs: &Tensor) -> Result { + let xs = self.ln1.forward(xs)?; + let xs = xs.relu()?; + self.ln2.forward(&xs) + } +} + +pub fn main() -> anyhow::Result<()> { let dev = candle::Device::cuda_if_available(0)?; + + // Load the dataset let m = candle_nn::vision::mnist::load_dir("data")?; println!("train-images: {:?}", m.train_images.shape()); println!("train-labels: {:?}", m.train_labels.shape()); @@ -19,18 +133,23 @@ pub fn main() -> Result<()> { let train_labels = m.train_labels; let train_images = m.train_images; let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?; - let ws = Var::zeros((IMAGE_DIM, LABELS), DType::F32, &dev)?; - let bs = Var::zeros(LABELS, DType::F32, &dev)?; - let sgd = candle_nn::SGD::new(&[&ws, &bs], 1.0); + + let vs = VarStore::new(DType::F32, dev); + let model = LinearModel::new(vs.clone())?; + // let model = Mlp::new(vs)?; + + let all_vars = vs.all_vars(); + let all_vars = all_vars.iter().collect::>(); + let sgd = candle_nn::SGD::new(&all_vars, 1.0); let test_images = m.test_images; let test_labels = m.test_labels.to_dtype(DType::U32)?; for epoch in 1..200 { - let logits = train_images.matmul(&ws)?.broadcast_add(&bs)?; + let logits = model.forward(&train_images)?; let log_sm = ops::log_softmax(&logits, D::Minus1)?; let loss = loss::nll(&log_sm, &train_labels)?; sgd.backward_step(&loss)?; - let test_logits = test_images.matmul(&ws)?.broadcast_add(&bs)?; + let test_logits = model.forward(&test_images)?; let sum_ok = test_logits .argmax(D::Minus1)? .eq(&test_labels)? diff --git a/candle-nn/src/var_builder.rs b/candle-nn/src/var_builder.rs index 5c222bf6f5..be1380b70c 100644 --- a/candle-nn/src/var_builder.rs +++ b/candle-nn/src/var_builder.rs @@ -209,6 +209,7 @@ impl<'a> VarBuilder<'a> { }; Ok(tensor) } + pub fn get>(&self, s: S, tensor_name: &str) -> Result { let data = self.data.as_ref(); let s: Shape = s.into();