From 123b0e8d6b5ec905a49bd73b5f018a510438a0a9 Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 16 Aug 2024 02:30:15 -0700 Subject: [PATCH] simplify ffi wrappers (#169) --- compiler/rustc_codegen_llvm/src/back/write.rs | 14 ----------- compiler/rustc_codegen_llvm/src/builder.rs | 4 ---- compiler/rustc_codegen_llvm/src/llvm/ffi.rs | 23 ++----------------- 3 files changed, 2 insertions(+), 39 deletions(-) diff --git a/compiler/rustc_codegen_llvm/src/back/write.rs b/compiler/rustc_codegen_llvm/src/back/write.rs index ba77a4f20413c..df8519cbc2c35 100644 --- a/compiler/rustc_codegen_llvm/src/back/write.rs +++ b/compiler/rustc_codegen_llvm/src/back/write.rs @@ -1019,16 +1019,6 @@ pub(crate) unsafe fn enzyme_ad( // A really simple check assert!(src_num_args <= target_num_args); - // create enzyme typetrees - let llvm_data_layout = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) }; - let llvm_data_layout = - std::str::from_utf8(unsafe { CStr::from_ptr(llvm_data_layout) }.to_bytes()) - .expect("got a non-UTF8 data-layout from LLVM"); - - let input_tts = - item.inputs.into_iter().map(|x| to_enzyme_typetree(x, llvm_data_layout, llcx)).collect(); - let output_tt = to_enzyme_typetree(item.output, llvm_data_layout, llcx); - let type_analysis: EnzymeTypeAnalysisRef = unsafe {CreateTypeAnalysis(logic_ref, std::ptr::null_mut(), std::ptr::null_mut(), 0)}; @@ -1066,8 +1056,6 @@ pub(crate) unsafe fn enzyme_ad( src_fnc, args_activity, ret_activity, - input_tts, - output_tt, void_ret, ), DiffMode::Reverse => enzyme_rust_reverse_diff( @@ -1076,8 +1064,6 @@ pub(crate) unsafe fn enzyme_ad( src_fnc, args_activity, ret_activity, - input_tts, - output_tt, ), _ => unreachable!(), }; diff --git a/compiler/rustc_codegen_llvm/src/builder.rs b/compiler/rustc_codegen_llvm/src/builder.rs index 0d91399215354..41d63a94a0341 100644 --- a/compiler/rustc_codegen_llvm/src/builder.rs +++ b/compiler/rustc_codegen_llvm/src/builder.rs @@ -87,9 +87,6 @@ pub fn add_tt2<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, fn_def: #[allow(unused)] pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Context, val: &'ll Value, attrs: AutoDiffAttrs, i: usize) { - //pub mode: DiffMode, - //pub ret_activity: DiffActivity, - //pub input_activity: Vec, let inputs = attrs.input_activity; let outputs = attrs.ret_activity; let ad_name = match attrs.mode { @@ -136,7 +133,6 @@ pub fn add_opt_dbg_helper<'ll>(llmod: &'ll llvm::Module, llcx: &'ll llvm::Contex let num_args = llvm::LLVMCountParams(wrapper_fn); let mut args = Vec::with_capacity(num_args as usize + 1); args.push(val); - // metadata !"enzyme_const" let enzyme_const = llvm::LLVMMDStringInContext(llcx, "enzyme_const".as_ptr() as *const c_char, 12); let enzyme_out = llvm::LLVMMDStringInContext(llcx, "enzyme_out".as_ptr() as *const c_char, 10); let enzyme_dup = llvm::LLVMMDStringInContext(llcx, "enzyme_dup".as_ptr() as *const c_char, 10); diff --git a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs index baa729448684b..008825480e2d9 100644 --- a/compiler/rustc_codegen_llvm/src/llvm/ffi.rs +++ b/compiler/rustc_codegen_llvm/src/llvm/ffi.rs @@ -845,15 +845,12 @@ pub enum LLVMVerifierFailureAction { LLVMReturnStatusAction, } -#[allow(dead_code)] pub(crate) unsafe fn enzyme_rust_forward_diff( logic_ref: EnzymeLogicRef, type_analysis: EnzymeTypeAnalysisRef, fnc: &Value, input_diffactivity: Vec, ret_diffactivity: DiffActivity, - _input_tts: Vec, - _output_tt: TypeTree, void_ret: bool, ) -> (&Value, Vec) { let ret_activity = cdiffe_from(ret_diffactivity); @@ -882,9 +879,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( }; trace!("ret_primary_ret: {}", &ret_primary_ret); - //let mut args_tree = input_tts.iter().map(|x| x.inner).collect::>(); - //let mut args_tree = vec![TypeTree::new().inner; typetree.input_tt.len()]; - // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. let args_uncacheable = vec![0; input_activity.len()]; @@ -900,9 +894,6 @@ pub(crate) unsafe fn enzyme_rust_forward_diff( let tree_tmp = TypeTree::new(); let mut args_tree = vec![tree_tmp.inner; input_activity.len()]; - //let mut args_tree = vec![std::ptr::null_mut(); input_activity.len()]; - //let ret_tt = std::ptr::null_mut(); - //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()]; let ret_tt = TypeTree::new(); let dummy_type = CFnTypeInfo { Arguments: args_tree.as_mut_ptr(), @@ -944,8 +935,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( fnc: &Value, rust_input_activity: Vec, ret_activity: DiffActivity, - input_tts: Vec, - _output_tt: TypeTree, ) -> (&Value, Vec) { let (primary_ret, ret_activity) = match ret_activity { DiffActivity::Const => (true, CDIFFE_TYPE::DFT_CONSTANT), @@ -971,8 +960,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( input_activity.push(cdiffe_from(x)); } - //let args_tree = input_tts.iter().map(|x| x.inner).collect::>(); - // We don't support volatile / extern / (global?) values. // Just because I didn't had time to test them, and it seems less urgent. let args_uncacheable = vec![0; input_activity.len()]; @@ -982,14 +969,11 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( assert!(num_fnc_args == input_activity.len() as u32); let kv_tmp = IntList { data: std::ptr::null_mut(), size: 0 }; - let mut known_values = vec![kv_tmp; input_tts.len()]; + let mut known_values = vec![kv_tmp; input_activity.len()]; let tree_tmp = TypeTree::new(); - let mut args_tree = vec![tree_tmp.inner; input_tts.len()]; - //let mut args_tree = vec![TypeTree::new().inner; input_tts.len()]; + let mut args_tree = vec![tree_tmp.inner; input_activity.len()]; let ret_tt = TypeTree::new(); - //let mut args_tree = vec![std::ptr::null_mut(); input_tts.len()]; - //let ret_tt = std::ptr::null_mut(); let dummy_type = CFnTypeInfo { Arguments: args_tree.as_mut_ptr(), Return: ret_tt.inner, @@ -1029,9 +1013,6 @@ pub(crate) unsafe fn enzyme_rust_reverse_diff( } extern "C" { - // TODO: can I just ignore the non void return - // EraseFromParent doesn't exist :( - //pub fn LLVMEraseFromParent(BB: &BasicBlock) -> &Value; // Enzyme pub fn LLVMRustAddFncParamAttr<'a>( F: &'a Value,