Skip to content

Commit

Permalink
Mock Enzyme plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
vchuravy committed Oct 2, 2024
1 parent 052a118 commit e00cfb7
Show file tree
Hide file tree
Showing 2 changed files with 161 additions and 2 deletions.
142 changes: 140 additions & 2 deletions test/plugin_testsetup.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct NeverInlineMeta <: InlineStateMeta end
import GPUCompiler: abstract_call_known, GPUInterpreter
import Core.Compiler: CallMeta, Effects, NoCallInfo, ArgInfo,
StmtInfo, AbsIntState, EFFECTS_TOTAL,
MethodResultPure
MethodResultPure, CallInfo, IRCode

function abstract_call_known(meta::InlineStateMeta, interp::GPUInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
Expand Down Expand Up @@ -69,5 +69,143 @@ function inlining_handler(meta::InlineStateMeta, interp::GPUInterpreter, @nospec
return nothing
end

struct MockEnzymeMeta end

end
# Having to define this function is annoying
# introduce `abstract type InferenceMeta`
function inlining_handler(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(atype), callinfo)
return nothing
end

function autodiff end

import GPUCompiler: DeferredCallInfo
struct AutodiffCallInfo <: CallInfo
rt
info::DeferredCallInfo
end

function abstract_call_known(meta::MockEnzymeMeta, interp::GPUInterpreter, @nospecialize(f),
arginfo::ArgInfo, si::StmtInfo, sv::AbsIntState, max_methods::Int)
(; fargs, argtypes) = arginfo

if f === autodiff
if length(argtypes) <= 1
@static if VERSION < v"1.11.0-"
return CallMeta(Union{}, Effects(), NoCallInfo())
else
return CallMeta(Union{}, Union{}, Effects(), NoCallInfo())
end
end

other_fargs = fargs === nothing ? nothing : fargs[2:end]
other_arginfo = ArgInfo(other_fargs, argtypes[2:end])
call = Core.Compiler.abstract_call(interp, other_arginfo, si, sv, max_methods)
callinfo = DeferredCallInfo(MockEnzymeMeta(), call.rt, call.info)

# Real Enzyme must compute `rt` and `exct` according to enzyme semantics
# and likely perform a unwrapping of fargs...
rt = call.rt

# TODO: Edges? Effects?
@static if VERSION < v"1.11.0-"
# Can't use call.effects since otherwise this call might be just replaced with rt
return CallMeta(rt, Effects(), AutodiffCallInfo(rt, callinfo))
else
return CallMeta(rt, call.exct, Effects(), AutodiffCallInfo(rt, callinfo))
end
end

return nothing
end

import Core.Compiler: insert_node!, NewInstruction, ReturnNode, Instruction, InliningState, Signature

# We really need a Compiler stdlib
Base.getindex(ir::IRCode, i) = Core.Compiler.getindex(ir, i)
Base.setindex!(inst::Instruction, val, i) = Core.Compiler.setindex!(inst, val, i)

const FlagType = VERSION >= v"1.11.0-" ? UInt32 : UInt8
function Core.Compiler.handle_call!(todo::Vector{Pair{Int,Any}}, ir::IRCode, stmt_idx::Int,
stmt::Expr, info::AutodiffCallInfo, flag::FlagType,
sig::Signature, state::InliningState)

# Goal:
# The IR we want to inline here is:
# unpack the args ..
# ptr = gpuc.deferred(MockEnzymeMeta(), f, primal_args...)
# ret = ccall("extern __autodiff", llvmcall, RT, Tuple{Ptr{Cvoid, args...}}, ptr, adjoint_args...)

# 0. Obtain primal mi from DeferredCallInfo
# TODO: remove this code duplication
deferred_info = info.info
minfo = deferred_info.info
results = minfo.results
if length(results.matches) != 1
return nothing
end
match = only(results.matches)

# lookup the target mi with correct edge tracking
# TODO: Effects?
case = Core.Compiler.compileable_specialization(
match, Core.Compiler.Effects(), Core.Compiler.InliningEdgeTracker(state), info)
@assert case isa Core.Compiler.InvokeCase
@assert stmt.head === :call

# Now create the IR we want to inline
ir = Core.Compiler.IRCode() # contains a placeholder
args = [Core.Compiler.Argument(i) for i in 3:length(stmt.args)]
idx = 0

# 0. Enzyme proper: Desugar args
primal_args = args
primal_argtypes = match.spec_types.parameters[2:end]

adjoint_rt = info.rt
adjoint_args = args # TODO
adjoint_argtypes = primal_argtypes

# 1: Since Julia's inliner goes bottom up we need to pretend that we inlined the deferred call
expr = Expr(:foreigncall,
"extern gpuc.lookup",
Ptr{Cvoid},
Core.svec(Any, Any, Any, primal_argtypes...), # Must use Any for MethodInstance or ftype
0,
QuoteNode(:llvmcall),
deferred_info.meta,
case.invoke,
primal_args...
)
ptr = insert_node!(ir, (idx += 1), NewInstruction(expr, Ptr{Cvoid}))

# 2. Call to magic `__autodiff`
expr = Expr(:foreigncall,
"extern __autodiff",
adjoint_rt,
Core.svec(Any, Ptr{Cvoid}, adjoint_argtypes...),
0,
QuoteNode(:llvmcall),
ptr,
adjoint_args...
)
ret = insert_node!(ir, idx, NewInstruction(expr, adjoint_rt))

# Finally replace placeholder return
ir[Core.SSAValue(1)][:inst] = Core.ReturnNode(ret)
ir[Core.SSAValue(1)][:type] = Ptr{Cvoid}

ir = Core.Compiler.compact!(ir)

# which mi to use here?
# push inlining todos
# TODO: Effects
# aviatesk mentioned using inlining_policy instead...
itodo = Core.Compiler.InliningTodo(case.invoke, ir, Core.Compiler.Effects())
@assert itodo.linear_inline_eligible
push!(todo, (stmt_idx=>itodo))

return nothing
end

end #module
21 changes: 21 additions & 0 deletions test/ptx_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -504,4 +504,25 @@ end
ir = sprint(io->PTX.code_llvm(io, kernel_inline, Tuple{Ptr{Int64}, Int64}, meta=Plugin.NeverInlineMeta()))
@test occursin("call fastcc i64 @julia_inline", ir)
end

@testset "Mock Enzyme" begin
function f(x)
x^2
end

function kernel(a, x)
y = Plugin.autodiff(f, x)
unsafe_store!(a, y)
nothing
end


@show PTX.code_typed(kernel, Tuple{Ptr{Float64}, Float64}, meta=Plugin.MockEnzymeMeta())

# FIXME: the fact that meta is necessary here almost invalidates that extension mechanism
# we somehow need to be able to add this kind of "autodiff" abs int handling.
ir = sprint(io->PTX.code_llvm(io, kernel, Tuple{Ptr{Float64}, Float64}, meta=Plugin.MockEnzymeMeta()))
@test occursin("call double @__autodiff", ir)
end

end #testitem

0 comments on commit e00cfb7

Please sign in to comment.