Skip to content

Commit

Permalink
Add support for custom _reverse_forw functions
Browse files Browse the repository at this point in the history
This commit adds support for custom (user-provided) `_forw` functions.
A `_forw` function, if available, is called in place of the actual
function.

For example, if the primal code contains:

```cpp
someFn(u, v, w);
```

and user has defined a custom `_reverse_forw` function for `someFn` as follows:

```cpp
namespace clad {
  namespace custom_derivatives {
    void someFn_reverse_forw(double u, double v, double w, double *d_u,
      double *d_v, double *dw) {
      // ...
      // ...
    }
  }
}
```

Then clad will generate the derivative function as follows:

```cpp
// forward-pass
clad::custom_derivatives::someFn_reverse_forw(u, v, w, d_u, d_v, d_w);
// ...

// reverse-pass; no change in reverse-pass
someFn_pullback(u, v, w, d_u, d_v, d_w);
// ...
```

But more importantly, why do we need such a functionality? Two reasons:

- Supporting reference/pointer return types in the reverse-mode. This
  has been discussed at great length here:
vgvassilev#425 (vgvassilev#425)

- Supporting types whose elements grows dynamically, such as
  `std::vector` and `std::map`. The issue is that we correctly
  need to update the size/property of the adjoint variable when a
  function call updates the size/property of the corresponding primal
  variable. For example: a call to `vec.push_back(...)` should update
  the size of `_d_vec` as well. However, the actual function call does
  not modify the adjoint variable in any way. Here comes `_forw` functions
  to the rescue. `_forw` functions makes it possible to adjust the adjoint
  variable size/properties along with executing the actual function call.

Please note that `_reverse_forw` function signature takes adjoint variables as
arguments and return `clad::ValueAndAdjoint<U, V>` to support the
reference/pointer return type.
  • Loading branch information
infinite-void-16 committed Aug 19, 2024
1 parent 3f5bfd0 commit 6ede83c
Show file tree
Hide file tree
Showing 9 changed files with 430 additions and 14 deletions.
5 changes: 5 additions & 0 deletions include/clad/Differentiator/BuiltinDerivatives.h
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,11 @@ template <typename T, typename U> struct ValueAndPushforward {
}
};

template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
};

/// It is used to identify constructor custom pushforwards. For
/// constructor custom pushforward functions, we cannot use the same
/// strategy which we use for custom pushforward for member
Expand Down
3 changes: 3 additions & 0 deletions include/clad/Differentiator/CladUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,9 @@ namespace clad {

bool IsMemoryFunction(const clang::FunctionDecl* FD);
bool IsMemoryDeallocationFunction(const clang::FunctionDecl* FD);

/// Returns true if QT is a non-const reference type.
bool isNonConstReferenceType(clang::QualType QT);
} // namespace utils
} // namespace clad

Expand Down
8 changes: 2 additions & 6 deletions include/clad/Differentiator/Differentiator.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,8 @@
#include <cstring>

namespace clad {
template <typename T, typename U> struct ValueAndAdjoint {
T value;
U adjoint;
};

/// \returns the size of a c-style string
/// \returns the size of a c-style string
inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {
const char* code_copy = code;
#ifdef __CUDACC__
Expand Down Expand Up @@ -507,7 +503,7 @@ inline CUDA_HOST_DEVICE unsigned int GetLength(const char* code) {

// Gradient Structure for Reverse Mode Enzyme
template <unsigned N> struct EnzymeGradient { double d_arr[N]; };
}
} // namespace clad
#endif // CLAD_DIFFERENTIATOR

// Enable clad after the header was included.
Expand Down
5 changes: 5 additions & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,11 @@ namespace clad {
// Function to Differentiate with Enzyme as Backend
void DifferentiateWithEnzyme();

/// Tries to find and build call to user-provided `_forw` function.
clang::Expr* BuildCallToCustomForwPassFn(
const clang::FunctionDecl* FD, llvm::ArrayRef<clang::Expr*> primalArgs,
llvm::ArrayRef<clang::Expr*> derivedArgs, clang::Expr* baseExpr);

public:
using direction = rmv::direction;
clang::Expr* dfdx() {
Expand Down
29 changes: 29 additions & 0 deletions include/clad/Differentiator/STLBuiltins.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,35 @@ void fill_pushforward(::std::array<T, N>* a, const T& u,
d_a->fill(d_u);
}

template <typename T, typename U>
void push_back_reverse_forw(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U* d_val) {
v->push_back(val);
d_v->push_back(0);
}

template <typename T, typename U>
void push_back_pullback(::std::vector<T>* v, U val, ::std::vector<T>* d_v,
U* d_val) {
*d_val += d_v->back();
d_v->pop_back();
}

template <typename T>
clad::ValueAndAdjoint<T&, T&> operator_subscript_reverse_forw(
::std::vector<T>* vec, typename ::std::vector<T>::size_type idx,
::std::vector<T>* d_vec, typename ::std::vector<T>::size_type* d_idx) {
return {(*vec)[idx], (*d_vec)[idx]};
}

template <typename T, typename P>
void operator_subscript_pullback(::std::vector<T>* vec,
typename ::std::vector<T>::size_type idx,
P d_y, ::std::vector<T>* d_vec,
typename ::std::vector<T>::size_type* d_idx) {
(*d_vec)[idx] += d_y;
}

} // namespace class_functions
} // namespace custom_derivatives
} // namespace clad
Expand Down
5 changes: 5 additions & 0 deletions lib/Differentiator/CladUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -705,5 +705,10 @@ namespace clad {
return FD->getNameAsString() == "free";
#endif
}

