From 65aba5990f090ea696a1a8dc7ebec0fe496019ad Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 25 Apr 2024 12:08:10 +0900 Subject: [PATCH 1/2] [Truncate] Corrently handle constant returns --- enzyme/Enzyme/EnzymeLogic.cpp | 10 +++++++++- enzyme/test/Integration/Truncate/simple.cpp | 18 +++++++++--------- 2 files changed, 18 insertions(+), 10 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 96f5bc0e917d..30a89330b0fc 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5281,6 +5281,7 @@ class TruncateGenerator : public llvm::InstVisitor, SI.isVolatile(), SI.getOrdering(), SI.getSyncScopeID(), /*mask=*/nullptr); } + // TODO Is there a possibility we GEP a const and get a FP value? void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { @@ -5431,11 +5432,18 @@ class TruncateGenerator : public llvm::InstVisitor, newI->eraseFromParent(); return true; } + void visitIntrinsicInst(llvm::IntrinsicInst &II) { handleIntrinsic(II, II.getIntrinsicID()); } - void visitReturnInst(llvm::ReturnInst &I) { return; } + void visitReturnInst(llvm::ReturnInst &I) { + auto newI = cast(getNewFromOriginal(&I)); + IRBuilder<> B(newI); + if (isa(newI->getOperand(0))) + newI->setOperand(0, createFPRTConstCall(B, newI->getReturnValue())); + return; + } void visitBranchInst(llvm::BranchInst &I) { return; } void visitSwitchInst(llvm::SwitchInst &I) { return; } diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 635a2e3bc04c..d40c1620ffd2 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -89,16 +89,16 @@ int main() { double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); APPROX_EQ(trunc, truth, 1e-5); } + { + double a = 2; + double b = 3; + double truth = constt(a, b); + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); + APPROX_EQ(trunc, truth, 1e-5); + } #endif - // { - // double a = 2; - // double b = 3; - // double truth = intrinsics(a, b); - // a = __enzyme_truncate_mem_value(a, FROM, TO); - // b = __enzyme_truncate_mem_value(b, FROM, TO); - // double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(constt, FROM, TO)(a, b), FROM, TO); - // APPROX_EQ(trunc, truth, 1e-5); - // } #ifdef TRUNC_OP { From 16960d347b93941bab778d977363c948c70a9783 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Thu, 25 Apr 2024 16:08:05 +0900 Subject: [PATCH 2/2] fix --- enzyme/Enzyme/EnzymeLogic.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 30a89330b0fc..73086399390d 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5439,6 +5439,8 @@ class TruncateGenerator : public llvm::InstVisitor, void visitReturnInst(llvm::ReturnInst &I) { auto newI = cast(getNewFromOriginal(&I)); + if (newI->getNumOperands() == 0) + return; IRBuilder<> B(newI); if (isa(newI->getOperand(0))) newI->setOperand(0, createFPRTConstCall(B, newI->getReturnValue()));