Skip to content

Commit

Permalink
Generate overload function for vector mode to fix asserts
Browse files Browse the repository at this point in the history
  • Loading branch information
vaithak committed Jul 27, 2023
1 parent 35dc011 commit 9ecf289
Show file tree
Hide file tree
Showing 9 changed files with 360 additions and 155 deletions.
16 changes: 16 additions & 0 deletions include/clad/Differentiator/Array.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,22 @@ template <typename T> class array {
CUDA_HOST_DEVICE operator T*() const { return m_arr; }
}; // class array

// Function to instantiate a one-hot array of size n with 1 at index i.
// A one-hot vector is a vector with all elements set to 0 except for one
// element which is set to 1.
// For example, if n=4 and i=2, the returned array is: {0, 0, 1, 0}
template <typename T>
CUDA_HOST_DEVICE array<T> one_hot_vector(std::size_t n, std::size_t i) {
array<T> arr(n);
arr[i] = 1;
return arr;
}

// Function to instantiate a zero vector of size n
template <typename T> CUDA_HOST_DEVICE array<T> zero_vector(std::size_t n) {
return array<T>(n);
}

/// Overloaded operators for clad::array which return a new array.

/// Multiplies the number to every element in the array and returns a new
Expand Down
160 changes: 79 additions & 81 deletions include/clad/Differentiator/FunctionTraits.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,130 +37,128 @@ namespace clad {

/// Placeholder type for denoting no function type exists
///
/// This is used by `ExtractDerivedFnTraitsForwMode` and
/// This is used by `ExtractDerivedFnTraitsForwMode` and
/// `ExtractDerivedFnTraits` type trait as value for member typedef
/// `type` to denote no function type exists.
class NoFunction {};


// Trait class to deduce return type of function(both member and non-member) at commpile time
// Only function pointer types are supported by this trait class
template <class F>
struct return_type {};
template <class F>
using return_type_t = typename return_type<F>::type;
template <class F> struct return_type {};
template <class F> using return_type_t = typename return_type<F>::type;

// specializations for non-member functions pointer types
template <class ReturnType, class... Args>
template <class ReturnType, class... Args>
struct return_type<ReturnType (*)(Args...)> {
using type = ReturnType;
};
template <class ReturnType, class... Args>
template <class ReturnType, class... Args>
struct return_type<ReturnType (*)(Args..., ...)> {
using type = ReturnType;
};

// specializations for member functions pointer types with no qualifiers
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...)> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...)> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...)> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...)> {
using type = ReturnType;
};

// specializations for member functions pointer type with only cv-qualifiers
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) volatile> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) volatile> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) volatile> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) volatile> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const volatile> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const volatile> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const volatile> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const volatile> {
using type = ReturnType;
};

