Skip to content

Commit

Permalink
Add an exception for when the user tries to scale with cufft
Browse files Browse the repository at this point in the history
  • Loading branch information
FMarno committed Apr 20, 2023
1 parent 1d20490 commit 3840668
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
6 changes: 6 additions & 0 deletions src/dft/backends/cufft/commit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,12 @@ class cufft_commit final : public dft::detail::commit_impl<prec, dom> {
// this could be a recommit
clean_plans();

if (config_values.fwd_scale != 1.0 || config_values.bwd_scale != 1.0) {
throw mkl::unimplemented(
"dft/backends/cufft", __FUNCTION__,
"cuFFT does not support values other than 1 for FORWARD/BACKWARD_SCALE");
}

// The cudaStream for the plan is set at execution time so the interop handler can pick the stream.
constexpr cufftType fwd_type = [] {
if constexpr (dom == dft::domain::COMPLEX) {
Expand Down
12 changes: 8 additions & 4 deletions tests/unit_tests/dft/source/descriptor_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -429,9 +429,6 @@ inline void recommit_values(sycl::queue& sycl_queue) {
std::vector<test_params> argument_groups{
// not changeable
// FORWARD_DOMAIN, PRECISION, DIMENSION, COMMIT_STATUS
{ std::make_pair(config_param::LENGTHS, std::int64_t{ 10 }),
std::make_pair(config_param::FORWARD_SCALE, PrecisionType(1.2)),
std::make_pair(config_param::BACKWARD_SCALE, PrecisionType(3.4)) },
{ std::make_pair(config_param::NUMBER_OF_TRANSFORMS, std::int64_t{ 5 }),
std::make_pair(config_param::COMPLEX_STORAGE, config_value::COMPLEX_COMPLEX),
std::make_pair(config_param::REAL_STORAGE, config_value::REAL_REAL),
Expand All @@ -444,7 +441,10 @@ inline void recommit_values(sycl::queue& sycl_queue) {
{ std::make_pair(config_param::WORKSPACE, config_value::ALLOW),
std::make_pair(config_param::ORDERING, config_value::ORDERED),
std::make_pair(config_param::TRANSPOSE, bool{ false }),
std::make_pair(config_param::PACKED_FORMAT, config_value::CCE_FORMAT) }
std::make_pair(config_param::PACKED_FORMAT, config_value::CCE_FORMAT) },
{ std::make_pair(config_param::LENGTHS, std::int64_t{ 10 }),
std::make_pair(config_param::FORWARD_SCALE, PrecisionType( 1.2 )),
std::make_pair(config_param::BACKWARD_SCALE, PrecisionType( 3.4 )) }
};

for (std::size_t i = 0; i < argument_groups.size(); i += 1) {
Expand All @@ -455,6 +455,10 @@ inline void recommit_values(sycl::queue& sycl_queue) {
try {
commit_descriptor(descriptor, sycl_queue);
}
catch (oneapi::mkl::unimplemented e) {
std::cout << "unimplemented exception at index " << i << " with error : " << e.what()
<< "\ncontinuing...\n";
}
catch (oneapi::mkl::exception& e) {
FAIL() << "exception at index " << i << " with error : " << e.what();
}
Expand Down

0 comments on commit 3840668

Please sign in to comment.