-
Notifications
You must be signed in to change notification settings - Fork 122
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add TBR analysis in the reverse mode of Clad #616
Changes from all commits
5471ac1
ff10460
5fa2e00
7d8b432
85e90e0
9376a5d
a2cad57
42bff68
f3e0a92
b54c0b8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -7,6 +7,7 @@ on: | |
pull_request: | ||
branches: | ||
- master | ||
- tape-push | ||
- coverity_scan | ||
|
||
concurrency: | ||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -30,7 +30,7 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
class ReverseModeVisitor | ||||||||||||||||||||||||||||||||||||||||||
: public clang::ConstStmtVisitor<ReverseModeVisitor, StmtDiff>, | ||||||||||||||||||||||||||||||||||||||||||
public VisitorBase { | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
private: | ||||||||||||||||||||||||||||||||||||||||||
// FIXME: We should remove friend-dependency of the plugin classes here. | ||||||||||||||||||||||||||||||||||||||||||
// For this we will need to separate out AST related functions in | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -42,6 +42,7 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
/// the reverse mode we also accumulate Stmts for the reverse pass which | ||||||||||||||||||||||||||||||||||||||||||
/// will be executed on return. | ||||||||||||||||||||||||||||||||||||||||||
std::vector<Stmts> m_Reverse; | ||||||||||||||||||||||||||||||||||||||||||
std::vector<Stmts> m_EssentialReverse; | ||||||||||||||||||||||||||||||||||||||||||
/// Stack is used to pass the arguments (dfdx) to further nodes | ||||||||||||||||||||||||||||||||||||||||||
/// in the Visit method. | ||||||||||||||||||||||||||||||||||||||||||
std::stack<clang::Expr*> m_Stack; | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -55,10 +56,6 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
bool isInsideLoop = false; | ||||||||||||||||||||||||||||||||||||||||||
/// Output variable of vector-valued function | ||||||||||||||||||||||||||||||||||||||||||
std::string outputArrayStr; | ||||||||||||||||||||||||||||||||||||||||||
/// Stores the pop index values for arrays in reverse mode.This is required | ||||||||||||||||||||||||||||||||||||||||||
/// to maintain the correct statement order when the current block has | ||||||||||||||||||||||||||||||||||||||||||
/// delayed emission i.e. assignment LHS. | ||||||||||||||||||||||||||||||||||||||||||
Stmts m_PopIdxValues; | ||||||||||||||||||||||||||||||||||||||||||
std::vector<Stmts> m_LoopBlock; | ||||||||||||||||||||||||||||||||||||||||||
unsigned outputArrayCursor = 0; | ||||||||||||||||||||||||||||||||||||||||||
unsigned numParams = 0; | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -137,15 +134,19 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
Stmts& getCurrentBlock(direction d = direction::forward) { | ||||||||||||||||||||||||||||||||||||||||||
if (d == direction::forward) | ||||||||||||||||||||||||||||||||||||||||||
return m_Blocks.back(); | ||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||
else if (d == direction::reverse) | ||||||||||||||||||||||||||||||||||||||||||
return m_Reverse.back(); | ||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||
return m_EssentialReverse.back(); | ||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+137
to
+140
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: do not use 'else' after 'return' [llvm-else-after-return]
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
/// Create new block. | ||||||||||||||||||||||||||||||||||||||||||
Stmts& beginBlock(direction d = direction::forward) { | ||||||||||||||||||||||||||||||||||||||||||
if (d == direction::forward) | ||||||||||||||||||||||||||||||||||||||||||
m_Blocks.push_back({}); | ||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||
else if (d == direction::reverse) | ||||||||||||||||||||||||||||||||||||||||||
m_Reverse.push_back({}); | ||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||
m_EssentialReverse.push_back({}); | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: use emplace_back instead of push_back [modernize-use-emplace]
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
return getCurrentBlock(d); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
/// Remove the block from the stack, wrap it in CompoundStmt and return it. | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -154,13 +155,28 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
auto CS = MakeCompoundStmt(getCurrentBlock(direction::forward)); | ||||||||||||||||||||||||||||||||||||||||||
m_Blocks.pop_back(); | ||||||||||||||||||||||||||||||||||||||||||
return CS; | ||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||
} else if (d == direction::reverse) { | ||||||||||||||||||||||||||||||||||||||||||
auto CS = MakeCompoundStmt(getCurrentBlock(direction::reverse)); | ||||||||||||||||||||||||||||||||||||||||||
std::reverse(CS->body_begin(), CS->body_end()); | ||||||||||||||||||||||||||||||||||||||||||
m_Reverse.pop_back(); | ||||||||||||||||||||||||||||||||||||||||||
return CS; | ||||||||||||||||||||||||||||||||||||||||||
} else { | ||||||||||||||||||||||||||||||||||||||||||
auto CS = MakeCompoundStmt(getCurrentBlock(d)); | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: 'auto CS' can be declared as 'auto *CS' [llvm-qualified-auto] auto CS = MakeCompoundStmt(getCurrentBlock(d));
^ this fix will not be applied because it overlaps with another fix |
||||||||||||||||||||||||||||||||||||||||||
m_EssentialReverse.pop_back(); | ||||||||||||||||||||||||||||||||||||||||||
return CS; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
Comment on lines
+158
to
167
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: do not use 'else' after 'return' [llvm-else-after-return]
Suggested change
|
||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
Stmts EndBlockWithoutCreatingCS(direction d = direction::forward) { | ||||||||||||||||||||||||||||||||||||||||||
auto blk = getCurrentBlock(d); | ||||||||||||||||||||||||||||||||||||||||||
if (d == direction::forward) | ||||||||||||||||||||||||||||||||||||||||||
m_Blocks.pop_back(); | ||||||||||||||||||||||||||||||||||||||||||
else if (d == direction::reverse) | ||||||||||||||||||||||||||||||||||||||||||
m_Reverse.pop_back(); | ||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||
m_EssentialReverse.pop_back(); | ||||||||||||||||||||||||||||||||||||||||||
return blk; | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
/// Output a statement to the current block. If Stmt is null or is an unused | ||||||||||||||||||||||||||||||||||||||||||
/// expression, it is not output and false is returned. | ||||||||||||||||||||||||||||||||||||||||||
bool addToCurrentBlock(clang::Stmt* S, direction d = direction::forward) { | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -200,7 +216,15 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
// Name reverse temporaries as "_r" instead of "_t". | ||||||||||||||||||||||||||||||||||||||||||
if ((d == direction::reverse) && (prefix == "_t")) | ||||||||||||||||||||||||||||||||||||||||||
prefix = "_r"; | ||||||||||||||||||||||||||||||||||||||||||
return VisitorBase::StoreAndRef(E, Type, getCurrentBlock(d), prefix, | ||||||||||||||||||||||||||||||||||||||||||
Stmts* blk = nullptr; | ||||||||||||||||||||||||||||||||||||||||||
if (d == direction::essential_reverse) { | ||||||||||||||||||||||||||||||||||||||||||
if (!m_EssentialReverse.empty()) | ||||||||||||||||||||||||||||||||||||||||||
blk = &getCurrentBlock(direction::essential_reverse); | ||||||||||||||||||||||||||||||||||||||||||
else | ||||||||||||||||||||||||||||||||||||||||||
blk = &getCurrentBlock(direction::reverse); | ||||||||||||||||||||||||||||||||||||||||||
} else | ||||||||||||||||||||||||||||||||||||||||||
blk = &getCurrentBlock(d); | ||||||||||||||||||||||||||||||||||||||||||
return VisitorBase::StoreAndRef(E, Type, *blk, prefix, | ||||||||||||||||||||||||||||||||||||||||||
forceDeclCreation, IS); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
@@ -251,6 +275,12 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
StmtDiff Result; | ||||||||||||||||||||||||||||||||||||||||||
bool isConstant; | ||||||||||||||||||||||||||||||||||||||||||
bool isInsideLoop; | ||||||||||||||||||||||||||||||||||||||||||
bool needsUpdate; | ||||||||||||||||||||||||||||||||||||||||||
DelayedStoreResult(ReverseModeVisitor& pV, StmtDiff pResult, | ||||||||||||||||||||||||||||||||||||||||||
bool pIsConstant, bool pIsInsideLoop, | ||||||||||||||||||||||||||||||||||||||||||
bool pNeedsUpdate = false) | ||||||||||||||||||||||||||||||||||||||||||
: V(pV), Result(pResult), isConstant(pIsConstant), | ||||||||||||||||||||||||||||||||||||||||||
isInsideLoop(pIsInsideLoop), needsUpdate(pNeedsUpdate) {} | ||||||||||||||||||||||||||||||||||||||||||
void Finalize(clang::Expr* New); | ||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
@@ -394,7 +424,7 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
clang::QualType xType); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
/// Allows to easily create and manage a counter for counting the number of | ||||||||||||||||||||||||||||||||||||||||||
/// executed iterations of a loop. | ||||||||||||||||||||||||||||||||||||||||||
/// executed iterations of a loop. | ||||||||||||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||||||||||||
/// It is required to save the number of executed iterations to use the | ||||||||||||||||||||||||||||||||||||||||||
/// same number of iterations in the reverse pass. | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -413,11 +443,11 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
/// for counter; otherwise, returns nullptr. | ||||||||||||||||||||||||||||||||||||||||||
clang::Expr* getPush() const { return m_Push; } | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
/// Returns `clad::pop(_t)` expression if clad tape is used for | ||||||||||||||||||||||||||||||||||||||||||
/// Returns `clad::pop(_t)` expression if clad tape is used for | ||||||||||||||||||||||||||||||||||||||||||
/// for counter; otherwise, returns nullptr. | ||||||||||||||||||||||||||||||||||||||||||
clang::Expr* getPop() const { return m_Pop; } | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
/// Returns reference to the last object of the clad tape if clad tape | ||||||||||||||||||||||||||||||||||||||||||
/// Returns reference to the last object of the clad tape if clad tape | ||||||||||||||||||||||||||||||||||||||||||
/// is used as the counter; otherwise returns reference to the counter | ||||||||||||||||||||||||||||||||||||||||||
/// variable. | ||||||||||||||||||||||||||||||||||||||||||
clang::Expr* getRef() const { return m_Ref; } | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -459,11 +489,11 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
/// This class modifies forward and reverse blocks of the loop | ||||||||||||||||||||||||||||||||||||||||||
/// body so that `break` and `continue` statements are correctly | ||||||||||||||||||||||||||||||||||||||||||
/// handled. `break` and `continue` statements are handled by | ||||||||||||||||||||||||||||||||||||||||||
/// handled. `break` and `continue` statements are handled by | ||||||||||||||||||||||||||||||||||||||||||
/// enclosing entire reverse block loop body in a switch statement | ||||||||||||||||||||||||||||||||||||||||||
/// and only executing the statements, with the help of case labels, | ||||||||||||||||||||||||||||||||||||||||||
/// that were executed in the associated forward iteration. This is | ||||||||||||||||||||||||||||||||||||||||||
/// determined by keeping track of which `break`/`continue` statement | ||||||||||||||||||||||||||||||||||||||||||
/// that were executed in the associated forward iteration. This is | ||||||||||||||||||||||||||||||||||||||||||
/// determined by keeping track of which `break`/`continue` statement | ||||||||||||||||||||||||||||||||||||||||||
/// was hit in which iteration and that in turn helps to determine which | ||||||||||||||||||||||||||||||||||||||||||
/// case label should be selected. | ||||||||||||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -491,7 +521,7 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
/// \note `m_ControlFlowTape` is only initialized if the body contains | ||||||||||||||||||||||||||||||||||||||||||
/// `continue` or `break` statement. | ||||||||||||||||||||||||||||||||||||||||||
std::unique_ptr<CladTapeResult> m_ControlFlowTape; | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
/// Each `break` and `continue` statement is assigned a unique number, | ||||||||||||||||||||||||||||||||||||||||||
/// starting from 1, that is used as the case label corresponding to that `break`/`continue` | ||||||||||||||||||||||||||||||||||||||||||
/// statement. `m_CaseCounter` stores the value that was used for last | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -530,7 +560,7 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
/// control flow switch statement. | ||||||||||||||||||||||||||||||||||||||||||
clang::CaseStmt* GetNextCFCaseStmt(); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)` | ||||||||||||||||||||||||||||||||||||||||||
/// Builds and returns `clad::push(TapeRef, m_CurrentCounter)` | ||||||||||||||||||||||||||||||||||||||||||
/// expression, where `TapeRef` and `m_CurrentCounter` are replaced | ||||||||||||||||||||||||||||||||||||||||||
/// by their actual values respectively. | ||||||||||||||||||||||||||||||||||||||||||
clang::Stmt* CreateCFTapePushExprToCurrentCase(); | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -553,7 +583,9 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
void PopBreakContStmtHandler() { | ||||||||||||||||||||||||||||||||||||||||||
m_BreakContStmtHandlers.pop_back(); | ||||||||||||||||||||||||||||||||||||||||||
} | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
std::map<clang::SourceLocation, bool> m_ToBeRecorded; | ||||||||||||||||||||||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. warning: member variable 'm_ToBeRecorded' has public visibility [cppcoreguidelines-non-private-member-variables-in-classes] std::map<clang::SourceLocation, bool> m_ToBeRecorded;
^ |
||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
/// Registers an external RMV source. | ||||||||||||||||||||||||||||||||||||||||||
/// | ||||||||||||||||||||||||||||||||||||||||||
/// Multiple external RMV source can be registered by calling this function | ||||||||||||||||||||||||||||||||||||||||||
|
@@ -581,6 +613,8 @@ namespace clad { | |||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
clang::QualType ComputeAdjointType(clang::QualType T); | ||||||||||||||||||||||||||||||||||||||||||
clang::QualType ComputeParamType(clang::QualType T); | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
std::vector<clang::Expr*> GetInnermostReturnExpr(clang::Expr* E); | ||||||||||||||||||||||||||||||||||||||||||
}; | ||||||||||||||||||||||||||||||||||||||||||
} // end namespace clad | ||||||||||||||||||||||||||||||||||||||||||
|
||||||||||||||||||||||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
warning: unused function 'Expr_EvaluateAsConstantExpr' [clang-diagnostic-unused-function]