Skip to content

Commit

Permalink
Generate overload function for vector mode to fix asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jul 27, 2023
1 parent 35dc011 commit b8eb829
Show file tree
Hide file tree
Showing 9 changed files with 282 additions and 75 deletions.
16 changes: 16 additions & 0 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,22 @@ template <typename T> class array {
CUDA_HOST_DEVICE operator T*() const { return m_arr; }
}; // class array

// Function to instantiate a one-hot array of size n with 1 at index i.
// A one-hot vector is a vector with all elements set to 0 except for one
// element which is set to 1.
// For example, if n=4 and i=2, the returned array is: {0, 0, 1, 0}
template <typename T>
CUDA_HOST_DEVICE array<T> one_hot_vector(std::size_t n, std::size_t i) {
array<T> arr(n);
arr[i] = 1;
return arr;
}

// Function to instantiate a zero vector of size n
template <typename T> CUDA_HOST_DEVICE array<T> zero_vector(std::size_t n) {
return array<T>(n);
}

/// Overloaded operators for clad::array which return a new array.

/// Multiplies the number to every element in the array and returns a new
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -735,7 +735,7 @@ namespace clad {

template <class ReturnType, class... Args>
struct ExtractDerivedFnTraitsVecForwMode<ReturnType (*)(Args...)> {
using type = void (*)(Args..., OutputVecParamType_t<Args, ReturnType>...);
using type = void (*)(Args..., OutputVecParamType_t<Args, void>...);
};

/// Specialization for free function pointer type
Expand Down
10 changes: 8 additions & 2 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
DerivativeAndOverload DeriveVectorMode(const clang::FunctionDecl* FD,
const DiffRequest& request);

/// Builds an overload for the gradient function that has derived params for
/// all the arguments of the requested function and it calls the original
/// gradient function internally
clang::FunctionDecl* CreateVectorModeOverload();

/// Builds and returns the sequence of derived function parameters for
// vectorized forward mode.
///
Expand All @@ -49,12 +54,13 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
///
/// For example: for index = 2 and size = 4, the returned expression
/// is: {0, 0, 1, 0}
clang::Expr* getOneHotInitExpr(size_t index, size_t size);
clang::Expr* getOneHotInitExpr(size_t index, size_t size,
clang::QualType type);

/// Get an expression used to initialize a zero vector of the given size.
///
/// For example: for size = 4, the returned expression is: {0, 0, 0, 0}
clang::Expr* getZeroInitListExpr(size_t size);
clang::Expr* getZeroInitListExpr(size_t size, clang::QualType type);

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
// Decl is not Stmt, so it cannot be visited directly.
Expand Down
36 changes: 25 additions & 11 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace clad {
clang::Expr* getExpr_dx() {
return llvm::cast_or_null<clang::Expr>(getStmt_dx());
}

void updateStmt(clang::Stmt* S) { data[1] = S; }
void updateStmtDx(clang::Stmt* S) { data[0] = S; }
// Stmt_dx goes first!
Expand Down Expand Up @@ -309,8 +309,10 @@ namespace clad {
/// declaration reference expressions. This function builds a declaration
/// reference given a declaration.
/// \param[in] D The declaration to build a DeclRefExpr for.
/// \param[in] SS The scope specifier for the declaration.
/// \returns the DeclRefExpr for the given declaration.
clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D);
clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D,
const clang::CXXScopeSpec* SS = nullptr);

/// Stores the result of an expression in a temporary variable (of the same
/// type as is the result of the expression) and returns a reference to it.
Expand Down Expand Up @@ -456,7 +458,21 @@ namespace clad {
clang::Expr*
BuildCallExprToFunction(clang::FunctionDecl* FD,
llvm::MutableArrayRef<clang::Expr*> argExprs,
bool useRefQualifiedThisObj = false);
bool useRefQualifiedThisObj = false,
const clang::CXXScopeSpec* SS = nullptr);

/// Build a call to templated free function inside the clad namespace.
///
/// \param[in] name name of the function
/// \param[in] argExprs function arguments expressions
/// \param[in] templateArgs template arguments
/// \param[in] loc location of the call
/// \returns Built call expression
clang::Expr* BuildCallExprToCladFunction(
llvm::StringRef name, llvm::MutableArrayRef<clang::Expr*> argExprs,
llvm::ArrayRef<clang::TemplateArgument> templateArgs,
clang::SourceLocation loc);

/// Find declaration of clad::array_ref templated type.
clang::TemplateDecl* GetCladArrayRefDecl();
/// Create clad::array_ref<T> type.
Expand Down Expand Up @@ -492,9 +508,8 @@ namespace clad {
///
/// \returns The derivative function call.
clang::Expr* GetSingleArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::Expr* targetArg,
unsigned targetPos, unsigned numArgs,
llvm::SmallVectorImpl<clang::Expr*>& args);
clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args);
/// A function to get the multi-argument "central_difference"
/// call expression for the given arguments.
///
Expand All @@ -508,17 +523,16 @@ namespace clad {
///
/// \returns The derivative function call.
clang::Expr* GetMultiArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::QualType retType,
unsigned numArgs,
clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs,
llvm::SmallVectorImpl<clang::Stmt*>& NumericalDiffMultiArg,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);
/// Emits diagnostic messages on differentiation (or lack thereof) for
/// Emits diagnostic messages on differentiation (or lack thereof) for
/// call expressions.
///
/// \param[in] \c funcName The name of the underlying function of the
/// \param[in] \c funcName The name of the underlying function of the
/// call expression.
/// \param[in] \c srcLoc Any associated source location information.
/// \param[in] \c srcLoc Any associated source location information.
/// \param[in] \c isDerived A flag to determine if differentiation of the
/// call expression was successful.
void CallExprDiffDiagnostics(llvm::StringRef funcName,
Expand Down
188 changes: 164 additions & 24 deletions lib/Differentiator/VectorForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "ConstantFolder.h"
#include "clad/Differentiator/CladUtils.h"

#include "clang/AST/TemplateName.h"
#include "clang/Sema/Lookup.h"
#include "llvm/Support/SaveAndRestore.h"

using namespace clang;
Expand Down Expand Up @@ -99,13 +101,13 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD,
m_IndependentVars[independentVarIndex] == m_Function->getParamDecl(i)) {
// This parameter is an independent variable.
// Create a one hot vector for the parameter.
dVectorParam =
getOneHotInitExpr(independentVarIndex, m_IndependentVars.size());
dVectorParam = getOneHotInitExpr(independentVarIndex,
m_IndependentVars.size(), dParamType);
++independentVarIndex;
} else {
// This parameter is not an independent variable.
// Initialize by all zeros.
dVectorParam = getZeroInitListExpr(m_IndependentVars.size());
dVectorParam = getZeroInitListExpr(m_IndependentVars.size(), dParamType);
}

