diff --git a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp index d540529ae04b3f..d993724e52ce3a 100644 --- a/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp +++ b/src/plugins/intel_cpu/src/transformations/snippets/aarch64/pass/snippets_mark_skipped.cpp @@ -263,64 +263,20 @@ auto is_skipped_op(const std::shared_ptr& op) -> bool { ov::is_type(op); } -bool isSuitableChildForFusingMatMul(const std::shared_ptr &node, const bool canMatMulBeExecutedInI8, - NodeFusingType &updatedChainType, int& fusingAxis) { - // Firsly check for Bias and DQScales fusion - const bool is_bias = ov::is_type(node); - const bool is_dq_scales = ov::is_type(node) && canMatMulBeExecutedInI8; - if (is_bias || is_dq_scales) { - for (const auto &in : node->inputs()) { - const auto& parent_out = in.get_source_output(); - const auto& parent = parent_out.get_node_shared_ptr(); - const auto& parent_pshape = parent_out.get_partial_shape(); - if (ov::is_type(parent) && parent_pshape.rank().is_static()) { - if (parent->get_output_target_inputs(0).size() > 1) - break; - const auto bias_port = 1 - in.get_index(); - const auto bias_out = node->input_value(bias_port); - if ((bias_out.get_target_inputs().size() > 1) || !ov::op::util::is_on_constant_path(bias_out)) - break; - const auto& bias_pshape = bias_out.get_partial_shape(); - if (bias_pshape.is_dynamic()) - break; - auto getNormalizedPShape = [](const ov::PartialShape &dims, size_t ndims) -> ov::PartialShape { - if (dims.size() >= ndims) - return dims; - ov::PartialShape pshape(std::vector(ndims, 1)); - std::copy(dims.rbegin(), dims.rend(), pshape.rbegin()); - return pshape; - }; - const auto bias_pshape_norm = getNormalizedPShape(bias_pshape, parent_pshape.size()); - if (fusingAxis >= static_cast(bias_pshape_norm.size()) || fusingAxis >= static_cast(parent_pshape.size()) || - bias_pshape_norm.size() != parent_pshape.size() || bias_pshape_norm.size() < 2) - break; - if (((bias_pshape_norm[fusingAxis] == parent_pshape[fusingAxis]) || (is_dq_scales && bias_pshape_norm[fusingAxis] == 1)) && - (bias_pshape_norm[fusingAxis] == static_cast(shape_size(bias_pshape_norm.get_shape())))) - return true; - } - } - } - - // FuseMatMulAndSimpleOperation or FuseFullyConnectedAndSimpleOperation +bool isSuitableChildForFC(const std::shared_ptr &node, const bool canMatMulBeExecutedInI8, int& fusingAxis) { // Invoke SupportsFusingWithConvolution_Simple directly instead of isSuitableChildForFusingSimple to // eliminate getNumNonConstInputs() check if (SupportsFusingWithConvolution_Simple(node, fusingAxis)) { size_t num_non_const_inputs = 0; size_t num_mm_inputs = 0; for (const auto &parent_out : node->input_values()) { - // To avoid endless check `is_on_constant_path` for MatMul branch if (one_of(GetNodeFusingType(parent_out.get_node_shared_ptr()), NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8)) num_mm_inputs++; else if (!ov::op::util::is_on_constant_path(parent_out)) num_non_const_inputs++; } - if (num_non_const_inputs + num_mm_inputs != 1) - return false; - - updatedChainType = NodeFusingType::FusedWithMisc; - return true; + return (num_non_const_inputs + num_mm_inputs == 1); } - return false; } @@ -378,10 +334,8 @@ bool SnippetsMarkSkipped::run_on_model(const std::shared_ptr &m) { } } else if (one_of(fusingChainType, NodeFusingType::FusedWithFC, NodeFusingType::FusedWithFCI8)) { const bool isExecutedInINT8 = fusingChainType == NodeFusingType::FusedWithFCI8; - // Handle fusings for both MatMul and FullyConnected - NodeFusingType updatedChainType = fusingChainType; - if (isSuitableChildForFusingMatMul(node, isExecutedInINT8, updatedChainType, channelAxis)) - PropagateIfHasOnlyChild(node, updatedChainType); + if (isSuitableChildForFC(node, isExecutedInINT8, channelAxis)) + PropagateIfHasOnlyChild(node, fusingChainType); } } }