Skip to content

Commit

Permalink
Merge pull request #592 from chachaleo/refactor/NN
Browse files Browse the repository at this point in the history
refactor nn: helpers and enum
  • Loading branch information
raphaelDkhn authored Apr 22, 2024
2 parents 62d55e9 + f0e7287 commit 2a2c8e5
Show file tree
Hide file tree
Showing 23 changed files with 401 additions and 760 deletions.
3 changes: 2 additions & 1 deletion src/operators/nn.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ mod core;
mod implementations;
mod functional;
mod common;
mod helpers;

use orion::operators::nn::common::{AUTO_PAD, POOLING_TYPE};
use orion::operators::nn::common::{AUTO_PAD, MODE, PADDING_MODE, POOLING_TYPE};

use orion::operators::nn::core::NNTrait;

Expand Down
14 changes: 14 additions & 0 deletions src/operators/nn/common.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -12,3 +12,17 @@ enum POOLING_TYPE {
LPPOOL,
MAX,
}

#[derive(Copy, Drop)]
enum MODE {
NEAREST,
LINEAR,
CUBIC,
}

#[derive(Copy, Drop)]
enum PADDING_MODE {
ZEROS,
BORDER,
REFLECTION,
}
8 changes: 4 additions & 4 deletions src/operators/nn/core.cairo
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use orion::operators::tensor::core::Tensor;
use orion::operators::nn::AUTO_PAD;
use orion::operators::nn::{AUTO_PAD, MODE, PADDING_MODE};

/// Trait
///
Expand Down Expand Up @@ -1087,7 +1087,7 @@ trait NNTrait<T> {
X: @Tensor<T>,
W: @Tensor<T>,
B: Option<@Tensor<T>>,
auto_pad: Option<orion::operators::nn::functional::conv_transpose::AUTO_PAD>,
auto_pad: Option<AUTO_PAD>,
dilations: Option<Span<usize>>,
group: Option<usize>,
kernel_shape: Option<Span<usize>>,
Expand Down Expand Up @@ -1302,8 +1302,8 @@ trait NNTrait<T> {
X: @Tensor<T>,
grid: @Tensor<T>,
align_corner: Option<usize>,
mode: Option<orion::operators::nn::functional::grid_sample::MODE>,
padding_mode: Option<orion::operators::nn::functional::grid_sample::PADDING_MODE>,
mode: Option<MODE>,
padding_mode: Option<PADDING_MODE>,
) -> Tensor<T>;
///
/// # NNTrait::max_pool
Expand Down
52 changes: 10 additions & 42 deletions src/operators/nn/functional/col2im.cairo
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use orion::numbers::NumberTrait;
use orion::operators::tensor::core::{stride};
use orion::operators::tensor::core::{stride, unravel_index};
use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor,};
use orion::operators::vec::{NullableVec, NullableVecImpl};
use orion::operators::nn::helpers::{is_out, prod};

fn col2im<T, MAG, +TensorTrait<T>, +NumberTrait<T, MAG>, +Copy<T>, +Drop<T>, +Add<T>, +Mul<T>,>(
fn col2im<T, MAG, +TensorTrait<T>, +NumberTrait<T, MAG>, +Copy<T>, +Drop<T>, +Add<T>, +MulEq<T>,>(
data: @Tensor<T>,
image_shape: Span<usize>,
block_shape: Span<usize>,
Expand Down Expand Up @@ -53,7 +54,7 @@ fn col2im<T, MAG, +TensorTrait<T>, +NumberTrait<T, MAG>, +Copy<T>, +Drop<T>, +Ad
},
};

let bl = prod(block_shape, 0);
let bl = prod(block_shape);
let C = *(*data).shape.at(1) / bl;

