From 8f7b4aa770568a0dc26e3bc29f407518edc317d5 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Thu, 1 Aug 2024 17:26:14 +0300 Subject: [PATCH] Add support for std::initializer_list in the reverse mode. Fixes #830. --- .../clad/Differentiator/ReverseModeVisitor.h | 2 + lib/Differentiator/ReverseModeVisitor.cpp | 56 +++++++-- test/Gradient/Assignments.C | 44 +++++++ test/Gradient/Loops.C | 115 ++++++++++++++++++ 4 files changed, 208 insertions(+), 9 deletions(-) diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 8e57a4cb4..121b82ab8 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -401,6 +401,8 @@ namespace clad { StmtDiff VisitDoStmt(const clang::DoStmt* DS); StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS); StmtDiff VisitBreakStmt(const clang::BreakStmt* BS); + StmtDiff + VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE); StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE); StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 2fc337460..a7eb45946 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -760,6 +760,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(enzymeCall); } } + + StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr( + const clang::CXXStdInitializerListExpr* ILE) { + return Visit(ILE->getSubExpr(), dfdx()); + } + StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { diag( DiagnosticsEngine::Warning, S->getBeginLoc(), @@ -1479,7 +1485,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // global. Ref-type declarations cannot be moved to the function global // scope because they can't be separated from their inits. if (DRE->getDecl()->getType()->isReferenceType() && - !VD->getType()->isReferenceType()) + VD->getType()->isPointerType()) clonedDRE = BuildOp(UO_Deref, clonedDRE); if (m_DiffReq.Mode == DiffMode::jacobian) { if (m_VectorOutput.size() <= outputArrayCursor) @@ -2717,11 +2723,42 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VDCloneType = CloneType(VD->getType()); VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema); } - bool isDerivativeOfRefType = VD->getType()->isReferenceType(); + + bool isRefType = VD->getType()->isLValueReferenceType(); VarDecl* VDDerived = nullptr; bool isPointerType = VD->getType()->isPointerType(); bool isInitializedByNewExpr = false; bool initializeDerivedVar = true; + + // We need to replace std::initializer_list with clad::array because the former + // is temporary by design and it's not possible to create modifiable adjoints. + if (m_Sema.isStdInitializerList(utils::GetValueType(VD->getType()), nullptr)) { + if (VD->getInit()) { + if (const auto* CXXILE = dyn_cast( + VD->getInit()->IgnoreImplicit())) { + if (const auto* ILE = dyn_cast( + CXXILE->getSubExpr()->IgnoreImplicit())) { + VDDerivedType = GetCladArrayOfType((*ILE->getInits())->getType()); + unsigned numInits = ILE->getNumInits(); + VDDerivedInit = ConstantFolder::synthesizeLiteral( + m_Context.getSizeType(), m_Context, numInits); + VDCloneType = VDDerivedType; + } + } else if (isRefType) { + initDiff = Visit(VD->getInit()); + if (promoteToFnScope) { + VDDerivedInit = BuildOp(UO_AddrOf, initDiff.getExpr_dx()); + VDDerivedType = VDDerivedInit->getType(); + } else { + VDDerivedInit = initDiff.getExpr_dx(); + VDDerivedType = + m_Context.getLValueReferenceType(VDDerivedInit->getType()); + } + VDCloneType = VDDerivedType; + } + } + } + // Check if the variable is pointer type and initialized by new expression if (isPointerType && VD->getInit() && isa(VD->getInit())) isInitializedByNewExpr = true; @@ -2752,7 +2789,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // `VDDerivedType` is the corresponding non-reference type and the initial // value is set to 0. // Otherwise, for non-reference types, the initial value is set to 0. - VDDerivedInit = getZeroInit(VD->getType()); + if (!VDDerivedInit) + VDDerivedInit = getZeroInit(VD->getType()); // `specialThisDiffCase` is only required for correctly differentiating // the following code: @@ -2769,14 +2807,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - if (isDerivativeOfRefType) { + if (isRefType) { initDiff = Visit(VD->getInit()); if (!initDiff.getForwSweepExpr_dx()) { VDDerivedType = ComputeAdjointType(VD->getType().getNonReferenceType()); - isDerivativeOfRefType = false; + isRefType = false; } - if (promoteToFnScope || !isDerivativeOfRefType) + if (promoteToFnScope || !isRefType) VDDerivedInit = getZeroInit(VDDerivedType); else VDDerivedInit = initDiff.getForwSweepExpr_dx(); @@ -2833,7 +2871,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // differentiated and should not be differentiated again. // If `VD` is a reference to a non-local variable then also there's no // need to call `Visit` since non-local variables are not differentiated. - if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) { + if (!isRefType && (!isPointerType || isInitializedByNewExpr)) { Expr* derivedE = nullptr; if (!clad::utils::hasNonDifferentiableAttribute(VD)) { @@ -2882,7 +2920,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: Add extra parantheses if derived variable pointer is pointing to a // class type object. - if (isDerivativeOfRefType && promoteToFnScope) { + if (isRefType && promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, BuildOp(UnaryOperatorKind::UO_AddrOf, @@ -2904,7 +2942,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // -> // double* ref; // ref = &x; - if (isDerivativeOfRefType && promoteToFnScope) + if (isRefType && promoteToFnScope) VDClone = BuildGlobalVarDecl( VDCloneType, VD->getNameAsString(), BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()), diff --git a/test/Gradient/Assignments.C b/test/Gradient/Assignments.C index 77a3bcc9e..d4dc15600 100644 --- a/test/Gradient/Assignments.C +++ b/test/Gradient/Assignments.C @@ -777,6 +777,49 @@ double f22(double x, double y) { return t; } +double f23(double x, double y) { + auto&& list = {1., x+y}; + double res = 5; + if (x > y) { + auto& ref = list; + res = *(std::end(ref) - 1); + } + return res; +} + +//CHECK: void f23_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: clad::array *_d_ref = 0; +//CHECK-NEXT: clad::array *ref = {}; +//CHECK-NEXT: double _t0; +//CHECK-NEXT: clad::array _d_list = {{2U|2UL}}; +//CHECK-NEXT: clad::array list = {1., x + y}; +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: double res = 5; +//CHECK-NEXT: { +//CHECK-NEXT: _cond0 = x > y; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: _d_ref = &_d_list; +//CHECK-NEXT: ref = &list; +//CHECK-NEXT: _t0 = res; +//CHECK-NEXT: res = *(std::end(*ref) - 1); +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: { +//CHECK-NEXT: res = _t0; +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: _d_res = 0; +//CHECK-NEXT: *(std::end(*_d_ref) - 1) += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_list[1]; +//CHECK-NEXT: *_d_y += _d_list[1]; +//CHECK-NEXT: } +//CHECK-NEXT: } + #define TEST(F, x, y) \ { \ result[0] = 0; \ @@ -841,4 +884,5 @@ int main() { TEST(f20, 1, 2); // CHECK-EXEC: {0.00, 3.00} TEST(f21, 6, 4); // CHECK-EXEC: {1.00, 0.00} TEST(f22, 6, 4); // CHECK-EXEC: {0.00, 0.00} + TEST(f23, 7, 5); // CHECK-EXEC: {1.00, 1.00} } diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index 2315008bf..e04422e0f 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -3020,6 +3020,119 @@ double fn37(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } +double fn38(double x, double y) { + double sum = 0; + if (x > 0) { + auto&& range = {1., x, 2., y, 3.}; + for (auto elem : range) + sum += elem; + } + return sum; +} + +//CHECK: void fn38_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: bool _cond0; +//CHECK-NEXT: clad::array _d_range = {{5U|5UL}}; +//CHECK-NEXT: clad::array range = {}; +//CHECK-NEXT: unsigned {{int|long}} _t0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: double _d_sum = 0; +//CHECK-NEXT: double sum = 0; +//CHECK-NEXT: { +//CHECK-NEXT: _cond0 = x > 0; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: range = {1., x, 2., y, 3.}; +//CHECK-NEXT: _t0 = {{0U|0UL}}; +//CHECK-NEXT: clad::array &__range20 = range; +//CHECK-NEXT: clad::array &_d___range2 = _d_range; +//CHECK-NEXT: {{const double *\*|const_iterator }}__begin20 = std::begin(__range20); +//CHECK-NEXT: double *_d___begin2 = std::begin(_d___range2); +//CHECK-NEXT: {{const double *\*|const_iterator }}__end20 = std::end(__range20); +//CHECK-NEXT: double _d_elem = 0; +//CHECK-NEXT: double elem = 0; +//CHECK-NEXT: for (; __begin20 != __end20; ++__begin20 , ++_d___begin2) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_elem = *_d___begin2; +//CHECK-NEXT: elem = *__begin20; +//CHECK-NEXT: clad::push(_t2, elem); +//CHECK-NEXT: clad::push(_t3, _d_elem); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, sum); +//CHECK-NEXT: sum += elem; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: _d_sum += 1; +//CHECK-NEXT: if (_cond0) { +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin2--; +//CHECK-NEXT: elem = clad::pop(_t2); +//CHECK-NEXT: _d_elem = clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: sum = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: _d_elem += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: *_d___begin2 += _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_range[1]; +//CHECK-NEXT: *_d_y += _d_range[3]; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn39(double x) { + double res = 0; + auto &&range = {1, 2, 3}; + for (auto i = range.begin(); i != range.end(); i++) { + res += x * (*i); + } + return res; +} + +//CHECK: void fn39_grad(double x, double *_d_x) { +//CHECK-NEXT: int *_d_i = 0; +//CHECK-NEXT: {{const int *\*|const_iterator }}i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: clad::array _d_range = {{3U|3UL}}; +//CHECK-NEXT: clad::array range = {1, 2, 3}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: _d_i = std::begin(_d_range); +//CHECK-NEXT: for (i = std::begin(range); ; _d_i++ , i++) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!(i != std::end(range))) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: res += x * (*i); +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!_t0) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: i--; +//CHECK-NEXT: _d_i--; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: *_d_x += _r_d0 * (*i); +//CHECK-NEXT: *_d_i += x * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } #define TEST(F, x) { \ result[0] = 0; \ @@ -3109,6 +3222,8 @@ int main() { TEST_2(fn35, 2, 2); // CHECK-EXEC: {12.00, 4.00} TEST_2(fn36, 1, 1); // CHECK-EXEC: {1.75, 0.00} TEST_2(fn37, 1, 1); // CHECK-EXEC: {1.00, 1.00} + TEST_2(fn38, 6, 3); // CHECK-EXEC: {1.00, 1.00} + TEST(fn39, 9); // CHECK-EXEC: {6.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {