Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Hack in metadata for MT3 #242

Draft
wants to merge 3 commits into
base: ale/3.0
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions docs/src/egraphs.md
Original file line number Diff line number Diff line change
Expand Up @@ -269,7 +269,7 @@ The `EGraph{E,A}` type is parametrized by the expression type `E` and the

The following functions define an interface for analyses based on multiple dispatch:

* [make(g::EGraph{ExprType, AnalysisType}, n)](@ref) should take an e-node `n::VecExpr` and return a value from the analysis domain.
* [make(g::EGraph{ExprType, AnalysisType}, n, md)](@ref) should take an e-node `n::VecExpr`, and metadata `md` from an expression (possibly `noting`), and return a value from the analysis domain.
* [join(x::AnalysisType, y::AnalysisType)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?* or *how should they be merged?*).`Base.isless` must be defined.
* [modify!(g::EGraph{ExprType, AnalysisType}, eclass::EClass{AnalysisType})](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclass.id]` on-the-fly during an e-graph saturation iteration, given its analysis value, typically by adding an e-node.

Expand Down Expand Up @@ -325,7 +325,9 @@ From the definition of an e-node, we know that children of e-nodes are always ID
to e-classes in the `EGraph`.

```@example custom_analysis
function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr) where {ExpressionType}
function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr, md) where {ExpressionType}
# metadata `md` is not used in this instance.

v_isexpr(n) || return odd_even_base_case(op)
# The e-node is not a literal value,
# Let's consider only binary function call terms.
Expand Down
20 changes: 11 additions & 9 deletions src/EGraphs/egraph.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,9 @@ when two eclasses are being merged or the analysis is being constructed.
function join end

"""
make(g::EGraph{ExpressionType, AnalysisType}, n::VecExpr)::AnalysisType where {ExpressionType}
make(g::EGraph{ExpressionType, AnalysisType}, n::VecExpr, md::Any)::AnalysisType where {ExpressionType}

Given an e-node `n`, `make` should return the corresponding analysis value.
Given an e-node `n`, and metadata extracted from an expression, `make` should return the corresponding analysis value. Implementations need to handle the default case where metadata is `nothing`.
"""
function make end

Expand Down Expand Up @@ -168,7 +168,7 @@ EGraph{ExpressionType}(e; kwargs...) where {ExpressionType} = EGraph{ExpressionT
EGraph(e; kwargs...) = EGraph{typeof(e),Nothing}(e; kwargs...)

# Fallback implementation for analysis methods make and modify
@inline make(::EGraph, ::VecExpr) = nothing
@inline make(::EGraph, ::VecExpr, md) = nothing
@inline modify!(::EGraph, ::EClass{Analysis}) where {Analysis} = nothing

@inline get_constant(@nospecialize(g::EGraph), hash::UInt64) = g.constants[hash]
Expand Down Expand Up @@ -252,7 +252,7 @@ end
"""
Inserts an e-node in an [`EGraph`](@ref)
"""
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis}
function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, md, should_copy::Bool)::Id where {ExpressionType,Analysis}
canonicalize!(g, n)

id = get(g.memo, n, zero(Id))
Expand All @@ -273,7 +273,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)
g.memo[n] = id

add_class_by_op(g, n, id)
eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n))
eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n, md))
g.classes[IdKey(id)] = eclass
modify!(g, eclass)
push!(g.pending, n => id)
Expand Down Expand Up @@ -301,8 +301,9 @@ insert the literal into the [`EGraph`](@ref).
function addexpr!(g::EGraph, se)::Id
se isa EClass && return se.id
e = preprocess(se)
md = metadata(e)

isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false)
isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), md, false)

args = iscall(e) ? arguments(e) : children(e)
ar = length(args)
Expand All @@ -317,7 +318,7 @@ function addexpr!(g::EGraph, se)::Id
@inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH])
end

add!(g, n, false)
add!(g, n, md, false)
end

