Skip to content

Commit

Permalink
[NPUW] New compute patterns (#27618)
Browse files Browse the repository at this point in the history
  • Loading branch information
smirnov-alexey authored Nov 25, 2024
1 parent 0d918dc commit caef0ab
Show file tree
Hide file tree
Showing 4 changed files with 133 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@ namespace {
static const std::map<std::string, std::string> ISOL_PRESETS = {{"COMPUTE",
"P:DQMatMulGQu4/compute,P:DQMatMulCWu4/compute,"
"P:DQMatMulGQi4/compute,P:DQMatMulCWi4/compute,"
"P:DQMatMulConv/compute,"
"P:VocabMatMul/compute,"
"P:RMSNorm/compute"}};
"P:RMSNorm/compute,P:RMSNorm2/compute"}};
}

// For missing declaration warning
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -465,28 +465,20 @@ void Snapshot::earlyRegroup() {
break;
}
case PatternType::PATTERN: {
// FIXME: refactor as more patterns are supported
if (isolate.pattern == "RMSNorm") {
rewr.add_matcher<ov::npuw::patterns::compute::RMSNorm>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulCWu4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWu4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulGQu4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQu4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulCWi4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulCWi4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "DQMatMulGQi4") {
rewr.add_matcher<ov::npuw::patterns::compute::DQMatMulGQi4>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else if (isolate.pattern == "VocabMatMul") {
rewr.add_matcher<ov::npuw::patterns::compute::VocabMatMul>(shared_from_this(), isolate.tag);
handle_patterns = true;
} else {
LOG_WARN("OPENVINO_NPUW_ISOLATE: unsupported pattern " << isolate.pattern << " is skipped!");
}
#define HNDL(p) \
if (isolate.pattern == #p) { \
rewr.add_matcher<ov::npuw::patterns::compute::p>(shared_from_this(), isolate.tag); \
handle_patterns = true; \
}
HNDL(RMSNorm);
HNDL(RMSNorm2);
HNDL(DQMatMulCWu4);
HNDL(DQMatMulGQu4);
HNDL(DQMatMulCWi4);
HNDL(DQMatMulGQi4);
HNDL(DQMatMulConv);
HNDL(VocabMatMul);
#undef HNDL
}
}
}
Expand Down Expand Up @@ -723,6 +715,20 @@ std::shared_ptr<Repeated> Snapshot::tryMergeTriangles(const std::vector<Group::G
return {};
}

if (prods.size() < m_ctx.keep_blocks) {
// In some cases (specifically mixed precision) during MergeUniques() pass we could be left with
// E.g. 10 repeated blocks with tag AAA and 2 repeated blocks with tag BBB
// TryMergeTriangles() pass checks that producer and consumer have a different tag to be merged further.
// Let's say in our example 10 AAA blocks are finalized and cannot be merged further due to above check.
// However we will proceed to merge 3 BBB blocks with 3 AAA blocks since the tags are different.
// This will create a new tag CCC for the merged blocks and the merge will continue until those 3 blocks
// consume a large amount of legit AAA blocks.
// Later in CleanUpUniques() pass those repeated blocks will be stripped off repeated tag due to the same check
// in this "if". To prevent such cases where we would end up with small number of huge blocks this check was
// introduced.
return {};
}

// In this special case we only assume
// our vector of N repeating consumer groups
// 1. has the same size
Expand Down Expand Up @@ -939,6 +945,20 @@ std::shared_ptr<Repeated> Snapshot::tryMergeRepeating(const std::vector<Group::G
}
}

if (prods.size() < m_ctx.keep_blocks) {
// In some cases (specifically mixed precision) during MergeUniques() pass we could be left with
// E.g. 10 repeated blocks with tag AAA and 2 repeated blocks with tag BBB
// TryMergeRepeating() pass checks that producer and consumer have a different tag to be merged further.
// Let's say in our example 10 AAA blocks are finalized and cannot be merged further due to above check.
// However we will proceed to merge 3 BBB blocks with 3 AAA blocks since the tags are different.
// This will create a new tag CCC for the merged blocks and the merge will continue until those 3 blocks
// consume a large amount of legit AAA blocks.
// Later in CleanUpUniques() pass those repeated blocks will be stripped off repeated tag due to the same check
// in this "if". To prevent such cases where we would end up with small number of huge blocks this check was
// introduced.
return {};
}

std::shared_ptr<Repeated> new_rep = std::make_shared<Repeated>();

for (size_t i = 0; i < conss.size(); ++i) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -226,6 +226,51 @@ DQMatMulCWi4::DQMatMulCWi4(const std::shared_ptr<ov::npuw::online::Snapshot>& sn
register_matcher(std::make_shared<opp::Matcher>(qmm, "TagDQMatMulCWi4"), std::move(callback));
}

// Pattern:
// -> Transpose ------------------------------>
// Param/Const --> Convert(f32) --> Multiply -> Convolution -> Transpose ->
// Param/Const -> (Convert(f32)) ->

