Skip to content

Commit

Permalink
Support operators defined outside of classes
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Oct 15, 2024
1 parent 3cf9f46 commit 9e3fabf
Showing 1 changed file with 11 additions and 10 deletions.
21 changes: 11 additions & 10 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1771,14 +1771,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// statements there later.
std::size_t insertionPoint = getCurrentBlock(direction::reverse).size();

bool isCXXOperatorCall = isa<CXXOperatorCallExpr>(CE);
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
// Method operators have a base like methods do but it's included in the
// call arguments so we have to shift the indexing of call arguments.
bool isMethodOperatorCall = MD && isa<CXXOperatorCallExpr>(CE);

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
for (std::size_t i = static_cast<std::size_t>(isMethodOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
const auto* PVD =
FD->getParamDecl(i - static_cast<unsigned long>(isCXXOperatorCall));
const auto* PVD = FD->getParamDecl(
i - static_cast<unsigned long>(isMethodOperatorCall));
StmtDiff argDiff{};
// We do not need to create result arg for arguments passed by reference
// because the derivatives of arguments passed by reference are directly
Expand Down Expand Up @@ -1887,7 +1890,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* baseExpr = nullptr;
// If it has more args or f_darg0 was not found, we look for its pullback
// function.
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
if (!OverloadedDerivedFn) {
size_t idx = 0;

Expand Down Expand Up @@ -1949,7 +1951,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

if (pullback)
pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() -
static_cast<int>(isCXXOperatorCall),
static_cast<int>(isMethodOperatorCall),
pullback);

// Try to find it in builtin derivatives
Expand Down Expand Up @@ -2147,7 +2149,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, Loc));
}

for (std::size_t i = static_cast<std::size_t>(isCXXOperatorCall),
for (std::size_t i = static_cast<std::size_t>(isMethodOperatorCall),
e = CE->getNumArgs();
i != e; ++i) {
const Expr* arg = CE->getArg(i);
Expand All @@ -2172,7 +2174,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(resValue, resAdjoint, resAdjoint);
} // Recreate the original call expression.

if (const auto* OCE = dyn_cast<CXXOperatorCallExpr>(CE)) {
if (isMethodOperatorCall) {
const auto* OCE = cast<CXXOperatorCallExpr>(CE);
auto* FD = const_cast<CXXMethodDecl*>(
dyn_cast<CXXMethodDecl>(OCE->getCalleeDecl()));

Expand All @@ -2198,8 +2201,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
CallArgs, Loc)
.get();
return StmtDiff(call);

return {};
}

Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall(
Expand Down

0 comments on commit 9e3fabf

Please sign in to comment.