Skip to content

Commit

Permalink
Indexing examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtschelfthout committed May 12, 2024
1 parent 4e9dd36 commit 75ad393
Show file tree
Hide file tree
Showing 3 changed files with 149 additions and 15 deletions.
156 changes: 144 additions & 12 deletions examples/fancy_indexing.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,25 @@
use tensorken::{Axes, Cpu32, CpuI32};
use tensorken::{hd, tl, Axes, CpuBool, CpuI32, Ellipsis, NewAxis};

/// A macro to print the result of an expression and the expression itself.
macro_rules! with_shapes {
($t:expr, $e:expr) => {
println!(">>> {}", stringify!($e));
let input_shape = $t.shape();
let result: CpuI32 = $e;
let result_shape = result.shape();
println!("{:?} -> {:?}", input_shape, result_shape);
println!("{result}");
};
}

/// A macro to print the result of an expression and the expression itself.
macro_rules! do_example {
($e:expr) => {
println!(">>> {}", stringify!($e));
let result = $e;
println!("{result}");
};
}

/// A macro to print the result of an expression, the expression itself,
/// and bind the result to a variable.
Expand All @@ -8,33 +29,144 @@ macro_rules! let_example {
let $t = $e;
println!("{}", $t);
};
($t:ident, $e:expr, $debug:literal) => {
println!(">>> {}", stringify!(let $t = $e));
let $t = $e;
println!("{:?}", $t);
};
}

type Tr = Cpu32;
type TrI = CpuI32;
type TrB = CpuBool;

fn main() {
let_example!(t, Tr::linspace(1., 15.0, 15u8).reshape(&[5, 3]));
let_example!(t, TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>()));

// basic indexing
do_example!(t.ix3(0, 1, 2).to_scalar());
with_shapes!(t, t.ix1(1));

do_example!(t
.ix3(t.shape()[0] - 1, t.shape()[1] - 1, t.shape()[2] - 1)
.to_scalar());
do_example!(t.ix3(tl(0), tl(0), tl(0)).to_scalar());

with_shapes!(t, t.ix3(1, .., ..));
with_shapes!(t, t.ix3(.., .., 2));

with_shapes!(t, t.ix2(1, Ellipsis));
with_shapes!(t, t.ix2(Ellipsis, 2));

with_shapes!(t, t.ix3(1..3, 1..2, ..));
with_shapes!(t, t.ix3(..3, 1.., ..));
with_shapes!(t, t.ix3(..tl(0), ..tl(1), ..));
with_shapes!(t, t.ix3(hd(3)..hd(1), hd(2)..hd(1), ..));

with_shapes!(t, t.ix4(.., .., NewAxis, ..));

with_shapes!(t, t.ix4(0, 1.., NewAxis, ..tl(0)));

// fancy indexing with int tensors
let_example!(i, TrI::new(&[2], &[1, 2]));
with_shapes!(t, t.oix1(&i));

let_example!(i, TrI::new(&[4], &[3, 0, 2, 1]));
with_shapes!(t, t.oix1(&i));

let_example!(i, TrI::new(&[4], &[0, 0, 1, 1]));
with_shapes!(t, t.oix1(&i));

let_example!(i, TrI::new(&[6], &[1; 6]));
with_shapes!(t, t.oix1(&i));

let_example!(i, TrI::new(&[2, 2], &[0, 1, 1, 0]));
with_shapes!(t, t.oix1(&i));

let_example!(i, TrI::new(&[2, 2], &[0, 1, 1, 0]));
with_shapes!(t, t.oix3(.., .., &i));

// oix
with_shapes!(t, t.ix3(1..3, 1..2, ..));
let_example!(i1, TrI::new(&[2], &[1, 2]));
let_example!(i2, TrI::new(&[1], &[1]));
let_example!(i3, TrI::new(&[2], &[0, 1]));
with_shapes!(t, t.oix3(&i1, &i2, &i3));

let_example!(i1, TrI::new(&[2], &[3, 0]));
let_example!(i2, TrI::new(&[2], &[2, 0]));
let_example!(i3, TrI::new(&[2], &[1, 0]));
with_shapes!(t, t.oix3(&i1, &i2, &i3));

let_example!(i1, TrI::new(&[2, 2], &[3, 3, 0, 0]));
let_example!(i2, TrI::new(&[2], &[2, 0]));
let_example!(i3, TrI::new(&[2, 2], &[1, 0, 1, 0]));
with_shapes!(t, t.oix3(&i1, &i2, &i3));

// vix
let_example!(t, TrI::new(&[3, 3], &(1..10).collect::<Vec<_>>()));
let_example!(i1, TrI::new(&[2], &[0, 2]));
let_example!(i2, TrI::new(&[2], &[0, 2]));
with_shapes!(t, t.oix2(&i1, &i2));
with_shapes!(t, t.vix2(&i1, &i2));

