Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets][CPU] Added BrgemmCopyA support #26871

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 55 additions & 47 deletions src/plugins/intel_cpu/src/emitters/plugin/x64/utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,46 +5,52 @@
#include "utils.hpp"

#include "emitters/utils.hpp"
#include "utils/general_utils.h"

namespace ov {
namespace intel_cpu {

using namespace Xbyak;
using namespace dnnl::impl::cpu::x64;

EmitABIRegSpills::EmitABIRegSpills(jit_generator* h) : h(h), isa(get_isa()) {}
EmitABIRegSpills::EmitABIRegSpills(jit_generator* h, Type type) : h(h), isa(get_isa()), type(type) {
OPENVINO_ASSERT(one_of(type, Type::GPRS, Type::VECS, Type::ALL), "Incorrect type");
}

EmitABIRegSpills::~EmitABIRegSpills() {
OPENVINO_ASSERT(spill_status, "postamble or preamble is missed");
OPENVINO_ASSERT(rsp_status, "rsp_align or rsp_restore is missed");
}

void EmitABIRegSpills::preamble() {
// gprs
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp};
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);

h->sub(h->rsp, n_gprs_to_save * gpr_size);
for (size_t i = 0; i < n_gprs_to_save; ++i)
h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]);

if (isa == avx512_core) {
h->sub(h->rsp, k_mask_num * k_mask_size);
for (size_t i = 0; i < k_mask_num; ++i) {
h->kmovq(h->ptr[h->rsp + i * k_mask_size], Xbyak::Opmask(static_cast<int>(i)));
}
if (type & Type::GPRS) {
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp};
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);

h->sub(h->rsp, n_gprs_to_save * gpr_size);
for (size_t i = 0; i < n_gprs_to_save; ++i)
h->mov(h->ptr[h->rsp + i * gpr_size], gprs_to_save[i]);
}