bool isNonConstReferenceType(clang::QualType QT) {
return QT->isReferenceType() &&
!QT.getNonReferenceType().isConstQualified();
}
} // namespace utils
} // namespace clad
36 changes: 34 additions & 2 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1799,6 +1799,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

// Stores differentiation result of implicit `this` object, if any.
StmtDiff baseDiff;
Expr* baseExpr = nullptr;
// If it has more args or f_darg0 was not found, we look for its pullback
// function.
const auto* MD = dyn_cast<CXXMethodDecl>(FD);
Expand All @@ -1822,6 +1823,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
baseOriginalE = OCE->getArg(0);

baseDiff = Visit(baseOriginalE);
baseExpr = baseDiff.getExpr();
Expr* baseDiffStore = GlobalStoreAndRef(baseDiff.getExpr());
baseDiff.updateStmt(baseDiffStore);
Expr* baseDerivative = baseDiff.getExpr_dx();
Expand Down Expand Up @@ -2007,8 +2009,18 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
Expr* call = nullptr;

QualType returnType = FD->getReturnType();
if (returnType->isReferenceType() &&
!returnType.getNonReferenceType().isConstQualified()) {
if (Expr* customForwardPassCE = BuildCallToCustomForwPassFn(
FD, CallArgs, DerivedCallOutputArgs, baseExpr)) {
if (!utils::isNonConstReferenceType(returnType))
return StmtDiff{customForwardPassCE};
auto* callRes = StoreAndRef(customForwardPassCE);
auto* resValue =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "value");
auto* resAdjoint =
utils::BuildMemberExpr(m_Sema, getCurrentScope(), callRes, "adjoint");
return StmtDiff(resValue, nullptr, resAdjoint);
}
if (utils::isNonConstReferenceType(returnType)) {
DiffRequest calleeFnForwPassReq;
calleeFnForwPassReq.Function = FD;
calleeFnForwPassReq.Mode = DiffMode::reverse_mode_forward_pass;
Expand Down Expand Up @@ -4260,4 +4272,24 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
diffParams.end());
return params;
}

Expr* ReverseModeVisitor::BuildCallToCustomForwPassFn(
const FunctionDecl* FD, llvm::ArrayRef<Expr*> primalArgs,
llvm::ArrayRef<clang::Expr*> derivedArgs, Expr* baseExpr) {
std::string forwPassFnName =
clad::utils::ComputeEffectiveFnName(FD) + "_reverse_forw";
llvm::SmallVector<Expr*, 4> args;
if (baseExpr) {
baseExpr = BuildOp(UnaryOperatorKind::UO_AddrOf, baseExpr,
m_DiffReq->getLocation());
args.push_back(baseExpr);
}
args.append(primalArgs.begin(), primalArgs.end());
args.append(derivedArgs.begin(), derivedArgs.end());
Expr* customForwPassCE =
m_Builder.BuildCallToCustomDerivativeOrNumericalDiff(
forwPassFnName, args, getCurrentScope(),
const_cast<DeclContext*>(FD->getDeclContext()));
return customForwPassCE;
}
} // end namespace clad
8 changes: 2 additions & 6 deletions test/Gradient/FunctionCalls.C
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ double& identity(double& i) {

namespace clad{
namespace custom_derivatives{
clad::ValueAndAdjoint<double &, double &> custom_identity_forw(double &i, double *d_i) {
clad::ValueAndAdjoint<double &, double &> custom_identity_reverse_forw(double &i, double *d_i) {
return {i, *d_i};
}
} // namespace custom_derivatives
Expand Down Expand Up @@ -260,10 +260,6 @@ double fn7(double i, double j) {

// CHECK: void custom_identity_pullback(double &i, double _d_y, double *_d_i);

// CHECK: clad::ValueAndAdjoint<double &, double &> custom_identity_forw(double &i, double *d_i) {
// CHECK-NEXT: return {i, *d_i};
// CHECK-NEXT: }

// CHECK: void fn7_grad(double i, double j, double *_d_i, double *_d_j) {
// CHECK-NEXT: double _t0 = i;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t1 = identity_forw(i, &*_d_i);
Expand All @@ -274,7 +270,7 @@ double fn7(double i, double j) {
// CHECK-NEXT: double &_d_l = _t3.adjoint;
// CHECK-NEXT: double &l = _t3.value;
// CHECK-NEXT: double _t4 = i;
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = custom_identity_forw(i, &*_d_i);
// CHECK-NEXT: clad::ValueAndAdjoint<double &, double &> _t5 = {{.*}}custom_derivatives::custom_identity_reverse_forw(i, &*_d_i);
// CHECK-NEXT: double &_d_temp = _t5.adjoint;
// CHECK-NEXT: double &temp = _t5.value;
// CHECK-NEXT: double _t6 = k;
Expand Down
Loading

0 comments on commit 6ede83c

Please sign in to comment.