diff --git a/compiler/rustc_codegen_ssa/src/back/write.rs b/compiler/rustc_codegen_ssa/src/back/write.rs index fda8330ea8f11..f72a31d2c236b 100644 --- a/compiler/rustc_codegen_ssa/src/back/write.rs +++ b/compiler/rustc_codegen_ssa/src/back/write.rs @@ -118,6 +118,7 @@ pub struct ModuleConfig { pub inline_threshold: Option, pub emit_lifetime_markers: bool, pub llvm_plugins: Vec, + pub autodiff: Vec, } impl ModuleConfig { @@ -259,6 +260,7 @@ impl ModuleConfig { inline_threshold: sess.opts.cg.inline_threshold, emit_lifetime_markers: sess.emit_lifetime_markers(), llvm_plugins: if_regular!(sess.opts.unstable_opts.llvm_plugins.clone(), vec![]), + autodiff: if_regular!(sess.opts.unstable_opts.autodiff.clone(), vec![]), } } diff --git a/compiler/rustc_interface/src/tests.rs b/compiler/rustc_interface/src/tests.rs index 04a7714d4137e..6e27eafd63f59 100644 --- a/compiler/rustc_interface/src/tests.rs +++ b/compiler/rustc_interface/src/tests.rs @@ -729,6 +729,7 @@ fn test_unstable_options_tracking_hash() { // Make sure that changing a [TRACKED] option changes the hash. // tidy-alphabetical-start + tracked!(autodiff, vec![String::from("ad_flags")]); tracked!(allow_features, Some(vec![String::from("lang_items")])); tracked!(always_encode_mir, true); tracked!(asm_comments, true); diff --git a/compiler/rustc_session/src/config.rs b/compiler/rustc_session/src/config.rs index 2219fd5e951a8..00afe0a8f8594 100644 --- a/compiler/rustc_session/src/config.rs +++ b/compiler/rustc_session/src/config.rs @@ -174,6 +174,32 @@ pub enum InstrumentCoverage { Off, } +/// The different settings that the `-Z ad` flag can have. +#[derive(Clone, Copy, PartialEq, Hash, Debug)] +pub enum AutoDiff { + None, + PrintTA, + PrintAA, + PrintPerf, + Print, + PrintModBefore, + PrintModAfterOpts, + PrintModAfterEnzyme, + + /// Enzyme's loose type debug helper (can cause incorrect gradients) + LooseTypes, + /// Output a Module using __enzyme calls to prepare it for opt + enzyme pass usage + OPT, + + /// More flags + NoModOptAfter, + EnableFncOpt, + NoVecUnroll, + NoSafetyChecks, + Inline, + AltPipeline, +} + /// Settings for `-Z instrument-xray` flag. #[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Hash)] pub struct InstrumentXRay { @@ -2781,6 +2807,19 @@ pub fn build_session_options(early_dcx: &mut EarlyDiagCtxt, matches: &getopts::M } } + // Check for unstable values of -Z ad" + //match cg.autodiff { + // AutoDiff::None => {} + // // unstable values + // _ => { + // if !unstable_opts.unstable_options { + // early_dcx.early_fatal( + // "`-Z ad` requires `-Z unstable-options`", + // ); + // } + // } + //} + if let Ok(graphviz_font) = std::env::var("RUSTC_GRAPHVIZ_FONT") { unstable_opts.graphviz_font = graphviz_font; } diff --git a/compiler/rustc_session/src/options.rs b/compiler/rustc_session/src/options.rs index 10a4bdb94d46f..4d5a032d61be1 100644 --- a/compiler/rustc_session/src/options.rs +++ b/compiler/rustc_session/src/options.rs @@ -917,6 +917,42 @@ mod parse { } } + //pub(crate) fn parse_autodiff( + // slot: &mut AutoDiff, + // v: Option<&str>, + //) -> bool { + // if v.is_none() { + // *slot = AutoDiff::None; + // return true; + // } + + // let Some(v) = v else { + // *slot = AutoDiff::None; + // return true; + // }; + + // *slot = match v { + // "None" => AutoDiff::None, + // "PrintTA" => AutoDiff::PrintTA, + // "PrintAA" => AutoDiff::PrintAA, + // "PrintPerf" => AutoDiff::PrintPerf, + // "Print" => AutoDiff::Print, + // "PrintModBefore" => AutoDiff::PrintModBefore, + // "PrintModAfterOpts" => AutoDiff::PrintModAfterOpts, + // "PrintModAfterEnzyme" => AutoDiff::PrintModAfterEnzyme, + // "LooseTypes" => AutoDiff::LooseTypes, + // "OPT" => AutoDiff::OPT, + // "NoModOptAfter" => AutoDiff::NoModOptAfter, + // "EnableFncOpt" => AutoDiff::EnableFncOpt, + // "NoVecUnroll" => AutoDiff::NoVecUnroll, + // "NoSafetyChecks" => AutoDiff::NoSafetyChecks, + // "Inline" => AutoDiff::Inline, + // "AltPipeline" => AutoDiff::AltPipeline, + // _ => return false, + // }; + // true + //} + pub(crate) fn parse_instrument_coverage( slot: &mut InstrumentCoverage, v: Option<&str>, @@ -1544,6 +1580,8 @@ options! { either `loaded` or `not-loaded`."), assume_incomplete_release: bool = (false, parse_bool, [TRACKED], "make cfg(version) treat the current version as incomplete (default: no)"), + autodiff: Vec = (Vec::new(), parse_list, [TRACKED], + "a list autodiff flags to enable (space separated)"), #[rustc_lint_opt_deny_field_access("use `Session::binary_dep_depinfo` instead of this field")] binary_dep_depinfo: bool = (false, parse_bool, [TRACKED], "include artifacts (sysroot, crate dependencies) used during compilation in dep-info \