diff --git a/Project.toml b/Project.toml index 49f1217..ef3b247 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "oneAPI" uuid = "8f75cd03-7ff8-4ecb-9b8f-daf728133b1b" authors = ["Tim Besard "] -version = "1.6.1" +version = "2.0.0" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" @@ -17,6 +17,7 @@ NEO_jll = "700fe977-ac61-5f37-bbc8-c6c4b2b6a9fd" Preferences = "21216c6a-2e73-6563-6e65-726566657250" Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +SPIRVIntrinsics = "71d1d633-e7e8-4a92-83a1-de8814b09ba8" SPIRV_LLVM_Translator_unified_jll = "85f0d8ed-5b39-5caa-b1ae-7472de402361" SPIRV_Tools_jll = "6ac6d60f-d740-5983-97d7-a4482c0689f4" SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" @@ -36,6 +37,7 @@ KernelAbstractions = "0.9.1" LLVM = "6, 7, 8, 9" NEO_jll = "=24.26.30049" Preferences = "1" +SPIRVIntrinsics = "0.2" SPIRV_LLVM_Translator_unified_jll = "0.4" SpecialFunctions = "1.3, 2" StaticArrays = "1" diff --git a/src/device/array.jl b/src/device/array.jl index 8ad97ca..753d453 100644 --- a/src/device/array.jl +++ b/src/device/array.jl @@ -1,6 +1,6 @@ # Contiguous on-device arrays -export oneDeviceArray, oneDeviceVector, oneDeviceMatrix +export oneDeviceArray, oneDeviceVector, oneDeviceMatrix, oneLocalArray ## construction @@ -240,3 +240,16 @@ function Base.reinterpret(::Type{T}, a::oneDeviceArray{S,N,A}) where {T,S,N,A} osize = tuple(size1, Base.tail(isize)...) return oneDeviceArray{T,N,A}(osize, reinterpret(LLVMPtr{T,A}, a.ptr), a.maxsize) end + + +## local memory + +export oneLocalArray + +@inline function oneLocalArray(::Type{T}, dims) where {T} + len = prod(dims) + # NOTE: this relies on const-prop to forward the literal length to the generator. + # maybe we should include the size in the type, like StaticArrays does? + ptr = emit_localmemory(T, Val(len)) + oneDeviceArray(dims, ptr) +end diff --git a/src/device/opencl/atomic.jl b/src/device/opencl/atomic.jl deleted file mode 100644 index 97b316e..0000000 --- a/src/device/opencl/atomic.jl +++ /dev/null @@ -1,264 +0,0 @@ -# Atomic Functions - -# TODO: support for 64-bit atomics via atom_cmpxchg (from cl_khr_int64_base_atomics) - -# "atomic operations on 32-bit signed, unsigned integers and single precision -# floating-point to locations in __global or __local memory" - -const atomic_integer_types = [UInt32, Int32] -# TODO: 64-bit atomics with ZE_DEVICE_MODULE_FLAG_INT64_ATOMICS -# TODO: additional floating-point atomics with ZE_extension_float_atomics -const atomic_memory_types = [AS.Local, AS.Global] - - -# generically typed - -for gentype in atomic_integer_types, as in atomic_memory_types -@eval begin - -@device_function atomic_add!(p::LLVMPtr{$gentype,$as}, val::$gentype) = - @builtin_ccall("atomic_add", $gentype, - (LLVMPtr{$gentype,$as}, $gentype), p, val) - -@device_function atomic_sub!(p::LLVMPtr{$gentype,$as}, val::$gentype) = - @builtin_ccall("atomic_sub", $gentype, - (LLVMPtr{$gentype,$as}, $gentype), p, val) - -@device_function atomic_inc!(p::LLVMPtr{$gentype,$as}) = - @builtin_ccall("atomic_inc", $gentype, (LLVMPtr{$gentype,$as},), p) - -@device_function atomic_dec!(p::LLVMPtr{$gentype,$as}) = - @builtin_ccall("atomic_dec", $gentype, (LLVMPtr{$gentype,$as},), p) - -@device_function atomic_min!(p::LLVMPtr{$gentype,$as}, val::$gentype) = - @builtin_ccall("atomic_min", $gentype, - (LLVMPtr{$gentype,$as}, $gentype), p, val) - -@device_function atomic_max!(p::LLVMPtr{$gentype,$as}, val::$gentype) = - @builtin_ccall("atomic_max", $gentype, - (LLVMPtr{$gentype,$as}, $gentype), p, val) - -@device_function atomic_and!(p::LLVMPtr{$gentype,$as}, val::$gentype) = - @builtin_ccall("atomic_and", $gentype, - (LLVMPtr{$gentype,$as}, $gentype), p, val) - -@device_function atomic_or!(p::LLVMPtr{$gentype,$as}, val::$gentype) = - @builtin_ccall("atomic_or", $gentype, - (LLVMPtr{$gentype,$as}, $gentype), p, val) - -@device_function atomic_xor!(p::LLVMPtr{$gentype,$as}, val::$gentype) = - @builtin_ccall("atomic_xor", $gentype, - (LLVMPtr{$gentype,$as}, $gentype), p, val) - -@device_function atomic_xchg!(p::LLVMPtr{$gentype,$as}, val::$gentype) = - @builtin_ccall("atomic_xchg", $gentype, - (LLVMPtr{$gentype,$as}, $gentype), p, val) - -@device_function atomic_cmpxchg!(p::LLVMPtr{$gentype,$as}, cmp::$gentype, val::$gentype) = - @builtin_ccall("atomic_cmpxchg", $gentype, - (LLVMPtr{$gentype,$as}, $gentype, $gentype), p, cmp, val) - -end -end - - -# specifically typed - -for as in atomic_memory_types -@eval begin - -@device_function atomic_xchg!(p::LLVMPtr{Float32,$as}, val::Float32) = - @builtin_ccall("atomic_xchg", Float32, (LLVMPtr{Float32,$as}, Float32,), p, val) - -# XXX: why is only xchg supported on floats? isn't it safe for cmpxchg too, -# which should only perform bitwise comparisons? -@device_function atomic_cmpxchg!(p::LLVMPtr{Float32,$as}, cmp::Float32, val::Float32) = - reinterpret(Float32, atomic_cmpxchg!(reinterpret(LLVMPtr{UInt32,$as}, p), - reinterpret(UInt32, cmp), - reinterpret(UInt32, val))) - -end -end - - - -# documentation - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute `old + val` and store result at location pointed by `p`. The function -returns `old`. -""" -atomic_add! - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute `old - val` and store result at location pointed by `p`. The function -returns `old`. -""" -atomic_sub! - -""" -Swaps the old value stored at location `p` with new value given by `val`. -Returns old value. -""" -atomic_xchg! - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute (`old` + 1) and store result at location pointed by `p`. The function -returns `old`. -""" -atomic_inc! - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute (`old` - 1) and store result at location pointed by `p`. The function -returns `old`. -""" -atomic_dec! - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute `(old == cmp) ? val : old` and store result at location pointed by `p`. -The function returns `old`. -""" -atomic_cmpxchg! - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute `min(old, val)` and store minimum value at location pointed by `p`. The -function returns `old`. -""" -atomic_min! - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute `max(old, val)` and store maximum value at location pointed by `p`. The -function returns `old`. -""" -atomic_max - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute `old & val` and store result at location pointed by `p`. The function -returns `old`. -""" -atomic_and! - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute `old | val` and store result at location pointed by `p`. The function -returns `old`. -""" -atomic_or! - -""" -Read the 32-bit value (referred to as `old`) stored at location pointed by `p`. -Compute `old ^ val` and store result at location pointed by `p`. The function -returns `old`. -""" -atomic_xor! - - - -# -# High-level interface -# - -# prototype of a high-level interface for performing atomic operations on arrays -# -# this design could be generalized by having atomic {field,array}{set,ref} accessors, as -# well as acquire/release operations to implement the fallback functionality where any -# operation can be applied atomically. - -const inplace_ops = Dict( - :(+=) => :(+), - :(-=) => :(-), - :(*=) => :(*), - :(/=) => :(/), - :(÷=) => :(÷), - :(&=) => :(&), - :(|=) => :(|), - :(⊻=) => :(⊻), -) - -struct AtomicError <: Exception - msg::AbstractString -end - -Base.showerror(io::IO, err::AtomicError) = - print(io, "AtomicError: ", err.msg) - -""" - @atomic a[I] = op(a[I], val) - @atomic a[I] ...= val - -Atomically perform a sequence of operations that loads an array element `a[I]`, performs the -operation `op` on that value and a second value `val`, and writes the result back to the -array. This sequence can be written out as a regular assignment, in which case the same -array element should be used in the left and right hand side of the assignment, or as an -in-place application of a known operator. In both cases, the array reference should be pure -and not induce any side-effects. - -!!! warn - This interface is experimental, and might change without warning. Use the lower-level - `atomic_...!` functions for a stable API. -""" -macro atomic(ex) - # decode assignment and call - if ex.head == :(=) - ref = ex.args[1] - rhs = ex.args[2] - Meta.isexpr(rhs, :call) || throw(AtomicError("right-hand side of an @atomic assignment should be a call")) - op = rhs.args[1] - if rhs.args[2] != ref - throw(AtomicError("right-hand side of a non-inplace @atomic assignment should reference the left-hand side")) - end - val = rhs.args[3] - elseif haskey(inplace_ops, ex.head) - op = inplace_ops[ex.head] - ref = ex.args[1] - val = ex.args[2] - else - throw(AtomicError("unknown @atomic expression")) - end - - # decode array expression - Meta.isexpr(ref, :ref) || throw(AtomicError("@atomic should be applied to an array reference expression")) - array = ref.args[1] - indices = Expr(:tuple, ref.args[2:end]...) - - esc(quote - $atomic_arrayset($array, $indices, $op, $val) - end) -end - -# FIXME: make this respect the indexing style -@inline atomic_arrayset(A::AbstractArray{T}, Is::Tuple, op::Function, val) where {T} = - atomic_arrayset(A, Base._to_linear_index(A, Is...), op, convert(T, val)) - -# native atomics -for (op,impl) in [(+) => atomic_add!, - (-) => atomic_sub!, - (&) => atomic_and!, - (|) => atomic_or!, - (⊻) => atomic_xor!, - Base.max => atomic_max!, - Base.min => atomic_min!] - @eval @inline atomic_arrayset(A::oneDeviceArray{T}, I::Integer, ::typeof($op), - val::T) where {T <: Union{Int32,UInt32}} = - $impl(pointer(A, I), val) -end - -# fallback using compare-and-swap -function atomic_arrayset(A::AbstractArray{T}, I::Integer, op::Function, val) where {T} - ptr = pointer(A, I) - old = Base.unsafe_load(ptr, 1) - while true - cmp = old - new = convert(T, op(old, val)) - old = atomic_cmpxchg!(ptr, cmp, new) - (old == cmp) && return new - end -end diff --git a/src/device/opencl/integer.jl b/src/device/opencl/integer.jl deleted file mode 100644 index edbd4bb..0000000 --- a/src/device/opencl/integer.jl +++ /dev/null @@ -1,53 +0,0 @@ -# Integer Functions - -# TODO: vector types -const generic_integer_types = [Int8, UInt8, Int16, UInt16, Int32, UInt32, Int64, UInt64] - - -# generically typed - -for gentype in generic_integer_types -@eval begin - -@device_override Base.abs(x::$gentype) = @builtin_ccall("abs", $gentype, ($gentype,), x) -@device_function abs_diff(x::$gentype, y::$gentype) = @builtin_ccall("abs_diff", $gentype, ($gentype, $gentype), x, y) - -@device_function add_sat(x::$gentype, y::$gentype) = @builtin_ccall("add_sat", $gentype, ($gentype, $gentype), x, y) -@device_function hadd(x::$gentype, y::$gentype) = @builtin_ccall("hadd", $gentype, ($gentype, $gentype), x, y) -@device_function rhadd(x::$gentype, y::$gentype) = @builtin_ccall("rhadd", $gentype, ($gentype, $gentype), x, y) - -@device_override Base.clamp(x::$gentype, minval::$gentype, maxval::$gentype) = @builtin_ccall("clamp", $gentype, ($gentype, $gentype, $gentype), x, minval, maxval) - -@device_function clz(x::$gentype) = @builtin_ccall("clz", $gentype, ($gentype,), x) -@device_function ctz(x::$gentype) = @builtin_ccall("ctz", $gentype, ($gentype,), x) - -@device_function mad_hi(a::$gentype, b::$gentype, c::$gentype) = @builtin_ccall("mad_hi", $gentype, ($gentype, $gentype, $gentype), a, b, c) -@device_function mad_sat(a::$gentype, b::$gentype, c::$gentype) = @builtin_ccall("mad_sat", $gentype, ($gentype, $gentype, $gentype), a, b, c) - -# XXX: these definitions introduce ambiguities -#@device_override Base.max(x::$gentype, y::$gentype) = @builtin_ccall("max", $gentype, ($gentype, $gentype), x, y) -#@device_override Base.min(x::$gentype, y::$gentype) = @builtin_ccall("min", $gentype, ($gentype, $gentype), x, y) - -@device_function mul_hi(x::$gentype, y::$gentype) = @builtin_ccall("mul_hi", $gentype, ($gentype, $gentype), x, y) - -@device_function rotate(v::$gentype, i::$gentype) = @builtin_ccall("rotate", $gentype, ($gentype, $gentype), v, i) - -@device_function sub_sat(x::$gentype, y::$gentype) = @builtin_ccall("sub_sat", $gentype, ($gentype, $gentype), x, y) - -@device_function popcount(x::$gentype) = @builtin_ccall("popcount", $gentype, ($gentype,), x) - -@device_function mad24(x::$gentype, y::$gentype, z::$gentype) = @builtin_ccall("mad24", $gentype, ($gentype, $gentype, $gentype), x, y, z) -@device_function mul24(x::$gentype, y::$gentype) = @builtin_ccall("mul24", $gentype, ($gentype, $gentype), x, y) - -end -end - - -# specifically typed - -@device_function upsample(hi::Cchar, lo::Cuchar) = @builtin_ccall("upsample", Cshort, (Cchar, Cuchar), hi, lo) -upsample(hi::Cuchar, lo::Cuchar) = @builtin_ccall("upsample", Cushort, (Cuchar, Cuchar), hi, lo) -upsample(hi::Cshort, lo::Cushort) = @builtin_ccall("upsample", Cint, (Cshort, Cushort), hi, lo) -upsample(hi::Cushort, lo::Cushort) = @builtin_ccall("upsample", Cuint, (Cushort, Cushort), hi, lo) -upsample(hi::Cint, lo::Cuint) = @builtin_ccall("upsample", Clong, (Cint, Cuint), hi, lo) -upsample(hi::Cuint, lo::Cuint) = @builtin_ccall("upsample", Culong, (Cuint, Cuint), hi, lo) diff --git a/src/device/opencl/math.jl b/src/device/opencl/math.jl deleted file mode 100644 index 1e4c2a9..0000000 --- a/src/device/opencl/math.jl +++ /dev/null @@ -1,214 +0,0 @@ -# Math Functions - -# TODO: vector types -const generic_types = [Float32,Float64] -const generic_types_float = [Float32] -const generic_types_double = [Float64] - - -# generically typed - -for gentype in generic_types -@eval begin - -@device_override Base.acos(x::$gentype) = @builtin_ccall("acos", $gentype, ($gentype,), x) -@device_override Base.acosh(x::$gentype) = @builtin_ccall("acosh", $gentype, ($gentype,), x) -@device_function acospi(x::$gentype) = @builtin_ccall("acospi", $gentype, ($gentype,), x) - -@device_override Base.asin(x::$gentype) = @builtin_ccall("asin", $gentype, ($gentype,), x) -@device_override Base.asinh(x::$gentype) = @builtin_ccall("asinh", $gentype, ($gentype,), x) -@device_function asinpi(x::$gentype) = @builtin_ccall("asinpi", $gentype, ($gentype,), x) - -@device_override Base.atan(y_over_x::$gentype) = @builtin_ccall("atan", $gentype, ($gentype,), y_over_x) -@device_override Base.atan(y::$gentype, x::$gentype) = @builtin_ccall("atan2", $gentype, ($gentype, $gentype), y, x) -@device_override Base.atanh(x::$gentype) = @builtin_ccall("atanh", $gentype, ($gentype,), x) -@device_function atanpi(x::$gentype) = @builtin_ccall("atanpi", $gentype, ($gentype,), x) -@device_function atanpi(y::$gentype, x::$gentype) = @builtin_ccall("atan2pi", $gentype, ($gentype, $gentype), y, x) - -@device_override Base.cbrt(x::$gentype) = @builtin_ccall("cbrt", $gentype, ($gentype,), x) - -@device_override Base.ceil(x::$gentype) = @builtin_ccall("ceil", $gentype, ($gentype,), x) - -@device_override Base.copysign(x::$gentype, y::$gentype) = @builtin_ccall("copysign", $gentype, ($gentype, $gentype), x, y) - -@device_override Base.cos(x::$gentype) = @builtin_ccall("cos", $gentype, ($gentype,), x) -@device_override Base.cosh(x::$gentype) = @builtin_ccall("cosh", $gentype, ($gentype,), x) -@device_function cospi(x::$gentype) = @builtin_ccall("cospi", $gentype, ($gentype,), x) - -@device_override SpecialFunctions.erfc(x::$gentype) = @builtin_ccall("erfc", $gentype, ($gentype,), x) -@device_override SpecialFunctions.erf(x::$gentype) = @builtin_ccall("erf", $gentype, ($gentype,), x) - -@device_override Base.exp(x::$gentype) = @builtin_ccall("exp", $gentype, ($gentype,), x) -@device_override Base.exp2(x::$gentype) = @builtin_ccall("exp2", $gentype, ($gentype,), x) -@device_override Base.exp10(x::$gentype) = @builtin_ccall("exp10", $gentype, ($gentype,), x) -@device_override Base.expm1(x::$gentype) = @builtin_ccall("expm1", $gentype, ($gentype,), x) - -@device_override Base.abs(x::$gentype) = @builtin_ccall("fabs", $gentype, ($gentype,), x) - -@device_function dim(x::$gentype, y::$gentype) = @builtin_ccall("fdim", $gentype, ($gentype, $gentype), x, y) - -@device_override Base.floor(x::$gentype) = @builtin_ccall("floor", $gentype, ($gentype,), x) - -@device_override Base.fma(a::$gentype, b::$gentype, c::$gentype) = @builtin_ccall("fma", $gentype, ($gentype, $gentype, $gentype), a, b, c) - -@device_override Base.max(x::$gentype, y::$gentype) = @builtin_ccall("fmax", $gentype, ($gentype, $gentype), x, y) - -@device_override Base.min(x::$gentype, y::$gentype) = @builtin_ccall("fmin", $gentype, ($gentype, $gentype), x, y) - -# NOTE: Julia's mod behaves differently than fmod -#@device_override Base.mod(x::$gentype, y::$gentype) = @builtin_ccall("fmod", $gentype, ($gentype, $gentype), x, y) -# fract(x::$gentype, $gentype *iptr) = @builtin_ccall("fract", $gentype, ($gentype, $gentype *), x, iptr) - -@device_override Base.hypot(x::$gentype, y::$gentype) = @builtin_ccall("hypot", $gentype, ($gentype, $gentype), x, y) - -@device_override SpecialFunctions.loggamma(x::$gentype) = @builtin_ccall("lgamma", $gentype, ($gentype,), x) - -@device_override Base.log(x::$gentype) = @builtin_ccall("log", $gentype, ($gentype,), x) -@device_override Base.log2(x::$gentype) = @builtin_ccall("log2", $gentype, ($gentype,), x) -@device_override Base.log10(x::$gentype) = @builtin_ccall("log10", $gentype, ($gentype,), x) -@device_override Base.log1p(x::$gentype) = @builtin_ccall("log1p", $gentype, ($gentype,), x) -@device_function logb(x::$gentype) = @builtin_ccall("logb", $gentype, ($gentype,), x) - -@device_function mad(a::$gentype, b::$gentype, c::$gentype) = @builtin_ccall("mad", $gentype, ($gentype, $gentype, $gentype), a, b, c) - -@device_function maxmag(x::$gentype, y::$gentype) = @builtin_ccall("maxmag", $gentype, ($gentype, $gentype), x, y) -@device_function minmag(x::$gentype, y::$gentype) = @builtin_ccall("minmag", $gentype, ($gentype, $gentype), x, y) - -# modf(x::$gentype, $gentype *iptr) = @builtin_ccall("modf", $gentype, ($gentype, $gentype *), x, iptr) - -@device_function nextafter(x::$gentype, y::$gentype) = @builtin_ccall("nextafter", $gentype, ($gentype, $gentype), x, y) - -@device_override Base.:(^)(x::$gentype, y::$gentype) = @builtin_ccall("pow", $gentype, ($gentype, $gentype), x, y) -@device_function powr(x::$gentype, y::$gentype) = @builtin_ccall("powr", $gentype, ($gentype, $gentype), x, y) - -@device_override Base.rem(x::$gentype, y::$gentype) = @builtin_ccall("remainder", $gentype, ($gentype, $gentype), x, y) - -@device_function rint(x::$gentype) = @builtin_ccall("rint", $gentype, ($gentype,), x) - -@device_override Base.round(x::$gentype) = @builtin_ccall("round", $gentype, ($gentype,), x) - -@device_function rsqrt(x::$gentype) = @builtin_ccall("rsqrt", $gentype, ($gentype,), x) - -@device_override Base.sin(x::$gentype) = @builtin_ccall("sin", $gentype, ($gentype,), x) -@device_override function Base.sincos(x::$gentype) - cosval = Ref{$gentype}() - sinval = GC.@preserve cosval begin - ptr = Base.unsafe_convert(Ptr{$gentype}, cosval) - llvm_ptr = reinterpret(LLVMPtr{$gentype, AS.Private}, ptr) - @builtin_ccall("sincos", $gentype, ($gentype, LLVMPtr{$gentype, AS.Private}), x, llvm_ptr) - end - return sinval, cosval[] -end -@device_override Base.sinh(x::$gentype) = @builtin_ccall("sinh", $gentype, ($gentype,), x) -@device_function sinpi(x::$gentype) = @builtin_ccall("sinpi", $gentype, ($gentype,), x) - -@device_override Base.sqrt(x::$gentype) = @builtin_ccall("sqrt", $gentype, ($gentype,), x) - -@device_override Base.tan(x::$gentype) = @builtin_ccall("tan", $gentype, ($gentype,), x) -@device_override Base.tanh(x::$gentype) = @builtin_ccall("tanh", $gentype, ($gentype,), x) -@device_function tanpi(x::$gentype) = @builtin_ccall("tanpi", $gentype, ($gentype,), x) - -@device_override SpecialFunctions.gamma(x::$gentype) = @builtin_ccall("tgamma", $gentype, ($gentype,), x) - -@device_override Base.trunc(x::$gentype) = @builtin_ccall("trunc", $gentype, ($gentype,), x) - -end -end - - -# generically typed -- only floats - -for gentypef in generic_types_float - -if gentypef !== Float32 -@eval begin -@device_override Base.max(x::$gentypef, y::Float32) = @builtin_ccall("fmax", $gentypef, ($gentypef, Float32), x, y) -@device_override Base.min(x::$gentypef, y::Float32) = @builtin_ccall("fmin", $gentypef, ($gentypef, Float32), x, y) -end -end - -end - - -# generically typed -- only doubles - -for gentyped in generic_types_double - -if gentyped !== Float64 -@eval begin -@device_override Base.min(x::$gentyped, y::Float64) = @builtin_ccall("fmin", $gentyped, ($gentyped, Float64), x, y) -@device_override Base.max(x::$gentyped, y::Float64) = @builtin_ccall("fmax", $gentyped, ($gentyped, Float64), x, y) -end -end - -end - - -# specifically typed - -# frexp(x::Float32{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float32{n}, (Float32{n}, Int32{n} *), x, exp) -# frexp(x::Float32, Int32 *exp) = @builtin_ccall("frexp", Float32, (Float32, Int32 *), x, exp) -# frexp(x::Float64{n}, Int32{n} *exp) = @builtin_ccall("frexp", Float64{n}, (Float64{n}, Int32{n} *), x, exp) -# frexp(x::Float64, Int32 *exp) = @builtin_ccall("frexp", Float64, (Float64, Int32 *), x, exp) - -# ilogb(x::Float32{n}) = @builtin_ccall("ilogb", Int32{n}, (Float32{n},), x) -@device_function ilogb(x::Float32) = @builtin_ccall("ilogb", Int32, (Float32,), x) -# ilogb(x::Float64{n}) = @builtin_ccall("ilogb", Int32{n}, (Float64{n},), x) -@device_function ilogb(x::Float64) = @builtin_ccall("ilogb", Int32, (Float64,), x) - -# ldexp(x::Float32{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32{n}), x, k) -# ldexp(x::Float32{n}, k::Int32) = @builtin_ccall("ldexp", Float32{n}, (Float32{n}, Int32), x, k) -@device_override Base.ldexp(x::Float32, k::Int32) = @builtin_ccall("ldexp", Float32, (Float32, Int32), x, k) -# ldexp(x::Float64{n}, k::Int32{n}) = @builtin_ccall("ldexp", Float64{n}, (Float64{n}, Int32{n}), x, k) -# ldexp(x::Float64{n}, k::Int32) = @builtin_ccall("ldexp", Float64{n}, (Float64{n}, Int32), x, k) -@device_override Base.ldexp(x::Float64, k::Int32) = @builtin_ccall("ldexp", Float64, (Float64, Int32), x, k) - -# lgamma_r(x::Float32{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float32{n}, (Float32{n}, Int32{n} *), x, signp) -# lgamma_r(x::Float32, Int32 *signp) = @builtin_ccall("lgamma_r", Float32, (Float32, Int32 *), x, signp) -# lgamma_r(x::Float64{n}, Int32{n} *signp) = @builtin_ccall("lgamma_r", Float64{n}, (Float64{n}, Int32{n} *), x, signp) -# Float64 lgamma_r(x::Float64, Int32 *signp) = @builtin_ccall("lgamma_r", Float64, (Float64, Int32 *), x, signp) - -# nan(nancode::uintn) = @builtin_ccall("nan", Float32{n}, (uintn,), nancode) -@device_function nan(nancode::UInt32) = @builtin_ccall("nan", Float32, (UInt32,), nancode) -# nan(nancode::UInt64{n}) = @builtin_ccall("nan", Float64{n}, (UInt64{n},), nancode) -@device_function nan(nancode::UInt64) = @builtin_ccall("nan", Float64, (UInt64,), nancode) - -# pown(x::Float32{n}, y::Int32{n}) = @builtin_ccall("pown", Float32{n}, (Float32{n}, Int32{n}), x, y) -@device_override Base.:(^)(x::Float32, y::Int32) = @builtin_ccall("pown", Float32, (Float32, Int32), x, y) -# pown(x::Float64{n}, y::Int32{n}) = @builtin_ccall("pown", Float64{n}, (Float64{n}, Int32{n}), x, y) -@device_override Base.:(^)(x::Float64, y::Int32) = @builtin_ccall("pown", Float64, (Float64, Int32), x, y) - -# remquo(x::Float32{n}, y::Float32{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float32{n}, (Float32{n}, Float32{n}, Int32{n} *), x, y, quo) -# remquo(x::Float32, y::Float32, Int32 *quo) = @builtin_ccall("remquo", Float32, (Float32, Float32, Int32 *), x::Float32, y, quo) -# remquo(x::Float64{n}, y::Float64{n}, Int32{n} *quo) = @builtin_ccall("remquo", Float64{n}, (Float64{n}, Float64{n}, Int32{n} *), x, y, quo) -# remquo(x::Float64, y::Float64, Int32 *quo) = @builtin_ccall("remquo", Float64, (Float64, Float64, Int32 *), x, y, quo) - -# rootn(x::Float32{n}, y::Int32{n}) = @builtin_ccall("rootn", Float32{n}, (Float32{n}, Int32{n}), x, y) -@device_function rootn(x::Float32, y::Int32) = @builtin_ccall("rootn", Float32, (Float32, Int32), x, y) -# rootn(x::Float64{n}, y::Int32{n}) = @builtin_ccall("rootn", Float64{n}, (Float64{n}, Int32{n}), x, y) -# rootn(x::Float64, y::Int32) = @builtin_ccall("rootn", Float64{n}, (Float64, Int32), x, y) - - -# TODO: half and native - -function _mulhi(a::Int64, b::Int64) - shift = sizeof(a) * 4 - mask = typemax(UInt32) - a1, a2 = (a >> shift), a & mask - b1, b2 = (b >> shift), b & mask - a1b1, a1b2, a2b1 = a1*b1, a1*b2, a2*b1 - t1 = a1b2 + _mulhi(a2 % UInt32, b2 % UInt32) - t2 = a2b1 + (t1 & mask) - a1b1 + (t1 >> shift) + (t2 >> shift) -end -@static if isdefined(Base.MultiplicativeInverses, :_mul_high) - _mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = Base.MultiplicativeInverses._mul_high(a, b) - @device_override Base.MultiplicativeInverses._mul_high(a::Int64, b::Int64) = _mulhi(a, b) -else - _mulhi(a::T, b::T) where {T<:Union{Signed, Unsigned}} = ((widen(a)*b) >>> (sizeof(a)*8)) % T - @device_override function Base.div(a::Int64, b::Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}) - x = _mulhi(a, b.multiplier) - x += (a*b.addmul) % Int64 - ifelse(abs(b.divisor) == 1, a*b.divisor, (signbit(x) + (x >> b.shift)) % Int64) - end -end diff --git a/src/device/opencl/memory.jl b/src/device/opencl/memory.jl deleted file mode 100644 index cf62f59..0000000 --- a/src/device/opencl/memory.jl +++ /dev/null @@ -1,56 +0,0 @@ -# Shared Memory (part of B.2) - -export @LocalMemory, oneLocalArray - -@inline function oneLocalArray(::Type{T}, dims) where {T} - len = prod(dims) - # NOTE: this relies on const-prop to forward the literal length to the generator. - # maybe we should include the size in the type, like StaticArrays does? - ptr = emit_localmemory(T, Val(len)) - oneDeviceArray(dims, ptr) -end - -macro LocalMemory(T, dims) - Base.depwarn("@LocalMemory is deprecated, please use the oneLocalArray function", :oneLocalArray) - - quote - oneLocalArray($(esc(T)), $(esc(dims))) - end -end - -# get a pointer to local memory, with known (static) or zero length (dynamic) -@generated function emit_localmemory(::Type{T}, ::Val{len}=Val(0)) where {T,len} - Context() do ctx - # XXX: as long as LLVMPtr is emitted as i8*, it doesn't make sense to type the GV - eltyp = convert(LLVMType, LLVM.Int8Type()) - T_ptr = convert(LLVMType, LLVMPtr{T,AS.Local}) - - # create a function - llvm_f, _ = create_function(T_ptr) - - # create the global variable - mod = LLVM.parent(llvm_f) - gv_typ = LLVM.ArrayType(eltyp, len * sizeof(T)) - gv = GlobalVariable(mod, gv_typ, "local_memory", AS.Local) - if len > 0 - linkage!(gv, LLVM.API.LLVMInternalLinkage) - initializer!(gv, null(gv_typ)) - end - # TODO: Make the alignment configurable - alignment!(gv, Base.datatype_alignment(T)) - - # generate IR - IRBuilder() do builder - entry = BasicBlock(llvm_f, "entry") - position!(builder, entry) - - ptr = gep!(builder, gv_typ, gv, [ConstantInt(0), ConstantInt(0)]) - - untyped_ptr = bitcast!(builder, ptr, T_ptr) - - ret!(builder, untyped_ptr) - end - - call_function(llvm_f, LLVMPtr{T,AS.Local}) - end -end diff --git a/src/device/opencl/printf.jl b/src/device/opencl/printf.jl deleted file mode 100644 index 726ef6f..0000000 --- a/src/device/opencl/printf.jl +++ /dev/null @@ -1,205 +0,0 @@ -# printf - -# Formatted Output (B.17) - -@generated function promote_c_argument(arg) - # > When a function with a variable-length argument list is called, the variable - # > arguments are passed using C's old ``default argument promotions.'' These say that - # > types char and short int are automatically promoted to int, and type float is - # > automatically promoted to double. Therefore, varargs functions will never receive - # > arguments of type char, short int, or float. - - if arg == Cchar || arg == Cshort || arg == Cuchar || arg == Cushort - return :(Cint(arg)) - elseif arg == Cfloat - return :(Cdouble(arg)) - else - return :(arg) - end -end - -macro printf(fmt::String, args...) - fmt_val = Val(Symbol(fmt)) - - return :(emit_printf($fmt_val, $(map(arg -> :(promote_c_argument($arg)), esc.(args))...))) -end - -@generated function emit_printf(::Val{fmt}, argspec...) where {fmt} - arg_exprs = [:( argspec[$i] ) for i in 1:length(argspec)] - arg_types = [argspec...] - - Context() do ctx - T_void = LLVM.VoidType() - T_int32 = LLVM.Int32Type() - T_pint8 = LLVM.PointerType(LLVM.Int8Type()) - - # create functions - param_types = LLVMType[convert(LLVMType, typ) for typ in arg_types] - llvm_f, _ = create_function(T_int32, param_types) - mod = LLVM.parent(llvm_f) - - # generate IR - IRBuilder() do builder - entry = BasicBlock(llvm_f, "entry") - position!(builder, entry) - - str = globalstring_ptr!(builder, String(fmt)) - - # invoke printf and return - printf_typ = LLVM.FunctionType(T_int32, [T_pint8]; vararg=true) - printf = LLVM.Function(mod, "printf", printf_typ) - push!(function_attributes(printf), EnumAttribute("nobuiltin")) - chars = call!(builder, printf_typ, printf, [str, parameters(llvm_f)...]) - - ret!(builder, chars) - end - - call_function(llvm_f, Int32, Tuple{arg_types...}, arg_exprs...) - end -end - - -## print-like functionality - -# simple conversions, defining an expression and the resulting argument type. nothing fancy, -# `@print` pretty directly maps to `@printf`; we should just support `write(::IO)`. -const print_conversions = Dict( - Float32 => (x->:(Float64($x)), Float64), - Ptr{<:Any} => (x->:(convert(Ptr{Cvoid}, $x)), Ptr{Cvoid}), - Bool => (x->:(Int32($x)), Int32), -) - -# format specifiers -const print_specifiers = Dict( - # integers - Int16 => "%hd", - Int32 => "%d", - Int64 => Sys.iswindows() ? "%lld" : "%ld", - UInt16 => "%hu", - UInt32 => "%u", - UInt64 => Sys.iswindows() ? "%llu" : "%lu", - - # floating-point - Float64 => "%f", - - # other - Cchar => "%c", - Ptr{Cvoid} => "%p", -) - -@generated function _print(parts...) - fmt = "" - args = Expr[] - - for i in 1:length(parts) - part = :(parts[$i]) - T = parts[i] - - # put literals directly in the format string - if T <: Val - fmt *= string(T.parameters[1]) - continue - end - - # try to convert arguments if they are not supported directly - if !haskey(print_specifiers, T) - for Tmatch in keys(print_conversions) - if T <: Tmatch - conv, T = print_conversions[Tmatch] - part = conv(part) - break - end - end - end - - # render the argument - if haskey(print_specifiers, T) - fmt *= print_specifiers[T] - push!(args, part) - elseif T <: String - @error("@print does not support non-literal strings") - else - @error("@print does not support values of type $T") - end - end - - quote - Base.@_inline_meta - @printf($fmt, $(args...)) - end -end - -""" - @print(xs...) - @println(xs...) - -Print a textual representation of values `xs` to standard output from the GPU. The -functionality builds on `@printf`, and is intended as a more use friendly alternative of -that API. However, that also means there's only limited support for argument types, handling -16/32/64 signed and unsigned integers, 32 and 64-bit floating point numbers, `Cchar`s and -pointers. For more complex output, use `@printf` directly. - -Limited string interpolation is also possible: - -```julia - @print("Hello, World ", 42, "\\n") - @print "Hello, World \$(42)\\n" -``` -""" -macro print(parts...) - args = Union{Val,Expr,Symbol}[] - - parts = [parts...] - while true - isempty(parts) && break - - part = popfirst!(parts) - - # handle string interpolation - if isa(part, Expr) && part.head == :string - parts = vcat(part.args, parts) - continue - end - - # expose literals to the generator by using Val types - if isbits(part) # literal numbers, etc - push!(args, Val(part)) - elseif isa(part, QuoteNode) # literal symbols - push!(args, Val(part.value)) - elseif isa(part, String) # literal strings need to be interned - push!(args, Val(Symbol(part))) - else # actual values that will be passed to printf - push!(args, part) - end - end - - quote - _print($(map(esc, args)...)) - end -end - -@doc (@doc @print) -> -macro println(parts...) - esc(quote - oneAPI.@print($(parts...), "\n") - end) -end - -""" - @show(ex) - -GPU analog of `Base.@show`. It comes with the same type restrictions as [`@printf`](@ref). - -```julia -@show threadIdx().x -``` -""" -macro show(exs...) - blk = Expr(:block) - for ex in exs - push!(blk.args, :(oneAPI.@println($(sprint(Base.show_unquoted,ex)*" = "), - begin local value = $(esc(ex)) end))) - end - isempty(exs) || push!(blk.args, :value) - blk -end diff --git a/src/device/opencl/synchronization.jl b/src/device/opencl/synchronization.jl deleted file mode 100644 index 83bdad5..0000000 --- a/src/device/opencl/synchronization.jl +++ /dev/null @@ -1,21 +0,0 @@ -# Synchronization Functions - -export barrier - -const cl_mem_fence_flags = UInt32 -const CLK_LOCAL_MEM_FENCE = cl_mem_fence_flags(1) -const CLK_GLOBAL_MEM_FENCE = cl_mem_fence_flags(2) - -#barrier(flags=0) = @builtin_ccall("barrier", Cvoid, (UInt32,), flags) -barrier(flags=0) = Base.llvmcall((""" - declare void @_Z7barrierj(i32) #0 - define void @entry(i32 %0) #1 { - call void @_Z7barrierj(i32 %0) - ret void - } - attributes #0 = { convergent } - attributes #1 = { alwaysinline } - """, "entry"), - Cvoid, Tuple{Int32}, convert(Int32, flags)) -push!(opencl_builtins, "_Z7barrierj") -# TODO: add support for attributes to @builting_ccall/LLVM.@typed_ccall diff --git a/src/device/opencl/work_item.jl b/src/device/opencl/work_item.jl deleted file mode 100644 index e5c829a..0000000 --- a/src/device/opencl/work_item.jl +++ /dev/null @@ -1,30 +0,0 @@ -# Work-Item Functions - -export get_work_dim, - get_global_size, get_global_id, - get_local_size, get_enqueued_local_size, get_local_id, - get_num_groups, get_group_id, - get_global_offset, - get_global_linear_id, get_local_linear_id - -# NOTE: these functions now unsafely truncate to Int to avoid top bit checks. -# we should probably use range metadata instead. - -# TODO: 1-indexed dimension selection? - -get_work_dim() = @builtin_ccall("get_work_dim", UInt32, ()) % Int - -get_global_size(dimindx::Integer=0) = @builtin_ccall("get_global_size", UInt, (UInt32,), dimindx) % Int -get_global_id(dimindx::Integer=0) = @builtin_ccall("get_global_id", UInt, (UInt32,), dimindx) % Int + 1 - -get_local_size(dimindx::Integer=0) = @builtin_ccall("get_local_size", UInt, (UInt32,), dimindx) % Int -get_enqueued_local_size(dimindx::Integer=0) = @builtin_ccall("get_enqueued_local_size", UInt, (UInt32,), dimindx) % Int -get_local_id(dimindx::Integer=0) = @builtin_ccall("get_local_id", UInt, (UInt32,), dimindx) % Int + 1 - -get_num_groups(dimindx::Integer=0) = @builtin_ccall("get_num_groups", UInt, (UInt32,), dimindx) % Int -get_group_id(dimindx::Integer=0) = @builtin_ccall("get_group_id", UInt, (UInt32,), dimindx) % Int + 1 - -get_global_offset(dimindx::Integer=0) = @builtin_ccall("get_global_offset", UInt, (UInt32,), dimindx) % Int + 1 - -get_global_linear_id() = @builtin_ccall("get_global_linear_id", UInt, ()) % Int + 1 -get_local_linear_id() = @builtin_ccall("get_local_linear_id", UInt, ()) % Int + 1 diff --git a/src/device/pointer.jl b/src/device/pointer.jl deleted file mode 100644 index 116ca22..0000000 --- a/src/device/pointer.jl +++ /dev/null @@ -1,18 +0,0 @@ -# oneAPI-specific operations on pointers with address spaces - -## adrspace aliases - -export AS - -module AS - -const Private = 0 -const Global = 1 -const Constant = 2 -const Local = 3 -const Generic = 4 -const Input = 5 -const Output = 6 -const Count = 7 - -end diff --git a/src/device/utils.jl b/src/device/utils.jl deleted file mode 100644 index 60ff7a2..0000000 --- a/src/device/utils.jl +++ /dev/null @@ -1,110 +0,0 @@ - -const opencl_builtins = String["printf"] - -# OpenCL functions need to be mangled according to the C++ Itanium spec. We implement a very -# limited version of that spec here, just enough to support OpenCL built-ins. -# -# This macro also keeps track of called builtins, generating `ccall("extern...", llvmcall)` -# expressions for them (so that we can exclude them during IR verification). -macro builtin_ccall(name, ret, argtypes, args...) - @assert Meta.isexpr(argtypes, :tuple) - argtypes = argtypes.args - - function mangle(T::Type) - if T == Cint - "i" - elseif T == Cuint - "j" - elseif T == Clong - "l" - elseif T == Culong - "m" - elseif T == Clonglong - "x" - elseif T == Culonglong - "y" - elseif T == Cshort - "s" - elseif T == Cushort - "t" - elseif T == Cchar - "c" - elseif T == Cuchar - "h" - elseif T == Cfloat - "f" - elseif T == Cdouble - "d" - elseif T <: LLVMPtr - elt, as = T.parameters - - # mangle address space - ASstr = if as == AS.Global - "CLglobal" - #elseif as == AS.Global_device - # "CLdevice" - #elseif as == AS.Global_host - # "CLhost" - elseif as == AS.Local - "CLlocal" - elseif as == AS.Constant - "CLconstant" - elseif as == AS.Private - "CLprivate" - elseif as == AS.Generic - "CLgeneric" - else - error("Unknown address space $AS") - end - - # encode as vendor qualifier - ASstr = "U" * string(length(ASstr)) * ASstr - - # XXX: where does the V come from? - "P" * ASstr * "V" * mangle(elt) - else - error("Unknown type $T") - end - end - - # C++-style mangling; very limited to just support these intrinsics - # TODO: generalize for use with other intrinsics? do we need to mangle those? - mangled = "_Z$(length(name))$name" - for t in argtypes - # with `@eval @builtin_ccall`, we get actual types in the ast, otherwise symbols - t = (isa(t, Symbol) || isa(t, Expr)) ? eval(t) : t - mangled *= mangle(t) - end - - push!(opencl_builtins, mangled) - esc(quote - @typed_ccall($mangled, llvmcall, $ret, ($(argtypes...),), $(args...)) - end) -end - - -## device overrides - -# local method table for device functions -Base.Experimental.@MethodTable(method_table) - -macro device_override(ex) - esc(quote - Base.Experimental.@overlay($method_table, $ex) - end) -end - -macro device_function(ex) - ex = macroexpand(__module__, ex) - def = ExprTools.splitdef(ex) - - # generate a function that errors - def[:body] = quote - error("This function is not intended for use on the CPU") - end - - esc(quote - $(ExprTools.combinedef(def)) - @device_override $ex - end) -end diff --git a/src/mapreduce.jl b/src/mapreduce.jl index bf60f90..7f3d2e5 100644 --- a/src/mapreduce.jl +++ b/src/mapreduce.jl @@ -6,8 +6,8 @@ # Reduce a value across a group, using local memory for communication @inline function reduce_group(op, val::T, neutral, ::Val{maxitems}) where {T, maxitems} - items = get_local_size(0) - item = get_local_id(0) + items = get_local_size() + item = get_local_id() # local mem for a complete reduction shared = oneLocalArray(T, (maxitems,)) @@ -47,10 +47,10 @@ Base.@propagate_inbounds _map_getindex(args::Tuple{}, I) = () function partial_mapreduce_device(f, op, neutral, maxitems, Rreduce, Rother, R, As...) # decompose the 1D hardware indices into separate ones for reduction (across items # and possibly groups if it doesn't fit) and other elements (remaining groups) - localIdx_reduce = get_local_id(0) - localDim_reduce = get_local_size(0) - groupIdx_reduce, groupIdx_other = fldmod1(get_group_id(0), length(Rother)) - groupDim_reduce = get_num_groups(0) ÷ length(Rother) + localIdx_reduce = get_local_id() + localDim_reduce = get_local_size() + groupIdx_reduce, groupIdx_other = fldmod1(get_group_id(), length(Rother)) + groupDim_reduce = get_num_groups() ÷ length(Rother) # group-based indexing into the values outside of the reduction dimension # (that means we can safely synchronize items within this group) diff --git a/src/oneAPI.jl b/src/oneAPI.jl index 259fea7..fd0e048 100644 --- a/src/oneAPI.jl +++ b/src/oneAPI.jl @@ -25,18 +25,12 @@ include("../lib/level-zero/oneL0.jl") using .oneL0 functional() = oneL0.functional[] -# device functionality (needs to be loaded first, because of generated functions) -include("device/utils.jl") -include("device/pointer.jl") -include("device/array.jl") +# device functionality +import SPIRVIntrinsics +SPIRVIntrinsics.@import_all +SPIRVIntrinsics.@reexport_public include("device/runtime.jl") -include("device/opencl/work_item.jl") -include("device/opencl/synchronization.jl") -include("device/opencl/memory.jl") -include("device/opencl/printf.jl") -include("device/opencl/math.jl") -include("device/opencl/integer.jl") -include("device/opencl/atomic.jl") +include("device/array.jl") include("device/quirks.jl") # essential stuff diff --git a/src/oneAPIKernels.jl b/src/oneAPIKernels.jl index 4620395..66729b5 100644 --- a/src/oneAPIKernels.jl +++ b/src/oneAPIKernels.jl @@ -113,32 +113,32 @@ end ## Indexing Functions @device_override @inline function KA.__index_Local_Linear(ctx) - return get_local_id(0) + return get_local_id() end @device_override @inline function KA.__index_Group_Linear(ctx) - return get_group_id(0) + return get_group_id() end @device_override @inline function KA.__index_Global_Linear(ctx) - return get_global_id(0) + return get_global_id() end @device_override @inline function KA.__index_Local_Cartesian(ctx) - @inbounds KA.workitems(KA.__iterspace(ctx))[get_local_id(0)] + @inbounds KA.workitems(KA.__iterspace(ctx))[get_local_id()] end @device_override @inline function KA.__index_Group_Cartesian(ctx) - @inbounds KA.blocks(KA.__iterspace(ctx))[get_group_id(0)] + @inbounds KA.blocks(KA.__iterspace(ctx))[get_group_id()] end @device_override @inline function KA.__index_Global_Cartesian(ctx) - return @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(0), get_local_id(0)) + return @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(), get_local_id()) end @device_override @inline function KA.__validindex(ctx) if KA.__dynamic_checkbounds(ctx) - I = @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(0), get_local_id(0)) + I = @inbounds KA.expand(KA.__iterspace(ctx), get_group_id(), get_local_id()) return I in KA.__ndrange(ctx) else return true diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index 2800fc5..832ee70 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -1,17 +1,17 @@ @testset "work items" begin @on_device get_work_dim() - @on_device get_global_size(0) - @on_device get_global_id(0) + @on_device get_global_size() + @on_device get_global_id() - @on_device get_local_size(0) - @on_device get_enqueued_local_size(0) - @on_device get_local_id(0) + @on_device get_local_size() + @on_device get_enqueued_local_size() + @on_device get_local_id() - @on_device get_num_groups(0) - @on_device get_group_id(0) + @on_device get_num_groups() + @on_device get_group_id() - @on_device get_global_offset(0) + @on_device get_global_offset() @on_device get_global_linear_id() @on_device get_local_linear_id() @@ -218,7 +218,7 @@ end @testset "statically typed" begin function kernel(d, n) - t = get_local_id(0) + t = get_local_id() tr = n-t+1 s = oneLocalArray(Float32, 1024) @@ -244,7 +244,7 @@ end float64_supported && push!(typs, Float64) @testset for typ in typs function kernel(d::oneDeviceArray{T}, n) where {T} - t = get_local_id(0) + t = get_local_id() tr = n-t+1 s = oneLocalArray(T, 1024) diff --git a/test/execution.jl b/test/execution.jl index cddfc8a..2ebde5f 100644 --- a/test/execution.jl +++ b/test/execution.jl @@ -160,7 +160,7 @@ len = prod(dims) @testset "manually allocated" begin function kernel(input, output) - i = get_global_id(0) + i = get_global_id() val = input[i] output[i] = val @@ -181,8 +181,8 @@ end @testset "scalar through single-value array" begin function kernel(a, x) - i = get_global_id(0) - max = get_global_size(0) + i = get_global_id() + max = get_global_size() if i == max _val = a[i] x[] = _val @@ -204,8 +204,8 @@ end @testset "scalar through single-value array, using device function" begin @noinline child(a, i) = a[i] function parent(a, x) - i = get_global_id(0) - max = get_global_size(0) + i = get_global_id() + max = get_global_size() if i == max _val = child(a, i) x[] = _val @@ -259,7 +259,7 @@ end @eval struct ExecGhost end function kernel(ghost, a, b, c) - i = get_global_id(0) + i = get_global_id() c[i] = a[i] + b[i] return end @@ -270,7 +270,7 @@ end # bug: ghost type function parameters confused aggregate type rewriting function kernel(ghost, out, aggregate) - i = get_global_id(0) + i = get_global_id() out[i] = aggregate[1] return end