Skip to content

Commit

Permalink
Merge pull request #1520 from lattice/feature/sycl-merge
Browse files Browse the repository at this point in the history
clean up some casting in tests
  • Loading branch information
weinbe2 authored Nov 22, 2024
2 parents 9fa8615 + 3debb29 commit 1efcbeb
Show file tree
Hide file tree
Showing 7 changed files with 102 additions and 113 deletions.
11 changes: 5 additions & 6 deletions tests/contract_ft_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -171,19 +171,18 @@ inline int launch_contract_test(const QudaContractType cType, const std::array<i

fill_buffers<Float, 2>(buffs, X, dof);

for (int s = 0; s < nprops; ++s, off += spinor_field_floats * sizeof(Float)) {
spinorX[s] = (void *)((uintptr_t)buffs[0].data() + off);
spinorY[s] = (void *)((uintptr_t)buffs[1].data() + off);
for (int s = 0; s < nprops; ++s, off += spinor_field_floats) {
spinorX[s] = static_cast<void *>(buffs[0].data() + off);
spinorY[s] = static_cast<void *>(buffs[1].data() + off);
}
// Perform GPU contraction:
void *d_result_ = static_cast<void *>(d_result.data());

contractFTQuda(spinorX.data(), spinorY.data(), &d_result_, cType, (void *)(&cs_param), src_colors, X.data(),
source_position.data(), n_mom, mom.data(), fft_type.data());
// Check results:
int faults
= contractionFT_reference<Float>((Float **)spinorX.data(), (Float **)spinorY.data(), d_result.data(), cType,
src_colors, X.data(), source_position.data(), n_mom, mom.data(), fft_type.data());
int faults = contractionFT_reference<Float>(spinorX.data(), spinorY.data(), d_result.data(), cType, src_colors,
X.data(), source_position.data(), n_mom, mom.data(), fft_type.data());

return faults;
}
Expand Down
2 changes: 1 addition & 1 deletion tests/gauge_alg_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ struct GaugeAlgTest : public ::testing::TestWithParam<test_t> {
#ifndef QUDA_BUILD_NATIVE_FFT // skip FFT tests if FFT not available
const ::testing::TestInfo *const test_info = ::testing::UnitTest::GetInstance()->current_test_info();
const char *name = test_info->name();
if (strcmp(name, "Landau_FFT") == 0 || strcmp(name, "Coulomb_FFT") == 0) {
if (strncmp(name, "Landau_FFT", 10) == 0 || strncmp(name, "Coulomb_FFT", 11) == 0) {
execute = false;
GTEST_SKIP();
}
Expand Down
8 changes: 4 additions & 4 deletions tests/host_reference/contract_ft_reference.h
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ template <typename Float> inline void FourierPhase(Float z[2], const Float theta
};

template <typename Float>
void contractFTHost(Float **h_prop_array_flavor_1, Float **h_prop_array_flavor_2, double *h_result,
void contractFTHost(void **h_prop_array_flavor_1, void **h_prop_array_flavor_2, double *h_result,
const QudaContractType cType, const int src_colors, const int *X, const int *const source_position,
const int n_mom, const int *const mom_modes, const QudaFFTSymmType *const fft_type)
{
Expand Down Expand Up @@ -126,8 +126,8 @@ void contractFTHost(Float **h_prop_array_flavor_1, Float **h_prop_array_flavor_2
for (int c1 = 0; c1 < src_colors; c1++) {
// color contraction
size_t off = nSpin * 3 * 2 * (Vh * parity + cb_idx);
contractColors<Float>(h_prop_array_flavor_1[s1 * src_colors + c1] + off,
h_prop_array_flavor_2[s2 * src_colors + c1] + off, nSpin, M.data());
contractColors<Float>(static_cast<Float *>(h_prop_array_flavor_1[s1 * src_colors + c1]) + off,
static_cast<Float *>(h_prop_array_flavor_2[s2 * src_colors + c1]) + off, nSpin, M.data());

// apply gamma matrices here

Expand Down Expand Up @@ -158,7 +158,7 @@ void contractFTHost(Float **h_prop_array_flavor_1, Float **h_prop_array_flavor_2
};

template <typename Float>
int contractionFT_reference(Float **spinorX, Float **spinorY, const double *const d_result, const QudaContractType cType,
int contractionFT_reference(void **spinorX, void **spinorY, const double *const d_result, const QudaContractType cType,
const int src_colors, const int *X, const int *const source_position, const int n_mom,
const int *const mom_modes, const QudaFFTSymmType *const fft_type)
{
Expand Down
43 changes: 21 additions & 22 deletions tests/host_reference/gauge_force_reference.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -426,7 +426,7 @@ static void update_gauge(su3_matrix *gauge, int dir, su3_matrix **sitelink, su3_
/* This function only computes one direction @dir
*
*/
void gauge_force_reference_dir(void *refMom, int dir, double eb3, void *const *sitelink, void *const *sitelink_ex,
void gauge_force_reference_dir(void *refMom, int dir, double eb3, quda::GaugeField &u, quda::GaugeField &u_ex,
QudaPrecision prec, int **path_dir, int *length, void *loop_coeff, int num_paths,
const lattice_t &lat, bool compute_force)
{
Expand All @@ -437,26 +437,30 @@ void gauge_force_reference_dir(void *refMom, int dir, double eb3, void *const *s
for (int i = 0; i < num_paths; i++) {
if (prec == QUDA_DOUBLE_PRECISION) {
double *my_loop_coeff = (double *)loop_coeff;
compute_path_product((dsu3_matrix *)staple, (dsu3_matrix **)sitelink_ex, path_dir[i], length[i], my_loop_coeff[i],
dir, lat);
compute_path_product((dsu3_matrix *)staple, u_ex.data_array<dsu3_matrix *>().data, path_dir[i], length[i],
my_loop_coeff[i], dir, lat);
} else {
float *my_loop_coeff = (float *)loop_coeff;
compute_path_product((fsu3_matrix *)staple, (fsu3_matrix **)sitelink_ex, path_dir[i], length[i], my_loop_coeff[i],
dir, lat);
compute_path_product((fsu3_matrix *)staple, u_ex.data_array<fsu3_matrix *>().data, path_dir[i], length[i],
my_loop_coeff[i], dir, lat);
}
}

if (compute_force) {
if (prec == QUDA_DOUBLE_PRECISION) {
update_mom((danti_hermitmat *)refMom, dir, (dsu3_matrix **)sitelink, (dsu3_matrix *)staple, (double)eb3, lat);
update_mom((danti_hermitmat *)refMom, dir, u.data_array<dsu3_matrix *>().data, (dsu3_matrix *)staple, (double)eb3,
lat);
} else {
update_mom((fanti_hermitmat *)refMom, dir, (fsu3_matrix **)sitelink, (fsu3_matrix *)staple, (float)eb3, lat);
update_mom((fanti_hermitmat *)refMom, dir, u.data_array<fsu3_matrix *>().data, (fsu3_matrix *)staple, (float)eb3,
lat);
}
} else {
if (prec == QUDA_DOUBLE_PRECISION) {
update_gauge((dsu3_matrix *)refMom, dir, (dsu3_matrix **)sitelink, (dsu3_matrix *)staple, (double)eb3, lat);
update_gauge((dsu3_matrix *)refMom, dir, u.data_array<dsu3_matrix *>().data, (dsu3_matrix *)staple, (double)eb3,
lat);
} else {
update_gauge((fsu3_matrix *)refMom, dir, (fsu3_matrix **)sitelink, (fsu3_matrix *)staple, (float)eb3, lat);
update_gauge((fsu3_matrix *)refMom, dir, u.data_array<fsu3_matrix *>().data, (fsu3_matrix *)staple, (float)eb3,
lat);
}
}
host_free(staple);
Expand All @@ -465,8 +469,6 @@ void gauge_force_reference_dir(void *refMom, int dir, double eb3, void *const *s
void gauge_force_reference(void *refMom, double eb3, quda::GaugeField &u, int ***path_dir, int *length,
void *loop_coeff, int num_paths, bool compute_force)
{
void *sitelink[] = {u.data(0), u.data(1), u.data(2), u.data(3)};

// created extended field
quda::lat_dim_t R;
for (int d = 0; d < 4; d++) R[d] = 2 * quda::comm_dim_partitioned(d);
Expand All @@ -475,13 +477,12 @@ void gauge_force_reference(void *refMom, double eb3, quda::GaugeField &u, int **
param.gauge_order = QUDA_QDP_GAUGE_ORDER;
param.t_boundary = QUDA_PERIODIC_T;

auto qdp_ex = quda::createExtendedGauge((void **)sitelink, param, R);
auto qdp_ex = quda::createExtendedGauge(u.data_array().data, param, R);
lattice_t lat(*qdp_ex);

void *sitelink_ex[] = {qdp_ex->data(0), qdp_ex->data(1), qdp_ex->data(2), qdp_ex->data(3)};
for (int dir = 0; dir < 4; dir++) {
gauge_force_reference_dir(refMom, dir, eb3, sitelink, sitelink_ex, u.Precision(), path_dir[dir], length, loop_coeff,
num_paths, lat, compute_force);
gauge_force_reference_dir(refMom, dir, eb3, u, *qdp_ex, u.Precision(), path_dir[dir], length, loop_coeff, num_paths,
lat, compute_force);
}

delete qdp_ex;
Expand All @@ -490,29 +491,27 @@ void gauge_force_reference(void *refMom, double eb3, quda::GaugeField &u, int **
void gauge_loop_trace_reference(quda::GaugeField &u, std::vector<quda::Complex> &loop_traces, double factor,
int **input_path, int *length, double *path_coeff, int num_paths)
{
void *sitelink[] = {u.data(0), u.data(1), u.data(2), u.data(3)};

// create extended field
quda::lat_dim_t R;
for (int d = 0; d < 4; d++) R[d] = 2 * quda::comm_dim_partitioned(d);
QudaGaugeParam param = newQudaGaugeParam();
setGaugeParam(param);
param.gauge_order = QUDA_QDP_GAUGE_ORDER;
param.t_boundary = QUDA_PERIODIC_T;

auto qdp_ex = quda::createExtendedGauge((void **)sitelink, param, R);
auto qdp_ex = quda::createExtendedGauge(u.data_array().data, param, R);
lattice_t lat(*qdp_ex);
void *sitelink_ex[] = {qdp_ex->data(0), qdp_ex->data(1), qdp_ex->data(2), qdp_ex->data(3)};

std::vector<double> loop_tr_dbl(2 * num_paths);

for (int i = 0; i < num_paths; i++) {
if (u.Precision() == QUDA_DOUBLE_PRECISION) {
dcomplex tr = compute_loop_trace((dsu3_matrix **)sitelink_ex, input_path[i], length[i], path_coeff[i], lat);
dcomplex tr
= compute_loop_trace(qdp_ex->data_array<dsu3_matrix *>().data, input_path[i], length[i], path_coeff[i], lat);
loop_tr_dbl[2 * i] = factor * tr.real;
loop_tr_dbl[2 * i + 1] = factor * tr.imag;
} else {
dcomplex tr = compute_loop_trace((fsu3_matrix **)sitelink_ex, input_path[i], length[i], path_coeff[i], lat);
dcomplex tr
= compute_loop_trace(qdp_ex->data_array<fsu3_matrix *>().data, input_path[i], length[i], path_coeff[i], lat);
loop_tr_dbl[2 * i] = factor * tr.real;
loop_tr_dbl[2 * i + 1] = factor * tr.imag;
}
Expand Down
Loading

0 comments on commit 1efcbeb

Please sign in to comment.