Skip to content

Commit

Permalink
pass KernelOps into SharedMemoryCache constructor
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Jan 4, 2025
1 parent b58f1ec commit bc48268
Show file tree
Hide file tree
Showing 33 changed files with 445 additions and 225 deletions.
11 changes: 8 additions & 3 deletions include/dslash_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <shmem_pack_helper.cuh>
#include <kernel_helper.h>
#include <tune_quda.h>
#include <kernel_ops.h>

constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::TRUE;

Expand Down Expand Up @@ -660,17 +661,21 @@ namespace quda
are reserved for data packing, which may include communication to
neighboring processes.
*/
template <typename Arg> struct dslash_functor {
template <typename Arg> struct dslash_functor : getKernelOps<typename Arg::D> {
const typename Arg::Arg &arg;
static constexpr int nParity = Arg::nParity;
static constexpr bool dagger = Arg::dagger;
static constexpr KernelType kernel_type = Arg::kernel_type;
static constexpr const char *filename() { return Arg::D::filename(); }
constexpr dslash_functor(const Arg &arg) : arg(arg.arg) { }
using typename getKernelOps<typename Arg::D>::KernelOpsT;
template <typename... OpsArgs>
constexpr dslash_functor(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg.arg)
{
}

__forceinline__ __device__ void operator()(int, int s, int parity)
{
typename Arg::D dslash(arg);
typename Arg::D dslash(*this);
// for full fields set parity from z thread index else use arg setting
if (nParity == 1) parity = arg.parity;

Expand Down
84 changes: 54 additions & 30 deletions include/gauge_fix_ovr_hit_devf.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -40,20 +40,29 @@ namespace quda {
}
}

template <int N> struct GaugeFixHitDims {
static constexpr dim3 dims(dim3 block)
{
block.y = N;
return block;
}
};

/**
* Device function to perform gauge fixing with overrelxation.
* Uses 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation needs 8x more shared memory than the implementation using atomicadd
* Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd.
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>,nColor> &link, const Float relax_boost, int mu)
template <typename Float> using GaugeFixHit_AtomicAddOps = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>, nColor> &link, const Float relax_boost, int mu,
const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<4>> cache(ftor);
Float *elems = cache.data();

//initialize shared memory
if (mu < 4) elems[mu * blockSize + tid] = 0.0;
Expand Down Expand Up @@ -138,17 +147,20 @@ namespace quda {

/**
* Device function to perform gauge fixing with overrelxation.
* Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd.
* Uses 4*8 threads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation needs 8x more shared memory than the implementation using atomicadd
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>,nColor> &link, const Float relax_boost, int mu)
template <typename Float> using GaugeFixHit_NoAtomicAddOps = KernelOps<SharedMemoryCache<array<Float, 4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>, nColor> &link, const Float relax_boost, int mu,
const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<array<Float, 4>> cache(ftor);
Float *elems = &(*cache.data())[0];

//Loop over all SU(2) subroups of SU(N)
//#pragma unroll
Expand Down Expand Up @@ -228,15 +240,18 @@ namespace quda {
* Uses 8 treads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>,nColor> &link, const Float relax_boost, int mu)
template <typename Float>
using GaugeFixHit_NoAtomicAdd_LessSMOps = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>, nColor> &link, const Float relax_boost,
int mu, const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<4>> cache(ftor);
Float *elems = cache.data();

//Loop over all SU(2) subroups of SU(N)
//#pragma unroll
Expand Down Expand Up @@ -323,18 +338,20 @@ namespace quda {
/**
* Device function to perform gauge fixing with overrelxation.
* Uses 8 threads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation needs 8x more shared memory than the implementation using atomicadd
* This implementation needs 8x more shared memory than the implementation using atomicadd
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>,nColor> &link, Matrix<complex<Float>,nColor> &link1,
const Float relax_boost, int mu)
template <typename Float> using GaugeFixHit_AtomicAdd2Ops = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_AtomicAdd(Matrix<complex<Float>, nColor> &link,
Matrix<complex<Float>, nColor> &link1, const Float relax_boost, int mu,
const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<4>> cache(ftor);
Float *elems = cache.data();

//initialize shared memory
if (mu < 4) elems[mu * blockSize + tid] = 0.0;
Expand Down Expand Up @@ -408,16 +425,19 @@ namespace quda {
* Device function to perform gauge fixing with overrelxation.
* Uses 4 threads per lattice site, the reduction is performed by shared memory using atomicadd.
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>,nColor> &link, Matrix<complex<Float>,nColor> &link1,
const Float relax_boost, int mu)
template <typename Float>
using GaugeFixHit_NoAtomicAdd2Ops = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<16>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_NoAtomicAdd(Matrix<complex<Float>, nColor> &link,
Matrix<complex<Float>, nColor> &link1, const Float relax_boost, int mu,
const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<16>> cache(ftor);
Float *elems = cache.data();

//Loop over all SU(2) subroups of SU(N)
//#pragma unroll
Expand Down Expand Up @@ -485,15 +505,19 @@ namespace quda {
* Uses 4 threads per lattice site, the reduction is performed by shared memory without using atomicadd.
* This implementation uses the same amount of shared memory as the atomicadd implementation with more thread block synchronization
*/
template <typename Float, int gauge_dir, int nColor>
inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>,nColor> &link, Matrix<complex<Float>,nColor> &link1, const Float relax_boost, int mu)
template <typename Float>
using GaugeFixHit_NoAtomicAdd_LessSM2Ops = KernelOps<SharedMemoryCache<Float, GaugeFixHitDims<4>>>;
template <typename Float, int gauge_dir, int nColor, typename Ftor>
inline __device__ void GaugeFixHit_NoAtomicAdd_LessSM(Matrix<complex<Float>, nColor> &link,
Matrix<complex<Float>, nColor> &link1, const Float relax_boost,
int mu, const Ftor &ftor)
{
auto blockSize = target::block_dim().x;
auto tid = target::thread_idx().x;

//Container for the four real parameters of SU(2) subgroup in shared memory
SharedMemoryCache<Float> cache;
auto elems = cache.data();
SharedMemoryCache<Float, GaugeFixHitDims<4>> cache(ftor);
Float *elems = cache.data();

//Loop over all SU(2) subroups of SU(N)
//#pragma unroll
Expand Down
21 changes: 15 additions & 6 deletions include/kernels/block_transpose.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,7 @@ namespace quda
}
};

