diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 6ce0ab525..b3beb8163 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -746,7 +746,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } StmtDiff ReverseModeVisitor::VisitCompoundStmt(const CompoundStmt* CS) { - beginScope(Scope::DeclScope); + int scopeFlags = Scope::DeclScope; + // If this is the outermost compound statement of the function, + // propagate the function scope. + if (getCurrentScope()==m_DerivativeFnScope) + scopeFlags |= Scope::FnScope; + beginScope(scopeFlags); beginBlock(direction::forward); beginBlock(direction::reverse); for (Stmt* S : CS->body()) { @@ -2519,12 +2524,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VarDeclDiff ReverseModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { StmtDiff initDiff; Expr* VDDerivedInit = nullptr; - // We take the parent of the current scope because the main compound - // statement of the function has its own scope as well. - bool isInFunctionGlobalScope = getCurrentScope()->getParent()==m_DerivativeFnScope; + // reverse_mode_forward_pass does not have a reverse pass so declarations + // don't have to be moved to the function global scope. + bool moveToFunctionScope = getCurrentScope()->isFunctionScope() || m_Mode == DiffMode::reverse_mode_forward_pass; auto VDDerivedType = ComputeAdjointType(VD->getType()); auto VDCloneType = CloneType(VD->getType()); - if (!isInFunctionGlobalScope) + if (!moveToFunctionScope) VDCloneType = VDDerivedType; bool isDerivativeOfRefType = VD->getType()->isReferenceType(); VarDecl* VDDerived = nullptr; @@ -2537,7 +2542,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VDDerived = BuildGlobalVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), VDDerivedInit, false, nullptr, clang::VarDecl::InitializationStyle::CallInit); - if (!isInFunctionGlobalScope) + if (!moveToFunctionScope) initDiff = VDDerivedInit; } else { // If VD is a reference to a local variable, then the initial value is set @@ -2662,7 +2667,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, derivedVDE = BuildOp(UnaryOperatorKind::UO_Deref, derivedVDE); } - if (isDerivativeOfRefType && !isInFunctionGlobalScope) + if (isDerivativeOfRefType && !moveToFunctionScope) VDClone = BuildGlobalVarDecl(VDCloneType, VD->getNameAsString(), BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()), VD->isDirectInit()); @@ -2728,9 +2733,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, llvm::SmallVector declsDiff; // Need to put array decls inlined. llvm::SmallVector localDeclsDiff; - // We take the parent of the current scope because the main compound - // statement of the function has its own scope as well. - bool isInFunctionGlobalScope = getCurrentScope()->getParent()==m_DerivativeFnScope; + // reverse_mode_forward_pass does not have a reverse pass so declarations + // don't have to be moved to the function global scope. + bool moveToFunctionScope = getCurrentScope()->isFunctionScope() || m_Mode == DiffMode::reverse_mode_forward_pass; // For each variable declaration v, create another declaration _d_v to // store derivatives for potential reassignments. E.g. // double y = x; @@ -2763,7 +2768,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // This part in necessary to replace local variables inside loops // with function globals and replace initializations with - if (!isInFunctionGlobalScope) { + if (!moveToFunctionScope) { auto* decl = VDDiff.getDecl(); // The same variable will be assigned with new values every // loop iteration so the const qualifier must be dropped. @@ -2836,7 +2841,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // This part in necessary to replace local variables inside loops // with function globals and replace initializations with assignments. - if (!isInFunctionGlobalScope) { + if (!moveToFunctionScope) { addToBlock(DSClone, m_Globals); Stmt* initAssignments = MakeCompoundStmt(inits); initAssignments = unwrapIfSingleStmt(initAssignments);