Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add TBR analysis in the reverse mode of Clad #616

Closed
wants to merge 10 commits into from
Closed
1 change: 1 addition & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ on:
pull_request:
branches:
- master
- tape-push
- coverity_scan

concurrency:
Expand Down
13 changes: 8 additions & 5 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ namespace clad {
std::string ComputeEffectiveFnName(const clang::FunctionDecl* FD);

/// Creates and returns a compound statement having statements as follows:
/// {`S`, all the statement of `initial` in sequence}
/// {`S`, all the statement of `initial` in sequence}
clang::CompoundStmt* PrependAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);
Expand All @@ -38,7 +38,7 @@ namespace clad {
clang::CompoundStmt* AppendAndCreateCompoundStmt(clang::ASTContext& C,
clang::Stmt* initial,
clang::Stmt* S);

/// Shorthand to issues a warning or error.
template <std::size_t N>
void EmitDiag(clang::Sema& semaRef,
Expand Down Expand Up @@ -126,8 +126,8 @@ namespace clad {
///
/// \param S
/// \param namespc
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param shouldExist If true, then asserts that the specified namespace
/// is found.
/// \param DC
clang::NamespaceDecl* LookupNSD(clang::Sema& S, llvm::StringRef namespc,
bool shouldExist,
Expand Down Expand Up @@ -234,7 +234,7 @@ namespace clad {

bool IsCladValueAndPushforwardType(clang::QualType T);

/// Returns a valid `SourceRange` to be used in places where clang
/// Returns a valid `SourceRange` to be used in places where clang
/// requires a valid `SourceRange`.
clang::SourceRange GetValidSRange(clang::Sema& semaRef);

Expand Down Expand Up @@ -313,6 +313,9 @@ namespace clad {
bool hasNonDifferentiableAttribute(const clang::Decl* D);

bool hasNonDifferentiableAttribute(const clang::Expr* E);
/// Finds all the possible expressions E could return a reference to.
/// For example, for 'x = y' it will return the DeclRefExpr* of x.
std::vector<clang::Expr*> GetInnermostReturnExpr(clang::Expr* E);
} // namespace utils
}

Expand Down
14 changes: 14 additions & 0 deletions include/clad/Differentiator/Compatibility.h
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ static inline bool Expr_EvaluateAsInt(const Expr *E,
#endif
}

// Clang 12: bool Expr::EvaluateAsConstantExpr(EvalResult &Result,
// ConstExprUsage Usage, ASTContext &)
// => bool Expr::EvaluateAsConstantExpr(EvalResult &Result, ASTContext &)

static inline bool Expr_EvaluateAsConstantExpr(const Expr* E,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: unused function 'Expr_EvaluateAsConstantExpr' [clang-diagnostic-unused-function]

static inline bool Expr_EvaluateAsConstantExpr(const Expr* E,
                   ^

Expr::EvalResult& res,
const ASTContext& Ctx) {
#if CLANG_VERSION_MAJOR < 12
return E->EvaluateAsConstantExpr(res, Expr::EvaluateForCodeGen, Ctx);
#else
return E->EvaluateAsConstantExpr(res, Ctx);
#endif
}

// Compatibility helper function for creation IfStmt.
// Clang 8 and above use Create.
// Clang 12 and above use two extra params.
Expand Down
70 changes: 52 additions & 18 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ namespace clad {
class ReverseModeVisitor
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>,
public VisitorBase {

private:
// FIXME: We should remove friend-dependency of the plugin classes here.
// For this we will need to separate out AST related functions in
Expand All @@ -42,6 +42,7 @@ namespace clad {
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
std::vector<Stmts> m_Reverse;
std::vector<Stmts> m_EssentialReverse;
/// Stack is used to pass the arguments (dfdx) to further nodes
/// in the Visit method.
std::stack<clang::Expr*> m_Stack;
Expand All @@ -55,10 +56,6 @@ namespace clad {
bool isInsideLoop = false;
/// Output variable of vector-valued function
std::string outputArrayStr;
/// Stores the pop index values for arrays in reverse mode.This is required
/// to maintain the correct statement order when the current block has
/// delayed emission i.e. assignment LHS.
Stmts m_PopIdxValues;
std::vector<Stmts> m_LoopBlock;
unsigned outputArrayCursor = 0;
unsigned numParams = 0;
Expand Down Expand Up @@ -137,15 +134,19 @@ namespace clad {
Stmts& getCurrentBlock(direction d = direction::forward) {
if (d == direction::forward)
return m_Blocks.back();
else
else if (d == direction::reverse)
return m_Reverse.back();
else
return m_EssentialReverse.back();
Comment on lines +137 to +140
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use 'else' after 'return' [llvm-else-after-return]

Suggested change
else if (d == direction::reverse)
return m_Reverse.back();
else
return m_EssentialReverse.back();
if (d == direction::reverse)
return m_Reverse.back();
else
return m_EssentialReverse.back();

}
/// Create new block.
Stmts& beginBlock(direction d = direction::forward) {
if (d == direction::forward)
m_Blocks.push_back({});
else
else if (d == direction::reverse)
m_Reverse.push_back({});
else
m_EssentialReverse.push_back({});
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: use emplace_back instead of push_back [modernize-use-emplace]

Suggested change
m_EssentialReverse.push_back({});
m_EssentialReverse.emplace_back();

return getCurrentBlock(d);
}
/// Remove the block from the stack, wrap it in CompoundStmt and return it.
Expand All @@ -154,13 +155,28 @@ namespace clad {
auto CS = MakeCompoundStmt(getCurrentBlock(direction::forward));
m_Blocks.pop_back();
return CS;
} else {
} else if (d == direction::reverse) {
auto CS = MakeCompoundStmt(getCurrentBlock(direction::reverse));
std::reverse(CS->body_begin(), CS->body_end());
m_Reverse.pop_back();
return CS;
} else {
auto CS = MakeCompoundStmt(getCurrentBlock(d));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: 'auto CS' can be declared as 'auto *CS' [llvm-qualified-auto]

        auto CS = MakeCompoundStmt(getCurrentBlock(d));
        ^

this fix will not be applied because it overlaps with another fix

m_EssentialReverse.pop_back();
return CS;
}
Comment on lines +158 to 167
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: do not use 'else' after 'return' [llvm-else-after-return]

Suggested change
} else if (d == direction::reverse) {
auto CS = MakeCompoundStmt(getCurrentBlock(direction::reverse));
std::reverse(CS->body_begin(), CS->body_end());
m_Reverse.pop_back();
return CS;
} else {
auto CS = MakeCompoundStmt(getCurrentBlock(d));
m_EssentialReverse.pop_back();
return CS;
}
} if (d == direction::reverse) {
auto CS = MakeCompoundStmt(getCurrentBlock(direction::reverse));
std::reverse(CS->body_begin(), CS->body_end());
m_Reverse.pop_back();
return CS;
} else {
auto CS = MakeCompoundStmt(getCurrentBlock(d));
m_EssentialReverse.pop_back();
return CS;
}

}

Stmts EndBlockWithoutCreatingCS(direction d = direction::forward) {
auto blk = getCurrentBlock(d);
if (d == direction::forward)
m_Blocks.pop_back();
else if (d == direction::reverse)
m_Reverse.pop_back();
else
m_EssentialReverse.pop_back();
return blk;
}
/// Output a statement to the current block. If Stmt is null or is an unused
/// expression, it is not output and false is returned.
bool addToCurrentBlock(clang::Stmt* S, direction d = direction::forward) {
Expand Down Expand Up @@ -200,7 +216,15 @@ namespace clad {
// Name reverse temporaries as "_r" instead of "_t".
if ((d == direction::reverse) && (prefix == "_t"))
prefix = "_r";
return VisitorBase::StoreAndRef(E, Type, getCurrentBlock(d), prefix,
Stmts* blk = nullptr;
if (d == direction::essential_reverse) {
if (!m_EssentialReverse.empty())
blk = &getCurrentBlock(direction::essential_reverse);
else
blk = &getCurrentBlock(direction::reverse);
} else
blk = &getCurrentBlock(d);
return VisitorBase::StoreAndRef(E, Type, *blk, prefix,
forceDeclCreation, IS);
}

Expand Down Expand Up @@ -251,6 +275,12 @@ namespace clad {
StmtDiff Result;
bool isConstant;
bool isInsideLoop;
bool needsUpdate;
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult,
bool pIsConstant, bool pIsInsideLoop,
bool pNeedsUpdate = false)
: V(pV), Result(pResult), isConstant(pIsConstant),
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {}
void Finalize(clang::Expr* New);
};

Expand Down Expand Up @@ -394,7 +424,7 @@ namespace clad {
clang::QualType xType);

/// Allows to easily create and manage a counter for counting the number of
/// executed iterations of a loop.
/// executed iterations of a loop.
///
/// It is required to save the number of executed iterations to use the
/// same number of iterations in the reverse pass.
Expand All @@ -413,11 +443,11 @@ namespace clad {
/// for counter; otherwise, returns nullptr.
clang::Expr* getPush() const { return m_Push; }

/// Returns `clad::pop(_t)` expression if clad tape is used for
/// Returns `clad::pop(_t)` expression if clad tape is used for
/// for counter; otherwise, returns nullptr.
clang::Expr* getPop() const { return m_Pop; }

/// Returns reference to the last object of the clad tape if clad tape
/// Returns reference to the last object of the clad tape if clad tape
/// is used as the counter; otherwise returns reference to the counter
/// variable.
clang::Expr* getRef() const { return m_Ref; }
Expand Down Expand Up @@ -459,11 +489,11 @@ namespace clad {

/// This class modifies forward and reverse blocks of the loop
/// body so that `break` and `continue` statements are correctly
/// handled. `break` and `continue` statements are handled by
/// handled. `break` and `continue` statements are handled by
/// enclosing entire reverse block loop body in a switch statement
/// and only executing the statements, with the help of case labels,
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// that were executed in the associated forward iteration. This is
/// determined by keeping track of which `break`/`continue` statement
/// was hit in which iteration and that in turn helps to determine which
/// case label should be selected.
///
Expand Down Expand Up @@ -491,7 +521,7 @@ namespace clad {
/// \note `m_ControlFlowTape` is only initialized if the body contains
/// `continue` or `break` statement.
std::unique_ptr<CladTapeResult> m_ControlFlowTape;

/// Each `break` and `continue` statement is assigned a unique number,
/// starting from 1, that is used as the case label corresponding to that `break`/`continue`
/// statement. `m_CaseCounter` stores the value that was used for last
Expand Down Expand Up @@ -530,7 +560,7 @@ namespace clad {
/// control flow switch statement.
clang::CaseStmt* GetNextCFCaseStmt();

/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)`
/// expression, where `TapeRef` and `m_CurrentCounter` are replaced
/// by their actual values respectively.
clang::Stmt* CreateCFTapePushExprToCurrentCase();
Expand All @@ -553,7 +583,9 @@ namespace clad {
void PopBreakContStmtHandler() {
m_BreakContStmtHandlers.pop_back();
}


std::map<clang::SourceLocation, bool> m_ToBeRecorded;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

warning: member variable 'm_ToBeRecorded' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes]

    std::map<clang::SourceLocation, bool> m_ToBeRecorded;
                                          ^


/// Registers an external RMV source.
///
/// Multiple external RMV source can be registered by calling this function
Expand Down Expand Up @@ -581,6 +613,8 @@ namespace clad {

clang::QualType ComputeAdjointType(clang::QualType T);
clang::QualType ComputeParamType(clang::QualType T);

std::vector<clang::Expr*> GetInnermostReturnExpr(clang::Expr* E);
};
} // end namespace clad

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
namespace clad {
namespace rmv {
/// An enum to operate between forward and reverse passes.
enum direction : int { forward, reverse };
enum direction : int { forward, reverse, essential_reverse };
} // namespace rmv
} // namespace clad

Expand Down
Loading
Loading