diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 96f5bc0e917d..73086399390d 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5281,6 +5281,7 @@ class TruncateGenerator : public llvm::InstVisitor, SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), /*mask=*/nullptr); } + // TODO Is there a possibility we GEP a const and get a FP value? void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { @@ -5431,11 +5432,20 @@ class TruncateGenerator : public llvm::InstVisitor, newI->eraseFromParent(); return true; } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { handleIntrinsic(II, II.getIntrinsicID()); } - void visitReturnInst(llvm::ReturnInst &I) { return; } + void visitReturnInst(llvm::ReturnInst &I) { + auto newI = cast(getNewFromOriginal(&I)); + if (newI->getNumOperands() == 0) + return; + IRBuilder<> B(newI); + if (isa(newI->getOperand(0))) + newI->setOperand(0, createFPRTConstCall(B, newI->getReturnValue())); + return; + } void visitBranchInst(llvm::BranchInst &I) { return; } void visitSwitchInst(llvm::SwitchInst &I) { return; } diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 635a2e3bc04c..d40c1620ffd2 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -89,16 +89,16 @@ int main() { double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); APPROX_EQ(trunc, truth, 1e-5); } + { + double a = 2; + double b = 3; + double truth = constt(a, b); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } #endif - // { - // double a = 2; - // double b = 3; - // double truth = intrinsics(a, b); - // a = __enzyme_truncate_mem_value(a, FROM, TO); - // b = __enzyme_truncate_mem_value(b, FROM, TO); - // double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); - // APPROX_EQ(trunc, truth, 1e-5); - // } #ifdef TRUNC_OP {