Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

1.11: more methodinstance stuff #1989

Merged
merged 7 commits into from
Oct 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ BFloat16s = "0.2, 0.3, 0.4, 0.5"
CEnum = "0.4, 0.5"
ChainRulesCore = "1"
EnzymeCore = "0.8.4"
Enzyme_jll = "0.0.155"
Enzyme_jll = "0.0.156"
GPUCompiler = "0.21, 0.22, 0.23, 0.24, 0.25, 0.26, 0.27, 1"
LLVM = "6.1, 7, 8, 9"
LogExpFunctions = "0.3"
Expand Down
24 changes: 0 additions & 24 deletions src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3816,30 +3816,6 @@ function enzyme!(
LLVM.API.LLVMValueRef,
)
),
"julia.gc_loaded" => @cfunction(
inoutgcloaded_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"julia.pointer_from_objref" => @cfunction(
inout_rule,
UInt8,
(
Cint,
API.CTypeTreeRef,
Ptr{API.CTypeTreeRef},
Ptr{API.IntList},
Csize_t,
LLVM.API.LLVMValueRef,
)
),
"jl_inactive_inout" => @cfunction(
inout_rule,
UInt8,
Expand Down
19 changes: 9 additions & 10 deletions src/rules/llvmrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -628,13 +628,11 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin

if memory
if fwd
shadowsrc = inttoptr!(B, memoryptr, LLVM.PointerType(LLVM.IntType(8)))
lookup_src = false
shadowsrc = invert_pointer(gutils, memoryptr, B)
else
shadowsrc = invert_pointer(gutils, shadowsrc, B)
if !fwd
shadowsrc = lookup_value(gutils, shadowsrc, B)
end
shadowsrc = invert_pointer(gutils, shadowsrc, B)
shadowsrc = lookup_value(gutils, shadowsrc, B)
end
else
shadowsrc = invert_pointer(gutils, shadowsrc, B)
Expand Down Expand Up @@ -674,12 +672,13 @@ function arraycopy_common(fwd, B, orig, shadowsrc, gutils, shadowdst; len=nothin
# src already has done the lookup from the argument
shadowsrc0 = if lookup_src
if memory
# TODO this may not be at the same offset as the start of the copy, e.g. get_memory_data(src) != memoryptr
get_memory_data(B, evsrc)
else
get_array_data(B, evsrc)
end
else
evsrc
inttoptr!(B, evsrc, LLVM.PointerType(LLVM.IntType(8)))
end

shadowdst0 = if memory
Expand Down Expand Up @@ -781,7 +780,7 @@ end
false,
) #=lookup=#
if is_constant_value(gutils, origops[1])
elSize = get_array_elsz(B, ev)
elSize = get_memory_elsz(B, ev)
elSize = LLVM.zext!(B, elSize, LLVM.IntType(8 * sizeof(Csize_t)))
length = LLVM.mul!(B, len, elSize)
bt = GPUCompiler.backtrace(orig)
Expand All @@ -792,7 +791,7 @@ end
GPUCompiler.@safe_warn "TODO forward zero-set of memorycopy used memset rather than runtime type $btstr"
LLVM.memset!(
B,
ev2,
inttoptr!(B, ev2, LLVM.PointerType(LLVM.IntType(8))),
LLVM.ConstantInt(i8, 0, false),
length,
algn,
Expand Down Expand Up @@ -838,7 +837,7 @@ end
shadowres = LLVM.Value(unsafe_load(shadowR))

len = new_from_original(gutils, origops[3])
memoryptr = new_from_original(gutils, origops[2])
memoryptr = origops[2]
arraycopy_common(true, B, orig, origops[1], gutils, shadowres; len, memoryptr)
end

Expand All @@ -849,7 +848,7 @@ end
origops = LLVM.operands(orig)
if !is_constant_value(gutils, origops[1]) && !is_constant_value(gutils, orig)
len = new_from_original(gutils, origops[3])
memoryptr = new_from_original(gutils, origops[2])
memoryptr = origops[2]
arraycopy_common(false, B, orig, origops[1], gutils, nothing; len, memoryptr)
end

Expand Down
40 changes: 0 additions & 40 deletions src/rules/typerules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,43 +92,3 @@ function inoutcopyslice_rule(
end
return UInt8(false)
end

function inoutgcloaded_rule(
direction::Cint,
ret::API.CTypeTreeRef,
args::Ptr{API.CTypeTreeRef},
known_values::Ptr{API.IntList},
numArgs::Csize_t,
val::LLVM.API.LLVMValueRef,
)::UInt8
if numArgs != 1
return UInt8(false)
end
inst = LLVM.Instruction(val)

legal, typ = abs_typeof(inst)

if legal
if (direction & API.DOWN) != 0
ctx = LLVM.context(inst)
dl = string(LLVM.datalayout(LLVM.parent(LLVM.parent(LLVM.parent(inst)))))
if GPUCompiler.deserves_retbox(typ)
typ = Ptr{typ}
end
rest = typetree(typ, ctx, dl)
changed, legal = API.EnzymeCheckedMergeTypeTree(ret, rest)
@assert legal
end
return UInt8(false)
end

if (direction & API.UP) != 0
changed, legal = API.EnzymeCheckedMergeTypeTree(unsafe_load(args, 2), ret)
@assert legal
end
if (direction & API.DOWN) != 0
changed, legal = API.EnzymeCheckedMergeTypeTree(ret, unsafe_load(args, 2))
@assert legal
end
return UInt8(false)
end
75 changes: 62 additions & 13 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,24 +253,73 @@ export codegen_world_age

if VERSION >= v"1.11.0-DEV.1552"


const prevmethodinstance = GPUCompiler.generic_methodinstance

function methodinstance_generator(world::UInt, source, self, ft::Type, tt::Type)
@nospecialize
@assert Core.Compiler.isType(ft) && Core.Compiler.isType(tt)
ft = ft.parameters[1]
tt = tt.parameters[1]

stub = Core.GeneratedFunctionStub(identity, Core.svec(:methodinstance, :ft, :tt), Core.svec())

# look up the method match
method_error = :(throw(MethodError(ft, tt, $world)))
sig = Tuple{ft, tt.parameters...}
min_world = Ref{UInt}(typemin(UInt))
max_world = Ref{UInt}(typemax(UInt))
match = ccall(:jl_gf_invoke_lookup_worlds, Any,
(Any, Any, Csize_t, Ref{Csize_t}, Ref{Csize_t}),
sig, #=mt=# nothing, world, min_world, max_world)
match === nothing && return stub(world, source, method_error)

# look up the method and code instance
mi = ccall(:jl_specializations_get_linfo, Ref{MethodInstance},
(Any, Any, Any), match.method, match.spec_types, match.sparams)
ci = Core.Compiler.retrieve_code_info(mi, world)

# prepare a new code info
new_ci = copy(ci)
empty!(new_ci.code)
empty!(new_ci.codelocs)
empty!(new_ci.linetable)
empty!(new_ci.ssaflags)
new_ci.ssavaluetypes = 0

# propagate edge metadata
new_ci.min_world = min_world[]
new_ci.max_world = max_world[]
new_ci.edges = MethodInstance[mi]

# prepare the slots
new_ci.slotnames = Symbol[Symbol("#self#"), :ft, :tt]
new_ci.slotflags = UInt8[0x00 for i = 1:3]

# return the method instance
push!(new_ci.code, Core.Compiler.ReturnNode(mi))
push!(new_ci.ssaflags, 0x00)
push!(new_ci.linetable, GPUCompiler.@LineInfoNode(methodinstance))
push!(new_ci.codelocs, 1)
new_ci.ssavaluetypes += 1

return new_ci
end

@eval function prevmethodinstance(ft, tt)
$(Expr(:meta, :generated_only))
$(Expr(:meta, :generated, methodinstance_generator))
end

# XXX: version of Base.method_instance that uses a function type
@inline function my_methodinstance(@nospecialize(ft::Type), @nospecialize(tt::Type),
world::Integer=tls_world_age())
sig = GPUCompiler.signature_type_by_tt(ft, tt)
# @assert Base.isdispatchtuple(sig) # JuliaLang/julia#52233

mi = ccall(:jl_method_lookup_by_tt, Any,
(Any, Csize_t, Any),
sig, world, #=method_table=# nothing)
mi === nothing && throw(MethodError(ft, tt, world))
mi = mi::MethodInstance

# `jl_method_lookup_by_tt` and `jl_method_lookup` can return a unspecialized mi
if !Base.isdispatchtuple(mi.specTypes)
mi = Core.Compiler.specialize_method(mi.def, sig, mi.sparam_vals)::MethodInstance
if Base.isdispatchtuple(sig) # JuliaLang/julia#52233
return GPUCompiler.methodinstance(ft, tt, world)
else
return prevmethodinstance(ft, tt, world)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You just want to call mi = Core.Compiler.specialize_method(mi.def, sig, mi.sparam_vals)::MethodInstance in this case, but you need a call to which to obtain def

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Edit: Maybe? This is all pretty much unsound.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I mean happy to defer that to the experts, it's now just falling back to the old implementation

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Old from 1.10 in gpucompiler*** while concurrently I opened a PR to gpucompiler jl to actually fix and not have us have to vendor it

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

end

return mi
end
else
import GPUCompiler: methodinstance as my_methodinstance
Expand Down
4 changes: 2 additions & 2 deletions test/internal_rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@ module InternalRules

using Enzyme
using Enzyme.EnzymeRules
using EnzymeTestUtils
using FiniteDifferences
using LinearAlgebra
using SparseArrays
using Test
Expand Down Expand Up @@ -155,6 +153,7 @@ function tr_solv(A, B, uplo, trans, diag, idx)
end


using FiniteDifferences
@testset "Reverse triangular solve" begin
A = [0.7550523937508613 0.7979976952197996 0.29318222271218364; 0.4416768066117529 0.4335305304334933 0.8895389673238051; 0.07752980210005678 0.05978245503334367 0.4504482683752542]
B = [0.10527381151977078 0.5450388247476627 0.3179106723232359 0.43919576779182357 0.20974326586875847; 0.7551160501548224 0.049772782182839426 0.09284926395551141 0.07862188927391855 0.17346407477062986; 0.6258040138863172 0.5928022963567454 0.24251650865340169 0.6626410383247967 0.32752198021506784]
Expand Down Expand Up @@ -576,6 +575,7 @@ end
@test Enzyme.gradient(Reverse, chol_upper, x)[1] ≈ [0.05270807565639728, 0.0, 0.0, 0.0, 0.9999999999999999, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]
end

using EnzymeTestUtils
@testset "Linear solve for triangular matrices" begin
@testset for T in (UpperTriangular, LowerTriangular, UnitUpperTriangular, UnitLowerTriangular),
TE in (Float64, ComplexF64), sizeB in ((3,), (3, 3))
Expand Down
Loading