Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve the mnist training example. #276

Merged
merged 4 commits into from
Jul 29, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 30 additions & 14 deletions candle-core/src/device.rs
Original file line number Diff line number Diff line change
Expand Up @@ -116,46 +116,62 @@ impl Device {
}
}

pub(crate) fn rand_uniform<T: crate::FloatDType>(
pub(crate) fn rand_uniform_f64(
&self,
lo: T,
up: T,
lo: f64,
up: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
let lo = lo.to_f64();
let up = up.to_f64();
match self {
Device::Cpu => {
let storage = CpuDevice.rand_uniform(shape, T::DTYPE, lo, up)?;
let storage = CpuDevice.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_uniform(shape, T::DTYPE, lo, up)?;
let storage = device.rand_uniform(shape, dtype, lo, up)?;
Ok(Storage::Cuda(storage))
}
}
}

pub(crate) fn rand_normal<T: crate::FloatDType>(
pub(crate) fn rand_uniform<T: crate::FloatDType>(
&self,
mean: T,
std: T,
lo: T,
up: T,
shape: &Shape,
) -> Result<Storage> {
self.rand_uniform_f64(lo.to_f64(), up.to_f64(), shape, T::DTYPE)
}

pub(crate) fn rand_normal_f64(
&self,
mean: f64,
std: f64,
shape: &Shape,
dtype: DType,
) -> Result<Storage> {
let mean = mean.to_f64();
let std = std.to_f64();
match self {
Device::Cpu => {
let storage = CpuDevice.rand_normal(shape, T::DTYPE, mean, std)?;
let storage = CpuDevice.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cpu(storage))
}
Device::Cuda(device) => {
let storage = device.rand_normal(shape, T::DTYPE, mean, std)?;
let storage = device.rand_normal(shape, dtype, mean, std)?;
Ok(Storage::Cuda(storage))
}
}
}

pub(crate) fn rand_normal<T: crate::FloatDType>(
&self,
mean: T,
std: T,
shape: &Shape,
) -> Result<Storage> {
self.rand_normal_f64(mean.to_f64(), std.to_f64(), shape, T::DTYPE)
}

pub(crate) fn ones(&self, shape: &Shape, dtype: DType) -> Result<Storage> {
match self {
Device::Cpu => {
Expand Down
38 changes: 38 additions & 0 deletions candle-core/src/tensor.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}

pub(crate) fn rand_f64_impl<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_uniform_f64(lo, up, &s, dtype)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}

/// Creates a new tensor initialized with values sampled uniformly between `lo` and `up`.
pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
Expand All @@ -268,6 +282,20 @@ impl Tensor {
Ok(from_storage(storage, s, none, is_variable))
}

pub(crate) fn randn_f64_impl<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
is_variable: bool,
) -> Result<Self> {
let s = s.into();
let storage = device.rand_normal_f64(mean, std, &s, dtype)?;
let none = BackpropOp::none();
Ok(from_storage(storage, s, none, is_variable))
}

/// Creates a new tensor initialized with values sampled from a normal distribution with the
/// specified `mean` and standard deviation `std`.
pub fn randn<S: Into<Shape>, T: crate::FloatDType>(
Expand Down Expand Up @@ -1448,6 +1476,16 @@ impl Tensor {
}
}

/// Create a variable based on the values currently stored in a tensor. The storage is always
/// copied.
pub(crate) fn make_var(&self) -> Result<Tensor> {
let shape = self.shape().clone();
let mut storage = self.device().zeros(&shape, self.dtype())?;
self.storage()
.copy_strided_src(&mut storage, 0, self.layout())?;
Ok(from_storage(storage, shape, BackpropOp::none(), true))
}

// TODO: Do we want to allow target shape using -1 on some dimensions?
/// Reshape returns a tensor with the target shape provided that the number of elements of the
/// original tensor is the same.
Expand Down
27 changes: 27 additions & 0 deletions candle-core/src/variable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,33 @@ impl Var {
Ok(Self(inner))
}

pub fn from_tensor(t: &Tensor) -> Result<Self> {
let inner = t.make_var()?;
Ok(Self(inner))
}

pub fn rand_f64<S: Into<Shape>>(
lo: f64,
up: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::rand_f64_impl(lo, up, s, dtype, device, true)?;
Ok(Self(inner))
}

pub fn randn_f64<S: Into<Shape>>(
mean: f64,
std: f64,
s: S,
dtype: DType,
device: &Device,
) -> Result<Self> {
let inner = Tensor::randn_f64_impl(mean, std, s, dtype, device, true)?;
Ok(Self(inner))
}

pub fn rand<S: Into<Shape>, T: crate::FloatDType>(
lo: T,
up: T,
Expand Down
88 changes: 64 additions & 24 deletions candle-examples/examples/simple-training/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
#[cfg(feature = "mkl")]
extern crate intel_mkl_src;

use clap::{Parser, ValueEnum};

use candle::{DType, Device, Result, Shape, Tensor, Var, D};
use candle_nn::{loss, ops, Linear};
use candle_nn::{loss, ops, Init, Linear};
use std::sync::{Arc, Mutex};

const IMAGE_DIM: usize = 784;
Expand Down Expand Up @@ -44,7 +46,7 @@ impl VarStore {
}
}

fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str) -> Result<Tensor> {
fn get<S: Into<Shape>>(&self, shape: S, tensor_name: &str, init: Init) -> Result<Tensor> {
let shape = shape.into();
let path = if self.path.is_empty() {
tensor_name.to_string()
Expand All @@ -59,8 +61,7 @@ impl VarStore {
}
return Ok(tensor.as_tensor().clone());
}
// TODO: Proper initialization using the `Init` enum.
let var = Var::zeros(shape, tensor_data.dtype, &tensor_data.device)?;
let var = init.var(shape, tensor_data.dtype, &tensor_data.device)?;
let tensor = var.as_tensor().clone();
tensor_data.tensors.insert(path, var);
Ok(tensor)
Expand All @@ -77,21 +78,36 @@ impl VarStore {
}
}

fn linear(dim1: usize, dim2: usize, vs: VarStore) -> Result<Linear> {
let ws = vs.get((dim2, dim1), "weight")?;
let bs = vs.get(dim2, "bias")?;
fn linear_z(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let ws = vs.get((out_dim, in_dim), "weight", candle_nn::init::ZERO)?;
let bs = vs.get(out_dim, "bias", candle_nn::init::ZERO)?;
Ok(Linear::new(ws, Some(bs)))
}

fn linear(in_dim: usize, out_dim: usize, vs: VarStore) -> Result<Linear> {
let init_ws = candle_nn::init::DEFAULT_KAIMING_NORMAL;
let ws = vs.get((out_dim, in_dim), "weight", init_ws)?;
let bound = 1. / (in_dim as f64).sqrt();
let init_bs = Init::Uniform {
lo: -bound,
up: bound,
};
let bs = vs.get(out_dim, "bias", init_bs)?;
Ok(Linear::new(ws, Some(bs)))
}

#[allow(unused)]
trait Model: Sized {
fn new(vs: VarStore) -> Result<Self>;
fn forward(&self, xs: &Tensor) -> Result<Tensor>;
}

struct LinearModel {
linear: Linear,
}

#[allow(unused)]
impl LinearModel {
impl Model for LinearModel {
fn new(vs: VarStore) -> Result<Self> {
let linear = linear(IMAGE_DIM, LABELS, vs)?;
let linear = linear_z(IMAGE_DIM, LABELS, vs)?;
Ok(Self { linear })
}

Expand All @@ -100,14 +116,12 @@ impl LinearModel {
}
}

#[allow(unused)]
struct Mlp {
ln1: Linear,
ln2: Linear,
}

#[allow(unused)]
impl Mlp {
impl Model for Mlp {
fn new(vs: VarStore) -> Result<Self> {
let ln1 = linear(IMAGE_DIM, 100, vs.pp("ln1"))?;
let ln2 = linear(100, LABELS, vs.pp("ln2"))?;
Expand All @@ -121,26 +135,22 @@ impl Mlp {
}
}

pub fn main() -> anyhow::Result<()> {
fn training_loop<M: Model>(
m: candle_nn::vision::Dataset,
learning_rate: f64,
) -> anyhow::Result<()> {
let dev = candle::Device::cuda_if_available(0)?;

// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());
let train_labels = m.train_labels;
let train_images = m.train_images;
let train_labels = train_labels.to_dtype(DType::U32)?.unsqueeze(1)?;

let vs = VarStore::new(DType::F32, dev);
let model = LinearModel::new(vs.clone())?;
// let model = Mlp::new(vs)?;
let model = M::new(vs.clone())?;

let all_vars = vs.all_vars();
let all_vars = all_vars.iter().collect::<Vec<_>>();
let sgd = candle_nn::SGD::new(&all_vars, 1.0);
let sgd = candle_nn::SGD::new(&all_vars, learning_rate);
let test_images = m.test_images;
let test_labels = m.test_labels.to_dtype(DType::U32)?;
for epoch in 1..200 {
Expand All @@ -165,3 +175,33 @@ pub fn main() -> anyhow::Result<()> {
}
Ok(())
}

#[derive(ValueEnum, Clone)]
enum WhichModel {
Linear,
Mlp,
}

#[derive(Parser)]
struct Args {
#[clap(value_enum, default_value_t = WhichModel::Linear)]
model: WhichModel,

#[arg(long)]
learning_rate: Option<f64>,
}

pub fn main() -> anyhow::Result<()> {
let args = Args::parse();
// Load the dataset
let m = candle_nn::vision::mnist::load_dir("data")?;
println!("train-images: {:?}", m.train_images.shape());
println!("train-labels: {:?}", m.train_labels.shape());
println!("test-images: {:?}", m.test_images.shape());
println!("test-labels: {:?}", m.test_labels.shape());

match args.model {
WhichModel::Linear => training_loop::<LinearModel>(m, args.learning_rate.unwrap_or(1.)),
WhichModel::Mlp => training_loop::<Mlp>(m, args.learning_rate.unwrap_or(0.01)),
}
}
40 changes: 36 additions & 4 deletions candle-nn/src/init.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
//! Variable initialization.
// This is based on:
// https://github.com/pytorch/pytorch/blob/07107919297db3f8ab37f11c12666b6d6d5f692e/torch/nn/init.py#
use candle::Shape;
use candle::{DType, Device, Result, Shape, Tensor, Var};

/// Number of features as input or output of a layer.
/// In Kaiming initialization, choosing `FanIn` preserves
Expand Down Expand Up @@ -91,11 +91,11 @@ pub enum Init {
fan: FanInOut,
non_linearity: NonLinearity,
},

/// Orthogonal initialization
Orthogonal { gain: f64 },
}

pub const ZERO: Init = Init::Const(0.);
pub const ONE: Init = Init::Const(1.);

pub const DEFAULT_KAIMING_UNIFORM: Init = Init::Kaiming {
dist: NormalOrUniform::Uniform,
fan: FanInOut::FanIn,
Expand All @@ -107,3 +107,35 @@ pub const DEFAULT_KAIMING_NORMAL: Init = Init::Kaiming {
fan: FanInOut::FanIn,
non_linearity: NonLinearity::ReLU,
};

impl Init {
/// Creates a new tensor with the specified shape, device, and initialization.
pub fn var<S: Into<Shape>>(&self, s: S, dtype: DType, device: &Device) -> Result<Var> {
match self {
Self::Const(v) if *v == 0. => Var::zeros(s, dtype, device),
Self::Const(v) if *v == 1. => Var::ones(s, dtype, device),
Self::Const(cst) => {
Var::from_tensor(&Tensor::ones(s, dtype, device)?.affine(*cst, 0.)?)
}
Self::Uniform { lo, up } => Var::rand_f64(*lo, *up, s, dtype, device),
Self::Randn { mean, stdev } => Var::randn_f64(*mean, *stdev, s, dtype, device),
Self::Kaiming {
dist,
fan,
non_linearity,
} => {
let s = s.into();
let fan = fan.for_shape(&s);
let gain = non_linearity.gain();
let std = gain / (fan as f64).sqrt();
match dist {
NormalOrUniform::Uniform => {
let bound = 3f64.sqrt() * std;
Var::rand_f64(-bound, bound, s, dtype, device)
}
NormalOrUniform::Normal => Var::randn_f64(0., std, s, dtype, device),
}
}
}
}
}
1 change: 1 addition & 0 deletions candle-nn/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod vision;
pub use activation::Activation;
pub use conv::{Conv1d, Conv1dConfig};
pub use embedding::Embedding;
pub use init::Init;
pub use layer_norm::LayerNorm;
pub use linear::Linear;
pub use optim::SGD;
Expand Down
Loading