Skip to content

Commit

Permalink
fix issue on gpu, docs, refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Nov 19, 2024
1 parent e7d7c5f commit 189153f
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 35 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,11 @@

namespace ov {
namespace pass {

/**
* @ingroup ov_transformation_common_api
* @brief TBA
* @brief MarkDequantizationAndDecompression is a set of transformation which mark
* Dequantization and Decompression patterns with the keep_const_precision, disable_const_folding and
* dequantization attributes. Also it calls ConstantFolding.
*/
class TRANSFORMATIONS_API MarkDequantizationAndDecompression : public ModelPass {
public:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,12 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
using namespace ov::pass;
REGISTER_PASS(manager, InitNodeInfo)
REGISTER_PASS(manager, EliminateConvert)
if (m_low_precision_enabled) {
manager.register_pass<ov::pass::MarkDequantizationAndDecompression>(
element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4});
}
if (!m_use_shapes) {
manager.register_pass<ov::pass::DisableShapeOfConstantFolding>();
}

// RemoveConcatZeroDimInput and RemoveMultiSubGraphOpDanglingParamsResults
// should be performed before first ConstantFolding call.
// should be performed before first !ConstantFolding! call.
// The passes can deteach graph branches where zero dimesion is calculated.
// Zero dimensions in shape causes creation empty tensors, which are incorrect during CF.
// In particular, if zero dim tensor is consumed in body of MultiSubGraphOp
Expand All @@ -147,6 +144,13 @@ bool ov::pass::MOCTransformations::run_on_model(const std::shared_ptr<ov::Model>
REGISTER_PASS(manager, RemoveConcatZeroDimInput)
REGISTER_PASS(manager, EliminateLoopInputsOutputs);
REGISTER_PASS(manager, Validate)

if (m_low_precision_enabled) {
// includes ConstantFolding call
manager.register_pass<ov::pass::MarkDequantizationAndDecompression>(
element::TypeVector{ov::element::i8, ov::element::u8, ov::element::i4, ov::element::u4});
}

// todo: ticket 96960
// the order EliminateDuplicateTIInputs and RemoveMultiSubGraphOpDanglingParamsResults is important
// it looks like we need to combine these transformations into one.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,30 @@ using namespace ov::pass::pattern;

/**
* @ingroup ov_transformation_common_api
* @brief TBA
*
* @brief MarkDequantization matches Dequantization subgraphs and marks Subtract and Multiply nodes
* with the dequantization attribute. Also if Convert nodes are part of the subgraph they might be marked
* with the disable_const_folding attribute.
*
* If Convert -> Reshape/Unsqueeze are part of the Dequantization subraph, Convert and Reshape/Unsqueeze
* nodes will be swapped to eliminate Reshape/Unsqueeze in the next ConstantFolding.
*
* Dequantization subgraph may have two forms: with and without Subtract.
* ZeroPoints and Scale might be present as subgraphs and include Convert ops.
*
* Input ZeroPoints
* │ │
* ▼ ▼
* Convert (opt) Reshape/Unsqueeze
* │ │
* ▼ ▼ Scale Input Scale
* Subtract │ │ │
* │ ▼ ▼ ▼
* │ (opt) Reshape/Unsqueeze Convert (opt) Reshape/Unsqueeze
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass {
public:
Expand All @@ -36,7 +59,25 @@ class TRANSFORMATIONS_API MarkDequantization : public ov::pass::MatcherPass {

/**
* @ingroup ov_transformation_common_api
* @brief TBA
*
* @brief KeepConstsPrecision matches Dequantization subgraphs and if Input/ZeroPoints/Scale are Constants
* they might be marked with keep_const_precision attribute.
*
* Dequantization subgraph may have two forms: with and without Subtract.
*
* Input
* │
* ▼
* Convert ZeroPoints
* │ │
* ▼ ▼ Input
* Subtract │
* │ ▼
* │ Scale Convert Scale
* │ │ │ │
* ▼ ▼ ▼ ▼
* Multiply Multiply
*
*/
class TRANSFORMATIONS_API KeepConstsPrecision : public ov::pass::MatcherPass {
public:
Expand Down Expand Up @@ -91,16 +132,7 @@ void swap_nodes(const PatternValueMap& pt_map,
MarkDequantization::MarkDequantization(const element::TypeVector& precisions,
const bool fold_subtract_const,
const bool fold_multiply_const) {
// Dequantization subgraph may have two forms: with and without Subtract
//
// Input Input
// | |
// Convert zero point OR Convert scale
// \ / \ /
// Subtract scale Multiply
// \ /
// Multiply
//
// data input:
auto input_pattern = any_input();
auto convert_pattern = wrap_type<v0::Convert>({input_pattern}, consumers_count(1));

Expand Down Expand Up @@ -134,8 +166,7 @@ MarkDequantization::MarkDequantization(const element::TypeVector& precisions,
NodeVector converts_to_mark = {convert_pattern};
NodeVector converts_to_unmark = {};

if (fold_subtract_const ||
(pt_map.count(subtract_pattern) && pt_map.at(zp_pattern).get_element_type() != input.get_element_type())) {
if (fold_subtract_const) {
converts_to_unmark.push_back(zp_convert_pattern);
} else {
converts_to_mark.push_back(zp_convert_pattern);
Expand Down Expand Up @@ -163,16 +194,7 @@ MarkDequantization::MarkDequantization(const element::TypeVector& precisions,
KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions,
bool fold_subtract_const,
bool fold_multiply_const) {
// Dequantization subgraph may have two forms: with and without Subtract
//
// Input Input
// | |
// Convert zero point OR Convert scale
// \ / \ /
// Subtract scale Multiply
// \ /
// Multiply
//
// data input:
auto input_pattern = any_input();
auto convert_pattern = wrap_type<v0::Convert>({input_pattern}, consumers_count(1));

Expand All @@ -194,9 +216,10 @@ KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions,
return false;
}

std::map<std::shared_ptr<Node>, bool> keep_const_precisions = {{input_pattern, false},
{zp_pattern, fold_subtract_const},
{scale_pattern, fold_multiply_const}};
using PatternNode = std::shared_ptr<Node>;
std::map<PatternNode, bool> keep_const_precisions = {{input_pattern, false},
{zp_pattern, fold_subtract_const},
{scale_pattern, fold_multiply_const}};
for (const auto& pattern_node : keep_const_precisions) {
if (pt_map.count(pattern_node.first)) {
auto node = pt_map.at(pattern_node.first).get_node_shared_ptr();
Expand All @@ -218,7 +241,12 @@ KeepConstsPrecision::KeepConstsPrecision(const element::TypeVector& precisions,
}

bool pass::MarkDequantizationAndDecompression::run_on_model(const std::shared_ptr<ov::Model>& m) {
ov::pass::Manager manager("MarkDequantizationAndDecompressionManager");
const auto& pass_config = get_pass_config();
auto callback = pass_config->get_callback<MarkDequantizationAndDecompression>();
pass_config->set_callback<MarkDequantization>(callback);
pass_config->set_callback<KeepConstsPrecision>(callback);

ov::pass::Manager manager(pass_config, "MarkDequantizationAndDecompressionManager");
manager.register_pass<DisableDecompressionConvertConstantFolding>();
manager.register_pass<MarkDequantization>(m_precisions, m_fold_subtract_const, m_fold_multiply_const);
manager.register_pass<ConstantFolding>();
Expand Down

0 comments on commit 189153f

Please sign in to comment.