Skip to content

Commit

Permalink
Merge branch 'develop' into feature/sycl
Browse files Browse the repository at this point in the history
  • Loading branch information
jcosborn committed Dec 10, 2024
2 parents 7c24446 + a54595d commit 94a0c38
Show file tree
Hide file tree
Showing 27 changed files with 378 additions and 221 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ include/jitify_options.hpp
.tags*
autom4te.cache/*
.vscode
cmake/CPM_*.cmake
20 changes: 9 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,12 @@ if(QUDA_MAX_MULTI_BLAS_N GREATER 32)
message(SEND_ERROR "Maximum QUDA_MAX_MULTI_BLAS_N is 32.")
endif()

# For now only we only support register tiles for the staggered dslash operators
set(QUDA_MAX_MULTI_RHS_TILE "1" CACHE STRING "maximum tile size for MRHS kernels (staggered only)")
if(QUDA_MAX_MULTI_RHS_TILE GREATER QUDA_MAX_MULTI_RHS)
message(SEND_ERROR "QUDA_MAX_MULTI_RHS_TILE is greater than QUDA_MAX_MULTI_RHS")
endif()

set(QUDA_PRECISION
"14"
CACHE STRING "which precisions to instantiate in QUDA (4-bit number - double, single, half, quarter)")
Expand Down Expand Up @@ -275,6 +281,7 @@ mark_as_advanced(QUDA_ALTERNATIVE_I_TO_F)

mark_as_advanced(QUDA_MAX_MULTI_BLAS_N)
mark_as_advanced(QUDA_MAX_MULTI_RHS)
mark_as_advanced(QUDA_MAX_MULTI_RHS_TILE)
mark_as_advanced(QUDA_PRECISION)
mark_as_advanced(QUDA_RECONSTRUCT)
mark_as_advanced(QUDA_CLOVER_CHOLESKY_PROMOTE)
Expand Down Expand Up @@ -420,21 +427,12 @@ if(QUDA_DOWNLOAD_EIGEN)
CPMAddPackage(
NAME Eigen
VERSION ${QUDA_EIGEN_VERSION}
URL https://gitlab.com/libeigen/eigen/-/archive/${QUDA_EIGEN_VERSION}/eigen-${QUDA_EIGEN_VERSION}.tar.bz2
URL_HASH SHA256=B4C198460EBA6F28D34894E3A5710998818515104D6E74E5CC331CE31E46E626
URL https://gitlab.com/libeigen/eigen/-/archive/e67c494cba7180066e73b9f6234d0b2129f1cdf5.tar.bz2
URL_HASH SHA256=98d244932291506b75c4ae7459af29b1112ea3d2f04660686a925d9ef6634583
DOWNLOAD_ONLY YES
SYSTEM YES)
target_include_directories(Eigen SYSTEM INTERFACE ${Eigen_SOURCE_DIR})
install(DIRECTORY ${Eigen_SOURCE_DIR}/Eigen TYPE INCLUDE)

# Eigen 3.4 needs to be patched on Neon with nvc++
if (${CMAKE_CXX_COMPILER_ID} MATCHES "NVHPC")
set(CMAKE_PATCH_EIGEN OFF CACHE BOOL "Internal use only; do not modify")
if (NOT CMAKE_PATCH_EIGEN)
execute_process(COMMAND patch -N "${Eigen_SOURCE_DIR}/Eigen/src/Core/arch/NEON/Complex.h" "${CMAKE_SOURCE_DIR}/cmake/eigen34_neon.diff")
set(CMAKE_PATCH_EIGEN ON CACHE BOOL "Internal use only; do not modify" FORCE)
endif()
endif()
else()
# fall back to using find_package
find_package(Eigen QUIET)
Expand Down
8 changes: 0 additions & 8 deletions cmake/eigen34_neon.diff

This file was deleted.

38 changes: 16 additions & 22 deletions include/complex_quda.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,23 +363,19 @@ struct complex
typedef ValueType value_type;

// Constructors
__host__ __device__ inline complex<ValueType>(const ValueType &re = ValueType(), const ValueType &im = ValueType())
__host__ __device__ inline complex(const ValueType &re = ValueType(), const ValueType &im = ValueType())
{
real(re);
imag(im);
}

template <class X>
__host__ __device__
inline complex<ValueType>(const complex<X> & z)
template <class X> __host__ __device__ inline complex(const complex<X> &z)
{
real(z.real());
imag(z.imag());
}

template <class X>
__host__ __device__
inline complex<ValueType>(const std::complex<X> & z)
template <class X> __host__ __device__ inline complex(const std::complex<X> &z)
{
real(z.real());
imag(z.imag());
Expand Down Expand Up @@ -439,12 +435,11 @@ struct complex
template <> struct complex<float> : public float2 {
public:
typedef float value_type;
complex<float>() = default;
constexpr complex<float>(const float &re, const float &im = float()) : float2 {re, im} { }
complex() = default;
constexpr complex(const float &re, const float &im = float()) : float2 {re, im} { }

template <typename X>
constexpr complex<float>(const std::complex<X> &z) :
float2 {static_cast<float>(z.real()), static_cast<float>(z.imag())}
constexpr complex(const std::complex<X> &z) : float2 {static_cast<float>(z.real()), static_cast<float>(z.imag())}
{
}

Expand Down Expand Up @@ -503,16 +498,15 @@ template <> struct complex<float> : public float2 {
template <> struct complex<double> : public double2 {
public:
typedef double value_type;
complex<double>() = default;
constexpr complex<double>(const double &re, const double &im = double()) : double2 {re, im} { }
complex() = default;
constexpr complex(const double &re, const double &im = double()) : double2 {re, im} { }

template <typename X>
constexpr complex<double>(const std::complex<X> &z) :
double2 {static_cast<double>(z.real()), static_cast<double>(z.imag())}
constexpr complex(const std::complex<X> &z) : double2 {static_cast<double>(z.real()), static_cast<double>(z.imag())}
{
}

template <typename T> __host__ __device__ inline complex<double> &operator=(const complex<T> &z)
template <typename T> __host__ __device__ inline complex &operator=(const complex<T> &z)
{
real(z.real());
imag(z.imag());
Expand Down Expand Up @@ -575,9 +569,9 @@ template <> struct complex<int8_t> : public char2 {
public:
typedef int8_t value_type;

complex<int8_t>() = default;
complex() = default;

constexpr complex<int8_t>(const int8_t &re, const int8_t &im = int8_t()) : char2 {re, im} { }
constexpr complex(const int8_t &re, const int8_t &im = int8_t()) : char2 {re, im} { }

__host__ __device__ inline complex<int8_t> &operator+=(const complex<int8_t> &z)
{
Expand Down Expand Up @@ -611,9 +605,9 @@ struct complex <short> : public short2
public:
typedef short value_type;

complex<short>() = default;
complex() = default;

constexpr complex<short>(const short &re, const short &im = short()) : short2 {re, im} { }
constexpr complex(const short &re, const short &im = short()) : short2 {re, im} { }

__host__ __device__ inline complex<short> &operator+=(const complex<short> &z)
{
Expand Down Expand Up @@ -647,9 +641,9 @@ struct complex <int> : public int2
public:
typedef int value_type;

complex<int>() = default;
complex() = default;

constexpr complex<int>(const int &re, const int &im = int()) : int2 {re, im} { }
constexpr complex(const int &re, const int &im = int()) : int2 {re, im} { }

__host__ __device__ inline complex<int> &operator+=(const complex<int> &z)
{
Expand Down
12 changes: 11 additions & 1 deletion include/dslash.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,10 @@ namespace quda
if (arg.xpay) strcat(aux_base, ",xpay");
if (arg.dagger) strcat(aux_base, ",dagger");
setRHSstring(aux_base, in.size());
strcat(aux_base, ",n_rhs_tile=");
char tile_str[16];
i32toa(tile_str, Arg::n_src_tile);
strcat(aux_base, tile_str);
}

/**
Expand Down Expand Up @@ -329,7 +333,13 @@ namespace quda

Dslash(Arg &arg, cvector_ref<ColorSpinorField> &out, cvector_ref<const ColorSpinorField> &in,
const ColorSpinorField &halo, const std::string &app_base = "") :
TunableKernel3D(in[0], halo.X(4), arg.nParity), arg(arg), out(out), in(in), halo(halo), nDimComms(4), dslashParam(arg)
TunableKernel3D(in[0], (halo.X(4) + Arg::n_src_tile - 1) / Arg::n_src_tile, arg.nParity),
arg(arg),
out(out),
in(in),
halo(halo),
nDimComms(4),
dslashParam(arg)
{
if (checkLocation(out, in) == QUDA_CPU_FIELD_LOCATION)
errorQuda("CPU Fields not supported in Dslash framework yet");
Expand Down
15 changes: 7 additions & 8 deletions include/dslash_helper.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,7 @@
#include <tune_quda.h>
#include <kernel_ops.h>

#if defined(_NVHPC_CUDA)
#include <constant_kernel_arg.h>
constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::FALSE;
#else
constexpr quda::use_kernel_arg_p use_kernel_arg = quda::use_kernel_arg_p::TRUE;
#endif

#include <kernel.h>

Expand Down Expand Up @@ -242,11 +237,12 @@ namespace quda
return true;
}

template <typename Float_, int nDim_> struct DslashArg {
template <typename Float_, int nDim_, int n_src_tile_ = 1> struct DslashArg {

using Float = Float_;
using real = typename mapper<Float>::type;
static constexpr int nDim = nDim_;
static constexpr int n_src_tile = n_src_tile_; // how many RHS per thread

const int parity; // only use this for single parity fields
const int nParity; // number of parities we're working on
Expand All @@ -270,6 +266,7 @@ namespace quda
int threadDimMapLower[4];
int threadDimMapUpper[4];

int_fastdiv n_src;
int_fastdiv Ls;

// these are set with symmetric preconditioned twisted-mass dagger
Expand Down Expand Up @@ -328,6 +325,7 @@ namespace quda
exterior_threads(0),
threadDimMapLower {},
threadDimMapUpper {},
n_src(in.size()),
Ls(halo.X(4) / in.size()),
twist_a(0.0),
twist_b(0.0),
Expand Down Expand Up @@ -651,8 +649,9 @@ namespace quda
Arg arg;

dslash_functor_arg(const Arg &arg, unsigned int threads_x) :
kernel_param(dim3(threads_x, arg.dc.Ls, arg.nParity)),
arg(arg) { }
kernel_param(dim3(threads_x, (arg.dc.Ls + Arg::n_src_tile - 1) / Arg::n_src_tile, arg.nParity)), arg(arg)
{
}
};

/**
Expand Down
4 changes: 0 additions & 4 deletions include/eigen_helper.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,6 @@
#define EIGEN_USE_BLAS
#endif

#if defined(__NVCOMPILER) // WAR for nvc++ until we update to latest Eigen
#define EIGEN_DONT_VECTORIZE
#endif

#include <math.h>

// hide annoying warning
Expand Down
26 changes: 25 additions & 1 deletion include/field_cache.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ namespace quda {
*/
template <typename T>
struct FieldKey {
std::string volume; /** volume kstring */
std::string volume; /** volume string */
std::string aux; /** auxiliary string */

FieldKey() = default;
Expand Down Expand Up @@ -78,6 +78,18 @@ namespace quda {
*/
FieldTmp(const FieldKey<T> &key, const typename T::param_type &param);

/**
@brief Create a field temporary that corresponds to the field
constructed from the param struct. If a matching field is
present in the cache, it will be popped from the cache. If no
such temporary exists a temporary will be allocated.
@param[in] key Key corresponding to the field instance we
require
@param[in] param Parameter structure used to allocated
the temporary
*/
FieldTmp(typename T::param_type param);

/**
@brief Copy constructor is deleted to prevent accidental cache
bloat
Expand Down Expand Up @@ -111,6 +123,18 @@ namespace quda {
*/
template <typename T> auto getFieldTmp(const T &a) { return FieldTmp<T>(a); }

/**
@brief Get a field temporary that is identical to the field
instance argument. If a matching field is present in the cache,
it will be popped from the cache. If no such temporary exists, a
temporary will be allocated. When the destructor for the
FieldTmp is called, e.g., the returned object goes out of scope,
the temporary will be pushed onto the cache.
@param[in] a Field we wish to create a matching temporary for
*/
template <typename T> auto getFieldTmp(const typename T::param_type &param) { return FieldTmp<T>(param); }

/**
@brief Get a vector of field temporaries that are identical to
the vector instance argument. If enough matching fields are
Expand Down
Loading

0 comments on commit 94a0c38

Please sign in to comment.