Skip to content

Commit

Permalink
Merge pull request #412 from egraphs-good/ajpal-rs2bril
Browse files Browse the repository at this point in the history
Add Rust to Bril
  • Loading branch information
ajpal authored Apr 3, 2024
2 parents 43e4112 + b2e92a9 commit 1fbc7ec
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 31 deletions.
59 changes: 41 additions & 18 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

10 changes: 6 additions & 4 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,12 @@ indexmap = "2.0"
fixedbitset = "0.4.2"
smallvec = "1.11.1"

bril2json = { git = "https://github.com/sampsyo/bril", rev = "b20cc2d" }
brilirs = { git = "https://github.com/sampsyo/bril", rev = "b20cc2d" }
bril-rs = { git = "https://github.com/sampsyo/bril", rev = "b20cc2d" }
brilift = { git = "https://github.com/sampsyo/bril", rev = "b20cc2d" }
syn = {version = "2.0", features = ["full", "extra-traits"]}
bril2json = { git = "https://github.com/ajpal/bril", rev = "af21f1c" }
brilirs = { git = "https://github.com/ajpal/bril", rev = "af21f1c" }
bril-rs = { git = "https://github.com/ajpal/bril", rev = "af21f1c" }
brilift = { git = "https://github.com/ajpal/bril", rev = "af21f1c" }
rs2bril = { git = "https://github.com/ajpal/bril", rev = "af21f1c", features = ["import"]}
ordered-float = { version = "3.7" }
serde_json = "1.0.103"
dot-structures = "0.1.1"
Expand Down
13 changes: 10 additions & 3 deletions src/main.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use clap::Parser;
use eggcc::util::{visualize, Run, RunType, TestProgram};
use std::path::PathBuf;
use std::{ffi::OsStr, path::PathBuf};

#[derive(Debug, Parser)]
struct Args {
Expand Down Expand Up @@ -44,7 +44,7 @@ fn main() {
let args = Args::parse();

if let Some(debug_dir) = args.debug_dir {
if let Result::Err(error) = visualize(TestProgram::File(args.file.clone()), debug_dir) {
if let Result::Err(error) = visualize(TestProgram::BrilFile(args.file.clone()), debug_dir) {
eprintln!("{}", error);
return;
}
Expand All @@ -58,8 +58,15 @@ fn main() {
return;
}

let file = match args.file.extension().and_then(OsStr::to_str) {
Some("rs") => TestProgram::RustFile(args.file.clone()),
Some("bril") => TestProgram::BrilFile(args.file.clone()),
Some(x) => panic!("unexpected file extension {x}"),
None => panic!("could not parse file extension"),
};

let run = Run {
prog_with_args: TestProgram::File(args.file.clone()).read_program(),
prog_with_args: file.read_program(),
test_type: args.run_mode,
interp: args.interp,
profile_out: args.profile_out,
Expand Down
27 changes: 22 additions & 5 deletions src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ use graphviz_rust::cmd::Format;
use graphviz_rust::exec;
use graphviz_rust::printer::PrinterContext;
use std::fmt::Debug;
use std::io::Read;
use std::{
ffi::OsStr,
fmt::{Display, Formatter},
Expand Down Expand Up @@ -159,7 +160,7 @@ where
pub enum RunType {
/// Do nothing to the input bril program besides parse it.
/// Output the original program.
Nothing,
Parse,
/// Convert the input bril program to the tree encoding, optimize the program
/// using egglog, and output the resulting bril program.
/// The default way to run this tool.
Expand Down Expand Up @@ -212,7 +213,7 @@ impl RunType {
/// that can be interpreted.
pub fn produces_interpretable(&self) -> bool {
match self {
RunType::Nothing => true,
RunType::Parse => true,
RunType::Optimize => true,
RunType::RvsdgConversion => false,
RunType::RvsdgRoundTrip => true,
Expand Down Expand Up @@ -242,14 +243,15 @@ pub struct ProgWithArguments {
#[derive(Clone)]
pub enum TestProgram {
Prog(ProgWithArguments),
File(PathBuf),
BrilFile(PathBuf),
RustFile(PathBuf),
}

impl TestProgram {
pub fn read_program(self) -> ProgWithArguments {
match self {
TestProgram::Prog(prog) => prog,
TestProgram::File(path) => {
TestProgram::BrilFile(path) => {
let program_read = std::fs::read_to_string(path.clone()).unwrap();
let args = Optimizer::parse_bril_args(&program_read);
let program = Optimizer::parse_bril(&program_read).unwrap();
Expand All @@ -261,6 +263,21 @@ impl TestProgram {
args,
}
}
TestProgram::RustFile(path) => {
let mut src = String::new();
let mut file = std::fs::File::open(path.clone()).unwrap();

file.read_to_string(&mut src).unwrap();
let syntax = syn::parse_file(&src).unwrap();
let name = path.display().to_string();
let program = rs2bril::from_file_to_program(syntax, false, Some(name.clone()));

ProgWithArguments {
program,
name,
args: vec![],
}
}
}
}
}
Expand Down Expand Up @@ -390,7 +407,7 @@ impl Run {
};

let (visualizations, interpretable_out) = match self.test_type {
RunType::Nothing => (
RunType::Parse => (
vec![],
Some(Interpretable::Bril(self.prog_with_args.program.clone())),
),
Expand Down
2 changes: 1 addition & 1 deletion tests/files.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ fn generate_tests(glob: &str) -> Vec<Trial> {

let snapshot = f.to_str().unwrap().contains("small");

for run in Run::all_configurations_for(TestProgram::File(f)) {
for run in Run::all_configurations_for(TestProgram::BrilFile(f)) {
mk_trial(run, snapshot);
}
}
Expand Down

0 comments on commit 1fbc7ec

Please sign in to comment.