Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add debug info to fprt runtime calls #1843

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
//
//===----------------------------------------------------------------------===//
#include <llvm/Config/llvm-config.h>
#include <llvm/IR/GlobalValue.h>

#if LLVM_VERSION_MAJOR >= 16
#define private public
Expand Down
145 changes: 96 additions & 49 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,9 @@
#include "llvm/IR/LLVMContext.h"
#include "llvm/Support/ErrorHandling.h"
#include <cmath>
#include <llvm-c/Core.h>
#include <llvm/Transforms/Instrumentation.h>
#include <tuple>

#if LLVM_VERSION_MAJOR >= 16
#define private public
Expand Down Expand Up @@ -5025,6 +5028,8 @@ class TruncateUtils {
Type *fromType;
Type *toType;
LLVMContext &ctx;
EnzymeLogic &Logic;
Value *NullPtr;

private:
std::string getOriginalFPRTName(std::string Name) {
Expand Down Expand Up @@ -5077,22 +5082,29 @@ class TruncateUtils {

CallInst *createFPRTGeneric(llvm::IRBuilderBase &B, std::string Name,
const SmallVectorImpl<Value *> &ArgsIn,
llvm::Type *RetTy) {
llvm::Type *RetTy, Value *LocStr) {
SmallVector<Value *, 5> Args(ArgsIn.begin(), ArgsIn.end());
Args.push_back(B.getInt64(truncation.getTo().exponentWidth));
Args.push_back(B.getInt64(truncation.getTo().significandWidth));
Args.push_back(B.getInt64(truncation.getMode()));
#if LLVM_VERSION_MAJOR <= 14
Args.push_back(B.CreateBitCast(LocStr, NullPtr->getType()));
#else
Args.push_back(LocStr);
#endif

auto FprtFunc = getFPRTFunc(Name, Args, RetTy);
return cast<CallInst>(B.CreateCall(FprtFunc, Args));
}

public:
TruncateUtils(FloatTruncation truncation, Module *M)
: truncation(truncation), M(M), ctx(M->getContext()) {
TruncateUtils(FloatTruncation truncation, Module *M, EnzymeLogic &Logic)
: truncation(truncation), M(M), ctx(M->getContext()), Logic(Logic) {
fromType = truncation.getFromType(ctx);
toType = truncation.getToType(ctx);
if (fromType == toType)
assert(truncation.isToFPRT());
NullPtr = ConstantPointerNull::get(getDefaultAnonymousTapeType(ctx));
}

Type *getFromType() { return fromType; }
Expand All @@ -5103,23 +5115,54 @@ class TruncateUtils {
assert(V->getType() == getFromType());
SmallVector<Value *, 1> Args;
Args.push_back(V);
return createFPRTGeneric(B, "const", Args, getToType());
return createFPRTGeneric(B, "const", Args, getToType(), NullPtr);
}
CallInst *createFPRTNewCall(llvm::IRBuilderBase &B, Value *V) {
assert(V->getType() == getFromType());
SmallVector<Value *, 1> Args;
Args.push_back(V);
return createFPRTGeneric(B, "new", Args, getToType());
return createFPRTGeneric(B, "new", Args, getToType(), NullPtr);
}
CallInst *createFPRTGetCall(llvm::IRBuilderBase &B, Value *V) {
SmallVector<Value *, 1> Args;
Args.push_back(V);
return createFPRTGeneric(B, "get", Args, getToType());
return createFPRTGeneric(B, "get", Args, getToType(), NullPtr);
}
CallInst *createFPRTDeleteCall(llvm::IRBuilderBase &B, Value *V) {
SmallVector<Value *, 1> Args;
Args.push_back(V);
return createFPRTGeneric(B, "delete", Args, B.getVoidTy());
return createFPRTGeneric(B, "delete", Args, B.getVoidTy(), NullPtr);
}
// This will result in a unique string for each location, which means the
// runtime can check whether two operations are the same with a simple pointer
// comparison. However, we need LTO for this to be the case across different
// compilation units.
GlobalValue *getUniquedLocStr(Instruction &I) {
auto M = I.getParent()->getParent()->getParent();

std::string FileName = "unknown";
unsigned LineNo = 0;
unsigned ColNo = 0;

DILocation *DL = I.getDebugLoc();
if (DL) {
FileName = DL->getFilename();
LineNo = DL->getLine();
ColNo = DL->getColumn();
}

auto Key = std::make_tuple(FileName, LineNo, ColNo);
auto It = Logic.UniqDebugLocStrs.find(Key);

if (It != Logic.UniqDebugLocStrs.end())
return It->second;

std::string LocStr =
FileName + ":" + std::to_string(LineNo) + ":" + std::to_string(ColNo);
auto GV = createPrivateGlobalForString(*M, LocStr, true);
Logic.UniqDebugLocStrs[Key] = GV;

return GV;
}
CallInst *createFPRTOpCall(llvm::IRBuilderBase &B, llvm::Instruction &I,
llvm::Type *RetTy,
Expand All @@ -5146,7 +5189,7 @@ class TruncateUtils {
llvm_unreachable("Unexpected instruction for conversion to FPRT");
}
createOriginalFPRTFunc(I, Name, ArgsIn, RetTy);
return createFPRTGeneric(B, Name, ArgsIn, RetTy);
return createFPRTGeneric(B, Name, ArgsIn, RetTy, getUniquedLocStr(I));
}
};

Expand All @@ -5165,36 +5208,42 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
TruncateGenerator(ValueToValueMapTy &originalToNewFn,
FloatTruncation truncation, Function *oldFunc,
Function *newFunc, EnzymeLogic &Logic)
: TruncateUtils(truncation, newFunc->getParent()),
: TruncateUtils(truncation, newFunc->getParent(), Logic),
originalToNewFn(originalToNewFn), truncation(truncation),
oldFunc(oldFunc), newFunc(newFunc), mode(truncation.getMode()),
Logic(Logic), ctx(newFunc->getContext()) {}

void checkHandled(llvm::Instruction &inst) {
// TODO
// if (all_of(inst.getOperandList(),
// [&](Use *use) { return use->get()->getType() == fromType; }))
// todo(inst);
}
void todo(llvm::Instruction &I) {
if (all_of(I.operands(),
[&](Use &U) { return U.get()->getType() != fromType; }))
return;

// TODO
void handleTrunc();
void hendleIntToFloat();
void handleFloatToInt();
switch (mode) {
case TruncMemMode:
EmitFailure("FPEscaping", I.getDebugLoc(), &I, "FP value escapes!");
break;
case TruncOpMode:
case TruncOpFullModuleMode:
EmitWarning(
"UnhandledTrunc", I,
"Operation not handled - it will be executed in the original way.",
I);
break;
default:
llvm_unreachable("Unknown trunc mode");
}
}

void visitInstruction(llvm::Instruction &inst) {
void visitInstruction(llvm::Instruction &I) {
using namespace llvm;

// TODO explicitly handle all instructions rather than using the catch all
// below

switch (inst.getOpcode()) {
switch (I.getOpcode()) {
// #include "InstructionDerivatives.inc"
default:
break;
}

checkHandled(inst);
todo(I);
}

Value *truncate(IRBuilder<> &B, Value *v) {
Expand All @@ -5221,21 +5270,6 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
llvm_unreachable("Unknown trunc mode");
}

void todo(llvm::Instruction &I) {
std::string s;
llvm::raw_string_ostream ss(s);
ss << "cannot handle unknown instruction\n" << I;
if (CustomErrorHandler) {
IRBuilder<> Builder2(getNewFromOriginal(&I));
CustomErrorHandler(ss.str().c_str(), wrap(&I), ErrorType::NoTruncate,
this, nullptr, wrap(&Builder2));
return;
} else {
EmitFailure("NoTruncate", I.getDebugLoc(), &I, ss.str());
return;
}
}

void visitAllocaInst(llvm::AllocaInst &I) { return; }
void visitICmpInst(llvm::ICmpInst &I) { return; }
void visitFCmpInst(llvm::FCmpInst &CI) {
Expand Down Expand Up @@ -5284,10 +5318,28 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
void visitGetElementPtrInst(llvm::GetElementPtrInst &gep) { return; }
void visitPHINode(llvm::PHINode &phi) { return; }
void visitCastInst(llvm::CastInst &CI) {
// TODO Try to follow fps through trunc/exts
switch (mode) {
case TruncMemMode: {
if (CI.getSrcTy() == getFromType() || CI.getDestTy() == getFromType())
todo(CI);
auto newI = getNewFromOriginal(&CI);
auto newSrc = newI->getOperand(0);
if (CI.getSrcTy() == getFromType()) {
IRBuilder<> B(newI);
if (isa<Constant>(newSrc))
return;
newI->setOperand(0, createFPRTGetCall(B, newSrc));
EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.",
CI);
} else if (CI.getDestTy() == getFromType()) {
IRBuilder<> B(newI->getNextNode());
EmitWarning("FPNoFollow", CI, "Will not follow FP through this cast.",
CI);
auto nres = createFPRTNewCall(B, newI);
nres->takeName(newI);
nres->copyIRFlags(newI);
newI->replaceUsesWithIf(nres,
[&](Use &U) { return U.getUser() != nres; });
}
return;
}
case TruncOpMode:
Expand Down Expand Up @@ -5542,12 +5594,6 @@ class TruncateGenerator : public llvm::InstVisitor<TruncateGenerator>,
}
return;
}
void visitFPTruncInst(FPTruncInst &I) { return; }
void visitFPExtInst(FPExtInst &I) { return; }
void visitFPToUIInst(FPToUIInst &I) { return; }
void visitFPToSIInst(FPToSIInst &I) { return; }
void visitUIToFPInst(UIToFPInst &I) { return; }
void visitSIToFPInst(SIToFPInst &I) { return; }
};

bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v,
Expand All @@ -5559,7 +5605,8 @@ bool EnzymeLogic::CreateTruncateValue(RequestContext context, Value *v,

Value *converted = nullptr;
auto truncation = FloatTruncation(from, to, TruncMemMode);
TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent());
TruncateUtils TU(truncation, B.GetInsertBlock()->getParent()->getParent(),
*this);
if (isTruncate)
converted = TU.createFPRTNewCall(B, v);
else
Expand Down
6 changes: 6 additions & 0 deletions enzyme/Enzyme/EnzymeLogic.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#define ENZYME_LOGIC_H

#include <algorithm>
#include <map>
#include <set>
#include <utility>

Expand Down Expand Up @@ -404,9 +405,14 @@ struct FloatTruncation {
std::string mangleFrom() const { return from.to_string(); }
};

typedef std::map<std::tuple<std::string, unsigned, unsigned>,
llvm::GlobalValue *>
UniqDebugLocStrsTy;

class EnzymeLogic {
public:
PreProcessCache PPC;
UniqDebugLocStrsTy UniqDebugLocStrs;

/// \p PostOpt is whether to perform basic
/// optimization of the function after synthesis
Expand Down
Loading
Loading