let mut new_shape: Array<i32> = array![
Expand Down Expand Up @@ -158,15 +159,15 @@ fn col2im_naive_implementation<
let mut data_im = NullableVecImpl::new();
data_im.set(*image_shape.at(0) * *stride_img.at(0) - 1, NumberTrait::zero());

let kernel_size = prod(kernel_shape, 0);
let col_size = prod(dim_col, 0);
let kernel_size = prod(kernel_shape);
let col_size = prod(dim_col);
let mut c_col = 0;
while c_col != kernel_size {
let offset = get_indices(c_col, kernel_shape).span();
let offset = unravel_index(c_col, kernel_shape);

let mut col = 0;
while col != col_size {
let ind_col = get_indices(col, dim_col).span();
let ind_col = unravel_index(col, dim_col);
let mut ind_im: Array<usize> = array![];
let mut i = 0;
while i != n_dims {
Expand Down Expand Up @@ -218,7 +219,7 @@ fn col2im_shape_check<T, +TensorTrait<T>, +Copy<T>, +Drop<T>,>(
) {
let n_input_plane = *(*X).shape.at(0);

let kernel_size = prod(kernel_shape, 0);
let kernel_size = prod(kernel_shape);

assert(n_input_plane % kernel_size == 0, 'wrong input dimension');

Expand All @@ -240,7 +241,7 @@ fn col2im_shape_check<T, +TensorTrait<T>, +Copy<T>, +Drop<T>,>(
i += 1;
};

let block_size = prod(n_blocks.span(), 0);
let block_size = prod(n_blocks.span());

assert(input_length == block_size, 'input_length != block_size');
}
Expand All @@ -267,36 +268,3 @@ fn get_indices(index: usize, shape: Span<usize>,) -> Array<usize> {

new_res
}

fn is_out(ind: Span<usize>, shape: Span<usize>,) -> bool {
let mut n = 0;
let is_out = loop {
if n == ind.len() {
break false;
}
let s = *shape.at(n);
let i = *ind.at(n);
if i < 0 {
break true;
}
if i >= s {
break true;
}
n += 1;
};

is_out
}

fn prod<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +Mul<T>,>(
pA: Span<T>, start: usize
) -> T {
let mut i = start;
let mut prod = NumberTrait::one();
while i != pA.len() {
prod = prod * (*pA.at(i));
i += 1;
};

prod
}
163 changes: 3 additions & 160 deletions src/operators/nn/functional/conv.cairo
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ use orion::numbers::{U32IntoI32, I32IntoU32, I32Div, I32Number};
use orion::operators::tensor::{TensorTrait, Tensor, U32Tensor,};
use orion::operators::vec::{NullableVec, NullableVecImpl};
use orion::operators::tensor::core::{stride};

use orion::operators::nn::helpers::{cartesian, arange, max_in_tensor, min_in_tensor, dot};
use orion::operators::nn::AUTO_PAD;


Expand Down Expand Up @@ -230,7 +230,8 @@ fn conv<
}

// group == 1
if *dilations.at(0) != 1 || min(dilations.clone()) != max(dilations.clone()) {
if *dilations.at(0) != 1
|| min_in_tensor(dilations.clone()) != min_in_tensor(dilations.clone()) {
// computation of the dilated kernel
let nd = dilations.len();
let mut new_kernel_shape: Array<usize> = array![];
Expand Down Expand Up @@ -1213,161 +1214,3 @@ fn r_index_check(r_index: Span<i32>, shape_out: Span<usize>) -> bool {
flag
}

fn prod<T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +TensorTrait<T>, +Mul<T>,>(
pA: Span<T>, start: usize
) -> T {
let mut i = start;
let mut prod = NumberTrait::one();
while i != pA.len() {
prod = prod * (*pA.at(i));
i += 1;
};

prod
}

fn min(mut a: Span<usize>) -> usize {
assert(a.len() > 0, 'span cannot be empty');

let mut min = *a.at(0);
loop {
match a.pop_front() {
Option::Some(v) => { if *v < min {
min = *v;
}; },
Option::None => { break min; }
};
}
}

fn max(mut a: Span<usize>) -> usize {
assert(a.len() > 0, 'span cannot be empty');

let mut max = *a.at(0);
loop {
match a.pop_front() {
Option::Some(v) => { if *v > max {
max = *v;
}; },
Option::None => { break max; }
};
}
}

fn arange(start: usize, end: usize, step: usize) -> Span<usize> {
assert((end - start) % step == 0, 'incompatible step value');

let mut arr: Array<usize> = array![];
let mut i = start;
while i < end {
arr.append(i);
i += step;
};

arr.span()
}


fn cartesian(mut arrays: Span<Span<usize>>,) -> Span<Span<usize>> {
let mut n = 1;
let mut i = arrays.len() - 1;
loop {
n = n * (*(arrays.at(i))).len();
if i == 0 {
break;
}
i -= 1;
};

let mut i = 0;
let mut size_arrays: Array<usize> = array![];
while i != arrays.len() {
size_arrays.append((*(arrays.at(i))).len());
i += 1;
};

let size_arrays = size_arrays.span();
let mut output_arrays = array![];
let mut m = n;

let mut i = 0;
while i != arrays.len() {
m = m / (*(arrays.at(i))).len();
let mut out = repeat(*(arrays.at(i)), m);
out = repeat_2(out, size_arrays, i);

output_arrays.append(out);
i += 1;
};

let output_arrays = output_arrays.span();

let mut i = 0;
let mut ret = ArrayTrait::new();
while i != n {
let mut j = 0;
let mut x: Array<usize> = array![];
while j != arrays.len() {
x.append(*(output_arrays.at(j)).at(i));
j += 1;
};

ret.append(x.span());
i += 1;
};

ret.span()
}

fn repeat_2(mut array: Array<usize>, size_array: Span<usize>, index: usize) -> Array<usize> {
let mut size = array.len();
let mut i = 0;
while i != index {
let mut j = 1;
while j != *size_array.at(index - 1 - i) {
let mut k = 0;
while k != size {
array.append(*array.at(k));
k += 1;
};

j += 1;
};

size = size * *size_array.at(index - 1 - i);
i += 1;
};

array
}

fn repeat(array: Span<usize>, m: usize,) -> Array<usize> {
let mut out: Array<usize> = array![];
let mut j = 0;
while j != array.len() {
let mut k = 0;
while k != m {
out.append(*array.at(j));
k += 1;
};

j += 1;
};

out
}

fn dot<
T, MAG, +Drop<T>, +Copy<T>, +NumberTrait<T, MAG>, +Add<T>, +TensorTrait<T>, +AddEq<T>, +Mul<T>,
>(
a: Span<T>, b: Span<T>
) -> T {
let mut i = 0;
let mut sum = NumberTrait::zero();
while i != a.len() {
sum = sum + *a.at(i) * *b.at(i);
i += 1;
};

sum
}
Loading

0 comments on commit 2a2c8e5

Please sign in to comment.