-
Notifications
You must be signed in to change notification settings - Fork 102
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add missing file; add checks for MRHS.
- Loading branch information
1 parent
a3e2f4a
commit 979ba1a
Showing
2 changed files
with
58 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
#include <util_quda.h> | ||
#include <reference_wrapper_helper.h> | ||
#include <color_spinor_field.h> | ||
#include <array> | ||
#include <algorithm> | ||
|
||
namespace quda { | ||
|
||
template <int... Ints> struct IntList { | ||
}; | ||
|
||
template <int... Values> | ||
auto sort_values() { | ||
std::array<int, sizeof...(Values)> arr = {Values...}; | ||
// std::sort is NOT constexpr until C++20 | ||
std::sort(arr.begin(), arr.end()); | ||
return arr; | ||
} | ||
|
||
inline int round_to_nearest_instantiated_nVec(int input_nVec) { | ||
// clang-format off | ||
auto sorted_nVecs = sort_values<@QUDA_MULTIGRID_MRHS_LIST@>(); | ||
// clang-format on | ||
for (int nVec: sorted_nVecs) { | ||
if (input_nVec <= nVec) { | ||
return nVec; | ||
} | ||
} | ||
errorQuda("No instantiated nVec able to contain input nVec = %d", input_nVec); | ||
return 0; | ||
} | ||
|
||
template <class F> auto create_color_spinor_copy(cvector_ref<F> &fs, QudaFieldOrder order) | ||
{ | ||
ColorSpinorParam param(fs[0]); | ||
int nVec = round_to_nearest_instantiated_nVec(fs.size()); | ||
param.nColor = fs[0].Ncolor() * nVec; | ||
param.nVec = nVec; | ||
param.create = QUDA_NULL_FIELD_CREATE; | ||
param.fieldOrder = order; | ||
return ColorSpinorField(param); | ||
} | ||
|
||
inline auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order) | ||
{ | ||
ColorSpinorParam param(f); | ||
param.create = QUDA_NULL_FIELD_CREATE; | ||
param.fieldOrder = order; | ||
return ColorSpinorField(param); | ||
} | ||
|
||
} |