Skip to content

Commit

Permalink
Differentiate the RHS in multiplication instead of cloning by introdu…
Browse files Browse the repository at this point in the history
…cing placeholders
  • Loading branch information
PetroZarytskyi committed Aug 14, 2024
1 parent 135f6f9 commit 8cb3fc5
Show file tree
Hide file tree
Showing 6 changed files with 94 additions and 60 deletions.
17 changes: 11 additions & 6 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ namespace clad {
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
std::vector<Stmts> m_Reverse;
/// Stores all expressions used as placeholders which have to be
/// reset later.
std::set<const clang::Expr*> m_Placeholders;
/// Storing expressions to delete/free memory in the reverse pass.
Stmts m_DeallocExprs;
/// Stack is used to pass the arguments (dfdx) to further nodes
Expand Down Expand Up @@ -273,13 +276,14 @@ namespace clad {
bool isInsideLoop;
bool isFnScope;
bool needsUpdate;
clang::Expr* Placeholder;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
clang::VarDecl* pDeclaration, bool pIsConstant,
bool pIsInsideLoop, bool pIsFnScope,
bool pNeedsUpdate = false)
clang::VarDecl* pDeclaration, bool pIsInsideLoop,
bool pIsFnScope, bool pNeedsUpdate = false,
clang::Expr* pPlaceholder = nullptr)
: V(pV), Result(pResult), Declaration(pDeclaration),
isConstant(pIsConstant), isInsideLoop(pIsInsideLoop),
isFnScope(pIsFnScope), needsUpdate(pNeedsUpdate) {}
isInsideLoop(pIsInsideLoop), isFnScope(pIsFnScope),
needsUpdate(pNeedsUpdate), Placeholder(pPlaceholder) {}
void Finalize(clang::Expr* New);
};

Expand All @@ -292,7 +296,8 @@ namespace clad {
/// This is what DelayedGlobalStoreAndRef does. E is expected to be the
/// original (uncloned) expression.
DelayedStoreResult DelayedGlobalStoreAndRef(clang::Expr* E,
llvm::StringRef prefix = "_t");
llvm::StringRef prefix = "_t",
bool forceNoRecompute = false);

