Skip to content

Commit

Permalink
Add support for std::initializer_list in the reverse mode
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Aug 7, 2024
1 parent 5bd119b commit e9e6ed5
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 9 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,8 @@ namespace clad {
StmtDiff VisitDoStmt(const clang::DoStmt* DS);
StmtDiff VisitContinueStmt(const clang::ContinueStmt* CS);
StmtDiff VisitBreakStmt(const clang::BreakStmt* BS);
StmtDiff
VisitCXXStdInitializerListExpr(const clang::CXXStdInitializerListExpr* ILE);
StmtDiff VisitCXXThisExpr(const clang::CXXThisExpr* CTE);
StmtDiff VisitCXXNewExpr(const clang::CXXNewExpr* CNE);
StmtDiff VisitCXXDeleteExpr(const clang::CXXDeleteExpr* CDE);
Expand Down
51 changes: 42 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,12 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
addToCurrentBlock(enzymeCall);
}
}

StmtDiff ReverseModeVisitor::VisitCXXStdInitializerListExpr(
const clang::CXXStdInitializerListExpr* ILE) {
return Visit(ILE->getSubExpr(), dfdx());
}

StmtDiff ReverseModeVisitor::VisitStmt(const Stmt* S) {
diag(
DiagnosticsEngine::Warning, S->getBeginLoc(),
Expand Down Expand Up @@ -1472,7 +1478,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// global. Ref-type declarations cannot be moved to the function global
// scope because they can't be separated from their inits.
if (DRE->getDecl()->getType()->isReferenceType() &&
!VD->getType()->isReferenceType())
VD->getType()->isPointerType())
clonedDRE = BuildOp(UO_Deref, clonedDRE);
if (m_DiffReq.Mode == DiffMode::jacobian) {
if (m_VectorOutput.size() <= outputArrayCursor)
Expand Down Expand Up @@ -2706,11 +2712,37 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
VDCloneType = CloneType(VD->getType());
VDDerivedType = getNonConstType(VDCloneType, m_Context, m_Sema);
}
bool isDerivativeOfRefType = VD->getType()->isReferenceType();

bool isRefType = VD->getType()->isLValueReferenceType();
VarDecl* VDDerived = nullptr;
bool isPointerType = VD->getType()->isPointerType();
bool isInitializedByNewExpr = false;
bool initializeDerivedVar = true;
std::string typeName;
if (const auto* RT =
utils::GetValueType(VD->getType())->getAs<RecordType>())
typeName = RT->getDecl()->getNameAsString();
if (typeName == "initializer_list") {
if (VD->getInit()) {
if (const auto* CXXILE = dyn_cast<CXXStdInitializerListExpr>(
VD->getInit()->IgnoreImplicit())) {
if (const auto* ILE = dyn_cast<InitListExpr>(
CXXILE->getSubExpr()->IgnoreImplicit())) {
VDDerivedType = GetCladArrayOfType((*ILE->getInits())->getType());
unsigned numInits = ILE->getNumInits();
VDDerivedInit = ConstantFolder::synthesizeLiteral(
m_Context.getSizeType(), m_Context, numInits);
VDCloneType = VDDerivedType;
}
} else if (isRefType) {
initDiff = Visit(VD->getInit());
VDDerivedInit = BuildOp(UO_AddrOf, initDiff.getExpr_dx());
VDDerivedType = VDDerivedInit->getType();
VDCloneType = VDDerivedType;
}
}
}

// Check if the variable is pointer type and initialized by new expression
if (isPointerType && VD->getInit() && isa<CXXNewExpr>(VD->getInit()))
isInitializedByNewExpr = true;
Expand Down Expand Up @@ -2741,7 +2773,8 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// `VDDerivedType` is the corresponding non-reference type and the initial
// value is set to 0.
// Otherwise, for non-reference types, the initial value is set to 0.
VDDerivedInit = getZeroInit(VD->getType());
if (!VDDerivedInit)
VDDerivedInit = getZeroInit(VD->getType());

// `specialThisDiffCase` is only required for correctly differentiating
// the following code:
Expand All @@ -2758,14 +2791,14 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
}
}

if (isDerivativeOfRefType) {
if (isRefType) {
initDiff = Visit(VD->getInit());
if (!initDiff.getForwSweepExpr_dx()) {
VDDerivedType =
ComputeAdjointType(VD->getType().getNonReferenceType());
isDerivativeOfRefType = false;
isRefType = false;
}
if (promoteToFnScope || !isDerivativeOfRefType)
if (promoteToFnScope || !isRefType)
VDDerivedInit = getZeroInit(VDDerivedType);
else
VDDerivedInit = initDiff.getForwSweepExpr_dx();
Expand Down Expand Up @@ -2822,7 +2855,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// differentiated and should not be differentiated again.
// If `VD` is a reference to a non-local variable then also there's no
// need to call `Visit` since non-local variables are not differentiated.
if (!isDerivativeOfRefType && (!isPointerType || isInitializedByNewExpr)) {
if (!isRefType && (!isPointerType || isInitializedByNewExpr)) {
Expr* derivedE = nullptr;

if (!clad::utils::hasNonDifferentiableAttribute(VD)) {
Expand Down Expand Up @@ -2870,7 +2903,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// FIXME: Add extra parantheses if derived variable pointer is pointing to a
// class type object.
if (isDerivativeOfRefType && promoteToFnScope) {
if (isRefType && promoteToFnScope) {
Expr* assignDerivativeE =
BuildOp(BinaryOperatorKind::BO_Assign, derivedVDE,
BuildOp(UnaryOperatorKind::UO_AddrOf,
Expand All @@ -2891,7 +2924,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// ->
// double* ref;
// ref = &x;
if (isDerivativeOfRefType && promoteToFnScope)
if (isRefType && promoteToFnScope)
VDClone = BuildGlobalVarDecl(
VDCloneType, VD->getNameAsString(),
BuildOp(UnaryOperatorKind::UO_AddrOf, initDiff.getExpr()),
Expand Down

0 comments on commit e9e6ed5

Please sign in to comment.