Skip to content

Commit

Permalink
Add support for std::initializer_list in the reverse mode. Fixes vgva…
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 8, 2024
1 parent 618f9c7 commit 3e060b2
Show file tree
Hide file tree
Showing 3 changed files with 153 additions and 9 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ namespace clad {
StmtDiff VisitDoStmt(const clang::DoStmt* DS);
StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS);
StmtDiff VisitBreakStmt(const clang::BreakStmt* BS);
StmtDiff
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE);
StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE);
StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE);
Expand Down
51 changes: 42 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(enzymeCall);
}
}

StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
return Visit(ILE->getSubExpr(), dfdx());
}

StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) {
diag(
DiagnosticsEngine::Warning, S->getBeginLoc(),
Expand Down Expand Up @@ -1472,7 +1478,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// global. Ref-type declarations cannot be moved to the function global
// scope because they can't be separated from their inits.
if (DRE->getDecl()->getType()->isReferenceType() &&
!VD->getType()->isReferenceType())
VD->getType()->isPointerType())
clonedDRE = BuildOp(UO_Deref, clonedDRE);
if (m_DiffReq.Mode == DiffMode::jacobian) {
if (m_VectorOutput.size() <= outputArrayCursor)
Expand Down Expand Up @@ -2706,11 +2712,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDCloneType = CloneType(VD->getType());
VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema);
}
bool isDerivativeOfRefType = VD->getType()->isReferenceType();

bool isRefType = VD->getType()->isLValueReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
bool initializeDerivedVar = true;
std::string typeName;
if (const auto* RT =
utils::GetValueType(VD->getType())->getAs<RecordType>())
typeName = RT->getDecl()->getNameAsString();
if (typeName == "initializer_list") {
if (VD->getInit()) {
if (const auto* CXXILE = dyn_cast<CXXStdInitializerListExpr>(
VD->getInit()->IgnoreImplicit())) {
if (const auto* ILE = dyn_cast<InitListExpr>(
CXXILE->getSubExpr()->IgnoreImplicit())) {
VDDerivedType = GetCladArrayOfType((*ILE->getInits())->getType());
unsigned numInits = ILE->getNumInits();
VDDerivedInit = ConstantFolder::synthesizeLiteral(
m_Context.getSizeType(), m_Context, numInits);
VDCloneType = VDDerivedType;
}
} else if (isRefType) {
initDiff = Visit(VD->getInit());
VDDerivedInit = BuildOp(UO_AddrOf, initDiff.getExpr_dx());
VDDerivedType = VDDerivedInit->getType();
VDCloneType = VDDerivedType;
}
}
}

// Check if the variable is pointer type and initialized by new expression
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;
Expand Down Expand Up @@ -2741,7 +2773,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// `VDDerivedType` is the corresponding non-reference type and the initial
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
VDDerivedInit = getZeroInit(VD->getType());
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VD->getType());

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
Expand All @@ -2758,14 +2791,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