struct CladTapeResult {
ReverseModeVisitor& V;
Expand Down
114 changes: 68 additions & 46 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2296,25 +2296,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// to reduce cloning complexity and only clones once. Storing it in a
// global variable allows to save current result and make it accessible
// in the reverse pass.
std::unique_ptr<DelayedStoreResult> RDelayed;
StmtDiff RResult;
// If R has no side effects, it can be just cloned
// (no need to store it).

// Check if the local variable declaration is reference type, since it is
// moved to the global scope and the right side should be recomputed
bool promoteToFnScope = false;
if (auto* RDeclRef = dyn_cast<DeclRefExpr>(R->IgnoreImplicit()))
promoteToFnScope = RDeclRef->getDecl()->getType()->isReferenceType() &&
!getCurrentScope()->isFunctionScope();

if (!ShouldRecompute(R) || promoteToFnScope) {
RDelayed = std::unique_ptr<DelayedStoreResult>(
new DelayedStoreResult(DelayedGlobalStoreAndRef(R)));
RResult = RDelayed->Result;
} else {
RResult = StmtDiff(Clone(R));
}
DelayedStoreResult RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff& RResult = RDelayed.Result;

Expr* dl = nullptr;
if (dfdx())
Expand All @@ -2336,30 +2319,27 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
/*force=*/true);
Stmt* LPop = endBlock(direction::reverse);
Expr::EvalResult dummy;
if (RDelayed ||
!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context)) {
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) ||
RDelayed.needsUpdate) {
Expr* dr = nullptr;
if (dfdx())
dr = BuildOp(BO_Mul, LStored.getRevSweepAsExpr(), dfdx());
Rdiff = Visit(R, dr);
// Assign right multiplier's variable with R.
if (RDelayed)
RDelayed->Finalize(Rdiff.getExpr());
RDelayed.Finalize(Rdiff.getExpr());
}
addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse);
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult);
} else if (opCode == BO_Div) {
// xi = xl / xr
// dxi/xl = 1 / xr
// df/dxl += df/dxi * dxi/xl = df/dxi * (1/xr)
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t",
/*forceNoRecompute=*/true);
StmtDiff& RResult = RDelayed.Result;
Expr* dl = nullptr;
if (dfdx())
dl = BuildOp(BO_Div, dfdx(), RStored);
dl = BuildOp(BO_Div, dfdx(), RResult.getExpr());
Ldiff = Visit(L, dl);
StmtDiff LStored = Ldiff;
// Catch the pop statement and emit it after
Expand All @@ -2377,10 +2357,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// df/dxl += df/dxi * dxi/xr = df/dxi * (-xl /(xr * xr))
// Wrap R * R in parentheses: (R * R). otherwise code like 1 / R * R is
// produced instead of 1 / (R * R).
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) ||
RDelayed.needsUpdate) {
Expr* dr = nullptr;
if (dfdx()) {
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expr* RxR = BuildParens(
BuildOp(BO_Mul, RResult.getExpr(), RResult.getExpr()));
dr = BuildOp(BO_Mul, dfdx(),
BuildOp(UO_Minus,
BuildParens(BuildOp(
Expand All @@ -2391,8 +2374,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
RDelayed.Finalize(Rdiff.getExpr());
}
addToCurrentBlock(utils::unwrapIfSingleStmt(LPop), direction::reverse);
std::tie(Ldiff, Rdiff) =
std::make_pair(LStored.getExpr(), RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LStored, RResult);
} else if (BinOp->isAssignmentOp()) {
if (L->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
diag(DiagnosticsEngine::Warning,
Expand Down Expand Up @@ -2588,14 +2570,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* zero = getZeroInit(ResultRef->getType());
addToCurrentBlock(BuildOp(BO_Assign, ResultRef, zero),
direction::reverse);
auto RDelayed = DelayedGlobalStoreAndRef(R);
StmtDiff RResult = RDelayed.Result;
auto RDelayed = DelayedGlobalStoreAndRef(R, /*prefix=*/"_t",
/*forceNoRecompute=*/true);
StmtDiff& RResult = RDelayed.Result;
Expr* RStored =
StoreAndRef(RResult.getRevSweepAsExpr(), direction::reverse);
addToCurrentBlock(BuildOp(BO_AddAssign, ResultRef,
BuildOp(BO_Div, oldValue, RStored)),
direction::reverse);
if (!RDelayed.isConstant) {
Expr::EvalResult dummy;
if (!clad_compat::Expr_EvaluateAsConstantExpr(R, dummy, m_Context) ||
RDelayed.needsUpdate) {
if (isInsideLoop)
addToCurrentBlock(LCloned, direction::forward);
Expr* RxR = BuildParens(BuildOp(BO_Mul, RStored, RStored));
Expand All @@ -2607,7 +2592,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
valueForRevPass = BuildOp(BO_Div, Rdiff.getRevSweepAsExpr(),
Ldiff.getRevSweepAsExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult.getExpr());
std::tie(Ldiff, Rdiff) = std::make_pair(LCloned, RResult);
} else
llvm_unreachable("unknown assignment opCode");
if (m_ExternalSource)
Expand Down Expand Up @@ -3368,8 +3353,39 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

void ReverseModeVisitor::DelayedStoreResult::Finalize(Expr* New) {
if (isConstant || !needsUpdate)
class PlaceholderReplacer
: public RecursiveASTVisitor<PlaceholderReplacer> {
public:
const Expr* placeholder;
Expr* newExpr{nullptr};
PlaceholderReplacer(const Expr* Placeholder) : placeholder(Placeholder) {}

bool VisitExpr(Expr* E) {
for (auto iter = E->child_begin(), e = E->child_end(); iter != e;
++iter)
if (*iter == placeholder)
*iter = newExpr;
else
TraverseStmt(*iter);
return true;
}
};

if (!needsUpdate)
return;

if (Placeholder) {
PlaceholderReplacer repl(Placeholder);
repl.newExpr = New;
for (Stmt* S : V.getCurrentBlock(direction::forward))
repl.TraverseStmt(S);
for (Stmt* S : V.getCurrentBlock(direction::reverse))
repl.TraverseStmt(S);
Result = New;
V.m_Placeholders.erase(Placeholder);
return;
}

if (isInsideLoop) {
auto* Push = cast<CallExpr>(Result.getExpr());
unsigned lastArg = Push->getNumArgs() - 1;
Expand All @@ -3385,22 +3401,30 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

ReverseModeVisitor::DelayedStoreResult
ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E,
llvm::StringRef prefix) {
ReverseModeVisitor::DelayedGlobalStoreAndRef(Expr* E, llvm::StringRef prefix,
bool forceNoRecompute) {
assert(E && "must be provided");
if (!UsefulToStore(E)) {
StmtDiff Ediff = Visit(E);
Expr::EvalResult evalRes;
bool isConst =
clad_compat::Expr_EvaluateAsConstantExpr(E, evalRes, m_Context);
return DelayedStoreResult{*this,
Ediff,
/*Declaration=*/nullptr,
/*isConstant=*/isConst,
/*isInsideLoop=*/false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/false};
}
if (!forceNoRecompute && ShouldRecompute(E)) {
Expr* PH = ConstantFolder::synthesizeLiteral(E->getType(), m_Context, 1);
m_Placeholders.insert(PH);
return DelayedStoreResult{*this,
StmtDiff{PH, nullptr, nullptr, PH},
/*Declaration=*/nullptr,
/*isInsideLoop*/ false,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true,
/*pPlaceholder=*/PH};
}
if (isInsideLoop) {
Expr* dummy = E;
auto CladTape = MakeCladTapeFor(dummy);
Expand All @@ -3409,7 +3433,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return DelayedStoreResult{*this,
StmtDiff{Push, nullptr, nullptr, Pop},
/*Declaration=*/nullptr,
/*isConstant=*/false,
/*isInsideLoop=*/true,
/*isFnScope=*/false,
/*pNeedsUpdate=*/true};
Expand All @@ -3425,7 +3448,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return DelayedStoreResult{*this,
StmtDiff{Ref, nullptr, nullptr, Ref},
/*Declaration=*/VD,
/*isConstant=*/false,
/*isInsideLoop=*/false,
/*isFnScope=*/isFnScope,
/*pNeedsUpdate=*/true};
Expand Down
8 changes: 7 additions & 1 deletion test/ErrorEstimation/LoopsAndArrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -226,8 +226,8 @@ double func5(double* x, double* y, double* output) {

//CHECK: void func5_grad(double *x, double *y, double *output, double *_d_x, double *_d_y, double *_d_output, double &_final_error) {
//CHECK-NEXT: unsigned {{int|long}} output_size = 0;
//CHECK-NEXT: unsigned {{int|long}} x_size = 0;
//CHECK-NEXT: unsigned {{int|long}} y_size = 0;
//CHECK-NEXT: unsigned {{int|long}} x_size = 0;
//CHECK-NEXT: double _ret_value0 = 0;
//CHECK-NEXT: double _t0 = output[0];
//CHECK-NEXT: output[0] = x[1] * y[2] - x[2] * y[1];
Expand All @@ -249,10 +249,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[2] = _t2;
//CHECK-NEXT: double _r_d2 = _d_output[2];
//CHECK-NEXT: _d_output[2] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: _d_x[0] += _r_d2 * y[1];
//CHECK-NEXT: x_size = std::max(x_size, 0);
//CHECK-NEXT: _d_y[1] += x[0] * _r_d2;
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: x_size = std::max(x_size, 1);
//CHECK-NEXT: _d_y[0] += -_r_d2 * x[1];
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: _d_x[1] += y[0] * -_r_d2;
Expand All @@ -264,10 +266,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[1] = _t1;
//CHECK-NEXT: double _r_d1 = _d_output[1];
//CHECK-NEXT: _d_output[1] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: _d_x[2] += _r_d1 * y[0];
//CHECK-NEXT: x_size = std::max(x_size, 2);
//CHECK-NEXT: _d_y[0] += x[2] * _r_d1;
//CHECK-NEXT: y_size = std::max(y_size, 0);
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: _d_x[0] += -_r_d1 * y[2];
//CHECK-NEXT: x_size = std::max(x_size, 0);
//CHECK-NEXT: _d_y[2] += x[0] * -_r_d1;
Expand All @@ -279,10 +283,12 @@ double func5(double* x, double* y, double* output) {
//CHECK-NEXT: output[0] = _t0;
//CHECK-NEXT: double _r_d0 = _d_output[0];
//CHECK-NEXT: _d_output[0] = 0;
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: _d_x[1] += _r_d0 * y[2];
//CHECK-NEXT: x_size = std::max(x_size, 1);
//CHECK-NEXT: _d_y[2] += x[1] * _r_d0;
//CHECK-NEXT: y_size = std::max(y_size, 2);
//CHECK-NEXT: y_size = std::max(y_size, 1);
//CHECK-NEXT: _d_x[2] += -_r_d0 * y[1];
//CHECK-NEXT: x_size = std::max(x_size, 2);
//CHECK-NEXT: _d_y[1] += x[2] * -_r_d0;
Expand Down
3 changes: 2 additions & 1 deletion test/ErrorEstimation/LoopsAndArraysExec.C
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,8 @@ double mulSum(float* a, float* b, int n) {
//CHECK-NEXT: int _d_j = 0;
//CHECK-NEXT: int j = 0;
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: unsigned {{int|long}} a_size = 0;
//CHECK-NEXT: unsigned {{int|long}} b_size = 0;
//CHECK-NEXT: unsigned {{int|long}} a_size = 0;
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
Expand Down Expand Up @@ -111,6 +111,7 @@ double mulSum(float* a, float* b, int n) {
//CHECK-NEXT: _final_error += std::abs(_d_sum * sum * {{.+}});
//CHECK-NEXT: sum = clad::pop(_t3);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: b_size = std::max(b_size, j);
//CHECK-NEXT: _d_a[i] += _r_d0 * b[j];
//CHECK-NEXT: a_size = std::max(a_size, i);
//CHECK-NEXT: _d_b[j] += a[i] * _r_d0;
Expand Down
2 changes: 1 addition & 1 deletion test/Gradient/NonDifferentiable.C
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ int main() {
// CHECK-NEXT: SimpleFunctions1 _d_obj({});
// CHECK-NEXT: SimpleFunctions1 obj(2, 3);
// CHECK-NEXT: {
// CHECK-NEXT: *_d_obj.x_pointer += 1 * (*obj.y_pointer);
// CHECK-NEXT: *_d_obj.x_pointer += 1 * *obj.y_pointer;
// CHECK-NEXT: *_d_i += 1 * j;
// CHECK-NEXT: *_d_j += i * 1;
// CHECK-NEXT: }
Expand Down
10 changes: 5 additions & 5 deletions test/Gradient/Pointers.C
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ double minimalPointer(double x) {
// CHECK-NEXT: double *_d_p = &*_d_x;
// CHECK-NEXT: double *const p = &x;
// CHECK-NEXT: double _t0 = *p;
// CHECK-NEXT: *p = *p * (*p);
// CHECK-NEXT: *p = *p * *p;
// CHECK-NEXT: *_d_p += 1;
// CHECK-NEXT: {
// CHECK-NEXT: *p = _t0;
// CHECK-NEXT: double _r_d0 = *_d_p;
// CHECK-NEXT: *_d_p = 0;
// CHECK-NEXT: *_d_p += _r_d0 * (*p);
// CHECK-NEXT: *_d_p += _r_d0 * *p;
// CHECK-NEXT: *_d_p += *p * _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: }
Expand Down Expand Up @@ -87,7 +87,7 @@ double arrayPointer(const double* arr) {
// CHECK-NEXT: _d_p = _d_p - 2;
// CHECK-NEXT: p = p - 2;
// CHECK-NEXT: double _t11 = sum;
// CHECK-NEXT: sum += 5 * (*p);
// CHECK-NEXT: sum += 5 * *p;
// CHECK-NEXT: _d_sum += 1;
// CHECK-NEXT: {
// CHECK-NEXT: sum = _t11;
Expand Down Expand Up @@ -170,7 +170,7 @@ double pointerParam(const double* arr, size_t n) {
// CHECK-NEXT: clad::push(_t1, _d_j);
// CHECK-NEXT: clad::push(_t3, j) , j = &i;
// CHECK-NEXT: clad::push(_t4, sum);
// CHECK-NEXT: sum += arr[0] * (*j);
// CHECK-NEXT: sum += arr[0] * *j;
// CHECK-NEXT: clad::push(_t5, arr);
// CHECK-NEXT: clad::push(_t6, _d_arr);
// CHECK-NEXT: _d_arr = _d_arr + 1;
Expand All @@ -191,7 +191,7 @@ double pointerParam(const double* arr, size_t n) {
// CHECK-NEXT: {
// CHECK-NEXT: sum = clad::pop(_t4);
// CHECK-NEXT: double _r_d0 = _d_sum;
// CHECK-NEXT: _d_arr[0] += _r_d0 * (*j);
// CHECK-NEXT: _d_arr[0] += _r_d0 * *j;
// CHECK-NEXT: *_t2 += arr[0] * _r_d0;
// CHECK-NEXT: }
// CHECK-NEXT: j = clad::pop(_t3);
Expand Down

0 comments on commit 8cb3fc5

Please sign in to comment.