diff --git a/docs/src/egraphs.md b/docs/src/egraphs.md index 9fe1d229..fa6a80a7 100644 --- a/docs/src/egraphs.md +++ b/docs/src/egraphs.md @@ -269,7 +269,7 @@ The `EGraph{E,A}` type is parametrized by the expression type `E` and the The following functions define an interface for analyses based on multiple dispatch: -* [make(g::EGraph{ExprType, AnalysisType}, n)](@ref) should take an e-node `n::VecExpr` and return a value from the analysis domain. +* [make(g::EGraph{ExprType, AnalysisType}, n, md)](@ref) should take an e-node `n::VecExpr`, and metadata `md` from an expression (possibly `noting`), and return a value from the analysis domain. * [join(x::AnalysisType, y::AnalysisType)](@ref) should return the semilattice join of `x` and `y` in the analysis domain (e.g. *given two analyses value from ENodes in the same EClass, which one should I choose?* or *how should they be merged?*).`Base.isless` must be defined. * [modify!(g::EGraph{ExprType, AnalysisType}, eclass::EClass{AnalysisType})](@ref) Can be optionally implemented. This can be used modify an EClass `egraph[eclass.id]` on-the-fly during an e-graph saturation iteration, given its analysis value, typically by adding an e-node. @@ -325,7 +325,9 @@ From the definition of an e-node, we know that children of e-nodes are always ID to e-classes in the `EGraph`. ```@example custom_analysis -function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr) where {ExpressionType} +function EGraphs.make(g::EGraph{ExpressionType,OddEvenAnalysis}, op, n::VecExpr, md) where {ExpressionType} + # metadata `md` is not used in this instance. + v_isexpr(n) || return odd_even_base_case(op) # The e-node is not a literal value, # Let's consider only binary function call terms. diff --git a/src/EGraphs/egraph.jl b/src/EGraphs/egraph.jl index a4f0e5f3..c38ea512 100644 --- a/src/EGraphs/egraph.jl +++ b/src/EGraphs/egraph.jl @@ -23,9 +23,9 @@ when two eclasses are being merged or the analysis is being constructed. function join end """ - make(g::EGraph{ExpressionType, AnalysisType}, n::VecExpr)::AnalysisType where {ExpressionType} + make(g::EGraph{ExpressionType, AnalysisType}, n::VecExpr, md::Any)::AnalysisType where {ExpressionType} -Given an e-node `n`, `make` should return the corresponding analysis value. +Given an e-node `n`, and metadata extracted from an expression, `make` should return the corresponding analysis value. Implementations need to handle the default case where metadata is `nothing`. """ function make end @@ -168,7 +168,7 @@ EGraph{ExpressionType}(e; kwargs...) where {ExpressionType} = EGraph{ExpressionT EGraph(e; kwargs...) = EGraph{typeof(e),Nothing}(e; kwargs...) # Fallback implementation for analysis methods make and modify -@inline make(::EGraph, ::VecExpr) = nothing +@inline make(::EGraph, ::VecExpr, md) = nothing @inline modify!(::EGraph, ::EClass{Analysis}) where {Analysis} = nothing @inline get_constant(@nospecialize(g::EGraph), hash::UInt64) = g.constants[hash] @@ -252,7 +252,7 @@ end """ Inserts an e-node in an [`EGraph`](@ref) """ -function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool)::Id where {ExpressionType,Analysis} +function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, md, should_copy::Bool)::Id where {ExpressionType,Analysis} canonicalize!(g, n) id = get(g.memo, n, zero(Id)) @@ -273,7 +273,7 @@ function add!(g::EGraph{ExpressionType,Analysis}, n::VecExpr, should_copy::Bool) g.memo[n] = id add_class_by_op(g, n, id) - eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n)) + eclass = EClass{Analysis}(id, VecExpr[copy(n)], Pair{VecExpr,Id}[], make(g, n, md)) g.classes[IdKey(id)] = eclass modify!(g, eclass) push!(g.pending, n => id) @@ -301,8 +301,9 @@ insert the literal into the [`EGraph`](@ref). function addexpr!(g::EGraph, se)::Id se isa EClass && return se.id e = preprocess(se) + md = metadata(e) - isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), false) + isexpr(e) || return add!(g, VecExpr(Id[Id(0), Id(0), Id(0), add_constant!(g, e)]), md, false) args = iscall(e) ? arguments(e) : children(e) ar = length(args) @@ -317,7 +318,7 @@ function addexpr!(g::EGraph, se)::Id @inbounds n[i] = addexpr!(g, args[i - VECEXPR_META_LENGTH]) end - add!(g, n, false) + add!(g, n, md, false) end """ @@ -424,8 +425,9 @@ function process_unions!(g::EGraph{ExpressionType,AnalysisType})::Int where {Exp eclass_id = find(g, eclass_id) eclass_id_key = IdKey(eclass_id) eclass = g.classes[eclass_id_key] + md = eclass.data - node_data = make(g, node) + node_data = make(g, node, md) if !isnothing(node_data) if !isnothing(eclass.data) joined_data = join(eclass.data, node_data) @@ -471,7 +473,7 @@ end function check_analysis(g) for (id, eclass) in g.classes isnothing(eclass.data) && continue - pass = mapreduce(x -> make(g, x), (x, y) -> join(x, y), eclass) + pass = mapreduce(x -> make(g, x, x.data), (x, y) -> join(x, y), eclass) @assert eclass.data == pass end true diff --git a/src/EGraphs/extract.jl b/src/EGraphs/extract.jl index 85186b5c..540885c3 100644 --- a/src/EGraphs/extract.jl +++ b/src/EGraphs/extract.jl @@ -15,12 +15,12 @@ function Extractor(g::EGraph, cost_function::Function, cost_type = Float64) extractor end -function extract_expr_recursive(g::EGraph{T}, n::VecExpr, get_node::Function) where {T} +function extract_expr_recursive(g::EGraph{T, A}, n::VecExpr, get_node::Function) where {T, A} h = get_constant(g, v_head(n)) v_isexpr(n) || return h children = map(c -> extract_expr_recursive(g, c, get_node), get_node.(v_children(n))) - # TODO metadata? - maketerm(T, h, children, nothing) + md = metadata(g[lookup(g, n)].data) + maketerm(T, h, children, md) end function extract_expr_recursive(g::EGraph{Expr}, n::VecExpr, get_node::Function) diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index fcc3555f..2aa01675 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -117,7 +117,7 @@ end function instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatLiteral)::Id add_constant_hashed!(g, p.value, v_head(p.n)) - add!(g, p.n, true) + add!(g, p.n, nothing, true) end instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatVar)::Id = v_pair_first(bindings[p.idx]) @@ -127,7 +127,7 @@ function instantiate_enode!(bindings, g::EGraph{ExpressionType}, p::PatExpr)::Id for i in v_children_range(p.n) @inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) end - add!(g, p.n, true) + add!(g, p.n, nothing, true) end function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id @@ -137,7 +137,7 @@ function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id for i in v_children_range(p.n) @inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) end - add!(g, p.n, true) + add!(g, p.n, nothing, true) end """ diff --git a/test/egraphs/analysis.jl b/test/egraphs/analysis.jl index 092031f6..1266e0b4 100644 --- a/test/egraphs/analysis.jl +++ b/test/egraphs/analysis.jl @@ -15,7 +15,7 @@ Base.:(*)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n Base.:(+)(a::NumberFoldAnalysis, b::NumberFoldAnalysis) = NumberFoldAnalysis(a.n + b.n) # This should be auto-generated by a macro -function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr) where {ExpressionType} +function EGraphs.make(g::EGraph{ExpressionType,NumberFoldAnalysis}, n::VecExpr, md) where {ExpressionType} h = get_constant(g, v_head(n)) v_isexpr(n) || return h isa Number ? NumberFoldAnalysis(h) : nothing if v_iscall(n) && v_arity(n) == 2 diff --git a/test/integration/cas.jl b/test/integration/cas.jl index 08a34749..23151e83 100644 --- a/test/integration/cas.jl +++ b/test/integration/cas.jl @@ -215,7 +215,7 @@ end # ex = rewrite(ex, canonical_t; clean=false) -function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr) +function EGraphs.make(g::EGraph{Expr,Type}, n::VecExpr, md) h = get_constant(g, v_head(n)) v_isexpr(n) || return (h in (:im, im) ? Complex : typeof(h)) v_iscall(n) || return (Any) diff --git a/test/tutorials/custom_types.jl b/test/tutorials/custom_types.jl index 69c7febc..243a5b8e 100644 --- a/test/tutorials/custom_types.jl +++ b/test/tutorials/custom_types.jl @@ -62,13 +62,23 @@ ex = :(a[b]) # `metadata` should return the extra metadata. If you have many fields, i suggest using a `NamedTuple`. -# TermInterface.metadata(e::MyExpr) = e.foo +TermInterface.metadata(e::MyExpr) = e.foo + +struct MetadataAnalysis + metadata +end + +# Extract metadata from the analysis. The metadata is used in maketerm to reconstruct the expression. +TermInterface.metadata(ma::MetadataAnalysis) = ma.metadata -# struct MetadataAnalysis -# metadata -# end -# function EGraphs.make(g::EGraph{MyExprHead,MetadataAnalysis}, n::VecExpr) = +function EGraphs.join(a::MetadataAnalysis, b::MetadataAnalysis) + return a # Some fancy join here. +end + +function EGraphs.make(g::EGraph{MyExpr,MetadataAnalysis}, n::VecExpr, md) + isnothing(md) ? md : MetadataAnalysis(md) +end # Additionally, you can override `EGraphs.preprocess` on your custom expression # to pre-process any expression before insertion in the E-Graph. @@ -97,14 +107,12 @@ hcall = MyExpr(:h, [4], "hello") ex = MyExpr(:f, [MyExpr(:z, [2]), hcall]) # We use the first type parameter an existing e-graph to inform the system about # the *default* type of expressions that we want newly added expressions to have. -g = EGraph{MyExpr}(ex) +g = EGraph{MyExpr, MetadataAnalysis}(ex) # Now let's test that it works. saturate!(g, t) -# TODO metadata -# expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "") -expected = MyExpr(:f, [MyExpr(:h, [4], "")], "") +expected = MyExpr(:f, [MyExpr(:h, [4], "HELLO")], "") extracted = extract!(g, astsize) @test expected == extracted diff --git a/test/tutorials/lambda_theory.jl b/test/tutorials/lambda_theory.jl index 9f74a725..8df1eea8 100644 --- a/test/tutorials/lambda_theory.jl +++ b/test/tutorials/lambda_theory.jl @@ -104,7 +104,7 @@ const LambdaAnalysis = Set{Symbol} getdata(eclass) = eclass.data -function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr) where {ExprType} +function EGraphs.make(g::EGraph{ExprType,LambdaAnalysis}, n::VecExpr, md) where {ExprType} v_isexpr(n) || return LambdaAnalysis() if v_iscall(n) h = v_head(n)