template <typename Arg> struct BlockTransposeKernel {
const Arg &arg;
constexpr BlockTransposeKernel(const Arg &arg) : arg(arg) { }
static constexpr const char *filename() { return KERNEL_FILE; }

template <typename Arg> struct BlockTransposeKernelOps {
struct CacheDims {
static constexpr dim3 dims(dim3 block)
{
Expand All @@ -55,6 +51,19 @@ namespace quda
return block;
}
};
using color_spinor_t = ColorSpinor<typename Arg::real, 1, Arg::nSpin>;
using CacheT = SharedMemoryCache<color_spinor_t, CacheDims>;
using Ops = KernelOps<CacheT>;
};

template <typename Arg> struct BlockTransposeKernel : BlockTransposeKernelOps<Arg>::Ops {
const Arg &arg;
using typename BlockTransposeKernelOps<Arg>::Ops::KernelOpsT;
template <typename... OpsArgs>
constexpr BlockTransposeKernel(const Arg &arg, const OpsArgs &...ops) : KernelOpsT(ops...), arg(arg)
{
}
static constexpr const char *filename() { return KERNEL_FILE; }

/**
@brief Transpose between the two different orders of batched colorspinor fields:
Expand All @@ -69,7 +78,7 @@ namespace quda
int parity = parity_color / Arg::nColor;
using color_spinor_t = ColorSpinor<typename Arg::real, 1, Arg::nSpin>;

SharedMemoryCache<color_spinor_t, CacheDims> cache;
typename BlockTransposeKernelOps<Arg>::CacheT cache {*this};

int x_offset = target::block_dim().x * target::block_idx().x;
int v_offset = target::block_dim().y * target::block_idx().y;
Expand Down
Loading

0 comments on commit bc48268

Please sign in to comment.