// specializations for member functions pointer types with
// specializations for member functions pointer types with
// reference qualifiers and with and without cv-qualifiers
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) &> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...)&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) &> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...)&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const &> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const &> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) volatile &> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) volatile&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) volatile &> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) volatile&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const volatile &> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const volatile&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const volatile &> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const volatile&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) &&> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) &&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) &&> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) &&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const &&> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const&&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const &&> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const&&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) volatile &&> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) volatile&&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) volatile &&> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) volatile&&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const volatile &&> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args...) const volatile&&> {
using type = ReturnType;
};
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const volatile &&> {
using type = ReturnType;
template <class ReturnType, class C, class... Args>
struct return_type<ReturnType (C::*)(Args..., ...) const volatile&&> {
using type = ReturnType;
};

template<>
Expand Down Expand Up @@ -735,7 +733,7 @@ namespace clad {

template <class ReturnType, class... Args>
struct ExtractDerivedFnTraitsVecForwMode<ReturnType (*)(Args...)> {
using type = void (*)(Args..., OutputVecParamType_t<Args, ReturnType>...);
using type = void (*)(Args..., OutputVecParamType_t<Args, void>...);
};

/// Specialization for free function pointer type
Expand Down
10 changes: 8 additions & 2 deletions include/clad/Differentiator/VectorForwardModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
DerivativeAndOverload DeriveVectorMode(const clang::FunctionDecl* FD,
const DiffRequest& request);

/// Builds an overload for the gradient function that has derived params for
/// all the arguments of the requested function and it calls the original
/// gradient function internally
clang::FunctionDecl* CreateVectorModeOverload();

/// Builds and returns the sequence of derived function parameters for
// vectorized forward mode.
///
Expand All @@ -49,12 +54,13 @@ class VectorForwardModeVisitor : public BaseForwardModeVisitor {
///
/// For example: for index = 2 and size = 4, the returned expression
/// is: {0, 0, 1, 0}
clang::Expr* getOneHotInitExpr(size_t index, size_t size);
clang::Expr* getOneHotInitExpr(size_t index, size_t size,
clang::QualType type);

/// Get an expression used to initialize a zero vector of the given size.
///
/// For example: for size = 4, the returned expression is: {0, 0, 0, 0}
clang::Expr* getZeroInitListExpr(size_t size);
clang::Expr* getZeroInitListExpr(size_t size, clang::QualType type);

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
// Decl is not Stmt, so it cannot be visited directly.
Expand Down
36 changes: 25 additions & 11 deletions include/clad/Differentiator/VisitorBase.h
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ namespace clad {
clang::Expr* getExpr_dx() {
return llvm::cast_or_null<clang::Expr>(getStmt_dx());
}

void updateStmt(clang::Stmt* S) { data[1] = S; }
void updateStmtDx(clang::Stmt* S) { data[0] = S; }
// Stmt_dx goes first!
Expand Down Expand Up @@ -309,8 +309,10 @@ namespace clad {
/// declaration reference expressions. This function builds a declaration
/// reference given a declaration.
/// \param[in] D The declaration to build a DeclRefExpr for.
/// \param[in] SS The scope specifier for the declaration.
/// \returns the DeclRefExpr for the given declaration.
clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D);
clang::DeclRefExpr* BuildDeclRef(clang::DeclaratorDecl* D,
const clang::CXXScopeSpec* SS = nullptr);

/// Stores the result of an expression in a temporary variable (of the same
/// type as is the result of the expression) and returns a reference to it.
Expand Down Expand Up @@ -456,7 +458,21 @@ namespace clad {
clang::Expr*
BuildCallExprToFunction(clang::FunctionDecl* FD,
llvm::MutableArrayRef<clang::Expr*> argExprs,
bool useRefQualifiedThisObj = false);
bool useRefQualifiedThisObj = false,
const clang::CXXScopeSpec* SS = nullptr);

/// Build a call to templated free function inside the clad namespace.
///
/// \param[in] name name of the function
/// \param[in] argExprs function arguments expressions
/// \param[in] templateArgs template arguments
/// \param[in] loc location of the call
/// \returns Built call expression
clang::Expr* BuildCallExprToCladFunction(
llvm::StringRef name, llvm::MutableArrayRef<clang::Expr*> argExprs,
llvm::ArrayRef<clang::TemplateArgument> templateArgs,
clang::SourceLocation loc);

/// Find declaration of clad::array_ref templated type.
clang::TemplateDecl* GetCladArrayRefDecl();
/// Create clad::array_ref<T> type.
Expand Down Expand Up @@ -492,9 +508,8 @@ namespace clad {
///
/// \returns The derivative function call.
clang::Expr* GetSingleArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::Expr* targetArg,
unsigned targetPos, unsigned numArgs,
llvm::SmallVectorImpl<clang::Expr*>& args);
clang::Expr* targetFuncCall, clang::Expr* targetArg, unsigned targetPos,
unsigned numArgs, llvm::SmallVectorImpl<clang::Expr*>& args);
/// A function to get the multi-argument "central_difference"
/// call expression for the given arguments.
///
Expand All @@ -508,17 +523,16 @@ namespace clad {
///
/// \returns The derivative function call.
clang::Expr* GetMultiArgCentralDiffCall(
clang::Expr* targetFuncCall, clang::QualType retType,
unsigned numArgs,
clang::Expr* targetFuncCall, clang::QualType retType, unsigned numArgs,
llvm::SmallVectorImpl<clang::Stmt*>& NumericalDiffMultiArg,
llvm::SmallVectorImpl<clang::Expr*>& args,
llvm::SmallVectorImpl<clang::Expr*>& outputArgs);
/// Emits diagnostic messages on differentiation (or lack thereof) for
/// Emits diagnostic messages on differentiation (or lack thereof) for
/// call expressions.
///
/// \param[in] \c funcName The name of the underlying function of the
/// \param[in] \c funcName The name of the underlying function of the
/// call expression.
/// \param[in] \c srcLoc Any associated source location information.
/// \param[in] \c srcLoc Any associated source location information.
/// \param[in] \c isDerived A flag to determine if differentiation of the
/// call expression was successful.
void CallExprDiffDiagnostics(llvm::StringRef funcName,
Expand Down
Loading

0 comments on commit 9ecf289

Please sign in to comment.