Skip to content

Commit

Permalink
Fix calling conv attribution mismatch (#1296)
Browse files Browse the repository at this point in the history
* Fix calling conv attribution mismatch

* Fix missing store bug
  • Loading branch information
wsmoses authored Jun 23, 2023
1 parent 11a9451 commit 3c0014a
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 83 deletions.
164 changes: 88 additions & 76 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -719,95 +719,106 @@ void calculateUnusedValuesInFunction(
}
}

std::function<bool(const llvm::Value *)> isNoNeed =
[&](const llvm::Value *v) {
auto Obj = getBaseObject(v);
if (Obj != v)
return isNoNeed(Obj);
if (auto C = dyn_cast<LoadInst>(v))
return isNoNeed(C->getOperand(0));
else if (auto arg = dyn_cast<Argument>(v)) {
auto act = constant_args[arg->getArgNo()];
if (act == DIFFE_TYPE::DUP_NONEED) {
return true;
}
} else if (isa<AllocaInst>(v) || isAllocationCall(v, TLI)) {
if (!gutils->isConstantValue(const_cast<Value *>(v))) {
std::set<const Value *> done;
std::deque<const Value *> todo = {v};
bool legal = true;
while (todo.size()) {
const Value *cur = todo.back();
todo.pop_back();
if (done.count(cur))
continue;
done.insert(cur);
std::function<bool(const llvm::Value *)> isNoNeed = [&](const llvm::Value
*v) {
auto Obj = getBaseObject(v);
if (Obj != v)
return isNoNeed(Obj);
if (auto C = dyn_cast<LoadInst>(v))
return isNoNeed(C->getOperand(0));
else if (auto arg = dyn_cast<Argument>(v)) {
auto act = constant_args[arg->getArgNo()];
if (act == DIFFE_TYPE::DUP_NONEED) {
return true;
}
} else if (isa<AllocaInst>(v) || isAllocationCall(v, TLI)) {
if (!gutils->isConstantValue(const_cast<Value *>(v))) {
std::set<const Value *> done;
std::deque<const Value *> todo = {v};
bool legal = true;
while (todo.size()) {
const Value *cur = todo.back();
todo.pop_back();
if (done.count(cur))
continue;
done.insert(cur);

if (unnecessaryValues.count(cur))
continue;
if (unnecessaryValues.count(cur))
continue;

for (auto u : cur->users()) {
if (auto SI = dyn_cast<StoreInst>(u)) {
if (SI->getValueOperand() != cur)
continue;
}
if (auto I = dyn_cast<Instruction>(u)) {
if (unnecessaryInstructions.count(I))
continue;
if (isDeallocationCall(I, TLI))
continue;
for (auto u : cur->users()) {
if (auto SI = dyn_cast<StoreInst>(u)) {
if (SI->getValueOperand() != cur) {
continue;
}
}
if (auto I = dyn_cast<Instruction>(u)) {
if (unnecessaryInstructions.count(I)) {
if (!DifferentialUseAnalysis::is_use_directly_needed_in_reverse(
gutils, cur, I, oldUnreachable)) {
continue;
}
if (auto II = dyn_cast<IntrinsicInst>(u);
II && isIntelSubscriptIntrinsic(*II)) {
todo.push_back(&*u);
} else if (auto CI = dyn_cast<CallInst>(u)) {
bool writeOnlyNoCapture = true;
if (shouldDisableNoWrite(CI)) {
writeOnlyNoCapture = false;
}
}
if (isDeallocationCall(I, TLI)) {
continue;
}
}
if (auto II = dyn_cast<IntrinsicInst>(u);
II && isIntelSubscriptIntrinsic(*II)) {
todo.push_back(&*u);
continue;
} else if (auto CI = dyn_cast<CallInst>(u)) {
if (getFuncNameFromCall(CI) == "julia.write_barrier") {
continue;
}
bool writeOnlyNoCapture = true;
if (shouldDisableNoWrite(CI)) {
writeOnlyNoCapture = false;
}
#if LLVM_VERSION_MAJOR >= 14
for (size_t i = 0; i < CI->arg_size(); i++)
for (size_t i = 0; i < CI->arg_size(); i++)
#else
for (size_t i = 0; i < CI->getNumArgOperands(); i++)
for (size_t i = 0; i < CI->getNumArgOperands(); i++)
#endif
{
if (cur == CI->getArgOperand(i)) {
if (!isNoCapture(CI, i)) {
writeOnlyNoCapture = false;
break;
}
if (!isWriteOnly(CI, i)) {
writeOnlyNoCapture = false;
break;
}
}
{
if (cur == CI->getArgOperand(i)) {
if (!isNoCapture(CI, i)) {
writeOnlyNoCapture = false;
break;
}
// Don't need the primal argument if it is write only and
// not captured
if (writeOnlyNoCapture) {
continue;
if (!isWriteOnly(CI, i)) {
writeOnlyNoCapture = false;
break;
}
}
if (isa<CastInst>(u) || isa<GetElementPtrInst>(u) ||
isa<PHINode>(u)) {
todo.push_back(&*u);
} else {
legal = false;
break;
}
}
// Don't need the primal argument if it is write only and
// not captured
if (writeOnlyNoCapture) {
continue;
}
}
if (legal) {
return true;
if (isa<CastInst>(u) || isa<GetElementPtrInst>(u) ||
isa<PHINode>(u)) {
todo.push_back(&*u);
continue;
} else {
legal = false;
break;
}
}
} else if (auto II = dyn_cast<IntrinsicInst>(v);
II && isIntelSubscriptIntrinsic(*II)) {
unsigned int ptrArgIdx = 3;
return isNoNeed(II->getOperand(ptrArgIdx));
}
return false;
};
if (legal) {
return true;
}
}
} else if (auto II = dyn_cast<IntrinsicInst>(v);
II && isIntelSubscriptIntrinsic(*II)) {
unsigned int ptrArgIdx = 3;
return isNoNeed(II->getOperand(ptrArgIdx));
}
return false;
};

calculateUnusedValues(
func, unnecessaryValues, unnecessaryInstructions, returnValue,
Expand Down Expand Up @@ -935,9 +946,10 @@ void calculateUnusedValuesInFunction(
}

if (auto si = dyn_cast<StoreInst>(inst)) {
bool nnop = isNoNeed(si->getPointerOperand());
if (isa<UndefValue>(si->getValueOperand()))
return UseReq::Recur;
if (isNoNeed(si->getPointerOperand()))
if (nnop)
return UseReq::Recur;
}

Expand Down
28 changes: 21 additions & 7 deletions enzyme/Enzyme/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -1419,8 +1419,13 @@ static inline bool isReadOnly(const llvm::CallInst *call, ssize_t arg = -1) {
return true;

if (auto F = getFunctionFromCall(call)) {
if (isReadOnly(F, arg))
return true;
// Do not use function attrs for if different calling conv, such as a julia
// call wrapping args into an array. This is because the wrapped array
// may be nocapure/readonly, but the actual arg (which will be put in the
// array) may not be.
if (F->getCallingConv() == call->getCallingConv())
if (isReadOnly(F, arg))
return true;
}
return false;
}
Expand Down Expand Up @@ -1459,7 +1464,12 @@ static inline bool isWriteOnly(const llvm::CallInst *call, ssize_t arg = -1) {
#endif

if (auto F = getFunctionFromCall(call)) {
return isWriteOnly(F, arg);
// Do not use function attrs for if different calling conv, such as a julia
// call wrapping args into an array. This is because the wrapped array
// may be nocapure/readonly, but the actual arg (which will be put in the
// array) may not be.
if (F->getCallingConv() == call->getCallingConv())
return isWriteOnly(F, arg);
}
return false;
}
Expand All @@ -1476,10 +1486,14 @@ static inline bool isNoCapture(const llvm::CallInst *call, size_t idx) {
if (call->doesNotCapture(idx))
return true;

auto F = getFunctionFromCall(call);
if (F) {
if (F->hasParamAttribute(idx, llvm::Attribute::NoCapture))
return true;
if (auto F = getFunctionFromCall(call)) {
// Do not use function attrs for if different calling conv, such as a julia
// call wrapping args into an array. This is because the wrapped array
// may be nocapure/readonly, but the actual arg (which will be put in the
// array) may not be.
if (F->getCallingConv() == call->getCallingConv())
if (F->hasParamAttribute(idx, llvm::Attribute::NoCapture))
return true;
// if (F->getAttributes().hasParamAttribute(idx, "enzyme_NoCapture"))
// return true;
}
Expand Down
56 changes: 56 additions & 0 deletions enzyme/test/Enzyme/ReverseMode/missingstore.ll
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
; RUN: if [ %llvmver -lt 16 ]; then %opt < %s %loadEnzyme -enzyme-preopt=false -enzyme -mem2reg -instsimplify -simplifycfg -S | FileCheck %s; fi
; RUN: %opt < %s %newLoadEnzyme -enzyme-preopt=false -passes="enzyme,function(mem2reg,instsimplify,%simplifycfg)" -S | FileCheck %s

source_filename = "text"
target datalayout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128-ni:10:11:12:13"
target triple = "x86_64-pc-linux-gnu"

declare i8* @__enzyme_virtualreverse(...)

define i8* @dsquare(double %x) {
entry:
%i = tail call i8* (...) @__enzyme_virtualreverse({} addrspace(10)* ({ [1 x {} addrspace(10)*] } addrspace(11)*)* nonnull @julia__foldl_impl_3869)
ret i8* %i
}

define internal fastcc nonnull {} addrspace(10)* @julia__foldl_impl_3869({ [1 x {} addrspace(10)*] } addrspace(11)* nocapture nonnull readonly align 8 dereferenceable(8) %arg) {
top:
%i11 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj(i8* null, i64 8)
%i12 = bitcast { [1 x {} addrspace(10)*] } addrspace(11)* %arg to i64 addrspace(11)*
%i13 = bitcast {} addrspace(10)* %i11 to i64 addrspace(10)*
%i14 = load i64, i64 addrspace(11)* %i12, align 8
store i64 %i14, i64 addrspace(10)* %i13, align 8
call void @jl_invoke({} addrspace(10)* nonnull %i11)
ret {} addrspace(10)* null
}

; Function Attrs: nofree
define void @jl_invoke({} addrspace(10)* nocapture readonly %y) {
bb:
ret void
}

; Function Attrs: inaccessiblememonly allocsize(1)
declare noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj(i8*, i64)

; CHECK: define internal fastcc void @diffejulia__foldl_impl_3869({ [1 x {} addrspace(10)*] } addrspace(11)* nocapture readonly align 8 dereferenceable(8) %arg, { [1 x {} addrspace(10)*] } addrspace(11)* nocapture align 8 %"arg'", i8* %tapeArg)
; CHECK-NEXT: top:
; CHECK-NEXT: %0 = bitcast i8* %tapeArg to { i64, i64 }*
; CHECK-NEXT: %truetape = load { i64, i64 }, { i64, i64 }* %0
; CHECK-NEXT: tail call void @free(i8* nonnull %tapeArg)
; CHECK-NEXT: %i11 = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj(i8* null, i64 8)
; CHECK-NEXT: %"i11'mi" = call noalias nonnull {} addrspace(10)* @julia.gc_alloc_obj(i8* null, i64 8)
; CHECK-NEXT: %1 = bitcast {} addrspace(10)* %"i11'mi" to i8 addrspace(10)*
; CHECK-NEXT: call void @llvm.memset.p10i8.i64(i8 addrspace(10)* nonnull dereferenceable(8) dereferenceable_or_null(8) %1, i8 0, i64 8, i1 false)
; CHECK-NEXT: %"i13'ipc" = bitcast {} addrspace(10)* %"i11'mi" to i64 addrspace(10)*
; CHECK-NEXT: %i13 = bitcast {} addrspace(10)* %i11 to i64 addrspace(10)*
; CHECK-NEXT: %"i14'il_phi" = extractvalue { i64, i64 } %truetape, 0
; CHECK-NEXT: %i14 = extractvalue { i64, i64 } %truetape, 1
; CHECK-NEXT: store i64 %"i14'il_phi", i64 addrspace(10)* %"i13'ipc", align 8
; THE CRITICAL PART OF THIS TEST IS ENSURING THIS STORE EXISTS
; CHECK-NEXT: store i64 %i14, i64 addrspace(10)* %i13, align 8
; CHECK-NEXT: call void @diffejl_invoke({} addrspace(10)* %i11, {} addrspace(10)* %"i11'mi")
; CHECK-NEXT: ret void
; CHECK-NEXT: }


0 comments on commit 3c0014a

Please sign in to comment.