DQMatMulConv::DQMatMulConv(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
auto param = opp::any_input();
auto convert = opp::wrap_type<ov::op::v0::Convert>({param->output(0)});
auto param2 = opp::any_input();
auto convert2 = opp::optional<ov::op::v0::Convert>({param2->output(0)});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({convert, convert2});
auto transpose_in = opp::wrap_type<ov::op::v1::Transpose>({opp::any_input(), opp::any_input()});
auto conv = opp::wrap_type<ov::op::v1::Convolution>({transpose_in, multiply});
auto transpose_out = opp::wrap_type<ov::op::v1::Transpose>({conv, opp::any_input()});

auto node_to_gptr = snapshot->getNodeToGroupMap();

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_node_param = node_to_output.at(param);
auto matched_node_param2 = node_to_output.at(param2);

auto matched_node_transpose_in = node_to_output.at(transpose_in).get_node_shared_ptr();
auto matched_node_transpose_out = node_to_output.at(transpose_out).get_node_shared_ptr();
auto matched_node_multiply = node_to_output.at(multiply).get_node_shared_ptr();
auto matched_node_conv = node_to_output.at(conv).get_node_shared_ptr();

if ((matched_node_param.get_element_type() == ov::element::i4 ||
matched_node_param.get_element_type() == ov::element::i8) &&
(matched_node_param2.get_element_type() == ov::element::f32 ||
matched_node_param2.get_element_type() == ov::element::f16)) {
// Partitioning ignores Param/Const -> Convert nodes
node_to_gptr->at(matched_node_transpose_in)->isolate(isol_tag);
node_to_gptr->at(matched_node_transpose_out)->isolate(isol_tag);
node_to_gptr->at(matched_node_multiply)->isolate(isol_tag);
node_to_gptr->at(matched_node_conv)->isolate(isol_tag);
}

return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(transpose_out, "TagDQMatMulConv"), std::move(callback));
}

// This is a case for Raw (f16/f32) MatMul connected directly to the Result.
//
// The following combinations are covered:
Expand Down Expand Up @@ -327,6 +372,40 @@ RMSNorm::RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, co
register_matcher(std::make_shared<opp::Matcher>(multiply2, "TagRMSNorm"), std::move(callback));
}

// TODO: visualize
RMSNorm2::RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag) {
auto hadd = opp::wrap_type<ov::op::v1::Add>({opp::any_input(), opp::any_input()});
auto power = opp::wrap_type<ov::op::v1::Power>({hadd, opp::any_input()});
auto reduce = opp::wrap_type<ov::op::v1::ReduceSum>({power, opp::any_input()});
auto sqrt = opp::wrap_type<ov::op::v0::Sqrt>({reduce});
auto div = opp::wrap_type<ov::op::v1::Divide>({hadd, sqrt});
auto multiply = opp::wrap_type<ov::op::v1::Multiply>({opp::any_input(), div});

auto node_to_gptr = snapshot->getNodeToGroupMap();

// Note: Use [=] to make sure the above objects stay alive in the callback
auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();

auto matched_hadd = node_to_output.at(hadd).get_node_shared_ptr();
auto matched_power = node_to_output.at(power).get_node_shared_ptr();
auto matched_reduce = node_to_output.at(reduce).get_node_shared_ptr();
auto matched_sqrt = node_to_output.at(sqrt).get_node_shared_ptr();
auto matched_div = node_to_output.at(div).get_node_shared_ptr();
auto matched_multiply = node_to_output.at(multiply).get_node_shared_ptr();

node_to_gptr->at(matched_hadd)->isolate(isol_tag);
node_to_gptr->at(matched_power)->isolate(isol_tag);
node_to_gptr->at(matched_reduce)->isolate(isol_tag);
node_to_gptr->at(matched_sqrt)->isolate(isol_tag);
node_to_gptr->at(matched_div)->isolate(isol_tag);
node_to_gptr->at(matched_multiply)->isolate(isol_tag);

return false; // root hasn't changed
};
register_matcher(std::make_shared<opp::Matcher>(multiply, "TagRMSNorm2"), std::move(callback));
}

} // namespace compute
} // namespace patterns
} // namespace npuw
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,11 @@ class DQMatMulCWi4 : public ov::pass::MatcherPass {
DQMatMulCWi4(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

class DQMatMulConv : public ov::pass::MatcherPass {
public:
DQMatMulConv(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

class VocabMatMul : public ov::pass::MatcherPass {
public:
VocabMatMul(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
Expand All @@ -51,6 +56,11 @@ class RMSNorm : public ov::pass::MatcherPass {
RMSNorm(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

class RMSNorm2 : public ov::pass::MatcherPass {
public:
RMSNorm2(const std::shared_ptr<ov::npuw::online::Snapshot>& snapshot, const std::string& isol_tag);
};

} // namespace compute
} // namespace patterns
} // namespace npuw
Expand Down

0 comments on commit caef0ab

Please sign in to comment.