Skip to content

Commit

Permalink
Improve cache index error message (#2138)
Browse files Browse the repository at this point in the history
* Improve cache index error message

* fix

* fix

* fix
  • Loading branch information
wsmoses authored Oct 29, 2024
1 parent fffdd2c commit e9e4ef3
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 8 deletions.
31 changes: 23 additions & 8 deletions enzyme/Enzyme/AdjointGenerator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5493,14 +5493,29 @@ class AdjointGenerator : public llvm::InstVisitor<AdjointGenerator> {
It.setHeadBit(true);
BuilderZ.SetInsertPoint(It);
#endif
if (Mode == DerivativeMode::ReverseModeCombined)
cachereplace = newCall;
else
cachereplace = BuilderZ.CreatePHI(call.getType(), 1,
call.getName() + "_tmpcacheB");
cachereplace = gutils->cacheForReverse(
BuilderZ, cachereplace,
getIndex(&call, CacheType::Self, BuilderZ));
auto idx = getIndex(&call, CacheType::Self, BuilderZ);
if (idx == IndexMappingError) {
std::string str;
raw_string_ostream ss(str);
ss << "Failed to compute consistent cache index for operation: "
<< call << "\n";
if (CustomErrorHandler) {
CustomErrorHandler(str.c_str(), wrap(&call),
ErrorType::InternalError, nullptr, nullptr,
nullptr);
} else {
EmitFailure("GetIndexError", call.getDebugLoc(), &call,
ss.str());
}
} else {
if (Mode == DerivativeMode::ReverseModeCombined)
cachereplace = newCall;
else
cachereplace = BuilderZ.CreatePHI(
call.getType(), 1, call.getName() + "_tmpcacheB");
cachereplace =
gutils->cacheForReverse(BuilderZ, cachereplace, idx);
}
} else {
#if LLVM_VERSION_MAJOR >= 18
auto It = BuilderZ.GetInsertPoint();
Expand Down
1 change: 1 addition & 0 deletions enzyme/Enzyme/EnzymeLogic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1947,6 +1947,7 @@ const AugmentedReturn &EnzymeLogic::CreateAugmentedPrimal(
if (EmitNoDerivativeError(ss.str(), todiff, context)) {
auto newFunc = todiff;
std::map<AugmentedStruct, int> returnMapping;
returnMapping[AugmentedStruct::Return] = -1;
return insert_or_assign<AugmentedCacheKey, AugmentedReturn>(
AugmentedCachedFunctions, tup,
AugmentedReturn(newFunc, nullptr, {}, returnMapping, {}, {},
Expand Down

0 comments on commit e9e4ef3

Please sign in to comment.