From b4396c701dc80d85711209acb6b8dd0ca5f5be97 Mon Sep 17 00:00:00 2001 From: Kurt Schelfthout Date: Sun, 19 May 2024 15:58:17 +0100 Subject: [PATCH] Final examples. --- examples/fancy_indexing.rs | 44 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) diff --git a/examples/fancy_indexing.rs b/examples/fancy_indexing.rs index d7abe7c..a4f0cb2 100644 --- a/examples/fancy_indexing.rs +++ b/examples/fancy_indexing.rs @@ -171,4 +171,48 @@ fn main() { // then sum reduce the result, and remove the remaining original dimension let_example!(sum_result, mul_result.sum(&[0]).squeeze(&Axes::Axis(0))); + + // vix + // we want: t.vix(i, j) + let_example!(t, TrI::new(&[4, 6], &(1..25).collect::>())); + let_example!(i, TrI::new(&[2], &[2, 0])); + let_example!(j, TrI::new(&[2], &[1, 0])); + + // first convert the indexing tensor i to a one hot tensor i_one_hot + let_example!( + i_range, + TrI::new(&[4], (0..4).collect::>().as_slice()) + ); + let_example!(i, i.reshape(&[2, 1])); + // now i_one_hot has shape [2, 4] + let_example!(i_one_hot, i.eq(&i_range).cast::()); + + // multiply the one hot tensor with the original tensor + // but first reshape so the right dimensions are broadcasted + let_example!(i_one_hot, i_one_hot.reshape(&[2, 4, 1])); + let_example!(mul_result, t.mul(&i_one_hot)); + + // then sum reduce the result, and remove the remaining original dimension + let_example!(t, mul_result.sum(&[1]).squeeze(&Axes::Axis(1))); + + // then do the same for j + let_example!( + j_range, + TrI::new(&[6], (0..6).collect::>().as_slice()) + ); + let_example!(j, j.reshape(&[2, 1])); + // j_one_hot has shape [2,6]... + let_example!(j_one_hot, j.eq(&j_range).cast::()); + + // the same shape as t! So we can just multiply... + let_example!(mul_result, t.mul(&j_one_hot)); + // and sum: + let_example!(sum_result, mul_result.sum(&[1]).squeeze(&Axes::Axis(1))); + + // masks + let_example!( + b, + TrB::new(&[2, 3], &[false, false, true, false, true, false]) + ); + let_example!(i_b, TrI::new(&[2], &[2, 4])); }