Skip to content

Commit

Permalink
Merge pull request #333 from egraphs-good/oflatt-assume-propogate2
Browse files Browse the repository at this point in the history
Add rules that propogate assume nodes to create unique contexts
  • Loading branch information
oflatt authored Feb 8, 2024
2 parents a083fcd + 7cfe528 commit 43d0592
Show file tree
Hide file tree
Showing 10 changed files with 257 additions and 13 deletions.
15 changes: 14 additions & 1 deletion tree_assume/src/ast.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,12 @@ macro_rules! parallel {
pub use parallel;

pub fn parallel_vec(es: impl IntoIterator<Item = RcExpr>) -> RcExpr {
es.into_iter().fold(empty(), |acc, x| push_par(x, acc))
let mut iter = es.into_iter();
if let Some(e) = iter.next() {
iter.fold(single(e), |acc, x| push_par(x, acc))
} else {
empty()
}
}

pub fn tlet(lhs: RcExpr, rhs: RcExpr) -> RcExpr {
Expand Down Expand Up @@ -210,6 +215,14 @@ pub fn inloop(e1: RcExpr, e2: RcExpr) -> Assumption {
Assumption::InLoop(e1, e2)
}

pub fn inif(is_then: bool, pred: RcExpr) -> Assumption {
Assumption::InIf(is_then, pred)
}

pub fn infunc(name: &str) -> Assumption {
Assumption::InFunc(name.to_string())
}

pub fn assume(assumption: Assumption, body: RcExpr) -> RcExpr {
RcExpr::new(Expr::Assume(assumption, body))
}
6 changes: 4 additions & 2 deletions tree_assume/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ mod optimizations;
pub mod schema;
pub mod schema_helpers;
mod to_egglog;
pub(crate) mod utility;

pub type Result = std::result::Result<(), egglog::Error>;

Expand All @@ -29,7 +30,7 @@ pub fn egglog_test(
let result = interpret(&prog, input.clone());
assert_eq!(
result, expected,
"Program {:?} produced {} instead of expected {}",
"Program {:?}\nproduced:\n{}\ninstead of expected:\n{}",
prog, result, expected
);
}
Expand All @@ -38,7 +39,8 @@ pub fn egglog_test(
"{}\n{build}\n{}\n{check}\n",
[
include_str!("schema.egg"),
include_str!("optimizations/constant_fold.egg")
include_str!("optimizations/constant_fold.egg"),
include_str!("utility/assume.egg"),
]
.join("\n"),
include_str!("schedule.egg"),
Expand Down
4 changes: 2 additions & 2 deletions tree_assume/src/schedule.egg
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
(run-schedule
(repeat 6
(repeat 100 assume)
(saturate always-run)
(saturate error-checking)
(run constant_fold)
))
(saturate constant_fold)))
6 changes: 6 additions & 0 deletions tree_assume/src/schema.egg
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@
; The term is in a loop with `input` and `pred_output`.
; input pred_output
(InLoop Expr Expr)
; name of the function
(InFunc String)
; Branch of the switch and the predicate
(InSwitch i64 Expr)
; If the predicate was true, and the predicate
(InIf bool Expr)
; Other assumptions are possible, but not supported yet.
; For example:
; A boolean predicate is true.
Expand Down
2 changes: 2 additions & 0 deletions tree_assume/src/schema.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ pub type RcExpr = Rc<Expr>;
pub enum Assumption {
InLet(RcExpr),
InLoop(RcExpr, RcExpr),
InFunc(String),
InIf(bool, RcExpr),
}

#[derive(Debug, Clone, PartialEq, Eq)]
Expand Down
22 changes: 14 additions & 8 deletions tree_assume/src/schema_helpers.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,14 +35,20 @@ impl Expr {
}

pub fn to_program(self: &RcExpr, input_ty: Type, output_ty: Type) -> TreeProgram {
TreeProgram {
entry: RcExpr::new(Expr::Function(
"main".to_string(),
input_ty,
output_ty,
self.clone(),
)),
functions: vec![],
match self.as_ref() {
Expr::Function(..) => TreeProgram {
entry: self.clone(),
functions: vec![],
},
_ => TreeProgram {
entry: RcExpr::new(Expr::Function(
"main".to_string(),
input_ty,
output_ty,
self.clone(),
)),
functions: vec![],
},
}
}
}
Expand Down
9 changes: 9 additions & 0 deletions tree_assume/src/to_egglog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,15 @@ impl Assumption {
let rhs = rhs.to_egglog_internal(term_dag);
term_dag.app("InLoop".into(), vec![lhs, rhs])
}
Assumption::InFunc(name) => {
let name_lit = term_dag.lit(Literal::String(name.into()));
term_dag.app("InFunc".into(), vec![name_lit])
}
Assumption::InIf(is_then, pred) => {
let pred = pred.to_egglog_internal(term_dag);
let is_then = term_dag.lit(Literal::Bool(*is_then));
term_dag.app("InIf".into(), vec![is_then, pred])
}
}
}
}
Expand Down
82 changes: 82 additions & 0 deletions tree_assume/src/utility/assume.egg
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
; This file propogates assume nodes top-down from functions.
; It gives each program path a unique equality relation.
; This can be quite expensive, so be careful running these rules.

(ruleset assume)

(sort AssumeList)

; In order to saturate and not create unecessary contexts, we need to collapse
; duplicate assumptions.
; For example, the rewrite over `Function` could create many
; `(Assume (InFunc name) (Assume (InFunc name) ...))` nestings
(rewrite (Assume assumption (Assume assumption rest))
(Assume assumption rest)
:ruleset assume)

