From 9ecf289b08cbd3dd43bc9527c4823a6340f61fef Mon Sep 17 00:00:00 2001 From: Vaibhav Thakkar Date: Wed, 26 Jul 2023 16:28:08 +0530 Subject: [PATCH] Generate overload function for vector mode to fix asserts --- include/clad/Differentiator/Array.h | 16 ++ include/clad/Differentiator/FunctionTraits.h | 160 ++++++++------- .../Differentiator/VectorForwardModeVisitor.h | 10 +- include/clad/Differentiator/VisitorBase.h | 36 +++- .../VectorForwardModeVisitor.cpp | 188 +++++++++++++++--- lib/Differentiator/VisitorBase.cpp | 38 +++- test/ForwardMode/VectorMode.C | 46 ++--- test/ForwardMode/VectorModeInterface.C | 13 +- test/Misc/CladArray.C | 8 + 9 files changed, 360 insertions(+), 155 deletions(-) diff --git a/include/clad/Differentiator/Array.h b/include/clad/Differentiator/Array.h index cd194d388..a60d5148d 100644 --- a/include/clad/Differentiator/Array.h +++ b/include/clad/Differentiator/Array.h @@ -246,6 +246,22 @@ template class array { CUDA_HOST_DEVICE operator T*() const { return m_arr; } }; // class array +// Function to instantiate a one-hot array of size n with 1 at index i. +// A one-hot vector is a vector with all elements set to 0 except for one +// element which is set to 1. +// For example, if n=4 and i=2, the returned array is: {0, 0, 1, 0} +template +CUDA_HOST_DEVICE array one_hot_vector(std::size_t n, std::size_t i) { + array arr(n); + arr[i] = 1; + return arr; +} + +// Function to instantiate a zero vector of size n +template CUDA_HOST_DEVICE array zero_vector(std::size_t n) { + return array(n); +} + /// Overloaded operators for clad::array which return a new array. /// Multiplies the number to every element in the array and returns a new diff --git a/include/clad/Differentiator/FunctionTraits.h b/include/clad/Differentiator/FunctionTraits.h index 00f43ab7e..ddc132645 100644 --- a/include/clad/Differentiator/FunctionTraits.h +++ b/include/clad/Differentiator/FunctionTraits.h @@ -37,7 +37,7 @@ namespace clad { /// Placeholder type for denoting no function type exists /// - /// This is used by `ExtractDerivedFnTraitsForwMode` and + /// This is used by `ExtractDerivedFnTraitsForwMode` and /// `ExtractDerivedFnTraits` type trait as value for member typedef /// `type` to denote no function type exists. class NoFunction {}; @@ -45,122 +45,120 @@ namespace clad { // Trait class to deduce return type of function(both member and non-member) at commpile time // Only function pointer types are supported by this trait class - template - struct return_type {}; - template - using return_type_t = typename return_type::type; + template struct return_type {}; + template using return_type_t = typename return_type::type; // specializations for non-member functions pointer types - template + template struct return_type { using type = ReturnType; }; - template + template struct return_type { using type = ReturnType; }; // specializations for member functions pointer types with no qualifiers - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; // specializations for member functions pointer type with only cv-qualifiers - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - // specializations for member functions pointer types with + // specializations for member functions pointer types with // reference qualifiers and with and without cv-qualifiers - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; - template - struct return_type { - using type = ReturnType; + template + struct return_type { + using type = ReturnType; }; template<> @@ -735,7 +733,7 @@ namespace clad { template struct ExtractDerivedFnTraitsVecForwMode { - using type = void (*)(Args..., OutputVecParamType_t...); + using type = void (*)(Args..., OutputVecParamType_t...); }; /// Specialization for free function pointer type diff --git a/include/clad/Differentiator/VectorForwardModeVisitor.h b/include/clad/Differentiator/VectorForwardModeVisitor.h index 4ff26db54..eb65cbed7 100644 --- a/include/clad/Differentiator/VectorForwardModeVisitor.h +++ b/include/clad/Differentiator/VectorForwardModeVisitor.h @@ -32,6 +32,11 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { DerivativeAndOverload DeriveVectorMode(const clang::FunctionDecl* FD, const DiffRequest& request); + /// Builds an overload for the gradient function that has derived params for + /// all the arguments of the requested function and it calls the original + /// gradient function internally + clang::FunctionDecl* CreateVectorModeOverload(); + /// Builds and returns the sequence of derived function parameters for // vectorized forward mode. /// @@ -49,12 +54,13 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor { /// /// For example: for index = 2 and size = 4, the returned expression /// is: {0, 0, 1, 0} - clang::Expr* getOneHotInitExpr(size_t index, size_t size); + clang::Expr* getOneHotInitExpr(size_t index, size_t size, + clang::QualType type); /// Get an expression used to initialize a zero vector of the given size. /// /// For example: for size = 4, the returned expression is: {0, 0, 0, 0} - clang::Expr* getZeroInitListExpr(size_t size); + clang::Expr* getZeroInitListExpr(size_t size, clang::QualType type); StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override; // Decl is not Stmt, so it cannot be visited directly. diff --git a/include/clad/Differentiator/VisitorBase.h b/include/clad/Differentiator/VisitorBase.h index 15bac04af..e9118481f 100644 --- a/include/clad/Differentiator/VisitorBase.h +++ b/include/clad/Differentiator/VisitorBase.h @@ -49,7 +49,7 @@ namespace clad { clang::Expr* getExpr_dx() { return llvm::cast_or_null(getStmt_dx()); } - + void updateStmt(clang::Stmt* S) { data[1] = S; } void updateStmtDx(clang::Stmt* S) { data[0] = S; } // Stmt_dx goes first! @@ -309,8 +309,10 @@ namespace clad { /// declaration reference expressions. This function builds a declaration /// reference given a declaration. /// \param[in] D The declaration to build a DeclRefExpr for. + /// \param[in] SS The scope specifier for the declaration. /// \returns the DeclRefExpr for the given declaration. - clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D); + clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D, + const clang::CXXScopeSpec* SS = nullptr); /// Stores the result of an expression in a temporary variable (of the same /// type as is the result of the expression) and returns a reference to it. @@ -456,7 +458,21 @@ namespace clad { clang::Expr* BuildCallExprToFunction(clang::FunctionDecl* FD, llvm::MutableArrayRef argExprs, - bool useRefQualifiedThisObj = false); + bool useRefQualifiedThisObj = false, + const clang::CXXScopeSpec* SS = nullptr); + + /// Build a call to templated free function inside the clad namespace. + /// + /// \param[in] name name of the function + /// \param[in] argExprs function arguments expressions + /// \param[in] templateArgs template arguments + /// \param[in] loc location of the call + /// \returns Built call expression + clang::Expr* BuildCallExprToCladFunction( + llvm::StringRef name, llvm::MutableArrayRef argExprs, + llvm::ArrayRef templateArgs, + clang::SourceLocation loc); + /// Find declaration of clad::array_ref templated type. clang::TemplateDecl* GetCladArrayRefDecl(); /// Create clad::array_ref type. @@ -492,9 +508,8 @@ namespace clad { /// /// \returns The derivative function call. clang::Expr* GetSingleArgCentralDiffCall( - clang::Expr* targetFuncCall, clang::Expr* targetArg, - unsigned targetPos, unsigned numArgs, - llvm::SmallVectorImpl& args); + clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos, + unsigned numArgs, llvm::SmallVectorImpl& args); /// A function to get the multi-argument "central_difference" /// call expression for the given arguments. /// @@ -508,17 +523,16 @@ namespace clad { /// /// \returns The derivative function call. clang::Expr* GetMultiArgCentralDiffCall( - clang::Expr* targetFuncCall, clang::QualType retType, - unsigned numArgs, + clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs, llvm::SmallVectorImpl& NumericalDiffMultiArg, llvm::SmallVectorImpl& args, llvm::SmallVectorImpl& outputArgs); - /// Emits diagnostic messages on differentiation (or lack thereof) for + /// Emits diagnostic messages on differentiation (or lack thereof) for /// call expressions. /// - /// \param[in] \c funcName The name of the underlying function of the + /// \param[in] \c funcName The name of the underlying function of the /// call expression. - /// \param[in] \c srcLoc Any associated source location information. + /// \param[in] \c srcLoc Any associated source location information. /// \param[in] \c isDerived A flag to determine if differentiation of the /// call expression was successful. void CallExprDiffDiagnostics(llvm::StringRef funcName, diff --git a/lib/Differentiator/VectorForwardModeVisitor.cpp b/lib/Differentiator/VectorForwardModeVisitor.cpp index 904c9647f..82069dc9c 100644 --- a/lib/Differentiator/VectorForwardModeVisitor.cpp +++ b/lib/Differentiator/VectorForwardModeVisitor.cpp @@ -3,6 +3,8 @@ #include "ConstantFolder.h" #include "clad/Differentiator/CladUtils.h" +#include "clang/AST/TemplateName.h" +#include "clang/Sema/Lookup.h" #include "llvm/Support/SaveAndRestore.h" using namespace clang; @@ -99,13 +101,13 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, m_IndependentVars[independentVarIndex] == m_Function->getParamDecl(i)) { // This parameter is an independent variable. // Create a one hot vector for the parameter. - dVectorParam = - getOneHotInitExpr(independentVarIndex, m_IndependentVars.size()); + dVectorParam = getOneHotInitExpr(independentVarIndex, + m_IndependentVars.size(), dParamType); ++independentVarIndex; } else { // This parameter is not an independent variable. // Initialize by all zeros. - dVectorParam = getZeroInitListExpr(m_IndependentVars.size()); + dVectorParam = getZeroInitListExpr(m_IndependentVars.size(), dParamType); } // For each function arg to be differentiated, create a variable @@ -139,28 +141,166 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD, m_Sema.PopDeclContext(); endScope(); // Function decl scope - return DerivativeAndOverload{vectorDiffFD, nullptr}; + // Create the overload declaration for the derivative. + FunctionDecl* overloadFD = CreateVectorModeOverload(); + return DerivativeAndOverload{vectorDiffFD, overloadFD}; +} + +clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() { + auto vectorModeParams = m_Derivative->parameters(); + auto vectorModeNameInfo = m_Derivative->getNameInfo(); + + // Calculate the total number of parameters that would be required for + // automatic differentiation in the derived function if all args are + // requested. + std::size_t totalDerivedParamsSize = m_Function->getNumParams() * 2; + std::size_t numDerivativeParams = m_Function->getNumParams(); + + // Generate the function type for the derivative. + llvm::SmallVector paramTypes; + paramTypes.reserve(totalDerivedParamsSize); + for (auto PVD : m_Function->parameters()) { + paramTypes.push_back(PVD->getType()); + } + + // instantiate output parameter type as void* + QualType outputParamType = m_Context.getPointerType(m_Context.VoidTy); + + // Push param types for derived params. + for (std::size_t i = 0; i < m_Function->getNumParams(); ++i) + paramTypes.push_back(outputParamType); + + auto vectorModeFuncOverloadEPI = + dyn_cast(m_Function->getType())->getExtProtoInfo(); + QualType vectorModeFuncOverloadType = m_Context.getFunctionType( + m_Context.VoidTy, + llvm::ArrayRef(paramTypes.data(), paramTypes.size()), + vectorModeFuncOverloadEPI); + + // Create the function declaration for the derivative. + DeclContext* DC = const_cast(m_Function->getDeclContext()); + m_Sema.CurContext = DC; + DeclWithContext result = + m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, noLoc, + vectorModeNameInfo, vectorModeFuncOverloadType); + FunctionDecl* vectorModeOverloadFD = result.first; + + // Function declaration scope + beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope | + Scope::DeclScope); + m_Sema.PushFunctionScope(); + m_Sema.PushDeclContext(getCurrentScope(), vectorModeOverloadFD); + + llvm::SmallVector overloadParams; + overloadParams.reserve(totalDerivedParamsSize); + + llvm::SmallVector callArgs; // arguments to the call the requested + // vectormode function. + callArgs.reserve(vectorModeParams.size()); + + for (auto PVD : m_Function->parameters()) { + auto VD = utils::BuildParmVarDecl( + m_Sema, vectorModeOverloadFD, PVD->getIdentifier(), PVD->getType(), + PVD->getStorageClass(), /*defArg=*/nullptr, PVD->getTypeSourceInfo()); + overloadParams.push_back(VD); + callArgs.push_back(BuildDeclRef(VD)); + } + + for (std::size_t i = 0; i < numDerivativeParams; ++i) { + ParmVarDecl* PVD; + std::size_t effectiveIndex = m_Function->getNumParams() + i; + + if (effectiveIndex < vectorModeParams.size()) { + // This parameter represents an actual derivative parameter. + auto OriginalVD = vectorModeParams[effectiveIndex]; + PVD = utils::BuildParmVarDecl( + m_Sema, vectorModeOverloadFD, + CreateUniqueIdentifier("_temp_" + OriginalVD->getNameAsString()), + outputParamType, OriginalVD->getStorageClass()); + } else { + PVD = utils::BuildParmVarDecl( + m_Sema, vectorModeOverloadFD, + CreateUniqueIdentifier("_d_" + std::to_string(i)), outputParamType, + StorageClass::SC_None); + } + overloadParams.push_back(PVD); + } + + for (auto PVD : overloadParams) { + if (PVD->getIdentifier()) + m_Sema.PushOnScopeChains(PVD, getCurrentScope(), + /*AddToContext=*/false); + } + + vectorModeOverloadFD->setParams(overloadParams); + vectorModeOverloadFD->setBody(/*B=*/nullptr); + + // Create the body of the derivative. + beginScope(Scope::FnScope | Scope::DeclScope); + m_DerivativeFnScope = getCurrentScope(); + beginBlock(); + + // Build derivatives to be used in the call to the actual derived function. + // These are initialised by effectively casting the derivative parameters of + // overloaded derived function to the correct type. + for (std::size_t i = m_Function->getNumParams(); i < vectorModeParams.size(); + ++i) { + auto overloadParam = overloadParams[i]; + auto vectorModeParam = vectorModeParams[i]; + + // Create a cast expression to cast the derivative parameter to the correct + // type. + auto castExpr = + m_Sema + .BuildCXXNamedCast( + noLoc, tok::TokenKind::kw_static_cast, + m_Context.getTrivialTypeSourceInfo(vectorModeParam->getType()), + BuildDeclRef(overloadParam), noLoc, noLoc) + .get(); + auto vectorModeVD = + BuildVarDecl(vectorModeParam->getType(), + vectorModeParam->getNameAsString(), castExpr); + callArgs.push_back(BuildDeclRef(vectorModeVD)); + addToCurrentBlock(BuildDeclStmt(vectorModeVD)); + } + + Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs, + /*UseRefQualifiedThisObj=*/true); + addToCurrentBlock(callExpr); + Stmt* vectorModeOverloadBody = endBlock(); + + vectorModeOverloadFD->setBody(vectorModeOverloadBody); + + endScope(); // Function body scope + m_Sema.PopFunctionScopeInfo(); + m_Sema.PopDeclContext(); + endScope(); // Function decl scope + + return vectorModeOverloadFD; } clang::Expr* VectorForwardModeVisitor::getOneHotInitExpr(size_t index, - size_t size) { - // define a vector of size `size` with all elements set to 0, - // except for the element at `index` which is set to 1. - auto zero = - ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0); - auto one = - ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 1); - llvm::SmallVector oneHotInitList(size, zero); - oneHotInitList[index] = one; - return m_Sema.ActOnInitList(m_Function->getLocation(), llvm::MutableArrayRef(oneHotInitList), m_Function->getLocation()).get(); + size_t size, + clang::QualType type) { + // Build call expression for one_hot + llvm::SmallVector args = { + ConstantFolder::synthesizeLiteral(m_Context.UnsignedLongTy, m_Context, + size), + ConstantFolder::synthesizeLiteral(m_Context.UnsignedLongTy, m_Context, + index)}; + return BuildCallExprToCladFunction("one_hot_vector", args, {type}, + m_Function->getLocation()); } -clang::Expr* VectorForwardModeVisitor::getZeroInitListExpr(size_t size) { +clang::Expr* +VectorForwardModeVisitor::getZeroInitListExpr(size_t size, + clang::QualType type) { // define a vector of size `size` with all elements set to 0. - auto zero = - ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0); - llvm::SmallVector zeroInitList(size, zero); - return m_Sema.ActOnInitList(m_Function->getLocation(), llvm::MutableArrayRef(zeroInitList), m_Function->getLocation()).get(); + // Build call expression for zero_vector + llvm::SmallVector args = {ConstantFolder::synthesizeLiteral( + m_Context.UnsignedLongTy, m_Context, size)}; + return BuildCallExprToCladFunction("zero_vector", args, {type}, + m_Function->getLocation()); } llvm::SmallVector @@ -196,7 +336,7 @@ VectorForwardModeVisitor::BuildVectorModeParams(DiffParams& diffParams) { m_Sema.PushOnScopeChains(dPVD, getCurrentScope(), /*AddToContext=*/false); - m_ParamVariables[*it] = BuildOp(UO_Deref, BuildDeclRef(dPVD), m_Function->getLocation()); + m_ParamVariables[*it] = BuildOp(UO_Deref, BuildDeclRef(dPVD), noLoc); } // insert the derivative parameters at the end of the parameter list. params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end()); @@ -231,7 +371,7 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { auto dParamValue = m_Sema .ActOnArraySubscriptExpr(getCurrentScope(), dVectorRef, - dVectorRef->getExprLoc(), indexExpr, m_Function->getLocation()) + dVectorRef->getExprLoc(), indexExpr, noLoc) .get(); // Create an assignment expression to assign the ith element of the // return vector to the derivative of the ith parameter. @@ -241,7 +381,7 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) { } // Add an empty return statement to the array of statements. returnStmts.push_back( - m_Sema.ActOnReturnStmt(m_Function->getLocation(), nullptr, getCurrentScope()).get()); + m_Sema.ActOnReturnStmt(noLoc, nullptr, getCurrentScope()).get()); // Create a return statement from the compound statement. Stmt* returnStmt = MakeCompoundStmt(returnStmts); @@ -276,13 +416,13 @@ VarDeclDiff VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) { .ActOnCXXTypeConstructExpr( OpaquePtr::make( GetCladArrayOfType(utils::GetValueType(VD->getType()))), - m_Function->getLocation(), args, m_Function->getLocation(), false) + noLoc, args, noLoc, false) .get(); VarDecl* VDDerived = BuildVarDecl(GetCladArrayOfType(utils::GetValueType(VD->getType())), "_d_vector_" + VD->getNameAsString(), constructorCallExpr, - true, nullptr, VarDecl::InitializationStyle::CallInit); + false, nullptr, VarDecl::InitializationStyle::CallInit); m_Variables.emplace(VDClone, BuildDeclRef(VDDerived)); return VarDeclDiff(VDClone, VDDerived); diff --git a/lib/Differentiator/VisitorBase.cpp b/lib/Differentiator/VisitorBase.cpp index b44db0aa7..84a142cf9 100644 --- a/lib/Differentiator/VisitorBase.cpp +++ b/lib/Differentiator/VisitorBase.cpp @@ -179,11 +179,12 @@ namespace clad { return new (m_Context) DeclStmt(DGR, noLoc, noLoc); } - DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D) { + DeclRefExpr* VisitorBase::BuildDeclRef(DeclaratorDecl* D, + const CXXScopeSpec* SS) { QualType T = D->getType(); T = T.getNonReferenceType(); return cast(clad_compat::GetResult( - m_Sema.BuildDeclRefExpr(D, T, VK_LValue, noLoc))); + m_Sema.BuildDeclRefExpr(D, T, VK_LValue, noLoc, SS))); } IdentifierInfo* @@ -543,15 +544,14 @@ namespace clad { .get(); } - Expr* - VisitorBase::BuildCallExprToFunction(FunctionDecl* FD, - llvm::MutableArrayRef argExprs, - bool useRefQualifiedThisObj) { + Expr* VisitorBase::BuildCallExprToFunction( + FunctionDecl* FD, llvm::MutableArrayRef argExprs, + bool useRefQualifiedThisObj, const CXXScopeSpec* SS) { Expr* call = nullptr; if (auto derMethod = dyn_cast(FD)) { call = BuildCallExprToMemFn(derMethod, argExprs, useRefQualifiedThisObj); } else { - Expr* exprFunc = BuildDeclRef(FD); + Expr* exprFunc = BuildDeclRef(FD, SS); call = m_Sema .ActOnCallExpr( getCurrentScope(), @@ -564,6 +564,30 @@ namespace clad { return call; } + Expr* VisitorBase::BuildCallExprToCladFunction( + llvm::StringRef name, llvm::MutableArrayRef argExprs, + llvm::ArrayRef templateArgs, + SourceLocation loc) { + DeclarationName declName = &m_Context.Idents.get(name); + clang::LookupResult R(m_Sema, declName, noLoc, Sema::LookupOrdinaryName); + + // Find function declaration + NamespaceDecl* CladNS = GetCladNamespace(); + CXXScopeSpec CSS; + CSS.Extend(m_Context, CladNS, loc, loc); + m_Sema.LookupQualifiedName(R, CladNS, CSS); + + // Build the template specialization expression. + // TODO: currently this doesn't print func(args...) while dumping + // and only prints func(args...), we need to fix this. + FunctionTemplateDecl* FTD = + dyn_cast(R.getRepresentativeDecl()); + clang::TemplateArgumentList TL(TemplateArgumentList::OnStack, templateArgs); + FunctionDecl* FD = m_Sema.InstantiateFunctionDeclaration(FTD, &TL, loc); + + return BuildCallExprToFunction(FD, argExprs, false, &CSS); + } + TemplateDecl* VisitorBase::GetCladArrayRefDecl() { static TemplateDecl* Result = nullptr; if (!Result) diff --git a/test/ForwardMode/VectorMode.C b/test/ForwardMode/VectorMode.C index cd846983d..9dfef8e98 100644 --- a/test/ForwardMode/VectorMode.C +++ b/test/ForwardMode/VectorMode.C @@ -12,8 +12,8 @@ double f1(double x, double y) { void f1_dvec(double x, double y, double *_d_x, double *_d_y); // CHECK: void f1_dvec(double x, double y, double *_d_x, double *_d_y) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: double _t0 = x * y; // CHECK-NEXT: double _t1 = (x + y + 1); // CHECK-NEXT: { @@ -34,8 +34,8 @@ double f2(double x, double y) { void f2_dvec(double x, double y, double *_d_x, double *_d_y); // CHECK: void f2_dvec(double x, double y, double *_d_x, double *_d_y) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: clad::array _d_vector_temp1(clad::array(2UL, _d_vector_x * y + x * _d_vector_y)); // CHECK-NEXT: double temp1 = x * y; // CHECK-NEXT: clad::array _d_vector_temp2(clad::array(2UL, _d_vector_x + _d_vector_y + 0)); @@ -59,8 +59,8 @@ double f3(double x, double y) { void f3_dvec(double x, double y, double *_d_x, double *_d_y); // CHECK: void f3_dvec(double x, double y, double *_d_x, double *_d_y) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: if (y < 0) { // CHECK-NEXT: _d_vector_y = - _d_vector_y; // CHECK-NEXT: y = -y; @@ -89,8 +89,8 @@ double f4(double lower, double upper) { void f4_dvec(double lower, double upper, double *_d_lower, double *_d_upper); // CHECK: void f4_dvec(double lower, double upper, double *_d_lower, double *_d_upper) { -// CHECK-NEXT: clad::array _d_vector_lower = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_upper = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_lower = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_upper = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: clad::array _d_vector_sum(clad::array(2UL, 0)); // CHECK-NEXT: double sum = 0; // CHECK-NEXT: clad::array _d_vector_num_points(clad::array(2UL, 0)); @@ -120,9 +120,9 @@ double f5(double x, double y, double z) { // all // CHECK: void f5_dvec(double x, double y, double z, double *_d_x, double *_d_y, double *_d_z) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 1., 0.}; -// CHECK-NEXT: clad::array _d_vector_z = {0., 0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(3UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(3UL, 1UL); +// CHECK-NEXT: clad::array _d_vector_z = clad::one_hot_vector(3UL, 2UL); // CHECK-NEXT: { // CHECK-NEXT: clad::array _d_vector_return = 0. * x + 1. * _d_vector_x + 0. * y + 2. * _d_vector_y + 0. * z + 3. * _d_vector_z; // CHECK-NEXT: *_d_x = _d_vector_return[0]; @@ -133,9 +133,9 @@ double f5(double x, double y, double z) { // x, y // CHECK: void f5_dvec_0_1(double x, double y, double z, double *_d_x, double *_d_y) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 1.}; -// CHECK-NEXT: clad::array _d_vector_z = {0., 0.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(2UL, 1UL); +// CHECK-NEXT: clad::array _d_vector_z = clad::zero_vector(2UL); // CHECK-NEXT: { // CHECK-NEXT: clad::array _d_vector_return = 0. * x + 1. * _d_vector_x + 0. * y + 2. * _d_vector_y + 0. * z + 3. * _d_vector_z; // CHECK-NEXT: *_d_x = _d_vector_return[0]; @@ -145,9 +145,9 @@ double f5(double x, double y, double z) { // x, z // CHECK: void f5_dvec_0_2(double x, double y, double z, double *_d_x, double *_d_z) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 0.}; -// CHECK-NEXT: clad::array _d_vector_z = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::zero_vector(2UL); +// CHECK-NEXT: clad::array _d_vector_z = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: { // CHECK-NEXT: clad::array _d_vector_return = 0. * x + 1. * _d_vector_x + 0. * y + 2. * _d_vector_y + 0. * z + 3. * _d_vector_z; // CHECK-NEXT: *_d_x = _d_vector_return[0]; @@ -157,9 +157,9 @@ double f5(double x, double y, double z) { // y, z // CHECK: void f5_dvec_1_2(double x, double y, double z, double *_d_y, double *_d_z) { -// CHECK-NEXT: clad::array _d_vector_x = {0., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_z = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::zero_vector(2UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_z = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: { // CHECK-NEXT: clad::array _d_vector_return = 0. * x + 1. * _d_vector_x + 0. * y + 2. * _d_vector_y + 0. * z + 3. * _d_vector_z; // CHECK-NEXT: *_d_y = _d_vector_return[0]; @@ -169,9 +169,9 @@ double f5(double x, double y, double z) { // z // CHECK: void f5_dvec_2(double x, double y, double z, double *_d_z) { -// CHECK-NEXT: clad::array _d_vector_x = {0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0.}; -// CHECK-NEXT: clad::array _d_vector_z = {1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::zero_vector(1UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::zero_vector(1UL); +// CHECK-NEXT: clad::array _d_vector_z = clad::one_hot_vector(1UL, 0UL); // CHECK-NEXT: { // CHECK-NEXT: clad::array _d_vector_return = 0. * x + 1. * _d_vector_x + 0. * y + 2. * _d_vector_y + 0. * z + 3. * _d_vector_z; // CHECK-NEXT: *_d_z = _d_vector_return[0]; diff --git a/test/ForwardMode/VectorModeInterface.C b/test/ForwardMode/VectorModeInterface.C index b8710f37f..a550f0468 100644 --- a/test/ForwardMode/VectorModeInterface.C +++ b/test/ForwardMode/VectorModeInterface.C @@ -9,8 +9,8 @@ double f1(double x, double y) { } // CHECK: void f1_dvec(double x, double y, double *_d_x, double *_d_y) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: { // CHECK-NEXT: clad::array _d_vector_return = _d_vector_x * y + x * _d_vector_y; // CHECK-NEXT: *_d_x = _d_vector_return[0]; @@ -26,8 +26,8 @@ double f2(double x, double y) { void f2_dvec(double x, double y, double *_d_x, double *_d_y); // CHECK: void f2_dvec(double x, double y, double *_d_x, double *_d_y) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: { // CHECK-NEXT: clad::array _d_vector_return = _d_vector_x + _d_vector_y; // CHECK-NEXT: *_d_x = _d_vector_return[0]; @@ -45,8 +45,8 @@ double f_try_catch(double x, double y) } // CHECK: void f_try_catch_dvec(double x, double y, double *_d_x, double *_d_y) { -// CHECK-NEXT: clad::array _d_vector_x = {1., 0.}; -// CHECK-NEXT: clad::array _d_vector_y = {0., 1.}; +// CHECK-NEXT: clad::array _d_vector_x = clad::one_hot_vector(2UL, 0UL); +// CHECK-NEXT: clad::array _d_vector_y = clad::one_hot_vector(2UL, 1UL); // CHECK-NEXT: try { // CHECK-NEXT: return x; // CHECK-NEXT: } catch (int) { @@ -60,7 +60,6 @@ int main() { clad::differentiate(f_try_catch); clad::differentiate<2, clad::opts::vector_mode>(f_try_catch); // expected-error {{Only first order derivative is supported for now in vector forward mode}} clad::differentiate(f1); // expected-error {{Enzyme's vector mode is not yet supported}} - clad::gradient(f1, "x, y, z"); // expected-error {{Reverse vector mode is not yet supported.}} return 0; } diff --git a/test/Misc/CladArray.C b/test/Misc/CladArray.C index 0ec370747..86a760c5e 100644 --- a/test/Misc/CladArray.C +++ b/test/Misc/CladArray.C @@ -161,4 +161,12 @@ int main() { // CHECK-EXEC: 0 : 2.00 // CHECK-EXEC: 1 : 2.00 // CHECK-EXEC: 2 : 2.00 + + clad::array double_test_arr2 = clad::one_hot_vector (3, 1); + for (int i = 0; i < 3; i++) { + printf("%d : %.2f\n", i, double_test_arr2[i]); + } + // CHECK-EXEC: 0 : 0.00 + // CHECK-EXEC: 1 : 1.00 + // CHECK-EXEC: 2 : 0.00 }