-
Notifications
You must be signed in to change notification settings - Fork 306
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Pairwise summation #577
base: master
Are you sure you want to change the base?
Pairwise summation #577
Conversation
… across implementations
Nice. It's not obvious to me that we should use the current implementation, this, or other compensated summing algorithm in the default sums.
|
The only other compensated summation algorithm that comes to my mind is Kahan's summation algorithm, but its performance overhead is not negligible, hence I wouldn't provide it as a default implementation. Re 1. - Wouldn't using TypeId + 'static enable this implementation only for float primitives, Re. 3 - Ok, I'll work on it and add them to the PR. |
This reverts commit 3085194.
Running these benchmarks I get:
Never done serious benchmarking before, so take these results with a grain of salt. |
@LukeMathWalker Thanks for working on this! Pairwise summation is better than the current implementation. I would hope we'd be able to use pairwise summation without significant performance cost. I'm surprised that the differences in some of those benchmarks are as large as they are. A few thoughts:
|
We will still fall short on new types wrapping integer values, but this seems like a better solution. If we plan on specializing, we might also take advantage of the fact that, depending on the byte size, we might get more integers packed together using SIMD compared to floats. But I guess this falls under the broader discussion on how to leverage SIMD instructions in
I tried it out and this is the performance I get:
Good catch - I'll fix it.
Would it really be beneficial? When doing long products what I usually fear is overflowing/underflowing. To get around that, I use the exp-ln trick, as in rust-ndarray/ndarray-stats#20 for |
I took another look at pub fn sum_axis(&self, axis: Axis) -> Array<A, D::Smaller>
where A: Clone + Zero + Add<Output=A>,
D: RemoveAxis,
{
let n = self.len_of(axis);
let stride = self.strides()[axis.index()];
if self.ndim() == 2 && stride == 1 {
// contiguous along the axis we are summing
let mut res = Array::zeros(self.raw_dim().remove_axis(axis));
let ax = axis.index();
for (i, elt) in enumerate(&mut res) {
*elt = self.index_axis(Axis(1 - ax), i).sum();
}
res
} else if self.len_of(axis) <= numeric_util::NO_SIMD_NAIVE_SUM_THRESHOLD {
self.fold_axis(axis, A::zero(), |acc, x| acc.clone() + x.clone())
} else {
let (v1, v2) = self.view().split_at(axis, n / 2);
v1.sum_axis(axis) + v2.sum_axis(axis)
}
}
Actually, there's another issue I didn't catch earlier. In the base case, I also noticed that we technically need to implement pub fn unrolled_fold<A, I, F>(mut xs: &[A], init: I, f: F) -> A
where A: Clone,
I: Fn() -> A,
F: Fn(A, A) -> A,
{
// eightfold unrolled so that floating point can be vectorized
// (even with strict floating point accuracy semantics)
let (mut p0, mut p1, mut p2, mut p3,
mut p4, mut p5, mut p6, mut p7) =
(init(), init(), init(), init(),
init(), init(), init(), init());
while xs.len() >= 8 {
p0 = f(p0, xs[0].clone());
p1 = f(p1, xs[1].clone());
p2 = f(p2, xs[2].clone());
p3 = f(p3, xs[3].clone());
p4 = f(p4, xs[4].clone());
p5 = f(p5, xs[5].clone());
p6 = f(p6, xs[6].clone());
p7 = f(p7, xs[7].clone());
xs = &xs[8..];
}
let (q0, q1, q2, q3) = (f(p0, p4), f(p1, p5), f(p2, p6), f(p3, p7));
let (r0, r1) = (f(q0, q2), f(q1, q3));
let unrolled = f(r0, r1);
// make it clear to the optimizer that this loop is short
// and can not be autovectorized.
let mut partial = init();
for i in 0..xs.len() {
if i >= 7 { break; }
partial = f(partial.clone(), xs[i].clone())
}
f(unrolled, partial)
} The performance difference is within the By the way, after more thought, I'd prefer to rename
I assumed that floating-point multiplication would have the same issue with error accumulation as addition, but I don't really know. We can keep |
This is slightly faster and is easier to understand.
I noticed one more issue -- in the non-contiguous case, the current The only other things I'd like to see are:
Edit: I noticed that LukeMathWalker#3 causes a performance regression in the |
Improve pairwise summation
I looked at LukeMathWalker#3 and it seems good to me - I'll merge it and in the meanwhile you can polish the edits required to revert the performance regression you have observed. I'll work on getting more tests in there and some significant benchmarks with integers. The only thing about LukeMathWalker#3 that I was slightly confused about is: let cap = len.saturating_sub(1) / NAIVE_SUM_THRESHOLD + 1; // ceiling of division Why are we subtracting 1 from |
Re-running benchmarks: Desktop (AMD):
Laptop (Intel):
|
This causes no change in performance according to the relevant benchmarks in `bench1`.
This improves the performance of `sum` when an axis is contiguous but the array as a whole is not contiguous.
Okay, I've added my changes in LukeMathWalker#4. They should improve performance on the
The goal of this line is to determine the ceiling of let cap = len / NAIVE_SUM_THRESHOLD + if len % NAIVE_SUM_THRESHOLD != 0 { 1 } else { 0 }; The benchmarks are interesting, especially the difference between the two machines. All the results look favorable except for the discontiguous cases. I'm not really sure why they're so much worse. I've tried running |
Improve pairwise summation
The new formulation for I have reran the benchmarks with your latest changes - the inner discontinuous case is the only one that is clearly suffering right now. Desktop (AMD):
(Intel benchmarks coming later) |
What's the status of this PR? |
Motivation
The naive summation algorithm can cause significant loss in precision when summing several floating point numbers - pairwise summation mitigates the issue without adding significant computational overhead.
Ideally, this should only be used when dealing with floats, but unfortunately we can't specialize on the argument type. Mitigating precision loss in floating point sums, on the other hand, is of paramount importance for a lot of algorithms relying on it as a building block - given that the overhead is minimal, I'd argue that the precision benefit is sufficient to adopt it as standard summation method.
Public API
Nothing changes.
Technicalities
I haven't modified the overall structure in
sum
andsum_axis
: I have just provided some helper functions that isolate the pairwise summation algorithm.The function that sums slices leverages the available
unrolled_fold
to vectorise the sum once the vector length is below the size threhold.The function that sums iterators falls back on the slices summation function as soon as the first pass is over, to recover vectorisation given that the array of partial sums is contiguous in memory (only relevant for array with more than 512^2 elements).
The function that sums an iterator of arrays is self-recursive.
It would probably be possible to share the pagination logic (to sum in blocks of 512) between the two functions operating on iterators, but I haven't been able to do it so far.