Skip to content

Commit

Permalink
Make const loop variables global and drop const.
Browse files Browse the repository at this point in the history
We didn't take into account const variables inside loops can be re-initialized on different loop iterations so we have to store them just like non-const variables. const should be dropped to allow us to replace initializations with assignments.
Fixes vgvassilev#667.
  • Loading branch information
PetroZarytskyi committed Dec 12, 2023
1 parent ec5ea41 commit d057047
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 4 deletions.
13 changes: 9 additions & 4 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2692,9 +2692,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// assignments. This is a temporary measure to avoid the bug that arises
// from overwriting local variables on different loop passes.
if (isInsideLoop) {
if (VD->getType()->isBuiltinType() &&
!VD->getType().isConstQualified()) {
if (VD->getType()->isBuiltinType()) {
auto* decl = VDDiff.getDecl();
/// The same variable will be assigned with new values every
/// loop iteration so the const qualifier must be dropped.
if (decl->getType().isConstQualified()) {
QualType nonConstType =
getNonConstType(decl->getType(), m_Context, m_Sema);
decl->setType(nonConstType);
}
if (decl->getInit()) {
auto* declRef = BuildDeclRef(decl);
auto pushPop =
Expand Down Expand Up @@ -2744,8 +2750,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/// overwriting local variables on different loop passes.
if (isInsideLoop) {
if (auto* VD = dyn_cast<VarDecl>(decls[0])) {
if (VD->getType()->isBuiltinType() &&
!VD->getType().isConstQualified()) {
if (VD->getType()->isBuiltinType()) {
addToBlock(DSClone, m_Globals);
Stmt* initAssignments = MakeCompoundStmt(inits);
initAssignments = unwrapIfSingleStmt(initAssignments);
Expand Down
50 changes: 50 additions & 0 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,55 @@ double f5(double x){
//CHECK-NEXT: }
//CHECK-NEXT: }

double f_const_local(double x) {
double res = 0;
for (int i = 0; i < 3; ++i) {
const double n = x + i;
res += x * n;
}
return res;
} // == 3x^2 + 3x

//CHECK: void f_const_local_grad(double x, clad::array_ref<double> _d_x) {
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_n = 0;
//CHECK-NEXT: double n = 0;
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < 3; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, n) , n = x + i;
//CHECK-NEXT: clad::push(_t2, res);
//CHECK-NEXT: res += x * n;
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: --i;
//CHECK-NEXT: {
//CHECK-NEXT: res = clad::pop(_t2);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_res += _r_d0;
//CHECK-NEXT: double _r0 = _r_d0 * n;
//CHECK-NEXT: * _d_x += _r0;
//CHECK-NEXT: double _r1 = x * _r_d0;
//CHECK-NEXT: _d_n += _r1;
//CHECK-NEXT: _d_res -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += _d_n;
//CHECK-NEXT: _d_i += _d_n;
//CHECK-NEXT: _d_n = 0;
//CHECK-NEXT: n = clad::pop(_t1);
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT:}

double f_sum(double *p, int n) {
double s = 0;
for (int i = 0; i < n; i++)
Expand Down Expand Up @@ -1715,6 +1764,7 @@ int main() {
TEST(f3, 3); // CHECK-EXEC: {6.00}
TEST(f4, 3); // CHECK-EXEC: {27.00}
TEST(f5, 3); // CHECK-EXEC: {1.00}
TEST(f_const_local, 3); // CHECK-EXEC: {21.00}

double p[] = { 1, 2, 3, 4, 5 };

Expand Down

0 comments on commit d057047

Please sign in to comment.