h->sub(h->rsp, get_max_vecs_count() * get_vec_length());
for (size_t i = 0; i < get_max_vecs_count(); ++i) {
const auto addr = h->ptr[h->rsp + i * get_vec_length()];
if (isa == sse41) {
h->uni_vmovups(addr, Xmm(i));
} else if (isa == avx2) {
h->uni_vmovups(addr, Ymm(i));
} else {
h->uni_vmovups(addr, Zmm(i));
if (type & Type::VECS) {
if (isa == avx512_core) {
h->sub(h->rsp, k_mask_num * k_mask_size);
for (size_t i = 0; i < k_mask_num; ++i) {
h->kmovq(h->ptr[h->rsp + i * k_mask_size], Xbyak::Opmask(static_cast<int>(i)));
}
}

h->sub(h->rsp, get_max_vecs_count() * get_vec_length());
for (size_t i = 0; i < get_max_vecs_count(); ++i) {
const auto addr = h->ptr[h->rsp + i * get_vec_length()];
if (isa == sse41) {
h->uni_vmovups(addr, Xmm(i));
} else if (isa == avx2) {
h->uni_vmovups(addr, Ymm(i));
} else {
h->uni_vmovups(addr, Zmm(i));
}
}
}

Expand All @@ -53,34 +59,36 @@ void EmitABIRegSpills::preamble() {
}

void EmitABIRegSpills::postamble() {
// restore vector registers
for (int i = static_cast<int>(get_max_vecs_count()) - 1; i >= 0; --i) {
const auto addr = h->ptr[h->rsp + i * get_vec_length()];
if (isa == sse41) {
h->uni_vmovups(Xmm(i), addr);
} else if (isa == avx2) {
h->uni_vmovups(Ymm(i), addr);
} else {
h->uni_vmovups(Zmm(i), addr);
if (type & Type::VECS) {
for (int i = static_cast<int>(get_max_vecs_count()) - 1; i >= 0; --i) {
const auto addr = h->ptr[h->rsp + i * get_vec_length()];
if (isa == sse41) {
h->uni_vmovups(Xmm(i), addr);
} else if (isa == avx2) {
h->uni_vmovups(Ymm(i), addr);
} else {
h->uni_vmovups(Zmm(i), addr);
}
}
}
h->add(h->rsp, (get_max_vecs_count()) * get_vec_length());
h->add(h->rsp, (get_max_vecs_count()) * get_vec_length());

// restore k reg
if (isa == avx512_core) {
for (int i = k_mask_num - 1; i >= 0; --i) {
h->kmovq(Xbyak::Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
if (isa == avx512_core) {
for (int i = k_mask_num - 1; i >= 0; --i) {
h->kmovq(Xbyak::Opmask(i), h->ptr[h->rsp + i * k_mask_size]);
}
h->add(h->rsp, k_mask_num * k_mask_size);
}
h->add(h->rsp, k_mask_num * k_mask_size);
}

// restore gpr registers
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp};
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);
for (int i = n_gprs_to_save - 1; i >= 0; --i)
h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]);
h->add(h->rsp, n_gprs_to_save * gpr_size);
if (type & Type::GPRS) {
// restore gpr registers
Xbyak::Operand gprs_to_save[] = {h->r8, h->r9, h->r10, h->r11, h->r12, h->r13, h->r14, h->r15,
h->rax, h->rbx, h->rcx, h->rdx, h->rdi, h->rsi, h->rbp};
size_t n_gprs_to_save = sizeof(gprs_to_save) / sizeof(gprs_to_save[0]);
for (int i = n_gprs_to_save - 1; i >= 0; --i)
h->mov(gprs_to_save[i], h->ptr[h->rsp + i * gpr_size]);
h->add(h->rsp, n_gprs_to_save * gpr_size);
}

// Update the status
spill_status = true;
Expand Down
9 changes: 8 additions & 1 deletion src/plugins/intel_cpu/src/emitters/plugin/x64/utils.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,13 @@ namespace intel_cpu {
// The class emit register spills for the possible call of external binary code
class EmitABIRegSpills {
public:
EmitABIRegSpills(dnnl::impl::cpu::x64::jit_generator* h);
enum Type {
GPRS = 1 << 1, // spill only general-purpose regisers
VECS = 1 << 2, // spill only vector regisers
ALL = GPRS | VECS, // default, spill vector and general-purpose registers
};

EmitABIRegSpills(dnnl::impl::cpu::x64::jit_generator* h, Type type = Type::ALL);
~EmitABIRegSpills();

// push (save) all registers on the stack
Expand All @@ -35,6 +41,7 @@ class EmitABIRegSpills {

dnnl::impl::cpu::x64::jit_generator* h {nullptr};
const dnnl::impl::cpu::x64::cpu_isa_t isa {dnnl::impl::cpu::x64::cpu_isa_t::isa_undef};
const Type type {Type::ALL};

static constexpr int k_mask_size = 8;
static constexpr int k_mask_num = 8;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ namespace intel_cpu {
#define SNIPPETS_MAX_DATA_PTR_COUNT 11
#endif

// Maximum count of Buffer offsets (clusters)
#define SNIPPETS_MAX_BUFFER_COUNT 16

#define GET_OFF(field) offsetof(jit_snippets_call_args, field)
#define GET_OFF_LOOP_ARGS(field) offsetof(jit_snippets_call_args::loop_args_t, field)

Expand All @@ -46,7 +49,7 @@ struct jit_snippets_call_args {
// for all non-static data members. So we can keep them public or friend all control-flow emitters
loop_args_t* loop_args = nullptr;
amx_tile_config_t amx_tile_config;
size_t buffer_offsets[SNIPPETS_MAX_DATA_PTR_COUNT] = {};
size_t buffer_offsets[SNIPPETS_MAX_BUFFER_COUNT] = {};
};

struct jit_snippets_call_args::loop_args_t {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "snippets/snippets_isa.hpp"
#include "emitters/snippets/cpu_runtime_configurator.hpp"

#include "emitters/snippets/x64/jit_brgemm_copy_a_emitter.hpp"
#include "emitters/snippets/x64/jit_brgemm_copy_b_emitter.hpp"
#include "emitters/snippets/x64/jit_brgemm_emitter.hpp"
#include "emitters/snippets/x64/jit_memory_emitters.hpp"
Expand All @@ -23,6 +24,7 @@
#include "transformations/snippets/common/op/load_convert.hpp"
#include "transformations/snippets/common/op/store_convert.hpp"
#include "transformations/snippets/common/op/fused_mul_add.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_a.hpp"
#include "transformations/snippets/x64/op/brgemm_copy_b.hpp"
#include "transformations/snippets/x64/op/brgemm_cpu.hpp"
#include "transformations/snippets/x64/op/perf_count_rdtsc.hpp"
Expand Down Expand Up @@ -243,6 +245,9 @@ intel_cpu::CPUTargetMachine::CPUTargetMachine(dnnl::impl::cpu::x64::cpu_isa_t ho
jitters[intel_cpu::BrgemmCPU::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCopyA::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_a_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
jitters[intel_cpu::BrgemmCopyB::get_type_info_static()] = CREATE_SNIPPETS_EMITTER(intel_cpu::jit_brgemm_copy_b_emitter,
configurator->get_kernel_executor_table(),
compiled_kernel_cache);
Expand Down Expand Up @@ -356,6 +361,7 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons
std::dynamic_pointer_cast<intel_cpu::tpp::modifier::TensorProcessingPrimitive>(op) ||
std::dynamic_pointer_cast<intel_cpu::tpp::op::Scalar>(op) ||
#endif
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyA>(op)||
std::dynamic_pointer_cast<intel_cpu::BrgemmCopyB>(op))
return ov::snippets::RegType::gpr;
else if (
Expand All @@ -368,7 +374,8 @@ ov::snippets::RegType intel_cpu::CPUGenerator::get_specific_op_out_reg_type(cons

bool intel_cpu::CPUGenerator::uses_precompiled_kernel(const std::shared_ptr<snippets::Emitter>& e) const {
bool need = std::dynamic_pointer_cast<intel_cpu::jit_brgemm_emitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_b_emitter>(e);
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_b_emitter>(e) ||
std::dynamic_pointer_cast<intel_cpu::jit_brgemm_copy_a_emitter>(e);
#ifdef SNIPPETS_DEBUG_CAPS
const auto cpu_target_machine = std::dynamic_pointer_cast<intel_cpu::CPUTargetMachine>(target);
need = need || (cpu_target_machine && cpu_target_machine->debug_config.enable_segfault_detector) ||
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "jit_brgemm_copy_a_emitter.hpp"

#include "emitters/plugin/x64/utils.hpp"
#include "emitters/snippets/x64/utils.hpp"
#include "emitters/snippets/jit_snippets_call_args.hpp"

#include "snippets/utils/utils.hpp"

#include "transformations/snippets/x64/op/brgemm_copy_a.hpp"
#include "transformations/snippets/x64/op/brgemm_utils.hpp"


using namespace dnnl::impl::cpu::x64;
using namespace ov::intel_cpu::brgemm_utils;
using namespace ov::snippets::utils;

namespace ov {
namespace intel_cpu {

jit_brgemm_copy_a_emitter::jit_brgemm_copy_a_emitter(jit_generator* h, cpu_isa_t isa, const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache)
: jit_emitter(h, isa) {
in_out_type_ = emitter_in_out_map::gpr_to_gpr;
const auto brgemm_repack = ov::as_type_ptr<ov::intel_cpu::BrgemmCopyA>(expr->get_node());
OV_CPU_JIT_EMITTER_ASSERT(brgemm_repack, "expects BrgemmCopyA node");

// Note: even if the BrgemmCopyA node is dynamic, the first shapeInfer and RuntimeConfigurator::update()
// are performed before the BrgemmCopyAKernelExecutor registration. So we have to trigger update() manually
// for both static and the 1st dynamic shapes.
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()),
"Jit emitter is called when the shapes are unknown");

const auto& brgemm_config = brgemm_repack->get_config();
BrgemmCopyAKernelConfig kernel_config(brgemm_repack->get_input_element_type(0), brgemm_config.isa());
m_kernel_executor = kernel_table->register_kernel<BrgemmCopyAKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()};
m_buffer_ids = {utils::get_buffer_cluster_id(expr->get_input_port(0)), utils::get_buffer_cluster_id(expr->get_output_port(0))};
}

void jit_brgemm_copy_a_emitter::validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const {
OV_CPU_JIT_EMITTER_ASSERT(in.size() == 1 && out.size() == 1, "expects 1 input and 1 output");
}

void jit_brgemm_copy_a_emitter::emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const {
validate_arguments(in, out);

std::vector<size_t> mem_ptrs_idxs{in[0], out[0]};

EmitABIRegSpills spill(h);
spill.preamble();

h->mov(h->rbp, reinterpret_cast<uint64_t>(BrgemmCopyAKernelExecutor::execute));
auto reserved_stack_size = sizeof(BrgemmCopyAKernel::call_args);
// Reserve memory on the stack
h->sub(h->rsp, reserved_stack_size);

const bool is_dynamic_case = std::any_of(m_memory_offsets.cbegin(), m_memory_offsets.cend(), ov::snippets::utils::is_dynamic_value<size_t>);
Xbyak::Reg64 aux_reg = is_dynamic_case ? ov::intel_cpu::utils::get_aux_gpr(mem_ptrs_idxs) : Xbyak::Reg64();

const std::vector<size_t> args_offsets {GET_OFF_BRGEMM_COPY_A_ARGS(src), GET_OFF_BRGEMM_COPY_A_ARGS(tr_src)};
const auto& mem_ptrs = ov::intel_cpu::utils::transform_idxs_to_regs(mem_ptrs_idxs);
for (size_t i = 0; i < mem_ptrs.size(); i++) {
if (ov::snippets::utils::is_dynamic_value(m_memory_offsets[i]))
utils::push_ptr_with_runtime_offset_on_stack(h, args_offsets[i], mem_ptrs[i], aux_reg,
GET_OFF(buffer_offsets) + m_buffer_ids[i] * sizeof(size_t));
else
utils::push_ptr_with_static_offset_on_stack(h, args_offsets[i], mem_ptrs[i], m_memory_offsets[i]);
}

h->mov(abi_param1, reinterpret_cast<uintptr_t>(m_kernel_executor.get()));
h->mov(abi_param2, h->rsp);

spill.rsp_align();
h->call(h->rbp);
spill.rsp_restore();

h->add(h->rsp, reserved_stack_size);

spill.postamble();
}

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "emitters/plugin/x64/jit_emitter.hpp"

#include "kernel_executors/brgemm_copy_a.hpp"


namespace ov {
namespace intel_cpu {

class jit_brgemm_copy_a_emitter : public jit_emitter {
public:
jit_brgemm_copy_a_emitter(dnnl::impl::cpu::x64::jit_generator* h, dnnl::impl::cpu::x64::cpu_isa_t isa,
const ov::snippets::lowered::ExpressionPtr& expr,
const snippets::KernelExecutorTablePtr& kernel_table,
const ov::intel_cpu::MultiCacheWeakPtr& compiled_kernel_cache);

size_t get_inputs_num() const override {return 1;}
static std::set<std::vector<element::Type>> get_supported_precisions(const std::shared_ptr<ov::Node>& node = nullptr) {
return {{element::i8}, {element::u8}, {element::bf16}};
}

private:
void validate_arguments(const std::vector<size_t> &in, const std::vector<size_t> &out) const override;
void emit_impl(const std::vector<size_t>& in, const std::vector<size_t>& out) const override;

std::vector<size_t> m_memory_offsets{};
std::vector<size_t> m_buffer_ids{};
std::shared_ptr<BrgemmCopyAKernelExecutor> m_kernel_executor {nullptr};

#ifdef SNIPPETS_DEBUG_CAPS
friend std::string init_info_jit_brgemm_copy_a_emitter(const jit_brgemm_copy_a_emitter *emitter);
#endif
};

} // namespace intel_cpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -48,18 +48,14 @@ jit_brgemm_copy_b_emitter::jit_brgemm_copy_b_emitter(jit_generator* h, cpu_isa_t
OV_CPU_JIT_EMITTER_ASSERT(!snippets::utils::is_dynamic_vdims(expr->get_input_port_descriptor(0)->get_shape()),
"Jit emitter is called when the shapes are unknown");

const auto& in_subtensor = get_projected_subtensor(expr->get_input_port(0));
const auto K_blk = *++in_subtensor.rbegin();

const auto& src_prc = brgemm_repack->get_src_element_type();
const auto& wei_prc = brgemm_repack->get_input_element_type(0);
const auto wei_N_blk = brgemm_utils::repacking::compute_inner_n_block(wei_prc);
const auto is_transposed = get_is_transposed(expr);
const auto brgemm_type = get_brgemm_type(src_prc, K_blk, is_transposed);
const auto primitive_isa = brgemm_utils::get_primitive_isa(src_prc, with_amx(brgemm_type));
m_with_comp = with_compensations(brgemm_type);
const auto& brgemm_config = brgemm_repack->get_config();
m_with_comp = brgemm_config.need_compensations();

BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, primitive_isa, m_with_comp, is_transposed, wei_N_blk);
BrgemmCopyBKernelConfig kernel_config(src_prc, wei_prc, brgemm_config.isa(), m_with_comp, is_transposed, wei_N_blk);
m_kernel_executor = kernel_table->register_kernel<BrgemmCopyBKernelExecutor>(expr, compiled_kernel_cache, kernel_config);

m_memory_offsets = {brgemm_repack->get_offset_in(), brgemm_repack->get_offset_out()};
Expand Down
Loading
Loading