Skip to content

Commit

Permalink
The MMA kernels will now choose smallest instantiated nVec that is la…
Browse files Browse the repository at this point in the history
…rger than the actual nVec when creating the MMA-ordered fields.
  • Loading branch information
hummingtree committed Dec 3, 2024
1 parent abe7854 commit a3e2f4a
Show file tree
Hide file tree
Showing 5 changed files with 8 additions and 60 deletions.
2 changes: 2 additions & 0 deletions include/quda_define.h.in
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,8 @@
#define GPU_MULTIGRID
#endif

#cmakedefine QUDA_ENABLE_MMA

#ifdef QUDA_MULTIGRID

#cmakedefine QUDA_MULTIGRID_MMA_SETUP_TYPE
Expand Down
1 change: 1 addition & 0 deletions lib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,7 @@ configure_file(color_spinor_pack.in.cu color_spinor_pack.cu @ONLY)
configure_file(color_spinor_util.in.cu color_spinor_util.cu @ONLY)
configure_file(dslash_coarse_mma.in.hpp dslash_coarse_mma.hpp @ONLY)
configure_file(block_transpose.in.cu block_transpose.cu @ONLY)
configure_file(multigrid.in.hpp multigrid.hpp @ONLY)

if(QUDA_MULTIGRID)
string(REPLACE "," ";" QUDA_MULTIGRID_NVEC_LIST_SEMICOLON "${QUDA_MULTIGRID_NVEC_LIST}")
Expand Down
15 changes: 1 addition & 14 deletions lib/dslash_coarse.in.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#include "multigrid.h"
#include <multigrid.hpp>
#include <dirac_quda.h>

namespace quda
{

template <int...> struct IntList {
};

#if defined(QUDA_MMA_AVAILABLE)
template <bool dagger, int Nc, int nVec, int... N>
void ApplyCoarseMma(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &inA,
Expand Down Expand Up @@ -80,17 +78,6 @@ namespace quda
return output;
}

template <class F> auto create_color_spinor_copy(cvector_ref<F> &fs, QudaFieldOrder order)
{
ColorSpinorParam param(fs[0]);
int nVec = (fs.size() + 7) / 8 * 8; // Make a multiple of 8
param.nColor = fs[0].Ncolor() * nVec;
param.nVec = nVec;
param.create = QUDA_NULL_FIELD_CREATE;
param.fieldOrder = order;
return ColorSpinorField(param);
}

// Apply the coarse Dirac matrix to a coarse grid vector
// out(x) = M*in = X*in - kappa*\sum_mu Y_{-\mu}(x)in(x+mu) + Y^\dagger_mu(x-mu)in(x-mu)
// or
Expand Down
25 changes: 2 additions & 23 deletions lib/prolongator.in.cpp
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
#include "multigrid.h"
#include <multigrid.hpp>
#include <blas_quda.h>

namespace quda
{

template <int...> struct IntList {
};

template <int fineColor, int coarseColor, int nVec, int... N>
void ProlongateMma2(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
const int *fine_to_coarse, const int *const *spin_map, int parity, IntList<nVec, N...>) {
Expand All @@ -21,25 +19,6 @@ namespace quda
}
}

template <class F> auto create_color_spinor_copy(cvector_ref<F> &fs, QudaFieldOrder order)
{
ColorSpinorParam param(fs[0]);
int nVec = (fs.size() + 7) / 8 * 8; // Make a multiple of 8
param.nColor = fs[0].Ncolor() * nVec;
param.nVec = nVec;
param.create = QUDA_NULL_FIELD_CREATE;
param.fieldOrder = order;
return ColorSpinorField(param);
}

static auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order)
{
ColorSpinorParam param(f);
param.create = QUDA_NULL_FIELD_CREATE;
param.fieldOrder = order;
return ColorSpinorField(param);
}

template <bool use_mma, int fineColor, int coarseColor, int... N>
void Prolongate2(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
const int *fine_to_coarse, const int *const *spin_map, int parity, IntList<coarseColor, N...>)
Expand Down Expand Up @@ -104,7 +83,7 @@ namespace quda
// clang-format off
IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors;
// clang-format on
if (use_mma && in.size() % 8 == 0) {
if (use_mma) {
// use MMA
Prolongate<true>(out, in, v, fine_to_coarse, spin_map, parity, fineColors);
} else {
Expand Down
25 changes: 2 additions & 23 deletions lib/restrictor.in.cpp
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
#include "multigrid.h"
#include <multigrid.hpp>

namespace quda
{

template <int...> struct IntList {
};

template <int fineColor, int coarseColor, int nVec, int... N>
void RestrictMma2(ColorSpinorField &out, const ColorSpinorField &in, const ColorSpinorField &v,
const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity, IntList<nVec, N...>) {
Expand All @@ -20,25 +18,6 @@ namespace quda
}
}

template <class F> auto create_color_spinor_copy(cvector_ref<F> &fs, QudaFieldOrder order)
{
ColorSpinorParam param(fs[0]);
int nVec = (fs.size() + 7) / 8 * 8; // Make a multiple of 8
param.nColor = fs[0].Ncolor() * nVec;
param.nVec = nVec;
param.create = QUDA_NULL_FIELD_CREATE;
param.fieldOrder = order;
return ColorSpinorField(param);
}

static auto create_color_spinor_copy(const ColorSpinorField &f, QudaFieldOrder order)
{
ColorSpinorParam param(f);
param.create = QUDA_NULL_FIELD_CREATE;
param.fieldOrder = order;
return ColorSpinorField(param);
}

template <bool use_mma, int fineColor, int coarseColor, int... N>
void Restrict2(cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in, const ColorSpinorField &v,
const int *fine_to_coarse, const int *coarse_to_fine, const int *const *spin_map, int parity, IntList<coarseColor, N...>)
Expand Down Expand Up @@ -104,7 +83,7 @@ namespace quda
IntList<@QUDA_MULTIGRID_NC_NVEC_LIST@> fineColors;
// clang-format on

if (use_mma && in.size() % 8 == 0) {
if (use_mma) {
Restrict<true>(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors);
} else {
Restrict<false>(out, in, v, fine_to_coarse, coarse_to_fine, spin_map, parity, fineColors);
Expand Down

0 comments on commit a3e2f4a

Please sign in to comment.