Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 12, 2023
1 parent ec5ea41 commit f672ce0
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 4 deletions.
23 changes: 19 additions & 4 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,8 @@ namespace clad {
/// the reverse mode we also accumulate Stmts for the reverse pass which
/// will be executed on return.
std::vector<Stmts> m_Reverse;
/// Accumulates local variables for all visited blocks.
std::vector<std::vector<clang::VarDecl*>> m_Locals;
/// Stack is used to pass the arguments (dfdx) to further nodes
/// in the Visit method.
std::stack<clang::Expr*> m_Stack;
Expand Down Expand Up @@ -142,15 +144,23 @@ namespace clad {
}
/// Create new block.
Stmts& beginBlock(direction d = direction::forward) {
if (d == direction::forward)
if (d == direction::forward) {
m_Blocks.emplace_back();
else
m_Locals.emplace_back();
} else {
m_Reverse.emplace_back();
}
return getCurrentBlock(d);
}
/// Remove the block from the stack, wrap it in CompoundStmt and return it.
clang::CompoundStmt* endBlock(direction d = direction::forward) {
if (d == direction::forward) {
for (clang::VarDecl* VD : m_Locals.back()) {
clang::Stmt* declStmt = BuildDeclStmt(VD);
addToCurrentBlock(declStmt, direction::reverse);
}
m_Locals.pop_back();

auto* CS = MakeCompoundStmt(getCurrentBlock(direction::forward));
m_Blocks.pop_back();
return CS;
Expand All @@ -164,9 +174,14 @@ namespace clad {

Stmts EndBlockWithoutCreatingCS(direction d = direction::forward) {
auto blk = getCurrentBlock(d);
if (d == direction::forward)
if (d == direction::forward) {
m_Blocks.pop_back();
else
for (clang::VarDecl* VD : m_Locals.back()) {
clang::Stmt* declStmt = BuildDeclStmt(VD);
addToCurrentBlock(declStmt, direction::reverse);
}
m_Locals.pop_back();
} else
m_Reverse.pop_back();
return blk;
}
Expand Down
2 changes: 2 additions & 0 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2613,6 +2613,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}
m_Variables.emplace(VDClone, derivedVDE);
if (m_Blocks.size()!=1)
m_Locals.back().push_back(VDClone);

return VarDeclDiff(VDClone, VDDerived);
}
Expand Down

0 comments on commit f672ce0

Please sign in to comment.