Skip to content

Commit

Permalink
Ensure original hw limitations are applied for compression
Browse files Browse the repository at this point in the history
  • Loading branch information
EgorDuplensky committed Nov 25, 2024
1 parent 723877b commit d2e0493
Show file tree
Hide file tree
Showing 4 changed files with 63 additions and 21 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include "openvino/pass/matcher_pass.hpp"
#include "ov_ops/fully_connected.hpp"
#include "transformations_visibility.hpp"

namespace ov {
Expand All @@ -17,9 +18,12 @@ class TRANSFORMATIONS_API ConvertFullyConnectedToFullyConnectedCompressed;

class ov::pass::ConvertFullyConnectedToFullyConnectedCompressed : public ov::pass::MatcherPass {
public:
using SupportsPredicate =
std::function<bool(const std::shared_ptr<ov::op::internal::FullyConnected>&, size_t, size_t, size_t)>;

OPENVINO_RTTI("ConvertFullyConnectedToFullyConnectedCompressed", "0");
ConvertFullyConnectedToFullyConnectedCompressed(
const std::vector<ov::element::Type>& supported_compression_types,
std::function<bool(size_t, size_t, size_t)> supports_config = nullptr,
bool convert_u4zp_to_u8 = false);
ConvertFullyConnectedToFullyConnectedCompressed(const std::vector<ov::element::Type>& supported_activation_types,
const std::vector<ov::element::Type>& supported_weights_types,
SupportsPredicate supports_config = nullptr,
bool convert_u4zp_to_u8 = false);
};
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,9 @@
#include "transformations/utils/utils.hpp"

ov::pass::ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnectedToFullyConnectedCompressed(
const std::vector<ov::element::Type>& supported_compression_types,
std::function<bool(size_t, size_t, size_t)> supports_config,
const std::vector<ov::element::Type>& supported_activation_types,
const std::vector<ov::element::Type>& supported_weights_types,
SupportsPredicate supports_config,
bool convert_u4zp_to_u8) {
using namespace ov::pass::pattern;

Expand All @@ -35,7 +36,8 @@ ov::pass::ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnected
return in_ps.rank().is_static() && out_ps.rank().is_static() && in_ps.size() == 3 && out_ps.size() == 2;
};

auto weights_m = wrap_type<ov::op::v0::Constant>(ov::pass::pattern::type_matches_any(supported_compression_types));
auto activation_m = any_input(ov::pass::pattern::type_matches_any(supported_activation_types));
auto weights_m = wrap_type<ov::op::v0::Constant>(ov::pass::pattern::type_matches_any(supported_weights_types));
auto convert_m = wrap_type<ov::op::v0::Convert>({weights_m});

auto sub_const_m = wrap_type<ov::op::v0::Constant>();
Expand All @@ -59,10 +61,9 @@ ov::pass::ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnected
auto transpose_const_m = wrap_type<ov::op::v0::Constant>();
auto transpose_m = wrap_type<ov::op::v1::Transpose>({transpose_input, transpose_const_m});

auto data_m = any_input();
auto bias_m = any_input();
auto weights_input_m = std::make_shared<ov::pass::pattern::op::Or>(ov::OutputVector{reshape_m, transpose_m, mul_m});
auto fully_connected_m = wrap_type<ov::op::internal::FullyConnected>({data_m, weights_input_m, bias_m});
auto fully_connected_m = wrap_type<ov::op::internal::FullyConnected>({activation_m, weights_input_m, bias_m});

ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand All @@ -89,7 +90,7 @@ ov::pass::ConvertFullyConnectedToFullyConnectedCompressed::ConvertFullyConnected

const size_t G = grouped ? (has_transpose ? *(scale_shape.rbegin() + 2) : *(scale_shape.rbegin() + 1)) : 1;

if (supports_config && !supports_config(IC, OC, G))
if (supports_config && !supports_config(fc, IC, OC, G))
return false;

auto reshape_const_to_2d = [has_transpose, grouped](std::shared_ptr<ov::Node> node) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@
#include "common/pass/rnn_sequences_optimization.hpp"
#include "transformations/common_optimizations/reshape_sequence_fusion.hpp"
#include "transformations/defs.hpp"
#include "config.h"

#if defined(OPENVINO_ARCH_X86_64)
#include "cpu/x64/cpu_isa_traits.hpp"
#endif

#include "itt.hpp"

namespace ov {
namespace intel_cpu {

inline void ConvertToCPUSpecificOpset(std::shared_ptr<ov::Model> &model) {
inline void ConvertToCPUSpecificOpset(std::shared_ptr<ov::Model> &model, const Config& config) {
RUN_ON_FUNCTION_SCOPE(ConvertToCPUSpecificOpset);

ov::pass::Manager manager("CPU:ConvertToCPUSpecificOpset");
Expand All @@ -42,7 +47,13 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ov::Model> &model) {

CPU_REGISTER_PASS_COMMON(manager, FullyConnectedBiasFusion);

std::vector<ov::element::Type> supported_compression_types {
std::vector<ov::element::Type> supported_activation_types {
// @todo enable for bf16 as well
// after EnforceInferencePrecision is replaced with ConvertPrecision
ov::element::f32,
};

std::vector<ov::element::Type> supported_compressed_weights_types {
ov::element::u8,
ov::element::i8,
ov::element::u4,
Expand All @@ -51,14 +62,40 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ov::Model> &model) {
ov::element::f4e2m1,
};

CPU_REGISTER_PASS_X64(manager, pass::ConvertFullyConnectedToFullyConnectedCompressed,
supported_compression_types,
[](size_t IC, size_t OC, size_t G) {
if (IC % G != 0 || IC / G < 4 || OC == 1) {
return false;
}
return true;
});
#if defined(OPENVINO_ARCH_X86_64)
// @todo introduce something like CPU_REGISTER_PASS_X64_AVX2
const bool isDecompressionSupported = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2);
if (isDecompressionSupported) {
CPU_REGISTER_PASS_X64(
manager,
pass::ConvertFullyConnectedToFullyConnectedCompressed,
supported_activation_types,
supported_compressed_weights_types,
[&config](const std::shared_ptr<ov::op::internal::FullyConnected>& fc, size_t IC, size_t OC, size_t G) {
// @todo replace 'inferencePrecision' check with 'fc->get_input_element_type(0) == ov::element::bf16'
// after bf16 pipeline is moved to ConvertPrecision
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx) &&
config.inferencePrecision == ov::element::bf16) {
// OneDNN AMX IP implementation has limited shapes support due to performance considerations. As a
// current solution conditions below are copied from OneDNN to make sure correct IP impl will be
// used since fallback one doesn't support weights decompression feature.
size_t simdWidth = 16;
size_t vnniFactor = 2;
size_t maxSize = 512;
auto amxRow = vnniFactor * simdWidth;

if ((IC <= amxRow && OC <= amxRow) || (IC <= maxSize && OC <= maxSize && IC % amxRow != 0)) {
return false;
}
}

if (IC % G != 0 || IC / G < 4 || OC == 1) {
return false;
}
return true;
});
}
#endif // OPENVINO_ARCH_X86_64

CPU_REGISTER_PASS_X64(manager, pass::ConvertFCToFCQuantizedLegacy);
if (std::getenv("EXTRA_DUMP")) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -320,7 +320,7 @@ void Transformations::UpToLpt() {
void Transformations::CpuSpecificOpSet(void) {
CPU_DEBUG_CAP_TRANSFORMATION_SCOPE(this, Specific);

ConvertToCPUSpecificOpset(model);
ConvertToCPUSpecificOpset(model, config);
}

void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecisions) {
Expand Down

0 comments on commit d2e0493

Please sign in to comment.