// For each function arg to be differentiated, create a variable
Expand Down Expand Up @@ -139,28 +141,166 @@ VectorForwardModeVisitor::DeriveVectorMode(const FunctionDecl* FD,
m_Sema.PopDeclContext();
endScope(); // Function decl scope

return DerivativeAndOverload{vectorDiffFD, nullptr};
// Create the overload declaration for the derivative.
FunctionDecl* overloadFD = CreateVectorModeOverload();
return DerivativeAndOverload{vectorDiffFD, overloadFD};
}

clang::FunctionDecl* VectorForwardModeVisitor::CreateVectorModeOverload() {
auto vectorModeParams = m_Derivative->parameters();
auto vectorModeNameInfo = m_Derivative->getNameInfo();

// Calculate the total number of parameters that would be required for
// automatic differentiation in the derived function if all args are
// requested.
std::size_t totalDerivedParamsSize = m_Function->getNumParams() * 2;
std::size_t numDerivativeParams = m_Function->getNumParams();

// Generate the function type for the derivative.
llvm::SmallVector<clang::QualType, 8> paramTypes;
paramTypes.reserve(totalDerivedParamsSize);
for (auto PVD : m_Function->parameters()) {
paramTypes.push_back(PVD->getType());
}

// instantiate output parameter type as void*
QualType outputParamType = m_Context.getPointerType(m_Context.VoidTy);

// Push param types for derived params.
for (std::size_t i = 0; i < m_Function->getNumParams(); ++i)
paramTypes.push_back(outputParamType);

auto vectorModeFuncOverloadEPI =
dyn_cast<FunctionProtoType>(m_Function->getType())->getExtProtoInfo();
QualType vectorModeFuncOverloadType = m_Context.getFunctionType(
m_Context.VoidTy,
llvm::ArrayRef<QualType>(paramTypes.data(), paramTypes.size()),
vectorModeFuncOverloadEPI);

// Create the function declaration for the derivative.
DeclContext* DC = const_cast<DeclContext*>(m_Function->getDeclContext());
m_Sema.CurContext = DC;
DeclWithContext result =
m_Builder.cloneFunction(m_Function, *this, DC, m_Sema, m_Context, noLoc,
vectorModeNameInfo, vectorModeFuncOverloadType);
FunctionDecl* vectorModeOverloadFD = result.first;

// Function declaration scope
beginScope(Scope::FunctionPrototypeScope | Scope::FunctionDeclarationScope |
Scope::DeclScope);
m_Sema.PushFunctionScope();
m_Sema.PushDeclContext(getCurrentScope(), vectorModeOverloadFD);

llvm::SmallVector<ParmVarDecl*, 4> overloadParams;
overloadParams.reserve(totalDerivedParamsSize);

llvm::SmallVector<Expr*, 4> callArgs; // arguments to the call the requested
// vectormode function.
callArgs.reserve(vectorModeParams.size());

for (auto PVD : m_Function->parameters()) {
auto VD = utils::BuildParmVarDecl(
m_Sema, vectorModeOverloadFD, PVD->getIdentifier(), PVD->getType(),
PVD->getStorageClass(), /*defArg=*/nullptr, PVD->getTypeSourceInfo());
overloadParams.push_back(VD);
callArgs.push_back(BuildDeclRef(VD));
}

for (std::size_t i = 0; i < numDerivativeParams; ++i) {
ParmVarDecl* PVD;
std::size_t effectiveIndex = m_Function->getNumParams() + i;

if (effectiveIndex < vectorModeParams.size()) {
// This parameter represents an actual derivative parameter.
auto OriginalVD = vectorModeParams[effectiveIndex];
PVD = utils::BuildParmVarDecl(
m_Sema, vectorModeOverloadFD,
CreateUniqueIdentifier("_temp_" + OriginalVD->getNameAsString()),
outputParamType, OriginalVD->getStorageClass());
} else {
PVD = utils::BuildParmVarDecl(
m_Sema, vectorModeOverloadFD,
CreateUniqueIdentifier("_d_" + std::to_string(i)), outputParamType,
StorageClass::SC_None);
}
overloadParams.push_back(PVD);
}

for (auto PVD : overloadParams) {
if (PVD->getIdentifier())
m_Sema.PushOnScopeChains(PVD, getCurrentScope(),
/*AddToContext=*/false);
}

vectorModeOverloadFD->setParams(overloadParams);
vectorModeOverloadFD->setBody(/*B=*/nullptr);

// Create the body of the derivative.
beginScope(Scope::FnScope | Scope::DeclScope);
m_DerivativeFnScope = getCurrentScope();
beginBlock();

// Build derivatives to be used in the call to the actual derived function.
// These are initialised by effectively casting the derivative parameters of
// overloaded derived function to the correct type.
for (std::size_t i = m_Function->getNumParams(); i < vectorModeParams.size();
++i) {
auto overloadParam = overloadParams[i];
auto vectorModeParam = vectorModeParams[i];

// Create a cast expression to cast the derivative parameter to the correct
// type.
auto castExpr =
m_Sema
.BuildCXXNamedCast(
noLoc, tok::TokenKind::kw_static_cast,
m_Context.getTrivialTypeSourceInfo(vectorModeParam->getType()),
BuildDeclRef(overloadParam), noLoc, noLoc)
.get();
auto vectorModeVD =
BuildVarDecl(vectorModeParam->getType(),
vectorModeParam->getNameAsString(), castExpr);
callArgs.push_back(BuildDeclRef(vectorModeVD));
addToCurrentBlock(BuildDeclStmt(vectorModeVD));
}

Expr* callExpr = BuildCallExprToFunction(m_Derivative, callArgs,
/*UseRefQualifiedThisObj=*/true);
addToCurrentBlock(callExpr);
Stmt* vectorModeOverloadBody = endBlock();

vectorModeOverloadFD->setBody(vectorModeOverloadBody);

endScope(); // Function body scope
m_Sema.PopFunctionScopeInfo();
m_Sema.PopDeclContext();
endScope(); // Function decl scope

return vectorModeOverloadFD;
}

