Skip to content

Commit

Permalink
Merge pull request trilinos#11506 from dalg24/tpetra_atomic_max_abs
Browse files Browse the repository at this point in the history
Tpetra: refactor atomic maximum absolute value of scalar (compatibility with Kokkos)
  • Loading branch information
csiefer2 authored Jan 31, 2023
2 parents 18f4069 + 110d828 commit 62bb6ac
Showing 1 changed file with 24 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@

#include "Kokkos_Core.hpp"
#include "Kokkos_ArithTraits.hpp"
#include "impl/Kokkos_Atomic_Generic.hpp"
#include <sstream>
#include <stdexcept>

Expand Down Expand Up @@ -786,28 +785,37 @@ outOfBounds (const IntegerType x, const IntegerType exclusiveUpperBound)
}
};

// Kokkos::Impl::atomic_fetch_oper wants a class like this.
template<class Scalar1, class Scalar2>
struct AbsMaxOper {
KOKKOS_INLINE_FUNCTION
static Scalar1 apply(const Scalar1& val1, const Scalar2& val2) {
const auto val1_abs = Kokkos::ArithTraits<Scalar1>::abs(val1);
const auto val2_abs = Kokkos::ArithTraits<Scalar2>::abs(val2);
return val1_abs > val2_abs ? Scalar1(val1_abs) : Scalar1(val2_abs);
}
};

struct AbsMaxOp {
template <class Scalar>
struct AbsMaxHelper{
Scalar value;

KOKKOS_FUNCTION AbsMaxHelper& operator+=(AbsMaxHelper const& rhs) {
auto lhs_abs_value = Kokkos::ArithTraits<Scalar>::abs(value);
auto rhs_abs_value = Kokkos::ArithTraits<Scalar>::abs(rhs.value);
value = lhs_abs_value > rhs_abs_value ? lhs_abs_value : rhs_abs_value;
return *this;
}

KOKKOS_FUNCTION AbsMaxHelper operator+(AbsMaxHelper const& rhs) const {
AbsMaxHelper ret = *this;
ret += rhs;
return ret;
}
};

template <typename SC>
KOKKOS_INLINE_FUNCTION
void operator() (atomic_tag, SC& dest, const SC& src) const {
Kokkos::Impl::atomic_fetch_oper (AbsMaxOper<SC, SC> (), &dest, src);
void operator() (atomic_tag, SC& dst, const SC& src) const {
Kokkos::atomic_add(reinterpret_cast<AbsMaxHelper<SC>*>(&dst), AbsMaxHelper<SC>{src});
}

template <typename SC>
KOKKOS_INLINE_FUNCTION
void operator() (nonatomic_tag, SC& dest, const SC& src) const {
dest = AbsMaxOper<SC, SC> ().apply (dest, src);
void operator() (nonatomic_tag, SC& dst, const SC& src) const {
auto dst_abs_value = Kokkos::ArithTraits<SC>::abs(dst);
auto src_abs_value = Kokkos::ArithTraits<SC>::abs(src);
dst = dst_abs_value > src_abs_value ? dst_abs_value : src_abs_value;
}
};

Expand Down

0 comments on commit 62bb6ac

Please sign in to comment.