Skip to content

Commit

Permalink
new tests and a ref bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Dec 1, 2023
1 parent bb5e26d commit 51c1c4d
Show file tree
Hide file tree
Showing 3 changed files with 228 additions and 6 deletions.
11 changes: 7 additions & 4 deletions lib/Differentiator/TBRAnalyzer.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,10 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E,
return EData;
}

TBRAnalyzer::VarData::VarData(const QualType QT) {
TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) {
if (forceNonRefType && QT->isReferenceType())
QT = QT->getPointeeType();

if (QT->isReferenceType()) {
m_Type = VarData::REF_TYPE;
m_Val.m_RefData = nullptr;
Expand Down Expand Up @@ -265,7 +268,7 @@ void TBRAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) {
addVar(VD);
}

void TBRAnalyzer::addVar(const clang::VarDecl* VD) {
void TBRAnalyzer::addVar(const clang::VarDecl* VD, bool forceNonRefType) {
auto& curBranch = getCurBlockVarsData();

QualType varType;
Expand All @@ -288,7 +291,7 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) {
return;
}
}
curBranch[VD] = VarData(varType);
curBranch[VD] = VarData(varType, forceNonRefType);
}

void TBRAnalyzer::markLocation(const clang::Expr* E) {
Expand Down Expand Up @@ -342,7 +345,7 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) {
}
auto paramsRef = FD->parameters();
for (std::size_t i = 0; i < FD->getNumParams(); ++i)
addVar(paramsRef[i]);
addVar(paramsRef[i], /*forceNonRefType=*/true);
// Add the entry block to the queue.
m_CFGQueue.insert(m_CurBlockID);

Expand Down
8 changes: 6 additions & 2 deletions lib/Differentiator/TBRAnalyzer.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,11 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {
}

/// Builds a VarData object (and its children) based on the provided type.
VarData(QualType QT);
/// If `forceNonRefType` is true, the constructed VarData will not be of
/// reference type (it will store TBR information itself without referring
/// to other VarData's). This is necessary for reference-type parameters,
/// when the referenced expressions are out of the function's scope.
VarData(QualType QT, bool forceNonRefType=false);

/// Erases all children VarData's of this VarData.
~VarData() {
Expand Down Expand Up @@ -250,7 +254,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor<TBRAnalyzer> {

//// Setters
/// Creates VarData for a new VarDecl*.
void addVar(const clang::VarDecl* VD);
void addVar(const clang::VarDecl* VD, bool forceNonRefType = false);
/// Makes a copy of the VarData corresponding to VD
/// to the current block from the lowest predecessor
/// where VD is present.
Expand Down
215 changes: 215 additions & 0 deletions test/Arrays/ArrayInputsReverseMode.C
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,200 @@ double func7(double *params) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double helper2(double i, double *arr, int n) {
return arr[0]*i;
}

//CHECK: void helper2_pullback(double i, double *arr, int n, double _d_y, clad::array_ref<double> _d_i, clad::array_ref<double> _d_arr, clad::array_ref<int> _d_n) {
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: double _r0 = _d_y * i;
//CHECK-NEXT: _d_arr[0] += _r0;
//CHECK-NEXT: double _r1 = arr[0] * _d_y;
//CHECK-NEXT: * _d_i += _r1;
//CHECK-NEXT: }
//CHECK-NEXT: }

double func8(double i, double *arr, int n) {
double res = 0;
arr[0] = 1;
res = helper2(i, arr, n);
arr[0] = 5;
return res;
}

//CHECK: void func8_grad(double i, double *arr, int n, clad::array_ref<double> _d_i, clad::array_ref<double> _d_arr, clad::array_ref<int> _d_n) {
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double _t0;
//CHECK-NEXT: double _t1;
//CHECK-NEXT: double *_t2;
//CHECK-NEXT: double _t3;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: _t0 = arr[0];
//CHECK-NEXT: arr[0] = 1;
//CHECK-NEXT: _t1 = res;
//CHECK-NEXT: _t2 = arr;
//CHECK-NEXT: res = helper2(i, arr, n);
//CHECK-NEXT: _t3 = arr[0];
//CHECK-NEXT: arr[0] = 5;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: {
//CHECK-NEXT: arr[0] = _t3;
//CHECK-NEXT: double _r_d2 = _d_arr[0];
//CHECK-NEXT: _d_arr[0] -= _r_d2;
//CHECK-NEXT: _d_arr[0];
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: res = _t1;
//CHECK-NEXT: double _r_d1 = _d_res;
//CHECK-NEXT: arr = _t2;
//CHECK-NEXT: double _grad0 = 0.;
//CHECK-NEXT: int _grad2 = 0;
//CHECK-NEXT: helper2_pullback(i, _t2, n, _r_d1, &_grad0, _d_arr, &_grad2);
//CHECK-NEXT: double _r0 = _grad0;
//CHECK-NEXT: * _d_i += _r0;
//CHECK-NEXT: clad::array<double> _r1(_d_arr);
//CHECK-NEXT: int _r2 = _grad2;
//CHECK-NEXT: * _d_n += _r2;
//CHECK-NEXT: _d_res -= _r_d1;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: arr[0] = _t0;
//CHECK-NEXT: double _r_d0 = _d_arr[0];
//CHECK-NEXT: _d_arr[0] -= _r_d0;
//CHECK-NEXT: _d_arr[0];
//CHECK-NEXT: }
//CHECK-NEXT: }

void modify(double& elem, double val) {
elem = val;
}

//CHECK: void modify_pullback(double &elem, double val, clad::array_ref<double> _d_elem, clad::array_ref<double> _d_val) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _t0 = elem;
//CHECK-NEXT: elem = val;
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = * _d_elem;
//CHECK-NEXT: * _d_val += _r_d0;
//CHECK-NEXT: * _d_elem -= _r_d0;
//CHECK-NEXT: * _d_elem;
//CHECK-NEXT: }
//CHECK-NEXT: }

double func9(double i, double j) {
double arr[5] = {};
for (int idx = 0; idx < 5; ++idx) {
modify(arr[idx], i);
}
return arr[0] + arr[1] + arr[2] + arr[3] + arr[4];
}


//CHECK: void func9_grad(double i, double j, clad::array_ref<double> _d_i, clad::array_ref<double> _d_j) {
//CHECK-NEXT: clad::array<double> _d_arr(5UL);
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_idx = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double arr[5] = {};
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int idx = 0; idx < 5; ++idx) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, arr[idx]);
//CHECK-NEXT: modify(arr[idx], i);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: {
//CHECK-NEXT: _d_arr[0] += 1;
//CHECK-NEXT: _d_arr[1] += 1;
//CHECK-NEXT: _d_arr[2] += 1;
//CHECK-NEXT: _d_arr[3] += 1;
//CHECK-NEXT: _d_arr[4] += 1;
//CHECK-NEXT: }
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: --idx;
//CHECK-NEXT: {
//CHECK-NEXT: double _r1 = clad::pop(_t1);
//CHECK-NEXT: arr[idx] = _r1;
//CHECK-NEXT: double _grad1 = 0.;
//CHECK-NEXT: modify_pullback(_r1, i, &_d_arr[idx], &_grad1);
//CHECK-NEXT: double _r0 = _d_arr[idx];
//CHECK-NEXT: double _r2 = _grad1;
//CHECK-NEXT: * _d_i += _r2;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