clang::Expr* VectorForwardModeVisitor::getOneHotInitExpr(size_t index,
size_t size) {
// define a vector of size `size` with all elements set to 0,
// except for the element at `index` which is set to 1.
auto zero =
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0);
auto one =
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 1);
llvm::SmallVector<clang::Expr*, 8> oneHotInitList(size, zero);
oneHotInitList[index] = one;
return m_Sema.ActOnInitList(m_Function->getLocation(), llvm::MutableArrayRef<clang::Expr*>(oneHotInitList), m_Function->getLocation()).get();
size_t size,
clang::QualType type) {
// Build call expression for one_hot
llvm::SmallVector<Expr*, 2> args = {
ConstantFolder::synthesizeLiteral(m_Context.UnsignedLongTy, m_Context,
size),
ConstantFolder::synthesizeLiteral(m_Context.UnsignedLongTy, m_Context,
index)};
return BuildCallExprToCladFunction("one_hot_vector", args, {type},
m_Function->getLocation());
}

clang::Expr* VectorForwardModeVisitor::getZeroInitListExpr(size_t size) {
clang::Expr*
VectorForwardModeVisitor::getZeroInitListExpr(size_t size,
clang::QualType type) {
// define a vector of size `size` with all elements set to 0.
auto zero =
ConstantFolder::synthesizeLiteral(m_Context.DoubleTy, m_Context, 0);
llvm::SmallVector<clang::Expr*, 8> zeroInitList(size, zero);
return m_Sema.ActOnInitList(m_Function->getLocation(), llvm::MutableArrayRef<clang::Expr*>(zeroInitList), m_Function->getLocation()).get();
// Build call expression for zero_vector
llvm::SmallVector<Expr*, 2> args = {ConstantFolder::synthesizeLiteral(
m_Context.UnsignedLongTy, m_Context, size)};
return BuildCallExprToCladFunction("zero_vector", args, {type},
m_Function->getLocation());
}

llvm::SmallVector<clang::ParmVarDecl*, 8>
Expand Down Expand Up @@ -196,7 +336,7 @@ VectorForwardModeVisitor::BuildVectorModeParams(DiffParams& diffParams) {
m_Sema.PushOnScopeChains(dPVD, getCurrentScope(),
/*AddToContext=*/false);

m_ParamVariables[*it] = BuildOp(UO_Deref, BuildDeclRef(dPVD), m_Function->getLocation());
m_ParamVariables[*it] = BuildOp(UO_Deref, BuildDeclRef(dPVD), noLoc);
}
// insert the derivative parameters at the end of the parameter list.
params.insert(params.end(), paramDerivatives.begin(), paramDerivatives.end());
Expand Down Expand Up @@ -231,7 +371,7 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
auto dParamValue =
m_Sema
.ActOnArraySubscriptExpr(getCurrentScope(), dVectorRef,
dVectorRef->getExprLoc(), indexExpr, m_Function->getLocation())
dVectorRef->getExprLoc(), indexExpr, noLoc)
.get();
// Create an assignment expression to assign the ith element of the
// return vector to the derivative of the ith parameter.
Expand All @@ -241,7 +381,7 @@ StmtDiff VectorForwardModeVisitor::VisitReturnStmt(const ReturnStmt* RS) {
}
// Add an empty return statement to the array of statements.
returnStmts.push_back(
m_Sema.ActOnReturnStmt(m_Function->getLocation(), nullptr, getCurrentScope()).get());
m_Sema.ActOnReturnStmt(noLoc, nullptr, getCurrentScope()).get());

// Create a return statement from the compound statement.
Stmt* returnStmt = MakeCompoundStmt(returnStmts);
Expand Down Expand Up @@ -276,13 +416,13 @@ VarDeclDiff VectorForwardModeVisitor::DifferentiateVarDecl(const VarDecl* VD) {
.ActOnCXXTypeConstructExpr(
OpaquePtr<QualType>::make(
GetCladArrayOfType(utils::GetValueType(VD->getType()))),
m_Function->getLocation(), args, m_Function->getLocation(), false)
noLoc, args, noLoc, false)
.get();

VarDecl* VDDerived =
BuildVarDecl(GetCladArrayOfType(utils::GetValueType(VD->getType())),
"_d_vector_" + VD->getNameAsString(), constructorCallExpr,
true, nullptr, VarDecl::InitializationStyle::CallInit);
false, nullptr, VarDecl::InitializationStyle::CallInit);

m_Variables.emplace(VDClone, BuildDeclRef(VDDerived));
return VarDeclDiff(VDClone, VDDerived);
Expand Down
Loading

0 comments on commit b8eb829

Please sign in to comment.