diff --git a/packages/tpetra/core/src/kokkos_refactor/Tpetra_KokkosRefactor_Details_MultiVectorDistObjectKernels.hpp b/packages/tpetra/core/src/kokkos_refactor/Tpetra_KokkosRefactor_Details_MultiVectorDistObjectKernels.hpp index b27678ff3b13..d947a2364648 100644 --- a/packages/tpetra/core/src/kokkos_refactor/Tpetra_KokkosRefactor_Details_MultiVectorDistObjectKernels.hpp +++ b/packages/tpetra/core/src/kokkos_refactor/Tpetra_KokkosRefactor_Details_MultiVectorDistObjectKernels.hpp @@ -787,23 +787,35 @@ outOfBounds (const IntegerType x, const IntegerType exclusiveUpperBound) struct AbsMaxOp { template - struct WrapScalarAndCompareAbsMax{ + struct AbsMaxHelper{ Scalar value; - private: - friend KOKKOS_FUNCTION bool operator<(WrapScalarAndCompareAbsMax const& lhs, WrapScalarAndCompareAbsMax const& rhs) { return Kokkos::ArithTraits::abs(lhs.value) < Kokkos::ArithTraits::abs(rhs.value); } - friend KOKKOS_FUNCTION bool operator>(WrapScalarAndCompareAbsMax const& lhs, WrapScalarAndCompareAbsMax const& rhs) { return Kokkos::ArithTraits::abs(lhs.value) > Kokkos::ArithTraits::abs(rhs.value); } + + KOKKOS_FUNCTION AbsMaxHelper& operator+=(AbsMaxHelper const& rhs) { + auto lhs_abs_value = Kokkos::ArithTraits::abs(value); + auto rhs_abs_value = Kokkos::ArithTraits::abs(rhs.value); + value = lhs_abs_value > rhs_abs_value ? lhs_abs_value : rhs_abs_value; + return *this; + } + + KOKKOS_FUNCTION AbsMaxHelper operator+(AbsMaxHelper const& rhs) const { + AbsMaxHelper ret = *this; + ret += rhs; + return ret; + } }; template KOKKOS_INLINE_FUNCTION void operator() (atomic_tag, SC& dst, const SC& src) const { - Kokkos::atomic_max(reinterpret_cast*>(&dst), WrapScalarAndCompareAbsMax{src}); + Kokkos::atomic_add(reinterpret_cast*>(&dst), AbsMaxHelper{src}); } template KOKKOS_INLINE_FUNCTION void operator() (nonatomic_tag, SC& dst, const SC& src) const { - if (Kokkos::ArithTraits::abs(dst) < Kokkos::ArithTraits::abs(src)) dst = src; + auto dst_abs_value = Kokkos::ArithTraits::abs(dst); + auto src_abs_value = Kokkos::ArithTraits::abs(src); + dst = dst_abs_value > src_abs_value ? dst_abs_value : src_abs_value; } };