Skip to content

Commit

Permalink
Merge pull request #890 from rust-ndarray/jt-scalar-ops-improvement
Browse files Browse the repository at this point in the history
Scalar + &array and &array + scalar performance improvements
  • Loading branch information
bluss authored Jan 10, 2021
2 parents 3f10677 + dcf38e8 commit 6483ef5
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 4 deletions.
16 changes: 16 additions & 0 deletions benches/bench1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -433,6 +433,22 @@ fn scalar_add_2(bench: &mut test::Bencher) {
bench.iter(|| n + &a);
}

#[bench]
fn scalar_add_strided_1(bench: &mut test::Bencher) {
let a =
Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]);
let n = 1.;
bench.iter(|| &a + n);
}

#[bench]
fn scalar_add_strided_2(bench: &mut test::Bencher) {
let a =
Array::from_shape_fn((64, 64 * 2), |(i, j)| (i * 64 + j) as f32).slice_move(s![.., ..;2]);
let n = 1.;
bench.iter(|| n + &a);
}

#[bench]
fn scalar_sub_1(bench: &mut test::Bencher) {
let a = Array::<f32, _>::zeros((64, 64));
Expand Down
8 changes: 4 additions & 4 deletions src/impl_ops.rs
Original file line number Diff line number Diff line change
Expand Up @@ -159,8 +159,8 @@ impl<'a, A, S, D, B> $trt<B> for &'a ArrayBase<S, D>
B: ScalarOperand,
{
type Output = Array<A, D>;
fn $mth(self, x: B) -> Array<A, D> {
self.to_owned().$mth(x)
fn $mth(self, x: B) -> Self::Output {
self.map(move |elt| elt.clone() $operator x.clone())
}
}
);
Expand Down Expand Up @@ -210,11 +210,11 @@ impl<'a, S, D> $trt<&'a ArrayBase<S, D>> for $scalar
D: Dimension,
{
type Output = Array<$scalar, D>;
fn $mth(self, rhs: &ArrayBase<S, D>) -> Array<$scalar, D> {
fn $mth(self, rhs: &ArrayBase<S, D>) -> Self::Output {
if_commutative!($commutative {
rhs.$mth(self)
} or {
self.$mth(rhs.to_owned())
rhs.map(move |elt| self.clone() $operator elt.clone())
})
}
}
Expand Down

0 comments on commit 6483ef5

Please sign in to comment.