diff --git a/CMakeLists.txt b/CMakeLists.txt index abcf27de99..4043970062 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -151,8 +151,8 @@ option(QUDA_DIRAC_CLOVER_HASENBUSCH "build clover Hasenbusch twist operators" ${ option(QUDA_DIRAC_NDEG_TWISTED_MASS "build non-degenerate twisted mass Dirac operators" ${QUDA_DIRAC_DEFAULT}) option(QUDA_DIRAC_NDEG_TWISTED_CLOVER "build non-degenerate twisted clover Dirac operators" ${QUDA_DIRAC_DEFAULT}) option(QUDA_DIRAC_LAPLACE "build laplace operator" ${QUDA_DIRAC_DEFAULT}) - option(QUDA_DIRAC_DISTANCE_PRECONDITIONING "build code for distance preconditioned Wilson/clover Dirac operators" OFF) +set(QUDA_DOMAIN_DECOMPOSITION "0" CACHE STRING "which domain decomposition to instantiate in QUDA (1-bit number - RedBlack)") option(QUDA_COVDEV "build code for covariant derivative" OFF) diff --git a/ci/docker/Dockerfile.build b/ci/docker/Dockerfile.build index 5ca1d7a3dc..3019b28c38 100644 --- a/ci/docker/Dockerfile.build +++ b/ci/docker/Dockerfile.build @@ -39,6 +39,7 @@ RUN QUDA_TEST_GRID_SIZE="1 1 1 2" cmake -S /quda/src \ -DQUDA_MULTIGRID_NVEC_LIST=6 \ -DQUDA_MDW_FUSED_LS_LIST=4 \ -DQUDA_MPI=ON \ + -DQUDA_DSLASH_DISTANCE=1 \ -DQUDA_DIRAC_DEFAULT_OFF=ON \ -DQUDA_DIRAC_WILSON=ON \ -DQUDA_DIRAC_CLOVER=ON \ diff --git a/include/color_spinor_field.h b/include/color_spinor_field.h index 9edc8fbdef..6ce0282049 100644 --- a/include/color_spinor_field.h +++ b/include/color_spinor_field.h @@ -150,6 +150,7 @@ namespace quda int composite_dim = 0; // e.g., number of eigenvectors in the set bool is_component = false; int component_id = 0; // eigenvector index + DDParam dd {}; /** If using CUDA native fields, this function will ensure that the @@ -367,6 +368,9 @@ namespace quda // CompositeColorSpinorField components; + /** Domain decomposition options */ + DDParam dd {}; + /** Compute the required extended ghost zone sizes and offsets @param[in] nFace The depth of the halo @@ -449,6 +453,31 @@ namespace quda */ void copy(const ColorSpinorField &src); + /** + @brief Project the field to a domain determined by DDParam + */ + void projectDD(); + + /** + @brief Returns DDParam (const version) + */ + const DDParam& DD() const { return dd; } + + /** + @brief Returns DDParam (non const version) + */ + DDParam& DD() { return dd; } + + /** + @brief Sets DDParam from a given DDParam + */ + void DD(const DDParam &in) { dd = in; } + + /** + @brief Sets DDParam from a given list of options (DD flags) + */ + template void DD(const quda::DD &flag, const Args &...args) { dd.set(flag, args...); } + /** @brief Zero all elements of this field */ @@ -993,6 +1022,8 @@ namespace quda void *Dst = nullptr, const void *Src = nullptr); void genericSource(ColorSpinorField &a, QudaSourceType sourceType, int x, int s, int c); + + void genericProjectDD(ColorSpinorField &a); int genericCompare(const ColorSpinorField &a, const ColorSpinorField &b, int tol); /** diff --git a/include/declare_enum.h b/include/declare_enum.h new file mode 100644 index 0000000000..69b721b343 --- /dev/null +++ b/include/declare_enum.h @@ -0,0 +1,60 @@ +/* + * A macro that declares an `enum class` as well as a `to_string` function for the enums. + * The enum has also a default value `size` that measures the size of the enum. + * + * Credit: https://stackoverflow.com/a/71375077/12084612 + * ------- + * License: CC BY-SA 4.0 + * -------- + * Usage: + * ------ + * + * DECLARE_ENUM(WeekEnum, Mon, Tue, Wed, Thu, Fri, Sat, Sun,); + * + * int main() + * { + * WeekEnum weekDay = WeekEnum::Wed; + * std::cout << to_string(weekDay) << std::endl; // prints Wed + * std::cout << to_string(WeekEnum::Sat) << std::endl; // prints Sat + * std::cout << to_string((int) WeekEnum::size) << std::endl; // prints 7 + * return 0; + * } + * + */ + +#pragma once +#include +#include +#include +#include +#include + +// Add the definition of this method into a cpp file. (only the declaration in the header) +static inline const std::vector get_enum_names(const std::string &en_key, const std::string &en_str) +{ + static std::unordered_map> en_names_map; + const auto it = en_names_map.find(en_key); + if (it != en_names_map.end()) return it->second; + + constexpr auto delim(','); + std::vector en_names; + std::size_t start {}; + auto end = en_str.find(delim); + while (end != std::string::npos) { + while (en_str[start] == ' ') ++start; + en_names.push_back(en_str.substr(start, end - start)); + start = end + 1; + end = en_str.find(delim, start); + } + while (en_str[start] == ' ') ++start; + en_names.push_back(en_str.substr(start)); + return en_names_map.emplace(en_key, std::move(en_names)).first->second; +} + +#define DECLARE_ENUM(ENUM_NAME, ...) \ + enum class ENUM_NAME : unsigned int { __VA_ARGS__ size }; \ + inline std::string to_string(ENUM_NAME en) \ + { \ + const auto names = get_enum_names(#ENUM_NAME, #__VA_ARGS__); \ + return names[static_cast(en)]; \ + } diff --git a/include/domain_decomposition.h b/include/domain_decomposition.h new file mode 100644 index 0000000000..24e653ac37 --- /dev/null +++ b/include/domain_decomposition.h @@ -0,0 +1,153 @@ +#pragma once + +#include "declare_enum.h" + +namespace quda +{ + + // using namespace quda; + + DECLARE_ENUM(DD, // name of the enum class + + reset, // No domain decomposition. It sets all flags to zero. + + red_black_type, // Flags used by red_black + red_active, // if red blocks are active + black_active, // if black blocks are active + no_block_hopping, // if hopping between red and black is allowed + ); + + // Params for domain decompation + struct DDParam { + + QudaDDType type = QUDA_DD_NO; + array(DD::size)> flags = {}; // the default value of all flags is 0 + array block_dim = {}; // the size of the block per direction + + // Default constructor + DDParam() = default; + + // returns false if in use + constexpr bool operator!() const { return type == QUDA_DD_NO; } + + // returns value of given flag + constexpr bool is(const DD &flag) const { return flags[(int)flag]; } + + // sets given flag to true + constexpr void set(const DD &flag) + { + flags[(int)flag] = true; + + if ((int)flag == (int)DD::reset) { +#pragma unroll + for (auto i = 0u; i < (int)DD::size; i++) flags[i] = 0; + type = QUDA_DD_NO; + } else if ((int)flag >= (int)DD::red_black_type) { + type = QUDA_DD_RED_BLACK; + } + } + + template constexpr void set(const DD &flag, const Args &...args) + { + set(flag); + set(args...); + } + + // Pretty print the args struct + void print() const + { + if (not *this) { + printfQuda("DD not in use\n"); + return; + } + printfQuda("Printing DDParam\n"); + for (int i = 0; i < (int)DD::size; i++) + printfQuda("flags[DD::%s] = %s\n", to_string((DD)i).c_str(), flags[i] ? "true" : "false"); + for (int i = 0; i < QUDA_MAX_DIM; i++) printfQuda("block_dim[%d] = %d\n", i, static_cast(block_dim[i])); + } + + // Checks if this matches to given DDParam + template inline bool check(const F &field, bool verbose = false) const + { + if (not *this) return true; + + if (type == QUDA_DD_RED_BLACK) { + for (int i = 0; i < field.Ndim(); i++) { + if (block_dim[i] < 0) { + if (verbose) printfQuda("block_dim[%d] = %d is negative\n", i, block_dim[i]); + return false; + } + if (block_dim[i] > 0) { + int globalDim = comm_dim(i) * field.full_dim(i); + if (globalDim % block_dim[i] != 0) { + if (verbose) printfQuda("block_dim[%d] = %d does not divide %d \n", i, block_dim[i], globalDim); + return false; + } + if ((globalDim / block_dim[i]) % 2 != 0) { + if (verbose) + printfQuda("block_dim[%d] = %d does not divide %d **evenly** \n", i, block_dim[i], globalDim); + return false; + } + } + } + if (block_dim[0] % 2) { + if (verbose) printfQuda("block_dim[0] = %d must be even \n", block_dim[0]); + return false; + } + } + + return true; + } + + // Checks if this matches to given DDParam + inline bool match(const DDParam &dd, bool verbose = false) const + { + // if one of the two is not in use we return true, i.e. one of the two is a full field + if (not *this or not dd) return true; + + // false if type does not match + if (type != dd.type) { + if (verbose) printfQuda("DD type do not match (%d != %d)\n", type, dd.type); + return false; + } + + if (type == QUDA_DD_RED_BLACK) { + for (int i = 0; i < QUDA_MAX_DIM; i++) + if (block_dim[i] != dd.block_dim[i]) { + if (verbose) printfQuda("block_dim[%d] = %d != %d \n", i, block_dim[i], dd.block_dim[i]); + return false; + } + if (is(DD::no_block_hopping) != dd.is(DD::no_block_hopping)) { + if (verbose) printfQuda("no_block_hopping do not match.\n"); + return false; + } + } + + return true; + } + + // Checks if this is equal to given DDParam + inline bool operator==(const DDParam &dd) const + { + // if both are not in use we return true + if (not *this and not dd) return true; + + // false if type does not match + if (type != dd.type) return false; + + // checking all flags matches (note this should be actually type-wise) + for (int i = 0; i < (int)DD::size; i++) + if (flags[i] != dd.flags[i]) return false; + + // checking block_dim matches when needed + if (type == QUDA_DD_RED_BLACK) + for (int i = 0; i < QUDA_MAX_DIM; i++) + if (block_dim[i] != dd.block_dim[i]) return false; + + return true; + } + + inline bool operator!=(const DDParam &dd) const { return !(*this == dd); } + }; + +} // namespace quda diff --git a/include/domain_decomposition_helper.cuh b/include/domain_decomposition_helper.cuh new file mode 100644 index 0000000000..2b7ab5acea --- /dev/null +++ b/include/domain_decomposition_helper.cuh @@ -0,0 +1,112 @@ +#pragma once + +#include +#include + +namespace quda +{ + + // No DD (use also as a template for required functions) + struct DDNo { + + // Initialization of input parameters from ColorSpinorField + DDNo(const DDParam &dd) + { + if (dd.type != QUDA_DD_NO) { errorQuda("Unsupported type %d\n", dd.type); } + } + + // Only DDNo returns true. All others return false + constexpr bool operator!() const { return true; } + + // Whether comms are required along given direction + template constexpr bool commDim(int, const DDArg &, const Arg &) const + { + return true; + } + + // Whether field at given coord is zero + template constexpr bool isZero(const Coord &) const { return false; } + + // Whether do hopping with field at neighboring coord + template constexpr bool doHopping(const Coord &, int, int) const { return true; } + }; + + // Red-black Block DD + struct DDRedBlack { + + const int_fastdiv block_dim[QUDA_MAX_DIM]; // the size of the block per direction + const bool red_active; // if red blocks are active + const bool black_active; // if black blocks are active + const bool block_hopping; // if hopping between red and black is allowed + + DDRedBlack(const DDParam &dd) : + block_dim {dd.block_dim[0], dd.block_dim[1], dd.block_dim[2], dd.block_dim[3]}, + red_active(dd.type == QUDA_DD_NO or dd.is(DD::red_active)), + black_active(dd.type == QUDA_DD_NO or dd.is(DD::black_active)), + block_hopping(dd.type == QUDA_DD_NO or not dd.is(DD::no_block_hopping)) + { + if (dd.type != QUDA_DD_NO and dd.type != QUDA_DD_RED_BLACK) { errorQuda("Unsupported type %d", dd.type); } + } + + constexpr bool operator!() const { return false; } + + // Whether comms are required along given direction + template constexpr bool commDim(int d, const DDArg &dd, const Arg &arg) const + { + if (not red_active and not black_active) return false; + if (not dd.red_active and not dd.black_active) return false; + if (arg.dim[d] % block_dim[d] == 0) { + if (not red_active and not dd.red_active) return false; + if (not black_active and not dd.black_active) return false; + if (not block_hopping and not dd.block_hopping) return false; + } + return true; + } + + // Computes block_parity: 0 = red, 1 = black + template constexpr bool block_parity(const Coord &x) const + { + int block_parity = 0; + for (int i = 0; i < x.size(); i++) { + if (block_dim[i] > 0) block_parity += x.gx[i] / block_dim[i]; + } + return block_parity % 2 == 1; + } + + template constexpr bool on_border(const Coord &x, int mu, int dist) const + { + if (block_dim[mu] == 0) return false; + int x_mu = x.gx[mu] + dist; + if (x_mu < 0) x_mu += x.gDim[mu]; + if (x_mu >= x.gDim[mu]) x_mu -= x.gDim[mu]; + return x.gx[mu] / block_dim[mu] != x_mu / block_dim[mu]; + } + + template constexpr bool isZero(const Coord &x) const + { + bool is_black = block_parity(x); + bool is_red = not is_black; + + if (is_red and red_active) return false; + if (is_black and black_active) return false; + return true; + } + + template constexpr bool doHopping(const Coord &x, int mu, int dist) const + { + bool is_black = block_parity(x); + bool is_red = !is_black; + bool is_border = on_border(x, mu, dist); + + if (!is_border) { // Within block + if (is_red and red_active) return true; + if (is_black and black_active) return true; + } else if (block_hopping) { // Between blocks + if (is_red and black_active) return true; + if (is_black and red_active) return true; + } + return false; + } + }; + +} // namespace quda diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index 6b4747f39e..4990abc333 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -11,6 +11,7 @@ #include #include #include +#include constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::TRUE; @@ -99,6 +100,7 @@ namespace quda { constexpr auto nDim = Arg::nDim; Coord coord; + for (auto i = 0; i < nDim; i++) coord.gDim[i] = arg.gDim[i]; dim = kernel_type; // keep compiler happy // only for 5-d checkerboarding where we need to include the fifth dimension @@ -150,6 +152,7 @@ namespace quda coordsFromFaceIndex(coord.X, coord.x_cb, coord, idx, face_num, parity, arg); } } + for (int i = 0; i < nDim; i++) { coord.gx[i] = arg.commCoord[i] + coord.x[i]; } coord.s = s; return coord; } @@ -236,7 +239,7 @@ namespace quda return true; } - template struct DslashArg { + template struct DslashArg { using Float = Float_; using real = typename mapper::type; @@ -250,9 +253,13 @@ namespace quda const int_fastdiv X0h; const int_fastdiv dim[5]; // full lattice dimensions + const int gDim[5]; // global full lattice dimensions const int volumeCB; // checkerboarded volume int commDim[4]; // whether a given dimension is partitioned or not (potentially overridden for Schwarz) + const int commCoord[5]; + const int globalDim3; + const bool dagger; // dagger const bool xpay; // whether we are doing xpay or not @@ -283,6 +290,10 @@ namespace quda int exterior_dims; // dimension to run in the exterior Dslash int exterior_blocks; + DDArg dd_out; + DDArg dd_in; + DDArg dd_x; + // for shmem ... static constexpr bool packkernel = false; void *packBuffer[4 * QUDA_MAX_DIM]; @@ -316,7 +327,10 @@ namespace quda reconstruct(U.Reconstruct()), X0h(nParity == 2 ? in.X(0) / 2 : in.X(0)), dim {(3 - nParity) * in.X(0), in.X(1), in.X(2), in.X(3), in.Ndim() == 5 ? in.X(4) : 1}, + gDim {comm_dim(0) * dim[0], comm_dim(1) * dim[1], comm_dim(2) * dim[2], comm_dim(3) * dim[3], dim[4]}, volumeCB(in.VolumeCB()), + commCoord {comm_coord(0) * dim[0], comm_coord(1) * dim[1], comm_coord(2) * dim[2], comm_coord(3) * dim[3], dim[4]}, + globalDim3(comm_dim(3) * this->dim[3]), dagger(dagger), xpay(xpay), kernel_type(INTERIOR_KERNEL), @@ -336,6 +350,9 @@ namespace quda pack_blocks(0), exterior_dims(0), exterior_blocks(0), + dd_out(out.DD()), + dd_in(in.DD()), + dd_x(x.DD()), #ifndef NVSHMEM_COMMS counter(0) #else @@ -354,10 +371,11 @@ namespace quda if (in[i].data() == out[i].data()) errorQuda("Aliasing pointers"); checkOrder(out, in, x); // check all orders match checkLocation(out, in, x, U); // check all locations match + checkDD(out, in, x); // check all DD match checkNative(in, U); for (int d = 0; d < 4; d++) { - commDim[d] = (comm_override[d] == 0) ? 0 : comm_dim_partitioned(d); + commDim[d] = (comm_override[d] == 0) ? 0 : (comm_dim_partitioned(d) * dd_out.commDim(d, dd_in, *this)); } if (in.Location() == QUDA_CUDA_FIELD_LOCATION) { @@ -413,7 +431,8 @@ namespace quda } }; - template std::ostream &operator<<(std::ostream &out, const DslashArg &arg) + template + std::ostream &operator<<(std::ostream &out, const DslashArg &arg) { out << "parity = " << arg.parity << std::endl; out << "nParity = " << arg.nParity << std::endl; diff --git a/include/enum_quda.h b/include/enum_quda.h index 462ce98cf0..ed0cb1897b 100644 --- a/include/enum_quda.h +++ b/include/enum_quda.h @@ -372,16 +372,18 @@ typedef enum QudaFieldCreate_s { QUDA_INVALID_FIELD_CREATE = QUDA_INVALID_ENUM } QudaFieldCreate; -typedef enum QudaGammaBasis_s { // gamj=((top 2 rows)(bottom 2 rows)) s1,s2,s3 are Pauli spin matrices, 1 is 2x2 identity - QUDA_DEGRAND_ROSSI_GAMMA_BASIS, // gam1=((0,i*s1)(-i*s1,0)) gam2=((0,-i*s2)(i*s2,0)) gam3=((0,i*s3)(-i*s3,0)) gam4=((0,1)(1,0)) gam5=((-1,0)(0,1)) - QUDA_UKQCD_GAMMA_BASIS, // gam1=((0,i*s1)(-i*s1,0)) gam2=((0,i*s2)(-i*s2,0)) gam3=((0,i*s3)(-i*s3,0)) gam4=((1,0)(0,-1)) gam5=((0,-1)(-1,0)) - QUDA_CHIRAL_GAMMA_BASIS, // gam1=((0,-i*s1)(i*s1,0)) gam2=((0,-i*s2)(i*s2,0)) gam3=((0,-i*s3)(i*s3,0)) gam4=((0,-1)(-1,0))gam5=((1,0)(0,-1)) - QUDA_DIRAC_PAULI_GAMMA_BASIS, // gam1=((0,-i*s1)(i*s1,0)) gam2=((0,-i*s2)(i*s2,0)) gam3=((0,-i*s3)(i*s3,0)) gam4=((1,0)(0,-1)) gam5=((0,1)(1,0)) - QUDA_INVALID_GAMMA_BASIS = QUDA_INVALID_ENUM // gam5=gam4*gam1*gam2*gam3 +typedef enum QudaGammaBasis_s { // gamj=((top 2 rows)(bottom 2 rows)) s1,s2,s3 are Pauli spin matrices, 1 is 2x2 identity + QUDA_DEGRAND_ROSSI_GAMMA_BASIS, // gam1=((0,i*s1)(-i*s1,0)) gam2=((0,-i*s2)(i*s2,0)) gam3=((0,i*s3)(-i*s3,0)) + // gam4=((0,1)(1,0)) gam5=((-1,0)(0,1)) + QUDA_UKQCD_GAMMA_BASIS, // gam1=((0,i*s1)(-i*s1,0)) gam2=((0,i*s2)(-i*s2,0)) gam3=((0,i*s3)(-i*s3,0)) gam4=((1,0)(0,-1)) gam5=((0,-1)(-1,0)) + QUDA_CHIRAL_GAMMA_BASIS, // gam1=((0,-i*s1)(i*s1,0)) gam2=((0,-i*s2)(i*s2,0)) gam3=((0,-i*s3)(i*s3,0)) gam4=((0,-1)(-1,0))gam5=((1,0)(0,-1)) + QUDA_DIRAC_PAULI_GAMMA_BASIS, // gam1=((0,-i*s1)(i*s1,0)) gam2=((0,-i*s2)(i*s2,0)) gam3=((0,-i*s3)(i*s3,0)) + // gam4=((1,0)(0,-1)) gam5=((0,1)(1,0)) + QUDA_INVALID_GAMMA_BASIS = QUDA_INVALID_ENUM // gam5=gam4*gam1*gam2*gam3 } QudaGammaBasis; - // Dirac-Pauli -> DeGrand-Rossi T = i/sqrt(2)*((s2,-s2)(s2,s2)) field_DR = T * field_DP - // UKQCD -> DeGrand-Rossi T = i/sqrt(2)*((-s2,-s2)(-s2,s2)) field_DR = T * field_UK - // Chiral -> DeGrand-Rossi T = i*((0,-s2)(s2,0)) field_DR = T * field_chiral +// Dirac-Pauli -> DeGrand-Rossi T = i/sqrt(2)*((s2,-s2)(s2,s2)) field_DR = T * field_DP +// UKQCD -> DeGrand-Rossi T = i/sqrt(2)*((-s2,-s2)(-s2,s2)) field_DR = T * field_UK +// Chiral -> DeGrand-Rossi T = i*((0,-s2)(s2,0)) field_DR = T * field_chiral typedef enum QudaSourceType_s { QUDA_POINT_SOURCE, QUDA_RANDOM_SOURCE, @@ -636,6 +638,8 @@ typedef enum QudaExtLibType_s { QUDA_EXTLIB_INVALID = QUDA_INVALID_ENUM } QudaExtLibType; +typedef enum QudaDDType_s { QUDA_DD_NO, QUDA_DD_RED_BLACK, QUDA_DD_INVALID = QUDA_INVALID_ENUM } QudaDDType; + typedef enum QudaWFlowStepType_s { WFLOW_STEP_W1, WFLOW_STEP_W2, diff --git a/include/enum_quda_fortran.h b/include/enum_quda_fortran.h index 8874959d7b..faf68bf914 100644 --- a/include/enum_quda_fortran.h +++ b/include/enum_quda_fortran.h @@ -544,3 +544,8 @@ #define QUDA_CUSOLVE_EXTLIB 0 #define QUDA_EIGEN_EXTLIB 1 #define QUDA_EXTLIB_INVALID QUDA_INVALID_ENUM + +#define QudaDDType integer(4) +#define QUDA_DD_NO 0 +#define QUDA_DD_RED_BLACK 1 +#define QUDA_DD_INVALID QUDA_INVALID_ENUM diff --git a/include/index_helper.cuh b/include/index_helper.cuh index cf1faf72b0..2f8a4a3d91 100644 --- a/include/index_helper.cuh +++ b/include/index_helper.cuh @@ -230,12 +230,15 @@ namespace quda { template struct Coord { - int x[nDim]; // nDim lattice coordinates + array x = {}; // nDim lattice coordinates + array gx = {}; // nDim global lattice coordinates + array gDim = {}; // global lattice dimensions int x_cb; // checkerboard lattice site index int s; // fifth dimension coord int X; // full lattice site index constexpr const int& operator[](int i) const { return x[i]; } constexpr int& operator[](int i) { return x[i]; } + constexpr int size() const { return nDim; } }; /** diff --git a/include/instantiate.h b/include/instantiate.h index 8eee6ad269..9b40ede127 100644 --- a/include/instantiate.h +++ b/include/instantiate.h @@ -84,6 +84,24 @@ namespace quda } } + /** + @brief precision_type_mapper Struct used to convert QudaPrecision to data-type. + */ + template struct precision_type_mapper { + }; + template <> struct precision_type_mapper { + using type = double; + }; + template <> struct precision_type_mapper { + using type = float; + }; + template <> struct precision_type_mapper { + using type = short; + }; + template <> struct precision_type_mapper { + using type = int8_t; + }; + /** @brief Helper function for returning if a given reconstruct is enabled @tparam reconstruct The reconstruct requested @@ -97,6 +115,20 @@ namespace quda template <> constexpr bool is_enabled() { return (QUDA_RECONSTRUCT & 1) ? true : false; } template <> constexpr bool is_enabled() { return true; } + /** + @brief Helper function for returning if a given domain decomposition is enabled + @tparam DD The domain decomposition requested + @return True if enabled, false if not + */ + constexpr bool is_enabled(QudaDDType DD) + { + switch (DD) { + case QUDA_DD_NO: return true; + case QUDA_DD_RED_BLACK: return (QUDA_DOMAIN_DECOMPOSITION & 1) ? true : false; + default: return false; + } + } + struct ReconstructFull { static constexpr std::array recon = {QUDA_RECONSTRUCT_NO, QUDA_RECONSTRUCT_13, QUDA_RECONSTRUCT_12, QUDA_RECONSTRUCT_9, QUDA_RECONSTRUCT_8, QUDA_RECONSTRUCT_10}; diff --git a/include/instantiate_dslash.h b/include/instantiate_dslash.h index eab0ead243..587f36e7ee 100644 --- a/include/instantiate_dslash.h +++ b/include/instantiate_dslash.h @@ -5,6 +5,7 @@ #include #include #include +#include namespace quda { @@ -16,24 +17,24 @@ namespace quda @param[in] U Gauge field @param[in] args Additional arguments for different dslash kernels */ - template