Skip to content

Commit

Permalink
Product of references in different scope fix (vgvassilev#1030)
Browse files Browse the repository at this point in the history
  • Loading branch information
ovdiiuv authored Aug 8, 2024
1 parent b0995ff commit 1b81084
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 5 deletions.
10 changes: 9 additions & 1 deletion lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2292,7 +2292,15 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
StmtDiff RResult;
// If R has no side effects, it can be just cloned
// (no need to store it).
if (!ShouldRecompute(R)) {

// Check if the local variable declaration is reference type, since it is
// moved to the global scope and the right side should be recomputed
bool promoteToFnScope = false;
if (auto* RDeclRef = dyn_cast<DeclRefExpr>(R->IgnoreImplicit()))
promoteToFnScope = RDeclRef->getDecl()->getType()->isReferenceType() &&
!getCurrentScope()->isFunctionScope();

if (!ShouldRecompute(R) || promoteToFnScope) {
RDelayed = std::unique_ptr<DelayedStoreResult>(
new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
Expand Down
10 changes: 6 additions & 4 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -2689,7 +2689,7 @@ double fn34(double x, double y){
double r = 0;
double a[] = {y, x*y, x*x + y};
for(auto& i: a){
r+=i;
r+=i*i;
}
return r;
}
Expand Down Expand Up @@ -2724,7 +2724,7 @@ double fn34(double x, double y){
//CHECK-NEXT: clad::push(_t3, _d_i);
//CHECK-NEXT: }
//CHECK-NEXT: clad::push(_t1, r);
//CHECK-NEXT: r += *i;
//CHECK-NEXT: r += *i * *i;
//CHECK-NEXT: }
//CHECK-NEXT: _d_r += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
Expand All @@ -2737,7 +2737,8 @@ double fn34(double x, double y){
//CHECK-NEXT: {
//CHECK-NEXT: r = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_r;
//CHECK-NEXT: *_d_i += _r_d0;
//CHECK-NEXT: *_d_i += _r_d0 * *i;
//CHECK-NEXT: *_d_i += *i * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }
Expand All @@ -2751,6 +2752,7 @@ double fn34(double x, double y){
//CHECK-NEXT: }
//CHECK-NEXT: }


double fn35(double x, double y){
double a[] = {1, 2, 3};
double sum = 0;
Expand Down Expand Up @@ -2901,7 +2903,7 @@ int main() {
TEST_2(fn32, 3, 5); // CHECK-EXEC: {45.00, 27.00}
TEST_2(fn33, 3, 5); // CHECK-EXEC: {15.00, 9.00}

TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00}
TEST_2(fn34, 2, 2); // CHECK-EXEC: {64.00, 32.00}
TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00}
}

Expand Down

0 comments on commit 1b81084

Please sign in to comment.