double sq(double& elem) {
elem = elem * elem;
return elem;
}

//CHECK: void sq_pullback(double &elem, double _d_y, clad::array_ref<double> _d_elem) {
//CHECK-NEXT: double _t0;
//CHECK-NEXT: _t0 = elem;
//CHECK-NEXT: elem = elem * elem;
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: * _d_elem += _d_y;
//CHECK-NEXT: {
//CHECK-NEXT: elem = _t0;
//CHECK-NEXT: double _r_d0 = * _d_elem;
//CHECK-NEXT: double _r0 = _r_d0 * elem;
//CHECK-NEXT: * _d_elem += _r0;
//CHECK-NEXT: double _r1 = elem * _r_d0;
//CHECK-NEXT: * _d_elem += _r1;
//CHECK-NEXT: * _d_elem -= _r_d0;
//CHECK-NEXT: * _d_elem;
//CHECK-NEXT: }
//CHECK-NEXT: }

double func10(double *arr, int n) {
double res = 0;
for (int i=0; i<n; ++i) {
res += sq(arr[i]);
}
return res;
}

//CHECK: void func10_grad_0(double *arr, int n, clad::array_ref<double> _d_arr) {
//CHECK-NEXT: int _d_n = 0;
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: unsigned long _t0;
//CHECK-NEXT: int _d_i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: _t0 = 0;
//CHECK-NEXT: for (int i = 0; i < n; ++i) {
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, res);
//CHECK-NEXT: clad::push(_t2, arr[i]);
//CHECK-NEXT: res += sq(arr[i]);
//CHECK-NEXT: }
//CHECK-NEXT: goto _label0;
//CHECK-NEXT: _label0:
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: --i;
//CHECK-NEXT: {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: _d_res += _r_d0;
//CHECK-NEXT: double _r1 = clad::pop(_t2);
//CHECK-NEXT: arr[i] = _r1;
//CHECK-NEXT: sq_pullback(_r1, _r_d0, &_d_arr[i]);
//CHECK-NEXT: double _r0 = _d_arr[i];
//CHECK-NEXT: _d_res -= _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

int main() {
double arr[] = {1, 2, 3};
auto f_dx = clad::gradient(f);
Expand Down Expand Up @@ -475,4 +669,25 @@ int main() {
double dparams = 0.0;
func7grad.execute(&params, &dparams);
printf("Result = {%.2f}\n", dparams); // CHECK-EXEC: Result = {-0.25}

auto func8grad = clad::gradient(func8);
double arr2[5] = {1, 2, 3, 4, 5};
double _d_arr2[5] = {};
clad::array_ref<double> _d_arr_ref2(_d_arr2, 5);
double d_i = 0, d_n = 0;
func8grad.execute(3, arr, 5, &d_i, _d_arr_ref2, &d_n);
printf("Result = {%.2f}\n", d_i); // CHECK-EXEC: Result = {1.00}

auto func9grad = clad::gradient(func9);
double d_j;
d_i = d_j = 0;
func9grad.execute(3, 5, &d_i, &d_j);
printf("Result = {%.2f}\n", d_i); // CHECK-EXEC: Result = {5.00}

auto func10grad = clad::gradient(func10, "arr");
double arr3[5] = {1, 2, 3, 4, 5};
double _d_arr3[5] = {};
clad::array_ref<double> ref3(_d_arr3, 5);
func10grad.execute(arr3, 5, ref3);
printf("Result (arr) = {%.2f, %.2f, %.2f, %.2f, %.2f}\n", ref3[0], ref3[1], ref3[2], ref3[3], ref3[4]); // CHECK-EXEC: Result (arr) = {2.00, 4.00, 6.00, 8.00, 10.00}
}

0 comments on commit 51c1c4d

Please sign in to comment.