-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Generalize arithmetic ops to more combinations of scalars and arrays #782
base: master
Are you sure you want to change the base?
Conversation
It appears that Rust 1.37 has a bug that prevents this PR from working properly. Fortunately, this bug isn't present in the latest stable compiler, but we'll need to wait until we bump the minimum required Rust version before merging this PR. |
Delightful that the |
This change has two benefits: * The new implementation applies to more combinations of types. For example, it now applies to `&Array2<f32>` and `Complex<f32>`. * The new implementation avoids cloning the elements twice, and it avoids iterating over the elements twice. (The old implementation called `.to_owned()` followed by the arithmetic operation, while the new implementation clones the elements and performs the arithmetic operation in the same iteration.) On my machine, this change improves the performance for both contiguous and discontiguous arrays. (`scalar_add_1/2` go from ~530 ns/iter to ~380 ns/iter, and `scalar_add_strided_1/2` go from ~1540 ns/iter to ~1420 ns/iter.)
This doesn't have a noticeable impact on the results of the `scalar_add_2` and `scalar_add_strided_2` benchmarks.
19e35d3
to
b2a7d0b
Compare
Rebased to current master |
$scalar: Clone + $trt<A, Output=B>, | ||
A: Clone, | ||
S: Data<Elem=A>, | ||
D: Dimension, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This impl somehow now breaks Rust -- see the failed tests -- and causes a recursion errror - for an expression that has type f32
+ f32
which is quite strange/scary(!)
--> tests/oper.rs:159:48
|
159 | .fold(f32::zero(), |acc, (&x, &y)| acc + x * y)
| ^
|
= help: consider adding a `#![recursion_limit="256"]` attribute to your crate (`oper`)
= note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
= note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
= note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
= note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
= note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
= note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
= note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
= note: required because of the requirements on the impl of `Add<&ndarray::ArrayBase<_, _>>` for `f32`
Unsure if this is a Rust bug - for example that the impl is accepted(?), but I think this impl is too general and has infinite descent.
Given the question if f32
implements Add<&ArrayBase<S, D>>
look for other impl that has f32: Add<A>
where S: Data<Elem=A>
which looks recursive, is that it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It looks like a compiler bug to me. As you point out, the expression involves only f32
, but for some reason, the error message indicates that one of the arguments is an array. It's also interesting that on my machine with Rust 1.48.0, the error message is slightly different, saying "impl of Add<ndarray::ArrayBase<_, _>>
for f32
" instead of the error message in your comment "impl of Add<&ndarray::ArrayBase<_, _>>
for f32
". (Note the &
.)
The function fails to compile (with the same error message) even after adding type annotations:
fn reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32
where
V1: AsArray<'a, f32>,
V2: AsArray<'a, f32>,
{
let a: ArrayView1<'a, f32> = a.into();
let b: ArrayView1<'a, f32> = b.into();
a.iter()
.zip(b.iter())
.fold(f32::zero(), |acc: f32, (&x, &y): (&f32, &f32)| acc + x * y)
}
but if I remove the + x * y
, it compiles successfully:
fn reference_dot<'a, V1, V2>(a: V1, b: V2) -> f32
where
V1: AsArray<'a, f32>,
V2: AsArray<'a, f32>,
{
let a: ArrayView1<'a, f32> = a.into();
let b: ArrayView1<'a, f32> = b.into();
a.iter()
.zip(b.iter())
.fold(f32::zero(), |acc: f32, (&x, &y): (&f32, &f32)| acc)
}
I don't see any reason other than a compiler bug for the first function to fail to compile when the second one compiles without errors, since the type annotations confirm that the closure is operating only on f32
values.
This also compiles successfully:
fn reference_dot2<'a>(a: ArrayView1<'a, f32>, b: ArrayView1<'a, f32>) -> f32 {
a.iter()
.zip(b.iter())
.fold(f32::zero(), |acc: f32, (&x, &y): (&f32, &f32)| acc + x * y)
}
so the bug involves the .into()
calls in some way. It's surprising that adding explicit type annotations for the results of the .into()
calls, as in the first example, doesn't work around the bug.
Fwiw, I don't think impl<'a, A, S, D, B> $trt<&'a ArrayBase<S, D>> for $scalar
is infinitely recursive, since AFAIK it's not possible to have an array of (arrays of (arrays of (arrays of ... [infinite depth]))). The innermost array type can only have an element type that's not an array. You're right that there is recursion if you're dealing with arrays of arrays, but that's the correct behavior, and the recursion is not infinite.
For the particular function we're looking at, the impl doesn't apply, and I don't think the compiler should be trying to apply it. (I think it should only apply the impl if it knows the RHS has some type &ArrayBase<?S, ?D>
, where ?S
and ?D
are inference variables.)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting, the test runners for cross_test, stable, mips vs i686 disagree with each other about the error too, in the same way, even if they both use Rust 1.48
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I reported the issue (with a simplified example) at rust-lang/rust#80542.
I have considered the question of deprecating scalars as left hand side (LHS) operands. The reason would be because their implementation does not fit well with how trait impls are normally written, and the inevitable asymmetry between array + scalar and scalar + array in terms of which types are accepted. |
I agree that the implementations we have are somewhat unsatisfying, but IMO they're useful enough to keep. I would guess that the vast majority of users are dealing with the element types we implement the operators for, probably mostly I suppose an alternative option to the existing impls would be a Scalar(2.) / array which would work with any element type but would be less intuitive and would make expressions more verbose. I'm not sure a |
I think I have found out that if this (ugly) workaround is applied, the ScalarOperand trait is not needed anymore - meaning an unrestricted I think that Scalar is a lot better than |
Benchmark and performance improvements are being included by using PR #890, that supersedes just the first commit and the |
This PR generalizes the existing implementations of arithmetic operations between scalars and arrays to more combinations of types. This is especially useful for operations between complex and real types.
A couple of notes:
This removes the special handling of commutative operations. (Before, commutative operations with the scalar on the left side were implemented by calling the operation with the scalar on the right side.) IMO, implementations for more combinations of types are more important than possible differences in compile time due to reusing implementations.
The new
&arr (op) scalar
implementation brings a performance boost.An alternative approach for the "scalar on lhs" operations would be to add more implementations for specific combinations of types, e.g.
f32
andComplex<f32>
. I chose the generic approach instead for its conciseness and flexibility.This change is backwards compatible, except for possible changes in type inference due to the implementations for more combinations of types.
Fixes #781.