From fc3b35cbb03712082ca8c91e3571a0319cac2c06 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 1 Aug 2024 19:50:22 +0300 Subject: [PATCH] Support bitwise, shift, comparison, remainder, not operators. This commit adds support for bitwise, shift, comparison, remainder, and bitwise not operators. Shift operators are considered differentiable since they essentially represent multiplication by ``2^n`` or ``2^-n``, where ``n`` is the RHS of the shift operators ``<<`` and ``>>``. Not operators are considered differentiable as well because they represent ``2^n - 1 - x`` or ``- 1 - x`` (depending on whether the type is signed) so the derivative is ``-_d_x``. Other operators have unclear differentiable effects and so they are considered non-differentiable. Fixes #381. --- lib/Differentiator/BaseForwardModeVisitor.cpp | 24 ++++++++++++++----- test/FirstDerivative/UnsupportedOpsWarn.C | 17 +++++-------- 2 files changed, 24 insertions(+), 17 deletions(-) diff --git a/lib/Differentiator/BaseForwardModeVisitor.cpp b/lib/Differentiator/BaseForwardModeVisitor.cpp index f06120c6c..c53757261 100644 --- a/lib/Differentiator/BaseForwardModeVisitor.cpp +++ b/lib/Differentiator/BaseForwardModeVisitor.cpp @@ -801,8 +801,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) { if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) || condUO) { condDiff = Visit(cond); - if (condDiff.getExpr_dx() && - (!isUnusedResult(condDiff.getExpr_dx()) || condUO)) + if (condDiff.getExpr_dx() && (!isUnusedResult(condDiff.getExpr_dx()))) cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()), BuildParens(condDiff.getExpr())); else @@ -1381,7 +1380,15 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { } else if (opKind == UnaryOperatorKind::UO_AddrOf) { return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx())); } else if (opKind == UnaryOperatorKind::UO_LNot) { - return StmtDiff(op, diff.getExpr_dx()); + Expr* zero = getZeroInit(UnOp->getType()); + if (diff.getExpr_dx() && !isUnusedResult(diff.getExpr_dx())) + return {BuildOp(BO_Comma, BuildParens(diff.getExpr_dx()), op), zero}; + return {op, zero}; + } else if (opKind == UnaryOperatorKind::UO_Not) { + // ~x is 2^n - 1 - x for unsigned types and -x - 1 for the signed ones. + // Either way, taking a derivative gives us -_d_x. + Expr* derivedOp = BuildOp(UO_Minus, diff.getExpr_dx()); + return {op, derivedOp}; } else { unsupportedOpWarn(UnOp->getEndLoc()); auto zero = @@ -1497,7 +1504,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { } else opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()), BuildParens(Rdiff.getExpr_dx())); - } else if (BinOp->isLogicalOp()) { + } else if (BinOp->isLogicalOp() || BinOp->isBitwiseOp() || + BinOp->isComparisonOp() || opCode == BO_Rem) { // For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and // correct derivative execution. auto buildOneSide = [this](StmtDiff& Xdiff) { @@ -1514,8 +1522,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) { // Since the both parts are included in the opDiff, there's no point in // including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left - // nullptr is used for treating expressions like ((A && B) && C) correctly. - return StmtDiff(opDiff, nullptr); + // zero is used for treating expressions like ((A && B) && C) correctly. + return StmtDiff(opDiff, getZeroInit(BinOp->getType())); + } else if (BinOp->isShiftOp()) { + // Shifting is essentially multiplicating the LHS by 2^RHS (or 2^-RHS). + // We should do the same to the derivarive. + opDiff = BuildOp(opCode, Ldiff.getExpr_dx(), Rdiff.getExpr()); } else { // FIXME: add support for other binary operators unsupportedOpWarn(BinOp->getEndLoc()); diff --git a/test/FirstDerivative/UnsupportedOpsWarn.C b/test/FirstDerivative/UnsupportedOpsWarn.C index 0f59ac961..551391407 100644 --- a/test/FirstDerivative/UnsupportedOpsWarn.C +++ b/test/FirstDerivative/UnsupportedOpsWarn.C @@ -6,12 +6,10 @@ //CHECK-NOT: {{.*error|warning|note:.*}} int binOpWarn_0(int x){ - return x << 1; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return x << 1; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int binOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void binOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } @@ -23,17 +21,14 @@ int binOpWarn_1(int x){ // CHECK-NEXT: } int unOpWarn_0(int x){ - return ~x; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}} + return ~x; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}} } -// CHECK: int unOpWarn_0_darg0(int x) { -// CHECK-NEXT: int _d_x = 1; -// CHECK-NEXT: return 0; +// CHECK: void unOpWarn_0_grad(int x, int *_d_x) { // CHECK-NEXT: } int main(){ - - clad::differentiate(binOpWarn_0, 0); + clad::gradient(binOpWarn_0); clad::gradient(binOpWarn_1); - clad::differentiate(unOpWarn_0, 0); + clad::gradient(unOpWarn_0); }