Skip to content

Commit

Permalink
Store and restore reference arguments only if they are lvalue and non…
Browse files Browse the repository at this point in the history
…-const.

Currently, in the reverse mode, the argument is stored when the function parameter is of a reference type and the argument is not a temporary expression (e.g. ``0`` when bound to a ``const int&`` parameter). A better approach to this is checking if the parameter is an lvalue non-const reference. This lets us avoid unnecessary stores and errors:
```
double g(const double& t) {...}

double f (...) {
  double x = ...;
  const double y = ...;
  g(x);   // No need to store x since the function isn't able to modify it
  g(y);   // An attempt to restore y will lead to an error because it's a const
}
```
This commit fixes the issue in the ReverseModeVisitor and TBRAnalyzer.
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Oct 15, 2024
1 parent fc6d311 commit 332358e
Show file tree
Hide file tree
Showing 4 changed files with 126 additions and 172 deletions.
5 changes: 3 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1850,8 +1850,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// may be changed since we have no way to determine otherwise.
// FIXME: We cannot use GlobalStoreAndRef to store a whole array so now
// arrays are not stored.
bool passByRef = PVD->getType()->isReferenceType() &&
!isa<MaterializeTemporaryExpr>(arg);
QualType paramTy = PVD->getType();
bool passByRef = paramTy->isLValueReferenceType() &&
!paramTy.getNonReferenceType().isConstQualified();
Expr* argDiffStore;
if (passByRef && !argDiff.getExpr()->isEvaluatable(m_Context))
argDiffStore =
Expand Down
11 changes: 8 additions & 3 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -758,11 +758,16 @@ bool TBRAnalyzer::VisitCallExpr(clang::CallExpr* CE) {
setMode(Mode::kMarkingMode | Mode::kNonLinearMode);
for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) {
clang::Expr* arg = CE->getArg(i);
bool passByRef = false;
QualType paramTy;
if (noHiddenParam)
passByRef = FD->getParamDecl(i)->getType()->isReferenceType();
paramTy = FD->getParamDecl(i)->getType();
else if (i != 0)
passByRef = FD->getParamDecl(i - 1)->getType()->isReferenceType();
paramTy = FD->getParamDecl(i - 1)->getType();

bool passByRef = false;
if (!paramTy.isNull())
passByRef = paramTy->isLValueReferenceType() &&
!paramTy.getNonReferenceType().isConstQualified();
setMode(Mode::kMarkingMode | Mode::kNonLinearMode);
TraverseStmt(arg);
resetMode();
Expand Down
32 changes: 9 additions & 23 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -340,11 +340,9 @@ double fn9(double x, double y) {
}

// CHECK:void fn9_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: double _t0 = y;
// CHECK-NEXT: {
// CHECK-NEXT: y = _t0;
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: custom_max_pullback(x * y, _t0, 1, &_r0, &*_d_y);
// CHECK-NEXT: custom_max_pullback(x * y, y, 1, &_r0, &*_d_y);
// CHECK-NEXT: *_d_x += _r0 * y;
// CHECK-NEXT: *_d_y += x * _r0;
// CHECK-NEXT: }
Expand All @@ -362,42 +360,36 @@ double fn10(double x, double y) {
// CHECK-NEXT: double _d_out = 0.;
// CHECK-NEXT: double out = x;
// CHECK-NEXT: double _t0 = out;
// CHECK-NEXT: double _t1 = out;
// CHECK-NEXT: out = std::max(out, 0.);
// CHECK-NEXT: double _t2 = out;
// CHECK-NEXT: double _t3 = out;
// CHECK-NEXT: double _t1 = out;
// CHECK-NEXT: out = std::min(out, 10.);
// CHECK-NEXT: double _t4 = out;
// CHECK-NEXT: double _t5 = out;
// CHECK-NEXT: double _t2 = out;
// CHECK-NEXT: out = std::clamp(out, 3., 7.);
// CHECK-NEXT: {
// CHECK-NEXT: _d_out += 1 * y;
// CHECK-NEXT: *_d_y += out * 1;
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: out = _t4;
// CHECK-NEXT: out = _t2;
// CHECK-NEXT: double _r_d2 = _d_out;
// CHECK-NEXT: _d_out = 0.;
// CHECK-NEXT: out = _t5;
// CHECK-NEXT: double _r2 = 0.;
// CHECK-NEXT: double _r3 = 0.;
// CHECK-NEXT: clad::custom_derivatives::std::clamp_pullback(_t5, 3., 7., _r_d2, &_d_out, &_r2, &_r3);
// CHECK-NEXT: clad::custom_derivatives::std::clamp_pullback(out, 3., 7., _r_d2, &_d_out, &_r2, &_r3);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: out = _t2;
// CHECK-NEXT: out = _t1;
// CHECK-NEXT: double _r_d1 = _d_out;
// CHECK-NEXT: _d_out = 0.;
// CHECK-NEXT: out = _t3;
// CHECK-NEXT: double _r1 = 0.;
// CHECK-NEXT: clad::custom_derivatives::std::min_pullback(_t3, 10., _r_d1, &_d_out, &_r1);
// CHECK-NEXT: clad::custom_derivatives::std::min_pullback(out, 10., _r_d1, &_d_out, &_r1);
// CHECK-NEXT: }
// CHECK-NEXT: {
// CHECK-NEXT: out = _t0;
// CHECK-NEXT: double _r_d0 = _d_out;
// CHECK-NEXT: _d_out = 0.;
// CHECK-NEXT: out = _t1;
// CHECK-NEXT: double _r0 = 0.;
// CHECK-NEXT: clad::custom_derivatives::std::max_pullback(_t1, 0., _r_d0, &_d_out, &_r0);
// CHECK-NEXT: clad::custom_derivatives::std::max_pullback(out, 0., _r_d0, &_d_out, &_r0);
// CHECK-NEXT: }
// CHECK-NEXT: *_d_x += _d_out;
// CHECK-NEXT: }
Expand Down Expand Up @@ -428,13 +420,7 @@ double fn11(double x, double y) {
}

// CHECK: void fn11_grad(double x, double y, double *_d_x, double *_d_y) {
// CHECK-NEXT: double _t0 = x;
// CHECK-NEXT: double _t1 = y;
// CHECK-NEXT: {
// CHECK-NEXT: x = _t0;
// CHECK-NEXT: y = _t1;
// CHECK-NEXT: clad::custom_derivatives::n1::sum_pullback(_t0, _t1, 1, &*_d_x, &*_d_y);
// CHECK-NEXT: }
// CHECK-NEXT: clad::custom_derivatives::n1::sum_pullback(x, y, 1, &*_d_x, &*_d_y);
// CHECK-NEXT: }

double do_nothing(double* u, double* v, double* w) {
Expand Down
Loading

0 comments on commit 332358e

Please sign in to comment.