From 979ba1a047aa8c7e08d22630c20e3ad2f27edd9c Mon Sep 17 00:00:00 2001 From: Jiqun Tu Date: Tue, 3 Dec 2024 15:08:09 -0800 Subject: [PATCH] Add missing file; add checks for MRHS. --- lib/CMakeLists.txt | 6 +++++ lib/multigrid.in.hpp | 52 ++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 58 insertions(+) create mode 100644 lib/multigrid.in.hpp diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index 3da3fbbad4..017341adc5 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -191,6 +191,12 @@ configure_file(multigrid.in.hpp multigrid.hpp @ONLY) if(QUDA_MULTIGRID) string(REPLACE "," ";" QUDA_MULTIGRID_NVEC_LIST_SEMICOLON "${QUDA_MULTIGRID_NVEC_LIST}") string(REPLACE "," ";" QUDA_MULTIGRID_MRHS_LIST_SEMICOLON "${QUDA_MULTIGRID_MRHS_LIST}") + foreach(QUDA_MULTIGRID_MRHS ${QUDA_MULTIGRID_MRHS_LIST_SEMICOLON}) + math(EXPR MOD "${QUDA_MULTIGRID_MRHS} % 8") + if(NOT (${QUDA_MULTIGRID_MRHS} GREATER_EQUAL "8" AND ${QUDA_MULTIGRID_MRHS} LESS_EQUAL "64" AND ${MOD} EQUAL "0")) + message(FATAL_ERROR "${QUDA_MULTIGRID_MRHS} is not a valid value for QUDA_MULTIGRID_MRHS_LIST") + endif() + endforeach() foreach(QUDA_MULTIGRID_NVEC ${QUDA_MULTIGRID_NVEC_LIST_SEMICOLON}) configure_file(copy_gauge_mg.in.cu "copy_gauge_mg_${QUDA_MULTIGRID_NVEC}.cu" @ONLY) configure_file(extract_gauge_ghost_mg.in.cu "extract_gauge_ghost_mg_${QUDA_MULTIGRID_NVEC}.cu" @ONLY) diff --git a/lib/multigrid.in.hpp b/lib/multigrid.in.hpp new file mode 100644 index 0000000000..46da3568e0 --- /dev/null +++ b/lib/multigrid.in.hpp @@ -0,0 +1,52 @@ +#include +#include +#include +#include +#include + +namespace quda { + + template struct IntList { + }; + + template + auto sort_values() { + std::array arr = {Values...}; + // std::sort is NOT constexpr until C++20 + std::sort(arr.begin(), arr.end()); + return arr; + } + + inline int round_to_nearest_instantiated_nVec(int input_nVec) { + // clang-format off + auto sorted_nVecs = sort_values<@QUDA_MULTIGRID_MRHS_LIST@>(); + // clang-format on + for (int nVec: sorted_nVecs) { + if (input_nVec <= nVec) { + return nVec; + } + } + errorQuda("No instantiated nVec able to contain input nVec = %d", input_nVec); + return 0; + } + + template auto create_color_spinor_copy(cvector_ref &fs, QudaFieldOrder order) + { + ColorSpinorParam param(fs[0]); + int nVec = round_to_nearest_instantiated_nVec(fs.size()); + param.nColor = fs[0].Ncolor() * nVec; + param.nVec = nVec; + param.create = QUDA_NULL_FIELD_CREATE; + param.fieldOrder = order; + return ColorSpinorField(param); + } + + inline auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order) + { + ColorSpinorParam param(f); + param.create = QUDA_NULL_FIELD_CREATE; + param.fieldOrder = order; + return ColorSpinorField(param); + } + +}