Skip to content

Commit

Permalink
wip
Browse files Browse the repository at this point in the history
  • Loading branch information
Manuel Drehwald committed Aug 11, 2023
1 parent f1d2291 commit 115df3f
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 26 deletions.
7 changes: 7 additions & 0 deletions enzyme/Enzyme/Clang/EnzymePassLoader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#include "llvm/Transforms/Scalar/GVN.h"

#include "../Enzyme.h"
#include "../OptBlas.h"
#include "../PreserveNVVM.h"

using namespace llvm;
Expand All @@ -57,6 +58,10 @@ static void loadNVVMPass(const PassManagerBuilder &Builder,
legacy::PassManagerBase &PM) {
PM.add(createPreserveNVVMPass(/*Begin=*/true));
}
static void loadBLASPass(const PassManagerBuilder &Builder,
legacy::PassManagerBase &PM) {
PM.add(createOptimizeBlasPass(/*Begin=*/true));
}

// These constructors add our pass to a list of global extensions.
static RegisterStandardPasses
Expand All @@ -66,6 +71,8 @@ static RegisterStandardPasses
static RegisterStandardPasses
clangtoolLoader_OEarly(PassManagerBuilder::EP_EarlyAsPossible,
loadNVVMPass);
static RegisterStandardPasses
clangtoolLoader_Ox(PassManagerBuilder::EP_VectorizerStart, loadBLASPass);

static void loadLTOPass(const PassManagerBuilder &Builder,
legacy::PassManagerBase &PM) {
Expand Down
26 changes: 8 additions & 18 deletions enzyme/Enzyme/Enzyme.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3140,7 +3140,7 @@ void augmentPassBuilder(llvm::PassBuilder &PB) {
auto loadPass = [prePass](ModulePassManager &MPM)
#endif
{
MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false));
// MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false));
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ true));

#if LLVM_VERSION_MAJOR >= 12
Expand All @@ -3163,8 +3163,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) {
#endif
MPM.addPass(createModuleToFunctionPassAdaptor(std::move(OptimizerPM)));
MPM.addPass(EnzymeNewPM(/*PostOpt=*/true));
// Manuel, new
MPM.addPass(OptimizeBlasNewPM(/*Begin*/ false));
MPM.addPass(PreserveNVVMNewPM(/*Begin*/ false));
#if LLVM_VERSION_MAJOR >= 16
OptimizerPM2.addPass(llvm::GVNPass());
Expand Down Expand Up @@ -3193,18 +3191,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) {
PB.registerPipelineStartEPCallback(loadPass);
#endif

#if LLVM_VERSION_MAJOR >= 12
auto optBLAS = [](ModulePassManager &MPM, OptimizationLevel)
#else
auto optBLAS = [](ModulePassManager &MPM)
#endif
{ MPM.addPass(OptimizeBlasNewPM(/*Begin*/ true)); };

// We should register at vectorizer start for consistency, however,
// that requires a functionpass, and we have a modulepass.
// PB.registerVectorizerStartEPCallback(loadPass);
PB.registerPipelineStartEPCallback(optBLAS);

#if LLVM_VERSION_MAJOR >= 12
auto loadNVVM = [](ModulePassManager &MPM, OptimizationLevel)
#else
Expand All @@ -3217,7 +3203,6 @@ void augmentPassBuilder(llvm::PassBuilder &PB) {
// PB.registerVectorizerStartEPCallback(loadPass);
PB.registerPipelineStartEPCallback(loadNVVM);
#if LLVM_VERSION_MAJOR >= 15
PB.registerFullLinkTimeOptimizationEarlyEPCallback(optBLAS);
PB.registerFullLinkTimeOptimizationEarlyEPCallback(loadNVVM);

auto preLTOPass = [](ModulePassManager &MPM, OptimizationLevel Level) {
Expand Down Expand Up @@ -3467,10 +3452,12 @@ llvmGetPassPluginInfo() {
#ifdef ENZYME_RUNPASS
augmentPassBuilder(PB);
#endif
llvm::errs() << "CCCCCC\n";
PB.registerPipelineParsingCallback(
[](llvm::StringRef Name, llvm::ModulePassManager &MPM,
llvm::ArrayRef<llvm::PassBuilder::PipelineElement>) {
if (Name == "blas-opt") {
llvm::errs() << "AAAA\n";
MPM.addPass(OptimizeBlasNewPM(/*Begin*/ true));
return true;
}
Expand All @@ -3486,7 +3473,9 @@ llvmGetPassPluginInfo() {
MPM.addPass(TypeAnalysisPrinterNewPM());
return true;
}
return false;
llvm::errs() << "BBBB\n";
return true;
// return false;
});
PB.registerPipelineParsingCallback(
[](llvm::StringRef Name, llvm::FunctionPassManager &FPM,
Expand All @@ -3495,7 +3484,8 @@ llvmGetPassPluginInfo() {
FPM.addPass(ActivityAnalysisPrinterNewPM());
return true;
}
return false;
return true;
// return false;
});
}};
}
4 changes: 2 additions & 2 deletions enzyme/Enzyme/OptBlas.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -131,8 +131,8 @@ bool optimizeFncsWithBlas(llvm::Module &M) {

OptimizeBlasNewPM::Result
OptimizeBlasNewPM::run(llvm::Module &M, llvm::ModuleAnalysisManager &MAM) {
llvm::errs() << "fooBar\n";
llvm::errs() << "newPM opt-blas\n";
bool changed = optimizeFncsWithBlas(M);
return changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
}
llvm::AnalysisKey OptimizeBlasNewPM::Key;
// llvm::AnalysisKey OptimizeBlasNewPM::Key;
7 changes: 3 additions & 4 deletions enzyme/Enzyme/OptBlas.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,12 @@ bool optimizeFncsWithBlas(llvm::Module &M);

llvm::ModulePass *createOptimizeBlasPass(bool Begin);

class OptimizeBlasNewPM final
: public llvm::AnalysisInfoMixin<OptimizeBlasNewPM> {
friend struct llvm::AnalysisInfoMixin<OptimizeBlasNewPM>;
class OptimizeBlasNewPM final : public llvm::PassInfoMixin<OptimizeBlasNewPM> {
friend struct llvm::PassInfoMixin<OptimizeBlasNewPM>;

private:
bool Begin;
static llvm::AnalysisKey Key;
// static llvm::AnalysisKey Key;

public:
using Result = llvm::PreservedAnalyses;
Expand Down
4 changes: 2 additions & 2 deletions enzyme/test/Integration/BlasOpt/first.c
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ void cblas_dgemm(const enum CBLAS_ORDER Order, const enum CBLAS_TRANSPOSE TransA
const int lda, const double *B, const int ldb,
const double beta, double *C, const int ldc);

void g(double *restrict x, double *restrict y, double *restrict v, double *restrict w, double *restrict C) {
void f(double *restrict x, double *restrict y, double *restrict v, double *restrict w, double *restrict C) {
double A[] = {0.00, 0.00, 0.00, 0.00,
0.00, 0.00, 0.00, 0.00,
0.00};
Expand All @@ -38,7 +38,7 @@ int main() {
double C[] = {0.00, 0.00, 0.00, 0.00,
0.00, 0.00, 0.00, 0.00,
0.00};
g(x,y,v,w,C);
f(x,y,v,w,C);
for (int i = 0; i < 9; i++)
printf("%f\n", C[i]);
}

0 comments on commit 115df3f

Please sign in to comment.