Skip to content

Commit

Permalink
Instantiate atomic reduction templates for min/max ops for double/flo…
Browse files Browse the repository at this point in the history
…at types

Added entries for float and double types to TypePairSupportDataForCompReductionAtomic
as spotted by @ndgrigorian in the PR review.

Also moved comments around.
  • Loading branch information
oleksandr-pavlyk committed Nov 3, 2023
1 parent 41ec378 commit 645044a
Showing 1 changed file with 10 additions and 9 deletions.
19 changes: 10 additions & 9 deletions dpctl/tensor/libtensor/include/kernels/reductions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -2247,11 +2247,10 @@ template <typename argTy, typename outTy>
struct TypePairSupportDataForCompReductionAtomic
{

/* value if true a kernel for <argTy, outTy> must be instantiated, false
/* value is true if a kernel for <argTy, outTy> must be instantiated, false
* otherwise */
static constexpr bool is_defined = std::disjunction< // disjunction is C++17
// feature, supported
// by DPC++
// disjunction is C++17 feature, supported by DPC++
static constexpr bool is_defined = std::disjunction<
// input int32
td_ns::TypePairDefinedEntry<argTy, std::int32_t, outTy, std::int32_t>,
// input uint32
Expand All @@ -2260,6 +2259,10 @@ struct TypePairSupportDataForCompReductionAtomic
td_ns::TypePairDefinedEntry<argTy, std::int64_t, outTy, std::int64_t>,
// input uint64
td_ns::TypePairDefinedEntry<argTy, std::uint64_t, outTy, std::uint64_t>,
// input float
td_ns::TypePairDefinedEntry<argTy, float, outTy, float>,
// input double
td_ns::TypePairDefinedEntry<argTy, double, outTy, double>,
// fall-through
td_ns::NotDefinedEntry>::is_defined;
};
Expand All @@ -2268,19 +2271,17 @@ template <typename argTy, typename outTy>
struct TypePairSupportDataForCompReductionTemps
{

static constexpr bool is_defined = std::disjunction< // disjunction is C++17
// feature, supported
// by DPC++ input bool
// disjunction is C++17 feature, supported by DPC++
static constexpr bool is_defined = std::disjunction<
// input bool
td_ns::TypePairDefinedEntry<argTy, bool, outTy, bool>,
// input int8_t
td_ns::TypePairDefinedEntry<argTy, std::int8_t, outTy, std::int8_t>,

// input uint8_t
td_ns::TypePairDefinedEntry<argTy, std::uint8_t, outTy, std::uint8_t>,

// input int16_t
td_ns::TypePairDefinedEntry<argTy, std::int16_t, outTy, std::int16_t>,

// input uint16_t
td_ns::TypePairDefinedEntry<argTy, std::uint16_t, outTy, std::uint16_t>,

Expand Down

0 comments on commit 645044a

Please sign in to comment.