From 9e3fabf6a8c5460d28dcb31564eda456706ccd1d Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Tue, 15 Oct 2024 20:49:35 +0300 Subject: [PATCH] Support operators defined outside of classes --- lib/Differentiator/ReverseModeVisitor.cpp | 21 +++++++++++---------- 1 file changed, 11 insertions(+), 10 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 5cfc2ab0d..6eb058942 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1771,14 +1771,17 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // statements there later. std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); - bool isCXXOperatorCall = isa(CE); + const auto* MD = dyn_cast(FD); + // Method operators have a base like methods do but it's included in the + // call arguments so we have to shift the indexing of call arguments. + bool isMethodOperatorCall = MD && isa(CE); - for (std::size_t i = static_cast(isCXXOperatorCall), + for (std::size_t i = static_cast(isMethodOperatorCall), e = CE->getNumArgs(); i != e; ++i) { const Expr* arg = CE->getArg(i); - const auto* PVD = - FD->getParamDecl(i - static_cast(isCXXOperatorCall)); + const auto* PVD = FD->getParamDecl( + i - static_cast(isMethodOperatorCall)); StmtDiff argDiff{}; // We do not need to create result arg for arguments passed by reference // because the derivatives of arguments passed by reference are directly @@ -1887,7 +1890,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, Expr* baseExpr = nullptr; // If it has more args or f_darg0 was not found, we look for its pullback // function. - const auto* MD = dyn_cast(FD); if (!OverloadedDerivedFn) { size_t idx = 0; @@ -1949,7 +1951,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (pullback) pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() - - static_cast(isCXXOperatorCall), + static_cast(isMethodOperatorCall), pullback); // Try to find it in builtin derivatives @@ -2147,7 +2149,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, Loc)); } - for (std::size_t i = static_cast(isCXXOperatorCall), + for (std::size_t i = static_cast(isMethodOperatorCall), e = CE->getNumArgs(); i != e; ++i) { const Expr* arg = CE->getArg(i); @@ -2172,7 +2174,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(resValue, resAdjoint, resAdjoint); } // Recreate the original call expression. - if (const auto* OCE = dyn_cast(CE)) { + if (isMethodOperatorCall) { + const auto* OCE = cast(CE); auto* FD = const_cast( dyn_cast(OCE->getCalleeDecl())); @@ -2198,8 +2201,6 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, CallArgs, Loc) .get(); return StmtDiff(call); - - return {}; } Expr* ReverseModeVisitor::GetMultiArgCentralDiffCall(