let_example!(i1, TrI::new(&[2, 2], &[0, 0, 2, 2]));
let_example!(i2, TrI::new(&[2, 2], &[0, 2, 0, 2]));
with_shapes!(t, t.vix2(&i1, &i2));

let_example!(t, TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>()));
let_example!(i1, TrI::new(&[2], &[0, 2]));
let_example!(i2, TrI::new(&[2], &[0, 1]));
with_shapes!(t, t.vix3(&i1, .., &i2));
with_shapes!(t, t.oix3(&i1, .., &i2));

// masks
let_example!(t, TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>()));
let_example!(i, TrB::new(&[4], &[false, false, true, false]));
with_shapes!(t, t.oix1(&i));

let_example!(i, t.eq(&TrI::new(&[2], &[1, 2])));
with_shapes!(t, t.oix1(&i));

// basic and fancy indexing compose
let_example!(t, TrI::new(&[4, 3, 2], &(1..25).collect::<Vec<_>>()));
let_example!(i1, TrI::new(&[2], &[0, 2]));
let_example!(i2, TrI::new(&[2], &[0, 1]));
with_shapes!(t, t.vix3(&i1, ..2, &i2));
with_shapes!(t, t.vix3(.., ..2, ..).vix3(&i1, .., &i2));

// oix

// one dim
let_example!(t, TrI::new(&[4], &(1..5).collect::<Vec<_>>()));
let_example!(i, TrI::new(&[2], &[2, 0]));

// first convert the indexing tensor i to a one hot tensor i_one_hot
let_example!(
i_range,
TrI::new(&[4, 1], (0..4).collect::<Vec<_>>().as_slice())
);
let_example!(i_one_hot, i.eq(&i_range).cast::<i32>());

// then multiply the one hot tensor with the original tensor
// but first reshape a bit so the right dimensions are broadcasted
let_example!(t, t.reshape(&[4, 1]));

let_example!(mul_result, t.mul(&i_one_hot));

// then sum reduce the result, and remove the remaining original dimension
let_example!(sum_result, mul_result.sum(&[0]).squeeze(&Axes::Axis(0)));

// two dim
// we want: t.oix(i, ..)
let_example!(t, TrI::new(&[4, 2], &(1..9).collect::<Vec<_>>()));
let_example!(i, TrI::new(&[2], &[2, 0]));

// first convert the indexing tensor i to a one hot tensor i_one_hot
let_example!(
i_range,
TrI::new(&[5, 1], (0..5).collect::<Vec<_>>().as_slice())
TrI::new(&[4, 1], (0..4).collect::<Vec<_>>().as_slice())
);
let_example!(i_one_hot, i.eq(&i_range).cast::<f32>());
let_example!(i_one_hot, i.eq(&i_range).cast::<i32>());

// then multiply the one hot tensor with the original tensor
// but first reshape a bit so the right dimensions are broadcasted
let_example!(t, t.reshape(&[5, 1, 3]));
let_example!(i_one_hot, i_one_hot.reshape(&[5, 2, 1]));
let_example!(t, t.reshape(&[4, 1, 2]));
let_example!(i_one_hot, i_one_hot.reshape(&[4, 2, 1]));
let_example!(mul_result, t.mul(&i_one_hot));

// then sum reduce the result, and remove the remaining original dimension
Expand Down
5 changes: 2 additions & 3 deletions examples/tour.rs
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,13 @@ fn main() {

let_example!(t, &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0]));
do_example!(t.crop(&[(0, 2), (1, 2)]));

let_example!(t, &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0]));
do_example!(t.reshape(&[1, 6]));
do_example!(t.pad(&[(1, 2), (1, 3)]));

let_example!(t, &Tr::linspace(1.0, 24.0, 24u8).reshape(&[2, 3, 4]));
do_example!(t.ix1(1));
do_example!(t.ix2(1, 0));
do_example!(t.ix2(..tl(1), Ellipsis));
do_example!(t.ix2(..tl(0), Ellipsis));
// hd counts from the front, tl from the back
do_example!(t.ix2(Ellipsis, hd(1)..tl(1)));
// invert the range to flip. First bound is still inclusive, second exclusive.
Expand Down
3 changes: 3 additions & 0 deletions src/indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,11 @@ pub enum IndexElement<I: DiffableOps> {
Single(SingleIndex),
// A range of elements in an axis. The second element is inclusive if the bool is true.
Slice(SingleIndex, SingleIndex, bool),
// Create a new axis with size 1.
NewAxis,
// Keep the remaining dimensions as is.
Ellipsis,
// Fancy index - mask or int tensor.
Fancy(Fancy<I>),
}

Expand Down

0 comments on commit 75ad393

Please sign in to comment.