Skip to content

Commit

Permalink
add RUSTFLAGS version of enzyme flags
Browse files Browse the repository at this point in the history
  • Loading branch information
ZuseZ4 committed Aug 11, 2024
1 parent a55f6f1 commit 024c9ad
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 0 deletions.
2 changes: 2 additions & 0 deletions compiler/rustc_codegen_ssa/src/back/write.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,7 @@ pub struct ModuleConfig {
pub inline_threshold: Option<u32>,
pub emit_lifetime_markers: bool,
pub llvm_plugins: Vec<String>,
pub autodiff: Vec<String>,
}

impl ModuleConfig {
Expand Down Expand Up @@ -259,6 +260,7 @@ impl ModuleConfig {
inline_threshold: sess.opts.cg.inline_threshold,
emit_lifetime_markers: sess.emit_lifetime_markers(),
llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]),
autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]),
}
}

Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_interface/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -729,6 +729,7 @@ fn test_unstable_options_tracking_hash() {

// Make sure that changing a [TRACKED] option changes the hash.
// tidy-alphabetical-start
tracked!(autodiff, vec![String::from("ad_flags")]);
tracked!(allow_features, Some(vec![String::from("lang_items")]));
tracked!(always_encode_mir, true);
tracked!(asm_comments, true);
Expand Down
39 changes: 39 additions & 0 deletions compiler/rustc_session/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,32 @@ pub enum InstrumentCoverage {
Off,
}

/// The different settings that the `-Z ad` flag can have.
#[derive(Clone, Copy, PartialEq, Hash, Debug)]
pub enum AutoDiff {
None,
PrintTA,
PrintAA,
PrintPerf,
Print,
PrintModBefore,
PrintModAfterOpts,
PrintModAfterEnzyme,

/// Enzyme's loose type debug helper (can cause incorrect gradients)
LooseTypes,
/// Output a Module using __enzyme calls to prepare it for opt + enzyme pass usage
OPT,

/// More flags
NoModOptAfter,
EnableFncOpt,
NoVecUnroll,
NoSafetyChecks,
Inline,
AltPipeline,
}

/// Settings for `-Z instrument-xray` flag.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)]
pub struct InstrumentXRay {
Expand Down Expand Up @@ -2781,6 +2807,19 @@ pub fn build_session_options(early_dcx: &mut EarlyDiagCtxt, matches: &getopts::M
}
}

// Check for unstable values of -Z ad"
//match cg.autodiff {
// AutoDiff::None => {}
// // unstable values
// _ => {
// if !unstable_opts.unstable_options {
// early_dcx.early_fatal(
// "`-Z ad` requires `-Z unstable-options`",
// );
// }
// }
//}

if let Ok(graphviz_font) = std::env::var("RUSTC_GRAPHVIZ_FONT") {
unstable_opts.graphviz_font = graphviz_font;
}
Expand Down
38 changes: 38 additions & 0 deletions compiler/rustc_session/src/options.rs
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,42 @@ mod parse {
}
}

//pub(crate) fn parse_autodiff(
// slot: &mut AutoDiff,
// v: Option<&str>,
//) -> bool {
// if v.is_none() {
// *slot = AutoDiff::None;
// return true;
// }

// let Some(v) = v else {
// *slot = AutoDiff::None;
// return true;
// };

// *slot = match v {
// "None" => AutoDiff::None,
// "PrintTA" => AutoDiff::PrintTA,
// "PrintAA" => AutoDiff::PrintAA,
// "PrintPerf" => AutoDiff::PrintPerf,
// "Print" => AutoDiff::Print,
// "PrintModBefore" => AutoDiff::PrintModBefore,
// "PrintModAfterOpts" => AutoDiff::PrintModAfterOpts,
// "PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme,
// "LooseTypes" => AutoDiff::LooseTypes,
// "OPT" => AutoDiff::OPT,
// "NoModOptAfter" => AutoDiff::NoModOptAfter,
// "EnableFncOpt" => AutoDiff::EnableFncOpt,
// "NoVecUnroll" => AutoDiff::NoVecUnroll,
// "NoSafetyChecks" => AutoDiff::NoSafetyChecks,
// "Inline" => AutoDiff::Inline,
// "AltPipeline" => AutoDiff::AltPipeline,
// _ => return false,
// };
// true
//}

pub(crate) fn parse_instrument_coverage(
slot: &mut InstrumentCoverage,
v: Option<&str>,
Expand Down Expand Up @@ -1544,6 +1580,8 @@ options! {
either `loaded` or `not-loaded`."),
assume_incomplete_release: bool = (false, parse_bool, [TRACKED],
"make cfg(version) treat the current version as incomplete (default: no)"),
autodiff: Vec<String> = (Vec::new(), parse_list, [TRACKED],
"a list autodiff flags to enable (space separated)"),
#[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")]
binary_dep_depinfo: bool = (false, parse_bool, [TRACKED],
"include artifacts (sysroot, crate dependencies) used during compilation in dep-info \
Expand Down

0 comments on commit 024c9ad

Please sign in to comment.