Skip to content

Commit

Permalink
Improve the constness of GetInnermostReturnExpr
Browse files Browse the repository at this point in the history
  • Loading branch information
vgvassilev committed Nov 20, 2023
1 parent 2140789 commit 7703d74
Show file tree
Hide file tree
Showing 4 changed files with 39 additions and 33 deletions.
8 changes: 6 additions & 2 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -316,8 +316,12 @@ namespace clad {
bool hasNonDifferentiableAttribute(const clang::Decl* D);

bool hasNonDifferentiableAttribute(const clang::Expr* E);
/// FIXME: add documentation
std::vector<clang::Expr*> GetInnermostReturnExpr(const clang::Expr* E);

/// Collects every DeclRefExpr, MemberExpr, ArraySubscriptExpr in an
/// assignment operator or a ternary if operator. This is useful to when we
/// need to decide what needs to be stored on tape in reverse mode.
void GetInnermostReturnExpr(const clang::Expr* E,
llvm::SmallVectorImpl<clang::Expr*>& Exprs);
} // namespace utils
}

Expand Down
39 changes: 18 additions & 21 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -544,17 +544,18 @@ namespace clad {
return false;
}

std::vector<clang::Expr*> GetInnermostReturnExpr(const clang::Expr* E) {
struct Finder : public ConstStmtVisitor<Finder> {
std::vector<clang::Expr*> m_return_exprs;
void GetInnermostReturnExpr(const clang::Expr* E,
llvm::SmallVectorImpl<clang::Expr*>& Exprs) {
struct Finder : public StmtVisitor<Finder> {
llvm::SmallVectorImpl<clang::Expr*>& m_Exprs;

public:
std::vector<clang::Expr*> Find(const clang::Expr* E) {
Finder(clang::Expr* E, llvm::SmallVectorImpl<clang::Expr*>& Exprs)
: m_Exprs(Exprs) {
Visit(E);
return m_return_exprs;
}

void VisitBinaryOperator(const clang::BinaryOperator* BO) {
void VisitBinaryOperator(clang::BinaryOperator* BO) {
if (BO->isAssignmentOp() || BO->isCompoundAssignmentOp()) {
Visit(BO->getLHS());
} else if (BO->getOpcode() == clang::BO_Comma) {
Expand All @@ -564,41 +565,37 @@ namespace clad {
}
}

void VisitConditionalOperator(const clang::ConditionalOperator* CO) {
void VisitConditionalOperator(clang::ConditionalOperator* CO) {
// FIXME: in cases like (cond ? x : y) = 2; both x and y will be
// stored.
Visit(CO->getTrueExpr());
Visit(CO->getFalseExpr());
}

void VisitUnaryOperator(const clang::UnaryOperator* UnOp) {
void VisitUnaryOperator(clang::UnaryOperator* UnOp) {
auto opCode = UnOp->getOpcode();
if (opCode == clang::UO_PreInc || opCode == clang::UO_PreDec)
Visit(UnOp->getSubExpr());
}

void VisitDeclRefExpr(const clang::DeclRefExpr* DRE) {
m_return_exprs.push_back(const_cast<clang::DeclRefExpr*>(DRE));
void VisitDeclRefExpr(clang::DeclRefExpr* DRE) {
m_Exprs.push_back(DRE);
}

void VisitParenExpr(const clang::ParenExpr* PE) {
Visit(PE->getSubExpr());
}
void VisitParenExpr(clang::ParenExpr* PE) { Visit(PE->getSubExpr()); }

void VisitMemberExpr(const clang::MemberExpr* ME) {
m_return_exprs.push_back(const_cast<clang::MemberExpr*>(ME));
}
void VisitMemberExpr(clang::MemberExpr* ME) { m_Exprs.push_back(ME); }

void VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE) {
m_return_exprs.push_back(const_cast<clang::ArraySubscriptExpr*>(ASE));
void VisitArraySubscriptExpr(clang::ArraySubscriptExpr* ASE) {
m_Exprs.push_back(ASE);
}

void VisitImplicitCastExpr(const clang::ImplicitCastExpr* ICE) {
void VisitImplicitCastExpr(clang::ImplicitCastExpr* ICE) {
Visit(ICE->getSubExpr());
}
};
Finder finder;
return finder.Find(E);
// FIXME: Fix the constness on the callers of this function.
Finder finder(const_cast<clang::Expr*>(E), Exprs);
}

bool IsAutoOrAutoPtrType(QualType T) {
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2330,7 +2330,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Ldiff = Visit(L, dfdx());
Stmts essentialRevBlock = EndBlockWithoutCreatingCS(direction::essential_reverse);
auto* Lblock = endBlock(direction::reverse);
auto return_exprs = utils::GetInnermostReturnExpr(Ldiff.getExpr());
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(Ldiff.getExpr(), ExprsToStore);
if (L->HasSideEffects(m_Context)) {
Expr* E = Ldiff.getExpr();
auto* storeE = StoreAndRef(E, m_Context.getLValueReferenceType(E->getType()));
Expand Down Expand Up @@ -2368,7 +2369,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Lblock_begin = std::next(Lblock_begin);
}

for (auto& E : return_exprs) {
for (auto& E : ExprsToStore) {
auto pushPop = StoreAndRestore(E);
addToCurrentBlock(pushPop.getExpr(), direction::forward);
addToCurrentBlock(pushPop.getExpr_dx(), direction::reverse);
Expand Down
20 changes: 12 additions & 8 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -599,10 +599,11 @@ void TBRAnalyzer::VisitDeclStmt(const DeclStmt* DS) {
auto& VDExpr = getCurBlockVarsData()[VD];
/// if the declared variable is ref type attach its VarData to the
/// VarData of the RHS variable.
auto returnExprs = utils::GetInnermostReturnExpr(init);
if (VDExpr.type == VarData::REF_TYPE && !returnExprs.empty())
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(init, ExprsToStore);
if (VDExpr.type == VarData::REF_TYPE && !ExprsToStore.empty())
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-union-access)
VDExpr.val.m_RefData = returnExprs[0];
VDExpr.val.m_RefData = ExprsToStore[0];
}
}
}
Expand Down Expand Up @@ -695,8 +696,9 @@ void TBRAnalyzer::VisitBinaryOperator(const BinaryOperator* BinOp) {
Visit(R);
resetMode();
}
const auto returnExprs = utils::GetInnermostReturnExpr(L);
for (const auto* innerExpr : returnExprs) {
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(L, ExprsToStore);
for (const auto* innerExpr : ExprsToStore) {
/// Mark corresponding SourceLocation as required/not required to be
/// stored for all expressions that could be used changed.
markLocation(innerExpr);
Expand Down Expand Up @@ -726,8 +728,9 @@ void TBRAnalyzer::VisitUnaryOperator(const clang::UnaryOperator* UnOp) {
// FIXME: this doesn't support all the possible references
/// Mark corresponding SourceLocation as required/not required to be
/// stored for all expressions that could be used in this operation.
const auto innerExprs = utils::GetInnermostReturnExpr(E);
for (const auto* innerExpr : innerExprs) {
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(E, ExprsToStore);
for (const auto* innerExpr : ExprsToStore) {
/// Mark corresponding SourceLocation as required/not required to be
/// stored for all expressions that could be changed.
markLocation(innerExpr);
Expand All @@ -754,7 +757,8 @@ void TBRAnalyzer::VisitCallExpr(const clang::CallExpr* CE) {
resetMode();
const auto* B = arg->IgnoreParenImpCasts();
// FIXME: this supports only DeclRefExpr
const auto innerExpr = utils::GetInnermostReturnExpr(arg);
llvm::SmallVector<Expr*, 4> ExprsToStore;
utils::GetInnermostReturnExpr(arg, ExprsToStore);
if (passByRef) {
/// Mark SourceLocation as required to store for ref-type arguments.
if (isa<DeclRefExpr>(B) || isa<MemberExpr>(B)) {
Expand Down

0 comments on commit 7703d74

Please sign in to comment.