diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index a161f1f58..c119ed4a6 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -401,6 +401,8 @@ namespace clad { StmtDiff VisitDoStmt(const clang::DoStmt* DS); StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS); StmtDiff VisitBreakStmt(const clang::BreakStmt* BS); + StmtDiff + VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE); StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE); StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE); StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE); diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0a5776435..b9821fafc 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -760,6 +760,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, addToCurrentBlock(enzymeCall); } } + + StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr( + const clang::CXXStdInitializerListExpr* ILE) { + return Visit(ILE->getSubExpr(), dfdx()); + } + StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) { diag( DiagnosticsEngine::Warning, S->getBeginLoc(), @@ -1472,7 +1478,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // global. Ref-type declarations cannot be moved to the function global // scope because they can't be separated from their inits. if (DRE->getDecl()->getType()->isReferenceType() && - !VD->getType()->isReferenceType()) + VD->getType()->isPointerType()) clonedDRE = BuildOp(UO_Deref, clonedDRE); if (m_DiffReq.Mode == DiffMode::jacobian) { if (m_VectorOutput.size() <= outputArrayCursor) @@ -2714,11 +2720,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, VDCloneType = CloneType(VD->getType()); VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema); } - bool isDerivativeOfRefType = VD->getType()->isReferenceType(); + + bool isRefType = VD->getType()->isLValueReferenceType(); VarDecl* VDDerived = nullptr; bool isPointerType = VD->getType()->isPointerType(); bool isInitializedByNewExpr = false; bool initializeDerivedVar = true; + std::string typeName; + if (const auto* RT = + utils::GetValueType(VD->getType())->getAs()) + typeName = RT->getDecl()->getNameAsString(); + if (typeName == "initializer_list") { + if (VD->getInit()) { + if (const auto* CXXILE = dyn_cast( + VD->getInit()->IgnoreImplicit())) { + if (const auto* ILE = dyn_cast( + CXXILE->getSubExpr()->IgnoreImplicit())) { + VDDerivedType = GetCladArrayOfType((*ILE->getInits())->getType()); + unsigned numInits = ILE->getNumInits(); + VDDerivedInit = ConstantFolder::synthesizeLiteral( + m_Context.getSizeType(), m_Context, numInits); + VDCloneType = VDDerivedType; + } + } else if (isRefType) { + initDiff = Visit(VD->getInit()); + VDDerivedInit = BuildOp(UO_AddrOf, initDiff.getExpr_dx()); + VDDerivedType = VDDerivedInit->getType(); + VDCloneType = VDDerivedType; + } + } + } + // Check if the variable is pointer type and initialized by new expression if (isPointerType && VD->getInit() && isa(VD->getInit())) isInitializedByNewExpr = true; @@ -2749,7 +2781,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // `VDDerivedType` is the corresponding non-reference type and the initial // value is set to 0. // Otherwise, for non-reference types, the initial value is set to 0. - VDDerivedInit = getZeroInit(VD->getType()); + if (!VDDerivedInit) + VDDerivedInit = getZeroInit(VD->getType()); // `specialThisDiffCase` is only required for correctly differentiating // the following code: @@ -2766,14 +2799,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, } } - if (isDerivativeOfRefType) { + if (isRefType) { initDiff = Visit(VD->getInit()); if (!initDiff.getForwSweepExpr_dx()) { VDDerivedType = ComputeAdjointType(VD->getType().getNonReferenceType()); - isDerivativeOfRefType = false; + isRefType = false; } - if (promoteToFnScope || !isDerivativeOfRefType) + if (promoteToFnScope || !isRefType) VDDerivedInit = getZeroInit(VDDerivedType); else VDDerivedInit = initDiff.getForwSweepExpr_dx(); @@ -2830,7 +2863,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // differentiated and should not be differentiated again. // If `VD` is a reference to a non-local variable then also there's no // need to call `Visit` since non-local variables are not differentiated. - if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) { + if (!isRefType && (!isPointerType || isInitializedByNewExpr)) { Expr* derivedE = nullptr; if (!clad::utils::hasNonDifferentiableAttribute(VD)) { @@ -2878,7 +2911,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // FIXME: Add extra parantheses if derived variable pointer is pointing to a // class type object. - if (isDerivativeOfRefType && promoteToFnScope) { + if (isRefType && promoteToFnScope) { Expr* assignDerivativeE = BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE, BuildOp(UnaryOperatorKind::UO_AddrOf, @@ -2899,7 +2932,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // -> // double* ref; // ref = &x; - if (isDerivativeOfRefType && promoteToFnScope) + if (isRefType && promoteToFnScope) VDClone = BuildGlobalVarDecl( VDCloneType, VD->getNameAsString(), BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()), diff --git a/test/Gradient/Loops.C b/test/Gradient/Loops.C index a0a940610..642b0c63a 100644 --- a/test/Gradient/Loops.C +++ b/test/Gradient/Loops.C @@ -2878,6 +2878,113 @@ double fn36(double x, double y) { //CHECK-NEXT: } //CHECK-NEXT: } +double fn37(double x, double y) { + auto&& range = {1., x, 2., y, 3.}; + double sum = 0; + for (auto elem : range) + sum += elem; + return sum; +} + +//CHECK: void fn37_grad(double x, double y, double *_d_x, double *_d_y) { +//CHECK-NEXT: unsigned {{int|long}} _t0; +//CHECK-NEXT: clad::array *_d___range1 = 0; +//CHECK-NEXT: clad::array *__range10 = {}; +//CHECK-NEXT: double *_d___begin1 = 0; +//CHECK-NEXT: const double *__begin10 = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: clad::tape _t2 = {}; +//CHECK-NEXT: clad::tape _t3 = {}; +//CHECK-NEXT: clad::array _d_range = {{5U|5UL}}; +//CHECK-NEXT: clad::array range = {1., x, 2., y, 3.}; +//CHECK-NEXT: double _d_sum = 0; +//CHECK-NEXT: double sum = 0; +//CHECK-NEXT: _t0 = {{0U|0UL}}; +//CHECK-NEXT: _d___range1 = &_d_range; +//CHECK-NEXT: _d___begin1 = std::begin(*_d___range1); +//CHECK-NEXT: __range10 = ⦥ +//CHECK-NEXT: __begin10 = std::begin(*__range10); +//CHECK-NEXT: const double *__end10 = std::end(*__range10); +//CHECK-NEXT: double _d_elem = 0; +//CHECK-NEXT: double elem = 0; +//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) { +//CHECK-NEXT: { +//CHECK-NEXT: _d_elem = *_d___begin1; +//CHECK-NEXT: elem = *__begin10; +//CHECK-NEXT: clad::push(_t2, elem); +//CHECK-NEXT: clad::push(_t3, _d_elem); +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, sum); +//CHECK-NEXT: sum += elem; +//CHECK-NEXT: } +//CHECK-NEXT: _d_sum += 1; +//CHECK-NEXT: for (; _t0; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: { +//CHECK-NEXT: _d___begin1--; +//CHECK-NEXT: elem = clad::pop(_t2); +//CHECK-NEXT: _d_elem = clad::pop(_t3); +//CHECK-NEXT: } +//CHECK-NEXT: sum = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_sum; +//CHECK-NEXT: _d_elem += _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: *_d___begin1 += _d_elem; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: *_d_x += _d_range[1]; +//CHECK-NEXT: *_d_y += _d_range[3]; +//CHECK-NEXT: } +//CHECK-NEXT: } + +double fn38(double x) { + double res = 0; + auto &&range = {1, 2, 3}; + for (auto i = range.begin(); i != range.end(); i++) { + res += x * (*i); + } + return res; +} + +//CHECK: void fn38_grad(double x, double *_d_x) { +//CHECK-NEXT: int *_d_i = 0; +//CHECK-NEXT: const int *i = 0; +//CHECK-NEXT: clad::tape _t1 = {}; +//CHECK-NEXT: double _d_res = 0; +//CHECK-NEXT: double res = 0; +//CHECK-NEXT: clad::array _d_range = {{3U|3UL}}; +//CHECK-NEXT: clad::array range = {1, 2, 3}; +//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}}; +//CHECK-NEXT: _d_i = std::begin(_d_range); +//CHECK-NEXT: for (i = std::begin(range); ; _d_i++ , i++) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!(i != range.end())) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: _t0++; +//CHECK-NEXT: clad::push(_t1, res); +//CHECK-NEXT: res += x * (*i); +//CHECK-NEXT: } +//CHECK-NEXT: _d_res += 1; +//CHECK-NEXT: for (;; _t0--) { +//CHECK-NEXT: { +//CHECK-NEXT: if (!_t0) +//CHECK-NEXT: break; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: i--; +//CHECK-NEXT: _d_i--; +//CHECK-NEXT: } +//CHECK-NEXT: { +//CHECK-NEXT: res = clad::pop(_t1); +//CHECK-NEXT: double _r_d0 = _d_res; +//CHECK-NEXT: *_d_x += _r_d0 * (*i); +//CHECK-NEXT: *_d_i += x * _r_d0; +//CHECK-NEXT: } +//CHECK-NEXT: } +//CHECK-NEXT: } + #define TEST(F, x) { \ result[0] = 0; \ auto F##grad = clad::gradient(F);\ @@ -2965,6 +3072,8 @@ int main() { TEST_2(fn34, 2, 2); // CHECK-EXEC: {64.00, 32.00} TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00} TEST_2(fn36, 6, 3); // CHECK-EXEC: {1.00, 1.00} + TEST_2(fn37, 6, 3); // CHECK-EXEC: {1.00, 1.00} + TEST(fn38, 9); // CHECK-EXEC: {6.00} } //CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {