diff --git a/.github/workflows/benchmark_pr.yml b/.github/workflows/benchmark_pr.yml index 587fb8d4..223d456c 100644 --- a/.github/workflows/benchmark_pr.yml +++ b/.github/workflows/benchmark_pr.yml @@ -47,7 +47,7 @@ jobs: ls -l ~/.julia/bin mkdir results benchpkg Metatheory \ - --rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" \ + --rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" \ --url=${{ github.event.repository.clone_url }} \ --bench-on="${{github.event.pull_request.head.sha}}" \ --output-dir=results/ --tune @@ -68,7 +68,7 @@ jobs: - name: Create markdown table from benchmarks run: | julia --project=egg-benchmark/scripts egg-benchmark/scripts/load_results.jl \ - -b ${{github.event.pull_request.head.sha}} -b "${{github.event.repository.default_branch}}" \ + -b ${{github.event.pull_request.head.sha}} -b "${{github.event.pull_request.base.sha}}" \ --mt-results=results/ \ --egg-results=egg-benchmark/target/criterion \ -o table.md diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 125f32f3..5c100471 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,6 +38,8 @@ jobs: ${{ runner.os }}- - uses: julia-actions/julia-buildpkg@v1 - uses: julia-actions/julia-runtest@v1 + env: + JULIA_NUM_THREADS: 20 - uses: julia-actions/julia-processcoverage@v1 - uses: codecov/codecov-action@v4 with: diff --git a/NEWS.md b/NEWS.md index 623e2ab3..0c49b75a 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,5 +1,6 @@ # 3.0 - Updated TermInterface to 1.0.1 +- Use a custom per-thread stack for e-matching # 2.0 - No longer dispatch against types, but instead dispatch against objects. @@ -30,4 +31,4 @@ Metatheory.jl + SymbolicUtils.jl = ❤️ - Removed `@metatheory_init` - Rules now support type and function predicates as in SymbolicUtils.jl - Redesigned the library -- Introduced `@timerewrite` to time the execution of classical rewriting systems. \ No newline at end of file +- Introduced `@timerewrite` to time the execution of classical rewriting systems. diff --git a/src/EGraphs/saturation.jl b/src/EGraphs/saturation.jl index 41b2702f..8afb219a 100644 --- a/src/EGraphs/saturation.jl +++ b/src/EGraphs/saturation.jl @@ -77,6 +77,7 @@ function eqsat_search!( @debug "SEARCHING" + stack = get_local_stack() for (rule_idx, rule) in enumerate(theory) prev_matches = n_matches @timeit report.to string(rule_idx) begin @@ -91,14 +92,14 @@ function eqsat_search!( for i in ids_left cansearch(scheduler, rule_idx, i) || continue - n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer) + n_matches += rule.ematcher_left!(g, rule_idx, i, stack, ematch_buffer) inform!(scheduler, rule_idx, i, n_matches) end if is_bidirectional(rule) for i in ids_right cansearch(scheduler, rule_idx, i) || continue - n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer) + n_matches += rule.ematcher_right!(g, rule_idx, i, stack, ematch_buffer) inform!(scheduler, rule_idx, i, n_matches) end end @@ -123,21 +124,21 @@ end instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatVar)::Id = v_pair_first(bindings[p.idx]) function instantiate_enode!(bindings, g::EGraph{ExpressionType}, p::PatExpr)::Id where {ExpressionType} add_constant_hashed!(g, p.head, p.head_hash) - + n = copy(p.n) for i in v_children_range(p.n) - @inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) + @inbounds n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) end - add!(g, p.n, true) + add!(g, n, false) end function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id add_constant_hashed!(g, p.quoted_head, p.quoted_head_hash) v_set_head!(p.n, p.quoted_head_hash) - + n = copy(p.n) for i in v_children_range(p.n) - @inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) + @inbounds n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH]) end - add!(g, p.n, true) + add!(g, n, false) end """ diff --git a/src/Rules.jl b/src/Rules.jl index d32e064e..7d3c0850 100644 --- a/src/Rules.jl +++ b/src/Rules.jl @@ -1,5 +1,6 @@ module Rules +using Base.Threads using TermInterface using AutoHashEquals using Metatheory.Patterns @@ -16,7 +17,8 @@ export RewriteRule, Theory, direct, direct_left_to_right, - direct_right_to_left + direct_right_to_left, + get_local_stack const STACK_SIZE = 512 @@ -27,14 +29,14 @@ Rules defined as with the --> are called *directed rewrite* rules. Application of a *directed rewrite* rule is a replacement of the `left` pattern with the `right` substitution, with the correct instantiation -of pattern variables. +of pattern variables. ```julia @rule ~a * ~b --> ~b * ~a ``` -An *equational rule* is a symbolic substitution rule with operator `==` that -can be rewritten bidirectionally. Therefore, it can only be used +An *equational rule* is a symbolic substitution rule with operator `==` that +can be rewritten bidirectionally. Therefore, it can only be used with the EGraphs backend. ```julia @@ -43,7 +45,7 @@ with the EGraphs backend. Rules defined with the `!=` act as *anti*-rules for checking contradictions in e-graph rewriting. If two terms, corresponding to the left and right hand side of an -*anti-rule* are found in an `EGraph`, saturation is halted immediately. +*anti-rule* are found in an `EGraph`, saturation is halted immediately. ```julia !a != a @@ -71,11 +73,34 @@ Base.@kwdef struct RewriteRule{Op<:Function} ematcher_right!::Union{Nothing,Function} = nothing matcher_left::Function matcher_right::Union{Nothing,Function} = nothing - stack::OptBuffer{UInt16} = OptBuffer{UInt16}(STACK_SIZE) lhs_original = nothing rhs_original = nothing end +const THREAD_STACKS = OptBuffer{UInt16}[] +""" +Retrieve the per-thread stack thread used for program counters in matching. + +We need a stack for each thread so that multithreading works correctly. + +Modeled off [Julia's global RNG](https://github.com/JuliaLang/julia/blob/bc4b2e848400764e389c825b57d1481ed76f4d85/stdlib/Random/src/RNGs.jl) +""" +@inline get_local_stack() = get_local_stack(Threads.threadid()) +@noinline function get_local_stack(tid::Int) + @assert 0 < tid <= length(THREAD_STACKS) + if @inbounds isassigned(THREAD_STACKS, tid) + @inbounds stack = THREAD_STACKS[tid] + else + stack = OptBuffer{UInt16}(STACK_SIZE) + @inbounds THREAD_STACKS[tid] = stack + end + return stack +end + +function __init__() + resize!(empty!(THREAD_STACKS), Threads.nthreads()) +end + function --> end const DirectedRule = RewriteRule{typeof(-->)} const EqualityRule = RewriteRule{typeof(==)} @@ -99,8 +124,8 @@ function Base.show(io::IO, r::RewriteRule) end -(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), r.stack) -(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), r.stack) +(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), get_local_stack()) +(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), get_local_stack()) # --------------------- # Theories @@ -165,10 +190,10 @@ function Base.inv(r::RewriteRule) end """ -Turns an EqualityRule into a DirectedRule. For example, +Turns an EqualityRule into a DirectedRule. For example, ```julia -direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x) +direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x) ``` """ function direct(r::EqualityRule) @@ -176,10 +201,10 @@ function direct(r::EqualityRule) end """ -Turns an EqualityRule into a DirectedRule, but right to left. For example, +Turns an EqualityRule into a DirectedRule, but right to left. For example, ```julia -direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x) +direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x) ``` """ direct_right_to_left(r::EqualityRule) = inv(direct(r)) diff --git a/test/egraphs/concurrency.jl b/test/egraphs/concurrency.jl new file mode 100644 index 00000000..307538be --- /dev/null +++ b/test/egraphs/concurrency.jl @@ -0,0 +1,28 @@ +# MT currently does not support thread-parallel saturation with a shared state. +# But it should be possible to saturate independent egraphs withing separate threads. + +using Test, Metatheory + +function run_eq() + theory = @theory a b c begin + a + b == b + a + a + (b + c) == (a + b) + c + end + + g = EGraph{Expr}(:(1 + (2 + (3 + (4 + (5 + 6)))))); + saturate!(g, theory, SaturationParams(timeout=100)) + end + + function test_threads() + @assert Threads.nthreads() > 1 # this test is only useful in multi-threaded scenarios. + + # run equality saturation in parallel threads (no shared state) + Threads.@threads for _ in 1:1000 + run_eq() + end + true + end + +@testset "Concurrency" begin + @test test_threads() +end \ No newline at end of file diff --git a/test/egraphs/ematch.jl b/test/egraphs/ematch.jl index d2c5fb6d..995ebba0 100644 --- a/test/egraphs/ematch.jl +++ b/test/egraphs/ematch.jl @@ -13,48 +13,48 @@ b = OptBuffer{UInt128}(10) r = @rule 2 --> true g = EGraph(2) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 end @testset "Composite Ground Terms" begin r = @rule f(2, 3) --> true g = EGraph(:(f(2, 3))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0 g = EGraph(:(f(2, 4))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0 r = @rule f(2, h(3, 4)) --> true g = EGraph(:(f(2, h(3, 4)))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0 end @testset "Pattern Variables" begin g = EGraph(:(f(2, 1))) r = @rule ~a --> true - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 1 end @testset "Type Assertions" begin r = @rule ~a::Int --> true g = EGraph(:(f(2, 1))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 g = EGraph(:3) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 new_id = addexpr!(g, :f) union!(g, g.root, new_id) @@ -62,7 +62,7 @@ end new_id = addexpr!(g, 4) union!(g, g.root, new_id) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 2 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 2 end @testset "Predicate Assertions" begin @@ -78,13 +78,13 @@ end end g = EGraph(:(f(2, 1))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 g = EGraph(:2) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 g = EGraph(:3) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 new_id = addexpr!(g, :f) union!(g, g.root, new_id) @@ -92,7 +92,7 @@ end new_id = addexpr!(g, 4) union!(g, g.root, new_id) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 end @@ -100,18 +100,18 @@ end g = EGraph(:(f(2, 1))) r = @rule f(2, ~a) --> true - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 - @test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0 - @test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 + @test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0 + @test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0 r = @rule f(~a, ~a) --> true - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0 g = EGraph(:(f(2, 2))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 g = EGraph(:(f(h(3, 4), h(3, 4)))) - @test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1 + @test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1 end