Skip to content

Commit

Permalink
Support bitwise, shift, comparison, remainder, not operators.
Browse files Browse the repository at this point in the history
This commit adds support for bitwise, shift, comparison, remainder, and bitwise not operators. Shift operators are considered differentiable since they essentially represent multiplication by ``2^n`` or ``2^-n``, where ``n`` is the RHS of the shift operators ``<<`` and ``>>``. Not operators are considered differentiable as well because they represent ``2^n - 1 - x`` or ``- 1 - x`` (depending on whether the type is signed) so the derivative is ``-_d_x``. Other operators have unclear differentiable effects and so they are considered non-differentiable.
Fixes vgvassilev#381.
  • Loading branch information
PetroZarytskyi committed Aug 5, 2024
1 parent 1fc6c6c commit fc3b35c
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 17 deletions.
24 changes: 18 additions & 6 deletions lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -801,8 +801,7 @@ StmtDiff BaseForwardModeVisitor::VisitForStmt(const ForStmt* FS) {
if ((condBO && (condBO->isLogicalOp() || condBO->isAssignmentOp())) ||
condUO) {
condDiff = Visit(cond);
if (condDiff.getExpr_dx() &&
(!isUnusedResult(condDiff.getExpr_dx()) || condUO))
if (condDiff.getExpr_dx() && (!isUnusedResult(condDiff.getExpr_dx())))
cond = BuildOp(BO_Comma, BuildParens(condDiff.getExpr_dx()),
BuildParens(condDiff.getExpr()));
else
Expand Down Expand Up @@ -1381,7 +1380,15 @@ StmtDiff BaseForwardModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) {
} else if (opKind == UnaryOperatorKind::UO_AddrOf) {
return StmtDiff(op, BuildOp(opKind, diff.getExpr_dx()));
} else if (opKind == UnaryOperatorKind::UO_LNot) {
return StmtDiff(op, diff.getExpr_dx());
Expr* zero = getZeroInit(UnOp->getType());
if (diff.getExpr_dx() && !isUnusedResult(diff.getExpr_dx()))
return {BuildOp(BO_Comma, BuildParens(diff.getExpr_dx()), op), zero};
return {op, zero};
} else if (opKind == UnaryOperatorKind::UO_Not) {
// ~x is 2^n - 1 - x for unsigned types and -x - 1 for the signed ones.
// Either way, taking a derivative gives us -_d_x.
Expr* derivedOp = BuildOp(UO_Minus, diff.getExpr_dx());
return {op, derivedOp};
} else {
unsupportedOpWarn(UnOp->getEndLoc());
auto zero =
Expand Down Expand Up @@ -1497,7 +1504,8 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
} else
opDiff = BuildOp(BO_Comma, BuildParens(Ldiff.getExpr()),
BuildParens(Rdiff.getExpr_dx()));
} else if (BinOp->isLogicalOp()) {
} else if (BinOp->isLogicalOp() || BinOp->isBitwiseOp() ||
BinOp->isComparisonOp() || opCode == BO_Rem) {
// For (A && B) return ((dA, A) && (dB, B)) to ensure correct evaluation and
// correct derivative execution.
auto buildOneSide = [this](StmtDiff& Xdiff) {
Expand All @@ -1514,8 +1522,12 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {

// Since the both parts are included in the opDiff, there's no point in
// including it as a Stmt_dx. Moreover, the fact that Stmt_dx is left
// nullptr is used for treating expressions like ((A && B) && C) correctly.
return StmtDiff(opDiff, nullptr);
// zero is used for treating expressions like ((A && B) && C) correctly.
return StmtDiff(opDiff, getZeroInit(BinOp->getType()));
} else if (BinOp->isShiftOp()) {
// Shifting is essentially multiplicating the LHS by 2^RHS (or 2^-RHS).
// We should do the same to the derivarive.
opDiff = BuildOp(opCode, Ldiff.getExpr_dx(), Rdiff.getExpr());
} else {
// FIXME: add support for other binary operators
unsupportedOpWarn(BinOp->getEndLoc());
Expand Down
17 changes: 6 additions & 11 deletions test/FirstDerivative/UnsupportedOpsWarn.C
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,10 @@
//CHECK-NOT: {{.*error|warning|note:.*}}

int binOpWarn_0(int x){
return x << 1; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}}
return x << 1; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}}
}

// CHECK: int binOpWarn_0_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: return 0;
// CHECK: void binOpWarn_0_grad(int x, int *_d_x) {
// CHECK-NEXT: }


Expand All @@ -23,17 +21,14 @@ int binOpWarn_1(int x){
// CHECK-NEXT: }

int unOpWarn_0(int x){
return ~x; // expected-warning {{attempt to differentiate unsupported operator, derivative set to 0}}
return ~x; // expected-warning {{attempt to differentiate unsupported operator, ignored.}} set to 0}}
}

// CHECK: int unOpWarn_0_darg0(int x) {
// CHECK-NEXT: int _d_x = 1;
// CHECK-NEXT: return 0;
// CHECK: void unOpWarn_0_grad(int x, int *_d_x) {
// CHECK-NEXT: }

int main(){

clad::differentiate(binOpWarn_0, 0);
clad::gradient(binOpWarn_0);
clad::gradient(binOpWarn_1);
clad::differentiate(unOpWarn_0, 0);
clad::gradient(unOpWarn_0);
}

0 comments on commit fc3b35c

Please sign in to comment.