Skip to content

Commit

Permalink
[CPU] enable MLP & QKV on non-PA case and minor fixes (openvinotoolki…
Browse files Browse the repository at this point in the history
…t#26103)

### Details:
- enable MLP & QKV optimization on all cases (previous PR only enable on
PageAttention case), this brings ~30% first token latency reduction w/o
any regression on 2nd token latency.
- relax restrictions on K dimension size: from integer multiple of 256
down to 32, this allows more LLMs to benefit from this optimization
(like QWen1.5)
- reduce total memory footprint by converting directly from fp16 weight
into bf16 (previously we require bf16 weights in node and this
introduces a duplicated weight memory in bf16 format when most LLM are
in fp16 format).
- allocate small sub weight tensors introduces too much page-fault
overhead, make big weight tensor allocation can reduce it by a lot
thanks to huge-pages
- prepack weight tensor in AMX B-tile format directly from gate & up
tensor w/o intermediate combination step, this reduces first inference
latency
- fix scratch buffer bugs in previous PR, now scratch buffers are truly
shared among layers.

### Tickets:
 - *ticket-id*
  • Loading branch information
usstq authored Aug 28, 2024
1 parent d78c565 commit f7435e4
Show file tree
Hide file tree
Showing 8 changed files with 609 additions and 181 deletions.
68 changes: 62 additions & 6 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -168,11 +168,38 @@ void MKernel::tile_config_M(TileConfig& tile_cfg, int M) {
});
}

class FP16ToBF16Kernel : public dnnl::impl::cpu::x64::jit_generator {
public:
DECLARE_CPU_JIT_AUX_FUNCTIONS(FP16ToBF16Kernel)
FP16ToBF16Kernel() : jit_generator("FP16ToBF16Kernel") {
create_kernel();
}

void generate() override {
Xbyak::Label loop_begin;
Xbyak::Reg64 src = abi_param1;
for (int i = 0; i < 16; i++) {
vcvtph2ps(zmm0, ptr[src]);
vcvtph2ps(zmm1, ptr[src + 32]);
vcvtne2ps2bf16(zmm2, zmm1, zmm0);
vmovups(ptr[src], zmm2);
lea(src, ptr[src + 64]);
}

ret();
}
};

