From 115df3f53a1f41fbdebaef72cb17e6b7152a8cde Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Fri, 11 Aug 2023 10:52:28 -0400 Subject: [PATCH] wip --- enzyme/Enzyme/Clang/EnzymePassLoader.cpp | 7 +++++++ enzyme/Enzyme/Enzyme.cpp | 26 ++++++++---------------- enzyme/Enzyme/OptBlas.cpp | 4 ++-- enzyme/Enzyme/OptBlas.h | 7 +++---- enzyme/test/Integration/BlasOpt/first.c | 4 ++-- 5 files changed, 22 insertions(+), 26 deletions(-) diff --git a/enzyme/Enzyme/Clang/EnzymePassLoader.cpp b/enzyme/Enzyme/Clang/EnzymePassLoader.cpp index b77608f06c81..c689f4af2ed1 100644 --- a/enzyme/Enzyme/Clang/EnzymePassLoader.cpp +++ b/enzyme/Enzyme/Clang/EnzymePassLoader.cpp @@ -34,6 +34,7 @@ #include "llvm/Transforms/Scalar/GVN.h" #include "../Enzyme.h" +#include "../OptBlas.h" #include "../PreserveNVVM.h" using namespace llvm; @@ -57,6 +58,10 @@ static void loadNVVMPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM) { PM.add(createPreserveNVVMPass(/*Begin=*/true)); } +static void loadBLASPass(const PassManagerBuilder &Builder, + legacy::PassManagerBase &PM) { + PM.add(createOptimizeBlasPass(/*Begin=*/true)); +} // These constructors add our pass to a list of global extensions. static RegisterStandardPasses @@ -66,6 +71,8 @@ static RegisterStandardPasses static RegisterStandardPasses clangtoolLoader_OEarly(PassManagerBuilder::EP_EarlyAsPossible, loadNVVMPass); +static RegisterStandardPasses + clangtoolLoader_Ox(PassManagerBuilder::EP_VectorizerStart, loadBLASPass); static void loadLTOPass(const PassManagerBuilder &Builder, legacy::PassManagerBase &PM) { diff --git a/enzyme/Enzyme/Enzyme.cpp b/enzyme/Enzyme/Enzyme.cpp index 7e89eb77df3c..63e28b12963a 100644 --- a/enzyme/Enzyme/Enzyme.cpp +++ b/enzyme/Enzyme/Enzyme.cpp @@ -3140,7 +3140,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { auto loadPass = [prePass](ModulePassManager &MPM) #endif { - MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false)); + // MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false)); MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true)); #if LLVM_VERSION_MAJOR >= 12 @@ -3163,8 +3163,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { #endif MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM))); MPM.addPass(EnzymeNewPM(/*PostOpt=*/true)); - // Manuel, new - MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false)); MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false)); #if LLVM_VERSION_MAJOR >= 16 OptimizerPM2.addPass(llvm::GVNPass()); @@ -3193,18 +3191,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { PB.registerPipelineStartEPCallback(loadPass); #endif -#if LLVM_VERSION_MAJOR >= 12 - auto optBLAS = [](ModulePassManager &MPM, OptimizationLevel) -#else - auto optBLAS = [](ModulePassManager &MPM) -#endif - { MPM.addPass(OptimizeBlasNewPM(/*Begin*/ true)); }; - - // We should register at vectorizer start for consistency, however, - // that requires a functionpass, and we have a modulepass. - // PB.registerVectorizerStartEPCallback(loadPass); - PB.registerPipelineStartEPCallback(optBLAS); - #if LLVM_VERSION_MAJOR >= 12 auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel) #else @@ -3217,7 +3203,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) { // PB.registerVectorizerStartEPCallback(loadPass); PB.registerPipelineStartEPCallback(loadNVVM); #if LLVM_VERSION_MAJOR >= 15 - PB.registerFullLinkTimeOptimizationEarlyEPCallback(optBLAS); PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM); auto preLTOPass = [](ModulePassManager &MPM, OptimizationLevel Level) { @@ -3467,10 +3452,12 @@ llvmGetPassPluginInfo() { #ifdef ENZYME_RUNPASS augmentPassBuilder(PB); #endif + llvm::errs() << "CCCCCC\n"; PB.registerPipelineParsingCallback( [](llvm::StringRef Name, llvm::ModulePassManager &MPM, llvm::ArrayRef) { if (Name == "blas-opt") { + llvm::errs() << "AAAA\n"; MPM.addPass(OptimizeBlasNewPM(/*Begin*/ true)); return true; } @@ -3486,7 +3473,9 @@ llvmGetPassPluginInfo() { MPM.addPass(TypeAnalysisPrinterNewPM()); return true; } - return false; + llvm::errs() << "BBBB\n"; + return true; + // return false; }); PB.registerPipelineParsingCallback( [](llvm::StringRef Name, llvm::FunctionPassManager &FPM, @@ -3495,7 +3484,8 @@ llvmGetPassPluginInfo() { FPM.addPass(ActivityAnalysisPrinterNewPM()); return true; } - return false; + return true; + // return false; }); }}; } diff --git a/enzyme/Enzyme/OptBlas.cpp b/enzyme/Enzyme/OptBlas.cpp index c1dcea74c424..d264fc8aaf61 100644 --- a/enzyme/Enzyme/OptBlas.cpp +++ b/enzyme/Enzyme/OptBlas.cpp @@ -131,8 +131,8 @@ bool optimizeFncsWithBlas(llvm::Module &M) { OptimizeBlasNewPM::Result OptimizeBlasNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) { - llvm::errs() << "fooBar\n"; + llvm::errs() << "newPM opt-blas\n"; bool changed = optimizeFncsWithBlas(M); return changed ? PreservedAnalyses::none() : PreservedAnalyses::all(); } -llvm::AnalysisKey OptimizeBlasNewPM::Key; +// llvm::AnalysisKey OptimizeBlasNewPM::Key; diff --git a/enzyme/Enzyme/OptBlas.h b/enzyme/Enzyme/OptBlas.h index f354c22b0d5f..31d16b7a4f12 100644 --- a/enzyme/Enzyme/OptBlas.h +++ b/enzyme/Enzyme/OptBlas.h @@ -33,13 +33,12 @@ bool optimizeFncsWithBlas(llvm::Module &M); llvm::ModulePass *createOptimizeBlasPass(bool Begin); -class OptimizeBlasNewPM final - : public llvm::AnalysisInfoMixin { - friend struct llvm::AnalysisInfoMixin; +class OptimizeBlasNewPM final : public llvm::PassInfoMixin { + friend struct llvm::PassInfoMixin; private: bool Begin; - static llvm::AnalysisKey Key; + // static llvm::AnalysisKey Key; public: using Result = llvm::PreservedAnalyses; diff --git a/enzyme/test/Integration/BlasOpt/first.c b/enzyme/test/Integration/BlasOpt/first.c index 37f98b43bb35..259a43bfaa98 100644 --- a/enzyme/test/Integration/BlasOpt/first.c +++ b/enzyme/test/Integration/BlasOpt/first.c @@ -13,7 +13,7 @@ void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA const int lda, const double *B, const int ldb, const double beta, double *C, const int ldc); -void g(double *restrict x, double *restrict y, double *restrict v, double *restrict w, double *restrict C) { +void f(double *restrict x, double *restrict y, double *restrict v, double *restrict w, double *restrict C) { double A[] = {0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00}; @@ -38,7 +38,7 @@ int main() { double C[] = {0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00, 0.00}; - g(x,y,v,w,C); + f(x,y,v,w,C); for (int i = 0; i < 9; i++) printf("%f\n", C[i]); }