From 71dd694fbdcc6d8b3d91c384e192d840669bcc58 Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Wed, 13 Sep 2023 01:38:51 +0530 Subject: [PATCH] Fix call expr to functor inside a function --- lib/Differentiator/ReverseModeVisitor.cpp | 73 +++++++++++++++-------- test/Gradient/Functors.C | 34 +++++++++++ 2 files changed, 81 insertions(+), 26 deletions(-) diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 9e00797f2..55256330e 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -470,7 +470,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (m_ExternalSource) m_ExternalSource->ActAfterParsingDiffArgs(request, args); - auto derivativeName = request.BaseFunctionName + "_pullback"; + auto derivativeName = + utils::ComputeEffectiveFnName(m_Function) + "_pullback"; auto DNI = utils::BuildDeclarationNameInfo(m_Sema, derivativeName); auto paramTypes = ComputeParamTypes(args); @@ -1412,12 +1413,20 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // statements there later. std::size_t insertionPoint = getCurrentBlock(direction::reverse).size(); + // `CXXOperatorCallExpr` have the `base` expression as the first argument. + size_t skipFirstArg = 0; + + // Here we do not need to check if FD is an instance method or a static + // method because C++ forbids creating operator overloads as static methods. + if (isa(CE) && isa(FD)) + skipFirstArg = 1; + // FIXME: We should add instructions for handling non-differentiable // arguments. Currently we are implicitly assuming function call only // contains differentiable arguments. - for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + for (std::size_t i = skipFirstArg, e = CE->getNumArgs(); i != e; ++i) { const Expr* arg = CE->getArg(i); - auto PVD = FD->getParamDecl(i); + const auto* PVD = FD->getParamDecl(i - skipFirstArg); StmtDiff argDiff{}; bool passByRef = utils::IsReferenceOrPointerType(PVD->getType()); // We do not need to create result arg for arguments passed by reference @@ -1597,26 +1606,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, /// Add base derivative expression in the derived call output args list if /// `CE` is a call to an instance member function. - if (auto MCE = dyn_cast(CE)) { - baseDiff = Visit(MCE->getImplicitObjectArgument()); - StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); - if (isInsideLoop) { - addToCurrentBlock(baseDiffStore.getExpr()); - VarDecl* baseLocalVD = BuildVarDecl( - baseDiffStore.getExpr_dx()->getType(), - CreateUniqueIdentifier("_r"), baseDiffStore.getExpr_dx(), - /*DirectInit=*/false, /*TSI=*/nullptr, - VarDecl::InitializationStyle::CInit); - auto& block = getCurrentBlock(direction::reverse); - block.insert(block.begin() + insertionPoint, - BuildDeclStmt(baseLocalVD)); - insertionPoint += 1; - Expr* baseLocalE = BuildDeclRef(baseLocalVD); - baseDiffStore = {baseDiffStore.getExpr(), baseLocalE}; + if (const auto* MD = dyn_cast(FD)) { + if (MD->isInstance()) { + const Expr* baseOriginalE = nullptr; + if (const auto* MCE = dyn_cast(CE)) + baseOriginalE = MCE->getImplicitObjectArgument(); + else if (const auto* OCE = dyn_cast(CE)) + baseOriginalE = OCE->getArg(0); + + baseDiff = Visit(baseOriginalE); + StmtDiff baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr()); + if (isInsideLoop) { + addToCurrentBlock(baseDiffStore.getExpr()); + VarDecl* baseLocalVD = BuildVarDecl( + baseDiffStore.getExpr_dx()->getType(), + CreateUniqueIdentifier("_r"), baseDiffStore.getExpr_dx(), + /*DirectInit=*/false, /*TSI=*/nullptr, + VarDecl::InitializationStyle::CInit); + auto& block = getCurrentBlock(direction::reverse); + block.insert(block.begin() + insertionPoint, + BuildDeclStmt(baseLocalVD)); + insertionPoint += 1; + Expr* baseLocalE = BuildDeclRef(baseLocalVD); + baseDiffStore = {baseDiffStore.getExpr(), baseLocalE}; + } + baseDiff = {baseDiffStore.getExpr_dx(), baseDiff.getExpr_dx()}; + Expr* baseDerivative = baseDiff.getExpr_dx(); + if (!baseDerivative->getType()->isPointerType()) + baseDerivative = + BuildOp(UnaryOperatorKind::UO_AddrOf, baseDerivative); + DerivedCallOutputArgs.push_back(baseDerivative); } - baseDiff = {baseDiffStore.getExpr_dx(), baseDiff.getExpr_dx()}; - DerivedCallOutputArgs.push_back( - BuildOp(UnaryOperatorKind::UO_AddrOf, baseDiff.getExpr_dx())); } for (auto argDerivative : CallArgDx) { @@ -1689,7 +1709,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackCallArgs = DerivedCallArgs; if (pullback) - pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(), + pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs() - + static_cast(skipFirstArg), pullback); // Try to find it in builtin derivatives @@ -1775,7 +1796,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, usingNumericalDiff = true; } } else if (pullbackFD) { - if (isa(CE)) { + if (baseDiff.getExpr()) { Expr* baseE = baseDiff.getExpr(); OverloadedDerivedFn = BuildCallExprToMemFn( baseE, pullbackFD->getName(), pullbackCallArgs, pullbackFD); @@ -1878,7 +1899,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We cannot reuse the derivatives previously computed because // they might contain 'clad::pop(..)` expression. - if (isa(CE)) { + if (baseDiff.getExpr_dx()) { Expr* derivedBase = baseDiff.getExpr_dx(); // FIXME: We may need this if-block once we support pointers, and // passing pointers-by-reference if @@ -1906,7 +1927,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } else CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get()); } - if (isa(CE)) { + if (baseDiff.getExpr()) { Expr* baseE = baseDiff.getExpr(); call = BuildCallExprToMemFn(baseE, calleeFnForwPassFD->getName(), CallArgs, calleeFnForwPassFD); diff --git a/test/Gradient/Functors.C b/test/Gradient/Functors.C index 7a989836a..0fa4832e7 100644 --- a/test/Gradient/Functors.C +++ b/test/Gradient/Functors.C @@ -173,6 +173,12 @@ struct ExperimentNNS { } // namespace inner } // namespace outer +// A function calling operator() on a functor. +double CallFunctor(double i, double j) { + Experiment E(3, 5); + return E(i, j); +} + #define INIT(E) \ auto E##_grad = clad::gradient(&E); \ auto E##Ref_grad = clad::gradient(E); @@ -298,4 +304,32 @@ int main() { // CHECK-EXEC: 54.00 42.00 TEST_LAMBDA(lambdaWithCapture); // CHECK-EXEC: 54.00 42.00 // CHECK-EXEC: 54.00 42.00 + + // CHECK: void CallFunctor_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { + // CHECK-NEXT: Experiment _d_E({}); + // CHECK-NEXT: double _t0; + // CHECK-NEXT: double _t1; + // CHECK-NEXT: Experiment _t2; + // CHECK-NEXT: Experiment E(3, 5); + // CHECK-NEXT: _t0 = i; + // CHECK-NEXT: _t1 = j; + // CHECK-NEXT: _t2 = E; + // CHECK-NEXT: goto _label0; + // CHECK-NEXT: _label0: + // CHECK-NEXT: { + // CHECK-NEXT: double _grad0 = 0.; + // CHECK-NEXT: double _grad1 = 0.; + // CHECK-NEXT: _t2.operator_call_pullback(_t0, _t1, 1, &_d_E, &_grad0, &_grad1); + // CHECK-NEXT: double _r0 = _grad0; + // CHECK-NEXT: * _d_i += _r0; + // CHECK-NEXT: double _r1 = _grad1; + // CHECK-NEXT: * _d_j += _r1; + // CHECK-NEXT: } + // CHECK-NEXT: } + + // testing differentiating a function calling operator() on a functor + auto CallFunctor_grad = clad::gradient(CallFunctor); + double di = 0, dj = 0; + CallFunctor_grad.execute(7, 9, &di, &dj); + printf("%.2f %.2f\n", di, dj); // CHECK-EXEC: 27.00 21.00 }