diff --git a/examples/fancy_indexing.rs b/examples/fancy_indexing.rs index 82635c3..d7abe7c 100644 --- a/examples/fancy_indexing.rs +++ b/examples/fancy_indexing.rs @@ -1,4 +1,25 @@ -use tensorken::{Axes, Cpu32, CpuI32}; +use tensorken::{hd, tl, Axes, CpuBool, CpuI32, Ellipsis, NewAxis}; + +/// A macro to print the result of an expression and the expression itself. +macro_rules! with_shapes { + ($t:expr, $e:expr) => { + println!(">>> {}", stringify!($e)); + let input_shape = $t.shape(); + let result: CpuI32 = $e; + let result_shape = result.shape(); + println!("{:?} -> {:?}", input_shape, result_shape); + println!("{result}"); + }; +} + +/// A macro to print the result of an expression and the expression itself. +macro_rules! do_example { + ($e:expr) => { + println!(">>> {}", stringify!($e)); + let result = $e; + println!("{result}"); + }; +} /// A macro to print the result of an expression, the expression itself, /// and bind the result to a variable. @@ -8,33 +29,144 @@ macro_rules! let_example { let $t = $e; println!("{}", $t); }; - ($t:ident, $e:expr, $debug:literal) => { - println!(">>> {}", stringify!(let $t = $e)); - let $t = $e; - println!("{:?}", $t); - }; } -type Tr = Cpu32; type TrI = CpuI32; +type TrB = CpuBool; fn main() { - let_example!(t, Tr::linspace(1., 15.0, 15u8).reshape(&[5, 3])); + let_example!(t, TrI::new(&[4, 3, 2], &(1..25).collect::>())); + + // basic indexing + do_example!(t.ix3(0, 1, 2).to_scalar()); + with_shapes!(t, t.ix1(1)); + + do_example!(t + .ix3(t.shape()[0] - 1, t.shape()[1] - 1, t.shape()[2] - 1) + .to_scalar()); + do_example!(t.ix3(tl(0), tl(0), tl(0)).to_scalar()); + + with_shapes!(t, t.ix3(1, .., ..)); + with_shapes!(t, t.ix3(.., .., 2)); + + with_shapes!(t, t.ix2(1, Ellipsis)); + with_shapes!(t, t.ix2(Ellipsis, 2)); + + with_shapes!(t, t.ix3(1..3, 1..2, ..)); + with_shapes!(t, t.ix3(..3, 1.., ..)); + with_shapes!(t, t.ix3(..tl(0), ..tl(1), ..)); + with_shapes!(t, t.ix3(hd(3)..hd(1), hd(2)..hd(1), ..)); + + with_shapes!(t, t.ix4(.., .., NewAxis, ..)); + + with_shapes!(t, t.ix4(0, 1.., NewAxis, ..tl(0))); + + // fancy indexing with int tensors + let_example!(i, TrI::new(&[2], &[1, 2])); + with_shapes!(t, t.oix1(&i)); + + let_example!(i, TrI::new(&[4], &[3, 0, 2, 1])); + with_shapes!(t, t.oix1(&i)); + + let_example!(i, TrI::new(&[4], &[0, 0, 1, 1])); + with_shapes!(t, t.oix1(&i)); + + let_example!(i, TrI::new(&[6], &[1; 6])); + with_shapes!(t, t.oix1(&i)); + + let_example!(i, TrI::new(&[2, 2], &[0, 1, 1, 0])); + with_shapes!(t, t.oix1(&i)); + + let_example!(i, TrI::new(&[2, 2], &[0, 1, 1, 0])); + with_shapes!(t, t.oix3(.., .., &i)); + + // oix + with_shapes!(t, t.ix3(1..3, 1..2, ..)); + let_example!(i1, TrI::new(&[2], &[1, 2])); + let_example!(i2, TrI::new(&[1], &[1])); + let_example!(i3, TrI::new(&[2], &[0, 1])); + with_shapes!(t, t.oix3(&i1, &i2, &i3)); + + let_example!(i1, TrI::new(&[2], &[3, 0])); + let_example!(i2, TrI::new(&[2], &[2, 0])); + let_example!(i3, TrI::new(&[2], &[1, 0])); + with_shapes!(t, t.oix3(&i1, &i2, &i3)); + + let_example!(i1, TrI::new(&[2, 2], &[3, 3, 0, 0])); + let_example!(i2, TrI::new(&[2], &[2, 0])); + let_example!(i3, TrI::new(&[2, 2], &[1, 0, 1, 0])); + with_shapes!(t, t.oix3(&i1, &i2, &i3)); + + // vix + let_example!(t, TrI::new(&[3, 3], &(1..10).collect::>())); + let_example!(i1, TrI::new(&[2], &[0, 2])); + let_example!(i2, TrI::new(&[2], &[0, 2])); + with_shapes!(t, t.oix2(&i1, &i2)); + with_shapes!(t, t.vix2(&i1, &i2)); + + let_example!(i1, TrI::new(&[2, 2], &[0, 0, 2, 2])); + let_example!(i2, TrI::new(&[2, 2], &[0, 2, 0, 2])); + with_shapes!(t, t.vix2(&i1, &i2)); + + let_example!(t, TrI::new(&[4, 3, 2], &(1..25).collect::>())); + let_example!(i1, TrI::new(&[2], &[0, 2])); + let_example!(i2, TrI::new(&[2], &[0, 1])); + with_shapes!(t, t.vix3(&i1, .., &i2)); + with_shapes!(t, t.oix3(&i1, .., &i2)); + + // masks + let_example!(t, TrI::new(&[4, 3, 2], &(1..25).collect::>())); + let_example!(i, TrB::new(&[4], &[false, false, true, false])); + with_shapes!(t, t.oix1(&i)); + + let_example!(i, t.eq(&TrI::new(&[2], &[1, 2]))); + with_shapes!(t, t.oix1(&i)); + + // basic and fancy indexing compose + let_example!(t, TrI::new(&[4, 3, 2], &(1..25).collect::>())); + let_example!(i1, TrI::new(&[2], &[0, 2])); + let_example!(i2, TrI::new(&[2], &[0, 1])); + with_shapes!(t, t.vix3(&i1, ..2, &i2)); + with_shapes!(t, t.vix3(.., ..2, ..).vix3(&i1, .., &i2)); + + // oix + + // one dim + let_example!(t, TrI::new(&[4], &(1..5).collect::>())); let_example!(i, TrI::new(&[2], &[2, 0])); + // first convert the indexing tensor i to a one hot tensor i_one_hot + let_example!( + i_range, + TrI::new(&[4, 1], (0..4).collect::>().as_slice()) + ); + let_example!(i_one_hot, i.eq(&i_range).cast::()); + + // then multiply the one hot tensor with the original tensor + // but first reshape a bit so the right dimensions are broadcasted + let_example!(t, t.reshape(&[4, 1])); + + let_example!(mul_result, t.mul(&i_one_hot)); + + // then sum reduce the result, and remove the remaining original dimension + let_example!(sum_result, mul_result.sum(&[0]).squeeze(&Axes::Axis(0))); + + // two dim // we want: t.oix(i, ..) + let_example!(t, TrI::new(&[4, 2], &(1..9).collect::>())); + let_example!(i, TrI::new(&[2], &[2, 0])); // first convert the indexing tensor i to a one hot tensor i_one_hot let_example!( i_range, - TrI::new(&[5, 1], (0..5).collect::>().as_slice()) + TrI::new(&[4, 1], (0..4).collect::>().as_slice()) ); - let_example!(i_one_hot, i.eq(&i_range).cast::()); + let_example!(i_one_hot, i.eq(&i_range).cast::()); // then multiply the one hot tensor with the original tensor // but first reshape a bit so the right dimensions are broadcasted - let_example!(t, t.reshape(&[5, 1, 3])); - let_example!(i_one_hot, i_one_hot.reshape(&[5, 2, 1])); + let_example!(t, t.reshape(&[4, 1, 2])); + let_example!(i_one_hot, i_one_hot.reshape(&[4, 2, 1])); let_example!(mul_result, t.mul(&i_one_hot)); // then sum reduce the result, and remove the remaining original dimension diff --git a/examples/tour.rs b/examples/tour.rs index dbcfe4d..80edd0d 100644 --- a/examples/tour.rs +++ b/examples/tour.rs @@ -85,14 +85,13 @@ fn main() { let_example!(t, &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0])); do_example!(t.crop(&[(0, 2), (1, 2)])); - - let_example!(t, &Tr::new(&[3, 2], &[2.0, 1.0, 4.0, 2.0, 8.0, 4.0])); + do_example!(t.reshape(&[1, 6])); do_example!(t.pad(&[(1, 2), (1, 3)])); let_example!(t, &Tr::linspace(1.0, 24.0, 24u8).reshape(&[2, 3, 4])); do_example!(t.ix1(1)); do_example!(t.ix2(1, 0)); - do_example!(t.ix2(..tl(1), Ellipsis)); + do_example!(t.ix2(..tl(0), Ellipsis)); // hd counts from the front, tl from the back do_example!(t.ix2(Ellipsis, hd(1)..tl(1))); // invert the range to flip. First bound is still inclusive, second exclusive. diff --git a/src/indexing.rs b/src/indexing.rs index c2bb32c..baffee4 100644 --- a/src/indexing.rs +++ b/src/indexing.rs @@ -60,8 +60,11 @@ pub enum IndexElement { Single(SingleIndex), // A range of elements in an axis. The second element is inclusive if the bool is true. Slice(SingleIndex, SingleIndex, bool), + // Create a new axis with size 1. NewAxis, + // Keep the remaining dimensions as is. Ellipsis, + // Fancy index - mask or int tensor. Fancy(Fancy), }