Skip to content

Commit

Permalink
[Snippets][CPU] Set N_blk=24 for BRGEMM on AVX2 (openvinotoolkit#26319)
Browse files Browse the repository at this point in the history
### Details:
 - *Set `n_blk=24` on avx2 for `BrgemmCPU`*

### Tickets:
 - *151064*
  • Loading branch information
a-sidorova authored Sep 4, 2024
1 parent 9c41f10 commit a0fe89a
Show file tree
Hide file tree
Showing 7 changed files with 114 additions and 63 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@ class BrgemmBlockingBase {
* @param brgemm_expr Brgemm expression
* @return tuple in format (m_block, n_block, k_block)
*/
virtual std::tuple<size_t, size_t, size_t> get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr);
virtual std::tuple<size_t, size_t, size_t> get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) const;
/**
* @interface get_brgemm_dimensions
* @brief Extract current dimensions M,N,K of `brgemm_expr`
* @param brgemm_expr Brgemm expression
* @return tuple in format (M, N, K)
*/
static std::tuple<size_t, size_t, size_t> get_brgemm_dimensions(const ov::snippets::lowered::ExpressionPtr& brgemm_expr);
/**
* @interface mark_blocking_loops
* @brief Covers brgemm with blocking loops. Also should calculate optimal blocking parameters inside.
Expand Down Expand Up @@ -72,6 +79,10 @@ class BrgemmBlockingBase {
virtual SpecificIterationHandlers get_m_loop_handlers(size_t work_amount, size_t block_size) const;
virtual SpecificIterationHandlers get_n_loop_handlers(size_t work_amount, size_t block_size) const;
virtual SpecificIterationHandlers get_k_loop_handlers(size_t work_amount, size_t block_size) const;

virtual size_t get_default_m_blk(size_t m) const;
virtual size_t get_default_n_blk(size_t n) const;
virtual size_t get_default_k_blk(size_t k) const;
};

/**
Expand Down
47 changes: 27 additions & 20 deletions src/common/snippets/src/lowered/pass/brgemm_blocking.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,32 @@ SpecificIterationHandlers BrgemmBlockingBase::get_k_loop_handlers(size_t work_am
return get_default_blocking_loop_handlers(work_amount, block_size);
}

std::tuple<size_t, size_t, size_t> BrgemmBlockingBase::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) {
size_t BrgemmBlockingBase::get_default_m_blk(size_t m) const {
return 32;
}
size_t BrgemmBlockingBase::get_default_n_blk(size_t n) const {
return 64;
}
size_t BrgemmBlockingBase::get_default_k_blk(size_t k) const {
return !utils::is_dynamic_value(k) && k > 1024 ? 1024 : 512;
}

std::tuple<size_t, size_t, size_t> BrgemmBlockingBase::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) const {
size_t m, n, k;
std::tie(m, n, k) = get_brgemm_dimensions(brgemm_expr);

// Ticket: 113745
// TODO: extend block size selection heuristics
auto get_block_size = [](const size_t dim, const size_t default_blk) {
if (!snippets::utils::is_dynamic_value(dim) && dim <= default_blk)
return get_full_dim_value();
return default_blk;
};
return std::make_tuple(get_block_size(m, get_default_m_blk(m)), get_block_size(n, get_default_n_blk(n)), get_block_size(k, get_default_k_blk(k)));
}

std::tuple<size_t, size_t, size_t> BrgemmBlockingBase::get_brgemm_dimensions(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) {
OPENVINO_ASSERT(brgemm_expr, "Brgemm expression is nullptr!");
const auto& in_0_desc = brgemm_expr->get_input_port_descriptor(0);
const auto& in_1_desc = brgemm_expr->get_input_port_descriptor(1);
const auto& out_desc = brgemm_expr->get_output_port_descriptor(0);
Expand All @@ -106,25 +131,7 @@ std::tuple<size_t, size_t, size_t> BrgemmBlockingBase::get_blocking_params(const
const auto& k1 = *++in_1_planar_dims.rbegin();
size_t k = 0;
OPENVINO_ASSERT(utils::merge_dynamic_dim(k, k0, k1), "Brgemm input descriptors have incompatible K dimension value.");

// Ticket: 113745
// TODO: extend block size selection heuristics
auto get_block_size_m = [](const size_t M) -> size_t {
if (!snippets::utils::is_dynamic_value(M) && M <= 32)
return get_full_dim_value();
return 32;
};
auto get_block_size_n = [](const size_t N) -> size_t {
if (!snippets::utils::is_dynamic_value(N) && N <= 64)
return get_full_dim_value();
return 64;
};
auto get_block_size_k = [](const size_t K) -> size_t {
if (ov::snippets::utils::is_dynamic_value(K))
return 512;
return K > 1024 ? 1024 : K > 512 ? 512 : get_full_dim_value();
};
return std::make_tuple(get_block_size_m(m), get_block_size_n(n), get_block_size_k(k));
return std::make_tuple(m, n, k);
}

bool BrgemmBlockingBase::mark_blocking_loops(snippets::lowered::LinearIR& linear_ir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,21 @@ LinearIR::constExprIt BrgemmCPUBlocking::get_loop_begin_pos(LinearIR& linear_ir,
return loop_begin_it;
}

std::tuple<size_t, size_t, size_t> BrgemmCPUBlocking::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) {
auto blocking_params = ov::snippets::lowered::pass::BrgemmBlockingBase::get_blocking_params(brgemm_expr);
size_t BrgemmCPUBlocking::get_default_n_blk(size_t n) const {
return dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ? 64 : 24;
}

std::tuple<size_t, size_t, size_t> BrgemmCPUBlocking::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) const {
const auto brgemm = ov::as_type_ptr<ov::intel_cpu::BrgemmCPU>(brgemm_expr->get_node());
OPENVINO_ASSERT(brgemm, "BrgemmCPU is expected!");

size_t m_blk, n_blk, k_blk;
std::tie(m_blk, n_blk, k_blk) = BrgemmBlockingBase::get_blocking_params(brgemm_expr);
if (with_repacking(brgemm->get_type())) {
std::get<1>(blocking_params) = get_full_dim_value();
std::get<2>(blocking_params) = get_full_dim_value();
n_blk = get_full_dim_value();
k_blk = get_full_dim_value();
}
return blocking_params;
return std::make_tuple(m_blk, n_blk, k_blk);
}

SpecificIterationHandlers BrgemmCPUBlocking::get_k_loop_handlers(size_t work_amount, size_t block_size) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,12 +46,14 @@ class BrgemmCPUBlocking : public ov::snippets::lowered::pass::BrgemmBlocking<Brg

snippets::lowered::SpecificIterationHandlers get_k_loop_handlers(size_t work_amount, size_t block_size) const override;

std::tuple<size_t, size_t, size_t> get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) override;
std::tuple<size_t, size_t, size_t> get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) const override;
bool mark_blocking_loops(snippets::lowered::LinearIR& linear_ir,
const snippets::lowered::LinearIR::constExprIt& brgemm_it,
size_t m_block,
size_t n_block,
size_t k_block) override;

size_t get_default_n_blk(size_t n) const override;
};

} // namespace pass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,16 @@ std::shared_ptr<snippets::lowered::pass::PassBase> BrgemmTPPBlocking::SetBrgemmB
return !other || ov::is_type<SetBrgemmBeta>(other) ? std::make_shared<SetBrgemmBeta>() : nullptr;
}

std::tuple<size_t, size_t, size_t> BrgemmTPPBlocking::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) {
const auto& in_0_desc = brgemm_expr->get_input_port_descriptor(0);
const auto& in_1_desc = brgemm_expr->get_input_port_descriptor(1);
const auto& out_desc = brgemm_expr->get_output_port_descriptor(0);

const auto& in_0_planar_dims = get_planar_vdims(in_0_desc->get_shape(), in_0_desc->get_layout());
const auto& in_1_planar_dims = get_planar_vdims(in_1_desc->get_shape(), in_1_desc->get_layout());
const auto& out_preordered_dims = get_preordered_vdims(out_desc->get_shape(), out_desc->get_layout());

const auto& m = *++out_preordered_dims.rbegin();
const auto& n = *out_preordered_dims.rbegin();
const auto& k = *in_0_planar_dims.rbegin();
OPENVINO_ASSERT(k == *++in_1_planar_dims.rbegin(), "Brgemm input descriptors have different K dimension value.");
std::tuple<size_t, size_t, size_t> BrgemmTPPBlocking::get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) const {
size_t m, n, k;
std::tie(m, n, k) = get_brgemm_dimensions(brgemm_expr);
OPENVINO_ASSERT(!is_dynamic_value(m) && !is_dynamic_value(n) && !is_dynamic_value(n), "BrgemmTPP doesn't support dynamic shapes");

const auto block_size_m = std::min<size_t>(32, m);
const auto block_size_n = std::min<size_t>(64, n);
const auto block_size_k = k > 1024 ? 1024 : k > 512 ? 512 : k;
return std::make_tuple(block_size_m, block_size_n, block_size_k);
size_t m_blk, n_blk, k_blk;
std::tie(m_blk, n_blk, k_blk) = BrgemmBlockingBase::get_blocking_params(brgemm_expr);

auto get_projected_blk = [](const size_t dim, const size_t blk) { return ov::snippets::utils::is_full_dim_value(blk) ? dim : blk; };
return std::make_tuple(get_projected_blk(m, m_blk), get_projected_blk(n, n_blk), get_projected_blk(k, k_blk));
}

ov::snippets::lowered::SpecificIterationHandlers BrgemmTPPBlocking::get_k_loop_handlers(size_t work_amount, size_t block_size) const {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ class BrgemmTPPBlocking : public ov::snippets::lowered::pass::BrgemmBlocking<ov:
};

private:
std::tuple<size_t, size_t, size_t> get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) override;
std::tuple<size_t, size_t, size_t> get_blocking_params(const ov::snippets::lowered::ExpressionPtr& brgemm_expr) const override;
ov::snippets::lowered::SpecificIterationHandlers get_k_loop_handlers(size_t work_amount, size_t block_size) const override;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/tpp/x64/op/brgemm.hpp"
#include "cpu/x64/cpu_isa_traits.hpp"

namespace ov {
namespace test {
Expand Down Expand Up @@ -119,11 +120,6 @@ void create_brgemm_with_copy_b_loop_infos(const LinearIRPtr& linear_ir,
}
} // namespace

static const size_t m_blk = 32;
static const size_t k_blk = 512;
static const size_t n_blk = 64;
static const size_t full_dim = ov::snippets::utils::get_full_dim_value();

class BrgemmBlockingTest : public LoweredPassTestsF {
public:
BrgemmBlockingTest() : LoweredPassTestsF() {
Expand All @@ -132,10 +128,19 @@ class BrgemmBlockingTest : public LoweredPassTestsF {
comparator.enable(LIRComparator::LIRCmpValues::PORT_CONNECTORS);
comparator.enable(LIRComparator::LIRCmpValues::LOOP_MANAGER);
}

protected:
size_t m_blk = 32;
size_t k_blk = 512;
size_t n_blk = 64;

static const size_t full_dim = ov::snippets::utils::get_full_dim_value();
};
class BrgemmCPUBlockingTest : public BrgemmBlockingTest {
public:
BrgemmCPUBlockingTest() : BrgemmBlockingTest() {}
BrgemmCPUBlockingTest() : BrgemmBlockingTest() {
n_blk = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ? 64 : 24;
}

void SetUp() override {
pipeline.register_pass<ov::intel_cpu::pass::BrgemmCPUBlocking>();
Expand Down Expand Up @@ -171,10 +176,38 @@ TEST_F(BrgemmCPUBlockingTest, Floating) {
}
}

TEST_F(BrgemmCPUBlockingTest, Floating_LargeK) {
const ov::Dimension::value_type m = 384;
const ov::Dimension::value_type n = 384;
const ov::Dimension::value_type k = 2048;
const ov::PartialShape input_shape_a{1, 16, m, k};
const ov::PartialShape input_shape_b{1, 16, k, n};
const auto precision = ov::element::f32;
k_blk = 1024;

{
auto data_a = linear_ir->push_node<ov::opset10::Parameter>(precision, input_shape_a);
auto data_b = linear_ir->push_node<ov::opset10::Parameter>(precision, input_shape_b);
auto brgemm = linear_ir->push_node<BrgemmCPU>(data_a.second, data_b.second, BRGEMM_TYPE::STAND_ALONE);
init_expr_descriptors(*brgemm.first, {});
auto result = linear_ir->push_node<ov::opset10::Result>(brgemm.second);
}
{
auto data_a = linear_ir_ref->push_node<ov::opset10::Parameter>(precision, input_shape_a);
auto data_b = linear_ir_ref->push_node<ov::opset10::Parameter>(precision, input_shape_b);
auto brgemm = linear_ir_ref->push_node<BrgemmCPU>(data_a.second, data_b.second, BRGEMM_TYPE::STAND_ALONE);
const auto& brgemm_expr = *brgemm.first;
init_expr_descriptors(brgemm_expr, {{m_blk, k_blk}, {k_blk, n_blk}, {m_blk, n_blk}});
create_brgemm_loop_infos(linear_ir_ref, brgemm_expr, m, m_blk, k, k_blk, n, n_blk);
brgemm_expr->set_loop_ids({2, 1, 0});
auto result = linear_ir_ref->push_node<ov::opset10::Result>(brgemm.second);
}
}

TEST_F(BrgemmCPUBlockingTest, BlockingIsNotNeeded) {
const size_t m = 32;
const size_t k = 16;
const size_t n = 64;
const ov::Dimension::value_type m = 32;
const ov::Dimension::value_type k = 16;
const ov::Dimension::value_type n = dnnl::impl::cpu::x64::mayiuse(dnnl::impl::cpu::x64::avx512_core) ? 64 : 24;
const ov::PartialShape input_shape_a{1, 16, m, k};
const ov::PartialShape input_shape_b{1, 16, k, n};
const auto precision = ov::element::f32;
Expand All @@ -197,9 +230,9 @@ TEST_F(BrgemmCPUBlockingTest, BlockingIsNotNeeded) {
}

TEST_F(BrgemmCPUBlockingTest, WithDataRepacking) {
const size_t m = 384;
const size_t k = 1024;
const size_t n = 384;
const ov::Dimension::value_type m = 384;
const ov::Dimension::value_type k = 1024;
const ov::Dimension::value_type n = 384;
const ov::PartialShape input_shape_a{1, 16, m, k};
const ov::PartialShape input_shape_b{1, 16, k, n};
const auto precision_a = ov::element::u8;
Expand Down Expand Up @@ -232,9 +265,9 @@ TEST_F(BrgemmCPUBlockingTest, WithDataRepacking) {
}

TEST_F(BrgemmCPUBlockingTest, WithCompensations) {
const size_t m = 384;
const size_t k = 1024;
const size_t n = 384;
const ov::Dimension::value_type m = 384;
const ov::Dimension::value_type k = 1024;
const ov::Dimension::value_type n = 384;
const ov::PartialShape input_shape_a{1, 16, m, k};
const ov::PartialShape input_shape_b{1, 16, k, n};
const auto precision = ov::element::i8;
Expand Down Expand Up @@ -267,9 +300,9 @@ TEST_F(BrgemmCPUBlockingTest, WithCompensations) {
}

TEST_F(BrgemmCPUBlockingTest, AMX) {
const size_t m = 384;
const size_t k = 1024;
const size_t n = 384;
const ov::Dimension::value_type m = 384;
const ov::Dimension::value_type k = 1024;
const ov::Dimension::value_type n = 384;
const ov::PartialShape input_shape_a{1, 16, m, k};
const ov::PartialShape input_shape_b{1, 16, k, n};
const auto precision = ov::element::bf16;
Expand Down

0 comments on commit a0fe89a

Please sign in to comment.