diff --git a/lib/Differentiator/TBRAnalyzer.cpp b/lib/Differentiator/TBRAnalyzer.cpp index ea9e4eb6a..72fbb463c 100644 --- a/lib/Differentiator/TBRAnalyzer.cpp +++ b/lib/Differentiator/TBRAnalyzer.cpp @@ -188,7 +188,10 @@ TBRAnalyzer::VarData* TBRAnalyzer::getExprVarData(const clang::Expr* E, return EData; } -TBRAnalyzer::VarData::VarData(const QualType QT) { +TBRAnalyzer::VarData::VarData(QualType QT, bool forceNonRefType) { + if (forceNonRefType && QT->isReferenceType()) + QT = QT->getPointeeType(); + if (QT->isReferenceType()) { m_Type = VarData::REF_TYPE; m_Val.m_RefData = nullptr; @@ -265,7 +268,7 @@ void TBRAnalyzer::copyVarToCurBlock(const clang::VarDecl* VD) { addVar(VD); } -void TBRAnalyzer::addVar(const clang::VarDecl* VD) { +void TBRAnalyzer::addVar(const clang::VarDecl* VD, bool forceNonRefType) { auto& curBranch = getCurBlockVarsData(); QualType varType; @@ -288,7 +291,7 @@ void TBRAnalyzer::addVar(const clang::VarDecl* VD) { return; } } - curBranch[VD] = VarData(varType); + curBranch[VD] = VarData(varType, forceNonRefType); } void TBRAnalyzer::markLocation(const clang::Expr* E) { @@ -342,7 +345,7 @@ void TBRAnalyzer::Analyze(const FunctionDecl* FD) { } auto paramsRef = FD->parameters(); for (std::size_t i = 0; i < FD->getNumParams(); ++i) - addVar(paramsRef[i]); + addVar(paramsRef[i], /*forceNonRefType=*/true); // Add the entry block to the queue. m_CFGQueue.insert(m_CurBlockID); diff --git a/lib/Differentiator/TBRAnalyzer.h b/lib/Differentiator/TBRAnalyzer.h index d1f772757..36489f3e9 100644 --- a/lib/Differentiator/TBRAnalyzer.h +++ b/lib/Differentiator/TBRAnalyzer.h @@ -98,7 +98,11 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor { } /// Builds a VarData object (and its children) based on the provided type. - VarData(QualType QT); + /// If `forceNonRefType` is true, the constructed VarData will not be of + /// reference type (it will store TBR information itself without referring + /// to other VarData's). This is necessary for reference-type parameters, + /// when the referenced expressions are out of the function's scope. + VarData(QualType QT, bool forceNonRefType=false); /// Erases all children VarData's of this VarData. ~VarData() { @@ -250,7 +254,7 @@ class TBRAnalyzer : public clang::RecursiveASTVisitor { //// Setters /// Creates VarData for a new VarDecl*. - void addVar(const clang::VarDecl* VD); + void addVar(const clang::VarDecl* VD, bool forceNonRefType = false); /// Makes a copy of the VarData corresponding to VD /// to the current block from the lowest predecessor /// where VD is present. diff --git a/test/Arrays/ArrayInputsReverseMode.C b/test/Arrays/ArrayInputsReverseMode.C index 35874ca1d..306f73257 100644 --- a/test/Arrays/ArrayInputsReverseMode.C +++ b/test/Arrays/ArrayInputsReverseMode.C @@ -428,6 +428,200 @@ double func7(double *params) { //CHECK-NEXT: } //CHECK-NEXT: } +double helper2(double i, double *arr, int n) { + return arr[0]*i; +} + +//CHECK: void helper2_pullback(double i, double *arr, int n, double _d_y, clad::array_ref _d_i, clad::array_ref _d_arr, clad::array_ref _d_n) { +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: { +//CHECK-NEXT: double _r0 = _d_y * i; +//CHECK-NEXT: _d_arr[0] += _r0; +//CHECK-NEXT: double _r1 = arr[0] * _d_y; +//CHECK-NEXT: * _d_i += _r1; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double func8(double i, double *arr, int n) { + double res = 0; + arr[0] = 1; + res = helper2(i, arr, n); + arr[0] = 5; + return res; +} + +//CHECK: void func8_grad(double i, double *arr, int n, clad::array_ref _d_i, clad::array_ref _d_arr, clad::array_ref _d_n) { +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: double _t0; +//CHECK-NEXT: double _t1; +//CHECK-NEXT: double *_t2; +//CHECK-NEXT: double _t3; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: _t0 = arr[0]; +//CHECK-NEXT: arr[0] = 1; +//CHECK-NEXT: _t1 = res; +//CHECK-NEXT: _t2 = arr; +//CHECK-NEXT: res = helper2(i, arr, n); +//CHECK-NEXT: _t3 = arr[0]; +//CHECK-NEXT: arr[0] = 5; +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: { +//CHECK-NEXT: arr[0] = _t3; +//CHECK-NEXT: double _r_d2 = _d_arr[0]; +//CHECK-NEXT: _d_arr[0] -= _r_d2; +//CHECK-NEXT: _d_arr[0]; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: res = _t1; +//CHECK-NEXT: double _r_d1 = _d_res; +//CHECK-NEXT: arr = _t2; +//CHECK-NEXT: double _grad0 = 0.; +//CHECK-NEXT: int _grad2 = 0; +//CHECK-NEXT: helper2_pullback(i, _t2, n, _r_d1, &_grad0, _d_arr, &_grad2); +//CHECK-NEXT: double _r0 = _grad0; +//CHECK-NEXT: * _d_i += _r0; +//CHECK-NEXT: clad::array _r1(_d_arr); +//CHECK-NEXT: int _r2 = _grad2; +//CHECK-NEXT: * _d_n += _r2; +//CHECK-NEXT: _d_res -= _r_d1; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: arr[0] = _t0; +//CHECK-NEXT: double _r_d0 = _d_arr[0]; +//CHECK-NEXT: _d_arr[0] -= _r_d0; +//CHECK-NEXT: _d_arr[0]; +//CHECK-NEXT: } +//CHECK-NEXT: } + +void modify(double& elem, double val) { + elem = val; +} + +//CHECK: void modify_pullback(double &elem, double val, clad::array_ref _d_elem, clad::array_ref _d_val) { +//CHECK-NEXT: double _t0; +//CHECK-NEXT: _t0 = elem; +//CHECK-NEXT: elem = val; +//CHECK-NEXT: { +//CHECK-NEXT: elem = _t0; +//CHECK-NEXT: double _r_d0 = * _d_elem; +//CHECK-NEXT: * _d_val += _r_d0; +//CHECK-NEXT: * _d_elem -= _r_d0; +//CHECK-NEXT: * _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double func9(double i, double j) { + double arr[5] = {}; + for (int idx = 0; idx < 5; ++idx) { + modify(arr[idx], i); + } + return arr[0] + arr[1] + arr[2] + arr[3] + arr[4]; +} + + +//CHECK: void func9_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +//CHECK-NEXT: clad::array _d_arr(5UL); +//CHECK-NEXT: unsigned long _t0; +//CHECK-NEXT: int _d_idx = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: double arr[5] = {}; +//CHECK-NEXT: _t0 = 0; +//CHECK-NEXT: for (int idx = 0; idx < 5; ++idx) { +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, arr[idx]); +//CHECK-NEXT: modify(arr[idx], i); +//CHECK-NEXT: } +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: { +//CHECK-NEXT: _d_arr[0] += 1; +//CHECK-NEXT: _d_arr[1] += 1; +//CHECK-NEXT: _d_arr[2] += 1; +//CHECK-NEXT: _d_arr[3] += 1; +//CHECK-NEXT: _d_arr[4] += 1; +//CHECK-NEXT: } +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: --idx; +//CHECK-NEXT: { +//CHECK-NEXT: double _r1 = clad::pop(_t1); +//CHECK-NEXT: arr[idx] = _r1; +//CHECK-NEXT: double _grad1 = 0.; +//CHECK-NEXT: modify_pullback(_r1, i, &_d_arr[idx], &_grad1); +//CHECK-NEXT: double _r0 = _d_arr[idx]; +//CHECK-NEXT: double _r2 = _grad1; +//CHECK-NEXT: * _d_i += _r2; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } + +double sq(double& elem) { + elem = elem * elem; + return elem; +} + +//CHECK: void sq_pullback(double &elem, double _d_y, clad::array_ref _d_elem) { +//CHECK-NEXT: double _t0; +//CHECK-NEXT: _t0 = elem; +//CHECK-NEXT: elem = elem * elem; +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: * _d_elem += _d_y; +//CHECK-NEXT: { +//CHECK-NEXT: elem = _t0; +//CHECK-NEXT: double _r_d0 = * _d_elem; +//CHECK-NEXT: double _r0 = _r_d0 * elem; +//CHECK-NEXT: * _d_elem += _r0; +//CHECK-NEXT: double _r1 = elem * _r_d0; +//CHECK-NEXT: * _d_elem += _r1; +//CHECK-NEXT: * _d_elem -= _r_d0; +//CHECK-NEXT: * _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double func10(double *arr, int n) { + double res = 0; + for (int i=0; i _d_arr) { +//CHECK-NEXT: int _d_n = 0; +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: unsigned long _t0; +//CHECK-NEXT: int _d_i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: _t0 = 0; +//CHECK-NEXT: for (int i = 0; i < n; ++i) { +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: clad::push(_t2, arr[i]); +//CHECK-NEXT: res += sq(arr[i]); +//CHECK-NEXT: } +//CHECK-NEXT: goto _label0; +//CHECK-NEXT: _label0: +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: --i; +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: _d_res += _r_d0; +//CHECK-NEXT: double _r1 = clad::pop(_t2); +//CHECK-NEXT: arr[i] = _r1; +//CHECK-NEXT: sq_pullback(_r1, _r_d0, &_d_arr[i]); +//CHECK-NEXT: double _r0 = _d_arr[i]; +//CHECK-NEXT: _d_res -= _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } + int main() { double arr[] = {1, 2, 3}; auto f_dx = clad::gradient(f); @@ -475,4 +669,25 @@ int main() { double dparams = 0.0; func7grad.execute(¶ms, &dparams); printf("Result = {%.2f}\n", dparams); // CHECK-EXEC: Result = {-0.25} + + auto func8grad = clad::gradient(func8); + double arr2[5] = {1, 2, 3, 4, 5}; + double _d_arr2[5] = {}; + clad::array_ref _d_arr_ref2(_d_arr2, 5); + double d_i = 0, d_n = 0; + func8grad.execute(3, arr, 5, &d_i, _d_arr_ref2, &d_n); + printf("Result = {%.2f}\n", d_i); // CHECK-EXEC: Result = {1.00} + + auto func9grad = clad::gradient(func9); + double d_j; + d_i = d_j = 0; + func9grad.execute(3, 5, &d_i, &d_j); + printf("Result = {%.2f}\n", d_i); // CHECK-EXEC: Result = {5.00} + + auto func10grad = clad::gradient(func10, "arr"); + double arr3[5] = {1, 2, 3, 4, 5}; + double _d_arr3[5] = {}; + clad::array_ref ref3(_d_arr3, 5); + func10grad.execute(arr3, 5, ref3); + printf("Result (arr) = {%.2f, %.2f, %.2f, %.2f, %.2f}\n", ref3[0], ref3[1], ref3[2], ref3[3], ref3[4]); // CHECK-EXEC: Result (arr) = {2.00, 4.00, 6.00, 8.00, 10.00} }