Skip to content

Commit

Permalink
Adds sycl::vec overloads to abs, cos, expm1, log, log1p, and sqrt
Browse files Browse the repository at this point in the history
  • Loading branch information
ndgrigorian committed Nov 25, 2023
1 parent 5ec9fd5 commit a0959d0
Show file tree
Hide file tree
Showing 6 changed files with 122 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct AbsFunctor
{

using is_constant = typename std::false_type;
// constexpr resT constant_value = resT{};
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;

Expand Down Expand Up @@ -127,6 +129,40 @@ template <typename argT, typename resT> struct AbsFunctor
#endif
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
if constexpr (std::is_integral<argT>::value) {
if constexpr (std::is_same_v<argT, bool> ||
std::is_unsigned<argT>::value) {
return in;
}
else {
auto const &res_vec = sycl::abs(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {

return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
}
else {
auto const &res_vec = sycl::fabs(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
}
};

template <typename argT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct CosFunctor
{
Expand All @@ -59,7 +60,8 @@ template <typename argT, typename resT> struct CosFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -165,6 +167,20 @@ template <typename argT, typename resT> struct CosFunctor
return std::cos(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::cos(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct Expm1Functor
{
Expand All @@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Expm1Functor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -132,6 +134,20 @@ template <typename argT, typename resT> struct Expm1Functor
return std::expm1(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::expm1(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct LogFunctor
{
Expand All @@ -60,7 +61,8 @@ template <typename argT, typename resT> struct LogFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand All @@ -79,6 +81,20 @@ template <typename argT, typename resT> struct LogFunctor
return std::log(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::log(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

// TODO: evaluate precision against alternatives
template <typename argT, typename resT> struct Log1pFunctor
Expand All @@ -60,7 +61,8 @@ template <typename argT, typename resT> struct Log1pFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -99,6 +101,20 @@ template <typename argT, typename resT> struct Log1pFunctor
return std::log1p(in);
}
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::log1p(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ namespace py = pybind11;
namespace td_ns = dpctl::tensor::type_dispatch;

using dpctl::tensor::type_utils::is_complex;
using dpctl::tensor::type_utils::vec_cast;

template <typename argT, typename resT> struct SqrtFunctor
{
Expand All @@ -62,7 +63,8 @@ template <typename argT, typename resT> struct SqrtFunctor
// constant value, if constant
// constexpr resT constant_value = resT{};
// is function defined for sycl::vec
using supports_vec = typename std::false_type;
using supports_vec = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
// do both argTy and resTy support sugroup store/load operation
using supports_sg_loadstore = typename std::negation<
std::disjunction<is_complex<resT>, is_complex<argT>>>;
Expand Down Expand Up @@ -263,6 +265,20 @@ template <typename argT, typename resT> struct SqrtFunctor
? csqrt_finite_unscaled(x, y)
: csqrt_finite_scaled(x, y);
}

template <int vec_sz>
sycl::vec<resT, vec_sz> operator()(const sycl::vec<argT, vec_sz> &in)
{
auto const &res_vec = sycl::sqrt(in);
using deducedT = typename std::remove_cv_t<
std::remove_reference_t<decltype(res_vec)>>::element_type;
if constexpr (std::is_same_v<resT, deducedT>) {
return res_vec;
}
else {
return vec_cast<resT, deducedT, vec_sz>(res_vec);
}
}
};

template <typename argTy,
Expand Down

0 comments on commit a0959d0

Please sign in to comment.