Skip to content

Commit

Permalink
Add missing file; add checks for MRHS.
Browse files Browse the repository at this point in the history
  • Loading branch information
hummingtree committed Dec 3, 2024
1 parent a3e2f4a commit 979ba1a
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 0 deletions.
6 changes: 6 additions & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ configure_file(multigrid.in.hpp multigrid.hpp @ONLY)
if(QUDA_MULTIGRID)
string(REPLACE "," ";" QUDA_MULTIGRID_NVEC_LIST_SEMICOLON "${QUDA_MULTIGRID_NVEC_LIST}")
string(REPLACE "," ";" QUDA_MULTIGRID_MRHS_LIST_SEMICOLON "${QUDA_MULTIGRID_MRHS_LIST}")
foreach(QUDA_MULTIGRID_MRHS ${QUDA_MULTIGRID_MRHS_LIST_SEMICOLON})
math(EXPR MOD "${QUDA_MULTIGRID_MRHS} % 8")
if(NOT (${QUDA_MULTIGRID_MRHS} GREATER_EQUAL "8" AND ${QUDA_MULTIGRID_MRHS} LESS_EQUAL "64" AND ${MOD} EQUAL "0"))
message(FATAL_ERROR "${QUDA_MULTIGRID_MRHS} is not a valid value for QUDA_MULTIGRID_MRHS_LIST")
endif()
endforeach()
foreach(QUDA_MULTIGRID_NVEC ${QUDA_MULTIGRID_NVEC_LIST_SEMICOLON})
configure_file(copy_gauge_mg.in.cu "copy_gauge_mg_${QUDA_MULTIGRID_NVEC}.cu" @ONLY)
configure_file(extract_gauge_ghost_mg.in.cu "extract_gauge_ghost_mg_${QUDA_MULTIGRID_NVEC}.cu" @ONLY)
Expand Down
52 changes: 52 additions & 0 deletions lib/multigrid.in.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
#include <util_quda.h>
#include <reference_wrapper_helper.h>
#include <color_spinor_field.h>
#include <array>
#include <algorithm>

namespace quda {

template <int... Ints> struct IntList {
};

template <int... Values>
auto sort_values() {
std::array<int, sizeof...(Values)> arr = {Values...};
// std::sort is NOT constexpr until C++20
std::sort(arr.begin(), arr.end());
return arr;
}

inline int round_to_nearest_instantiated_nVec(int input_nVec) {
// clang-format off
auto sorted_nVecs = sort_values<@QUDA_MULTIGRID_MRHS_LIST@>();
// clang-format on
for (int nVec: sorted_nVecs) {
if (input_nVec <= nVec) {
return nVec;
}
}
errorQuda("No instantiated nVec able to contain input nVec = %d", input_nVec);
return 0;
}

template <class F> auto create_color_spinor_copy(cvector_ref<F> &fs, QudaFieldOrder order)
{
ColorSpinorParam param(fs[0]);
int nVec = round_to_nearest_instantiated_nVec(fs.size());
param.nColor = fs[0].Ncolor() * nVec;
param.nVec = nVec;
param.create = QUDA_NULL_FIELD_CREATE;
param.fieldOrder = order;
return ColorSpinorField(param);
}

inline auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order)
{
ColorSpinorParam param(f);
param.create = QUDA_NULL_FIELD_CREATE;
param.fieldOrder = order;
return ColorSpinorField(param);
}

}

0 comments on commit 979ba1a

Please sign in to comment.