From f075372c14277ead92cbf11fe8190dc942bf71e7 Mon Sep 17 00:00:00 2001 From: "xinhao.zheng" Date: Fri, 16 Aug 2024 01:05:35 +0000 Subject: [PATCH 1/8] Integrate kleidiAI release v0.1.0 into MNN 2.9.3 Put KleidiAI files in folder source/backend/cpu/arm/kleidiAI/kai, download from arm gitlab and remain unchanged. Maybe will remove these files and download them when build. MNNKleidiAI.cpp is interface between MNN and KleidiAI. Rewrite function in class DenseConvInt8TiledExecutor , in ConvInt8TiledExecutor.cpp, to call KleidiAI functions. Maybe implement a new execution later. Changes to GeometryConvUtils.cpp and ShapeTensorConvert.cpp are for the input and output of DenseConvInt8TiledExecutor is NCHW, rather than NC4HW4, to avoid redundant pack/unpack and get better performance. --- source/backend/cpu/CMakeLists.txt | 7 + source/backend/cpu/CPUBackend.hpp | 4 + .../backend/cpu/arm/kleidiAI/CMakeLists.txt | 34 ++ .../backend/cpu/arm/kleidiAI/MNNKleidiAI.cpp | 215 ++++++++ source/backend/cpu/arm/kleidiAI/MNNKleidiAI.h | 108 ++++ .../backend/cpu/arm/kleidiAI/kai/kai_common.h | 108 ++++ ...ai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c | 228 ++++++++ ...ai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h | 125 +++++ ...2_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c | 507 ++++++++++++++++++ ...2_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h | 125 +++++ .../pack/kai_lhs_quant_pack_qai8dxp_f32.c | 179 +++++++ .../pack/kai_lhs_quant_pack_qai8dxp_f32.h | 77 +++ .../kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c | 171 ++++++ .../kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h | 111 ++++ .../cpu/compute/ConvInt8TiledExecutor.cpp | 47 ++ source/core/TensorUtils.hpp | 6 + source/geometry/GeometryConvUtils.cpp | 13 + source/shape/ShapeTensorConvert.cpp | 3 + 18 files changed, 2068 insertions(+) create mode 100644 source/backend/cpu/arm/kleidiAI/CMakeLists.txt create mode 100644 source/backend/cpu/arm/kleidiAI/MNNKleidiAI.cpp create mode 100644 source/backend/cpu/arm/kleidiAI/MNNKleidiAI.h create mode 100644 source/backend/cpu/arm/kleidiAI/kai/kai_common.h create mode 100644 source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c create mode 100644 source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h create mode 100644 source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c create mode 100644 source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h create mode 100644 source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c create mode 100644 source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h create mode 100644 source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c create mode 100644 source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h diff --git a/source/backend/cpu/CMakeLists.txt b/source/backend/cpu/CMakeLists.txt index 82287d69f..e8e465610 100644 --- a/source/backend/cpu/CMakeLists.txt +++ b/source/backend/cpu/CMakeLists.txt @@ -50,3 +50,10 @@ IF(MNN_ARM82) ENDIF() ENDIF() +# Kleidi AI +IF(MNN_KLEIDIAI) + add_definitions(-DMNN_KLEIDIAI_ENABLED=1) + include(${CMAKE_CURRENT_LIST_DIR}/arm/kleidiAI/CMakeLists.txt) + list(APPEND MNN_TARGETS MNN_KleidiAI) + list(APPEND MNN_OBJECTS_TO_LINK $) +ENDIF() \ No newline at end of file diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 00e39fc30..23dcf5c9d 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -17,6 +17,10 @@ #include "core/BufferAllocator.hpp" #include "MNN_generated.h" +#ifdef MNN_KLEIDIAI_ENABLED +#include "arm/kleidiAI/MNNKleidiAI.h" +#endif + namespace MNN { class CPURuntime : public Runtime { public: diff --git a/source/backend/cpu/arm/kleidiAI/CMakeLists.txt b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt new file mode 100644 index 000000000..556da784a --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt @@ -0,0 +1,34 @@ +list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/MNNKleidiAI.cpp) +list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/MNNKleidiAI.h) + +include_directories( + ${CMAKE_CURRENT_LIST_DIR}/ + ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/ + ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/ + ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/ + ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/pack/) + +list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c) +list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c) +list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c) +list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c) + + +add_library( + MNN_KleidiAI + SHARED + ${MNN_KleidiAI_SOURCES} ${MNN_KleidiAI_HEADERS} +) + +# Enable ARMv8.6-A features +target_compile_definitions(MNN_KleidiAI PRIVATE + __ARM_FEATURE_MATMUL_INT8 + __ARM_FEATURE_BF16_VECTOR_ARITHMETIC + __ARM_FEATURE_BF16_SCALAR_ARITHMETIC + __ARM_BF16_FORMAT_ALTERNATIVE + __ARM_FEATURE_DOTPROD +) + +target_compile_options(MNN_KleidiAI + PRIVATE -march=armv8.6-a +) \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.cpp b/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.cpp new file mode 100644 index 000000000..465a6e4ad --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.cpp @@ -0,0 +1,215 @@ +#if defined(__aarch64__) + +#include "MNNKleidiAI.h" + +using namespace MNN; + +KleidiAI *KleidiAI::instance = NULL; + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +static void packQsi4cxpQsi8cxs1s0(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2); + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = k0_idx + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = k1_idx + n0_valid_idx * rhs_stride; + + int8_t byte0 = 0; + int8_t byte1 = 0; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + sums[nr_idx] += (int32_t)byte0 + (int32_t)byte1; + + const uint8_t dst_qs0 = (byte0 + rhs_zero_point) | ((byte1 + rhs_zero_point) << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * sizeof(float)); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} + +void KleidiAI::packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize) { + if(rowNum == 1) { + return; + } + + const size_t tmp_size = rowNum * rowSize * sizeof(float); + uint8_t *tmpBuffer = new uint8_t[tmp_size]; + memcpy(tmpBuffer, data, tmp_size); + + const float *src = (const float *)tmpBuffer; + float *dst = (float *)data; + + size_t blockNum = rowSize / 4; + size_t blockSize = 4 * sizeof(float); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const float *rowSrc = src + blockIndex * 4; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(dst, rowSrc, blockSize); + dst += 4; + rowSrc += rowSize; + } + } +} + +void KleidiAI::packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize) { + if(rowNum == 1) { + return; + } + + const size_t tmp_size = rowNum * rowSize * sizeof(float); + uint8_t *tmpBuffer = new uint8_t[tmp_size]; + memcpy(tmpBuffer, data, tmp_size); + + const float *src = (const float *)tmpBuffer; + float *dst = (float *)data; + + size_t blockNum = rowSize / 4; + size_t blockSize = 4 * sizeof(float); + + for(size_t blockIndex = 0; blockIndex < blockNum; blockIndex++) { + const float *rowSrc = src + blockIndex * 4 * rowNum; + float *block_dst = dst + blockIndex * 4; + for(size_t rowIndex = 0; rowIndex < rowNum; rowIndex++) { + memcpy(block_dst, rowSrc, blockSize); + block_dst += rowSize; + rowSrc += 4; + } + } +} + +//Lhs +size_t KleidiAI::getLhsQuantedPackedSize(size_t m, size_t k) { + return kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(m, k, getMr(m), getKr(), getSr()); +} + +size_t KleidiAI::getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k) { + return mIdx == 0 ? 0 : kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(mIdx, k, getMr(m), getKr(), getSr()); +} + +void KleidiAI::runLhsQuantPack(size_t m, size_t k, const void* lhs, void* lhsQuantedPacked) { + kai_run_lhs_quant_pack_qai8dxp_f32(m, k, getMr(m), getKr(), getSr(), 0, (const float *)lhs, k * sizeof(float), lhsQuantedPacked); +} + +//Rhs +size_t KleidiAI::getRhsPackedSize(size_t n, size_t k) { + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(n, k, getNr(), getKr(), getSr()); +} + +size_t KleidiAI::getRhsPackedOffset(size_t nIdx, size_t k) { + return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(nIdx, k, getNr(), getKr(), getSr()); +} + +void KleidiAI::runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4) { + struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params params; + params.lhs_zero_point = 1; + params.rhs_zero_point = 8; + if(!packedInt4) { + packQsi4cxpQsi8cxs1s0(1, n, k, getNr(), getKr(), getSr(), + (const uint8_t *)rhs, + (const float *)bias, (const float *)scale, + rhsPacked, + 0, ¶ms); + } else { + kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(1, n, k, getNr(), getKr(), getSr(), + (const uint8_t *)rhs, + (const float *)bias, (const float *)scale, + rhsPacked, + 0, ¶ms); + } +} + +//Matmul +void KleidiAI::runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst) { + if(m == 1) { //dotprod + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(m, n, k, + (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, + dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } else { //i8mm + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(m, n, k, + (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, + dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); + } +} + +#endif // defined(__aarch64__) \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.h b/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.h new file mode 100644 index 000000000..92ec40fdb --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.h @@ -0,0 +1,108 @@ +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include "kai_lhs_quant_pack_qai8dxp_f32.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" + +#include "./kai/kai_common.h" + +namespace MNN { + class KleidiAI { + public: + static KleidiAI &getInstance(bool bAsymmetric) { + if(!instance) { + instance = new KleidiAI(bAsymmetric); + } + return *instance; + } + + static KleidiAI &getInstance() { + if(!instance) { + instance = new KleidiAI; + } + return *instance; + } + + ~KleidiAI() {} + + struct KAIInfo { + bool kaiEnable = false; + bool asymmetric = false; //Asymmetric quantized model. + bool dot = false; //CPU support sdot. + bool i8mm = false; //CPU support i8mm. + }; + + //Kai util + void packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize); + void packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize); + + //Set info + void setEnable(bool enable) { mKAIInfo.kaiEnable = enable; } + void setModelAsymmetric(bool bAsymmetric) { mKAIInfo.asymmetric = bAsymmetric; } + + //Check + bool canAccelerate() { return (mKAIInfo.kaiEnable && mKAIInfo.dot && mKAIInfo.i8mm && !mKAIInfo.asymmetric); } + + //Get info + size_t getMr(size_t m = 1) { return (m == 1) ? mKaiMrDotprod : mKaiMrI8mm; } + size_t getNr() { return mKaiNr; } + size_t getKr() { return mKaiKr; } + size_t getSr() { return mKaiSr; } + size_t getMStep(size_t m = 1) { return (m == 1) ? mKaiMstepDotprod : mKaiMstepI8mm; } + size_t getNStep() { return mKaiNStep; } + size_t getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep) { return kai_roundup(totalVec / totalThread, minStep); } + + //Lhs + size_t getLhsQuantedPackedSize(size_t m, size_t k); + size_t getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k); + void runLhsQuantPack(size_t m, size_t k, const void* lhs, void* lhsQuantedPacked); + + //Rhs + size_t getRhsPackedSize(size_t n, size_t k); + size_t getRhsPackedOffset(size_t nIdx, size_t k); + void runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4 = false); + + //Dst + size_t getDstOffset(size_t mIdx, size_t nIdx, size_t n) { return (nIdx * sizeof(float)) + mIdx * (n * sizeof(float)); } + + //Matmul + void runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst); + + private: + KleidiAI(bool bAsymmetric = false) { + const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo(); + mKAIInfo.dot = gCPUInfo.dot; + mKAIInfo.i8mm = gCPUInfo.i8mm; + mKAIInfo.kaiEnable = true; + mKAIInfo.asymmetric = bAsymmetric; + } + + static KleidiAI *instance; + KAIInfo mKAIInfo; + + const size_t mKaiMstepDotprod = 1; + const size_t mKaiMstepI8mm = 8; + const size_t mKaiNStep = 4; + + const size_t mKaiMrDotprod = 1; + const size_t mKaiMrI8mm = 4; + const size_t mKaiNr = 4; + const size_t mKaiKr = 16; + const size_t mKaiSr = 2; + }; +} \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/kai/kai_common.h b/source/backend/cpu/arm/kleidiAI/kai/kai_common.h new file mode 100644 index 000000000..e8765cf38 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/kai_common.h @@ -0,0 +1,108 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include +#include +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +// NOLINTBEGIN(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) +// +// * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. +// * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected. +// * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting. +#define KAI_ERROR(msg) \ + do { \ + fflush(stdout); \ + fprintf(stderr, "%s", msg); \ + exit(EXIT_FAILURE); \ + } while (0) + +#define KAI_ASSERT_MSG(cond, msg) \ + do { \ + if (!(cond)) { \ + KAI_ERROR(msg); \ + } \ + } while (0) + +// NOLINTEND(cppcoreguidelines-avoid-do-while,cppcoreguidelines-pro-type-vararg,cert-err33-c) + +#define KAI_ASSERT(cond) KAI_ASSERT_MSG(cond, #cond) + +#define KAI_ASSERT_IF_MSG(precond, cond, msg) KAI_ASSERT_MSG(!(precond) || (cond), msg) +#define KAI_ASSERT_IF(precond, cond) KAI_ASSERT_IF_MSG(precond, cond, #precond " |-> " #cond) + +#define KAI_ASSUME_MSG KAI_ASSERT_MSG +#define KAI_ASSUME KAI_ASSERT +#define KAI_ASSUME_IF_MSG KAI_ASSERT_IF_MSG +#define KAI_ASSUME_IF KAI_ASSERT_IF + +#define KAI_UNUSED(x) (void)(x) +#define KAI_MIN(a, b) (((a) < (b)) ? (a) : (b)) +#define KAI_MAX(a, b) (((a) > (b)) ? (a) : (b)) + +/// KleidiAI data types +/// Format: (reserved)|(num-bytes)|(type)|(variant-type) +enum kai_datatype { + Unknown = 0x0000, + F32 = 0x0411, + F16 = 0x0212, + Bf16 = 0x0213, + Int32 = 0x0421, + Int16 = 0x0222, + Int8 = 0x0124, + Uint32 = 0x0431, + Uint16 = 0x0232, + Uint8 = 0x0134, + Bool = 0x0441 +}; + +/// Gets number of bytes for a given data type +/// @param[in] dt KleidiAI data type +/// +/// @return the numbers of bytes for the data type +inline static size_t kai_num_bytes_datatype(enum kai_datatype dt) { + return (size_t)(dt >> 8); +} + +/// Converts a scalar f16 value to f32 +/// @param[in] f16 The f16 value +/// +/// @return the f32 value +inline static float kai_f16_to_f32(uint16_t f16) { +#if defined(__ARM_NEON) + __fp16 f32 = 0; + memcpy(&f32, &f16, sizeof(uint16_t)); + return (float)f32; +#endif +} + +/// Converts a scalar f32 value to f16 +/// @param[in] f32 The f32 value +/// +/// @return the f16 value +inline static uint16_t kai_f32_to_f16(float f32) { +#if defined(__ARM_NEON) + uint16_t f16 = 0; + __fp16 tmp = f32; + memcpy(&f16, &tmp, sizeof(uint16_t)); + return f16; +#endif +} + +inline static size_t kai_roundup(size_t a, size_t b) { + return ((a + b - 1) / b) * b; +} + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c new file mode 100644 index 000000000..8f1479dcd --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c @@ -0,0 +1,228 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_DOTPROD) +#error "Dotprod extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 1; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 1; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* restrict lhs_packed, const void* restrict rhs_packed, float* restrict dst, + size_t dst_stride_row, size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t kai_k0 = kai_kr * kai_sr; + + const size_t num_rows = m; + const size_t num_cols = n; + + const size_t lhs_packed_stride = kai_lhs_packed_stride(k); + const size_t k_internal = kai_k_roundedup(k); + + const int8x16_t nibble_mask = vdupq_n_s8(0xF0); + + const uint8_t* lhs_ptr_start = lhs_packed; + + for (size_t row_idx = 0; row_idx < num_rows; row_idx += kai_mr) { + const uint8_t* rhs_ptr = rhs_packed; + for (size_t col_idx = 0; col_idx < num_cols; col_idx += kai_nr) { + const uint8_t* lhs_ptr = lhs_ptr_start; + + // Main f32 accumulator + int32x4_t iacc0011 = vdupq_n_s32(0); + int32x4_t iacc2233 = vdupq_n_s32(0); + + for (size_t b = 0; b < k_internal; b += kai_k0) { + // Set up RHS + const int8x16_t rhs_raw_vec_0 = vld1q_s8((const int8_t*)(rhs_ptr + 0)); + const int8x16_t rhs_raw_vec_1 = vld1q_s8((const int8_t*)(rhs_ptr + 16)); + const int8x16_t rhs_raw_vec_2 = vld1q_s8((const int8_t*)(rhs_ptr + 32)); + const int8x16_t rhs_raw_vec_3 = vld1q_s8((const int8_t*)(rhs_ptr + 48)); + + // Low nibble + const int8x16_t rhs_vec_0_0 = vshlq_n_s8(rhs_raw_vec_0, 4); + const int8x16_t rhs_vec_1_0 = vshlq_n_s8(rhs_raw_vec_1, 4); + const int8x16_t rhs_vec_2_0 = vshlq_n_s8(rhs_raw_vec_2, 4); + const int8x16_t rhs_vec_3_0 = vshlq_n_s8(rhs_raw_vec_3, 4); + + // High nibble + const int8x16_t rhs_vec_0_1 = vandq_s8(rhs_raw_vec_0, nibble_mask); + const int8x16_t rhs_vec_1_1 = vandq_s8(rhs_raw_vec_1, nibble_mask); + const int8x16_t rhs_vec_2_1 = vandq_s8(rhs_raw_vec_2, nibble_mask); + const int8x16_t rhs_vec_3_1 = vandq_s8(rhs_raw_vec_3, nibble_mask); + + const int8x16_t lhs_vec_0 = vld1q_s8((const int8_t*)(lhs_ptr + 0)); + const int8x16_t lhs_vec_1 = vld1q_s8((const int8_t*)(lhs_ptr + 16)); + + lhs_ptr += 32; + rhs_ptr += 64; + + int8x16_t t; + + t = vcombine_s8(vget_low_s8(lhs_vec_0), vget_low_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_0, t); + t = vcombine_s8(vget_high_s8(lhs_vec_0), vget_high_s8(lhs_vec_0)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_0, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_0, t); + t = vcombine_s8(vget_low_s8(lhs_vec_1), vget_low_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_0_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_1_1, t); + t = vcombine_s8(vget_high_s8(lhs_vec_1), vget_high_s8(lhs_vec_1)); + iacc0011 = vdotq_s32(iacc0011, rhs_vec_2_1, t); + iacc2233 = vdotq_s32(iacc2233, rhs_vec_3_1, t); + } + + int32x4_t iacc = vpaddq_s32(iacc0011, iacc2233); + + // LHS offset + const int32x4_t lhs_offset = vld1q_dup_s32((const int32_t*)lhs_ptr); + lhs_ptr += sizeof(int32_t); + + // LHS scale + const float32x4_t lhs_scale = vld1q_dup_f32((const float*)lhs_ptr); + lhs_ptr += sizeof(float); + + // RHS sum values + const int32x4_t sum_n_s32 = vld1q_s32((const int32_t*)(rhs_ptr)); + rhs_ptr += sizeof(int32x4_t); + + // RHS scale + const float32x4_t rhs_scale = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // Load the bias + const float32x4_t bias0 = vld1q_f32((const float*)rhs_ptr); + rhs_ptr += sizeof(float32x4_t); + + // Add the reduction sum + iacc = vmlaq_s32(iacc, sum_n_s32, lhs_offset); + + float32x4_t main_acc = vmulq_f32(vcvtq_f32_s32(iacc), rhs_scale); + + main_acc = vmulq_f32(main_acc, lhs_scale); + + // Add the bias + main_acc = vaddq_f32(main_acc, bias0); + + // clamp (min-max) operation + const float32x4_t vmin_f32 = vdupq_n_f32(scalar_min); + const float32x4_t vmax_f32 = vdupq_n_f32(scalar_max); + + main_acc = vmaxq_f32(main_acc, vmin_f32); + main_acc = vminq_f32(main_acc, vmax_f32); + + if (col_idx + kai_nr <= n) { + vst1q_f32((float*)((uint8_t*)dst + col_idx * sizeof(float) + row_idx * dst_stride_row), main_acc); + } else { + size_t leftover = n % kai_nr; + *(float*)((uint8_t*)dst + (col_idx + 0) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 0); + if (leftover > 1) { + *(float*)((uint8_t*)dst + (col_idx + 1) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 1); + } + if (leftover > 2) { + *(float*)((uint8_t*)dst + (col_idx + 2) * sizeof(float) + row_idx * dst_stride_row) = + vgetq_lane_f32(main_acc, 2); + } + } + } + lhs_ptr_start += lhs_packed_stride; + } +} +#endif // Architectural feature check diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h new file mode 100644 index 000000000..b6d2cf87e --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the mr value, which must be used to pack the LHS matrix with +/// the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the nr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the kr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the sr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t n_idx, size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the destination offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes for the destination matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix +/// @param[in] n Number of columns in the destination (DST) matrix +/// +/// @return the DST size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. +/// Output tile: (rows x cols) = 1 x 4 +/// Accumulation performed in a single for loop: 64 +/// Instruction used: dotprod +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. +/// @param[in] lhs_packed The LHS matrix packed. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 +/// @param[out] dst Result of the vector-by-matrix +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c new file mode 100644 index 000000000..50e260d89 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c @@ -0,0 +1,507 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#if !defined(__ARM_FEATURE_MATMUL_INT8) +#error "I8mm extension required to compile this micro-kernel" +#else +#include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" + +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_m_step = 8; +static const size_t kai_n_step = 4; +static const size_t kai_mr = 4; +static const size_t kai_nr = 4; +static const size_t kai_kr = 16; +static const size_t kai_sr = 2; +static const size_t kai_num_bytes_multiplier_lhs = sizeof(float); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_offset_lhs = sizeof(int32_t); +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kai_kr * kai_sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_mr * (k_internal * sizeof(int8_t) + kai_num_bytes_multiplier_lhs + kai_num_bytes_offset_lhs); +} + +inline static size_t kai_rhs_packed_stride(size_t k) { + const size_t k_internal = kai_k_roundedup(k); + + KAI_ASSERT((k_internal % 2) == 0); + + return kai_nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_m_step; +} + +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_n_step; +} + +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_mr; +} + +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_nr; +} + +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_kr; +} + +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void) { + return kai_sr; +} + +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + + return (m_idx / kai_m_step) * kai_lhs_packed_stride(k); +} + +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k) { + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx / kai_n_step) * kai_rhs_packed_stride(k); +} + +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride) { + KAI_ASSERT((m_idx % kai_m_step) == 0); + KAI_ASSERT((n_idx % kai_n_step) == 0); + + return (n_idx * sizeof(float)) + m_idx * dst_stride; +} + +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n) { + return m * n * sizeof(float); +} + +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max) { + KAI_ASSERT(dst_stride_col == sizeof(float)); + + if (m == 0) { + return; + } + + const size_t k_internal = kai_k_roundedup(k); + + size_t num_blocks = k_internal / 32; + + float clamp_vals[2] = {scalar_min, scalar_max}; + + __asm__ __volatile__( + "mov x12, %x[m]\n" + "mov x11, #0x80\n" + "movi v11.16b, #0xf0\n" + "mov x20, #0x20\n" + "cmp x12, #0x8\n" + "madd x11, %x[num_blocks], x11, x20\n" + "blt 10f\n" + "1:" // Row loop + "mov x10, %x[rhs_packed]\n" + "mov x9, %x[n]\n" + "add x28, %x[dst], %x[dst_stride_row], LSL #3\n" + "2:" // Column loop + "mov x22, %x[lhs_packed]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x21, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "movi v6.4s, #0x0\n" + "movi v5.4s, #0x0\n" + "add x20, x22, x11\n" + "movi v4.4s, #0x0\n" + "movi v3.4s, #0x0\n" + "3:" // Sub block loop + "ldr q2, [x10, #0x0]\n" + "ldr q1, [x10, #0x10]\n" + "subs x21, x21, #0x1\n" + "ldr q20, [x22, #0x0]\n" + "ldr q19, [x22, #0x10]\n" + "ldr q18, [x20, #0x0]\n" + "ldr q0, [x20, #0x10]\n" + "ldr q31, [x10, #0x20]\n" + "ldr q30, [x10, #0x30]\n" + "shl v17.16b, v2.16b, #0x4\n" + "shl v16.16b, v1.16b, #0x4\n" + "ldr q29, [x22, #0x20]\n" + "ldr q28, [x22, #0x30]\n" + "and v2.16b, v2.16b, v11.16b\n" + "and v1.16b, v1.16b, v11.16b\n" + "ldr q27, [x20, #0x20]\n" + "ldr q26, [x20, #0x30]\n" + "add x10, x10, #0x40\n" + "ldr q25, [x22, #0x40]\n" + "ldr q24, [x22, #0x50]\n" + ".inst 0x4e91a68a // smmla v10.4s, v20.16b, v17.16b\n" + ".inst 0x4e90a689 // smmla v9.4s, v20.16b, v16.16b\n" + "ldr q23, [x20, #0x40]\n" + "ldr q22, [x20, #0x50]\n" + ".inst 0x4e91a668 // smmla v8.4s, v19.16b, v17.16b\n" + ".inst 0x4e90a667 // smmla v7.4s, v19.16b, v16.16b\n" + "ldr q21, [x22, #0x60]\n" + "ldr q20, [x22, #0x70]\n" + ".inst 0x4e91a646 // smmla v6.4s, v18.16b, v17.16b\n" + ".inst 0x4e90a645 // smmla v5.4s, v18.16b, v16.16b\n" + "ldr q19, [x20, #0x60]\n" + "ldr q18, [x20, #0x70]\n" + ".inst 0x4e91a404 // smmla v4.4s, v0.16b, v17.16b\n" + ".inst 0x4e90a403 // smmla v3.4s, v0.16b, v16.16b\n" + "shl v17.16b, v31.16b, #0x4\n" + "shl v16.16b, v30.16b, #0x4\n" + "add x22, x22, #0x80\n" + "add x20, x20, #0x80\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + ".inst 0x4e91a7aa // smmla v10.4s, v29.16b, v17.16b\n" + ".inst 0x4e90a7a9 // smmla v9.4s, v29.16b, v16.16b\n" + ".inst 0x4e91a788 // smmla v8.4s, v28.16b, v17.16b\n" + ".inst 0x4e90a787 // smmla v7.4s, v28.16b, v16.16b\n" + ".inst 0x4e91a766 // smmla v6.4s, v27.16b, v17.16b\n" + ".inst 0x4e90a765 // smmla v5.4s, v27.16b, v16.16b\n" + ".inst 0x4e91a744 // smmla v4.4s, v26.16b, v17.16b\n" + ".inst 0x4e90a743 // smmla v3.4s, v26.16b, v16.16b\n" + ".inst 0x4e82a72a // smmla v10.4s, v25.16b, v2.16b\n" + ".inst 0x4e81a729 // smmla v9.4s, v25.16b, v1.16b\n" + ".inst 0x4e82a708 // smmla v8.4s, v24.16b, v2.16b\n" + ".inst 0x4e81a707 // smmla v7.4s, v24.16b, v1.16b\n" + ".inst 0x4e82a6e6 // smmla v6.4s, v23.16b, v2.16b\n" + ".inst 0x4e81a6e5 // smmla v5.4s, v23.16b, v1.16b\n" + ".inst 0x4e82a6c4 // smmla v4.4s, v22.16b, v2.16b\n" + ".inst 0x4e81a6c3 // smmla v3.4s, v22.16b, v1.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9fa666 // smmla v6.4s, v19.16b, v31.16b\n" + ".inst 0x4e9ea665 // smmla v5.4s, v19.16b, v30.16b\n" + ".inst 0x4e9fa644 // smmla v4.4s, v18.16b, v31.16b\n" + ".inst 0x4e9ea643 // smmla v3.4s, v18.16b, v30.16b\n" + "bgt 3b\n" + "ldr q25, [x10, #0x0]\n" + "ld1 { v17.4s }, [x22]\n" + "uzp1 v23.2d, v10.2d, v9.2d\n" + "uzp2 v22.2d, v10.2d, v9.2d\n" + "ldr q24, [x10, #0x10]\n" + "uzp1 v21.2d, v8.2d, v7.2d\n" + "uzp2 v20.2d, v8.2d, v7.2d\n" + "add x22, x22, #0x10\n" + "ldr q16, [x22, #0x0]\n" + "add x10, x10, #0x20\n" + "mla v23.4s, v25.4s, v17.s[0]\n" + "mla v22.4s, v25.4s, v17.s[1]\n" + "mla v21.4s, v25.4s, v17.s[2]\n" + "mla v20.4s, v25.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v23.4s, v19.4s\n" + "fmul v9.4s, v22.4s, v18.4s\n" + "fmul v8.4s, v21.4s, v17.4s\n" + "fmul v7.4s, v20.4s, v16.4s\n" + "ld1 { v17.4s }, [x20]\n" + "uzp1 v23.2d, v6.2d, v5.2d\n" + "uzp2 v22.2d, v6.2d, v5.2d\n" + "add x20, x20, #0x10\n" + "ldr q16, [x20, #0x0]\n" + "uzp1 v21.2d, v4.2d, v3.2d\n" + "uzp2 v20.2d, v4.2d, v3.2d\n" + "mla v23.4s, v25.4s, v17.s[0]\n" + "mla v22.4s, v25.4s, v17.s[1]\n" + "mla v21.4s, v25.4s, v17.s[2]\n" + "mla v20.4s, v25.4s, v17.s[3]\n" + "fmul v19.4s, v24.4s, v16.s[0]\n" + "fmul v18.4s, v24.4s, v16.s[1]\n" + "fmul v17.4s, v24.4s, v16.s[2]\n" + "scvtf v23.4s, v23.4s\n" + "fmul v16.4s, v24.4s, v16.s[3]\n" + "scvtf v22.4s, v22.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v6.4s, v23.4s, v19.4s\n" + "fmul v5.4s, v22.4s, v18.4s\n" + "fmul v4.4s, v21.4s, v17.4s\n" + "fmul v3.4s, v20.4s, v16.4s\n" + "ldr q18, [x10, #0x0]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x9, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x10, x10, #0x10\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" + "fadd v6.4s, v6.4s, v18.4s\n" + "fadd v5.4s, v5.4s, v18.4s\n" + "fadd v4.4s, v4.4s, v18.4s\n" + "fadd v3.4s, v3.4s, v18.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmax v6.4s, v6.4s, v17.4s\n" + "fmax v5.4s, v5.4s, v17.4s\n" + "fmax v4.4s, v4.4s, v17.4s\n" + "fmax v3.4s, v3.4s, v17.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "fmin v6.4s, v6.4s, v16.4s\n" + "fmin v5.4s, v5.4s, v16.4s\n" + "fmin v4.4s, v4.4s, v16.4s\n" + "fmin v3.4s, v3.4s, v16.4s\n" + "blt 6f\n" + "mov x20, %x[dst]\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q9, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q7, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q6, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q5, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q4, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "str q3, [x20, #0x0]\n" + "b 9f\n" + "6:" // Partial output + "mov x27, %x[dst]\n" + "add x26, x27, %x[dst_stride_row], LSL #2\n" + "add x25, x26, %x[dst_stride_row], LSL #1\n" + "add x24, x26, %x[dst_stride_row]\n" + "add x23, x25, %x[dst_stride_row]\n" + "add x22, x27, %x[dst_stride_row], LSL #1\n" + "add x21, x27, %x[dst_stride_row]\n" + "add x20, x22, %x[dst_stride_row]\n" + "tbz x9, #1, 7f\n" + "st1 { v3.d }[0], [x23], #0x8\n" + "st1 { v4.d }[0], [x25], #0x8\n" + "st1 { v5.d }[0], [x24], #0x8\n" + "st1 { v6.d }[0], [x26], #0x8\n" + "st1 { v7.d }[0], [x20], #0x8\n" + "st1 { v8.d }[0], [x22], #0x8\n" + "st1 { v9.d }[0], [x21], #0x8\n" + "st1 { v10.d }[0], [x27], #0x8\n" + "tbz x9, #0, 8f\n" + "st1 { v3.s }[2], [x23]\n" + "st1 { v4.s }[2], [x25]\n" + "st1 { v5.s }[2], [x24]\n" + "st1 { v6.s }[2], [x26]\n" + "st1 { v7.s }[2], [x20]\n" + "st1 { v8.s }[2], [x22]\n" + "st1 { v9.s }[2], [x21]\n" + "st1 { v10.s }[2], [x27]\n" + "b 8f\n" + "7:" // Output block 0: partial_1_0 + "st1 { v3.s }[0], [x23]\n" + "st1 { v4.s }[0], [x25]\n" + "st1 { v5.s }[0], [x24]\n" + "st1 { v6.s }[0], [x26]\n" + "st1 { v7.s }[0], [x20]\n" + "st1 { v8.s }[0], [x22]\n" + "st1 { v9.s }[0], [x21]\n" + "st1 { v10.s }[0], [x27]\n" + "8:" // Output block 0: Done + "9:" // Output stage exit + "subs x9, x9, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 2b\n" + "mov x20, #0x2\n" + "sub x12, x12, #0x8\n" + "cmp x12, #0x8\n" + "mov %x[dst], x28\n" + "madd %x[lhs_packed], x20, x11, %x[lhs_packed]\n" + "bge 1b\n" + "10:" // Row loop skip + "cbz x12, 19f\n" + "11:" // Row tail: Row loop + "mov x26, %x[rhs_packed]\n" + "mov x25, %x[n]\n" + "add x24, %x[dst], %x[dst_stride_row], LSL #2\n" + "12:" // Row tail: Column loop + "mov x22, %x[lhs_packed]\n" + "movi v10.4s, #0x0\n" + "movi v9.4s, #0x0\n" + "mov x20, %x[num_blocks]\n" + "movi v8.4s, #0x0\n" + "movi v7.4s, #0x0\n" + "13:" // Row tail: Sub block loop + "ldr q31, [x26, #0x0]\n" + "ldr q30, [x26, #0x10]\n" + "subs x20, x20, #0x1\n" + "ldr q29, [x22, #0x0]\n" + "ldr q28, [x22, #0x10]\n" + "ldr q27, [x26, #0x20]\n" + "ldr q26, [x26, #0x30]\n" + "add x26, x26, #0x40\n" + "ldr q25, [x22, #0x20]\n" + "ldr q24, [x22, #0x30]\n" + "shl v23.16b, v31.16b, #0x4\n" + "shl v22.16b, v30.16b, #0x4\n" + "ldr q21, [x22, #0x40]\n" + "ldr q20, [x22, #0x50]\n" + "and v31.16b, v31.16b, v11.16b\n" + "and v30.16b, v30.16b, v11.16b\n" + "ldr q19, [x22, #0x60]\n" + "ldr q18, [x22, #0x70]\n" + "shl v17.16b, v27.16b, #0x4\n" + "shl v16.16b, v26.16b, #0x4\n" + ".inst 0x4e97a7aa // smmla v10.4s, v29.16b, v23.16b\n" + ".inst 0x4e96a7a9 // smmla v9.4s, v29.16b, v22.16b\n" + "and v27.16b, v27.16b, v11.16b\n" + "add x22, x22, #0x80\n" + ".inst 0x4e97a788 // smmla v8.4s, v28.16b, v23.16b\n" + ".inst 0x4e96a787 // smmla v7.4s, v28.16b, v22.16b\n" + "and v26.16b, v26.16b, v11.16b\n" + ".inst 0x4e91a72a // smmla v10.4s, v25.16b, v17.16b\n" + ".inst 0x4e90a729 // smmla v9.4s, v25.16b, v16.16b\n" + ".inst 0x4e91a708 // smmla v8.4s, v24.16b, v17.16b\n" + ".inst 0x4e90a707 // smmla v7.4s, v24.16b, v16.16b\n" + ".inst 0x4e9fa6aa // smmla v10.4s, v21.16b, v31.16b\n" + ".inst 0x4e9ea6a9 // smmla v9.4s, v21.16b, v30.16b\n" + ".inst 0x4e9fa688 // smmla v8.4s, v20.16b, v31.16b\n" + ".inst 0x4e9ea687 // smmla v7.4s, v20.16b, v30.16b\n" + ".inst 0x4e9ba66a // smmla v10.4s, v19.16b, v27.16b\n" + ".inst 0x4e9aa669 // smmla v9.4s, v19.16b, v26.16b\n" + ".inst 0x4e9ba648 // smmla v8.4s, v18.16b, v27.16b\n" + ".inst 0x4e9aa647 // smmla v7.4s, v18.16b, v26.16b\n" + "bgt 13b\n" + "ldr q18, [x26, #0x0]\n" + "ld1 { v17.4s }, [x22]\n" + "uzp1 v24.2d, v10.2d, v9.2d\n" + "uzp2 v23.2d, v10.2d, v9.2d\n" + "ldr q22, [x26, #0x10]\n" + "uzp1 v21.2d, v8.2d, v7.2d\n" + "uzp2 v20.2d, v8.2d, v7.2d\n" + "add x22, x22, #0x10\n" + "ldr q16, [x22, #0x0]\n" + "add x26, x26, #0x20\n" + "mla v24.4s, v18.4s, v17.s[0]\n" + "mla v23.4s, v18.4s, v17.s[1]\n" + "mla v21.4s, v18.4s, v17.s[2]\n" + "mla v20.4s, v18.4s, v17.s[3]\n" + "fmul v19.4s, v22.4s, v16.s[0]\n" + "fmul v18.4s, v22.4s, v16.s[1]\n" + "fmul v17.4s, v22.4s, v16.s[2]\n" + "fmul v16.4s, v22.4s, v16.s[3]\n" + "scvtf v24.4s, v24.4s\n" + "scvtf v23.4s, v23.4s\n" + "scvtf v21.4s, v21.4s\n" + "scvtf v20.4s, v20.4s\n" + "fmul v10.4s, v24.4s, v19.4s\n" + "fmul v9.4s, v23.4s, v18.4s\n" + "fmul v8.4s, v21.4s, v17.4s\n" + "fmul v7.4s, v20.4s, v16.4s\n" + "ldr q18, [x26, #0x0]\n" + "ld1r { v17.4s }, [%x[clamp_vals]]\n" + "add x20, %x[clamp_vals], #0x4\n" + "cmp x25, #0x4\n" + "ld1r { v16.4s }, [x20]\n" + "add x26, x26, #0x10\n" + "fadd v10.4s, v10.4s, v18.4s\n" + "fadd v9.4s, v9.4s, v18.4s\n" + "fadd v8.4s, v8.4s, v18.4s\n" + "fadd v7.4s, v7.4s, v18.4s\n" + "fmax v10.4s, v10.4s, v17.4s\n" + "fmax v9.4s, v9.4s, v17.4s\n" + "fmax v8.4s, v8.4s, v17.4s\n" + "fmax v7.4s, v7.4s, v17.4s\n" + "fmin v10.4s, v10.4s, v16.4s\n" + "fmin v9.4s, v9.4s, v16.4s\n" + "fmin v8.4s, v8.4s, v16.4s\n" + "fmin v7.4s, v7.4s, v16.4s\n" + "blt 15f\n" + "mov x20, %x[dst]\n" + "cmp x12, #0x1\n" + "str q10, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "cmp x12, #0x2\n" + "str q9, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "cmp x12, #0x3\n" + "str q8, [x20, #0x0]\n" + "add x20, x20, %x[dst_stride_row]\n" + "ble 18f\n" + "str q7, [x20, #0x0]\n" + "b 18f\n" + "15:" // Row tail: Partial output + "mov x23, %x[dst]\n" + "cmp x12, #0x1\n" + "add x22, x23, %x[dst_stride_row]\n" + "csel x22, x22, x23, GT\n" + "cmp x12, #0x2\n" + "add x21, x23, %x[dst_stride_row], LSL #1\n" + "csel x21, x21, x22, GT\n" + "cmp x12, #0x3\n" + "add x20, x21, %x[dst_stride_row]\n" + "csel x20, x20, x21, GT\n" + "tbz x25, #1, 16f\n" + "st1 { v7.d }[0], [x20], #0x8\n" + "st1 { v8.d }[0], [x21], #0x8\n" + "st1 { v9.d }[0], [x22], #0x8\n" + "st1 { v10.d }[0], [x23], #0x8\n" + "tbz x25, #0, 17f\n" + "st1 { v7.s }[2], [x20]\n" + "st1 { v8.s }[2], [x21]\n" + "st1 { v9.s }[2], [x22]\n" + "st1 { v10.s }[2], [x23]\n" + "b 17f\n" + "16:" // Row tail: Output block 0: partial_1_0 + "st1 { v7.s }[0], [x20]\n" + "st1 { v8.s }[0], [x21]\n" + "st1 { v9.s }[0], [x22]\n" + "st1 { v10.s }[0], [x23]\n" + "17:" // Row tail: Output block 0: Done + "18:" // Row tail: Output stage exit + "subs x25, x25, #0x4\n" + "add %x[dst], %x[dst], #0x10\n" + "bgt 12b\n" + "subs x12, x12, #0x4\n" + "add %x[lhs_packed], %x[lhs_packed], x11\n" + "mov %x[dst], x24\n" + "bgt 11b\n" + "19:" // Row tail: Row loop skip + : [dst] "+&r"(dst), [lhs_packed] "+&r"(lhs_packed) + : [clamp_vals] "r"(clamp_vals), [dst_stride_row] "r"(dst_stride_row), [m] "r"(m), [n] "r"(n), + [num_blocks] "r"(num_blocks), [rhs_packed] "r"(rhs_packed) + : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v16", "v17", "v18", + "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "x9", "x10", "x11", + "x12", "x20", "x21", "x22", "x23", "x24", "x25", "x26", "x27", "x28"); +} +#endif // Architectural feature check diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h new file mode 100644 index 000000000..c5a4553fc --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h @@ -0,0 +1,125 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Micro-kernel dependencies +/// +/// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix +/// -# kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix + +/// -------------------------------------------------- + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @return the m step value +size_t kai_get_m_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @return the n step +size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the mr value, which must be used to pack the LHS matrix with +/// the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel +/// +/// @return the mr value +size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Function to get the nr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the nr value +size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the kr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the kr value +size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the sr value, which must be used to pack the RHS matrix with +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// +/// @return the sr value +size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qai8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). It must be a multiple of 8 +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m_idx, size_t k); + +/// Gets the offset in bytes for the packed RHS matrix, +/// which contains the packed 4-bit quantized symmetric per-channel (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of 4. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t n_idx, size_t k); + +/// Gets the offset in bytes for the DST matrix +/// +/// @param[in] m_idx Row index in the DST matrix. It must be a multiple of 8. +/// @param[in] n_idx Column index in the DST matrix. It must be multiple of 4. +/// @param[in] dst_stride The number of bytes in in each row of the DST matrix +/// +/// @return the DST offset in bytes +size_t kai_get_dst_offset_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m_idx, size_t n_idx, size_t dst_stride); + +/// Gets the size in bytes for the destination matrix. +/// +/// @param[in] m Number of rows in the destination (DST) matrix. +/// @param[in] n Number of columns in the destination (DST) matrix. +/// +/// @return the destination size in bytes +size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(size_t m, size_t n); + +/// Runs the matrix multiplication (matmul) micro-kernel followed by a clamp (min-max) operation. +/// +/// LHS matrix: Signed 8-bit quantized asymmetric per-row (qai8dx) and packed +/// RHS matrix: Signed 4-bit quantized symmetric per-channel (qsu4cx) and packed. +/// Output tile: (rows x cols) = 8 x 4 +/// Accumulation performed in a single for loop: 32 +/// Instruction used: i8mm +/// +/// @param[in] m The number of output rows written. +/// @param[in] n The number of output columns written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. +/// @param[in] lhs_packed The LHS matrix packed. +/// When the activation are dynamically quantized, you can obtain this matrix +/// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs +/// both the dynamic quantization to 8-bit and activation packing in a single step. +/// @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref +/// kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 +/// @param[out] dst Result of the vector-by-matrix +/// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. +/// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) +/// @param[in] scalar_min Min value used to clamp the final result. +/// @param[in] scalar_max Max value used to clamp the final result. +void kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( + size_t m, size_t n, size_t k, const void* lhs_packed, const void* rhs_packed, float* dst, size_t dst_stride_row, + size_t dst_stride_col, float scalar_min, float scalar_max); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c new file mode 100644 index 000000000..20deb99d1 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c @@ -0,0 +1,179 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_lhs_quant_pack_qai8dxp_f32.h" + +#if defined(__aarch64__) +#include +#endif +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_per_multiplier = sizeof(float); +static const size_t kai_num_bytes_per_offset = sizeof(int32_t); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for memory alignment. + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +inline static size_t kai_lhs_packed_stride(size_t k, size_t mr, size_t kr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return mr * (k_internal * sizeof(int8_t) + kai_num_bytes_per_multiplier + kai_num_bytes_per_offset); +} + +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr) { + KAI_UNUSED(mr); + return 1; +} + +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride) { + return m_idx * lhs_stride; +} + +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr) { + // It always points to the beginning of the row + return (m_idx / mr) * kai_lhs_packed_stride(k, mr, kr, sr); +} + +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(m, mr) / mr; + + return num_rows * kai_lhs_packed_stride(k, mr, kr, sr); +} + +void kai_run_lhs_quant_pack_qai8dxp_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* restrict lhs, + size_t lhs_stride, void* restrict lhs_packed) { + KAI_ASSERT((kr % sr) == 0); + + if (m == 0) { + return; + } + + const size_t num_rows = m; + + const float* src_ptr = lhs; + + const size_t dst_stride = kai_lhs_packed_stride(k, mr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const int32_t k_block_len = (int32_t)(kr / sr); + + for (size_t row_idx = 0; row_idx < num_rows; ++row_idx) { + float max0 = -FLT_MAX; + float min0 = FLT_MAX; + + // Find min/max for each channel + int32_t k_idx = 0; + +#if defined(__aarch64__) + float32x4_t vmax0 = vdupq_n_f32(-FLT_MAX); + float32x4_t vmin0 = vdupq_n_f32(FLT_MAX); + + for (; k_idx <= ((int32_t)k - 8); k_idx += 8) { + const float32x4_t src0_0 = vld1q_f32(src_ptr + 0 + (size_t)k_idx); + const float32x4_t src0_1 = vld1q_f32(src_ptr + 4 + (size_t)k_idx); + + // Calculate the max + vmax0 = vmaxq_f32(src0_0, vmax0); + vmax0 = vmaxq_f32(vmax0, src0_1); + + // Calculate the min + vmin0 = vminq_f32(src0_0, vmin0); + vmin0 = vminq_f32(vmin0, src0_1); + } + // Get the max/min + max0 = vmaxvq_f32(vmax0); + min0 = vminvq_f32(vmin0); +#endif + for (; k_idx < (int32_t)k; ++k_idx) { + const float src0_0 = *(src_ptr + (size_t)k_idx); + max0 = KAI_MAX(src0_0, max0); + min0 = KAI_MIN(src0_0, min0); + } + + // Maximum/minimum int8 values + const float qmin = (float)INT8_MIN; + const float qmax = (float)INT8_MAX; + + const float rmin0 = KAI_MIN(0.0F, min0); + const float rmax0 = KAI_MAX(0.0F, max0); + + const float scale0 = rmin0 == rmax0 ? 1.F : (qmax - qmin) / (rmax0 - rmin0); + + // Reciprocal to quantize + const float recip_scale0 = scale0 ? 1.0F / scale0 : 0.0F; + + const float descaled_min0 = rmin0 * scale0; + const float descaled_max0 = rmax0 * scale0; + + const float zero_point_from_min_error0 = qmin + descaled_min0; + const float zero_point_from_max_error0 = qmax + descaled_max0; + + float zero_point0 = + zero_point_from_min_error0 + zero_point_from_max_error0 > 0 ? qmin - descaled_min0 : qmax - descaled_max0; + + zero_point0 = KAI_MAX(zero_point0, qmin); + zero_point0 = KAI_MIN(zero_point0, qmax); + + // Round to nearest integer + const int32_t nudged_zero_point0 = (int32_t)rintf(zero_point0); + + const size_t dst_x = ((row_idx + m_idx_start) % mr); + + uint8_t* dst_ptr = (uint8_t*)lhs_packed + dst_x * k_block_len * sizeof(int8_t); + + // Quantize the channels + k_idx = 0; + for (; k_idx < (int32_t)k_internal; k_idx += k_block_len) { + for (size_t k_block_idx = 0; k_block_idx < (size_t)k_block_len; ++k_block_idx) { + // Clamp at the last valid k-index + const size_t k_idx_start = KAI_MIN((size_t)k_idx + k_block_idx, k - 1); + + const float src0_0 = *(src_ptr + k_idx_start); + + // Scale the values + int32_t v0_s32 = (int32_t)(roundf(src0_0 * scale0)); + + v0_s32 = v0_s32 + nudged_zero_point0; + v0_s32 = KAI_MAX(v0_s32, INT8_MIN); + v0_s32 = KAI_MIN(v0_s32, INT8_MAX); + *((int8_t*)(dst_ptr)) = (int8_t)v0_s32; + dst_ptr += sizeof(int8_t); + } + dst_ptr += (mr - 1) * k_block_len * sizeof(int8_t); + } + + dst_ptr = (uint8_t*)lhs_packed + mr * (k_internal * sizeof(int8_t)); + + dst_ptr += dst_x * kai_num_bytes_per_offset; + + // LHS offset at the beginning of the row + *((int32_t*)(dst_ptr)) = -nudged_zero_point0; + + // Assuming the same sizeof() for kai_num_bytes_per_offset and kai_num_bytes_per_multiplier + KAI_ASSERT(kai_num_bytes_per_offset == kai_num_bytes_per_multiplier); + + dst_ptr += mr * kai_num_bytes_per_offset; + + // Store the scale quantization params + *((float*)(dst_ptr)) = recip_scale0; + + src_ptr += (lhs_stride / sizeof(float)); + + // Move to the next row if we have interleaved all Mr rows + if ((((row_idx + 1) + m_idx_start) % mr) == 0) { + lhs_packed = (void*)((uint8_t*)lhs_packed + dst_stride); + } + } +} diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h new file mode 100644 index 000000000..acba70cd6 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.h @@ -0,0 +1,77 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +/// Gets the m step value. +/// The micro-kernel can process any M values. However, the starting M index to +/// be processed must be a multiple of m step. +/// +/// @param[in] mr The number of M rows to interleave on the same output row. +/// +/// @return the m step value +size_t kai_get_m_step_lhs_quant_pack_qai8dxp_f32(size_t mr); + +/// Gets the offset in bytes for the LHS matrix (not packed) +/// +/// This function should be called before passing the pointer to the LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] lhs_stride The number of bytes in in each row of the LHS matrix (not packed) +/// +/// @return the offset in bytes to the LHS matrix +size_t kai_get_lhs_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t lhs_stride); + +/// Gets the offset in bytes for the packed LHS matrix, +/// which contains the packed 8-bit quantized asymmetric per-row (qa8dx) values. +/// +/// This function should be called before passing the pointer to the packed LHS matrix to the micro-kernel. +/// +/// @param[in] m_idx Row index in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed LHS matrix +size_t kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(size_t m_idx, size_t k, size_t mr, size_t kr, size_t sr); + +/// Gets the size in bytes for the quantized and packed LHS matrix +/// +/// @param[in] m Total number of rows in the LHS matrix (not packed). +/// @param[in] k Total number of columns in the LHS matrix (not packed). +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed LHS matrix size in bytes +size_t kai_get_lhs_packed_size_lhs_quant_pack_qai8dxp_f32(size_t m, size_t k, size_t mr, size_t kr, size_t sr); + +/// Run the micro-kernel to quantize and pack the LHS matrix. +/// +/// @param[in] m The number of output rows written. +/// @param[in] k The number of channels. The common dimension of LHS & RHS. It must be multiple of 8. +/// @param[in] mr The number of M rows to interleave on the same output row. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] m_idx_start The starting M index. +/// @param[in] lhs LHS of the vector-by-matrix. +/// @param[in] lhs_stride Stride in bytes between two rows of LHS. +/// @param[out] lhs_packed The quantized and packed LHS matrix. +void kai_run_lhs_quant_pack_qai8dxp_f32( + size_t m, size_t k, size_t mr, size_t kr, size_t sr, size_t m_idx_start, const float* lhs, size_t lhs_stride, + void* lhs_packed); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c new file mode 100644 index 000000000..359471879 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c @@ -0,0 +1,171 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" + +#include +#include +#include + +#include "kai/kai_common.h" + +static const size_t kai_num_bytes_sum_rhs = sizeof(int32_t); +static const size_t kai_num_bytes_multiplier_rhs = sizeof(float); +static const size_t kai_num_bytes_bias = sizeof(float); + +inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { + // Since we pack a float and int32 value at the end of the row, + // we must make sure that k is a multiple of 4 for alignment + size_t kr_sr_roundedup4 = kai_roundup(kr * sr, 4); + return kai_roundup(k, kr_sr_roundedup4); +} + +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { + return nr; +} + +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { + return n_idx * rhs_stride; +} + +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { + const size_t k_internal = kai_k_roundedup(k, kr, sr); + + KAI_ASSERT((k_internal % 2) == 0); + + return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); +} + +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { + KAI_ASSERT((n_idx % nr) == 0); + + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); +} + +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { + const size_t num_rows = kai_roundup(n, nr) / nr; + + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); +} + +void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * kai_num_bytes_bias); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h new file mode 100644 index 000000000..4fc97ba96 --- /dev/null +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h @@ -0,0 +1,111 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// +#pragma once + +#include +#include + +#ifdef __cplusplus +extern "C" { +#endif + +struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params { + int8_t lhs_zero_point; + uint8_t rhs_zero_point; +}; + +/// Get the n step value. +/// The micro-kernel can process any N values. However, the starting N index to +/// be processed must be a multiple of n step. +/// +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// +/// @return the n step value +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr); + +/// Gets the offset in bytes for the RHS matrix (not packed). +/// +/// @note The int4 values are stored in a N x K matrix. Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) +/// +/// @return the offset in bytes to the RHS matrix (not packed) +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); + +/// Get the row stride in bytes to the packed RHS matrix +/// +/// @param[in] k In the RHS matrix (not packed), K is the number of columns. +/// @param[in] nr The number of columns written by the matmul micro-kernel. +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the stride in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); + +/// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel +/// (qsu4cx) values. +/// +/// @param[in] n_idx Row index in the RHS matrix (not packed). It must be a multiple of n_step. +/// @param[in] k In the RHS matrix (not packed), K is the number of columns. +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the offset in bytes to the packed RHS matrix +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); + +/// @brief Gets the size in bytes for the packed RHS matrix +/// +/// @param[in] n The number of rows in the RHS matrix (not packed) +/// @param[in] k The number of columns in the RHS matrix (not packed). +/// @param[in] nr The number of columns written by the matmul micro-kernel +/// @param[in] kr The number of columns loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// +/// @return the packed RHS matrix size in bytes +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); + +/// Run the micro-kernel to pack the RHS matrix. +/// +/// @note The int4 values are stored in a N x K matrix. Two int4 values are stored in one byte. +/// The lower order part of the byte (low) holds the first nibble (K-index + 0). +/// The higher order of the byte holds the second nibble (K-index + 1). +/// +/// @param[in] num_groups The number of groups. It must be 1. +/// @param[in] n The number of rows. +/// @param[in] k The common dimension between the LHS and RHS matrix (K). It must be an even value. +/// @param[in] nr The number of N rows to interleave on the same output output row. +/// @param[in] kr The number of K values loaded in the single inner most loop of the matmul micro-kernel. +/// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. +/// However, kr must be multiple of sr. +/// @param[in] rhs The RHS matrix containing the 4-bit values. +/// Size in bytes is expected to be greater than or equal to n * k * (sizeof(uint8_t) / 2). +/// @param[in] bias The biases. +/// @param[in] scale The scale for each output channel. +/// @param[out] rhs_packed The packed RHS matrix. +/// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. +/// @param[in] params Parameters for the micro-kernel. +void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( + size_t num_groups, // + size_t n, // + size_t k, // + size_t nr, // + size_t kr, // + size_t sr, // + const uint8_t* rhs, // + const float* bias, // + const float* scale, // + void* rhs_packed, // + size_t extra_bytes, // + const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params); + +#ifdef __cplusplus +} +#endif diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index bcba4eedb..e7604fed9 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -660,6 +660,53 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto core = static_cast(backend())->int8Functions(); auto gcore = static_cast(backend())->functions(); +#if MNN_KLEIDIAI_ENABLED + KleidiAI& kai = KleidiAI::getInstance(); + if(mDynamicQuantExe) { + if(mResource->mDequantize.bits == 4 && kai.canAccelerate()) { + const size_t m = input->batch(); //lhs vector number. + const size_t n = output->channel(); //rhs vector number. + const size_t k = input->channel(); //vector size. + + auto lhs = reinterpret_cast(input->host()); + auto lhsPacked = mTempIm2ColBuffer->host(); + auto rhsPacked = mResourceInt8->mWeightInt8->host(); + auto dst = reinterpret_cast(output->host()); + +#if !KAI_CONV_NCHW_IN_OUT + kai.packNC4HW4ToNCHW((float *)lhs, m, k); +#endif + auto BatchDynamicQuant = [=]() { + KleidiAI& kai = KleidiAI::getInstance(); + kai.runLhsQuantPack(m, k, lhs, lhsPacked); + }; + + BatchDynamicQuant(); + + int nPerThread = kai.getVecNumPerThread(n, static_cast(backend())->threadNumber(), kai.getNStep()); + int threadNeed = n % nPerThread == 0 ? n / nPerThread : (n / nPerThread + 1); + + auto ThreadFunction = [=](int tId) { + KleidiAI& kai = KleidiAI::getInstance(); + auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(tId * nPerThread, k); + auto threadDst = reinterpret_cast(dst) + kai.getDstOffset(0, tId * nPerThread, n); + int threadN = (tId == threadNeed - 1) ? (n - nPerThread * tId) : nPerThread; //Last threadN may less than nPerThread. + kai.runMatmul(m, threadN, k, lhsPacked, threadRhsPacked, n * sizeof(float), threadDst); + }; + + MNN_CONCURRENCY_BEGIN(tId, threadNeed) { + ThreadFunction((int)tId); + } + MNN_CONCURRENCY_END(); + +#if !KAI_CONV_NCHW_IN_OUT + kai.packNCHWToNC4HW4((float *)dst, m, n); +#endif + return NO_ERROR; + } + } +#endif + int UNIT__, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT__, &SRC_UNIT, &DST_XUNIT); auto blitProc = core->MNNPackC4Int8ForMatMul_A; diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 2237e065b..65d959f64 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -19,6 +19,12 @@ #undef CONSTANT #endif // CONSTANT +#ifdef MNN_KLEIDIAI_ENABLED +#define KAI_CONV_NCHW_IN_OUT 1 +#else +#define KAI_CONV_NCHW_IN_OUT 0 +#endif + namespace MNN { struct TensorArrayAttr { // array size is dynamic or not diff --git a/source/geometry/GeometryConvUtils.cpp b/source/geometry/GeometryConvUtils.cpp index dd3b53cfa..10f27f762 100644 --- a/source/geometry/GeometryConvUtils.cpp +++ b/source/geometry/GeometryConvUtils.cpp @@ -253,6 +253,19 @@ bool GeometryConvUtils::computeSingle(const Op* op, const std::vector& auto output = originOutput; auto inputDes = TensorUtils::getDescribe(newInputs[0]); auto format = inputDes->dimensionFormat; +#if KAI_CONV_NCHW_IN_OUT + { + std::shared_ptr cmd(new Command); + cmd->op = op; + cmd->inputs = std::move(newInputs); + cmd->outputs = std::move(newOutputs); + res.command.emplace_back(std::move(cmd)); + if (originOutput != output) { + ConvertUtils::compute(output, originOutput, res); + } + return true; + } +#endif if (MNN_DATA_FORMAT_NC4HW4 != format) { std::shared_ptr newInput(new Tensor(newInputs[0], Tensor::CAFFE_C4, false)); ConvertUtils::compute(newInputs[0], newInput.get(), res); diff --git a/source/shape/ShapeTensorConvert.cpp b/source/shape/ShapeTensorConvert.cpp index a3d2035e0..c577d1753 100644 --- a/source/shape/ShapeTensorConvert.cpp +++ b/source/shape/ShapeTensorConvert.cpp @@ -23,6 +23,9 @@ class TensorConvertSizeComputer : public SizeComputer { sourceFmt = MNN_DATA_FORMAT_NCHW; } auto destFmt = info->dest(); +#if KAI_CONV_NCHW_IN_OUT + destFmt = MNN_DATA_FORMAT_NCHW; +#endif TensorUtils::getDescribe(outputs[0])->dimensionFormat = destFmt; if (destFmt == MNN_DATA_FORMAT_NC4HW4) { destFmt = MNN_DATA_FORMAT_NCHW; From 95a6e4190aa812bc2cbc50e4588a41b3f80de34e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E7=8E=8B=E5=8F=AC=E5=BE=B7?= <8401806+wangzhaode@users.noreply.github.com> Date: Mon, 2 Sep 2024 17:35:36 +0800 Subject: [PATCH 2/8] Bugfix of thread workload. --- source/backend/cpu/compute/ConvInt8TiledExecutor.cpp | 2 ++ 1 file changed, 2 insertions(+) diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index e7604fed9..a2179b739 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -555,6 +555,8 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input int tileLimit = 0; int outC = output->channel(); int outC4 = UP_DIV(outC, gcore->pack); + int totalWork = outC4; + int part = 1; if (threads < planeSize) { // Thread split by output nhw. tileLimit = ALIMIN(tileLimitByC, UP_DIV(planeSize, threads)); From 644f22f02802f0dfa7791abe0ae545cd241908e8 Mon Sep 17 00:00:00 2001 From: "xinhao.zheng" Date: Mon, 21 Oct 2024 14:32:47 +0800 Subject: [PATCH 3/8] Update mnn_kleidiai interface. --- CMakeLists.txt | 1 + source/backend/cpu/CPUBackend.hpp | 2 +- .../backend/cpu/arm/kleidiAI/CMakeLists.txt | 73 +++++++++++++------ .../{MNNKleidiAI.cpp => mnn_kleidiai.cpp} | 39 ++++++++-- .../{MNNKleidiAI.h => mnn_kleidiai.h} | 35 +++++---- .../cpu/compute/ConvInt8TiledExecutor.cpp | 52 ++++++++----- source/core/TensorUtils.hpp | 8 +- source/geometry/GeometryConvUtils.cpp | 22 +++--- source/shape/ShapeTensorConvert.cpp | 4 +- 9 files changed, 161 insertions(+), 75 deletions(-) rename source/backend/cpu/arm/kleidiAI/{MNNKleidiAI.cpp => mnn_kleidiai.cpp} (89%) rename source/backend/cpu/arm/kleidiAI/{MNNKleidiAI.h => mnn_kleidiai.h} (76%) diff --git a/CMakeLists.txt b/CMakeLists.txt index deb46775e..805616041 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -253,6 +253,7 @@ message(STATUS "\tOpenCL: ${MNN_OPENCL}") message(STATUS "\tOpenGL: ${MNN_OPENGL}") message(STATUS "\tVulkan: ${MNN_VULKAN}") message(STATUS "\tARM82: ${MNN_ARM82}") +message(STATUS "\tKleidiAI: ${MNN_KLEIDIAI}") message(STATUS "\toneDNN: ${MNN_ONEDNN}") message(STATUS "\tTensorRT: ${MNN_TENSORRT}") message(STATUS "\tCoreML: ${MNN_COREML}") diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 23dcf5c9d..9c11bd12b 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -18,7 +18,7 @@ #include "MNN_generated.h" #ifdef MNN_KLEIDIAI_ENABLED -#include "arm/kleidiAI/MNNKleidiAI.h" +#include "arm/kleidiAI/mnn_kleidiai.h" #endif namespace MNN { diff --git a/source/backend/cpu/arm/kleidiAI/CMakeLists.txt b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt index 556da784a..34cd7a3bd 100644 --- a/source/backend/cpu/arm/kleidiAI/CMakeLists.txt +++ b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt @@ -1,18 +1,26 @@ -list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/MNNKleidiAI.cpp) -list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/MNNKleidiAI.h) +# +# SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +# +# SPDX-License-Identifier: Apache-2.0 +# -include_directories( - ${CMAKE_CURRENT_LIST_DIR}/ - ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/ - ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/ - ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/ - ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/pack/) +project(MNN_KleidiAI + LANGUAGES C CXX ASM +) + +set(KLEIDIAI_MIN_CLANG_VERSION 11) +set(KLEIDIAI_MIN_GNU_VERSION 11) -list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c) -list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c) -list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c) -list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c) +if(CMAKE_C_COMPILER_ID STREQUAL "Clang" AND CMAKE_C_COMPILER_VERSION VERSION_LESS ${KLEIDIAI_MIN_CLANG_VERSION}) + message(WARNING "KleidiAI: Using non-supported Clang version. Expected ${KLEIDIAI_MIN_CLANG_VERSION} or newer, received ${CMAKE_C_COMPILER_VERSION}.") +endif() +if(CMAKE_C_COMPILER_ID STREQUAL "GNU" AND CMAKE_C_COMPILER_VERSION VERSION_LESS ${KLEIDIAI_MIN_GNU_VERSION}) + message(WARNING "KleidiAI: Using non-supported GCC version. Expected ${KLEIDIAI_MIN_GNU_VERSION} or newer, received ${CMAKE_C_COMPILER_VERSION}.") +endif() + +list(APPEND MNN_KleidiAI_SOURCES ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.cpp) +list(APPEND MNN_KleidiAI_HEADERS ${CMAKE_CURRENT_LIST_DIR}/mnn_kleidiai.h) add_library( MNN_KleidiAI @@ -20,15 +28,36 @@ add_library( ${MNN_KleidiAI_SOURCES} ${MNN_KleidiAI_HEADERS} ) -# Enable ARMv8.6-A features -target_compile_definitions(MNN_KleidiAI PRIVATE - __ARM_FEATURE_MATMUL_INT8 - __ARM_FEATURE_BF16_VECTOR_ARITHMETIC - __ARM_FEATURE_BF16_SCALAR_ARITHMETIC - __ARM_BF16_FORMAT_ALTERNATIVE - __ARM_FEATURE_DOTPROD +set(KLEIDIAI_SRC ${CMAKE_CURRENT_LIST_DIR}) + +include_directories( + ${KLEIDIAI_SRC}/ + ${KLEIDIAI_SRC}/kai/ + ${KLEIDIAI_SRC}/kai/ukernels/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/ + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/) + +set(KLEIDIAI_FILES_SCALAR + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +) + +set(KLEIDIAI_FILES_NEON_DOTPROD + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c ) -target_compile_options(MNN_KleidiAI - PRIVATE -march=armv8.6-a -) \ No newline at end of file +set(KLEIDIAI_FILES_NEON_I8MM + ${KLEIDIAI_SRC}/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c +) + +# Selectively enable architecture features. +target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_SCALAR}) +if((CMAKE_SYSTEM_PROCESSOR STREQUAL "aarch64" OR CMAKE_SYSTEM_PROCESSOR STREQUAL "arm64") AND NOT MSVC) + target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_DOTPROD}) + target_sources(MNN_KleidiAI PRIVATE ${KLEIDIAI_FILES_NEON_I8MM}) + + set_source_files_properties(${KLEIDIAI_FILES_SCALAR} PROPERTIES COMPILE_OPTIONS -march=armv8-a) + set_source_files_properties(${KLEIDIAI_FILES_NEON_DOTPROD} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+dotprod) + set_source_files_properties(${KLEIDIAI_FILES_NEON_I8MM} PROPERTIES COMPILE_OPTIONS -march=armv8.2-a+i8mm) +endif() \ No newline at end of file diff --git a/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.cpp b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp similarity index 89% rename from source/backend/cpu/arm/kleidiAI/MNNKleidiAI.cpp rename to source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp index 465a6e4ad..d41f2e519 100644 --- a/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.cpp +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp @@ -1,6 +1,12 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + #if defined(__aarch64__) -#include "MNNKleidiAI.h" +#include "mnn_kleidiai.h" using namespace MNN; @@ -13,7 +19,7 @@ inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { return kai_roundup(k, kr_sr_roundedup4); } -static void packQsi4cxpQsi8cxs1s0(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, +static void packQsu4cxs1s0Qsi8cxp(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, const float* scale, void* rhs_packed, size_t extra_bytes, const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { KAI_ASSERT(num_groups == 1); @@ -130,6 +136,8 @@ void KleidiAI::packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize) { rowSrc += rowSize; } } + + delete[] tmpBuffer; } void KleidiAI::packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize) { @@ -156,6 +164,23 @@ void KleidiAI::packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize) { rowSrc += 4; } } + + delete[] tmpBuffer; +} + +//Set info +void KleidiAI::setEnable(bool enable) { + mKaiInfo.kaiEnable = enable; + if(canAccelerate()) { + MNN_PRINT("\nKleidiAI is running!\n"); + } +} + +void KleidiAI::setModelAsymmetric(bool bAsymmetric) { + mKaiInfo.asymmetric = bAsymmetric; + if(canAccelerate()) { + MNN_PRINT("\nKleidiAI is running!\n"); + } } //Lhs @@ -167,8 +192,8 @@ size_t KleidiAI::getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k) { return mIdx == 0 ? 0 : kai_get_lhs_packed_offset_lhs_quant_pack_qai8dxp_f32(mIdx, k, getMr(m), getKr(), getSr()); } -void KleidiAI::runLhsQuantPack(size_t m, size_t k, const void* lhs, void* lhsQuantedPacked) { - kai_run_lhs_quant_pack_qai8dxp_f32(m, k, getMr(m), getKr(), getSr(), 0, (const float *)lhs, k * sizeof(float), lhsQuantedPacked); +void KleidiAI::runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, void* lhsQuantedPacked) { + kai_run_lhs_quant_pack_qai8dxp_f32(m, k, mr, getKr(), getSr(), 0, (const float *)lhs, k * sizeof(float), lhsQuantedPacked); } //Rhs @@ -185,7 +210,7 @@ void KleidiAI::runRhsPack(size_t n, size_t k, const void* rhs, const void* scale params.lhs_zero_point = 1; params.rhs_zero_point = 8; if(!packedInt4) { - packQsi4cxpQsi8cxs1s0(1, n, k, getNr(), getKr(), getSr(), + packQsu4cxs1s0Qsi8cxp(1, n, k, getNr(), getKr(), getSr(), (const uint8_t *)rhs, (const float *)bias, (const float *)scale, rhsPacked, @@ -202,11 +227,11 @@ void KleidiAI::runRhsPack(size_t n, size_t k, const void* rhs, const void* scale //Matmul void KleidiAI::runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst) { if(m == 1) { //dotprod - kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(m, n, k, + kai_run_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(m, n, k, (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } else { //i8mm - kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(m, n, k, + kai_run_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(m, n, k, (const void *)lhsPacked, (const void *)rhsPacked, (float *)dst, dst_stride, sizeof(float), -FLT_MAX, FLT_MAX); } diff --git a/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.h b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h similarity index 76% rename from source/backend/cpu/arm/kleidiAI/MNNKleidiAI.h rename to source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h index 92ec40fdb..815befa36 100644 --- a/source/backend/cpu/arm/kleidiAI/MNNKleidiAI.h +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h @@ -1,3 +1,9 @@ +// +// SPDX-FileCopyrightText: Copyright 2024 Arm Limited and/or its affiliates +// +// SPDX-License-Identifier: Apache-2.0 +// + #pragma once #include @@ -19,7 +25,7 @@ #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" -#include "./kai/kai_common.h" +#include "kai_common.h" namespace MNN { class KleidiAI { @@ -40,23 +46,23 @@ namespace MNN { ~KleidiAI() {} - struct KAIInfo { + typedef struct KaiInfo { bool kaiEnable = false; bool asymmetric = false; //Asymmetric quantized model. bool dot = false; //CPU support sdot. bool i8mm = false; //CPU support i8mm. - }; + } KaiInfo; //Kai util void packNCHWToNC4HW4(float* data, size_t rowNum, size_t rowSize); void packNC4HW4ToNCHW(float* data, size_t rowNum, size_t rowSize); //Set info - void setEnable(bool enable) { mKAIInfo.kaiEnable = enable; } - void setModelAsymmetric(bool bAsymmetric) { mKAIInfo.asymmetric = bAsymmetric; } + void setEnable(bool enable); + void setModelAsymmetric(bool bAsymmetric); //Check - bool canAccelerate() { return (mKAIInfo.kaiEnable && mKAIInfo.dot && mKAIInfo.i8mm && !mKAIInfo.asymmetric); } + bool canAccelerate() { return (mKaiInfo.kaiEnable && mKaiInfo.dot && mKaiInfo.i8mm && !mKaiInfo.asymmetric); } //Get info size_t getMr(size_t m = 1) { return (m == 1) ? mKaiMrDotprod : mKaiMrI8mm; } @@ -65,12 +71,12 @@ namespace MNN { size_t getSr() { return mKaiSr; } size_t getMStep(size_t m = 1) { return (m == 1) ? mKaiMstepDotprod : mKaiMstepI8mm; } size_t getNStep() { return mKaiNStep; } - size_t getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep) { return kai_roundup(totalVec / totalThread, minStep); } + size_t getVecNumPerThread(size_t totalVec, size_t totalThread, size_t minStep) { return kai_roundup((totalVec + totalThread - 1) / totalThread, minStep); } //Lhs size_t getLhsQuantedPackedSize(size_t m, size_t k); size_t getLhsQuantedPackedOffset(size_t m, size_t mIdx, size_t k); - void runLhsQuantPack(size_t m, size_t k, const void* lhs, void* lhsQuantedPacked); + void runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, void* lhsQuantedPacked); //Rhs size_t getRhsPackedSize(size_t n, size_t k); @@ -86,14 +92,17 @@ namespace MNN { private: KleidiAI(bool bAsymmetric = false) { const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo(); - mKAIInfo.dot = gCPUInfo.dot; - mKAIInfo.i8mm = gCPUInfo.i8mm; - mKAIInfo.kaiEnable = true; - mKAIInfo.asymmetric = bAsymmetric; + mKaiInfo.dot = gCPUInfo.dot; + mKaiInfo.i8mm = gCPUInfo.i8mm; + mKaiInfo.kaiEnable = true; + mKaiInfo.asymmetric = bAsymmetric; + if(canAccelerate()) { + MNN_PRINT("\nKleidiAI is running!\n"); + } } static KleidiAI *instance; - KAIInfo mKAIInfo; + KaiInfo mKaiInfo; const size_t mKaiMstepDotprod = 1; const size_t mKaiMstepI8mm = 8; diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index a2179b739..bbd18f2db 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -662,7 +662,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu auto core = static_cast(backend())->int8Functions(); auto gcore = static_cast(backend())->functions(); -#if MNN_KLEIDIAI_ENABLED +#ifdef MNN_KLEIDIAI_ENABLED KleidiAI& kai = KleidiAI::getInstance(); if(mDynamicQuantExe) { if(mResource->mDequantize.bits == 4 && kai.canAccelerate()) { @@ -670,30 +670,48 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu const size_t n = output->channel(); //rhs vector number. const size_t k = input->channel(); //vector size. - auto lhs = reinterpret_cast(input->host()); + auto lhs = input->host(); auto lhsPacked = mTempIm2ColBuffer->host(); auto rhsPacked = mResourceInt8->mWeightInt8->host(); - auto dst = reinterpret_cast(output->host()); + auto dst = output->host(); + + int threadNum = static_cast(backend())->threadNumber(); + int threadNeed, vecPerThread; #if !KAI_CONV_NCHW_IN_OUT kai.packNC4HW4ToNCHW((float *)lhs, m, k); #endif - auto BatchDynamicQuant = [=]() { - KleidiAI& kai = KleidiAI::getInstance(); - kai.runLhsQuantPack(m, k, lhs, lhsPacked); - }; - BatchDynamicQuant(); + //Dynamic quant pack lhs. + if(m == 1) { + kai.runLhsQuantPack(1, k, 1, lhs, lhsPacked); + } else { + vecPerThread = kai.getVecNumPerThread(m, threadNum, kai.getMr(m)); + threadNeed = m % vecPerThread == 0 ? m / vecPerThread : (m / vecPerThread + 1); + size_t srcStride = vecPerThread * k * sizeof(float); + + auto BatchDynamicQuant = [=, &kai](int tId) { + auto threadSrc = lhs + tId * srcStride; + auto threadDst = lhsPacked + kai.getLhsQuantedPackedOffset(m, tId * vecPerThread, k); + int vecNum = (tId == threadNeed - 1) ? (m - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. + kai.runLhsQuantPack(vecNum, k, kai.getMr(m), threadSrc, threadDst); + }; + + MNN_CONCURRENCY_BEGIN(tId, threadNeed) { + BatchDynamicQuant((int)tId); + } + MNN_CONCURRENCY_END(); + } - int nPerThread = kai.getVecNumPerThread(n, static_cast(backend())->threadNumber(), kai.getNStep()); - int threadNeed = n % nPerThread == 0 ? n / nPerThread : (n / nPerThread + 1); + //Run matmul. + vecPerThread = kai.getVecNumPerThread(n, threadNum, kai.getNStep()); + threadNeed = n % vecPerThread == 0 ? n / vecPerThread : (n / vecPerThread + 1); - auto ThreadFunction = [=](int tId) { - KleidiAI& kai = KleidiAI::getInstance(); - auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(tId * nPerThread, k); - auto threadDst = reinterpret_cast(dst) + kai.getDstOffset(0, tId * nPerThread, n); - int threadN = (tId == threadNeed - 1) ? (n - nPerThread * tId) : nPerThread; //Last threadN may less than nPerThread. - kai.runMatmul(m, threadN, k, lhsPacked, threadRhsPacked, n * sizeof(float), threadDst); + auto ThreadFunction = [=, &kai](int tId) { + auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(tId * vecPerThread, k); + auto threadDst = dst + kai.getDstOffset(0, tId * vecPerThread, n); + int vecNum = (tId == threadNeed - 1) ? (n - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. + kai.runMatmul(m, vecNum, k, lhsPacked, threadRhsPacked, n * sizeof(float), threadDst); }; MNN_CONCURRENCY_BEGIN(tId, threadNeed) { @@ -1031,7 +1049,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } else { MNN_CONCURRENCY_BEGIN(tId, threads) { int ocIndex = PackUnit * mDivides[tId]; - if (ocIndex < ocUp4) { + if (ocIndex < ocUp4){ ThreadFunction((int)tId, 0, mTileCount,1, ocIndex); } } diff --git a/source/core/TensorUtils.hpp b/source/core/TensorUtils.hpp index 65d959f64..442b3184a 100644 --- a/source/core/TensorUtils.hpp +++ b/source/core/TensorUtils.hpp @@ -20,9 +20,13 @@ #endif // CONSTANT #ifdef MNN_KLEIDIAI_ENABLED +#include "../backend/cpu/arm/kleidiAI/mnn_kleidiai.h" +/** + * Set DenseConvInt8TiledExecutor's input/output tensor format: + * KAI_CONV_NCHW_IN_OUT = 1: format will be NCHW, skip pack/unpack functions. + * KAI_CONV_NCHW_IN_OUT = 0: format will be NC4HW4, need pack/unpack functions to fit kleidiAI ukernel. + **/ #define KAI_CONV_NCHW_IN_OUT 1 -#else -#define KAI_CONV_NCHW_IN_OUT 0 #endif namespace MNN { diff --git a/source/geometry/GeometryConvUtils.cpp b/source/geometry/GeometryConvUtils.cpp index 10f27f762..21670bd24 100644 --- a/source/geometry/GeometryConvUtils.cpp +++ b/source/geometry/GeometryConvUtils.cpp @@ -247,25 +247,23 @@ std::shared_ptr GeometryConvUtils::im2Col(Tensor* im2Col, Tensor* input, return tempTensor; } bool GeometryConvUtils::computeSingle(const Op* op, const std::vector& inputs, const std::vector& outputs, GeometryComputer::Context& context, CommandBuffer& res) { - auto newOutputs = outputs; - auto newInputs = inputs; - auto originOutput = outputs[0]; - auto output = originOutput; - auto inputDes = TensorUtils::getDescribe(newInputs[0]); - auto format = inputDes->dimensionFormat; #if KAI_CONV_NCHW_IN_OUT - { + if(KleidiAI::getInstance().canAccelerate()) { std::shared_ptr cmd(new Command); cmd->op = op; - cmd->inputs = std::move(newInputs); - cmd->outputs = std::move(newOutputs); + cmd->inputs = std::move(inputs); + cmd->outputs = std::move(outputs); res.command.emplace_back(std::move(cmd)); - if (originOutput != output) { - ConvertUtils::compute(output, originOutput, res); - } return true; } #endif + auto newOutputs = outputs; + auto newInputs = inputs; + auto originOutput = outputs[0]; + auto output = originOutput; + auto inputDes = TensorUtils::getDescribe(newInputs[0]); + auto format = inputDes->dimensionFormat; + if (MNN_DATA_FORMAT_NC4HW4 != format) { std::shared_ptr newInput(new Tensor(newInputs[0], Tensor::CAFFE_C4, false)); ConvertUtils::compute(newInputs[0], newInput.get(), res); diff --git a/source/shape/ShapeTensorConvert.cpp b/source/shape/ShapeTensorConvert.cpp index c577d1753..899b9410b 100644 --- a/source/shape/ShapeTensorConvert.cpp +++ b/source/shape/ShapeTensorConvert.cpp @@ -24,7 +24,9 @@ class TensorConvertSizeComputer : public SizeComputer { } auto destFmt = info->dest(); #if KAI_CONV_NCHW_IN_OUT - destFmt = MNN_DATA_FORMAT_NCHW; + if(KleidiAI::getInstance().canAccelerate()) { + destFmt = MNN_DATA_FORMAT_NCHW; + } #endif TensorUtils::getDescribe(outputs[0])->dimensionFormat = destFmt; if (destFmt == MNN_DATA_FORMAT_NC4HW4) { From 6f5be724a9b1e2442c3733f81cc7f03c621c8d14 Mon Sep 17 00:00:00 2001 From: "xinhao.zheng" Date: Tue, 22 Oct 2024 14:37:28 +0800 Subject: [PATCH 4/8] Update MNN to latest version --- .../backend/cpu/arm/kleidiAI/CMakeLists.txt | 2 +- .../backend/cpu/arm/kleidiAI/kai/kai_common.h | 126 ++++++++++++--- ...ai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c | 1 + ...ai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h | 10 +- ...2_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c | 1 + ...2_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h | 10 +- ...c => kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c} | 118 +++++++++------ ...h => kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h} | 16 +- .../backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp | 143 ++++++++++++++++-- .../backend/cpu/arm/kleidiAI/mnn_kleidiai.h | 2 +- .../cpu/compute/ConvInt8TiledExecutor.cpp | 139 +++++++++++------ 11 files changed, 427 insertions(+), 141 deletions(-) rename source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/{kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c => kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c} (56%) rename source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/{kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h => kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h} (88%) diff --git a/source/backend/cpu/arm/kleidiAI/CMakeLists.txt b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt index 34cd7a3bd..f12e27c19 100644 --- a/source/backend/cpu/arm/kleidiAI/CMakeLists.txt +++ b/source/backend/cpu/arm/kleidiAI/CMakeLists.txt @@ -40,7 +40,7 @@ include_directories( set(KLEIDIAI_FILES_SCALAR ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_lhs_quant_pack_qai8dxp_f32.c - ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c + ${KLEIDIAI_SRC}/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c ) set(KLEIDIAI_FILES_NEON_DOTPROD diff --git a/source/backend/cpu/arm/kleidiAI/kai/kai_common.h b/source/backend/cpu/arm/kleidiAI/kai/kai_common.h index e8765cf38..9569e5468 100644 --- a/source/backend/cpu/arm/kleidiAI/kai/kai_common.h +++ b/source/backend/cpu/arm/kleidiAI/kai/kai_common.h @@ -6,10 +6,10 @@ #pragma once #include +#include #include #include #include -#include #ifdef __cplusplus extern "C" { @@ -20,11 +20,11 @@ extern "C" { // * cppcoreguidelines-avoid-do-while: do-while is necessary for macros. // * cppcoreguidelines-pro-type-vararg: use of variadic arguments in fprintf is expected. // * cert-err33-c: checking the output of fflush and fprintf is not necessary for error reporting. -#define KAI_ERROR(msg) \ - do { \ - fflush(stdout); \ - fprintf(stderr, "%s", msg); \ - exit(EXIT_FAILURE); \ +#define KAI_ERROR(msg) \ + do { \ + fflush(stdout); \ + fprintf(stderr, "%s:%d %s", __FILE__, __LINE__, msg); \ + exit(EXIT_FAILURE); \ } while (0) #define KAI_ASSERT_MSG(cond, msg) \ @@ -53,24 +53,24 @@ extern "C" { /// KleidiAI data types /// Format: (reserved)|(num-bytes)|(type)|(variant-type) enum kai_datatype { - Unknown = 0x0000, - F32 = 0x0411, - F16 = 0x0212, - Bf16 = 0x0213, - Int32 = 0x0421, - Int16 = 0x0222, - Int8 = 0x0124, - Uint32 = 0x0431, - Uint16 = 0x0232, - Uint8 = 0x0134, - Bool = 0x0441 + kai_dt_unknown = 0x0000, + kai_dt_f32 = 0x0411, + kai_dt_f16 = 0x0212, + kai_dt_bf16 = 0x0213, + kai_dt_int32 = 0x0421, + kai_dt_int16 = 0x0222, + kai_dt_int8 = 0x0124, + kai_dt_uint32 = 0x0431, + kai_dt_uint16 = 0x0232, + kai_dt_uint8 = 0x0134, + kai_dt_bool = 0x0441 }; /// Gets number of bytes for a given data type /// @param[in] dt KleidiAI data type /// /// @return the numbers of bytes for the data type -inline static size_t kai_num_bytes_datatype(enum kai_datatype dt) { +inline static size_t kai_get_datatype_size_in_bytes(enum kai_datatype dt) { return (size_t)(dt >> 8); } @@ -78,7 +78,7 @@ inline static size_t kai_num_bytes_datatype(enum kai_datatype dt) { /// @param[in] f16 The f16 value /// /// @return the f32 value -inline static float kai_f16_to_f32(uint16_t f16) { +inline static float kai_cast_f32_f16(uint16_t f16) { #if defined(__ARM_NEON) __fp16 f32 = 0; memcpy(&f32, &f16, sizeof(uint16_t)); @@ -86,11 +86,37 @@ inline static float kai_f16_to_f32(uint16_t f16) { #endif } +/// Converts a scalar bf16 value to f32 +/// @param[in] bf16 The f16 value +/// +/// @return the f32 value +inline static float kai_cast_f32_bf16(uint16_t bf16) { + const uint32_t i32 = (bf16 << 16); + float f32; + memcpy(&f32, &i32, sizeof(i32)); + return f32; +} + +/// Converts a f32 value to bf16 +/// @param[in] f32 The f32 value +/// +/// @return the bf16 value +inline static uint16_t kai_cast_bf16_f32(float f32) { + uint16_t bf16 = 0; +#ifdef __ARM_FEATURE_BF16 + __asm__ __volatile__("bfcvt %h[output], %s[input]" : [output] "=w"(bf16) : [input] "w"(f32)); +#else + const uint32_t* i32 = (uint32_t*)(&f32); + bf16 = (*i32 >> 16); +#endif + return bf16; +} + /// Converts a scalar f32 value to f16 /// @param[in] f32 The f32 value /// /// @return the f16 value -inline static uint16_t kai_f32_to_f16(float f32) { +inline static uint16_t kai_cast_f16_f32(float f32) { #if defined(__ARM_NEON) uint16_t f16 = 0; __fp16 tmp = f32; @@ -103,6 +129,66 @@ inline static size_t kai_roundup(size_t a, size_t b) { return ((a + b - 1) / b) * b; } +#ifdef __ARM_FEATURE_SVE + +/// Gets the SME vector length for 8-bit elements. +inline static uint64_t kai_get_sme_vector_length_u8(void) { + uint64_t res = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cntb %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +/// Gets the SME vector length for 16-bit elements. +inline static uint64_t kai_get_sme_vector_length_u16(void) { + uint64_t res = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cnth %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +/// Gets the SME vector length for 32-bit elements. +inline static uint64_t kai_get_sme_vector_length_u32(void) { + uint64_t res = 0; + + __asm__ __volatile__( + ".inst 0xd503477f // SMSTART ZA\n" + "cntw %0\n" + ".inst 0xd503467f // SMSTOP\n" + : "=r"(res) + : + : "z0", "z1", "z2", "z3", "z4", "z5", "z6", "z7", "z8", "z9", "z10", "z11", "z12", "z13", "z14", "z15", "z16", + "z17", "z18", "z19", "z20", "z21", "z22", "z23", "z24", "z25", "z26", "z27", "z28", "z29", "z30", "z31"); + + return res; +} + +#endif // __ARM_FEATURE_SVE + +/// Extends the sign bit of int 4-bit value (stored in int8_t variable) +/// @param[in] value The 4-bit int value +/// +/// @return the int8_t value with sign extended +inline static int8_t kai_ext_sign_i8_i4(int8_t value) { + return (value ^ 0x8) - 8; +} + #ifdef __cplusplus } #endif diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c index 8f1479dcd..cd24f7313 100644 --- a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.c @@ -9,6 +9,7 @@ #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include +#include #include #include "kai/kai_common.h" diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h index b6d2cf87e..fefca19a9 100644 --- a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h @@ -14,7 +14,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix +/// -# kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS matrix /// -------------------------------------------------- @@ -39,19 +39,19 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod size_t kai_get_mr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /// Gets the nr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /// Gets the kr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel /// /// @return the kr value size_t kai_get_kr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); /// Gets the sr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel /// /// @return the sr value size_t kai_get_sr_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod(void); @@ -110,7 +110,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotpr /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 +/// kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 /// @param[out] dst Result of the vector-by-matrix /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c index 50e260d89..7e40839e6 100644 --- a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.c @@ -9,6 +9,7 @@ #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" #include +#include #include #include "kai/kai_common.h" diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h index c5a4553fc..04df4a825 100644 --- a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h @@ -14,7 +14,7 @@ extern "C" { /// Micro-kernel dependencies /// /// -# kai_lhs_quant_pack_qai8dxp_f32 to dynamically quantize and pack the LHS matrix -/// -# kai_rhs_pack_kxn_qsi4cxp_qsu4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 to pack the RHS matrix +/// -# kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0 OR kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 to pack the RHS matrix /// -------------------------------------------------- @@ -39,19 +39,19 @@ size_t kai_get_n_step_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(vo size_t kai_get_mr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /// Function to get the nr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel /// /// @return the nr value size_t kai_get_nr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /// Gets the kr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel /// /// @return the kr value size_t kai_get_kr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); /// Gets the sr value, which must be used to pack the RHS matrix with -/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 micro-kernel +/// the @ref kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 micro-kernel /// /// @return the sr value size_t kai_get_sr_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm(void); @@ -110,7 +110,7 @@ size_t kai_get_dst_size_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm( /// by calling the @ref kai_lhs_quant_pack_qai8dxp_f32 micro-kernel which performs /// both the dynamic quantization to 8-bit and activation packing in a single step. /// @param[in] rhs_packed The RHS matrix packed, which is obtained by calling @ref -/// kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0 +/// kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0 /// @param[out] dst Result of the vector-by-matrix /// @param[in] dst_stride_row Stride in bytes between two rows of the DST matrix. /// @param[in] dst_stride_col Stride in bytes between two columns of the DST matrix. For now, it must be sizeof(float) diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c similarity index 56% rename from source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c rename to source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c index 359471879..d3ec86067 100644 --- a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.c +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.c @@ -3,9 +3,10 @@ // // SPDX-License-Identifier: Apache-2.0 // -#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" #include +#include #include #include @@ -22,15 +23,15 @@ inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { return kai_roundup(k, kr_sr_roundedup4); } -size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr) { +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t nr) { return nr; } -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride) { +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride) { return n_idx * rhs_stride; } -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr) { const size_t k_internal = kai_k_roundedup(k, kr, sr); KAI_ASSERT((k_internal % 2) == 0); @@ -38,23 +39,23 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_ return nr * ((k_internal / 2) + kai_num_bytes_multiplier_rhs + kai_num_bytes_sum_rhs + kai_num_bytes_bias); } -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr) { KAI_ASSERT((n_idx % nr) == 0); - return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return (n_idx / nr) * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); } -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr) { const size_t num_rows = kai_roundup(n, nr) / nr; - return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + return num_rows * kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); } -void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( +void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { KAI_ASSERT(num_groups == 1); KAI_ASSERT(extra_bytes == 0); KAI_ASSERT((kr % sr) == 0); @@ -62,11 +63,11 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( KAI_ASSERT(scale != NULL); KAI_ASSERT(rhs_packed != NULL); KAI_ASSERT(params != NULL); - KAI_ASSERT(params->rhs_zero_point == 8); KAI_ASSERT(params->lhs_zero_point == 1); + KAI_ASSERT(params->rhs_zero_point == 0 || params->rhs_zero_point == 8); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); @@ -100,47 +101,78 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; - uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; - uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + if (params->rhs_zero_point == 8) { + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; - if (k0_idx < k) { - byte0 = rhs[src_addr_byte0]; - } - - if (k1_idx < k) { - byte1 = rhs[src_addr_byte1]; - } + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } - // The following operations where we extract the values from the bytes - // can be also written in the following and less efficient manner: - /* - uint8_t src_x0_lo = 0; - uint8_t src_x0_hi = 0; + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } - if ((k0_idx % 2) == 0) { - src_x0_lo = (byte0 & 0x0F); - } else { - src_x0_lo = (byte0 >> 4); + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } else { + int8_t byte0 = 0; + int8_t byte1 = 0; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; } - if ((k1_idx % 2) == 0) { - src_x0_hi = (byte1 & 0x0F); - } else { - src_x0_hi = (byte1 >> 4); + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; } - */ - const size_t shift_right_x0 = (k0_idx % 2) * 4; - const size_t shift_right_x1 = (k1_idx % 2) * 4; - const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; - const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + // The logic behind the following operations where we extract the + // values from the bytes is same as unsigned + + const size_t shift_right_x0 = (k0_idx % 2) * 4; + const size_t shift_right_x1 = (k1_idx % 2) * 4; - sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + int8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + int8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; - const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + const int8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); - *dst_row = dst_qs0 ^ 0x88; - dst_row += sizeof(uint8_t); + *(int8_t*)dst_row = dst_qs0; + dst_row += sizeof(int8_t); + + src_x0_lo = kai_ext_sign_i8_i4(src_x0_lo); + src_x0_hi = kai_ext_sign_i8_i4(src_x0_hi); + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi; + } } // Adjust the reduction sums diff --git a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h similarity index 88% rename from source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h rename to source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h index 4fc97ba96..dc7c1bd02 100644 --- a/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h +++ b/source/backend/cpu/arm/kleidiAI/kai/ukernels/matmul/pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h @@ -12,7 +12,7 @@ extern "C" { #endif -struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params { +struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params { int8_t lhs_zero_point; uint8_t rhs_zero_point; }; @@ -24,7 +24,7 @@ struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params { /// @param[in] nr The number of columns written by the matmul micro-kernel /// /// @return the n step value -size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr); +size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t nr); /// Gets the offset in bytes for the RHS matrix (not packed). /// @@ -36,7 +36,7 @@ size_t kai_get_n_step_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t nr); /// @param[in] rhs_stride The number of bytes in in each row of the RHS matrix (not packed) /// /// @return the offset in bytes to the RHS matrix (not packed) -size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t rhs_stride); +size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n_idx, size_t rhs_stride); /// Get the row stride in bytes to the packed RHS matrix /// @@ -46,7 +46,7 @@ size_t kai_get_rhs_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n_idx, size_t r /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the stride in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); +size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t k, size_t nr, size_t kr, size_t sr); /// Gets the offset in bytes for the packed RHS matrix, which contains the packed 4-bit quantized symmetric per-channel /// (qsu4cx) values. @@ -58,7 +58,7 @@ size_t kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t k, size_ /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the offset in bytes to the packed RHS matrix -size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( +size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( size_t n_idx, size_t k, size_t nr, size_t kr, size_t sr); /// @brief Gets the size in bytes for the packed RHS matrix @@ -70,7 +70,7 @@ size_t kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( /// @param[in] sr The number of kr splits. It can be 1 (no splits) up to kr. /// /// @return the packed RHS matrix size in bytes -size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); +size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(size_t n, size_t k, size_t nr, size_t kr, size_t sr); /// Run the micro-kernel to pack the RHS matrix. /// @@ -92,7 +92,7 @@ size_t kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(size_t n, size_t /// @param[out] rhs_packed The packed RHS matrix. /// @param[in] extra_bytes Extra bytes to append to the end of each row of the packed RHS matrix. /// @param[in] params Parameters for the micro-kernel. -void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( +void kai_run_rhs_pack_nxk_qsi4cxp_qs4cxs1s0( size_t num_groups, // size_t n, // size_t k, // @@ -104,7 +104,7 @@ void kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0( const float* scale, // void* rhs_packed, // size_t extra_bytes, // - const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params); + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params); #ifdef __cplusplus } diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp index d41f2e519..dc1f9169f 100644 --- a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.cpp @@ -19,9 +19,128 @@ inline static size_t kai_k_roundedup(size_t k, size_t kr, size_t sr) { return kai_roundup(k, kr_sr_roundedup4); } -static void packQsu4cxs1s0Qsi8cxp(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, +static void packQsi4cxps16s0Qs4cxs0s1( + size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, + const float* scale, void* rhs_packed, size_t extra_bytes, + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { + KAI_ASSERT(num_groups == 1); + KAI_ASSERT(extra_bytes == 0); + KAI_ASSERT((kr % sr) == 0); + KAI_ASSERT(rhs != NULL); + KAI_ASSERT(scale != NULL); + KAI_ASSERT(rhs_packed != NULL); + KAI_ASSERT(params != NULL); + KAI_ASSERT(params->rhs_zero_point == 8); + KAI_ASSERT(params->lhs_zero_point == 1); + + const size_t rhs_zero_point = params->rhs_zero_point; + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); + const size_t k_internal = kai_k_roundedup(k, kr, sr); + const size_t dst_num_rows = kai_roundup(n, nr) / nr; + const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); + const size_t block_length_in_bytes = kr / sr; + const size_t k_interleaved_v = 16U; + const size_t rhs_stride = kai_roundup(k, 2) / 2; + + for (size_t dst_row_idx = 0; dst_row_idx < dst_num_rows; ++dst_row_idx) { + uint8_t* dst_row = (uint8_t*)rhs_packed + dst_row_idx * rhs_packed_stride; + + int32_t* sums = (int32_t*)(dst_row + nr * (k_internal / 2)); + + // Initialize to zero the RHS reduction sums + memset(sums, 0, nr * sizeof(int32_t)); + + for (size_t dst_byte_idx = 0; dst_byte_idx < dst_num_bytes_per_row; ++dst_byte_idx) { + const size_t block_idx = dst_byte_idx / block_length_in_bytes; + const size_t block_byte_idx = dst_byte_idx % block_length_in_bytes; + const size_t super_block_idx = block_idx / nr; + const size_t nr_idx = block_idx % nr; + + const size_t k_adjustment = + ((block_byte_idx + super_block_idx * block_length_in_bytes) / k_interleaved_v) * k_interleaved_v; + const size_t k0_idx = block_byte_idx + super_block_idx * block_length_in_bytes + k_adjustment; + const size_t k1_idx = k0_idx + k_interleaved_v; + const size_t n0_idx = dst_row_idx * nr + nr_idx; + + // Clamp the index to avoid out-of-bound reads + const size_t n0_valid_idx = KAI_MIN(n0_idx, n - 1); + + const size_t src_addr_byte0 = (k0_idx / 2) + n0_valid_idx * rhs_stride; + const size_t src_addr_byte1 = (k1_idx / 2) + n0_valid_idx * rhs_stride; + + uint8_t byte0 = rhs_zero_point | rhs_zero_point << 4; + uint8_t byte1 = rhs_zero_point | rhs_zero_point << 4; + + if (k0_idx < k) { + byte0 = rhs[src_addr_byte0]; + } + + if (k1_idx < k) { + byte1 = rhs[src_addr_byte1]; + } + + // The following operations where we extract the values from the bytes + // can be also written in the following and less efficient manner: + /* + uint8_t src_x0_lo = 0; + uint8_t src_x0_hi = 0; + + if ((k0_idx % 2) == 0) { + src_x0_lo = (byte0 & 0x0F); + } else { + src_x0_lo = (byte0 >> 4); + } + + if ((k1_idx % 2) == 0) { + src_x0_hi = (byte1 & 0x0F); + } else { + src_x0_hi = (byte1 >> 4); + } + */ + const size_t shift_right_x0 = ((k0_idx + 1) % 2) * 4; + const size_t shift_right_x1 = ((k1_idx + 1) % 2) * 4; + + const uint8_t src_x0_lo = (byte0 >> shift_right_x0) & 0x0F; + const uint8_t src_x0_hi = (byte1 >> shift_right_x1) & 0x0F; + + sums[nr_idx] += (int32_t)src_x0_lo + (int32_t)src_x0_hi - 2 * (int32_t)rhs_zero_point; + + const uint8_t dst_qs0 = src_x0_lo | (src_x0_hi << 4); + + *dst_row = dst_qs0 ^ 0x88; + dst_row += sizeof(uint8_t); + } + + // Adjust the reduction sums + for (size_t i = 0; i < nr; ++i) { + sums[i] = sums[i] * 16; + dst_row += sizeof(int32_t); + } + + // Adjust the scales + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + *((float*)(dst_row)) = scale[src_row_idx] * 0.0625F; + dst_row += sizeof(float); + } + + // Set the bias + if (bias == NULL) { + memset(dst_row, 0, nr * sizeof(float)); + } else { + for (size_t i = 0; i < nr; ++i) { + // Clamp the row index to avoid out-of-bound reads + const size_t src_row_idx = KAI_MIN(dst_row_idx * nr + i, n - 1); + ((float*)dst_row)[i] = bias[src_row_idx]; + } + } + } +} + +static void packQs4cxs16s0Qsi8cx(size_t num_groups, size_t n, size_t k, size_t nr, size_t kr, size_t sr, const uint8_t* rhs, const float* bias, const float* scale, void* rhs_packed, size_t extra_bytes, - const struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params* params) { + const struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params* params) { KAI_ASSERT(num_groups == 1); KAI_ASSERT(extra_bytes == 0); KAI_ASSERT((kr % sr) == 0); @@ -33,7 +152,7 @@ static void packQsu4cxs1s0Qsi8cxp(size_t num_groups, size_t n, size_t k, size_t KAI_ASSERT(params->lhs_zero_point == 1); const size_t rhs_zero_point = params->rhs_zero_point; - const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(k, nr, kr, sr); + const size_t rhs_packed_stride = kai_get_rhs_packed_stride_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(k, nr, kr, sr); const size_t k_internal = kai_k_roundedup(k, kr, sr); const size_t dst_num_rows = kai_roundup(n, nr) / nr; const size_t dst_num_bytes_per_row = nr * (kai_k_roundedup(k, kr, sr) / 2); @@ -198,29 +317,29 @@ void KleidiAI::runLhsQuantPack(size_t m, size_t k, size_t mr, const void* lhs, v //Rhs size_t KleidiAI::getRhsPackedSize(size_t n, size_t k) { - return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(n, k, getNr(), getKr(), getSr()); + return kai_get_rhs_packed_size_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(n, k, getNr(), getKr(), getSr()); } size_t KleidiAI::getRhsPackedOffset(size_t nIdx, size_t k) { - return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(nIdx, k, getNr(), getKr(), getSr()); + return kai_get_rhs_packed_offset_rhs_pack_nxk_qsi4cxp_qs4cxs1s0(nIdx, k, getNr(), getKr(), getSr()); } void KleidiAI::runRhsPack(size_t n, size_t k, const void* rhs, const void* scale, const void *bias, void* rhsPacked, bool packedInt4) { - struct kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0_params params; + struct kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0_params params; params.lhs_zero_point = 1; params.rhs_zero_point = 8; if(!packedInt4) { - packQsu4cxs1s0Qsi8cxp(1, n, k, getNr(), getKr(), getSr(), + packQs4cxs16s0Qsi8cx(1, n, k, getNr(), getKr(), getSr(), (const uint8_t *)rhs, (const float *)bias, (const float *)scale, rhsPacked, 0, ¶ms); } else { - kai_run_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0(1, n, k, getNr(), getKr(), getSr(), - (const uint8_t *)rhs, - (const float *)bias, (const float *)scale, - rhsPacked, - 0, ¶ms); + packQsi4cxps16s0Qs4cxs0s1(1, n, k, getNr(), getKr(), getSr(), + (const uint8_t *)rhs, + (const float *)bias, (const float *)scale, + rhsPacked, + 0, ¶ms); } } diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h index 815befa36..0c8b73059 100644 --- a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h @@ -21,7 +21,7 @@ #include #include "kai_lhs_quant_pack_qai8dxp_f32.h" -#include "kai_rhs_pack_nxk_qsi4cxp_qsu4cxs1s0.h" +#include "kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0.h" #include "kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod.h" #include "kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm.h" diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index bbd18f2db..ba9b9d343 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -268,6 +268,35 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int oc = convOp->common()->outputCount(); int ic = convOp->common()->inputCount(); bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(oc, UNIT) == oc && ROUND_UP(ic, SRC_UNIT) == ic); + +#ifdef MNN_KLEIDIAI_ENABLED + KleidiAI kai = KleidiAI::getInstance(quanCommon->asymmetric); + if(quanCommon->canUseInt4 && kai.canAccelerate()) { + int n = oc; + int k = ic; + int packedWeightSize = kai.getRhsPackedSize(n, k); + + //Alloc packed weight tensor. + mResourceInt8->mWeightInt8.reset(Tensor::createDevice({packedWeightSize})); + bool success = backend->onAcquireBuffer(mResourceInt8->mWeightInt8.get(), Backend::STATIC); + + if (!success) { + MNN_ERROR("Out of static memory!\n"); + return; + } + + //Run rhs pack. + kai.runRhsPack(n, k, (uint8_t*)quanCommon->weight.get(), + mResourceInt8->mOriginScale->host(), + mResourceInt8->mOriginBias->host(), + mResourceInt8->mWeightInt8->host(), + true); + + return; + } + +#endif + if (quanCommon->canUseInt4 && directReadInt4weight) { // int4 weight reorder mResourceInt8->mWeightAsymmetricQuant = true; @@ -506,6 +535,25 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); +#ifdef MNN_KLEIDIAI_ENABLED + KleidiAI& kai = KleidiAI::getInstance(); + if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate()) { + int batch = inputs[0]->batch(); + int channel = inputs[0]->channel(); + + int packedSize = kai.getLhsQuantedPackedSize(batch, channel); + mTempIm2ColBuffer.reset(Tensor::createDevice({packedSize})); + bool success = backend()->onAcquireBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); + if (!success) { + MNN_ERROR("Out of dynamic memory!\n"); + return OUT_OF_MEMORY; + } + + backend()->onReleaseBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); + return NO_ERROR; + } +#endif + if (mResourceInt8->mDynamicQuant == false) { mMutableResource->updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0])); CPUConvolution::onResize(inputs, outputs); @@ -664,66 +712,65 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu #ifdef MNN_KLEIDIAI_ENABLED KleidiAI& kai = KleidiAI::getInstance(); - if(mDynamicQuantExe) { - if(mResource->mDequantize.bits == 4 && kai.canAccelerate()) { - const size_t m = input->batch(); //lhs vector number. - const size_t n = output->channel(); //rhs vector number. - const size_t k = input->channel(); //vector size. + if(mResourceInt8->mDynamicQuant && mResourceInt8->mActBits == 4 && kai.canAccelerate()) { + const size_t m = input->batch(); //lhs vector number. + const size_t n = output->channel(); //rhs vector number. + const size_t k = input->channel(); //vector size. - auto lhs = input->host(); - auto lhsPacked = mTempIm2ColBuffer->host(); - auto rhsPacked = mResourceInt8->mWeightInt8->host(); - auto dst = output->host(); + auto lhs = input->host(); + auto lhsPacked = mTempIm2ColBuffer->host(); + auto rhsPacked = mResourceInt8->mWeightInt8->host(); + auto dst = output->host(); - int threadNum = static_cast(backend())->threadNumber(); - int threadNeed, vecPerThread; + int threadNum = static_cast(backend())->threadNumber(); + int threadNeed, vecPerThread; #if !KAI_CONV_NCHW_IN_OUT - kai.packNC4HW4ToNCHW((float *)lhs, m, k); + kai.packNC4HW4ToNCHW((float *)lhs, m, k); #endif - //Dynamic quant pack lhs. - if(m == 1) { - kai.runLhsQuantPack(1, k, 1, lhs, lhsPacked); - } else { - vecPerThread = kai.getVecNumPerThread(m, threadNum, kai.getMr(m)); - threadNeed = m % vecPerThread == 0 ? m / vecPerThread : (m / vecPerThread + 1); - size_t srcStride = vecPerThread * k * sizeof(float); - - auto BatchDynamicQuant = [=, &kai](int tId) { - auto threadSrc = lhs + tId * srcStride; - auto threadDst = lhsPacked + kai.getLhsQuantedPackedOffset(m, tId * vecPerThread, k); - int vecNum = (tId == threadNeed - 1) ? (m - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. - kai.runLhsQuantPack(vecNum, k, kai.getMr(m), threadSrc, threadDst); - }; - - MNN_CONCURRENCY_BEGIN(tId, threadNeed) { - BatchDynamicQuant((int)tId); - } - MNN_CONCURRENCY_END(); - } - - //Run matmul. - vecPerThread = kai.getVecNumPerThread(n, threadNum, kai.getNStep()); - threadNeed = n % vecPerThread == 0 ? n / vecPerThread : (n / vecPerThread + 1); - - auto ThreadFunction = [=, &kai](int tId) { - auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(tId * vecPerThread, k); - auto threadDst = dst + kai.getDstOffset(0, tId * vecPerThread, n); - int vecNum = (tId == threadNeed - 1) ? (n - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. - kai.runMatmul(m, vecNum, k, lhsPacked, threadRhsPacked, n * sizeof(float), threadDst); + //Dynamic quant pack lhs. + if(m == 1) { + kai.runLhsQuantPack(1, k, 1, lhs, lhsPacked); + } else { + vecPerThread = kai.getVecNumPerThread(m, threadNum, kai.getMr(m)); + threadNeed = m % vecPerThread == 0 ? m / vecPerThread : (m / vecPerThread + 1); + size_t srcStride = vecPerThread * k * sizeof(float); + + auto BatchDynamicQuant = [=, &kai](int tId) { + auto threadSrc = lhs + tId * srcStride; + auto threadDst = lhsPacked + kai.getLhsQuantedPackedOffset(m, tId * vecPerThread, k); + int vecNum = (tId == threadNeed - 1) ? (m - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. + kai.runLhsQuantPack(vecNum, k, kai.getMr(m), threadSrc, threadDst); }; MNN_CONCURRENCY_BEGIN(tId, threadNeed) { - ThreadFunction((int)tId); + BatchDynamicQuant((int)tId); } MNN_CONCURRENCY_END(); + } + + //Run matmul. + vecPerThread = kai.getVecNumPerThread(n, threadNum, kai.getNStep()); + threadNeed = n % vecPerThread == 0 ? n / vecPerThread : (n / vecPerThread + 1); + + auto ThreadFunction = [=, &kai](int tId) { + auto threadRhsPacked = rhsPacked + kai.getRhsPackedOffset(tId * vecPerThread, k); + auto threadDst = dst + kai.getDstOffset(0, tId * vecPerThread, n); + int vecNum = (tId == threadNeed - 1) ? (n - vecPerThread * tId) : vecPerThread; //Last threadN may less than vecPerThread. + kai.runMatmul(m, vecNum, k, lhsPacked, threadRhsPacked, n * sizeof(float), threadDst); + }; + + MNN_CONCURRENCY_BEGIN(tId, threadNeed) { + ThreadFunction((int)tId); + } + MNN_CONCURRENCY_END(); #if !KAI_CONV_NCHW_IN_OUT - kai.packNCHWToNC4HW4((float *)dst, m, n); + kai.packNCHWToNC4HW4((float *)dst, m, n); #endif - return NO_ERROR; - } + + return NO_ERROR; } #endif From 39dadd08d2262466fbd8489bedd0e4067a29829c Mon Sep 17 00:00:00 2001 From: "xinhao.zheng" Date: Tue, 22 Oct 2024 15:11:21 +0800 Subject: [PATCH 5/8] Refine some code --- source/backend/cpu/compute/ConvInt8TiledExecutor.cpp | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index ba9b9d343..47c761e3a 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -603,8 +603,6 @@ ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& input int tileLimit = 0; int outC = output->channel(); int outC4 = UP_DIV(outC, gcore->pack); - int totalWork = outC4; - int part = 1; if (threads < planeSize) { // Thread split by output nhw. tileLimit = ALIMIN(tileLimitByC, UP_DIV(planeSize, threads)); @@ -1096,7 +1094,7 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu } else { MNN_CONCURRENCY_BEGIN(tId, threads) { int ocIndex = PackUnit * mDivides[tId]; - if (ocIndex < ocUp4){ + if (ocIndex < ocUp4) { ThreadFunction((int)tId, 0, mTileCount,1, ocIndex); } } From a1cbdf11cec977b627fc282ea892f46b3b990460 Mon Sep 17 00:00:00 2001 From: "xinhao.zheng" Date: Thu, 24 Oct 2024 08:23:38 +0800 Subject: [PATCH 6/8] Refine CmakeList.txt --- CMakeLists.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/CMakeLists.txt b/CMakeLists.txt index 805616041..9983eae10 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -207,6 +207,7 @@ option(MNN_OPENCL "Enable OpenCL" OFF) option(MNN_OPENGL "Enable OpenGL" OFF) option(MNN_VULKAN "Enable Vulkan" OFF) option(MNN_ARM82 "Enable ARMv8.2's FP16 Compute" ON) +option(MNN_KLEIDIAI "Enable KLEIDIAI" OFF) option(MNN_ONEDNN "Enable oneDNN" OFF) option(MNN_AVX512 "Enable AVX512" OFF) option(MNN_CUDA "Enable CUDA" OFF) From 28115248d7e170969b760eb589c010a04fed5ac8 Mon Sep 17 00:00:00 2001 From: "xinhao.zheng" Date: Thu, 24 Oct 2024 14:07:29 +0800 Subject: [PATCH 7/8] Refine rhs pack --- .../cpu/compute/ConvInt8TiledExecutor.cpp | 38 +++++++++---------- 1 file changed, 19 insertions(+), 19 deletions(-) diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index 47c761e3a..bdaa08045 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -271,29 +271,29 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O #ifdef MNN_KLEIDIAI_ENABLED KleidiAI kai = KleidiAI::getInstance(quanCommon->asymmetric); - if(quanCommon->canUseInt4 && kai.canAccelerate()) { - int n = oc; - int k = ic; - int packedWeightSize = kai.getRhsPackedSize(n, k); + if(quanCommon->canUseInt4 && kai.canAccelerate()) { + int n = oc; + int k = ic; + int packedWeightSize = kai.getRhsPackedSize(n, k); - //Alloc packed weight tensor. - mResourceInt8->mWeightInt8.reset(Tensor::createDevice({packedWeightSize})); - bool success = backend->onAcquireBuffer(mResourceInt8->mWeightInt8.get(), Backend::STATIC); - - if (!success) { - MNN_ERROR("Out of static memory!\n"); - return; - } - - //Run rhs pack. - kai.runRhsPack(n, k, (uint8_t*)quanCommon->weight.get(), - mResourceInt8->mOriginScale->host(), - mResourceInt8->mOriginBias->host(), - mResourceInt8->mWeightInt8->host(), - true); + //Alloc packed weight tensor. + mResourceInt8->mWeightInt8.reset(Tensor::createDevice({packedWeightSize})); + bool success = backend->onAcquireBuffer(mResourceInt8->mWeightInt8.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Out of static memory!\n"); return; } + + //Run rhs pack. + kai.runRhsPack(n, k, (uint8_t*)quanCommon->weight.get(), + mResourceInt8->mOriginScale->host(), + mResourceInt8->mOriginBias->host(), + mResourceInt8->mWeightInt8->host(), + directReadInt4weight); + + return; + } #endif From 8f6a1234ae4be0040f3f9ccb8f552ca4dbeb5e91 Mon Sep 17 00:00:00 2001 From: yanxing Date: Mon, 28 Oct 2024 16:53:06 +0800 Subject: [PATCH 8/8] add acthalf and blockwise condition in canAccelerate. --- .../backend/cpu/arm/kleidiAI/mnn_kleidiai.h | 16 ++++++++--- .../cpu/compute/ConvInt8TiledExecutor.cpp | 28 +++++++++++-------- 2 files changed, 28 insertions(+), 16 deletions(-) diff --git a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h index 0c8b73059..38cdce230 100644 --- a/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h +++ b/source/backend/cpu/arm/kleidiAI/mnn_kleidiai.h @@ -30,9 +30,9 @@ namespace MNN { class KleidiAI { public: - static KleidiAI &getInstance(bool bAsymmetric) { + static KleidiAI &getInstance(bool bAsymmetric, bool acthalf, bool blockwise) { if(!instance) { - instance = new KleidiAI(bAsymmetric); + instance = new KleidiAI(bAsymmetric, acthalf, blockwise); } return *instance; } @@ -49,6 +49,8 @@ namespace MNN { typedef struct KaiInfo { bool kaiEnable = false; bool asymmetric = false; //Asymmetric quantized model. + bool acthalf = false; // activation half precision. + bool blockwise = false; // weight quant using block wise. bool dot = false; //CPU support sdot. bool i8mm = false; //CPU support i8mm. } KaiInfo; @@ -62,7 +64,10 @@ namespace MNN { void setModelAsymmetric(bool bAsymmetric); //Check - bool canAccelerate() { return (mKaiInfo.kaiEnable && mKaiInfo.dot && mKaiInfo.i8mm && !mKaiInfo.asymmetric); } + bool canAccelerate() { + return (mKaiInfo.kaiEnable && mKaiInfo.dot && mKaiInfo.i8mm && + !mKaiInfo.asymmetric && !mKaiInfo.acthalf && !mKaiInfo.blockwise); + } //Get info size_t getMr(size_t m = 1) { return (m == 1) ? mKaiMrDotprod : mKaiMrI8mm; } @@ -90,12 +95,15 @@ namespace MNN { void runMatmul(size_t m, size_t n, size_t k, const void* lhsPacked, const void* rhsPacked, size_t dst_stride, void* dst); private: - KleidiAI(bool bAsymmetric = false) { + KleidiAI(bool bAsymmetric = false, bool acthalf = false, bool blockwise = false) { const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo(); mKaiInfo.dot = gCPUInfo.dot; mKaiInfo.i8mm = gCPUInfo.i8mm; mKaiInfo.kaiEnable = true; mKaiInfo.asymmetric = bAsymmetric; + mKaiInfo.acthalf = acthalf; + mKaiInfo.blockwise = blockwise; + if(canAccelerate()) { MNN_PRINT("\nKleidiAI is running!\n"); } diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index bdaa08045..4788d88c3 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -83,7 +83,7 @@ void ConvInt8TiledExecutor::reorderWeight(Tensor* weight, const uint8_t* weightS for (int y = 0; y < ic; ++y) { const int yOutSide = y / SRC_UNIT; const int yInSide = y % SRC_UNIT; - + int blockId = (yOutSide + k * icDivU) / blockL; int blockInsideId = (yOutSide + k * icDivU) % blockL; @@ -268,9 +268,13 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int oc = convOp->common()->outputCount(); int ic = convOp->common()->inputCount(); bool directReadInt4weight = (kernelCount == 1 && ROUND_UP(oc, UNIT) == oc && ROUND_UP(ic, SRC_UNIT) == ic); - + #ifdef MNN_KLEIDIAI_ENABLED - KleidiAI kai = KleidiAI::getInstance(quanCommon->asymmetric); + bool half_act = gcore->bytes == 2; + int biasSize = mResourceInt8->mOriginBias->size(); + int alphaSize = mResourceInt8->mOriginScale->size(); + bool blockwise = (biasSize * 2) != alphaSize; + KleidiAI kai = KleidiAI::getInstance(quanCommon->asymmetric, half_act, blockwise); if(quanCommon->canUseInt4 && kai.canAccelerate()) { int n = oc; int k = ic; @@ -294,9 +298,9 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O return; } - + #endif - + if (quanCommon->canUseInt4 && directReadInt4weight) { // int4 weight reorder mResourceInt8->mWeightAsymmetricQuant = true; @@ -305,7 +309,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int lU = UP_DIV(ic, SRC_UNIT); int hP = UNIT; int lP = SRC_UNIT; - + // weight shape. std::vector shape; if (SRC_UNIT > pack) { @@ -337,7 +341,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int blockkInsideId = j % blockL; for (int k = 0; k < cnt; ++k) { int dstIndx0 = (blockId * stride0 + i * stride1 + blockkInsideId * lP * hP) / 2 + (2 * k); - + int hpId0 = (2 * k + 1) / lP; int lpId0 = (2 * k) % lP; int hpId1 = (2 * (k + cnt) + 1) / lP; @@ -350,7 +354,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O int s3 = (srcPtr[srcIndx1] & 15); int d0 = s0 * 16 + s2; int d1 = s1 * 16 + s3; - + dstPtr[dstIndx0] = d0; dstPtr[dstIndx0 + 1] = d1; } @@ -358,7 +362,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O } } else { // std::shared_ptr srcWeight; - + if (quanCommon->canUseInt4) { mResourceInt8->mWeightAsymmetricQuant = true; auto srcPtr = reinterpret_cast(quanCommon->weight.get()); @@ -392,7 +396,7 @@ DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const O dst0[j] = d; } } - + // Update int4 weight to mWeightInt8. mResourceInt8->mWeightInt8 = weightLow; } else { @@ -434,7 +438,7 @@ static void _computeAlphaScale(Backend* backend, const Convolution2D* conv2d, st auto alphaPtr = scaleBias->host(); auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + ocUp4 * core->bytes); ::memset(alphaPtr, 0, 2 * ocUp4 * core->bytes); - + // Load quant scale and bias weightOrigin = resourceInt8->mWeightInt8->host(); auto wZero = resourceInt8->mWeightQuantZero->host(); // has packed to outputUp4 @@ -454,7 +458,7 @@ static void _computeAlphaScale(Backend* backend, const Convolution2D* conv2d, st } } resourceInt8->mOriginScale = scaleBias; - + // Compute float weightKernelSum resourceInt8->mWeightKernelSum.reset(Tensor::createDevice({ocUp4 * 4})); success = backend->onAcquireBuffer(resourceInt8->mWeightKernelSum.get(), Backend::STATIC);