From 74904d8926b740e26b04237cfe6655b8ba5131c4 Mon Sep 17 00:00:00 2001 From: Parth Date: Wed, 27 Apr 2022 17:14:54 +0530 Subject: [PATCH] Fix tests --- .../clad/Differentiator/DerivativeBuilder.h | 2 +- include/clad/Differentiator/DiffMode.h | 2 +- ...Visitor.h => ReverseModeForwPassVisitor.h} | 15 ++-- lib/Differentiator/CMakeLists.txt | 2 +- lib/Differentiator/DerivativeBuilder.cpp | 6 +- ...tor.cpp => ReverseModeForwPassVisitor.cpp} | 81 +++++++++++-------- lib/Differentiator/ReverseModeVisitor.cpp | 19 ++--- test/Gradient/Assignments.C | 67 ++++++++------- 8 files changed, 106 insertions(+), 88 deletions(-) rename include/clad/Differentiator/{TransformSourceFnVisitor.h => ReverseModeForwPassVisitor.h} (79%) rename lib/Differentiator/{TransformSourceFnVisitor.cpp => ReverseModeForwPassVisitor.cpp} (87%) diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 6e5c2f3ff..ca4ef6be9 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -70,7 +70,7 @@ namespace clad { friend class ReverseModeVisitor; friend class HessianModeVisitor; friend class JacobianModeVisitor; - friend class TransformSourceFnVisitor; + friend class ReverseModeForwPassVisitor; clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index fe2b38071..e435bbb42 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -8,7 +8,7 @@ namespace clad { experimental_pushforward, experimental_pullback, reverse, - reverse_source_fn, + reverse_mode_forward_pass, hessian, jacobian, error_estimation diff --git a/include/clad/Differentiator/TransformSourceFnVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h similarity index 79% rename from include/clad/Differentiator/TransformSourceFnVisitor.h rename to include/clad/Differentiator/ReverseModeForwPassVisitor.h index e256ff883..d0bb1212a 100644 --- a/include/clad/Differentiator/TransformSourceFnVisitor.h +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -1,8 +1,8 @@ #ifndef CLAD_TRANSFORM_SOURCE_FN_VISITOR_H #define CLAD_TRANSFORM_SOURCE_FN_VISITOR_H -#include "clad/Differentiator/ReverseModeVisitor.h" #include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clad/Differentiator/ReverseModeVisitor.h" #include "clang/AST/StmtVisitor.h" #include "clang/Sema/Sema.h" @@ -10,23 +10,22 @@ #include "llvm/ADT/SmallVector.h" namespace clad { -class TransformSourceFnVisitor - : public ReverseModeVisitor { +class ReverseModeForwPassVisitor : public ReverseModeVisitor { private: Stmts m_Globals; llvm::SmallVector ComputeParamTypes(const DiffParams& diffParams); clang::QualType ComputeReturnType(); - llvm::SmallVector - BuildParams(DiffParams& diffParams); - clang::QualType GetParameterDerivativeType(clang::QualType yType, clang::QualType xType); + llvm::SmallVector BuildParams(DiffParams& diffParams); + clang::QualType GetParameterDerivativeType(clang::QualType yType, + clang::QualType xType); public: - TransformSourceFnVisitor(DerivativeBuilder& builder); + ReverseModeForwPassVisitor(DerivativeBuilder& builder); OverloadedDeclWithContext Derive(const clang::FunctionDecl* FD, const DiffRequest& request); - + StmtDiff ProcessSingleStmt(const clang::Stmt* S); StmtDiff VisitStmt(const clang::Stmt* S) override; diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index 6a67f059a..530bdd61e 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -37,8 +37,8 @@ add_llvm_library(cladDifferentiator HessianModeVisitor.cpp JacobianModeVisitor.cpp MultiplexExternalRMVSource.cpp + ReverseModeForwPassVisitor.cpp ReverseModeVisitor.cpp - TransformSourceFnVisitor.cpp ErrorEstimator.cpp EstimationModel.cpp StmtClone.cpp diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index d04f4e806..344f65cf5 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -11,7 +11,7 @@ #include "clad/Differentiator/HessianModeVisitor.h" #include "clad/Differentiator/JacobianModeVisitor.h" #include "clad/Differentiator/ReverseModeVisitor.h" -#include "clad/Differentiator/TransformSourceFnVisitor.h" +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" #include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/StmtClone.h" @@ -173,8 +173,8 @@ namespace clad { } else if (request.Mode == DiffMode::experimental_pullback) { ReverseModeVisitor V(*this); result = V.DerivePullback(FD, request); - } else if (request.Mode == DiffMode::reverse_source_fn) { - TransformSourceFnVisitor V(*this); + } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { + ReverseModeForwPassVisitor V(*this); result = V.Derive(FD, request); } else if (request.Mode == DiffMode::hessian) { diff --git a/lib/Differentiator/TransformSourceFnVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp similarity index 87% rename from lib/Differentiator/TransformSourceFnVisitor.cpp rename to lib/Differentiator/ReverseModeForwPassVisitor.cpp index fed42efeb..86ebb3161 100644 --- a/lib/Differentiator/TransformSourceFnVisitor.cpp +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -1,4 +1,4 @@ -#include "clad/Differentiator/TransformSourceFnVisitor.h" +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" #include "clad/Differentiator/CladUtils.h" #include "clad/Differentiator/DiffPlanner.h" @@ -12,23 +12,24 @@ using namespace clang; namespace clad { -TransformSourceFnVisitor::TransformSourceFnVisitor(DerivativeBuilder& builder) +ReverseModeForwPassVisitor::ReverseModeForwPassVisitor( + DerivativeBuilder& builder) : ReverseModeVisitor(builder) {} OverloadedDeclWithContext -TransformSourceFnVisitor::Derive(const FunctionDecl* FD, - const DiffRequest& request) { +ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, + const DiffRequest& request) { silenceDiags = !request.VerboseDiags; m_Function = FD; - m_Mode = DiffMode::reverse_source_fn; + m_Mode = DiffMode::reverse_mode_forward_pass; assert(m_Function && "Must not be null."); DiffParams args{}; std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); - auto fnName = m_Function->getNameAsString() + "_with_adjoint"; + auto fnName = m_Function->getNameAsString() + "_forw"; auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName); auto paramTypes = ComputeParamTypes(args); @@ -86,8 +87,9 @@ TransformSourceFnVisitor::Derive(const FunctionDecl* FD, // FIXME: This function is copied from ReverseModeVisitor. Find a suitable place // for it. -QualType TransformSourceFnVisitor::GetParameterDerivativeType(QualType yType, - QualType xType) { +QualType +ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType, + QualType xType) { assert(yType.getNonReferenceType()->isRealType() && "yType should be a builtin-numerical scalar type!!"); QualType xValueType = utils::GetValueType(xType); @@ -101,7 +103,7 @@ QualType TransformSourceFnVisitor::GetParameterDerivativeType(QualType yType, } llvm::SmallVector -TransformSourceFnVisitor::ComputeParamTypes(const DiffParams& diffParams) { +ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { llvm::SmallVector paramTypes; paramTypes.reserve(m_Function->getNumParams() * 2); for (auto PVD : m_Function->parameters()) @@ -129,7 +131,7 @@ TransformSourceFnVisitor::ComputeParamTypes(const DiffParams& diffParams) { return paramTypes; } -clang::QualType TransformSourceFnVisitor::ComputeReturnType() { +clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() { auto valAndAdjointTempDecl = GetCladClassDecl("ValueAndAdjoint"); auto RT = m_Function->getReturnType(); auto T = GetCladClassOfType(valAndAdjointTempDecl, {RT, RT}); @@ -137,7 +139,7 @@ clang::QualType TransformSourceFnVisitor::ComputeReturnType() { } llvm::SmallVector -TransformSourceFnVisitor::BuildParams(DiffParams& diffParams) { +ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { llvm::SmallVector params, paramDerivatives; params.reserve(m_Function->getNumParams() + diffParams.size()); auto derivativeFnType = cast(m_Derivative->getType()); @@ -207,17 +209,17 @@ TransformSourceFnVisitor::BuildParams(DiffParams& diffParams) { return params; } -StmtDiff TransformSourceFnVisitor::ProcessSingleStmt(const clang::Stmt* S) { +StmtDiff ReverseModeForwPassVisitor::ProcessSingleStmt(const clang::Stmt* S) { StmtDiff SDiff = Visit(S); return {SDiff.getStmt()}; } -StmtDiff TransformSourceFnVisitor::VisitStmt(const clang::Stmt* S) { +StmtDiff ReverseModeForwPassVisitor::VisitStmt(const clang::Stmt* S) { return {Clone(S)}; } StmtDiff -TransformSourceFnVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) { +ReverseModeForwPassVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) { beginScope(Scope::DeclScope); beginBlock(); for (Stmt* S : CS->body()) { @@ -229,7 +231,7 @@ TransformSourceFnVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) { return {forward}; } -StmtDiff TransformSourceFnVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { +StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { DeclRefExpr* clonedDRE = nullptr; // Check if referenced Decl was "replaced" with another identifier inside // the derivative @@ -265,17 +267,17 @@ StmtDiff TransformSourceFnVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { } StmtDiff -TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { +ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { const Expr* value = RS->getRetValue(); auto returnDiff = Visit(value); llvm::SmallVector returnArgs = {returnDiff.getExpr(), returnDiff.getExpr_dx()}; - Expr* returnInitList = m_Sema.BuildInitList(noLoc, returnArgs, noLoc).get(); + Expr* returnInitList = m_Sema.ActOnInitList(noLoc, returnArgs, noLoc).get(); Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get(); return {newRS}; } -// StmtDiff TransformSourceFnVisitor::VisitDeclStmt(const DeclStmt* DS) { +// StmtDiff ReverseModeForwPassVisitor::VisitDeclStmt(const DeclStmt* DS) { // llvm::SmallVector decls, derivedDecls; // for (auto D : DS->decls()) { // if (auto VD = dyn_cast(D)) { @@ -292,42 +294,48 @@ TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { // } // } -// VarDeclDiff TransformSourceFnVisitor::DifferentiateVarDecl(const VarDecl* VD) { +// VarDeclDiff ReverseModeForwPassVisitor::DifferentiateVarDecl(const VarDecl* +// VD) { // StmtDiff initDiff; // Expr* VDDerivedInit = nullptr; // auto VDDerivedType = VD->getType(); // bool isVDRefType = VD->getType()->isReferenceType(); // VarDecl* VDDerived = nullptr; - + // if (auto VDCAT = dyn_cast(VD->getType())) { // assert("Should not reach here!!!"); // // VDDerivedType = -// // GetCladArrayOfType(QualType(VDCAT->getPointeeOrArrayElementType(), -// // VDCAT->getIndexTypeCVRQualifiers())); +// // GetCladArrayOfType(QualType(VDCAT->getPointeeOrArrayElementType(), +// // VDCAT->getIndexTypeCVRQualifiers())); // // VDDerivedInit = ConstantFolder::synthesizeLiteral( -// // m_Context.getSizeType(), m_Context, VDCAT->getSize().getZExtValue()); -// // VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), +// // m_Context.getSizeType(), m_Context, +// VDCAT->getSize().getZExtValue()); +// // VDDerived = BuildVarDecl(VDDerivedType, "_d_" + +// VD->getNameAsString(), // // VDDerivedInit, false, nullptr, -// // clang::VarDecl::InitializationStyle::CallInit); +// // clang::VarDecl::InitializationStyle::CallInit); // } else { -// // If VD is a reference to a local variable, then the initial value is set +// // If VD is a reference to a local variable, then the initial value is +// set // // to the derived variable of the corresponding local variable. -// // If VD is a reference to a non-local variable (global variable, struct +// // If VD is a reference to a non-local variable (global variable, +// struct // // member etc), then no derived variable is available, thus `VDDerived` // // does not need to reference any variable, consequentially the -// // `VDDerivedType` is the corresponding non-reference type and the initial +// // `VDDerivedType` is the corresponding non-reference type and the +// initial // // value is set to 0. // // Otherwise, for non-reference types, the initial value is set to 0. // VDDerivedInit = getZeroInit(VD->getType()); // // `specialThisDiffCase` is only required for correctly differentiating -// // the following code: +// // the following code: // // ``` // // Class _d_this_obj; // // Class* _d_this = &_d_this_obj; // // ``` // // Computation of hessian requires this code to be correctly -// // differentiated. +// // differentiated. // bool specialThisDiffCase = false; // if (auto MD = dyn_cast(m_Function)) { // if (VDDerivedType->isPointerType() && MD->isInstance()) { @@ -361,7 +369,8 @@ TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { // m_Context.getTrivialTypeSourceInfo(VDDerivedType), // VD->getInitStyle()); // else -// VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), +// VDDerived = BuildVarDecl(VDDerivedType, "_d_" + +// VD->getNameAsString(), // VDDerivedInit); // } @@ -370,12 +379,14 @@ TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { // // If `VD` is a reference to a non-local variable then also there's no // // need to call `Visit` since non-local variables are not differentiated. // if (!isVDRefType) { -// initDiff = VD->getInit() ? Visit(VD->getInit(), BuildDeclRef(VDDerived)) +// initDiff = VD->getInit() ? Visit(VD->getInit(), +// BuildDeclRef(VDDerived)) // : StmtDiff{}; -// // If we are differentiating `VarDecl` corresponding to a local variable +// // If we are differentiating `VarDecl` corresponding to a local +// variable // // inside a loop, then we need to reset it to 0 at each iteration. -// // +// // // // for example, if defined inside a loop, // // ``` // // double localVar = i; @@ -385,7 +396,7 @@ TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { // // { // // *_d_i += _d_localVar; // // _d_localVar = 0; -// // } +// // } // if (isInsideLoop) { // Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign, // BuildDeclRef(VDDerived), diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 0f4744451..58f77998a 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1704,7 +1704,7 @@ namespace clad { if (FD->getReturnType()->isReferenceType()) { DiffRequest transformReq; transformReq.Function = FD; - transformReq.Mode = DiffMode::reverse_source_fn; + transformReq.Mode = DiffMode::reverse_mode_forward_pass; transformReq.BaseFunctionName = FD->getNameAsString(); transformReq.VerboseDiags = true; FunctionDecl* transformedSourcefn = plugin::ProcessDiffRequest(m_CladPlugin, transformReq); @@ -2207,8 +2207,13 @@ namespace clad { } if (isVDRefType) { + QualType T = utils::GetValueType(VD->getType()); + T.removeLocalConst(); VDDerivedType = - m_Context.getPointerType(VDDerivedType.getNonReferenceType()); + m_Context.getPointerType(T); + initDiff = Visit(VD->getInit()); + if (!initDiff.getExpr_dx()) + VDDerivedType = VDDerivedType.getNonReferenceType(); VDDerivedInit = getZeroInit(VDDerivedType); } @@ -2221,15 +2226,11 @@ namespace clad { // ``` // Computation of hessian requires this code to be correctly // differentiated. - if (isVDRefType || specialThisDiffCase) { + if (specialThisDiffCase) { VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); initDiff = Visit(VD->getInit()); - // if (initDiff.getExpr_dx()) - // VDDerivedInit = initDiff.getExpr_dx(); - // else - // VDDerivedType = VDDerivedType.getNonReferenceType(); - if (!initDiff.getExpr_dx()) - VDDerivedType = VDDerivedType.getNonReferenceType(); + if (initDiff.getExpr_dx()) + VDDerivedInit = initDiff.getExpr_dx(); } // Here separate behaviour for record and non-record types is only // necessary to preserve the old tests. diff --git a/test/Gradient/Assignments.C b/test/Gradient/Assignments.C index 0b1bc1fb2..c19cf3cb4 100644 --- a/test/Gradient/Assignments.C +++ b/test/Gradient/Assignments.C @@ -580,10 +580,11 @@ double f14(double i, double j) { } // CHECK: void f14_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { -// CHECK-NEXT: double &_d_a = * _d_i; +// CHECK-NEXT: double *_d_a = 0; // CHECK-NEXT: double _t0; // CHECK-NEXT: double _t1; // CHECK-NEXT: double _t2; +// CHECK-NEXT: _d_a = &* _d_i; // CHECK-NEXT: double &a = i; // CHECK-NEXT: _t0 = i; // CHECK-NEXT: a = 2 * _t0; @@ -596,24 +597,24 @@ double f14(double i, double j) { // CHECK-NEXT: _label0: // CHECK-NEXT: * _d_i += 1; // CHECK-NEXT: { -// CHECK-NEXT: double _r_d2 = _d_a; -// CHECK-NEXT: _d_a += _r_d2 * _t1; +// CHECK-NEXT: double _r_d2 = *_d_a; +// CHECK-NEXT: *_d_a += _r_d2 * _t1; // CHECK-NEXT: double _r2 = _t2 * _r_d2; // CHECK-NEXT: * _d_i += _r2; -// CHECK-NEXT: _d_a -= _r_d2; +// CHECK-NEXT: *_d_a -= _r_d2; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: double _r_d1 = _d_a; -// CHECK-NEXT: _d_a += _r_d1; +// CHECK-NEXT: double _r_d1 = *_d_a; +// CHECK-NEXT: *_d_a += _r_d1; // CHECK-NEXT: * _d_i += _r_d1; -// CHECK-NEXT: _d_a -= _r_d1; +// CHECK-NEXT: *_d_a -= _r_d1; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: double _r_d0 = _d_a; +// CHECK-NEXT: double _r_d0 = *_d_a; // CHECK-NEXT: double _r0 = _r_d0 * _t0; // CHECK-NEXT: double _r1 = 2 * _r_d0; // CHECK-NEXT: * _d_i += _r1; -// CHECK-NEXT: _d_a -= _r_d0; +// CHECK-NEXT: *_d_a -= _r_d0; // CHECK-NEXT: } // CHECK-NEXT: } @@ -633,9 +634,9 @@ double f15(double i, double j) { // CHECK-NEXT: double _t0; // CHECK-NEXT: double _t1; // CHECK-NEXT: double _d_b = 0; -// CHECK-NEXT: double &_d_a = _d_b; -// CHECK-NEXT: double &_d_c = * _d_i; -// CHECK-NEXT: double &_d_d = * _d_j; +// CHECK-NEXT: double *_d_a = 0; +// CHECK-NEXT: double *_d_c = 0; +// CHECK-NEXT: double *_d_d = 0; // CHECK-NEXT: double _t2; // CHECK-NEXT: double _t3; // CHECK-NEXT: double _t4; @@ -646,8 +647,11 @@ double f15(double i, double j) { // CHECK-NEXT: _t1 = i; // CHECK-NEXT: _t0 = j; // CHECK-NEXT: double b = _t1 * _t0; +// CHECK-NEXT: _d_a = &_d_b; // CHECK-NEXT: double &a = b; +// CHECK-NEXT: _d_c = &* _d_i; // CHECK-NEXT: double &c = i; +// CHECK-NEXT: _d_d = &* _d_j; // CHECK-NEXT: double &d = j; // CHECK-NEXT: _t3 = a; // CHECK-NEXT: _t2 = i; @@ -664,26 +668,26 @@ double f15(double i, double j) { // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { -// CHECK-NEXT: _d_a += 1; -// CHECK-NEXT: _d_c += 1; -// CHECK-NEXT: _d_d += 1; +// CHECK-NEXT: *_d_a += 1; +// CHECK-NEXT: *_d_c += 1; +// CHECK-NEXT: *_d_d += 1; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: double _r_d3 = _d_d; -// CHECK-NEXT: _d_d += _r_d3 * _t6; +// CHECK-NEXT: double _r_d3 = *_d_d; +// CHECK-NEXT: *_d_d += _r_d3 * _t6; // CHECK-NEXT: double _r7 = _t7 * _r_d3; // CHECK-NEXT: double _r8 = _r7 * _t8; // CHECK-NEXT: double _r9 = 3 * _r7; // CHECK-NEXT: * _d_j += _r9; -// CHECK-NEXT: _d_d -= _r_d3; +// CHECK-NEXT: *_d_d -= _r_d3; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: double _r_d2 = _d_c; -// CHECK-NEXT: _d_c += _r_d2; +// CHECK-NEXT: double _r_d2 = *_d_c; +// CHECK-NEXT: *_d_c += _r_d2; // CHECK-NEXT: double _r5 = _r_d2 * _t5; // CHECK-NEXT: double _r6 = 3 * _r_d2; // CHECK-NEXT: * _d_i += _r6; -// CHECK-NEXT: _d_c -= _r_d2; +// CHECK-NEXT: *_d_c -= _r_d2; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: double _r_d1 = _d_b; @@ -694,11 +698,11 @@ double f15(double i, double j) { // CHECK-NEXT: _d_b -= _r_d1; // CHECK-NEXT: } // CHECK-NEXT: { -// CHECK-NEXT: double _r_d0 = _d_a; -// CHECK-NEXT: _d_a += _r_d0 * _t2; +// CHECK-NEXT: double _r_d0 = *_d_a; +// CHECK-NEXT: *_d_a += _r_d0 * _t2; // CHECK-NEXT: double _r2 = _t3 * _r_d0; // CHECK-NEXT: * _d_i += _r2; -// CHECK-NEXT: _d_a -= _r_d0; +// CHECK-NEXT: *_d_a -= _r_d0; // CHECK-NEXT: } // CHECK-NEXT: { // CHECK-NEXT: double _r0 = _d_b * _t0; @@ -717,14 +721,17 @@ double f16(double i, double j) { } // CHECK: void f16_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { -// CHECK-NEXT: double &_d_a = * _d_i; -// CHECK-NEXT: double &_d_b = _d_a; -// CHECK-NEXT: double &_d_c = _d_b; +// CHECK-NEXT: double *_d_a = 0; +// CHECK-NEXT: double *_d_b = 0; +// CHECK-NEXT: double *_d_c = 0; // CHECK-NEXT: double _t0; // CHECK-NEXT: double _t1; // CHECK-NEXT: double _t2; +// CHECK-NEXT: _d_a = &* _d_i; // CHECK-NEXT: double &a = i; +// CHECK-NEXT: _d_b = &*_d_a; // CHECK-NEXT: double &b = a; +// CHECK-NEXT: _d_c = &*_d_b; // CHECK-NEXT: double &c = b; // CHECK-NEXT: _t1 = c; // CHECK-NEXT: _t2 = j; @@ -735,13 +742,13 @@ double f16(double i, double j) { // CHECK-NEXT: _label0: // CHECK-NEXT: * _d_i += 1; // CHECK-NEXT: { -// CHECK-NEXT: double _r_d0 = _d_c; -// CHECK-NEXT: _d_c += _r_d0 * _t0; +// CHECK-NEXT: double _r_d0 = *_d_c; +// CHECK-NEXT: *_d_c += _r_d0 * _t0; // CHECK-NEXT: double _r0 = _t1 * _r_d0; // CHECK-NEXT: double _r1 = _r0 * _t2; // CHECK-NEXT: double _r2 = 4 * _r0; // CHECK-NEXT: * _d_j += _r2; -// CHECK-NEXT: _d_c -= _r_d0; +// CHECK-NEXT: *_d_c -= _r_d0; // CHECK-NEXT: } // CHECK-NEXT: }