diff --git a/README.md b/README.md index b74489e..fc3d717 100644 --- a/README.md +++ b/README.md @@ -197,8 +197,6 @@ BenchmarkTools.Trial: 10000 samples with 10 evaluations. ### Local per-thread storage (`threadlocal`) -**Warning: this feature is likely broken!** - You also can define local storage for each thread, providing a vector containing each of the local storages at the end. ```julia @@ -234,6 +232,25 @@ julia> let Float16[83.0, 90.0, 27.0, 65.0] ``` +### `reduction` +The `reduction` keyword enables reduction of an already initialized `isbits` variable with certain supported associative operations (see [docs](https://JuliaSIMD.github.io/Polyester.jl/stable)), such that the transition from serialized code is as simple as adding the `@batch` macro. Contrary to `threadlocal` this does not incur any additional allocations + +```julia +julia> function batch_reduction() + y1 = 0 + y2 = 1 + @batch reduction=((+, y1), (*, y2)) for i in 1:9 + y1 += i + y2 *= i + end + y1, y2 + end +julia> batch_reduction() +(45, 362880) +julia> @allocated batch_reduction() +0 +``` + ## Disabling Polyester threads When running many repetitions of a Polyester-multithreaded function (e.g. in an embarrassingly parallel problem that repeatedly executes a small already Polyester-multithreaded function), it can be beneficial to disable Polyester (the inner multithreaded loop) and multithread only at the outer level (e.g. with `Base.Threads`). This can be done with the `disable_polyester_threads` context manager. In the expandable section below you can see examples with benchmarks. diff --git a/src/Polyester.jl b/src/Polyester.jl index 7399664..ef324e5 100644 --- a/src/Polyester.jl +++ b/src/Polyester.jl @@ -6,6 +6,7 @@ end using ThreadingUtilities import StaticArrayInterface const ArrayInterface = StaticArrayInterface +using Base.Cartesian: @nexprs using StaticArrayInterface: static_length, static_step, static_first, static_size using StrideArraysCore: object_and_preserve using ManualMemory: Reference @@ -23,6 +24,15 @@ using CPUSummary: num_cores export batch, @batch, disable_polyester_threads +const SUPPORTED_REDUCE_OPS = (:+, :*, :min, :max, :&, :|) +initializer(::typeof(+), ::T) where {T} = zero(T) +initializer(::typeof(+), ::Bool) = zero(Int) +initializer(::typeof(*), ::T) where {T} = one(T) +initializer(::typeof(min), ::T) where {T} = typemax(T) +initializer(::typeof(max), ::T) where {T} = typemin(T) +initializer(::typeof(&), ::Bool) = true +initializer(::typeof(|), ::Bool) = false + include("batch.jl") include("closure.jl") @@ -38,10 +48,4 @@ function reset_threads!() foreach(ThreadingUtilities.checktask, eachindex(ThreadingUtilities.TASKS)) return nothing end - -# y = rand(1) -# x = rand(1) -# @batch for i ∈ eachindex(y,x) -# y[i] = sin(x[i]) -# end end diff --git a/src/batch.jl b/src/batch.jl index fcc236c..99f0353 100644 --- a/src/batch.jl +++ b/src/batch.jl @@ -1,27 +1,38 @@ -struct BatchClosure{F,A,C} # C is a Val{Bool} triggering local storage - f::F +# S is a Val{Bool} indicating whether we will need to load the thread index +# C is a Tuple{...} containing the types of the reduction variables +struct BatchClosure{F,A,S,C} + f::F end -function (b::BatchClosure{F,A,C})(p::Ptr{UInt}) where {F,A,C} +function (b::BatchClosure{F,A,S,C})(p::Ptr{UInt}) where {F,A,S,C} (offset, args) = ThreadingUtilities.load(p, A, 2 * sizeof(UInt)) (offset, start) = ThreadingUtilities.load(p, UInt, offset) (offset, stop) = ThreadingUtilities.load(p, UInt, offset) - if C + if C === Tuple{} && !S + b.f(args, (start + one(UInt)) % Int, stop % Int) + elseif C === Tuple{} && S ((offset, i) = ThreadingUtilities.load(p, UInt, offset)) b.f(args, (start + one(UInt)) % Int, stop % Int, i % Int) + elseif C !== Tuple{} && !S + ((offset, reducinits) = ThreadingUtilities.load(p, C, offset)) + reducres = b.f(args, (start + one(UInt)) % Int, stop % Int, reducinits) + ThreadingUtilities.store!(p, reducres, offset) else - b.f(args, (start + one(UInt)) % Int, stop % Int) + ((offset, i) = ThreadingUtilities.load(p, UInt, offset)) + ((offset, reducinits) = ThreadingUtilities.load(p, C, offset)) + reducres = b.f(args, (start + one(UInt)) % Int, stop % Int, i % Int, reducinits) + ThreadingUtilities.store!(p, reducres, offset) end ThreadingUtilities._atomic_store!(p, ThreadingUtilities.SPIN) nothing end -@generated function batch_closure(f::F, args::A, ::Val{C}) where {F,A,C} +@generated function batch_closure(f::F, args::A, ::Val{S}, reducinits::C) where {F,A,S,C} q = if Base.issingletontype(F) - bc = BatchClosure{F,A,C}(F.instance) + bc = BatchClosure{F,A,S,C}(F.instance) :(@cfunction($bc, Cvoid, (Ptr{UInt},))) else quote - bc = BatchClosure{F,A,C}(f) + bc = BatchClosure{F,A,S,C}(f) @cfunction($(Expr(:$, :bc)), Cvoid, (Ptr{UInt},)) end end @@ -32,17 +43,46 @@ end # @cfunction($bc, Cvoid, (Ptr{UInt},)) # end +@inline function load_threadlocals(tid, argtup::A, ::Val{S}, reductup::C) where {A,S,C} + p = ThreadingUtilities.taskpointer(tid) + (offset, _) = ThreadingUtilities.load(p, UInt, sizeof(UInt)) + (offset, _) = ThreadingUtilities.load(p, A, offset) + (offset, _) = ThreadingUtilities.load(p, UInt, offset) + (offset, _) = ThreadingUtilities.load(p, UInt, offset) + if S + (offset, _) = ThreadingUtilities.load(p, UInt, offset) + end + (offset, _) = ThreadingUtilities.load(p, C, offset) + (offset, reducvals) = ThreadingUtilities.load(p, C, offset) + return reducvals +end + +@inline function setup_batch!( + p::Ptr{UInt}, + fptr::Ptr{Cvoid}, + argtup, + start::UInt, + stop::UInt, +) + offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt)) + offset = ThreadingUtilities.store!(p, argtup, offset) + offset = ThreadingUtilities.store!(p, start, offset) + offset = ThreadingUtilities.store!(p, stop, offset) + nothing +end @inline function setup_batch!( p::Ptr{UInt}, fptr::Ptr{Cvoid}, argtup, start::UInt, stop::UInt, + i_or_reductup, ) offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt)) offset = ThreadingUtilities.store!(p, argtup, offset) offset = ThreadingUtilities.store!(p, start, offset) offset = ThreadingUtilities.store!(p, stop, offset) + offset = ThreadingUtilities.store!(p, i_or_reductup, offset) nothing end @inline function setup_batch!( @@ -52,12 +92,14 @@ end start::UInt, stop::UInt, i::UInt, + reductup, ) offset = ThreadingUtilities.store!(p, fptr, sizeof(UInt)) offset = ThreadingUtilities.store!(p, argtup, offset) offset = ThreadingUtilities.store!(p, start, offset) offset = ThreadingUtilities.store!(p, stop, offset) offset = ThreadingUtilities.store!(p, i, offset) + offset = ThreadingUtilities.store!(p, reductup, offset) nothing end @inline function launch_batched_thread!(cfunc, tid, argtup, start, stop) @@ -66,7 +108,20 @@ end setup_batch!(p, fptr, argtup, start, stop) end end -@inline function launch_batched_thread!(cfunc, tid, argtup, start, stop, i) +@inline function launch_batched_thread!(cfunc, tid, argtup, start, stop, i_or_reductup) + fptr = Base.unsafe_convert(Ptr{Cvoid}, cfunc) + ThreadingUtilities.launch( + tid, + fptr, + argtup, + start, + stop, + i_or_reductup, + ) do p, fptr, argtup, start, stop, i_or_reductup + setup_batch!(p, fptr, argtup, start, stop, i_or_reductup) + end +end +@inline function launch_batched_thread!(cfunc, tid, argtup, start, stop, i, reductup) fptr = Base.unsafe_convert(Ptr{Cvoid}, cfunc) ThreadingUtilities.launch( tid, @@ -75,8 +130,9 @@ end start, stop, i, - ) do p, fptr, argtup, start, stop, i - setup_batch!(p, fptr, argtup, start, stop, i) + reductup + ) do p, fptr, argtup, start, stop, i, reductup + setup_batch!(p, fptr, argtup, start, stop, i, reductup) end end _extract_params(::Type{T}) where {T<:Tuple} = T.parameters @@ -111,7 +167,9 @@ end @generated function _batch_no_reserve( f!::F, - threadlocal::Val{thread_local}, + needtid::Val{S}, + reducops::Tuple{Vararg{Any,C}}, + reducinits::Tuple{Vararg{Any,C}}, threadmask_tuple::NTuple{N}, nthread_tuple, torelease_tuple, @@ -119,7 +177,7 @@ end Nd, ulen, args::Vararg{Any,K}, -) where {F,K,N,thread_local} +) where {F,K,N,S,C} q = quote $(Expr(:meta, :inline)) # threads = UnsignedIteratorEarlyStop(threadmask, nthread) @@ -127,16 +185,54 @@ end # nthread_total = sum(nthread_tuple) Ndp = Nd + one(Nd) end - launch_quote = if thread_local - :(launch_batched_thread!(cfunc, tid, argtup, start, stop, i % UInt)) + C !== 0 && push!( + q.args, + quote + @nexprs $C j -> RVAR_j = reducinits[j] + end + ) + launch_quote = if S + if C === 0 + :(launch_batched_thread!(cfunc, tid, argtup, start, stop, tid % UInt)) + else + :(launch_batched_thread!(cfunc, tid, argtup, start, stop, tid % UInt, reducinits)) + end + else + if C === 0 + :(launch_batched_thread!(cfunc, tid, argtup, start, stop)) + else + :(launch_batched_thread!(cfunc, tid, argtup, start, stop, reducinits)) + end + end + f_quote = Expr(:call, :f!, :arguments, :((start + one(UInt)) % Int), :(ulen % Int)) + S && push!(f_quote.args, :((sum(nthread_tuple) + 1) % Int)) + C !== 0 && push!(f_quote.args, :reducinits) + rem_quote = Expr(:block, :(thread_results = $f_quote)) + if C !== 0 + push!( + rem_quote.args, + :(@nexprs $C j -> RVAR_j = reducops[j](RVAR_j, thread_results[j])) + ) + end + update_retv = if C === 0 + Expr(:block) else - :(launch_batched_thread!(cfunc, tid, argtup, start, stop)) + quote + thread_results = load_threadlocals(tid, argtup, needtid, reducinits) + @nexprs $C j -> RVAR_j = reducops[j](RVAR_j, thread_results[j]) + end end - rem_quote = if thread_local - :(f!(arguments, (start + one(UInt)) % Int, ulen % Int, (sum(nthread_tuple) + 1) % Int)) + ret_quote = Expr(:return) + if C === 0 + push!(ret_quote.args, nothing) else - :(f!(arguments, (start + one(UInt)) % Int, ulen % Int)) + redtup = Expr(:tuple) + for j in 1:C + push!(redtup.args, Symbol("RVAR_", j)) + end + push!(ret_quote.args, redtup) end + block = quote start = zero(UInt) tid = 0x00000000 @@ -161,17 +257,18 @@ end for (threadmask, nthread) ∈ zip(threadmask_tuple, nthread_tuple) tm = mask(UnsignedIteratorEarlyStop(threadmask, nthread)) while tm ≠ zero(tm) - # assume(tm ≠ zero(tm)) + # assume(tm ≠ zero(tm)) tz = trailing_zeros(tm) % UInt32 tz += 0x00000001 tm >>>= tz tid += tz # @show tid, ThreadingUtilities._atomic_state(tid) ThreadingUtilities.wait(tid) + $update_retv end end free_threads!(torelease_tuple) - nothing + $ret_quote end gcpr = Expr(:gc_preserve, block, :cfunc) argt = Expr(:tuple) @@ -182,10 +279,9 @@ end q.args, :(arguments = $argt), :(argtup = Reference(arguments)), - :(cfunc = batch_closure(f!, argtup, Val{$thread_local}())), + :(cfunc = batch_closure(f!, argtup, Val{$S}(), reducinits)), gcpr, ) - push!(q.args, nothing) q end @@ -195,19 +291,24 @@ end args::Vararg{Any,K}, ) where {F,K} - batch(f!, Val{false}(), (len, nbatches), args...) + batch(f!, Val{false}(), (), (), (len, nbatches), args...) end @inline function batch( f!::F, - threadlocal::Val{thread_local}, + needtid::Val{S}, + reducops::Tuple{Vararg{Any,C}}, + reducinits::Tuple{Vararg{Any,C}}, (len, nbatches)::Tuple{Vararg{Union{StaticInt,Integer},2}}, args::Vararg{Any,K}, -) where {F,K,thread_local} - len > 0 || return +) where {F,K,C,S} + len > 0 || return reducinits + for var in reducinits + @assert isbits(var) + end if (nbatches > len) if (typeof(nbatches) !== typeof(len)) - return batch(f!, threadlocal, (len, len), args...) + return batch(f!, reducops, reducinits, needtid, (len, len), args...) end nbatches = len end @@ -218,19 +319,32 @@ end nthread = sum(nthreads) if nthread % Int32 ≤ zero(Int32) @label SERIAL - if thread_local - f!(args, one(Int), ulen % Int, 1) + if S + if C === 0 + reducres = f!(args, one(Int), ulen % Int, 1) + return reducres + else + reducres = f!(args, one(Int), ulen % Int, 1, reducinits) + return reducres + end else - f!(args, one(Int), ulen % Int) + if C === 0 + reducres = f!(args, one(Int), ulen % Int) + return reducres + else + reducres = f!(args, one(Int), ulen % Int, reducinits) + return reducres + end end - return end nbatch = nthread + one(nthread) Nd = Base.udiv_int(ulen, nbatch % UInt) # reasonable for `ulen` to be ≥ 2^32 Nr = (ulen - Nd * nbatch) % Int _batch_no_reserve( f!, - threadlocal, + needtid, + reducops, + reducinits, map(mask, threads), nthreads, torelease, @@ -244,7 +358,9 @@ function batch( f!::F, (len, nbatches, reserve_per_worker)::Tuple{Vararg{Union{StaticInt,Integer},3}}, args::Vararg{Any,K}; - threadlocal::Val{thread_local} = Val(false), -) where {F,K,thread_local} - batch(f!, threadlocal, (len, nbatches), args...) + needtid::Val{S} = Val(false), + reducops::Tuple{Vararg{Any,C}} = (), + reducinits::Tuple{Vararg{Any,C}} = (), +) where {F,K,C,S} + batch(f!, needtid, reducops, reducinits, (len, nbatches), args...) end diff --git a/src/closure.jl b/src/closure.jl index 6929ff2..2196815 100644 --- a/src/closure.jl +++ b/src/closure.jl @@ -225,7 +225,7 @@ function makestatic!(expr) end expr end -function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, stride, mod) +function enclose(exorig::Expr, minbatchsize, per, threadlocal, reduction, stride, mod) Meta.isexpr(exorig, :for, 2) || throw(ArgumentError("Expression invalid; should be a for loop.")) ex = copy(exorig) @@ -237,6 +237,7 @@ function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, str loop_offs = Symbol("##LOOPOFFSET##") innerloop = Symbol("##inner##loop##") rcombiner = Symbol("##split##recombined##") + reduction_op, reduction_var = reduction threadlocal_var = Symbol("threadlocal") #FIXME: don't do this? per = stride ? :thread : per @@ -244,6 +245,12 @@ function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, str arguments = Symbol[innerloop, rcombiner]#loop_offs, loop_step] defined = Dict{Symbol,Symbol}(loop_offs => loop_offs, loop_step => loop_step) threadlocal_var_gen = getgensym!(defined, threadlocal_var) + reduction_var_gen = Expr(:tuple) + if reduction_var !== Tuple{}() + for i ∈ eachindex(reduction_var) + push!(reduction_var_gen.args, getgensym!(defined, reduction_var[i])) + end + end define_induction_variables!(arguments, defined, ex, mod) firstloop = ex.args[1] if firstloop.head === :block @@ -340,39 +347,71 @@ function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, str push!(threadtup.args, :(min($il, $num_thread_expr))) end closure = Symbol("##closure##") - threadlocal, threadlocal_type = threadlocal_tuple - threadlocal_var_single = gensym(threadlocal_var) - q_single = symbolsubs(exorig, threadlocal_var, threadlocal_var_single) donothing = Expr(:block) + return_quote = Expr(:return) + # threadlocal stuff + threadlocal_var_single = gensym(threadlocal_var) + threadlocal_val, threadlocal_type = threadlocal + q_single = threadlocal_val === Symbol("") ? exorig : + symbolsubs(exorig, threadlocal_var, threadlocal_var_single) + # threadlocal_type = getfield(mod, threadlocal_type) + threadlocal_accum = Symbol("##THREADLOCAL##ACCUM##") threadlocal_init_single = - threadlocal === Symbol("") ? donothing : :($threadlocal_var_single = $threadlocal) - threadlocal_repack_single = - threadlocal === Symbol("") ? donothing : :($threadlocal_var_single) - threadlocal_single_store = - threadlocal === Symbol("") ? donothing : + threadlocal_val === Symbol("") ? donothing : + :($threadlocal_var_single = $threadlocal_val) + threadlocal_return_single = + threadlocal_val === Symbol("") ? donothing : :($threadlocal_var_single) + threadlocal_vect_single = + threadlocal_val === Symbol("") ? donothing : :($(esc(threadlocal_var)) = [single_thread_result]) - threadlocal_init1 = - threadlocal === Symbol("") ? donothing : - :($threadlocal_var = Vector{$threadlocal_type}(undef, 0)) - threadlocal_init2 = - threadlocal === Symbol("") ? donothing : - :(resize!($(esc(threadlocal_var)), max(1, $(threadtup.args[2])))) + threadlocal_init = + threadlocal_val === Symbol("") ? donothing : quote + $(esc(threadlocal_accum)) = + Vector{$threadlocal_type}(undef, max(1, $(threadtup.args[2]))) + end + threadlocal_vect = + threadlocal_val === Symbol("") ? donothing : + :($(esc(threadlocal_var)) = multi_thread_result) threadlocal_get = - threadlocal === Symbol("") ? donothing : - :($threadlocal_var_gen = $threadlocal::$threadlocal_type) + threadlocal_val === Symbol("") ? donothing : + :($threadlocal_var_gen::$threadlocal_type = $threadlocal_val) threadlocal_set = - threadlocal === Symbol("") ? donothing : - :($threadlocal_var[var"##THREAD##"] = $threadlocal_var_gen) - push!(q.args, threadlocal_init2) - args = Expr(:tuple, Symbol("##LOOPOFFSET##"), Symbol("##LOOP_STEP##")) - closure_args = if threadlocal !== Symbol("") || stride - :($args, var"##SUBSTART##"::Int, var"##SUBSTOP##"::Int, var"##THREAD##"::Int) + threadlocal_val === Symbol("") ? donothing : + :($threadlocal_accum[var"##THREAD##"] = $threadlocal_var_gen) + threadlocal_return = + threadlocal_val === Symbol("") ? donothing : :($threadlocal_accum) + threadlocal_val !== Symbol("") && push!(q.args, threadlocal_init) + # reduction stuff + reduction_ops = Expr(:tuple) + reduction_vars = Expr(:tuple) + reduction_inits = Expr(:tuple) + reduction_set = Expr(:block) + for i in eachindex(reduction_var) + op = getfield(Polyester, reduction_op[i]) + var = esc(reduction_var[i]) + init = :(initializer($op, $var)) + push!(reduction_ops.args, op) + push!(reduction_vars.args, var) + push!(reduction_inits.args, init) + push!(reduction_set.args, :($var = $op($var, reduction_final[$i]))) + end + reduction_init = + reduction_var === Tuple{}() ? donothing : + :($reduction_var_gen = var"##REDUCTION##INIT##") + if reduction_var !== Tuple{}() + push!(return_quote.args, reduction_var_gen) else - :($args, var"##SUBSTART##"::Int, var"##SUBSTOP##"::Int) + push!(return_quote.args, nothing) + end + + args = Expr(:tuple, Symbol("##LOOPOFFSET##"), Symbol("##LOOP_STEP##")) + closure_args = Expr(:tuple, args, :(var"##SUBSTART##"::Int), :(var"##SUBSTOP##"::Int)) + if threadlocal_val !== Symbol("") || stride + push!(closure_args.args, :(var"##THREAD##"::Int)) end + reduction_var !== Tuple{}() && push!(closure_args.args, Symbol("##REDUCTION##INIT##")) if stride # we are to do length(var"##SUBSTART##":var"##SUBSTOP##") iterations - # loop_start_expr = :(var"##THREAD##" * var"##LOOP_STEP##" + var"##LOOPOFFSET##" - var"##LOOP_STEP##") loop_stop_expr = :($loopstart + (var"##SUBSTOP##" - var"##SUBSTART##") * var"##STEP##") @@ -394,21 +433,24 @@ function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, str local $loop_stop = $loop_stop_expr # $(stride ? :(@show $loopstart, $loop_stop) : nothing) $threadlocal_get + $reduction_init @inbounds begin $excomb end $threadlocal_set - nothing + $return_quote end end end push!(q.args, esc(closureq)) - batchcall = if threadlocal !== Symbol("") || stride + batchcall = if threadlocal_val !== Symbol("") || stride Expr( :call, batch, esc(closure), Val(true), + reduction_ops, + reduction_inits, threadtup, Symbol("##LOOPOFFSET##"), Symbol("##LOOP_STEP##"), @@ -419,6 +461,8 @@ function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, str batch, esc(closure), Val(false), + reduction_ops, + reduction_inits, threadtup, Symbol("##LOOPOFFSET##"), Symbol("##LOOP_STEP##"), @@ -428,6 +472,10 @@ function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, str push!(args.args, get(defined, a, a)) push!(batchcall.args, esc(a)) end + if threadlocal_val !== Symbol("") + push!(args.args, threadlocal_accum) + push!(batchcall.args, esc(threadlocal_accum)) + end push!(q.args, batchcall) quote var"##NUM#THREADS##" = $(Threads.nthreads)() @@ -441,16 +489,18 @@ function enclose(exorig::Expr, minbatchsize, per::Symbol, threadlocal_tuple, str single_thread_result = begin $(esc(threadlocal_init_single)) # Initialize threadlocal storage $(esc(q_single)) - $(esc(threadlocal_repack_single)) + $(esc(threadlocal_return_single)) end - # Put the single-thread threadlocal storage in a single-element Vector - $threadlocal_single_store + $threadlocal_vect_single else - $(esc(threadlocal_init1)) - let - $q + multi_thread_result = let + reduction_final = $q + $reduction_set + $(esc(threadlocal_return)) end + $threadlocal_vect end + nothing end end @@ -461,16 +511,28 @@ Evaluate the loop on multiple threads. @batch minbatch=N for i in Iter; ...; end -Create a thread-local storage used in the loop. +Evaluate at least N iterations per thread. Will use at most `length(Iter) ÷ N` threads. @batch threadlocal=init() for i in Iter; ...; end +Create a thread-local storage used in the loop. + The `init` function will be called at the start at each thread. `threadlocal` will refer to storage local for the thread. At the end of the loop, a `threadlocal` vector containing all the thread-local values will be available. A type can be specified with `threadlocal=init()::Type`. -Evaluate at least N iterations per thread. Will use at most `length(Iter) ÷ N` threads. + @batch reduction=((op1, var1), (op2, var2), ...) for i in Iter; ...; end + +Perform OpenMP-esque reduction on the `isbits` variables `var1`, `var2`, `...` using the +operations `op1`, `op2`, `...` . The variables have to be initialized before the loop and +cannot be a fieldname like `x.y` or `x[i]`. +Supported operations are `+`, `*`, `min`, `max`, `&`, and `|`. The type does not have +to be provided, since it is already inferred from the initialized variables---**caution has +to be taken to ensure that the type remains consistent throughout the loop**. +While `threadlocal` can do the same thing, `reduction` does not incur additional allocations +and is generally more efficient for its purpose. It is up to the user to ensure that there +are no data dependencies between iterations, which could lead to incorrect results. @batch per=core for i in Iter; ...; end @batch per=thread for i in Iter; ...; end @@ -499,14 +561,17 @@ You can pass both `per=(core/thread)` and `minbatch=N` options at the same time, @batch stride=true for i in Iter; ...; end -This may be better for load balancing if iterations close to each other take a similar amount of time, but iterations far apart take different lengths of time. Setting this also forces `per=thread`. The default is `stride=false`. +This may be better for load balancing if iterations close to each other take a similar +amount of time, but iterations far apart take different lengths of time. Setting this also +forces `per=thread`. The default is `stride=false`. """ macro batch(ex) enclose( macroexpand(__module__, ex), 1, :unspecified, - (Symbol(""), :Any), + (Symbol(""), :Any), # threadlocal: var, type + (Tuple{}(), Tuple{}()), # reduction: ops, vars false, __module__, ) @@ -515,20 +580,35 @@ function interpret_kwarg( arg, minbatch = 1, per = :unspecified, - threadlocal = (Symbol(""), :Any), + threadlocal = (Symbol(""), :Any), # var, type + reduction = (Tuple{}(), Tuple{}()), # ops, vars stride = false, ) a = arg.args[1] v = arg.args[2] - if a === :reserve - @warn "reserve has been deprecated" - @assert v ≥ 0 - reserve_per = v - elseif a === :minbatch + if a === :minbatch minbatch = v elseif a === :per per = v::Symbol @assert (per === :core) | (per === :thread) + elseif a === :reduction + @assert Meta.isexpr(v, :tuple) && v.head == :tuple + if Meta.isexpr(v.args[1], :tuple, 2) + for red in v.args + @assert Meta.isexpr(red, :tuple, 2) && red.head == :tuple + end + reducops = ntuple(length(v.args)) do i + v.args[i].args[1] + end + @assert SUPPORTED_REDUCE_OPS ⊇ reducops "Unsupported reduction operation." + reducvars = ntuple(length(v.args)) do i + v.args[i].args[2] + end + @assert allunique(reducvars) + reduction = (reducops, reducvars) + else + reduction = ((v.args[1],), (v.args[2],)) + end elseif a === :threadlocal if Meta.isexpr(v, :(::), 2) && v.head == :(::) threadlocal = (v.args[1], v.args[2]) @@ -540,37 +620,90 @@ function interpret_kwarg( else throw(ArgumentError("kwarg $(a) not recognized.")) end - minbatch, per, threadlocal, stride + minbatch, per, threadlocal, reduction, stride end macro batch(arg1, ex) - minbatch, per, threadlocal, stride = interpret_kwarg(arg1) + minbatch, per, threadlocal, reduction, stride = interpret_kwarg(arg1) per = per === :unspecified ? (stride ? :thread : :core) : per - enclose(macroexpand(__module__, ex), minbatch, per, threadlocal, stride, __module__) + enclose( + macroexpand(__module__, ex), + minbatch, + per, + threadlocal, + reduction, + stride, + __module__ + ) end macro batch(arg1, arg2, ex) - minbatch, per, threadlocal, stride = interpret_kwarg(arg1) - minbatch, per, threadlocal, stride = - interpret_kwarg(arg2, minbatch, per, threadlocal, stride) + minbatch, per, threadlocal, reduction, stride = interpret_kwarg(arg1) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg2, minbatch, per, threadlocal, reduction, stride) per = per === :unspecified ? (stride ? :thread : :core) : per - enclose(macroexpand(__module__, ex), minbatch, per, threadlocal, stride, __module__) + enclose( + macroexpand(__module__, ex), + minbatch, + per, + threadlocal, + reduction, + stride, + __module__ + ) end macro batch(arg1, arg2, arg3, ex) - minbatch, per, threadlocal, stride = interpret_kwarg(arg1) - minbatch, per, threadlocal, stride = - interpret_kwarg(arg2, minbatch, per, threadlocal, stride) - minbatch, per, threadlocal, stride = - interpret_kwarg(arg3, minbatch, per, threadlocal, stride) + minbatch, per, threadlocal, reduction, stride = interpret_kwarg(arg1) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg2, minbatch, per, threadlocal, reduction, stride) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg3, minbatch, per, threadlocal, reduction, stride) per = per === :unspecified ? (stride ? :thread : :core) : per - enclose(macroexpand(__module__, ex), minbatch, per, threadlocal, stride, __module__) + enclose( + macroexpand(__module__, ex), + minbatch, + per, + threadlocal, + reduction, + stride, + __module__ + ) end macro batch(arg1, arg2, arg3, arg4, ex) - minbatch, per, threadlocal, stride = interpret_kwarg(arg1) - minbatch, per, threadlocal, stride = - interpret_kwarg(arg2, minbatch, per, threadlocal, stride) - minbatch, per, threadlocal, stride = - interpret_kwarg(arg3, minbatch, per, threadlocal, stride) - minbatch, per, threadlocal, stride = - interpret_kwarg(arg3, minbatch, per, threadlocal, stride) + minbatch, per, threadlocal, reduction, stride = interpret_kwarg(arg1) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg2, minbatch, per, threadlocal, reduction, stride) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg3, minbatch, per, threadlocal, reduction, stride) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg4, minbatch, per, threadlocal, reduction, stride) per = per === :unspecified ? (stride ? :thread : :core) : per - enclose(macroexpand(__module__, ex), minbatch, per, threadlocal, stride, __module__) + enclose( + macroexpand(__module__, ex), + minbatch, + per, + threadlocal, + reduction, + stride, + __module__ + ) +end +macro batch(arg1, arg2, arg3, arg4, arg5, ex) + minbatch, per, threadlocal, reduction, stride = interpret_kwarg(arg1) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg2, minbatch, per, threadlocal, reduction, stride) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg3, minbatch, per, threadlocal, reduction, stride) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg4, minbatch, per, threadlocal, reduction, stride) + minbatch, per, threadlocal, reduction, stride = + interpret_kwarg(arg5, minbatch, per, threadlocal, reduction, stride) + per = per === :unspecified ? (stride ? :thread : :core) : per + enclose( + macroexpand(__module__, ex), + minbatch, + per, + threadlocal, + reduction, + stride, + __module__ + ) end diff --git a/test/runtests.jl b/test/runtests.jl index 4f077dd..40f0eff 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -297,7 +297,7 @@ Base.eachindex(e::Iterators.Enumerate{LazyTree{T}}) where {T} = eachindex(e.itr) evt = 5 end -@testset "local thread storate" begin +@testset "threadlocal storage" begin local1 = let @batch threadlocal = 0 for i = 0:9 threadlocal += 1 @@ -376,6 +376,120 @@ end @test allocated(f) < 300 + 40 * Threads.nthreads() end +@testset "reduction" begin + local1 = let + red = 0 + @batch reduction = (+, red) for i = 0:9 + red += 1 + end + red + end + local2 = let + red = 0 + @batch minbatch = 5 reduction = (+, red) for i = 0:9 + red += 1 + end + red + end + local3 = let + red = 0 + @batch per = core reduction = (+, red) for i = 0:9 + red += 1 + end + red + end + local4 = let + red = 0 + @batch per = core minbatch = 100 reduction = (+, red) for i = 0:9 + red += 1 + end + red + end + local5 = let + red = 0 + @batch minbatch = 100 stride = true reduction = (+, red) for i = 0:9 + red += 1 + end + red + end + myinitA() = 0 + local6 = let + red = myinitA() + @batch reduction = (+, red) for i = 0:9 + red += 1 + end + red + end + local7, local8 = let + red = 0 + @batch minbatch = 100 stride = true reduction = (+,red) threadlocal = red for i = 0:9 + red += 1 + threadlocal += 1 + end + red, threadlocal[1] + end + localsr = let # stride + reduction w/o minbatch + red = 0 + @batch stride = true reduction = (+, red) for i = 0:9 + red += 1 + end + red + end + @test local1==local2==local3==local4==local5==local6==local7==local8==localsr + # check different operations + local9 = let + red = 1.0 + @batch reduction = (*,red) for i = 1:100 + red *= 4i^2 / (4i^2 - 1) + end + 2red + end + @test local9 ≈ 2prod(4i^2 / (4i^2 - 1) for i = 1:100) + # multiple reductions + local10, local11, local12 = let + red1 = 0 + red2 = 0 + red3 = 0 + @batch reduction = ((+,red1), (+,red2), (+,red3)) for i = 0:9 + red1 += 1 + red2 += 1 + red3 -= 1 + end + red1, red2, red3 + end + @test local10 == local11 == -local12 + # check for name interference with threadlocal (used to error on single threaded runs) + function f() + n = 1000 + threadlocal = false + @batch minbatch = 10 reduction = (+,threadlocal) for i = 1:n + threadlocal += true + end + return threadlocal + end + allocated(f::F) where {F} = @allocated f() + inferred(f::F) where {F} = try @inferred f(); true catch; false end + @test allocated(f) == 0 + @test inferred(f) == true + # remaining supported operations + arr = rand(10) + local13, local14, local15, local16 = let arr = arr + red1 = true + red2 = false + red3 = typemax(eltype(arr)) + red4 = typemin(eltype(arr)) + @batch reduction = ((&,red1), (|,red2), (min,red3), (max,red4)) for x in arr + red1 &= x > 0.5 + red2 |= x > 0.5 + red3 = min(red3, x) + red4 = max(red4, x) + end + red1, red2, red3, red4 + end + @test (local13, local14, local15, local16) == + (mapreduce(x->x>0.5, &, arr), mapreduce(x->x>0.5, |, arr), minimum(arr), maximum(arr)) +end + @testset "locks and refvalues" begin a = Ref(0.0) l = Threads.SpinLock() @@ -596,5 +710,5 @@ end if VERSION ≥ v"1.6" println("Package tests complete. Running `Aqua` checks.") - Aqua.test_all(Polyester) + Aqua.test_all(Polyester; deps_compat = (check_extras=false,)) end