Skip to content

Commit

Permalink
Fix swapped input and output strides
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno committed Mar 2, 2023
1 parent 152731d commit 6358f18
Showing 1 changed file with 9 additions and 6 deletions.
15 changes: 9 additions & 6 deletions src/dft/backends/cufft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ namespace detail {
template <dft::precision prec, dft::domain dom>
class cufft_commit final : public dft::detail::commit_impl {
private:
// For real to complex transforms, the "type" arg also encodes the direction (e.g. CUFFT_R2C vs CUFFT_C2R) in the plan so we must have one for each direction.
// We also need this because oneMKL uses a directionless "FWD_DISTANCE" and "BWD_DISTANCE" while cuFFT uses a directional "idist" and "odist".
// plans[0] is forward, plans[1] is backward
std::array<cufftHandle, 2> plans;

Expand All @@ -54,7 +56,7 @@ class cufft_commit final : public dft::detail::commit_impl {
}

// The cudaStream for the plan is set at execution time so the interop handler can pick the stream.
const cufftType fwd_type = [] {
constexpr cufftType fwd_type = [] {
if constexpr (dom == dft::domain::COMPLEX) {
if constexpr (prec == dft::precision::SINGLE) {
return CUFFT_C2C;
Expand All @@ -72,7 +74,7 @@ class cufft_commit final : public dft::detail::commit_impl {
}
}
}();
const cufftType bwd_type = [] {
constexpr cufftType bwd_type = [] {
if constexpr (dom == dft::domain::COMPLEX) {
if constexpr (prec == dft::precision::SINGLE) {
return CUFFT_C2C;
Expand Down Expand Up @@ -131,15 +133,16 @@ class cufft_commit final : public dft::detail::commit_impl {
batch // batch
);

// flip fwd_distance and bwd_distance because cuFFt uses input distance and output distance.
// backward plan
cufftPlanMany(&plans[1], // plan
rank, // rank
n_copy.data(), // n
onembed.data(), // inembed
ostride, // istride
inembed.data(), // inembed
istride, // istride
bwd_dist, // idist
inembed.data(), // onembed
istride, // ostride
onembed.data(), // onembed
ostride, // ostride
fwd_dist, // odist
bwd_type, // type
batch // batch
Expand Down

0 comments on commit 6358f18

Please sign in to comment.