Skip to content

Commit

Permalink
Allow descriptors to be recommited for certain params
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno committed Feb 20, 2023
1 parent f7ba088 commit aad7522
Show file tree
Hide file tree
Showing 22 changed files with 413 additions and 278 deletions.
6 changes: 5 additions & 1 deletion include/oneapi/mkl/dft/detail/commit_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,22 @@
#endif

#include "oneapi/mkl/detail/backends.hpp"
#include "types_impl.hpp"

namespace oneapi {
namespace mkl {
namespace dft {
namespace detail {

template <precision prec, domain dom>
class commit_impl {
public:
commit_impl(sycl::queue queue, mkl::backend backend)
: queue_(queue),
backend_(backend),
status(false) {}

commit_impl(const commit_impl& other) = default;
commit_impl(const commit_impl& other) = delete;

virtual ~commit_impl() = default;

Expand All @@ -55,6 +57,8 @@ class commit_impl {

virtual void* get_handle() = 0;

virtual void commit(sycl::queue& queue, dft_values<prec, dom>&) = 0;

protected:
bool status;
mkl::backend backend_;
Expand Down
10 changes: 5 additions & 5 deletions include/oneapi/mkl/dft/detail/descriptor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ template <precision prec, domain dom>
class descriptor;

template <precision prec, domain dom>
inline commit_impl* get_commit(descriptor<prec, dom>& desc);
inline commit_impl<prec, dom>* get_commit(descriptor<prec, dom>& desc);

template <precision prec, domain dom>
class descriptor {
Expand All @@ -68,22 +68,22 @@ class descriptor {
void commit(backend_selector<backend::mklgpu> selector);
#endif

dft_values<prec, dom> get_values() {
dft_values<prec, dom>& get_values() {
return values_;
};

private:
// Has a value when the descriptor is committed.
std::unique_ptr<commit_impl> pimpl_;
std::unique_ptr<commit_impl<prec, dom>> pimpl_;

// descriptor configuration values_ and structs
dft_values<prec, dom> values_;

friend commit_impl* get_commit<prec, dom>(descriptor<prec, dom>&);
friend commit_impl<prec, dom>* get_commit<prec, dom>(descriptor<prec, dom>&);
};

template <precision prec, domain dom>
inline commit_impl* get_commit(descriptor<prec, dom>& desc) {
inline commit_impl<prec, dom>* get_commit(descriptor<prec, dom>& desc) {
return desc.pimpl_.get();
}

Expand Down
4 changes: 2 additions & 2 deletions include/oneapi/mkl/dft/detail/dft_ct.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@
// Commit

template <dft::detail::precision prec, dft::detail::domain dom>
ONEMKL_EXPORT dft::detail::commit_impl *create_commit(dft::detail::descriptor<prec, dom> &desc,
sycl::queue &sycl_queue);
ONEMKL_EXPORT dft::detail::commit_impl<prec, dom> *create_commit(
dft::detail::descriptor<prec, dom> &desc, sycl::queue &sycl_queue);

// BUFFER version

Expand Down
4 changes: 3 additions & 1 deletion include/oneapi/mkl/dft/detail/dft_loader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ namespace mkl {
namespace dft {
namespace detail {

template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
class descriptor;

template <precision prec, domain dom>
ONEMKL_EXPORT commit_impl* create_commit(descriptor<prec, dom>& desc, sycl::queue& queue);
ONEMKL_EXPORT commit_impl<prec, dom>* create_commit(descriptor<prec, dom>& desc,
sycl::queue& queue);

} // namespace detail
} // namespace dft
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace dft {

namespace detail {
// Forward declarations
template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace dft {

namespace detail {
// Forward declarations
template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
Expand Down
4 changes: 3 additions & 1 deletion include/oneapi/mkl/dft/detail/types_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <cstdint>
#include <vector>
#include <type_traits>
#include <set>

namespace oneapi {
namespace mkl {
Expand Down Expand Up @@ -100,6 +101,8 @@ class dft_values {
using real_t = std::conditional_t<prec == precision::SINGLE, float, double>;

public:
std::set<config_param> changed;

std::vector<std::int64_t> input_strides;
std::vector<std::int64_t> output_strides;
real_t bwd_scale;
Expand All @@ -116,7 +119,6 @@ class dft_values {
bool transpose;
config_value packed_format;
std::vector<std::int64_t> dimensions;
std::int64_t rank;
};

} // namespace detail
Expand Down
5 changes: 4 additions & 1 deletion src/dft/backends/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(sycl::queue &queue) {
pimpl_.reset(detail::create_commit(*this, queue));
if (!pimpl_) {
pimpl_.reset(detail::create_commit(*this, queue));
}
pimpl_->commit(queue, values_);
}
template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(sycl::queue &);
template void descriptor<precision::SINGLE, domain::REAL>::commit(sycl::queue &);
Expand Down
167 changes: 128 additions & 39 deletions src/dft/backends/mklcpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,12 +40,12 @@ namespace dft {
namespace mklcpu {

template <precision prec, domain dom>
class commit_derived_impl : public detail::commit_impl {
class commit_derived_impl final : public detail::commit_impl<prec, dom> {
public:
commit_derived_impl(sycl::queue queue, detail::dft_values<prec, dom> config_values)
: detail::commit_impl(queue, backend::mklcpu),
status(DFT_NOTSET) {
if (config_values.rank == 1) {
commit_derived_impl(sycl::queue queue, const detail::dft_values<prec, dom>& config_values)
: detail::commit_impl<prec, dom>(queue, backend::mklcpu) {
DFT_ERROR status = DFT_NOTSET;
if (config_values.dimensions.size() == 1) {
status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom),
config_values.rank, config_values.dimensions[0]);
}
Expand All @@ -54,14 +54,16 @@ class commit_derived_impl : public detail::commit_impl {
config_values.rank, config_values.dimensions.data());
}
if (status != DFTI_NO_ERROR) {
throw oneapi::mkl::exception("dft", "commit", "DftiCreateDescriptor failed");
throw oneapi::mkl::exception("dft/backends/mklcpu", "commit",
"DftiCreateDescriptor failed");
}
}

void commit(sycl::queue&, detail::dft_values<prec, dom>& config_values) override {
set_value(handle, config_values);

status = DftiCommitDescriptor(handle);
if (status != DFTI_NO_ERROR) {
throw oneapi::mkl::exception("dft", "commit", "DftiCommitDescriptor failed");
if (auto status = DftiCommitDescriptor(handle); status != DFTI_NO_ERROR) {
throw oneapi::mkl::exception("dft/backends/mklcpu", "commit",
"DftiCommitDescriptor failed");
}
}

Expand All @@ -74,7 +76,6 @@ class commit_derived_impl : public detail::commit_impl {
}

private:
DFT_ERROR status;
DFTI_DESCRIPTOR_HANDLE handle = nullptr;

constexpr DFTI_CONFIG_VALUE get_domain(domain d) {
Expand All @@ -96,44 +97,132 @@ class commit_derived_impl : public detail::commit_impl {
}

template <typename... Args>
DFT_ERROR set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name,
Args... args) {
DFT_ERROR value_err = DFT_NOTSET;
value_err = DftiSetValue(hand, name, args...);
if (value_err != DFTI_NO_ERROR) {
throw oneapi::mkl::exception("dft", "set_value_item", std::to_string(name));
void set_value_item(DFTI_DESCRIPTOR_HANDLE hand, enum DFTI_CONFIG_PARAM name, Args... args) {
if (auto ret = DftiSetValue(hand, name, args...); ret != DFTI_NO_ERROR) {
throw oneapi::mkl::exception(
"dft/backends/mklcpu", "set_value_item",
"name: " + std::to_string(name) + " error: " + std::to_string(ret));
}

return value_err;
}

void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, detail::dft_values<prec, dom> config) {
set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data());
set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data());
set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale);
set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale);
set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS, config.number_of_transforms);
set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist);
set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist);
set_value_item(
descHandle, DFTI_PLACEMENT,
(config.placement == config_value::INPLACE) ? DFTI_INPLACE : DFTI_NOT_INPLACE);
void set_value(DFTI_DESCRIPTOR_HANDLE& descHandle, detail::dft_values<prec, dom>& config) {
using onemkl_param = dft::detail::config_param;
using backend_param = dft::config_param;

// The following are read-only:
// Dimension, forward domain, precision, commit status.
// Lengths are supplied at descriptor construction time.

for (const auto change : config.changed) {
switch (change) {
case onemkl_param::FORWARD_DOMAIN:
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"cannot update FORWARD_DOMAIN parameter.");
break;
case onemkl_param::DIMENSION:
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"cannot update DIMENSION parameter.");
break;
case onemkl_param::LENGTHS:
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"cannot update LENGTHS parameter.");
break;
case onemkl_param::PRECISION:
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"cannot update PRECISION parameter.");
break;
case onemkl_param::FORWARD_SCALE:
set_value_item(descHandle, DFTI_FORWARD_SCALE, config.fwd_scale);
break;
case onemkl_param::BACKWARD_SCALE:
set_value_item(descHandle, DFTI_BACKWARD_SCALE, config.bwd_scale);
break;
case onemkl_param::NUMBER_OF_TRANSFORMS:
set_value_item(descHandle, DFTI_NUMBER_OF_TRANSFORMS,
config.number_of_transforms);
break;
case onemkl_param::COMPLEX_STORAGE:
throw mkl::invalid_argument(
"dft/backends/mklgpu", "commit",
"MKLCPU does not support the COMPLEX_STORAGE parameter.");
break;
case onemkl_param::REAL_STORAGE:
throw mkl::invalid_argument(
"dft/backends/mklgpu", "commit",
"MKLCPU does not support the REAL_STORAGE parameter.");
break;
case onemkl_param::CONJUGATE_EVEN_STORAGE:
throw mkl::invalid_argument(
"dft/backends/mklgpu", "commit",
"MKLCPU does not support the CONJUGATE_EVEN_STORAGE parameter.");
break;
case onemkl_param::PLACEMENT:
set_value_item(descHandle, DFTI_PLACEMENT,
(config.placement == config_value::INPLACE) ? DFTI_INPLACE
: DFTI_NOT_INPLACE);
break;
case onemkl_param::INPUT_STRIDES:
set_value_item(descHandle, DFTI_INPUT_STRIDES, config.input_strides.data());
break;
case onemkl_param::OUTPUT_STRIDES:
set_value_item(descHandle, DFTI_OUTPUT_STRIDES, config.output_strides.data());
break;
case onemkl_param::FWD_DISTANCE:
// TODO forward distance vs input distance
set_value_item(descHandle, DFTI_INPUT_DISTANCE, config.fwd_dist);
break;
case onemkl_param::BWD_DISTANCE:
// TODO backward distance vs output distance
set_value_item(descHandle, DFTI_OUTPUT_DISTANCE, config.bwd_dist);
break;
case onemkl_param::WORKSPACE:
// Setting the workspace causes an FFT_INVALID_DESCRIPTOR.
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"MKLCPU does not support the WORKSPACE parameter.");
break;
case onemkl_param::ORDERING:
// Setting the ordering causes an FFT_INVALID_DESCRIPTOR. Check that default is used:
if (config.ordering != dft::detail::config_value::ORDERED) {
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"MKLCPU only supports ordered ordering.");
}
break;
case onemkl_param::TRANSPOSE:
// Setting the transpose causes an FFT_INVALID_DESCRIPTOR. Check that default is used:
if (config.transpose != false) {
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"MKLCPU only supports non-transposed.");
}
break;
case onemkl_param::PACKED_FORMAT:
throw mkl::invalid_argument(
"dft/backends/mklgpu", "commit",
"MKLCPU does not support the PACKED_FORMAT parameter.");
break;
case onemkl_param::COMMIT_STATUS:
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"cannot update COMMIT_STATUS parameter.");
break;
}
}
config.changed.clear();
}
};

template <precision prec, domain dom>
detail::commit_impl* create_commit(descriptor<prec, dom>& desc, sycl::queue& sycl_queue) {
detail::commit_impl<prec, dom>* create_commit(descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
return new commit_derived_impl<prec, dom>(sycl_queue, desc.get_values());
}

template detail::commit_impl* create_commit(descriptor<precision::SINGLE, domain::REAL>&,
sycl::queue&);
template detail::commit_impl* create_commit(descriptor<precision::SINGLE, domain::COMPLEX>&,
sycl::queue&);
template detail::commit_impl* create_commit(descriptor<precision::DOUBLE, domain::REAL>&,
sycl::queue&);
template detail::commit_impl* create_commit(descriptor<precision::DOUBLE, domain::COMPLEX>&,
sycl::queue&);
template detail::commit_impl<precision::SINGLE, domain::REAL>* create_commit(
descriptor<precision::SINGLE, domain::REAL>&, sycl::queue&);
template detail::commit_impl<precision::SINGLE, domain::COMPLEX>* create_commit(
descriptor<precision::SINGLE, domain::COMPLEX>&, sycl::queue&);
template detail::commit_impl<precision::DOUBLE, domain::REAL>* create_commit(
descriptor<precision::DOUBLE, domain::REAL>&, sycl::queue&);
template detail::commit_impl<precision::DOUBLE, domain::COMPLEX>* create_commit(
descriptor<precision::DOUBLE, domain::COMPLEX>&, sycl::queue&);

} // namespace mklcpu
} // namespace dft
Expand Down
5 changes: 4 additions & 1 deletion src/dft/backends/mklcpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,10 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklcpu> selector) {
pimpl_.reset(mklcpu::create_commit(*this, selector.get_queue()));
if (!pimpl_) {
pimpl_.reset(mklcpu::create_commit(*this, selector.get_queue()));
}
pimpl_->commit(selector.get_queue(), values_);
}

template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(
Expand Down
Loading

0 comments on commit aad7522

Please sign in to comment.