Skip to content

Commit

Permalink
Speed up sharing-aware extractor and fix spelling (#11)
Browse files Browse the repository at this point in the history
* Faster greedy sharing-aware extractor

* fix spelling
  • Loading branch information
TrevorHansen authored Oct 23, 2023
1 parent 9fb66c6 commit f168a55
Show file tree
Hide file tree
Showing 2 changed files with 60 additions and 38 deletions.
12 changes: 6 additions & 6 deletions plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ def process(js, extractors=[]):
assert len(extractors) == 2
e1, e2 = extractors

e1_cummulative=0
e2_cummulative=0
e1_cumulative=0
e2_cumulative=0

summaries = {}

Expand All @@ -45,8 +45,8 @@ def process(js, extractors=[]):
dag_ratio = d[e1]["dag"] / d[e2]["dag"]
micros_ratio = max(1, d[e1]["micros"]) / max(1, d[e2]["micros"])

e1_cummulative += d[e1]["micros"];
e2_cummulative += d[e2]["micros"];
e1_cumulative += d[e1]["micros"];
e2_cumulative += d[e2]["micros"];

summaries[name] = {
"tree": tree_ratio,
Expand All @@ -57,8 +57,8 @@ def process(js, extractors=[]):
print(f"Error processing {name}")
raise e

print(f"Cummulative time for {e1}: {e1_cummulative/1000:.0f}ms")
print(f"Cummulative time for {e2}: {e2_cummulative/1000:.0f}ms")
print(f"cumulative time for {e1}: {e1_cumulative/1000:.0f}ms")
print(f"cumulative time for {e2}: {e2_cumulative/1000:.0f}ms")

print(f"{e1} / {e2}")

Expand Down
86 changes: 54 additions & 32 deletions src/extract/greedy_dag_1.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,57 +17,79 @@ impl FasterGreedyDagExtractor {
egraph: &EGraph,
node_id: NodeId,
costs: &HashMap<ClassId, CostSet>,
best_cost: Cost,
) -> CostSet {
let node = &egraph[&node_id];

let cid = egraph.nid_to_cid(&node_id);
// No children -> easy.
if node.children.is_empty() {
return CostSet {
costs: std::collections::HashMap::default(),
total: node.cost,
choice: node_id.clone(),
};
}

let mut desc = 0;
let mut children_cost = Cost::default();
for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
let cs = costs.get(child_cid).unwrap();
desc += cs.costs.len();
children_cost += cs.total;
// Get unique classes of children.
let mut childrens_classes = node
.children
.iter()
.map(|c| egraph.nid_to_cid(&c).clone())
.collect::<Vec<ClassId>>();
childrens_classes.sort();
childrens_classes.dedup();

let first_cost = costs.get(&childrens_classes[0]).unwrap();

if childrens_classes.len() == 1 && (node.cost + first_cost.total > best_cost) {
// Shortcut. Can't be cheaper so return junk.
return CostSet {
costs: std::collections::HashMap::default(),
total: INFINITY,
choice: node_id.clone(),
};
}

let mut cost_set = CostSet {
costs: std::collections::HashMap::with_capacity(desc),
total: Cost::default(),
choice: node_id.clone(),
};
// Clone the biggest set and insert the others into it.
let id_of_biggest = childrens_classes
.iter()
.max_by_key(|s| costs.get(s).unwrap().costs.len())
.unwrap();
let mut result = costs.get(&id_of_biggest).unwrap().costs.clone();
for child_cid in &childrens_classes {
if child_cid == id_of_biggest {
continue;
}

for child in &node.children {
let child_cid = egraph.nid_to_cid(child);
cost_set
.costs
.extend(costs.get(child_cid).unwrap().costs.clone());
let next_cost = &costs.get(child_cid).unwrap().costs;
for (key, value) in next_cost.iter() {
result.insert(key.clone(), value.clone());
}
}

let contains = cost_set.costs.contains_key(&cid.clone());
cost_set.costs.insert(cid.clone(), node.cost); // this node.
let cid = egraph.nid_to_cid(&node_id);
let contains = result.contains_key(&cid);
result.insert(cid.clone(), node.cost);

if contains {
cost_set.total = INFINITY;
let result_cost = if contains {
INFINITY
} else {
if cost_set.costs.len() == desc + 1 {
// No extra duplicates are found, so the cost is the current
// nodes cost + the children's cost.
cost_set.total = children_cost + node.cost;
} else {
cost_set.total = cost_set.costs.values().sum();
}
result.values().sum()
};

cost_set
return CostSet {
costs: result,
total: result_cost,
choice: node_id.clone(),
};
}
}

impl FasterGreedyDagExtractor {
fn check(egraph: &EGraph, node_id: NodeId, costs: &HashMap<ClassId, CostSet>) {
let cid = egraph.nid_to_cid(&node_id);
let previous = costs.get(cid).unwrap().total;
let cs = Self::calculate_cost_set(egraph, node_id, costs);
let cs = Self::calculate_cost_set(egraph, node_id, costs, INFINITY);
println!("{} {}", cs.total, previous);
assert!(cs.total >= previous);
}
Expand Down Expand Up @@ -114,7 +136,7 @@ impl Extractor for FasterGreedyDagExtractor {
prev_cost = lookup.unwrap().total;
}

let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs);
let cost_set = Self::calculate_cost_set(egraph, node_id.clone(), &costs, prev_cost);
if cost_set.total < prev_cost {
costs.insert(class_id.clone(), cost_set);
analysis_pending.extend(parents[class_id].iter().cloned());
Expand Down

0 comments on commit f168a55

Please sign in to comment.