From 30cc978335230c6fda9c313f9998cfb299470e23 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 22 Apr 2024 10:39:41 +0900 Subject: [PATCH 1/6] Add debug info to fprt runtime calls --- enzyme/Enzyme/Enzyme.cpp | 1 + enzyme/Enzyme/EnzymeLogic.cpp | 56 +++- enzyme/Enzyme/EnzymeLogic.h | 6 + enzyme/include/enzyme/fprt/mpfr-test.h | 306 ++++++++++++++++++ enzyme/include/enzyme/fprt/mpfr.h | 58 ++-- enzyme/test/Integration/Truncate/simple.cpp | 3 + .../Integration/Truncate/truncate-all.cpp | 19 +- 7 files changed, 404 insertions(+), 45 deletions(-) create mode 100644 enzyme/include/enzyme/fprt/mpfr-test.h diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index e2b90c52c511..13cd4d747cb9 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -24,6 +24,7 @@ // //===----------------------------------------------------------------------===// #include +#include #if LLVM_VERSION_MAJOR >= 16 #define private public diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 96f5bc0e917d..0a8edc526890 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -42,6 +42,9 @@ #include "llvm/IR/LLVMContext.h" #include "llvm/Support/ErrorHandling.h" #include +#include +#include +#include #if LLVM_VERSION_MAJOR >= 16 #define private public @@ -5025,6 +5028,8 @@ class TruncateUtils { Type *fromType; Type *toType; LLVMContext &ctx; + EnzymeLogic &Logic; + Value *NullPtr; private: std::string getOriginalFPRTName(std::string Name) { @@ -5077,22 +5082,24 @@ class TruncateUtils { CallInst *createFPRTGeneric(llvm::IRBuilderBase &B, std::string Name, const SmallVectorImpl &ArgsIn, - llvm::Type *RetTy) { + llvm::Type *RetTy, Value *LocStr) { SmallVector Args(ArgsIn.begin(), ArgsIn.end()); Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); Args.push_back(B.getInt64(truncation.getTo().significandWidth)); Args.push_back(B.getInt64(truncation.getMode())); + Args.push_back(LocStr); auto FprtFunc = getFPRTFunc(Name, Args, RetTy); return cast(B.CreateCall(FprtFunc, Args)); } public: - TruncateUtils(FloatTruncation truncation, Module *M) - : truncation(truncation), M(M), ctx(M->getContext()) { + TruncateUtils(FloatTruncation truncation, Module *M, EnzymeLogic &Logic) + : truncation(truncation), M(M), ctx(M->getContext()), Logic(Logic) { fromType = truncation.getFromType(ctx); toType = truncation.getToType(ctx); if (fromType == toType) assert(truncation.isToFPRT()); + NullPtr = ConstantPointerNull::get(PointerType::get(ctx, 0)); } Type *getFromType() { return fromType; } @@ -5103,23 +5110,51 @@ class TruncateUtils { assert(V->getType() == getFromType()); SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "const", Args, getToType()); + return createFPRTGeneric(B, "const", Args, getToType(), NullPtr); } CallInst *createFPRTNewCall(llvm::IRBuilderBase &B, Value *V) { assert(V->getType() == getFromType()); SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "new", Args, getToType()); + return createFPRTGeneric(B, "new", Args, getToType(), NullPtr); } CallInst *createFPRTGetCall(llvm::IRBuilderBase &B, Value *V) { SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "get", Args, getToType()); + return createFPRTGeneric(B, "get", Args, getToType(), NullPtr); } CallInst *createFPRTDeleteCall(llvm::IRBuilderBase &B, Value *V) { SmallVector Args; Args.push_back(V); - return createFPRTGeneric(B, "delete", Args, B.getVoidTy()); + return createFPRTGeneric(B, "delete", Args, B.getVoidTy(), NullPtr); + } + // This will result in a unique string for each location, which means the + // runtime can check whether two operations are the same with a simple pointer + // comparison. However, we need LTO for this to be the case across different + // compilation units. + GlobalValue *getUniquedLocStr(Instruction &I) { + auto M = I.getParent()->getParent()->getParent(); + std::string FileName = M->getName().str(); + + unsigned LineNo = 0; + unsigned ColNo = 0; + if (I.getDebugLoc().get()) { + LineNo = I.getDebugLoc().getLine(); + ColNo = I.getDebugLoc().getCol(); + } + + auto Key = std::make_tuple(FileName, LineNo, ColNo); + auto It = Logic.UniqDebugLocStrs.find(Key); + + if (It != Logic.UniqDebugLocStrs.end()) + return It->second; + + std::string LocStr = + FileName + ":" + std::to_string(LineNo) + ":" + std::to_string(ColNo); + auto GV = createPrivateGlobalForString(*M, LocStr, true); + Logic.UniqDebugLocStrs[Key] = GV; + + return GV; } CallInst *createFPRTOpCall(llvm::IRBuilderBase &B, llvm::Instruction &I, llvm::Type *RetTy, @@ -5146,7 +5181,7 @@ class TruncateUtils { llvm_unreachable("Unexpected instruction for conversion to FPRT"); } createOriginalFPRTFunc(I, Name, ArgsIn, RetTy); - return createFPRTGeneric(B, Name, ArgsIn, RetTy); + return createFPRTGeneric(B, Name, ArgsIn, RetTy, getUniquedLocStr(I)); } }; @@ -5165,7 +5200,7 @@ class TruncateGenerator : public llvm::InstVisitor, TruncateGenerator(ValueToValueMapTy &originalToNewFn, FloatTruncation truncation, Function *oldFunc, Function *newFunc, EnzymeLogic &Logic) - : TruncateUtils(truncation, newFunc->getParent()), + : TruncateUtils(truncation, newFunc->getParent(), Logic), originalToNewFn(originalToNewFn), truncation(truncation), oldFunc(oldFunc), newFunc(newFunc), mode(truncation.getMode()), Logic(Logic), ctx(newFunc->getContext()) {} @@ -5559,7 +5594,8 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, Value *converted = nullptr; auto truncation = FloatTruncation(from, to, TruncMemMode); - TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent()); + TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent(), + *this); if (isTruncate) converted = TU.createFPRTNewCall(B, v); else diff --git a/enzyme/Enzyme/EnzymeLogic.h b/enzyme/Enzyme/EnzymeLogic.h index ae23fb74a781..fc2555c658fb 100644 --- a/enzyme/Enzyme/EnzymeLogic.h +++ b/enzyme/Enzyme/EnzymeLogic.h @@ -31,6 +31,7 @@ #define ENZYME_LOGIC_H #include +#include #include #include @@ -404,9 +405,14 @@ struct FloatTruncation { std::string mangleFrom() const { return from.to_string(); } }; +typedef std::map, + llvm::GlobalValue *> + UniqDebugLocStrsTy; + class EnzymeLogic { public: PreProcessCache PPC; + UniqDebugLocStrsTy UniqDebugLocStrs; /// \p PostOpt is whether to perform basic /// optimization of the function after synthesis diff --git a/enzyme/include/enzyme/fprt/mpfr-test.h b/enzyme/include/enzyme/fprt/mpfr-test.h new file mode 100644 index 000000000000..1b24cf88fb09 --- /dev/null +++ b/enzyme/include/enzyme/fprt/mpfr-test.h @@ -0,0 +1,306 @@ +//===- fprt/mpfr - MPFR wrappers ---------------------------------------===// +// +// Enzyme Project +// +// Part of the Enzyme Project, under the Apache License v2.0 with LLVM +// Exceptions. See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +// If using this code in an academic setting, please cite the following: +// @incollection{enzymeNeurips, +// title = {Instead of Rewriting Foreign Code for Machine Learning, +// Automatically Synthesize Fast Gradients}, +// author = {Moses, William S. and Churavy, Valentin}, +// booktitle = {Advances in Neural Information Processing Systems 33}, +// year = {2020}, +// note = {To appear in}, +// } +// +//===----------------------------------------------------------------------===// +// +// This file contains easy to use wrappers around MPFR functions. +// +//===----------------------------------------------------------------------===// +#ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ +#define __ENZYME_RUNTIME_ENZYME_MPFR__ + +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// TODO s +// +// (for MPFR ver. 2.1) +// +// We need to set the range of the allowed exponent using `mpfr_set_emin` and +// `mpfr_set_emax`. (This means we can also play with whether the range is +// centered around 0 (1?) or somewhere else) +// +// (also these need to be mutex'ed as the exponent change is global in mpfr and +// not float-specific) ... (mpfr seems to have thread safe mode - check if it is +// enabled or if it is enabled by default) +// +// For that we need to do this check: +// If the user changes the exponent range, it is her/his responsibility to +// check that all current floating-point variables are in the new allowed +// range (for example using mpfr_check_range), otherwise the subsequent +// behavior will be undefined, in the sense of the ISO C standard. +// +// MPFR docs state the following: +// Note: Overflow handling is still experimental and currently implemented +// partially. If an overflow occurs internally at the wrong place, anything +// can happen (crash, wrong results, etc). +// +// Which we would like to avoid somehow. +// +// MPFR also has this limitation that we need to address for accurate +// simulation: +// [...] subnormal numbers are not implemented. +// +// TODO we need to provide f32 versions, and also instrument the +// truncation/expansion between f32/f64/etc + +#define __ENZYME_MPFR_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_ORIGINAL_ATTRIBUTES __attribute__((weak)) +#define __ENZYME_MPFR_DEFAULT_ROUNDING_MODE GMP_RNDN + +static bool __enzyme_fprt_is_mem_mode(int64_t mode) { return mode & 0b0001; } +static bool __enzyme_fprt_is_op_mode(int64_t mode) { return mode & 0b0010; } + +typedef struct { + mpfr_t v; +} __enzyme_fp; + +static double __enzyme_fprt_ptr_to_double(__enzyme_fp *p) { + return *((double *)(&p)); +} +static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { + return *((__enzyme_fp **)(&d)); +} + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, + int64_t mode, char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); + return mpfr_get_d(a->v, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); +} + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, + int64_t mode, char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); + mpfr_init2(a->v, significand); + mpfr_set_d(a->v, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); + return __enzyme_fprt_ptr_to_double(a); +} + +__ENZYME_MPFR_ATTRIBUTES +double __enzyme_fprt_64_52_const(double _a, int64_t exponent, + int64_t significand, int64_t mode, char *loc) { + printf("%p, %s\n", loc, loc); + // TODO This should really be called only once for an appearance in the code, + // currently it is called every time a flop uses a constant. + return __enzyme_fprt_64_52_new(_a, exponent, significand, mode, loc); +} + +__ENZYME_MPFR_ATTRIBUTES +__enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, + int64_t significand, + int64_t mode, char *loc) { + printf("%p, %s\n", loc, loc); + __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); + mpfr_init2(a->v, significand); + return a; +} + +__ENZYME_MPFR_ATTRIBUTES +void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, + int64_t mode, char *loc) { + printf("%p, %s\n", loc, loc); + free(__enzyme_fprt_double_to_ptr(a)); +} + +#define __ENZYME_MPFR_SINGOP(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ + char *loc) { \ + printf("%p, %s\n", loc, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +// TODO this is a bit sketchy if the user cast their float to int before calling +// this. We need to detect these patterns +#define __ENZYME_MPFR_BIN_INT(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, \ + FROM_TYPE, RET, MPFR_GET, ARG1, MPFR_SET_ARG1, \ + ARG2, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + char *loc) { \ + printf("%p, %s\n", loc, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, b, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, b, ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +#define __ENZYME_MPFR_BIN(OP_TYPE, LLVM_OP_NAME, MPFR_FUNC_NAME, FROM_TYPE, \ + RET, MPFR_GET, ARG1, MPFR_SET_ARG1, ARG2, \ + MPFR_SET_ARG2, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + char *loc) { \ + printf("%p, %s\n", loc, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb, mc; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_set_##MPFR_SET_ARG1(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_SET_ARG2(mb, b, ROUNDING_MODE); \ + mpfr_##MPFR_FUNC_NAME(mc, ma, mb, ROUNDING_MODE); \ + RET c = mpfr_get_##MPFR_GET(mc, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + return c; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ + mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, mb->v, ROUNDING_MODE); \ + return __enzyme_fprt_ptr_to_double(mc); \ + } else { \ + abort(); \ + } \ + } + +#define __ENZYME_MPFR_FMULADD(LLVM_OP_NAME, FROM_TYPE, TYPE, MPFR_TYPE, \ + LLVM_TYPE, ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ + TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ + int64_t mode, char *loc) { \ + printf("%p, %s\n", loc, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb, mc, mmul, madd; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_init2(mc, significand); \ + mpfr_init2(mmul, significand); \ + mpfr_init2(madd, significand); \ + mpfr_set_##MPFR_TYPE(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_TYPE(mb, b, ROUNDING_MODE); \ + mpfr_set_##MPFR_TYPE(mc, c, ROUNDING_MODE); \ + mpfr_mul(mmul, ma, mb, ROUNDING_MODE); \ + mpfr_add(madd, mmul, mc, ROUNDING_MODE); \ + TYPE res = mpfr_get_##MPFR_TYPE(madd, ROUNDING_MODE); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + mpfr_clear(mc); \ + mpfr_clear(mmul); \ + mpfr_clear(madd); \ + return res; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + __enzyme_fp *mc = __enzyme_fprt_double_to_ptr(c); \ + double mmul = __enzyme_fprt_##FROM_TYPE##_binop_fmul( \ + __enzyme_fprt_ptr_to_double(ma), __enzyme_fprt_ptr_to_double(mb), \ + exponent, significand, mode, loc); \ + double madd = __enzyme_fprt_##FROM_TYPE##_binop_fadd( \ + mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode, \ + loc); \ + return madd; \ + } else { \ + abort(); \ + } \ + } + +// TODO This does not currently make distinctions between ordered/unordered. +#define __ENZYME_MPFR_FCMP_IMPL(NAME, ORDERED, CMP, FROM_TYPE, TYPE, MPFR_GET, \ + ROUNDING_MODE) \ + __ENZYME_MPFR_ATTRIBUTES \ + bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ + TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ + char *loc) { \ + printf("%p, %s\n", loc, loc); \ + if (__enzyme_fprt_is_op_mode(mode)) { \ + mpfr_t ma, mb; \ + mpfr_init2(ma, significand); \ + mpfr_init2(mb, significand); \ + mpfr_set_##MPFR_GET(ma, a, ROUNDING_MODE); \ + mpfr_set_##MPFR_GET(mb, b, ROUNDING_MODE); \ + int ret = mpfr_cmp(ma, mb); \ + mpfr_clear(ma); \ + mpfr_clear(mb); \ + return ret CMP; \ + } else if (__enzyme_fprt_is_mem_mode(mode)) { \ + __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ + __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ + int ret = mpfr_cmp(ma->v, mb->v); \ + return ret CMP; \ + } else { \ + abort(); \ + } \ + } + +__ENZYME_MPFR_ORIGINAL_ATTRIBUTES +bool __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(double a, + int32_t tests); +__ENZYME_MPFR_ATTRIBUTES bool +__enzyme_fprt_64_52_intr_llvm_is_fpclass_f64(double a, int32_t tests) { + return __enzyme_fprt_original_64_52_intr_llvm_is_fpclass_f64(a, tests); +} + +#include "flops.def" + +#ifdef __cplusplus +} +#endif + +#endif // #ifndef __ENZYME_RUNTIME_ENZYME_MPFR__ diff --git a/enzyme/include/enzyme/fprt/mpfr.h b/enzyme/include/enzyme/fprt/mpfr.h index a75cfbd84f15..53c97e0c213f 100644 --- a/enzyme/include/enzyme/fprt/mpfr.h +++ b/enzyme/include/enzyme/fprt/mpfr.h @@ -32,18 +32,6 @@ extern "C" { #endif -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// I dont think we intercept comparisons - we most definitely should. -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO -// TODO TODO TODO - // TODO s // // (for MPFR ver. 2.1) @@ -73,9 +61,6 @@ extern "C" { // simulation: // [...] subnormal numbers are not implemented. // -// TODO maybe take debug info as parameter - then we can emit warnings or tie -// operations to source location -// // TODO we need to provide f32 versions, and also instrument the // truncation/expansion between f32/f64/etc @@ -99,14 +84,14 @@ static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, char *loc) { __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); return mpfr_get_d(a->v, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); } __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, char *loc) { __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); mpfr_init2(a->v, significand); mpfr_set_d(a->v, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); @@ -115,16 +100,16 @@ double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_const(double _a, int64_t exponent, - int64_t significand, int64_t mode) { + int64_t significand, int64_t mode, char *loc) { // TODO This should really be called only once for an appearance in the code, // currently it is called every time a flop uses a constant. - return __enzyme_fprt_64_52_new(_a, exponent, significand, mode); + return __enzyme_fprt_64_52_new(_a, exponent, significand, mode, loc); } __ENZYME_MPFR_ATTRIBUTES __enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, char *loc) { __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); mpfr_init2(a->v, significand); return a; @@ -132,7 +117,7 @@ __enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, __ENZYME_MPFR_ATTRIBUTES void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, - int64_t mode) { + int64_t mode, char *loc) { free(__enzyme_fprt_double_to_ptr(a)); } @@ -141,7 +126,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ + char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mc; \ mpfr_init2(ma, significand); \ @@ -154,8 +140,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, return c; \ } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ @@ -170,7 +156,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ARG2, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mc; \ mpfr_init2(ma, significand); \ @@ -183,8 +170,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, return c; \ } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, b, ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ @@ -197,7 +184,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, MPFR_SET_ARG2, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ - ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode) { \ + ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ + char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb, mc; \ mpfr_init2(ma, significand); \ @@ -214,8 +202,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, } else if (__enzyme_fprt_is_mem_mode(mode)) { \ __enzyme_fp *ma = __enzyme_fprt_double_to_ptr(a); \ __enzyme_fp *mb = __enzyme_fprt_double_to_ptr(b); \ - __enzyme_fp *mc = \ - __enzyme_fprt_64_52_new_intermediate(exponent, significand, mode); \ + __enzyme_fp *mc = __enzyme_fprt_64_52_new_intermediate( \ + exponent, significand, mode, loc); \ mpfr_##MPFR_FUNC_NAME(mc->v, ma->v, mb->v, ROUNDING_MODE); \ return __enzyme_fprt_ptr_to_double(mc); \ } else { \ @@ -228,7 +216,7 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, __ENZYME_MPFR_ATTRIBUTES \ TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ - int64_t mode) { \ + int64_t mode, char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb, mc, mmul, madd; \ mpfr_init2(ma, significand); \ @@ -254,9 +242,10 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, __enzyme_fp *mc = __enzyme_fprt_double_to_ptr(c); \ double mmul = __enzyme_fprt_##FROM_TYPE##_binop_fmul( \ __enzyme_fprt_ptr_to_double(ma), __enzyme_fprt_ptr_to_double(mb), \ - exponent, significand, mode); \ + exponent, significand, mode, loc); \ double madd = __enzyme_fprt_##FROM_TYPE##_binop_fadd( \ - mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode); \ + mmul, __enzyme_fprt_ptr_to_double(mc), exponent, significand, mode, \ + loc); \ return madd; \ } else { \ abort(); \ @@ -268,7 +257,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, ROUNDING_MODE) \ __ENZYME_MPFR_ATTRIBUTES \ bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ - TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode) { \ + TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ + char *loc) { \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb; \ mpfr_init2(ma, significand); \ diff --git a/enzyme/test/Integration/Truncate/simple.cpp b/enzyme/test/Integration/Truncate/simple.cpp index 635a2e3bc04c..5e819f5a4897 100644 --- a/enzyme/test/Integration/Truncate/simple.cpp +++ b/enzyme/test/Integration/Truncate/simple.cpp @@ -5,6 +5,9 @@ // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -lm -lmpfr && %s.a.out ; fi // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -lm -lmpfr && %s.a.out ; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr-test.h -lm -lmpfr && %s.a.out ; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -DTRUNC_MEM -DTRUNC_OP -O2 %s -o %s.a.out %newLoadClangEnzyme -include enzyme/fprt/mpfr-test.h -lm -lmpfr && %s.a.out ; fi + #include #include "../test_utils.h" diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index d5038d4750cb..87a45ea0d416 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -16,12 +16,29 @@ // RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -include enzyme/fprt/mpfr.h -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr -lm && %s.a.out | FileCheck --check-prefix TO_3_7 %s; fi // TO_3_7: 897581056.000000 +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -g -include enzyme/fprt/mpfr-test.h -O3 %s -o %s.a.out %newLoadClangEnzyme -mllvm --enzyme-truncate-all="11-52to3-7" -lmpfr -lm && %s.a.out | FileCheck --check-prefix CHECK-LOCS %s; fi +// CHECK-LOCS: 0x[[op1:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op1loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op2:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op2loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op3:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op3loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op4:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op4loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op5:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op5loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op6:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op6loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op7:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op7loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op1]], {{.*}}truncate-all.cpp:[[op1loc]] +// CHECK-LOCS-NEXT: 0x[[op2]], {{.*}}truncate-all.cpp:[[op2loc]] +// CHECK-LOCS-NEXT: 0x[[op3]], {{.*}}truncate-all.cpp:[[op3loc]] +// CHECK-LOCS-NEXT: 0x[[op4]], {{.*}}truncate-all.cpp:[[op4loc]] +// CHECK-LOCS-NEXT: 0x[[op5]], {{.*}}truncate-all.cpp:[[op5loc]] +// CHECK-LOCS-NEXT: 0x[[op6]], {{.*}}truncate-all.cpp:[[op6loc]] +// CHECK-LOCS-NEXT: 0x[[op7]], {{.*}}truncate-all.cpp:[[op7loc]] + + #include #include "../test_utils.h" -#define N 10 +#define N 6 #define floatty double From 386e432d5b783e15f228346ab8f0134defcd5a6d Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 22 Apr 2024 11:02:53 +0900 Subject: [PATCH 2/6] Also log op name --- enzyme/include/enzyme/fprt/mpfr-test.h | 30 +++++++++++++++++--------- 1 file changed, 20 insertions(+), 10 deletions(-) diff --git a/enzyme/include/enzyme/fprt/mpfr-test.h b/enzyme/include/enzyme/fprt/mpfr-test.h index 1b24cf88fb09..b2d365eef1b6 100644 --- a/enzyme/include/enzyme/fprt/mpfr-test.h +++ b/enzyme/include/enzyme/fprt/mpfr-test.h @@ -85,7 +85,8 @@ static __enzyme_fp *__enzyme_fprt_double_to_ptr(double d) { __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, int64_t mode, char *loc) { - printf("%p, %s\n", loc, loc); + if (loc) + printf("%p, %s\n", loc, loc); __enzyme_fp *a = __enzyme_fprt_double_to_ptr(_a); return mpfr_get_d(a->v, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); } @@ -93,7 +94,8 @@ double __enzyme_fprt_64_52_get(double _a, int64_t exponent, int64_t significand, __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, int64_t mode, char *loc) { - printf("%p, %s\n", loc, loc); + if (loc) + printf("%p, %s\n", loc, loc); __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); mpfr_init2(a->v, significand); mpfr_set_d(a->v, _a, __ENZYME_MPFR_DEFAULT_ROUNDING_MODE); @@ -103,7 +105,8 @@ double __enzyme_fprt_64_52_new(double _a, int64_t exponent, int64_t significand, __ENZYME_MPFR_ATTRIBUTES double __enzyme_fprt_64_52_const(double _a, int64_t exponent, int64_t significand, int64_t mode, char *loc) { - printf("%p, %s\n", loc, loc); + if (loc) + printf("%p, %s\n", loc, loc); // TODO This should really be called only once for an appearance in the code, // currently it is called every time a flop uses a constant. return __enzyme_fprt_64_52_new(_a, exponent, significand, mode, loc); @@ -113,7 +116,8 @@ __ENZYME_MPFR_ATTRIBUTES __enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, int64_t significand, int64_t mode, char *loc) { - printf("%p, %s\n", loc, loc); + if (loc) + printf("%p, %s\n", loc, loc); __enzyme_fp *a = (__enzyme_fp *)malloc(sizeof(__enzyme_fp)); mpfr_init2(a->v, significand); return a; @@ -122,7 +126,8 @@ __enzyme_fp *__enzyme_fprt_64_52_new_intermediate(int64_t exponent, __ENZYME_MPFR_ATTRIBUTES void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, int64_t mode, char *loc) { - printf("%p, %s\n", loc, loc); + if (loc) + printf("%p, %s\n", loc, loc); free(__enzyme_fprt_double_to_ptr(a)); } @@ -133,7 +138,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ ARG1 a, int64_t exponent, int64_t significand, int64_t mode, \ char *loc) { \ - printf("%p, %s\n", loc, loc); \ + if (loc) \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mc; \ mpfr_init2(ma, significand); \ @@ -164,7 +170,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ char *loc) { \ - printf("%p, %s\n", loc, loc); \ + if (loc) \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mc; \ mpfr_init2(ma, significand); \ @@ -193,7 +200,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, RET __enzyme_fprt_##FROM_TYPE##_##OP_TYPE##_##LLVM_OP_NAME( \ ARG1 a, ARG2 b, int64_t exponent, int64_t significand, int64_t mode, \ char *loc) { \ - printf("%p, %s\n", loc, loc); \ + if (loc) \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb, mc; \ mpfr_init2(ma, significand); \ @@ -225,7 +233,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, TYPE __enzyme_fprt_##FROM_TYPE##_intr_##LLVM_OP_NAME##_##LLVM_TYPE( \ TYPE a, TYPE b, TYPE c, int64_t exponent, int64_t significand, \ int64_t mode, char *loc) { \ - printf("%p, %s\n", loc, loc); \ + if (loc) \ + printf("%p, %s, %s\n", loc, #LLVM_OP_NAME, loc); \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb, mc, mmul, madd; \ mpfr_init2(ma, significand); \ @@ -268,7 +277,8 @@ void __enzyme_fprt_64_52_delete(double a, int64_t exponent, int64_t significand, bool __enzyme_fprt_##FROM_TYPE##_fcmp_##NAME( \ TYPE a, TYPE b, int64_t exponent, int64_t significand, int64_t mode, \ char *loc) { \ - printf("%p, %s\n", loc, loc); \ + if (loc) \ + printf("%p, %s, %s\n", loc, "fcmp" #NAME, loc); \ if (__enzyme_fprt_is_op_mode(mode)) { \ mpfr_t ma, mb; \ mpfr_init2(ma, significand); \ From 21b3fe5e1efc060066d0a2c730cbb1a988046665 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 22 Apr 2024 11:10:27 +0900 Subject: [PATCH 3/6] Fix tests --- enzyme/Enzyme/EnzymeLogic.cpp | 2 +- enzyme/test/Enzyme/Truncate/cmp.ll | 2 +- enzyme/test/Enzyme/Truncate/const.ll | 6 +++--- enzyme/test/Enzyme/Truncate/intrinsic.ll | 24 ++++++++++++------------ enzyme/test/Enzyme/Truncate/simple.ll | 6 +++--- enzyme/test/Enzyme/Truncate/value.ll | 4 ++-- 6 files changed, 22 insertions(+), 22 deletions(-) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 0a8edc526890..eb313b009306 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5099,7 +5099,7 @@ class TruncateUtils { toType = truncation.getToType(ctx); if (fromType == toType) assert(truncation.isToFPRT()); - NullPtr = ConstantPointerNull::get(PointerType::get(ctx, 0)); + NullPtr = ConstantPointerNull::get(getDefaultAnonymousTapeType(ctx)); } Type *getFromType() { return fromType; } diff --git a/enzyme/test/Enzyme/Truncate/cmp.ll b/enzyme/test/Enzyme/Truncate/cmp.ll index d33c40d7de11..15140bdb5f75 100644 --- a/enzyme/test/Enzyme/Truncate/cmp.ll +++ b/enzyme/test/Enzyme/Truncate/cmp.ll @@ -29,7 +29,7 @@ entry: } ; CHECK: define internal i1 @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { -; CHECK-NEXT: %res = call i1 @__enzyme_fprt_64_52_fcmp_olt(double %x, double %y, i64 8, i64 23, i64 1) +; CHECK-NEXT: %res = call i1 @__enzyme_fprt_64_52_fcmp_olt(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret i1 %res ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/const.ll b/enzyme/test/Enzyme/Truncate/const.ll index 25c5c5ee4c3b..b90b20615a93 100644 --- a/enzyme/test/Enzyme/Truncate/const.ll +++ b/enzyme/test/Enzyme/Truncate/const.ll @@ -23,12 +23,12 @@ entry: } ; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x) { -; CHECK-NEXT: %1 = call double @__enzyme_fprt_64_52_const(double 1.000000e+00, i64 8, i64 23, i64 1) -; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double %1, i64 8, i64 23, i64 1) +; CHECK-NEXT: %1 = call double @__enzyme_fprt_64_52_const(double 1.000000e+00, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double %1, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %res ; CHECK-NEXT: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x) { -; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double 1.000000e+00, i64 3, i64 7, i64 2) +; CHECK-NEXT: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %x, double 1.000000e+00, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %res ; CHECK-NEXT: } diff --git a/enzyme/test/Enzyme/Truncate/intrinsic.ll b/enzyme/test/Enzyme/Truncate/intrinsic.ll index 3e5fe36b3784..a5899e75c68c 100644 --- a/enzyme/test/Enzyme/Truncate/intrinsic.ll +++ b/enzyme/test/Enzyme/Truncate/intrinsic.ll @@ -42,28 +42,28 @@ entry: } ; CHECK: define internal double @__enzyme_done_truncate_mem_func_64_52to32_23_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 1) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 1) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 1) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 1) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to32_23_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 2) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 2) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 2) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 2) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } ; CHECK: define internal double @__enzyme_done_truncate_op_func_64_52to11_7_f(double %x, double %y) { -; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 3, i64 7, i64 2) -; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7, i64 2) -; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7, i64 2) -; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 3, i64 7, i64 2) +; CHECK-DAG: %1 = call double @__enzyme_fprt_64_52_func_pow(double %x, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %2 = call double @__enzyme_fprt_64_52_intr_llvm_pow_f64(double %x, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %3 = call double @__enzyme_fprt_64_52_intr_llvm_powi_f64_i16(double %x, i16 2, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) +; CHECK-DAG: %res = call double @__enzyme_fprt_64_52_binop_fadd(double %2, double %3, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: call void @llvm.nvvm.barrier0() ; CHECK-DAG: ret double %res ; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/simple.ll b/enzyme/test/Enzyme/Truncate/simple.ll index 747e268ae381..cd94c87aba46 100644 --- a/enzyme/test/Enzyme/Truncate/simple.ll +++ b/enzyme/test/Enzyme/Truncate/simple.ll @@ -36,21 +36,21 @@ entry: ; CHECK: define internal void @__enzyme_done_truncate_mem_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 1) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } ; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to32_23_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 2) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 8, i64 23, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } ; CHECK: define internal void @__enzyme_done_truncate_op_func_64_52to11_7_f(double* %x) { ; CHECK-DAG: %y = load double, double* %x, align 8 -; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 3, i64 7, i64 2) +; CHECK-DAG: %m = call double @__enzyme_fprt_64_52_binop_fmul(double %y, double %y, i64 3, i64 7, i64 2, {{.*}}i8{{.*}}) ; CHECK-DAG: store double %m, double* %x, align 8 ; CHECK-DAG: ret void ; CHECK-DAG: } diff --git a/enzyme/test/Enzyme/Truncate/value.ll b/enzyme/test/Enzyme/Truncate/value.ll index fa79e93440bb..1722b4fc1efc 100644 --- a/enzyme/test/Enzyme/Truncate/value.ll +++ b/enzyme/test/Enzyme/Truncate/value.ll @@ -18,10 +18,10 @@ entry: ; CHECK: define double @expand_tester(double %a, double* %c) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_get(double %a, i64 8, i64 23, i64 1) +; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_get(double %a, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %0 ; CHECK: define double @truncate_tester(double %a) { ; CHECK-NEXT: entry: -; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_new(double %a, i64 8, i64 23, i64 1) +; CHECK-NEXT: %0 = call double @__enzyme_fprt_64_52_new(double %a, i64 8, i64 23, i64 1, {{.*}}i8{{.*}}) ; CHECK-NEXT: ret double %0 From ff5e1b4cd95e13dd191c02459292e5e394f1e456 Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 22 Apr 2024 12:13:14 +0900 Subject: [PATCH 4/6] Fix file location --- enzyme/Enzyme/EnzymeLogic.cpp | 11 ++++--- .../Truncate/truncate-all-header.h | 15 ++++++++++ .../Integration/Truncate/truncate-all.cpp | 30 +++++++++++-------- 3 files changed, 39 insertions(+), 17 deletions(-) create mode 100644 enzyme/test/Integration/Truncate/truncate-all-header.h diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index eb313b009306..60de63e0e59f 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5134,13 +5134,16 @@ class TruncateUtils { // compilation units. GlobalValue *getUniquedLocStr(Instruction &I) { auto M = I.getParent()->getParent()->getParent(); - std::string FileName = M->getName().str(); + std::string FileName = "unknown"; unsigned LineNo = 0; unsigned ColNo = 0; - if (I.getDebugLoc().get()) { - LineNo = I.getDebugLoc().getLine(); - ColNo = I.getDebugLoc().getCol(); + + DILocation *DL = I.getDebugLoc(); + if (DL) { + FileName = DL->getFilename(); + LineNo = DL->getLine(); + ColNo = DL->getColumn(); } auto Key = std::make_tuple(FileName, LineNo, ColNo); diff --git a/enzyme/test/Integration/Truncate/truncate-all-header.h b/enzyme/test/Integration/Truncate/truncate-all-header.h new file mode 100644 index 000000000000..3fd9f0780365 --- /dev/null +++ b/enzyme/test/Integration/Truncate/truncate-all-header.h @@ -0,0 +1,15 @@ +#ifndef TRUNCATE_ALL_HEADER_H_ +#define TRUNCATE_ALL_HEADER_H_ + +#include + +#define N 6 + +#define floatty double + +__attribute__((noinline)) static +floatty intrinsics2(floatty a, floatty b) { + return sin(a) * cos(b); +} + +#endif // TRUNCATE_ALL_HEADER_H_ diff --git a/enzyme/test/Integration/Truncate/truncate-all.cpp b/enzyme/test/Integration/Truncate/truncate-all.cpp index 87a45ea0d416..818c2c603cac 100644 --- a/enzyme/test/Integration/Truncate/truncate-all.cpp +++ b/enzyme/test/Integration/Truncate/truncate-all.cpp @@ -20,29 +20,23 @@ // CHECK-LOCS: 0x[[op1:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op1loc:.*]] // CHECK-LOCS-NEXT: 0x[[op2:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op2loc:.*]] // CHECK-LOCS-NEXT: 0x[[op3:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op3loc:.*]] -// CHECK-LOCS-NEXT: 0x[[op4:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op4loc:.*]] -// CHECK-LOCS-NEXT: 0x[[op5:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op5loc:.*]] -// CHECK-LOCS-NEXT: 0x[[op6:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op6loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op4:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op4loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op5:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op5loc:.*]] +// CHECK-LOCS-NEXT: 0x[[op6:[0-9a-f]*]], {{.*}}truncate-all-header.h:[[op6loc:.*]] // CHECK-LOCS-NEXT: 0x[[op7:[0-9a-f]*]], {{.*}}truncate-all.cpp:[[op7loc:.*]] // CHECK-LOCS-NEXT: 0x[[op1]], {{.*}}truncate-all.cpp:[[op1loc]] // CHECK-LOCS-NEXT: 0x[[op2]], {{.*}}truncate-all.cpp:[[op2loc]] // CHECK-LOCS-NEXT: 0x[[op3]], {{.*}}truncate-all.cpp:[[op3loc]] -// CHECK-LOCS-NEXT: 0x[[op4]], {{.*}}truncate-all.cpp:[[op4loc]] -// CHECK-LOCS-NEXT: 0x[[op5]], {{.*}}truncate-all.cpp:[[op5loc]] -// CHECK-LOCS-NEXT: 0x[[op6]], {{.*}}truncate-all.cpp:[[op6loc]] +// CHECK-LOCS-NEXT: 0x[[op4]], {{.*}}truncate-all-header.h:[[op4loc]] +// CHECK-LOCS-NEXT: 0x[[op5]], {{.*}}truncate-all-header.h:[[op5loc]] +// CHECK-LOCS-NEXT: 0x[[op6]], {{.*}}truncate-all-header.h:[[op6loc]] // CHECK-LOCS-NEXT: 0x[[op7]], {{.*}}truncate-all.cpp:[[op7loc]] -#include - +#include "truncate-all-header.h" #include "../test_utils.h" -#define N 6 - -#define floatty double - - __attribute__((noinline)) floatty simple_add(floatty a, floatty b) { return a + b; @@ -52,6 +46,13 @@ floatty intrinsics(floatty a, floatty b) { return sqrt(a) * pow(b, 2); } __attribute__((noinline)) +floatty compute2(floatty *A, floatty *B, floatty *C, int n) { + for (int i = 0; i < n; i++) { + C[i] = A[i] / 2 + intrinsics2(A[i], simple_add(B[i] * 10000, 0.000001)); + } + return C[0]; +} +__attribute__((noinline)) floatty compute(floatty *A, floatty *B, floatty *C, int n) { for (int i = 0; i < n; i++) { C[i] = A[i] / 2 + intrinsics(A[i], simple_add(B[i] * 10000, 0.000001)); @@ -69,6 +70,9 @@ int main() { B[i] = 1 + i % 3; } + compute2(A, B, C, N); + for (int i = 0; i < N; i++) + C[i] = 0; compute(A, B, C, N); printf("%f\n", C[5]); } From 0fcd564641843920ea425c233352053001e4abfb Mon Sep 17 00:00:00 2001 From: Ivan Radanov Ivanov Date: Mon, 22 Apr 2024 15:39:34 +0900 Subject: [PATCH 5/6] Fix older llvm vers --- enzyme/Enzyme/EnzymeLogic.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index 60de63e0e59f..dbf1710e46b9 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5087,7 +5087,12 @@ class TruncateUtils { Args.push_back(B.getInt64(truncation.getTo().exponentWidth)); Args.push_back(B.getInt64(truncation.getTo().significandWidth)); Args.push_back(B.getInt64(truncation.getMode())); +#if LLVM_VERSION_MAJOR <= 14 + Args.push_back(B.CreateBitCast(LocStr, NullPtr->getType())); +#else Args.push_back(LocStr); +#endif + auto FprtFunc = getFPRTFunc(Name, Args, RetTy); return cast(B.CreateCall(FprtFunc, Args)); } From dc65cf2583aabe7ed1f53cd2a25ff8be6d612a95 Mon Sep 17 00:00:00 2001 From: "Ivan R. Ivanov" Date: Sun, 28 Apr 2024 09:35:53 -0700 Subject: [PATCH 6/6] [Truncate] Handle casts and emit warnings (#1846) * Handle casts and emit warnings * only emit warning/error if trunc from type is used --- enzyme/Enzyme/EnzymeLogic.cpp | 81 ++++++++++--------- enzyme/test/Integration/Truncate/warnings.cpp | 53 ++++++++++++ 2 files changed, 95 insertions(+), 39 deletions(-) create mode 100644 enzyme/test/Integration/Truncate/warnings.cpp diff --git a/enzyme/Enzyme/EnzymeLogic.cpp b/enzyme/Enzyme/EnzymeLogic.cpp index dbf1710e46b9..fbdbc01257fb 100644 --- a/enzyme/Enzyme/EnzymeLogic.cpp +++ b/enzyme/Enzyme/EnzymeLogic.cpp @@ -5213,31 +5213,37 @@ class TruncateGenerator : public llvm::InstVisitor, oldFunc(oldFunc), newFunc(newFunc), mode(truncation.getMode()), Logic(Logic), ctx(newFunc->getContext()) {} - void checkHandled(llvm::Instruction &inst) { - // TODO - // if (all_of(inst.getOperandList(), - // [&](Use *use) { return use->get()->getType() == fromType; })) - // todo(inst); - } + void todo(llvm::Instruction &I) { + if (all_of(I.operands(), + [&](Use &U) { return U.get()->getType() != fromType; })) + return; - // TODO - void handleTrunc(); - void hendleIntToFloat(); - void handleFloatToInt(); + switch (mode) { + case TruncMemMode: + EmitFailure("FPEscaping", I.getDebugLoc(), &I, "FP value escapes!"); + break; + case TruncOpMode: + case TruncOpFullModuleMode: + EmitWarning( + "UnhandledTrunc", I, + "Operation not handled - it will be executed in the original way.", + I); + break; + default: + llvm_unreachable("Unknown trunc mode"); + } + } - void visitInstruction(llvm::Instruction &inst) { + void visitInstruction(llvm::Instruction &I) { using namespace llvm; - // TODO explicitly handle all instructions rather than using the catch all - // below - - switch (inst.getOpcode()) { + switch (I.getOpcode()) { // #include "InstructionDerivatives.inc" default: break; } - checkHandled(inst); + todo(I); } Value *truncate(IRBuilder<> &B, Value *v) { @@ -5264,21 +5270,6 @@ class TruncateGenerator : public llvm::InstVisitor, llvm_unreachable("Unknown trunc mode"); } - void todo(llvm::Instruction &I) { - std::string s; - llvm::raw_string_ostream ss(s); - ss << "cannot handle unknown instruction\n" << I; - if (CustomErrorHandler) { - IRBuilder<> Builder2(getNewFromOriginal(&I)); - CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate, - this, nullptr, wrap(&Builder2)); - return; - } else { - EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str()); - return; - } - } - void visitAllocaInst(llvm::AllocaInst &I) { return; } void visitICmpInst(llvm::ICmpInst &I) { return; } void visitFCmpInst(llvm::FCmpInst &CI) { @@ -5327,10 +5318,28 @@ class TruncateGenerator : public llvm::InstVisitor, void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; } void visitPHINode(llvm::PHINode &phi) { return; } void visitCastInst(llvm::CastInst &CI) { + // TODO Try to follow fps through trunc/exts switch (mode) { case TruncMemMode: { - if (CI.getSrcTy() == getFromType() || CI.getDestTy() == getFromType()) - todo(CI); + auto newI = getNewFromOriginal(&CI); + auto newSrc = newI->getOperand(0); + if (CI.getSrcTy() == getFromType()) { + IRBuilder<> B(newI); + if (isa(newSrc)) + return; + newI->setOperand(0, createFPRTGetCall(B, newSrc)); + EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.", + CI); + } else if (CI.getDestTy() == getFromType()) { + IRBuilder<> B(newI->getNextNode()); + EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.", + CI); + auto nres = createFPRTNewCall(B, newI); + nres->takeName(newI); + nres->copyIRFlags(newI); + newI->replaceUsesWithIf(nres, + [&](Use &U) { return U.getUser() != nres; }); + } return; } case TruncOpMode: @@ -5585,12 +5594,6 @@ class TruncateGenerator : public llvm::InstVisitor, } return; } - void visitFPTruncInst(FPTruncInst &I) { return; } - void visitFPExtInst(FPExtInst &I) { return; } - void visitFPToUIInst(FPToUIInst &I) { return; } - void visitFPToSIInst(FPToSIInst &I) { return; } - void visitUIToFPInst(UIToFPInst &I) { return; } - void visitSIToFPInst(SIToFPInst &I) { return; } }; bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v, diff --git a/enzyme/test/Integration/Truncate/warnings.cpp b/enzyme/test/Integration/Truncate/warnings.cpp new file mode 100644 index 000000000000..4f730a433549 --- /dev/null +++ b/enzyme/test/Integration/Truncate/warnings.cpp @@ -0,0 +1,53 @@ +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// RUN: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_MEM -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi +// COM: if [ %llvmver -ge 12 ] && [ %hasMPFR == "yes" ] ; then %clang -c -DTRUNC_OP -O2 -g %s -o /dev/null -emit-llvm %newLoadClangEnzyme -include enzyme/fprt/mpfr.h -Xclang -verify -Rpass=enzyme; fi + +#include + +#define FROM 64 +#define TO 32 + +double bithack(double a) { + return *((int64_t *)&a) + 1; // expected-remark {{Will not follow FP through this cast.}}, expected-remark {{Will not follow FP through this cast.}} +} +__attribute__((noinline)) +float truncf(double a) { + return (float)a; // expected-remark {{Will not follow FP through this cast.}} +} + +double intrinsics(double a, double b) { + return bithack(a) * truncf(b); // expected-remark {{Will not follow FP through this cast.}} +} + +typedef double (*fty)(double *, double *, double *, int); + +typedef double (*fty2)(double, double); + +extern fty __enzyme_truncate_mem_func_2(...); +extern fty2 __enzyme_truncate_mem_func(...); +extern fty __enzyme_truncate_op_func_2(...); +extern fty2 __enzyme_truncate_op_func(...); +extern double __enzyme_truncate_mem_value(...); +extern double __enzyme_expand_mem_value(...); + + +int main() { + #ifdef TRUNC_MEM + { + double a = 2; + double b = 3; + a = __enzyme_truncate_mem_value(a, FROM, TO); + b = __enzyme_truncate_mem_value(b, FROM, TO); + double trunc = __enzyme_expand_mem_value(__enzyme_truncate_mem_func(intrinsics, FROM, TO)(a, b), FROM, TO); + } + #endif + #ifdef TRUNC_OP + { + double a = 2; + double b = 3; + double trunc = __enzyme_truncate_op_func(intrinsics, FROM, TO)(a, b); + } + #endif + +}