Skip to content

Commit

Permalink
Add conversion to void* in the execute function
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi authored and vgvassilev committed Feb 15, 2024
1 parent 5ce9105 commit d1595dc
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 164 deletions.
2 changes: 1 addition & 1 deletion docs/userDocs/source/user/CoreConcepts.rst
Original file line number Diff line number Diff line change
Expand Up @@ -266,7 +266,7 @@ partial derivative of the function with respect to every input, and as such
is used in Clad's reverse mode. The signature of the method is as follows::

template <typename F, std::size_t... Ints,
typename RetType = typename clad::return_type<F>::type,
typename RetType = typename clad::function_traits<F>::return_type,
typename... Args>
void central_difference(F f, clad::tape_impl<clad::array_ref<RetType>>& _grad, bool printErrors, Args&&... args) {
// Similar to the above method, here:
Expand Down
23 changes: 17 additions & 6 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -93,35 +93,41 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
/// return f(1.0, 2.0, 0, 0);
// for executing non-member functions
template <bool EnablePadding, class... Rest, class F, class... Args,
class... fArgTypes,
typename std::enable_if<EnablePadding, bool>::type = true>
CUDA_HOST_DEVICE return_type_t<F>
execute_with_default_args(list<Rest...>, F f, Args&&... args) {
execute_with_default_args(list<Rest...>, F f, list<fArgTypes...>,
Args&&... args) {
return f(static_cast<Args>(args)..., static_cast<Rest>(nullptr)...);
}

template <bool EnablePadding, class... Rest, class F, class... Args,
class... fArgTypes,
typename std::enable_if<!EnablePadding, bool>::type = true>
return_type_t<F> execute_with_default_args(list<Rest...>, F f,
list<fArgTypes...>,
Args&&... args) {
return f(static_cast<Args>(args)...);
}

// for executing member-functions
template <bool EnablePadding, class... Rest, class ReturnType, class C,
class Obj, class... Args,
class Obj, class... Args, class... fArgTypes,
typename std::enable_if<EnablePadding, bool>::type = true>
CUDA_HOST_DEVICE auto execute_with_default_args(list<Rest...>,
ReturnType C::*f, Obj&& obj,
list<fArgTypes...>,
Args&&... args)
-> return_type_t<decltype(f)> {
return (static_cast<Obj>(obj).*f)(static_cast<Args>(args)...,
return (static_cast<Obj>(obj).*f)((fArgTypes)(args)...,
static_cast<Rest>(nullptr)...);
}

template <bool EnablePadding, class... Rest, class ReturnType, class C,
class Obj, class... Args,
class Obj, class... Args, class... fArgTypes,
typename std::enable_if<!EnablePadding, bool>::type = true>
auto execute_with_default_args(list<Rest...>, ReturnType C::*f, Obj&& obj,
list<fArgTypes...>,
Args&&... args) -> return_type_t<decltype(f)> {
return (static_cast<Obj>(obj).*f)(static_cast<Args>(args)...);
}
Expand Down Expand Up @@ -238,7 +244,9 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
execute_helper(Fn f, Args&&... args) {
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), F>{}, f, static_cast<Args>(args)...);
DropArgs_t<sizeof...(Args), F>{}, f,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
}

/// Helper functions for executing member derived functions.
Expand All @@ -256,7 +264,9 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), decltype(f)>{}, f,
static_cast<Obj>(obj), static_cast<Args>(args)...);
static_cast<Obj>(obj),
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
}
/// If user have not passed object explicitly, then this specialization
/// will be used and derived function will be called through the object
Expand All @@ -270,6 +280,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
// `static_cast` is required here for perfect forwarding.
return execute_with_default_args<EnablePadding>(
DropArgs_t<sizeof...(Args), decltype(f)>{}, f, *m_Functor,
TakeNFirstArgs_t<sizeof...(Args), decltype(f)>{},
static_cast<Args>(args)...);
}
};
Expand Down
Loading

0 comments on commit d1595dc

Please sign in to comment.