Skip to content

Commit

Permalink
Implement forward-mode jacobians
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Oct 7, 2024
1 parent 8ed2707 commit 558c337
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 3 deletions.
22 changes: 22 additions & 0 deletions include/clad/Differentiator/JacobianModeVisitor.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
#ifndef CLAD_DIFFERENTIATOR_JACOBIANMODEVISITOR_H
#define CLAD_DIFFERENTIATOR_JACOBIANMODEVISITOR_H

#include "VectorPushForwardModeVisitor.h"

namespace clad {
class JacobianModeVisitor : public VectorPushForwardModeVisitor {

public:
JacobianModeVisitor(DerivativeBuilder& builder, const DiffRequest& request);

DiffMode GetPushForwardMode() override;

std::string GetPushForwardFunctionSuffix() override;

DerivativeAndOverload DeriveJacobian();

StmtDiff VisitReturnStmt(const clang::ReturnStmt* RS) override;
};
} // end namespace clad

#endif // CLAD_DIFFERENTIATOR_JACOBIANMODEVISITOR_H
6 changes: 5 additions & 1 deletion lib/Differentiator/BaseForwardModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,8 @@ DerivativeAndOverload BaseForwardModeVisitor::Derive() {
clang::QualType BaseForwardModeVisitor::ComputePushforwardFnReturnType() {
assert(m_DiffReq.Mode == GetPushForwardMode());
QualType originalFnRT = m_DiffReq->getReturnType();
if (m_DiffReq.Mode == DiffMode::jacobian)
return GetPushForwardDerivativeType(originalFnRT);
if (originalFnRT->isVoidType())
return m_Context.VoidTy;
TemplateDecl* valueAndPushforward =
Expand Down Expand Up @@ -1445,7 +1447,9 @@ BaseForwardModeVisitor::VisitBinaryOperator(const BinaryOperator* BinOp) {
derivedR = BuildParens(derivedR);
opDiff = BuildOp(opCode, derivedL, derivedR);
} else if (BinOp->isAssignmentOp()) {
if (Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) != Expr::MLV_Valid) {
if ((Ldiff.getExpr_dx()->isModifiableLvalue(m_Context) !=
Expr::MLV_Valid) &&
!isCladArrayType(Ldiff.getExpr_dx()->getType())) {
diag(DiagnosticsEngine::Warning, BinOp->getEndLoc(),
"derivative of an assignment attempts to assign to unassignable "
"expr, assignment ignored");
Expand Down
1 change: 1 addition & 0 deletions lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ llvm_add_library(cladDifferentiator
DiffPlanner.cpp
ErrorEstimator.cpp
EstimationModel.cpp
JacobianModeVisitor.cpp
HessianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
PushForwardModeVisitor.cpp
Expand Down
5 changes: 3 additions & 2 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/ErrorEstimator.h"
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/PushForwardModeVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
Expand Down Expand Up @@ -429,8 +430,8 @@ static void registerDerivative(FunctionDecl* derivedFD, Sema& semaRef) {
HessianModeVisitor H(*this, request);
result = H.Derive();
} else if (request.Mode == DiffMode::jacobian) {
ReverseModeVisitor R(*this, request);
result = R.Derive();
JacobianModeVisitor J(*this, request);
result = J.DeriveJacobian();
} else if (request.Mode == DiffMode::error_estimation) {
ReverseModeVisitor R(*this, request);
InitErrorEstimation(m_ErrorEstHandler, m_EstModel, *this, request);
Expand Down
43 changes: 43 additions & 0 deletions lib/Differentiator/JacobianModeVisitor.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
#include "clad/Differentiator/JacobianModeVisitor.h"

#include "ConstantFolder.h"
#include "clad/Differentiator/CladUtils.h"

#include "llvm/Support/SaveAndRestore.h"

using namespace clang;

namespace clad {
JacobianModeVisitor::JacobianModeVisitor(DerivativeBuilder& builder,
const DiffRequest& request)
: VectorPushForwardModeVisitor(builder, request) {}

DiffMode JacobianModeVisitor::GetPushForwardMode() {
return DiffMode::jacobian;
}

std::string JacobianModeVisitor::GetPushForwardFunctionSuffix() {
return "_jac";
}

DerivativeAndOverload JacobianModeVisitor::DeriveJacobian() {
return DerivePushforward(m_DiffReq.Function, m_DiffReq);
}

StmtDiff JacobianModeVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
// If there is no return value, we must not attempt to differentiate
if (!RS->getRetValue())
return nullptr;

StmtDiff retValDiff = Visit(RS->getRetValue());
// return StmtDiff(retValDiff.getExpr_dx());
// This can instantiate as part of the move or copy initialization and
// needs a fake source location.
SourceLocation fakeLoc = utils::GetValidSLoc(m_Sema);
Stmt* returnStmt =
m_Sema
.ActOnReturnStmt(fakeLoc, retValDiff.getExpr_dx(), getCurrentScope())
.get();
return StmtDiff(returnStmt);
}
} // end namespace clad

0 comments on commit 558c337

Please sign in to comment.