; ################### start top-down assumptions

(rewrite
(Function name in_ty out_ty out)
(Function name in_ty out_ty
(Assume (InFunc name)
out))
:ruleset assume)


; ################### operations
(rewrite (Assume asum (Bop op c1 c2))
(Bop op (Assume asum c1) (Assume asum c2))
:ruleset assume)
(rewrite (Assume assum (Uop op c1))
(Uop op (Assume assum c1))
:ruleset assume)
(rewrite (Assume assum (Get expr index))
(Get (Assume assum expr) index)
:ruleset assume)
(rewrite (Assume assum (Alloc expr ty))
(Alloc (Assume assum expr) ty)
:ruleset assume)
(rewrite (Assume assum (Call name expr))
(Call name (Assume assum expr))
:ruleset assume)

; ################### tuple operations
(rewrite (Assume assum (Single expr))
(Single (Assume assum expr))
:ruleset assume)
(rewrite (Assume assum (Concat order e1 e2))
(Concat order (Assume assum e1) (Assume assum e2))
:ruleset assume)

; #################### control flow

; assumptions, predicate, cases, current case
(function SwitchAssume (AssumeList Expr ListExpr i64) ListExpr :unextractable)

(rewrite (Assume assum (If pred then else))
(If (Assume assum pred)
(Assume
(InIf true (Assume assum pred)) then)
(Assume
(InIf false (Assume assum pred)) else))
:ruleset assume)

(rewrite (Assume assum (Let inputs body))
(Let
(Assume assum inputs)
(Assume
(InLet (Assume assum inputs))
body))
:ruleset assume)


(rule ((= lhs (Assume assum (DoWhile inputs pred_outputs))))
((union lhs
(DoWhile
(Assume assum inputs)
(Assume
(InLoop (Assume assum inputs) pred_outputs)
pred_outputs))))
:ruleset assume)

123 changes: 123 additions & 0 deletions tree_assume/src/utility/assume.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
#[cfg(test)]
use crate::{egglog_test, interpreter::Value, schema::Constant};

#[test]
fn test_assume_in_func() -> crate::Result {
use crate::ast::*;
let expr = function("main", intt(), intt(), int(2));
let expected = function("main", intt(), intt(), assume(infunc("main"), int(2)));
egglog_test(
&format!("{expr}"),
&format!("(check (= {expr} {expected}))"),
vec![
expr.to_program(emptyt(), intt()),
expected.to_program(emptyt(), intt()),
],
Value::Tuple(vec![]),
Value::Const(Constant::Int(2)),
)
}

#[test]
fn test_assume_two_lets() -> crate::Result {
use crate::ast::*;
let expr = function(
"main",
intt(),
intt(),
tlet(int(1), tlet(add(arg(), arg()), mul(arg(), int(2)))),
);
let int1 = assume(infunc("main"), int(1));
let arg1 = assume(inlet(int1.clone()), arg());
let addarg1 = add(arg1.clone(), arg1.clone());
let int2 = assume(inlet(addarg1.clone()), int(2));
let arg2 = assume(inlet(addarg1.clone()), arg());
let expr2 = function(
"main",
intt(),
intt(),
tlet(
int1,
tlet(
add(arg1.clone(), arg1.clone()),
mul(arg2.clone(), int2.clone()),
),
),
);

egglog_test(
&format!("{expr}"),
&format!("(check (= {expr} {expr2}))"),
vec![
expr.to_program(emptyt(), intt()),
expr2.to_program(emptyt(), intt()),
],
Value::Tuple(vec![]),
Value::Const(Constant::Int(4)),
)
}

#[test]
fn test_switch_contexts() -> crate::Result {
use crate::ast::*;
let expr = function("main", intt(), intt(), tif(ttrue(), int(1), int(2)));
let pred = assume(infunc("main"), ttrue());
let expr2 = function(
"main",
intt(),
intt(),
tif(
pred.clone(),
assume(inif(true, pred.clone()), int(1)),
assume(inif(false, pred.clone()), int(2)),
),
);
egglog_test(
&format!("{expr}"),
&format!("(check (= {expr} {expr2}))"),
vec![
expr.to_program(emptyt(), intt()),
expr2.to_program(emptyt(), intt()),
],
Value::Tuple(vec![]),
Value::Const(Constant::Int(1)),
)
}

#[test]
fn test_dowhile_cycle_assume() -> crate::Result {
use crate::ast::*;
// loop runs one iteration and returns 3
let myloop = dowhile(single(int(2)), parallel!(tfalse(), int(3)));
let expr = function("main", intt(), intt(), myloop);

let int2 = single(assume(infunc("main"), int(2)));
let inner_assume = inloop(int2.clone(), parallel!(tfalse(), int(3)));
let expr2 = function(
"main",
intt(),
intt(),
dowhile(
int2.clone(),
parallel!(
assume(inner_assume.clone(), tfalse()),
assume(inner_assume.clone(), int(3)),
),
),
);

egglog_test(
&format!(
"{expr}
(union {} {expr})",
single(int(3))
),
&format!("(check (= {expr} {expr2}))"),
vec![
expr.to_program(emptyt(), intt()),
expr2.to_program(emptyt(), intt()),
],
Value::Tuple(vec![]),
Value::Tuple(vec![Value::Const(Constant::Int(3))]),
)
}
1 change: 1 addition & 0 deletions tree_assume/src/utility/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
pub(crate) mod assume;

0 comments on commit 43d0592

Please sign in to comment.