diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 17b392ada..b7819db56 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -3295,19 +3295,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } addToCurrentBlock(condDiff.getStmt_dx(), direction::reverse); bodyDiff = Visit(body); - if (auto* bodyCS = dyn_cast(bodyDiff.getStmt())) - for (Stmt* S : bodyCS->body()) - addToCurrentBlock(S, direction::forward); - else - addToCurrentBlock(bodyDiff.getStmt(), direction::forward); - if (bodyDiff.getStmt_dx()) { - if (auto* bodyDxCS = dyn_cast(bodyDiff.getStmt_dx())) - for (auto iter = bodyDxCS->body_rbegin(), e = bodyDxCS->body_rend(); - iter != e; ++iter) - addToCurrentBlock(*iter, direction::reverse); - else - addToCurrentBlock(bodyDiff.getStmt_dx(), direction::reverse); - } + for (Stmt* S : cast(bodyDiff.getStmt())->body()) + addToCurrentBlock(S, direction::forward); + auto* bodyDxCS = cast(bodyDiff.getStmt_dx()); + for (auto iter = bodyDxCS->body_rbegin(), e = bodyDxCS->body_rend(); + iter != e; ++iter) + addToCurrentBlock(*iter, direction::reverse); addToCurrentBlock(forLoopIncDiff, direction::reverse); bodyDiff.updateStmt(endBlock(direction::forward)); bodyDiff.updateStmtDx(unwrapIfSingleStmt(endBlock(direction::reverse)));