Skip to content

Commit

Permalink
Add a flag for TBR analysis.
Browse files Browse the repository at this point in the history
  • Loading branch information
PetroZarytskyi committed Nov 15, 2023
1 parent 19820b4 commit cec731f
Show file tree
Hide file tree
Showing 5 changed files with 34 additions and 7 deletions.
2 changes: 2 additions & 0 deletions include/clad/Differentiator/DiffPlanner.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ namespace clad {
bool CallUpdateRequired = false;
/// A flag to enable/disable diag warnings/errors during differentiation.
bool VerboseDiags = false;
/// A flag to enable TBR analysis during reverse-mode differentiation.
bool EnableTBRAnalysis = false;
/// Puts the derived function and its code in the diff call
void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD,
clang::Sema& SemaRef);
Expand Down
1 change: 1 addition & 0 deletions include/clad/Differentiator/ReverseModeVisitor.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ namespace clad {
unsigned numParams = 0;
bool isVectorValued = false;
bool use_enzyme = false;
bool enableTBR = false;
// FIXME: Should we make this an object instead of a pointer?
// Downside of making it an object: We will need to include
// 'MultiplexExternalRMVSource.h' file
Expand Down
24 changes: 19 additions & 5 deletions lib/Differentiator/ReverseModeVisitor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -292,6 +292,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
outputArrayStr = m_Function->getParamDecl(lastArgN)->getNameAsString();
}

// Check if DiffRequest asks for TBR analysis to be enabled
if (request.EnableTBRAnalysis)
enableTBR = true;

// Check if DiffRequest asks for use of enzyme as backend
if (request.use_enzyme)
use_enzyme = true;
Expand Down Expand Up @@ -454,9 +458,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
DerivativeAndOverload
ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD,
const DiffRequest& request) {
if (request.EnableTBRAnalysis)
enableTBR = true;
TBRAnalyzer analyzer(m_Context);
analyzer.Analyze(FD);
m_ToBeRecorded = analyzer.getResult();
if (enableTBR) {
analyzer.Analyze(FD);
m_ToBeRecorded = analyzer.getResult();
}

// for (auto pair : m_ToBeRecorded) {
// auto line =
Expand Down Expand Up @@ -570,8 +578,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,

void ReverseModeVisitor::DifferentiateWithClad() {
TBRAnalyzer analyzer(m_Context);
analyzer.Analyze(m_Function);
m_ToBeRecorded = analyzer.getResult();
if (enableTBR) {
analyzer.Analyze(m_Function);
m_ToBeRecorded = analyzer.getResult();
}

// for (auto pair : m_ToBeRecorded) {
// auto line =
Expand Down Expand Up @@ -1862,6 +1872,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
pullbackRequest.Mode = DiffMode::experimental_pullback;
// Silence diag outputs in nested derivation process.
pullbackRequest.VerboseDiags = false;
pullbackRequest.EnableTBRAnalysis = enableTBR;
FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest);
// Clad failed to derive it.
// FIXME: Add support for reference arguments to the numerical diff. If
Expand Down Expand Up @@ -2830,7 +2841,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context,
// We lack context to decide if this is useful to store or not. In the
// current system that should have been decided by the parent expression.
// FIXME: Here will be the entry point of the advanced activity analysis.
if (isa<DeclRefExpr>(B) || isa<ArraySubscriptExpr>(B)) {
if (isa<DeclRefExpr>(B) || isa<ArraySubscriptExpr>(B) || isa<MemberExpr>(B)) {
// If TBR analysis is off, assume E is useful to store.
if (!enableTBR)
return true;
auto found = m_ToBeRecorded.find(B->getBeginLoc());
return found != m_ToBeRecorded.end();
}
Expand Down
6 changes: 6 additions & 0 deletions tools/ClangPlugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,12 @@ namespace clad {
if (requests.empty())
return true;

// FIXME: flags have to be set manually since DiffCollector's constructor
// does not have access to m_DO.
if (m_DO.EnableTBRAnalysis)
for (DiffRequest& request : requests)
request.EnableTBRAnalysis = true;

// FIXME: Remove the PerformPendingInstantiations altogether. We should
// somehow make the relevant functions referenced.
// Instantiate all pending for instantiations templates, because we will
Expand Down
8 changes: 6 additions & 2 deletions tools/ClangPlugin.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,17 @@ namespace clad {
DifferentiationOptions()
: DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false),
DumpDerivedAST(false), GenerateSourceFile(false),
ValidateClangVersion(true), CustomEstimationModel(false),
PrintNumDiffErrorInfo(false), CustomModelName("") {}
ValidateClangVersion(true), EnableTBRAnalysis(false),
CustomEstimationModel(false), PrintNumDiffErrorInfo(false),
CustomModelName("") {}

bool DumpSourceFn : 1;
bool DumpSourceFnAST : 1;
bool DumpDerivedFn : 1;
bool DumpDerivedAST : 1;
bool GenerateSourceFile : 1;
bool ValidateClangVersion : 1;
bool EnableTBRAnalysis : 1;
bool CustomEstimationModel : 1;
bool PrintNumDiffErrorInfo : 1;
std::string CustomModelName;
Expand Down Expand Up @@ -132,6 +134,8 @@ namespace clad {
m_DO.GenerateSourceFile = true;
} else if (args[i] == "-fno-validate-clang-version") {
m_DO.ValidateClangVersion = false;
} else if (args[i] == "-enable-tbr") {
m_DO.EnableTBRAnalysis = true;
} else if (args[i] == "-fcustom-estimation-model") {
m_DO.CustomEstimationModel = true;
if (++i == e) {
Expand Down

0 comments on commit cec731f

Please sign in to comment.