Skip to content

Commit

Permalink
Generalize the logic for differentiating range-based for loops. Fixes v…
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 8, 2024
1 parent 01787b4 commit 618f9c7
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 24 deletions.
31 changes: 9 additions & 22 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -994,30 +994,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
auto* BeginDeclRef = cast<DeclRefExpr>(BeginExpr);
Expr* d_BeginDeclRef = m_Variables[BeginDeclRef->getDecl()];

auto* RangeExpr =
cast<DeclRefExpr>(cast<BinaryOperator>(VisitRange.getStmt())->getLHS());

Expr* RangeInit = Clone(FRS->getRangeInit());
Expr* AssignRange =
BuildOp(BO_Assign, RangeExpr, BuildOp(UO_AddrOf, RangeInit));
Expr* AssignBegin =
BuildOp(BO_Assign, BeginDeclRef, BuildOp(UO_Deref, RangeExpr));
addToCurrentBlock(AssignRange);
addToCurrentBlock(AssignBegin);
addToCurrentBlock(VisitRange.getStmt());
addToCurrentBlock(VisitBegin.getStmt());
const auto* EndDecl = cast<VarDecl>(FRS->getEndStmt()->getSingleDecl());

Expr* EndInit = cast<BinaryOperator>(EndDecl->getInit())->getRHS();
QualType EndType = CloneType(EndDecl->getType());
std::string EndName = EndDecl->getNameAsString();
Expr* EndAssign = BuildOp(BO_Add, BuildOp(UO_Deref, RangeExpr), EndInit);
Expr* EndInit = Visit(EndDecl->getInit()).getExpr();
VarDecl* EndVarDecl =
BuildGlobalVarDecl(EndType, EndName, EndAssign, /*DirectInit=*/false);
DeclStmt* AssignEnd = BuildDeclStmt(EndVarDecl);

addToCurrentBlock(AssignEnd);
auto* AssignEndVarDecl =
cast<VarDecl>(cast<DeclStmt>(AssignEnd)->getSingleDecl());
DeclRefExpr* EndExpr = BuildDeclRef(AssignEndVarDecl);
BuildGlobalVarDecl(EndType, EndName, EndInit, /*DirectInit=*/false);
addToCurrentBlock(BuildDeclStmt(EndVarDecl));
DeclRefExpr* EndExpr = BuildDeclRef(EndVarDecl);
Expr* IncBegin = BuildOp(UO_PreInc, BeginDeclRef);

beginBlock(direction::forward);
Expand All @@ -1036,14 +1023,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* ForwardCond = BuildOp(BO_NE, BeginDeclRef, EndExpr);
// Add item assignment statement to the body.
const Stmt* body = FRS->getBody();
StmtDiff bodyDiff = Visit(body);
StmtDiff bodyDiff =
DifferentiateLoopBody(body, loopCounter, nullptr, nullptr,
/*isForLoop=*/true);

StmtDiff storeLoop = StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl()));
StmtDiff storeAdjLoop =
StoreAndRestore(BuildDeclRef(LoopVDDiff.getDecl_dx()));

addToCurrentBlock(BuildDeclStmt(LoopVDDiff.getDecl_dx()));
Expr* CounterIncrement = loopCounter.getCounterIncrement();

Expr* LoopInit = LoopVDDiff.getDecl()->getInit();
LoopVDDiff.getDecl()->setInit(getZeroInit(LoopVDDiff.getDecl()->getType()));
Expand All @@ -1058,7 +1046,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

beginBlock(direction::forward);
addToCurrentBlock(CounterIncrement);
addToCurrentBlock(AdjLoopVDAddAssign);
addToCurrentBlock(AssignLoop);
addToCurrentBlock(storeLoop.getStmt());
Expand Down
64 changes: 62 additions & 2 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -2717,12 +2717,12 @@ double fn34(double x, double y){
//CHECK-NEXT: double *i = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
//CHECK-NEXT: {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: _d_i = &*_d___begin1;
//CHECK-NEXT: i = &*__begin10;
//CHECK-NEXT: clad::push(_t2, i);
//CHECK-NEXT: clad::push(_t3, _d_i);
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, r);
//CHECK-NEXT: r += *i;
//CHECK-NEXT: }
Expand Down Expand Up @@ -2784,12 +2784,12 @@ double fn35(double x, double y){
//CHECK-NEXT: double i = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
//CHECK-NEXT: {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: _d_i = *_d___begin1;
//CHECK-NEXT: i = *__begin10;
//CHECK-NEXT: clad::push(_t3, i);
//CHECK-NEXT: clad::push(_t4, _d_i);
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: clad::push(_t2, sin(i));
//CHECK-NEXT: sum += clad::back(_t2) * x;
Expand All @@ -2816,6 +2816,65 @@ double fn35(double x, double y){
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn36(double x, double y) {
double range[] = {x, 4., y};
double sum = 0;
for (auto elem : range)
sum += elem;
return sum;
}

//CHECK: void fn36_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: double (*_d___range1)[3] = 0;
//CHECK-NEXT: double (*__range10)[3] = {};
//CHECK-NEXT: double *_d___begin1 = 0;
//CHECK-NEXT: double *__begin10 = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: double _d_range[3] = {0};
//CHECK-NEXT: double range[3] = {x, 4., y};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: _d___range1 = &_d_range;
//CHECK-NEXT: _d___begin1 = *_d___range1;
//CHECK-NEXT: __range10 = &range;
//CHECK-NEXT: __begin10 = *__range10;
//CHECK-NEXT: double *__end10 = *__range10 + {{3|3L}};
//CHECK-NEXT: double _d_elem = 0;
//CHECK-NEXT: double elem = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
//CHECK-NEXT: {
//CHECK-NEXT: _d_elem = *_d___begin1;
//CHECK-NEXT: elem = *__begin10;
//CHECK-NEXT: clad::push(_t2, elem);
//CHECK-NEXT: clad::push(_t3, _d_elem);
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += elem;
//CHECK-NEXT: }
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: {
//CHECK-NEXT: _d___begin1--;
//CHECK-NEXT: elem = clad::pop(_t2);
//CHECK-NEXT: _d_elem = clad::pop(_t3);
//CHECK-NEXT: }
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_elem += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: *_d___begin1 += _d_elem;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: *_d_x += _d_range[0];
//CHECK-NEXT: *_d_y += _d_range[2];
//CHECK-NEXT: }
//CHECK-NEXT: }

#define TEST(F, x) { \
result[0] = 0; \
Expand Down Expand Up @@ -2903,6 +2962,7 @@ int main() {

TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00}
TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00}
TEST_2(fn36, 6, 3); // CHECK-EXEC: {1.00, 1.00}
}

//CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {
Expand Down

0 comments on commit 618f9c7

Please sign in to comment.