Skip to content

Commit

Permalink
Add a special case to differentiate begin/end calls
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 7, 2024
1 parent 79782b2 commit b66b769
Showing 1 changed file with 17 additions and 0 deletions.
17 changes: 17 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1565,6 +1565,23 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
return StmtDiff(Call, zero);
}

std::string FDName = FD->getNameAsString();
if (FDName == "begin" || FDName == "end") {
const Expr* arg = nullptr;
if (const auto* MCE = dyn_cast<CXXMemberCallExpr>(CE))
arg = MCE->getImplicitObjectArgument();
else
arg = CE->getArg(0);
if (const auto* CXXCE = dyn_cast<CXXConstructExpr>(arg))
arg = CXXCE->getArg(0);
StmtDiff argDiff = Visit(arg);
llvm::SmallVector<Expr*, 1> params{argDiff.getExpr()};
llvm::SmallVector<Expr*, 1> paramsDiff{argDiff.getExpr_dx()};
Expr* call = GetFunctionCall(FDName, "std", params);
Expr* callDiff = GetFunctionCall(FDName, "std", paramsDiff);
return {call, callDiff};
}

auto NArgs = FD->getNumParams();
// If the function has no args and is not a member function call then we
// assume that it is not related to independent variables and does not
Expand Down

0 comments on commit b66b769

Please sign in to comment.