Skip to content

Commit

Permalink
Correctly handle ref-type and array-type decls.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Jan 31, 2024
1 parent bec7856 commit 2a399a0
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 31 deletions.
67 changes: 38 additions & 29 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1296,7 +1296,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to the underlying struct of a lambda.
if (VD->getDeclContext() != m_Sema.CurContext)
clonedDRE = cast<DeclRefExpr>(BuildDeclRef(VD));
if (DRE->getDecl()->getType()->isReferenceType())
if (DRE->getDecl()->getType()->isReferenceType()
&& clonedDRE->getType()->isPointerType())
clonedDRE = BuildOp(UO_Deref, clonedDRE);
if (isVectorValued) {
if (m_VectorOutput.size() <= outputArrayCursor)
Expand Down Expand Up @@ -2515,7 +2516,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
StmtDiff initDiff;
Expr* VDDerivedInit = nullptr;
// FIXME: find a more reliable way to determine if the declaration
// is in the function global scope.
bool isInFunctionGlobalScope = m_Reverse.size() <= 2;
auto VDDerivedType = ComputeAdjointType(VD->getType());
auto VDCloneType = CloneType(VD->getType());
if (!isInFunctionGlobalScope)
VDCloneType = VDDerivedType;
bool isDerivativeOfRefType = VD->getType()->isReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
Expand All @@ -2527,6 +2534,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr,
clang::VarDecl::InitializationStyle::CallInit);
if (!isInFunctionGlobalScope)
initDiff = VDDerivedInit;
} else {
// If VD is a reference to a local variable, then the initial value is set
Expand Down Expand Up @@ -2596,7 +2604,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerivedInit = getZeroInit(VDDerivedType);
}
VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit);
VDDerivedInit, false, nullptr, VD->getInitStyle());
}

// If `VD` is a reference to a local variable, then it is already
Expand Down Expand Up @@ -2638,24 +2646,26 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
if (isDerivativeOfRefType) {
VDClone = BuildGlobalVarDecl(VDDerivedType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getExpr()), VD->isDirectInit());
Expr* assignDerivativeE =
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getForwSweepExpr_dx()));
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getForwSweepExpr_dx()));
addToCurrentBlock(assignDerivativeE);
if (isInsideLoop) {
StmtDiff pushPop = StoreAndRestore(derivedVDE, /*prefix=*/"_t", /*force=*/true);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
m_LoopBlock.back().push_back(pushPop.getExpr_dx());
}
derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE);
} else {
VDClone = BuildGlobalVarDecl(VDDerivedType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(), nullptr, VarDecl::InitializationStyle::CallInit);
}

if (isDerivativeOfRefType && !isInFunctionGlobalScope)
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getExpr()), VD->isDirectInit());
else
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
initDiff.getExpr(), VD->isDirectInit(), nullptr, VD->getInitStyle());
if (isPointerType) {
Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign,
derivedVDE, initDiff.getExpr_dx());
Expand Down Expand Up @@ -2715,6 +2725,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<Decl*, 4> declsDiff;
// Need to put array decls inlined.
llvm::SmallVector<Decl*, 4> localDeclsDiff;
// FIXME: find a more reliable way to determine if the declaration
// is in the function global scope.
bool isInFunctionGlobalScope = m_Reverse.size() <= 2;
// For each variable declaration v, create another declaration _d_v to
// store derivatives for potential reassignments. E.g.
// double y = x;
Expand Down Expand Up @@ -2742,17 +2755,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// double _d_y = x; // copied from original funcion, collides with
// _d_y
// }
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || VD->getType()->isReferenceType())
if (VDDiff.getDecl()->getDeclName() != VD->getDeclName() || VD->getType()!=VDDiff.getDecl()->getType())
m_DeclReplacements[VD] = VDDiff.getDecl();

// FIXME: This part in necessary to replace local variables inside loops
// This part in necessary to replace local variables inside loops
// with function globals and replace initializations with
// assignments. This is a temporary measure to avoid the bug that arises
// from overwriting local variables on different loop passes.
if (m_Reverse.size() > 2) {
if (!isInFunctionGlobalScope) {
auto* decl = VDDiff.getDecl();
/// The same variable will be assigned with new values every
/// loop iteration so the const qualifier must be dropped.
// The same variable will be assigned with new values every
// loop iteration so the const qualifier must be dropped.
if (decl->getType().isConstQualified()) {
QualType nonConstType =
getNonConstType(decl->getType(), m_Context, m_Sema);
Expand All @@ -2766,13 +2777,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
auto* assignment = BuildOp(BO_Assign, declRef, decl->getInit());
inits.push_back(BuildOp(BO_Comma, pushPop.getExpr(), assignment));
if (isa<ArrayType>(VD->getType()))
decl->setInit(nullptr);
if (isa<ArrayType>(VD->getType())) {
decl->setInitStyle(VarDecl::InitializationStyle::CallInit);
decl->setInit(Clone(VDDiff.getDecl_dx()->getInit()));
} else {
decl->setInit(getZeroInit(VD->getType()));
}
}
if (isa<ArrayType>(VD->getType()))
decl->setInitStyle(VarDecl::InitializationStyle::CallInit);
else
decl->setInit(getZeroInit(VD->getType()));
}

decls.push_back(VDDiff.getDecl());
Expand Down Expand Up @@ -2805,11 +2816,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
m_ExternalSource->ActBeforeFinalizingVisitDeclStmt(decls, declsDiff);
}

/// FIXME: This part in necessary to replace local variables inside loops
/// with function globals and replace initializations with assignments.
/// This is a temporary measure to avoid the bug that arises from
/// overwriting local variables on different loop passes.
if (m_Reverse.size() > 2) {
// This part in necessary to replace local variables inside loops
// with function globals and replace initializations with assignments.
if (!isInFunctionGlobalScope) {
addToBlock(DSClone, m_Globals);
Stmt* initAssignments = MakeCompoundStmt(inits);
initAssignments = unwrapIfSingleStmt(initAssignments);
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/StmtClone.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -522,8 +522,9 @@ bool ReferencesUpdater::VisitDeclRefExpr(DeclRefExpr* DRE) {
auto it = m_DeclReplacements.find(VD);
if (it != std::end(m_DeclReplacements)) {
DRE->setDecl(it->second);
if (it->second->getType()!=DRE->getType())
DRE->setType(it->second->getType());
if (it->second->getType().getNonReferenceType()!=DRE->getType()) {
DRE->setType(it->second->getType().getNonReferenceType());
}
}
}

Expand Down

0 comments on commit 2a399a0

Please sign in to comment.