diff --git a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp index 72440ef969..f67e1bba1f 100644 --- a/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp +++ b/dpctl/tensor/libtensor/include/utils/sycl_alloc_utils.hpp @@ -30,6 +30,7 @@ #include #include #include +#include #include #include "sycl/sycl.hpp" @@ -140,12 +141,50 @@ smart_malloc_host(std::size_t count, return smart_malloc(count, q, sycl::usm::alloc::host, propList); } +namespace +{ +template struct valid_smart_ptr : public std::false_type +{ +}; + +template +struct valid_smart_ptr &> + : public std::is_same +{ +}; + +template +struct valid_smart_ptr> + : public std::is_same +{ +}; + +// base case +template struct all_valid_smart_ptrs +{ + static constexpr bool value = true; +}; + +template +struct all_valid_smart_ptrs +{ + static constexpr bool value = valid_smart_ptr::value && + (all_valid_smart_ptrs::value); +}; +} // namespace + template sycl::event async_smart_free(sycl::queue &exec_q, const std::vector &depends, Args &&...args) { constexpr std::size_t n = sizeof...(Args); + static_assert( + n > 0, "async_smart_free requires at least one smart pointer argument"); + + static_assert( + all_valid_smart_ptrs::value, + "async_smart_free requires unique_ptr created with smart_malloc"); std::vector ptrs; ptrs.reserve(n);