Skip to content

Commit

Permalink
Switch CLArray to SVM. (#230)
Browse files Browse the repository at this point in the history
This makes it possible to derive device pointers from CLArrays,
a necessity for being able to pass arbitrary Julia objects to kernels.
  • Loading branch information
maleadt authored Sep 12, 2024
1 parent 0909a30 commit 1eb4e59
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 25 deletions.
20 changes: 10 additions & 10 deletions src/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,15 @@ export CLArray, CLMatrix, CLVector, buffer

mutable struct CLArray{T, N} <: AbstractArray{T, N}
ctx::cl.Context
buffer::cl.Buffer{T}
buffer::cl.SVMBuffer{T} # XXX: support regular buffers too?
size::NTuple{N, Int}

function CLArray{T,N}(::UndefInitializer, dims::Dims{N};
host=:rw, device=:rw) where {T,N}
buf = cl.Buffer{T}(prod(dims); host, device)
function CLArray{T,N}(::UndefInitializer, dims::Dims{N}; access=:rw) where {T,N}
buf = cl.SVMBuffer{T}(prod(dims), access)
new(cl.context(), buf, dims)
end

function CLArray{T,N}(buf::cl.Buffer, dims::Dims) where {T,N}
function CLArray{T,N}(buf::cl.SVMBuffer, dims::Dims) where {T,N}
new(cl.context(), buf, dims)
end
end
Expand Down Expand Up @@ -59,9 +58,10 @@ end

## array interface

context(A::CLArray) = A.ctx
buffer(A::CLArray) = A.buffer
Base.pointer(A::CLArray) = A.buffer.id
context(A::CLArray) = cl.context(A.buffer)

Base.pointer(A::CLArray, i::Integer=1) = pointer(buffer(A), i)
Base.eltype(A::CLArray{T, N}) where {T, N} = T
Base.size(A::CLArray) = A.size
Base.size(A::CLArray, dim::Integer) = A.size[dim]
Expand All @@ -78,9 +78,9 @@ end
## conversions

function CLArray(hostarray::AbstractArray{T,N}; kwargs...) where {T, N}
buf = cl.Buffer(hostarray; kwargs...)
sz = size(hostarray)
CLArray{T,N}(buf, sz)
arr = CLArray{T,N}(undef, size(hostarray); kwargs...)
copyto!(arr, hostarray)
return arr
end

function Base.Array(A::CLArray{T,N}) where {T, N}
Expand Down
11 changes: 4 additions & 7 deletions test/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,20 @@ using LinearAlgebra
@testset "CLArray" begin
@testset "constructors" begin
@test CLArray{Float32,1}(undef, 1) isa CLArray{Float32,1}
@test CLArray{Float32,1}(undef, 1; device=:r) isa CLArray{Float32,1}
@test CLArray{Float32,1}(undef, 1; host=:r) isa CLArray{Float32,1}
@test CLArray{Float32,1}(undef, 1; access=:r) isa CLArray{Float32,1}

@test CLArray{Float32}(undef, 1, 2) isa CLArray{Float32,2}
@test CLArray{Float32}(undef, 1, 2; device=:r) isa CLArray{Float32,2}
@test CLArray{Float32}(undef, 1, 2; host=:r) isa CLArray{Float32,2}
@test CLArray{Float32}(undef, 1, 2; access=:r) isa CLArray{Float32,2}

@test CLArray{Float32}(undef, (1, 2)) isa CLArray{Float32,2}
@test CLArray{Float32}(undef, (1, 2); device=:r) isa CLArray{Float32,2}
@test CLArray{Float32}(undef, (1, 2); host=:r) isa CLArray{Float32,2}
@test CLArray{Float32}(undef, (1, 2); access=:r) isa CLArray{Float32,2}

hostarray = rand(Float32, 128*64)
A = CLArray(hostarray)
@test A isa CLArray{Float32,1}
@test Array(A) == hostarray

B = CLArray(hostarray; device=:r, host=:rw)
B = CLArray(hostarray; access=:r)
@test B isa CLArray{Float32,1}
@test Array(B) == hostarray

Expand Down
10 changes: 5 additions & 5 deletions test/behaviour.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ end
__kernel void part3(__global const float *a,
__global const float *b,
__global float *c,
__constant struct Params* test)
__global struct Params* test)
{
int gid = get_global_id(0);
c[gid] = test->A * a[gid] + test->B * b[gid] + test->C;
Expand Down Expand Up @@ -205,11 +205,11 @@ end
P = [Params(0.5, 10.0, [0.0, 0.0], 3)]

#TODO: constructor for single immutable types.., check if passed parameter isbits
P_arr = CLArray(P; device=:r)
P_arr = CLArray(P; access=:r)

X_arr = CLArray(X; device=:r)
Y_arr = CLArray(Y; device=:r)
R_arr = CLArray{Float32}(undef, 10; device=:w)
X_arr = CLArray(X; access=:r)
Y_arr = CLArray(Y; access=:r)
R_arr = CLArray{Float32}(undef, 10; access=:w)

global_size = size(X)
cl.clcall(part3, Tuple{Ptr{Float32}, Ptr{Float32}, Ptr{Float32}, Ptr{Params}},
Expand Down
6 changes: 3 additions & 3 deletions test/kernel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@

h_ones = ones(Float32, count)

A = CLArray(h_ones; device=:r)
B = CLArray(h_ones; device=:r)
C = CLArray{Float32}(undef, count; device=:w)
A = CLArray(h_ones; access=:r)
B = CLArray(h_ones; access=:r)
C = CLArray{Float32}(undef, count; access=:w)

# we use julia's index by one convention
@test cl.set_arg!(k, 1, buffer(A)) != nothing
Expand Down

0 comments on commit 1eb4e59

Please sign in to comment.