From e2b5e5f6b32a6baf37769dd34598096afc313963 Mon Sep 17 00:00:00 2001 From: Parth Date: Sun, 10 Apr 2022 21:45:09 +0530 Subject: [PATCH] Add support for diff of ref return types in rev mode --- .github/workflows/ci.yml | 20 +- VERSION | 2 +- docs/internalDocs/ReleaseNotes.md | 10 +- docs/userDocs/source/user/UsingVectorMode.rst | 4 + environment.yml | 2 +- .../clad/Differentiator/DerivativeBuilder.h | 2 +- include/clad/Differentiator/DiffMode.h | 1 + include/clad/Differentiator/Differentiator.h | 6 + .../ReverseModeForwPassVisitor.h | 38 ++ .../clad/Differentiator/ReverseModeVisitor.h | 10 +- lib/Differentiator/CMakeLists.txt | 1 + lib/Differentiator/DerivativeBuilder.cpp | 5 + .../ReverseModeForwPassVisitor.cpp | 420 ++++++++++++++++++ lib/Differentiator/ReverseModeVisitor.cpp | 105 ++++- requirements.txt | 2 +- test/Gradient/FunctionCalls.C | 88 ++++ test/Gradient/MemberFunctions.C | 37 ++ test/Gradient/UserDefinedTypes.C | 1 + test/Misc/CladArray.C | 4 +- 19 files changed, 717 insertions(+), 41 deletions(-) create mode 100644 include/clad/Differentiator/ReverseModeForwPassVisitor.h create mode 100644 lib/Differentiator/ReverseModeForwPassVisitor.cpp diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 31947982b..ef1cd6bb5 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -38,11 +38,13 @@ jobs: chmod +x git-clang-format - name: Run git-clang-format run: | + PR_BASE=$(git rev-list ${{ github.event.pull_request.head.sha }} ^${{ github.event.pull_request.base.sha }} | tail --lines 1 | xargs -I {} git rev-parse {}~1) + echo "running git clang-format against $PR_BASE commit" git \ -c color.ui=always \ -c diff.wsErrorHighlight=all \ -c color.diff.whitespace='red reverse' \ - clang-format-15 --diff --binary clang-format-15 origin/master -- demos/ include/ lib/ tools/ || \ + clang-format-15 --diff --binary clang-format-15 --commit $PR_BASE -- demos/ include/ lib/ tools/ || \ (echo "Please run the following git-clang-format locally to fix the formatting: \n git clang-format origin/master -- demos/ include/ lib/ tools/" && exit 1) build: @@ -82,12 +84,12 @@ jobs: os: macos-latest compiler: clang clang-runtime: '14' - + - name: osx-clang-runtime15 os: macos-latest compiler: clang clang-runtime: '15' - + - name: osx-clang-runtime16 os: macos-latest compiler: clang @@ -414,7 +416,7 @@ jobs: os: ubuntu-22.04 compiler: clang-15 clang-runtime: '14' - + - name: ubu22-clang15-runtime15 os: ubuntu-22.04 compiler: clang-15 @@ -614,15 +616,15 @@ jobs: echo "PATH_TO_LLVM_BUILD=$env:PATH_TO_LLVM_BUILD" >> $env:GITHUB_ENV - name: Setup CUDA 8 on Linux if: ${{ matrix.cuda == true }} - run: | - wget --no-verbose https://developer.nvidia.com/compute/cuda/8.0/Prod2/local_installers/cuda_8.0.61_375.26_linux-run + run: | + wget --no-verbose https://developer.nvidia.com/compute/cuda/8.0/Prod2/local_installers/cuda_8.0.61_375.26_linux-run wget --no-verbose https://developer.nvidia.com/compute/cuda/8.0/Prod2/patches/2/cuda_8.0.61.2_linux-run sh ./cuda_8.0.61_375.26_linux-run --tar mxvf sudo cp InstallUtils.pm /usr/lib/x86_64-linux-gnu/perl-base export $PERL5LIB - sudo sh cuda_8.0.61_375.26_linux-run --override --no-opengl-lib --silent --toolkit --kernel-source-path=/lib/modules/4.15.0-1113-azure/build - sudo sh cuda_8.0.61.2_linux-run --silent --accept-eula - export PATH=/usr/local/cuda-8.0/bin:${PATH} + sudo sh cuda_8.0.61_375.26_linux-run --override --no-opengl-lib --silent --toolkit --kernel-source-path=/lib/modules/4.15.0-1113-azure/build + sudo sh cuda_8.0.61.2_linux-run --silent --accept-eula + export PATH=/usr/local/cuda-8.0/bin:${PATH} export LD_LIBRARY_PATH=${LD_LIBRARY_PATH}:/usr/local/cuda-8.0/lib64 echo "LD_LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV echo "PATH=$PATH" >> $GITHUB_ENV diff --git a/VERSION b/VERSION index 70f303689..e7ad5767c 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.2~dev +1.3~dev diff --git a/docs/internalDocs/ReleaseNotes.md b/docs/internalDocs/ReleaseNotes.md index fb0443720..4d2b1aa0a 100644 --- a/docs/internalDocs/ReleaseNotes.md +++ b/docs/internalDocs/ReleaseNotes.md @@ -2,7 +2,7 @@ Introduction ============ This document contains the release notes for the automatic differentiation -plugin for clang Clad, release 1.2. Clad is built on top of +plugin for clang Clad, release 1.3. Clad is built on top of [Clang](http://clang.llvm.org) and [LLVM](http://llvm.org>) compiler infrastructure. Here we describe the status of Clad in some detail, including major improvements from the previous release and new feature work. @@ -11,7 +11,7 @@ Note that if you are reading this file from a git checkout, this document applies to the *next* release, not the current one. -What's New in Clad 1.2? +What's New in Clad 1.3? ======================== Some of the major new features and improvements to Clad are listed here. Generic @@ -21,7 +21,7 @@ described first. External Dependencies --------------------- -* Clad now works with clang-5.0 to clang-15 +* Clad now works with clang-5.0 to clang-16 Forward Mode & Reverse Mode @@ -54,7 +54,7 @@ Fixed Bugs [XXX](https://github.com/vgvassilev/clad/issues/XXX) Special Kudos @@ -68,6 +68,6 @@ FirstName LastName (#commits) A B (N) diff --git a/docs/userDocs/source/user/UsingVectorMode.rst b/docs/userDocs/source/user/UsingVectorMode.rst index 57a70b250..215ea2120 100644 --- a/docs/userDocs/source/user/UsingVectorMode.rst +++ b/docs/userDocs/source/user/UsingVectorMode.rst @@ -1,6 +1,10 @@ Using Vector Mode for Differentiation ************************************** +.. note:: + This feature is still under development and may result in unexpected + behavior. Please report any issues you find. + For forward mode AD, the restriction is that the function can be only be differentiated with respect to a single input variable. However, in many cases, it is desirable to differentiate a function with respect to multiple input diff --git a/environment.yml b/environment.yml index ec74684cd..a22e24ccb 100644 --- a/environment.yml +++ b/environment.yml @@ -1,5 +1,5 @@ channels: - conda-forge dependencies: - - clad=1.0 + - clad=0.9 - xeus-cling diff --git a/include/clad/Differentiator/DerivativeBuilder.h b/include/clad/Differentiator/DerivativeBuilder.h index 39c31691c..7ef3aad1f 100644 --- a/include/clad/Differentiator/DerivativeBuilder.h +++ b/include/clad/Differentiator/DerivativeBuilder.h @@ -77,7 +77,7 @@ namespace clad { friend class ReverseModeVisitor; friend class HessianModeVisitor; friend class JacobianModeVisitor; - + friend class ReverseModeForwPassVisitor; clang::Sema& m_Sema; plugin::CladPlugin& m_CladPlugin; clang::ASTContext& m_Context; diff --git a/include/clad/Differentiator/DiffMode.h b/include/clad/Differentiator/DiffMode.h index a9c27a935..a03e77e49 100644 --- a/include/clad/Differentiator/DiffMode.h +++ b/include/clad/Differentiator/DiffMode.h @@ -11,6 +11,7 @@ enum class DiffMode { reverse, hessian, jacobian, + reverse_mode_forward_pass, error_estimation }; } diff --git a/include/clad/Differentiator/Differentiator.h b/include/clad/Differentiator/Differentiator.h index dca7e2f06..9d189c76f 100644 --- a/include/clad/Differentiator/Differentiator.h +++ b/include/clad/Differentiator/Differentiator.h @@ -20,6 +20,12 @@ #include namespace clad { + template + struct ValueAndAdjoint { + T value; + U adjoint; + }; + /// \returns the size of a c-style string CUDA_HOST_DEVICE unsigned int GetLength(const char* code) { unsigned int count; diff --git a/include/clad/Differentiator/ReverseModeForwPassVisitor.h b/include/clad/Differentiator/ReverseModeForwPassVisitor.h new file mode 100644 index 000000000..a00df1b50 --- /dev/null +++ b/include/clad/Differentiator/ReverseModeForwPassVisitor.h @@ -0,0 +1,38 @@ +#ifndef CLAD_TRANSFORM_SOURCE_FN_VISITOR_H +#define CLAD_TRANSFORM_SOURCE_FN_VISITOR_H + +#include "clad/Differentiator/ParseDiffArgsTypes.h" +#include "clad/Differentiator/ReverseModeVisitor.h" + +#include "clang/AST/StmtVisitor.h" +#include "clang/Sema/Sema.h" + +#include "llvm/ADT/SmallVector.h" + +namespace clad { +class ReverseModeForwPassVisitor : public ReverseModeVisitor { +private: + Stmts m_Globals; + + llvm::SmallVector + ComputeParamTypes(const DiffParams& diffParams); + clang::QualType ComputeReturnType(); + llvm::SmallVector BuildParams(DiffParams& diffParams); + clang::QualType GetParameterDerivativeType(clang::QualType yType, + clang::QualType xType); + +public: + ReverseModeForwPassVisitor(DerivativeBuilder& builder); + DerivativeAndOverload Derive(const clang::FunctionDecl* FD, + const DiffRequest& request); + + StmtDiff ProcessSingleStmt(const clang::Stmt* S); + + StmtDiff VisitStmt(const clang::Stmt* S) override; + StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS) override; + StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE) override; + StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; +}; +} // namespace clad + +#endif \ No newline at end of file diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 3b668ff87..2b05c5bd5 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -31,7 +31,7 @@ namespace clad { : public clang::ConstStmtVisitor, public VisitorBase { - private: + protected: // FIXME: We should remove friend-dependency of the plugin classes here. // For this we will need to separate out AST related functions in // a separate namespace, as well as add getters/setters function of @@ -321,11 +321,11 @@ namespace clad { StmtDiff VisitArraySubscriptExpr(const clang::ArraySubscriptExpr* ASE); StmtDiff VisitBinaryOperator(const clang::BinaryOperator* BinOp); StmtDiff VisitCallExpr(const clang::CallExpr* CE); - StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); + virtual StmtDiff VisitCompoundStmt(const clang::CompoundStmt* CS); StmtDiff VisitConditionalOperator(const clang::ConditionalOperator* CO); StmtDiff VisitCXXBoolLiteralExpr(const clang::CXXBoolLiteralExpr* BL); StmtDiff VisitCXXDefaultArgExpr(const clang::CXXDefaultArgExpr* DE); - StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); + virtual StmtDiff VisitDeclRefExpr(const clang::DeclRefExpr* DRE); StmtDiff VisitDeclStmt(const clang::DeclStmt* DS); StmtDiff VisitFloatingLiteral(const clang::FloatingLiteral* FL); StmtDiff VisitForStmt(const clang::ForStmt* FS); @@ -335,8 +335,8 @@ namespace clad { StmtDiff VisitIntegerLiteral(const clang::IntegerLiteral* IL); StmtDiff VisitMemberExpr(const clang::MemberExpr* ME); StmtDiff VisitParenExpr(const clang::ParenExpr* PE); - StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); - StmtDiff VisitStmt(const clang::Stmt* S); + virtual StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS); + virtual StmtDiff VisitStmt(const clang::Stmt* S); StmtDiff VisitUnaryOperator(const clang::UnaryOperator* UnOp); StmtDiff VisitExprWithCleanups(const clang::ExprWithCleanups* EWC); /// Decl is not Stmt, so it cannot be visited directly. diff --git a/lib/Differentiator/CMakeLists.txt b/lib/Differentiator/CMakeLists.txt index 61695990e..f5eddb2c6 100644 --- a/lib/Differentiator/CMakeLists.txt +++ b/lib/Differentiator/CMakeLists.txt @@ -31,6 +31,7 @@ add_llvm_library(cladDifferentiator HessianModeVisitor.cpp JacobianModeVisitor.cpp MultiplexExternalRMVSource.cpp + ReverseModeForwPassVisitor.cpp ReverseModeVisitor.cpp StmtClone.cpp VectorForwardModeVisitor.cpp diff --git a/lib/Differentiator/DerivativeBuilder.cpp b/lib/Differentiator/DerivativeBuilder.cpp index d5a15580b..a67f6ed44 100644 --- a/lib/Differentiator/DerivativeBuilder.cpp +++ b/lib/Differentiator/DerivativeBuilder.cpp @@ -22,6 +22,8 @@ #include "clad/Differentiator/HessianModeVisitor.h" #include "clad/Differentiator/JacobianModeVisitor.h" #include "clad/Differentiator/ReverseModeVisitor.h" +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" +#include "clad/Differentiator/DiffPlanner.h" #include "clad/Differentiator/StmtClone.h" #include "clad/Differentiator/VectorForwardModeVisitor.h" @@ -230,6 +232,9 @@ namespace clad { result = V.DerivePullback(FD, request); if (!m_ErrorEstHandler.empty()) CleanupErrorEstimation(m_ErrorEstHandler, m_EstModel); + } else if (request.Mode == DiffMode::reverse_mode_forward_pass) { + ReverseModeForwPassVisitor V(*this); + result = V.Derive(FD, request); } else if (request.Mode == DiffMode::hessian) { HessianModeVisitor H(*this); result = H.Derive(FD, request); diff --git a/lib/Differentiator/ReverseModeForwPassVisitor.cpp b/lib/Differentiator/ReverseModeForwPassVisitor.cpp new file mode 100644 index 000000000..d3f9d8670 --- /dev/null +++ b/lib/Differentiator/ReverseModeForwPassVisitor.cpp @@ -0,0 +1,420 @@ +#include "clad/Differentiator/ReverseModeForwPassVisitor.h" + +#include "clad/Differentiator/CladUtils.h" +#include "clad/Differentiator/DiffPlanner.h" +#include "clad/Differentiator/ErrorEstimator.h" + +#include "llvm/Support/SaveAndRestore.h" + +#include + +using namespace clang; + +namespace clad { + +ReverseModeForwPassVisitor::ReverseModeForwPassVisitor( + DerivativeBuilder& builder) + : ReverseModeVisitor(builder) {} + +DerivativeAndOverload +ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD, + const DiffRequest& request) { + silenceDiags = !request.VerboseDiags; + m_Function = FD; + + m_Mode = DiffMode::reverse_mode_forward_pass; + + assert(m_Function && "Must not be null."); + + DiffParams args{}; + std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args)); + + auto fnName = m_Function->getNameAsString() + "_forw"; + auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName); + + auto paramTypes = ComputeParamTypes(args); + auto returnType = ComputeReturnType(); + auto sourceFnType = dyn_cast(m_Function->getType()); + auto fnType = m_Context.getFunctionType(returnType, paramTypes, + sourceFnType->getExtProtoInfo()); + + llvm::SaveAndRestore saveContext(m_Sema.CurContext); + llvm::SaveAndRestore saveScope(m_CurScope); + m_Sema.CurContext = const_cast(m_Function->getDeclContext()); + + DeclWithContext fnBuildRes = + m_Builder.cloneFunction(m_Function, *this, m_Sema.CurContext, m_Sema, + m_Context, noLoc, fnDNI, fnType); + m_Derivative = fnBuildRes.first; + + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), m_Derivative); + + auto params = BuildParams(args); + m_Derivative->setParams(params); + m_Derivative->setBody(nullptr); + + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + + beginBlock(); + + StmtDiff bodyDiff = Visit(m_Function->getBody()); + Stmt* forward = bodyDiff.getStmt(); + + for (Stmt* S : ReverseModeVisitor::m_Globals) + addToCurrentBlock(S); + + if (auto CS = dyn_cast(forward)) + for (Stmt* S : CS->body()) + addToCurrentBlock(S); + + Stmt* fnBody = endBlock(); + // llvm::errs() << "Derive: dumping fnBody:\n"; + // fnBody->dumpColor(); + m_Derivative->setBody(fnBody); + endScope(); + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); + // llvm::errs() << "Derive: Dumping m_Derivative:\n"; + // m_Derivative->dumpColor(); + return DerivativeAndOverload{m_Derivative, nullptr}; +} + +// FIXME: This function is copied from ReverseModeVisitor. Find a suitable place +// for it. +QualType +ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType, + QualType xType) { + assert(yType.getNonReferenceType()->isRealType() && + "yType should be a builtin-numerical scalar type!!"); + QualType xValueType = utils::GetValueType(xType); + // derivative variables should always be of non-const type. + xValueType.removeLocalConst(); + QualType nonRefXValueType = xValueType.getNonReferenceType(); + if (nonRefXValueType->isRealType()) + return GetCladArrayRefOfType(yType); + else + return GetCladArrayRefOfType(nonRefXValueType); +} + +llvm::SmallVector +ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) { + llvm::SmallVector paramTypes; + paramTypes.reserve(m_Function->getNumParams() * 2); + for (auto PVD : m_Function->parameters()) + paramTypes.push_back(PVD->getType()); + + QualType effectiveReturnType = + m_Function->getReturnType().getNonReferenceType(); + + if (auto MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + QualType thisType = clad_compat::CXXMethodDecl_getThisType(m_Sema, MD); + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, thisType)); + } + } + + for (auto PVD : m_Function->parameters()) { + const auto *it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + paramTypes.push_back( + GetParameterDerivativeType(effectiveReturnType, PVD->getType())); + } + } + return paramTypes; +} + +clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() { + auto *valAndAdjointTempDecl = LookupTemplateDeclInCladNamespace("ValueAndAdjoint"); + auto RT = m_Function->getReturnType(); + auto T = InstantiateTemplate(valAndAdjointTempDecl, {RT, RT}); + return T; +} + +llvm::SmallVector +ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) { + llvm::SmallVector params; + llvm::SmallVector paramDerivatives; + params.reserve(m_Function->getNumParams() + diffParams.size()); + const auto *derivativeFnType = cast(m_Derivative->getType()); + + std::size_t dParamTypesIdx = m_Function->getNumParams(); + + if (auto MD = dyn_cast(m_Function)) { + const CXXRecordDecl* RD = MD->getParent(); + if (MD->isInstance() && !RD->isLambda()) { + auto thisDerivativePVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, CreateUniqueIdentifier("_d_this"), + derivativeFnType->getParamType(dParamTypesIdx)); + paramDerivatives.push_back(thisDerivativePVD); + + if (thisDerivativePVD->getIdentifier()) + m_Sema.PushOnScopeChains(thisDerivativePVD, getCurrentScope(), + /*AddToContext=*/false); + + Expr* deref = + BuildOp(UnaryOperatorKind::UO_Deref, BuildDeclRef(thisDerivativePVD)); + m_ThisExprDerivative = utils::BuildParenExpr(m_Sema, deref); + ++dParamTypesIdx; + } + } + for (auto PVD : m_Function->parameters()) { + // FIXME: Call expression may contain default arguments that we are now + // removing. This may cause issues. + auto newPVD = utils::BuildParmVarDecl( + m_Sema, m_Derivative, PVD->getIdentifier(), PVD->getType(), + PVD->getStorageClass(), /*DefArg=*/nullptr, PVD->getTypeSourceInfo()); + params.push_back(newPVD); + + if (newPVD->getIdentifier()) + m_Sema.PushOnScopeChains(newPVD, getCurrentScope(), + /*AddToContext=*/false); + + auto *it = std::find(std::begin(diffParams), std::end(diffParams), PVD); + if (it != std::end(diffParams)) { + *it = newPVD; + QualType dType = derivativeFnType->getParamType(dParamTypesIdx); + IdentifierInfo* dII = + CreateUniqueIdentifier("_d_" + PVD->getNameAsString()); + auto *dPVD = utils::BuildParmVarDecl(m_Sema, m_Derivative, dII, dType, + PVD->getStorageClass()); + paramDerivatives.push_back(dPVD); + ++dParamTypesIdx; + + if (dPVD->getIdentifier()) + m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), + /*AddToContext=*/false); + + if (utils::isArrayOrPointerType(PVD->getType())) { + m_Variables[*it] = (Expr*)BuildDeclRef(dPVD); + } else { + QualType valueType = DetermineCladArrayValueType(dPVD->getType()); + m_Variables[*it] = + BuildOp(UO_Deref, BuildDeclRef(dPVD), m_Function->getLocation()); + // Add additional paranthesis if derivative is of record type + // because `*derivative.someField` will be incorrectly evaluated if + // the derived function is compiled standalone. + if (valueType->isRecordType()) + m_Variables[*it] = utils::BuildParenExpr(m_Sema, m_Variables[*it]); + } + } + } + params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end()); + return params; +} + +StmtDiff ReverseModeForwPassVisitor::ProcessSingleStmt(const clang::Stmt* S) { + StmtDiff SDiff = Visit(S); + return {SDiff.getStmt()}; +} + +StmtDiff ReverseModeForwPassVisitor::VisitStmt(const clang::Stmt* S) { + return {Clone(S)}; +} + +StmtDiff +ReverseModeForwPassVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) { + beginScope(Scope::DeclScope); + beginBlock(); + for (Stmt* S : CS->body()) { + StmtDiff SDiff = ProcessSingleStmt(S); + addToCurrentBlock(SDiff.getStmt()); + } + CompoundStmt* forward = endBlock(); + endScope(); + return {forward}; +} + +StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) { + DeclRefExpr* clonedDRE = nullptr; + // Check if referenced Decl was "replaced" with another identifier inside + // the derivative + if (const auto *VD = dyn_cast(DRE->getDecl())) { + auto it = m_DeclReplacements.find(VD); + if (it != std::end(m_DeclReplacements)) + clonedDRE = BuildDeclRef(it->second); + else + clonedDRE = cast(Clone(DRE)); + // If current context is different than the context of the original + // declaration (e.g. we are inside lambda), rebuild the DeclRefExpr + // with Sema::BuildDeclRefExpr. This is required in some cases, e.g. + // Sema::BuildDeclRefExpr is responsible for adding captured fields + // to the underlying struct of a lambda. + if (clonedDRE->getDecl()->getDeclContext() != m_Sema.CurContext) { + auto *referencedDecl = cast(clonedDRE->getDecl()); + clonedDRE = cast(BuildDeclRef(referencedDecl)); + } + } else + clonedDRE = cast(Clone(DRE)); + + if (auto *decl = dyn_cast(clonedDRE->getDecl())) { + // Check DeclRefExpr is a reference to an independent variable. + auto it = m_Variables.find(decl); + if (it == std::end(m_Variables)) { + // Is not an independent variable, ignored. + return StmtDiff(clonedDRE); + } + return StmtDiff(clonedDRE, it->second); + } + + return StmtDiff(clonedDRE); +} + +StmtDiff +ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) { + const Expr* value = RS->getRetValue(); + auto returnDiff = Visit(value); + llvm::SmallVector returnArgs = {returnDiff.getExpr(), + returnDiff.getExpr_dx()}; + Expr* returnInitList = m_Sema.ActOnInitList(noLoc, returnArgs, noLoc).get(); + Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get(); + return {newRS}; +} + +// StmtDiff ReverseModeForwPassVisitor::VisitDeclStmt(const DeclStmt* DS) { +// llvm::SmallVector decls, derivedDecls; +// for (auto D : DS->decls()) { +// if (auto VD = dyn_cast(D)) { +// VarDeclDiff VDDiff = DifferentiateVarDecl(VD); + +// if (VDDiff.getDecl()->getDeclName() != VD->getDeclName()) +// m_DeclReplacements[VD] = VDDiff.getDecl(); +// decls.push_back(VDDiff.getDecl()); +// derivedDecls.push_back(VDDiff.getDecl_dx()); +// } else { +// diag(DiagnosticsEngine::Warning, D->getEndLoc(), +// "Unsupported declaration"); +// } +// } +// } + +// VarDeclDiff ReverseModeForwPassVisitor::DifferentiateVarDecl(const VarDecl* +// VD) { +// StmtDiff initDiff; +// Expr* VDDerivedInit = nullptr; +// auto VDDerivedType = VD->getType(); +// bool isVDRefType = VD->getType()->isReferenceType(); +// VarDecl* VDDerived = nullptr; + +// if (auto VDCAT = dyn_cast(VD->getType())) { +// assert("Should not reach here!!!"); +// // VDDerivedType = +// // GetCladArrayOfType(QualType(VDCAT->getPointeeOrArrayElementType(), +// // VDCAT->getIndexTypeCVRQualifiers())); +// // VDDerivedInit = ConstantFolder::synthesizeLiteral( +// // m_Context.getSizeType(), m_Context, +// VDCAT->getSize().getZExtValue()); +// // VDDerived = BuildVarDecl(VDDerivedType, "_d_" + +// VD->getNameAsString(), +// // VDDerivedInit, false, nullptr, +// // clang::VarDecl::InitializationStyle::CallInit); +// } else { +// // If VD is a reference to a local variable, then the initial value is +// set +// // to the derived variable of the corresponding local variable. +// // If VD is a reference to a non-local variable (global variable, +// struct +// // member etc), then no derived variable is available, thus `VDDerived` +// // does not need to reference any variable, consequentially the +// // `VDDerivedType` is the corresponding non-reference type and the +// initial +// // value is set to 0. +// // Otherwise, for non-reference types, the initial value is set to 0. +// VDDerivedInit = getZeroInit(VD->getType()); + +// // `specialThisDiffCase` is only required for correctly differentiating +// // the following code: +// // ``` +// // Class _d_this_obj; +// // Class* _d_this = &_d_this_obj; +// // ``` +// // Computation of hessian requires this code to be correctly +// // differentiated. +// bool specialThisDiffCase = false; +// if (auto MD = dyn_cast(m_Function)) { +// if (VDDerivedType->isPointerType() && MD->isInstance()) { +// specialThisDiffCase = true; +// } +// } + +// // FIXME: Remove the special cases introduced by `specialThisDiffCase` +// // once reverse mode supports pointers. `specialThisDiffCase` is only +// // required for correctly differentiating the following code: +// // ``` +// // Class _d_this_obj; +// // Class* _d_this = &_d_this_obj; +// // ``` +// // Computation of hessian requires this code to be correctly +// // differentiated. +// if (isVDRefType || specialThisDiffCase) { +// VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema); +// initDiff = Visit(VD->getInit()); +// if (initDiff.getExpr_dx()) +// VDDerivedInit = initDiff.getExpr_dx(); +// else +// VDDerivedType = VDDerivedType.getNonReferenceType(); +// } +// // Here separate behaviour for record and non-record types is only +// // necessary to preserve the old tests. +// if (VDDerivedType->isRecordType()) +// VDDerived = +// BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(), +// VDDerivedInit, VD->isDirectInit(), +// m_Context.getTrivialTypeSourceInfo(VDDerivedType), +// VD->getInitStyle()); +// else +// VDDerived = BuildVarDecl(VDDerivedType, "_d_" + +// VD->getNameAsString(), +// VDDerivedInit); +// } + +// // If `VD` is a reference to a local variable, then it is already +// // differentiated and should not be differentiated again. +// // If `VD` is a reference to a non-local variable then also there's no +// // need to call `Visit` since non-local variables are not differentiated. +// if (!isVDRefType) { +// initDiff = VD->getInit() ? Visit(VD->getInit(), +// BuildDeclRef(VDDerived)) +// : StmtDiff{}; + +// // If we are differentiating `VarDecl` corresponding to a local +// variable +// // inside a loop, then we need to reset it to 0 at each iteration. +// // +// // for example, if defined inside a loop, +// // ``` +// // double localVar = i; +// // ``` +// // this statement should get differentiated to, +// // ``` +// // { +// // *_d_i += _d_localVar; +// // _d_localVar = 0; +// // } +// if (isInsideLoop) { +// Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign, +// BuildDeclRef(VDDerived), +// getZeroInit(VDDerivedType)); +// addToCurrentBlock(assignToZero, direction::reverse); +// } +// } +// VarDecl* VDClone = nullptr; +// // Here separate behaviour for record and non-record types is only +// // necessary to preserve the old tests. +// if (VD->getType()->isRecordType()) +// VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), +// initDiff.getExpr(), VD->isDirectInit(), +// VD->getTypeSourceInfo(), VD->getInitStyle()); +// else +// VDClone = BuildVarDecl(VD->getType(), VD->getNameAsString(), +// initDiff.getExpr(), VD->isDirectInit()); +// m_Variables.emplace(VDClone, BuildDeclRef(VDDerived)); +// return VarDeclDiff(VDClone, VDDerived); +// } +} // namespace clad \ No newline at end of file diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index 4e9805b64..0d2a3e30f 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -1673,15 +1673,19 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, ArgDeclStmts.push_back(BuildDeclStmt(gradVarDecl)); idx++; } + Expr* pullback = dfdx(); + if ((pullback == nullptr) && FD->getReturnType()->isLValueReferenceType()) + pullback = getZeroInit(FD->getReturnType().getNonReferenceType()); + // FIXME: Remove this restriction. if (!FD->getReturnType()->isVoidType()) { - assert((dfdx() && !FD->getReturnType()->isVoidType()) && + assert((pullback && !FD->getReturnType()->isVoidType()) && "Call to function returning non-void type with no dfdx() is not " "supported!"); } if (FD->getReturnType()->isVoidType()) { - assert(dfdx() == nullptr && FD->getReturnType()->isVoidType() && + assert(pullback == nullptr && FD->getReturnType()->isVoidType() && "Call to function returning void type should not have any " "corresponding dfdx()."); } @@ -1691,9 +1695,9 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivedCallOutputArgs.end()); pullbackCallArgs = DerivedCallArgs; - if (dfdx()) + if (pullback) pullbackCallArgs.insert(pullbackCallArgs.begin() + CE->getNumArgs(), - dfdx()); + pullback); // Try to find it in builtin derivatives std::string customPullback = FD->getNameAsString() + "_pullback"; @@ -1857,15 +1861,84 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, std::end(CallArgs), std::begin(CallArgs), [this](Expr* E) { return Clone(E); }); - // Recreate the original call expression. - Expr* call = m_Sema - .ActOnCallExpr(getCurrentScope(), - Clone(CE->getCallee()), - noLoc, - CallArgs, - noLoc) - .get(); - return StmtDiff(call); + + Expr* call = nullptr; + + if (FD->getReturnType()->isReferenceType()) { + DiffRequest calleeFnForwPassReq; + calleeFnForwPassReq.Function = FD; + calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass; + calleeFnForwPassReq.BaseFunctionName = FD->getNameAsString(); + calleeFnForwPassReq.VerboseDiags = true; + FunctionDecl* calleeFnForwPassFD = + plugin::ProcessDiffRequest(m_CladPlugin, calleeFnForwPassReq); + + assert(calleeFnForwPassFD && + "Clad failed to generate callee function forward pass function"); + + // FIXME: We are using the derivatives in forward pass here + // If `expr_dx()` is only meant to be used in reverse pass, + // (for example, `clad::pop(...)` expression and a corresponding + // `clad::push(...)` in the forward pass), then this can result in + // incorrect derivative or crash at runtime. Ideally, we should have + // a separate routine to use derivative in the forward pass. + + // We cannot reuse the derivatives previously computed because + // they might contain 'clad::pop(..)` expression. + if (isa(CE)) { + Expr* derivedBase = baseDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and passing pointers-by-reference + // if (isCladArrayType(derivedBase->getType())) + // CallArgs.push_back(derivedBase); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedBase, noLoc)); + } + + for (std::size_t i = 0, e = CE->getNumArgs(); i != e; ++i) { + const Expr* arg = CE->getArg(i); + const ParmVarDecl* PVD = FD->getParamDecl(i); + StmtDiff argDiff = Visit(arg); + if ((argDiff.getExpr_dx() != nullptr) && PVD->getType()->isReferenceType()) { + Expr* derivedArg = argDiff.getExpr_dx(); + // FIXME: We may need this if-block once we support pointers, and passing pointers-by-reference + // if (isCladArrayType(derivedArg->getType())) + // CallArgs.push_back(derivedArg); + // else + CallArgs.push_back( + BuildOp(UnaryOperatorKind::UO_AddrOf, derivedArg, noLoc)); + } else + CallArgs.push_back(m_Sema.ActOnCXXNullPtrLiteral(noLoc).get()); + } + if (isa(CE)) { + Expr* baseE = baseDiff.getExpr(); + call = BuildCallExprToMemFn( + baseE, calleeFnForwPassFD->getName(), CallArgs, calleeFnForwPassFD); + } else { + call = m_Sema + .ActOnCallExpr(getCurrentScope(), + BuildDeclRef(calleeFnForwPassFD), noLoc, + CallArgs, noLoc) + .get(); + } + auto *callRes = StoreAndRef(call); + auto *resValue = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value"); + auto *resAdjoint = + utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint"); + return StmtDiff(resValue, nullptr, resAdjoint); + } else { + // Recreate the original call expression. + call = m_Sema + .ActOnCallExpr(getCurrentScope(), + Clone(CE->getCallee()), + noLoc, + CallArgs, + noLoc) + .get(); + return StmtDiff(call); + } + return {}; } StmtDiff ReverseModeVisitor::VisitUnaryOperator(const UnaryOperator* UnOp) { @@ -2312,7 +2385,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (isDerivativeOfRefType) { initDiff = Visit(VD->getInit()); - if (!initDiff.getExpr_dx()) { + if (!initDiff.getForwSweepExpr_dx()) { VDDerivedType = ComputeAdjointType(VD->getType().getNonReferenceType()); isDerivativeOfRefType = false; @@ -3136,7 +3209,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // TODO: Add DiffMode::experimental_pullback support here as well. if (m_Mode == DiffMode::reverse || m_Mode == DiffMode::experimental_pullback) { - QualType effectiveReturnType = m_Function->getReturnType(); + QualType effectiveReturnType = m_Function->getReturnType().getNonReferenceType(); if (m_Mode == DiffMode::experimental_pullback) { // FIXME: Generally, we use the function's return type as the argument's // derivative type. We cannot follow this strategy for `void` function @@ -3150,7 +3223,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, if (effectiveReturnType->isVoidType()) effectiveReturnType = m_Context.DoubleTy; else - paramTypes.push_back(m_Function->getReturnType()); + paramTypes.push_back(effectiveReturnType); } if (auto MD = dyn_cast(m_Function)) { diff --git a/requirements.txt b/requirements.txt index be009e650..ed6c08e9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ alabaster==0.7.13 Babel==2.12.1 -certifi==2023.5.7 +certifi==2023.7.22 charset-normalizer==3.1.0 docutils==0.20.1 idna==3.4 diff --git a/test/Gradient/FunctionCalls.C b/test/Gradient/FunctionCalls.C index 725731bc4..c6944c985 100644 --- a/test/Gradient/FunctionCalls.C +++ b/test/Gradient/FunctionCalls.C @@ -351,6 +351,92 @@ double fn6(double i=0, double j=0) { return i*j; } +double& identity(double& i) { + return i; +} + +double fn7(double i, double j) { + double& k = identity(i); + double& l = identity(j); + k += 7*j; + l += 9*i; + return i + j; +} + +// CHECK: void fn6_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double _t1; +// CHECK-NEXT: _t1 = i; +// CHECK-NEXT: _t0 = j; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _r0 = 1 * _t0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: double _r1 = _t1 * 1; +// CHECK-NEXT: * _d_j += _r1; +// CHECK-NEXT: } +// CHECK-NEXT: } +// CHECK: void identity_pullback(double &i, double _d_y, clad::array_ref _d_i) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: * _d_i += _d_y; +// CHECK-NEXT: } +// CHECK: clad::ValueAndAdjoint identity_forw(double &i, clad::array_ref _d_i) { +// CHECK-NEXT: return {i, * _d_i}; +// CHECK-NEXT: } +// CHECK: void fn7_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: double *_d_k = 0; +// CHECK-NEXT: double _t2; +// CHECK-NEXT: double *_d_l = 0; +// CHECK-NEXT: double _t4; +// CHECK-NEXT: double _t5; +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: clad::ValueAndAdjoint _t1 = identity_forw(i, &* _d_i); +// CHECK-NEXT: _d_k = &_t1.adjoint; +// CHECK-NEXT: double &k = _t1.value; +// CHECK-NEXT: _t2 = j; +// CHECK-NEXT: clad::ValueAndAdjoint _t3 = identity_forw(j, &* _d_j); +// CHECK-NEXT: _d_l = &_t3.adjoint; +// CHECK-NEXT: double &l = _t3.value; +// CHECK-NEXT: _t4 = j; +// CHECK-NEXT: k += 7 * _t4; +// CHECK-NEXT: _t5 = i; +// CHECK-NEXT: l += 9 * _t5; +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: * _d_i += 1; +// CHECK-NEXT: * _d_j += 1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d1 = *_d_l; +// CHECK-NEXT: *_d_l += _r_d1; +// CHECK-NEXT: double _r4 = _r_d1 * _t5; +// CHECK-NEXT: double _r5 = 9 * _r_d1; +// CHECK-NEXT: * _d_i += _r5; +// CHECK-NEXT: *_d_l -= _r_d1; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: double _r_d0 = *_d_k; +// CHECK-NEXT: *_d_k += _r_d0; +// CHECK-NEXT: double _r2 = _r_d0 * _t4; +// CHECK-NEXT: double _r3 = 7 * _r_d0; +// CHECK-NEXT: * _d_j += _r3; +// CHECK-NEXT: *_d_k -= _r_d0; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: identity_pullback(_t2, 0, &* _d_j); +// CHECK-NEXT: double _r1 = * _d_j; +// CHECK-NEXT: } +// CHECK-NEXT: { +// CHECK-NEXT: identity_pullback(_t0, 0, &* _d_i); +// CHECK-NEXT: double _r0 = * _d_i; +// CHECK-NEXT: } +// CHECK-NEXT: } + + template void reset(T* arr, int n) { for (int i=0; i _d_i, clad::array_ref _d_j); void const_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); void volatile_mem_fn_grad(double i, double j, clad::array_ref _d_i, clad::array_ref _d_j); @@ -756,6 +758,35 @@ double fn(double i,double j) { // CHECK-NEXT: } // CHECK-NEXT: } +double fn2(SimpleFunctions& sf, double i) { + return sf.ref_mem_fn(i); +} + +// CHECK: void ref_mem_fn_pullback(double i, double _d_y, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: (* _d_this).x += _d_y; +// CHECK-NEXT: } +// CHECK: clad::ValueAndAdjoint ref_mem_fn_forw(double i, clad::array_ref _d_this, clad::array_ref _d_i) { +// CHECK-NEXT: return {this->x, (* _d_this).x}; +// CHECK-NEXT: } +// CHECK: void fn2_grad(SimpleFunctions &sf, double i, clad::array_ref _d_sf, clad::array_ref _d_i) { +// CHECK-NEXT: double _t0; +// CHECK-NEXT: SimpleFunctions _t1; +// CHECK-NEXT: _t0 = i; +// CHECK-NEXT: _t1 = sf; +// CHECK-NEXT: clad::ValueAndAdjoint _t2 = _t1.ref_mem_fn_forw(_t0, &(* _d_sf), nullptr); +// CHECK-NEXT: goto _label0; +// CHECK-NEXT: _label0: +// CHECK-NEXT: { +// CHECK-NEXT: double _grad0 = 0.; +// CHECK-NEXT: _t1.ref_mem_fn_pullback(_t0, 1, &(* _d_sf), &_grad0); +// CHECK-NEXT: double _r0 = _grad0; +// CHECK-NEXT: * _d_i += _r0; +// CHECK-NEXT: } +// CHECK-NEXT: } + + int main() { auto d_mem_fn = clad::gradient(&SimpleFunctions::mem_fn); auto d_const_mem_fn = clad::gradient(&SimpleFunctions::const_mem_fn); @@ -790,6 +821,12 @@ int main() { printf("%.2f ",result[i]); //CHECK-EXEC: 40.00 16.00 } + SimpleFunctions sf(2, 3); + SimpleFunctions d_sf; + auto d_fn2 = clad::gradient(fn2); + d_fn2.execute(sf, 2, &d_sf, &result[0]); + printf("%.2f", result[0]); //CHECK-EXEC: 40.00 + auto d_const_volatile_lval_ref_mem_fn_i = clad::gradient(&SimpleFunctions::const_volatile_lval_ref_mem_fn, "i"); // CHECK: void const_volatile_lval_ref_mem_fn_grad_0(double i, double j, clad::array_ref _d_this, clad::array_ref _d_i) const volatile & { diff --git a/test/Gradient/UserDefinedTypes.C b/test/Gradient/UserDefinedTypes.C index d10b41857..4b1d0bda5 100644 --- a/test/Gradient/UserDefinedTypes.C +++ b/test/Gradient/UserDefinedTypes.C @@ -322,6 +322,7 @@ double fn5(const Tangent& t, double i) { // CHECK-NEXT: _t0 = i; // CHECK-NEXT: _t1 = i; // CHECK-NEXT: _t2 = t; +// CHECK-NEXT: double fn5_return = t.someMemFn2(_t0, _t1); // CHECK-NEXT: goto _label0; // CHECK-NEXT: _label0: // CHECK-NEXT: { diff --git a/test/Misc/CladArray.C b/test/Misc/CladArray.C index 2d82c3cc0..0ec370747 100644 --- a/test/Misc/CladArray.C +++ b/test/Misc/CladArray.C @@ -1,5 +1,5 @@ // RUN: %cladclang %s -I%S/../../include -oCladArray.out 2>&1 -// RUN: ./CladArray.out +// RUN: ./CladArray.out | FileCheck -check-prefix=CHECK-EXEC %s // CHECK-NOT: {{.*error|warning|note:.*}} #include "clad/Differentiator/Differentiator.h" @@ -28,7 +28,7 @@ int main() { for (int i = 0; i < 3; i++) { printf("%d : %d\n", i, test_arr[i]); } - //CHECK-EXEC: 0 : 2 + //CHECK-EXEC: 0 : 1 //CHECK-EXEC: 1 : 2 //CHECK-EXEC: 2 : 3