template <typename T>
void MKernel::repackB(ov::bfloat16* dst, T* src, int N_stride, int N, int K) {
if (N == 16 && K == 32 && std::is_same<T, ov::bfloat16>::value) {
static FP16ToBF16Kernel fp16_to_bf16;

if (N == 16 && K == 32 && (std::is_same<T, ov::bfloat16>::value || std::is_same<T, ov::float16>::value)) {
// SIMD optimized version
ov::Extensions::Cpu::XARCH::llm_mlp_transpose_epi32_16x16(dst, src, N_stride * sizeof(T));
if (std::is_same<T, ov::float16>::value) {
fp16_to_bf16(dst);
}
return;
}

Expand All @@ -197,17 +224,18 @@ void MKernel::repackB(ov::bfloat16* dst, T* src, int N_stride, int N, int K) {
}

template <typename T>
void MKernel::prepareB(PlainTensor& ret, T* p_weight, int stride, int N, int K) {
void MKernel::prepareB(PlainTensor& ret, ov::bfloat16* dst, T* p_weight, int stride, int N, int K) {
OPENVINO_ASSERT((N % 32) == 0);
OPENVINO_ASSERT((K % 32) == 0);
// weight matrix is in unit of [N/32, Kx32]
ret.resize<ov::bfloat16>({static_cast<size_t>(N / 32), static_cast<size_t>(K * 32)});
ret.resize<ov::bfloat16>({static_cast<size_t>(N / 32), static_cast<size_t>(K * 32)}, dst);

auto N_stride = stride / sizeof(T);
for (int n = 0, blkn = 0; n < N; n += 32, blkn++) {
for (int k = 0, blkk = 0; k < K; k += 32, blkk++) {
auto* dst_base = ret.ptr<ov::bfloat16>(blkn, 0);
for (int k = 0, blkk = 0; k < K; k += 32, blkk++, dst_base += 1024) {
// two adjacent 32x16 (512) block of weight: dst0 & dst1
auto* dst0 = ret.ptr<ov::bfloat16>(blkn, blkk * 1024);
auto* dst0 = dst_base;
auto* dst1 = dst0 + 16 * 32;
auto valid_k = (K - k) < 32 ? (K - k) : 32;

Expand All @@ -222,7 +250,35 @@ void MKernel::prepareB(PlainTensor& ret, T* p_weight, int stride, int N, int K)
}
}

template void MKernel::prepareB<ov::bfloat16>(PlainTensor& ret, ov::bfloat16* p_weight, int stride, int N, int K);
// interleaving two weights into one in unit of 16-column
template <typename T>
void MKernel::prepareB(PlainTensor& ret, ov::bfloat16* dst, T* p_weight1, T* p_weight2, int stride, int N, int K) {
OPENVINO_ASSERT((N % 32) == 0);
OPENVINO_ASSERT((K % 32) == 0);
// weight matrix is in unit of [N/32, Kx32]
ret.resize<ov::bfloat16>({static_cast<size_t>(N / 32), static_cast<size_t>(K * 32)}, dst);

auto N_stride = stride / sizeof(T);
auto N2 = N / 2;
for (int n = 0, blkn = 0; n < N2; n += 16, blkn++) {
for (int k = 0, blkk = 0; k < K; k += 32, blkk++) {
// two adjacent 32x16 (512) block of weight: dst0 & dst1
auto* dst0 = ret.ptr<ov::bfloat16>(blkn, blkk * 1024);
auto* dst1 = dst0 + 16 * 32;
auto valid_k = (K - k) < 32 ? (K - k) : 32;

auto* src0 = p_weight1 + n * N_stride + k;
auto valid_n0 = (N2 - n) < 16 ? (N2 - n) : 16;
repackB<T>(dst0, src0, N_stride, valid_n0, valid_k);

auto* src1 = p_weight2 + n * N_stride + k;
repackB<T>(dst1, src1, N_stride, valid_n0, valid_k);
}
}
}

template void MKernel::prepareB<ov::float16>(PlainTensor& ret, ov::bfloat16* dst, ov::float16* p_weight, int stride, int N, int K);
template void MKernel::prepareB<ov::float16>(PlainTensor& ret, ov::bfloat16* dst, ov::float16* p_weight1, ov::float16* p_weight2, int stride, int N, int K);

// run L2 cache blocking kernel with size:
// [BM, BK]*[BK, BN] => [BM, BN]
Expand Down
98 changes: 81 additions & 17 deletions src/plugins/intel_cpu/src/nodes/kernels/x64/mlp_kernel.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,18 @@
#include "../scaled_attn/executor_pa_common.hpp"
#include "utils/plain_tensor.hpp"

// register blocking size for K dimension (1x2 AMX B-tiles)
#define REG_BLK_K_SIZE 32

// register blocking size for N dimension (1x2 AMX B-tiles)
#define REG_BLK_N_SIZE 32

// cache blocking sie for K dimension
#define CACHE_BLK_K_SIZE 256

// cache blocking sie for M dimension
#define CACHE_BLK_M_SIZE 256

namespace ov {
namespace intel_cpu {

Expand Down Expand Up @@ -65,7 +77,11 @@ class MKernel : public dnnl::impl::cpu::x64::jit_generator {

// weight is supposed to be of shape[N, K], stride in unit of bytes
template <typename T>
void prepareB(PlainTensor& ret, T* p_weight, int stride, int N, int K);
void prepareB(PlainTensor& ret, ov::bfloat16* dst, T* p_weight, int stride, int N, int K);

// interleaving two weights into one in unit of 16-column
template <typename T>
void prepareB(PlainTensor& ret, ov::bfloat16* dst, T* p_weight1, T* p_weight2, int stride, int N, int K);

// to save push/pop: do not use `abi_save_gpr_regs`
uint8_t* prefetch_next_A_addr;
Expand Down Expand Up @@ -111,7 +127,7 @@ struct Work {
int BN = 0;
int blk_K_size = 0;
int output_id;
ov::bfloat16* p_raw_weights;
ov::float16* p_raw_weights;
operator bool() {
return BN > 0;
}
Expand All @@ -124,14 +140,41 @@ struct Work {

// input : weight [N, K], setup repacks range of N [n_start, n_end)
template <typename T>
void setup(T* p_weight, int stride) {
void setup(ov::bfloat16* dst, T* p_weight, int stride) {
auto& mkernel = get_MKernel();
auto num_blk_K = (k1 - k0) / blk_K_size;
auto* pw = p_weight + n0 * stride / sizeof(T) + k0;
auto num_blk_K = (k1 - k0 + blk_K_size - 1) / blk_K_size;
auto* pw = p_weight + n0 * stride / sizeof(T);

// weight is divided along K dimension into equal size blk_K_size, except last block.
weights.resize(num_blk_K);
for (int k = 0; k < num_blk_K; k++) {
mkernel.prepareB(weights[k], pw + k * blk_K_size, stride, BN, blk_K_size);
for (int k = k0, ki = 0; k < k1;) {
auto subK = std::min(blk_K_size, k1 - k);
mkernel.prepareB(weights[ki], dst, pw + k, stride, BN, subK);
dst += BN*subK;
k += subK;
ki++;
}

for (int Mtails = 0; Mtails < 32; Mtails++) {
mkernel.tile_config_M(m_tcfg[Mtails], Mtails == 0 ? 32 : Mtails);
}
}

template <typename T>
void setup(ov::bfloat16* dst, T* p_weight1, T* p_weight2, int stride) {
auto& mkernel = get_MKernel();
auto num_blk_K = (k1 - k0 + blk_K_size - 1) / blk_K_size;
auto* pw1 = p_weight1 + (n0/2) * stride / sizeof(T);
auto* pw2 = p_weight2 + (n0/2) * stride / sizeof(T);

// weight is divided along K dimension into equal size blk_K_size, except last block.
weights.resize(num_blk_K);
for (int k = k0, ki = 0; k < k1;) {
auto subK = std::min(blk_K_size, k1 - k);
mkernel.prepareB(weights[ki], dst, pw1 + k, pw2 + k, stride, BN, subK);
dst += BN*subK;
k += subK;
ki++;
}

for (int Mtails = 0; Mtails < 32; Mtails++) {
Expand All @@ -142,24 +185,28 @@ struct Work {
ov::Extensions::Cpu::TileConfig m_tcfg[32];
AutoTileConfiger m_tile_configer;

size_t get_C_size(int M) {
PlainTensor m_C;

size_t set_C(int M, float * ext_buff) {
auto Mtails = M % 32;
auto Mbody = M - Mtails;
auto C_M = Mbody + (Mtails ? 32 : 0);
return C_M * BN;
m_C.resize<float>({static_cast<size_t>(C_M), static_cast<size_t>(BN)}, ext_buff);
return C_M * BN * sizeof(float);
}

void run(int M, uint8_t* pA, int strideA, PlainTensor& C) {
void run(int M, uint8_t* pA, int strideA) {
auto& mkernel = get_MKernel();

int num_blk_K = (k1 - k0) / blk_K_size;
int num_blk_K = weights.size();

auto Mtails = M % 32;
auto Mbody = M - Mtails;

auto C_M = Mbody + (Mtails ? 32 : 0);
C.resize<float>({static_cast<size_t>(C_M), static_cast<size_t>(BN)});
auto pC = reinterpret_cast<uint8_t*>(C.ptr_v());

auto C_stride_bytes = BN * sizeof(float);
OPENVINO_ASSERT(C_M * C_stride_bytes <= m_C.stride_bytes(0) * m_C.size(0));
auto pC = reinterpret_cast<uint8_t*>(m_C.ptr_v());

pA += k0 * sizeof(ov::bfloat16);
bool do_accumulation = false;
Expand All @@ -174,7 +221,7 @@ struct Work {
strideA,
blockB,
pC,
C.stride_bytes(0),
C_stride_bytes,
reinterpret_cast<uint8_t*>(blockB1.ptr_v()),
do_accumulation);
}
Expand All @@ -185,8 +232,8 @@ struct Work {
pA + ki * blk_K_size * sizeof(ov::bfloat16) + Mbody * strideA,
strideA,
blockB,
pC + Mbody * C.stride_bytes(0),
C.stride_bytes(0),
pC + Mbody * C_stride_bytes,
C_stride_bytes,
reinterpret_cast<uint8_t*>(blockB1.ptr_v()),
do_accumulation);
}
Expand All @@ -196,6 +243,23 @@ struct Work {
}
};

// allocate weight memory in bigger trunck can benefit from HugePage (with much less page-fault effort)
struct WeightBuffer {
PlainTensor buffer;
std::vector<size_t> offsets;
void alloc(std::vector<Work>& works) {
size_t weight_cnt = 0;
for (auto& work : works) {
offsets.push_back(weight_cnt);
weight_cnt += (work.n1 - work.n0) * (work.k1 - work.k0);
}
buffer.resize<ov::bfloat16>({weight_cnt});
}
ov::bfloat16* get(int work_id) {
return buffer.ptr<ov::bfloat16>() + offsets[work_id];
}
};

// combine gate_proj & up_proj using activation algo, then convert to bf16
// ConvertFP32toBF16(act_fn(gate) * up)
class GateUpCombine : public dnnl::impl::cpu::x64::jit_generator {
Expand Down
Loading

0 comments on commit f7435e4

Please sign in to comment.