diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 97c78d519..488eaf6ba 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1565,6 +1565,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, return StmtDiff(Call, zero); } + std::string FDName = FD->getNameAsString(); + if (FDName == "begin" || FDName == "end") { + const Expr* arg = nullptr; + if (const auto* MCE = dyn_cast(CE)) + arg = MCE->getImplicitObjectArgument(); + else + arg = CE->getArg(0); + if (const auto* CXXCE = dyn_cast(arg)) + arg = CXXCE->getArg(0); + StmtDiff argDiff = Visit(arg); + llvm::SmallVector params{argDiff.getExpr()}; + llvm::SmallVector paramsDiff{argDiff.getExpr_dx()}; + Expr* call = GetFunctionCall(FDName, "std", params); + Expr* callDiff = GetFunctionCall(FDName, "std", paramsDiff); + return {call, callDiff}; + } + auto NArgs = FD->getNumParams(); // If the function has no args and is not a member function call then we // assume that it is not related to independent variables and does not