From 36f05d0656487aff9d031e82658025af2ee9be61 Mon Sep 17 00:00:00 2001 From: Trevor Hansen Date: Thu, 28 Dec 2023 16:37:01 +1100 Subject: [PATCH] better ilp extractor --- src/extract/ilp_cbc.rs | 232 ++++++++++++++--------------------------- src/main.rs | 2 + 2 files changed, 78 insertions(+), 156 deletions(-) diff --git a/src/extract/ilp_cbc.rs b/src/extract/ilp_cbc.rs index c48b04b..90e7462 100644 --- a/src/extract/ilp_cbc.rs +++ b/src/extract/ilp_cbc.rs @@ -1,11 +1,17 @@ -use core::panic; +/* An ILP extractor that returns the optimal DAG-extraction. + +This extractor is simple so that it's easy to see that it's correct. + +If the timeout is reached, it will return the result of the faster-greedy-dag extractor. +*/ + +// Without a timeout, some will take > 10 hours to finish. +const SOLVING_TIME_LIMIT_SECONDS: u64 = 10; use super::*; use coin_cbc::{Col, Model, Sense}; use indexmap::IndexSet; -const INITIALISE_WITH_BOTTOM_UP: bool = false; - struct ClassVars { active: Col, nodes: Vec, @@ -17,20 +23,14 @@ impl Extractor for CbcExtractor { fn extract(&self, egraph: &EGraph, roots: &[ClassId]) -> ExtractionResult { let mut model = Model::default(); - let true_literal = model.add_binary(); - model.set_col_lower(true_literal, 1.0); + model.set_parameter("seconds", &SOLVING_TIME_LIMIT_SECONDS.to_string()); let vars: IndexMap = egraph .classes() .values() .map(|class| { let cvars = ClassVars { - active: if roots.contains(&class.id) { - // Roots must be active. - true_literal - } else { - model.add_binary() - }, + active: model.add_binary(), nodes: class.nodes.iter().map(|_| model.add_binary()).collect(), }; (class.id.clone(), cvars) @@ -56,60 +56,24 @@ impl Extractor for CbcExtractor { .collect::>() }; - let mut intersection: IndexSet = - childrens_classes_var(egraph[class_id].nodes[0].clone()); - - for node in &egraph[class_id].nodes[1..] { - intersection = intersection - .intersection(&childrens_classes_var(node.clone())) - .cloned() - .collect(); - } - - // A class being active implies that all in the intersection - // of it's children are too. - for c in &intersection { - let row = model.add_row(); - model.set_row_upper(row, 0.0); - model.set_weight(row, class.active, 1.0); - model.set_weight(row, *c, -1.0); - } - for (node_id, &node_active) in egraph[class_id].nodes.iter().zip(&class.nodes) { for child_active in childrens_classes_var(node_id.clone()) { // node active implies child active, encoded as: // node_active <= child_active // node_active - child_active <= 0 - if !intersection.contains(&child_active) { - let row = model.add_row(); - model.set_row_upper(row, 0.0); - model.set_weight(row, node_active, 1.0); - model.set_weight(row, child_active, -1.0); - } + let row = model.add_row(); + model.set_row_upper(row, 0.0); + model.set_weight(row, node_active, 1.0); + model.set_weight(row, child_active, -1.0); } } } model.set_obj_sense(Sense::Minimize); for class in egraph.classes().values() { - let min_cost = class - .nodes - .iter() - .map(|n_id| egraph[n_id].cost) - .min() - .unwrap_or(Cost::default()) - .into_inner(); - - // Most helpful when the members of the class all have the same cost. - // For example if the members' costs are [1,1,1], three terms get - // replaced by one in the objective function. - if min_cost != 0.0 { - model.set_obj_coeff(vars[&class.id].active, min_cost); - } - for (node_id, &node_active) in class.nodes.iter().zip(&vars[&class.id].nodes) { let node = &egraph[node_id]; - let node_cost = node.cost.into_inner() - min_cost; + let node_cost = node.cost.into_inner(); assert!(node_cost >= 0.0); if node_cost != 0.0 { @@ -118,36 +82,11 @@ impl Extractor for CbcExtractor { } } - // set initial solution based on bottom up extractor - if INITIALISE_WITH_BOTTOM_UP { - let initial_result = super::bottom_up::BottomUpExtractor.extract(egraph, roots); - for (class, class_vars) in egraph.classes().values().zip(vars.values()) { - if let Some(node_id) = initial_result.choices.get(&class.id) { - model.set_col_initial_solution(class_vars.active, 1.0); - for col in &class_vars.nodes { - model.set_col_initial_solution(*col, 0.0); - } - let node_idx = class.nodes.iter().position(|n| n == node_id).unwrap(); - model.set_col_initial_solution(class_vars.nodes[node_idx], 1.0); - } else { - model.set_col_initial_solution(class_vars.active, 0.0); - } - } + for root in roots { + model.set_col_lower(vars[root].active, 1.0); } - let mut banned_cycles: IndexSet<(ClassId, usize)> = Default::default(); - find_cycles(egraph, |id, i| { - banned_cycles.insert((id, i)); - }); - for (class_id, class_vars) in &vars { - for (i, &node_active) in class_vars.nodes.iter().enumerate() { - if banned_cycles.contains(&(class_id.clone(), i)) { - model.set_col_upper(node_active, 0.0); - model.set_col_lower(node_active, 0.0); - } - } - } - log::info!("@blocked {}", banned_cycles.len()); + block_cycles(&mut model, &vars, &egraph); let solution = model.solve(); log::info!( @@ -157,6 +96,13 @@ impl Extractor for CbcExtractor { solution.raw().obj_value(), ); + if solution.raw().status() != coin_cbc::raw::Status::Finished { + let initial_result = + super::faster_greedy_dag::FasterGreedyDagExtractor.extract(egraph, roots); + log::info!("Unfinished CBC solution"); + return initial_result; + } + let mut result = ExtractionResult::default(); for (id, var) in &vars { @@ -172,96 +118,70 @@ impl Extractor for CbcExtractor { } } - let cycles = result.find_cycles(egraph, roots); - assert!(cycles.is_empty()); return result; } } -// from @khaki3 -// fixes bug in egg 0.9.4's version -// https://github.com/egraphs-good/egg/issues/207#issuecomment-1264737441 -fn find_cycles(egraph: &EGraph, mut f: impl FnMut(ClassId, usize)) { - let mut pending: IndexMap> = IndexMap::default(); - - let mut order: IndexMap = IndexMap::default(); - - let mut memo: IndexMap<(ClassId, usize), bool> = IndexMap::default(); - - let mut stack: Vec<(ClassId, usize)> = vec![]; - - let n2c = |nid: &NodeId| egraph.nid_to_cid(nid); - - for class in egraph.classes().values() { - let id = &class.id; - for (i, node_id) in egraph[id].nodes.iter().enumerate() { - let node = &egraph[node_id]; - for child in &node.children { - let child = n2c(child).clone(); - pending - .entry(child) - .or_insert_with(Vec::new) - .push((id.clone(), i)); - } +fn block_cycles(model: &mut Model, vars: &IndexMap, egraph: &EGraph) { + let mut levels: IndexMap = Default::default(); + for c in vars.keys() { + levels.insert(c.clone(), model.add_integer()); + } - if node.is_leaf() { - stack.push((id.clone(), i)); - } + // If n.variable is true, opposite_col will be false and vice versa. + let mut opposite: IndexMap = Default::default(); + for c in vars.values() { + for n in &c.nodes { + let opposite_col = model.add_binary(); + opposite.insert(*n, opposite_col); + let row = model.add_row(); + model.set_row_equal(row, 1.0); + model.set_weight(row, opposite_col, 1.0); + model.set_weight(row, *n, 1.0); } } - let mut count = 0; + for (class_id, c) in vars { + model.set_col_lower(*levels.get(class_id).unwrap(), 0.0); + model.set_col_upper(*levels.get(class_id).unwrap(), vars.len() as f64); - while let Some((id, i)) = stack.pop() { - if memo.get(&(id.clone(), i)).is_some() { - continue; - } + for i in 0..c.nodes.len() { + let n_id = &egraph[class_id].nodes[i]; + let n = &egraph[n_id]; + let var = c.nodes[i]; - let node_id = &egraph[&id].nodes[i]; - let node = &egraph[node_id]; - let mut update = false; - - if node.is_leaf() { - update = true; - } else if node.children.iter().all(|x| order.get(n2c(x)).is_some()) { - if let Some(ord) = order.get(&id) { - update = node.children.iter().all(|x| &order[n2c(x)] < ord); - if !update { - memo.insert((id, i), false); - continue; - } - } else { - update = true; + let children_classes = n + .children + .iter() + .map(|n| egraph[n].eclass.clone()) + .collect::>(); + + if children_classes.contains(class_id) { + // Self loop - disable this node. + // This is clumsier than calling set_col_lower(var,0.0), + // but means it'll be infeasible (rather than producing an + // incorrect solution) if var corresponds to a root node. + let row = model.add_row(); + model.set_weight(row, var, 1.0); + model.set_row_equal(row, 0.0); + continue; } - } - if update { - if order.get(&id).is_none() { - if egraph[node_id].is_leaf() { - order.insert(id.clone(), 0); - } else { - order.insert(id.clone(), count); - count += 1; - } - } - memo.insert((id.clone(), i), true); - if let Some(mut v) = pending.remove(&id) { - stack.append(&mut v); - stack.sort(); - stack.dedup(); - }; - } - } + for cc in children_classes { + assert!(*levels.get(class_id).unwrap() != *levels.get(&cc).unwrap()); - for class in egraph.classes().values() { - let id = &class.id; - for (i, node) in class.nodes.iter().enumerate() { - if let Some(true) = memo.get(&(id.clone(), i)) { - continue; + let row = model.add_row(); + model.set_row_upper(row, -1.0); + model.set_weight(row, *levels.get(class_id).unwrap(), 1.0); + model.set_weight(row, *levels.get(&cc).unwrap(), -1.0); + + // If n.variable is 0, then disable the contraint. + model.set_weight( + row, + *opposite.get(&var).unwrap(), + -((vars.len() + 1) as f64), + ); } - assert!(!egraph[node].is_leaf()); - f(id.clone(), i); } } - assert!(pending.is_empty()); } diff --git a/src/main.rs b/src/main.rs index bddc552..ddaed2f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -36,6 +36,8 @@ fn main() { "global-greedy-dag", extract::global_greedy_dag::GlobalGreedyDagExtractor.boxed(), ), + #[cfg(feature = "ilp-cbc")] + ("ilp-cbc", extract::ilp_cbc::CbcExtractor.boxed()), ] .into_iter() .collect();