if (isDerivativeOfRefType) {
if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType =
ComputeAdjointType(VD->getType().getNonReferenceType());
isDerivativeOfRefType = false;
isRefType = false;
}
if (promoteToFnScope || !isDerivativeOfRefType)
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
Expand Down Expand Up @@ -2822,7 +2855,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) {
if (!isRefType && (!isPointerType || isInitializedByNewExpr)) {
Expr* derivedE = nullptr;

if (!clad::utils::hasNonDifferentiableAttribute(VD)) {
Expand Down Expand Up @@ -2870,7 +2903,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
if (isDerivativeOfRefType && promoteToFnScope) {
if (isRefType && promoteToFnScope) {
Expr* assignDerivativeE =
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
Expand All @@ -2891,7 +2924,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// ->
// double* ref;
// ref = &x;
if (isDerivativeOfRefType && promoteToFnScope)
if (isRefType && promoteToFnScope)
VDClone = BuildGlobalVarDecl(
VDCloneType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()),
Expand Down
109 changes: 109 additions & 0 deletions test/Gradient/Loops.C
Original file line number Diff line number Diff line change
Expand Up @@ -2876,6 +2876,113 @@ double fn36(double x, double y) {
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn37(double x, double y) {
auto&& range = {1., x, 2., y, 3.};
double sum = 0;
for (auto elem : range)
sum += elem;
return sum;
}

//CHECK: void fn37_grad(double x, double y, double *_d_x, double *_d_y) {
//CHECK-NEXT: unsigned {{int|long}} _t0;
//CHECK-NEXT: clad::array<double> *_d___range1 = 0;
//CHECK-NEXT: clad::array<double> *__range10 = {};
//CHECK-NEXT: double *_d___begin1 = 0;
//CHECK-NEXT: const double *__begin10 = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: clad::tape<double> _t2 = {};
//CHECK-NEXT: clad::tape<double> _t3 = {};
//CHECK-NEXT: clad::array<double> _d_range = {{5U|5UL}};
//CHECK-NEXT: clad::array<double> range = {1., x, 2., y, 3.};
//CHECK-NEXT: double _d_sum = 0;
//CHECK-NEXT: double sum = 0;
//CHECK-NEXT: _t0 = {{0U|0UL}};
//CHECK-NEXT: _d___range1 = &_d_range;
//CHECK-NEXT: _d___begin1 = std::begin(*_d___range1);
//CHECK-NEXT: __range10 = &range;
//CHECK-NEXT: __begin10 = std::begin(*__range10);
//CHECK-NEXT: const double *__end10 = std::end(*__range10);
//CHECK-NEXT: double _d_elem = 0;
//CHECK-NEXT: double elem = 0;
//CHECK-NEXT: for (; __begin10 != __end10; ++__begin10 , ++_d___begin1) {
//CHECK-NEXT: {
//CHECK-NEXT: _d_elem = *_d___begin1;
//CHECK-NEXT: elem = *__begin10;
//CHECK-NEXT: clad::push(_t2, elem);
//CHECK-NEXT: clad::push(_t3, _d_elem);
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, sum);
//CHECK-NEXT: sum += elem;
//CHECK-NEXT: }
//CHECK-NEXT: _d_sum += 1;
//CHECK-NEXT: for (; _t0; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: {
//CHECK-NEXT: _d___begin1--;
//CHECK-NEXT: elem = clad::pop(_t2);
//CHECK-NEXT: _d_elem = clad::pop(_t3);
//CHECK-NEXT: }
//CHECK-NEXT: sum = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_sum;
//CHECK-NEXT: _d_elem += _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: *_d___begin1 += _d_elem;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: *_d_x += _d_range[1];
//CHECK-NEXT: *_d_y += _d_range[3];
//CHECK-NEXT: }
//CHECK-NEXT: }

double fn38(double x) {
double res = 0;
auto &&range = {1, 2, 3};
for (auto i = range.begin(); i != range.end(); i++) {
res += x * (*i);
}
return res;
}

//CHECK: void fn38_grad(double x, double *_d_x) {
//CHECK-NEXT: int *_d_i = 0;
//CHECK-NEXT: const int *i = 0;
//CHECK-NEXT: clad::tape<double> _t1 = {};
//CHECK-NEXT: double _d_res = 0;
//CHECK-NEXT: double res = 0;
//CHECK-NEXT: clad::array<int> _d_range = {{3U|3UL}};
//CHECK-NEXT: clad::array<int> range = {1, 2, 3};
//CHECK-NEXT: unsigned {{int|long}} _t0 = {{0U|0UL}};
//CHECK-NEXT: _d_i = std::begin(_d_range);
//CHECK-NEXT: for (i = std::begin(range); ; _d_i++ , i++) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!(i != range.end()))
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: _t0++;
//CHECK-NEXT: clad::push(_t1, res);
//CHECK-NEXT: res += x * (*i);
//CHECK-NEXT: }
//CHECK-NEXT: _d_res += 1;
//CHECK-NEXT: for (;; _t0--) {
//CHECK-NEXT: {
//CHECK-NEXT: if (!_t0)
//CHECK-NEXT: break;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: i--;
//CHECK-NEXT: _d_i--;
//CHECK-NEXT: }
//CHECK-NEXT: {
//CHECK-NEXT: res = clad::pop(_t1);
//CHECK-NEXT: double _r_d0 = _d_res;
//CHECK-NEXT: *_d_x += _r_d0 * (*i);
//CHECK-NEXT: *_d_i += x * _r_d0;
//CHECK-NEXT: }
//CHECK-NEXT: }
//CHECK-NEXT: }

#define TEST(F, x) { \
result[0] = 0; \
auto F##grad = clad::gradient(F);\
Expand Down Expand Up @@ -2963,6 +3070,8 @@ int main() {
TEST_2(fn34, 5, 2); // CHECK-EXEC: {12.00, 7.00}
TEST_2(fn35, 1, 1); // CHECK-EXEC: {1.89, 0.00}
TEST_2(fn36, 6, 3); // CHECK-EXEC: {1.00, 1.00}
TEST_2(fn37, 6, 3); // CHECK-EXEC: {1.00, 1.00}
TEST(fn38, 9); // CHECK-EXEC: {6.00}
}

//CHECK: void sq_pullback(double x, double _d_y, double *_d_x) {
Expand Down

0 comments on commit 3e060b2

Please sign in to comment.