From 12885194955ad4620251ad7aa16c58220967222d Mon Sep 17 00:00:00 2001 From: Manuel Drehwald Date: Sun, 16 Jul 2023 07:53:22 -0400 Subject: [PATCH] simplify and fix blasDiffUse (#1332) --- .../tools/enzyme-tblgen/blasDiffUseUpdater.h | 25 +++---------------- 1 file changed, 4 insertions(+), 21 deletions(-) diff --git a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h index 419ae2f321b4..ac6388956081 100644 --- a/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h +++ b/enzyme/tools/enzyme-tblgen/blasDiffUseUpdater.h @@ -49,20 +49,10 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { } emit_scalar_caching(pattern, os); - // we currently cache all vecs before we cache all matrices - // once fixed we can merge this calls for (size_t i = 0; i < nameVec.size(); i++) { auto ty = typeMap.lookup(i); - if (ty != ArgType::vincData) + if (ty != ArgType::vincData && ty != ArgType::mldData) continue; - assert(typeMap.lookup(i + 1) == ArgType::vincInc); - emit_mat_vec_caching(pattern, i, os); - } - for (size_t i = 0; i < nameVec.size(); i++) { - auto ty = typeMap.lookup(i); - if (ty != ArgType::mldData) - continue; - assert(typeMap.lookup(i + 1) == ArgType::mldLD); emit_mat_vec_caching(pattern, i, os); } @@ -70,16 +60,9 @@ void emit_BLASDiffUse(TGPattern &pattern, llvm::raw_ostream &os) { auto users = argUsers.lookup(argPos); auto name = nameVec[argPos]; size_t i = (lv23 ? argPos - 1 : argPos); - os << " if (val == arg_" << name << " && !cache_" << name << ") {\n"; - for (auto a : users) { - auto name = nameVec[a]; - // The following shows that I probably should change the tblgen - // logic and the Blas.td declaration - if (a == i) // a == i? argpos ? - continue; - os << " if (active_" << name << ") return true;\n"; - } - os << " }\n"; + os << " if (val == arg_" << name << " && need_" << name << " && !cache_" + << name << ")\n" + << " return true;\n"; } os << " return false;\n";