diff --git a/src/main.rs b/src/main.rs index e3dd1c0..6d6704d 100644 --- a/src/main.rs +++ b/src/main.rs @@ -4,6 +4,8 @@ pub use extract::*; use egraph_serialize::*; +mod to_egraph_serialized; + use indexmap::IndexMap; use ordered_float::NotNan; @@ -128,6 +130,8 @@ fn main() { .unwrap() .unwrap_or_else(|| "out.json".into()); + let pruned_filename: Option = args.opt_value_from_str("--pruned").unwrap(); + let filename: String = args.free_from_str().unwrap(); let rest = args.finish(); @@ -152,6 +156,12 @@ fn main() { result.check(&egraph); + if let Some(pruned_filename) = pruned_filename { + let egraph = to_egraph_serialized::get_term(&egraph, &result); + egraph.to_json_file(pruned_filename.clone()).unwrap(); + println!("Wrote pruned egraph to {}", pruned_filename.display()); + } + let tree = result.tree_cost(&egraph, &egraph.root_eclasses); let dag = result.dag_cost(&egraph, &egraph.root_eclasses); diff --git a/src/to_egraph_serialized.rs b/src/to_egraph_serialized.rs new file mode 100644 index 0000000..2d30973 --- /dev/null +++ b/src/to_egraph_serialized.rs @@ -0,0 +1,54 @@ +use egraph_serialize::{ClassId, NodeId}; +use indexmap::IndexMap; + +use crate::ExtractionResult; + +pub fn get_term( + egraph: &egraph_serialize::EGraph, + result: &ExtractionResult, +) -> egraph_serialize::EGraph { + let choices = &result.choices; + assert!( + egraph.root_eclasses.len() == 1, + "expected exactly one root eclass", + ); + let root_cid = egraph.root_eclasses[0].clone(); + let mut result_egraph = egraph_serialize::EGraph::default(); + // populate_egraph(egraph, &mut result_egraph, choices, root_cid); + for cid in choices.keys() { + let node = &choices[cid]; + // add the node to the result egraph + if !result_egraph.nodes.contains_key(node) { + let mut new_node = egraph.nodes[node].clone(); + new_node.children = egraph.nodes[node] + .children + .iter() + .map(|child| choices[egraph.nid_to_cid(&child)].clone()) + .collect(); + + result_egraph.add_node(node.clone(), new_node); + } + } + + // find number of eclasses in the original egraph + let mut eclasses = std::collections::HashSet::new(); + for enode in egraph.nodes.values() { + eclasses.insert(enode.eclass.clone()); + } + result_egraph.root_eclasses = egraph.root_eclasses.clone(); + result_egraph +} + +fn populate_egraph( + egraph: &egraph_serialize::EGraph, + result_egraph: &mut egraph_serialize::EGraph, + choices: &IndexMap, + cid: ClassId, +) { + // get the node for the eclass + let node = &choices[&cid]; + // add the node to the result egraph + if !result_egraph.nodes.contains_key(node) { + result_egraph.add_node(node.clone(), egraph.nodes[node].clone()); + } +}