Skip to content

Commit

Permalink
Add SVM Buffer type.
Browse files Browse the repository at this point in the history
  • Loading branch information
maleadt committed Sep 12, 2024
1 parent c544319 commit a4534b4
Show file tree
Hide file tree
Showing 4 changed files with 164 additions and 10 deletions.
1 change: 1 addition & 0 deletions lib/CL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ include("cmdqueue.jl")
include("event.jl")
include("memory.jl")
include("buffer.jl")
include("svm.jl")
include("program.jl")
include("kernel.jl")

Expand Down
18 changes: 8 additions & 10 deletions lib/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,13 +119,9 @@ end
Base.unsafe_copyto!(dst::Buffer, src::Buffer, N; kwargs...) =
unsafe_copyto!(dst, 1, src, 1, N; kwargs...)

# map a a buffer into the host address space and return a (pinned) array and an event
function unsafe_map!(b::Buffer{T}, dims::Dims, flags=:rw; offset::Integer=0,
# map a buffer into the host address space and return a (pinned) array and an event
function unsafe_map!(b::Buffer{T}, dims::Dims, flags=:rw; offset::Integer=1,
blocking::Bool=false, wait_for::Vector{Event}=Event[]) where {T}
if length(b) < prod(dims) + offset
throw(ArgumentError("Buffer length must be greater than or
equal to prod(dims) + offset"))
end
n_evts = length(wait_for)
evt_ids = isempty(wait_for) ? C_NULL : [pointer(evt) for evt in wait_for]
flags = if flags == :rw
Expand All @@ -140,8 +136,9 @@ function unsafe_map!(b::Buffer{T}, dims::Dims, flags=:rw; offset::Integer=0,
nbytes = prod(dims) * sizeof(T)
ret_evt = Ref{cl_event}()
status = Ref{Cint}()
byteoffset = (offset - 1) * sizeof(T)
mapped = clEnqueueMapBuffer(queue(), b, blocking,
flags, offset, nbytes,
flags, byteoffset, nbytes,
n_evts, evt_ids, ret_evt, status)
if status[] != CL_SUCCESS
throw(CLError(status[]))
Expand All @@ -155,7 +152,7 @@ function unsafe_unmap!(b::Buffer{T}, a::Array{T}; wait_for::Vector{Event}=Event[
n_evts = length(wait_for)
evt_ids = isempty(wait_for) ? C_NULL : [pointer(evt) for evt in wait_for]
ret_evt = Ref{cl_event}()
clEnqueueUnmapMemObject(queue(), b, a, n_evts, evt_ids, ret_evt)
clEnqueueUnmapMemObject(queue(), b, pointer(a), n_evts, evt_ids, ret_evt)
return Event(ret_evt[])
end

Expand All @@ -167,10 +164,11 @@ function unsafe_fill!(b::Buffer{T}, pattern::T, offset::Integer, N::Integer;
ret_evt = Ref{cl_event}()
nbytes = N * sizeof(T)
nbytes_pattern = sizeof(T)
byteoffset = (offset - 1) * sizeof(T)
@assert nbytes_pattern > 0
clEnqueueFillBuffer(queue(), b, [pattern],
nbytes_pattern, offset, nbytes,
nbytes_pattern, byteoffset, nbytes,
n_evts, evt_ids, ret_evt)
@return_event ret_evt[]
end
unsafe_fill!(b::Buffer, pattern, N::Integer) = unsafe_fill!(b, pattern, 0, N)
unsafe_fill!(b::Buffer, pattern, N::Integer) = unsafe_fill!(b, pattern, 1, N)
112 changes: 112 additions & 0 deletions lib/svm.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
mutable struct SVMBuffer{T}
ptr::Ptr{T}
len::Int

function SVMBuffer{T}(len::Integer, typ::Symbol=:rw;
fine_grained=false, alignment=nothing) where {T}
flags = if typ == :rw
CL_MEM_READ_WRITE
elseif typ == :r
CL_MEM_READ_ONLY
elseif typ == :w
CL_MEM_WRITE_ONLY
else
throw(ArgumentError("Invalid access type"))
end

if fine_grained
flags |= CL_MEM_SVM_FINE_GRAIN_BUFFER
end

ptr = clSVMAlloc(context(), flags, len * sizeof(T), something(alignment, 0))
obj = new{T}(ptr, len)
finalizer(obj) do x
# TODO: asynchronous free using clEnqueueSVMFree?
clSVMFree(context(), x)
end

return obj
end
end

Base.unsafe_convert(::Type{Ptr{T}}, x::SVMBuffer) where {T} = convert(Ptr{T}, x.ptr)
@inline function Base.pointer(x::SVMBuffer{T}, i::Integer=1) where T
Base.unsafe_convert(Ptr{T}, x) + (i-1)*sizeof(T)
end

Base.ndims(b::SVMBuffer) = 1
Base.eltype(b::SVMBuffer{T}) where {T} = T
Base.length(b::SVMBuffer{T}) where {T} = b.len
Base.sizeof(b::SVMBuffer{T}) where {T} = b.len * sizeof(T)


## memory operations

# these generally only make sense for coarse-grained SVM buffers;
# fine-grained buffers can just be used directly.

# copy from and to SVM buffers
for (srcty, dstty) in [(:Array, :SVMBuffer), (:SVMBuffer, :Array), (:SVMBuffer, :SVMBuffer)]
@eval begin
function Base.unsafe_copyto!(dst::$dstty{T}, dst_off::Int, src::$srcty{T}, src_off::Int,
N::Int; blocking::Bool=false,
wait_for::Vector{Event}=Event[]) where T
nbytes = N * sizeof(T)
n_evts = length(wait_for)
evt_ids = isempty(wait_for) ? C_NULL : [pointer(evt) for evt in wait_for]
ret_evt = Ref{cl_event}()
clEnqueueSVMMemcpy(queue(), blocking, pointer(dst, dst_off),
pointer(src, src_off), nbytes, n_evts, evt_ids, ret_evt)
@return_nanny_event(ret_evt[], dst)
end
Base.unsafe_copyto!(dst::$dstty, src::$srcty, N; kwargs...) =
unsafe_copyto!(dst, 1, src, 1, N; kwargs...)
end
end

# map an SVM buffer into the host address space and return a (pinned) array and an event
function unsafe_map!(b::SVMBuffer{T}, dims::Dims, flags=:rw; offset::Integer=1,
blocking::Bool=false, wait_for::Vector{Event}=Event[]) where {T}
n_evts = length(wait_for)
evt_ids = isempty(wait_for) ? C_NULL : [pointer(evt) for evt in wait_for]
flags = if flags == :rw
CL_MAP_READ | CL_MAP_WRITE
elseif flags == :r
CL_MAP_READ
elseif flags == :w
CL_MAP_WRITE
else
throw(ArgumentError("enqueue_unmap can have flags of :r, :w, or :rw, got :$flags"))
end
nbytes = prod(dims) * sizeof(T)
ret_evt = Ref{cl_event}()
clEnqueueSVMMap(queue(), blocking, flags, pointer(b, offset), nbytes,
n_evts, evt_ids, ret_evt)

return unsafe_wrap(Array, pointer(b, offset), dims; own=false), Event(ret_evt[])
end

# unmap a buffer, return an event
function unsafe_unmap!(b::SVMBuffer{T}, a::Array{T}; wait_for::Vector{Event}=Event[]) where {T}
n_evts = length(wait_for)
evt_ids = isempty(wait_for) ? C_NULL : [pointer(evt) for evt in wait_for]
ret_evt = Ref{cl_event}()
clEnqueueSVMUnmap(queue(), pointer(a), n_evts, evt_ids, ret_evt)
return Event(ret_evt[])
end

# fill a buffer with a pattern, returning an event
function unsafe_fill!(b::SVMBuffer{T}, pattern::T, offset::Integer, N::Integer;
wait_for::Vector{Event}=Event[]) where {T}
n_evts = length(wait_for)
evt_ids = isempty(wait_for) ? C_NULL : [pointer(evt) for evt in wait_for]
ret_evt = Ref{cl_event}()
nbytes = N * sizeof(T)
nbytes_pattern = sizeof(T)
@assert nbytes_pattern > 0
clEnqueueSVMMemFill(queue(), pointer(b, offset), [pattern],
nbytes_pattern, nbytes,
n_evts, evt_ids, ret_evt)
@return_event ret_evt[]
end
unsafe_fill!(b::SVMBuffer, pattern, N::Integer) = unsafe_fill!(b, pattern, 1, N)
43 changes: 43 additions & 0 deletions test/buffer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -60,3 +60,46 @@
@test arr == [42,42,42]
end
end


@testset "SVMBuffer" begin
# simple buffer
let buf = cl.SVMBuffer{Int}(1)
@test ndims(buf) == 1
@test eltype(buf) == Int
@test length(buf) == 1
@test sizeof(buf) == sizeof(Int)
end

# memory copy
let buf = cl.SVMBuffer{Int}(1)
unsafe_copyto!(buf, [42], 1; blocking=true)

arr = [0]
cl.unsafe_copyto!(arr, buf, 1; blocking=true)
@test arr == [42]
end

# memory map
let buf = cl.SVMBuffer{Int}(1)
unsafe_copyto!(buf, [42], 1; blocking=true)

arr, evt = cl.unsafe_map!(buf, (1,), :rw)
wait(evt)
@test arr[] == 42
arr[] = 100
cl.unsafe_unmap!(buf, arr) |> wait

res = [0]
cl.unsafe_copyto!(res, buf, 1; blocking=true)
@test res == [100]
end

# fill
let buf = cl.SVMBuffer{Int}(3)
cl.unsafe_fill!(buf, 42, 3)
arr = Vector{Int}(undef, 3)
unsafe_copyto!(arr, buf, 3; blocking=true)
@test arr == [42,42,42]
end
end

0 comments on commit a4534b4

Please sign in to comment.