From bc48268160b551e95cf8b06d1b2de4b3a53f866f Mon Sep 17 00:00:00 2001 From: James Osborn Date: Sat, 4 Jan 2025 16:52:57 -0600 Subject: [PATCH] pass KernelOps into SharedMemoryCache constructor --- include/dslash_helper.cuh | 11 +- include/gauge_fix_ovr_hit_devf.cuh | 84 +++++++++------ include/kernels/block_transpose.cuh | 21 ++-- include/kernels/coarse_op_kernel.cuh | 83 ++++++++++----- include/kernels/color_spinor_pack.cuh | 49 +++++---- include/kernels/covariant_derivative.cuh | 2 +- include/kernels/dslash_clover_helper.cuh | 17 ++- include/kernels/dslash_coarse.cuh | 34 +++--- include/kernels/dslash_domain_wall_4d.cuh | 2 +- .../dslash_domain_wall_4d_fused_m5.cuh | 50 +++++---- include/kernels/dslash_domain_wall_5d.cuh | 2 +- include/kernels/dslash_domain_wall_m5.cuh | 100 +++++++++++++----- include/kernels/dslash_mdw_fused.cuh | 6 +- include/kernels/dslash_mobius_eofa.cuh | 24 +++-- .../kernels/dslash_ndeg_twisted_clover.cuh | 29 +++-- ...ash_ndeg_twisted_clover_preconditioned.cuh | 25 +++-- include/kernels/dslash_ndeg_twisted_mass.cuh | 2 +- ...slash_ndeg_twisted_mass_preconditioned.cuh | 21 ++-- include/kernels/dslash_staggered.cuh | 2 +- .../dslash_twisted_clover_preconditioned.cuh | 2 +- include/kernels/dslash_twisted_mass.cuh | 2 +- .../dslash_twisted_mass_preconditioned.cuh | 2 +- include/kernels/dslash_wilson.cuh | 4 +- include/kernels/dslash_wilson_clover.cuh | 2 +- .../dslash_wilson_clover_hasenbusch_twist.cuh | 4 +- ...clover_hasenbusch_twist_preconditioned.cuh | 2 +- .../dslash_wilson_clover_preconditioned.cuh | 2 +- include/kernels/gauge_fix_ovr.cuh | 45 +++++--- include/kernels/laplace.cuh | 4 +- include/kernels/madwf_transfer.cuh | 15 ++- include/kernels/staggered_quark_smearing.cuh | 2 +- include/targets/generic/kernel_ops.h | 12 ++- .../generic/shared_memory_cache_helper.h | 8 +- 33 files changed, 445 insertions(+), 225 deletions(-) diff --git a/include/dslash_helper.cuh b/include/dslash_helper.cuh index 6b4747f39e..eef1c9c7fc 100644 --- a/include/dslash_helper.cuh +++ b/include/dslash_helper.cuh @@ -11,6 +11,7 @@ #include #include #include +#include constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::TRUE; @@ -660,17 +661,21 @@ namespace quda are reserved for data packing, which may include communication to neighboring processes. */ - template struct dslash_functor { + template struct dslash_functor : getKernelOps { const typename Arg::Arg &arg; static constexpr int nParity = Arg::nParity; static constexpr bool dagger = Arg::dagger; static constexpr KernelType kernel_type = Arg::kernel_type; static constexpr const char *filename() { return Arg::D::filename(); } - constexpr dslash_functor(const Arg &arg) : arg(arg.arg) { } + using typename getKernelOps::KernelOpsT; + template + constexpr dslash_functor(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg.arg) + { + } __forceinline__ __device__ void operator()(int, int s, int parity) { - typename Arg::D dslash(arg); + typename Arg::D dslash(*this); // for full fields set parity from z thread index else use arg setting if (nParity == 1) parity = arg.parity; diff --git a/include/gauge_fix_ovr_hit_devf.cuh b/include/gauge_fix_ovr_hit_devf.cuh index 6d0e5838e2..c8eefbbd00 100644 --- a/include/gauge_fix_ovr_hit_devf.cuh +++ b/include/gauge_fix_ovr_hit_devf.cuh @@ -40,20 +40,29 @@ namespace quda { } } + template struct GaugeFixHitDims { + static constexpr dim3 dims(dim3 block) + { + block.y = N; + return block; + } + }; + /** * Device function to perform gauge fixing with overrelxation. - * Uses 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd. - * This implementation needs 8x more shared memory than the implementation using atomicadd + * Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd. */ - template - inline __device__ void GaugeFixHit_AtomicAdd(Matrix,nColor> &link, const Float relax_boost, int mu) + template using GaugeFixHit_AtomicAddOps = KernelOps>>; + template + inline __device__ void GaugeFixHit_AtomicAdd(Matrix, nColor> &link, const Float relax_boost, int mu, + const Ftor &ftor) { auto blockSize = target::block_dim().x; auto tid = target::thread_idx().x; //Container for the four real parameters of SU(2) subgroup in shared memory - SharedMemoryCache cache; - auto elems = cache.data(); + SharedMemoryCache> cache(ftor); + Float *elems = cache.data(); //initialize shared memory if (mu < 4) elems[mu * blockSize + tid] = 0.0; @@ -138,17 +147,20 @@ namespace quda { /** * Device function to perform gauge fixing with overrelxation. - * Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd. + * Uses 4*8 threads per lattice site, the reduction is performed by shared memory without using atomicadd. + * This implementation needs 8x more shared memory than the implementation using atomicadd */ - template - inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix,nColor> &link, const Float relax_boost, int mu) + template using GaugeFixHit_NoAtomicAddOps = KernelOps>>; + template + inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix, nColor> &link, const Float relax_boost, int mu, + const Ftor &ftor) { auto blockSize = target::block_dim().x; auto tid = target::thread_idx().x; //Container for the four real parameters of SU(2) subgroup in shared memory - SharedMemoryCache cache; - auto elems = cache.data(); + SharedMemoryCache> cache(ftor); + Float *elems = &(*cache.data())[0]; //Loop over all SU(2) subroups of SU(N) //#pragma unroll @@ -228,15 +240,18 @@ namespace quda { * Uses 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd. * This implementation uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization */ - template - inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix,nColor> &link, const Float relax_boost, int mu) + template + using GaugeFixHit_NoAtomicAdd_LessSMOps = KernelOps>>; + template + inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix, nColor> &link, const Float relax_boost, + int mu, const Ftor &ftor) { auto blockSize = target::block_dim().x; auto tid = target::thread_idx().x; //Container for the four real parameters of SU(2) subgroup in shared memory - SharedMemoryCache cache; - auto elems = cache.data(); + SharedMemoryCache> cache(ftor); + Float *elems = cache.data(); //Loop over all SU(2) subroups of SU(N) //#pragma unroll @@ -323,18 +338,20 @@ namespace quda { /** * Device function to perform gauge fixing with overrelxation. * Uses 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd. - * This implementation needs 8x more shared memory than the implementation using atomicadd + * This implementation needs 8x more shared memory than the implementation using atomicadd */ - template - inline __device__ void GaugeFixHit_AtomicAdd(Matrix,nColor> &link, Matrix,nColor> &link1, - const Float relax_boost, int mu) + template using GaugeFixHit_AtomicAdd2Ops = KernelOps>>; + template + inline __device__ void GaugeFixHit_AtomicAdd(Matrix, nColor> &link, + Matrix, nColor> &link1, const Float relax_boost, int mu, + const Ftor &ftor) { auto blockSize = target::block_dim().x; auto tid = target::thread_idx().x; //Container for the four real parameters of SU(2) subgroup in shared memory - SharedMemoryCache cache; - auto elems = cache.data(); + SharedMemoryCache> cache(ftor); + Float *elems = cache.data(); //initialize shared memory if (mu < 4) elems[mu * blockSize + tid] = 0.0; @@ -408,16 +425,19 @@ namespace quda { * Device function to perform gauge fixing with overrelxation. * Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd. */ - template - inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix,nColor> &link, Matrix,nColor> &link1, - const Float relax_boost, int mu) + template + using GaugeFixHit_NoAtomicAdd2Ops = KernelOps>>; + template + inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix, nColor> &link, + Matrix, nColor> &link1, const Float relax_boost, int mu, + const Ftor &ftor) { auto blockSize = target::block_dim().x; auto tid = target::thread_idx().x; //Container for the four real parameters of SU(2) subgroup in shared memory - SharedMemoryCache cache; - auto elems = cache.data(); + SharedMemoryCache> cache(ftor); + Float *elems = cache.data(); //Loop over all SU(2) subroups of SU(N) //#pragma unroll @@ -485,15 +505,19 @@ namespace quda { * Uses 4 threads per lattice site, the reduction is performed by shared memory without using atomicadd. * This implementation uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization */ - template - inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix,nColor> &link, Matrix,nColor> &link1, const Float relax_boost, int mu) + template + using GaugeFixHit_NoAtomicAdd_LessSM2Ops = KernelOps>>; + template + inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix, nColor> &link, + Matrix, nColor> &link1, const Float relax_boost, + int mu, const Ftor &ftor) { auto blockSize = target::block_dim().x; auto tid = target::thread_idx().x; //Container for the four real parameters of SU(2) subgroup in shared memory - SharedMemoryCache cache; - auto elems = cache.data(); + SharedMemoryCache> cache(ftor); + Float *elems = cache.data(); //Loop over all SU(2) subroups of SU(N) //#pragma unroll diff --git a/include/kernels/block_transpose.cuh b/include/kernels/block_transpose.cuh index 91aaec4d63..c9b04522a0 100644 --- a/include/kernels/block_transpose.cuh +++ b/include/kernels/block_transpose.cuh @@ -42,11 +42,7 @@ namespace quda } }; - template struct BlockTransposeKernel { - const Arg &arg; - constexpr BlockTransposeKernel(const Arg &arg) : arg(arg) { } - static constexpr const char *filename() { return KERNEL_FILE; } - + template struct BlockTransposeKernelOps { struct CacheDims { static constexpr dim3 dims(dim3 block) { @@ -55,6 +51,19 @@ namespace quda return block; } }; + using color_spinor_t = ColorSpinor; + using CacheT = SharedMemoryCache; + using Ops = KernelOps; + }; + + template struct BlockTransposeKernel : BlockTransposeKernelOps::Ops { + const Arg &arg; + using typename BlockTransposeKernelOps::Ops::KernelOpsT; + template + constexpr BlockTransposeKernel(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) + { + } + static constexpr const char *filename() { return KERNEL_FILE; } /** @brief Transpose between the two different orders of batched colorspinor fields: @@ -69,7 +78,7 @@ namespace quda int parity = parity_color / Arg::nColor; using color_spinor_t = ColorSpinor; - SharedMemoryCache cache; + typename BlockTransposeKernelOps::CacheT cache {*this}; int x_offset = target::block_dim().x * target::block_idx().x; int v_offset = target::block_dim().y * target::block_idx().y; diff --git a/include/kernels/coarse_op_kernel.cuh b/include/kernels/coarse_op_kernel.cuh index 03f9d4b75e..b1115916a8 100644 --- a/include/kernels/coarse_op_kernel.cuh +++ b/include/kernels/coarse_op_kernel.cuh @@ -11,6 +11,7 @@ #include #include #include +#include namespace quda { @@ -303,7 +304,7 @@ namespace quda { #pragma unroll for (int s = 0; s < uvSpin; s++) { if constexpr (Arg::compute_max) { - uv_max = fmax(UV[s].abs_max(), uv_max); + uv_max = max(UV[s].abs_max(), uv_max); } else { UV[s].saveCS(arg.UV, 0, 0, parity, x_cb, s, i0, j0); } @@ -375,7 +376,7 @@ namespace quda { #pragma unroll for (int s = 0; s < uvSpin; s++) { if constexpr (Arg::compute_max) { - uv_max = fmax(UV[s].abs_max(), uv_max); + uv_max = max(UV[s].abs_max(), uv_max); } else { UV[s].saveCS(arg.UV, 0, 0, parity, x_cb, s, i0, j0); } @@ -494,7 +495,7 @@ namespace quda { #pragma unroll for (int s = 0; s < uvSpin; s++) { if constexpr (Arg::compute_max) { - uv_max = fmax(UV[s].abs_max(), uv_max); + uv_max = max(UV[s].abs_max(), uv_max); } else { UV[s].saveCS(arg.UV, 0, 0, parity, x_cb, s, i0, j0); } @@ -598,7 +599,7 @@ namespace quda { #pragma unroll for (int s = 0; s < uvSpin; s++) { if constexpr (Arg::compute_max) { - uv_max = fmax(UV[s].abs_max(), uv_max); + uv_max = max(UV[s].abs_max(), uv_max); } else { UV[s].saveCS(arg.UV, 0, 0, parity, x_cb, s, i0, j0); } @@ -751,8 +752,8 @@ namespace quda { for (int s = 0; s < Arg::fineSpin / 2; s++) { #pragma unroll for (int ic = 0; ic < Arg::fineColor; ic++) { - auto abs_max = fmax(abs(AV(s, ic).real()), abs(AV(s, ic).imag())); - max = fmax(abs_max, max); + auto abs_max = quda::max(abs(AV(s, ic).real()), abs(AV(s, ic).imag())); + max = quda::max(abs_max, max); } } atomic_fetch_abs_max(arg.max, max); @@ -887,8 +888,8 @@ namespace quda { for (int s = 0; s < Arg::fineSpin / 2; s++) { #pragma unroll for (int c = 0; c < Arg::fineColor; c++) { - auto abs_max = fmax(abs(AV(s, c).real()), abs(AV(s, c).imag())); - max = fmax(abs_max, max); + auto abs_max = quda::max(abs(AV(s, c).real()), abs(AV(s, c).imag())); + max = quda::max(abs_max, max); } } atomic_fetch_abs_max(arg.max, max); @@ -913,8 +914,8 @@ namespace quda { for (int s = 0; s < Arg::fineSpin / 2; s++) { #pragma unroll for (int c = 0; c < Arg::fineColor; c++) { - auto abs_max = fmax(abs(AV(s, c).real()), abs(AV(s, c).imag())); - max = fmax(abs_max, max); + auto abs_max = quda::max(abs(AV(s, c).real()), abs(AV(s, c).imag())); + max = quda::max(abs_max, max); } } atomic_fetch_abs_max(arg.max, max); @@ -1019,8 +1020,8 @@ namespace quda { real max = static_cast(0.0); #pragma unroll for (int ic_f = 0; ic_f < Arg::fineColor; ic_f++) { - auto abs_max = fmax(abs(out(0, ic_f).real()), abs(out(0, ic_f).imag())); - max = fmax(abs_max, max); + auto abs_max = quda::max(abs(out(0, ic_f).real()), abs(out(0, ic_f).imag())); + max = quda::max(abs_max, max); } atomic_fetch_abs_max(arg.max, max); } @@ -1392,14 +1393,18 @@ namespace quda { using CacheT = complex[Arg::max_color_height_per_block][Arg::max_color_width_per_block][4] [Arg::coarseSpin][Arg::coarseSpin]; template using Cache = SharedMemoryCache, DimsStatic<2, 1, 1>>; + template using Ops = KernelOps>; - template - inline __device__ void operator()(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, int i0, int j0, int parity, const Pack &pack, const Arg &arg) + template + inline __device__ void operator()(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, int i0, int j0, + int parity, const Pack &pack, const Ftor &ftor) { + using Arg = typename Ftor::Arg; + const Arg &arg = ftor.arg; using real = typename Arg::Float; using TileType = typename Arg::vuvTileType; const int dim_index = arg.dim_index % arg.Y_atomic.geometry; - Cache cache; + Cache cache {ftor}; auto &X = cache.data()[0]; auto &Y = cache.data()[1]; @@ -1489,16 +1494,25 @@ namespace quda { } }; - template - __device__ __host__ void storeCoarseSharedAtomic(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, int i0, int j0, int parity, const Arg &arg) + template + __device__ __host__ void storeCoarseSharedAtomic(VUV &vuv, bool isDiagonal, int coarse_x_cb, int coarse_parity, + int i0, int j0, int parity, const Ftor &ftor) { + using Arg = typename Ftor::Arg; + const Arg &arg = ftor.arg; switch (arg.dir) { case QUDA_BACKWARDS: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, Pack(), arg); break; + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, + Pack(), ftor); + break; case QUDA_FORWARDS: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, Pack(), arg); break; + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, + Pack(), ftor); + break; case QUDA_IN_PLACE: - target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, Pack(), arg); break; + target::dispatch(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, + Pack(), ftor); + break; default: break;// do nothing } @@ -1584,9 +1598,12 @@ namespace quda { } - template - __device__ __host__ void computeVUV(const Arg &arg, int parity, int x_cb, int i0, int j0, int parity_coarse_, int coarse_x_cb_) + template + __device__ __host__ void computeVUV(const Ftor &ftor, int parity, int x_cb, int i0, int j0, int parity_coarse_, + int coarse_x_cb_) { + using Arg = typename Ftor::Arg; + const Arg &arg = ftor.arg; using real = typename Arg::Float; constexpr int nDim = 4; int coord[QUDA_MAX_DIM]; @@ -1618,7 +1635,7 @@ namespace quda { } if (arg.shared_atomic) - storeCoarseSharedAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, arg); + storeCoarseSharedAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, parity, ftor); else storeCoarseGlobalAtomic(vuv, isDiagonal, coarse_x_cb, coarse_parity, i0, j0, arg); } @@ -1681,11 +1698,15 @@ namespace quda { } }; - template struct compute_vuv { + template struct compute_vuv : storeCoarseSharedAtomic_impl::Ops { + using Arg = Arg_; static constexpr int nFace = 1; const Arg &arg; static constexpr const char *filename() { return KERNEL_FILE; } - constexpr compute_vuv(const Arg &arg) : arg(arg) { } + using typename storeCoarseSharedAtomic_impl::Ops::KernelOpsT; + template constexpr compute_vuv(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) + { + } /** 3-d parallelism @@ -1703,15 +1724,19 @@ namespace quda { if (c_col >= arg.vuvTile.N_tiles) return; if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; - computeVUV(arg, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, x_coarse_cb); + computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, x_coarse_cb); } }; - template struct compute_vlv { + template struct compute_vlv : storeCoarseSharedAtomic_impl::Ops { + using Arg = Arg_; static constexpr int nFace = 3; const Arg &arg; static constexpr const char *filename() { return KERNEL_FILE; } - constexpr compute_vlv(const Arg &arg) : arg(arg) { } + using typename storeCoarseSharedAtomic_impl::Ops::KernelOpsT; + template constexpr compute_vlv(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) + { + } /** 3-d parallelism @@ -1729,7 +1754,7 @@ namespace quda { if (c_col >= arg.vuvTile.N_tiles) return; if (!arg.shared_atomic && x_cb >= arg.fineVolumeCB) return; - computeVUV(arg, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, x_coarse_cb); + computeVUV(*this, parity, x_cb, c_row * arg.vuvTile.M, c_col * arg.vuvTile.N, parity_coarse, x_coarse_cb); } }; diff --git a/include/kernels/color_spinor_pack.cuh b/include/kernels/color_spinor_pack.cuh index b3637e1652..9228098c30 100644 --- a/include/kernels/color_spinor_pack.cuh +++ b/include/kernels/color_spinor_pack.cuh @@ -163,8 +163,9 @@ namespace quda { }; template struct site_max { - template inline auto operator()(typename Arg::real thread_max, Arg &) + template inline auto operator()(typename Ftor::Arg::real thread_max, Ftor &) { + using Arg = typename Ftor::Arg; // on the host we require that both spin and color are fully thread local constexpr int Ms = spins_per_thread(Arg::nSpin); constexpr int Mc = colors_per_thread(Arg::nColor); @@ -187,12 +188,15 @@ namespace quda { return block; } }; + template using Cache = SharedMemoryCache>; + template using Ops = KernelOps>; - template __device__ inline auto operator()(typename Arg::real thread_max, Arg &) + template __device__ inline auto operator()(typename Ftor::Arg::real thread_max, const Ftor &ftor) { + using Arg = typename Ftor::Arg; using real = typename Arg::real; constexpr int color_spin_threads = CacheDims::color_spin_threads; - SharedMemoryCache> cache; + Cache cache {ftor}; cache.save(thread_max); cache.sync(); real this_site_max = static_cast(0); @@ -205,21 +209,23 @@ namespace quda { } }; - template __device__ __host__ inline std::enable_if_t - compute_site_max(const Arg &, int, int, int, int, int) + template + __device__ __host__ inline std::enable_if_t + compute_site_max(const Ftor &, int, int, int, int, int) { - return static_cast(1.0); // dummy return for non-block float + return static_cast(1.0); // dummy return for non-block float } /** Compute the max element over the spin-color components of a given site. */ - template __device__ __host__ inline std::enable_if_t - compute_site_max(const Arg &arg, int src_idx, int x_cb, int spinor_parity, int spin_block, int color_block) + template + __device__ __host__ inline std::enable_if_t + compute_site_max(const Ftor &ftor, int src_idx, int x_cb, int spinor_parity, int spin_block, int color_block) { - using real = typename Arg::real; - const int Ms = spins_per_thread(Arg::nSpin); - const int Mc = colors_per_thread(Arg::nColor); + using real = typename Ftor::Arg::real; + const int Ms = spins_per_thread(Ftor::Arg::nSpin); + const int Mc = colors_per_thread(Ftor::Arg::nColor); complex thread_max = {0.0, 0.0}; #pragma unroll @@ -228,13 +234,13 @@ namespace quda { #pragma unroll for (int color_local=0; color_local z = arg.in[src_idx](spinor_parity, x_cb, s, c); - thread_max.real(std::max(thread_max.real(), std::abs(z.real()))); - thread_max.imag(std::max(thread_max.imag(), std::abs(z.imag()))); + complex z = ftor.arg.in[src_idx](spinor_parity, x_cb, s, c); + thread_max.real(max(thread_max.real(), abs(z.real()))); + thread_max.imag(max(thread_max.imag(), abs(z.imag()))); } } - return target::dispatch(std::max(thread_max.real(), thread_max.imag()), arg); + return target::dispatch(max(thread_max.real(), thread_max.imag()), ftor); } /** @@ -288,9 +294,16 @@ namespace quda { } } - template struct GhostPacker { + template + using GhostPackerOps = std::conditional_t::Ops, NoKernelOps>; + + template struct GhostPacker : GhostPackerOps { + using Arg = Arg_; const Arg &arg; - constexpr GhostPacker(const Arg &arg) : arg(arg) {} + using typename GhostPackerOps::KernelOpsT; + template constexpr GhostPacker(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) + { + } static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ void operator()(int tid, int spin_color_block, int parity) @@ -309,7 +322,7 @@ namespace quda { int src_idx; int x_cb = indexFromFaceIndex(src_idx, dim, dir, ghost_idx, parity, arg); - auto max = compute_site_max(arg, src_idx, x_cb, spinor_parity, spin_block, color_block); + auto max = compute_site_max(*this, src_idx, x_cb, spinor_parity, spin_block, color_block); #pragma unroll for (int spin_local=0; spin_local struct covDev : dslash_default { const Arg &arg; - constexpr covDev(const Arg &arg) : arg(arg) {} + template constexpr covDev(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/dslash_clover_helper.cuh b/include/kernels/dslash_clover_helper.cuh index b141aa9234..94b6bce0d9 100644 --- a/include/kernels/dslash_clover_helper.cuh +++ b/include/kernels/dslash_clover_helper.cuh @@ -181,18 +181,26 @@ namespace quda { arg.out[src_idx](x_cb, spinor_parity) = out; } }; - + + template + using NdegTwistCloverApplyOps + = KernelOps>>; + // if (!inverse) apply (Clover + i*a*gamma_5*tau_3 + b*epsilon*tau_1) to the input spinor // else apply (Clover + i*a*gamma_5*tau_3 + b*epsilon*tau_1)/(Clover^2 + a^2 - b^2) to the input spinor // noting that appropriate signs are carried by a and b depending on inverse - template struct NdegTwistCloverApply { + template struct NdegTwistCloverApply : NdegTwistCloverApplyOps { static constexpr int N = Arg::nColor * Arg::nSpin / 2; using real = typename Arg::real; using fermion = ColorSpinor; using half_fermion = ColorSpinor; using Mat = HMatrix; const Arg &arg; - constexpr NdegTwistCloverApply(const Arg &arg) : arg(arg) {} + using typename NdegTwistCloverApplyOps::KernelOpsT; + template + constexpr NdegTwistCloverApply(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) + { + } static constexpr const char* filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb, int src_flavor, int parity) @@ -215,7 +223,7 @@ namespace quda { Mat A = arg.clover(x_cb, clover_parity, chirality); - SharedMemoryCache cache; + SharedMemoryCache cache {*this}; half_fermion in_chi[n_flavor]; // flavor array of chirally projected fermion #pragma unroll @@ -266,5 +274,4 @@ namespace quda { arg.out[src_idx](my_flavor_idx, spinor_parity) = out; } }; - } diff --git a/include/kernels/dslash_coarse.cuh b/include/kernels/dslash_coarse.cuh index 18dd1da1a5..798e29fb8e 100644 --- a/include/kernels/dslash_coarse.cuh +++ b/include/kernels/dslash_coarse.cuh @@ -290,16 +290,15 @@ namespace quda { } template struct dim_collapse { - template void operator()(T &out, int, int, const Arg &arg) - { - out *= -arg.kappa; - } + template void operator()(T &out, int, int, const Ftor &ftor) { out *= -ftor.arg.kappa; } }; template <> struct dim_collapse { - template __device__ __host__ inline void operator()(T &out, int dir, int dim, const Arg &arg) + template + __device__ __host__ inline void operator()(T &out, int dir, int dim, const Ftor &ftor) { - SharedMemoryCache cache; + using Arg = typename Ftor::Arg; + SharedMemoryCache cache {ftor}; // only need to write to shared memory if not master thread if (dim > 0 || dir) cache.save(out); @@ -319,14 +318,25 @@ namespace quda { out += cache.load_z(target::thread_idx().z + d * 2 + 1); } - out *= -arg.kappa; + out *= -ftor.arg.kappa; } } }; - template struct CoarseDslash { + template struct CoarseDslashParams { + static constexpr int Mc = colors_per_thread(Arg::nColor, Arg::dim_stride); + using array_t = array, Mc>; + using Ops = KernelOps, op_warp_combine>; + }; + + template struct CoarseDslash : CoarseDslashParams::Ops { + using Arg = Arg_; const Arg &arg; - constexpr CoarseDslash(const Arg &arg) : arg(arg) {} + using typename CoarseDslashParams::Ops::KernelOpsT; + template + constexpr CoarseDslash(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) + { + } static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb_color_offset, int src_parity, int sMd) @@ -347,7 +357,7 @@ namespace quda { int parity = (arg.nParity == 2) ? (src_parity / arg.n_src) : arg.parity; // z thread dimension is (( s*(Nc/Mc) + color_block )*dim_thread_split + dim)*2 + dir - constexpr int Mc = colors_per_thread(Arg::nColor, Arg::dim_stride); + constexpr int Mc = CoarseDslashParams::Mc; int dir = sMd & 1; int sMdim = sMd >> 1; int dim = sMdim % Arg::dim_stride; @@ -355,11 +365,11 @@ namespace quda { int s = sM / (Arg::nColor/Mc); int color_block = (sM % (Arg::nColor/Mc)) * Mc; - array, Mc> out{ }; + typename CoarseDslashParams::array_t out {}; if (Arg::dslash) { applyDslash(out, dim, dir, x_cb, src_idx, parity, s, color_block, color_offset, arg); - target::dispatch(out, dir, dim, arg); + target::dispatch(out, dir, dim, *this); } if (doBulk() && Arg::clover && dir==0 && dim==0) applyClover(out, arg, x_cb, src_idx, parity, s, color_block, color_offset); diff --git a/include/kernels/dslash_domain_wall_4d.cuh b/include/kernels/dslash_domain_wall_4d.cuh index 4fbb511230..bbba5f0798 100644 --- a/include/kernels/dslash_domain_wall_4d.cuh +++ b/include/kernels/dslash_domain_wall_4d.cuh @@ -29,7 +29,7 @@ namespace quda struct domainWall4D : dslash_default { const Arg &arg; - constexpr domainWall4D(const Arg &arg) : arg(arg) {} + template constexpr domainWall4D(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh index 06ef52324b..cd9eefe4af 100644 --- a/include/kernels/dslash_domain_wall_4d_fused_m5.cuh +++ b/include/kernels/dslash_domain_wall_4d_fused_m5.cuh @@ -56,13 +56,17 @@ namespace quda } }; - template - struct domainWall4DFusedM5 : dslash_default { + constexpr bool domainWall4DFusedM5shared = true; // Use shared memory + template + struct domainWall4DFusedM5 : dslash_default, d5Params::Ops { + using Arg = Arg_; static constexpr Dslash5Type dslash5_type = Arg::type; + static constexpr bool shared = domainWall4DFusedM5shared; const Arg &arg; - constexpr domainWall4DFusedM5(const Arg &arg) : arg(arg) { } + using typename d5Params::Ops::KernelOpsT; + template constexpr domainWall4DFusedM5(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template @@ -96,8 +100,8 @@ namespace quda */ if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE) { constexpr bool sync = false; - out = d5(arg, stencil_out, my_spinor_parity, 0, s, - src_idx); + out = d5(*this, stencil_out, + my_spinor_parity, 0, s, src_idx); } } @@ -111,8 +115,8 @@ namespace quda */ if (active) { constexpr bool sync = false; - out = variableInv(arg, stencil_out, my_spinor_parity, - 0, s, src_idx); + out = variableInv( + *this, stencil_out, my_spinor_parity, 0, s, src_idx); } Vector aggregate_external; @@ -138,8 +142,8 @@ namespace quda constexpr bool sync = true; constexpr bool this_dagger = true; // Then we apply the second m5inv-dag - out = variableInv(arg, out, my_spinor_parity, 0, - s, src_idx); + out = variableInv( + *this, out, my_spinor_parity, 0, s, src_idx); } } else if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS @@ -156,16 +160,16 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS_PRE_M5_MOB) { constexpr bool sync = false; - out = d5( - arg, stencil_out, my_spinor_parity, 0, s, src_idx); + out = d5( + *this, stencil_out, my_spinor_parity, 0, s, src_idx); } } if (xpay && mykernel_type == INTERIOR_KERNEL) { Vector x = arg.x[src_idx](xs, my_spinor_parity); constexpr bool sync_m5mob = Arg::dslash5_type == Dslash5Type::DSLASH5_MOBIUS ? false : true; - x = d5( - arg, x, my_spinor_parity, 0, s, src_idx); + x = d5( + *this, x, my_spinor_parity, 0, s, src_idx); out = x + arg.a_5[s] * out; } else if (mykernel_type != INTERIOR_KERNEL && active) { Vector x = arg.out[src_idx](xs, my_spinor_parity); @@ -183,8 +187,8 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_INV_MOBIUS) { // Apply the m5inv. constexpr bool sync = false; - out = variableInv(arg, stencil_out, my_spinor_parity, - 0, s, src_idx); + out = variableInv( + *this, stencil_out, my_spinor_parity, 0, s, src_idx); } if (xpay && mykernel_type == INTERIOR_KERNEL) { @@ -204,12 +208,12 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_INV_MOBIUS_M5_PRE) { // Apply the m5inv. constexpr bool sync_m5inv = false; - out = variableInv(arg, out, my_spinor_parity, - 0, s, src_idx); + out = variableInv( + *this, out, my_spinor_parity, 0, s, src_idx); // Apply the m5pre. constexpr bool sync_m5pre = true; - out = d5(arg, out, my_spinor_parity, 0, s, - src_idx); + out = d5(*this, out, my_spinor_parity, + 0, s, src_idx); } /****** @@ -218,12 +222,12 @@ namespace quda if (Arg::dslash5_type == Dslash5Type::M5_PRE_MOBIUS_M5_INV) { // Apply the m5pre. constexpr bool sync_m5pre = false; - out = d5(arg, out, my_spinor_parity, 0, s, - src_idx); + out = d5(*this, out, my_spinor_parity, + 0, s, src_idx); // Apply the m5inv. constexpr bool sync_m5inv = true; - out = variableInv(arg, out, my_spinor_parity, - 0, s, src_idx); + out = variableInv( + *this, out, my_spinor_parity, 0, s, src_idx); } } } diff --git a/include/kernels/dslash_domain_wall_5d.cuh b/include/kernels/dslash_domain_wall_5d.cuh index 3737bacb5c..97c8040af8 100644 --- a/include/kernels/dslash_domain_wall_5d.cuh +++ b/include/kernels/dslash_domain_wall_5d.cuh @@ -32,7 +32,7 @@ namespace quda struct domainWall5D : dslash_default { const Arg &arg; - constexpr domainWall5D(const Arg &arg) : arg(arg) {} + template constexpr domainWall5D(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation constexpr QudaPCType pc_type() const { return QUDA_5D_PC; } diff --git a/include/kernels/dslash_domain_wall_m5.cuh b/include/kernels/dslash_domain_wall_m5.cuh index 62bf0d27d6..9ea0419b8a 100644 --- a/include/kernels/dslash_domain_wall_m5.cuh +++ b/include/kernels/dslash_domain_wall_m5.cuh @@ -209,9 +209,17 @@ namespace quda } }; - template - __device__ __host__ inline Vector d5(const Arg &arg, const Vector &in, int parity, int x_cb, int s, int src_idx) + template struct d5Params { + using Vec = ColorSpinor; + using Cache = SharedMemoryCache; + using Ops = std::conditional_t, NoKernelOps>; + }; + + template + __device__ __host__ inline Vector d5(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s, int src_idx) { + const Arg &arg = ftor.arg; int local_src_idx = target::thread_idx().y / arg.Ls; using real = typename Arg::real; constexpr bool is_variable = true; @@ -219,22 +227,23 @@ namespace quda Vector out; - if (mobius_m5::use_half_vector()) { + if constexpr (mobius_m5::use_half_vector()) { // if using shared-memory caching then load spinor field for my site into cache typedef ColorSpinor HalfVector; - SharedMemoryCache cache; + using Cache = std::conditional_t, const Ftor &>; + Cache cache {ftor}; { // forwards direction constexpr int proj_dir = dagger ? +1 : -1; - if (shared) { - if (sync) { cache.sync(); } + if constexpr (shared) { + if constexpr (sync) { cache.sync(); } cache.save(in.project(4, proj_dir)); cache.sync(); } const int fwd_s = (s + 1) % arg.Ls; const int fwd_idx = fwd_s * arg.volume_4d_cb + x_cb; HalfVector half_in; - if (shared) { + if constexpr (shared) { half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + fwd_s, parity); } else { Vector full_in = arg.in[src_idx](fwd_idx, parity); @@ -249,7 +258,7 @@ namespace quda { // backwards direction constexpr int proj_dir = dagger ? -1 : +1; - if (shared) { + if constexpr (shared) { cache.sync(); cache.save(in.project(4, proj_dir)); cache.sync(); @@ -257,7 +266,7 @@ namespace quda const int back_s = (s + arg.Ls - 1) % arg.Ls; const int back_idx = back_s * arg.volume_4d_cb + x_cb; HalfVector half_in; - if (shared) { + if constexpr (shared) { half_in = cache.load(threadIdx.x, local_src_idx * arg.Ls + back_s, parity); } else { Vector full_in = arg.in[src_idx](back_idx, parity); @@ -273,8 +282,10 @@ namespace quda } else { // use_half_vector // if using shared-memory caching then load spinor field for my site into cache - SharedMemoryCache cache; - if (shared) { + using Cache = std::conditional_t, const Ftor &>; + Cache cache {ftor}; + + if constexpr (shared) { if (sync) { cache.sync(); } cache.save(in); cache.sync(); @@ -319,9 +330,14 @@ namespace quda return out; } - template struct dslash5 { + template struct dslash5 : d5Params::Ops { + using Arg = Arg_; const Arg &arg; - constexpr dslash5(const Arg &arg) : arg(arg) { } + using typename d5Params::Ops::KernelOpsT; + template + constexpr dslash5(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) + { + } static constexpr const char *filename() { return KERNEL_FILE; } /** @@ -342,7 +358,7 @@ namespace quda constexpr bool sync = false; constexpr bool shared = false; - Vector out = d5(arg, Vector(), parity, x_cb, s, src_idx); + Vector out = d5(*this, Vector(), parity, x_cb, s, src_idx); if (Arg::xpay) { if (Arg::type == Dslash5Type::DSLASH5_DWF) { @@ -361,6 +377,12 @@ namespace quda } }; + template struct constantInvParams { + using Vec = ColorSpinor; + using Cache = SharedMemoryCache; + using Ops = std::conditional_t, NoKernelOps>; + }; + /** @brief Apply the M5 inverse operator at a given site on the lattice. This is the original algorithm as described in Kim and @@ -376,18 +398,21 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s_ Ls dimension coordinate */ - template - __device__ __host__ inline Vector constantInv(const Arg &arg, const Vector &in, int parity, int x_cb, int s_, + template + __device__ __host__ inline Vector constantInv(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s_, int src_idx) { + using Arg = typename Ftor::Arg; + const Arg &arg = ftor.arg; int local_src_idx = target::thread_idx().y / arg.Ls; using real = typename Arg::real; const auto k = arg.kappa; const auto inv = arg.inv; // if using shared-memory caching then load spinor field for my site into cache - SharedMemoryCache cache; - if (shared) { + using Cache = std::conditional_t, const Ftor &>; + Cache cache {ftor}; + if constexpr (shared) { // cache.save(arg.in(s_ * arg.volume_4d_cb + x_cb, parity)); if (sync) { cache.sync(); } cache.save(in); @@ -419,6 +444,12 @@ namespace quda return out; } + template struct variableInvParams { + using Vec = ColorSpinor; + using Cache = SharedMemoryCache; + using Ops = std::conditional_t, NoKernelOps>; + }; + /** @brief Apply the M5 inverse operator at a given site on the lattice. This is an alternative algorithm that is applicable to @@ -436,10 +467,11 @@ namespace quda @param[in] x_b Checkerboarded 4-d space-time index @param[in] s_ Ls dimension coordinate */ - template - __device__ __host__ inline Vector variableInv(const Arg &arg, const Vector &in, int parity, int x_cb, int s_, + template + __device__ __host__ inline Vector variableInv(const Ftor &ftor, const Vector &in, int parity, int x_cb, int s_, int src_idx) { + const Arg &arg = ftor.arg; int local_src_idx = target::thread_idx().y / arg.Ls; constexpr int nSpin = 4; using real = typename Arg::real; @@ -447,8 +479,9 @@ namespace quda coeff_type::value, Arg> coeff(arg); Vector out; - if (mobius_m5::use_half_vector()) { - SharedMemoryCache cache; + if constexpr (mobius_m5::use_half_vector()) { + using Cache = std::conditional_t, const Ftor &>; + Cache cache {ftor}; { // first do R constexpr int proj_dir = dagger ? -1 : +1; @@ -507,7 +540,8 @@ namespace quda out += l.reconstruct(4, proj_dir); } } else { // use_half_vector - SharedMemoryCache cache; + using Cache = std::conditional_t, const Ftor &>; + Cache cache {ftor}; if (shared) { if (sync) { cache.sync(); } cache.save(in); @@ -562,9 +596,19 @@ namespace quda @brief Functor for applying the M5 inverse operator @param[in] arg Argument struct containing any meta data and accessors */ - template struct dslash5inv { + template struct dslash5invParams { + using Ops = std::conditional_t::Ops, + typename constantInvParams::Ops>; + }; + + template struct dslash5inv : dslash5invParams::Ops { + using Arg = Arg_; const Arg &arg; - constexpr dslash5inv(const Arg &arg) : arg(arg) {} + using typename dslash5invParams::Ops::KernelOpsT; + template + constexpr dslash5inv(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) + { + } static constexpr const char *filename() { return KERNEL_FILE; } /** @@ -587,10 +631,10 @@ namespace quda Vector in = arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity); Vector out; constexpr bool sync = false; - if (mobius_m5::var_inverse()) { // zMobius, must call variableInv - out = variableInv(arg, in, parity, x_cb, s, src_idx); + if constexpr (mobius_m5::var_inverse()) { // zMobius, must call variableInv + out = variableInv(*this, in, parity, x_cb, s, src_idx); } else { - out = constantInv(arg, in, parity, x_cb, s, src_idx); + out = constantInv(*this, in, parity, x_cb, s, src_idx); } if (Arg::xpay) { diff --git a/include/kernels/dslash_mdw_fused.cuh b/include/kernels/dslash_mdw_fused.cuh index 67f98b30cc..35b3fe905a 100644 --- a/include/kernels/dslash_mdw_fused.cuh +++ b/include/kernels/dslash_mdw_fused.cuh @@ -279,7 +279,7 @@ namespace quda { @brief Tensor core kernel for applying Wilson hopping term and then the beta + alpha * M5inv operator The integer kernel types corresponds to the enum MdwfFusedDslashType. */ - template struct FusedMobiusDslash { + template struct FusedMobiusDslash : KernelOps> { Arg &arg; constexpr FusedMobiusDslash(Arg &arg) : arg(arg) {} static constexpr const char *filename() { return KERNEL_FILE; } @@ -292,7 +292,7 @@ namespace quda { constexpr int Ls = Arg::Ls; const int explicit_parity = arg.nParity == 2 ? arg.parity : 0; - SharedMemoryCache cache; + SharedMemoryCache cache {*this}; static_assert(Arg::block_dim_x * Ls / 32 < 32, "Number of threads in a threadblock should be less than 1024."); @@ -441,7 +441,7 @@ namespace quda { } // while } }; - + #endif // QUDA_MMA_AVAILABLE } diff --git a/include/kernels/dslash_mobius_eofa.cuh b/include/kernels/dslash_mobius_eofa.cuh index c46ffa1d62..49e65da6d7 100644 --- a/include/kernels/dslash_mobius_eofa.cuh +++ b/include/kernels/dslash_mobius_eofa.cuh @@ -92,6 +92,8 @@ namespace quda } }; + template + using eofa_dslash5Ops = KernelOps>>; /** @brief Apply the D5 operator at given site @param[in] arg Argument struct containing any meta data and accessors @@ -99,9 +101,13 @@ namespace quda @param[in] x_cb Checkerboarded 4-d space-time index @param[in] s Ls dimension coordinate */ - template struct eofa_dslash5 { + template struct eofa_dslash5 : eofa_dslash5Ops { const Arg &arg; - constexpr eofa_dslash5(const Arg &arg) : arg(arg) {} + using typename eofa_dslash5Ops::KernelOpsT; + template + constexpr eofa_dslash5(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) + { + } static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) @@ -112,7 +118,7 @@ namespace quda int src_idx = src_s / arg.Ls; int s = src_s % arg.Ls; - SharedMemoryCache cache; + SharedMemoryCache cache {*this}; Vector out; cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); @@ -167,6 +173,8 @@ namespace quda } }; + template + using eofa_dslash5invOps = KernelOps>>; /** @brief Apply the M5 inverse operator at a given site on the lattice. This is the original algorithm as described in Kim and @@ -179,9 +187,13 @@ namespace quda @param[in] x_cb Checkerboarded 4-d space-time index @param[in] s Ls dimension coordinate */ - template struct eofa_dslash5inv { + template struct eofa_dslash5inv : eofa_dslash5invOps { const Arg &arg; - constexpr eofa_dslash5inv(const Arg &arg) : arg(arg) {} + using typename eofa_dslash5invOps::KernelOpsT; + template + constexpr eofa_dslash5inv(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) + { + } static constexpr const char *filename() { return KERNEL_FILE; } __device__ __host__ inline void operator()(int x_cb, int src_s, int parity) @@ -193,7 +205,7 @@ namespace quda int s = src_s % arg.Ls; const auto sherman_morrison = arg.sherman_morrison; - SharedMemoryCache cache; + SharedMemoryCache cache {*this}; cache.save(arg.in[src_idx](s * arg.volume_4d_cb + x_cb, parity)); cache.sync(); diff --git a/include/kernels/dslash_ndeg_twisted_clover.cuh b/include/kernels/dslash_ndeg_twisted_clover.cuh index fdc4c5cb99..0036e9d71c 100644 --- a/include/kernels/dslash_ndeg_twisted_clover.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover.cuh @@ -36,14 +36,22 @@ namespace quda checkLocation(U, A); } }; - + + template struct nDegTwistedCloverParams { + using real = typename mapper::type; + using Vec = ColorSpinor; + using Cache = SharedMemoryCache; + using Ops = std::conditional_t, NoKernelOps>; + }; + template - struct nDegTwistedClover : dslash_default { - + struct nDegTwistedClover : dslash_default, nDegTwistedCloverParams::Ops { + const Arg &arg; - constexpr nDegTwistedClover(const Arg &arg) : arg(arg) {} + using typename nDegTwistedCloverParams::Ops::KernelOpsT; + template constexpr nDegTwistedClover(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation - + /** @brief Apply the non-degenerate twisted-clover dslash out(x) = M*in = a * D * in + (A(x) + i*b*gamma_5*tau_3 + c*tau_1)*x @@ -58,29 +66,28 @@ namespace quda int src_idx = src_flavor / 2; int flavor = src_flavor % 2; - bool active = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) int thread_dim; // which dimension is thread working on (fused kernel only) auto coord = getCoords(arg, idx, flavor, parity, thread_dim); - + const int my_spinor_parity = nParity == 2 ? parity : 0; const int my_flavor_idx = coord.x_cb + flavor * arg.dc.volume_4d_cb; Vector out; - + // defined in dslash_wilson.cuh applyWilson(out, arg, coord, parity, idx, thread_dim, active, src_idx); - if (mykernel_type == INTERIOR_KERNEL) { + if constexpr (mykernel_type == INTERIOR_KERNEL) { // apply the chiral and flavor twists // use consistent load order across s to ensure better cache locality Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); - SharedMemoryCache cache; + SharedMemoryCache cache {*this}; cache.save(x); x.toRel(); // switch to chiral basis - + Vector tmp; #pragma unroll for (int chirality = 0; chirality < 2; chirality++) { diff --git a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh index 01bcfd7088..fc8b04ab67 100644 --- a/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_clover_preconditioned.cuh @@ -40,13 +40,24 @@ namespace quda } }; + template struct nDegTwistedCloverPreconditionedParams { + using real = typename mapper::type; + using Vec = ColorSpinor; + using Cache = SharedMemoryCache; + using Ops = KernelOps; + }; + template - struct nDegTwistedCloverPreconditioned : dslash_default { - + struct nDegTwistedCloverPreconditioned : dslash_default, nDegTwistedCloverPreconditionedParams::Ops { + const Arg &arg; - constexpr nDegTwistedCloverPreconditioned(const Arg &arg) : arg(arg) {} + using typename nDegTwistedCloverPreconditionedParams::Ops::KernelOpsT; + template + constexpr nDegTwistedCloverPreconditioned(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) + { + } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation - + /** @brief Apply the preconditioned twisted-clover dslash out(x) = M*in = a*(C + i*b*gamma_5*tau_3 + c*tau_1)/(C^2 + b^2 - c^2)*D*x ( xpay == false ) @@ -93,7 +104,7 @@ namespace quda int chirality = flavor; // relabel flavor as chirality - SharedMemoryCache cache; + SharedMemoryCache cache {*this}; auto swizzle = [&](HalfVector x[2], int chirality) { if (chirality == 0) @@ -121,7 +132,7 @@ namespace quda A_chi[flavor_] += arg.c * out_chi[1 - flavor_]; } - if (arg.dynamic_clover) { + if constexpr (Arg::dynamic_clover) { HMat A2 = A.square(); A2 += arg.b2_minus_c2; Cholesky, Arg::nColor * Arg::nSpin / 2> cholesky(A2); @@ -142,7 +153,7 @@ namespace quda Vector tmp = out_chi[0].chiral_reconstruct(0) + out_chi[1].chiral_reconstruct(1); tmp.toNonRel(); // switch back to non-chiral basis - if (xpay) { + if constexpr (xpay) { Vector x = arg.x[src_idx](my_flavor_idx, my_spinor_parity); out = x + arg.a * tmp; } else { diff --git a/include/kernels/dslash_ndeg_twisted_mass.cuh b/include/kernels/dslash_ndeg_twisted_mass.cuh index e4df183846..3e1167c602 100644 --- a/include/kernels/dslash_ndeg_twisted_mass.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass.cuh @@ -27,7 +27,7 @@ namespace quda struct nDegTwistedMass : dslash_default { const Arg &arg; - constexpr nDegTwistedMass(const Arg &arg) : arg(arg) {} + template constexpr nDegTwistedMass(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh index 4bdb432039..568dd7e6ed 100644 --- a/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_ndeg_twisted_mass_preconditioned.cuh @@ -38,11 +38,21 @@ namespace quda } }; + template struct nDegTwistedMassPreconditionedParams { + using real = typename mapper::type; + using Vec = ColorSpinor; + using Cache = SharedMemoryCache; + using Ops = std::conditional_t, NoKernelOps>; + }; + template - struct nDegTwistedMassPreconditioned : dslash_default { + struct nDegTwistedMassPreconditioned : dslash_default, nDegTwistedMassPreconditionedParams::Ops { const Arg &arg; - constexpr nDegTwistedMassPreconditioned(const Arg &arg) : arg(arg) {} + using typename nDegTwistedMassPreconditionedParams::Ops::KernelOpsT; + template constexpr nDegTwistedMassPreconditioned(const Ftor &ftor) : KernelOpsT(ftor), arg(ftor.arg) + { + } constexpr int twist_pack() const { return (!Arg::asymmetric && dagger) ? 2 : 0; } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation @@ -97,9 +107,9 @@ namespace quda Vector x = arg.out[src_idx](my_flavor_idx, my_spinor_parity); out += x; } - - if (!dagger || Arg::asymmetric) { // apply A^{-1} to D*in - SharedMemoryCache cache; + + if constexpr (!dagger || Arg::asymmetric) { // apply A^{-1} to D*in + SharedMemoryCache cache {*this}; if (isComplete(arg, coord) && active) { // to apply the preconditioner we need to put "out" in shared memory so the other flavor can access it cache.save(out); @@ -116,7 +126,6 @@ namespace quda if (mykernel_type != EXTERIOR_KERNEL_ALL || active) arg.out[src_idx](my_flavor_idx, my_spinor_parity) = out; } - }; } // namespace quda diff --git a/include/kernels/dslash_staggered.cuh b/include/kernels/dslash_staggered.cuh index ec09ab1622..4b903ba55d 100644 --- a/include/kernels/dslash_staggered.cuh +++ b/include/kernels/dslash_staggered.cuh @@ -203,7 +203,7 @@ namespace quda struct staggered : dslash_default { const Arg &arg; - constexpr staggered(const Arg &arg) : arg(arg) {} + template constexpr staggered(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/dslash_twisted_clover_preconditioned.cuh b/include/kernels/dslash_twisted_clover_preconditioned.cuh index 99db81ee60..ded25c9f82 100644 --- a/include/kernels/dslash_twisted_clover_preconditioned.cuh +++ b/include/kernels/dslash_twisted_clover_preconditioned.cuh @@ -41,7 +41,7 @@ namespace quda struct twistedCloverPreconditioned : dslash_default { const Arg &arg; - constexpr twistedCloverPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr twistedCloverPreconditioned(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_twisted_mass.cuh b/include/kernels/dslash_twisted_mass.cuh index bed5cf9369..e8c9a77d35 100644 --- a/include/kernels/dslash_twisted_mass.cuh +++ b/include/kernels/dslash_twisted_mass.cuh @@ -25,7 +25,7 @@ namespace quda struct twistedMass : dslash_default { const Arg &arg; - constexpr twistedMass(const Arg &arg) : arg(arg) {} + template constexpr twistedMass(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_twisted_mass_preconditioned.cuh b/include/kernels/dslash_twisted_mass_preconditioned.cuh index 3fc7e9d42d..00aed4345a 100644 --- a/include/kernels/dslash_twisted_mass_preconditioned.cuh +++ b/include/kernels/dslash_twisted_mass_preconditioned.cuh @@ -138,7 +138,7 @@ namespace quda struct twistedMassPreconditioned : dslash_default { const Arg &arg; - constexpr twistedMassPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr twistedMassPreconditioned(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation constexpr int twist_pack() const { return (!Arg::asymmetric && dagger) ? 1 : 0; } diff --git a/include/kernels/dslash_wilson.cuh b/include/kernels/dslash_wilson.cuh index 584999d52f..163e5d36cc 100644 --- a/include/kernels/dslash_wilson.cuh +++ b/include/kernels/dslash_wilson.cuh @@ -162,7 +162,7 @@ namespace quda template struct wilson : dslash_default { const Arg &arg; - constexpr wilson(const Arg &arg) : arg(arg) {} + template constexpr wilson(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation // out(x) = M*in = (-D + m) * in(x-mu) @@ -175,7 +175,7 @@ namespace quda bool active = mykernel_type == EXTERIOR_KERNEL_ALL ? false : true; // is thread active (non-trival for fused kernel only) int thread_dim; // which dimension is thread working on (fused kernel only) - + auto coord = getCoords(arg, idx, 0, parity, thread_dim); const int my_spinor_parity = nParity == 2 ? parity : 0; diff --git a/include/kernels/dslash_wilson_clover.cuh b/include/kernels/dslash_wilson_clover.cuh index cb4c75a86b..f774938ade 100644 --- a/include/kernels/dslash_wilson_clover.cuh +++ b/include/kernels/dslash_wilson_clover.cuh @@ -39,7 +39,7 @@ namespace quda struct wilsonClover : dslash_default { const Arg &arg; - constexpr wilsonClover(const Arg &arg) : arg(arg) {} + template constexpr wilsonClover(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh b/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh index 994fd2caf1..9e9163136f 100644 --- a/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh +++ b/include/kernels/dslash_wilson_clover_hasenbusch_twist.cuh @@ -37,9 +37,9 @@ namespace quda struct cloverHasenbusch : dslash_default { const Arg &arg; - constexpr cloverHasenbusch(const Arg &arg) : arg(arg) {} + template constexpr cloverHasenbusch(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation - + /** @brief Apply the Wilson-clover dslash out(x) = M*in = A(x)*x(x) + D * in(x-mu) diff --git a/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh b/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh index e4bb7113f0..e2d1e6a44c 100644 --- a/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh +++ b/include/kernels/dslash_wilson_clover_hasenbusch_twist_preconditioned.cuh @@ -39,7 +39,7 @@ namespace quda struct cloverHasenbuschPreconditioned : dslash_default { const Arg &arg; - constexpr cloverHasenbuschPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr cloverHasenbuschPreconditioned(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/dslash_wilson_clover_preconditioned.cuh b/include/kernels/dslash_wilson_clover_preconditioned.cuh index 86d5c71534..c2d1636241 100644 --- a/include/kernels/dslash_wilson_clover_preconditioned.cuh +++ b/include/kernels/dslash_wilson_clover_preconditioned.cuh @@ -37,7 +37,7 @@ namespace quda struct wilsonCloverPreconditioned : dslash_default { const Arg &arg; - constexpr wilsonCloverPreconditioned(const Arg &arg) : arg(arg) {} + template constexpr wilsonCloverPreconditioned(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation /** diff --git a/include/kernels/gauge_fix_ovr.cuh b/include/kernels/gauge_fix_ovr.cuh index c43649c15e..2424ad9ecc 100644 --- a/include/kernels/gauge_fix_ovr.cuh +++ b/include/kernels/gauge_fix_ovr.cuh @@ -126,12 +126,25 @@ namespace quda { } }; + template + using computeFixOps2 + = std::conditional_t, + std::conditional_t, + GaugeFixHit_NoAtomicAdd_LessSM2Ops>>; + template + using computeFixOps = std::conditional_t< + Arg::type == 0, GaugeFixHit_NoAtomicAddOps, + std::conditional_t< + Arg::type == 1, GaugeFixHit_AtomicAddOps, + std::conditional_t, computeFixOps2>>>; /** * @brief Perform gauge fixing with overrelaxation */ - template struct computeFix { + template struct computeFix : computeFixOps { const Arg &arg; - constexpr computeFix(const Arg &arg) : arg(arg) {} + using typename computeFixOps::KernelOpsT; + template + constexpr computeFix(const Arg &arg, const Ops &...ops) : KernelOpsT(ops...), arg(arg) { } static constexpr const char *filename() { return KERNEL_FILE; } __device__ inline void operator()(int idx, int mu) @@ -162,7 +175,7 @@ namespace quda { X[dr] += 2 * arg.border[dr]; } - if (Arg::type < 3) { + if constexpr (Arg::type < 3) { // 8 threads per lattice site int dim = mu; if (dim >= 4) { @@ -178,16 +191,19 @@ namespace quda { idx = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1; Link link = arg.u(dim, idx, parity); - switch (Arg::type) { + if constexpr (Arg::type == 0) { // 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd. // this implementation needs 8x more shared memory than the implementation using atomicadd - case 0: GaugeFixHit_NoAtomicAdd(link, arg.relax_boost, mu); break; + GaugeFixHit_NoAtomicAdd(link, arg.relax_boost, mu, *this); + } + if constexpr (Arg::type == 1) { // 8 threads per lattice site, the reduction is performed by shared memory using atomicadd - case 1: GaugeFixHit_AtomicAdd(link, arg.relax_boost, mu); break; + GaugeFixHit_AtomicAdd(link, arg.relax_boost, mu, *this); + } + if constexpr (Arg::type == 2) { // 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd. // uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization - case 2: GaugeFixHit_NoAtomicAdd_LessSM(link, arg.relax_boost, mu); break; - default: break; + GaugeFixHit_NoAtomicAdd_LessSM(link, arg.relax_boost, mu, *this); } arg.u(dim, idx, parity) = link; @@ -205,16 +221,19 @@ namespace quda { int idx1 = (((x[3] * X[2] + x[2]) * X[1] + x[1]) * X[0] + x[0]) >> 1; Link link1 = arg.u(mu, idx1, 1 - parity); - switch (Arg::type) { + if constexpr (Arg::type == 3) { // 4 threads per lattice site, the reduction is performed by shared memory without using atomicadd. // this implementation needs 4x more shared memory than the implementation using atomicadd - case 3: GaugeFixHit_NoAtomicAdd(link, link1, arg.relax_boost, mu); break; + GaugeFixHit_NoAtomicAdd(link, link1, arg.relax_boost, mu, *this); + } + if constexpr (Arg::type == 4) { // 4 threads per lattice site, the reduction is performed by shared memory using atomicadd - case 4: GaugeFixHit_AtomicAdd(link, link1, arg.relax_boost, mu); break; + GaugeFixHit_AtomicAdd(link, link1, arg.relax_boost, mu, *this); + } + if constexpr (Arg::type == 5) { // 4 threads per lattice site, the reduction is performed by shared memory without using atomicadd. // uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization - case 5: GaugeFixHit_NoAtomicAdd_LessSM(link, link1, arg.relax_boost, mu); break; - default: break; + GaugeFixHit_NoAtomicAdd_LessSM(link, link1, arg.relax_boost, mu, *this); } arg.u(mu, idx, parity) = link; diff --git a/include/kernels/laplace.cuh b/include/kernels/laplace.cuh index a3c2b1b377..e7833dae79 100644 --- a/include/kernels/laplace.cuh +++ b/include/kernels/laplace.cuh @@ -134,12 +134,12 @@ namespace quda } } } - + // out(x) = M*in template struct laplace : dslash_default { const Arg &arg; - constexpr laplace(const Arg &arg) : arg(arg) {} + template constexpr laplace(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/kernels/madwf_transfer.cuh b/include/kernels/madwf_transfer.cuh index 96383ec35b..616d6a40c0 100644 --- a/include/kernels/madwf_transfer.cuh +++ b/include/kernels/madwf_transfer.cuh @@ -90,10 +90,19 @@ namespace quda } }; - template struct Transfer5D { + template struct Transfer5DParams { + using Cache = SharedMemoryCache; + using Ops = KernelOps; + }; + + template struct Transfer5D : Transfer5DParams::Ops { const Arg &arg; - constexpr Transfer5D(const Arg &arg) : arg(arg) { } + using typename Transfer5DParams::Ops::KernelOpsT; + template + constexpr Transfer5D(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg) + { + } static constexpr const char *filename() { return KERNEL_FILE; } /** @@ -116,7 +125,7 @@ namespace quda const matrix_t *wm_p = arg.wm_p; int thread_idx = target::thread_idx().y * target::block_dim().x + target::thread_idx().x; - SharedMemoryCache cache; + typename Transfer5DParams::Cache cache {*this}; while (thread_idx < static_cast(Ls_out * Ls_in * sizeof(matrix_t) / sizeof(real))) { cache.data()[thread_idx] = reinterpret_cast(wm_p)[thread_idx]; thread_idx += target::block_dim().y * target::block_dim().x; diff --git a/include/kernels/staggered_quark_smearing.cuh b/include/kernels/staggered_quark_smearing.cuh index 85e1790318..7d198b5388 100644 --- a/include/kernels/staggered_quark_smearing.cuh +++ b/include/kernels/staggered_quark_smearing.cuh @@ -161,7 +161,7 @@ namespace quda struct staggered_qsmear : dslash_default { const Arg &arg; - constexpr staggered_qsmear(const Arg &arg) : arg(arg) { } + template constexpr staggered_qsmear(const Ftor &ftor) : arg(ftor.arg) { } static constexpr const char *filename() { return KERNEL_FILE; } // this file name - used for run-time compilation template diff --git a/include/targets/generic/kernel_ops.h b/include/targets/generic/kernel_ops.h index ea7d22daaa..7e3888e846 100644 --- a/include/targets/generic/kernel_ops.h +++ b/include/targets/generic/kernel_ops.h @@ -108,9 +108,7 @@ namespace quda that need tagging. This can be used as an alternative in cases where the operations are only conditionally used. */ - struct NoKernelOps { - using KernelOpsT = NoKernelOps; - }; + using NoKernelOps = KernelOps<>; /** @brief getKernelOps is used to get the KernelOps type from a @@ -213,4 +211,12 @@ namespace quda using type = T; }; + // forward declarations of op types to be defined by target + struct op_blockSync; + template struct op_warp_combine; + + // only types for convenience + using only_blockSync = KernelOps; + template using only_warp_combine = KernelOps>; + } // namespace quda diff --git a/include/targets/generic/shared_memory_cache_helper.h b/include/targets/generic/shared_memory_cache_helper.h index c4b466fb57..86e603e6d2 100644 --- a/include/targets/generic/shared_memory_cache_helper.h +++ b/include/targets/generic/shared_memory_cache_helper.h @@ -1,5 +1,6 @@ #pragma once +#include // for atom_t #include #include #include @@ -90,13 +91,18 @@ namespace quda /** @brief Constructor for SharedMemoryCache. */ - constexpr SharedMemoryCache() : block(D::dims(target::block_dim())), stride(block.x * block.y * block.z) + template + constexpr SharedMemoryCache(const KernelOps &ops, Arg... arg) : + Smem(ops), block(D::dims(target::block_dim(), arg...)), stride(block.x * block.y * block.z) { + checkKernelOps>(ops); // sanity check static_assert(shared_mem_size(dim3 {32, 16, 8}) == Smem::get_offset(dim3 {32, 16, 8}) + SizeDims::size(dim3 {32, 16, 8}) * sizeof(T)); } + constexpr SharedMemoryCache(const SharedMemoryCache &) = delete; + /** @brief Grab the raw base address to shared memory. */