Skip to content

Commit

Permalink
Address review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
smirnov-alexey committed Nov 20, 2024
1 parent 295a5b0 commit 65686c2
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 47 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -465,34 +465,20 @@ void Snapshot::earlyRegroup() {
break;
}
case PatternType::PATTERN: {
// FIXME: refactor as more patterns are supported
if (isolate.pattern == "RMSNorm") {
rewr.add_matcher<ov::npuw::patterns::compute::RMSNorm>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "RMSNorm2") {
rewr.add_matcher<ov::npuw::patterns::compute::RMSNorm2>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulCWu4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWu4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulGQu4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQu4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulCWi4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWi4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulGQi4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQi4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulConv") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulConv>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "VocabMatMul") {
rewr.add_matcher<ov::npuw::patterns::compute::VocabMatMul>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else {
LOG_WARN("OPENVINO_NPUW_ISOLATE: unsupported pattern " << isolate.pattern << " is skipped!");
}
#define HNDL(p) \
if (isolate.pattern == #p) { \
rewr.add_matcher<ov::npuw::patterns::compute::p>(shared_from_this(), isolate.tag); \
handle_patterns = true; \
}
HNDL(RMSNorm);
HNDL(RMSNorm2);
HNDL(DQMatMulCWu4);
HNDL(DQMatMulGQu4);
HNDL(DQMatMulCWi4);
HNDL(DQMatMulGQi4);
HNDL(DQMatMulConv);
HNDL(VocabMatMul);
#undef HNDL
}
}
}
Expand Down Expand Up @@ -730,11 +716,16 @@ std::shared_ptr<Repeated> Snapshot::tryMergeTriangles(const std::vector<Group::G
}

if (prods.size() < m_ctx.keep_blocks) {
// In some cases (specifically mixed precision) we could have merged
// a small number of huge groups which consumed other legit repeated blocks
// due to having a different repeated tag due to unique weights precision combination
// from the rest of the model.
// This check was added to prevent that and shouldn't affect other models.
// In some cases (specifically mixed precision) during MergeUniques() pass we could be left with
// E.g. 10 repeated blocks with tag AAA and 2 repeated blocks with tag BBB
// TryMergeTriangles() pass checks that producer and consumer have a different tag to be merged further.
// Let's say in our example 10 AAA blocks are finalized and cannot be merged further due to above check.
// However we will proceed to merge 3 BBB blocks with 3 AAA blocks since the tags are different.
// This will create a new tag CCC for the merged blocks and the merge will continue until those 3 blocks
// consume a large amount of legit AAA blocks.
// Later in CleanUpUniques() pass those repeated blocks will be stripped off repeated tag due to the same check
// in this "if". To prevent such cases where we would end up with small number of huge blocks this check was
// introduced.
return {};
}

Expand Down Expand Up @@ -955,11 +946,16 @@ std::shared_ptr<Repeated> Snapshot::tryMergeRepeating(const std::vector<Group::G
}

if (prods.size() < m_ctx.keep_blocks) {
// In some cases (specifically mixed precision) we could have merged
// a small number of huge groups which consumed other legit repeated blocks
// due to having a different repeated tag due to unique weights precision combination
// from the rest of the model.
// This check was added to prevent that and shouldn't affect other models.
// In some cases (specifically mixed precision) during MergeUniques() pass we could be left with
// E.g. 10 repeated blocks with tag AAA and 2 repeated blocks with tag BBB
// TryMergeRepeating() pass checks that producer and consumer have a different tag to be merged further.
// Let's say in our example 10 AAA blocks are finalized and cannot be merged further due to above check.
// However we will proceed to merge 3 BBB blocks with 3 AAA blocks since the tags are different.
// This will create a new tag CCC for the merged blocks and the merge will continue until those 3 blocks
// consume a large amount of legit AAA blocks.
// Later in CleanUpUniques() pass those repeated blocks will be stripped off repeated tag due to the same check
// in this "if". To prevent such cases where we would end up with small number of huge blocks this check was
// introduced.
return {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,7 @@ DQMatMulConv::DQMatMulConv(const std::shared_ptr<ov::npuw::online::Snapshot>& sn
auto param2 = opp::any_input();
auto convert2 = opp::optional<ov::op::v0::Convert>({param2->output(0)});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({convert, convert2});
auto tr_input = opp::any_input();
auto transpose_in = opp::wrap_type<ov::op::v1::Transpose>({tr_input, opp::any_input()});
auto transpose_in = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
auto conv = opp::wrap_type<ov::op::v1::Convolution>({transpose_in, multiply});
auto transpose_out = opp::wrap_type<ov::op::v1::Transpose>({conv, opp::any_input()});

Expand All @@ -248,18 +247,18 @@ DQMatMulConv::DQMatMulConv(const std::shared_ptr<ov::npuw::online::Snapshot>& sn
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_node_param = node_to_output.at(param).get_node_shared_ptr();
auto matched_node_param2 = node_to_output.at(param2).get_node_shared_ptr();
auto matched_node_param = node_to_output.at(param);
auto matched_node_param2 = node_to_output.at(param2);

auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr();
auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr();
auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr();
auto matched_node_conv = node_to_output.at(conv).get_node_shared_ptr();

if ((matched_node_param->get_element_type() == ov::element::i4 ||
matched_node_param->get_element_type() == ov::element::i8) &&
(matched_node_param2->get_element_type() == ov::element::f32 ||
matched_node_param2->get_element_type() == ov::element::f16)) {
if ((matched_node_param.get_element_type() == ov::element::i4 ||
matched_node_param.get_element_type() == ov::element::i8) &&
(matched_node_param2.get_element_type() == ov::element::f32 ||
matched_node_param2.get_element_type() == ov::element::f16)) {
// Partitioning ignores Param/Const -> Convert nodes
node_to_gptr->at(matched_node_transpose_in)->isolate(isol_tag);
node_to_gptr->at(matched_node_transpose_out)->isolate(isol_tag);
Expand Down Expand Up @@ -348,7 +347,6 @@ RMSNorm::RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, co

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
std::cout << "RMSNorm MATCHED!" << std::endl;
auto& node_to_output = m.get_pattern_value_map();

auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr();
Expand Down

0 comments on commit 65686c2

Please sign in to comment.