diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 488eaf6ba..9164ed983 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -994,30 +994,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, auto* BeginDeclRef = cast(BeginExpr); Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()]; - auto* RangeExpr = - cast(cast(VisitRange.getStmt())->getLHS()); - - Expr* RangeInit = Clone(FRS->getRangeInit()); - Expr* AssignRange = - BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit)); - Expr* AssignBegin = - BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr)); - addToCurrentBlock(AssignRange); - addToCurrentBlock(AssignBegin); + addToCurrentBlock(VisitRange.getStmt()); + addToCurrentBlock(VisitBegin.getStmt()); const auto* EndDecl = cast(FRS->getEndStmt()->getSingleDecl()); - Expr* EndInit = cast(EndDecl->getInit())->getRHS(); QualType EndType = CloneType(EndDecl->getType()); std::string EndName = EndDecl->getNameAsString(); - Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit); + Expr* EndInit = Visit(EndDecl->getInit()).getExpr(); VarDecl* EndVarDecl = - BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false); - DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl); - - addToCurrentBlock(AssignEnd); - auto* AssignEndVarDecl = - cast(cast(AssignEnd)->getSingleDecl()); - DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl); + BuildGlobalVarDecl(EndType, EndName, EndInit, /*DirectInit=*/false); + addToCurrentBlock(BuildDeclStmt(EndVarDecl)); + DeclRefExpr* EndExpr = BuildDeclRef(EndVarDecl); Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef); beginBlock(direction::forward); @@ -1036,14 +1023,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr); // Add item assignment statement to the body. const Stmt* body = FRS->getBody(); - StmtDiff bodyDiff = Visit(body); + StmtDiff bodyDiff = + DifferentiateLoopBody(body, loopCounter, nullptr, nullptr, + /*isForLoop=*/true); StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl())); StmtDiff storeAdjLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx())); addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx())); - Expr* CounterIncrement = loopCounter.getCounterIncrement(); Expr* LoopInit = LoopVDDiff.getDecl()->getInit(); LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType())); @@ -1058,7 +1046,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } beginBlock(direction::forward); - addToCurrentBlock(CounterIncrement); addToCurrentBlock(AdjLoopVDAddAssign); addToCurrentBlock(AssignLoop); addToCurrentBlock(storeLoop.getStmt()); diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index d648a1149..2d5d957b1 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2717,12 +2717,12 @@ double fn34(double x, double y){ //CHECK-NEXT: double *i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { //CHECK-NEXT: { -//CHECK-NEXT: _t0++; //CHECK-NEXT: _d_i = &*_d___begin1; //CHECK-NEXT: i = &*__begin10; //CHECK-NEXT: clad::push(_t2, i); //CHECK-NEXT: clad::push(_t3, _d_i); //CHECK-NEXT: } +//CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, r); //CHECK-NEXT: r += *i; //CHECK-NEXT: } @@ -2784,12 +2784,12 @@ double fn35(double x, double y){ //CHECK-NEXT: double i = 0; //CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { //CHECK-NEXT: { -//CHECK-NEXT: _t0++; //CHECK-NEXT: _d_i = *_d___begin1; //CHECK-NEXT: i = *__begin10; //CHECK-NEXT: clad::push(_t3, i); //CHECK-NEXT: clad::push(_t4, _d_i); //CHECK-NEXT: } +//CHECK-NEXT: _t0++; //CHECK-NEXT: clad::push(_t1, sum); //CHECK-NEXT: clad::push(_t2, sin(i)); //CHECK-NEXT: sum += clad::back(_t2) * x; @@ -2816,6 +2816,65 @@ double fn35(double x, double y){ //CHECK-NEXT: } //CHECK-NEXT: } +double fn36(double x, double y) { + double range[] = {x, 4., y}; + double sum = 0; + for (auto elem : range) + sum += elem; + return sum; +} + +//CHECK: void fn36_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: unsigned {{int|long}} _t0; +//CHECK-NEXT: double (*_d___range1)[3] = 0; +//CHECK-NEXT: double (*__range10)[3] = {}; +//CHECK-NEXT: double *_d___begin1 = 0; +//CHECK-NEXT: double *__begin10 = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: double _d_range[3] = {0}; +//CHECK-NEXT: double range[3] = {x, 4., y}; +//CHECK-NEXT: double _d_sum = 0; +//CHECK-NEXT: double sum = 0; +//CHECK-NEXT: _t0 = {{0U|0UL}}; +//CHECK-NEXT: _d___range1 = &_d_range; +//CHECK-NEXT: _d___begin1 = *_d___range1; +//CHECK-NEXT: __range10 = ⦥ +//CHECK-NEXT: __begin10 = *__range10; +//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}}; +//CHECK-NEXT: double _d_elem = 0; +//CHECK-NEXT: double elem = 0; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_elem = *_d___begin1; +//CHECK-NEXT: elem = *__begin10; +//CHECK-NEXT: clad::push(_t2, elem); +//CHECK-NEXT: clad::push(_t3, _d_elem); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, sum); +//CHECK-NEXT: sum += elem; +//CHECK-NEXT: } +//CHECK-NEXT: _d_sum += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: elem = clad::pop(_t2); +//CHECK-NEXT: _d_elem = clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: sum = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: _d_elem += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: *_d___begin1 += _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_range[0]; +//CHECK-NEXT: *_d_y += _d_range[2]; +//CHECK-NEXT: } +//CHECK-NEXT: } #define TEST(F, x) { \ result[0] = 0; \ @@ -2903,6 +2962,7 @@ int main() { TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00} TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00} + TEST_2(fn36, 6, 3); // CHECK-EXEC: {1.00, 1.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {