From 13523ac5e5036a0ddc8edb65e752e252bd725774 Mon Sep 17 00:00:00 2001 From: Dennis Yatunin Date: Tue, 10 Sep 2024 11:04:05 -0700 Subject: [PATCH] Fix inlining of generated functions --- src/UnrolledUtilities.jl | 10 ++-- src/generatively_unrolled_functions.jl | 78 +++++++++++++++++--------- test/test_and_analyze.jl | 12 ++-- 3 files changed, 60 insertions(+), 40 deletions(-) diff --git a/src/UnrolledUtilities.jl b/src/UnrolledUtilities.jl index 95f1b83..56ba671 100644 --- a/src/UnrolledUtilities.jl +++ b/src/UnrolledUtilities.jl @@ -56,10 +56,10 @@ include("generatively_unrolled_functions.jl") error("unrolled_applyat has detected an out-of-bounds index") @inline unrolled_reduce(op, itr, init) = - (rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init) -@inline unrolled_reduce(op, itr; init = NoInit()) = isempty(itr) && init isa NoInit ? error("unrolled_reduce requires an init value for empty iterators") : + (rec_unroll(itr) ? rec_unrolled_reduce : gen_unrolled_reduce)(op, itr, init) +@inline unrolled_reduce(op, itr; init = NoInit()) = unrolled_reduce(op, itr, init) # TODO: Figure out why unrolled_reduce(op, Val(N), init) compiles faster than @@ -67,11 +67,9 @@ include("generatively_unrolled_functions.jl") # parametrization test in ClimaAtmos, to the point where the StaticOneTo version # completely hangs while the Val version compiles in only a few seconds. @inline unrolled_reduce(op, val_N::Val, init) = - val_unrolled_reduce(op, val_N, init) -@inline unrolled_reduce(op, val_N::Val; init = NoInit()) = val_N isa Val{0} && init isa NoInit ? - error("unrolled_reduce requires an init value for empty iterators") : - unrolled_reduce(op, val_N, init) + error("unrolled_reduce requires an init value for Val(0)") : + val_unrolled_reduce(op, val_N, init) @inline unrolled_mapreduce(f, op, itrs...; init = NoInit()) = unrolled_reduce(op, Iterators.map(f, itrs...), init) diff --git a/src/generatively_unrolled_functions.jl b/src/generatively_unrolled_functions.jl index 53c1c85..45cb9fd 100644 --- a/src/generatively_unrolled_functions.jl +++ b/src/generatively_unrolled_functions.jl @@ -1,39 +1,55 @@ -@inline @generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = - Expr(:||, (:(f(generic_getindex(itr, $n))) for n in 1:N)...) +@generated _gen_unrolled_any(::Val{N}, f, itr) where {N} = Expr( + :block, + Expr(:meta, :inline), + Expr(:||, (:(f(generic_getindex(itr, $n))) for n in 1:N)...), +) @inline gen_unrolled_any(f, itr) = _gen_unrolled_any(Val(length(itr)), f, itr) -@inline @generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = - Expr(:&&, (:(f(generic_getindex(itr, $n))) for n in 1:N)...) +@generated _gen_unrolled_all(::Val{N}, f, itr) where {N} = Expr( + :block, + Expr(:meta, :inline), + Expr(:&&, (:(f(generic_getindex(itr, $n))) for n in 1:N)...), +) @inline gen_unrolled_all(f, itr) = _gen_unrolled_all(Val(length(itr)), f, itr) -@inline @generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = - Expr(:block, (:(f(generic_getindex(itr, $n))) for n in 1:N)..., nothing) +@generated _gen_unrolled_foreach(::Val{N}, f, itr) where {N} = Expr( + :block, + Expr(:meta, :inline), + (:(f(generic_getindex(itr, $n))) for n in 1:N)..., + nothing, +) @inline gen_unrolled_foreach(f, itr) = _gen_unrolled_foreach(Val(length(itr)), f, itr) -@inline @generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = - Expr(:tuple, (:(f(generic_getindex(itr, $n))) for n in 1:N)...) +@generated _gen_unrolled_map(::Val{N}, f, itr) where {N} = Expr( + :block, + Expr(:meta, :inline), + Expr(:tuple, (:(f(generic_getindex(itr, $n))) for n in 1:N)...), +) @inline gen_unrolled_map(f, itr) = _gen_unrolled_map(Val(length(itr)), f, itr) -@inline @generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr( +@generated _gen_unrolled_applyat(::Val{N}, f, n, itr) where {N} = Expr( :block, + Expr(:meta, :inline), (:(n == $n && return f(generic_getindex(itr, $n))) for n in 1:N)..., :(unrolled_applyat_bounds_error()), ) # This block gets optimized into a switch instruction during LLVM codegen. @inline gen_unrolled_applyat(f, n, itr) = _gen_unrolled_applyat(Val(length(itr)), f, n, itr) -@inline @generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = +@generated _gen_unrolled_reduce(::Val{N}, op, itr, init) where {N} = Expr( + :block, + Expr(:meta, :inline), foldl( - init <: NoInit ? (2:N) : (1:N); + (op_expr, n) -> :(op($op_expr, generic_getindex(itr, $n))), + (init <: NoInit ? 2 : 1):N; init = init <: NoInit ? :(generic_getindex(itr, 1)) : :init, - ) do prev_op_expr, n - :(op($prev_op_expr, generic_getindex(itr, $n))) - end # Use foldl instead of reduce to guarantee left associativity. + ), # Use foldl instead of reduce to guarantee left associativity. +) @inline gen_unrolled_reduce(op, itr, init) = _gen_unrolled_reduce(Val(length(itr)), op, itr, init) -@inline @generated function _gen_unrolled_accumulate( +@generated function _gen_unrolled_accumulate( ::Val{N}, op, itr, @@ -43,12 +59,16 @@ first_item_expr = :(generic_getindex(itr, 1)) init_expr = init <: NoInit ? first_item_expr : :(op(init, $first_item_expr)) transformed_exprs_and_op_exprs = - accumulate(1:N; init = (nothing, init_expr)) do (_, prev_op_expr), n + accumulate(1:N; init = (nothing, init_expr)) do (_, op_expr), n var = gensym() - op_expr = :(op($var, generic_getindex(itr, $(n + 1)))) - (:($var = $prev_op_expr; transform($var)), op_expr) + next_op_expr = :(op($var, generic_getindex(itr, $(n + 1)))) + (:($var = $op_expr; transform($var)), next_op_expr) end - return Expr(:tuple, Iterators.map(first, transformed_exprs_and_op_exprs)...) + return Expr( + :block, + Expr(:meta, :inline), + Expr(:tuple, Iterators.map(first, transformed_exprs_and_op_exprs)...), + ) end @inline gen_unrolled_accumulate(op, itr, init, transform) = _gen_unrolled_accumulate(Val(length(itr)), op, itr, init, transform) @@ -56,16 +76,18 @@ end # TODO: The following is experimental and will likely be removed in the future. # For some reason, combining these two methods into one (or combining them with # the method for gen_unrolled_reduce defined above) causes compilation of the -# non-orographic gravity wave parametrization test in ClimaAtmos to hang. Even -# more bizarrely, using the assignment form of the first method definition below -# (as opposed to the function syntax used here) causes compilation to hang as -# well. This behavior has not yet been replicated in a minimal working example. -@inline @generated function val_unrolled_reduce(op, ::Val{N}, init) where {N} +# non-orographic gravity wave parametrization test in ClimaAtmos to hang. +# Wrapping the first method's result in a block and adding an inline annotation +# also causes compilation to hang. Even using the assignment form of the first +# method definition below (as opposed to the function syntax used here) causes +# it to hang. This has not yet been replicated in a minimal working example. +@generated function val_unrolled_reduce(op, ::Val{N}, init) where {N} return foldl((:init, 1:N...)) do prev_op_expr, item_expr :(op($prev_op_expr, $item_expr)) end end -@inline @generated val_unrolled_reduce(op, ::Val{N}, ::NoInit) where {N} = - foldl(1:N) do prev_op_expr, item_expr - :(op($prev_op_expr, $item_expr)) - end +@generated val_unrolled_reduce(op, ::Val{N}, ::NoInit) where {N} = Expr( + :block, + Expr(:meta, :inline), + foldl((op_expr, item_expr) -> :(op($op_expr, $item_expr)), 1:N), +) diff --git a/test/test_and_analyze.jl b/test/test_and_analyze.jl index b1beb7d..400746e 100644 --- a/test/test_and_analyze.jl +++ b/test/test_and_analyze.jl @@ -877,7 +877,7 @@ title = "Very Long Iterators" comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) @testset "unrolled functions of Tuples vs. StaticOneTos" begin - for itr in (ntuple(identity, 2000), StaticOneTo(2000), StaticOneTo(8186)) + for itr in (ntuple(identity, 2000), StaticOneTo(2000), StaticOneTo(8185)) @test_unrolled (itr,) unrolled_reduce(+, itr) reduce(+, itr) "Ints" @test_unrolled( (itr,), @@ -885,15 +885,15 @@ comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) mapreduce(log, +, itr), "Ints", ) - end # These can each take 40 seconds to compile for ntuple(identity, 8186). - for itr in (ntuple(identity, 8187), StaticOneTo(8187)) + end # These can each take 40 seconds to compile for ntuple(identity, 8185). + for itr in (ntuple(identity, 8186), StaticOneTo(8186)) @test_throws "gc handles" unrolled_reduce(+, itr) @test_throws "gc handles" unrolled_mapreduce(log, +, itr) end # TODO: Why does the compiler throw an error when generating functions that - # get unrolled into more than 8186 lines of LLVM code? + # get unrolled into more than 8185 lines of LLVM code? - for itr in (StaticOneTo(8186), StaticOneTo(8187)) + for itr in (StaticOneTo(8185), StaticOneTo(8186)) @test_unrolled( (itr,), unrolled_reduce(+, Val(length(itr))), @@ -902,7 +902,7 @@ comparison_table_dict = (comparison_table_dicts[title] = OrderedDict()) ) end @test_throws "gc handles" unrolled_reduce(+, Val(8188)) - # TODO: Why is the limit 8187 for the Val version of unrolled_reduce? + # TODO: Why is the limit 8186 for the Val version of unrolled_reduce? end title = "Generative vs. Recursive Unrolling"