Skip to content

Commit

Permalink
Merge pull request #29 from bokutotu/feat/dropout
Browse files Browse the repository at this point in the history
Feat/dropout
  • Loading branch information
bokutotu authored Aug 7, 2024
2 parents 3849e4a + 412087b commit f48f88f
Show file tree
Hide file tree
Showing 11 changed files with 1,026 additions and 4 deletions.
5 changes: 2 additions & 3 deletions zenu-autograd/src/functions/conv2d.rs
Original file line number Diff line number Diff line change
Expand Up @@ -351,11 +351,10 @@ pub fn conv2d<T: Num, D: Device>(
config: Option<Conv2dConfigs<T>>,
) -> Variable<T, D> {
let y = conv2d_inner(x, filter, stride, padding, config);
let y = match bias {
match bias {
Some(bias) => conv2d_bias(y, bias),
None => y,
};
y
}
}

fn deconv2d_inner<T: Num, D: Device>(
Expand Down
194 changes: 194 additions & 0 deletions zenu-autograd/src/functions/dropout.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,194 @@
use std::{cell::RefCell, rc::Rc};

use zenu_matrix::{
device::Device,
dim::DimDyn,
nn::dropout::{dropout as forward, dropout_grad as grad, DropoutState},
num::Num,
};

use crate::{creator::alloc::alloc, is_train, Function, Variable, VariableWeak};

#[derive(Clone)]
pub struct DropoutConfig<T: Num, D: Device> {
inner: Rc<RefCell<DropoutState<T, D>>>,
}

impl<T: Num, D: Device> DropoutConfig<T, D> {
pub fn new(rate: f32) -> Self {
let inner = Rc::new(RefCell::new(DropoutState::new(rate)));
Self { inner }
}

pub fn gpu_init(&self, shape: DimDyn) {
self.inner.borrow_mut().gpu_init(shape);
}
}

struct DropoutForward<T: Num, D: Device> {
config: DropoutConfig<T, D>,
input: Variable<T, D>,
output: VariableWeak<T, D>,
}

struct DropoutBackward<T: Num, D: Device> {
config: DropoutConfig<T, D>,
output_grad: Variable<T, D>,
input_grad: VariableWeak<T, D>,
}

impl<T: Num, D: Device> Function<T, D> for DropoutForward<T, D> {
fn forward(&self) {
let input = self.input.get_data();
let output = self.output.upgrade().unwrap();
let mut output = output.get_data_mut();
let mut config = self.config.inner.borrow_mut();
output.to_ref_mut().copy_from(&forward(&input, &mut config));
}

fn backward(&self) {
let output = self.output.upgrade().unwrap();
let output_grad = output.get_grad().unwrap();
let input_grad = dropout_backward(output_grad, self.config.clone());
self.input.set_grad(input_grad);
}

fn get_inputs(&self) -> Vec<Variable<T, D>> {
vec![self.input.clone()]
}
}

impl<T: Num, D: Device> Function<T, D> for DropoutBackward<T, D> {
fn forward(&self) {
let output_grad = self.output_grad.get_data();
let input_grad = self.input_grad.upgrade().unwrap();
let mut input_grad = input_grad.get_data_mut();
let config = self.config.inner.borrow();
input_grad
.to_ref_mut()
.copy_from(&grad(&output_grad, &config));
}

fn backward(&self) {
todo!();
}

fn get_inputs(&self) -> Vec<Variable<T, D>> {
vec![self.output_grad.clone()]
}
}

pub fn dropout<T: Num, D: Device>(
input: Variable<T, D>,
rate: f32,
config: Option<DropoutConfig<T, D>>,
) -> Variable<T, D> {
if !is_train() {
return input;
}
let output = alloc(input.get_shape());

let config = match config {
Some(config) => config,
None => DropoutConfig::new(rate),
};

let dropout = DropoutForward {
config,
input,
output: output.clone().downgrade(),
};

dropout.forward();

output.set_creator(Rc::new(RefCell::new(Box::new(dropout))));

output.set_name("dropout");

output
}

fn dropout_backward<T: Num, D: Device>(
output_grad: Variable<T, D>,
config: DropoutConfig<T, D>,
) -> Variable<T, D> {
let input_grad = alloc(output_grad.get_shape());

let dropout = DropoutBackward {
config,
output_grad,
input_grad: input_grad.clone().downgrade(),
};

dropout.forward();

input_grad.set_creator(Rc::new(RefCell::new(Box::new(dropout))));

input_grad.set_name("dropout_grad");

input_grad
}

#[cfg(test)]
mod dropout {
use zenu_matrix::device::{cpu::Cpu, Device};
use zenu_test::run_test;

use crate::creator::rand::normal;

use super::dropout;

fn dropout_4d_train<D: Device>() {
let input = normal::<f32, _, D>(1f32, 1f32, None, [3, 3, 3, 3]);
let output = dropout(input.clone(), 0.8, None);
output.backward();

let input_mat_cpu = {
let input = input.get_data().clone();
input.to::<Cpu>()
};
let output_mat_cpu = {
let output = output.get_data().clone();
output.to::<Cpu>()
};

let mask = {
let s = output_mat_cpu.as_slice();
s.iter().map(|&x| (x != 0f32)).collect::<Vec<_>>()
};

let output_slice = output_mat_cpu.as_slice();
let input_slice = input_mat_cpu.as_slice();

for idx in 0..output_slice.len() {
if !mask[idx] {
assert_eq!(output_slice[idx], 0f32);
} else {
let diff = output_slice[idx] - input_slice[idx] / 0.2;
assert!(
diff.abs() < 1e-5,
"idx : {} output : {} input slice: {} diff :{}",
idx,
output_slice[idx],
input_slice[idx],
diff
);
}
}

let input_grad = input.get_grad().unwrap();
let input_grad_cpu = {
let input_grad = input_grad.get_data().clone();
input_grad.to::<Cpu>()
};

for idx in 0..output_slice.len() {
if !mask[idx] {
assert_eq!(input_grad_cpu.as_slice()[idx], 0f32);
} else {
assert_eq!(input_grad_cpu.as_slice()[idx], 1f32 / (1f32 - 0.8));
}
}
}
run_test!(dropout_4d_train, dropout_4d_train_cpu, dropout_4d_train_gpu);
}
1 change: 1 addition & 0 deletions zenu-autograd/src/functions/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ pub mod broadcast;
pub mod clip;
pub mod conv2d;
pub mod cosh;
pub mod dropout;
pub mod exp;
pub mod flatten;
pub mod log;
Expand Down
Loading

0 comments on commit f48f88f

Please sign in to comment.