diff --git a/include/oneapi/mkl/dft/detail/commit_impl.hpp b/include/oneapi/mkl/dft/detail/commit_impl.hpp index afa8238eb..a2b84e3e8 100644 --- a/include/oneapi/mkl/dft/detail/commit_impl.hpp +++ b/include/oneapi/mkl/dft/detail/commit_impl.hpp @@ -28,12 +28,14 @@ #endif #include "oneapi/mkl/detail/backends.hpp" +#include "types_impl.hpp" namespace oneapi { namespace mkl { namespace dft { namespace detail { +template class commit_impl { public: commit_impl(sycl::queue queue, mkl::backend backend) @@ -41,7 +43,7 @@ class commit_impl { backend_(backend), status(false) {} - commit_impl(const commit_impl& other) = default; + commit_impl(const commit_impl& other) = delete; virtual ~commit_impl() = default; @@ -55,6 +57,8 @@ class commit_impl { virtual void* get_handle() = 0; + virtual void commit(sycl::queue& queue, dft_values&) = 0; + protected: bool status; mkl::backend backend_; diff --git a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp index 0b388e60d..7ff6e72d5 100644 --- a/include/oneapi/mkl/dft/detail/descriptor_impl.hpp +++ b/include/oneapi/mkl/dft/detail/descriptor_impl.hpp @@ -41,7 +41,7 @@ template class descriptor; template -inline commit_impl* get_commit(descriptor& desc); +inline commit_impl* get_commit(descriptor& desc); template class descriptor { @@ -68,22 +68,22 @@ class descriptor { void commit(backend_selector selector); #endif - dft_values get_values() { + dft_values& get_values() { return values_; }; private: // Has a value when the descriptor is committed. - std::unique_ptr pimpl_; + std::unique_ptr> pimpl_; // descriptor configuration values_ and structs dft_values values_; - friend commit_impl* get_commit(descriptor&); + friend commit_impl* get_commit(descriptor&); }; template -inline commit_impl* get_commit(descriptor& desc) { +inline commit_impl* get_commit(descriptor& desc) { return desc.pimpl_.get(); } diff --git a/include/oneapi/mkl/dft/detail/dft_ct.hxx b/include/oneapi/mkl/dft/detail/dft_ct.hxx index 5c06c9efb..1c5c29dca 100644 --- a/include/oneapi/mkl/dft/detail/dft_ct.hxx +++ b/include/oneapi/mkl/dft/detail/dft_ct.hxx @@ -20,8 +20,8 @@ // Commit template -ONEMKL_EXPORT dft::detail::commit_impl *create_commit(dft::detail::descriptor &desc, - sycl::queue &sycl_queue); +ONEMKL_EXPORT dft::detail::commit_impl *create_commit( + dft::detail::descriptor &desc, sycl::queue &sycl_queue); // BUFFER version diff --git a/include/oneapi/mkl/dft/detail/dft_loader.hpp b/include/oneapi/mkl/dft/detail/dft_loader.hpp index 0793658de..f173aa40a 100644 --- a/include/oneapi/mkl/dft/detail/dft_loader.hpp +++ b/include/oneapi/mkl/dft/detail/dft_loader.hpp @@ -34,13 +34,15 @@ namespace mkl { namespace dft { namespace detail { +template class commit_impl; template class descriptor; template -ONEMKL_EXPORT commit_impl* create_commit(descriptor& desc, sycl::queue& queue); +ONEMKL_EXPORT commit_impl* create_commit(descriptor& desc, + sycl::queue& queue); } // namespace detail } // namespace dft diff --git a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp index d8a824875..cfd2c6d99 100644 --- a/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp @@ -35,6 +35,7 @@ namespace dft { namespace detail { // Forward declarations +template class commit_impl; template diff --git a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp index b6227aa7b..a03e235c7 100644 --- a/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp +++ b/include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp @@ -35,6 +35,7 @@ namespace dft { namespace detail { // Forward declarations +template class commit_impl; template diff --git a/include/oneapi/mkl/dft/detail/types_impl.hpp b/include/oneapi/mkl/dft/detail/types_impl.hpp index d0555409f..de89c6f4b 100644 --- a/include/oneapi/mkl/dft/detail/types_impl.hpp +++ b/include/oneapi/mkl/dft/detail/types_impl.hpp @@ -23,6 +23,7 @@ #include #include #include +#include namespace oneapi { namespace mkl { @@ -100,6 +101,8 @@ class dft_values { using real_t = std::conditional_t; public: + std::set changed; + std::vector input_strides; std::vector output_strides; real_t bwd_scale; @@ -116,7 +119,6 @@ class dft_values { bool transpose; config_value packed_format; std::vector dimensions; - std::int64_t rank; }; } // namespace detail diff --git a/src/dft/backends/descriptor.cpp b/src/dft/backends/descriptor.cpp index 105dcef74..cc0234606 100644 --- a/src/dft/backends/descriptor.cpp +++ b/src/dft/backends/descriptor.cpp @@ -28,7 +28,10 @@ namespace dft { template void descriptor::commit(sycl::queue &queue) { - pimpl_.reset(detail::create_commit(*this, queue)); + if (!pimpl_) { + pimpl_.reset(detail::create_commit(*this, queue)); + } + pimpl_->commit(queue, values_); } template void descriptor::commit(sycl::queue &); template void descriptor::commit(sycl::queue &); diff --git a/src/dft/backends/mklcpu/commit.cpp b/src/dft/backends/mklcpu/commit.cpp index e3c993b69..d458328ef 100644 --- a/src/dft/backends/mklcpu/commit.cpp +++ b/src/dft/backends/mklcpu/commit.cpp @@ -40,12 +40,12 @@ namespace dft { namespace mklcpu { template -class commit_derived_impl : public detail::commit_impl { +class commit_derived_impl final : public detail::commit_impl { public: - commit_derived_impl(sycl::queue queue, detail::dft_values config_values) - : detail::commit_impl(queue, backend::mklcpu), - status(DFT_NOTSET) { - if (config_values.rank == 1) { + commit_derived_impl(sycl::queue queue, const detail::dft_values& config_values) + : detail::commit_impl(queue, backend::mklcpu) { + DFT_ERROR status = DFT_NOTSET; + if (config_values.dimensions.size() == 1) { status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), config_values.rank, config_values.dimensions[0]); } @@ -54,14 +54,16 @@ class commit_derived_impl : public detail::commit_impl { config_values.rank, config_values.dimensions.data()); } if (status != DFTI_NO_ERROR) { - throw oneapi::mkl::exception("dft", "commit", "DftiCreateDescriptor failed"); + throw oneapi::mkl::exception("dft/backends/mklcpu", "commit", + "DftiCreateDescriptor failed"); } + } + void commit(sycl::queue&, detail::dft_values& config_values) override { set_value(handle, config_values); - - status = DftiCommitDescriptor(handle); - if (status != DFTI_NO_ERROR) { - throw oneapi::mkl::exception("dft", "commit", "DftiCommitDescriptor failed"); + if (auto status = DftiCommitDescriptor(handle); status != DFTI_NO_ERROR) { + throw oneapi::mkl::exception("dft/backends/mklcpu", "commit", + "DftiCommitDescriptor failed"); } } @@ -74,7 +76,6 @@ class commit_derived_impl : public detail::commit_impl { } private: - DFT_ERROR status; DFTI_DESCRIPTOR_HANDLE handle = nullptr; constexpr DFTI_CONFIG_VALUE get_domain(domain d) { @@ -96,44 +97,132 @@ class commit_derived_impl : public detail::commit_impl { } template - DFT_ERROR set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, - Args... args) { - DFT_ERROR value_err = DFT_NOTSET; - value_err = DftiSetValue(hand, name, args...); - if (value_err != DFTI_NO_ERROR) { - throw oneapi::mkl::exception("dft", "set_value_item", std::to_string(name)); + void set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, Args... args) { + if (auto ret = DftiSetValue(hand, name, args...); ret != DFTI_NO_ERROR) { + throw oneapi::mkl::exception( + "dft/backends/mklcpu", "set_value_item", + "name: " + std::to_string(name) + " error: " + std::to_string(ret)); } - - return value_err; } - void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, detail::dft_values config) { - set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); - set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); - set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); - set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); - set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms); - set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); - set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); - set_value_item( - descHandle, DFTI_PLACEMENT, - (config.placement == config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE); + void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, detail::dft_values& config) { + using onemkl_param = dft::detail::config_param; + using backend_param = dft::config_param; + + // The following are read-only: + // Dimension, forward domain, precision, commit status. + // Lengths are supplied at descriptor construction time. + + for (const auto change : config.changed) { + switch (change) { + case onemkl_param::FORWARD_DOMAIN: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update FORWARD_DOMAIN parameter."); + break; + case onemkl_param::DIMENSION: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update DIMENSION parameter."); + break; + case onemkl_param::LENGTHS: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update LENGTHS parameter."); + break; + case onemkl_param::PRECISION: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update PRECISION parameter."); + break; + case onemkl_param::FORWARD_SCALE: + set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale); + break; + case onemkl_param::BACKWARD_SCALE: + set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale); + break; + case onemkl_param::NUMBER_OF_TRANSFORMS: + set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, + config.number_of_transforms); + break; + case onemkl_param::COMPLEX_STORAGE: + throw mkl::invalid_argument( + "dft/backends/mklgpu", "commit", + "MKLCPU does not support the COMPLEX_STORAGE parameter."); + break; + case onemkl_param::REAL_STORAGE: + throw mkl::invalid_argument( + "dft/backends/mklgpu", "commit", + "MKLCPU does not support the REAL_STORAGE parameter."); + break; + case onemkl_param::CONJUGATE_EVEN_STORAGE: + throw mkl::invalid_argument( + "dft/backends/mklgpu", "commit", + "MKLCPU does not support the CONJUGATE_EVEN_STORAGE parameter."); + break; + case onemkl_param::PLACEMENT: + set_value_item(descHandle, DFTI_PLACEMENT, + (config.placement == config_value::INPLACE) ? DFTI_INPLACE + : DFTI_NOT_INPLACE); + break; + case onemkl_param::INPUT_STRIDES: + set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data()); + break; + case onemkl_param::OUTPUT_STRIDES: + set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data()); + break; + case onemkl_param::FWD_DISTANCE: + // TODO forward distance vs input distance + set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist); + break; + case onemkl_param::BWD_DISTANCE: + // TODO backward distance vs output distance + set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist); + break; + case onemkl_param::WORKSPACE: + // Setting the workspace causes an FFT_INVALID_DESCRIPTOR. + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "MKLCPU does not support the WORKSPACE parameter."); + break; + case onemkl_param::ORDERING: + // Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.ordering != dft::detail::config_value::ORDERED) { + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "MKLCPU only supports ordered ordering."); + } + break; + case onemkl_param::TRANSPOSE: + // Setting the transpose causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.transpose != false) { + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "MKLCPU only supports non-transposed."); + } + break; + case onemkl_param::PACKED_FORMAT: + throw mkl::invalid_argument( + "dft/backends/mklgpu", "commit", + "MKLCPU does not support the PACKED_FORMAT parameter."); + break; + case onemkl_param::COMMIT_STATUS: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update COMMIT_STATUS parameter."); + break; + } + } + config.changed.clear(); } }; template -detail::commit_impl* create_commit(descriptor& desc, sycl::queue& sycl_queue) { +detail::commit_impl* create_commit(descriptor& desc, + sycl::queue& sycl_queue) { return new commit_derived_impl(sycl_queue, desc.get_values()); } -template detail::commit_impl* create_commit(descriptor&, - sycl::queue&); -template detail::commit_impl* create_commit(descriptor&, - sycl::queue&); -template detail::commit_impl* create_commit(descriptor&, - sycl::queue&); -template detail::commit_impl* create_commit(descriptor&, - sycl::queue&); +template detail::commit_impl* create_commit( + descriptor&, sycl::queue&); +template detail::commit_impl* create_commit( + descriptor&, sycl::queue&); +template detail::commit_impl* create_commit( + descriptor&, sycl::queue&); +template detail::commit_impl* create_commit( + descriptor&, sycl::queue&); } // namespace mklcpu } // namespace dft diff --git a/src/dft/backends/mklcpu/descriptor.cpp b/src/dft/backends/mklcpu/descriptor.cpp index 48791d468..6d6620998 100644 --- a/src/dft/backends/mklcpu/descriptor.cpp +++ b/src/dft/backends/mklcpu/descriptor.cpp @@ -28,7 +28,10 @@ namespace dft { template void descriptor::commit(backend_selector selector) { - pimpl_.reset(mklcpu::create_commit(*this, selector.get_queue())); + if (!pimpl_) { + pimpl_.reset(mklcpu::create_commit(*this, selector.get_queue())); + } + pimpl_->commit(selector.get_queue(), values_); } template void descriptor::commit( diff --git a/src/dft/backends/mklgpu/commit.cpp b/src/dft/backends/mklgpu/commit.cpp index 08c454c22..7a489e76c 100644 --- a/src/dft/backends/mklgpu/commit.cpp +++ b/src/dft/backends/mklgpu/commit.cpp @@ -51,7 +51,7 @@ namespace detail { /// Commit impl class specialization for MKLGPU. template -class commit_derived_impl : public dft::detail::commit_impl { +class commit_derived_impl final : public dft::detail::commit_impl { private: // Equivalent MKLGPU precision and domain from OneMKL's precision / domain. static constexpr dft::precision mklgpu_prec = to_mklgpu(prec); @@ -59,23 +59,26 @@ class commit_derived_impl : public dft::detail::commit_impl { using mklgpu_descriptor_t = dft::descriptor; public: - commit_derived_impl(sycl::queue queue, dft::detail::dft_values config_values) - : oneapi::mkl::dft::detail::commit_impl(queue, backend::mklgpu), + commit_derived_impl(sycl::queue queue, const dft::detail::dft_values& config_values) + : oneapi::mkl::dft::detail::commit_impl(queue, backend::mklgpu), handle(config_values.dimensions) { - set_value(handle, config_values); // MKLGPU does not throw an informative exception for the following: if constexpr (prec == dft::detail::precision::DOUBLE) { if (!queue.get_device().has(sycl::aspect::fp64)) { - throw mkl::exception("DFT", "commit", "Device does not support double precision."); + throw mkl::exception("dft/backends/mklgpu", "commit", "Device does not support double precision."); } } + } + virtual void commit(sycl::queue& queue, + dft::detail::dft_values& config_values) override { + set_value(handle, config_values); try { handle.commit(queue); } catch (const std::exception& mkl_exception) { // Catching the real MKL exception causes headaches with naming. - throw mkl::exception("DFT", "commit", mkl_exception.what()); + throw mkl::exception("dft/backends/mklgpu", "commit", mkl_exception.what()); } } @@ -89,61 +92,129 @@ class commit_derived_impl : public dft::detail::commit_impl { // The native MKLGPU class. mklgpu_descriptor_t handle; - void set_value(mklgpu_descriptor_t& desc, dft::detail::dft_values config) { + void set_value(mklgpu_descriptor_t& desc, dft::detail::dft_values& config) { using onemkl_param = dft::detail::config_param; using backend_param = dft::config_param; // The following are read-only: // Dimension, forward domain, precision, commit status. // Lengths are supplied at descriptor construction time. - desc.set_value(backend_param::FORWARD_SCALE, config.fwd_scale); - desc.set_value(backend_param::BACKWARD_SCALE, config.bwd_scale); - desc.set_value(backend_param::NUMBER_OF_TRANSFORMS, config.number_of_transforms); - desc.set_value(backend_param::COMPLEX_STORAGE, - to_mklgpu(config.complex_storage)); - if (config.real_storage != dft::detail::config_value::REAL_REAL) { - throw mkl::invalid_argument("DFT", "commit", - "MKLGPU only supports real-real real storage."); - } - desc.set_value(backend_param::CONJUGATE_EVEN_STORAGE, - to_mklgpu(config.conj_even_storage)); - desc.set_value(backend_param::PLACEMENT, - to_mklgpu(config.placement)); - desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data()); - desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data()); - desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist); - desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist); - // Setting the workspace causes an FFT_INVALID_DESCRIPTOR. - // Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used: - if (config.ordering != dft::detail::config_value::ORDERED) { - throw mkl::invalid_argument("DFT", "commit", "MKLGPU only supports ordered ordering."); - } - // Setting the transpose causes an FFT_INVALID_DESCRIPTOR. Check that default is used: - if (config.transpose != false) { - throw mkl::invalid_argument("DFT", "commit", "MKLGPU only supports non-transposed."); + + for (const auto change : config.changed) { + switch (change) { + case onemkl_param::FORWARD_DOMAIN: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update FORWARD_DOMAIN parameter."); + break; + case onemkl_param::DIMENSION: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update DIMENSION parameter."); + break; + case onemkl_param::LENGTHS: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update LENGTHS parameter."); + break; + case onemkl_param::PRECISION: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update PRECISION parameter."); + break; + case onemkl_param::FORWARD_SCALE: + desc.set_value(backend_param::FORWARD_SCALE, config.fwd_scale); + break; + case onemkl_param::BACKWARD_SCALE: + desc.set_value(backend_param::BACKWARD_SCALE, config.bwd_scale); + break; + case onemkl_param::NUMBER_OF_TRANSFORMS: + desc.set_value(backend_param::NUMBER_OF_TRANSFORMS, + config.number_of_transforms); + break; + case onemkl_param::COMPLEX_STORAGE: + desc.set_value( + backend_param::COMPLEX_STORAGE, + to_mklgpu(config.complex_storage)); + break; + case onemkl_param::REAL_STORAGE: + if (config.real_storage != dft::detail::config_value::REAL_REAL) { + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "MKLGPU only supports real-real real storage."); + } + break; + case onemkl_param::CONJUGATE_EVEN_STORAGE: + desc.set_value( + backend_param::CONJUGATE_EVEN_STORAGE, + to_mklgpu(config.conj_even_storage)); + break; + case onemkl_param::PLACEMENT: + desc.set_value(backend_param::PLACEMENT, + to_mklgpu(config.placement)); + break; + case onemkl_param::INPUT_STRIDES: + desc.set_value(backend_param::INPUT_STRIDES, config.input_strides.data()); + break; + case onemkl_param::OUTPUT_STRIDES: + desc.set_value(backend_param::OUTPUT_STRIDES, config.output_strides.data()); + break; + case onemkl_param::FWD_DISTANCE: + desc.set_value(backend_param::FWD_DISTANCE, config.fwd_dist); + break; + case onemkl_param::BWD_DISTANCE: + desc.set_value(backend_param::BWD_DISTANCE, config.bwd_dist); + break; + case onemkl_param::WORKSPACE: + // Setting the workspace causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.workspace != dft::detail::config_value::ALLOW) { + throw mkl::invalid_argument( + "dft/backends/mklgpu", "commit", + "MKLGPU does not support the WORKSPACE parameter."); + } + break; + case onemkl_param::ORDERING: + // Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.ordering != dft::detail::config_value::ORDERED) { + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "MKLGPU only supports ordered ordering."); + } + break; + case onemkl_param::TRANSPOSE: + // Setting the transpose causes an FFT_INVALID_DESCRIPTOR. Check that default is used: + if (config.transpose != false) { + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "MKLGPU only supports non-transposed."); + } + break; + case onemkl_param::PACKED_FORMAT: + desc.set_value(backend_param::PACKED_FORMAT, + to_mklgpu(config.packed_format)); + break; + case onemkl_param::COMMIT_STATUS: + throw mkl::invalid_argument("dft/backends/mklgpu", "commit", + "cannot update COMMIT_STATUS parameter."); + break; + } } - desc.set_value(backend_param::PACKED_FORMAT, - to_mklgpu(config.packed_format)); + config.changed.clear(); } }; } // namespace detail template -dft::detail::commit_impl* create_commit(dft::detail::descriptor& desc, - sycl::queue& sycl_queue) { +dft::detail::commit_impl* create_commit(dft::detail::descriptor& desc, + sycl::queue& sycl_queue) { return new detail::commit_derived_impl(sycl_queue, desc.get_values()); } -template dft::detail::commit_impl* create_commit( - dft::detail::descriptor&, - sycl::queue&); -template dft::detail::commit_impl* create_commit( +template dft::detail::commit_impl* +create_commit(dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( dft::detail::descriptor&, sycl::queue&); -template dft::detail::commit_impl* create_commit( - dft::detail::descriptor&, - sycl::queue&); -template dft::detail::commit_impl* create_commit( +template dft::detail::commit_impl* +create_commit(dft::detail::descriptor&, + sycl::queue&); +template dft::detail::commit_impl* +create_commit( dft::detail::descriptor&, sycl::queue&); diff --git a/src/dft/backends/mklgpu/descriptor.cpp b/src/dft/backends/mklgpu/descriptor.cpp index 99b7be745..cedf20685 100644 --- a/src/dft/backends/mklgpu/descriptor.cpp +++ b/src/dft/backends/mklgpu/descriptor.cpp @@ -28,7 +28,10 @@ namespace dft { template void descriptor::commit(backend_selector selector) { - pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue())); + if (!pimpl_) { + pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue())); + } + pimpl_->commit(selector.get_queue(), values_); } template void descriptor::commit( diff --git a/src/dft/descriptor.cxx b/src/dft/descriptor.cxx index 378d9d031..d0a6a18c7 100644 --- a/src/dft/descriptor.cxx +++ b/src/dft/descriptor.cxx @@ -44,10 +44,6 @@ void compute_default_strides(const std::vector& dimensions, template void descriptor::set_value(config_param param, ...) { - if (pimpl_) { - throw mkl::invalid_argument("DFT", "set_value", - "Cannot set value on committed descriptor."); - } va_list vl; va_start(vl, param); switch (param) { @@ -58,13 +54,7 @@ void descriptor::set_value(config_param param, ...) { throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter."); break; case config_param::LENGTHS: { - if (values_.rank == 1) { - std::int64_t length = va_arg(vl, std::int64_t); - detail::set_value(values_, &length); - } - else { - detail::set_value(values_, va_arg(vl, std::int64_t*)); - } + throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter."); break; } case config_param::PRECISION: @@ -113,6 +103,9 @@ void descriptor::set_value(config_param param, ...) { case config_param::TRANSPOSE: detail::set_value(values_, va_arg(vl, int)); break; + case config_param::WORKSPACE: + detail::set_value(values_, va_arg(vl, config_value)); + break; case config_param::PACKED_FORMAT: detail::set_value(values_, va_arg(vl, config_value)); break; @@ -151,7 +144,20 @@ descriptor::descriptor(std::vector dimensions) { values_.transpose = false; values_.packed_format = config_value::CCE_FORMAT; values_.dimensions = std::move(dimensions); - values_.rank = values_.dimensions.size(); + values_.changed.insert({ + config_param::BACKWARD_SCALE, + config_param::FORWARD_SCALE, + config_param::NUMBER_OF_TRANSFORMS, + config_param::FWD_DISTANCE, + config_param::BWD_DISTANCE, + config_param::PLACEMENT, + config_param::COMPLEX_STORAGE, + config_param::REAL_STORAGE, + config_param::CONJUGATE_EVEN_STORAGE, + config_param::ORDERING, + config_param::TRANSPOSE, + config_param::PACKED_FORMAT, + }); } template @@ -174,7 +180,9 @@ void descriptor::get_value(config_param param, ...) { va_start(vl, param); switch (param) { case config_param::FORWARD_DOMAIN: *va_arg(vl, dft::domain*) = dom; break; - case config_param::DIMENSION: *va_arg(vl, std::int64_t*) = values_.rank; break; + case config_param::DIMENSION: + *va_arg(vl, std::int64_t*) = static_cast(values_.dimensions.size()); + break; case config_param::LENGTHS: std::copy(values_.dimensions.begin(), values_.dimensions.end(), va_arg(vl, std::int64_t*)); diff --git a/src/dft/descriptor_config_helper.hpp b/src/dft/descriptor_config_helper.hpp index 0f932a9ec..32b8fe68b 100644 --- a/src/dft/descriptor_config_helper.hpp +++ b/src/dft/descriptor_config_helper.hpp @@ -94,18 +94,9 @@ PARAM_TYPE_HELPER(config_param::COMMIT_STATUS, config_value) template void set_value(dft_values& vals, param_type_helper_t, Param>&& set_val) { + vals.changed.insert(Param); if constexpr (Param == config_param::LENGTHS) { - int rank = vals.rank; - if (set_val == nullptr) { - throw mkl::invalid_argument("DFT", "set_value", "Given nullptr."); - } - for (int i{ 0 }; i < rank; ++i) { - if (set_val[i] <= 0) { - throw mkl::invalid_argument("DFT", "set_value", - "Invalid length value (negative or 0)."); - } - } - std::copy(set_val, set_val + rank, vals.dimensions.begin()); + throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter."); } else if constexpr (Param == config_param::PRECISION) { throw mkl::invalid_argument("DFT", "set_value", "Read-only parameter."); @@ -159,18 +150,16 @@ void set_value(dft_values& vals, } } else if constexpr (Param == config_param::INPUT_STRIDES) { - int rank = vals.rank; if (set_val == nullptr) { throw mkl::invalid_argument("DFT", "set_value", "Given nullptr."); } - std::copy(set_val, set_val + vals.rank + 1, vals.input_strides.begin()); + std::copy(set_val, set_val + vals.dimensions.size() + 1, vals.input_strides.begin()); } else if constexpr (Param == config_param::OUTPUT_STRIDES) { - int rank = vals.rank; if (set_val == nullptr) { throw mkl::invalid_argument("DFT", "set_value", "Given nullptr."); } - std::copy(set_val, set_val + vals.rank + 1, vals.output_strides.begin()); + std::copy(set_val, set_val + vals.dimensions.size() + 1, vals.output_strides.begin()); } else if constexpr (Param == config_param::FWD_DISTANCE) { vals.fwd_dist = set_val; diff --git a/src/dft/dft_loader.cpp b/src/dft/dft_loader.cpp index 4a595263e..fcfb607ec 100644 --- a/src/dft/dft_loader.cpp +++ b/src/dft/dft_loader.cpp @@ -34,28 +34,28 @@ static oneapi::mkl::detail::table_initializer -commit_impl* create_commit( +commit_impl* create_commit( descriptor& desc, sycl::queue& sycl_queue) { auto libkey = get_device_id(sycl_queue); return function_tables[libkey].create_commit_sycl_fz(desc, sycl_queue); } template <> -commit_impl* create_commit( +commit_impl* create_commit( descriptor& desc, sycl::queue& sycl_queue) { auto libkey = get_device_id(sycl_queue); return function_tables[libkey].create_commit_sycl_dz(desc, sycl_queue); } template <> -commit_impl* create_commit( +commit_impl* create_commit( descriptor& desc, sycl::queue& sycl_queue) { auto libkey = get_device_id(sycl_queue); return function_tables[libkey].create_commit_sycl_fr(desc, sycl_queue); } template <> -commit_impl* create_commit( +commit_impl* create_commit( descriptor& desc, sycl::queue& sycl_queue) { auto libkey = get_device_id(sycl_queue); return function_tables[libkey].create_commit_sycl_dr(desc, sycl_queue); diff --git a/src/dft/function_table.hpp b/src/dft/function_table.hpp index f64c77f0d..4c630662d 100644 --- a/src/dft/function_table.hpp +++ b/src/dft/function_table.hpp @@ -35,19 +35,25 @@ typedef struct { int version; - oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fz)( + oneapi::mkl::dft::detail::commit_impl* ( + *create_commit_sycl_fz)( oneapi::mkl::dft::descriptor& desc, sycl::queue& sycl_queue); - oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dz)( + oneapi::mkl::dft::detail::commit_impl* ( + *create_commit_sycl_dz)( oneapi::mkl::dft::descriptor& desc, sycl::queue& sycl_queue); - oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fr)( + oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fr)( oneapi::mkl::dft::descriptor& desc, sycl::queue& sycl_queue); - oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dr)( + oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dr)( oneapi::mkl::dft::descriptor& desc, sycl::queue& sycl_queue); diff --git a/tests/unit_tests/dft/include/compute_inplace.hpp b/tests/unit_tests/dft/include/compute_inplace.hpp index a78da23dd..2390e7af1 100644 --- a/tests/unit_tests/dft/include/compute_inplace.hpp +++ b/tests/unit_tests/dft/include/compute_inplace.hpp @@ -31,6 +31,8 @@ int DFT_Test::test_in_place_buffer() { descriptor_t descriptor{ size }; descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::INPLACE); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); + commit_descriptor(descriptor, sycl_queue); const size_t container_size = domain == oneapi::mkl::dft::domain::REAL ? conjugate_even_size : size; @@ -39,7 +41,6 @@ int DFT_Test::test_in_place_buffer() { std::copy(input.cbegin(), input.cend(), inout_host.begin()); sycl::buffer inout_buf{ inout_host.data(), sycl::range<1>(container_size) }; - commit_descriptor(descriptor, sycl_queue); try { oneapi::mkl::dft::compute_forward(descriptor, inout_buf); @@ -66,15 +67,9 @@ int DFT_Test::test_in_place_buffer() { inout_host.size(), abs_error_margin, rel_error_margin, std::cout)); } - descriptor_t descriptor_back{ size }; - descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); - commit_descriptor(descriptor_back, sycl_queue); - try { - oneapi::mkl::dft::compute_backward, - FwdInputType>(descriptor_back, inout_buf); + oneapi::mkl::dft::compute_backward, + FwdInputType>(descriptor, inout_buf); } catch (oneapi::mkl::unimplemented &e) { std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; @@ -98,6 +93,7 @@ int DFT_Test::test_in_place_USM() { descriptor_t descriptor{ size }; descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::INPLACE); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); commit_descriptor(descriptor, sycl_queue); const size_t container_size = @@ -134,17 +130,11 @@ int DFT_Test::test_in_place_USM() { abs_error_margin, rel_error_margin, std::cout)); } - descriptor_t descriptor_back{ size }; - descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); - commit_descriptor(descriptor_back, sycl_queue); - try { std::vector dependencies; sycl::event done = - oneapi::mkl::dft::compute_backward, - FwdInputType>(descriptor_back, inout.data()); + oneapi::mkl::dft::compute_backward, + FwdInputType>(descriptor, inout.data()); done.wait(); } catch (oneapi::mkl::unimplemented &e) { diff --git a/tests/unit_tests/dft/include/compute_inplace_real_real.hpp b/tests/unit_tests/dft/include/compute_inplace_real_real.hpp index 3b9878932..8987aab47 100644 --- a/tests/unit_tests/dft/include/compute_inplace_real_real.hpp +++ b/tests/unit_tests/dft/include/compute_inplace_real_real.hpp @@ -37,6 +37,7 @@ int DFT_Test::test_in_place_real_real_USM() { oneapi::mkl::dft::config_value::INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::config_value::REAL_REAL); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); commit_descriptor(descriptor, sycl_queue); auto ua_input = usm_allocator_t(cxt, *dev); @@ -51,19 +52,9 @@ int DFT_Test::test_in_place_real_real_USM() { descriptor, inout_re.data(), inout_im.data(), dependencies); done.wait(); - descriptor_t descriptor_back{ size }; - - descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - descriptor_back.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, - oneapi::mkl::dft::config_value::REAL_REAL); - descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); - commit_descriptor(descriptor_back, sycl_queue); - - done = - oneapi::mkl::dft::compute_backward, - PrecisionType>(descriptor_back, inout_re.data(), - inout_im.data(), dependencies); + done = oneapi::mkl::dft::compute_backward, + PrecisionType>(descriptor, inout_re.data(), + inout_im.data(), dependencies); done.wait(); } catch (oneapi::mkl::unimplemented &e) { @@ -92,6 +83,7 @@ int DFT_Test::test_in_place_real_real_buffer() { oneapi::mkl::dft::config_value::INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::config_value::REAL_REAL); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); commit_descriptor(descriptor, sycl_queue); sycl::buffer inout_re_buf{ input_re.data(), sycl::range<1>(size) }; @@ -100,18 +92,8 @@ int DFT_Test::test_in_place_real_real_buffer() { oneapi::mkl::dft::compute_forward(descriptor, inout_re_buf, inout_im_buf); - descriptor_t descriptor_back{ size }; - - descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::INPLACE); - descriptor_back.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, - oneapi::mkl::dft::config_value::REAL_REAL); - descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); - commit_descriptor(descriptor_back, sycl_queue); - - oneapi::mkl::dft::compute_backward, - PrecisionType>(descriptor_back, inout_re_buf, - inout_im_buf); + oneapi::mkl::dft::compute_backward, + PrecisionType>(descriptor, inout_re_buf, inout_im_buf); } catch (oneapi::mkl::unimplemented &e) { std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; diff --git a/tests/unit_tests/dft/include/compute_out_of_place.hpp b/tests/unit_tests/dft/include/compute_out_of_place.hpp index 2e9f005ec..daeddf04c 100644 --- a/tests/unit_tests/dft/include/compute_out_of_place.hpp +++ b/tests/unit_tests/dft/include/compute_out_of_place.hpp @@ -34,14 +34,9 @@ int DFT_Test::test_out_of_place_buffer() { descriptor_t descriptor{ size }; descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); commit_descriptor(descriptor, sycl_queue); - descriptor_t descriptor_back{ size }; - descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); - commit_descriptor(descriptor_back, sycl_queue); - std::vector fwd_data(input); std::vector bwd_data(bwd_size, 0); @@ -66,9 +61,9 @@ int DFT_Test::test_out_of_place_buffer() { } try { - oneapi::mkl::dft::compute_backward, - FwdOutputType, FwdInputType>(descriptor_back, - bwd_buf, fwd_buf); + oneapi::mkl::dft::compute_backward, + FwdOutputType, FwdInputType>(descriptor, bwd_buf, + fwd_buf); } catch (oneapi::mkl::unimplemented &e) { std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; @@ -93,14 +88,9 @@ int DFT_Test::test_out_of_place_USM() { descriptor_t descriptor{ size }; descriptor.set_value(oneapi::mkl::dft::config_param::PLACEMENT, oneapi::mkl::dft::config_value::NOT_INPLACE); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); commit_descriptor(descriptor, sycl_queue); - descriptor_t descriptor_back{ size }; - descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); - commit_descriptor(descriptor_back, sycl_queue); - auto ua_input = usm_allocator_t(cxt, *dev); auto ua_output = usm_allocator_t(cxt, *dev); @@ -121,8 +111,8 @@ int DFT_Test::test_out_of_place_USM() { rel_error_margin, std::cout)); try { - oneapi::mkl::dft::compute_backward, - FwdOutputType, FwdInputType>(descriptor_back, bwd.data(), + oneapi::mkl::dft::compute_backward, + FwdOutputType, FwdInputType>(descriptor, bwd.data(), fwd.data(), no_dependencies) .wait(); } diff --git a/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp b/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp index e3148cdb9..922af1257 100644 --- a/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp +++ b/tests/unit_tests/dft/include/compute_out_of_place_real_real.hpp @@ -37,6 +37,7 @@ int DFT_Test::test_out_of_place_real_real_USM() { oneapi::mkl::dft::config_value::NOT_INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::config_value::REAL_REAL); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); commit_descriptor(descriptor, sycl_queue); auto ua_input = usm_allocator_t(cxt, *dev); @@ -58,20 +59,9 @@ int DFT_Test::test_out_of_place_real_real_USM() { descriptor, in_re.data(), in_im.data(), out_re.data(), out_im.data(), dependencies); done.wait(); - descriptor_t descriptor_back{ size }; - - descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - descriptor_back.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, - oneapi::mkl::dft::config_value::REAL_REAL); - descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); - commit_descriptor(descriptor_back, sycl_queue); - - done = - oneapi::mkl::dft::compute_backward, - PrecisionType, PrecisionType>( - descriptor_back, out_re.data(), out_im.data(), out_back_re.data(), - out_back_im.data()); + done = oneapi::mkl::dft::compute_backward, + PrecisionType, PrecisionType>( + descriptor, out_re.data(), out_im.data(), out_back_re.data(), out_back_im.data()); done.wait(); } catch (oneapi::mkl::unimplemented &e) { @@ -100,6 +90,7 @@ int DFT_Test::test_out_of_place_real_real_buffer() { oneapi::mkl::dft::config_value::NOT_INPLACE); descriptor.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, oneapi::mkl::dft::config_value::REAL_REAL); + descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); commit_descriptor(descriptor, sycl_queue); sycl::buffer in_dev_re{ input_re.data(), sycl::range<1>(size) }; @@ -112,18 +103,9 @@ int DFT_Test::test_out_of_place_real_real_buffer() { oneapi::mkl::dft::compute_forward( descriptor, in_dev_re, in_dev_im, out_dev_re, out_dev_im); - descriptor_t descriptor_back{ size }; - - descriptor_back.set_value(oneapi::mkl::dft::config_param::PLACEMENT, - oneapi::mkl::dft::config_value::NOT_INPLACE); - descriptor_back.set_value(oneapi::mkl::dft::config_param::COMPLEX_STORAGE, - oneapi::mkl::dft::config_value::REAL_REAL); - descriptor_back.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / size)); - commit_descriptor(descriptor_back, sycl_queue); - - oneapi::mkl::dft::compute_backward, + oneapi::mkl::dft::compute_backward, PrecisionType, PrecisionType>( - descriptor_back, out_dev_re, out_dev_im, out_back_dev_re, out_back_dev_im); + descriptor, out_dev_re, out_dev_im, out_back_dev_re, out_back_dev_im); } catch (oneapi::mkl::unimplemented &e) { std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl; diff --git a/tests/unit_tests/dft/source/descriptor_tests.cpp b/tests/unit_tests/dft/source/descriptor_tests.cpp index 7bd955b2c..d0103ad8c 100644 --- a/tests/unit_tests/dft/source/descriptor_tests.cpp +++ b/tests/unit_tests/dft/source/descriptor_tests.cpp @@ -19,6 +19,7 @@ #include #include +#include #if __has_include() #include @@ -37,64 +38,6 @@ namespace { constexpr std::int64_t default_1d_lengths = 4; const std::vector default_3d_lengths{ 124, 5, 3 }; -template -inline void set_and_get_lengths(sycl::queue& sycl_queue) { - /* Negative Testing */ - { - oneapi::mkl::dft::descriptor descriptor{ default_3d_lengths }; - EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::LENGTHS, nullptr), - oneapi::mkl::invalid_argument); - } - - /* 1D */ - { - oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; - - std::int64_t lengths_value{ 0 }; - std::int64_t new_lengths{ 2345 }; - std::int64_t dimensions_before_set{ 0 }; - std::int64_t dimensions_after_set{ 0 }; - - descriptor.get_value(oneapi::mkl::dft::config_param::LENGTHS, &lengths_value); - descriptor.get_value(oneapi::mkl::dft::config_param::DIMENSION, &dimensions_before_set); - EXPECT_EQ(default_1d_lengths, lengths_value); - EXPECT_EQ(dimensions_before_set, 1); - - descriptor.set_value(oneapi::mkl::dft::config_param::LENGTHS, new_lengths); - descriptor.get_value(oneapi::mkl::dft::config_param::LENGTHS, &lengths_value); - descriptor.get_value(oneapi::mkl::dft::config_param::DIMENSION, &dimensions_after_set); - EXPECT_EQ(new_lengths, lengths_value); - EXPECT_EQ(dimensions_before_set, dimensions_after_set); - - commit_descriptor(descriptor, sycl_queue); - } - - /* >= 2D */ - { - const std::int64_t dimensions = 3; - - oneapi::mkl::dft::descriptor descriptor{ default_3d_lengths }; - - std::vector lengths_value(3); - std::vector new_lengths{ 1, 2, 7 }; - std::int64_t dimensions_before_set{ 0 }; - std::int64_t dimensions_after_set{ 0 }; - - descriptor.get_value(oneapi::mkl::dft::config_param::LENGTHS, lengths_value.data()); - descriptor.get_value(oneapi::mkl::dft::config_param::DIMENSION, &dimensions_before_set); - - EXPECT_EQ(default_3d_lengths, lengths_value); - EXPECT_EQ(dimensions, dimensions_before_set); - - descriptor.set_value(oneapi::mkl::dft::config_param::LENGTHS, new_lengths.data()); - descriptor.get_value(oneapi::mkl::dft::config_param::LENGTHS, lengths_value.data()); - descriptor.get_value(oneapi::mkl::dft::config_param::DIMENSION, &dimensions_after_set); - - EXPECT_EQ(new_lengths, lengths_value); - EXPECT_EQ(dimensions_before_set, dimensions_after_set); - } -} - template inline void set_and_get_strides(sycl::queue& sycl_queue) { oneapi::mkl::dft::descriptor descriptor{ default_3d_lengths }; @@ -355,11 +298,22 @@ inline void get_readonly_values(sycl::queue& sycl_queue) { descriptor.get_value(oneapi::mkl::dft::config_param::PRECISION, &precision_value); EXPECT_EQ(precision_value, precision); + std::int64_t lengths_value{ 0 }; + descriptor.get_value(oneapi::mkl::dft::config_param::LENGTHS, &lengths_value); + EXPECT_EQ(default_1d_lengths, lengths_value); + std::int64_t dimension_value; descriptor.get_value(oneapi::mkl::dft::config_param::DIMENSION, &dimension_value); EXPECT_EQ(dimension_value, 1); oneapi::mkl::dft::descriptor descriptor3D{ default_3d_lengths }; + + std::array lengths_value_3d; + descriptor3D.get_value(oneapi::mkl::dft::config_param::LENGTHS, lengths_value_3d.data()); + EXPECT_EQ(default_3d_lengths[0], lengths_value_3d[0]); + EXPECT_EQ(default_3d_lengths[1], lengths_value_3d[1]); + EXPECT_EQ(default_3d_lengths[2], lengths_value_3d[2]); + descriptor3D.get_value(oneapi::mkl::dft::config_param::DIMENSION, &dimension_value); EXPECT_EQ(dimension_value, 3); @@ -390,6 +344,11 @@ inline void set_readonly_values(sycl::queue& sycl_queue) { oneapi::mkl::dft::precision::DOUBLE), oneapi::mkl::invalid_argument); + EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::LENGTHS, std::int64_t{ 10 }), + oneapi::mkl::invalid_argument); + std::array new_lengths{ 8, 8, 8 }; + EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::LENGTHS, new_lengths.data()), + oneapi::mkl::invalid_argument); std::int64_t set_dimension{ 3 }; EXPECT_THROW(descriptor.set_value(oneapi::mkl::dft::config_param::DIMENSION, set_dimension), oneapi::mkl::invalid_argument); @@ -404,6 +363,56 @@ inline void set_readonly_values(sycl::queue& sycl_queue) { commit_descriptor(descriptor, sycl_queue); } +template +inline void recommit_values(sycl::queue& sycl_queue) { + using oneapi::mkl::dft::config_param; + using oneapi::mkl::dft::config_value; + using PrecisionType = + typename std::conditional_t; + using value = std::variant; + + // this will hold a param to change and the value to change it to + using test_params = std::pair; + + oneapi::mkl::dft::descriptor descriptor{ default_1d_lengths }; + EXPECT_NO_THROW(commit_descriptor(descriptor, sycl_queue)); + + std::array strides{ 0, 1 }; + + std::vector arguments{ + // not changeable + // FORWARD_DOMAIN, PRECISION , DIMENSION, LENGTHS, COMMIT_STATUS + std::make_pair(config_param::FORWARD_SCALE, PrecisionType{ 1.2 }), + std::make_pair(config_param::BACKWARD_SCALE, PrecisionType{ 3.4 }), + std::make_pair(config_param::NUMBER_OF_TRANSFORMS, std::int64_t{ 5 }), + std::make_pair(config_param::COMPLEX_STORAGE, config_value::COMPLEX_COMPLEX), + std::make_pair(config_param::REAL_STORAGE, config_value::REAL_REAL), + std::make_pair(config_param::CONJUGATE_EVEN_STORAGE, config_value::COMPLEX_COMPLEX), + std::make_pair(config_param::PLACEMENT, config_value::INPLACE), + std::make_pair(config_param::INPUT_STRIDES, strides.data()), + std::make_pair(config_param::OUTPUT_STRIDES, strides.data()), + std::make_pair(config_param::FWD_DISTANCE, std::int64_t{ 6 }), + std::make_pair(config_param::BWD_DISTANCE, std::int64_t{ 7 }), + std::make_pair(config_param::WORKSPACE, config_value::ALLOW), + std::make_pair(config_param::ORDERING, config_value::ORDERED), + std::make_pair(config_param::TRANSPOSE, bool{ false }), + std::make_pair(config_param::PACKED_FORMAT, config_value::CCE_FORMAT) + }; + + for (int i = 0; i < arguments.size(); i += 1) { + std::visit( + [&arguments, &descriptor, i](auto&& a) { descriptor.set_value(arguments[i].first, a); }, + arguments[i].second); + try { + commit_descriptor(descriptor, sycl_queue); + } + catch (oneapi::mkl::exception& e) { + FAIL() << "exception at index " << i << " with error : " << e.what(); + } + } +} + template int test(sycl::device* dev) { sycl::queue sycl_queue(*dev, exception_handler); @@ -415,11 +424,11 @@ int test(sycl::device* dev) { } } - set_and_get_lengths(sycl_queue); set_and_get_strides(sycl_queue); set_and_get_values(sycl_queue); get_readonly_values(sycl_queue); set_readonly_values(sycl_queue); + recommit_values(sycl_queue); return !::testing::Test::HasFailure(); } diff --git a/tests/unit_tests/main_test.cpp b/tests/unit_tests/main_test.cpp index 2172a579f..6ac1a9565 100644 --- a/tests/unit_tests/main_test.cpp +++ b/tests/unit_tests/main_test.cpp @@ -102,9 +102,9 @@ int main(int argc, char** argv) { for (auto dev : plat_devs) { try { /* Do not test for OpenCL backend on GPU */ - if (dev.is_gpu() && plat.get_info().find( - "OpenCL") != std::string::npos) - continue; + // if (dev.is_gpu() && plat.get_info().find( + // "OpenCL") != std::string::npos) + // continue; if (unique_devices.find(dev.get_info()) == unique_devices.end()) { unique_devices.insert(dev.get_info());