Skip to content

Commit

Permalink
Generalize the logic for differentiating range-based for loops
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 7, 2024
1 parent b66b769 commit 5bd119b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 24 deletions.
31 changes: 9 additions & 22 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,30 +994,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* BeginDeclRef = cast<DeclRefExpr>(BeginExpr);
Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()];

auto* RangeExpr =
cast<DeclRefExpr>(cast<BinaryOperator>(VisitRange.getStmt())->getLHS());

Expr* RangeInit = Clone(FRS->getRangeInit());
Expr* AssignRange =
BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit));
Expr* AssignBegin =
BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr));
addToCurrentBlock(AssignRange);
addToCurrentBlock(AssignBegin);
addToCurrentBlock(VisitRange.getStmt());
addToCurrentBlock(VisitBegin.getStmt());
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());

Expr* EndInit = cast<BinaryOperator>(EndDecl->getInit())->getRHS();
QualType EndType = CloneType(EndDecl->getType());
std::string EndName = EndDecl->getNameAsString();
Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit);
Expr* EndInit = Visit(EndDecl->getInit()).getExpr();
VarDecl* EndVarDecl =
BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false);
DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl);

addToCurrentBlock(AssignEnd);
auto* AssignEndVarDecl =
cast<VarDecl>(cast<DeclStmt>(AssignEnd)->getSingleDecl());
DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl);
BuildGlobalVarDecl(EndType, EndName, EndInit, /*DirectInit=*/false);
addToCurrentBlock(BuildDeclStmt(EndVarDecl));
DeclRefExpr* EndExpr = BuildDeclRef(EndVarDecl);
Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef);

beginBlock(direction::forward);
Expand All @@ -1036,14 +1023,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr);
// Add item assignment statement to the body.
const Stmt* body = FRS->getBody();
StmtDiff bodyDiff = Visit(body);
StmtDiff bodyDiff =
DifferentiateLoopBody(body, loopCounter, nullptr, nullptr,
/*isForLoop=*/true);

StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl()));
StmtDiff storeAdjLoop =
StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx()));

addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx()));
Expr* CounterIncrement = loopCounter.getCounterIncrement();

Expr* LoopInit = LoopVDDiff.getDecl()->getInit();
LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType()));
Expand All @@ -1058,7 +1046,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

beginBlock(direction::forward);
addToCurrentBlock(CounterIncrement);
addToCurrentBlock(AdjLoopVDAddAssign);
addToCurrentBlock(AssignLoop);
addToCurrentBlock(storeLoop.getStmt());
Expand Down
4 changes: 2 additions & 2 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -2717,12 +2717,12 @@ double fn34(double x, double y){
//CHECK-NEXT: double *i = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
//CHECK-NEXT: {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: _d_i = &*_d___begin1;
//CHECK-NEXT: i = &*__begin10;
//CHECK-NEXT: clad::push(_t2, i);
//CHECK-NEXT: clad::push(_t3, _d_i);
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, r);
//CHECK-NEXT: r += *i;
//CHECK-NEXT: }
Expand Down Expand Up @@ -2784,12 +2784,12 @@ double fn35(double x, double y){
//CHECK-NEXT: double i = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
//CHECK-NEXT: {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: _d_i = *_d___begin1;
//CHECK-NEXT: i = *__begin10;
//CHECK-NEXT: clad::push(_t3, i);
//CHECK-NEXT: clad::push(_t4, _d_i);
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: clad::push(_t2, sin(i));
//CHECK-NEXT: sum += clad::back(_t2) * x;
Expand Down

0 comments on commit 5bd119b

Please sign in to comment.