"""
Expand Down Expand Up @@ -424,8 +425,9 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp
eclass_id = find(g, eclass_id)
eclass_id_key = IdKey(eclass_id)
eclass = g.classes[eclass_id_key]
md = eclass.data

node_data = make(g, node)
node_data = make(g, node, md)
if !isnothing(node_data)
if !isnothing(eclass.data)
joined_data = join(eclass.data, node_data)
Expand Down Expand Up @@ -471,7 +473,7 @@ end
function check_analysis(g)
for (id, eclass) in g.classes
isnothing(eclass.data) && continue
pass = mapreduce(x -> make(g, x), (x, y) -> join(x, y), eclass)
pass = mapreduce(x -> make(g, x, x.data), (x, y) -> join(x, y), eclass)
@assert eclass.data == pass
end
true
Expand Down
6 changes: 3 additions & 3 deletions src/EGraphs/extract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,12 @@ function Extractor(g::EGraph, cost_function::Function, cost_type = Float64)
extractor
end

function extract_expr_recursive(g::EGraph{T}, n::VecExpr, get_node::Function) where {T}
function extract_expr_recursive(g::EGraph{T, A}, n::VecExpr, get_node::Function) where {T, A}
h = get_constant(g, v_head(n))
v_isexpr(n) || return h
children = map(c -> extract_expr_recursive(g, c, get_node), get_node.(v_children(n)))
# TODO metadata?
maketerm(T, h, children, nothing)
md = metadata(g[lookup(g, n)].data)
maketerm(T, h, children, md)
end

function extract_expr_recursive(g::EGraph{Expr}, n::VecExpr, get_node::Function)
Expand Down
6 changes: 3 additions & 3 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ end

function instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatLiteral)::Id
add_constant_hashed!(g, p.value, v_head(p.n))
add!(g, p.n, true)
add!(g, p.n, nothing, true)
end

instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatVar)::Id = v_pair_first(bindings[p.idx])
Expand All @@ -127,7 +127,7 @@ function instantiate_enode!(bindings, g::EGraph{ExpressionType}, p::PatExpr)::Id
for i in v_children_range(p.n)
@inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH])
end
add!(g, p.n, true)
add!(g, p.n, nothing, true)
end

function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id
Expand All @@ -137,7 +137,7 @@ function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id
for i in v_children_range(p.n)
@inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH])
end
add!(g, p.n, true)
add!(g, p.n, nothing, true)
end

"""
Expand Down
2 changes: 1 addition & 1 deletion test/egraphs/analysis.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ Base.:(*)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n
Base.:(+)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n + b.n)

# This should be auto-generated by a macro
function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr) where {ExpressionType}
function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr, md) where {ExpressionType}
h = get_constant(g, v_head(n))
v_isexpr(n) || return h isa Number ? NumberFoldAnalysis(h) : nothing
if v_iscall(n) && v_arity(n) == 2
Expand Down
2 changes: 1 addition & 1 deletion test/integration/cas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ end
# ex = rewrite(ex, canonical_t; clean=false)


function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr)
function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr, md)
h = get_constant(g, v_head(n))
v_isexpr(n) || return (h in (:im, im) ? Complex : typeof(h))
v_iscall(n) || return (Any)
Expand Down
26 changes: 17 additions & 9 deletions test/tutorials/custom_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -62,13 +62,23 @@ ex = :(a[b])


# `metadata` should return the extra metadata. If you have many fields, i suggest using a `NamedTuple`.
# TermInterface.metadata(e::MyExpr) = e.foo
TermInterface.metadata(e::MyExpr) = e.foo

struct MetadataAnalysis
metadata
end

# Extract metadata from the analysis. The metadata is used in maketerm to reconstruct the expression.
TermInterface.metadata(ma::MetadataAnalysis) = ma.metadata

# struct MetadataAnalysis
# metadata
# end

# function EGraphs.make(g::EGraph{MyExprHead,MetadataAnalysis}, n::VecExpr) =
function EGraphs.join(a::MetadataAnalysis, b::MetadataAnalysis)
return a # Some fancy join here.
end

function EGraphs.make(g::EGraph{MyExpr,MetadataAnalysis}, n::VecExpr, md)
isnothing(md) ? md : MetadataAnalysis(md)
end

# Additionally, you can override `EGraphs.preprocess` on your custom expression
# to pre-process any expression before insertion in the E-Graph.
Expand Down Expand Up @@ -97,14 +107,12 @@ hcall = MyExpr(:h, [4], "hello")
ex = MyExpr(:f, [MyExpr(:z, [2]), hcall])
# We use the first type parameter an existing e-graph to inform the system about
# the *default* type of expressions that we want newly added expressions to have.
g = EGraph{MyExpr}(ex)
g = EGraph{MyExpr, MetadataAnalysis}(ex)

# Now let's test that it works.
saturate!(g, t)

# TODO metadata
# expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "")
expected = MyExpr(:f, [MyExpr(:h, [4], "")], "")
expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "")

extracted = extract!(g, astsize)
@test expected == extracted
Expand Down
2 changes: 1 addition & 1 deletion test/tutorials/lambda_theory.jl
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ const LambdaAnalysis = Set{Symbol}

getdata(eclass) = eclass.data

function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr) where {ExprType}
function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr, md) where {ExprType}
v_isexpr(n) || return LambdaAnalysis()
if v_iscall(n)
h = v_head(n)
Expand Down