diff --git a/src/dft/backends/cufft/commit.cpp b/src/dft/backends/cufft/commit.cpp index cb4d3dca4..81d14bc7c 100644 --- a/src/dft/backends/cufft/commit.cpp +++ b/src/dft/backends/cufft/commit.cpp @@ -41,6 +41,8 @@ namespace detail { template class cufft_commit final : public dft::detail::commit_impl { private: + // For real to complex transforms, the "type" arg also encodes the direction (e.g. CUFFT_R2C vs CUFFT_C2R) in the plan so we must have one for each direction. + // We also need this because oneMKL uses a directionless "FWD_DISTANCE" and "BWD_DISTANCE" while cuFFT uses a directional "idist" and "odist". // plans[0] is forward, plans[1] is backward std::array plans; @@ -54,7 +56,7 @@ class cufft_commit final : public dft::detail::commit_impl { } // The cudaStream for the plan is set at execution time so the interop handler can pick the stream. - const cufftType fwd_type = [] { + constexpr cufftType fwd_type = [] { if constexpr (dom == dft::domain::COMPLEX) { if constexpr (prec == dft::precision::SINGLE) { return CUFFT_C2C; @@ -72,7 +74,7 @@ class cufft_commit final : public dft::detail::commit_impl { } } }(); - const cufftType bwd_type = [] { + constexpr cufftType bwd_type = [] { if constexpr (dom == dft::domain::COMPLEX) { if constexpr (prec == dft::precision::SINGLE) { return CUFFT_C2C; @@ -131,15 +133,16 @@ class cufft_commit final : public dft::detail::commit_impl { batch // batch ); + // flip fwd_distance and bwd_distance because cuFFt uses input distance and output distance. // backward plan cufftPlanMany(&plans[1], // plan rank, // rank n_copy.data(), // n - onembed.data(), // inembed - ostride, // istride + inembed.data(), // inembed + istride, // istride bwd_dist, // idist - inembed.data(), // onembed - istride, // ostride + onembed.data(), // onembed + ostride, // ostride fwd_dist, // odist bwd_type, // type batch // batch