Skip to content

Commit

Permalink
[DFT] Allow the descriptor to be modified and recommitted (#282)
Browse files Browse the repository at this point in the history
* Allow descriptors to be recommitted for certain params

* remove unneeded argument from commit_impl::commit

* formatting

* allow commit to change the queue used

* Added wait queue when it is removed from a commit_impl

Also reported error in cpu backend

* style change

* modified descriptor recommit tests

* made recommit test out-of-place for simplicity

* Add test for queue that goes out of scope

* use the same device for any new queues in tests

* update output strides for backward stage of out-of-place tests
  • Loading branch information
FMarno authored Mar 29, 2023
1 parent e3fe270 commit f58b93f
Show file tree
Hide file tree
Showing 21 changed files with 339 additions and 222 deletions.
8 changes: 8 additions & 0 deletions include/oneapi/mkl/dft/detail/commit_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,12 @@ enum class backend;

namespace oneapi::mkl::dft::detail {

enum class precision;
enum class domain;
template <precision prec, domain dom>
class dft_values;

template <precision prec, domain dom>
class commit_impl {
public:
commit_impl(sycl::queue queue, mkl::backend backend) : queue_(queue), backend_(backend) {}
Expand All @@ -51,6 +57,8 @@ class commit_impl {

virtual void* get_handle() noexcept = 0;

virtual void commit(const dft_values<prec, dom>&) = 0;

private:
mkl::backend backend_;
sycl::queue queue_;
Expand Down
8 changes: 4 additions & 4 deletions include/oneapi/mkl/dft/detail/descriptor_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ template <precision prec, domain dom>
class descriptor;

template <precision prec, domain dom>
inline commit_impl* get_commit(descriptor<prec, dom>& desc);
inline commit_impl<prec, dom>* get_commit(descriptor<prec, dom>& desc);

template <precision prec, domain dom>
class descriptor {
Expand Down Expand Up @@ -74,16 +74,16 @@ class descriptor {

private:
// Has a value when the descriptor is committed.
std::unique_ptr<commit_impl> pimpl_;
std::unique_ptr<commit_impl<prec, dom>> pimpl_;

// descriptor configuration values_ and structs
dft_values<prec, dom> values_;

friend commit_impl* get_commit<prec, dom>(descriptor<prec, dom>&);
friend commit_impl<prec, dom>* get_commit<prec, dom>(descriptor<prec, dom>&);
};

template <precision prec, domain dom>
inline commit_impl* get_commit(descriptor<prec, dom>& desc) {
inline commit_impl<prec, dom>* get_commit(descriptor<prec, dom>& desc) {
return desc.pimpl_.get();
}

Expand Down
2 changes: 1 addition & 1 deletion include/oneapi/mkl/dft/detail/dft_ct.hxx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
// Commit

template <dft::detail::precision prec, dft::detail::domain dom>
ONEMKL_EXPORT dft::detail::commit_impl *create_commit(
ONEMKL_EXPORT dft::detail::commit_impl<prec, dom> *create_commit(
const dft::detail::descriptor<prec, dom> &desc, sycl::queue &sycl_queue);

// BUFFER version
Expand Down
4 changes: 3 additions & 1 deletion include/oneapi/mkl/dft/detail/dft_loader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,15 @@ namespace mkl {
namespace dft {
namespace detail {

template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
class descriptor;

template <precision prec, domain dom>
ONEMKL_EXPORT commit_impl* create_commit(const descriptor<prec, dom>& desc, sycl::queue& queue);
ONEMKL_EXPORT commit_impl<prec, dom>* create_commit(const descriptor<prec, dom>& desc,
sycl::queue& queue);

} // namespace detail
} // namespace dft
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/mkl/dft/detail/mklcpu/onemkl_dft_mklcpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace dft {

namespace detail {
// Forward declarations
template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
Expand Down
1 change: 1 addition & 0 deletions include/oneapi/mkl/dft/detail/mklgpu/onemkl_dft_mklgpu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ namespace dft {

namespace detail {
// Forward declarations
template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
Expand Down
8 changes: 7 additions & 1 deletion src/dft/backends/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(sycl::queue &queue) {
pimpl_.reset(detail::create_commit(*this, queue));
if (!pimpl_ || pimpl_->get_queue() != queue) {
if (pimpl_) {
pimpl_->get_queue().wait();
}
pimpl_.reset(detail::create_commit(*this, queue));
}
pimpl_->commit(values_);
}
template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(sycl::queue &);
template void descriptor<precision::SINGLE, domain::REAL>::commit(sycl::queue &);
Expand Down
38 changes: 21 additions & 17 deletions src/dft/backends/mklcpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,10 @@ namespace dft {
namespace mklcpu {

template <precision prec, domain dom>
class commit_derived_impl final : public detail::commit_impl {
class commit_derived_impl final : public detail::commit_impl<prec, dom> {
public:
commit_derived_impl(sycl::queue queue, const detail::dft_values<prec, dom>& config_values)
: detail::commit_impl(queue, backend::mklcpu) {
: detail::commit_impl<prec, dom>(queue, backend::mklcpu) {
DFT_ERROR status = DFT_NOTSET;
if (config_values.dimensions.size() == 1) {
status = DftiCreateDescriptor(&handle, get_precision(prec), get_domain(dom), 1,
Expand All @@ -55,16 +55,19 @@ class commit_derived_impl final : public detail::commit_impl {
config_values.dimensions.data());
}
if (status != DFTI_NO_ERROR) {
throw oneapi::mkl::exception("dft/backends/mklcpu", "commit",
"DftiCreateDescriptor failed");
throw oneapi::mkl::exception(
"dft/backends/mklcpu", "commit",
"DftiCreateDescriptor failed with status: " + std::to_string(status));
}
}

void commit(const detail::dft_values<prec, dom>& config_values) override {
set_value(handle, config_values);

status = DftiCommitDescriptor(handle);
auto status = DftiCommitDescriptor(handle);
if (status != DFTI_NO_ERROR) {
throw oneapi::mkl::exception("dft/backends/mklcpu", "commit",
"DftiCommitDescriptor failed");
throw oneapi::mkl::exception(
"dft/backends/mklcpu", "commit",
"DftiCommitDescriptor failed with status: " + std::to_string(status));
}
}

Expand Down Expand Up @@ -122,18 +125,19 @@ class commit_derived_impl final : public detail::commit_impl {
};

template <precision prec, domain dom>
detail::commit_impl* create_commit(const descriptor<prec, dom>& desc, sycl::queue& sycl_queue) {
detail::commit_impl<prec, dom>* create_commit(const descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
return new commit_derived_impl<prec, dom>(sycl_queue, desc.get_values());
}

template detail::commit_impl* create_commit(const descriptor<precision::SINGLE, domain::REAL>&,
sycl::queue&);
template detail::commit_impl* create_commit(const descriptor<precision::SINGLE, domain::COMPLEX>&,
sycl::queue&);
template detail::commit_impl* create_commit(const descriptor<precision::DOUBLE, domain::REAL>&,
sycl::queue&);
template detail::commit_impl* create_commit(const descriptor<precision::DOUBLE, domain::COMPLEX>&,
sycl::queue&);
template detail::commit_impl<precision::SINGLE, domain::REAL>* create_commit(
const descriptor<precision::SINGLE, domain::REAL>&, sycl::queue&);
template detail::commit_impl<precision::SINGLE, domain::COMPLEX>* create_commit(
const descriptor<precision::SINGLE, domain::COMPLEX>&, sycl::queue&);
template detail::commit_impl<precision::DOUBLE, domain::REAL>* create_commit(
const descriptor<precision::DOUBLE, domain::REAL>&, sycl::queue&);
template detail::commit_impl<precision::DOUBLE, domain::COMPLEX>* create_commit(
const descriptor<precision::DOUBLE, domain::COMPLEX>&, sycl::queue&);

} // namespace mklcpu
} // namespace dft
Expand Down
8 changes: 7 additions & 1 deletion src/dft/backends/mklcpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklcpu> selector) {
pimpl_.reset(mklcpu::create_commit(*this, selector.get_queue()));
if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) {
if (pimpl_) {
pimpl_->get_queue().wait();
}
pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue()));
}
pimpl_->commit(values_);
}

template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(
Expand Down
28 changes: 16 additions & 12 deletions src/dft/backends/mklgpu/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ namespace detail {

/// Commit impl class specialization for MKLGPU.
template <dft::detail::precision prec, dft::detail::domain dom>
class commit_derived_impl final : public dft::detail::commit_impl {
class commit_derived_impl final : public dft::detail::commit_impl<prec, dom> {
private:
// Equivalent MKLGPU precision and domain from OneMKL's precision / domain.
static constexpr dft::precision mklgpu_prec = to_mklgpu(prec);
Expand All @@ -60,19 +60,21 @@ class commit_derived_impl final : public dft::detail::commit_impl {

public:
commit_derived_impl(sycl::queue queue, const dft::detail::dft_values<prec, dom>& config_values)
: oneapi::mkl::dft::detail::commit_impl(queue, backend::mklgpu),
: oneapi::mkl::dft::detail::commit_impl<prec, dom>(queue, backend::mklgpu),
handle(config_values.dimensions) {
set_value(handle, config_values);
// MKLGPU does not throw an informative exception for the following:
if constexpr (prec == dft::detail::precision::DOUBLE) {
if (!queue.get_device().has(sycl::aspect::fp64)) {
throw mkl::exception("dft/backends/mklgpu", "commit",
"Device does not support double precision.");
}
}
}

virtual void commit(const dft::detail::dft_values<prec, dom>& config_values) override {
set_value(handle, config_values);
try {
handle.commit(queue);
handle.commit(this->get_queue());
}
catch (const std::exception& mkl_exception) {
// Catching the real MKL exception causes headaches with naming.
Expand Down Expand Up @@ -125,28 +127,30 @@ class commit_derived_impl final : public dft::detail::commit_impl {
throw mkl::invalid_argument("dft/backends/mklgpu", "commit",
"MKLGPU only supports non-transposed.");
}
desc.set_value(backend_param::PACKED_FORMAT,
to_mklgpu<onemkl_param::PACKED_FORMAT>(config.packed_format));
}
};
} // namespace detail

template <dft::detail::precision prec, dft::detail::domain dom>
dft::detail::commit_impl* create_commit(const dft::detail::descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
dft::detail::commit_impl<prec, dom>* create_commit(const dft::detail::descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
return new detail::commit_derived_impl<prec, dom>(sycl_queue, desc.get_values());
}

template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::SINGLE, dft::detail::domain::REAL>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::SINGLE, dft::detail::domain::REAL>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::DOUBLE, dft::detail::domain::REAL>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::DOUBLE, dft::detail::domain::REAL>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX>&,
sycl::queue&);

Expand Down
8 changes: 7 additions & 1 deletion src/dft/backends/mklgpu/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::mklgpu> selector) {
pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue()));
if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) {
if (pimpl_) {
pimpl_->get_queue().wait();
}
pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue()));
}
pimpl_->commit(values_);
}

template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(
Expand Down
4 changes: 0 additions & 4 deletions src/dft/descriptor.cxx
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,6 @@ void compute_default_strides(const std::vector<std::int64_t>& dimensions,

template <precision prec, domain dom>
void descriptor<prec, dom>::set_value(config_param param, ...) {
if (pimpl_) {
throw mkl::invalid_argument("DFT", "set_value",
"Cannot set value on committed descriptor.");
}
va_list vl;
va_start(vl, param);
switch (param) {
Expand Down
8 changes: 4 additions & 4 deletions src/dft/dft_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,28 +34,28 @@ static oneapi::mkl::detail::table_initializer<mkl::domain::dft, dft_function_tab
function_tables;

template <>
commit_impl* create_commit<precision::SINGLE, domain::COMPLEX>(
commit_impl<precision::SINGLE, domain::COMPLEX>* create_commit<precision::SINGLE, domain::COMPLEX>(
const descriptor<precision::SINGLE, domain::COMPLEX>& desc, sycl::queue& sycl_queue) {
auto libkey = get_device_id(sycl_queue);
return function_tables[libkey].create_commit_sycl_fz(desc, sycl_queue);
}

template <>
commit_impl* create_commit<precision::DOUBLE, domain::COMPLEX>(
commit_impl<precision::DOUBLE, domain::COMPLEX>* create_commit<precision::DOUBLE, domain::COMPLEX>(
const descriptor<precision::DOUBLE, domain::COMPLEX>& desc, sycl::queue& sycl_queue) {
auto libkey = get_device_id(sycl_queue);
return function_tables[libkey].create_commit_sycl_dz(desc, sycl_queue);
}

template <>
commit_impl* create_commit<precision::SINGLE, domain::REAL>(
commit_impl<precision::SINGLE, domain::REAL>* create_commit<precision::SINGLE, domain::REAL>(
const descriptor<precision::SINGLE, domain::REAL>& desc, sycl::queue& sycl_queue) {
auto libkey = get_device_id(sycl_queue);
return function_tables[libkey].create_commit_sycl_fr(desc, sycl_queue);
}

template <>
commit_impl* create_commit<precision::DOUBLE, domain::REAL>(
commit_impl<precision::DOUBLE, domain::REAL>* create_commit<precision::DOUBLE, domain::REAL>(
const descriptor<precision::DOUBLE, domain::REAL>& desc, sycl::queue& sycl_queue) {
auto libkey = get_device_id(sycl_queue);
return function_tables[libkey].create_commit_sycl_dr(desc, sycl_queue);
Expand Down
14 changes: 10 additions & 4 deletions src/dft/function_table.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,19 +35,25 @@

typedef struct {
int version;
oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fz)(
oneapi::mkl::dft::detail::commit_impl<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::COMPLEX>* (
*create_commit_sycl_fz)(
const oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::COMPLEX>& desc,
sycl::queue& sycl_queue);
oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dz)(
oneapi::mkl::dft::detail::commit_impl<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::COMPLEX>* (
*create_commit_sycl_dz)(
const oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::COMPLEX>& desc,
sycl::queue& sycl_queue);
oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_fr)(
oneapi::mkl::dft::detail::commit_impl<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::REAL>* (*create_commit_sycl_fr)(
const oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::SINGLE,
oneapi::mkl::dft::domain::REAL>& desc,
sycl::queue& sycl_queue);
oneapi::mkl::dft::detail::commit_impl* (*create_commit_sycl_dr)(
oneapi::mkl::dft::detail::commit_impl<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::REAL>* (*create_commit_sycl_dr)(
const oneapi::mkl::dft::descriptor<oneapi::mkl::dft::precision::DOUBLE,
oneapi::mkl::dft::domain::REAL>& desc,
sycl::queue& sycl_queue);
Expand Down
Loading

1 comment on commit f58b93f

@HaoweiZhangIntel
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be mklcpu instead of mklgpu @ src/dft/backends/mklcpu/descriptor.cpp:
pimpl_.reset(mklgpu::create_commit(*this, selector.get_queue()));

Please sign in to comment.