From 65686c2f410b91d9748567ba486776ae5628f338 Mon Sep 17 00:00:00 2001 From: Alexey Smirnov Date: Wed, 20 Nov 2024 17:37:02 +0000 Subject: [PATCH] Address review comments --- .../npuw/partitioning/online/snapshot.cpp | 72 +++++++++---------- .../npuw/partitioning/patterns/compute.cpp | 16 ++--- 2 files changed, 41 insertions(+), 47 deletions(-) diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp index b88a6d4546c113..f1ef604033481d 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/online/snapshot.cpp @@ -465,34 +465,20 @@ void Snapshot::earlyRegroup() { break; } case PatternType::PATTERN: { - // FIXME: refactor as more patterns are supported - if (isolate.pattern == "RMSNorm") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "RMSNorm2") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulCWu4") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulGQu4") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulCWi4") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulGQi4") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "DQMatMulConv") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else if (isolate.pattern == "VocabMatMul") { - rewr.add_matcher(shared_from_this(), isolate.tag); - handle_patterns = true; - } else { - LOG_WARN("OPENVINO_NPUW_ISOLATE: unsupported pattern " << isolate.pattern << " is skipped!"); - } +#define HNDL(p) \ + if (isolate.pattern == #p) { \ + rewr.add_matcher(shared_from_this(), isolate.tag); \ + handle_patterns = true; \ + } + HNDL(RMSNorm); + HNDL(RMSNorm2); + HNDL(DQMatMulCWu4); + HNDL(DQMatMulGQu4); + HNDL(DQMatMulCWi4); + HNDL(DQMatMulGQi4); + HNDL(DQMatMulConv); + HNDL(VocabMatMul); +#undef HNDL } } } @@ -730,11 +716,16 @@ std::shared_ptr Snapshot::tryMergeTriangles(const std::vector Snapshot::tryMergeRepeating(const std::vector& sn auto param2 = opp::any_input(); auto convert2 = opp::optional({param2->output(0)}); auto multiply = opp::wrap_type({convert, convert2}); - auto tr_input = opp::any_input(); - auto transpose_in = opp::wrap_type({tr_input, opp::any_input()}); + auto transpose_in = opp::wrap_type({opp::any_input(), opp::any_input()}); auto conv = opp::wrap_type({transpose_in, multiply}); auto transpose_out = opp::wrap_type({conv, opp::any_input()}); @@ -248,18 +247,18 @@ DQMatMulConv::DQMatMulConv(const std::shared_ptr& sn auto callback = [=](ov::pass::pattern::Matcher& m) { auto& node_to_output = m.get_pattern_value_map(); - auto matched_node_param = node_to_output.at(param).get_node_shared_ptr(); - auto matched_node_param2 = node_to_output.at(param2).get_node_shared_ptr(); + auto matched_node_param = node_to_output.at(param); + auto matched_node_param2 = node_to_output.at(param2); auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr(); auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr(); auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr(); auto matched_node_conv = node_to_output.at(conv).get_node_shared_ptr(); - if ((matched_node_param->get_element_type() == ov::element::i4 || - matched_node_param->get_element_type() == ov::element::i8) && - (matched_node_param2->get_element_type() == ov::element::f32 || - matched_node_param2->get_element_type() == ov::element::f16)) { + if ((matched_node_param.get_element_type() == ov::element::i4 || + matched_node_param.get_element_type() == ov::element::i8) && + (matched_node_param2.get_element_type() == ov::element::f32 || + matched_node_param2.get_element_type() == ov::element::f16)) { // Partitioning ignores Param/Const -> Convert nodes node_to_gptr->at(matched_node_transpose_in)->isolate(isol_tag); node_to_gptr->at(matched_node_transpose_out)->isolate(isol_tag); @@ -348,7 +347,6 @@ RMSNorm::RMSNorm(const std::shared_ptr& snapshot, co // Note: Use [=] to make sure the above objects stay alive in the callback auto callback = [=](ov::pass::pattern::Matcher& m) { - std::cout << "RMSNorm MATCHED!" << std::endl; auto& node_to_output = m.get_pattern_value_map(); auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr();