diff --git a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp index 97768fc8e9..1ef724eebc 100644 --- a/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp +++ b/dpctl/tensor/libtensor/include/kernels/elementwise_functions/sin.hpp @@ -58,7 +58,8 @@ template struct SinFunctor // constant value, if constant // constexpr resT constant_value = resT{}; // is function defined for sycl::vec - using supports_vec = typename std::false_type; + using supports_vec = typename std::negation< + std::disjunction, is_complex>>; // do both argTy and resTy support sugroup store/load operation using supports_sg_loadstore = typename std::negation< std::disjunction, is_complex>>; @@ -181,6 +182,20 @@ template struct SinFunctor return std::sin(in); } } + + template + sycl::vec operator()(const sycl::vec &in) + { + auto const &res_vec = sycl::sin(in); + using deducedT = typename std::remove_cv_t< + std::remove_reference_t>::element_type; + if constexpr (std::is_same_v) { + return res_vec; + } + else { + return vec_cast(res_vec); + } + } }; template