Skip to content

Commit

Permalink
Tpetra: fixup do not preserve the sign max abs operations
Browse files Browse the repository at this point in the history
  • Loading branch information
dalg24 committed Jan 29, 2023
1 parent ed2d18e commit 110d828
Showing 1 changed file with 18 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -787,23 +787,35 @@ outOfBounds (const IntegerType x, const IntegerType exclusiveUpperBound)

struct AbsMaxOp {
template <class Scalar>
struct WrapScalarAndCompareAbsMax{
struct AbsMaxHelper{
Scalar value;
private:
friend KOKKOS_FUNCTION bool operator<(WrapScalarAndCompareAbsMax const& lhs, WrapScalarAndCompareAbsMax const& rhs) { return Kokkos::ArithTraits<Scalar>::abs(lhs.value) < Kokkos::ArithTraits<Scalar>::abs(rhs.value); }
friend KOKKOS_FUNCTION bool operator>(WrapScalarAndCompareAbsMax const& lhs, WrapScalarAndCompareAbsMax const& rhs) { return Kokkos::ArithTraits<Scalar>::abs(lhs.value) > Kokkos::ArithTraits<Scalar>::abs(rhs.value); }

KOKKOS_FUNCTION AbsMaxHelper& operator+=(AbsMaxHelper const& rhs) {
auto lhs_abs_value = Kokkos::ArithTraits<Scalar>::abs(value);
auto rhs_abs_value = Kokkos::ArithTraits<Scalar>::abs(rhs.value);
value = lhs_abs_value > rhs_abs_value ? lhs_abs_value : rhs_abs_value;
return *this;
}

KOKKOS_FUNCTION AbsMaxHelper operator+(AbsMaxHelper const& rhs) const {
AbsMaxHelper ret = *this;
ret += rhs;
return ret;
}
};

template <typename SC>
KOKKOS_INLINE_FUNCTION
void operator() (atomic_tag, SC& dst, const SC& src) const {
Kokkos::atomic_max(reinterpret_cast<WrapScalarAndCompareAbsMax<SC>*>(&dst), WrapScalarAndCompareAbsMax<SC>{src});
Kokkos::atomic_add(reinterpret_cast<AbsMaxHelper<SC>*>(&dst), AbsMaxHelper<SC>{src});
}

template <typename SC>
KOKKOS_INLINE_FUNCTION
void operator() (nonatomic_tag, SC& dst, const SC& src) const {
if (Kokkos::ArithTraits<SC>::abs(dst) < Kokkos::ArithTraits<SC>::abs(src)) dst = src;
auto dst_abs_value = Kokkos::ArithTraits<SC>::abs(dst);
auto src_abs_value = Kokkos::ArithTraits<SC>::abs(src);
dst = dst_abs_value > src_abs_value ? dst_abs_value : src_abs_value;
}
};

Expand Down

0 comments on commit 110d828

Please sign in to comment.