Skip to content

Commit

Permalink
Fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
parth-07 committed Apr 27, 2022
1 parent 8f6aa67 commit 74904d8
Show file tree
Hide file tree
Showing 8 changed files with 106 additions and 88 deletions.
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DerivativeBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ namespace clad {
friend class ReverseModeVisitor;
friend class HessianModeVisitor;
friend class JacobianModeVisitor;
friend class TransformSourceFnVisitor;
friend class ReverseModeForwPassVisitor;

clang::Sema& m_Sema;
plugin::CladPlugin& m_CladPlugin;
Expand Down
2 changes: 1 addition & 1 deletion include/clad/Differentiator/DiffMode.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace clad {
experimental_pushforward,
experimental_pullback,
reverse,
reverse_source_fn,
reverse_mode_forward_pass,
hessian,
jacobian,
error_estimation
Expand Down
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
#ifndef CLAD_TRANSFORM_SOURCE_FN_VISITOR_H
#define CLAD_TRANSFORM_SOURCE_FN_VISITOR_H

#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/ParseDiffArgsTypes.h"
#include "clad/Differentiator/ReverseModeVisitor.h"

#include "clang/AST/StmtVisitor.h"
#include "clang/Sema/Sema.h"

#include "llvm/ADT/SmallVector.h"

namespace clad {
class TransformSourceFnVisitor
: public ReverseModeVisitor {
class ReverseModeForwPassVisitor : public ReverseModeVisitor {
private:
Stmts m_Globals;

llvm::SmallVector<clang::QualType, 8>
ComputeParamTypes(const DiffParams& diffParams);
clang::QualType ComputeReturnType();
llvm::SmallVector<clang::ParmVarDecl*, 8>
BuildParams(DiffParams& diffParams);
clang::QualType GetParameterDerivativeType(clang::QualType yType, clang::QualType xType);
llvm::SmallVector<clang::ParmVarDecl*, 8> BuildParams(DiffParams& diffParams);
clang::QualType GetParameterDerivativeType(clang::QualType yType,
clang::QualType xType);

public:
TransformSourceFnVisitor(DerivativeBuilder& builder);
ReverseModeForwPassVisitor(DerivativeBuilder& builder);
OverloadedDeclWithContext Derive(const clang::FunctionDecl* FD,
const DiffRequest& request);

StmtDiff ProcessSingleStmt(const clang::Stmt* S);

StmtDiff VisitStmt(const clang::Stmt* S) override;
Expand Down
2 changes: 1 addition & 1 deletion lib/Differentiator/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ add_llvm_library(cladDifferentiator
HessianModeVisitor.cpp
JacobianModeVisitor.cpp
MultiplexExternalRMVSource.cpp
ReverseModeForwPassVisitor.cpp
ReverseModeVisitor.cpp
TransformSourceFnVisitor.cpp
ErrorEstimator.cpp
EstimationModel.cpp
StmtClone.cpp
Expand Down
6 changes: 3 additions & 3 deletions lib/Differentiator/DerivativeBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
#include "clad/Differentiator/HessianModeVisitor.h"
#include "clad/Differentiator/JacobianModeVisitor.h"
#include "clad/Differentiator/ReverseModeVisitor.h"
#include "clad/Differentiator/TransformSourceFnVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"

#include "clad/Differentiator/DiffPlanner.h"
#include "clad/Differentiator/StmtClone.h"
Expand Down Expand Up @@ -173,8 +173,8 @@ namespace clad {
} else if (request.Mode == DiffMode::experimental_pullback) {
ReverseModeVisitor V(*this);
result = V.DerivePullback(FD, request);
} else if (request.Mode == DiffMode::reverse_source_fn) {
TransformSourceFnVisitor V(*this);
} else if (request.Mode == DiffMode::reverse_mode_forward_pass) {
ReverseModeForwPassVisitor V(*this);
result = V.Derive(FD, request);
}
else if (request.Mode == DiffMode::hessian) {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#include "clad/Differentiator/TransformSourceFnVisitor.h"
#include "clad/Differentiator/ReverseModeForwPassVisitor.h"

#include "clad/Differentiator/CladUtils.h"
#include "clad/Differentiator/DiffPlanner.h"
Expand All @@ -12,23 +12,24 @@ using namespace clang;

namespace clad {

TransformSourceFnVisitor::TransformSourceFnVisitor(DerivativeBuilder& builder)
ReverseModeForwPassVisitor::ReverseModeForwPassVisitor(
DerivativeBuilder& builder)
: ReverseModeVisitor(builder) {}

OverloadedDeclWithContext
TransformSourceFnVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
ReverseModeForwPassVisitor::Derive(const FunctionDecl* FD,
const DiffRequest& request) {
silenceDiags = !request.VerboseDiags;
m_Function = FD;

m_Mode = DiffMode::reverse_source_fn;
m_Mode = DiffMode::reverse_mode_forward_pass;

assert(m_Function && "Must not be null.");

DiffParams args{};
std::copy(FD->param_begin(), FD->param_end(), std::back_inserter(args));

auto fnName = m_Function->getNameAsString() + "_with_adjoint";
auto fnName = m_Function->getNameAsString() + "_forw";
auto fnDNI = utils::BuildDeclarationNameInfo(m_Sema, fnName);

auto paramTypes = ComputeParamTypes(args);
Expand Down Expand Up @@ -86,8 +87,9 @@ TransformSourceFnVisitor::Derive(const FunctionDecl* FD,

// FIXME: This function is copied from ReverseModeVisitor. Find a suitable place
// for it.
QualType TransformSourceFnVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {
QualType
ReverseModeForwPassVisitor::GetParameterDerivativeType(QualType yType,
QualType xType) {
assert(yType.getNonReferenceType()->isRealType() &&
"yType should be a builtin-numerical scalar type!!");
QualType xValueType = utils::GetValueType(xType);
Expand All @@ -101,7 +103,7 @@ QualType TransformSourceFnVisitor::GetParameterDerivativeType(QualType yType,
}

llvm::SmallVector<clang::QualType, 8>
TransformSourceFnVisitor::ComputeParamTypes(const DiffParams& diffParams) {
ReverseModeForwPassVisitor::ComputeParamTypes(const DiffParams& diffParams) {
llvm::SmallVector<clang::QualType, 8> paramTypes;
paramTypes.reserve(m_Function->getNumParams() * 2);
for (auto PVD : m_Function->parameters())
Expand Down Expand Up @@ -129,15 +131,15 @@ TransformSourceFnVisitor::ComputeParamTypes(const DiffParams& diffParams) {
return paramTypes;
}

clang::QualType TransformSourceFnVisitor::ComputeReturnType() {
clang::QualType ReverseModeForwPassVisitor::ComputeReturnType() {
auto valAndAdjointTempDecl = GetCladClassDecl("ValueAndAdjoint");
auto RT = m_Function->getReturnType();
auto T = GetCladClassOfType(valAndAdjointTempDecl, {RT, RT});
return T;
}

llvm::SmallVector<clang::ParmVarDecl*, 8>
TransformSourceFnVisitor::BuildParams(DiffParams& diffParams) {
ReverseModeForwPassVisitor::BuildParams(DiffParams& diffParams) {
llvm::SmallVector<clang::ParmVarDecl*, 8> params, paramDerivatives;
params.reserve(m_Function->getNumParams() + diffParams.size());
auto derivativeFnType = cast<FunctionProtoType>(m_Derivative->getType());
Expand Down Expand Up @@ -207,17 +209,17 @@ TransformSourceFnVisitor::BuildParams(DiffParams& diffParams) {
return params;
}

StmtDiff TransformSourceFnVisitor::ProcessSingleStmt(const clang::Stmt* S) {
StmtDiff ReverseModeForwPassVisitor::ProcessSingleStmt(const clang::Stmt* S) {
StmtDiff SDiff = Visit(S);
return {SDiff.getStmt()};
}

StmtDiff TransformSourceFnVisitor::VisitStmt(const clang::Stmt* S) {
StmtDiff ReverseModeForwPassVisitor::VisitStmt(const clang::Stmt* S) {
return {Clone(S)};
}

StmtDiff
TransformSourceFnVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) {
ReverseModeForwPassVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) {
beginScope(Scope::DeclScope);
beginBlock();
for (Stmt* S : CS->body()) {
Expand All @@ -229,7 +231,7 @@ TransformSourceFnVisitor::VisitCompoundStmt(const clang::CompoundStmt* CS) {
return {forward};
}

StmtDiff TransformSourceFnVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
StmtDiff ReverseModeForwPassVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
DeclRefExpr* clonedDRE = nullptr;
// Check if referenced Decl was "replaced" with another identifier inside
// the derivative
Expand Down Expand Up @@ -265,17 +267,17 @@ StmtDiff TransformSourceFnVisitor::VisitDeclRefExpr(const DeclRefExpr* DRE) {
}

StmtDiff
TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
ReverseModeForwPassVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
const Expr* value = RS->getRetValue();
auto returnDiff = Visit(value);
llvm::SmallVector<Expr*, 2> returnArgs = {returnDiff.getExpr(),
returnDiff.getExpr_dx()};
Expr* returnInitList = m_Sema.BuildInitList(noLoc, returnArgs, noLoc).get();
Expr* returnInitList = m_Sema.ActOnInitList(noLoc, returnArgs, noLoc).get();
Stmt* newRS = m_Sema.BuildReturnStmt(noLoc, returnInitList).get();
return {newRS};
}

// StmtDiff TransformSourceFnVisitor::VisitDeclStmt(const DeclStmt* DS) {
// StmtDiff ReverseModeForwPassVisitor::VisitDeclStmt(const DeclStmt* DS) {
// llvm::SmallVector<Decl*, 4> decls, derivedDecls;
// for (auto D : DS->decls()) {
// if (auto VD = dyn_cast<VarDecl>(D)) {
Expand All @@ -292,42 +294,48 @@ TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
// }
// }

// VarDeclDiff TransformSourceFnVisitor::DifferentiateVarDecl(const VarDecl* VD) {
// VarDeclDiff ReverseModeForwPassVisitor::DifferentiateVarDecl(const VarDecl*
// VD) {
// StmtDiff initDiff;
// Expr* VDDerivedInit = nullptr;
// auto VDDerivedType = VD->getType();
// bool isVDRefType = VD->getType()->isReferenceType();
// VarDecl* VDDerived = nullptr;

// if (auto VDCAT = dyn_cast<ConstantArrayType>(VD->getType())) {
// assert("Should not reach here!!!");
// // VDDerivedType =
// // GetCladArrayOfType(QualType(VDCAT->getPointeeOrArrayElementType(),
// // VDCAT->getIndexTypeCVRQualifiers()));
// // GetCladArrayOfType(QualType(VDCAT->getPointeeOrArrayElementType(),
// // VDCAT->getIndexTypeCVRQualifiers()));
// // VDDerivedInit = ConstantFolder::synthesizeLiteral(
// // m_Context.getSizeType(), m_Context, VDCAT->getSize().getZExtValue());
// // VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
// // m_Context.getSizeType(), m_Context,
// VDCAT->getSize().getZExtValue());
// // VDDerived = BuildVarDecl(VDDerivedType, "_d_" +
// VD->getNameAsString(),
// // VDDerivedInit, false, nullptr,
// // clang::VarDecl::InitializationStyle::CallInit);
// // clang::VarDecl::InitializationStyle::CallInit);
// } else {
// // If VD is a reference to a local variable, then the initial value is set
// // If VD is a reference to a local variable, then the initial value is
// set
// // to the derived variable of the corresponding local variable.
// // If VD is a reference to a non-local variable (global variable, struct
// // If VD is a reference to a non-local variable (global variable,
// struct
// // member etc), then no derived variable is available, thus `VDDerived`
// // does not need to reference any variable, consequentially the
// // `VDDerivedType` is the corresponding non-reference type and the initial
// // `VDDerivedType` is the corresponding non-reference type and the
// initial
// // value is set to 0.
// // Otherwise, for non-reference types, the initial value is set to 0.
// VDDerivedInit = getZeroInit(VD->getType());

// // `specialThisDiffCase` is only required for correctly differentiating
// // the following code:
// // the following code:
// // ```
// // Class _d_this_obj;
// // Class* _d_this = &_d_this_obj;
// // ```
// // Computation of hessian requires this code to be correctly
// // differentiated.
// // differentiated.
// bool specialThisDiffCase = false;
// if (auto MD = dyn_cast<CXXMethodDecl>(m_Function)) {
// if (VDDerivedType->isPointerType() && MD->isInstance()) {
Expand Down Expand Up @@ -361,7 +369,8 @@ TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
// m_Context.getTrivialTypeSourceInfo(VDDerivedType),
// VD->getInitStyle());
// else
// VDDerived = BuildVarDecl(VDDerivedType, "_d_" + VD->getNameAsString(),
// VDDerived = BuildVarDecl(VDDerivedType, "_d_" +
// VD->getNameAsString(),
// VDDerivedInit);
// }

Expand All @@ -370,12 +379,14 @@ TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
// // If `VD` is a reference to a non-local variable then also there's no
// // need to call `Visit` since non-local variables are not differentiated.
// if (!isVDRefType) {
// initDiff = VD->getInit() ? Visit(VD->getInit(), BuildDeclRef(VDDerived))
// initDiff = VD->getInit() ? Visit(VD->getInit(),
// BuildDeclRef(VDDerived))
// : StmtDiff{};

// // If we are differentiating `VarDecl` corresponding to a local variable
// // If we are differentiating `VarDecl` corresponding to a local
// variable
// // inside a loop, then we need to reset it to 0 at each iteration.
// //
// //
// // for example, if defined inside a loop,
// // ```
// // double localVar = i;
Expand All @@ -385,7 +396,7 @@ TransformSourceFnVisitor::VisitReturnStmt(const clang::ReturnStmt* RS) {
// // {
// // *_d_i += _d_localVar;
// // _d_localVar = 0;
// // }
// // }
// if (isInsideLoop) {
// Stmt* assignToZero = BuildOp(BinaryOperatorKind::BO_Assign,
// BuildDeclRef(VDDerived),
Expand Down
19 changes: 10 additions & 9 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1704,7 +1704,7 @@ namespace clad {
if (FD->getReturnType()->isReferenceType()) {
DiffRequest transformReq;
transformReq.Function = FD;
transformReq.Mode = DiffMode::reverse_source_fn;
transformReq.Mode = DiffMode::reverse_mode_forward_pass;
transformReq.BaseFunctionName = FD->getNameAsString();
transformReq.VerboseDiags = true;
FunctionDecl* transformedSourcefn = plugin::ProcessDiffRequest(m_CladPlugin, transformReq);
Expand Down Expand Up @@ -2207,8 +2207,13 @@ namespace clad {
}

if (isVDRefType) {
QualType T = utils::GetValueType(VD->getType());
T.removeLocalConst();
VDDerivedType =
m_Context.getPointerType(VDDerivedType.getNonReferenceType());
m_Context.getPointerType(T);
initDiff = Visit(VD->getInit());
if (!initDiff.getExpr_dx())
VDDerivedType = VDDerivedType.getNonReferenceType();
VDDerivedInit = getZeroInit(VDDerivedType);
}

Expand All @@ -2221,15 +2226,11 @@ namespace clad {
// ```
// Computation of hessian requires this code to be correctly
// differentiated.
if (isVDRefType || specialThisDiffCase) {
if (specialThisDiffCase) {
VDDerivedType = getNonConstType(VDDerivedType, m_Context, m_Sema);
initDiff = Visit(VD->getInit());
// if (initDiff.getExpr_dx())
// VDDerivedInit = initDiff.getExpr_dx();
// else
// VDDerivedType = VDDerivedType.getNonReferenceType();
if (!initDiff.getExpr_dx())
VDDerivedType = VDDerivedType.getNonReferenceType();
if (initDiff.getExpr_dx())
VDDerivedInit = initDiff.getExpr_dx();
}
// Here separate behaviour for record and non-record types is only
// necessary to preserve the old tests.
Expand Down
Loading

0 comments on commit 74904d8

Please sign in to comment.