diff --git a/include/gauge_field.h b/include/gauge_field.h index e4ec3ae09d..b1d0311709 100644 --- a/include/gauge_field.h +++ b/include/gauge_field.h @@ -385,6 +385,13 @@ namespace quda { */ virtual void copy(const GaugeField &src) = 0; + /** + * @brief Generic gauge field shift + * @param[in] src Source from which we are shifting (extended field in case of MPI) + * @param[in] dx Host array of shifts to apply to the field + */ + virtual void shift(const GaugeField &src, const array &dx) = 0; + /** @brief Compute the L1 norm of the field @param[in] dim Which dimension we are taking the norm of (dim=-1 mean all dimensions) @@ -535,6 +542,13 @@ namespace quda { */ void copy(const GaugeField &src); + /** + * @brief Generic gauge field shift + * @param[in] src Source from which we are shifting (extended field in case of MPI) + * @param[in] dx Host array of shifts to apply to the field + */ + void shift(const GaugeField &src, const array &dx); + /** @brief Download into this field from a CPU field @param[in] cpu The CPU field source @@ -672,6 +686,13 @@ namespace quda { */ void copy(const GaugeField &src); + /** + * @brief Generic gauge field shift + * @param[in] src Source from which we are shifting (extended field in case of MPI) + * @param[in] dx Host array of shifts to apply to the field + */ + void shift(const GaugeField &src, const array &dx); + void* Gauge_p() { return gauge; } const void* Gauge_p() const { return gauge; } @@ -864,4 +885,12 @@ namespace quda { #define checkReconstruct(...) Reconstruct_(__func__, __FILE__, __LINE__, __VA_ARGS__) + /** + * @brief Generic gauge field shift + * @param[out] dst Gauge field to store output + * @param[in] srd Source from which we are shifting (extended field in case of MPI) + * @param[in] dx Host array of shifts to apply to the field + */ + void gaugeShift(GaugeField &dst, const GaugeField &src, const array &dx); + } // namespace quda diff --git a/include/kernels/gauge_shift.cuh b/include/kernels/gauge_shift.cuh new file mode 100644 index 0000000000..83d7518358 --- /dev/null +++ b/include/kernels/gauge_shift.cuh @@ -0,0 +1,61 @@ +#pragma once + +#include +#include +#include +#include + +namespace quda +{ + + template struct GaugeShiftArg : kernel_param<> { + using Float = Float_; + static constexpr int nColor = nColor_; + static_assert(nColor == 3, "Only nColor=3 enabled at this time"); + typedef typename gauge_mapper::type Gauge; + + Gauge out; + const Gauge in; + + int S[4]; // the regular volume parameters + int X[4]; // the regular volume parameters + int E[4]; // the extended volume parameters + int border[4]; // radius of border + int P; // change of parity + + GaugeShiftArg(GaugeField &out, const GaugeField &in, const array &dx) : + kernel_param(dim3(in.VolumeCB(), 2, in.Geometry())), out(out), in(in) + { + P = 0; + for (int i = 0; i < 4; i++) { + S[i] = dx[i]; + X[i] = out.X()[i]; + E[i] = in.X()[i]; + border[i] = (E[i] - X[i]) / 2; + P += dx[i]; + } + P = std::abs(P) % 2; + } + }; + + template struct GaugeShift { + const Arg &arg; + constexpr GaugeShift(const Arg &arg) : arg(arg) { } + static constexpr const char *filename() { return KERNEL_FILE; } + + __device__ __host__ void operator()(int x_cb, int parity, int dir) + { + using real = typename Arg::Float; + typedef Matrix, Arg::nColor> Link; + + int x[4] = {0, 0, 0, 0}; + getCoords(x, x_cb, arg.X, parity); + for (int dr = 0; dr < 4; ++dr) x[dr] += arg.border[dr]; // extended grid coordinates + int nbr_oddbit = arg.P == 1 ? (parity ^ 1) : parity; + + Link link = arg.in(dir, linkIndexShift(x, arg.S, arg.E), nbr_oddbit); + arg.out(dir, x_cb, parity) = link; + } + }; + +} // namespace quda diff --git a/lib/CMakeLists.txt b/lib/CMakeLists.txt index dbd731d521..5038beb292 100644 --- a/lib/CMakeLists.txt +++ b/lib/CMakeLists.txt @@ -73,6 +73,7 @@ set (QUDA_OBJS copy_gauge_half.cu copy_gauge_quarter.cu copy_gauge.cpp copy_clover.cu copy_gauge_offset.cu copy_color_spinor_offset.cu copy_clover_offset.cu + gauge_shift.cu staggered_oprod.cu clover_trace_quda.cu hisq_paths_force_quda.cu unitarize_force_quda.cu unitarize_links_quda.cu milc_interface.cpp diff --git a/lib/cpu_gauge_field.cpp b/lib/cpu_gauge_field.cpp index f4b27109a8..ee4f8ec30d 100644 --- a/lib/cpu_gauge_field.cpp +++ b/lib/cpu_gauge_field.cpp @@ -334,6 +334,26 @@ namespace quda { } } + void cpuGaugeField::shift(const GaugeField &src, const array &dx) + { + for (int i = 0; i < this->nDim; i++) { + if (dx[i] != 0) break; + // if zero shift, we simply copy + if (i == this->nDim - 1) return this->copy(src); + } + if (this == &src) errorQuda("Cannot copy in itself"); + + checkField(src); + + // TODO: check src extension (needs to be enough for shifting) + + if (typeid(src) == typeid(cudaGaugeField)) { + errorQuda("Not Implemented"); + } else { + errorQuda("Not compatible type"); + } + } + void cpuGaugeField::setGauge(void **gauge_) { if(create != QUDA_REFERENCE_FIELD_CREATE) { diff --git a/lib/cuda_gauge_field.cpp b/lib/cuda_gauge_field.cpp index 568209a74f..d9f4988c34 100644 --- a/lib/cuda_gauge_field.cpp +++ b/lib/cuda_gauge_field.cpp @@ -610,6 +610,25 @@ namespace quda { qudaDeviceSynchronize(); // include sync here for accurate host-device profiling } + void cudaGaugeField::shift(const GaugeField &src, const array &dx) + { + for (int i = 0; i < this->nDim; i++) { + if (dx[i] != 0) break; + if (i == this->nDim - 1) return this->copy(src); + } + if (this == &src) errorQuda("Cannot copy in itself"); + + checkField(src); + + // TODO: check src extension (needs to be enough for shifting) + + if (typeid(src) == typeid(cudaGaugeField)) { + gaugeShift(*this, src, dx); + } else { + errorQuda("Not compatible type"); + } + } + void cudaGaugeField::loadCPUField(const cpuGaugeField &cpu) { copy(cpu); qudaDeviceSynchronize(); diff --git a/lib/gauge_shift.cu b/lib/gauge_shift.cu new file mode 100644 index 0000000000..4e3cf239fc --- /dev/null +++ b/lib/gauge_shift.cu @@ -0,0 +1,53 @@ +#include +#include +#include +#include + +namespace quda +{ + + template class ShiftGauge : public TunableKernel3D + { + GaugeField &out; + const GaugeField ∈ + const array &dx; + unsigned int minThreads() const { return in.VolumeCB(); } + + public: + ShiftGauge(GaugeField &out, const GaugeField &in, const array &dx) : + TunableKernel3D(in, 2, in.Geometry()), out(out), in(in), dx(dx) + { + strcat(aux, ",shift="); + for (int i = 0; i < in.Ndim(); i++) { strcat(aux, std::to_string(dx[i]).c_str()); } + strcat(aux, comm_dim_partitioned_string()); + apply(device::get_default_stream()); + } + + void apply(const qudaStream_t &stream) + { + TuneParam tp = tuneLaunch(*this, getTuning(), getVerbosity()); + launch(tp, stream, GaugeShiftArg(out, in, dx)); + } + + void preTune() { } + void postTune() { } + + long long flops() const { return in.Volume() * 4; } + long long bytes() const { return in.Bytes(); } + }; + + void gaugeShift(GaugeField &out, const GaugeField &in, const array &dx) + { + checkPrecision(in, out); + checkLocation(in, out); + checkReconstruct(in, out); + + if (out.Geometry() != in.Geometry()) { + errorQuda("Field geometries %d %d do not match", out.Geometry(), in.Geometry()); + } + + // gauge field must be passed as first argument so we peel off its reconstruct type + instantiate(out, in, dx); + } + +} // namespace quda