From cec731f4f55578c13e7296ecb778a0e3ae0708a8 Mon Sep 17 00:00:00 2001 From: "petro.zarytskyi" Date: Wed, 15 Nov 2023 11:10:00 +0200 Subject: [PATCH] Add a flag for TBR analysis. --- include/clad/Differentiator/DiffPlanner.h | 2 ++ .../clad/Differentiator/ReverseModeVisitor.h | 1 + lib/Differentiator/ReverseModeVisitor.cpp | 24 +++++++++++++++---- tools/ClangPlugin.cpp | 6 +++++ tools/ClangPlugin.h | 8 +++++-- 5 files changed, 34 insertions(+), 7 deletions(-) diff --git a/include/clad/Differentiator/DiffPlanner.h b/include/clad/Differentiator/DiffPlanner.h index bd51548e6..2705cc447 100644 --- a/include/clad/Differentiator/DiffPlanner.h +++ b/include/clad/Differentiator/DiffPlanner.h @@ -43,6 +43,8 @@ namespace clad { bool CallUpdateRequired = false; /// A flag to enable/disable diag warnings/errors during differentiation. bool VerboseDiags = false; + /// A flag to enable TBR analysis during reverse-mode differentiation. + bool EnableTBRAnalysis = false; /// Puts the derived function and its code in the diff call void updateCall(clang::FunctionDecl* FD, clang::FunctionDecl* OverloadedFD, clang::Sema& SemaRef); diff --git a/include/clad/Differentiator/ReverseModeVisitor.h b/include/clad/Differentiator/ReverseModeVisitor.h index 7e285365e..e9d36d0a1 100644 --- a/include/clad/Differentiator/ReverseModeVisitor.h +++ b/include/clad/Differentiator/ReverseModeVisitor.h @@ -64,6 +64,7 @@ namespace clad { unsigned numParams = 0; bool isVectorValued = false; bool use_enzyme = false; + bool enableTBR = false; // FIXME: Should we make this an object instead of a pointer? // Downside of making it an object: We will need to include // 'MultiplexExternalRMVSource.h' file diff --git a/lib/Differentiator/ReverseModeVisitor.cpp b/lib/Differentiator/ReverseModeVisitor.cpp index c40c729d8..4b1f556ab 100644 --- a/lib/Differentiator/ReverseModeVisitor.cpp +++ b/lib/Differentiator/ReverseModeVisitor.cpp @@ -292,6 +292,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, outputArrayStr = m_Function->getParamDecl(lastArgN)->getNameAsString(); } + // Check if DiffRequest asks for TBR analysis to be enabled + if (request.EnableTBRAnalysis) + enableTBR = true; + // Check if DiffRequest asks for use of enzyme as backend if (request.use_enzyme) use_enzyme = true; @@ -454,9 +458,13 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, DerivativeAndOverload ReverseModeVisitor::DerivePullback(const clang::FunctionDecl* FD, const DiffRequest& request) { + if (request.EnableTBRAnalysis) + enableTBR = true; TBRAnalyzer analyzer(m_Context); - analyzer.Analyze(FD); - m_ToBeRecorded = analyzer.getResult(); + if (enableTBR) { + analyzer.Analyze(FD); + m_ToBeRecorded = analyzer.getResult(); + } // for (auto pair : m_ToBeRecorded) { // auto line = @@ -570,8 +578,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, void ReverseModeVisitor::DifferentiateWithClad() { TBRAnalyzer analyzer(m_Context); - analyzer.Analyze(m_Function); - m_ToBeRecorded = analyzer.getResult(); + if (enableTBR) { + analyzer.Analyze(m_Function); + m_ToBeRecorded = analyzer.getResult(); + } // for (auto pair : m_ToBeRecorded) { // auto line = @@ -1862,6 +1872,7 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, pullbackRequest.Mode = DiffMode::experimental_pullback; // Silence diag outputs in nested derivation process. pullbackRequest.VerboseDiags = false; + pullbackRequest.EnableTBRAnalysis = enableTBR; FunctionDecl* pullbackFD = plugin::ProcessDiffRequest(m_CladPlugin, pullbackRequest); // Clad failed to derive it. // FIXME: Add support for reference arguments to the numerical diff. If @@ -2830,7 +2841,10 @@ Expr* getArraySizeExpr(const ArrayType* AT, ASTContext& context, // We lack context to decide if this is useful to store or not. In the // current system that should have been decided by the parent expression. // FIXME: Here will be the entry point of the advanced activity analysis. - if (isa(B) || isa(B)) { + if (isa(B) || isa(B) || isa(B)) { + // If TBR analysis is off, assume E is useful to store. + if (!enableTBR) + return true; auto found = m_ToBeRecorded.find(B->getBeginLoc()); return found != m_ToBeRecorded.end(); } diff --git a/tools/ClangPlugin.cpp b/tools/ClangPlugin.cpp index b3c374218..70e39aa1b 100644 --- a/tools/ClangPlugin.cpp +++ b/tools/ClangPlugin.cpp @@ -143,6 +143,12 @@ namespace clad { if (requests.empty()) return true; + // FIXME: flags have to be set manually since DiffCollector's constructor + // does not have access to m_DO. + if (m_DO.EnableTBRAnalysis) + for (DiffRequest& request : requests) + request.EnableTBRAnalysis = true; + // FIXME: Remove the PerformPendingInstantiations altogether. We should // somehow make the relevant functions referenced. // Instantiate all pending for instantiations templates, because we will diff --git a/tools/ClangPlugin.h b/tools/ClangPlugin.h index 4d570019e..688b1b869 100644 --- a/tools/ClangPlugin.h +++ b/tools/ClangPlugin.h @@ -67,8 +67,9 @@ namespace clad { DifferentiationOptions() : DumpSourceFn(false), DumpSourceFnAST(false), DumpDerivedFn(false), DumpDerivedAST(false), GenerateSourceFile(false), - ValidateClangVersion(true), CustomEstimationModel(false), - PrintNumDiffErrorInfo(false), CustomModelName("") {} + ValidateClangVersion(true), EnableTBRAnalysis(false), + CustomEstimationModel(false), PrintNumDiffErrorInfo(false), + CustomModelName("") {} bool DumpSourceFn : 1; bool DumpSourceFnAST : 1; @@ -76,6 +77,7 @@ namespace clad { bool DumpDerivedAST : 1; bool GenerateSourceFile : 1; bool ValidateClangVersion : 1; + bool EnableTBRAnalysis : 1; bool CustomEstimationModel : 1; bool PrintNumDiffErrorInfo : 1; std::string CustomModelName; @@ -132,6 +134,8 @@ namespace clad { m_DO.GenerateSourceFile = true; } else if (args[i] == "-fno-validate-clang-version") { m_DO.ValidateClangVersion = false; + } else if (args[i] == "-enable-tbr") { + m_DO.EnableTBRAnalysis = true; } else if (args[i] == "-fcustom-estimation-model") { m_DO.CustomEstimationModel = true; if (++i == e) {