Skip to content

Commit

Permalink
[GPU] Add callback function for MarkDequantizationSubgraph pass to sk…
Browse files Browse the repository at this point in the history
…ip non FC-related subgraphs
  • Loading branch information
sshlyapn committed Dec 8, 2023
1 parent a84d615 commit 2a756be
Showing 1 changed file with 31 additions and 0 deletions.
31 changes: 31 additions & 0 deletions src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,35 @@ static bool disable_reduce_decomposition(const std::shared_ptr<const ov::Node> n
}
return false;
}

static bool is_non_decompression_multiply(const std::shared_ptr<const ov::Node> node) {
auto get_single_consumer = [](const std::shared_ptr<const ov::Node> node) -> std::shared_ptr<ov::Node> {
const auto consumers = node->get_output_target_inputs(0);
if (consumers.size() != 1)
return nullptr;
return consumers.begin()->get_node()->shared_from_this();
};

auto consumer = get_single_consumer(node);
if (!consumer)
return true;

if (ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
} else if (ov::is_type<ov::opset1::Reshape>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
}
}
if (consumer != nullptr && ov::is_type<ov::opset1::Convert>(consumer)) {
consumer = get_single_consumer(consumer);
if (consumer != nullptr && ov::is_type<ov::opset1::MatMul>(consumer)) {
return false;
}
}
return true;
}
} // namespace

namespace ov {
Expand Down Expand Up @@ -247,6 +276,8 @@ void TransformationsPipeline::apply(std::shared_ptr<ov::Model> func) {
});

manager.register_pass<ov::pass::MarkDequantizationSubgraph>(ov::element::TypeVector{ov::element::u8, ov::element::u4, ov::element::i4}, true);
// Ignore nodes that are not related to FullyConnected and allow ConstantFolding to be applied to them
pass_config->set_callback<ov::pass::MarkDequantizationSubgraph>(is_non_decompression_multiply);

const bool keep_precision_sensitive_in_fp32_1 = true;
const bool convert_input_output_precision = false;
Expand Down

0 comments on commit 2a756be

Please sign in to comment.