diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 1b96727eda..5427b4fcb7 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -3437,19 +3437,21 @@ function Base.copyto!( Nq = 1 Nh = 1 end - bounds = window_bounds(space, bc) + (li, lw, rw, ri) = bounds = window_bounds(space, bc) + nnodes = ri - li + 1 args = ( strip_space(out, space), strip_space(bc, space), axes(out), bounds, + nnodes, Nq, Nh, ) kernel = @cuda launch = false copyto_stencil_kernel!(args...) kernel_config = CUDA.launch_configuration(kernel.fun) max_threads = kernel_config.threads - nitems = Nq * Nq * Nh + nitems = nnodes * Nq * Nq * Nh nthreads = min(max_threads, nitems) nblocks = cld(nitems, nthreads) # executed @@ -3460,14 +3462,33 @@ function Base.copyto!( return out end -function copyto_stencil_kernel!(out, bc, space, bds, Nq, Nh) +function copyto_stencil_kernel!(out, bc, space, bds, nnodes, Nq, Nh) gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x - if gid ≤ Nq * Nq * Nh - h = cld(gid, Nq * Nq) - j = cld(gid - (h - 1) * Nq * Nq, Nq) - i = gid - (h - 1) * Nq * Nq - (j - 1) * Nq + if gid ≤ nnodes * Nq * Nq * Nh + (li, lw, rw, ri) = bds + h = cld(gid, nnodes * Nq * Nq) + j = cld(gid - (h - 1) * nnodes * Nq * Nq, nnodes * Nq) + i = cld( + gid - (h - 1) * nnodes * Nq * Nq - (j - 1) * nnodes * Nq, + nnodes, + ) + ndidx = + gid - (h - 1) * nnodes * Nq * Nq - (j - 1) * nnodes * Nq - + (i - 1) * nnodes hidx = (i, j, h) - apply_stencil!(space, out, bc, hidx, bds) + #apply_stencil!(space, out, bc, hidx, bds) + fun = + !Topologies.isperiodic(Spaces.vertical_topology(space)) ? + ( + ndidx ≤ lw - 1 ? + LeftBoundaryWindow{Spaces.left_boundary_name(space)}() : + ( + ndidx ≥ rw + 1 ? + RightBoundaryWindow{Spaces.right_boundary_name(space)}() : + Interior() + ) + ) : Interior() + setidx!(space, out, ndidx, hidx, getidx(space, bc, fun, ndidx, hidx)) end return nothing end