Skip to content

Commit

Permalink
Remove Redundant Goto Statements
Browse files Browse the repository at this point in the history
Solves issue #526. We can only remove the 'goto' label when the parent of the parent of the return statement is null.
  • Loading branch information
ShounakDas101 committed Jul 26, 2023
1 parent 2d24f37 commit 40fdd2b
Show file tree
Hide file tree
Showing 35 changed files with 38 additions and 416 deletions.
5 changes: 4 additions & 1 deletion include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
#include "clang/AST/RecursiveASTVisitor.h"
#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "clang/AST/ParentMapContext.h"
#include <array>
#include <memory>
#include <stack>
Expand Down Expand Up @@ -91,6 +91,9 @@ namespace clad {

// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();

// Whether Stmt is Return and not inside any block;
bool OnlyReturn = false;

public:
using direction = rmv::direction;
Expand Down
37 changes: 30 additions & 7 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include "clad/Differentiator/StmtClone.h"
#include "clad/Differentiator/ExternalRMVSource.h"
#include "clad/Differentiator/MultiplexExternalRMVSource.h"

#include "clang/AST/ParentMapContext.h"
#include "clang/AST/ASTContext.h"
#include "clang/AST/Expr.h"
#include "clang/AST/TemplateBase.h"
Expand Down Expand Up @@ -744,6 +744,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
beginBlock(direction::forward);
beginBlock(direction::reverse);
for (Stmt* S : CS->body()) {
std::string cppString="ReturnStmt";
const char* cString = S->getStmtClassName();
if(cppString.compare(cString) == 0 ){
auto parents = m_Context.getParents(*CS);
if (!parents.empty()){
const Stmt* parentStmt = parents[0].get<Stmt>();
if(parentStmt==nullptr)
OnlyReturn=true;
else
OnlyReturn=false;
}
}
if (m_ExternalSource)
m_ExternalSource->ActBeforeDifferentiatingStmtInVisitCompoundStmt();
StmtDiff SDiff = DifferentiateSingleStmt(S);
Expand Down Expand Up @@ -1153,14 +1165,22 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// If the original function returns at this point, some part of the reverse
// pass (corresponding to other branches that do not return here) must be
// skipped. We create a label in the reverse pass and jump to it via goto.
LabelDecl* LD = LabelDecl::Create(
m_Context, m_Sema.CurContext, noLoc, CreateUniqueIdentifier("_label"));
m_Sema.PushOnScopeChains(LD, m_DerivativeFnScope, true);
LabelDecl* LD = nullptr;
if (!OnlyReturn) {
LD = LabelDecl::Create(
m_Context, m_Sema.CurContext, noLoc, CreateUniqueIdentifier("_label"));
m_Sema.PushOnScopeChains(LD, m_DerivativeFnScope, true);
}
// Attach label to the last Stmt in the corresponding Reverse Stmt.
if (!Reverse)
Reverse = m_Sema.ActOnNullStmt(noLoc).get();
Stmt* LS = m_Sema.ActOnLabelStmt(noLoc, LD, noLoc, Reverse).get();
addToCurrentBlock(LS, direction::reverse);
if (!OnlyReturn) {
Stmt* LS = m_Sema.ActOnLabelStmt(noLoc, LD, noLoc, Reverse).get();
addToCurrentBlock(LS, direction::reverse);
}
else {
addToCurrentBlock(Reverse, direction::reverse);
}
for (Stmt* S : cast<CompoundStmt>(ReturnDiff.getStmt())->body())
addToCurrentBlock(S, direction::forward);

Expand All @@ -1175,7 +1195,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}

// Create goto to the label.
return m_Sema.ActOnGotoStmt(noLoc, noLoc, LD).get();
if (!OnlyReturn)
return m_Sema.ActOnGotoStmt(noLoc, noLoc, LD).get();

return nullptr;
}

