Skip to content

Commit

Permalink
[NPUW] Add optional convert when transpose weights for miniCPM. (#28012)
Browse files Browse the repository at this point in the history
### Details:
 - Add optional convert after multiply for pattern match
 - Add optional logic when transpose weights for miniCPM.

### Tickets:
 - [EISW-131455](https://jira.devtools.intel.com/browse/EISW-131455)
  • Loading branch information
lalalapotter authored Dec 12, 2024
1 parent 8d6491b commit 329670f
Showing 1 changed file with 20 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -386,7 +386,8 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
auto qcoeff = opp::wrap_type<ov::op::v0::Parameter>();
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qmuls, opp::any_input()});
auto qcvtm = opp::optional<ov::op::v0::Convert>({qmuls->output(0)});
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qcvtm, opp::any_input()});
auto qcvtr = opp::optional<ov::op::v0::Convert>({qreshp->output(0)});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qcvtr});
Expand All @@ -398,6 +399,10 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
std::shared_ptr<Node> matched_node_qcvtm = nullptr;
if (node_to_output.count(qcvtm)) {
matched_node_qcvtm = node_to_output.at(qcvtm).get_node_shared_ptr();
}
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);
Expand Down Expand Up @@ -426,6 +431,9 @@ DQMatMulGQ2i::DQMatMulGQ2i(Context::Ref ctx) {
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
if (matched_node_qcvtm) {
new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qcvtm, new_transpose_order_c);
}
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();
matched_matmul->validate_and_infer_types();
Expand Down Expand Up @@ -660,10 +668,11 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
auto qcoeff = opp::wrap_type<ov::op::v0::Parameter>();
auto qcvtw = opp::wrap_type<ov::op::v0::Convert>({qweight});
auto qmuls = opp::wrap_type<ov::op::v1::Multiply>({qcvtw, qcoeff});
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qmuls, opp::any_input()});
auto qcvtm = opp::optional<ov::op::v0::Convert>({qreshp->output(0)});
auto qcvtm = opp::optional<ov::op::v0::Convert>({qmuls->output(0)});
auto qreshp = opp::wrap_type<ov::op::v1::Reshape>({qcvtm, opp::any_input()});
auto qcvtr = opp::optional<ov::op::v0::Convert>({qreshp->output(0)});
auto qmmi = opp::any_input();
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qcvtm});
auto qmm = opp::wrap_type<ov::op::v0::MatMul>({qmmi, qcvtr});

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
Expand All @@ -672,6 +681,10 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
auto matched_node_qweight = node_to_output.at(qweight).get_node_shared_ptr();
auto matched_node_qcoeff = node_to_output.at(qcoeff).get_node_shared_ptr();
auto matched_node_qmuls = node_to_output.at(qmuls).get_node_shared_ptr();
std::shared_ptr<Node> matched_node_qcvtm = nullptr;
if (node_to_output.count(qcvtm)) {
matched_node_qcvtm = node_to_output.at(qcvtm).get_node_shared_ptr();
}
auto matched_node_matmul = node_to_output.at(qmm).get_node_shared_ptr();
auto matched_node_qreshp = node_to_output.at(qreshp).get_node_shared_ptr();
auto matched_out_mmi = node_to_output.at(qmmi);
Expand Down Expand Up @@ -703,6 +716,9 @@ DQMatMulGQ2iP::DQMatMulGQ2iP(Context::Ref ctx) {
auto new_transpose_order_c =
std::make_shared<ov::op::v0::Constant>(ov::element::i32, ov::Shape{3}, new_transpose_order);
auto new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qmuls, new_transpose_order_c);
if (matched_node_qcvtm) {
new_transpose = std::make_shared<ov::op::v1::Transpose>(matched_node_qcvtm, new_transpose_order_c);
}
matched_node_qreshp->input(0).replace_source_output(new_transpose);
matched_node_qreshp->validate_and_infer_types();
matched_matmul->validate_and_infer_types();
Expand Down

0 comments on commit 329670f

Please sign in to comment.