From 329670f2266425a5bf80ce09dc9f960bce4d8460 Mon Sep 17 00:00:00 2001 From: Cengguang Zhang Date: Fri, 13 Dec 2024 00:41:25 +0800 Subject: [PATCH] [NPUW] Add optional convert when transpose weights for miniCPM. (#28012) ### Details: - Add optional convert after multiply for pattern match - Add optional logic when transpose weights for miniCPM. ### Tickets: - [EISW-131455](https://jira.devtools.intel.com/browse/EISW-131455) --- .../plugin/npuw/partitioning/patterns/opt.cpp | 24 +++++++++++++++---- 1 file changed, 20 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp index 41c978ec5ae542..5abe4b39fd44f2 100644 --- a/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp +++ b/src/plugins/intel_npu/src/plugin/npuw/partitioning/patterns/opt.cpp @@ -386,7 +386,8 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) { auto qcoeff = opp::wrap_type(); auto qcvtw = opp::wrap_type({qweight}); auto qmuls = opp::wrap_type({qcvtw, qcoeff}); - auto qreshp = opp::wrap_type({qmuls, opp::any_input()}); + auto qcvtm = opp::optional({qmuls->output(0)}); + auto qreshp = opp::wrap_type({qcvtm, opp::any_input()}); auto qcvtr = opp::optional({qreshp->output(0)}); auto qmmi = opp::any_input(); auto qmm = opp::wrap_type({qmmi, qcvtr}); @@ -398,6 +399,10 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) { auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr(); + std::shared_ptr matched_node_qcvtm = nullptr; + if (node_to_output.count(qcvtm)) { + matched_node_qcvtm = node_to_output.at(qcvtm).get_node_shared_ptr(); + } auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr(); auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr(); auto matched_out_mmi = node_to_output.at(qmmi); @@ -426,6 +431,9 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) { auto new_transpose_order_c = std::make_shared(ov::element::i32, ov::Shape{3}, new_transpose_order); auto new_transpose = std::make_shared(matched_node_qmuls, new_transpose_order_c); + if (matched_node_qcvtm) { + new_transpose = std::make_shared(matched_node_qcvtm, new_transpose_order_c); + } matched_node_qreshp->input(0).replace_source_output(new_transpose); matched_node_qreshp->validate_and_infer_types(); matched_matmul->validate_and_infer_types(); @@ -660,10 +668,11 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) { auto qcoeff = opp::wrap_type(); auto qcvtw = opp::wrap_type({qweight}); auto qmuls = opp::wrap_type({qcvtw, qcoeff}); - auto qreshp = opp::wrap_type({qmuls, opp::any_input()}); - auto qcvtm = opp::optional({qreshp->output(0)}); + auto qcvtm = opp::optional({qmuls->output(0)}); + auto qreshp = opp::wrap_type({qcvtm, opp::any_input()}); + auto qcvtr = opp::optional({qreshp->output(0)}); auto qmmi = opp::any_input(); - auto qmm = opp::wrap_type({qmmi, qcvtm}); + auto qmm = opp::wrap_type({qmmi, qcvtr}); // Note: Use [=] to make sure the above objects stay alive in the callback auto callback = [=](ov::pass::pattern::Matcher& m) { @@ -672,6 +681,10 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) { auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr(); auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr(); auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr(); + std::shared_ptr matched_node_qcvtm = nullptr; + if (node_to_output.count(qcvtm)) { + matched_node_qcvtm = node_to_output.at(qcvtm).get_node_shared_ptr(); + } auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr(); auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr(); auto matched_out_mmi = node_to_output.at(qmmi); @@ -703,6 +716,9 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) { auto new_transpose_order_c = std::make_shared(ov::element::i32, ov::Shape{3}, new_transpose_order); auto new_transpose = std::make_shared(matched_node_qmuls, new_transpose_order_c); + if (matched_node_qcvtm) { + new_transpose = std::make_shared(matched_node_qcvtm, new_transpose_order_c); + } matched_node_qreshp->input(0).replace_source_output(new_transpose); matched_node_qreshp->validate_and_infer_types(); matched_matmul->validate_and_infer_types();