StmtDiff ReverseModeVisitor::VisitParenExpr(const ParenExpr* PE) {
Expand Down
18 changes: 0 additions & 18 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ double addArr(double *arr, int n) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: ret += arr[clad::push(_t1, i)];
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_ret += _d_y;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
Expand All @@ -45,8 +43,6 @@ double f(double *arr) {
//CHECK: void f_grad(double *arr, clad::array_ref<double> _d_arr) {
//CHECK-NEXT: double *_t0;
//CHECK-NEXT: _t0 = arr;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: int _grad1 = 0;
//CHECK-NEXT: addArr_pullback(_t0, 3, 1, _d_arr, &_grad1);
Expand Down Expand Up @@ -82,8 +78,6 @@ float func(float* a, float* b) {
//CHECK-NEXT: _ref0 *= clad::push(_t3, b[clad::push(_t5, i)]);
//CHECK-NEXT: sum += a[clad::push(_t7, i)];
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
Expand Down Expand Up @@ -113,8 +107,6 @@ float helper(float x) {
// CHECK: void helper_pullback(float x, float _d_y, clad::array_ref<float> _d_x) {
// CHECK-NEXT: float _t0;
// CHECK-NEXT: _t0 = x;
// CHECK-NEXT: goto _label0;
// CHECK-NEXT: _label0:
// CHECK-NEXT: {
// CHECK-NEXT: float _r0 = _d_y * _t0;
// CHECK-NEXT: float _r1 = 2 * _d_y;
Expand All @@ -141,8 +133,6 @@ float func2(float* a) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: sum += helper(clad::push(_t3, a[clad::push(_t1, i)]));
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: float _r_d0 = _d_sum;
Expand Down Expand Up @@ -175,8 +165,6 @@ float func3(float* a, float* b) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: sum += (a[clad::push(_t1, i)] += b[clad::push(_t3, i)]);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: float _r_d0 = _d_sum;
Expand Down Expand Up @@ -221,8 +209,6 @@ double func4(double x) {
//CHECK-NEXT: clad::push(_t4, arr , 3UL);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t3; _t3--) {
//CHECK-NEXT: {
Expand Down Expand Up @@ -286,8 +272,6 @@ double func5(int k) {
//CHECK-NEXT: clad::push(_t4, arr , n);
//CHECK-NEXT: sum += addArr(arr, clad::push(_t5, n));
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t3; _t3--) {
//CHECK-NEXT: {
Expand Down Expand Up @@ -339,8 +323,6 @@ double func6(double seed) {
//CHECK-NEXT: clad::push(_t3, arr , 3UL);
//CHECK-NEXT: sum += addArr(arr, 3);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
Expand Down
4 changes: 0 additions & 4 deletions test/Arrays/Arrays.C
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,6 @@ double const_dot_product(double x, double y, double z) {
//CHECK-NEXT: _t2 = consts[1];
//CHECK-NEXT: _t5 = vars[2];
//CHECK-NEXT: _t4 = consts[2];
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = 1 * _t0;
//CHECK-NEXT: _d_vars[0] += _r0;
Expand Down Expand Up @@ -193,8 +191,6 @@ double const_matmul_sum(double a, double b, double c, double d) {
//: _t15 = A[1][1];
//: _t14 = B[1][1];
//: double C[2][2] = {{[{][{]}}_t1 * _t0 + _t3 * _t2, _t5 * _t4 + _t7 * _t6}, {_t9 * _t8 + _t11 * _t10, _t13 * _t12 + _t15 * _t14}};
//: goto _label0;
//: _label0:
//: {
//: _d_C[0][0] += 1;
//: _d_C[0][1] += 1;
Expand Down
14 changes: 0 additions & 14 deletions test/ErrorEstimation/Assignments.C
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ float func(float x, float y) {
//CHECK-NEXT: x = x + y;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: y = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_y += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d1 = * _d_y;
Expand Down Expand Up @@ -61,8 +59,6 @@ float func2(float x, int y) {
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: x = _t1 * _t0 + _t3 * _t2;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = * _d_x;
Expand All @@ -89,8 +85,6 @@ float func3(int x, int y) {

//CHECK: void func3_grad(int x, int y, clad::array_ref<int> _d_x, clad::array_ref<int> _d_y, double &_final_error) {
//CHECK-NEXT: x = y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_y += 1;
//CHECK-NEXT: {
//CHECK-NEXT: int _r_d0 = * _d_x;
Expand All @@ -117,8 +111,6 @@ float func4(float x, float y) {
//CHECK-NEXT: _EERepl_z0 = z;
//CHECK-NEXT: x = z + y;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = * _d_x;
Expand Down Expand Up @@ -149,8 +141,6 @@ float func5(float x, float y) {
//CHECK-NEXT: int z = 56;
//CHECK-NEXT: x = z + y;
//CHECK-NEXT: _EERepl_x1 = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = * _d_x;
Expand All @@ -169,8 +159,6 @@ float func5(float x, float y) {
float func6(float x) { return x; }

//CHECK: void func6_grad(float x, clad::array_ref<float> _d_x, double &_final_error) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: double _delta_x = 0;
//CHECK-NEXT: _delta_x += std::abs(* _d_x * x * {{.+}});
Expand All @@ -186,8 +174,6 @@ float func7(float x, float y) { return (x * y); }
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: _ret_value0 = (_t1 * _t0);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r0 = 1 * _t0;
//CHECK-NEXT: * _d_x += _r0;
Expand Down
22 changes: 0 additions & 22 deletions test/ErrorEstimation/BasicOps.C
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ float func(float x, float y) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: float z = _t1 * _t0;
//CHECK-NEXT: _EERepl_z0 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r0 = _d_z * _t0;
Expand Down Expand Up @@ -95,8 +93,6 @@ float func2(float x, float y) {
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: float z = _t3 / _t2;
//CHECK-NEXT: _EERepl_z0 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r2 = _d_z / _t2;
Expand Down Expand Up @@ -164,8 +160,6 @@ float func3(float x, float y) {
//CHECK-NEXT: float t = _t5 * _t2;
//CHECK-NEXT: _EERepl_t0 = t;
//CHECK-NEXT: _EERepl_y1 = y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_t += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r2 = _d_t * _t2;
Expand Down Expand Up @@ -210,8 +204,6 @@ float func4(float x, float y) { return std::pow(x, y); }
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _t1 = y;
//CHECK-NEXT: _ret_value0 = std::pow(_t0, _t1);
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _grad0 = 0.F;
//CHECK-NEXT: float _grad1 = 0.F;
Expand Down Expand Up @@ -248,8 +240,6 @@ float func5(float x, float y) {
//CHECK-NEXT: _t2 = y;
//CHECK-NEXT: _t1 = y;
//CHECK-NEXT: _ret_value0 = _t2 * _t1;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r1 = 1 * _t1;
//CHECK-NEXT: * _d_y += _r1;
Expand Down Expand Up @@ -280,8 +270,6 @@ double helper(double x, double y) { return x * y; }
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: _ret_value0 = _t1 * _t0;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y0 * _t0;
//CHECK-NEXT: * _d_x += _r0;
Expand Down Expand Up @@ -316,8 +304,6 @@ float func6(float x, float y) {
//CHECK-NEXT: _t4 = z;
//CHECK-NEXT: _t3 = z;
//CHECK-NEXT: _ret_value0 = _t4 * _t3;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r2 = 1 * _t3;
//CHECK-NEXT: _d_z += _r2;
Expand Down Expand Up @@ -350,8 +336,6 @@ float func7(float x) {
//CHECK: void func7_grad(float x, clad::array_ref<float> _d_x, double &_final_error) {
//CHECK-NEXT: int _d_z = 0;
//CHECK-NEXT: int z = x;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: _d_z += 1;
Expand All @@ -372,8 +356,6 @@ double helper2(float& x) { return x * x; }
//CHECK-NEXT: _t1 = x;
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: _ret_value0 = _t1 * _t0;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y * _t0;
//CHECK-NEXT: * _d_x += _r0;
Expand All @@ -400,8 +382,6 @@ float func8(float x, float y) {
//CHECK-NEXT: _t0 = x;
//CHECK-NEXT: z = y + helper2(x);
//CHECK-NEXT: _EERepl_z1 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = _d_z;
Expand Down Expand Up @@ -449,8 +429,6 @@ float func9(float x, float y) {
//CHECK-NEXT: _t5 = helper2(y);
//CHECK-NEXT: z += _t8 * _t5;
//CHECK-NEXT: _EERepl_z1 = z;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_z += 1;
//CHECK-NEXT: {
//CHECK-NEXT: float _r_d0 = _d_z;
Expand Down
6 changes: 0 additions & 6 deletions test/ErrorEstimation/ConditonalStatements.C
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ float func(float x, float y) {
//CHECK-NEXT: x = y;
//CHECK-NEXT: }
//CHECK-NEXT: _ret_value0 = x + y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: * _d_x += 1;
//CHECK-NEXT: * _d_y += 1;
Expand Down Expand Up @@ -160,8 +158,6 @@ float func3(float x, float y) { return x > 30 ? x * y : x + y; }
//CHECK-NEXT: _t0 = y;
//CHECK-NEXT: }
//CHECK-NEXT: _ret_value0 = _cond0 ? _t1 * _t0 : x + y;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: if (_cond0) {
//CHECK-NEXT: float _r0 = 1 * _t0;
//CHECK-NEXT: * _d_x += _r0;
Expand Down Expand Up @@ -207,8 +203,6 @@ float func4(float x, float y) {
//CHECK-NEXT: _t3 = y;
//CHECK-NEXT: _t2 = x;
//CHECK-NEXT: _ret_value0 = _t3 / _t2;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: float _r1 = 1 / _t2;
//CHECK-NEXT: * _d_y += _r1;
Expand Down
Loading

0 comments on commit 40fdd2b

Please sign in to comment.