Skip to content

Commit

Permalink
remove the dq model pass, leave the separate matchers only
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Nov 19, 2024
1 parent 48b5694 commit 06f1c22
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 134 deletions.
2 changes: 1 addition & 1 deletion docs/articles_en/assets/snippets/lpt_intel_cpu_plugin.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ auto defaultPrecisions =
useLpt ? ov::pass::low_precision::precision_set::get_int8_support() : std::vector<ov::element::Type>{};
if (useLpt) {
// disable constant folding on dequantization subgraphs so they can be processed by LPT
manager.register_pass<ov::pass::MarkDequantizationAndDecompression>(defaultPrecisions);
manager.register_pass<ov::pass::MarkDequantization>(defaultPrecisions);
}

// OpenVINO common transformations happen here
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,9 @@ TEST_F(TransformationTestsF, KeepConstPrecision) {
model = std::make_shared<Model>(stub_op, ParameterVector{});
}

manager.register_pass<pass::MarkDequantizationAndDecompression>(element::TypeVector{element::u4});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u4});
manager.register_pass<pass::ConstantFolding>();
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u4});
manager.register_pass<pass::ConvertPrecision>(ov::element::u4, ov::element::u8, type_to_fuse_map{}, false, false);

{
Expand All @@ -46,7 +48,7 @@ TEST_F(TransformationTestsF, KeepConstPrecision) {
}
}

TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformation) {
TEST_F(TransformationTestsF, MarkDequantizationTransformation) {
// Input graph:
//
// Parameter
Expand All @@ -69,7 +71,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformation) {
// \ /
// Convolution
//
// After MarkDequantizationAndDecompression all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// All 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute
// Weights and zero points are marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -114,7 +116,8 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformation) {
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationAndDecompression>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -170,7 +173,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformation) {
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNoZeroPoint) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNoZeroPoint) {
// Input graph:
//
// Parameter
Expand All @@ -190,7 +193,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNoZ
// \ /
// Convolution
//
// After MarkDequantizationAndDecompression all Multiply nodes from above graph
// After MarkDequantization all Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also 'Convert(DCF)' node from above graph is marked with 'DisableConstantFolding' attribute
// Weights node is marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -229,7 +232,8 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNoZ
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationAndDecompression>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -274,7 +278,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNoZ
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNoZeroPointFP16) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNoZeroPointFP16) {
// Input graph:
//
// Parameter
Expand All @@ -294,7 +298,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNoZ
// \ /
// Convolution
//
// After MarkDequantizationAndDecompression all Multiply nodes from above graph
// After MarkDequantization all Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also 'Convert(DCF)' node from above graph is marked with 'DisableConstantFolding' attribute
// Weights node is marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -337,7 +341,8 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNoZ
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationAndDecompression>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});

{
auto parameter = std::make_shared<opset10::Parameter>(element::f32, Shape{1, 16, 14, 14});
Expand Down Expand Up @@ -385,7 +390,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNoZ
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNotConstantWeights) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationNotConstantWeights) {
// Input graph:
//
// Parameter
Expand All @@ -408,7 +413,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNot
// \ /
// Convolution
//
// After MarkDequantizationAndDecompression all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also all 'Convert(DCF)' nodes from above graph are marked with 'DisableConstantFolding' attribute
// Weights and zero point nodes are marked with 'KeepConstPrecision' attribute
Expand Down Expand Up @@ -456,7 +461,8 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNot
model = std::make_shared<Model>(conv, ParameterVector{parameter});
}

manager.register_pass<pass::MarkDequantizationAndDecompression>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8, element::i8});
manager.register_pass<pass::ConstantFolding>();

{
Expand Down Expand Up @@ -511,7 +517,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationNot
comparator.enable(FunctionsComparator::CmpValues::RUNTIME_KEYS);
}

TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationFoldSubConst) {
TEST_F(TransformationTestsF, MarkDequantizationTransformationFoldSubConst) {
// Input graph: After transformation:
//
// Constant Constant Constant
Expand All @@ -525,7 +531,7 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationFol
// | / \ /
// Multiply Multiply
//
// After MarkDequantizationAndDecompression all Subtract and Multiply nodes from above graph
// After MarkDequantization all Subtract and Multiply nodes from above graph
// are marked with 'DequantizationNode' attribute.
// Also all 'Convert(DCF)' node before weights is marked with 'DisableConstantFolding' attribute
// but Convert before Dequantization Sub const isn't because fold_subtract_const is set to true
Expand All @@ -542,7 +548,8 @@ TEST_F(TransformationTestsF, MarkDequantizationAndDecompressionTransformationFol
model = std::make_shared<ov::Model>(ov::OutputVector{multiply});
}

manager.register_pass<pass::MarkDequantizationAndDecompression>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::MarkDequantization>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::KeepConstsPrecision>(element::TypeVector{element::u8}, true);
manager.register_pass<pass::ConstantFolding>();

{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,26 +13,67 @@ namespace ov {
namespace pass {
/**
* @ingroup ov_transformation_common_api
* @brief MarkDequantizationAndDecompression is a set of transformation which mark
* Dequantization and Decompression patterns with the keep_const_precision, disable_const_folding and
* dequantization attributes. Also it calls ConstantFolding.
*
* @brief MarkDequantization matches Dequantization subgraphs and marks Subtract and Multiply nodes
* with the dequantization attribute. Also if Convert nodes are part of the subgraph they might be marked
* with the disable_const_folding attribute.
*
* If Convert -> Reshape/Unsqueeze are part of the Dequantization subraph, Convert and Reshape/Unsqueeze
* nodes will be swapped to eliminate Reshape/Unsqueeze in the next ConstantFolding.
*
* Dequantization subgraph may have two forms: with and without Subtract.
* ZeroPoints and Scale might be present as subgraphs and include Convert ops.
*
* Input ZeroPoints
* │ │
* ▼ ▼
* Convert (opt) Reshape/Unsqueeze
* │ │
* ▼ ▼ Scale Input Scale
* Subtract │ │ │
* │ ▼ ▼ ▼
* │ (opt) Reshape/Unsqueeze Convert (opt) Reshape/Unsqueeze
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API MarkDequantizationAndDecompression : public ModelPass {
class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MarkDequantizationAndDecompression", "0");
explicit MarkDequantizationAndDecompression(element::TypeVector precisions,
const bool fold_subtract_const = false,
const bool fold_multiply_const = true)
: m_fold_subtract_const(fold_subtract_const),
m_fold_multiply_const(fold_multiply_const),
m_precisions(std::move(precisions)) {}

bool run_on_model(const std::shared_ptr<ov::Model>& m) override;
OPENVINO_RTTI("MarkDequantization", "0");
explicit MarkDequantization(const element::TypeVector& precisions,
bool fold_subtract_const = false,
bool fold_multiply_const = true);
};

private:
bool m_fold_subtract_const = false;
bool m_fold_multiply_const = true;
element::TypeVector m_precisions;
/**
* @ingroup ov_transformation_common_api
*
* @brief KeepConstsPrecision matches Dequantization subgraphs and if Input/ZeroPoints/Scale are Constants
* they might be marked with keep_const_precision attribute.
*
* Dequantization subgraph may have two forms: with and without Subtract.
*
* Input
* │
* ▼
* Convert ZeroPoints
* │ │
* ▼ ▼ Input
* Subtract │
* │ ▼
* │ Scale Convert Scale
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API KeepConstsPrecision : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("KeepConstsPrecision", "0");
explicit KeepConstsPrecision(const element::TypeVector& precisions,
bool fold_subtract_const = false,
bool fold_multiply_const = true);
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ bool ov::pass::CommonOptimizations::run_on_model(const std::shared_ptr<ov::Model
manager.set_per_pass_validation(false);

using namespace ov::pass;
REGISTER_PASS(manager, DisableDecompressionConvertConstantFolding)
// MOCTransformations contain StridedSliceOptimization transformation,
// so we must call SliceToStridedSlice before MOCTransformations call
REGISTER_PASS(manager, SliceToStridedSlice, true)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,11 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
manager.register_pass<ov::pass::DisableShapeOfConstantFolding>();
}

if (m_low_precision_enabled) {
manager.register_pass<ov::pass::MarkDequantization>(
element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4});
}

// RemoveConcatZeroDimInput and RemoveMultiSubGraphOpDanglingParamsResults
// should be performed before first !ConstantFolding! call.
// The passes can deteach graph branches where zero dimesion is calculated.
Expand All @@ -145,12 +150,6 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
REGISTER_PASS(manager, EliminateLoopInputsOutputs);
REGISTER_PASS(manager, Validate)

if (m_low_precision_enabled) {
// includes ConstantFolding call
manager.register_pass<ov::pass::MarkDequantizationAndDecompression>(
element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4});
}

// todo: ticket 96960
// the order EliminateDuplicateTIInputs and RemoveMultiSubGraphOpDanglingParamsResults is important
// it looks like we need to combine these transformations into one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,71 +22,6 @@ using namespace ov;
using namespace ov::op;
using namespace ov::pass::pattern;

/**
* @ingroup ov_transformation_common_api
*
* @brief MarkDequantization matches Dequantization subgraphs and marks Subtract and Multiply nodes
* with the dequantization attribute. Also if Convert nodes are part of the subgraph they might be marked
* with the disable_const_folding attribute.
*
* If Convert -> Reshape/Unsqueeze are part of the Dequantization subraph, Convert and Reshape/Unsqueeze
* nodes will be swapped to eliminate Reshape/Unsqueeze in the next ConstantFolding.
*
* Dequantization subgraph may have two forms: with and without Subtract.
* ZeroPoints and Scale might be present as subgraphs and include Convert ops.
*
* Input ZeroPoints
* │ │
* ▼ ▼
* Convert (opt) Reshape/Unsqueeze
* │ │
* ▼ ▼ Scale Input Scale
* Subtract │ │ │
* │ ▼ ▼ ▼
* │ (opt) Reshape/Unsqueeze Convert (opt) Reshape/Unsqueeze
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("MarkDequantization", "0");
explicit MarkDequantization(const element::TypeVector& precisions,
bool fold_subtract_const,
bool fold_multiply_const);
};

/**
* @ingroup ov_transformation_common_api
*
* @brief KeepConstsPrecision matches Dequantization subgraphs and if Input/ZeroPoints/Scale are Constants
* they might be marked with keep_const_precision attribute.
*
* Dequantization subgraph may have two forms: with and without Subtract.
*
* Input
* │
* ▼
* Convert ZeroPoints
* │ │
* ▼ ▼ Input
* Subtract │
* │ ▼
* │ Scale Convert Scale
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API KeepConstsPrecision : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("KeepConstsPrecision", "0");
explicit KeepConstsPrecision(const element::TypeVector& precisions,
bool fold_subtract_const,
bool fold_multiply_const);
};

namespace {

bool check_precision(const ov::element::Type_t type_to_check, const ov::element::TypeVector& precisions) {
Expand Down Expand Up @@ -129,9 +64,9 @@ void swap_nodes(const PatternValueMap& pt_map,

} // namespace

MarkDequantization::MarkDequantization(const element::TypeVector& precisions,
const bool fold_subtract_const,
const bool fold_multiply_const) {
ov::pass::MarkDequantization::MarkDequantization(const element::TypeVector& precisions,
const bool fold_subtract_const,
const bool fold_multiply_const) {
// data input:
auto input_pattern = any_input();
auto convert_pattern = wrap_type<v0::Convert>({input_pattern}, consumers_count(1));
Expand Down Expand Up @@ -191,9 +126,9 @@ MarkDequantization::MarkDequantization(const element::TypeVector& precisions,
this->register_matcher(m, callback);
}

KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions,
bool fold_subtract_const,
bool fold_multiply_const) {
ov::pass::KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions,
bool fold_subtract_const,
bool fold_multiply_const) {
// data input:
auto input_pattern = any_input();
auto convert_pattern = wrap_type<v0::Convert>({input_pattern}, consumers_count(1));
Expand Down Expand Up @@ -239,17 +174,3 @@ KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions,
auto m = std::make_shared<Matcher>(multiply_pattern, "KeepConstsPrecision");
this->register_matcher(m, callback);
}

bool pass::MarkDequantizationAndDecompression::run_on_model(const std::shared_ptr<ov::Model>& m) {
const auto& pass_config = get_pass_config();
auto callback = pass_config->get_callback<MarkDequantizationAndDecompression>();
pass_config->set_callback<MarkDequantization>(callback);
pass_config->set_callback<KeepConstsPrecision>(callback);

ov::pass::Manager manager(pass_config, "MarkDequantizationAndDecompressionManager");
manager.register_pass<DisableDecompressionConvertConstantFolding>();
manager.register_pass<MarkDequantization>(m_precisions, m_fold_subtract_const, m_fold_multiply_const);
manager.register_pass<ConstantFolding>();
manager.register_pass<KeepConstsPrecision>(m_precisions, m_fold_subtract_const, m_fold_multiply_const);
return manager.run_passes(m);
}
Loading

0 comments on commit 06f1c22

Please sign in to comment.