Skip to content

Commit

Permalink
Final examples.
Browse files Browse the repository at this point in the history
  • Loading branch information
kurtschelfthout committed May 19, 2024
1 parent 75ad393 commit b4396c7
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions examples/fancy_indexing.rs
Original file line number Diff line number Diff line change
Expand Up @@ -171,4 +171,48 @@ fn main() {

// then sum reduce the result, and remove the remaining original dimension
let_example!(sum_result, mul_result.sum(&[0]).squeeze(&Axes::Axis(0)));

// vix
// we want: t.vix(i, j)
let_example!(t, TrI::new(&[4, 6], &(1..25).collect::<Vec<_>>()));
let_example!(i, TrI::new(&[2], &[2, 0]));
let_example!(j, TrI::new(&[2], &[1, 0]));

// first convert the indexing tensor i to a one hot tensor i_one_hot
let_example!(
i_range,
TrI::new(&[4], (0..4).collect::<Vec<_>>().as_slice())
);
let_example!(i, i.reshape(&[2, 1]));
// now i_one_hot has shape [2, 4]
let_example!(i_one_hot, i.eq(&i_range).cast::<i32>());

// multiply the one hot tensor with the original tensor
// but first reshape so the right dimensions are broadcasted
let_example!(i_one_hot, i_one_hot.reshape(&[2, 4, 1]));
let_example!(mul_result, t.mul(&i_one_hot));

// then sum reduce the result, and remove the remaining original dimension
let_example!(t, mul_result.sum(&[1]).squeeze(&Axes::Axis(1)));

// then do the same for j
let_example!(
j_range,
TrI::new(&[6], (0..6).collect::<Vec<_>>().as_slice())
);
let_example!(j, j.reshape(&[2, 1]));
// j_one_hot has shape [2,6]...
let_example!(j_one_hot, j.eq(&j_range).cast::<i32>());

// the same shape as t! So we can just multiply...
let_example!(mul_result, t.mul(&j_one_hot));
// and sum:
let_example!(sum_result, mul_result.sum(&[1]).squeeze(&Axes::Axis(1)));

// masks
let_example!(
b,
TrB::new(&[2, 3], &[false, false, true, false, true, false])
);
let_example!(i_b, TrI::new(&[2], &[2, 4]));
}

0 comments on commit b4396c7

Please sign in to comment.