Skip to content

Commit

Permalink
Apply the Arg::check_bounds to the other kernels.
Browse files Browse the repository at this point in the history
  • Loading branch information
hummingtree committed Jan 3, 2025
1 parent df97e1a commit b0b4355
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 30 deletions.
33 changes: 21 additions & 12 deletions include/targets/cuda/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,16 @@ namespace quda

auto i = threadIdx.x + blockIdx.x * blockDim.x;

while (i < arg.threads.x) {
if constexpr (Arg::check_bounds) {
while (i < arg.threads.x) {
f(i);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
} else {
f(i);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
}

Expand Down Expand Up @@ -161,15 +165,20 @@ namespace quda
auto i = threadIdx.x + blockIdx.x * blockDim.x;
auto j = threadIdx.y + blockIdx.y * blockDim.y;
auto k = threadIdx.z + blockIdx.z * blockDim.z;
if (j >= arg.threads.y) return;
if (k >= arg.threads.z) return;

while (i < arg.threads.x) {
if constexpr (Arg::check_bounds) {
if (j >= arg.threads.y) return;
if (k >= arg.threads.z) return;

while (i < arg.threads.x) {
f(i, j, k);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
} else {
f(i, j, k);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
}

Expand Down
50 changes: 32 additions & 18 deletions include/targets/hip/kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@ namespace quda

auto i = threadIdx.x + blockIdx.x * blockDim.x;

while (i < arg.threads.x) {
if constexpr (Arg::check_bounds) {
while (i < arg.threads.x) {
f(i);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
} else {
f(i);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
}

Expand All @@ -43,14 +47,19 @@ namespace quda

auto i = threadIdx.x + blockIdx.x * blockDim.x;
auto j = threadIdx.y + blockIdx.y * blockDim.y;
if (j >= arg.threads.y) return;

while (i < arg.threads.x) {
if constexpr (Arg::check_bounds) {
if (j >= arg.threads.y) return;

while (i < arg.threads.x) {
f(i, j);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
} else {
f(i, j);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
}

Expand All @@ -76,15 +85,20 @@ namespace quda
auto i = threadIdx.x + blockIdx.x * blockDim.x;
auto j = threadIdx.y + blockIdx.y * blockDim.y;
auto k = threadIdx.z + blockIdx.z * blockDim.z;
if (j >= arg.threads.y) return;
if (k >= arg.threads.z) return;

while (i < arg.threads.x) {
if constexpr (Arg::check_bounds) {
if (j >= arg.threads.y) return;
if (k >= arg.threads.z) return;

while (i < arg.threads.x) {
f(i, j, k);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
} else {
f(i, j, k);
if (grid_stride)
i += gridDim.x * blockDim.x;
else
break;
}
}

Expand Down

0 comments on commit b0b4355

Please sign in to comment.