Skip to content

Commit

Permalink
NPUW: Adding a new dcoff pattern (openvinotoolkit#25938)
Browse files Browse the repository at this point in the history
### Details:
- Implemented a new pattern in continuation of the PR:
[PR:2587](openvinotoolkit#25827).

### Tickets:
 - *121052*

Co-authored-by: Dmitry Matveev <[email protected]>
  • Loading branch information
ujjayant-kadian and dmatveev authored Aug 7, 2024
1 parent cf6cb43 commit 6f98a27
Show file tree
Hide file tree
Showing 3 changed files with 91 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1624,6 +1624,9 @@ void Partitioner::decompressionCutOff(const std::string& func_name) {
// LLaMaGPTQ
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassReshape2>(dcoff_mode, dcoff_type, std::ref(params_to));

// Phi-3 4SymW16A/GPTQ
rewr.add_matcher<ov::npuw::patterns::SymmZP::DCOFFPassCWAI3>(dcoff_mode, dcoff_type, std::ref(params_to));

rewr.run_on_model(f._model);

ov::pass::Validate val;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,89 @@ DCOFFPassReshape2::DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dco
register_matcher(std::make_shared<opp::Matcher>(reshpe, "TagDCOFFReshape2"), std::move(callback));
}

// Pattern: Phi-3 4SymW16A/GPTQ
//
//
// "tensor" "scale" > "tensor"
// Param:A Param:C > Param:A
// i4 f16|f32 > f16
// : : > :
// V : > V
// Convert : > Convert
// f16|f32 : > f32
// : : >
// V V >
// Multiply >
// f16|f32 >
// : >
// : >
// V >
// Convert

DCOFFPassCWAI3::DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref) {
auto paramA = opp::wrap_type<ov::op::v0::Parameter>();
auto paramC = opp::wrap_type<ov::op::v0::Parameter>();
auto cvtA = opp::wrap_type<ov::op::v0::Convert>({paramA});
auto mulply = opp::wrap_type<ov::op::v1::Multiply>({cvtA, paramC});
auto cvt = opp::wrap_type<ov::op::v0::Convert>({mulply});

auto callback = [=](ov::pass::pattern::Matcher& m) {
auto& node_to_output = m.get_pattern_value_map();
auto matched_nodeA = node_to_output.at(paramA).get_node_shared_ptr();
auto matched_nodeC = node_to_output.at(paramC).get_node_shared_ptr();

NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeA));
NPUW_ASSERT(ov::op::util::is_parameter(matched_nodeC));

auto matched_paramA = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeA);
auto matched_paramC = std::static_pointer_cast<ov::op::v0::Parameter>(matched_nodeC);

if (ov::element::i4 == matched_paramA->get_element_type() &&
(ov::element::f16 == matched_paramC->get_element_type() ||
ov::element::f32 == matched_paramC->get_element_type())) {
LOG_DEBUG("Matched: " << matched_paramA << ", set element type to " << dcoff_type);
matched_paramA->set_element_type(dcoff_type);

if (dcoff_mode == DCOffMode::CAST_SCALE) {
NPUW_ASSERT(dcoff_type == ov::element::f16);

LOG_DEBUG("Matched: " << matched_paramC << " - parameter to remove...");
LOG_BLOCK();

// Extra transformation here:
// - remove Multiply + Intermediate Convert
// - mark paramC for removal.
// Convert will be reconnected to paramA directly.

// Record mapping from the Scale coeff parameter to the Real weight parameter
pref.get().scales[matched_paramC] = matched_paramA;

// Disconnect Multiply and Convert from their outputs
auto matched_mulply = node_to_output.at(mulply).get_node_shared_ptr();
auto matched_convrt = node_to_output.at(cvtA).get_node_shared_ptr();
auto drop_outputs = [](std::shared_ptr<ov::Node> node) {
for (auto&& node_outputs : node->outputs()) {
for (auto&& node_reader_port : node_outputs.get_target_inputs()) {
node_outputs.remove_target_input(node_reader_port);
}
}
};
LOG_DEBUG("Dropping the connections...");
drop_outputs(matched_mulply);
drop_outputs(matched_convrt);

LOG_DEBUG("Reconnecting the Root...");
auto matched_cvt = node_to_output.at(cvt).get_node_shared_ptr();
matched_cvt->input(0).replace_source_output(matched_paramA);
}
LOG_DEBUG("Done");
}
return false; // root node hasn't changed
};

register_matcher(std::make_shared<opp::Matcher>(cvt, "TagDCOFFPassCWAI3"), std::move(callback));
}

//------------------------------------------------------------------------------
// Pattern: 4SymW16A for CWAI
//
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,11 @@ class DCOFFPassReshape2 : public ov::pass::MatcherPass {
DCOFFPassReshape2(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
};

class DCOFFPassCWAI3 : public ov::pass::MatcherPass {
public:
DCOFFPassCWAI3(DCOffMode dcoff_mode, ov::element::Type dcoff_type, DCOFFParamRef pref);
};

class CWAI1 : public ov::pass::MatcherPass {
public:
using CPtr = std::shared_ptr<ov::op::v0::Constant>;
Expand Down

0 comments on commit 6f98a27

Please sign in to comment.