Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multithreading support #244

Open
wants to merge 12 commits into
base: ale/3.0
Choose a base branch
from
4 changes: 2 additions & 2 deletions .github/workflows/benchmark_pr.yml
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ jobs:
ls -l ~/.julia/bin
mkdir results
benchpkg Metatheory \
--rev="${{github.event.repository.default_branch}},${{github.event.pull_request.head.sha}}" \
--rev="${{github.event.pull_request.base.sha}},${{github.event.pull_request.head.sha}}" \
--url=${{ github.event.repository.clone_url }} \
--bench-on="${{github.event.pull_request.head.sha}}" \
--output-dir=results/ --tune
Expand All @@ -68,7 +68,7 @@ jobs:
- name: Create markdown table from benchmarks
run: |
julia --project=egg-benchmark/scripts egg-benchmark/scripts/load_results.jl \
-b ${{github.event.pull_request.head.sha}} -b "${{github.event.repository.default_branch}}" \
-b ${{github.event.pull_request.head.sha}} -b "${{github.event.pull_request.base.sha}}" \
--mt-results=results/ \
--egg-results=egg-benchmark/target/criterion \
-o table.md
Expand Down
2 changes: 2 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ jobs:
${{ runner.os }}-
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
JULIA_NUM_THREADS: 20
- uses: julia-actions/julia-processcoverage@v1
- uses: codecov/codecov-action@v4
with:
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# 3.0
- Updated TermInterface to 1.0.1
- Use a custom per-thread stack for e-matching

# 2.0
- No longer dispatch against types, but instead dispatch against objects.
Expand Down Expand Up @@ -30,4 +31,4 @@ Metatheory.jl + SymbolicUtils.jl = ❤️
- Removed `@metatheory_init`
- Rules now support type and function predicates as in SymbolicUtils.jl
- Redesigned the library
- Introduced `@timerewrite` to time the execution of classical rewriting systems.
- Introduced `@timerewrite` to time the execution of classical rewriting systems.
17 changes: 9 additions & 8 deletions src/EGraphs/saturation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ function eqsat_search!(


@debug "SEARCHING"
stack = get_local_stack()
for (rule_idx, rule) in enumerate(theory)
prev_matches = n_matches
@timeit report.to string(rule_idx) begin
Expand All @@ -91,14 +92,14 @@ function eqsat_search!(

for i in ids_left
cansearch(scheduler, rule_idx, i) || continue
n_matches += rule.ematcher_left!(g, rule_idx, i, rule.stack, ematch_buffer)
n_matches += rule.ematcher_left!(g, rule_idx, i, stack, ematch_buffer)
inform!(scheduler, rule_idx, i, n_matches)
end

if is_bidirectional(rule)
for i in ids_right
cansearch(scheduler, rule_idx, i) || continue
n_matches += rule.ematcher_right!(g, rule_idx, i, rule.stack, ematch_buffer)
n_matches += rule.ematcher_right!(g, rule_idx, i, stack, ematch_buffer)
inform!(scheduler, rule_idx, i, n_matches)
end
end
Expand All @@ -123,21 +124,21 @@ end
instantiate_enode!(bindings, @nospecialize(g::EGraph), p::PatVar)::Id = v_pair_first(bindings[p.idx])
function instantiate_enode!(bindings, g::EGraph{ExpressionType}, p::PatExpr)::Id where {ExpressionType}
add_constant_hashed!(g, p.head, p.head_hash)

n = copy(p.n)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It seems to me like not copying would be a bug, even in the case of non-multithreading!

for i in v_children_range(p.n)
@inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH])
@inbounds n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH])
end
add!(g, p.n, true)
add!(g, n, false)
end

function instantiate_enode!(bindings, g::EGraph{Expr}, p::PatExpr)::Id
add_constant_hashed!(g, p.quoted_head, p.quoted_head_hash)
v_set_head!(p.n, p.quoted_head_hash)

n = copy(p.n)
for i in v_children_range(p.n)
@inbounds p.n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH])
@inbounds n[i] = instantiate_enode!(bindings, g, p.children[i - VECEXPR_META_LENGTH])
end
add!(g, p.n, true)
add!(g, n, false)
end

"""
Expand Down
49 changes: 37 additions & 12 deletions src/Rules.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
module Rules

using Base.Threads
using TermInterface
using AutoHashEquals
using Metatheory.Patterns
Expand All @@ -16,7 +17,8 @@ export RewriteRule,
Theory,
direct,
direct_left_to_right,
direct_right_to_left
direct_right_to_left,
get_local_stack

const STACK_SIZE = 512

Expand All @@ -27,14 +29,14 @@ Rules defined as with the --> are
called *directed rewrite* rules. Application of a *directed rewrite* rule
is a replacement of the `left` pattern with
the `right` substitution, with the correct instantiation
of pattern variables.
of pattern variables.

```julia
@rule ~a * ~b --> ~b * ~a
```

An *equational rule* is a symbolic substitution rule with operator `==` that
can be rewritten bidirectionally. Therefore, it can only be used
An *equational rule* is a symbolic substitution rule with operator `==` that
can be rewritten bidirectionally. Therefore, it can only be used
with the EGraphs backend.

```julia
Expand All @@ -43,7 +45,7 @@ with the EGraphs backend.

Rules defined with the `!=` act as *anti*-rules for checking contradictions in e-graph
rewriting. If two terms, corresponding to the left and right hand side of an
*anti-rule* are found in an `EGraph`, saturation is halted immediately.
*anti-rule* are found in an `EGraph`, saturation is halted immediately.

```julia
!a != a
Expand Down Expand Up @@ -71,11 +73,34 @@ Base.@kwdef struct RewriteRule{Op<:Function}
ematcher_right!::Union{Nothing,Function} = nothing
matcher_left::Function
matcher_right::Union{Nothing,Function} = nothing
stack::OptBuffer{UInt16} = OptBuffer{UInt16}(STACK_SIZE)
lhs_original = nothing
rhs_original = nothing
end

const THREAD_STACKS = OptBuffer{UInt16}[]
"""
Retrieve the per-thread stack thread used for program counters in matching.

