diff --git a/docs/articles_en/assets/snippets/lpt_intel_cpu_plugin.cpp b/docs/articles_en/assets/snippets/lpt_intel_cpu_plugin.cpp index d9e41bc77eec17..76e6d60b8e3e90 100644 --- a/docs/articles_en/assets/snippets/lpt_intel_cpu_plugin.cpp +++ b/docs/articles_en/assets/snippets/lpt_intel_cpu_plugin.cpp @@ -38,7 +38,7 @@ auto defaultPrecisions = useLpt ? ov::pass::low_precision::precision_set::get_int8_support() : std::vector{}; if (useLpt) { // disable constant folding on dequantization subgraphs so they can be processed by LPT - manager.register_pass(defaultPrecisions); + manager.register_pass(defaultPrecisions); } // OpenVINO common transformations happen here diff --git a/src/common/low_precision_transformations/tests/mark_dequantization_subgraph_transformation.cpp b/src/common/low_precision_transformations/tests/mark_dequantization_subgraph_transformation.cpp index f68b7ba43b7c9f..bf254cded24ed8 100644 --- a/src/common/low_precision_transformations/tests/mark_dequantization_subgraph_transformation.cpp +++ b/src/common/low_precision_transformations/tests/mark_dequantization_subgraph_transformation.cpp @@ -11,10 +11,51 @@ #include "transformations/rt_info/keep_const_precision.hpp" #include "common_test_utils/ov_test_utils.hpp" +#include "transformations/convert_precision.hpp" using namespace ov; -TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) { +TEST_F(TransformationTestsF, KeepConstPrecision) { + { + auto lp_const = std::make_shared(element::u4, Shape{27}, 1); + + const auto target_shape = std::make_shared(ov::element::i64, ov::Shape{3}, 3); + auto reshape = std::make_shared(lp_const, target_shape, false); + + auto second_convert = std::make_shared(reshape, element::f32); + auto zero_point = opset10::Constant::create(element::f32, Shape{}, {127}); + auto subtract = std::make_shared(second_convert, zero_point); + auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2}); + auto multiply = std::make_shared(subtract, scale); + auto stub_op = std::make_shared(multiply); + model = std::make_shared(stub_op, ParameterVector{}); + } + + manager.register_pass(element::TypeVector{element::u4}); + manager.register_pass(); + manager.register_pass(element::TypeVector{element::u4}); + manager.register_pass(ov::element::u4, ov::element::u8, type_to_fuse_map{}, false, false); + + { + auto lp_const = std::make_shared(element::u4, Shape{3, 3, 3}, 1); + auto second_convert = std::make_shared(lp_const, element::f32); + auto zero_point = opset10::Constant::create(element::f32, Shape{}, {127}); + auto subtract = std::make_shared(second_convert, zero_point); + auto scale = opset10::Constant::create(element::f32, Shape{}, {0.2}); + auto multiply = std::make_shared(subtract, scale); + auto stub_op = std::make_shared(multiply); + model_ref = std::make_shared(stub_op, ParameterVector{}); + + mark_as_dequantization_node(subtract); + mark_as_dequantization_node(multiply); + enable_keep_const_precision(lp_const); + ov::pass::disable_constant_folding(second_convert); + } + comparator.enable(FunctionsComparator::CmpValues::CONST_VALUES); + comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); +} + +TEST_F(TransformationTestsF, MarkDequantizationTransformation) { // Input graph: // // Parameter @@ -37,7 +78,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) { // \ / // Convolution // - // After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph + // After MarkDequantization all Subtract and Multiply nodes from above graph // are marked with 'DequantizationNode' attribute. // All 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute // Weights and zero points are marked with 'KeepConstPrecision' attribute @@ -82,7 +123,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) { model = std::make_shared(conv, ParameterVector{parameter}); } - manager.register_pass(element::TypeVector{element::u8, element::i8}); + manager.register_pass(element::TypeVector{element::u8, element::i8}); + manager.register_pass(element::TypeVector{element::u8, element::i8}); manager.register_pass(); { @@ -138,7 +180,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformation) { comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); } -TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint) { +TEST_F(TransformationTestsF, MarkDequantizationTransformationNoZeroPoint) { // Input graph: // // Parameter @@ -158,7 +200,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint // \ / // Convolution // - // After MarkDequantizationSubgraph all Multiply nodes from above graph + // After MarkDequantization all Multiply nodes from above graph // are marked with 'DequantizationNode' attribute. // Also 'Convert(DCF)' node from above graph is marked with 'DisableConstantFolding' attribute // Weights node is marked with 'KeepConstPrecision' attribute @@ -197,7 +239,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint model = std::make_shared(conv, ParameterVector{parameter}); } - manager.register_pass(element::TypeVector{element::u8, element::i8}); + manager.register_pass(element::TypeVector{element::u8, element::i8}); + manager.register_pass(element::TypeVector{element::u8, element::i8}); manager.register_pass(); { @@ -242,7 +285,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); } -TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPointFP16) { +TEST_F(TransformationTestsF, MarkDequantizationTransformationNoZeroPointFP16) { // Input graph: // // Parameter @@ -262,7 +305,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint // \ / // Convolution // - // After MarkDequantizationSubgraph all Multiply nodes from above graph + // After MarkDequantization all Multiply nodes from above graph // are marked with 'DequantizationNode' attribute. // Also 'Convert(DCF)' node from above graph is marked with 'DisableConstantFolding' attribute // Weights node is marked with 'KeepConstPrecision' attribute @@ -305,9 +348,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint model = std::make_shared(conv, ParameterVector{parameter}); } - manager.register_pass(element::TypeVector{element::u8, element::i8}); - manager.register_pass(); - manager.register_pass(); + manager.register_pass(element::TypeVector{element::u8, element::i8}); + manager.register_pass(element::TypeVector{element::u8, element::i8}); { auto parameter = std::make_shared(element::f32, Shape{1, 16, 14, 14}); @@ -355,7 +397,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNoZeroPoint comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); } -TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstantWeights) { +TEST_F(TransformationTestsF, MarkDequantizationTransformationNotConstantWeights) { // Input graph: // // Parameter @@ -378,7 +420,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant // \ / // Convolution // - // After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph + // After MarkDequantization all Subtract and Multiply nodes from above graph // are marked with 'DequantizationNode' attribute. // Also all 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute // Weights and zero point nodes are marked with 'KeepConstPrecision' attribute @@ -426,7 +468,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant model = std::make_shared(conv, ParameterVector{parameter}); } - manager.register_pass(element::TypeVector{element::u8, element::i8}); + manager.register_pass(element::TypeVector{element::u8, element::i8}); + manager.register_pass(element::TypeVector{element::u8, element::i8}); manager.register_pass(); { @@ -481,7 +524,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationNotConstant comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS); } -TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubConst) { +TEST_F(TransformationTestsF, MarkDequantizationTransformationFoldSubConst) { // Input graph: After transformation: // // Constant Constant Constant @@ -495,7 +538,7 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubCons // | / \ / // Multiply Multiply // - // After MarkDequantizationSubgraph all Subtract and Multiply nodes from above graph + // After MarkDequantization all Subtract and Multiply nodes from above graph // are marked with 'DequantizationNode' attribute. // Also all 'Convert(DCF)' node before weights is marked with 'DisableConstantFolding' attribute // but Convert before Dequantization Sub const isn't because fold_subtract_const is set to true @@ -512,7 +555,8 @@ TEST_F(TransformationTestsF, MarkDequantizationSubgraphTransformationFoldSubCons model = std::make_shared(ov::OutputVector{multiply}); } - manager.register_pass(element::TypeVector{element::u8}, true); + manager.register_pass(element::TypeVector{element::u8}, true); + manager.register_pass(element::TypeVector{element::u8}, true); manager.register_pass(); { diff --git a/src/common/transformations/include/transformations/low_precision/mark_dequantization_subgraph.hpp b/src/common/transformations/include/transformations/low_precision/mark_dequantization_subgraph.hpp index 8b9b9e573ba957..6cbd8d990ac73e 100644 --- a/src/common/transformations/include/transformations/low_precision/mark_dequantization_subgraph.hpp +++ b/src/common/transformations/include/transformations/low_precision/mark_dequantization_subgraph.hpp @@ -4,27 +4,77 @@ #pragma once +#include + #include "openvino/pass/matcher_pass.hpp" #include "transformations_visibility.hpp" namespace ov { namespace pass { +/** + * @ingroup ov_transformation_common_api + * + * @brief MarkDequantization matches Dequantization subgraphs and marks Subtract and Multiply nodes + * with the dequantization attribute. Also if Convert nodes are part of the subgraph they might be marked + * with the disable_const_folding attribute. + * + * If Convert -> Reshape/Unsqueeze are part of the Dequantization subraph, Convert and Reshape/Unsqueeze + * nodes will be swapped to eliminate Reshape/Unsqueeze in the next ConstantFolding. + * + * Dequantization subgraph may have two forms: with and without Subtract. + * ZeroPoints and Scale might be present as subgraphs and include Convert ops. + * + * Input ZeroPoints + * │ │ + * ▼ ▼ + * Convert (opt) Reshape/Unsqueeze + * │ │ + * ▼ ▼ Scale Input Scale + * Subtract │ │ │ + * │ ▼ ▼ ▼ + * │ (opt) Reshape/Unsqueeze Convert (opt) Reshape/Unsqueeze + * │ │ │ │ + * ▼ ▼ ▼ ▼ + * Multiply Multiply + * + */ +class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("MarkDequantization", "0"); + explicit MarkDequantization(const element::TypeVector& precisions, + bool fold_subtract_const = false, + bool fold_multiply_const = true); +}; /** * @ingroup ov_transformation_common_api - * @brief MarkDequantizationSubgraph marks dequantization subgraph, that is: - * Convert->Subtract(optional)->Multiply - * in two ways: - * - first Convert is marked with DisableConstantFolding attribute, also if Subtract is present - * and its second input is a Convert - that Convert is marked with DisableConstantFolding as well, - * - Subtract and Multiply are marked with 'DequantizationNode' attribute + * + * @brief KeepConstsPrecision matches Dequantization subgraphs and if Input/ZeroPoints/Scale are Constants + * they might be marked with keep_const_precision attribute. + * + * Dequantization subgraph may have two forms: with and without Subtract. + * + * Input + * │ + * ▼ + * Convert ZeroPoints + * │ │ + * ▼ ▼ Input + * Subtract │ + * │ ▼ + * │ Scale Convert Scale + * │ │ │ │ + * ▼ ▼ ▼ ▼ + * Multiply Multiply + * */ -class TRANSFORMATIONS_API MarkDequantizationSubgraph : public MatcherPass { +class TRANSFORMATIONS_API KeepConstsPrecision : public ov::pass::MatcherPass { public: - OPENVINO_RTTI("MarkDequantizationSubgraph", "0"); - MarkDequantizationSubgraph(const element::TypeVector& precisions, - const bool fold_subtract_const = false, - const bool disable_fold_multiply_const = false); + OPENVINO_RTTI("KeepConstsPrecision", "0"); + explicit KeepConstsPrecision(const element::TypeVector& precisions, + bool fold_subtract_const = false, + bool fold_multiply_const = true); }; + } // namespace pass } // namespace ov diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 282fc69486b923..185ae84ec83642 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -130,7 +130,7 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr using namespace ov::pass; REGISTER_PASS(manager, InitNodeInfo) if (m_low_precision_enabled) { - manager.register_pass( + manager.register_pass( element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4}); } if (!m_use_shapes) { diff --git a/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp b/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp index 9fdb17804409a9..8132ef2e68e2f9 100644 --- a/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp +++ b/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp @@ -4,115 +4,182 @@ #include "transformations/low_precision/mark_dequantization_subgraph.hpp" +#include "itt.hpp" #include "openvino/op/multiply.hpp" +#include "openvino/op/reshape.hpp" #include "openvino/op/subtract.hpp" -#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/op/unsqueeze.hpp" +#include "openvino/pass/manager.hpp" +#include "openvino/pass/pattern/op/optional.hpp" #include "openvino/pass/pattern/op/wrap_type.hpp" #include "transformations/rt_info/dequantization_node.hpp" #include "transformations/rt_info/disable_constant_folding.hpp" #include "transformations/rt_info/keep_const_precision.hpp" #include "transformations/utils/utils.hpp" -ov::pass::MarkDequantizationSubgraph::MarkDequantizationSubgraph(const element::TypeVector& precisions, - const bool fold_subtract_const, - const bool disable_fold_multiply_const) { - // Dequantization subgraph may have two forms: with and without Subtract - // - // Input Input - // | | - // Convert zero point OR Convert scale - // \ / \ / - // Subtract scale Multiply - // \ / - // Multiply - // - auto input_pattern = pattern::any_input(); - auto convert_pattern = pattern::wrap_type({input_pattern}, pattern::consumers_count(1)); - auto zero_point_pattern = pattern::any_input(); - auto subtract_pattern = pattern::wrap_type({convert_pattern, zero_point_pattern}); - auto multiply_pattern = pattern::wrap_type({subtract_pattern, pattern::any_input()}); - auto multiply_no_subtract_pattern = - pattern::wrap_type({convert_pattern, pattern::any_input()}); - auto root = std::make_shared(OutputVector{multiply_pattern, multiply_no_subtract_pattern}); - - ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](pattern::Matcher& m) -> bool { - const auto& pattern_map = m.get_pattern_value_map(); - auto convert = pattern_map.at(convert_pattern).get_node_shared_ptr(); - auto input = pattern_map.at(input_pattern); - const auto multiply = m.get_match_root(); +using namespace ov; +using namespace ov::op; +using namespace ov::pass::pattern; - if (transformation_callback(multiply)) { - return false; - } +namespace { - auto subtract_it = pattern_map.find(subtract_pattern); - if (subtract_it == pattern_map.end()) { - for (size_t i = 0; i < multiply->get_input_size(); i++) { - const auto node = ov::as_type_ptr(multiply->get_input_node_shared_ptr(i)); - if (node && std::find(precisions.begin(), precisions.end(), node->get_input_element_type(0)) != - precisions.end()) { - convert = node; - input = convert->input_value(0); - } +bool check_precision(const ov::element::Type_t type_to_check, const ov::element::TypeVector& precisions) { + return std::find(precisions.begin(), precisions.end(), type_to_check) != precisions.end(); +}; + +using RTInfoSetter = std::function& node)>; +void set_rt_info(const PatternValueMap& pt_map, + const RTInfoSetter& rt_info_setter, + const NodeVector& pattern_nodes, + const ov::element::TypeVector& precisions) { + for (const auto& pattern_node : pattern_nodes) { + if (pt_map.count(pattern_node)) { + auto node = pt_map.at(pattern_node).get_node_shared_ptr(); + + // we don't need to mark Converts with disable_cf attribute if the `from` type (input type) + // is not in the `precisions` list. + if (ov::as_type_ptr(node) && !check_precision(node->get_input_element_type(0), precisions)) { + continue; } + + rt_info_setter(node); } + } +}; + +bool swap_nodes(const PatternValueMap& pt_map, + const std::shared_ptr& first, + const std::shared_ptr& second) { + if (pt_map.count(first) && pt_map.count(second)) { + auto first_node = pt_map.at(first).get_node_shared_ptr(); + auto second_node = pt_map.at(second).get_node_shared_ptr(); - const auto& input_precision = input.get_element_type(); - // validation by Convert operation input precisions - if (std::find(precisions.begin(), precisions.end(), input_precision) == precisions.end()) { + auto target_inputs = second_node->output(0).get_target_inputs(); + second_node->input(0).replace_source_output(first_node->input_value(0)); + first_node->input(0).replace_source_output(second_node->output(0)); + for (const auto& in : target_inputs) { + in.replace_source_output(first_node->output(0)); + } + first_node->validate_and_infer_types(); + second_node->validate_and_infer_types(); + return true; + } + return false; +} + +} // namespace + +ov::pass::MarkDequantization::MarkDequantization(const element::TypeVector& precisions, + const bool fold_subtract_const, + const bool fold_multiply_const) { + MATCHER_SCOPE(MarkDequantization); + + // data input: + auto input_pattern = any_input(); + auto convert_pattern = wrap_type({input_pattern}, consumers_count(1)); + + // zero points: + auto zp_pattern = any_input(); + auto zp_convert_pattern = pattern::optional(zp_pattern); + auto zp_reshape_pattern = pattern::optional({zp_convert_pattern, any_input()}); + auto subtract_pattern = pattern::optional({convert_pattern, zp_reshape_pattern}); + + // scale: + auto scale_pattern = any_input(); + auto scale_convert_pattern = pattern::optional(scale_pattern); + auto scale_reshape_pattern = pattern::optional({scale_convert_pattern, any_input()}); + auto multiply_pattern = wrap_type({subtract_pattern, scale_reshape_pattern}); + + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) -> bool { + const auto& pt_map = m.get_pattern_value_map(); + auto convert = pt_map.at(convert_pattern); + auto input = pt_map.at(input_pattern); + const auto multiply = m.get_match_root(); + + if (!check_precision(input.get_element_type(), precisions) || transformation_callback(multiply)) { return false; } - if (ov::op::util::is_on_constant_path(input)) { - // disable ConstantFolding if dequantization subgraph is on constant data - ov::disable_constant_folding(convert); - // It is also necessary to avoid precision conversion for constant nodes with input_precision - auto keep_const_precision = [&](Node* node) { - if (auto constant = ov::as_type(node)) { - const auto& const_et = constant->get_element_type(); - if (std::find(precisions.begin(), precisions.end(), const_et) != precisions.end()) - ov::enable_keep_const_precision(convert->get_input_node_shared_ptr(0)); - } - }; - std::unordered_set visited; - ov::op::util::visit_constant_path(input.get_node(), visited, keep_const_precision); + // Multiply and Subtract have to be marked as dq + set_rt_info(pt_map, mark_as_dequantization_node, {subtract_pattern, multiply_pattern}, {/* not applicable */}); + + // Convert might be presented on scales, zp and data_input. + // Depending on the transformation arguments they have to be marked/unmarked with disable_cf rt_info. + NodeVector converts_to_mark = {convert_pattern}; + NodeVector converts_to_unmark = {}; + + if (fold_subtract_const) { + converts_to_unmark.push_back(zp_convert_pattern); + } else { + converts_to_mark.push_back(zp_convert_pattern); } - if (subtract_it != pattern_map.end()) { - // mark Subtract as dequantization node - ov::mark_as_dequantization_node(subtract_it->second.get_node_shared_ptr()); - auto zero_point = pattern_map.at(zero_point_pattern).get_node_shared_ptr(); - if (ov::is_type(zero_point) && - input_precision == zero_point->get_input_element_type(0) && - ov::is_type(zero_point->get_input_node_ptr(0))) { - if (!fold_subtract_const) { - // disable ConstantFolding also for Convert on zero_point - // so we don't have to constantfold it and then convert it back to - // low precision in LP transformations - ov::disable_constant_folding(zero_point); - ov::enable_keep_const_precision(zero_point->get_input_node_shared_ptr(0)); - } else { - ov::enable_constant_folding(zero_point); - ov::disable_keep_const_precision(zero_point->get_input_node_shared_ptr(0)); - } - } + if (fold_multiply_const) { + converts_to_unmark.push_back(scale_convert_pattern); + } else { + converts_to_mark.push_back(scale_convert_pattern); } - // mark Multiply as dequantization node - ov::mark_as_dequantization_node(multiply); - auto scale = multiply->get_input_node_shared_ptr(1); - if (ov::is_type(scale) && - ov::is_type(scale->get_input_node_ptr(0))) { - if (disable_fold_multiply_const) { - ov::disable_constant_folding(scale); - ov::unmark_as_decompression(scale); - ov::enable_keep_const_precision(scale->get_input_node_shared_ptr(0)); - } + set_rt_info(pt_map, disable_constant_folding, converts_to_mark, precisions); + set_rt_info(pt_map, enable_constant_folding, converts_to_unmark, precisions); + + // Move Reshape/Unsqueeze ops up to fold them in ConstantFolding. + auto changed = swap_nodes(pt_map, zp_convert_pattern, zp_reshape_pattern); + changed = swap_nodes(pt_map, scale_convert_pattern, scale_reshape_pattern) || changed; + return changed; + }; + + auto m = std::make_shared(multiply_pattern, "MarkDequantization"); + this->register_matcher(m, callback); +} + +ov::pass::KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions, + bool fold_subtract_const, + bool fold_multiply_const) { + MATCHER_SCOPE(KeepConstsPrecision); + + // data input: + auto input_pattern = any_input(); + auto convert_pattern = wrap_type({input_pattern}, consumers_count(1)); + + // zero points: + auto zp_pattern = any_input(); + auto zp_convert_pattern = pattern::optional(zp_pattern); + auto subtract_pattern = pattern::optional({convert_pattern, zp_convert_pattern}); + + // scale: + auto scale_pattern = any_input(); + auto scale_convert_pattern = pattern::optional(scale_pattern); + auto multiply_pattern = wrap_type({subtract_pattern, scale_convert_pattern}); + + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](Matcher& m) -> bool { + const auto& pt_map = m.get_pattern_value_map(); + const auto multiply = m.get_match_root(); + + if (transformation_callback(multiply)) { + return false; } + using PatternNode = std::shared_ptr; + std::map keep_const_precisions = {{input_pattern, false}, + {zp_pattern, fold_subtract_const}, + {scale_pattern, fold_multiply_const}}; + for (const auto& pattern_node : keep_const_precisions) { + if (pt_map.count(pattern_node.first)) { + auto node = pt_map.at(pattern_node.first).get_node_shared_ptr(); + const auto& precision = node->get_output_element_type(0); + if (ov::as_type_ptr(node) && check_precision(precision, precisions)) { + if (pattern_node.second) { + ov::disable_keep_const_precision(node); + } else { + ov::enable_keep_const_precision(node); + } + } + } + } return false; }; - auto m = std::make_shared(root, "MarkDequantizationSubgraph"); + auto m = std::make_shared(multiply_pattern, "KeepConstsPrecision"); this->register_matcher(m, callback); } diff --git a/src/common/transformations/tests/op_conversions/convert_subtract.cpp b/src/common/transformations/tests/op_conversions/convert_subtract.cpp index fb835d0cdb581e..1a1d6d8b5c83bb 100644 --- a/src/common/transformations/tests/op_conversions/convert_subtract.cpp +++ b/src/common/transformations/tests/op_conversions/convert_subtract.cpp @@ -77,7 +77,7 @@ TEST_F(TransformationTestsF, ConvertSubtractDequantizationSubgraph) { model = std::make_shared(mul, ParameterVector{data}); - manager.register_pass(element::TypeVector{element::u8}); + manager.register_pass(element::TypeVector{element::u8}); manager.register_pass(); } diff --git a/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp b/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp index 47fcc7af60bf61..e3d3f4f1235504 100644 --- a/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp +++ b/src/frontends/onnx/frontend/src/op/dequantize_linear.cpp @@ -231,6 +231,7 @@ ov::OutputVector dequantize_linear(const ov::frontend::onnx::Node& node) { src_x.get_shape()[0] % block_size == 0, "DequantizeLinear doesn't support case when first dimension of X cannot be divided by block_size"); + // For further broadcasting scales and zp - reshape input to a shape [x.shape[0]/block_size, block_size, x.shape[1]] ov::Output broadcastable_x = op::util::reshape( src_x, Shape{static_cast(src_x.get_shape()[0]) / block_size, block_size, src_x.get_shape()[1]}); @@ -240,16 +241,14 @@ ov::OutputVector dequantize_linear(const ov::frontend::onnx::Node& node) { const auto scale_type = scale.get_element_type(); if (inputs.size() > 2) { zp = inputs[2]; + zp = std::make_shared(zp, unsqueezed_axes); if (zp.get_element_type() != scale.get_element_type()) { zp = std::make_shared(zp, scale_type); - disable_constant_folding(zp.get_node_shared_ptr()); } - zp = std::make_shared(zp, unsqueezed_axes); } const auto& x = src_x.get_element_type() == scale_type ? broadcastable_x : std::make_shared(broadcastable_x, scale_type); - // For further broadcasting scales and zp - reshape input to a shape [x.shape[0]/block_size, block_size, x.shape[1]] // Adding additional dimension for broadcasting scale = std::make_shared(scale, unsqueezed_axes); diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 909f6b7531d421..d777feea0d2e69 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -343,8 +343,9 @@ void Transformations::PreLpt(const std::vector& defaultPrecis ov::element::i4, ov::element::nf4, ov::element::f4e2m1}; + CPU_REGISTER_PASS_X64(decompression_handling_manager, - ov::pass::MarkDequantizationSubgraph, + ov::pass::MarkDequantization, decompression_precisions, false, true); @@ -353,7 +354,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis [&](const_node_ptr& node) -> bool { return !is_decompression_multiply(node); }, - ov::pass::MarkDequantizationSubgraph); + ov::pass::MarkDequantization); CPU_SET_CALLBACK_COMMON( decompression_handling_manager, @@ -379,7 +380,7 @@ void Transformations::PreLpt(const std::vector& defaultPrecis ov::pass::Manager manager("Plugin:CPU"); manager.set_per_pass_validation(false); if (useLpt) - CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkDequantizationSubgraph, defaultPrecisions); + CPU_REGISTER_PASS_COMMON(manager, ov::pass::MarkDequantization, defaultPrecisions); auto get_convert_precisions = [&]() { precisions_map map = {{ov::element::i64, ov::element::i32}, @@ -434,6 +435,13 @@ void Transformations::PreLpt(const std::vector& defaultPrecis CPU_REGISTER_PASS_COMMON(manager, ov::pass::AUGRUCellFusion); CPU_REGISTER_PASS_COMMON(manager, ov::pass::CommonOptimizations); + CPU_REGISTER_PASS_X64(manager, ov::pass::KeepConstsPrecision, decompression_precisions, false, true); + CPU_SET_CALLBACK_X64( + manager, + [&](const_node_ptr& node) -> bool { + return !is_decompression_multiply(node); + }, + ov::pass::KeepConstsPrecision); CPU_REGISTER_PASS_COMMON(manager, ov::pass::WrapInterpolateIntoTransposes); CPU_REGISTER_PASS_COMMON(manager, ov::pass::TransposeSinking); CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConvertSequenceToTensorIterator); diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 50eecf51b945b7..94d93277e57816 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -292,10 +292,10 @@ void TransformationsPipeline::apply(std::shared_ptr func) { auto is_model_quantized = ov::pass::low_precision::LowPrecision::isFunctionQuantized(func); enableInt8 = config.get_property(ov::intel_gpu::enable_lp_transformations) && is_model_quantized; - if (enableInt8) { - manager.register_pass( - std::vector{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 }); - } + + manager.register_pass( + std::vector{ ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4 }, + !device_info.supports_immad); manager.register_pass(); manager.register_pass(); @@ -373,8 +373,9 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // Disable subtract folding only for the dGPUs to meet the requirements of oneDNN: // it expects to have the same data type for weights and zero points (apply it only for u8 data type, since other compression // types are not supported by oneDNN) - manager.register_pass(supported_woq_types, !device_info.supports_immad); - pass_config->set_callback([&](const std::shared_ptr node) { + manager.register_pass(supported_woq_types, !device_info.supports_immad); + pass_config->set_callback([&](const std::shared_ptr node) { return !is_decompression_multiply(node, device_info.supports_immad); }); @@ -927,8 +928,8 @@ void TransformationsPipeline::apply(std::shared_ptr func) { manager.register_pass(fuse_mlp_swiglu); // ZP should not be folded for FC. But still, ZP should be folded for Gather. - // Therefore, run MarkDequantizationSubgraph again to fold ZP constant. - manager.register_pass(supported_woq_types, true); + // Therefore, run MarkDequantization again to fold ZP constant. + manager.register_pass(supported_woq_types, true); if (device_info.supports_immad) { if (disable_horizontal_fc_fusion) manager.register_pass();