diff --git a/src/Operators/finitedifference.jl b/src/Operators/finitedifference.jl index 5427b4fcb7..b89b50a322 100644 --- a/src/Operators/finitedifference.jl +++ b/src/Operators/finitedifference.jl @@ -3420,7 +3420,6 @@ function strip_space(bc::StencilBroadcasted{Style}, parent_space) where {Style} ) end - function Base.copyto!( out::Field, bc::Union{ @@ -3438,57 +3437,70 @@ function Base.copyto!( Nh = 1 end (li, lw, rw, ri) = bounds = window_bounds(space, bc) - nnodes = ri - li + 1 - args = ( - strip_space(out, space), - strip_space(bc, space), + ninteriornodes = rw - lw + 1 + + max_threads = 256 + nitemsbdy = Nq * Nq * Nh # # of independent boundary items + nitemsint = ninteriornodes * Nq * Nq * Nh # # of independent interior items + (nthreadsbdy, nblocksbdy) = Spaces._configure_threadblock(nitemsbdy) + (nthreadsint, nblocksint) = Spaces._configure_threadblock(nitemsint) + isnotperiodic = !Topologies.isperiodic(Spaces.vertical_topology(space)) + strip_space_out = strip_space(out, space) + strip_space_bc = strip_space(bc, space) + # left and right windows, if applicable + isnotperiodic && + @cuda threads = (nthreadsbdy,) blocks = (nblocksbdy,) copyto_stencil_bdy_kernel!( + strip_space_out, + strip_space_bc, + axes(out), + bounds, + Nq, + Nh, + ) + # interior nodes + @cuda threads = (nthreadsint,) blocks = (nblocksint,) copyto_stencil_interior_kernel!( + strip_space_out, + strip_space_bc, axes(out), bounds, - nnodes, + ninteriornodes, Nq, Nh, ) - kernel = @cuda launch = false copyto_stencil_kernel!(args...) - kernel_config = CUDA.launch_configuration(kernel.fun) - max_threads = kernel_config.threads - nitems = nnodes * Nq * Nq * Nh - nthreads = min(max_threads, nitems) - nblocks = cld(nitems, nthreads) - # executed - @cuda threads = (nthreads,) blocks = (nblocks,) copyto_stencil_kernel!( - args..., - ) - return out end -function copyto_stencil_kernel!(out, bc, space, bds, nnodes, Nq, Nh) +function copyto_stencil_bdy_kernel!(out, bc, space, bds, Nq, Nh) gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x - if gid ≤ nnodes * Nq * Nq * Nh + if gid ≤ Nq * Nq * Nh (li, lw, rw, ri) = bds - h = cld(gid, nnodes * Nq * Nq) - j = cld(gid - (h - 1) * nnodes * Nq * Nq, nnodes * Nq) - i = cld( - gid - (h - 1) * nnodes * Nq * Nq - (j - 1) * nnodes * Nq, - nnodes, - ) - ndidx = - gid - (h - 1) * nnodes * Nq * Nq - (j - 1) * nnodes * Nq - - (i - 1) * nnodes + hidx = Spaces._get_idx((Nq, Nq, Nh), gid) + lbw = LeftBoundaryWindow{Spaces.left_boundary_name(space)}() + rbw = RightBoundaryWindow{Spaces.right_boundary_name(space)}() + @inbounds for idx in li:(lw - 1) + setidx!(space, out, idx, hidx, getidx(space, bc, lbw, idx, hidx)) + end + @inbounds for idx in (rw + 1):ri + setidx!(space, out, idx, hidx, getidx(space, bc, rbw, idx, hidx)) + end + end + return nothing +end + +function copyto_stencil_interior_kernel!(out, bc, space, bds, nnodes, Nq, Nh) + gid = threadIdx().x + (blockIdx().x - 1) * blockDim().x + if gid ≤ nnodes * Nq * Nq * Nh + (_, lw, rw, _) = bds + (ndidx, i, j, h) = Spaces._get_idx((nnodes, Nq, Nq, Nh), gid) hidx = (i, j, h) - #apply_stencil!(space, out, bc, hidx, bds) - fun = - !Topologies.isperiodic(Spaces.vertical_topology(space)) ? - ( - ndidx ≤ lw - 1 ? - LeftBoundaryWindow{Spaces.left_boundary_name(space)}() : - ( - ndidx ≥ rw + 1 ? - RightBoundaryWindow{Spaces.right_boundary_name(space)}() : - Interior() - ) - ) : Interior() - setidx!(space, out, ndidx, hidx, getidx(space, bc, fun, ndidx, hidx)) + ndidx += lw - 1 + setidx!( + space, + out, + ndidx, + hidx, + getidx(space, bc, Interior(), ndidx, hidx), + ) end return nothing end diff --git a/src/Spaces/dss_cuda.jl b/src/Spaces/dss_cuda.jl index f32a0d63ab..e20c974400 100644 --- a/src/Spaces/dss_cuda.jl +++ b/src/Spaces/dss_cuda.jl @@ -1,12 +1,16 @@ _max_threads_cuda() = 256 -function _configure_threadblock(nitems) - nthreads = min(_max_threads_cuda(), nitems) + +function _configure_threadblock(max_threads, nitems) + nthreads = min(max_threads, nitems) nblocks = cld(nitems, nthreads) return (nthreads, nblocks) end +_configure_threadblock(nitems) = + _configure_threadblock(_max_threads_cuda(), nitems) + function dss_load_perimeter_data!( ::ClimaComms.CUDADevice, dss_buffer::DSSBuffer,