From 189153f47c9bdabf23ee2a7782ad5147b65e6e04 Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Tue, 19 Nov 2024 14:56:31 +0400 Subject: [PATCH] fix issue on gpu, docs, refactoring --- .../mark_dequantization_subgraph.hpp | 5 +- .../moc_transformations.cpp | 14 ++-- .../mark_dequantization_subgraph.cpp | 84 ++++++++++++------- 3 files changed, 68 insertions(+), 35 deletions(-) diff --git a/src/common/transformations/include/transformations/low_precision/mark_dequantization_subgraph.hpp b/src/common/transformations/include/transformations/low_precision/mark_dequantization_subgraph.hpp index 7770647d736e67..c60d9ca5d3659c 100644 --- a/src/common/transformations/include/transformations/low_precision/mark_dequantization_subgraph.hpp +++ b/src/common/transformations/include/transformations/low_precision/mark_dequantization_subgraph.hpp @@ -11,10 +11,11 @@ namespace ov { namespace pass { - /** * @ingroup ov_transformation_common_api - * @brief TBA + * @brief MarkDequantizationAndDecompression is a set of transformation which mark + * Dequantization and Decompression patterns with the keep_const_precision, disable_const_folding and + * dequantization attributes. Also it calls ConstantFolding. */ class TRANSFORMATIONS_API MarkDequantizationAndDecompression : public ModelPass { public: diff --git a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp index 86e6604e3241cd..cfcd1a96fa577f 100644 --- a/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/moc_transformations.cpp @@ -130,15 +130,12 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr using namespace ov::pass; REGISTER_PASS(manager, InitNodeInfo) REGISTER_PASS(manager, EliminateConvert) - if (m_low_precision_enabled) { - manager.register_pass( - element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4}); - } if (!m_use_shapes) { manager.register_pass(); } + // RemoveConcatZeroDimInput and RemoveMultiSubGraphOpDanglingParamsResults - // should be performed before first ConstantFolding call. + // should be performed before first !ConstantFolding! call. // The passes can deteach graph branches where zero dimesion is calculated. // Zero dimensions in shape causes creation empty tensors, which are incorrect during CF. // In particular, if zero dim tensor is consumed in body of MultiSubGraphOp @@ -147,6 +144,13 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr REGISTER_PASS(manager, RemoveConcatZeroDimInput) REGISTER_PASS(manager, EliminateLoopInputsOutputs); REGISTER_PASS(manager, Validate) + + if (m_low_precision_enabled) { + // includes ConstantFolding call + manager.register_pass( + element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4}); + } + // todo: ticket 96960 // the order EliminateDuplicateTIInputs and RemoveMultiSubGraphOpDanglingParamsResults is important // it looks like we need to combine these transformations into one. diff --git a/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp b/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp index 34aaac6a5dcf41..3e742ff305c68c 100644 --- a/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp +++ b/src/common/transformations/src/transformations/low_precision/mark_dequantization_subgraph.cpp @@ -24,7 +24,30 @@ using namespace ov::pass::pattern; /** * @ingroup ov_transformation_common_api - * @brief TBA + * + * @brief MarkDequantization matches Dequantization subgraphs and marks Subtract and Multiply nodes + * with the dequantization attribute. Also if Convert nodes are part of the subgraph they might be marked + * with the disable_const_folding attribute. + * + * If Convert -> Reshape/Unsqueeze are part of the Dequantization subraph, Convert and Reshape/Unsqueeze + * nodes will be swapped to eliminate Reshape/Unsqueeze in the next ConstantFolding. + * + * Dequantization subgraph may have two forms: with and without Subtract. + * ZeroPoints and Scale might be present as subgraphs and include Convert ops. + * + * Input ZeroPoints + * │ │ + * ▼ ▼ + * Convert (opt) Reshape/Unsqueeze + * │ │ + * ▼ ▼ Scale Input Scale + * Subtract │ │ │ + * │ ▼ ▼ ▼ + * │ (opt) Reshape/Unsqueeze Convert (opt) Reshape/Unsqueeze + * │ │ │ │ + * ▼ ▼ ▼ ▼ + * Multiply Multiply + * */ class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass { public: @@ -36,7 +59,25 @@ class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass { /** * @ingroup ov_transformation_common_api - * @brief TBA + * + * @brief KeepConstsPrecision matches Dequantization subgraphs and if Input/ZeroPoints/Scale are Constants + * they might be marked with keep_const_precision attribute. + * + * Dequantization subgraph may have two forms: with and without Subtract. + * + * Input + * │ + * ▼ + * Convert ZeroPoints + * │ │ + * ▼ ▼ Input + * Subtract │ + * │ ▼ + * │ Scale Convert Scale + * │ │ │ │ + * ▼ ▼ ▼ ▼ + * Multiply Multiply + * */ class TRANSFORMATIONS_API KeepConstsPrecision : public ov::pass::MatcherPass { public: @@ -91,16 +132,7 @@ void swap_nodes(const PatternValueMap& pt_map, MarkDequantization::MarkDequantization(const element::TypeVector& precisions, const bool fold_subtract_const, const bool fold_multiply_const) { - // Dequantization subgraph may have two forms: with and without Subtract - // - // Input Input - // | | - // Convert zero point OR Convert scale - // \ / \ / - // Subtract scale Multiply - // \ / - // Multiply - // + // data input: auto input_pattern = any_input(); auto convert_pattern = wrap_type({input_pattern}, consumers_count(1)); @@ -134,8 +166,7 @@ MarkDequantization::MarkDequantization(const element::TypeVector& precisions, NodeVector converts_to_mark = {convert_pattern}; NodeVector converts_to_unmark = {}; - if (fold_subtract_const || - (pt_map.count(subtract_pattern) && pt_map.at(zp_pattern).get_element_type() != input.get_element_type())) { + if (fold_subtract_const) { converts_to_unmark.push_back(zp_convert_pattern); } else { converts_to_mark.push_back(zp_convert_pattern); @@ -163,16 +194,7 @@ MarkDequantization::MarkDequantization(const element::TypeVector& precisions, KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions, bool fold_subtract_const, bool fold_multiply_const) { - // Dequantization subgraph may have two forms: with and without Subtract - // - // Input Input - // | | - // Convert zero point OR Convert scale - // \ / \ / - // Subtract scale Multiply - // \ / - // Multiply - // + // data input: auto input_pattern = any_input(); auto convert_pattern = wrap_type({input_pattern}, consumers_count(1)); @@ -194,9 +216,10 @@ KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions, return false; } - std::map, bool> keep_const_precisions = {{input_pattern, false}, - {zp_pattern, fold_subtract_const}, - {scale_pattern, fold_multiply_const}}; + using PatternNode = std::shared_ptr; + std::map keep_const_precisions = {{input_pattern, false}, + {zp_pattern, fold_subtract_const}, + {scale_pattern, fold_multiply_const}}; for (const auto& pattern_node : keep_const_precisions) { if (pt_map.count(pattern_node.first)) { auto node = pt_map.at(pattern_node.first).get_node_shared_ptr(); @@ -218,7 +241,12 @@ KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions, } bool pass::MarkDequantizationAndDecompression::run_on_model(const std::shared_ptr& m) { - ov::pass::Manager manager("MarkDequantizationAndDecompressionManager"); + const auto& pass_config = get_pass_config(); + auto callback = pass_config->get_callback(); + pass_config->set_callback(callback); + pass_config->set_callback(callback); + + ov::pass::Manager manager(pass_config, "MarkDequantizationAndDecompressionManager"); manager.register_pass(); manager.register_pass(m_precisions, m_fold_subtract_const, m_fold_multiply_const); manager.register_pass();