Skip to content

Commit

Permalink
f16_mha_on_avx512_core_amx_f16_target
Browse files Browse the repository at this point in the history
  • Loading branch information
chenhu-wang committed Nov 12, 2024
1 parent 339a091 commit e1f3fc6
Show file tree
Hide file tree
Showing 12 changed files with 100 additions and 16 deletions.
14 changes: 14 additions & 0 deletions src/common/low_precision_transformations/src/low_precision.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,21 @@ bool ov::pass::low_precision::LowPrecision::run_on_model(const std::shared_ptr<o
ADD_MATCHER(common, ConvolutionTransformation, params)
ADD_MATCHER(common, ConvolutionBackpropDataTransformation, params)
ADD_MATCHER(common, DepthToSpaceTransformation, params)

// manager.run_passes(f);
// ov::pass::Manager mgr;
// std::string xml = "FakeQuantizeDecompositionTransformation_before.xml";
// std::string bin = "FakeQuantizeDecompositionTransformation_before.bin";
// mgr.register_pass<ov::pass::Serialize>(xml, bin);
// mgr.run_passes(f);
ADD_MATCHER(common, FakeQuantizeDecompositionTransformation, params)
// manager.run_passes(f);
// ov::pass::Manager mgr1;
// std::string xml1 = "FakeQuantizeDecompositionTransformation_after.xml";
// std::string bin1 = "FakeQuantizeDecompositionTransformation_after.bin";
// mgr1.register_pass<ov::pass::Serialize>(xml1, bin1);
// mgr1.run_passes(f);

ADD_MATCHER(common, FakeQuantizeTransformation, params)
ADD_MATCHER(common, InterpolateTransformation, params)
ADD_MATCHER(common, GroupConvolutionTransformation, params)
Expand Down
3 changes: 2 additions & 1 deletion src/common/snippets/src/op/brgemm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ ov::element::Type Brgemm::get_output_type(const ov::element::Type& in_type0, con
const bool is_f32 = utils::everyone_is(element::f32, in_type0, in_type1);
const bool is_int8 = utils::one_of(in_type0, element::i8, element::u8) && in_type1 == element::i8;
const bool is_bf16 = utils::everyone_is(element::bf16, in_type0, in_type1);
if (is_f32 || is_bf16) {
const bool is_f16 = utils::everyone_is(element::f16, in_type0, in_type1);
if (is_f32 || is_bf16 || is_f16) {
return element::f32;
} else if (is_int8) {
return element::i32;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ class jit_brgemm_copy_b_emitter : public jit_emitter {

size_t get_inputs_num() const override {return 1;}
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr) {
return {{element::i8}, {element::bf16}, {element::f32}};
return {{element::i8}, {element::bf16}, {element::f16}, {element::f32}};
}

private:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ std::set<std::vector<element::Type>> jit_brgemm_emitter::get_supported_precision
return {{element::i8, element::i8, element::u8},
{element::u8, element::i8, element::u8},
{element::bf16, element::bf16, element::u8}};
} else if (brgemm->get_type() == BRGEMM_TYPE::WITH_AMX_F16) {
return {{element::i8, element::i8, element::u8},
{element::u8, element::i8, element::u8},
{element::bf16, element::bf16, element::u8},
{element::f16, element::f16, element::u8}};
}
OV_CPU_JIT_EMITTER_THROW("got BrgemmCPU node with unsupported type");
}
Expand Down
19 changes: 16 additions & 3 deletions src/plugins/intel_cpu/src/nodes/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,13 @@ Subgraph::Subgraph(const std::shared_ptr<ov::Node>& op, const GraphContext::CPtr
#endif
const auto& tmp_snippet = ov::as_type_ptr<snippets::op::Subgraph>(op);
OPENVINO_ASSERT(tmp_snippet, "Attempt to create Subgraph node from an invalid op type");
// if (tmp_snippet->get_friendly_name() == "DequantizeLinear_172_original") {
// ov::pass::Manager mgr;
// std::string xml = "DequantizeLinear_172_original.xml";
// std::string bin = "DequantizeLinear_172_original.bin";
// mgr.register_pass<ov::pass::Serialize>(xml, bin);
// mgr.run_passes(tmp_snippet->body_ptr());
// }
subgraph_attrs->snippet = tmp_snippet->clone();
subgraph_attrs->bodyHash = getBodyHash(tmp_snippet);

Expand Down Expand Up @@ -449,11 +456,16 @@ void Subgraph::initSupportedPrimitiveDescriptors() {
config.inConfs.resize(inputShapes.size());
for (size_t i = 0; i < inputShapes.size(); i++) {
const auto originalInputPrecision = getOriginalInputPrecisionAtPort(i);
const auto precision = ((originalInputPrecision == ov::element::f32) &&
auto precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::bf16 &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
static_cast<ov::element::Type>(ov::element::bf16) :
originalInputPrecision;
precision = ((originalInputPrecision == ov::element::f32) &&
context->getConfig().inferencePrecision == ov::element::f16 &&
subgraph_attrs->snippet->has_domain_sensitive_ops()) ?
static_cast<ov::element::Type>(ov::element::f16) :
precision;
if (supportedPrecisions.count(precision) == 0)
OPENVINO_THROW("Subgraph node with name `", getName(), "` doesn't support ", precision, " precision.");

Expand Down Expand Up @@ -638,13 +650,14 @@ Subgraph::DataFlowPasses Subgraph::getDataFlowPasses() {
SNIPPETS_REGISTER_PASS_ABSOLUTE_COMMON(Place::PipelineStart, ConvertToSwishCPU);
SNIPPETS_REGISTER_PASS_RELATIVE_COMMON(Place::After, ov::snippets::pass::Canonicalization,
ov::snippets::pass::AnalyzeBroadcastableInputs, broadcastable_inputs);
if (context->getConfig().inferencePrecision == ov::element::bf16 && subgraph_attrs->snippet->has_domain_sensitive_ops()) {
if ((context->getConfig().inferencePrecision == ov::element::bf16 || context->getConfig().inferencePrecision == ov::element::f16)
&& subgraph_attrs->snippet->has_domain_sensitive_ops()) {
// enforce BF16 precisions to supported operations
// MatMul has to be decomposed to Brgemm operations before enforcement
// Note, MatMul decomposition will be run later again for case if BF16 enforcement is not happened
SNIPPETS_REGISTER_PASS_ABSOLUTE_X86_64(Place::PipelineStart, ov::snippets::pass::MatMulToBrgemm);
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::After, ov::snippets::pass::MatMulToBrgemm,
pass::EnforcePrecision, element::f32, element::bf16);
pass::EnforcePrecision, element::f32, context->getConfig().inferencePrecision);
}
SNIPPETS_REGISTER_PASS_RELATIVE_X86_64(Place::Before, ov::snippets::pass::PropagatePrecision,
ov::intel_cpu::pass::BrgemmToBrgemmCPU);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ void BrgemmCopyB::validate_and_infer_types() {
}

void BrgemmCopyB::validate_element_type(const ov::element::Type& element_type) {
OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::i8),
OPENVINO_ASSERT(one_of(element_type, element::f32, element::bf16, element::f16, element::i8),
"BrgemmCopyB doesn't support element type" + element_type.get_type_name());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,8 @@ std::shared_ptr<Node> BrgemmCPU::clone_with_new_inputs(const OutputVector& new_a
}

std::shared_ptr<BrgemmCopyB> BrgemmCPU::get_brgemm_copy() const {
OPENVINO_ASSERT(one_of(m_type, BRGEMM_TYPE::REPACKING_ONLY, BRGEMM_TYPE::WITH_COMPENSATIONS, BRGEMM_TYPE::WITH_AMX), "Brgemm doesn't need BrgemmCopyB");
OPENVINO_ASSERT(one_of(m_type, BRGEMM_TYPE::REPACKING_ONLY, BRGEMM_TYPE::WITH_COMPENSATIONS, BRGEMM_TYPE::WITH_AMX,
BRGEMM_TYPE::WITH_AMX_F16), "Brgemm doesn't need BrgemmCopyB");
auto b_input_node = get_input_node_shared_ptr(1);
if (const auto brgemm_copy_b = ov::as_type_ptr<BrgemmCopyB>(b_input_node)) {
return brgemm_copy_b;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,10 @@ cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx) {

// Note: AMX might be not used even if it's supported by the hardware, check the BrgemmToBrgemmCPU pass for details
if (is_with_amx) {
SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
if (dt_in0 == ov::element::f16)
SUPPORT_ONE(avx512_core_amx_fp16, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
else
SUPPORT_ONE(avx512_core_amx, "Unsupported hardware configuration: amx is supported only on avx512 platforms")
} else if (dt_in0 == ov::element::bf16) {
SUPPORT_ONE(avx512_core_bf16, "Unsupported hardware configuration: bf16 is supported only on avx512 platforms")
} else if (one_of(dt_in0, ov::element::u8, ov::element::i8)) {
Expand All @@ -53,6 +56,10 @@ BRGEMM_TYPE get_brgemm_type(const ov::element::Type& element_type_a, const Dimen
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx) &&
K_dim.is_static() && K_dim.get_length() % brgemmVNNIFactor == 0)
return BRGEMM_TYPE::WITH_AMX;
if (element_type_a == ov::element::f16 &&
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
K_dim.is_static() && K_dim.get_length() % brgemmVNNIFactor == 0)
return BRGEMM_TYPE::WITH_AMX_F16;
// Note: this condition reproduces logic from the OneDNN Brgemm implementation. This is needed to align with the
// backend requirements. More details in onednn/src/cpu/x64/brgemm/brgemm_utils.cpp
if (element_type_a == ov::element::i8)
Expand All @@ -79,6 +86,7 @@ size_t compute_inner_n_block(const ov::element::Type& precision) {
switch (precision) {
case element::i8: return 64;
case element::bf16: return 32;
case element::f16: return 32;
case element::f32: return 16;
default: OPENVINO_THROW("BrgemmCopyB doesn't support precision ", precision);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,8 @@ enum class BRGEMM_TYPE {
STAND_ALONE, // No extra requirements, used for f32|f32
WITH_AMX, // i8|i8 or bf16|bf16 on AMX system - needs BrgemmCopyB and scratchpad
WITH_COMPENSATIONS, // i8|i8 (non-AMX system) - needs BrgemmCopyB for data repacking and compensations
REPACKING_ONLY // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking
REPACKING_ONLY, // u8|i8 or bf16|bf16 (non-AMX system) - needs BrgemmCopyB on second input for data repacking
WITH_AMX_F16 // i8|i8 or bf16|bf16 or f16|f16 on AMX system - needs BrgemmCopyB and scratchpad
};

dnnl::impl::cpu::x64::cpu_isa_t get_primitive_isa(const ov::element::Type& dt_in0, bool is_with_amx);
Expand All @@ -27,7 +28,7 @@ BRGEMM_TYPE get_brgemm_type(const element::Type& element_type_a, const Dimension

inline bool stand_alone(BRGEMM_TYPE type) { return type == BRGEMM_TYPE::STAND_ALONE; }

inline bool with_amx(BRGEMM_TYPE type) { return type == BRGEMM_TYPE::WITH_AMX; }
inline bool with_amx(BRGEMM_TYPE type) { return type == BRGEMM_TYPE::WITH_AMX || type == BRGEMM_TYPE::WITH_AMX_F16; }

inline bool with_compensations(BRGEMM_TYPE type) { return type == BRGEMM_TYPE::WITH_COMPENSATIONS; }

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,10 @@ pass::BrgemmToBrgemmCPU::BrgemmToBrgemmCPU() {
const auto& layout_c = brgemm_out_desc->get_layout();

const auto element_type_a = brgemm->get_input_element_type(0);
std::cout << "element_type_a is:" << element_type_a << std::endl;
const bool transpose_b = !layout_b.empty() && layout_b.back() != layout_b.size() - 1;
const auto brgemm_type = brgemm_utils::get_brgemm_type(element_type_a, K, transpose_b);
std::cout << "brgemm_type is:" << static_cast<int>(brgemm_type) << std::endl;
const auto offset_a = brgemm->get_offset_a();
const auto offset_b = brgemm->get_offset_b();
const auto offset_c = brgemm->get_offset_c();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -817,8 +817,25 @@ void Transformations::PostLpt() {
return node->get_rt_info().count("UNROLL_TI") == 0;
},
ov::pass::UnrollTensorIterator);

// postLPTPassManager.run_passes(model);
// ov::pass::Manager mgr1;
// std::string xml1 = "MoveEltwiseUpThroughDataMov_b.xml";
// std::string bin1 = "MoveEltwiseUpThroughDataMov_b.bin";
// mgr1.register_pass<ov::pass::Serialize>(xml1, bin1);
// mgr1.run_passes(model);

CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::MoveEltwiseUpThroughDataMov);

// postLPTPassManager.run_passes(model);
// ov::pass::Manager mgr;
// std::string xml = "MoveEltwiseUpThroughDataMov_a.xml";
// std::string bin = "MoveEltwiseUpThroughDataMov_a.bin";
// mgr.register_pass<ov::pass::Serialize>(xml, bin);
// mgr.run_passes(model);

CPU_DISABLE_PASS_COMMON(postLPTPassManager, ov::pass::MoveEltwiseUpThroughDataMovPerChannel);

CPU_SET_CALLBACK_COMMON(postLPTPassManager,
[](const std::shared_ptr<const ov::Node>& node) -> bool {
if (!ov::is_type<const ov::op::v0::FakeQuantize>(node) && node->get_output_element_type(0) != node->get_input_element_type(0))
Expand All @@ -831,6 +848,7 @@ void Transformations::PostLpt() {
return false;
},
ov::pass::MoveEltwiseUpThroughDataMovScalar);

CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::Validate);

CPU_REGISTER_PASS_COMMON(postLPTPassManager, ov::pass::ConstantFolding);
Expand Down Expand Up @@ -959,6 +977,7 @@ void Transformations::MainSnippets(void) {

ov::pass::Manager snippetsManager("CPU:Snippets");
snippetsManager.set_per_pass_validation(false);
// if callback needed for better perf, enable SnippetsMarkSkipped, and disable TokenizeFCSnippets.
if (!ignoreCallback) {
#if defined(OPENVINO_ARCH_ARM64)
CPU_REGISTER_PASS_ARM(snippetsManager, SnippetsMarkSkipped);
Expand All @@ -979,7 +998,9 @@ void Transformations::MainSnippets(void) {
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2) &&
one_of(config.inferencePrecision, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) &&
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined));
one_of(config.inferencePrecision, ov::element::bf16, ov::element::f32, element::undefined)) ||
(dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16) &&
one_of(config.inferencePrecision, ov::element::f16));
#endif
if (!isMHASupported) {
CPU_DISABLE_PASS_COMMON(snippetsManager, snippets::pass::TokenizeMHASnippets);
Expand All @@ -995,12 +1016,11 @@ void Transformations::MainSnippets(void) {
const auto in_type1 = matmul->get_input_element_type(1);
const auto is_fp32 = (in_type0 == ov::element::f32 && in_type1 == ov::element::f32 &&
one_of(config.inferencePrecision, element::f32, element::undefined));
const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16);
const auto is_fp16 = (in_type0 == ov::element::f16 || in_type1 == ov::element::f16) ||
((in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::f16));
const auto is_bf16 = (in_type0 == ov::element::bf16 && in_type1 == ov::element::bf16) ||
((in_type0 == element::f32 && in_type1 == ov::element::f32 && config.inferencePrecision == ov::element::bf16));
const auto is_int8 = in_type0 == ov::element::i8;
if (is_fp16)
return false;
if (is_fp32)
return true;
// Only FP32 dynamic MHA is supported
Expand All @@ -1015,14 +1035,16 @@ void Transformations::MainSnippets(void) {
if (dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx)) {
const auto& b_shape = matmul->get_input_partial_shape(1);
const auto K = matmul->get_transpose_b() ? *b_shape.rbegin() : *++b_shape.rbegin();
if (is_bf16) return K.is_static() && (K.get_length() % 2 == 0);
if (is_bf16 || is_fp16) return K.is_static() && (K.get_length() % 2 == 0);
if (is_int8) return K.is_static();
}
if (is_int8)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_vnni) ||
dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx2_vnni);
if (is_bf16)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_bf16);
if (is_fp16)
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core_amx_fp16);
return true;
};
auto is_unsupported_parallel_work_amount = [&](const std::shared_ptr<const ov::Node>& n, const ov::PartialShape& shape) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,7 @@ class MHATest : public testing::WithParamInterface<MHATuple>, virtual public Sub
for (size_t i = 0; i < funcInputs.size(); ++i) {
const auto& funcInput = funcInputs[i];
ov::Tensor tensor;
if (funcInput.get_element_type() == ov::element::bf16) {
if (funcInput.get_element_type() == ov::element::bf16 || funcInput.get_element_type() == ov::element::f16) {
ov::test::utils::InputGenerateData in_data;
in_data.start_from = -1;
in_data.range = 2;
Expand Down Expand Up @@ -232,6 +232,9 @@ class MHATest : public testing::WithParamInterface<MHATuple>, virtual public Sub
configuration.insert({ov::hint::inference_precision(ov::element::bf16)});
}

if (inputPrecisions[0] == ElementType::f16)
configuration.insert({ov::hint::inference_precision(ov::element::f16)});

// Snippets MHA tokenization has limitations to avoid performance degradations. These limitations depend on
// target machine. Just for testing, we disable these limitations to allow Snippets to tokenize pattern on all
// machines for validation.
Expand Down Expand Up @@ -308,6 +311,20 @@ INSTANTIATE_TEST_SUITE_P(
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

INSTANTIATE_TEST_SUITE_P(
smoke_MHA_FP16,
MHATest,
::testing::Combine(
::testing::ValuesIn(static_shapes_to_test_representation(inputShapes)),
::testing::Values(
std::vector<ElementType>{ElementType::f16, ElementType::f16, ElementType::f16, ElementType::f16}),
::testing::ValuesIn(matMulIn0Precisions),
::testing::ValuesIn(patternTypes),
::testing::Values(ExpectedNodes{{"Subgraph", 1},
{"Transpose", 1}}), // Plugin disables tokenization of Transpose on output
::testing::Values(ov::test::utils::DEVICE_CPU)),
MHATest::getTestCaseName);

} // namespace

static std::shared_ptr<ov::Model> initMHAQuantSubgraph0(std::vector<ov::PartialShape>& inputDynamicShapes,
Expand Down

0 comments on commit e1f3fc6

Please sign in to comment.