We need a stack for each thread so that multithreading works correctly.

Modeled off [Julia's global RNG](https://github.com/JuliaLang/julia/blob/bc4b2e848400764e389c825b57d1481ed76f4d85/stdlib/Random/src/RNGs.jl)
"""
@inline get_local_stack() = get_local_stack(Threads.threadid())
@noinline function get_local_stack(tid::Int)
@assert 0 < tid <= length(THREAD_STACKS)
if @inbounds isassigned(THREAD_STACKS, tid)
@inbounds stack = THREAD_STACKS[tid]
else
stack = OptBuffer{UInt16}(STACK_SIZE)
@inbounds THREAD_STACKS[tid] = stack
end
return stack
end

function __init__()
resize!(empty!(THREAD_STACKS), Threads.nthreads())
end

function --> end
const DirectedRule = RewriteRule{typeof(-->)}
const EqualityRule = RewriteRule{typeof(==)}
Expand All @@ -99,8 +124,8 @@ function Base.show(io::IO, r::RewriteRule)
end


(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), r.stack)
(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), r.stack)
(r::DirectedRule)(term) = r.matcher_left(term, (bindings...) -> instantiate(term, r.right, bindings), get_local_stack())
(r::DynamicRule)(term) = r.matcher_left(term, (bindings...) -> r.right(term, nothing, bindings...), get_local_stack())

# ---------------------
# Theories
Expand Down Expand Up @@ -165,21 +190,21 @@ function Base.inv(r::RewriteRule)
end

"""
Turns an EqualityRule into a DirectedRule. For example,
Turns an EqualityRule into a DirectedRule. For example,

```julia
direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x)
direct(@rule f(~x) == g(~x)) == f(~x) --> g(~x)
```
"""
function direct(r::EqualityRule)
RewriteRule(r.name, -->, (getfield(r, k) for k in fieldnames(DirectedRule)[3:end])...)
end

"""
Turns an EqualityRule into a DirectedRule, but right to left. For example,
Turns an EqualityRule into a DirectedRule, but right to left. For example,

```julia
direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x)
direct(@rule f(~x) == g(~x)) == g(~x) --> f(~x)
```
"""
direct_right_to_left(r::EqualityRule) = inv(direct(r))
Expand Down
28 changes: 28 additions & 0 deletions test/egraphs/concurrency.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# MT currently does not support thread-parallel saturation with a shared state.
# But it should be possible to saturate independent egraphs withing separate threads.

using Test, Metatheory

function run_eq()
theory = @theory a b c begin
a + b == b + a
a + (b + c) == (a + b) + c
end

g = EGraph{Expr}(:(1 + (2 + (3 + (4 + (5 + 6))))));
saturate!(g, theory, SaturationParams(timeout=100))
end

function test_threads()
@assert Threads.nthreads() > 1 # this test is only useful in multi-threaded scenarios.

# run equality saturation in parallel threads (no shared state)
Threads.@threads for _ in 1:1000
run_eq()
end
true
end

@testset "Concurrency" begin
@test test_threads()
end
52 changes: 26 additions & 26 deletions test/egraphs/ematch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,56 +13,56 @@ b = OptBuffer{UInt128}(10)
r = @rule 2 --> true
g = EGraph(2)

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
end

@testset "Composite Ground Terms" begin
r = @rule f(2, 3) --> true
g = EGraph(:(f(2, 3)))

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0

g = EGraph(:(f(2, 4)))

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0


r = @rule f(2, h(3, 4)) --> true
g = EGraph(:(f(2, h(3, 4))))

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0
end

@testset "Pattern Variables" begin
g = EGraph(:(f(2, 1)))
r = @rule ~a --> true

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 1
end

@testset "Type Assertions" begin
r = @rule ~a::Int --> true
g = EGraph(:(f(2, 1)))
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0

g = EGraph(:3)
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1

new_id = addexpr!(g, :f)
union!(g, g.root, new_id)

new_id = addexpr!(g, 4)
union!(g, g.root, new_id)

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 2
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 2
end

@testset "Predicate Assertions" begin
Expand All @@ -78,40 +78,40 @@ end
end

g = EGraph(:(f(2, 1)))
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0

g = EGraph(:2)
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1

g = EGraph(:3)
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0

new_id = addexpr!(g, :f)
union!(g, g.root, new_id)

new_id = addexpr!(g, 4)
union!(g, g.root, new_id)

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
end


@testset "Non-Ground Terms" begin
g = EGraph(:(f(2, 1)))
r = @rule f(2, ~a) --> true

@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, Id(1), r.stack, b) == 0
@test r.ematcher_left!(g, 0, Id(2), r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
@test r.ematcher_left!(g, 0, Id(1), get_local_stack(), b) == 0
@test r.ematcher_left!(g, 0, Id(2), get_local_stack(), b) == 0

r = @rule f(~a, ~a) --> true
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 0
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 0

g = EGraph(:(f(2, 2)))
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1

g = EGraph(:(f(h(3, 4), h(3, 4))))
@test r.ematcher_left!(g, 0, g.root, r.stack, b) == 1
@test r.ematcher_left!(g, 0, g.root, get_local_stack(), b) == 1
end


Expand Down
Loading