Skip to content

Commit

Permalink
simplified isSuitableChildForFC logic
Browse files Browse the repository at this point in the history
  • Loading branch information
alvoron committed Nov 25, 2024
1 parent b5f3487 commit ceff99e
Showing 1 changed file with 4 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -263,64 +263,20 @@ auto is_skipped_op(const std::shared_ptr<ov::Node>& op) -> bool {
ov::is_type<ov::op::v0::Result>(op);
}

bool isSuitableChildForFusingMatMul(const std::shared_ptr<const Node> &node, const bool canMatMulBeExecutedInI8,
NodeFusingType &updatedChainType, int& fusingAxis) {
// Firsly check for Bias and DQScales fusion
const bool is_bias = ov::is_type<ov::opset1::Add>(node);
const bool is_dq_scales = ov::is_type<ov::opset1::Multiply>(node) && canMatMulBeExecutedInI8;
if (is_bias || is_dq_scales) {
for (const auto &in : node->inputs()) {
const auto& parent_out = in.get_source_output();
const auto& parent = parent_out.get_node_shared_ptr();
const auto& parent_pshape = parent_out.get_partial_shape();
if (ov::is_type<ov::op::v0::MatMul>(parent) && parent_pshape.rank().is_static()) {
if (parent->get_output_target_inputs(0).size() > 1)
break;
const auto bias_port = 1 - in.get_index();
const auto bias_out = node->input_value(bias_port);
if ((bias_out.get_target_inputs().size() > 1) || !ov::op::util::is_on_constant_path(bias_out))
break;
const auto& bias_pshape = bias_out.get_partial_shape();
if (bias_pshape.is_dynamic())
break;
auto getNormalizedPShape = [](const ov::PartialShape &dims, size_t ndims) -> ov::PartialShape {
if (dims.size() >= ndims)
return dims;
ov::PartialShape pshape(std::vector<size_t>(ndims, 1));
std::copy(dims.rbegin(), dims.rend(), pshape.rbegin());
return pshape;
};
const auto bias_pshape_norm = getNormalizedPShape(bias_pshape, parent_pshape.size());
if (fusingAxis >= static_cast<int>(bias_pshape_norm.size()) || fusingAxis >= static_cast<int>(parent_pshape.size()) ||
bias_pshape_norm.size() != parent_pshape.size() || bias_pshape_norm.size() < 2)
break;
if (((bias_pshape_norm[fusingAxis] == parent_pshape[fusingAxis]) || (is_dq_scales && bias_pshape_norm[fusingAxis] == 1)) &&
(bias_pshape_norm[fusingAxis] == static_cast<int64_t>(shape_size(bias_pshape_norm.get_shape()))))
return true;
}
}
}

// FuseMatMulAndSimpleOperation or FuseFullyConnectedAndSimpleOperation
bool isSuitableChildForFC(const std::shared_ptr<const Node> &node, const bool canMatMulBeExecutedInI8, int& fusingAxis) {
// Invoke SupportsFusingWithConvolution_Simple directly instead of isSuitableChildForFusingSimple to
// eliminate getNumNonConstInputs() check
if (SupportsFusingWithConvolution_Simple(node, fusingAxis)) {
size_t num_non_const_inputs = 0;
size_t num_mm_inputs = 0;
for (const auto &parent_out : node->input_values()) {
// To avoid endless check `is_on_constant_path` for MatMul branch
if (one_of(GetNodeFusingType(parent_out.get_node_shared_ptr()), NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8))
num_mm_inputs++;
else if (!ov::op::util::is_on_constant_path(parent_out))
num_non_const_inputs++;
}
if (num_non_const_inputs + num_mm_inputs != 1)
return false;

updatedChainType = NodeFusingType::FusedWithMisc;
return true;
return (num_non_const_inputs + num_mm_inputs == 1);
}

return false;
}

Expand Down Expand Up @@ -378,10 +334,8 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr<ov::Model> &m) {
}
} else if (one_of(fusingChainType, NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8)) {
const bool isExecutedInINT8 = fusingChainType == NodeFusingType::FusedWithFCI8;
// Handle fusings for both MatMul and FullyConnected
NodeFusingType updatedChainType = fusingChainType;
if (isSuitableChildForFusingMatMul(node, isExecutedInINT8, updatedChainType, channelAxis))
PropagateIfHasOnlyChild(node, updatedChainType);
if (isSuitableChildForFC(node, isExecutedInINT8, channelAxis))
PropagateIfHasOnlyChild(node, fusingChainType);
}
}
}
Expand Down

0 comments on commit ceff99e

Please sign in to comment.