Skip to content

Commit

Permalink
Propagate FnScope to the first compound statement and detect the func…
Browse files Browse the repository at this point in the history
…tion global scope based on getCurrentScope.
  • Loading branch information
PetroZarytskyi committed Feb 6, 2024
1 parent 87331f4 commit bd0bb57
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -746,7 +746,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

StmtDiff ReverseModeVisitor::VisitCompoundStmt(const CompoundStmt* CS) {
beginScope(Scope::DeclScope);
int scopeFlags = Scope::DeclScope;
// If this is the outermost compound statement of the function,
// propagate the function scope.
if (getCurrentScope()==m_DerivativeFnScope)
scopeFlags |= Scope::FnScope;
beginScope(scopeFlags);
beginBlock(direction::forward);
beginBlock(direction::reverse);
for (Stmt* S : CS->body()) {
Expand Down Expand Up @@ -2519,12 +2524,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
StmtDiff initDiff;
Expr* VDDerivedInit = nullptr;
// We take the parent of the current scope because the main compound
// statement of the function has its own scope as well.
bool isInFunctionGlobalScope = getCurrentScope()->getParent()==m_DerivativeFnScope;
// reverse_mode_forward_pass does not have a reverse pass so declarations
// don't have to be moved to the function global scope.
bool moveToFunctionScope = getCurrentScope()->isFunctionScope() || m_Mode == DiffMode::reverse_mode_forward_pass;
auto VDDerivedType = ComputeAdjointType(VD->getType());
auto VDCloneType = CloneType(VD->getType());
if (!isInFunctionGlobalScope)
if (!moveToFunctionScope)
VDCloneType = VDDerivedType;
bool isDerivativeOfRefType = VD->getType()->isReferenceType();
VarDecl* VDDerived = nullptr;
Expand All @@ -2537,7 +2542,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
VDDerivedInit, false, nullptr,
clang::VarDecl::InitializationStyle::CallInit);
if (!isInFunctionGlobalScope)
if (!moveToFunctionScope)
initDiff = VDDerivedInit;
} else {
// If VD is a reference to a local variable, then the initial value is set
Expand Down Expand Up @@ -2662,7 +2667,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE);
}

if (isDerivativeOfRefType && !isInFunctionGlobalScope)
if (isDerivativeOfRefType && !moveToFunctionScope)
VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf,
initDiff.getExpr()), VD->isDirectInit());
Expand Down Expand Up @@ -2728,9 +2733,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
llvm::SmallVector<Decl*, 4> declsDiff;
// Need to put array decls inlined.
llvm::SmallVector<Decl*, 4> localDeclsDiff;
// We take the parent of the current scope because the main compound
// statement of the function has its own scope as well.
bool isInFunctionGlobalScope = getCurrentScope()->getParent()==m_DerivativeFnScope;
// reverse_mode_forward_pass does not have a reverse pass so declarations
// don't have to be moved to the function global scope.
bool moveToFunctionScope = getCurrentScope()->isFunctionScope() || m_Mode == DiffMode::reverse_mode_forward_pass;
// For each variable declaration v, create another declaration _d_v to
// store derivatives for potential reassignments. E.g.
// double y = x;
Expand Down Expand Up @@ -2763,7 +2768,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// This part in necessary to replace local variables inside loops
// with function globals and replace initializations with
if (!isInFunctionGlobalScope) {
if (!moveToFunctionScope) {
auto* decl = VDDiff.getDecl();
// The same variable will be assigned with new values every
// loop iteration so the const qualifier must be dropped.
Expand Down Expand Up @@ -2836,7 +2841,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// This part in necessary to replace local variables inside loops
// with function globals and replace initializations with assignments.
if (!isInFunctionGlobalScope) {
if (!moveToFunctionScope) {
addToBlock(DSClone, m_Globals);
Stmt* initAssignments = MakeCompoundStmt(inits);
initAssignments = unwrapIfSingleStmt(initAssignments);
Expand Down

0 comments on commit bd0bb57

Please sign in to comment.