Skip to content

Commit

Permalink
WIP: enable recommit in cuda backend
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno committed Mar 30, 2023
1 parent 46b450a commit e6478b4
Show file tree
Hide file tree
Showing 7 changed files with 88 additions and 42 deletions.
1 change: 1 addition & 0 deletions include/oneapi/mkl/dft/detail/cufft/onemkl_dft_cufft.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ namespace oneapi::mkl::dft {

namespace detail {
// Forward declarations
template <precision prec, domain dom>
class commit_impl;

template <precision prec, domain dom>
Expand Down
99 changes: 65 additions & 34 deletions src/dft/backends/cufft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@

#include <array>
#include <algorithm>
#include <optional>

#include <cufft.h>
#include <cuda.h>

#include "oneapi/mkl/dft/detail/commit_impl.hpp"
#include "oneapi/mkl/dft/detail/descriptor_impl.hpp"
Expand All @@ -39,21 +41,35 @@ namespace detail {

/// Commit impl class specialization for cuFFT.
template <dft::precision prec, dft::domain dom>
class cufft_commit final : public dft::detail::commit_impl {
class cufft_commit final : public dft::detail::commit_impl<prec, dom> {
private:
// For real to complex transforms, the "type" arg also encodes the direction (e.g. CUFFT_R2C vs CUFFT_C2R) in the plan so we must have one for each direction.
// We also need this because oneMKL uses a directionless "FWD_DISTANCE" and "BWD_DISTANCE" while cuFFT uses a directional "idist" and "odist".
// plans[0] is forward, plans[1] is backward
std::array<cufftHandle, 2> plans;
std::optional<std::array<cufftHandle, 2>> plans = std::nullopt;

public:
cufft_commit(sycl::queue& queue, const dft::detail::dft_values<prec, dom>& config_values)
: oneapi::mkl::dft::detail::commit_impl(queue, backend::cufft) {
: oneapi::mkl::dft::detail::commit_impl<prec, dom>(queue, backend::cufft) {
if constexpr (prec == dft::detail::precision::DOUBLE) {
if (!queue.get_device().has(sycl::aspect::fp64)) {
throw mkl::exception("DFT", "commit", "Device does not support double precision.");
}
}
}

void clean_plans() {
if (plans) {
cufftDestroy(plans.value()[0]);
cufftDestroy(plans.value()[1]);
plans = std::nullopt;
}
}

void commit(const dft::detail::dft_values<prec, dom>& config_values) override {
cuCtxSynchronize();
// this could be a recommit
clean_plans();

// The cudaStream for the plan is set at execution time so the interop handler can pick the stream.
constexpr cufftType fwd_type = [] {
Expand Down Expand Up @@ -119,63 +135,78 @@ class cufft_commit final : public dft::detail::commit_impl {
onembed[1] = config_values.output_strides[1] / onembed[2];
}

cufftHandle fwd_plan, bwd_plan;

// forward plan
cufftPlanMany(&plans[0], // plan
rank, // rank
n_copy.data(), // n
inembed.data(), // inembed
istride, // istride
fwd_dist, // idist
onembed.data(), // onembed
ostride, // ostride
bwd_dist, // odist
fwd_type, // type
batch // batch
auto res = cufftPlanMany(&fwd_plan, // plan
rank, // rank
n_copy.data(), // n
inembed.data(), // inembed
istride, // istride
fwd_dist, // idist
onembed.data(), // onembed
ostride, // ostride
bwd_dist, // odist
fwd_type, // type
batch // batch
);

if (res != CUFFT_SUCCESS) {
std::cout << "bad fwd\n";
}

// flip fwd_distance and bwd_distance because cuFFt uses input distance and output distance.
// backward plan
cufftPlanMany(&plans[1], // plan
rank, // rank
n_copy.data(), // n
inembed.data(), // inembed
istride, // istride
bwd_dist, // idist
onembed.data(), // onembed
ostride, // ostride
fwd_dist, // odist
bwd_type, // type
batch // batch
res = cufftPlanMany(&bwd_plan, // plan
rank, // rank
n_copy.data(), // n
inembed.data(), // inembed
istride, // istride
bwd_dist, // idist
onembed.data(), // onembed
ostride, // ostride
fwd_dist, // odist
bwd_type, // type
batch // batch
);
if (res != CUFFT_SUCCESS) {
std::cout << "bad bwd\n";
}
plans = { fwd_plan, bwd_plan };

cuCtxSynchronize();
}

~cufft_commit() override {
cufftDestroy(plans[0]);
cufftDestroy(plans[1]);
clean_plans();
}

void* get_handle() noexcept override {
return plans.data();
return plans.value().data();
}
};
} // namespace detail

template <dft::precision prec, dft::domain dom>
dft::detail::commit_impl* create_commit(const dft::detail::descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
dft::detail::commit_impl<prec, dom>* create_commit(const dft::detail::descriptor<prec, dom>& desc,
sycl::queue& sycl_queue) {
return new detail::cufft_commit<prec, dom>(sycl_queue, desc.get_values());
}

template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::SINGLE, dft::detail::domain::REAL>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::SINGLE, dft::detail::domain::REAL>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::SINGLE, dft::detail::domain::COMPLEX>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::DOUBLE, dft::detail::domain::REAL>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::DOUBLE, dft::detail::domain::REAL>&,
sycl::queue&);
template dft::detail::commit_impl* create_commit(
template dft::detail::commit_impl<dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX>*
create_commit(
const dft::detail::descriptor<dft::detail::precision::DOUBLE, dft::detail::domain::COMPLEX>&,
sycl::queue&);

Expand Down
8 changes: 7 additions & 1 deletion src/dft/backends/cufft/descriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,13 @@ namespace dft {

template <precision prec, domain dom>
void descriptor<prec, dom>::commit(backend_selector<backend::cufft> selector) {
pimpl_.reset(cufft::create_commit(*this, selector.get_queue()));
if (!pimpl_ || pimpl_->get_queue() != selector.get_queue()) {
if (pimpl_) {
pimpl_->get_queue().wait();
}
pimpl_.reset(cufft::create_commit(*this, selector.get_queue()));
}
pimpl_->commit(values_);
}

template void descriptor<precision::SINGLE, domain::COMPLEX>::commit(
Expand Down
3 changes: 2 additions & 1 deletion src/dft/backends/cufft/execute_helper.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@
namespace oneapi::mkl::dft::cufft::detail {

template <dft::precision prec, dft::domain dom>
inline dft::detail::commit_impl *checked_get_commit(dft::detail::descriptor<prec, dom> &desc) {
inline dft::detail::commit_impl<prec, dom> *checked_get_commit(
dft::detail::descriptor<prec, dom> &desc) {
auto commit_handle = dft::detail::get_commit(desc);
if (commit_handle == nullptr || commit_handle->get_backend() != backend::cufft) {
throw mkl::invalid_argument("dft/backends/cufft", "get_commit",
Expand Down
5 changes: 5 additions & 0 deletions tests/unit_tests/dft/include/compute_inplace.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ int DFT_Test<precision, domain>::test_in_place_buffer() {
complex_strides.data());
}

std::cout << "descriptor\n";
commit_descriptor(descriptor, sycl_queue);

std::vector<FwdInputType> inout_host(container_size_total, 0);
Expand Down Expand Up @@ -150,12 +151,16 @@ int DFT_Test<precision, domain>::test_in_place_buffer() {
complex_strides.data());
descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES,
real_strides.data());
std::cout << "recommitting\n";
commit_descriptor(descriptor, sycl_queue);
std::cout << "recommitted\n";
}

try {
sycl_queue.wait_and_throw();
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdInputType>(descriptor, inout_buf);
sycl_queue.wait_and_throw();
}
catch (oneapi::mkl::unimplemented& e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
Expand Down
8 changes: 3 additions & 5 deletions tests/unit_tests/dft/include/compute_out_of_place.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / forward_elements));
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES,
Expand Down Expand Up @@ -109,7 +108,7 @@ int DFT_Test<precision, domain>::test_out_of_place_buffer() {
}

// account for scaling that occurs during DFT
std::for_each(input.begin(), input.end(), [this](auto &x) { x *= size; });
std::for_each(input.begin(), input.end(), [this](auto &x) { x *= forward_elements; });

EXPECT_TRUE(check_equal_vector(fwd_data.data(), input.data(), input.size(), abs_error_margin,
rel_error_margin, std::cout));
Expand All @@ -132,7 +131,6 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
descriptor.set_value(oneapi::mkl::dft::config_param::NUMBER_OF_TRANSFORMS, batches);
descriptor.set_value(oneapi::mkl::dft::config_param::FWD_DISTANCE, forward_elements);
descriptor.set_value(oneapi::mkl::dft::config_param::BWD_DISTANCE, backward_distance);
descriptor.set_value(oneapi::mkl::dft::config_param::BACKWARD_SCALE, (1.0 / forward_elements));
if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
const auto complex_strides = get_conjugate_even_complex_strides(sizes);
descriptor.set_value(oneapi::mkl::dft::config_param::OUTPUT_STRIDES,
Expand Down Expand Up @@ -184,15 +182,15 @@ int DFT_Test<precision, domain>::test_out_of_place_USM() {
oneapi::mkl::dft::compute_backward<std::remove_reference_t<decltype(descriptor)>,
FwdOutputType, FwdInputType>(descriptor, bwd.data(),
fwd.data(), no_dependencies)
.wait();
.wait_and_throw();
}
catch (oneapi::mkl::unimplemented &e) {
std::cout << "Skipping test because: \"" << e.what() << "\"" << std::endl;
return test_skipped;
}

// account for scaling that occurs during DFT
std::for_each(input.begin(), input.end(), [this](auto &x) { x *= size; });
std::for_each(input.begin(), input.end(), [this](auto &x) { x *= forward_elements; });

EXPECT_TRUE(check_equal_vector(fwd.data(), input.data(), input.size(), abs_error_margin,
rel_error_margin, std::cout));
Expand Down
6 changes: 5 additions & 1 deletion tests/unit_tests/dft/include/compute_tester.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,11 @@ struct DFT_Test {
// improvement here if test performance becomes an issue.
out_host_ref = std::vector<FwdOutputType>(size_total);

rand_vector(input, size_total);
//rand_vector(input, size_total);
for (int i = 0; i < input.size(); ++i) {
input[i] = i;
}

if constexpr (domain == oneapi::mkl::dft::domain::REAL) {
for (int i = 0; i < input.size(); ++i) {
input_re[i] = { input[i] };
Expand Down

0 comments on commit e6478b4

Please sign in to comment.