diff --git a/CMakeLists.txt b/CMakeLists.txt index abb66d6c8..b85353bb3 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -408,6 +408,8 @@ include(${CMAKE_CURRENT_LIST_DIR}/source/backend/cpu/CMakeLists.txt) SET(MNN_PUB_HDRS "") SET(MNN_EXPR_PUB_HDRS "") +set(MNN_EXTRA_HEADERS "") + list(APPEND MNN_PUB_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/include/MNN/MNNDefine.h") list(APPEND MNN_PUB_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/include/MNN/Interpreter.hpp") list(APPEND MNN_PUB_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/include/MNN/HalideRuntime.h") @@ -430,6 +432,19 @@ list(APPEND MNN_EXPR_PUB_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/include/MNN/expr/Neur list(APPEND MNN_EXPR_PUB_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/include/MNN/expr/ExecutorScope.hpp") list(APPEND MNN_EXPR_PUB_HDRS "${CMAKE_CURRENT_SOURCE_DIR}/include/MNN/expr/Scope.hpp") +# Add Extra Header +IF(MNN_BUILD_OPENCV) + file(GLOB MNN_CV_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/tools/cv/include/cv/*.hpp PARENT_SCOPE) + file(GLOB MNN_CV_IMGHDRS ${CMAKE_CURRENT_SOURCE_DIR}/tools/cv/include/cv/imgproc/*.hpp PARENT_SCOPE) + list(APPEND MNN_EXTRA_HEADERS ${MNN_CV_HDRS}) + list(APPEND MNN_EXTRA_HEADERS ${MNN_CV_IMGHDRS}) +ENDIF() +IF(MNN_BUILD_LLM) + file(GLOB MNN_LLM_HDRS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/llm/engine/include/llm/*) + list(APPEND MNN_EXTRA_HEADERS ${CMAKE_CURRENT_SOURCE_DIR}/transformers/llm/engine/include/llm/llm.hpp) +ENDIF() + + set(MNN_DEPS "") set(MNN_EXTRA_DEPENDS "") @@ -659,11 +674,11 @@ IF(MNN_TENSORRT) ENDIF() IF(MNN_SEP_BUILD) - add_library(MNN SHARED ${CMAKE_CURRENT_LIST_DIR}/cmake/dummy.cpp ${MNN_OBJECTS_TO_LINK} ${MNN_PUB_HDRS} ${MNN_EXPR_PUB_HDRS}) + add_library(MNN SHARED ${CMAKE_CURRENT_LIST_DIR}/cmake/dummy.cpp ${MNN_OBJECTS_TO_LINK} ${MNN_PUB_HDRS} ${MNN_EXPR_PUB_HDRS} ${MNN_EXTRA_HEADERS}) target_link_libraries(MNN PUBLIC ${MNN_EXTRA_DEPENDS}) ELSE() IF(MNN_BUILD_SHARED_LIBS) - add_library(MNN SHARED ${CMAKE_CURRENT_LIST_DIR}/cmake/dummy.cpp ${MNN_OBJECTS_TO_LINK} ${MNN_PUB_HDRS} ${MNN_EXPR_PUB_HDRS}) + add_library(MNN SHARED ${CMAKE_CURRENT_LIST_DIR}/cmake/dummy.cpp ${MNN_OBJECTS_TO_LINK} ${MNN_PUB_HDRS} ${MNN_EXPR_PUB_HDRS} ${MNN_EXTRA_HEADERS}) if (WIN32) foreach(TARGET ${MNN_TARGETS}) target_compile_definitions(${TARGET} PRIVATE "-DBUILDING_MNN_DLL") @@ -673,7 +688,7 @@ ELSE() target_compile_definitions(MNN INTERFACE "-DUSING_MNN_DLL") endif() ELSE() - add_library(MNN STATIC ${CMAKE_CURRENT_LIST_DIR}/cmake/dummy.cpp ${MNN_OBJECTS_TO_LINK} ${MNN_PUB_HDRS} ${MNN_EXPR_PUB_HDRS}) + add_library(MNN STATIC ${CMAKE_CURRENT_LIST_DIR}/cmake/dummy.cpp ${MNN_OBJECTS_TO_LINK} ${MNN_PUB_HDRS} ${MNN_EXPR_PUB_HDRS} ${MNN_EXTRA_HEADERS}) ENDIF() target_link_libraries(MNN PUBLIC ${MNN_EXTRA_DEPENDS}) ENDIF() @@ -729,7 +744,6 @@ IF(MNN_BUILD_OPENCV AND NOT MNN_SEP_BUILD) ENDIF() target_sources(MNN PRIVATE $) ENDIF() - IF(MNN_BUILD_LLM) # add_definitions(-DMNN_BUILD_LLM) include(${CMAKE_CURRENT_LIST_DIR}/transformers/llm/engine/CMakeLists.txt) @@ -831,6 +845,27 @@ ELSE() ARCHIVE DESTINATION lib FRAMEWORK DESTINATION /Library/Frameworks/ ) + IF(MNN_BUILD_OPENCV) + if (NOT MNN_AAPL_FMWK) + INSTALL(FILES ${MNN_CV_HDRS} DESTINATION include/MNN/cv) + INSTALL(FILES ${MNN_CV_IMGHDRS} DESTINATION include/MNN/cv/imgproc) + endif() + FOREACH(HDR ${MNN_CV_HDRS}) + SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/cv/ ) + ENDFOREACH() + FOREACH(HDR ${MNN_CV_IMGHDRS}) + SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/cv/imgproc ) + ENDFOREACH() + ENDIF() + IF(MNN_BUILD_LLM) + if (NOT MNN_AAPL_FMWK) + INSTALL(FILES ${MNN_LLM_HDRS} DESTINATION include/MNN/llm) + endif() + FOREACH(HDR ${MNN_LLM_HDRS}) + SET_SOURCE_FILES_PROPERTIES(${HDR} PROPERTIES MACOSX_PACKAGE_LOCATION Headers/llm ) + ENDFOREACH() + ENDIF() + if (NOT MNN_AAPL_FMWK) INSTALL(FILES ${MNN_PUB_HDRS} DESTINATION include/MNN/) INSTALL(FILES ${MNN_EXPR_PUB_HDRS} DESTINATION include/MNN/expr/) diff --git a/source/backend/cpu/arm/arm32/bf16/MNNAxByClampBroadcastC4_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNAxByClampBroadcastC4_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm32/bf16/MNNAxByClampBroadcastC4_BF16.S rename to backupcode/cpubackend/arm/arm32/bf16/MNNAxByClampBroadcastC4_BF16.S diff --git a/source/backend/cpu/arm/arm32/bf16/MNNConvRunForLineDepthwise_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNConvRunForLineDepthwise_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm32/bf16/MNNConvRunForLineDepthwise_BF16.S rename to backupcode/cpubackend/arm/arm32/bf16/MNNConvRunForLineDepthwise_BF16.S diff --git a/source/backend/cpu/arm/arm32/bf16/MNNConvRunForUnitDepthWise_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNConvRunForUnitDepthWise_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm32/bf16/MNNConvRunForUnitDepthWise_BF16.S rename to backupcode/cpubackend/arm/arm32/bf16/MNNConvRunForUnitDepthWise_BF16.S diff --git a/source/backend/cpu/arm/arm32/bf16/MNNGelu_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNGelu_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm32/bf16/MNNGelu_BF16.S rename to backupcode/cpubackend/arm/arm32/bf16/MNNGelu_BF16.S diff --git a/backupcode/cpubackend/arm/arm32/bf16/MNNPackC4ForMatMul_A_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNPackC4ForMatMul_A_BF16.S new file mode 100644 index 000000000..663ffae68 --- /dev/null +++ b/backupcode/cpubackend/arm/arm32/bf16/MNNPackC4ForMatMul_A_BF16.S @@ -0,0 +1,208 @@ +// +// NEON_MNNPackC4ForMatMul_A_BF16.S +// MNN +// +// Created by MNN on 2021/02/21. +// Copyright © 2018-2021 Alibaba Group Holding Limited +// +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +asm_function NEON_MNNPackC4ForMatMul_A_BF16 +// treate float pointer as int16_t* +//void NEON_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) +//Auto: r0: dest, r1:sourceGroup, r2: info, r3:el +push {r4-r8, r10, r11, lr} // avoid to touch platform-register r-9 +ldr r10, [r2, #0] // number +ldr r4, [r2, #4] // eReal +ldr r11, [r2, #8] // eDest +ldr r6, [r2, #12] // xOffset +// xOffset -> xOffset * 4 * sizeof(float) +// eReal -> eReal * 4 * sizeof(float) +// eDest -> eDest * sizeof(float) +mov r12, #2 // sizeof(int16_t) +mov lr, #8 // sizeof(int16_t) * 4 +mul r4, lr, r4 +mul r11, r12, r11 +mul r6, lr, r6 + +LoopNumber: +ldr r5, [r3, #4] // l +ldr r8, [r3, #8] // eOffset +ldr r7, [r3, #12] // lOffset + +push {r0, r1} +ldr r1, [r1, #0] + +// Compute dest ptr: r0 = r0 + eOffset * sizeof(float) + lOffset * eDest * sizeof(float) +; mov lr, #2 //sizeof(int16_t) +mul r7, r11, r7 +mul r8, r12, r8 +add r0, r0, r7 +add r0, r0, r8 + +ldr r2, [r3, #0] // e + +Body: +cmp r2, #12 +bne Right + cmp r5, #4 + blt LoopEL3 + LoopL4: + mov r2, r1 +.macro MAIN_TRANSPOSE + vld1.16 {d16}, [r1], r6 // load size: 4 * sizeof(int16_t) + vld1.16 {d19}, [r1], r6 + vld1.16 {d22}, [r1], r6 + vld1.16 {d25}, [r1], r6 + vld1.16 {d17}, [r1], r6 + vld1.16 {d20}, [r1], r6 + vld1.16 {d23}, [r1], r6 + vld1.16 {d26}, [r1], r6 + vld1.16 {d18}, [r1], r6 + vld1.16 {d21}, [r1], r6 + vld1.16 {d24}, [r1], r6 + vld1.16 {d27}, [r1], r6 + + // transpose each 4 16-bit elements in 2 d_n vectors, by transpose 16-bit and scale up transpose 32-bit. + vtrn.16 d16, d19 + vtrn.16 d22, d25 + // vswp d0[2-3], d2[0-1] + // vswp d1[2-3], d3[0-1] + // swap half of 64-bit is equal to transpose in 32-bit unit. + vtrn.32 d16, d22 + vtrn.32 d19, d25 + + vtrn.16 d17, d20 + vtrn.16 d23, d26 + vtrn.32 d17, d23 + vtrn.32 d20, d26 + + vtrn.16 d18, d21 + vtrn.16 d24, d27 + vtrn.32 d18, d24 + vtrn.32 d21, d27 + // after transpose from 12x4 to 4x12, memory layout is + // +-------+------+------+ + // | d16...|d17...|d18...| + // +-------+------+------+ + // | d19...|d20...|d21...| + // +-------+------+------+ + // | d22...|d23...|d24...| + // +-------+------+------+ + // | d25...|d26...|d27...| + // +-------+------+------+ +.endm + MAIN_TRANSPOSE + + vstm r0!, {d16, d17, d18, d19, d20, d21, d22, d23, d24, d25, d26, d27} // store at one time: 12 * 4 * sizeof(int16_t) + + add r1, r2, r4 + sub r5, r5, #4 + cmp r5, #4 + bge LoopL4 + + LoopEL3: + cmp r5, #3 + blt LoopEL2 + MAIN_TRANSPOSE + + vstm r0!, {d16, d17, d18, d19, d20, d21, d22, d23, d24} + + b LoopEEnd + + LoopEL2: + cmp r5, #2 + blt LoopEL1 + MAIN_TRANSPOSE + + vstm r0!, {d16, d17, d18, d19, d20, d21} + + b LoopEEnd + + LoopEL1: + cmp r5, #0 + beq LoopEEnd + MAIN_TRANSPOSE + + vstm r0!, {d16, d17, d18} + + LoopEEnd: + +b End + + +Right: + +LoopE1: + mov lr, r5 + mov r7, r1 + mov r8, r0 + cmp r5, #4 + blt LoopE1L3 + LoopE1L4: + vld1.16 {d0}, [r1], r4 + vst1.16 {d0[0]}, [r0], r11 + vst1.16 {d0[1]}, [r0], r11 + vst1.16 {d0[2]}, [r0], r11 + vst1.16 {d0[3]}, [r0], r11 + sub r5, r5, #4 + cmp r5, #4 + bge LoopE1L4 + + LoopE1L3: + cmp r5, #3 + blt LoopE1L2 + vld1.16 {d0}, [r1], r4 + vst1.16 {d0[0]}, [r0], r11 + vst1.16 {d0[1]}, [r0], r11 + vst1.16 {d0[2]}, [r0], r11 + + sub r5, r5, #3 + + LoopE1L2: + cmp r5, #2 + blt LoopE1L1 + vld1.16 {d0}, [r1], r4 + vst1.16 {d0[0]}, [r0], r11 + vst1.16 {d0[1]}, [r0], r11 + sub r5, r5, #2 + + LoopE1L1: + cmp r5, #1 + blt LoopE1End + vld1.16 {d0[0]}, [r1], r4 + vst1.16 {d0[0]}, [r0], r11 + + LoopE1End: + + subs r2, r2, #1 + add r0, r8, r12 // !!!! caution : sizeof(int16_t) + add r1, r7, r6 + mov r5, lr + bne LoopE1 + +End: + +pop {r0, r1} +subs r10, r10, #1 + +// x3 is (const int32_t* el), this array size of 4. as a result for next struct element, +// address added by 4 * sizeof(int32_t) +add r3, r3, #16 + +// x1 is (const int16_t** sourceGroup), even though data content is int16_t, +// the element in sourceGroup in 'int16_t*', as a result for next struct element, +// value added by sizeof(void*) +add r1, r1, #4 + +bne LoopNumber + +pop {r4-r8, r10, r11, pc} + +#endif +#endif diff --git a/backupcode/cpubackend/arm/arm32/bf16/MNNPackC4_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNPackC4_BF16.S new file mode 100644 index 000000000..70b9e61e4 --- /dev/null +++ b/backupcode/cpubackend/arm/arm32/bf16/MNNPackC4_BF16.S @@ -0,0 +1,202 @@ +// +// MNNPackC4_BF16.S +// MNN +// +// Created by MNN on 2021/02/26. +// Copyright © 2018-2021 Alibaba Group Holding Limited +// + + + + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" +.text +.align 5 + +// .macro transpose +// vtrn.16 d0, d1 +// vtrn.16 d2, d3 +// vswp d0[2-3], d1[2-3] // should swap high half of d-vector, the half is 32-bit. there is no instruction, we use vst4.16 instead +// vswp d2[2-3], d3[2-3] +// .endm + +asm_function MNNPackC4_BF16 +// treate float pointer as int16_t* +//void MNNPackC4_BF16(float* dst, const float* src, size_t area, size_t depth, int32_t* areaOffset) +//Auto load: +//r0:dst, r1:src, r2:area, r3:depth + +push {r4-r8, r10, lr} // avoid to touch platform-register r-9 + +ldr lr, [sp, #28] +ldr r10, [lr, #4] +ldr lr, [lr, #0] + +mul r4, r2, r3 +cmp r4, #0 +beq UpEnd + +//r4: srcDepthOffset:srcArea*sizeof(int16_t) +mov r4, #2 +mul r4, lr, r4 + +//r10 -> 4 * (dstArea * sizeof(int16_t) - area * sizeof(int16_t)) +mov r12, #8 +sub r10, r10, r2 +mul r10, r12, r10 + +//lr -> (srcArea * sizeof(int16_t) - area * sizeof(int16_t)) +mov r12, #2 +sub lr, lr, r2 +mul lr, r12, lr + +UpL4: +cmp r3, #3 +ble UpL3 + +UpL4Loop: +add r5, r1, r4 +add r6, r4, r5 +add r7, r4, r6 +mov r8, r2 +cmp r8, #3 +ble UpL4AreaRemain +UpL4AreaLoop: +vld1.16 {d0}, [r1]! // load 4 elements of 16-bit into 64bit vector register d0 +vld1.16 {d1}, [r5]! +vld1.16 {d2}, [r6]! +vld1.16 {d3}, [r7]! +// transpose // no suitable instruction to transpose int16_t type +vst4.16 {d0, d1, d2, d3}, [r0]! +sub r8, r8, #4 +cmp r8, #4 +bge UpL4AreaLoop + +UpL4AreaRemain: +cmp r8, #0 +beq UpL4AreaRemainEnd +UpL4AreaRemainLoop: +vld1.16 {d0[0]}, [r1]! +vld1.16 {d0[1]}, [r5]! +vld1.16 {d0[2]}, [r6]! +vld1.16 {d0[3]}, [r7]! + +vst1.16 {d0}, [r0]! + +subs r8, r8, #1 +bne UpL4AreaRemainLoop +UpL4AreaRemainEnd: +sub r3, r3, #4 +add r1, r7, lr +cmp r3, #4 +add r0, r10, r0 +bge UpL4Loop + +UpL3: +cmp r3, #2 +ble UpL2 +add r5, r1, r4 +add r6, r4, r5 +mov r8, r2 +cmp r8, #3 +ble UpL3AreaRemain +UpL3AreaLoop: +vld1.16 {d0}, [r1]! +vmov.i16 d3, #0 +vld1.16 {d1}, [r5]! +vld1.16 {d2}, [r6]! +// transpose // no suitable instruction to transpose int16_t type +vst4.16 {d0, d1, d2, d3}, [r0]! +sub r8, r8, #4 +cmp r8, #4 +bge UpL3AreaLoop + +cmp r8, #0 +beq UpL3AreaRemainEnd +UpL3AreaRemain: +vmov.i16 d0, #0 +vld1.16 {d0[0]}, [r1]! +vld1.16 {d0[1]}, [r5]! +vld1.16 {d0[2]}, [r6]! + +vst1.16 {d0}, [r0]! + +subs r8, r8, #1 +bne UpL3AreaRemain + +UpL3AreaRemainEnd: +sub r3, r3, #3 + + +UpL2: +cmp r3, #1 +ble UpL1 +add r5, r1, r4 +mov r8, r2 +cmp r8, #3 +ble UpL2AreaRemain +UpL2AreaLoop: +vld1.16 {d0}, [r1]! +vmov.i16 d3, #0 +vld1.16 {d1}, [r5]! +vmov.i16 d2, #0 +// transpose // no suitable instruction to transpose int16_t type +vst4.16 {d0, d1, d2, d3}, [r0]! +sub r8, r8, #4 +cmp r8, #4 +bge UpL2AreaLoop + +cmp r8, #0 +beq UpL2AreaRemainEnd +UpL2AreaRemain: +vmov.i16 d0, #0 +vld1.16 {d0[0]}, [r1]! +vld1.16 {d0[1]}, [r5]! + +vst1.16 {d0}, [r0]! + +subs r8, r8, #1 +bne UpL2AreaRemain + +UpL2AreaRemainEnd: +sub r3, r3, #2 + +UpL1: +cmp r3, #0 +beq UpEnd +mov r8, r2 +cmp r8, #3 +ble UpL1AreaRemain +UpL1AreaLoop: +vld1.16 {d0}, [r1]! +vmov.i16 d3, #0 +vmov.i16 d1, #0 +vmov.i16 d2, #0 +// transpose // no suitable instruction to transpose int16_t type +vst4.16 {d0, d1, d2, d3}, [r0]! +sub r8, r8, #4 +cmp r8, #4 +bge UpL1AreaLoop + +cmp r8, #0 +beq UpL1AreaRemainEnd +UpL1AreaRemain: +vmov.i16 d0, #0 +vld1.16 {d0[0]}, [r1]! + +vst1.16 {d0}, [r0]! + +subs r8, r8, #1 +bne UpL1AreaRemain + +UpL1AreaRemainEnd: + +UpEnd: + +pop {r4-r8, r10, pc} + +#endif +#endif diff --git a/backupcode/cpubackend/arm/arm32/bf16/MNNPackedMatMulRemain_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNPackedMatMulRemain_BF16.S new file mode 100644 index 000000000..252f1956a --- /dev/null +++ b/backupcode/cpubackend/arm/arm32/bf16/MNNPackedMatMulRemain_BF16.S @@ -0,0 +1,154 @@ +// +// NEON_MNNPackedMatMulRemain_BF16.S +// MNN +// +// Created by MNN on 2021/02/24. +// Copyright © 2018-2021 Alibaba Group Holding Limited. +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 12 * 8 MatMul +asm_function NEON_MNNPackedMatMulRemain_BF16 +// treate float pointer as int16_t* +//void NEON_MNNPackedMatMulRemain_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); +//Auto r0: C, r1:A, r2:B, r3:eSize, +//r4:parameter, r5: cache no usage, r6:postParameters, r7:bias + +push {r4-r8, r10, r11, lr} // avoid to touch platform-register r-9 +ldr r4, [sp, #32] +ldr r6, [sp, #36] +ldr r7, [sp, #40] +ldr r12, [r4, #0] +cmp r6, #0 +beq Start +vld1.32 {q3}, [r6] +vdup.f32 q12, d7[0] // min +vdup.f32 q13, d7[1] // max +Start: +cmp r3, #4 +blt L1 + +LoopE4: + ldr r5, [r4, #8] // h + add r5, r5, #3 + lsr r5, r5, #2 // r5 = UP_DIV(r5, 4) + mov lr, r0 + mov r11, r2 + push {r7} + LoopE4H: + mov r10, r1 + ldr r8, [r4, #4] // l + vmov.i32 q8, #0 + vmov.i32 q9, #0 + vmov.i32 q10, #0 + vmov.i32 q11, #0 + LoopE4L: + vld1.16 {d0}, [r10], r12 + vld1.16 {d2}, [r11]! // load 4 * sizeof(int16_t) + vshll.s16 q0, d0, #16 // shift left long of each int16_t as float32 + vshll.s16 q1, d2, #16 + vmla.f32 q8, q1, d0[0] + vmla.f32 q9, q1, d0[1] + vmla.f32 q10, q1, d1[0] + vmla.f32 q11, q1, d1[1] + subs r8, r8, #1 + bne LoopE4L + cmp r6, #0 + beq StoreE4 + vld1.16 {d28}, [r7]! // load 4 * sizeof(int16_t) + vshll.s16 q14, d28, #16 // shift left long of each int16_t as float32 + vmla.f32 q8, q14, d6[1] + vmla.f32 q9, q14, d6[1] + vmla.f32 q10, q14, d6[1] + vmla.f32 q11, q14, d6[1] + + PostTreatE4: + vmax.f32 q8, q8, q12 + vmax.f32 q9, q9, q12 + vmax.f32 q10, q10, q12 + vmax.f32 q11, q11, q12 + + vmin.f32 q8, q8, q13 + vmin.f32 q9, q9, q13 + vmin.f32 q10, q10, q13 + vmin.f32 q11, q11, q13 + + StoreE4: + ldr r8, [r4, #20] + add r11, r11, r8 + ldr r8, [r4, #12] + + vshrn.i32 d16, q8, #16 // shift right 16bit of each float32 as int16_t + vshrn.i32 d17, q9, #16 + vshrn.i32 d18, q10, #16 + vshrn.i32 d19, q11, #16 + vst1.16 {d16, d17}, [lr]! + vst1.16 {d18, d19}, [lr], r8 + sub lr, lr, #16 + subs r5, r5, #1 // move 4 colum along lP dim. lP = l / 4 + bne LoopE4H + sub r3, r3, #4 // move 4 colum along e dim. + add r0, r0, #32 // move address of 4 * 4 * sizeof(int16_t) + add r1, r1, #8 // move address of 4 * sizeof(int16_t) in src tile block + cmp r3, #4 + pop {r7} + bge LoopE4 + +L1: +cmp r3, #0 +beq End +LoopE1: + ldr r5, [r4, #8] // h + add r5, r5, #3 + lsr r5, r5, #2 + mov lr, r0 + mov r11, r2 + push {r7} + LoopE1H: + mov r10, r1 + ldr r8, [r4, #4] // l + vmov.i32 q15, #0 + LoopE1L: + vld1.16 {d0[0]}, [r10], r12 + vld1.16 {d2}, [r11]! // load 4 * sizeof(int16_t) + vshll.s16 q0, d0, #16 // shift left long of each int16_t as float32 + vshll.s16 q1, d2, #16 + + vmla.f32 q15, q1, d0[0] + subs r8, r8, #1 + bne LoopE1L + cmp r6, #0 + beq StoreE1 + vld1.16 {d28}, [r7]! // load 4 * sizeof(int16_t) + vshll.s16 q14, d28, #16 // shift left long of each int16_t as float32 + vmla.f32 q15, q14, d6[1] + + PostTreatE1: + vmax.f32 q15, q15, q12 + vmin.f32 q15, q15, q13 + + StoreE1: + ldr r8, [r4, #20] + add r11, r11, r8 + ldr r8, [r4, #12] + + vshrn.i32 d30, q15, #16 // shift right 16bit of each float32 as int16_t + vst1.16 {d30}, [lr], r8 + subs r5, r5, #1 + bne LoopE1H + subs r3, r3, #1 + add r0, r0, #8 // move address of 4 * sizeof(int16_t) + add r1, r1, #2 // move address of 1 * sizeof(int16_t) + pop {r7} + bne LoopE1 +End: +pop {r4-r8, r10, r11, pc} + +#endif +#endif diff --git a/backupcode/cpubackend/arm/arm32/bf16/MNNPackedMatMul_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNPackedMatMul_BF16.S new file mode 100644 index 000000000..3b9ab3d48 --- /dev/null +++ b/backupcode/cpubackend/arm/arm32/bf16/MNNPackedMatMul_BF16.S @@ -0,0 +1,211 @@ +// +// NEON_MNNPackedMatMul_BF16.S +// MNN +// +// Created by MNN on 2021/02/24. +// Copyright © 2018-2021 Alibaba Group Holding Limited. +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 12 * 8 MatMul +asm_function NEON_MNNPackedMatMul_BF16 +// treate float pointer as int16_t* +//void NEON_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); +// Auto: r0: C, r1:A, r2:B, r3:parameter +// Load from sp: r5: postParameters, r6:bias + +push {r4-r8, r10, r11, lr} // avoid to touch platform-register r-9 +ldr r5, [sp, #32] +ldr r6, [sp, #36] + +ldr r4, [r3, #8] // h +ldr r7, [r3, #4] // l +add r4, r4, #3 +ldr r8, [r3, #12]//cStride +ldr r3, [r3, #20]//bExtraStride +lsr r4, r4, #2 + +sub r8, r8, #96 // after segment "Store", total line stride is CStride, all vst. offset is 12 * 4 * size_t(int16_t) = 96byte + +vpush {q4-q7} +// q0, q1, q2: src +// q3: weight +// q4 - q15: dst + +LoopH: + subs r12, r7, #1 + mov r11, r1 + vld1.16 {d6}, [r2]! + vld1.16 {d0, d1}, [r11]! // load 2 * 4 * sizeof(int16_t) + vshll.s16 q3, d6, #16 // shift left long of each int16_t as float32 + vshll.s16 q1, d1, #16 // !! caution: must shll d1 before d0 + vshll.s16 q0, d0, #16 + + vmul.f32 q4, q3, d0[0] + vmul.f32 q5, q3, d0[1] + vmul.f32 q6, q3, d1[0] + vld1.16 {d4}, [r11]! // load 4 * sizeof(int16_t) + vshll.s16 q2, d4, #16 + vmul.f32 q7, q3, d1[1] + + vmul.f32 q8, q3, d2[0] + vmul.f32 q9, q3, d2[1] + vmul.f32 q10, q3, d3[0] + vmul.f32 q11, q3, d3[1] + + vmul.f32 q12, q3, d4[0] + vmul.f32 q13, q3, d4[1] + vmul.f32 q14, q3, d5[0] + vmul.f32 q15, q3, d5[1] + beq LoopLEnd + LoopL: + vld1.16 {d6}, [r2]! + vld1.16 {d0, d1}, [r11]! // load 2 * 4 * sizeof(int16_t) + vshll.s16 q3, d6, #16 // shift left long of each int16_t as float32 + vshll.s16 q1, d1, #16 // !! caution: must shll d1 before d0 + vshll.s16 q0, d0, #16 + + vmla.f32 q4, q3, d0[0] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d1[0] + vld1.16 {d4}, [r11]! + vshll.s16 q2, d4, #16 + + vmla.f32 q7, q3, d1[1] + + vmla.f32 q8, q3, d2[0] + vmla.f32 q9, q3, d2[1] + vmla.f32 q10, q3, d3[0] + vmla.f32 q11, q3, d3[1] + + vmla.f32 q12, q3, d4[0] + vmla.f32 q13, q3, d4[1] + vmla.f32 q14, q3, d5[0] + vmla.f32 q15, q3, d5[1] + + subs r12, r12, #1 + bne LoopL + LoopLEnd: + cmp r5, #0 + beq Store + vld1.32 {q0}, [r5] // parameter remains float + cmp r6, #0 + beq LoadOrigin + vld1.16 {d6}, [r6]! // load 4 * sizeof(int16_t) + vshll.s16 q3, d6, #16 // shift left long of each int16_t as int32_t + vmla.f32 q4, q3, d0[1] + vmla.f32 q5, q3, d0[1] + vmla.f32 q6, q3, d0[1] + vmla.f32 q7, q3, d0[1] + vmla.f32 q8, q3, d0[1] + vmla.f32 q9, q3, d0[1] + vmla.f32 q10, q3, d0[1] + vmla.f32 q11, q3, d0[1] + vmla.f32 q12, q3, d0[1] + vmla.f32 q13, q3, d0[1] + vmla.f32 q14, q3, d0[1] + vmla.f32 q15, q3, d0[1] + + b PostTreat + + LoadOrigin: + mov r11, r0 + vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) + vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t + vshll.s16 q1, d2, #16 + vmla.f32 q4, q1, d0[1] + vmla.f32 q5, q2, d0[1] + + vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) + vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t + vshll.s16 q1, d2, #16 + vmla.f32 q6, q1, d0[1] + vmla.f32 q7, q2, d0[1] + + vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) + vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t + vshll.s16 q1, d2, #16 + vmla.f32 q8, q1, d0[1] + vmla.f32 q9, q2, d0[1] + + vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) + vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t + vshll.s16 q1, d2, #16 + vmla.f32 q10, q1, d0[1] + vmla.f32 q11, q2, d0[1] + + vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) + vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t + vshll.s16 q1, d2, #16 + vmla.f32 q12, q1, d0[1] + vmla.f32 q13, q2, d0[1] + + vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) + vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t + vshll.s16 q1, d2, #16 + vmla.f32 q14, q1, d0[1] + vmla.f32 q15, q2, d0[1] + + PostTreat: + vdup.f32 q2, d1[0] // min + vdup.f32 q1, d1[1] // max + + vmax.f32 q4, q4, q2 + vmax.f32 q5, q5, q2 + vmax.f32 q6, q6, q2 + vmax.f32 q7, q7, q2 + vmax.f32 q8, q8, q2 + vmax.f32 q9, q9, q2 + vmax.f32 q10, q10, q2 + vmax.f32 q11, q11, q2 + vmax.f32 q12, q12, q2 + vmax.f32 q13, q13, q2 + vmax.f32 q14, q14, q2 + vmax.f32 q15, q15, q2 + + vmin.f32 q4, q4, q1 + vmin.f32 q5, q5, q1 + vmin.f32 q6, q6, q1 + vmin.f32 q7, q7, q1 + vmin.f32 q8, q8, q1 + vmin.f32 q9, q9, q1 + vmin.f32 q10, q10, q1 + vmin.f32 q11, q11, q1 + vmin.f32 q12, q12, q1 + vmin.f32 q13, q13, q1 + vmin.f32 q14, q14, q1 + vmin.f32 q15, q15, q1 + + Store: + vshrn.i32 d8, q4, #16 // !!caution: these instructions has relying, eg: d10 must be written after reading q5. shift right 16bit of each float32 as int16_t + vshrn.i32 d9, q5, #16 + vshrn.i32 d10, q6, #16 + vshrn.i32 d11, q7, #16 + vshrn.i32 d12, q8, #16 + vshrn.i32 d13, q9, #16 + vshrn.i32 d14, q10, #16 + vshrn.i32 d15, q11, #16 + vshrn.i32 d16, q12, #16 + vshrn.i32 d17, q13, #16 + vshrn.i32 d18, q14, #16 + vshrn.i32 d19, q15, #16 + + vstm r0!, {d8, d9, d10, d11, d12, d13, d14, d15, d16, d17, d18, d19} + + add r0, r0, r8 + add r2, r2, r3 + + subs r4, r4, #1 + bne LoopH + +vpop {q4-q7} +pop {r4-r8, r10, r11, pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/bf16/MNNReluWithSlopeChannelBF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNReluWithSlopeChannelBF16.S similarity index 100% rename from source/backend/cpu/arm/arm32/bf16/MNNReluWithSlopeChannelBF16.S rename to backupcode/cpubackend/arm/arm32/bf16/MNNReluWithSlopeChannelBF16.S diff --git a/source/backend/cpu/arm/arm32/bf16/MNNUnPackC4_BF16.S b/backupcode/cpubackend/arm/arm32/bf16/MNNUnPackC4_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm32/bf16/MNNUnPackC4_BF16.S rename to backupcode/cpubackend/arm/arm32/bf16/MNNUnPackC4_BF16.S diff --git a/backupcode/cpubackend/arm/arm64/bf16/ARMV86_MNNPackedMatMulRemain_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/ARMV86_MNNPackedMatMulRemain_BF16.S new file mode 100644 index 000000000..2acfe6930 --- /dev/null +++ b/backupcode/cpubackend/arm/arm64/bf16/ARMV86_MNNPackedMatMulRemain_BF16.S @@ -0,0 +1,566 @@ +// +// ARMV86_MNNPackedMatMulRemain_BF16.S +// MNN +// +// Created by MNN on 2022/10/09. +// Copyright © 2018-2021 Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +.macro SET_ZERO d0, d1, d2, d3 + movi \d0\().4s, #0 + movi \d1\().4s, #0 + movi \d2\().4s, #0 + movi \d3\().4s, #0 +.endm + +.macro Float32ToBf16 d0, d1, d2, d3 + shrn \d0\().4h, \d0\().4s, #16 + shrn \d1\().4h, \d1\().4s, #16 + shrn \d2\().4h, \d2\().4s, #16 + shrn \d3\().4h, \d3\().4s, #16 +.endm + +.macro FOURFMAX s, d0, d1, d2, d3 + fmax \d0\().4s, \d0\().4s, \s\().4s + fmax \d1\().4s, \d1\().4s, \s\().4s + fmax \d2\().4s, \d2\().4s, \s\().4s + fmax \d3\().4s, \d3\().4s, \s\().4s +.endm + +.macro FOURFMIN s, d0, d1, d2, d3 + fmin \d0\().4s, \d0\().4s, \s\().4s + fmin \d1\().4s, \d1\().4s, \s\().4s + fmin \d2\().4s, \d2\().4s, \s\().4s + fmin \d3\().4s, \d3\().4s, \s\().4s +.endm + +.macro SET_BIAS s, d0, d1, d2 + mov \d0\().16b, \s\().16b + mov \d1\().16b, \s\().16b + mov \d2\().16b, \s\().16b +.endm + +// 12 * 8 * 4 MatMul +asm_function ARMV86_MNNPackedMatMulRemain_BF16 +//void ARMV86_MNNPackedMatMulRemain_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); +//Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias +sub sp, sp, #64 +str x19, [sp, #0] +str x20, [sp, #8] +str x21, [sp, #16] +str x22, [sp, #24] +ldr x11, [x4, #0] // aStride +ldr x9, [x4, #8] // l +ldr x10, [x4, #16] // h +lsl x11, x11, #2 // aStride * 4 +mov x22, #64 // B_stride = LP * HP = 4 * 8 * sizeof(int16_t) + +ldr x7, [x4, #24] // cStride +ldr x19, [x4, #40] // bExtraStride + +add x10, x10, #3 +lsr x10, x10, #2 +add x9, x9, #3 +lsr x9, x9, #2 + +cbz x5, Start +ld1 {v5.4s}, [x5] +dup v9.4s, v5.s[2] // Min Value +dup v10.4s, v5.s[3] // Max Value + +Start: + +E8: +cmp x3, #8 +blt E4 + +LoopE8: // e, TILE_BLOCK size is 8 + mov x20, x6 // bias + mov x8, x10 // updiv(h, 4) + mov x21, x0 // dest, C + mov x13, x2 // weight, B + + LH8: + cmp x8, #2 // h/4 > 2 + blt LH4 + sub x14, x7, #64 // cStride - 64 + LoopH8x8: + mov x15, x1 // src, A + mov x12, x9 // l + cbz x5, NoBiasLH8 + ld1 {v0.4h, v1.4h}, [x20], #16 // 8 * sizeof(int16_t) + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + mov v2.16b, v0.16b + mov v3.16b, v1.16b + uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + uzp1 v24.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v25.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3 + SET_BIAS v16, v18, v20, v22 + SET_BIAS v17, v19, v21, v23 + SET_BIAS v24, v26, v28, v30 + SET_BIAS v25, v27, v29, v31 + b LoopL + NoBiasLH8: + SET_ZERO v16, v17, v18, v19 + SET_ZERO v20, v21, v22, v23 + SET_ZERO v24, v25, v26, v27 + SET_ZERO v28, v29, v30, v31 + LoopL: + // A [8, 4, bf16] : rn = 4 : v4 - v7 + // B [8, 4, bf16] : rn = 4 : v0 - v3 + // C [8, 8, fp32] : rn = 16 : v16 - v31 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x15], x11 // A: 8 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x13], x22 // B: 8 * 4 * sizeof(int16_t) + .inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h + .inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h + .inst 0x6e40ecb2 // bfmmla v18.4s, v5.8h, v0.8h + .inst 0x6e41ecb3 // bfmmla v19.4s, v5.8h, v1.8h + .inst 0x6e40ecd4 // bfmmla v20.4s, v6.8h, v0.8h + .inst 0x6e41ecd5 // bfmmla v21.4s, v6.8h, v1.8h + .inst 0x6e40ecf6 // bfmmla v22.4s, v7.8h, v0.8h + .inst 0x6e41ecf7 // bfmmla v23.4s, v7.8h, v1.8h + .inst 0x6e42ec98 // bfmmla v24.4s, v4.8h, v2.8h + .inst 0x6e43ec99 // bfmmla v25.4s, v4.8h, v3.8h + .inst 0x6e42ecba // bfmmla v26.4s, v5.8h, v2.8h + .inst 0x6e43ecbb // bfmmla v27.4s, v5.8h, v3.8h + .inst 0x6e42ecdc // bfmmla v28.4s, v6.8h, v2.8h + .inst 0x6e43ecdd // bfmmla v29.4s, v6.8h, v3.8h + .inst 0x6e42ecfe // bfmmla v30.4s, v7.8h, v2.8h + .inst 0x6e43ecff // bfmmla v31.4s, v7.8h, v3.8h + subs x12, x12, #1 + bgt LoopL + LoopLEnd: + uzp1 v15.2d, v16.2d, v17.2d + uzp2 v16.2d, v16.2d, v17.2d + uzp1 v17.2d, v18.2d, v19.2d + uzp2 v18.2d, v18.2d, v19.2d + uzp1 v19.2d, v20.2d, v21.2d + uzp2 v20.2d, v20.2d, v21.2d + uzp1 v21.2d, v22.2d, v23.2d + uzp2 v22.2d, v22.2d, v23.2d + uzp1 v23.2d, v24.2d, v25.2d + uzp2 v24.2d, v24.2d, v25.2d + uzp1 v25.2d, v26.2d, v27.2d + uzp2 v26.2d, v26.2d, v27.2d + uzp1 v27.2d, v28.2d, v29.2d + uzp2 v28.2d, v28.2d, v29.2d + uzp1 v29.2d, v30.2d, v31.2d + uzp2 v30.2d, v30.2d, v31.2d + cbz x5, StoreLH8 + PostTreatLH8: + FOURFMAX v9, v15, v16, v17, v18 + FOURFMAX v9, v19, v20, v21, v22 + FOURFMAX v9, v23, v24, v25, v26 + FOURFMAX v9, v27, v28, v29, v30 + FOURFMIN v10, v15, v16, v17, v18 + FOURFMIN v10, v19, v20, v21, v22 + FOURFMIN v10, v23, v24, v25, v26 + FOURFMIN v10, v27, v28, v29, v30 + StoreLH8: + Float32ToBf16 v15, v16, v17, v18 + Float32ToBf16 v19, v20, v21, v22 + Float32ToBf16 v23, v24, v25, v26 + Float32ToBf16 v27, v28, v29, v30 + st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t) + add x0, x0, x14 + st1 {v23.4h, v24.4h, v25.4h, v26.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v27.4h, v28.4h, v29.4h, v30.4h}, [x0], #32 // 16 * sizeof(int16_t) + add x0, x0, x14 + add x13, x13, x19 // weight stride + sub x8, x8, #2 + cmp x8, #2 + bge LoopH8x8 + LH4: + cbz x8, E8End + LoopHRemain: + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasHRemain + ld1 {v0.4h}, [x20] + shll v0.4s, v0.4h, #16 + mov v2.16b, v0.16b + uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + SET_BIAS v16, v18, v20, v22 + SET_BIAS v17, v19, v21, v23 + b LoopLR + NoBiasHRemain: + SET_ZERO v16, v17, v18, v19 + SET_ZERO v20, v21, v22, v23 + LoopLR: + // A [8, 4, bf16] : rn = 4 : v4 - v7 + // B [4, 4, bf16] : rn = 2 : v0 - v1 + // C [8, 4, fp32] : rn = 8 : v16 - v23 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x15], x11 // A: 8 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h}, [x13], x22 // B: 4 * 4 * sizeof(int16_t) + .inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h + .inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h + .inst 0x6e40ecb2 // bfmmla v18.4s, v5.8h, v0.8h + .inst 0x6e41ecb3 // bfmmla v19.4s, v5.8h, v1.8h + .inst 0x6e40ecd4 // bfmmla v20.4s, v6.8h, v0.8h + .inst 0x6e41ecd5 // bfmmla v21.4s, v6.8h, v1.8h + .inst 0x6e40ecf6 // bfmmla v22.4s, v7.8h, v0.8h + .inst 0x6e41ecf7 // bfmmla v23.4s, v7.8h, v1.8h + subs x12, x12, #1 + bne LoopLR + LoopLREnd: + uzp1 v15.2d, v16.2d, v17.2d + uzp2 v16.2d, v16.2d, v17.2d + uzp1 v17.2d, v18.2d, v19.2d + uzp2 v18.2d, v18.2d, v19.2d + uzp1 v19.2d, v20.2d, v21.2d + uzp2 v20.2d, v20.2d, v21.2d + uzp1 v21.2d, v22.2d, v23.2d + uzp2 v22.2d, v22.2d, v23.2d + cbz x5, StoreLH8x4 + PostTreatLH8x4: + FOURFMAX v9, v15, v16, v17, v18 + FOURFMAX v9, v19, v20, v21, v22 + FOURFMIN v10, v15, v16, v17, v18 + FOURFMIN v10, v19, v20, v21, v22 + StoreLH8x4: + Float32ToBf16 v15, v16, v17, v18 + Float32ToBf16 v19, v20, v21, v22 + st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t) + E8End: + sub x3, x3, #8 + cmp x3, #8 + add x0, x21, #64 // move dest address of 8 * 4 * sizeof(int16_t) + add x1, x1, #64 // move A matrix address of 8 * 4 * sizeof(int16_t) + bge LoopE8 + +E4: +cmp x3, #4 +mov x20, x6 +blt E2 + +mov x8, x10 +mov x21, x0 +mov x13, x2 + +cmp x8, #2 +blt E4LH4 +E4LH8: + E4LoopH8: + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasE4 + ld1 {v0.4h, v1.4h}, [x20], #16 // 8 * sizeof(int16_t) + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + mov v2.16b, v0.16b + mov v3.16b, v1.16b + uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + uzp1 v20.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v21.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3 + mov v18.16b, v16.16b + mov v19.16b, v17.16b + mov v22.16b, v20.16b + mov v23.16b, v21.16b + b E4LoopL + NoBiasE4: + SET_ZERO v16, v17, v18, v19 + SET_ZERO v20, v21, v22, v23 + E4LoopL: + // A [4, 4, bf16] : rn = 4 : v4 - v5 + // B [8, 4, bf16] : rn = 4 : v0 - v3 + // C [4, 8, fp32] : rn = 8 : v16 - v23 + ld1 {v4.8h, v5.8h}, [x15], x11 // A: 4 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x13], x22 // B: 8 * 4 * sizeof(int16_t) + .inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h + .inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h + .inst 0x6e40ecb2 // bfmmla v18.4s, v5.8h, v0.8h + .inst 0x6e41ecb3 // bfmmla v19.4s, v5.8h, v1.8h + .inst 0x6e42ec94 // bfmmla v20.4s, v4.8h, v2.8h + .inst 0x6e43ec95 // bfmmla v21.4s, v4.8h, v3.8h + .inst 0x6e42ecb6 // bfmmla v22.4s, v5.8h, v2.8h + .inst 0x6e43ecb7 // bfmmla v23.4s, v5.8h, v3.8h + subs x12, x12, #1 + bgt E4LoopL + E4LoopLEnd: + uzp1 v15.2d, v16.2d, v17.2d + uzp2 v16.2d, v16.2d, v17.2d + uzp1 v17.2d, v18.2d, v19.2d + uzp2 v18.2d, v18.2d, v19.2d + uzp1 v19.2d, v20.2d, v21.2d + uzp2 v20.2d, v20.2d, v21.2d + uzp1 v21.2d, v22.2d, v23.2d + uzp2 v22.2d, v22.2d, v23.2d + cbz x5, StoreLH4x8 + PostTreatLH4x8: + FOURFMAX v9, v15, v16, v17, v18 + FOURFMAX v9, v19, v20, v21, v22 + FOURFMIN v10, v15, v16, v17, v18 + FOURFMIN v10, v19, v20, v21, v22 + StoreLH4x8: + Float32ToBf16 v15, v16, v17, v18 + Float32ToBf16 v19, v20, v21, v22 + st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], x7 // 16 * sizeof(int16_t) + st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], x7 // 16 * sizeof(int16_t) + add x13, x13, x19 // weight stride + sub x8, x8, #2 + cmp x8, #2 + bge E4LoopH8 + E4LH4: + cbz x8, E4End + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasE4R + ld1 {v0.4h}, [x20] + shll v0.4s, v0.4h, #16 + mov v2.16b, v0.16b + uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + mov v18.16b, v16.16b + mov v19.16b, v17.16b + b E4LoopLR + NoBiasE4R: + SET_ZERO v16, v17, v18, v19 + E4LoopLR: + // A [4, 4, bf16] : rn = 4 : v4 - v5 + // B [4, 4, bf16] : rn = 4 : v0 - v1 + // C [4, 4, fp32] : rn = 4 : v16 - v19 + ld1 {v4.8h, v5.8h}, [x15], x11 // A: 4 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h}, [x13], x22 // B: 4 * 4 * sizeof(int16_t) + .inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h + .inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h + .inst 0x6e40ecb2 // bfmmla v18.4s, v5.8h, v0.8h + .inst 0x6e41ecb3 // bfmmla v19.4s, v5.8h, v1.8h + subs x12, x12, #1 + bgt E4LoopLR + E4LoopLREnd: + uzp1 v15.2d, v16.2d, v17.2d + uzp2 v16.2d, v16.2d, v17.2d + uzp1 v17.2d, v18.2d, v19.2d + uzp2 v18.2d, v18.2d, v19.2d + cbz x5, StoreLH4x4 + PostTreatLH4x4: + FOURFMAX v9, v15, v16, v17, v18 + FOURFMIN v10, v19, v20, v21, v22 + StoreLH4x4: + Float32ToBf16 v15, v16, v17, v18 + st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0] // 16 * sizeof(int16_t) + E4End: + sub x3, x3, #4 + add x0, x21, #32 // move dest address of 4 * 4 * sizeof(int16_t) + add x1, x1, #32 // move dest address of 4 * 4 * sizeof(int16_t) + +E2: +cmp x3, #2 +mov x20, x6 +blt E1 + +mov x8, x10 +mov x21, x0 +mov x13, x2 + +cmp x8, #2 +blt E2LH4 +E2LH8: + E2LoopH8: + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasE2 + ld1 {v0.4h, v1.4h}, [x20], #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + mov v2.16b, v0.16b + mov v3.16b, v1.16b + uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + uzp1 v18.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v19.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3 + b E2LoopL + NoBiasE2: + SET_ZERO v16, v17, v18, v19 + E2LoopL: + // A [2, 4, bf16] : rn = 1 : v4 + // B [8, 4, bf16] : rn = 2 : v0 - v3 + // C [2, 8, fp32] : rn = 4 : v16 - v19 + ld1 {v4.8h}, [x15], x11 // A: 2 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x13], x22 // B: 8 * 4 * sizeof(int16_t) + .inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h + .inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h + .inst 0x6e42ec92 // bfmmla v18.4s, v4.8h, v2.8h + .inst 0x6e43ec93 // bfmmla v19.4s, v4.8h, v3.8h + subs x12, x12, #1 + bgt E2LoopL + E2LoopLEnd: + uzp1 v15.2d, v16.2d, v17.2d + uzp2 v16.2d, v16.2d, v17.2d + uzp1 v17.2d, v18.2d, v19.2d + uzp2 v18.2d, v18.2d, v19.2d + cbz x5, StoreLH2x8 + PostTreatLH2x8: + FOURFMAX v9, v15, v16, v17, v18 + FOURFMIN v10, v15, v16, v17, v18 + StoreLH2x8: + Float32ToBf16 v15, v16, v17, v18 + st1 {v15.4h, v16.4h}, [x0], x7 // 8 * sizeof(int16_t) + st1 {v17.4h, v18.4h}, [x0], x7 // 8 * sizeof(int16_t) + add x13, x13, x19 // weight stride + sub x8, x8, #2 + cmp x8, #2 + bge E2LoopH8 + E2LH4: + cbz x8, E2End + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasE2R + ld1 {v0.4h}, [x20] + shll v0.4s, v0.4h, #16 + mov v2.16b, v0.16b + uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + b E2LoopLR + NoBiasE2R: + movi v16.4s, #0 + movi v17.4s, #0 + E2LoopLR: + // A [2, 4, bf16] : rn = 1 : v4 + // B [4, 4, bf16] : rn = 2 : v0 - v1 + // C [2, 4, fp32] : rn = 2 : v16 - v17 + ld1 {v4.8h}, [x15], x11 // A: 2 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h}, [x13], x22 // B: 4 * 4 * sizeof(int16_t) + .inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h + .inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h + subs x12, x12, #1 + bgt E2LoopLR + E2LoopLREnd: + uzp1 v15.2d, v16.2d, v17.2d + uzp2 v16.2d, v16.2d, v17.2d + cbz x5, StoreLH2x4 + PostTreatLH2x4: + fmax v15.4s, v15.4s, v9.4s + fmax v16.4s, v16.4s, v9.4s + fmin v15.4s, v15.4s, v10.4s + fmin v16.4s, v16.4s, v10.4s + StoreLH2x4: + shrn v15.4h, v15.4s, #16 + shrn v16.4h, v16.4s, #16 + st1 {v15.4h, v16.4h}, [x0] // 8 * sizeof(int16_t) + E2End: + sub x3, x3, #2 + add x0, x21, #16 // move dest address of 2 * 4 * sizeof(int16_t) + add x1, x1, #16 // move dest address of 2 * 4 * sizeof(int16_t) + +E1: +cmp x3, #0 +beq End + +LoopE1: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + + cmp x8, #2 + blt E1LH4 + + E1LH8: + E1LoopH8: + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasE1 + ld1 {v0.4h, v1.4h}, [x20], #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + mov v2.16b, v0.16b + mov v3.16b, v1.16b + uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + uzp1 v18.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v19.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3 + b E1LoopL + NoBiasE1: + SET_ZERO v16, v17, v18, v19 + E1LoopL: + // A [1, 4, bf16] : rn = 1 : v4 + // B [8, 4, bf16] : rn = 4 : v0 - v3 + // C [1, 8, fp32] : rn = 4 : v16 - v19 + ld1 {v4.4h}, [x15], x11 // A: 1 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x13], x22 // B: 8 * 4 * sizeof(int16_t) + .inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h + .inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h + .inst 0x6e42ec92 // bfmmla v18.4s, v4.8h, v2.8h + .inst 0x6e43ec93 // bfmmla v19.4s, v4.8h, v3.8h + subs x12, x12, #1 + bgt E1LoopL + E1LoopLEnd: + // v16-v19: [r0, r1, 0, 0] + uzp1 v15.2d, v16.2d, v17.2d + uzp1 v16.2d, v18.2d, v19.2d + cbz x5, StoreLH1x8 + PostTreatLH1x8: + fmax v15.4s, v15.4s, v9.4s + fmax v16.4s, v16.4s, v9.4s + fmin v15.4s, v15.4s, v10.4s + fmin v16.4s, v16.4s, v10.4s + StoreLH1x8: + shrn v15.4h, v15.4s, #16 + shrn v16.4h, v16.4s, #16 + st1 {v15.4h}, [x0], x7 + st1 {v16.4h}, [x0], x7 + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + bge E1LoopH8 + + E1LH4: + cbz x8, E1End + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasE1R + ld1 {v0.4h}, [x20] + shll v0.4s, v0.4h, #16 + mov v2.16b, v0.16b + uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + b E1LoopLR + NoBiasE1R: + movi v16.4s, #0 + movi v17.4s, #0 + E1LoopLR: + // A [1, 4, bf16] : rn = 1 : v4 + // B [4, 4, bf16] : rn = 2 : v0 - v1 + // C [1, 8, fp32] : rn = 4 : v16 - v17 + ld1 {v4.4h}, [x15], x11 // A: 1 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h}, [x13], x22 // B: 4 * 4 * sizeof(int16_t) + .inst 0x6e40ec90 // bfmmla v16.4s, v4.8h, v0.8h + .inst 0x6e41ec91 // bfmmla v17.4s, v4.8h, v1.8h + subs x12, x12, #1 + bgt E1LoopLR + E1LoopLREnd: + uzp1 v15.2d, v16.2d, v17.2d + cbz x5, StoreLH1x4 + PostTreatLH1x4: + fmax v15.4s, v15.4s, v9.4s + fmin v15.4s, v15.4s, v10.4s + StoreLH1x4: + shrn v15.4h, v15.4s, #16 + st1 {v15.4h}, [x0] + E1End: + subs x3, x3, #1 + add x0, x21, #8 + add x1, x1, #8 + bne LoopE1 +End: +ldr x19, [sp, #0] +ldr x20, [sp, #8] +ldr x21, [sp, #16] +ldr x22, [sp, #24] +add sp, sp, #64 + +ret +#endif diff --git a/backupcode/cpubackend/arm/arm64/bf16/ARMV86_MNNPackedMatMul_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/ARMV86_MNNPackedMatMul_BF16.S new file mode 100644 index 000000000..7d3282969 --- /dev/null +++ b/backupcode/cpubackend/arm/arm64/bf16/ARMV86_MNNPackedMatMul_BF16.S @@ -0,0 +1,286 @@ +// +// ARMV86_MNNPackedMatMul_BF16.S +// MNN +// +// Created by MNN on 2022/10/09. +// Copyright © 2018-2021 Alibaba Group Holding Limited +// +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro SET_ZERO d0, d1, d2, d3 + movi \d0\().4s, #0 + movi \d1\().4s, #0 + movi \d2\().4s, #0 + movi \d3\().4s, #0 +.endm + +.macro Float32ToBf16 d0, d1, d2, d3 + shrn \d0\().4h, \d0\().4s, #16 + shrn \d1\().4h, \d1\().4s, #16 + shrn \d2\().4h, \d2\().4s, #16 + shrn \d3\().4h, \d3\().4s, #16 +.endm + +.macro FOURFMAX s, d0, d1, d2, d3 + fmax \d0\().4s, \d0\().4s, \s\().4s + fmax \d1\().4s, \d1\().4s, \s\().4s + fmax \d2\().4s, \d2\().4s, \s\().4s + fmax \d3\().4s, \d3\().4s, \s\().4s +.endm + +.macro FOURFMIN s, d0, d1, d2, d3 + fmin \d0\().4s, \d0\().4s, \s\().4s + fmin \d1\().4s, \d1\().4s, \s\().4s + fmin \d2\().4s, \d2\().4s, \s\().4s + fmin \d3\().4s, \d3\().4s, \s\().4s +.endm + +.macro SET_BIAS s, d0, d1, d2, d3 + mov \d0\().16b, \s\().16b + mov \d1\().16b, \s\().16b + mov \d2\().16b, \s\().16b + mov \d3\().16b, \s\().16b +.endm + +// 12 * 8 * 4 MatMul +asm_function ARMV86_MNNPackedMatMul_BF16 +//void ARMV86_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); +// x0: C, x1:A, x2:B, x3:parameter, x4: postParameters, x5:bias +stp d14, d15, [sp, #-80]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] +stp x19, x21, [sp, #64] + +//ldr x8, [x3, #0] // deprecated +ldr x9, [x3, #8] // l +ldr x10, [x3, #16] // h +mov x11, #64 // B_stride = LP * HP = 4 * 8 * sizeof(int16_t) + +ldr x13, [x3, #24] // cStride +ldr x7, [x3, #40] // bExtraStride + +add x10, x10, #3 +lsr x10, x10, #2 +add x9, x9, #3 +lsr x9, x9, #2 + +cbz x4, Start +ld1 {v5.4s}, [x4] +mov w19, v5.s[2] // min value +mov w20, v5.s[3] // max value + +Start: + cmp x10, #2 + blt LH4 +LH8: + sub x14, x13, #96 // cStride - 96 +LoopH: + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasH8 + ld1 {v0.4h, v1.4h}, [x5], #16 // 8 * sizeof(int16_t) + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + mov v2.16b, v0.16b + mov v3.16b, v1.16b + uzp1 v18.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v19.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + uzp1 v30.2d, v1.2d, v3.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v31.2d, v1.2d, v3.2d // bias_2, bias_3, bias_2, bias_3 + SET_BIAS v18, v8, v10, v12, v14 + mov v16.16b, v18.16b + SET_BIAS v19, v9, v11, v13, v15 + mov v17.16b, v19.16b + SET_BIAS v30, v20, v22, v24, v26 + mov v28.16b, v30.16b + SET_BIAS v31, v21, v23, v25, v27 + mov v29.16b, v31.16b + b LoopL + NoBiasH8: + SET_ZERO v8, v9, v10, v11 + SET_ZERO v12, v13, v14, v15 + SET_ZERO v16, v17, v18, v19 + SET_ZERO v20, v21, v22, v23 + SET_ZERO v24, v25, v26, v27 + SET_ZERO v28, v29, v30, v31 + LoopL: + // A [12, 4, bf16] : rn = 6 : v2 - v7 + // B [ 8, 4, bf16] : rn = 2 : v0 - v1 + // C [12, 8, fp32] : rn = 24 : v8 - v31 + ld1 {v2.8h, v3.8h, v4.8h, v5.8h}, [x15], #64 // A: 8 * 4 * sizeof(int16_t) + ld1 {v6.8h, v7.8h}, [x15], #32 // A: 4 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h}, [x2], #32 // B: 4 * 4 * sizeof(int16_t) + .inst 0x6e40ec48 // bfmmla v8.4s, v2.8h, v0.8h + .inst 0x6e41ec49 // bfmmla v9.4s, v2.8h, v1.8h + .inst 0x6e40ec6a // bfmmla v10.4s, v3.8h, v0.8h + .inst 0x6e41ec6b // bfmmla v11.4s, v3.8h, v1.8h + .inst 0x6e40ec8c // bfmmla v12.4s, v4.8h, v0.8h + .inst 0x6e41ec8d // bfmmla v13.4s, v4.8h, v1.8h + .inst 0x6e40ecae // bfmmla v14.4s, v5.8h, v0.8h + .inst 0x6e41ecaf // bfmmla v15.4s, v5.8h, v1.8h + .inst 0x6e40ecd0 // bfmmla v16.4s, v6.8h, v0.8h + .inst 0x6e41ecd1 // bfmmla v17.4s, v6.8h, v1.8h + .inst 0x6e40ecf2 // bfmmla v18.4s, v7.8h, v0.8h + .inst 0x6e41ecf3 // bfmmla v19.4s, v7.8h, v1.8h + ld1 {v0.8h, v1.8h}, [x2], #32 // B: 4 * 4 * sizeof(int16_t) + .inst 0x6e40ec54 // bfmmla v20.4s, v2.8h, v0.8h + .inst 0x6e41ec55 // bfmmla v21.4s, v2.8h, v1.8h + .inst 0x6e40ec76 // bfmmla v22.4s, v3.8h, v0.8h + .inst 0x6e41ec77 // bfmmla v23.4s, v3.8h, v1.8h + .inst 0x6e40ec98 // bfmmla v24.4s, v4.8h, v0.8h + .inst 0x6e41ec99 // bfmmla v25.4s, v4.8h, v1.8h + .inst 0x6e40ecba // bfmmla v26.4s, v5.8h, v0.8h + .inst 0x6e41ecbb // bfmmla v27.4s, v5.8h, v1.8h + .inst 0x6e40ecdc // bfmmla v28.4s, v6.8h, v0.8h + .inst 0x6e41ecdd // bfmmla v29.4s, v6.8h, v1.8h + .inst 0x6e40ecfe // bfmmla v30.4s, v7.8h, v0.8h + .inst 0x6e41ecff // bfmmla v31.4s, v7.8h, v1.8h + subs x12, x12, #1 + bgt LoopL + LoopLEnd: + uzp1 v7.2d, v8.2d, v9.2d + uzp2 v8.2d, v8.2d, v9.2d + uzp1 v9.2d, v10.2d, v11.2d + uzp2 v10.2d, v10.2d, v11.2d + uzp1 v11.2d, v12.2d, v13.2d + uzp2 v12.2d, v12.2d, v13.2d + uzp1 v13.2d, v14.2d, v15.2d + uzp2 v14.2d, v14.2d, v15.2d + uzp1 v15.2d, v16.2d, v17.2d + uzp2 v16.2d, v16.2d, v17.2d + uzp1 v17.2d, v18.2d, v19.2d + uzp2 v18.2d, v18.2d, v19.2d + uzp1 v19.2d, v20.2d, v21.2d + uzp2 v20.2d, v20.2d, v21.2d + uzp1 v21.2d, v22.2d, v23.2d + uzp2 v22.2d, v22.2d, v23.2d + uzp1 v23.2d, v24.2d, v25.2d + uzp2 v24.2d, v24.2d, v25.2d + uzp1 v25.2d, v26.2d, v27.2d + uzp2 v26.2d, v26.2d, v27.2d + uzp1 v27.2d, v28.2d, v29.2d + uzp2 v28.2d, v28.2d, v29.2d + uzp1 v29.2d, v30.2d, v31.2d + uzp2 v30.2d, v30.2d, v31.2d + cbz x4, StoreLH8 + PostTreatLH8: + dup v5.4s, w19 + dup v6.4s, w20 + FOURFMAX v5, v7, v8, v9, v10 + FOURFMAX v5, v11, v12, v13, v14 + FOURFMAX v5, v15, v16, v17, v18 + FOURFMAX v5, v19, v20, v21, v22 + FOURFMAX v5, v23, v24, v25, v26 + FOURFMAX v5, v27, v28, v29, v30 + FOURFMIN v6, v7, v8, v9, v10 + FOURFMIN v6, v11, v12, v13, v14 + FOURFMIN v6, v15, v16, v17, v18 + FOURFMIN v6, v19, v20, v21, v22 + FOURFMIN v6, v23, v24, v25, v26 + FOURFMIN v6, v27, v28, v29, v30 + StoreLH8: + Float32ToBf16 v7, v8, v9, v10 + Float32ToBf16 v11, v12, v13, v14 + Float32ToBf16 v15, v16, v17, v18 + Float32ToBf16 v19, v20, v21, v22 + Float32ToBf16 v23, v24, v25, v26 + Float32ToBf16 v27, v28, v29, v30 + st1 {v7.4h, v8.4h, v9.4h, v10.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v11.4h, v12.4h, v13.4h, v14.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t) + add x0, x0, x14 + st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v23.4h, v24.4h, v25.4h, v26.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v27.4h, v28.4h, v29.4h, v30.4h}, [x0], #32 // 16 * sizeof(int16_t) + add x0, x0, x14 + add x2, x2, x7 // weight stride + sub x10, x10, #2 + cmp x10, #2 + bge LoopH +LH4: +cbz x10, End +LoopHR: + mov x15, x1 + mov x12, x9 + cbz x5, NoBiasH4 + ld1 {v0.4h}, [x5], #8 // 8 * sizeof(int16_t) + shll v0.4s, v0.4h, #16 + mov v2.16b, v0.16b + uzp1 v18.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 + uzp2 v19.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 + SET_BIAS v18, v8, v10, v12, v14 + mov v16.16b, v18.16b + SET_BIAS v19, v9, v11, v13, v15 + mov v17.16b, v19.16b + b LoopLR + NoBiasH4: + SET_ZERO v8, v9, v10, v11 + SET_ZERO v12, v13, v14, v15 + SET_ZERO v16, v17, v18, v19 + LoopLR: + // A [12, 4, bf16] : rn = 6 : v2 - v7 + // B [ 4, 4, bf16] : rn = 2 : v0 - v1 + // C [12, 4, fp32] : rn = 12 : v8 - v19 + ld1 {v2.8h, v3.8h, v4.8h, v5.8h}, [x15], #64 // A: 8 * 4 * sizeof(int16_t) + ld1 {v6.8h, v7.8h}, [x15], #32 // A: 4 * 4 * sizeof(int16_t) + ld1 {v0.8h, v1.8h}, [x2], x11 // B: 4 * 4 * sizeof(int16_t) + .inst 0x6e40ec48 // bfmmla v8.4s, v2.8h, v0.8h + .inst 0x6e41ec49 // bfmmla v9.4s, v2.8h, v1.8h + .inst 0x6e40ec6a // bfmmla v10.4s, v3.8h, v0.8h + .inst 0x6e41ec6b // bfmmla v11.4s, v3.8h, v1.8h + .inst 0x6e40ec8c // bfmmla v12.4s, v4.8h, v0.8h + .inst 0x6e41ec8d // bfmmla v13.4s, v4.8h, v1.8h + .inst 0x6e40ecae // bfmmla v14.4s, v5.8h, v0.8h + .inst 0x6e41ecaf // bfmmla v15.4s, v5.8h, v1.8h + .inst 0x6e40ecd0 // bfmmla v16.4s, v6.8h, v0.8h + .inst 0x6e41ecd1 // bfmmla v17.4s, v6.8h, v1.8h + .inst 0x6e40ecf2 // bfmmla v18.4s, v7.8h, v0.8h + .inst 0x6e41ecf3 // bfmmla v19.4s, v7.8h, v1.8h + subs x12, x12, #1 + bgt LoopLR + LoopLREnd: + add x2, x2, x7 // weight stride + uzp1 v7.2d, v8.2d, v9.2d + uzp2 v8.2d, v8.2d, v9.2d + uzp1 v9.2d, v10.2d, v11.2d + uzp2 v10.2d, v10.2d, v11.2d + uzp1 v11.2d, v12.2d, v13.2d + uzp2 v12.2d, v12.2d, v13.2d + uzp1 v13.2d, v14.2d, v15.2d + uzp2 v14.2d, v14.2d, v15.2d + uzp1 v15.2d, v16.2d, v17.2d + uzp2 v16.2d, v16.2d, v17.2d + uzp1 v17.2d, v18.2d, v19.2d + uzp2 v18.2d, v18.2d, v19.2d + cbz x4, StoreLH4 + PostTreatLH4: + dup v5.4s, w19 + dup v6.4s, w20 + FOURFMAX v5, v7, v8, v9, v10 + FOURFMAX v5, v11, v12, v13, v14 + FOURFMAX v5, v15, v16, v17, v18 + FOURFMIN v6, v7, v8, v9, v10 + FOURFMIN v6, v11, v12, v13, v14 + FOURFMIN v6, v15, v16, v17, v18 + StoreLH4: + Float32ToBf16 v7, v8, v9, v10 + Float32ToBf16 v11, v12, v13, v14 + Float32ToBf16 v15, v16, v17, v18 + st1 {v7.4h, v8.4h, v9.4h, v10.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v11.4h, v12.4h, v13.4h, v14.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t) +End: +ldp x19, x21, [sp, #64] +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #80 +ret + +#endif diff --git a/source/backend/cpu/arm/arm64/bf16/MNNAxByClampBroadcastC4_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNAxByClampBroadcastC4_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm64/bf16/MNNAxByClampBroadcastC4_BF16.S rename to backupcode/cpubackend/arm/arm64/bf16/MNNAxByClampBroadcastC4_BF16.S diff --git a/source/backend/cpu/arm/arm64/bf16/MNNConvRunForLineDepthwise_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNConvRunForLineDepthwise_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm64/bf16/MNNConvRunForLineDepthwise_BF16.S rename to backupcode/cpubackend/arm/arm64/bf16/MNNConvRunForLineDepthwise_BF16.S diff --git a/source/backend/cpu/arm/arm64/bf16/MNNConvRunForUnitDepthWise_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNConvRunForUnitDepthWise_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm64/bf16/MNNConvRunForUnitDepthWise_BF16.S rename to backupcode/cpubackend/arm/arm64/bf16/MNNConvRunForUnitDepthWise_BF16.S diff --git a/source/backend/cpu/arm/arm64/bf16/MNNGelu_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNGelu_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm64/bf16/MNNGelu_BF16.S rename to backupcode/cpubackend/arm/arm64/bf16/MNNGelu_BF16.S diff --git a/backupcode/cpubackend/arm/arm64/bf16/MNNPackC4ForMatMul_A_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNPackC4ForMatMul_A_BF16.S new file mode 100644 index 000000000..faa7d31a1 --- /dev/null +++ b/backupcode/cpubackend/arm/arm64/bf16/MNNPackC4ForMatMul_A_BF16.S @@ -0,0 +1,260 @@ + +// +// NEON_MNNPackC4ForMatMul_A_BF16.S +// MNN +// +// Created by MNN on 2021/02/26. +// Copyright © 2018-2021 Alibaba Group Holding Limited +// +#ifdef __aarch64__ +#include "MNNAsmGlobal.h" + +.macro transpose_4x4 x0, x1, x2, x3, x5, x6 // transpose 4x4 of sizeof(int16_t), only low half simd vector is valid. + trn1 \x5\().4h, \x0\().4h, \x1\().4h + trn2 \x1\().4h, \x0\().4h, \x1\().4h + trn1 \x6\().4h, \x2\().4h, \x3\().4h + trn2 \x3\().4h, \x2\().4h, \x3\().4h + trn1 \x0\().2s, \x5\().2s, \x6\().2s + trn2 \x2\().2s, \x5\().2s, \x6\().2s + trn1 \x6\().2s, \x1\().2s, \x3\().2s + trn2 \x3\().2s, \x1\().2s, \x3\().2s + mov \x1\().8b, \x6\().8b +.endm + +.text +.align 5 +asm_function NEON_MNNPackC4ForMatMul_A_BF16 +// treate float pointer as int16_t* +//void NEON_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) +//Auto: x0: dest, x1:sourceGroup, x2: info, x3:el +ldr w10, [x2, #0] // number +mov x4, #0 +mov x11, #0 +mov x6, #0 +ldr w4, [x2, #4] // eReal +ldr w11, [x2, #8] // eDest +ldr w6, [x2, #12] // xOffset +// xOffset -> xOffset * 4 * sizeof(int16_t) +// eReal -> eReal * 4 * sizeof(int16_t) +// eDest -> eDest * sizeof(int16_t) +mov x12, #2 // sizeof(int16_t). kept as a const +mov x9, #8 +mul x4, x9, x4 +mul x11, x12, x11 +mul x6, x9, x6 + +LoopNumber: +mov x2, #0 +mov x5, #0 +mov x8, #0 +mov x7, #0 +ldr w5, [x3, #4] // l +ldr w8, [x3, #8] // eOffset +ldr w7, [x3, #12] // lOffset + +mov x13, x0 +mov x14, x1 +ldr x1, [x1, #0] + +// Compute dest ptr: x0 = x0 + eOffset * sizeof(int16_t) + lOffset * eDest * sizeof(int16_t) +mul x7, x11, x7 +mul x8, x12, x8 +add x0, x0, x7 +add x0, x0, x8 + +ldr w2, [x3, #0] // e + +Body: +cmp w2, #12 // original eDest +bne Right + cmp w5, #4 + blt LoopEL3 + LoopL4: + mov x2, x1 +.macro MAIN_TRANSPOSE + ld1 {v0.4h}, [x1], x6 // load size: 4 * sizeof(int16_t), jump one stride line as x6 + ld1 {v3.4h}, [x1], x6 + ld1 {v6.4h}, [x1], x6 + ld1 {v17.4h}, [x1], x6 + ld1 {v1.4h}, [x1], x6 + ld1 {v4.4h}, [x1], x6 + ld1 {v7.4h}, [x1], x6 + ld1 {v18.4h}, [x1], x6 + ld1 {v2.4h}, [x1], x6 + ld1 {v5.4h}, [x1], x6 + ld1 {v16.4h}, [x1], x6 + ld1 {v19.4h}, [x1], x6 + + transpose_4x4 v0, v3, v6, v17, v23, v24 + transpose_4x4 v1, v4, v7, v18, v25, v26 + transpose_4x4 v2, v5, v16, v19, v27, v28 +.endm + MAIN_TRANSPOSE + + stp d0, d1, [x0] // store size: 2 * 4 * sizeof(int16_t) + stp d2, d3, [x0, #(16 * 1)] + stp d4, d5, [x0, #(16 * 2)] + stp d6, d7, [x0, #(16 * 3)] + stp d16, d17, [x0, #(16 * 4)] + stp d18, d19, [x0, #(16 * 5)] + add x0, x0, #(16 * 6) + + // st1 {v0.4h}, [x0], #8 // store size: 4 * sizeof(int16_t) + // st1 {v1.4h}, [x0], #8 + // st1 {v2.4h}, [x0], #8 + // st1 {v3.4h}, [x0], #8 + // st1 {v4.4h}, [x0], #8 + // st1 {v5.4h}, [x0], #8 + // st1 {v6.4h}, [x0], #8 + // st1 {v7.4h}, [x0], #8 + // st1 {v16.4h}, [x0], #8 + // st1 {v17.4h}, [x0], #8 + // st1 {v18.4h}, [x0], #8 + // st1 {v19.4h}, [x0], #8 + + // st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x0], #32 + // st1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x0], #32 + // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 + + add x1, x2, x4 + sub x5, x5, #4 + cmp w5, #4 + bge LoopL4 + + LoopEL3: + cmp w5, #3 + blt LoopEL2 + MAIN_TRANSPOSE + + stp d0, d1, [x0] // store size: 2 * 4 * sizeof(int16_t) + stp d2, d3, [x0, #(16 * 1)] + stp d4, d5, [x0, #(16 * 2)] + stp d6, d7, [x0, #(16 * 3)] + str d16, [x0, #(16 * 4)] + add x0, x0, #(16 * 4 + 8) + + // st1 {v0.4h}, [x0], #8 // store size: 4 * sizeof(int16_t) + // st1 {v1.4h}, [x0], #8 + // st1 {v2.4h}, [x0], #8 + // st1 {v3.4h}, [x0], #8 + // st1 {v4.4h}, [x0], #8 + // st1 {v5.4h}, [x0], #8 + // st1 {v6.4h}, [x0], #8 + // st1 {v7.4h}, [x0], #8 + // st1 {v16.4h}, [x0], #8 + + // st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x0], #32 + // st1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x0], #32 + // st1 {v16.4h}, [x0], #8 + + b LoopEEnd + + LoopEL2: + cmp w5, #2 + blt LoopEL1 + MAIN_TRANSPOSE + stp d0, d1, [x0] // store size: 2 * 4 * sizeof(int16_t) + stp d2, d3, [x0, #(16 * 1)] + stp d4, d5, [x0, #(16 * 2)] + add x0, x0, #(16 * 3) + + // st1 {v0.4h}, [x0], #8 // store size: 4 * sizeof(int16_t) + // st1 {v1.4h}, [x0], #8 + // st1 {v2.4h}, [x0], #8 + // st1 {v3.4h}, [x0], #8 + // st1 {v4.4h}, [x0], #8 + // st1 {v5.4h}, [x0], #8 + + // st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x0], #32 + // st1 {v4.4h, v5.4h}, [x0], #16 + + b LoopEEnd + + LoopEL1: + cmp w5, #1 + blt LoopEEnd + MAIN_TRANSPOSE + stp d0, d1, [x0] + str d2, [x0, #16] + add x0, x0, #(16 + 8) + + // st1 {v0.4h}, [x0], #8 // store size: 4 * sizeof(int16_t) + // st1 {v1.4h}, [x0], #8 + // st1 {v2.4h}, [x0], #8 + + // st1 {v0.4h, v1.4h, v2.4h}, [x0], #24 + + LoopEEnd: + +b End + + +Right: + +LoopE1: + mov w9, w5 + mov x7, x1 + mov x8, x0 + cmp w5, #4 + blt LoopE1L3 + LoopE1L4: + ld1 {v0.4h}, [x1], x4 + st1 {v0.h}[0], [x0], x11 + st1 {v0.h}[1], [x0], x11 + st1 {v0.h}[2], [x0], x11 + st1 {v0.h}[3], [x0], x11 + sub w5, w5, #4 + cmp w5, #4 + bge LoopE1L4 + + LoopE1L3: + cmp w5, #3 + blt LoopE1L2 + ld1 {v0.4h}, [x1], x4 + st1 {v0.h}[0], [x0], x11 + st1 {v0.h}[1], [x0], x11 + st1 {v0.h}[2], [x0], x11 + + sub w5, w5, #3 + + LoopE1L2: + cmp w5, #2 + blt LoopE1L1 + ld1 {v0.4h}, [x1], x4 + st1 {v0.h}[0], [x0], x11 + st1 {v0.h}[1], [x0], x11 + sub w5, w5, #2 + + LoopE1L1: + cmp w5, #1 + blt LoopE1End + ld1 {v0.h}[0], [x1], x4 + st1 {v0.h}[0], [x0], x11 + + LoopE1End: + + subs w2, w2, #1 + add x0, x8, x12 // !!!! caution : sizeof(int16_t) + add x1, x7, x6 + mov w5, w9 + bne LoopE1 + +End: + +mov x0, x13 +mov x1, x14 +subs w10, w10, #1 + +// x3 is (const int32_t* el), this array size of 4. as a result for next struct element, +// address added by 4 * sizeof(int32_t) +add x3, x3, #16 + +// x1 is (const int16_t** sourceGroup), even though data content is int16_t, +// the element in sourceGroup in 'int16_t*', as a result for next struct element, +// value added by sizeof(void*) +add x1, x1, #8 +bne LoopNumber + +ret + +#endif diff --git a/source/backend/cpu/arm/arm64/bf16/MNNPackC4_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNPackC4_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm64/bf16/MNNPackC4_BF16.S rename to backupcode/cpubackend/arm/arm64/bf16/MNNPackC4_BF16.S diff --git a/backupcode/cpubackend/arm/arm64/bf16/MNNPackC8_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNPackC8_BF16.S new file mode 100644 index 000000000..87503e839 --- /dev/null +++ b/backupcode/cpubackend/arm/arm64/bf16/MNNPackC8_BF16.S @@ -0,0 +1,126 @@ +// +// MNNPackC8_BF16.S +// MNN +// +// Created by MNN on 2021/02/20. +// Copyright © 2018-2021 Alibaba Group Holding Limited. +// +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + + +.text +.align 5 +asm_function MNNPackC8_BF16 +// treate float pointer as int16_t* +//void MNNPackC8_BF16(float* dest, const float* source, size_t l, size_t h); +// h, l -> hC8, l, 8 +// Auto: x0:dest, x1:source, x2: l, x3: h +// x4: lC8, x5:hC8, x6: sourceStride, x7: destStride + +lsr x4, x2, #3 +lsr x5, x3, #3 +mov x12, #2 // sizeof(int16_t) +mov x13, #16 // 8 * sizeof(int16_t) +mul x6, x12, x2 +mul x7, x13, x2 +mov x12, #16 // 8 * sizeof(int16_t) +mul x15, x12, x2 + +.macro transpose_4x4 x0, x1, x2, x3, x5, x6 + trn1 \x5\().4s, \x0\().4s, \x1\().4s + trn2 \x1\().4s, \x0\().4s, \x1\().4s + trn1 \x6\().4s, \x2\().4s, \x3\().4s + trn2 \x3\().4s, \x2\().4s, \x3\().4s + trn1 \x0\().2d, \x5\().2d, \x6\().2d + trn2 \x2\().2d, \x5\().2d, \x6\().2d + trn1 \x6\().2d, \x1\().2d, \x3\().2d + trn2 \x3\().2d, \x1\().2d, \x3\().2d + mov \x1\().16b, \x6\().16b +.endm + +LoopH: +mov x8, x0 +mov x9, x1 +mov x12, x4 + +LoopL: +mov x10, x9 +ld1 {v16.4h, v17.4h}, [x9], x6 +ld1 {v18.4h, v19.4h}, [x9], x6 +ld1 {v20.4h, v21.4h}, [x9], x6 +ld1 {v22.4h, v23.4h}, [x9], x6 + +ld1 {v24.4h, v25.4h}, [x9], x6 +ld1 {v26.4h, v27.4h}, [x9], x6 +ld1 {v28.4h, v29.4h}, [x9], x6 +ld1 {v30.4h, v31.4h}, [x9], x6 + +shll v16.4s, v16.4h, #16 +shll v17.4s, v17.4h, #16 +shll v18.4s, v18.4h, #16 +shll v19.4s, v19.4h, #16 +shll v20.4s, v20.4h, #16 +shll v21.4s, v21.4h, #16 +shll v22.4s, v22.4h, #16 +shll v23.4s, v23.4h, #16 +shll v24.4s, v24.4h, #16 +shll v25.4s, v25.4h, #16 +shll v26.4s, v26.4h, #16 +shll v27.4s, v27.4h, #16 +shll v28.4s, v28.4h, #16 +shll v29.4s, v29.4h, #16 +shll v30.4s, v30.4h, #16 +shll v31.4s, v31.4h, #16 + + +transpose_4x4 v16, v18, v20, v22, v0, v1 +transpose_4x4 v17, v19, v21, v23, v2, v3 +transpose_4x4 v24, v26, v28, v30, v4, v5 +transpose_4x4 v25, v27, v29, v31, v6, v7 + + +shrn v16.4h, v16.4s, #16 +shrn v17.4h, v17.4s, #16 +shrn v18.4h, v18.4s, #16 +shrn v19.4h, v19.4s, #16 +shrn v20.4h, v20.4s, #16 +shrn v21.4h, v21.4s, #16 +shrn v22.4h, v22.4s, #16 +shrn v23.4h, v23.4s, #16 +shrn v24.4h, v24.4s, #16 +shrn v25.4h, v25.4s, #16 +shrn v26.4h, v26.4s, #16 +shrn v27.4h, v27.4s, #16 +shrn v28.4h, v28.4s, #16 +shrn v29.4h, v29.4s, #16 +shrn v30.4h, v30.4s, #16 +shrn v31.4h, v31.4s, #16 + + +stp d16, d24, [x8], #16 +stp d18, d26, [x8], #16 +stp d20, d28, [x8], #16 +stp d22, d30, [x8], #16 + +stp d17, d25, [x8], #16 +stp d19, d27, [x8], #16 +stp d21, d29, [x8], #16 +stp d23, d31, [x8], #16 + +add x9, x10, #16 // 8 * sizeof(int16_t) + +subs x12, x12, #1 +bne LoopL + + +subs x5, x5, #1 +add x0, x0, x7 +add x1, x1, x15 +bne LoopH + + +ret + +#endif diff --git a/backupcode/cpubackend/arm/arm64/bf16/MNNPackedMatMulRemain_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNPackedMatMulRemain_BF16.S new file mode 100644 index 000000000..a65140adc --- /dev/null +++ b/backupcode/cpubackend/arm/arm64/bf16/MNNPackedMatMulRemain_BF16.S @@ -0,0 +1,672 @@ +// +// MNNPackedMatMulRemain_BF16.S +// MNN +// +// Created by MNN on 2021/02/21. +// Copyright © 2018-2021 Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 +// 12 * 8 MatMul +asm_function NEON_MNNPackedMatMulRemain_BF16 +//void NEON_MNNPackedMatMulRemain_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); +//Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias +sub sp, sp, #32 +str x19, [sp, #0] +str x20, [sp, #8] +str x21, [sp, #16] +ldr x11, [x4, #0] // aStride +ldr x9, [x4, #8] // l +ldr x10, [x4, #16] // h + +ldr x7, [x4, #24] // cStride +ldr x19, [x4, #40] // bExtraStride + +add x10, x10, #3 +lsr x10, x10, #2 + +cbz x5, Start +ld1 {v5.4s}, [x5] +dup v6.4s, v5.s[2] // Min Value +dup v7.4s, v5.s[3] // Max Value + +Start: + +E8: +cmp x3, #8 +blt E4 + +LoopE8: // e, TILE_BLOCK size is 8 + mov x20, x6 // bias + mov x8, x10 // updiv(h, 4) + mov x21, x0 // dest, C + mov x13, x2 // weight, B + + LH8: + cmp x8, #2 // h/4 > 2 + blt LH4 + // sub x14, x7, #32 // in "StoreLH8", total 2 lines stride is x14, first line is 4 * 4 * size_t(int16_t) = 32byte + LoopH8x8: + mov x15, x1 // src, A + subs x12, x9, #1 // l + ld1 {v3.4h, v4.4h}, [x13], #16 // 2 * 4 * sizeof(int16_t) + ld1 {v0.4h, v1.4h}, [x15], x11 + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + + fmul v16.4s, v3.4s, v0.s[0] + fmul v17.4s, v3.4s, v0.s[1] + fmul v18.4s, v3.4s, v0.s[2] + fmul v19.4s, v3.4s, v0.s[3] + + fmul v20.4s, v4.4s, v0.s[0] + fmul v21.4s, v4.4s, v0.s[1] + fmul v22.4s, v4.4s, v0.s[2] + fmul v23.4s, v4.4s, v0.s[3] + + fmul v24.4s, v3.4s, v1.s[0] + fmul v25.4s, v3.4s, v1.s[1] + fmul v26.4s, v3.4s, v1.s[2] + fmul v27.4s, v3.4s, v1.s[3] + + fmul v28.4s, v4.4s, v1.s[0] + fmul v29.4s, v4.4s, v1.s[1] + fmul v30.4s, v4.4s, v1.s[2] + fmul v31.4s, v4.4s, v1.s[3] + beq LoopLEnd + + LoopL: + ld1 {v3.4h, v4.4h}, [x13], #16 // 2 * 4 * sizeof(int16_t) + ld1 {v0.4h, v1.4h}, [x15], x11 + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + + fmla v16.4s, v3.4s, v0.s[0] + fmla v17.4s, v3.4s, v0.s[1] + fmla v18.4s, v3.4s, v0.s[2] + fmla v19.4s, v3.4s, v0.s[3] + + fmla v20.4s, v4.4s, v0.s[0] + fmla v21.4s, v4.4s, v0.s[1] + fmla v22.4s, v4.4s, v0.s[2] + fmla v23.4s, v4.4s, v0.s[3] + + fmla v24.4s, v3.4s, v1.s[0] + fmla v25.4s, v3.4s, v1.s[1] + fmla v26.4s, v3.4s, v1.s[2] + fmla v27.4s, v3.4s, v1.s[3] + + fmla v28.4s, v4.4s, v1.s[0] + fmla v29.4s, v4.4s, v1.s[1] + fmla v30.4s, v4.4s, v1.s[2] + fmla v31.4s, v4.4s, v1.s[3] + + subs x12, x12, #1 + bne LoopL + + LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH8 + AddBiasLH8: + ld1 {v0.4h, v1.4h}, [x20], #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + fmla v24.4s, v0.4s, v5.s[1] + fmla v25.4s, v0.4s, v5.s[1] + fmla v26.4s, v0.4s, v5.s[1] + fmla v27.4s, v0.4s, v5.s[1] + + fmla v28.4s, v1.4s, v5.s[1] + fmla v29.4s, v1.4s, v5.s[1] + fmla v30.4s, v1.4s, v5.s[1] + fmla v31.4s, v1.4s, v5.s[1] + + PostTreatLH8: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + fmax v24.4s, v24.4s, v6.4s + fmax v25.4s, v25.4s, v6.4s + fmax v26.4s, v26.4s, v6.4s + fmax v27.4s, v27.4s, v6.4s + fmax v28.4s, v28.4s, v6.4s + fmax v29.4s, v29.4s, v6.4s + fmax v30.4s, v30.4s, v6.4s + fmax v31.4s, v31.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + fmin v24.4s, v24.4s, v7.4s + fmin v25.4s, v25.4s, v7.4s + fmin v26.4s, v26.4s, v7.4s + fmin v27.4s, v27.4s, v7.4s + fmin v28.4s, v28.4s, v7.4s + fmin v29.4s, v29.4s, v7.4s + fmin v30.4s, v30.4s, v7.4s + fmin v31.4s, v31.4s, v7.4s + + StoreLH8: + shrn v16.4h, v16.4s, #16 + shrn v17.4h, v17.4s, #16 + shrn v18.4h, v18.4s, #16 + shrn v19.4h, v19.4s, #16 + shrn v20.4h, v20.4s, #16 + shrn v21.4h, v21.4s, #16 + shrn v22.4h, v22.4s, #16 + shrn v23.4h, v23.4s, #16 + shrn v24.4h, v24.4s, #16 + shrn v25.4h, v25.4s, #16 + shrn v26.4h, v26.4s, #16 + shrn v27.4h, v27.4s, #16 + shrn v28.4h, v28.4s, #16 + shrn v29.4h, v29.4s, #16 + shrn v30.4h, v30.4s, #16 + shrn v31.4h, v31.4s, #16 + + stp d16, d17, [x0] + stp d18, d19, [x0, #(16 * 1)] + stp d24, d25, [x0, #(16 * 2)] + stp d26, d27, [x0, #(16 * 3)] + add x0, x0, x7 // stp donot support post-index offset in register + + stp d20, d21, [x0] + stp d22, d23, [x0, #(16 * 1)] + stp d28, d29, [x0, #(16 * 2)] + stp d30, d31, [x0, #(16 * 3)] + add x0, x0, x7 // stp donot support post-index offset in register + + // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 // 4 * 4 * sizeof(int16_t) + // st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [x0], x14 + // st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], #32 + // st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [x0], x14 + + bge LoopH8x8 + + LH4: + cbz x8, E8End + LoopHRemain: + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.4h}, [x13] + ld1 {v0.4h}, [x15], #8 + shll v3.4s, v3.4h, #16 + shll v0.4s, v0.4h, #16 + + fmul v16.4s, v3.4s, v0.s[0] + fmul v17.4s, v3.4s, v0.s[1] + add x13, x13, #16 // weight + ld1 {v1.4h}, [x15] + shll v1.4s, v1.4h, #16 + + fmul v18.4s, v3.4s, v0.s[2] + sub x15, x15, #8 + fmul v19.4s, v3.4s, v0.s[3] + add x15, x15, x11 + fmul v20.4s, v3.4s, v1.s[0] + fmul v21.4s, v3.4s, v1.s[1] + fmul v22.4s, v3.4s, v1.s[2] + fmul v23.4s, v3.4s, v1.s[3] + beq LoopLREnd + + LoopLR: + ld1 {v3.4h}, [x13] + ld1 {v0.4h}, [x15], #8 + shll v3.4s, v3.4h, #16 + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v3.4s, v0.s[0] + fmla v17.4s, v3.4s, v0.s[1] + add x13, x13, #16 // weight + ld1 {v1.4h}, [x15] + shll v1.4s, v1.4h, #16 + + fmla v18.4s, v3.4s, v0.s[2] + sub x15, x15, #8 + fmla v19.4s, v3.4s, v0.s[3] + add x15, x15, x11 + + fmla v20.4s, v3.4s, v1.s[0] + fmla v21.4s, v3.4s, v1.s[1] + fmla v22.4s, v3.4s, v1.s[2] + fmla v23.4s, v3.4s, v1.s[3] + + subs x12, x12, #1 + bne LoopLR + LoopLREnd: + + cbz x5, StoreLH8x4 + AddBiasLH8x4: + ld1 {v0.4h}, [x20] + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v0.4s, v5.s[1] + fmla v21.4s, v0.4s, v5.s[1] + fmla v22.4s, v0.4s, v5.s[1] + fmla v23.4s, v0.4s, v5.s[1] + + PostTreatLH8x4: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + + StoreLH8x4: + shrn v16.4h, v16.4s, #16 + shrn v17.4h, v17.4s, #16 + shrn v18.4h, v18.4s, #16 + shrn v19.4h, v19.4s, #16 + shrn v20.4h, v20.4s, #16 + shrn v21.4h, v21.4s, #16 + shrn v22.4h, v22.4s, #16 + shrn v23.4h, v23.4s, #16 + + stp d16, d17, [x0] + stp d18, d19, [x0, #(16 * 1)] + stp d20, d21, [x0, #(16 * 2)] + stp d22, d23, [x0, #(16 * 3)] + add x0, x0, #(16 * 4) + + // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 + // st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], #32 + + E8End: + + sub x3, x3, #8 + cmp x3, #8 + add x0, x21, #64 // move dest address of 8 * 4 * sizeof(int16_t) + add x1, x1, #16 // move A matrix address of 8 * sizeof(int16_t) + bge LoopE8 + +E4: +cmp x3, #4 +mov x20, x6 +blt E1 + mov x8, x10 + mov x21, x0 + mov x13, x2 + + cmp x8, #2 + blt E4LH4 + + E4LH8: + E4LoopH8: + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.4h, v4.4h}, [x13], #16 + ld1 {v0.4h}, [x15], x11 + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + + fmul v16.4s, v3.4s, v0.s[0] + fmul v17.4s, v3.4s, v0.s[1] + fmul v18.4s, v3.4s, v0.s[2] + fmul v19.4s, v3.4s, v0.s[3] + + fmul v20.4s, v4.4s, v0.s[0] + fmul v21.4s, v4.4s, v0.s[1] + fmul v22.4s, v4.4s, v0.s[2] + fmul v23.4s, v4.4s, v0.s[3] + + beq E4LoopLEnd + + subs x12, x12, #1 + ld1 {v3.4h, v4.4h}, [x13], #16 + ld1 {v0.4h}, [x15], x11 + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v3.4s, v0.s[0] + fmla v17.4s, v3.4s, v0.s[1] + + beq E4LoopLComputeEnd + + E4LoopL: + fmla v18.4s, v3.4s, v0.s[2] + fmla v19.4s, v3.4s, v0.s[3] + + fmla v20.4s, v4.4s, v0.s[0] + fmla v21.4s, v4.4s, v0.s[1] + fmla v22.4s, v4.4s, v0.s[2] + fmla v23.4s, v4.4s, v0.s[3] + + ld1 {v3.4h, v4.4h}, [x13], #16 + ld1 {v0.4h}, [x15], x11 + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v3.4s, v0.s[0] + fmla v17.4s, v3.4s, v0.s[1] + + subs x12, x12, #1 + bne E4LoopL + E4LoopLComputeEnd: + fmla v18.4s, v3.4s, v0.s[2] + fmla v19.4s, v3.4s, v0.s[3] + + fmla v20.4s, v4.4s, v0.s[0] + fmla v21.4s, v4.4s, v0.s[1] + fmla v22.4s, v4.4s, v0.s[2] + fmla v23.4s, v4.4s, v0.s[3] + + E4LoopLEnd: + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH4x8 + + AddBiasLH4x8: + ld1 {v0.4h, v1.4h}, [x20], #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + PostTreatLH4x8: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + + StoreLH4x8: + shrn v16.4h, v16.4s, #16 + shrn v17.4h, v17.4s, #16 + shrn v18.4h, v18.4s, #16 + shrn v19.4h, v19.4s, #16 + shrn v20.4h, v20.4s, #16 + shrn v21.4h, v21.4s, #16 + shrn v22.4h, v22.4s, #16 + shrn v23.4h, v23.4s, #16 + + + stp d16, d17, [x0] + stp d18, d19, [x0, #16] + add x0, x0, x7 + stp d20, d21, [x0] + stp d22, d23, [x0, #16] + add x0, x0, x7 + + // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], x7 + // st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], x7 + + bge E4LoopH8 + + E4LH4: + cbz x8, E4End + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.4h}, [x13] + ld1 {v0.4h}, [x15], x11 + shll v3.4s, v3.4h, #16 + shll v0.4s, v0.4h, #16 + + fmul v16.4s, v3.4s, v0.s[0] + fmul v17.4s, v3.4s, v0.s[1] + fmul v18.4s, v3.4s, v0.s[2] + fmul v19.4s, v3.4s, v0.s[3] + add x13, x13, #16 // weight + + beq E4LoopLREnd + + E4LoopLR: + ld1 {v3.4h}, [x13] + ld1 {v0.4h}, [x15], x11 + shll v3.4s, v3.4h, #16 + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v3.4s, v0.s[0] + fmla v17.4s, v3.4s, v0.s[1] + fmla v18.4s, v3.4s, v0.s[2] + fmla v19.4s, v3.4s, v0.s[3] + add x13, x13, #16 // weight + + subs x12, x12, #1 + bne E4LoopLR + E4LoopLREnd: + + cbz x5, StoreLH4x4 + AddBiasLH4x4: + ld1 {v0.4h}, [x20] + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + + PostTreatLH4x4: + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + + StoreLH4x4: + + shrn v16.4h, v16.4s, #16 + shrn v17.4h, v17.4s, #16 + shrn v18.4h, v18.4s, #16 + shrn v19.4h, v19.4s, #16 + + stp d16, d17, [x0] + stp d18, d19, [x0, #16] + + // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0] + + E4End: + + sub x3, x3, #4 + add x0, x21, #32 // move dest address of 4 * 4 * sizeof(int16_t) + add x1, x1, #8 // move dest address of 4 * sizeof(int16_t) + +E1: +cmp x3, #0 +beq End + +LoopE1: + mov x20, x6 + mov x8, x10 + mov x21, x0 + mov x13, x2 + + cmp x8, #2 + blt E1LH4 + + E1LH8: + E1LoopH8: + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.4h, v4.4h}, [x13], #16 // + ld1 {v0.h}[0], [x15], x11 + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + + fmul v16.4s, v3.4s, v0.s[0] + fmul v20.4s, v4.4s, v0.s[0] + + beq E1LoopLEnd + + E1LoopL: + ld1 {v3.4h, v4.4h}, [x13], #16 // + ld1 {v0.h}[0], [x15], x11 + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v3.4s, v0.s[0] + fmla v20.4s, v4.4s, v0.s[0] + + subs x12, x12, #1 + bne E1LoopL + + E1LoopLEnd: + + add x13, x13, x19 + sub x8, x8, #2 + cmp x8, #2 + + cbz x5, StoreLH1x8 + AddBiasLH1x8: + ld1 {v0.4h, v1.4h}, [x20], #16 + shll v1.4s, v1.4h, #16 + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v0.4s, v5.s[1] + fmla v20.4s, v1.4s, v5.s[1] + + PostTreatLH1x8: + fmax v16.4s, v16.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmin v16.4s, v16.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + + StoreLH1x8: + shrn v16.4h, v16.4s, #16 + shrn v20.4h, v20.4s, #16 + st1 {v16.4h}, [x0], x7 + st1 {v20.4h}, [x0], x7 + + bge E1LoopH8 + + E1LH4: + cbz x8, E1End + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.4h}, [x13] + ld1 {v0.h}[0], [x15], x11 + shll v3.4s, v3.4h, #16 + shll v0.4s, v0.4h, #16 + + fmul v16.4s, v3.4s, v0.s[0] + add x13, x13, #16 // weight + + beq E1LoopLREnd + + E1LoopLR: + ld1 {v3.4h}, [x13] + ld1 {v0.h}[0], [x15], x11 + shll v3.4s, v3.4h, #16 + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v3.4s, v0.s[0] + add x13, x13, #16 // weight + + subs x12, x12, #1 + bne E1LoopLR + E1LoopLREnd: + + cbz x5, StoreLH1x4 + AddBiasLH1x4: + ld1 {v0.4h}, [x20] + shll v0.4s, v0.4h, #16 + + fmla v16.4s, v0.4s, v5.s[1] + + PostTreatLH1x4: + fmax v16.4s, v16.4s, v6.4s + fmin v16.4s, v16.4s, v7.4s + + StoreLH1x4: + shrn v16.4h, v16.4s, #16 + st1 {v16.4h}, [x0] + + E1End: + + subs x3, x3, #1 + add x0, x21, #8 + add x1, x1, #2 + bne LoopE1 + + +End: +ldr x19, [sp, #0] +ldr x20, [sp, #8] +ldr x21, [sp, #16] +add sp, sp, #32 + +ret + +#endif diff --git a/backupcode/cpubackend/arm/arm64/bf16/MNNPackedMatMul_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNPackedMatMul_BF16.S new file mode 100644 index 000000000..22c2c24ca --- /dev/null +++ b/backupcode/cpubackend/arm/arm64/bf16/MNNPackedMatMul_BF16.S @@ -0,0 +1,501 @@ +// +// MNNPackedMatMul_BF16.S +// MNN +// +// Created by MNN on 2021/02/21. +// Copyright © 2018-2021 Alibaba Group Holding Limited +// +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + + +.text +.align 5 +// 12 * 8 MatMul +asm_function NEON_MNNPackedMatMul_BF16 +//void NEON_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); +// x0: C, x1:A, x2:B, x3:parameter, x4: postParameters, x5:bias +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +//ldr x8, [x3, #0] // deprecated +ldr x9, [x3, #8] // l +ldr x10, [x3, #16] // h + +ldr x13, [x3, #24] // cStride +ldr x7, [x3, #40] // bExtraStride + +// v0, v1, v2: A +// v3, v4: B +// v8 - v31: C +add x10, x10, #3 +lsr x10, x10, #2 + +cbz x4, Start +ld1 {v5.4s}, [x4] +dup v6.4s, v5.s[2] // Min Value +dup v7.4s, v5.s[3] // Max Value + +Start: + +cmp x10, #2 +blt LH4 + +LH8: +// sub x14, x13, #80 // in "StoreLH8", total 3 lines Cstride is x13, first 5 line stp is 5 * 8 * sizeof(int16_t) = 64byte + // stp should add at last +LoopH: + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.4h, v4.4h}, [x2], #16 // 8 * sizeof(int16_t) + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 // 12 * sizeof(int16_t) + + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + shll v2.4s, v2.4h, #16 + + fmul v8.4s, v3.4s, v0.s[0] + fmul v9.4s, v3.4s, v0.s[1] + fmul v10.4s, v3.4s, v0.s[2] + fmul v11.4s, v3.4s, v0.s[3] + fmul v12.4s, v3.4s, v1.s[0] + fmul v13.4s, v3.4s, v1.s[1] + fmul v14.4s, v3.4s, v1.s[2] + fmul v15.4s, v3.4s, v1.s[3] + fmul v16.4s, v3.4s, v2.s[0] + fmul v17.4s, v3.4s, v2.s[1] + fmul v18.4s, v3.4s, v2.s[2] + fmul v19.4s, v3.4s, v2.s[3] + + fmul v20.4s, v4.4s, v0.s[0] + fmul v21.4s, v4.4s, v0.s[1] + fmul v22.4s, v4.4s, v0.s[2] + fmul v23.4s, v4.4s, v0.s[3] + + fmul v24.4s, v4.4s, v1.s[0] + fmul v25.4s, v4.4s, v1.s[1] + fmul v26.4s, v4.4s, v1.s[2] + fmul v27.4s, v4.4s, v1.s[3] + + fmul v28.4s, v4.4s, v2.s[0] + fmul v29.4s, v4.4s, v2.s[1] + fmul v30.4s, v4.4s, v2.s[2] + fmul v31.4s, v4.4s, v2.s[3] + + beq LoopLEnd + + cmp x12, #2 + blt L1 + LoopL2: + ld1 {v3.4h, v4.4h}, [x2], #16 // 8 * sizeof(int16_t) + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 // 12 * sizeof(int16_t) // * sizeof(int16_t) + + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + shll v2.4s, v2.4h, #16 + + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + fmla v20.4s, v4.4s, v0.s[0] + fmla v21.4s, v4.4s, v0.s[1] + fmla v22.4s, v4.4s, v0.s[2] + fmla v23.4s, v4.4s, v0.s[3] + + fmla v24.4s, v4.4s, v1.s[0] + fmla v25.4s, v4.4s, v1.s[1] + fmla v26.4s, v4.4s, v1.s[2] + fmla v27.4s, v4.4s, v1.s[3] + + fmla v28.4s, v4.4s, v2.s[0] + fmla v29.4s, v4.4s, v2.s[1] + fmla v30.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + ld1 {v3.4h, v4.4h}, [x2], #16 // 8 * sizeof(int16_t) + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 // 12 * sizeof(int16_t) // * sizeof(int16_t) + + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + shll v2.4s, v2.4h, #16 + + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + fmla v20.4s, v4.4s, v0.s[0] + fmla v21.4s, v4.4s, v0.s[1] + fmla v22.4s, v4.4s, v0.s[2] + fmla v23.4s, v4.4s, v0.s[3] + + fmla v24.4s, v4.4s, v1.s[0] + fmla v25.4s, v4.4s, v1.s[1] + fmla v26.4s, v4.4s, v1.s[2] + fmla v27.4s, v4.4s, v1.s[3] + + fmla v28.4s, v4.4s, v2.s[0] + fmla v29.4s, v4.4s, v2.s[1] + fmla v30.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + sub x12, x12, #2 + cmp x12, #2 + bge LoopL2 + + cbz x12, LoopLEnd + + L1: + ld1 {v3.4h, v4.4h}, [x2], #16 // 8 * sizeof(int16_t) + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 // 12 * sizeof(int16_t) // * sizeof(int16_t) + + shll v3.4s, v3.4h, #16 + shll v4.4s, v4.4h, #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + shll v2.4s, v2.4h, #16 + + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + fmla v12.4s, v3.4s, v1.s[0] + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + fmla v20.4s, v4.4s, v0.s[0] + fmla v21.4s, v4.4s, v0.s[1] + fmla v22.4s, v4.4s, v0.s[2] + fmla v23.4s, v4.4s, v0.s[3] + + fmla v24.4s, v4.4s, v1.s[0] + fmla v25.4s, v4.4s, v1.s[1] + fmla v26.4s, v4.4s, v1.s[2] + fmla v27.4s, v4.4s, v1.s[3] + + fmla v28.4s, v4.4s, v2.s[0] + fmla v29.4s, v4.4s, v2.s[1] + fmla v30.4s, v4.4s, v2.s[2] + fmla v31.4s, v4.4s, v2.s[3] + + LoopLEnd: + + add x2, x2, x7 // weight stride + sub x10, x10, #2 + cmp x10, #2 + + cbz x4, StoreLH8 + + AddBiasLH8: + ld1 {v0.4h, v1.4h}, [x5], #16 // 8 * sizeof(int16_t) + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + + fmla v8.4s, v0.4s, v5.s[1] + fmla v9.4s, v0.4s, v5.s[1] + fmla v10.4s, v0.4s, v5.s[1] + fmla v11.4s, v0.4s, v5.s[1] + + fmla v12.4s, v0.4s, v5.s[1] + fmla v13.4s, v0.4s, v5.s[1] + fmla v14.4s, v0.4s, v5.s[1] + fmla v15.4s, v0.4s, v5.s[1] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + fmla v20.4s, v1.4s, v5.s[1] + fmla v21.4s, v1.4s, v5.s[1] + fmla v22.4s, v1.4s, v5.s[1] + fmla v23.4s, v1.4s, v5.s[1] + + fmla v24.4s, v1.4s, v5.s[1] + fmla v25.4s, v1.4s, v5.s[1] + fmla v26.4s, v1.4s, v5.s[1] + fmla v27.4s, v1.4s, v5.s[1] + + fmla v28.4s, v1.4s, v5.s[1] + fmla v29.4s, v1.4s, v5.s[1] + fmla v30.4s, v1.4s, v5.s[1] + fmla v31.4s, v1.4s, v5.s[1] + + PostTreatLH8: + fmax v8.4s, v8.4s, v6.4s + fmax v9.4s, v9.4s, v6.4s + fmax v10.4s, v10.4s, v6.4s + fmax v11.4s, v11.4s, v6.4s + fmax v12.4s, v12.4s, v6.4s + fmax v13.4s, v13.4s, v6.4s + fmax v14.4s, v14.4s, v6.4s + fmax v15.4s, v15.4s, v6.4s + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + fmax v20.4s, v20.4s, v6.4s + fmax v21.4s, v21.4s, v6.4s + fmax v22.4s, v22.4s, v6.4s + fmax v23.4s, v23.4s, v6.4s + fmax v24.4s, v24.4s, v6.4s + fmax v25.4s, v25.4s, v6.4s + fmax v26.4s, v26.4s, v6.4s + fmax v27.4s, v27.4s, v6.4s + fmax v28.4s, v28.4s, v6.4s + fmax v29.4s, v29.4s, v6.4s + fmax v30.4s, v30.4s, v6.4s + fmax v31.4s, v31.4s, v6.4s + + fmin v8.4s, v8.4s, v7.4s + fmin v9.4s, v9.4s, v7.4s + fmin v10.4s, v10.4s, v7.4s + fmin v11.4s, v11.4s, v7.4s + fmin v12.4s, v12.4s, v7.4s + fmin v13.4s, v13.4s, v7.4s + fmin v14.4s, v14.4s, v7.4s + fmin v15.4s, v15.4s, v7.4s + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + fmin v20.4s, v20.4s, v7.4s + fmin v21.4s, v21.4s, v7.4s + fmin v22.4s, v22.4s, v7.4s + fmin v23.4s, v23.4s, v7.4s + fmin v24.4s, v24.4s, v7.4s + fmin v25.4s, v25.4s, v7.4s + fmin v26.4s, v26.4s, v7.4s + fmin v27.4s, v27.4s, v7.4s + fmin v28.4s, v28.4s, v7.4s + fmin v29.4s, v29.4s, v7.4s + fmin v30.4s, v30.4s, v7.4s + fmin v31.4s, v31.4s, v7.4s + + StoreLH8: + + shrn v8.4h, v8.4s, #16 + shrn v9.4h, v9.4s, #16 + shrn v10.4h, v10.4s, #16 + shrn v11.4h, v11.4s, #16 + shrn v12.4h, v12.4s, #16 + shrn v13.4h, v13.4s, #16 + shrn v14.4h, v14.4s, #16 + shrn v15.4h, v15.4s, #16 + shrn v16.4h, v16.4s, #16 + shrn v17.4h, v17.4s, #16 + shrn v18.4h, v18.4s, #16 + shrn v19.4h, v19.4s, #16 + shrn v20.4h, v20.4s, #16 + shrn v21.4h, v21.4s, #16 + shrn v22.4h, v22.4s, #16 + shrn v23.4h, v23.4s, #16 + shrn v24.4h, v24.4s, #16 + shrn v25.4h, v25.4s, #16 + shrn v26.4h, v26.4s, #16 + shrn v27.4h, v27.4s, #16 + shrn v28.4h, v28.4s, #16 + shrn v29.4h, v29.4s, #16 + shrn v30.4h, v30.4s, #16 + shrn v31.4h, v31.4s, #16 + + stp d8, d9, [x0] + stp d10, d11, [x0, #(16 * 1)] // 2 * 4 * sizeof(int16_t) + stp d12, d13, [x0, #(16 * 2)] + stp d14, d15, [x0, #(16 * 3)] + stp d16, d17, [x0, #(16 * 4)] + stp d18, d19, [x0, #(16 * 5)] + add x0, x0, x13 // stp donot support post-index offset in register + stp d20, d21, [x0] + stp d22, d23, [x0, #(16 * 1)] + stp d24, d25, [x0, #(16 * 2)] + stp d26, d27, [x0, #(16 * 3)] + stp d28, d29, [x0, #(16 * 4)] + stp d30, d31, [x0, #(16 * 5)] + add x0, x0, x13 + + // st1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x0], #32 // 16 * sizeof(int16_t) + // st1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x0], #32 // 16 * sizeof(int16_t) + // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], x14 + // st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], #32 // 16 * sizeof(int16_t) + // st1 {v24.4h, v25.4h, v26.4h, v27.4h}, [x0], #32 // 16 * sizeof(int16_t) + // st1 {v28.4h, v29.4h, v30.4h, v31.4h}, [x0], x14 + + bge LoopH + +LH4: +cbz x10, End +LoopHRemain: + mov x15, x1 + subs x12, x9, #1 + ld1 {v3.4h}, [x2] + ld1 {v0.4h}, [x15], #8 + shll v3.4s, v3.4h, #16 + shll v0.4s, v0.4h, #16 + + fmul v8.4s, v3.4s, v0.s[0] + fmul v9.4s, v3.4s, v0.s[1] + add x2, x2, #16 // + ld1 {v1.4h}, [x15], #8 + shll v1.4s, v1.4h, #16 + + fmul v10.4s, v3.4s, v0.s[2] + fmul v11.4s, v3.4s, v0.s[3] + fmul v12.4s, v3.4s, v1.s[0] + + ld1 {v2.4h}, [x15], #8 + shll v2.4s, v2.4h, #16 + + fmul v13.4s, v3.4s, v1.s[1] + fmul v14.4s, v3.4s, v1.s[2] + fmul v15.4s, v3.4s, v1.s[3] + fmul v16.4s, v3.4s, v2.s[0] + fmul v17.4s, v3.4s, v2.s[1] + fmul v18.4s, v3.4s, v2.s[2] + fmul v19.4s, v3.4s, v2.s[3] + + beq LoopLREnd + + LoopLR: + ld1 {v3.4h}, [x2] + ld1 {v0.4h, v1.4h, v2.4h}, [x15], #24 // 12 * sizeof(int16_t) + shll v3.4s, v3.4h, #16 + shll v0.4s, v0.4h, #16 + shll v1.4s, v1.4h, #16 + shll v2.4s, v2.4h, #16 + + fmla v8.4s, v3.4s, v0.s[0] + fmla v9.4s, v3.4s, v0.s[1] + fmla v10.4s, v3.4s, v0.s[2] + fmla v11.4s, v3.4s, v0.s[3] + add x2, x2, #16 // + fmla v12.4s, v3.4s, v1.s[0] + fmla v13.4s, v3.4s, v1.s[1] + fmla v14.4s, v3.4s, v1.s[2] + fmla v15.4s, v3.4s, v1.s[3] + fmla v16.4s, v3.4s, v2.s[0] + fmla v17.4s, v3.4s, v2.s[1] + fmla v18.4s, v3.4s, v2.s[2] + fmla v19.4s, v3.4s, v2.s[3] + + subs x12, x12, #1 + bne LoopLR + LoopLREnd: + + cbz x4, StoreLH4 + AddBiasLH4: + ld1 {v0.4h}, [x5], #8 + shll v0.4s, v0.4h, #16 + + fmla v8.4s, v0.4s, v5.s[1] + fmla v9.4s, v0.4s, v5.s[1] + fmla v10.4s, v0.4s, v5.s[1] + fmla v11.4s, v0.4s, v5.s[1] + + fmla v12.4s, v0.4s, v5.s[1] + fmla v13.4s, v0.4s, v5.s[1] + fmla v14.4s, v0.4s, v5.s[1] + fmla v15.4s, v0.4s, v5.s[1] + + fmla v16.4s, v0.4s, v5.s[1] + fmla v17.4s, v0.4s, v5.s[1] + fmla v18.4s, v0.4s, v5.s[1] + fmla v19.4s, v0.4s, v5.s[1] + + PostTreatLH4: + fmax v8.4s, v8.4s, v6.4s + fmax v9.4s, v9.4s, v6.4s + fmax v10.4s, v10.4s, v6.4s + fmax v11.4s, v11.4s, v6.4s + fmax v12.4s, v12.4s, v6.4s + fmax v13.4s, v13.4s, v6.4s + fmax v14.4s, v14.4s, v6.4s + fmax v15.4s, v15.4s, v6.4s + fmax v16.4s, v16.4s, v6.4s + fmax v17.4s, v17.4s, v6.4s + fmax v18.4s, v18.4s, v6.4s + fmax v19.4s, v19.4s, v6.4s + + fmin v8.4s, v8.4s, v7.4s + fmin v9.4s, v9.4s, v7.4s + fmin v10.4s, v10.4s, v7.4s + fmin v11.4s, v11.4s, v7.4s + fmin v12.4s, v12.4s, v7.4s + fmin v13.4s, v13.4s, v7.4s + fmin v14.4s, v14.4s, v7.4s + fmin v15.4s, v15.4s, v7.4s + fmin v16.4s, v16.4s, v7.4s + fmin v17.4s, v17.4s, v7.4s + fmin v18.4s, v18.4s, v7.4s + fmin v19.4s, v19.4s, v7.4s + + StoreLH4: + + shrn v8.4h, v8.4s, #16 + shrn v9.4h, v9.4s, #16 + shrn v10.4h, v10.4s, #16 + shrn v11.4h, v11.4s, #16 + shrn v12.4h, v12.4s, #16 + shrn v13.4h, v13.4s, #16 + shrn v14.4h, v14.4s, #16 + shrn v15.4h, v15.4s, #16 + shrn v16.4h, v16.4s, #16 + shrn v17.4h, v17.4s, #16 + shrn v18.4h, v18.4s, #16 + shrn v19.4h, v19.4s, #16 + + stp d8, d9, [x0] + stp d10, d11, [x0, #(16 * 1)] + stp d12, d13, [x0, #(16 * 2)] + stp d14, d15, [x0, #(16 * 3)] + stp d16, d17, [x0, #(16 * 4)] + stp d18, d19, [x0, #(16 * 5)] + + // st1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x0], #32 + // st1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x0], #32 + // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0] + + sub x10, x10, #1 + + +End: +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 + + +ret + +#endif diff --git a/source/backend/cpu/arm/arm64/bf16/MNNReluWithSlopeChannelBF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNReluWithSlopeChannelBF16.S similarity index 100% rename from source/backend/cpu/arm/arm64/bf16/MNNReluWithSlopeChannelBF16.S rename to backupcode/cpubackend/arm/arm64/bf16/MNNReluWithSlopeChannelBF16.S diff --git a/source/backend/cpu/arm/arm64/bf16/MNNUnPackC4_BF16.S b/backupcode/cpubackend/arm/arm64/bf16/MNNUnPackC4_BF16.S similarity index 100% rename from source/backend/cpu/arm/arm64/bf16/MNNUnPackC4_BF16.S rename to backupcode/cpubackend/arm/arm64/bf16/MNNUnPackC4_BF16.S diff --git a/source/backend/cpu/bf16/BF16Backend.cpp b/backupcode/cpubackend/bf16/BF16Backend.cpp similarity index 100% rename from source/backend/cpu/bf16/BF16Backend.cpp rename to backupcode/cpubackend/bf16/BF16Backend.cpp diff --git a/source/backend/cpu/bf16/BF16Backend.hpp b/backupcode/cpubackend/bf16/BF16Backend.hpp similarity index 100% rename from source/backend/cpu/bf16/BF16Backend.hpp rename to backupcode/cpubackend/bf16/BF16Backend.hpp diff --git a/source/backend/cpu/bf16/BF16Binary.cpp b/backupcode/cpubackend/bf16/BF16Binary.cpp similarity index 100% rename from source/backend/cpu/bf16/BF16Binary.cpp rename to backupcode/cpubackend/bf16/BF16Binary.cpp diff --git a/source/backend/cpu/bf16/BF16Binary.hpp b/backupcode/cpubackend/bf16/BF16Binary.hpp similarity index 100% rename from source/backend/cpu/bf16/BF16Binary.hpp rename to backupcode/cpubackend/bf16/BF16Binary.hpp diff --git a/backupcode/cpubackend/bf16/BF16Functions.cpp b/backupcode/cpubackend/bf16/BF16Functions.cpp new file mode 100644 index 000000000..3f792a3ce --- /dev/null +++ b/backupcode/cpubackend/bf16/BF16Functions.cpp @@ -0,0 +1,918 @@ +#ifdef MNN_USE_SSE +#include "../x86_x64/sse/FunctionSummary.hpp" +#include "../x86_x64/avx/FunctionSummary.hpp" +#include "../x86_x64/avxfma/FunctionSummary.hpp" +#include "../x86_x64/avx512/FunctionSummary.hpp" +#include "../x86_x64/cpu_id.h" +#endif +#include "core/Macro.h" +#if defined(MNN_USE_NEON) +#include "../arm/FunctionSummary.hpp" +#endif + +#include "BF16Functions.hpp" +#include "WinogradOptFunctionHalf.hpp" +#include "../compute/CommonOptFunction.h" +#include "../CPUPool.hpp" +#include "../CPURuntime.hpp" +#include "VecHalf.hpp" +#include "math/Vec.hpp" +#include "BF16Binary.hpp" +#include "BF16Unary.hpp" +using BFVec4 = MNN::Math::VecHalf<4>; +using Vec4 = MNN::Math::Vec; +extern "C" { +void MNNReluWithSlopeChannelBF16(float* dstO, const float* srcO, const float* slopeO, size_t sizeQuad, size_t depthQuad); +} +namespace MNN { +// just for reference BF16 converting of c++ code, not for arm or sse. +inline int16_t MNNFP32ToBF16(float fp32Value) { + int32_t* s32Value = (int32_t*)(&fp32Value); + return (int16_t)((*s32Value) >> 16); +} +inline float MNNLowpToFp32(int16_t s16Value) { + int32_t s32Value = ((int32_t)s16Value) << 16; + float* fp32Value = (float*)(&s32Value); + return *fp32Value; +} + +static void _MNNFp32ToLowp(const float* src, int16_t* dst, size_t size) { + int sizeC4 = size / 4; + for (int i = 0; i < sizeC4; ++i) { + auto srcV = Vec4::load(src); + auto dstV = BFVec4(std::move(srcV.value)); + BFVec4::save(dst, dstV); + src+=4; + dst+=4; + } + int sizeRemain = size % 4; + if (sizeRemain > 0) { + float srcTemp[4]; + int64_t dstTemp[1]; + ::memcpy(srcTemp, src, sizeRemain * sizeof(float)); + auto srcV = Vec4::load(srcTemp); + auto dstV = BFVec4(std::move(srcV.value)); + BFVec4::save((int16_t*)dstTemp, dstV); + ::memcpy(dst, dstTemp, sizeRemain * sizeof(int16_t)); + } +} +static void _MNNLowpToFp32(const int16_t* src, float* dst, size_t size) { + int sizeC4 = size / 4; + for (int i = 0; i < sizeC4; ++i) { + auto srcV = BFVec4::load(src); + auto dstV = Vec4(std::move(srcV.value)); + Vec4::save(dst, dstV); + src+=4; + dst+=4; + } + int sizeRemain = size % 4; + if (sizeRemain > 0) { + int64_t srcTemp[2]; + float dstTemp[4]; + ::memcpy(srcTemp, src, sizeRemain * sizeof(int16_t)); + auto srcV = BFVec4::load((int16_t*)srcTemp); + auto dstV = Vec4(std::move(srcV.value)); + Vec4::save(dstTemp, dstV); + ::memcpy(dst, dstTemp, sizeRemain * sizeof(float)); + } +} +static void MNNConvRunForUnitDepthWiseBF16(float* dst, const float* src, const float* weight, size_t fw, size_t fh, + size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) { + int fx, fy; + BFVec4 dstValue(0.0f); + const int16_t* src_z = (const int16_t*)src; + const int16_t* weight_z = (const int16_t*)weight; + for (fy = 0; fy < fh; ++fy) { + const auto src_y = src_z + fy * dilateY_step; + const auto weight_y = weight_z + fy * weight_y_step; + for (fx = 0; fx < fw; ++fx) { + const auto weight_x = weight_y + 4 * fx; + const auto src_x = src_y + fx * dilateX_step; + dstValue = dstValue + BFVec4::load(src_x) * BFVec4::load(weight_x); + } + } + BFVec4::save((int16_t*)dst, dstValue); +} + +static void MNNConvRunForLineDepthwiseBF16(float* dstO, const float* srcO, const float* weightO, size_t width, size_t src_w_setup, + size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, + size_t srcHStep, size_t dstHStep) { + int dx, fx, fy; + auto dst = (int16_t*)dstO; + auto src = (const int16_t*)srcO; + auto weight = (const int16_t*)weightO; + for (int y = 0; y < height; ++y) { + auto srcY = src + y * srcHStep; + auto dstY = dst + y * dstHStep; + for (dx = 0; dx < width; ++dx) { + auto dst_x = dstY + dx * 4; + BFVec4 dstValue(0.0f); + const auto src_z = srcY + src_w_setup * dx; + const auto weight_z = weight; + for (fy = 0; fy < fh; ++fy) { + const auto src_y = src_z + fy * dilateY_step; + const auto weight_y = weight_z + fy * fw * 4; + for (fx = 0; fx < fw; ++fx) { + const auto weight_x = weight_y + 4 * fx; + const auto src_x = src_y + fx * dilateX_step; + dstValue = dstValue + BFVec4::load(src_x) * BFVec4::load(weight_x); + } + } + BFVec4::save(dst_x, dstValue); + } + } +} +void MNNAxByClampBroadcastUnitBF16(float* CF, const float* AF, const float* BF, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { + auto C = (int16_t*)CF; + auto A = (const int16_t*)AF; + auto B = (const int16_t*)BF; + auto minF = BFVec4(parameters[2]); + auto maxF = BFVec4(parameters[3]); + auto beta = BFVec4(parameters[1]); + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + 4 * y; + auto bv = BFVec4::load(b); + auto c = C + cStride * y; + for (int x = 0; x < width; ++x) { + auto av = BFVec4::load(a + 4 * x); + auto cv = av + bv * beta; + cv = BFVec4::min(cv, maxF); + cv = BFVec4::max(cv, minF); + BFVec4::save(c + 4 * x, cv); + } + } +} +#ifndef MNN_USE_NEON +void MNNReluWithSlopeChannelBF16(float* dstO, const float* srcO, const float* slopeO, size_t sizeQuad, size_t depthQuad) { + auto slope = (const int16_t*)slopeO; + auto dst = (int16_t*)dstO; + auto src = (const int16_t*)srcO; + auto zero = BFVec4(0.0f); + for (int j = 0; j < depthQuad; j++) { + auto slopeZ = BFVec4::load(slope + 4 * j); + auto srcZ = src + 4 * j * sizeQuad; + auto dstZ = dst + 4 * j * sizeQuad; + for (int i = 0; i < sizeQuad; i++) { + auto srcValue = BFVec4::load(srcZ + 4 * i); + std::array dstV; + for (int c = 0; c < 4; c++) { + if (srcValue[c] < 0) { + dstV[c] = srcValue[c] * slopeZ[c]; + } else { + dstV[c] = srcValue[c]; + } + } + auto dstValue = BFVec4(std::move(Vec4::load(dstV.data()).value)); + BFVec4::save(dstZ + 4 * i, dstValue); + } + } +} +#endif + +#if !defined(MNN_USE_SSE) && !defined(MNN_USE_NEON) +void MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) { + MNNPackC4ForMatMul_A(destOrigin, sourceGroup, info, el); + return; +} + +void MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose) { + auto hP = h / 4; + auto hR = hP * 4; + if (hR != h) { + ::memset(dest, 0, UP_DIV(h, 4)*4*l*sizeof(int16_t)); + } + if (!transpose) { + for (int y=0; y 0) { + auto destY = dest + hP * 4 * l; + auto sourceY = source + hP * 4; + for (int x=0; x().max(); + float maxValue = std::numeric_limits().max(); + if (nullptr != postParameters) { + minValue = postParameters[2]; + maxValue = postParameters[3]; + alpha = postParameters[0]; + beta = postParameters[1]; + } + + for (int x = 0; x < eSize; ++x) { + auto dst = C + 4 * x; + auto src = + A + x; // input data is packed as tileCount x l x 16, is only one tiled block here, indexed as A[z * 16 + x] + for (int ry = 0; ry < h; ++ry) { + auto y = ry / 4; + auto yRemain = ry % 4; + auto bY = B + y * bStride; + auto dstY = dst + y * cStride; // convert NCHW to NC4HW4 ie 1·(y/4)·X·4 + int wdy = ry / 6; + int wdyRemain = ry % 6; + auto weight = + B + wdy * bStride + + wdyRemain; // weight is packed as (h/6) x l x 6, indexed as B[(ry / 6) * Bstride +z*6 + (ry % 6)] + float summer = 0.0f; + for (int z = 0; z < l; ++z) { + auto aZ = src + z * 16; + auto wZ = weight + z * 6; + summer += MNNLowpToFp32(wZ[0]) * MNNLowpToFp32(aZ[0]); + } + float originValue = MNNLowpToFp32(dstY[yRemain]); + if (nullptr != bias) { + originValue = MNNLowpToFp32(bias[ry]); + } + auto dstValue = originValue * beta + alpha * summer; + dstValue = std::min(dstValue, maxValue); + dstValue = std::max(dstValue, minValue); + dstY[yRemain] = MNNFP32ToBF16(dstValue); + } + } +} + +void MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, float* cache, + const float* postParameters, const float* bias, const float* k, const float* b) { + return MNNPackedMatMulRemain_BF16(C, A, B, 16, parameter, cache, postParameters, bias, nullptr, nullptr); + // return _AVX_MNNPackedMatMulFMA(C, A, B, parameter, cache); +} + + +static void _MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow); + +static void _MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigthF, float *destF, int cacheLineSize, int ow, const float* bias, const float* parameters) { + auto weigth = (const int16_t*)weigthF; + auto dest = (int16_t*)destF; + int unit = ow / 2; + auto biasF = BFVec4::load((const int16_t*)bias); + auto minV = BFVec4(parameters[2]); + auto maxV = BFVec4(parameters[3]); + MNN_ASSERT(cacheLineSize >= 1); + for (int x = 0; x < unit; ++x) { + auto offset = 4 * 4 * x; + int i = 0; + BFVec4 m0 = BFVec4::load(weigth + i * 16 + 4 * 0) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 0); + BFVec4 m1 = BFVec4::load(weigth + i * 16 + 4 * 1) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 1); + BFVec4 m2 = BFVec4::load(weigth + i * 16 + 4 * 2) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 2); + BFVec4 m3 = BFVec4::load(weigth + i * 16 + 4 * 3) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 3); + + for (i = 1; i < cacheLineSize; ++i) { + m0 = m0 + BFVec4::load(weigth + i * 16 + 4 * 0) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 0); + m1 = m1 + BFVec4::load(weigth + i * 16 + 4 * 1) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 1); + m2 = m2 + BFVec4::load(weigth + i * 16 + 4 * 2) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 2); + m3 = m3 + BFVec4::load(weigth + i * 16 + 4 * 3) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 3); + } + + auto o0 = m0 + m1 + m2 + biasF; + auto o1 = m1 - m2 + m3 + biasF; + o0 = BFVec4::min(o0, maxV); + o1 = BFVec4::min(o1, maxV); + o0 = BFVec4::max(o0, minV); + o1 = BFVec4::max(o1, minV); + BFVec4::save(dest + 8 * x + 0 * 4, o0); + BFVec4::save(dest + 8 * x + 1 * 4, o1); + } + if (unit * 2 < ow) { + auto offset = 4 * 4 * unit; + int i = 0; + BFVec4 m0 = BFVec4::load(weigth + i * 16 + 4 * 0) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 0); + BFVec4 m1 = BFVec4::load(weigth + i * 16 + 4 * 1) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 1); + BFVec4 m2 = BFVec4::load(weigth + i * 16 + 4 * 2) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 2); + + for (i = 1; i < cacheLineSize; ++i) { + m0 = m0 + BFVec4::load(weigth + i * 16 + 4 * 0) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 0); + m1 = m1 + BFVec4::load(weigth + i * 16 + 4 * 1) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 1); + m2 = m2 + BFVec4::load(weigth + i * 16 + 4 * 2) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 2); + } + + auto o0 = m0 + m1 + m2 + biasF; + o0 = BFVec4::min(o0, maxV); + o0 = BFVec4::max(o0, minV); + BFVec4::save(dest + 8 * unit + 0 * 4, o0); + } +} +static void _MNNConvDwF23SourceTransUnit(const int16_t *source, int16_t *dest, size_t unit); +static void _MNNSourceTransformCommonF23(const float *sourceF, float *destF, int unit, int iw, int pad, int su, int eu) { + auto source = (const int16_t*)sourceF; + auto dest = (int16_t*)destF; + for (int x = 0; x < su; ++x) { + auto dstX = dest + 4 * 4 * x; + auto sx = x * 2 - (int)pad; + auto ex = sx + 4; + + auto clampSx = std::max(sx, 0); + auto clampEx = std::min(ex, (int)iw); + + BFVec4 v[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (int i = clampSx; i < clampEx; ++i) { + v[i - sx] = BFVec4::load(source + 4 * i); + } + auto m0 = v[0] - v[2]; + auto m1 = v[1] + v[2]; + auto m2 = v[2] - v[1]; + auto m3 = v[3] - v[1]; + + BFVec4::save(dstX + 4 * 0, m0); + BFVec4::save(dstX + 4 * 1, m1); + BFVec4::save(dstX + 4 * 2, m2); + BFVec4::save(dstX + 4 * 3, m3); + } + _MNNConvDwF23SourceTransUnit(source + 4 * (su * 2 - pad), dest + 4 * 4 * su, eu - su); + + for (int x = eu; x < unit; ++x) { + auto dstX = dest + 4 * 4 * x; + auto sx = x * 2 - (int)pad; + auto ex = sx + 4; + + auto clampSx = std::max(sx, 0); + auto clampEx = std::min(ex, (int)iw); + + BFVec4 v[4] = {0.0f, 0.0f, 0.0f, 0.0f}; + for (int i = clampSx; i < clampEx; ++i) { + v[i - sx] = BFVec4::load(source + 4 * i); + } + auto m0 = v[0] - v[2]; + auto m1 = v[1] + v[2]; + auto m2 = v[2] - v[1]; + auto m3 = v[3] - v[1]; + + BFVec4::save(dstX + 4 * 0, m0); + BFVec4::save(dstX + 4 * 1, m1); + BFVec4::save(dstX + 4 * 2, m2); + BFVec4::save(dstX + 4 * 3, m3); + } +} + +static void _MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigthF, float *destF, size_t ow, const float* bias, const float* parameters) { + int unit = ow / 2; + auto weigth = (const int16_t*)weigthF; + auto dest = (int16_t*)destF; + + auto w00 = BFVec4::load(weigth + 0 * 16 + 4 * 0); + auto w01 = BFVec4::load(weigth + 0 * 16 + 4 * 1); + auto w02 = BFVec4::load(weigth + 0 * 16 + 4 * 2); + auto w03 = BFVec4::load(weigth + 0 * 16 + 4 * 3); + auto w10 = BFVec4::load(weigth + 1 * 16 + 4 * 0); + auto w11 = BFVec4::load(weigth + 1 * 16 + 4 * 1); + auto w12 = BFVec4::load(weigth + 1 * 16 + 4 * 2); + auto w13 = BFVec4::load(weigth + 1 * 16 + 4 * 3); + auto w20 = BFVec4::load(weigth + 2 * 16 + 4 * 0); + auto w21 = BFVec4::load(weigth + 2 * 16 + 4 * 1); + auto w22 = BFVec4::load(weigth + 2 * 16 + 4 * 2); + auto w23 = BFVec4::load(weigth + 2 * 16 + 4 * 3); + + auto biasF = BFVec4::load((const int16_t*)bias); + auto minV = BFVec4(parameters[2]); + auto maxV = BFVec4(parameters[3]); + for (int x = 0; x < unit; ++x) { + auto offset = 4 * 4 * x; + int i = 0; + BFVec4 m0 = w00 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 0); + BFVec4 m1 = w01 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 1); + BFVec4 m2 = w02 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 2); + BFVec4 m3 = w03 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 3); + + m0 = m0 + w10 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 0); + m1 = m1 + w11 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 1); + m2 = m2 + w12 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 2); + m3 = m3 + w13 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 3); + + m0 = m0 + w20 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 0); + m1 = m1 + w21 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 1); + m2 = m2 + w22 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 2); + m3 = m3 + w23 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 3); + + auto o0 = m0 + m1 + m2 + biasF; + auto o1 = m1 - m2 + m3 + biasF; + o0 = BFVec4::min(o0, maxV); + o1 = BFVec4::min(o1, maxV); + o0 = BFVec4::max(o0, minV); + o1 = BFVec4::max(o1, minV); + BFVec4::save(dest + 8 * x + 0 * 4, o0); + BFVec4::save(dest + 8 * x + 1 * 4, o1); + } + if (unit * 2 < ow) { + auto offset = 4 * 4 * unit; + BFVec4 m0 = w00 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 0); + BFVec4 m1 = w01 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 1); + BFVec4 m2 = w02 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 2); + + m0 = m0 + w10 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 0); + m1 = m1 + w11 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 1); + m2 = m2 + w12 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 2); + + m0 = m0 + w20 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 0); + m1 = m1 + w21 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 1); + m2 = m2 + w22 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 2); + auto o0 = m0 + m1 + m2 + biasF; + o0 = BFVec4::min(o0, maxV); + o0 = BFVec4::max(o0, minV); + BFVec4::save(dest + 8 * unit + 0 * 4, o0); + } +} +static void _MNNConvDwF23SourceTransUnit(const int16_t *source, int16_t *dest, size_t unit) { + if (unit <= 0) { + return; + } + BFVec4 v0 = BFVec4::load(source + 4 * 0); + BFVec4 v1 = BFVec4::load(source + 4 * 1); + BFVec4 v2; + BFVec4 v3; + source += 8; + + for (int x = 0; x < unit; ++x) { + v2 = BFVec4::load(source + 0 * 4); + v3 = BFVec4::load(source + 1 * 4); + auto m0 = v0 - v2; + auto m1 = v1 + v2; + auto m2 = v2 - v1; + auto m3 = v3 - v1; + + BFVec4::save(dest + 4 * 0, m0); + BFVec4::save(dest + 4 * 1, m1); + BFVec4::save(dest + 4 * 2, m2); + BFVec4::save(dest + 4 * 3, m3); + + source += 8; + dest += 16; + + v0 = v2; + v1 = v3; + } +} + +static void _MNNMatrixSub(float* CF, const float* AF, const float* BF, size_t widthC4, size_t cStride, size_t aStride, + size_t bStride, size_t height) { + auto A = (int16_t*)AF; + auto B = (int16_t*)BF; + auto C = (int16_t*)CF; + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + bStride * y; + auto c = C + cStride * y; + for (int x = 0; x < widthC4; ++x) { + BFVec4::save(c + 4 * x, BFVec4::load(a + 4 * x) - BFVec4::load(b + 4 * x)); + } + } +} +static void _MNNMatrixAdd(float* CF, const float* AF, const float* BF, size_t widthC4, size_t cStride, size_t aStride, + size_t bStride, size_t height) { + auto A = (int16_t*)AF; + auto B = (int16_t*)BF; + auto C = (int16_t*)CF; + for (int y = 0; y < height; ++y) { + auto a = A + aStride * y; + auto b = B + bStride * y; + auto c = C + cStride * y; + for (int x = 0; x < widthC4; ++x) { + BFVec4::save(c + 4 * x, BFVec4::load(a + 4 * x) + BFVec4::load(b + 4 * x)); + } + } +} + +static void _MNNStrassenMergeCFunction(float* c11F, float* c12F, float* c21F, float* c22F, float* xAddrF, size_t cStride, + size_t eSub, size_t hSub) { + auto c11 = (int16_t*)c11F; + auto c12 = (int16_t*)c12F; + auto c21 = (int16_t*)c21F; + auto c22 = (int16_t*)c22F; + auto xAddr = (int16_t*)xAddrF; + for (int y=0; y= height || w < 0 || w >= width) { + return -1; + } + } else { + // Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER + // For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1), + // the leftover reflections degrade to GridSamplePaddingMode_BORDER + h = h < 0 ? 0 : ( h > (height - 1) ? (height - 1) : h); + w = w < 0 ? 0 : ( w > (width - 1) ? (width - 1) : w); + } + return h * width * 4 + w * 4; +} + +void _MNNGridSampleInterp(float* output, const float* input, const float* cord, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode) { + int16_t* outputPtr = (int16_t*)output; + const int16_t* inputPtr = (const int16_t*)input; + const int16_t* cordPtr = (const int16_t*)cord; + + for (auto ow = 0; ow < outW; ++ow) { + auto w = MNNLowpToFp32(cordPtr[2 * ow + 0]); + auto h = MNNLowpToFp32(cordPtr[2 * ow + 1]); + BFVec4 interp; + + if (sampleMode == true) { //sampleMode == SampleMode_NEAREST + int nh = ::floor(h + 0.5f); + int nw = ::floor(w + 0.5f); + size_t ns = _MNNGridSampleComputeOffset(nh, nw, inH, inW, padMode); + for (int k = 0; k < channelCUnit; ++k) { + interp = ns == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + ns); + BFVec4::save(outputPtr + k * outOffset + 4 * ow, interp); + } + } else { //sampleMode == GridSampleMode_BILINEAR + int w0_h = ::floor(h); + int w0_w = ::floor(w); + int w1_h = ::ceil(h); + int w1_w = ::ceil(w); + auto oneV = BFVec4(1.0f); + + auto f0 = BFVec4((float)w1_w - w); + auto f1 = oneV - f0; + auto h0 = BFVec4((float)w1_h - h); + auto h1 = oneV - h0; + + size_t s00 = _MNNGridSampleComputeOffset(w0_h, w0_w, inH, inW, padMode); + size_t s01 = _MNNGridSampleComputeOffset(w0_h, w1_w, inH, inW, padMode); + size_t s10 = _MNNGridSampleComputeOffset(w1_h, w0_w, inH, inW, padMode); + size_t s11 = _MNNGridSampleComputeOffset(w1_h, w1_w, inH, inW, padMode); + + for (int k = 0; k < channelCUnit; ++k) { + BFVec4 i00 = s00 == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + s00); + BFVec4 i01 = s01 == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + s01); + BFVec4 i10 = s10 == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + s10); + BFVec4 i11 = s11 == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + s11); + + BFVec4 i0 = i00 * f0 + i01 * f1; + BFVec4 i1 = i10 * f0 + i11 * f1; + + interp = i0 * h0 + i1 * h1; + BFVec4::save(outputPtr + k * outOffset + 4 * ow, interp); + } + } + } +} + + +static void _MNNAddC4WithStride(const float* sourceF, float* destF, size_t srcStride, size_t dstStride, size_t count) { + auto source = (const int16_t*)sourceF; + auto dest = (int16_t*)destF; + for (int i = 0; i < count; ++i) { + auto s = source + i * srcStride; + auto d = dest + i * dstStride; + BFVec4::save(d, BFVec4::load(d) + BFVec4::load(s)); + } +} +static void _MNNDeconvRunForUnitDepthWise(const int16_t* dst, int16_t* src, const int16_t* weight, size_t fw, size_t fh, + size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) { + int fx, fy; + auto src_z = src; + auto weight_z = weight; + BFVec4 dstV = BFVec4::load(dst); + for (fy = 0; fy < fh; ++fy) { + auto src_y = src_z + fy * dilateY_step; + auto weight_y = weight_z + fy * weight_y_step; + for (fx = 0; fx < fw; ++fx) { + BFVec4 weight_x = BFVec4::load(weight_y + 4 * fx); + BFVec4 src_x = BFVec4::load(src_y + fx * dilateX_step); + BFVec4::save(src_y + fx * dilateX_step, src_x + weight_x * dstV); + } + } +} +static void _MNNDeconvRunForLineDepthwise(const int16_t* dst, int16_t* src, const int16_t* weight, size_t width, size_t src_w_setup, + size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) { + int dx; + for (dx = 0; dx < width; ++dx) { + auto dst_x = dst + dx * 4; + auto src_dx = src + src_w_setup * dx; + _MNNDeconvRunForUnitDepthWise(dst_x, src_dx, weight, fw, fh, fw * 4, dilateX_step, dilateY_step); + } +} + +static void _MNNComputeMatMulForH_1_BF16(const float* AF, const float* BF, float* CF, const float* biasPtrF, const MatMulParam* param, size_t tId) { + auto A = (const int16_t*)AF; + auto B = (const int16_t*)BF; + auto C = (int16_t*)CF; + auto biasPtr = (const int16_t*)biasPtrF; + int e = param->e; + int l = param->l; + int numberThread = param->numberThread; + float biasValue = 0.0f; + auto bf = BF16Functions::get(); + if (nullptr != biasPtr) { + bf->MNNLowpToFp32(biasPtr, &biasValue, 1); + } + if (param->ATranspose) { + auto eC4 = e / 4; + auto eR = e % 4; + for (int y=tId; y 0) { + BFVec4 sumValue = BFVec4(biasValue); + auto srcY = A + eC4 * 4; + int16_t AR[4]; + for (int x=0; x 0) { + int16_t AR[4] = {0, 0, 0, 0}; + int16_t BR[4] = {0, 0, 0, 0}; + ::memcpy(AR, srcY + lC4 * 4, lR * sizeof(int16_t)); + ::memcpy(BR, B + 4 * lC4, lR * sizeof(int16_t)); + sumValue = sumValue + BFVec4::load(AR) * BFVec4::load(BR); + } + float sumSingle = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; + bf->MNNFp32ToLowp(&sumSingle, C + y, 1); + } +} + +static void _MNNComputeMatMulForE_1_BF16(const float* AF, const float* BF, float* CF, const float* biasPtrF, const MatMulParam* param, size_t tId) { + auto l = param->l; + auto h = param->h; + auto numberThread = param->numberThread; + auto lC4 = l / 4; + auto lR = l % 4; + auto A = (const int16_t*)AF; + auto B = (const int16_t*)BF; + auto C = (int16_t*)CF; + auto biasPtr = (const int16_t*)biasPtrF; + auto bf16 = BF16Functions::get(); + if (param->BTranspose) { + for (int y=tId; y 0) { + int16_t AR[4] = {0, 0, 0, 0}; + int16_t BR[4] = {0, 0, 0, 0}; + ::memcpy(AR, A + lC4 * 4, lR * sizeof(int16_t)); + ::memcpy(BR, by + 4 * lC4, lR * sizeof(int16_t)); + sumValue = sumValue + BFVec4::load(AR) * BFVec4::load(BR); + } + float sumRemain = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; + if (nullptr != biasPtr) { + sumRemain += BFVec4::broadcast(biasPtr[y])[0]; + } + bf16->MNNFp32ToLowp(&sumRemain, C + y, 1); + } + } else { + auto hC4 = h / 4; + auto hR = h % 4; + for (int y=tId; y 0) { + auto bs = B + 4 * hC4; + BFVec4 sumValue = BFVec4(0.0f); + if (biasPtr != nullptr) { + int16_t biasTemp[4]; + ::memcpy(biasTemp, biasPtr + 4 * hC4, hR * sizeof(int16_t)); + sumValue = BFVec4::load(biasTemp); + } + auto srcY = A + 4 * hC4 * l; + int16_t bTemp[4]; + for (int x=0; xMNNConvRunForLineDepthwise = MNNConvRunForLineDepthwiseBF16; + gInstance->MNNConvRunForUnitDepthWise = MNNConvRunForUnitDepthWiseBF16; + gInstance->MNNAxByClampBroadcastUnit = MNNAxByClampBroadcastUnitBF16; + gInstance->MNNFp32ToLowp = _MNNFp32ToLowp; + gInstance->MNNLowpToFp32 = _MNNLowpToFp32; + gInstance->bytes = 2; + gInstance->pack = 4; + gInstance->MNNPackCUnit = (decltype(gInstance->MNNPackCUnit))MNNPackC4Int16; + gInstance->MNNUnpackCUnit = (decltype(gInstance->MNNUnpackCUnit))MNNUnpackC4Int16; + gInstance->MNNUnpackCUnitTranspose = (decltype(gInstance->MNNUnpackCUnitTranspose))MNNPackTransposeInt16; + gInstance->MNNPackCUnitTranspose = (decltype(gInstance->MNNPackCUnitTranspose))MNNUnpackTransposeInt16; + gInstance->MNNConvDwF23MulTransUnit = _MNNConvDwF23MulTransUnit; + gInstance->MNNSourceTransformCommonF23 = _MNNSourceTransformCommonF23; + gInstance->MNNMultiAndDestTransformCommon23 = _MNNMultiAndDestTransformCommon23; + gInstance->MNNMatrixAdd = _MNNMatrixAdd; + gInstance->MNNMatrixSub = _MNNMatrixSub; + gInstance->MNNStrassenMergeCFunction = _MNNStrassenMergeCFunction; + gInstance->penalty = 10.0f; + gInstance->MNNScaleAndAddBias = _MNNScaleAndAddBias; + gInstance->MNNGridSampleComputeCord = _MNNGridSampleComputeCord; + gInstance->MNNGridSampleInterp = _MNNGridSampleInterp; + gInstance->MNNCopyC4WithStride = MNNCopyC4Int16WithStride; + gInstance->MNNAddC4WithStride = _MNNAddC4WithStride; + gInstance->chooseWinoSourceTransformPack = (decltype(gInstance->chooseWinoSourceTransformPack))(WinogradFunctionHalf::chooseWinoSourceTransformPack); + gInstance->chooseWinoSourceUnrollTransform = (decltype(gInstance->chooseWinoSourceUnrollTransform))(WinogradFunctionHalf::chooseSourceUnrollTransform); + gInstance->chooseWinoDestUnrollTransform = (decltype(gInstance->chooseWinoDestUnrollTransform))(WinogradFunctionHalf::chooseWinoDestUnrollTransform); + gInstance->MNNDeconvRunForLineDepthwise = (decltype(gInstance->MNNDeconvRunForLineDepthwise))_MNNDeconvRunForLineDepthwise; + gInstance->MNNDeconvRunForUnitDepthWise = (decltype(gInstance->MNNDeconvRunForUnitDepthWise))_MNNDeconvRunForUnitDepthWise; + gInstance->MNNSelectBinaryFunctionForFloat = BF16BinaryFloatSelect; + gInstance->MNNSelectUnaryFunctionForFloat = BF16UnaryFloatSelect; + gInstance->MNNReluWithSlopeChannel = MNNReluWithSlopeChannelBF16;// TODO: Optimize it + +#if !defined(MNN_USE_SSE) && !defined(MNN_USE_NEON) + gInstance->penalty = 1.5f; + gInstance->MNNPackForMatMul_B = MNNPackForMatMul_B_BF16; // common function MNNPackForMatMul_B_BF16 is needed even with out sse or arm neon. + gInstance->MNNPackC4ForMatMul_A = MNNPackC4ForMatMul_A_BF16;// + gInstance->MNNPackedMatMul = (decltype(gInstance->MNNPackedMatMul))MNNPackedMatMul_BF16; + gInstance->MNNPackedMatMulRemain = (decltype(gInstance->MNNPackedMatMulRemain))MNNPackedMatMulRemain_BF16; +#endif + gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_BF16; + gInstance->MNNComputeMatMulForE_1 = _MNNComputeMatMulForE_1_BF16; + gInstance->MNNPoolingAvg = (decltype(gInstance->MNNPoolingAvg))(poolingAvg); + gInstance->MNNPoolingMax = (decltype(gInstance->MNNPoolingMax))(poolingMax); + gInstance->MNNPoolingMaxWithRedice = (decltype(gInstance->MNNPoolingMaxWithRedice))(poolingMaxWithRedice); + +#if defined(MNN_USE_SSE) + gInstance->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B_BF16; + auto cpuFlags = libyuv::InitCpuFlags(); + if (!(cpuFlags & libyuv::kCpuHasF16C)) { + delete gInstance; + gInstance = nullptr; + return false; + } + if (cpuFlags & libyuv::kCpuHasAVX2) { + gInstance->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B_BF16; + gInstance->MNNGetMatMulPackMode = _AVX_MNNGetMatMulPackMode_BF16; + gInstance->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A_BF16; + gInstance->MNNPackedMatMul = _AVX_MNNPackedMatMulFMA_BF16; + gInstance->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemainFMA_BF16; + return true; + } +#elif defined(MNN_USE_NEON) + gInstance->MNNPackForMatMul_B = NEON_MNNPackForMatMul_B_BF16; + gInstance->MNNGetMatMulPackMode = NEON_MNNGetMatMulPackMode_BF16; + gInstance->MNNPackC4ForMatMul_A = NEON_MNNPackC4ForMatMul_A_BF16; + gInstance->MNNPackedMatMul = NEON_MNNPackedMatMul_BF16; + gInstance->MNNPackedMatMulRemain = NEON_MNNPackedMatMulRemain_BF16; + gInstance->MNNConvRunForLineDepthwise = NEON_MNNConvRunForLineDepthwise_BF16; + gInstance->MNNConvRunForUnitDepthWise = NEON_MNNConvRunForUnitDepthWise_BF16; + gInstance->MNNAxByClampBroadcastUnit = NEON_MNNAxByClampBroadcastC4_BF16; +#ifdef __aarch64__ + cpuinfo_arm_isa gCPUInfo; + cpuinfo_arm_init(&gCPUInfo); + gInstance->supportFp16arith = gCPUInfo.fp16arith; + gInstance->supportSDot = gCPUInfo.dot; + gInstance->supportI8mm = gCPUInfo.i8mm; + if (gInstance->supportI8mm) { + gInstance->MNNPackForMatMul_B = ARMV86_MNNPackForMatMul_B_BF16; + gInstance->MNNPackC4ForMatMul_A = ARMV86_MNNPackC4ForMatMul_A_BF16; + gInstance->MNNGetMatMulPackMode = ARMV86_MNNGetMatMulPackMode_BF16; + gInstance->MNNPackedMatMul = ARMV86_MNNPackedMatMul_BF16; + gInstance->MNNPackedMatMulRemain = ARMV86_MNNPackedMatMulRemain_BF16; + } +#endif + return true; +#endif + // TODO: raw cpu version of bf16 + return true; +} + +CoreFunctions* BF16Functions::get() { + return gInstance; +} +}; diff --git a/backupcode/cpubackend/bf16/BF16Functions.hpp b/backupcode/cpubackend/bf16/BF16Functions.hpp new file mode 100644 index 000000000..e6b29a0f6 --- /dev/null +++ b/backupcode/cpubackend/bf16/BF16Functions.hpp @@ -0,0 +1,16 @@ +#ifndef BF16Functions_hpp +#define BF16Functions_hpp +#include +#include +#include +#include "core/Macro.h" +#include "../compute/CommonOptFunction.h" +namespace MNN { +class BF16Functions { +public: + static bool init(); + static CoreFunctions* get(); +}; +}; + +#endif diff --git a/source/backend/cpu/bf16/BF16Unary.cpp b/backupcode/cpubackend/bf16/BF16Unary.cpp similarity index 100% rename from source/backend/cpu/bf16/BF16Unary.cpp rename to backupcode/cpubackend/bf16/BF16Unary.cpp diff --git a/source/backend/cpu/bf16/BF16Unary.hpp b/backupcode/cpubackend/bf16/BF16Unary.hpp similarity index 100% rename from source/backend/cpu/bf16/BF16Unary.hpp rename to backupcode/cpubackend/bf16/BF16Unary.hpp diff --git a/backupcode/cpubackend/bf16/CMakeLists.txt b/backupcode/cpubackend/bf16/CMakeLists.txt new file mode 100644 index 000000000..b533bec6f --- /dev/null +++ b/backupcode/cpubackend/bf16/CMakeLists.txt @@ -0,0 +1,19 @@ + +file(GLOB MNN_BF16_SRCS "${CMAKE_CURRENT_LIST_DIR}/*") + +file(GLOB MNN_BF16_SRCS_ASM "${CMAKE_CURRENT_LIST_DIR}/asm/*") + +add_library( + MNN_BF16 + OBJECT + ${MNN_BF16_SRCS} + ) +target_compile_options(MNN_BF16 PRIVATE -DMNN_SUPPORT_BF16) +if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(X86_64)|(x64)|(X64)|(amd64)|(AMD64)|(i686)") + if (MNN_USE_SSE) + target_compile_options(MNN_BF16 PRIVATE -DMNN_USE_SSE) + if (MNN_SSE_USE_FP16_INSTEAD) + target_compile_options(MNN_BF16 PRIVATE -DMNN_SSE_USE_FP16_INSTEAD -mf16c) + endif() + endif() +endif() diff --git a/backupcode/cpubackend/bf16/VecHalf.hpp b/backupcode/cpubackend/bf16/VecHalf.hpp new file mode 100644 index 000000000..d5fe3a69f --- /dev/null +++ b/backupcode/cpubackend/bf16/VecHalf.hpp @@ -0,0 +1,517 @@ +// +// VecHalf.hpp +// MNN +// +// Created by MNN on 2021/01/26. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifndef VecHalf_hpp +#define VecHalf_hpp +#include "core/Macro.h" +#include +#include +#include // supply std::max and std::min + +#ifdef MNN_USE_NEON +#include +#endif +#ifdef MNN_USE_SSE +#if defined(_MSC_VER) +#include +#else +#include +#endif +#endif + +namespace MNN { +namespace Math { + +template +struct VecHalf { + using VecType = VecHalf; + std::array value; + VecType operator+(const VecType& lr) const { + VecType dst; + for (int i = 0; i < N; ++i) { + dst.value[i] = value[i] + lr.value[i]; + } + return dst; + } + VecType operator-(const VecType& lr) const { + VecType dst; + for (int i = 0; i < N; ++i) { + dst.value[i] = value[i] - lr.value[i]; + } + return dst; + } + VecType operator*(const VecType& lr) const { + VecType dst; + for (int i = 0; i < N; ++i) { + dst.value[i] = value[i] * lr.value[i]; + } + return dst; + } + VecType operator+=(const VecType& lr) { + for (int i = 0; i < N; ++i) { + value[i] = value[i] + lr.value[i]; + } + return *this; + } + VecType operator-=(const VecType& lr) { + for (int i = 0; i < N; ++i) { + value[i] = value[i] - lr.value[i]; + } + return *this; + } + VecType operator*(float lr) const { + VecType dst; + for (int i = 0; i < N; ++i) { + dst.value[i] = value[i] * lr; + } + return dst; + } + + VecType& operator=(const VecType& lr) { + for (int i = 0; i < N; ++i) { + value[i] = lr.value[i]; + } + return *this; + } + VecType operator-() { + VecType dst; + for (int i = 0; i < N; ++i) { + dst.value[i] = -value[i]; + } + return dst; + } + VecHalf() { + } + VecHalf(const float v) { + for (int i = 0; i < N; ++i) { + value[i] = v; + } + } + + VecHalf(float v0, float v1, float v2, float v3) { + value[0] = v0; + value[1] = v1; + value[2] = v2; + value[3] = v3; + } + VecHalf(std::array&& v) { + value = std::move(v); + } + VecHalf(const VecType& lr) { + for (int i = 0; i < N; ++i) { + value[i] = lr.value[i]; + } + } + float operator[](size_t i) { + return value[i]; + } + static VecType broadcast(int16_t val) { + VecType v; + auto tempV = (int32_t*)v.value.data(); + for (int i = 0; i < N; ++i) { + tempV[i] = val << 16; + } + return v; + } + static VecType broadcast(int16_t* val) { + VecType v; + auto tempV = (int32_t*)v.value.data(); + tempV[0] = (*val) << 16; + for (int i = 1; i < N; ++i) { + tempV[i] = tempV[0]; + } + return v; + } + static VecType load(const int16_t* addr) { + VecType v; + auto tempV = (int32_t*)v.value.data(); + for (int i = 0; i < N; ++i) { + tempV[i] = addr[i] << 16; + } + return v; + } + static void save(int16_t* addr, const VecType& v) { + auto tempV = (int32_t*)v.value.data(); + for (int i = 0; i < N; ++i) { + addr[i] = tempV[i] >> 16; + } + } + static VecType max(const VecType& v1, const VecType& v2) { + VecType dst; + for (int i = 0; i < N; ++i) { + dst.value[i] = std::max(v1.value[i], v2.value[i]); + } + return dst; + } + static VecType min(const VecType& v1, const VecType& v2) { + VecType dst; + for (int i = 0; i < N; ++i) { + dst.value[i] = std::min(v1.value[i], v2.value[i]); + } + return dst; + } + static VecType fma(const VecType& v1, const VecType& v2, const VecType& v3) { + return v1 + v2 * v3; + } + static VecType fms(const VecType& v1, const VecType& v2, const VecType& v3) { + return v1 - v2 * v3; + } + static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) { + VecType source[4] = {vec0, vec1, vec2, vec3}; + for (int i = 0; i < N; ++i) { + vec0.value[i] = source[i % 4].value[i >> 2]; + vec1.value[i] = source[i % 4].value[(i + N)>> 2]; + vec2.value[i] = source[i % 4].value[(i + 2 * N)>> 2]; + vec3.value[i] = source[i % 4].value[(i + 3 * N)>> 2]; + } + } + + static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) { + + MNN_ASSERT(false); + } +}; + +#if defined(MNN_USE_SSE) + +template<> +struct VecHalf<4> { + using VecType = VecHalf<4>; + __m128 value; + VecType operator+(const VecType& lr) const { + VecType dst = { _mm_add_ps(value, lr.value) }; + return dst; + } + VecType operator-(const VecType& lr) const { + VecType dst = { _mm_sub_ps(value, lr.value) }; + return dst; + } + VecType operator*(const VecType& lr) const { + VecType dst = { _mm_mul_ps(value, lr.value) }; + return dst; + } + VecType operator+=(const VecType& lr) { + value = _mm_add_ps(value, lr.value); + return *this; + } + VecType operator-=(const VecType& lr) { + value = _mm_sub_ps(value, lr.value); + return *this; + } + VecType operator*(float lr) const { + VecType dst = { _mm_mul_ps(value, _mm_set1_ps(lr)) }; + return dst; + } + + VecType& operator=(const VecType& lr) { + value = lr.value; + return *this; + } + VecType operator-() { + VecType dst; +#if defined(_MSC_VER) + dst.value = _mm_xor_ps(value, _mm_set1_ps(-0.f)); // Using unary operation to SSE vec is GCC extension. We can not do this directly in MSVC. +#else + dst.value = -value; +#endif + return dst; + } + VecHalf() { + } + VecHalf(const float v) { + value = _mm_set1_ps(v); + } + VecHalf(const float f0, const float f1, const float f2, const float f3) { + value = _mm_set_ps(f0, f1, f2, f3); + } + VecHalf(__m128& v) { + value = v; + } + VecHalf(__m128&& v) { + value = std::move(v); + } + VecHalf(const VecType& lr) { + value = lr.value; + } + VecHalf(VecType&& lr) { + value = std::move(lr.value); + } + float operator[](size_t i) { +#if defined(_MSC_VER) // X64 native only mandatory support SSE and SSE2 extension, and we can not find intrinsic function to extract element directly by index in SSE and SSE2 extension. + float temp[4]; + _mm_storeu_ps(temp, value); + return temp[i]; +#else + return value[i]; +#endif + } + static VecType broadcast(int16_t val) { + auto temp = _mm_set1_epi16(val); +#ifndef MNN_SSE_USE_FP16_INSTEAD + auto zero = _mm_xor_si128(temp, temp); + auto res = _mm_castsi128_ps(_mm_unpacklo_epi16(zero, temp)); +#else + auto res = _mm_cvtph_ps(temp); +#endif + VecType v = { std::move(res) }; + return v; + } + static VecType broadcast(int16_t* val) { + return broadcast(*val); + } + static VecType load(const int16_t* addr) { + auto temp = _mm_loadl_epi64((__m128i*)addr); +#ifndef MNN_SSE_USE_FP16_INSTEAD + auto zero = _mm_xor_si128(temp, temp); + auto res = _mm_castsi128_ps(_mm_unpacklo_epi16(zero, temp)); +#else + auto res = _mm_cvtph_ps(temp); +#endif + VecType v = { std::move(res) }; + return v; + } + static void save(int16_t* addr, const VecType& v) { +#ifndef MNN_SSE_USE_FP16_INSTEAD + auto temp = _mm_castps_si128(v.value); + temp = _mm_srai_epi32(temp, 16); + temp = _mm_packs_epi32(temp, temp); +#else + static __m128 gMinValue = _mm_set1_ps(-32768); + static __m128 gMaxValue = _mm_set1_ps(32767); + auto t = _mm_max_ps(v.value, gMinValue); + t = _mm_min_ps(t, gMaxValue); + auto temp = _mm_cvtps_ph(t, 0x8); +#endif + _mm_storel_epi64((__m128i*)addr, temp); + } + static VecType max(const VecType& v1, const VecType& v2) { + VecType dst = { _mm_max_ps(v1.value, v2.value) }; + return dst; + } + static VecType min(const VecType& v1, const VecType& v2) { + VecType dst = { _mm_min_ps(v1.value, v2.value) }; + return dst; + } + static VecType fma(const VecType& v1, const VecType& v2, const VecType& v3) { + return v1 + v2 * v3; + } + static VecType fms(const VecType& v1, const VecType& v2, const VecType& v3) { + return v1 - v2 * v3; + } + static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) { + __m128 tmp3, tmp2, tmp1, tmp0; + tmp0 = _mm_unpacklo_ps((vec0.value), (vec1.value)); + tmp2 = _mm_unpacklo_ps((vec2.value), (vec3.value)); + tmp1 = _mm_unpackhi_ps((vec0.value), (vec1.value)); + tmp3 = _mm_unpackhi_ps((vec2.value), (vec3.value)); + vec0.value = _mm_movelh_ps(tmp0, tmp2); + vec1.value = _mm_movehl_ps(tmp2, tmp0); + vec2.value = _mm_movelh_ps(tmp1, tmp3); + vec3.value = _mm_movehl_ps(tmp3, tmp1); + } + + // x86 VecHalf transpose12 unused in any case + static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) { + MNN_ASSERT(false); + } +}; +#endif + +#if defined(MNN_USE_NEON) + +template<> +struct VecHalf<4> { + using VecType = VecHalf<4>; + float32x4_t value; + VecType operator+(const VecType& lr) const { + VecType dst = { vaddq_f32(value, lr.value) }; + return dst; + } + VecType operator-(const VecType& lr) const { + VecType dst = { vsubq_f32(value, lr.value) }; + return dst; + } + VecType operator*(const VecType& lr) const { + VecType dst = { vmulq_f32(value, lr.value) }; + return dst; + } + VecType operator*(const float lr) const { + VecType dst = { vmulq_f32(value, vdupq_n_f32(lr)) }; + return dst; + } + VecType operator+=(const VecType& lr) { + value = vaddq_f32(value, lr.value); + return *this; + } + VecType operator-=(const VecType& lr) { + value = vsubq_f32(value, lr.value); + return *this; + } + + VecType& operator=(const VecType& lr) { + value = lr.value; + return *this; + } + VecType operator-() { + VecType dst = { vnegq_f32(value) }; + return dst; + } + VecHalf() { + } + VecHalf(const float v) { + value = vdupq_n_f32(v); + } + VecHalf(const float f0, const float f1, const float f2, const float f3) { + vsetq_lane_f32(f0, value, 0); + vsetq_lane_f32(f1, value, 1); + vsetq_lane_f32(f2, value, 2); + vsetq_lane_f32(f3, value, 3); + } + VecHalf(float32x4_t& v) { + value = v; + } + VecHalf(float32x4_t&& v) { + value = std::move(v); + } + VecHalf(const VecType& lr) { + value = lr.value; + } + VecHalf(VecType&& lr) { + value = std::move(lr.value); + } + float operator[](const int i) { + // vgetq_lane_f32(value, i) does NOT work, i must be const number such as 0, 2, + return value[i]; + } + static VecType broadcast(int16_t* valPtr) { + VecType dst = { vreinterpretq_f32_s32(vshll_n_s16(vld1_dup_s16(valPtr), 16)) }; + return dst; + } + static VecType broadcast(int16_t val) { + VecType dst = { vreinterpretq_f32_s32(vshll_n_s16(vdup_n_s16(val), 16)) }; + return dst; + } + static VecType load(const int16_t* addr) { + + // equivalent to this: + // int16x4_t vec4s16 = vld1_s16(addr); // load bf16 data as fixed point data of 16-bit. + // int32x4_t vec4s32 =vshll_n_s16(vec4s16, 16); // shift left 16bit as 32-bit data. + // float32x4_t vec4f32 = vreinterpretq_f32_s32(vec4s32);// treat 32-bit fix point result as float32 data + // VecType dest = { vec4f32 }; // construct a struct of VecType + + VecType dst = { vreinterpretq_f32_s32(vshll_n_s16(vld1_s16(addr), 16)) }; + return dst; + } + static void save(int16_t* addr, const VecType& v) { + vst1_s16(addr, vshrn_n_s32(vreinterpretq_s32_f32(v.value), 16)); + return; + } + static VecType max(const VecType& v1, const VecType& v2) { + VecType dst = { vmaxq_f32(v1.value, v2.value) }; + return dst; + } + static VecType min(const VecType& v1, const VecType& v2) { + VecType dst = { vminq_f32(v1.value, v2.value) }; + return dst; + } + static VecType fma(const VecType& v1, const VecType& v2, const VecType& v3) { + VecType dst = {vmlaq_f32(v1.value, v2.value, v3.value)}; + return dst; + } + static VecType fms(const VecType& v1, const VecType& v2, const VecType& v3) { + VecType dst = {vmlsq_f32(v1.value, v2.value, v3.value)}; + return dst; + } + static inline void transpose4(VecType& vec0, VecType& vec1, VecType& vec2, VecType& vec3) { +#ifdef __aarch64__ + auto m0 = vtrn1q_s32(reinterpret_cast(vec0.value), reinterpret_cast(vec1.value)); + auto m1 = vtrn2q_s32(reinterpret_cast(vec0.value), reinterpret_cast(vec1.value)); + auto m2 = vtrn1q_s32(reinterpret_cast(vec2.value), reinterpret_cast(vec3.value)); + auto m3 = vtrn2q_s32(reinterpret_cast(vec2.value), reinterpret_cast(vec3.value)); + vec0.value = reinterpret_cast(vtrn1q_s64(reinterpret_cast(m0), reinterpret_cast(m2))); + vec1.value = reinterpret_cast(vtrn1q_s64(reinterpret_cast(m1), reinterpret_cast(m3))); + vec2.value = reinterpret_cast(vtrn2q_s64(reinterpret_cast(m0), reinterpret_cast(m2))); + vec3.value = reinterpret_cast(vtrn2q_s64(reinterpret_cast(m1), reinterpret_cast(m3))); +#else + + auto m0m1 = vtrnq_s32(reinterpret_cast(vec0.value), reinterpret_cast(vec1.value)); + auto m2m3 = vtrnq_s32(reinterpret_cast(vec2.value), reinterpret_cast(vec3.value)); + vec0.value = reinterpret_cast(m0m1.val[0]); + vec1.value = reinterpret_cast(m0m1.val[1]); + vec2.value = reinterpret_cast(m2m3.val[0]); + vec3.value = reinterpret_cast(m2m3.val[1]); + vec0.value = reinterpret_cast(vsetq_lane_s64(vgetq_lane_s64(reinterpret_cast(m2m3.val[0]), 0), reinterpret_cast(vec0.value), 1)); + vec1.value = reinterpret_cast(vsetq_lane_s64(vgetq_lane_s64(reinterpret_cast(m2m3.val[1]), 0), reinterpret_cast(vec1.value), 1)); + vec2.value = reinterpret_cast(vsetq_lane_s64(vgetq_lane_s64(reinterpret_cast(m0m1.val[0]), 1), reinterpret_cast(vec2.value), 0)); + vec3.value = reinterpret_cast(vsetq_lane_s64(vgetq_lane_s64(reinterpret_cast(m0m1.val[1]), 1), reinterpret_cast(vec3.value), 0)); + /* + generated arm32 assembly code is almost the same as: + vtrn.32 d0, d2 + vtrn.32 d1, d3 + vtrn.32 d4, d6 + vtrn.32 d5, d7 + vswp d1, d4 + vswp d3, d6 + */ + +#endif + } + static inline void transpose4(int16x4_t& vec0, int16x4_t& vec1, int16x4_t& vec2, int16x4_t& vec3) { + auto trans0 = vtrn_s16(vec0, vec1); + auto m0 = trans0.val[0]; + auto m1 = trans0.val[1]; + auto trans1 = vtrn_s16(vec2, vec3); + auto m2 = trans1.val[0]; + auto m3 = trans1.val[1]; + auto trans2 = vtrn_s32(reinterpret_cast(m0), reinterpret_cast(m2)); + vec0 = reinterpret_cast(trans2.val[0]); + vec2 = reinterpret_cast(trans2.val[1]); + auto trans3 = vtrn_s32(reinterpret_cast(m1), reinterpret_cast(m3)); + vec1 = reinterpret_cast(trans3.val[0]); + vec3 = reinterpret_cast(trans3.val[1]); + + } + static inline void transpose12(int16_t* srcPtr, const size_t packCUnit) { + auto s0 = vld1_s16(srcPtr + 0 * packCUnit); + auto s3 = vld1_s16(srcPtr + 1 * packCUnit); + auto s6 = vld1_s16(srcPtr + 2 * packCUnit); + auto s9 = vld1_s16(srcPtr + 3 * packCUnit); + auto s1 = vld1_s16(srcPtr + 4 * packCUnit); + auto s4 = vld1_s16(srcPtr + 5 * packCUnit); + auto s7 = vld1_s16(srcPtr + 6 * packCUnit); + auto s10 = vld1_s16(srcPtr + 7 * packCUnit); + auto s2 = vld1_s16(srcPtr + 8 * packCUnit); + auto s5 = vld1_s16(srcPtr + 9 * packCUnit); + auto s8 = vld1_s16(srcPtr + 10 * packCUnit); + auto s11 = vld1_s16(srcPtr + 11 * packCUnit); + + transpose4(s0, s3, s6, s9); + transpose4(s1, s4, s7, s10); + transpose4(s2, s5, s8, s11); + + vst1_s16(srcPtr + 0 * packCUnit, s0); + vst1_s16(srcPtr + 1 * packCUnit, s1); + vst1_s16(srcPtr + 2 * packCUnit, s2); + vst1_s16(srcPtr + 3 * packCUnit, s3); + vst1_s16(srcPtr + 4 * packCUnit, s4); + vst1_s16(srcPtr + 5 * packCUnit, s5); + vst1_s16(srcPtr + 6 * packCUnit, s6); + vst1_s16(srcPtr + 7 * packCUnit, s7); + vst1_s16(srcPtr + 8 * packCUnit, s8); + vst1_s16(srcPtr + 9 * packCUnit, s9); + vst1_s16(srcPtr + 10 * packCUnit, s10); + vst1_s16(srcPtr + 11 * packCUnit, s11); + + } +}; +#endif + +} + +} +#endif diff --git a/source/backend/cpu/bf16/WinogradOptFunctionHalf.cpp b/backupcode/cpubackend/bf16/WinogradOptFunctionHalf.cpp similarity index 100% rename from source/backend/cpu/bf16/WinogradOptFunctionHalf.cpp rename to backupcode/cpubackend/bf16/WinogradOptFunctionHalf.cpp diff --git a/source/backend/cpu/bf16/WinogradOptFunctionHalf.hpp b/backupcode/cpubackend/bf16/WinogradOptFunctionHalf.hpp similarity index 100% rename from source/backend/cpu/bf16/WinogradOptFunctionHalf.hpp rename to backupcode/cpubackend/bf16/WinogradOptFunctionHalf.hpp diff --git a/source/backend/cpu/bf16/register.py b/backupcode/cpubackend/bf16/register.py similarity index 100% rename from source/backend/cpu/bf16/register.py rename to backupcode/cpubackend/bf16/register.py diff --git a/docs/contribute/op.md b/docs/contribute/op.md index 8591c6338..059a84d25 100644 --- a/docs/contribute/op.md +++ b/docs/contribute/op.md @@ -12,9 +12,10 @@ MNN 的算子转换与实现如下图, 3. 添加几何计算实现(可选,如果实现几何计算,无须后续在各后端添加算子实现) 4. 添加各后端算子实现(可选,选择需要部分进行实现) -![image.png](https://cdn.nlark.com/yuque/0/2021/png/405896/1618994794052-575a79b9-d291-4d1b-a630-79dd705bc977.png#clientId=u1c902b2d-d8e6-4&from=paste&height=701&id=ue223d8c2&margin=%5Bobject%20Object%5D&name=image.png&originHeight=1402&originWidth=3394&originalType=binary&ratio=1&size=256977&status=done&style=none&taskId=u4663d0eb-adcf-435b-b540-f61d2617cd4&width=1697) +![image.png](pic1.png) ### 添加算子的流程 -![image.png](https://cdn.nlark.com/yuque/0/2021/png/405896/1618995111237-321c5ca8-ed99-4cfc-9d91-04deaa2e29eb.png#clientId=u1c902b2d-d8e6-4&from=paste&height=597&id=u518a1fda&margin=%5Bobject%20Object%5D&name=image.png&originHeight=1194&originWidth=2714&originalType=binary&ratio=1&size=222438&status=done&style=none&taskId=u9c8f2ef4-7bf3-4b18-9560-794c3344f01&width=1357) + +![image.png](pic2.png) 简单来说,优先转换,然后组合,然后几何计算,最后各后端实现。 ## 添加Schema描述 @@ -254,25 +255,31 @@ REGISTER_CPU_OP_CREATOR(CPUMyCustomOpCreator, OpType_MyCustomOp); ``` ### 添加Metal实现 -1. 添加Shader -在`source/backend/Metal`目录下添加`MetalMyCustomOp.metal`,并添加进Xcode工程。metal可以参考目录下已有实现。 -2. 实现类声明 -在`source/backend/Metal`目录下添加`MetalMyCustomOp.hpp`和`MetalMyCustomOp.cpp`,并添加进Xcode工程: +- 实现类声明 + +在`source/backend/metal`目录下添加`MetalMyCustomOp.hpp`和`MetalMyCustomOp.cpp` ```cpp class MetalMyCustomOp : public Execution { public: virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - virtual ErrorCode onExecute(const std::vector &inputs, - const std::vector &outputs) override; + virtual void onEncode(const std::vector &inputs, const std::vector &outputs, id encoder) override; }; ``` -3. 实现`onResize`和`onExecute` -不同于CPU Tensor将数据存储在host指针中,Metal数据指针存放在`deviceId`中,deviceId上存储的是`id`: +- 实现`onResize`和`onEncode` + +尽量将申请内存和计算group size 的操作放在 onResize 函数中。 + +onEncode 时,使用传入的 encoder 编排计算任务,不要自行创建 command buffer 或 encoder + +- 内存使用 + +不同于CPU Tensor将数据存储在host指针中,Metal数据指针存放在`deviceId`中,deviceId上存储的是`id`, ,由于内存复用机制,各Tensor有可能共用同一块内存,以offset进行偏移: ```objectivec auto buffer = (__bridge id)(void *)tensor->deviceId(); +auto offset = TensorUtils::getDescribe(tensor)->extra.offset; ``` Metal Op的特定参数等可以通过`id`存储。buffer数据类型可以与tensor不同,buffer甚至可以混合多种数据类型,只需保证创建时指定了正确的长度即可。例如: @@ -297,20 +304,11 @@ auto buffer = [context newDeviceBuffer:2 * sizeof(int) + 2 * sizeof(__fp16) acce 一般而言,heap只会与**CPUTransparent**一起使用。_heap实际只在iOS 10+上有效,iOS 9-上会回退到device上。_ -使用Metal时,**如非特殊情况,禁止自行创建device和library**。加载library、编译function都是耗时行为,**MNNMetalContext**上做了必要的缓存优化。通过context执行Metal的示例如下: -```cpp -auto context = (__bridge MNNMetalContext *)backend->context(); -auto kernel = /* metal kernel name NSString */; -auto encoder = [context encoder]; -auto bandwidth = [context load:kernel encoder:encoder]; -/* encoder set buffer(s)/sampler(s) */ -[context dispatchEncoder:encoder - threads:{x, y, z} - maxThreadsPerGroup:maxThreadsPerThreadgroup]; // recommended way to dispatch -[encoder endEncoding]; -``` +Metal 内存布局与CPU-FP32-Neon一致,在 Tensor 的 dimentionFormat 为 NC4HW4 时,使用 C4NHW4的排布。否则按默认线性布局。 + + +- 注册实现类 -4. 注册实现类 ```cpp class MetalMyCustomOpCreator : public MetalBackend::Creator { public: @@ -322,7 +320,11 @@ public: REGISTER_METAL_OP_CREATOR(MetalMyCustomOpCreator, OpType_MyCustomOp); ``` -添加注册代码后,重新运行一下 CMake ,自动变更注册文件 +- 工程更新 + +进入 source/backend/metal 目录,执行 [ python3 MetalCodeGen.py . ] ,更新自注册文件 + +重新运行一下 CMake ,或者手动在Xcode工程中新加文件 ### 添加Vulkan实现 1. 添加Shader diff --git a/docs/contribute/pic1.png b/docs/contribute/pic1.png new file mode 100644 index 000000000..5cfd6fbc6 Binary files /dev/null and b/docs/contribute/pic1.png differ diff --git a/docs/contribute/pic2.png b/docs/contribute/pic2.png new file mode 100644 index 000000000..12363e5a0 Binary files /dev/null and b/docs/contribute/pic2.png differ diff --git a/docs/faq.md b/docs/faq.md index c25515abc..d21236df3 100644 --- a/docs/faq.md +++ b/docs/faq.md @@ -208,7 +208,7 @@ OpenCL / Vulkan 采用静态变量自注册的方式往 MNN 主库注册后端. ``` -### 部分模型用 MNNV2Basic 运行出现段错误 +### 部分模型用 MNNV2Basic 运行出现段错误,或报 Interpreter don't support case for shape compute need input content, please use module api instead - 模型不满足运行条件 - MNNV2Basic 使用 Interpreter + Session 方式运行,此类运行方式要求模型满足一定条件,否则无法运行模型或产生特别的 crash ,条件如下: diff --git a/docs/start/demo.md b/docs/start/demo.md index d76166811..376054f0b 100644 --- a/docs/start/demo.md +++ b/docs/start/demo.md @@ -6,7 +6,7 @@ 代码位置:`demo/exec/multiPose.cpp` 1. 下载原始的Tensorflow模型 [pose model](https://github.com/czy2014hust/posenet-python/raw/master/models/model-mobilenet_v1_075.pb) -2. 使用 [模型转换工具](../tools/convert.md) 转换为 MNN 模型 +2. 使用 [模型转换工具](../tools/convert.md) 转换为 MNN 模型,转换时加上参数 --keepInputFormat=0 【把输入由NHWC转换为NC4HW4布局】 3. 执行姿态检测 ```bash ./multiPose.out model.mnn input.png pose.png diff --git a/docs/tools/test.md b/docs/tools/test.md index c4981e2a3..532877f9e 100644 --- a/docs/tools/test.md +++ b/docs/tools/test.md @@ -64,7 +64,7 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms ## ModuleBasic.out ### 功能 -类似`MNNV2Basic.out`,对于带控制流模型,或者多输入多输出的模型,建议采用这个工具 +类似`MNNV2Basic.out`,对于带控制流模型,或者多输入多输出的模型,必须采用这个工具 ### 参数 `./ModuleBasic.out model dir [runMask forwardType runLoops numberThread precision_memory cacheFile]` - `model:str` 模型文件路径 @@ -73,7 +73,7 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms - `forwardType:int` 执行推理的计算设备,有效值为:0(CPU)、1(Metal)、2(CUDA)、3(OpenCL)、6(OpenGL),7(Vulkan) ,9 (TensorRT),可选,默认为`0` - `runLoops:int` 性能测试的循环次数,可选,默认为`0`即不做性能测试 - `numberThread:int` GPU的线程数,可选,默认为`1` -- `precision_memory:int` 测试精度与内存模式,precision_memory % 16 为精度,有效输入为:0(Normal), 1(High), 2(Low), 3(Low_BF16),可选,默认为`2` ; precision_memory / 16 为内存设置,默认为 0 (memory_normal) 。例如测试 memory 为 2(low) ,precision 为 1 (high) 时,设置 precision_memory = 9 (2 * 4 + 1) +- `precision_memory_power:int` 测试精度与内存模式,precision_memory_power % 4 为精度,有效输入为:0(Normal), 1(High), 2(Low), 3(Low_BF16),可选,默认为`0` ; (precision_memory_power / 4 % 4) 为内存设置,默认为 0 (memory_normal) ; (precision_memory_power / 16 % 4) 为功耗设置,默认为 0 (power_normal)。例如测试 memory 为 2(low) ,precision 为 1 (high) ,power 为 0(normal) 时,设置 precision_memory = 9 (2 * 4 + 1 + 0 * 16) ### 默认输出 @@ -82,6 +82,7 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms ### 测试文件夹生成 - 若有原始的tf模型/Onnx模型,可以使用testMNNFromTf.py / testMNNFromOnnx.py / testMNNFromTflite.py 等脚本生成 - 若只有mnn模型,可以用 tools/script/make_test_for_mnn.py 脚本生成测试文件夹,使用方式:mkdir testdir && pythhon3 make_test_for_mnn.py XXX.mnn testdir +- 为了方便模拟应用中的运行性能,可以通过修改测试文件夹下的 input.json ,增加 freq 项,以指定该模型运行的频率(每秒多少次) ### runMask 参数说明 - 1 : 输出推理中间结果,每个算子的输入存到(Input_{op_name}.txt),输出存为({op_name}.txt), 默认输出当前目录的output目录下(使用工具之前要自己建好output目录),不支持与 2 / 4 叠加 @@ -93,6 +94,7 @@ Avg= 5.570600 ms, OpSum = 7.059200 ms min= 3.863000 ms, max= 11.596001 ms - 64 : 创建模型后,clone 出一个新的模型运行,用于测试 clone 功能(主要用于多并发推理)的正确性 - 128 : 使用文件夹下面的 input.mnn 和 output.mnn 做为输入和对比输出,对于数据量较大的情况宜用此方案 - 512 : 开启使用Winograd算法计算卷积时的内存优化,开启后模型的运行时内存会降低,但可能导致性能损失。 +- 1024: 使用动态量化推理时,对输入数据分batch量化以提高模型的推理精度 ### 示例 diff --git a/docs/transformers/diffusion.md b/docs/transformers/diffusion.md index ffe6eb9d0..32d790a26 100644 --- a/docs/transformers/diffusion.md +++ b/docs/transformers/diffusion.md @@ -2,44 +2,77 @@ ## 模型支持与下载 -[Download-runwayml/stable-diffusion-v1-5]: +1. runwayml/stable-diffusion-v1-5 +``` https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main -[Download-IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1]: +``` +2. chilloutmix +``` +https://modelscope.cn/models/wyj123456/chilloutmix +``` +3. IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1 +``` https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/main - +``` ## 模型转换 ### 将Huggingface的Stable Diffusion模型 转为onnx模型 +```sh +cd mnn_path/transformers/diffusion/ python export/onnx_export.py \ --model_path hf_sd_load_path \ --output_path onnx_save_path +``` +注意,上述脚本需要依赖torch/onnx/diffusers等库,可以安装conda环境: +``` +conda env create -f env.yaml +conda activate ldm +``` +在conda环境中执行模型转换脚本 ### 将onnx模型转为mnn模型 新建diffusion mnn模型文件夹,将转好的mnn文件放在该文件夹下。 +1. 实现encoder从onnx模型 -> mnn模型 +``` ./MNNConvert -f ONNX --modelFile onnx_save_path/text_encoder/model.onnx --MNNModel mnn_save_path/text_encoder.mnn --weightQuantBits 8 --bizCode biz +``` +2. 实现denoiser从onnx模型 -> mnn模型 +``` ./MNNConvert -f ONNX --modelFile onnx_save_path/unet/model.onnx --MNNModel mnn_save_path/unet.mnn --transformerFuse --weightQuantBits 8 --bizCode biz +``` +3. 实现decoder从onnx模型 -> mnn模型 +``` ./MNNConvert -f ONNX --modelFile onnx_save_path/vae_decoder/model.onnx --keepInputFormat --MNNModel mnn_save_path/vae_decoder.mnn --weightQuantBits 8 --bizCode biz - +``` ## 编译Diffusion Demo ### Linux/MAC/Windows上 +``` +cd mnn_path +mkdir build +cd build cmake .. -DMNN_BUILD_DIFFUSION=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_SEP_BUILD=OFF -DMNN_SUPPORT_TRANSFORMER_FUSE=ON - +make -j32 +``` ### Android上 -cd project/android/build +``` +cd mnn_path/project/android/build ../build_64.sh -DMNN_BUILD_DIFFUSION=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_SEP_BUILD=OFF -DMNN_SUPPORT_TRANSFORMER_FUSE=ON - +../updateTest.sh +``` ## 运行Diffusion Demo +``` ./diffusion_demo -其中,resource_path 就是mnn模型文件的路径,除了mnn文件,还需要 -(1)将MNN目录transformers/diffusion/scheduler/alphas.txt文件拷贝到该文件夹下。 -(2)针对stable-diffusion-v1-5模型需要将huggingfacetokenizer目录下merges.txt和vocab.json拷贝到该文件夹中。针对Taiyi-Stable-Diffusion模型需要将huggingfacetokenizer目录下vocab.txt拷贝到该文件夹中。 - -model_type是目前支持的两种diffusion模型的类别。如果是stable-diffusion-v1-5模型设为0,如果是Taiyi-Stable-Diffusion模型设为1。 - -output_image_name是生成图片的名字,默认图片位置在当前运行目录下。 - -input_text是文生图的prompt,如果是stable-diffusion-v1-5模型建议英文prompt,如果是Taiyi-Stable-Diffusion建议中文prompt。 +``` +其中,resource_path 就是mnn模型文件的路径,除了mnn文件,还需要: +1. 将MNN目录transformers/diffusion/scheduler/alphas.txt文件拷贝到该文件夹下。 +2. 针对stable-diffusion-v1-5模型需要将huggingfacetokenizer目录下merges.txt和vocab.json拷贝到该文件夹中。 +3. 针对Taiyi-Stable-Diffusion模型需要将huggingfacetokenizer目录下vocab.txt拷贝到该文件夹中。 +4. model_type是目前支持的两种diffusion模型的类别。如果是stable-diffusion-v1-5模型设为0,如果是Taiyi-Stable-Diffusion模型设为1。 +5. output_image_name是生成图片的名字,默认图片位置在当前运行目录下。 +6. input_text是文生图的prompt,如果是stable-diffusion-v1-5模型建议英文prompt,如果是Taiyi-Stable-Diffusion建议中文prompt。 运行指令例如: -./diffusion_demo mnn_save_path 0 demo.jpg "a cute cat" -./diffusion_demo mnn_save_path 1 demo.jpg "一只可爱的猫" - +``` +./diffusion_demo mnn_sd1.5_path 0 demo.jpg "a cute cat" +./diffusion_demo mnn_chilloutmix_path 0 demo.jpg "a pure girl" +./diffusion_demo mnn_taiyi_path 1 demo.jpg "一只可爱的猫" +``` diff --git a/docs/transformers/llm.md b/docs/transformers/llm.md index ea671993b..2358548c6 100644 --- a/docs/transformers/llm.md +++ b/docs/transformers/llm.md @@ -143,6 +143,12 @@ options: - visual_model: 当使用VL模型时,visual_model的实际路径为`base_dir + visual_model`,默认为`base_dir + 'visual.mnn'` - 推理配置 - max_new_tokens: 生成时最大token数,默认为`512` + - reuse_kv: 多轮对话时是否复用之前对话的`kv cache`,默认为`false` + - quant_kv: 存储`kv cache`时是否量化,可选为:`0, 1, 2, 3`,默认为`0`,含义如下: + - 0: key和value都不量化 + - 1: 使用非对称8bit量化存储key + - 2: 使用fp8格式寸处value + - 3: 使用非对称8bit量化存储key,使用fp8格式寸处value - 硬件配置 - backend_type: 推理使用硬件后端类型,默认为:`"cpu"` - thread_num: 推理使用硬件线程数,默认为:`4` diff --git a/express/Executor.cpp b/express/Executor.cpp index 0edb9d6ad..f6b85765c 100644 --- a/express/Executor.cpp +++ b/express/Executor.cpp @@ -17,7 +17,7 @@ #include "core/Backend.hpp" #include "RuntimeAttr.hpp" #include -#define DEFAULT_BACKUP_RUNTIME_KEY (std::make_pair(MNN_FORWARD_CPU, 1)) +#define DEFAULT_BACKUP_RUNTIME_KEY MNN_FORWARD_CPU #ifdef MNN_EXPR_ENABLE_PROFILER #define MNN_EXPRESS_ERROR_REPORT #endif @@ -41,12 +41,14 @@ void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& if(type == MNN_FORWARD_OPENCL || type == MNN_FORWARD_METAL) { info.numThread = 4; } - mAttr->firstType = std::make_pair(type, info.numThread); + mAttr->firstType = type; auto firstIter = mRuntimes.find(mAttr->firstType); if (firstIter == mRuntimes.end()) { info.user = (BackendConfig*)&config; std::shared_ptr bn(creator->onCreate(info)); mRuntimes[mAttr->firstType] = bn; + } else { + firstIter->second->onReset(numberThread, &config); } } else { auto creator = MNNGetExtraRuntimeCreator(type); @@ -58,7 +60,7 @@ void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& MNN_ASSERT(nullptr != creator); Backend::Info info; info.type = type; - mAttr->firstType = std::make_pair(type, numberThread); + mAttr->firstType = type; auto firstIter = mRuntimes.find(mAttr->firstType); if (firstIter == mRuntimes.end()) { info.mode = Backend::Info::DIRECT; @@ -66,6 +68,8 @@ void Executor::setGlobalExecutorConfig(MNNForwardType type, const BackendConfig& info.user = (BackendConfig*)&config; std::shared_ptr bn(creator->onCreate(info)); mRuntimes[mAttr->firstType] = bn; + } else { + firstIter->second->onReset(numberThread, &config); } } _refreshRuntime(); @@ -83,10 +87,10 @@ void Executor::gc(GCFlag flag) { } Executor::Executor(std::shared_ptr backend, MNNForwardType type, int numberThread) { - mRuntimes.insert(std::make_pair(std::make_pair(type, numberThread), backend)); + mRuntimes.insert(std::make_pair(type, backend)); mAttr.reset(new ExecutorAttr); - mAttr->firstType = std::make_pair(type, numberThread); - if (1 != numberThread || MNN_FORWARD_CPU != type) { + mAttr->firstType = type; + if (MNN_FORWARD_CPU != type) { // Create Backup Backend Backend::Info info; info.type = MNN_FORWARD_CPU; @@ -151,7 +155,9 @@ std::shared_ptr Executor::getGlobalExecutor() { info.type = MNN_FORWARD_CPU; info.numThread = 1; std::shared_ptr bn(creator->onCreate(info)); - bn->setAllocatorType(info.allocator); + RuntimeHint hint; + hint.memoryAllocatorType = 0;// Defer + bn->setRuntimeHint(hint); gExecutor = new std::shared_ptr(new Executor(bn, MNN_FORWARD_CPU, 1)); }); return *gExecutor; @@ -178,13 +184,13 @@ void Executor::_refreshRuntime() { mRuntimeInfo.second = mRuntimes[DEFAULT_BACKUP_RUNTIME_KEY]; auto firstIter = mRuntimes.find(getAttr()->firstType); if (firstIter != mRuntimes.end()) { - mRuntimeInfo.first.insert(std::make_pair(firstIter->first.first, firstIter->second)); + mRuntimeInfo.first.insert(std::make_pair(firstIter->first, firstIter->second)); } else { MNN_ASSERT(false); } for (auto& iter : mRuntimes) { - if (iter.first.first != getAttr()->firstType.first) { - mRuntimeInfo.first.insert(std::make_pair(iter.first.first, iter.second)); + if (iter.first != getAttr()->firstType) { + mRuntimeInfo.first.insert(std::make_pair(iter.first, iter.second)); } } } @@ -301,7 +307,7 @@ Executor::RuntimeManager* Executor::RuntimeManager::createRuntimeManager(const S } } compute.user = config.backendConfig; - auto iter = originRt.find(std::make_pair(compute.type, compute.numThread)); + auto iter = originRt.find(compute.type); if (iter == originRt.end()) { auto creator = MNNGetExtraRuntimeCreator(compute.type); if (nullptr == creator) { @@ -312,11 +318,13 @@ Executor::RuntimeManager* Executor::RuntimeManager::createRuntimeManager(const S MNN_ERROR("Can't create Runtime: %s\n", EnumNameForwardType((ForwardType)compute.type)); return nullptr; } - originRt.insert(std::make_pair(std::make_pair(compute.type, compute.numThread), std::shared_ptr(newBn))); + originRt.insert(std::make_pair(compute.type, std::shared_ptr(newBn))); + } else { + iter->second->onReset(compute.numThread, compute.user); } res->mInside->mRuntime.second = originRt[DEFAULT_BACKUP_RUNTIME_KEY]; - res->mInside->mRuntime.first.insert(std::make_pair(compute.type, originRt[std::make_pair(compute.type, compute.numThread)])); - res->mInside->mInfo = originRt[std::make_pair(compute.type, compute.numThread)]; + res->mInside->mRuntime.first.insert(std::make_pair(compute.type, originRt[compute.type])); + res->mInside->mInfo = originRt[compute.type]; res->mInside->mNumberThread = compute.numThread; if (nullptr != config.backendConfig) { res->mInside->mConfig = *config.backendConfig; @@ -586,10 +594,8 @@ void Executor::_makeCache(const std::vector& expr, bool forceCPU) { scheduleInfo.pipelineInfo[0].first.reportError = false; if (forceCPU) { scheduleInfo.pipelineInfo[0].first.info.type = MNN_FORWARD_CPU; - scheduleInfo.pipelineInfo[0].first.info.numThread = 1; } else { - scheduleInfo.pipelineInfo[0].first.info.type = current->getAttr()->firstType.first; - scheduleInfo.pipelineInfo[0].first.info.numThread = current->getAttr()->firstType.second; + scheduleInfo.pipelineInfo[0].first.info.type = current->getAttr()->firstType; } scheduleInfo.pipelineInfo[0].first.needComputeShape = false; scheduleInfo.pipelineInfo[0].first.needComputeGeometry = mLazyMode != LAZY_CONTENT; diff --git a/express/Expr.cpp b/express/Expr.cpp index aa664ad24..be8b01bfa 100644 --- a/express/Expr.cpp +++ b/express/Expr.cpp @@ -206,7 +206,7 @@ EXPRP Expr::create(std::shared_ptr extra, std::vector&& inp expr->mInputs = std::move(inputs); auto exe = ExecutorScope::Current(); expr->mInside->mReq = exe->getRequirement(expr.get()); - if (!(exe->getLazyMode() & Executor::LAZY_COMPUTE_ONCE)) { + if ((!(exe->getLazyMode() & Executor::LAZY_COMPUTE_ONCE)) && exe->lazyEval) { _addLinkForInputs(expr); } return expr; @@ -1228,21 +1228,8 @@ void Variable::save(const std::vector& vars, NetT* dest) { auto des = TensorUtils::getDescribe(tensor); auto describe = std::unique_ptr(new MNN::TensorDescribeT); describe->index = varIndexInfo[expr] + v; - describe->blob = std::unique_ptr(new MNN::BlobT); describe->name = dest->tensorName[subindex]; - auto& blob = describe->blob; - blob->dataFormat = des->dimensionFormat; - if (tensor->getType() == halide_type_of()) { - blob->dataType = DataType_DT_FLOAT; - } else { - SET_TYPE(INT8, int8)} - SET_TYPE(UINT8, uint8)} - SET_TYPE(INT32, int32)} - SET_TYPE(INT64, int64)} - } - for (int d = 0; d < tensor->dimensions();d++) { - describe->blob->dims.push_back(tensor->buffer().dim[d].extent); - } + auto tensorDes = TensorUtils::getDescribe(tensor); if (nullptr != tensorDes->quantAttr) { describe->quantInfo.reset(new TensorQuantInfoT); @@ -1252,6 +1239,20 @@ void Variable::save(const std::vector& vars, NetT* dest) { describe->quantInfo->scale = tensorDes->quantAttr->scale; } if (staticModel) { + describe->blob = std::unique_ptr(new MNN::BlobT); + auto& blob = describe->blob; + blob->dataFormat = des->dimensionFormat; + if (tensor->getType() == halide_type_of()) { + blob->dataType = DataType_DT_FLOAT; + } else { + SET_TYPE(INT8, int8)} + SET_TYPE(UINT8, uint8)} + SET_TYPE(INT32, int32)} + SET_TYPE(INT64, int64)} + } + for (int d = 0; d < tensor->dimensions();d++) { + describe->blob->dims.push_back(tensor->buffer().dim[d].extent); + } for (auto& reg : des->regions) { auto regionT = std::unique_ptr(new MNN::RegionT); regionT->src = std::unique_ptr(new MNN::ViewT); diff --git a/express/RuntimeAttr.hpp b/express/RuntimeAttr.hpp index 3272cde95..21fd54fa0 100644 --- a/express/RuntimeAttr.hpp +++ b/express/RuntimeAttr.hpp @@ -24,7 +24,7 @@ struct RuntimeAttr { }; struct ExecutorAttr { std::shared_ptr constantBackend; - std::pair firstType; + MNNForwardType firstType; std::string externalFile; }; }; diff --git a/express/module/Module.cpp b/express/module/Module.cpp index 00b0a63bc..82172bfbd 100644 --- a/express/module/Module.cpp +++ b/express/module/Module.cpp @@ -32,8 +32,8 @@ static MNN::Express::Executor::RuntimeManager* _createDefaultRuntimeManager(cons sche_config.backendConfig = config->backend->config; } else { auto exe = ExecutorScope::Current(); - sche_config.type = exe->getAttr()->firstType.first; - sche_config.mode = exe->getAttr()->firstType.second; + sche_config.type = exe->getAttr()->firstType; + sche_config.numThread = 1; } return Executor::RuntimeManager::createRuntimeManager(sche_config); } @@ -165,7 +165,7 @@ class NetModule : public Module { setType("Net"); #ifdef MNN_INTERNAL_ENABLED if (nullptr != net) { - mLogInfo = getBasicLoggingData(); + mLogInfo = logBasicInfo(); std::string uuid = std::string(net->mnn_uuid() ? net->mnn_uuid()->c_str() : ""); mLogInfo.emplace("UUID", uuid); mLogInfo.emplace("ModelVersion", info->version); @@ -208,8 +208,8 @@ class NetModule : public Module { auto mModule = mChildren[0]; #ifdef MNN_INTERNAL_ENABLED - auto glo = ExecutorScope::Current(); Timer _time; + auto glo = ExecutorScope::Current(); glo->getDebugTools()->flops = 0.0f; #endif auto outputs = mModule->onForward(inputs); @@ -235,8 +235,10 @@ class NetModule : public Module { metrics.emplace("Memory", std::to_string(memory)); } logAsync(metrics); + MNN_PRINT("Cost time with log: %f\n", (float)_time.durationInUs() / 1000.0f); } while(false); #endif + mModule->clearCache(); return outputs; } diff --git a/express/module/PipelineModule.cpp b/express/module/PipelineModule.cpp index fc2551687..932ae6daa 100644 --- a/express/module/PipelineModule.cpp +++ b/express/module/PipelineModule.cpp @@ -634,11 +634,9 @@ Module* PipelineModule::load(const std::vector& inputs, const std:: modRuntime.compute.type = modRuntime.rt.first.begin()->first; modRuntime.compute.numThread = 1; // set allocator type - modRuntime.rt.first.begin()->second->setAllocatorType(rtMgr->getInside()->modes.memoryAllocatorType); - modRuntime.rt.second->setAllocatorType(rtMgr->getInside()->modes.memoryAllocatorType); + modRuntime.rt.first.begin()->second->setRuntimeHint(rtMgr->getInside()->modes.runtimeHint); // set winograd memory type - modRuntime.rt.first.begin()->second->setWinogradMemoryLevel(rtMgr->getInside()->modes.winogradMemoryUsed); - modRuntime.rt.second->setWinogradMemoryLevel(rtMgr->getInside()->modes.winogradMemoryUsed); + modRuntime.rt.second->setRuntimeHint(rtMgr->getInside()->modes.runtimeHint); } auto& rt = modRuntime.rt; auto firstRt = rt.first[modRuntime.compute.type]; diff --git a/include/MNN/Interpreter.hpp b/include/MNN/Interpreter.hpp index 16344a52b..6debbe3f0 100644 --- a/include/MNN/Interpreter.hpp +++ b/include/MNN/Interpreter.hpp @@ -206,6 +206,20 @@ class MNN_PUBLIC Interpreter { // Geometry Compute option, default is 0xFFFF GEOMETRY_COMPUTE_MASK = 4, + + // 0: Close dynamic quant; 1: per batch quant; 2: per tensor quant + DYNAMIC_QUANT_OPTIONS = 5, + + // For Mobile CPU with big-litter core, set decrease rate to let MNN divide task differential by CPU's performance + // 0-100, 50 means litter core has 50% capacity of large core + // Default is 50 + CPU_LITTLECORE_DECREASE_RATE = 6, + + // 0: Do not quantize kvcache, just store float + // 1: Only quantize key cache, use int8 asymmetric quantization + // 2: Only quantize value cache, use fp8 quantization + // 3: quantize both key and value cache as described above + KVCACHE_QUANT_OPTIONS = 7, }; enum GeometryComputeMask { diff --git a/include/MNN/MNNDefine.h b/include/MNN/MNNDefine.h index b6d6645db..ab84cd8f8 100644 --- a/include/MNN/MNNDefine.h +++ b/include/MNN/MNNDefine.h @@ -69,6 +69,6 @@ MNN_ERROR("Check failed: %s ==> %s\n", #success, #log); \ #define STR(x) STR_IMP(x) #define MNN_VERSION_MAJOR 2 #define MNN_VERSION_MINOR 9 -#define MNN_VERSION_PATCH 2 +#define MNN_VERSION_PATCH 3 #define MNN_VERSION STR(MNN_VERSION_MAJOR) "." STR(MNN_VERSION_MINOR) "." STR(MNN_VERSION_PATCH) #endif /* MNNDefine_h */ diff --git a/include/MNN/expr/Executor.hpp b/include/MNN/expr/Executor.hpp index 3ca0d9e19..3871827c9 100644 --- a/include/MNN/expr/Executor.hpp +++ b/include/MNN/expr/Executor.hpp @@ -136,7 +136,7 @@ class MNN_PUBLIC Executor { void _makeCache(const std::vector& outputs, bool forceCPU); // TODO: Remove mRuntimes, only use mRuntimeInfo - std::map, std::shared_ptr> mRuntimes; + std::map> mRuntimes; RuntimeInfo mRuntimeInfo; std::shared_ptr mDebug; std::map> mSubGraph; diff --git a/project/android/build_64.sh b/project/android/build_64.sh index 8e2039c73..34b18057e 100755 --- a/project/android/build_64.sh +++ b/project/android/build_64.sh @@ -7,7 +7,6 @@ cmake ../../../ \ -DMNN_USE_LOGCAT=false \ -DMNN_BUILD_BENCHMARK=ON \ -DMNN_USE_SSE=OFF \ --DMNN_SUPPORT_BF16=OFF \ -DMNN_BUILD_TEST=ON \ -DANDROID_NATIVE_API_LEVEL=android-21 \ -DMNN_BUILD_FOR_ANDROID_COMMAND=true \ diff --git a/project/ios/MNN.xcodeproj/project.pbxproj b/project/ios/MNN.xcodeproj/project.pbxproj index f31638c61..009adba67 100644 --- a/project/ios/MNN.xcodeproj/project.pbxproj +++ b/project/ios/MNN.xcodeproj/project.pbxproj @@ -268,8 +268,6 @@ 4A224A1427D0C56E000A9260 /* ConvolutionWinogradBridge.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4A224A1027D0C56E000A9260 /* ConvolutionWinogradBridge.hpp */; }; 4A224A1527D0C56E000A9260 /* ConvolutionWinogradImpl.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4A224A1127D0C56E000A9260 /* ConvolutionWinogradImpl.hpp */; }; 4A224A1627D0C56E000A9260 /* ConvolutionWinogradBridge.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4A224A1227D0C56E000A9260 /* ConvolutionWinogradBridge.cpp */; }; - 4A5BEC6026AAB3B30032F6BD /* CommonCompute.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4A5BEC5E26AAB3B20032F6BD /* CommonCompute.hpp */; }; - 4A5BEC6126AAB3B30032F6BD /* MemoryFormater.h in Headers */ = {isa = PBXBuildFile; fileRef = 4A5BEC5F26AAB3B20032F6BD /* MemoryFormater.h */; }; 4A5BEC6426AAB4B30032F6BD /* ModuleTest.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4A5BEC6326AAB4B30032F6BD /* ModuleTest.cpp */; }; 4AF4FB24269ED235005BA97B /* SparseConvInt8TiledExecutor.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 4AF4FB20269ED234005BA97B /* SparseConvInt8TiledExecutor.cpp */; }; 4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 4AF4FB22269ED234005BA97B /* SparseConvInt8TiledExecutor.hpp */; }; @@ -732,8 +730,6 @@ 950B28FA2A0C9AC20002F454 /* CPUScaleInt8.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 950B28F92A0C9AC20002F454 /* CPUScaleInt8.cpp */; }; 950B28FE2A0C9B310002F454 /* MNNScaleAndAddBiasInt8.S in Sources */ = {isa = PBXBuildFile; fileRef = 950B28FD2A0C9B310002F454 /* MNNScaleAndAddBiasInt8.S */; }; 950B29002A0C9B4D0002F454 /* MNNScaleAndAddBiasInt8.S in Sources */ = {isa = PBXBuildFile; fileRef = 950B28FF2A0C9B4D0002F454 /* MNNScaleAndAddBiasInt8.S */; }; - 952298AF2B4D38CB0043978B /* ConvolutionHybrid.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 952298AD2B4D38CB0043978B /* ConvolutionHybrid.cpp */; }; - 952298B02B4D38CB0043978B /* ConvolutionHybrid.hpp in Headers */ = {isa = PBXBuildFile; fileRef = 952298AE2B4D38CB0043978B /* ConvolutionHybrid.hpp */; }; 952298B22B4D39050043978B /* MetalLoop.mm in Sources */ = {isa = PBXBuildFile; fileRef = 952298B12B4D39050043978B /* MetalLoop.mm */; }; 952298B42B4D39260043978B /* MetalArgMax.mm in Sources */ = {isa = PBXBuildFile; fileRef = 952298B32B4D39250043978B /* MetalArgMax.mm */; }; 952298B72B4D4CC80043978B /* CoreMLLayerNorm.cpp in Sources */ = {isa = PBXBuildFile; fileRef = 952298B52B4D4CC80043978B /* CoreMLLayerNorm.cpp */; }; @@ -807,8 +803,6 @@ CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9572A3AA4D4006438F2 /* MNNBilinearLineC8.S */; }; CEE9B95C2A3AA4D4006438F2 /* MNNBilinearSampleC8.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9582A3AA4D4006438F2 /* MNNBilinearSampleC8.S */; }; CEE9B95D2A3AA4D4006438F2 /* MNNCubicSampleC16.S in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B9592A3AA4D4006438F2 /* MNNCubicSampleC16.S */; }; - CEE9B9602A3AA4EF006438F2 /* CPUSoftMaxInt8.hpp in Headers */ = {isa = PBXBuildFile; fileRef = CEE9B95E2A3AA4EF006438F2 /* CPUSoftMaxInt8.hpp */; }; - CEE9B9612A3AA4EF006438F2 /* CPUSoftMaxInt8.cpp in Sources */ = {isa = PBXBuildFile; fileRef = CEE9B95F2A3AA4EF006438F2 /* CPUSoftMaxInt8.cpp */; }; EB45C774244D7C4F00E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S in Sources */ = {isa = PBXBuildFile; fileRef = EB45C773244D7C4F00E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S */; }; EB45C776244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S in Sources */ = {isa = PBXBuildFile; fileRef = EB45C775244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S */; }; EB8D2ABE246A4975009948D1 /* Arm82OpRegister.cpp in Sources */ = {isa = PBXBuildFile; fileRef = EB8D2ABD246A4975009948D1 /* Arm82OpRegister.cpp */; }; @@ -1098,8 +1092,6 @@ 4A224A1027D0C56E000A9260 /* ConvolutionWinogradBridge.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ConvolutionWinogradBridge.hpp; sourceTree = ""; }; 4A224A1127D0C56E000A9260 /* ConvolutionWinogradImpl.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ConvolutionWinogradImpl.hpp; sourceTree = ""; }; 4A224A1227D0C56E000A9260 /* ConvolutionWinogradBridge.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ConvolutionWinogradBridge.cpp; sourceTree = ""; }; - 4A5BEC5E26AAB3B20032F6BD /* CommonCompute.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CommonCompute.hpp; sourceTree = ""; }; - 4A5BEC5F26AAB3B20032F6BD /* MemoryFormater.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = MemoryFormater.h; sourceTree = ""; }; 4A5BEC6326AAB4B30032F6BD /* ModuleTest.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ModuleTest.cpp; sourceTree = ""; }; 4AF4FB20269ED234005BA97B /* SparseConvInt8TiledExecutor.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = SparseConvInt8TiledExecutor.cpp; sourceTree = ""; }; 4AF4FB22269ED234005BA97B /* SparseConvInt8TiledExecutor.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = SparseConvInt8TiledExecutor.hpp; sourceTree = ""; }; @@ -1572,8 +1564,6 @@ 950B28FB2A0C9AD30002F454 /* CPUScaleInt8.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPUScaleInt8.hpp; sourceTree = ""; }; 950B28FD2A0C9B310002F454 /* MNNScaleAndAddBiasInt8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNScaleAndAddBiasInt8.S; sourceTree = ""; }; 950B28FF2A0C9B4D0002F454 /* MNNScaleAndAddBiasInt8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNScaleAndAddBiasInt8.S; sourceTree = ""; }; - 952298AD2B4D38CB0043978B /* ConvolutionHybrid.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = ConvolutionHybrid.cpp; path = compute/ConvolutionHybrid.cpp; sourceTree = ""; }; - 952298AE2B4D38CB0043978B /* ConvolutionHybrid.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; name = ConvolutionHybrid.hpp; path = compute/ConvolutionHybrid.hpp; sourceTree = ""; }; 952298B12B4D39050043978B /* MetalLoop.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalLoop.mm; sourceTree = ""; }; 952298B32B4D39250043978B /* MetalArgMax.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = MetalArgMax.mm; sourceTree = ""; }; 952298B52B4D4CC80043978B /* CoreMLLayerNorm.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CoreMLLayerNorm.cpp; sourceTree = ""; }; @@ -1607,7 +1597,6 @@ C4D4823C27BA2BB40021C2B9 /* CPUDet.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUDet.cpp; sourceTree = ""; }; C4D4823D27BA2BB40021C2B9 /* CPUDet.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPUDet.hpp; sourceTree = ""; }; C4D4824227BA67DE0021C2B9 /* GeometryDet.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryDet.cpp; sourceTree = ""; }; - C4DBB34F27041F9C00ADB16E /* WinogradInt8Helper.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = WinogradInt8Helper.hpp; sourceTree = ""; }; C4EF5FB22657A9E70094235C /* ConvInt8TiledExecutor.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = ConvInt8TiledExecutor.cpp; sourceTree = ""; }; C4EF5FB32657A9E70094235C /* ConvInt8TiledExecutor.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = ConvInt8TiledExecutor.hpp; sourceTree = ""; }; C4F906AF276886040026B847 /* GeometryTopK.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = GeometryTopK.cpp; sourceTree = ""; }; @@ -1650,8 +1639,6 @@ CEE9B9572A3AA4D4006438F2 /* MNNBilinearLineC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearLineC8.S; sourceTree = ""; }; CEE9B9582A3AA4D4006438F2 /* MNNBilinearSampleC8.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNBilinearSampleC8.S; sourceTree = ""; }; CEE9B9592A3AA4D4006438F2 /* MNNCubicSampleC16.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNCubicSampleC16.S; sourceTree = ""; }; - CEE9B95E2A3AA4EF006438F2 /* CPUSoftMaxInt8.hpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.h; path = CPUSoftMaxInt8.hpp; sourceTree = ""; }; - CEE9B95F2A3AA4EF006438F2 /* CPUSoftMaxInt8.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; path = CPUSoftMaxInt8.cpp; sourceTree = ""; }; EB45C773244D7C4F00E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S; sourceTree = ""; }; EB45C775244D7C6600E28F44 /* MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.asm; path = MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S; sourceTree = ""; }; EB8D2ABD246A4975009948D1 /* Arm82OpRegister.cpp */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.cpp; name = Arm82OpRegister.cpp; path = ../arm82/Arm82OpRegister.cpp; sourceTree = ""; }; @@ -1866,10 +1853,8 @@ 488873A8215B639D0079B12E /* source */ = { isa = PBXGroup; children = ( - CE482EF5288536DA007CD935 /* internal */, 4DF87C482887D3560003E2D4 /* calib3d */, 4D4CF4612760946500A36D9F /* imgproc */, - 4A5BEC6226AAB3D70032F6BD /* common */, 4D9A931B26255BDA00F9B43C /* coreml */, 6A131E3C2582331C002EC3D6 /* plugin */, 489D7A152550FDC800AD896A /* metal */, @@ -1934,12 +1919,8 @@ CEE4566A2BC0E23D00F062C1 /* CPUExternalConst.cpp */, 95278CE62B9F0999009E9B29 /* CPUDynamicQuant.cpp */, 95278CE52B9F0999009E9B29 /* CPUDynamicQuant.hpp */, - 952298AD2B4D38CB0043978B /* ConvolutionHybrid.cpp */, - 952298AE2B4D38CB0043978B /* ConvolutionHybrid.hpp */, CE8049A92B31C65B009B422C /* CPULayerNorm.hpp */, 958375342A496E5C007C0A3E /* MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S */, - CEE9B95F2A3AA4EF006438F2 /* CPUSoftMaxInt8.cpp */, - CEE9B95E2A3AA4EF006438F2 /* CPUSoftMaxInt8.hpp */, CE9AFED428E54E3300566949 /* CPUInterp3D.cpp */, CE9AFED528E54E3300566949 /* CPUInterp3D.hpp */, 4DCF538B2892B16300B5B393 /* CPUHistogram.cpp */, @@ -2236,16 +2217,6 @@ path = ../../../test/speed; sourceTree = ""; }; - 4A5BEC6226AAB3D70032F6BD /* common */ = { - isa = PBXGroup; - children = ( - C4DBB34F27041F9C00ADB16E /* WinogradInt8Helper.hpp */, - 4A5BEC5F26AAB3B20032F6BD /* MemoryFormater.h */, - 4A5BEC5E26AAB3B20032F6BD /* CommonCompute.hpp */, - ); - path = common; - sourceTree = ""; - }; 4D4CF4612760946500A36D9F /* imgproc */ = { isa = PBXGroup; children = ( @@ -2913,19 +2884,16 @@ CEA82BDC2A15F8AD002CBC95 /* IdstConvolutionInt8.hpp in Headers */, 4DE4E82C275E307B0016A916 /* cv in Headers */, 1F501F842397BA5B004E8721 /* ImageProcess.hpp in Headers */, - CECF8C5D299CACFD00D3875B /* Log.hpp in Headers */, 1F501F822397BA5B004E8721 /* Interpreter.hpp in Headers */, C4F906B327688C3A0026B847 /* NMSModule.hpp in Headers */, 1F501F882397BA5B004E8721 /* Tensor.hpp in Headers */, 1F501F872397BA5B004E8721 /* Matrix.h in Headers */, CE8049AC2B31C65B009B422C /* CPULayerNorm.hpp in Headers */, - CECF8C5A299CACFD00D3875B /* WorkerThread.hpp in Headers */, 48C84B85250F711700EE7666 /* IfModule.hpp in Headers */, 4D9A937326255BDA00F9B43C /* CoreMLUnary.hpp in Headers */, 48C84B98250F71E900EE7666 /* CPUSoftmax.hpp in Headers */, 4882C8B8241A22B800DAC168 /* OpCommonUtils.hpp in Headers */, 48608B54250632EC00CB1D71 /* GeometryComputer.hpp in Headers */, - CECF8C7A299CAD9400D3875B /* sha1.h in Headers */, 4894C6EC27016F7200D8BE79 /* CPUResizeCache.hpp in Headers */, 92FF04A623AA0BFB00AC97F6 /* FileLoader.hpp in Headers */, 48F34733273A7C8400C45394 /* ImageProcessFunction.hpp in Headers */, @@ -2935,12 +2903,10 @@ 482BFBCF28351BA1009210E4 /* AllShader.hpp in Headers */, 4896D36A25FE2A3D00717702 /* Arm82Unary.hpp in Headers */, 1F501F862397BA5B004E8721 /* Rect.h in Headers */, - CEE9B9602A3AA4EF006438F2 /* CPUSoftMaxInt8.hpp in Headers */, 1F501F8B2397BA5B004E8721 /* MNNSharedContext.h in Headers */, 48925F352744AC0700919B37 /* CPUROIAlign.hpp in Headers */, 92FF029623AA0B5A00AC97F6 /* CPUCast.hpp in Headers */, 4D9A937826255BDA00F9B43C /* CoreMLBinary.hpp in Headers */, - CECF8C85299CAD9400D3875B /* log_util.h in Headers */, 4D6D7FD52656896600F80814 /* DenseConvolutionTiledExecutor.hpp in Headers */, 4D9A936626255BDA00F9B43C /* CoreMLExecutor.h in Headers */, 92FF027A23AA0B5A00AC97F6 /* CPUPool.hpp in Headers */, @@ -2949,7 +2915,6 @@ 1F501F802397BA5B004E8721 /* MNNDefine.h in Headers */, 19D0FE76285C66F200B74B1A /* MetalLayerNorm.hpp in Headers */, 489D7A682550FDC800AD896A /* MetalReduction.hpp in Headers */, - CECF8C86299CAD9400D3875B /* sds.h in Headers */, 1F501F7F2397BA5B004E8721 /* HalideRuntime.h in Headers */, 92FF029E23AA0B5A00AC97F6 /* CPUDeconvolutionDepthwise.hpp in Headers */, 4D9A935B26255BDA00F9B43C /* NeuralNetwork.pb-c.h in Headers */, @@ -2970,10 +2935,8 @@ 481C2DEE25FE2CD6001ED6DF /* Arm82Functions.hpp in Headers */, 4894C6EA27016F7200D8BE79 /* UnaryUtils.hpp in Headers */, EBD4842A2485FF650083CE95 /* Arm82Interp.hpp in Headers */, - CECF8C81299CAD9400D3875B /* log_util_imp.h in Headers */, 92FF037623AA0B5A00AC97F6 /* CPUBinary.hpp in Headers */, 4D9A935826255BDA00F9B43C /* FeatureTypes.pb-c.h in Headers */, - CECF8C7C299CAD9400D3875B /* hmac-sha.h in Headers */, 48608B53250632EC00CB1D71 /* GeometryComputerUtils.hpp in Headers */, 950B28F529F629A90002F454 /* CPUBinaryInt8.hpp in Headers */, 489D7A732550FDC800AD896A /* MetalBackend.hpp in Headers */, @@ -2996,7 +2959,6 @@ 4DF87C522887D3F20003E2D4 /* CPUSvd.hpp in Headers */, 48747D4B245D9D24000B9709 /* RuntimeFactory.hpp in Headers */, 92FF03B323AA0B5A00AC97F6 /* ConvolutionDepthwise3x3.hpp in Headers */, - CECF8C77299CAD9400D3875B /* log_builder.h in Headers */, 4D9A937226255BDA00F9B43C /* CoreMLConvolution.hpp in Headers */, 92FF038B23AA0B5A00AC97F6 /* CPUUnravelIndex.hpp in Headers */, 4AF4FB26269ED235005BA97B /* SparseConvInt8TiledExecutor.hpp in Headers */, @@ -3008,7 +2970,6 @@ 92FF028C23AA0B5A00AC97F6 /* CPUReduction.hpp in Headers */, 92FF03B923AA0B5A00AC97F6 /* ConvOpt.h in Headers */, 92FF04AB23AA0BFB00AC97F6 /* Pipeline.hpp in Headers */, - 952298B02B4D38CB0043978B /* ConvolutionHybrid.hpp in Headers */, 489D7A6E2550FDC800AD896A /* MetalROIPooling.hpp in Headers */, 4882C8B9241A22B800DAC168 /* ConvolutionCommon.hpp in Headers */, 92FF03AE23AA0B5A00AC97F6 /* ConvolutionIntFactory.hpp in Headers */, @@ -3035,7 +2996,6 @@ 92FF03CA23AA0B5A00AC97F6 /* CPUConvolutionDepthwise.hpp in Headers */, 92FF04A923AA0BFB00AC97F6 /* Schedule.hpp in Headers */, 489D7A9F2550FDC900AD896A /* MetalConvolutionCommon.hpp in Headers */, - CECF8C80299CAD9400D3875B /* lz4.h in Headers */, 92FF028623AA0B5A00AC97F6 /* CPUDeconvolution.hpp in Headers */, 489D7A722550FDC800AD896A /* MetalReLU6.hpp in Headers */, 92FF04B523AA0BFB00AC97F6 /* TensorUtils.hpp in Headers */, @@ -3056,7 +3016,6 @@ 4A224A1427D0C56E000A9260 /* ConvolutionWinogradBridge.hpp in Headers */, 4D9A935926255BDA00F9B43C /* DataStructures.pb-c.h in Headers */, 489D7A972550FDC900AD896A /* MetalConvolutionDepthwise.hpp in Headers */, - 4A5BEC6126AAB3B30032F6BD /* MemoryFormater.h in Headers */, 489D7AB42550FDC900AD896A /* MetalBinary.hpp in Headers */, 92FF04AF23AA0BFB00AC97F6 /* Macro.h in Headers */, 4D9A936C26255BDA00F9B43C /* CoreMLRaster.hpp in Headers */, @@ -3088,24 +3047,20 @@ 92FF03A623AA0B5A00AC97F6 /* ConvolutionTiledExecutor.hpp in Headers */, 92FF036523AA0B5A00AC97F6 /* CPUResize.hpp in Headers */, 92FF04B423AA0BFB00AC97F6 /* MNNMemoryUtils.h in Headers */, - CECF8C88299CAD9400D3875B /* log_api.h in Headers */, 4A224A0D27D0C2D9000A9260 /* ConvolutionPackWinograd.hpp in Headers */, 4A224A0E27D0C2D9000A9260 /* ConvolutionPackFreeWinograd.hpp in Headers */, 4D9A937426255BDA00F9B43C /* CoreMLReduction.hpp in Headers */, 48C84B8B250F711700EE7666 /* PipelineModule.hpp in Headers */, F41497D7278D8A21004A363A /* RuntimeAttr.hpp in Headers */, - CECF8C5B299CACFD00D3875B /* LogHelper.hpp in Headers */, 92FF04C123AA0BFB00AC97F6 /* Backend.hpp in Headers */, 482BFBCD28351BA1009210E4 /* ShaderMap.hpp in Headers */, 489D7A812550FDC900AD896A /* MetalPooling.hpp in Headers */, - CECF8C7F299CAD9400D3875B /* md5.h in Headers */, 92FF02A623AA0B5A00AC97F6 /* CPUQuantizedMaxPool.hpp in Headers */, 92FF028023AA0B5A00AC97F6 /* CPUFloatToInt8.hpp in Headers */, 92FF028723AA0B5A00AC97F6 /* CPUFixedPoint.hpp in Headers */, C43C8227251894F400A0FF84 /* Vec.hpp in Headers */, 4819FB1D24C138DF0050BD09 /* GeometryConvUtils.hpp in Headers */, 489D7A952550FDC900AD896A /* MetalMatMul.hpp in Headers */, - CECF8C83299CAD9400D3875B /* log_define.h in Headers */, C48CAE2628900C4A00271A6D /* ConvInt8Winograd.hpp in Headers */, 48F34730273A7C7300C45394 /* CPUImageProcess.hpp in Headers */, 489D7A702550FDC800AD896A /* MetalRaster.hpp in Headers */, @@ -3120,7 +3075,6 @@ 92FF038C23AA0B5A00AC97F6 /* CPUEltwise.hpp in Headers */, 92FF028823AA0B5A00AC97F6 /* CPUDequantize.hpp in Headers */, 481C2DF125FE2CD6001ED6DF /* Arm82OptFunc.hpp in Headers */, - 4A5BEC6026AAB3B30032F6BD /* CommonCompute.hpp in Headers */, C43C8225251894F400A0FF84 /* WingoradGenerater.hpp in Headers */, ); runOnlyForDeploymentPostprocessing = 0; @@ -3290,7 +3244,6 @@ 48FA474623AA127B00172C3B /* NeuralNetWorkOp.cpp in Sources */, 4D9A936E26255BDA00F9B43C /* CoreMLArgMax.cpp in Sources */, 92FF02F423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */, - CEE9B9612A3AA4EF006438F2 /* CPUSoftMaxInt8.cpp in Sources */, 482BFBCE28351BA1009210E4 /* ShaderMap.cpp in Sources */, 92FF038623AA0B5A00AC97F6 /* CPULinSpace.cpp in Sources */, 4819FB2D24C1396A0050BD09 /* GeometryConv2D.cpp in Sources */, @@ -3328,7 +3281,6 @@ 489D7A8A2550FDC900AD896A /* MetalConvolutionDepthwise.mm in Sources */, 48123003269EA83400EB7ABA /* ShapeUnique.cpp in Sources */, 92FF037D23AA0B5A00AC97F6 /* CPURelu.cpp in Sources */, - CECF8C5E299CACFD00D3875B /* WorkerThread.cpp in Sources */, 489D7A842550FDC900AD896A /* MetalBinary.mm in Sources */, 48747D6B245D9E33000B9709 /* GeometryFill.cpp in Sources */, 4819FB1F24C138DF0050BD09 /* GeometryConvUtils.cpp in Sources */, @@ -3428,7 +3380,6 @@ 48F34734273A7C8400C45394 /* ImageProcessFunction.cpp in Sources */, 6A131E4025823349002EC3D6 /* PluginKernel.cpp in Sources */, 48958781268EBA6F00EA01A7 /* CPUSegmentMean.cpp in Sources */, - CECF8C7B299CAD9400D3875B /* sha1.c in Sources */, 4D9A937026255BDA00F9B43C /* CoreMLUnary.cpp in Sources */, 92FF04A823AA0BFB00AC97F6 /* AutoTime.cpp in Sources */, 92FF04AE23AA0BFB00AC97F6 /* Backend.cpp in Sources */, @@ -3483,7 +3434,6 @@ 92FF03CE23AA0B5A00AC97F6 /* CPUOPRegister.cpp in Sources */, 92FF02B323AA0B5A00AC97F6 /* CPUInstanceNorm.cpp in Sources */, 4819FB2C24C1396A0050BD09 /* GeometryPoolGrad.cpp in Sources */, - CECF8C7E299CAD9400D3875B /* log_builder.cpp in Sources */, 92FF042223AA0B7100AC97F6 /* ShapeConcat.cpp in Sources */, 4D6D7FD12656891400F80814 /* MNNPackedSparseMatMulEpx4.S in Sources */, 4D5662CC299B76ED0031C1A1 /* MNNMaxPoolInt8.S in Sources */, @@ -3491,7 +3441,6 @@ 4844603D2726558B00F7EABA /* MNNConvWinoSourceTransformUnit6x6FP16.S in Sources */, 92FF044A23AA0B7100AC97F6 /* ShapeConvolution.cpp in Sources */, 11A01A0D258785FB00745FA7 /* MNNVectorTop1Int32.S in Sources */, - 952298AF2B4D38CB0043978B /* ConvolutionHybrid.cpp in Sources */, 92FF026A23AA0B5A00AC97F6 /* CPUNonMaxSuppressionV2.cpp in Sources */, 92FF045123AA0B7100AC97F6 /* ShapeArgMax.cpp in Sources */, 48F9E54E2493A0A800E46522 /* MNNPackC4ForMatMul_A.S in Sources */, @@ -3563,7 +3512,6 @@ 4D759B2C25FF89EE0037B0B6 /* GeometryShape.cpp in Sources */, 11A01A07258785EA00745FA7 /* MNNVectorTop1Float.S in Sources */, 48747D6E245D9E33000B9709 /* GeometrySlice.cpp in Sources */, - CECF8C7D299CAD9400D3875B /* md5.c in Sources */, 92FF041923AA0B7100AC97F6 /* ShapeQuantizedMaxPool.cpp in Sources */, 92FF038A23AA0B5A00AC97F6 /* CPURange.cpp in Sources */, CE125CC92A52BF6B003698C9 /* MNNBilinearLineC8.S in Sources */, @@ -3621,10 +3569,8 @@ 92FF042E23AA0B7100AC97F6 /* ShapeProposal.cpp in Sources */, 92FF025923AA0B5A00AC97F6 /* CPUPoolInt8.cpp in Sources */, 92FF045B23AA0B7100AC97F6 /* ShapeShape.cpp in Sources */, - CECF8C87299CAD9400D3875B /* sds.c in Sources */, 9560EAD62BDE426A00C8D0B6 /* GeometryLayernorm.cpp in Sources */, 4D6D7FD72656896D00F80814 /* SparseConvolutionTiledExecutor.cpp in Sources */, - CECF8C82299CAD9400D3875B /* log_api.cpp in Sources */, 92FF03A823AA0B5A00AC97F6 /* WinogradOptFunction.cpp in Sources */, 950B28E229F627E00002F454 /* MNNBinarySubInt8.S in Sources */, 950B28F029F627F70002F454 /* MNNBinarySubInt8.S in Sources */, @@ -3634,7 +3580,6 @@ 4D9A936026255BDA00F9B43C /* Model.pb-c.c in Sources */, CE9AFED628E54E3300566949 /* CPUInterp3D.cpp in Sources */, C4F906B427688C3A0026B847 /* NMSModule.cpp in Sources */, - CECF8C64299CAD8400D3875B /* LogHelper.mm in Sources */, 48FA474523AA127B00172C3B /* Executor.cpp in Sources */, 92FF02EA23AA0B5A00AC97F6 /* MNNGemmInt8AddBiasScale_16x4_Unit.S in Sources */, 48A8A61A21D101DE00C2B9A7 /* Matrix_CV.cpp in Sources */, @@ -3660,7 +3605,6 @@ 92FF027F23AA0B5A00AC97F6 /* CPUDeconvolutionDepthwise.cpp in Sources */, EBECA3A724643D5D0062C7A3 /* MNNQuantizeFP16_UNIT4.S in Sources */, 92FF04A423AA0BFB00AC97F6 /* Interpreter.cpp in Sources */, - CECF8C5C299CACFD00D3875B /* Log.cpp in Sources */, 92FF045623AA0B7100AC97F6 /* ShapeReshape.cpp in Sources */, 92FF032523AA0B5A00AC97F6 /* MNNConvDwF23SourceTransUnit.S in Sources */, 92FF044423AA0B7100AC97F6 /* ShapeLSTM.cpp in Sources */, @@ -3697,7 +3641,6 @@ 92FF02B623AA0B5A00AC97F6 /* CPUUnary.cpp in Sources */, 92FF032723AA0B5A00AC97F6 /* MNNDeconvRunForUnitDepthWise.S in Sources */, CE7DC00028E2DE6B00797689 /* ShapeConvTranspose3D.cpp in Sources */, - CECF8C78299CAD9400D3875B /* log_util_imp.cpp in Sources */, 92FF02CA23AA0B5A00AC97F6 /* MNNUnPackC4.S in Sources */, 952298B22B4D39050043978B /* MetalLoop.mm in Sources */, 48925F372744AC2A00919B37 /* ShapeROIAlign.cpp in Sources */, @@ -3723,13 +3666,11 @@ 92FF02FF23AA0B5A00AC97F6 /* MNNFloat2Int8.S in Sources */, 4D9A937926255BDA00F9B43C /* CoreMLRaster.cpp in Sources */, 48417FF224D13BF50056D9A7 /* GeometrySelect.cpp in Sources */, - CECF8C84299CAD9400D3875B /* lz4.c in Sources */, 489D7A7E2550FDC900AD896A /* MNNMetalContext.mm in Sources */, 92FF033423AA0B5A00AC97F6 /* MNNUInt8ToInt16WithOffsetC4Common.S in Sources */, 92FF036B23AA0B5A00AC97F6 /* CPUResize.cpp in Sources */, 92FF02C723AA0B5A00AC97F6 /* MNNCopyC4WithStride.S in Sources */, 92FF030923AA0B5A00AC97F6 /* MNNNV21ToBGRUnit.S in Sources */, - CECF8C79299CAD9400D3875B /* hmac-sha.cpp in Sources */, 92FF032623AA0B5A00AC97F6 /* MNNWinogradMatrixProductLeft.S in Sources */, 92FF04C023AA0BFB00AC97F6 /* Tensor.cpp in Sources */, CEE9B95B2A3AA4D4006438F2 /* MNNBilinearLineC8.S in Sources */, @@ -4127,7 +4068,7 @@ CODE_SIGN_STYLE = Automatic; DEAD_CODE_STRIPPING = YES; DEFINES_MODULE = YES; - DEVELOPMENT_TEAM = Q48UX93J22; + DEVELOPMENT_TEAM = 6G7464HHUS; DYLIB_COMPATIBILITY_VERSION = 1; DYLIB_CURRENT_VERSION = 1; DYLIB_INSTALL_NAME_BASE = "@rpath"; @@ -4202,7 +4143,7 @@ IPHONEOS_DEPLOYMENT_TARGET = 9.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; - PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3; + PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3vj; PRODUCT_NAME = "$(TARGET_NAME)"; TARGETED_DEVICE_FAMILY = "1,2"; }; @@ -4214,7 +4155,7 @@ ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon; ASSETCATALOG_COMPILER_LAUNCHIMAGE_NAME = LaunchImage; CODE_SIGN_STYLE = Automatic; - DEVELOPMENT_TEAM = Q48UX93J22; + DEVELOPMENT_TEAM = 6G7464HHUS; GCC_ENABLE_CPP_EXCEPTIONS = NO; GCC_ENABLE_CPP_RTTI = NO; HEADER_SEARCH_PATHS = ( @@ -4229,7 +4170,7 @@ IPHONEOS_DEPLOYMENT_TARGET = 9.0; LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks"; OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)"; - PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3; + PRODUCT_BUNDLE_IDENTIFIER = com.taobao.mnn.abcde3vj; PRODUCT_NAME = "$(TARGET_NAME)"; TARGETED_DEVICE_FAMILY = "1,2"; }; @@ -4245,7 +4186,7 @@ CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_TEAM = Q48UX93J22; + DEVELOPMENT_TEAM = 6G7464HHUS; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_FILE = demo/Info.plist; INFOPLIST_KEY_NSCameraUsageDescription = "use camera to capture photo for demo"; @@ -4278,7 +4219,7 @@ CLANG_WARN_QUOTED_INCLUDE_IN_FRAMEWORK_HEADER = YES; CODE_SIGN_STYLE = Automatic; CURRENT_PROJECT_VERSION = 1; - DEVELOPMENT_TEAM = Q48UX93J22; + DEVELOPMENT_TEAM = 6G7464HHUS; GENERATE_INFOPLIST_FILE = YES; INFOPLIST_FILE = demo/Info.plist; INFOPLIST_KEY_NSCameraUsageDescription = "use camera to capture photo for demo"; @@ -4343,4 +4284,3 @@ }; rootObject = 0F1465AE1FA18D1000F9860A /* Project object */; } - diff --git a/project/ios/Playground/AppDelegate.mm b/project/ios/Playground/AppDelegate.mm index f01ffb6ef..d073b12a8 100644 --- a/project/ios/Playground/AppDelegate.mm +++ b/project/ios/Playground/AppDelegate.mm @@ -12,35 +12,33 @@ #include #import #import "benchmark.h" - +#define TEST_WORKMODE 0 @implementation AppDelegate - (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:(NSDictionary *)launchOptions { -//#define UNITTEST -//#ifdef UNITTEST -// // unittest -// { -// MNN::BackendConfig config; -// // If want to test metal, change MNN_FORWARD_CPU to MNN_FORWARD_METAL -// MNN::Express::Executor::getGlobalExecutor()->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 1); -// int precisionInTestUtil = -// getTestPrecision(MNN_FORWARD_CPU, config.precision, MNN::Express::Executor::getGlobalExecutor()->getCurrentRuntimeStatus(MNN::STATUS_SUPPORT_FP16)); -// MNNTestSuite::runAll(precisionInTestUtil); -// } -//#endif -//#ifdef BENCHMARK -// // benchmark -// { -// auto bundle = CFBundleGetMainBundle(); -// auto url = CFBundleCopyBundleURL(bundle); -// auto string = CFURLCopyFileSystemPath(url, kCFURLPOSIXPathStyle); -// CFRelease(url); -// auto cstring = CFStringGetCStringPtr(string, kCFStringEncodingUTF8); -// auto res = std::string(cstring) + "/models"; -// CFRelease(string); -// iosBenchAll(res.c_str()); -// } -//#endif +#if TEST_WORKMODE==0 + // unittest + { + MNN::BackendConfig config; + // If want to test metal, change MNN_FORWARD_CPU to MNN_FORWARD_METAL + MNN::Express::Executor::getGlobalExecutor()->setGlobalExecutorConfig(MNN_FORWARD_CPU, config, 1); + MNNTestSuite::runAll(2); + } +#endif +#if TEST_WORKMODE==1 + // benchmark + { + auto bundle = CFBundleGetMainBundle(); + auto url = CFBundleCopyBundleURL(bundle); + auto string = CFURLCopyFileSystemPath(url, kCFURLPOSIXPathStyle); + CFRelease(url); + auto cstring = CFStringGetCStringPtr(string, kCFStringEncodingUTF8); + auto res = std::string(cstring) + "/models"; + CFRelease(string); + iosBenchAll(res.c_str()); + } +#endif +#if TEST_WORKMODE==2 auto bundle = CFBundleGetMainBundle(); auto url = CFBundleCopyBundleURL(bundle); auto string = CFURLCopyFileSystemPath(url, kCFURLPOSIXPathStyle); @@ -48,11 +46,10 @@ - (BOOL)application:(UIApplication *)application didFinishLaunchingWithOptions:( auto cstring = CFStringGetCStringPtr(string, kCFStringEncodingUTF8); auto res = std::string(cstring) + "/models/mobilenet_v2_auth.mnn"; - MNN::Interpreter* interpreter = MNN::Interpreter::createFromFile(res.c_str()); MNN::ScheduleConfig config; interpreter->createSession(config); - +#endif return YES; } diff --git a/pymnn/pip_package/pyproject.toml b/pymnn/pip_package/pyproject.toml index 25fc9d331..c178a4ebc 100644 --- a/pymnn/pip_package/pyproject.toml +++ b/pymnn/pip_package/pyproject.toml @@ -16,7 +16,7 @@ test-skip = [ ] test-requires = [ "opencv-python==4.6.0.66", - "numpy==1.13.3", + "numpy", "torch" ] test-command = [ diff --git a/pymnn/src/llm.h b/pymnn/src/llm.h index 8e9fffcfd..3ade7a17f 100644 --- a/pymnn/src/llm.h +++ b/pymnn/src/llm.h @@ -1,8 +1,8 @@ -#include "llm.hpp" +#include "llm/llm.hpp" typedef struct { PyObject_HEAD - Llm* llm; + MNN::Transformer::Llm* llm; } LLM; static PyObject* PyMNNLLM_new(struct _typeobject *type, PyObject *args, PyObject *kwds) { @@ -38,7 +38,7 @@ static PyObject* PyMNNLLM_response(LLM *self, PyObject *args) { if (!PyArg_ParseTuple(args, "s|p", &query, &stream)) { Py_RETURN_NONE; } - LlmStreamBuffer buffer(nullptr); + MNN::Transformer::LlmStreamBuffer buffer(nullptr); std::ostream null_os(&buffer); auto res = self->llm->response(query, stream ? &std::cout : &null_os); return string2Object(res); @@ -104,10 +104,10 @@ static PyObject* PyMNNLLM_create(PyObject *self, PyObject *args) { if (!llm) { return NULL; } - llm->llm = Llm::createLLM(path); + llm->llm = MNN::Transformer::Llm::createLLM(path); return (PyObject*)llm; } static PyMethodDef PyMNNLLM_static_methods[] = { {"create", PyMNNLLM_create, METH_VARARGS} -}; \ No newline at end of file +}; diff --git a/source/backend/arm82/Arm82Functions.cpp b/source/backend/arm82/Arm82Functions.cpp index 435d369d4..19038ec94 100644 --- a/source/backend/arm82/Arm82Functions.cpp +++ b/source/backend/arm82/Arm82Functions.cpp @@ -45,14 +45,12 @@ void MNNAbsMaxFP16(const float* source, float* absmax, size_t src_depth_quad, si void MNNQuantScaleFP16(float* sum, float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch); void MNNDynamicQuantFP16(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack); void MNNQuantSumFP16(float* sum, const float* dequant_scale, size_t thread, size_t batch); -#if defined(__aarch64__) -void MNNGemmHybridInt8FP16_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); -void MNNGemmHybridInt4FP16_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); -void MNNGemmHybridInt4FP16_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); -void MNNGemmHybridInt8FP16_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); #endif +#if defined(__aarch64__) +void CountMinMaxValue_FP16(float* source, float* minVal, float* maxVal, size_t sizeQuad); +void MNNSumByAxisLForMatmul_A_ARM86(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); +void MNNSumByAxisLForMatmul_A_ARM82(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); #endif - void MNNConvDwF23MulTransUnitFP16(FLOAT16 **cacheLine, const FLOAT16 *weight, FLOAT16 *dest, size_t ow); void MNNConvDwF23SourceTransUnitFP16(const FLOAT16 *source, FLOAT16 *dest, size_t unit); @@ -82,6 +80,32 @@ static void MNNMatrixSubFP16(FLOAT16* C, const FLOAT16* A, const FLOAT16* B, siz } } } +#if defined(__aarch64__) +static void ARM82CountMinMaxValue(float* source, float* minVal, float* maxVal, size_t size) { + if (size % 8 == 0) { + CountMinMaxValue_FP16(source, minVal, maxVal, size / 8); + } else { + auto remain = size - 8 * (size / 8); + auto max_ = ((__fp16*)source)[0]; + auto min_ = max_; + if (size >= 8) { + CountMinMaxValue_FP16(source, minVal, maxVal, size / 8); + max_ = ((__fp16*)maxVal)[0]; + min_ = ((__fp16*)minVal)[0]; + } + if (remain > 0) { + int16_t tmp[8] = {0}; + auto srcRemain = reinterpret_cast(source) + 8 * (size / 8) * 2; + ::memcpy(tmp, srcRemain, remain * 2); + CountMinMaxValue_FP16((float*)tmp, (float*)tmp, (float*)((uint8_t*)tmp + 2), 1); + max_ = ALIMAX(((__fp16*)tmp)[1], max_); + min_ = ALIMIN(((__fp16*)tmp)[0], min_); + } + reinterpret_cast<__fp16*>(minVal)[0] = min_; + reinterpret_cast<__fp16*>(maxVal)[0] = max_; + } +} +#endif static void Arm82MNNPackForMatMul_B(float* destC, const float* sourceC, size_t h, size_t l, bool transpose) { auto dest = (int16_t*)destC; @@ -686,6 +710,9 @@ bool Arm82Functions::init() { FUNC_PTR_ASSIGN(gInstance->MNNMatrixSub, MNNMatrixSubFP16); FUNC_PTR_ASSIGN(gInstance->MNNMatrixAdd, MNNMatrixAddFP16); FUNC_PTR_ASSIGN(gInstance->MNNStrassenMergeCFunction, ARM82StrassenMerge); +#ifdef MNN_LOW_MEMORY + FUNC_PTR_ASSIGN(gInstance->MNNDynamicUpdateConvBiasScale, origin->MNNDynamicUpdateConvBiasScale); +#endif gInstance->penalty = 2.0f; FUNC_PTR_ASSIGN(gInstance->MNNScaleAndAddBias, MNNScaleAndAddBiasFP16); FUNC_PTR_ASSIGN(gInstance->MNNGridSampleComputeCord, MNNGridSampleComputeCordFP16); @@ -702,28 +729,30 @@ bool Arm82Functions::init() { // MatMul FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul, MNNPackedMatMulFP16); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain, MNNPackedMatMulRemainFP16); +#if defined(__aarch64__) #ifdef MNN_LOW_MEMORY + // Weight Dequant Gemm Kernels FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int4, MNNPackedMatMulFP16_int4); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int4, MNNPackedMatMulRemainFP16_int4); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMul_int8, MNNPackedMatMulFP16_int8); FUNC_PTR_ASSIGN(gInstance->MNNPackedMatMulRemain_int8, MNNPackedMatMulRemainFP16_int8); + // Dynamic Qaunt Helper Functions FUNC_PTR_ASSIGN(gInstance->MNNAbsMax, MNNAbsMaxFP16); FUNC_PTR_ASSIGN(gInstance->MNNQuantScale, MNNQuantScaleFP16); FUNC_PTR_ASSIGN(gInstance->MNNDynamicQuant, MNNDynamicQuantFP16); FUNC_PTR_ASSIGN(gInstance->MNNQuantSum, MNNQuantSumFP16); + FUNC_PTR_ASSIGN(gInstance->MNNCountMaxMinValue, ARM82CountMinMaxValue); + // Dynamic Quant Gemm Kernels. gInstance->supportFp16arith = origin->supportFp16arith; gInstance->supportSDot = origin->supportSDot; gInstance->supportI8mm = origin->supportI8mm; - #if defined(__aarch64__) +#endif if (gInstance->supportSDot) { - gInstance->MNNGemmHybridInt8 = MNNGemmHybridInt8FP16_sdot; - gInstance->MNNGemmHybridInt4 = MNNGemmHybridInt4FP16_sdot; + FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, MNNSumByAxisLForMatmul_A_ARM82); } if (gInstance->supportI8mm) { - gInstance->MNNGemmHybridInt8 = MNNGemmHybridInt8FP16_smmla; - gInstance->MNNGemmHybridInt4 = MNNGemmHybridInt4FP16_smmla; + FUNC_PTR_ASSIGN(gInstance->MNNSumByAxisLForMatmul_A, MNNSumByAxisLForMatmul_A_ARM86); } - #endif #endif FUNC_PTR_ASSIGN(gInstance->MNNPackC4ForMatMul_A, Arm82MNNPackForMatMul_A); FUNC_PTR_ASSIGN(gInstance->MNNGetMatMulPackMode, Arm82MNNGetMatMulPackMode); diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNCountMinMax_ARM82.S b/source/backend/arm82/asm/arm64/low_memory/MNNCountMinMax_ARM82.S new file mode 100644 index 000000000..680e6f2ac --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNCountMinMax_ARM82.S @@ -0,0 +1,278 @@ +// +// MNNAbsMaxFP16.S +// MNN +// +// Created by MNN on 2023/10/31. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" +.text +.align 5 + +.macro MaxMin_4 s0, s1, s2, s3, z0, z1, z2, z3 // z0,z1:max z2,z3:min + fmax \z0\().8h, \s0\().8h, \s1\().8h + fmax \z1\().8h, \s2\().8h, \s3\().8h + fmin \z2\().8h, \s0\().8h, \s1\().8h + fmin \z3\().8h, \s2\().8h, \s3\().8h + + fmax \z0\().8h, \z0\().8h, \z1\().8h + fmin \z2\().8h, \z2\().8h, \z3\().8h +.endm + +.macro Max_6 s0, s1, s2, s3, s4, s5, z0 + fmax \s0\().8h, \s0\().8h, \s4\().8h + fmax \s1\().8h, \s1\().8h, \s5\().8h + fmax \s2\().8h, \s2\().8h, \s3\().8h + + fmax \s0\().8h, \s0\().8h, \s1\().8h + fmax \z0\().8h, \z0\().8h, \s2\().8h + + fmax \z0\().8h, \z0\().8h, \s0\().8h +.endm + +.macro Min_6 s0, s1, s2, s3, s4, s5, z0 + fmin \s0\().8h, \s0\().8h, \s4\().8h + fmin \s1\().8h, \s1\().8h, \s5\().8h + fmin \s2\().8h, \s2\().8h, \s3\().8h + + fmin \s0\().8h, \s0\().8h, \s1\().8h + fmin \z0\().8h, \z0\().8h, \s2\().8h + + fmin \z0\().8h, \z0\().8h, \s0\().8h +.endm + +.macro Max_5 s0, s1, s2, s3, s4, z0 + fmax \s0\().8h, \s0\().8h, \s3\().8h + fmax \s1\().8h, \s1\().8h, \s4\().8h + fmax \z0\().8h, \s2\().8h, \z0\().8h + + fmax \s0\().8h, \s0\().8h, \s1\().8h + fmax \z0\().8h, \z0\().8h, \s0\().8h + +.endm + +.macro Min_5 s0, s1, s2, s3, s4, z0 + fmin \s0\().8h, \s0\().8h, \s3\().8h + fmin \s1\().8h, \s1\().8h, \s4\().8h + fmin \z0\().8h, \s2\().8h, \z0\().8h + + fmin \s0\().8h, \s0\().8h, \s1\().8h + fmin \z0\().8h, \z0\().8h, \s0\().8h +.endm + +.macro Max_4 s0, s1, s2, s3, z0 + fmax \s0\().8h, \s0\().8h, \s2\().8h + fmax \s1\().8h, \s1\().8h, \s3\().8h + fmax \z0\().8h, \s0\().8h, \z0\().8h + fmax \z0\().8h, \z0\().8h, \s1\().8h + +.endm + +.macro Min_4 s0, s1, s2, s3, z0 + fmin \s0\().8h, \s0\().8h, \s2\().8h + fmin \s1\().8h, \s1\().8h, \s3\().8h + fmin \z0\().8h, \s0\().8h, \z0\().8h + fmin \z0\().8h, \z0\().8h, \s1\().8h +.endm + +.macro Max_3 s0, s1, s2, z0 + fmax \s0\().8h, \s0\().8h, \s2\().8h + fmax \z0\().8h, \s1\().8h, \z0\().8h + fmax \z0\().8h, \s0\().8h, \z0\().8h + +.endm + +.macro Min_3 s0, s1, s2, z0 + fmin \s0\().8h, \s0\().8h, \s2\().8h + fmin \z0\().8h, \s1\().8h, \z0\().8h + fmin \z0\().8h, \s0\().8h, \z0\().8h +.endm + +.macro Reduce_Max_Min s0, s1 + // 8->4 + fmaxp \s0\().8h, \s0\().8h, \s0\().8h + fminp \s1\().8h, \s1\().8h, \s1\().8h + // 4->2 + fmaxp \s0\().8h, \s0\().8h, \s0\().8h + fminp \s1\().8h, \s1\().8h, \s1\().8h + // 2->1 + fmaxp \s0\().8h, \s0\().8h, \s0\().8h + fminp \s1\().8h, \s1\().8h, \s1\().8h +.endm + + +//void CountMinMaxValue_FP16(float* source, float* minVal, float* maxVal, size_t sizeQuad) +asm_function CountMinMaxValue_FP16 + +// x0: source, x1:minVal, x2:maxVal, x3:size +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +Start: +ld1 {v31.8h}, [x0], #16 +sub x3, x3, #1 +mov v30.8h, v31.8h // v30:min v31:max + + +TILE_24: +cmp x3, #24 +blt TILE_20 + +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 +ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 +ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + +MaxMin_4 v0, v1, v2, v3, v24, v25, v26, v27 // v24:max, v26:min +MaxMin_4 v4, v5, v6, v7, v28, v29, v0, v1 // v28:max, v0:min +MaxMin_4 v8, v9, v10, v11, v2, v3, v25, v27 // v2:max, v25:min +MaxMin_4 v12, v13, v14, v15, v4, v5, v6, v7 // v4:max, v6:min +MaxMin_4 v16, v17, v18, v19, v1, v3, v10, v27 // v1:max, v10:min +MaxMin_4 v20, v21, v22, v23, v12, v13, v14, v15 // v12:max, v14:min + +Max_6 v1, v2, v4, v12, v24, v28, v31 +Min_6 v0, v6, v10, v14, v26, v25, v30 + +sub x3, x3, #24 +b TILE_24 + +TILE_20: +cmp x3, #20 +blt TILE_16 + +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 +ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 + +MaxMin_4 v0, v1, v2, v3, v24, v25, v26, v27 // v24:max, v26:min +MaxMin_4 v4, v5, v6, v7, v20, v21, v22, v23 // v20:max, v22:min +MaxMin_4 v8, v9, v10, v11, v0, v1, v2, v3 // v0:max, v2:min +MaxMin_4 v12, v13, v14, v15, v4, v5, v6, v7 // v4:max, v6:min +MaxMin_4 v16, v17, v18, v19, v25, v27, v21, v23 // v25:max, v21:min + +Max_5 v0, v4, v20, v25, v24, v31 +Min_5 v2, v6, v21, v22, v26, v30 + +sub x3, x3, #20 +b TILE_20 + +TILE_16: +cmp x3, #16 +blt TILE_12 + +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + +MaxMin_4 v0, v1, v2, v3, v24, v25, v26, v27 // v24:max, v26:min +MaxMin_4 v4, v5, v6, v7, v20, v21, v22, v23 // v20:max, v22:min +MaxMin_4 v8, v9, v10, v11, v16, v17, v18, v19 // v16:max, v18:min +MaxMin_4 v12, v13, v14, v15, v0, v1, v2, v3 // v0:max, v2:min + +Max_4 v0, v16, v20, v24, v31 +Min_4 v2, v18, v22, v26, v30 + +sub x3, x3, #16 +b TILE_16 + +TILE_12: +cmp x3, #12 +blt TILE_8 + +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 + +MaxMin_4 v0, v1, v2, v3, v24, v25, v26, v27 // v24:max, v26:min +MaxMin_4 v4, v5, v6, v7, v20, v21, v22, v23 // v20:max, v22:min +MaxMin_4 v8, v9, v10, v11, v16, v17, v18, v19 // v16:max, v18:min + +Max_3 v16, v20, v24, v31 +Min_3 v18, v22, v26, v30 + +sub x3, x3, #12 +b TILE_12 + +TILE_8: +cmp x3, #8 +blt TILE_4 + +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 + +MaxMin_4 v0, v1, v2, v3, v24, v25, v26, v27 // v24:max, v26:min +MaxMin_4 v4, v5, v6, v7, v20, v21, v22, v23 // v20:max, v22:min + +fmax v24.8h, v24.8h, v20.8h +fmin v26.8h, v26.8h, v22.8h +fmax v31.8h, v31.8h, v24.8h +fmin v30.8h, v30.8h, v26.8h + +sub x3, x3, #8 +b TILE_8 + +TILE_4: +cmp x3, #4 +blt TILE_2 + +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 + +MaxMin_4 v0, v1, v2, v3, v24, v25, v26, v27 // v24:max, v26:min + +fmax v31.8h, v31.8h, v24.8h +fmin v30.8h, v30.8h, v26.8h + +sub x3, x3, #4 +b TILE_4 + +TILE_2: +cmp x3, #2 +blt TILE_1 + +ld1 {v0.8h, v1.8h}, [x0], #32 + +fmax v2.8h, v0.8h, v1.8h +fmin v3.8h, v0.8h, v1.8h + +fmax v31.8h, v31.8h, v2.8h +fmin v30.8h, v30.8h, v3.8h + +sub x3, x3, #2 +b TILE_2 + +TILE_1: +cmp x3, #1 +blt End + +ld1 {v0.8h}, [x0], #16 + +fmax v31.8h, v31.8h, v0.8h +fmin v30.8h, v30.8h, v0.8h + +sub x3, x3, #1 +b TILE_1 + +End: +Reduce_Max_Min v31, v30 +//fcvtl v30.4s, v30.4h +//fcvtl v31.4s, v31.4h +st1 {v30.h}[0], [x1] +st1 {v31.h}[1], [x2] + +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret + +#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuanInput_ARM82.S b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuanInput_ARM82.S new file mode 100644 index 000000000..22919922f --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuanInput_ARM82.S @@ -0,0 +1,268 @@ +// +// DynamicQuanInput_ARM82.S +// MNN +// +// Created by MNN on 2019/01/22. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro SCALE_TO_FLOAT_8 s0, s1, s2, s3, s4, s5, s6, s7, z0 + fmul \s0\().8h, \s0\().8h, \z0\().8h + fmul \s1\().8h, \s1\().8h, \z0\().8h + fmul \s2\().8h, \s2\().8h, \z0\().8h + fmul \s3\().8h, \s3\().8h, \z0\().8h + fmul \s4\().8h, \s4\().8h, \z0\().8h + fmul \s5\().8h, \s5\().8h, \z0\().8h + fmul \s6\().8h, \s6\().8h, \z0\().8h + fmul \s7\().8h, \s7\().8h, \z0\().8h +.endm + +.macro SCALE_TO_FLOAT_4 s0, s1, s2, s3, z0 + fmul \s0\().8h, \s0\().8h, \z0\().8h + fmul \s1\().8h, \s1\().8h, \z0\().8h + fmul \s2\().8h, \s2\().8h, \z0\().8h + fmul \s3\().8h, \s3\().8h, \z0\().8h +.endm + +.macro ADD_ZEROPOINT_8 s0, s1, s2, s3, s4, s5, s6, s7, z0 + fadd \s0\().8h, \s0\().8h, \z0\().8h + fadd \s1\().8h, \s1\().8h, \z0\().8h + fadd \s2\().8h, \s2\().8h, \z0\().8h + fadd \s3\().8h, \s3\().8h, \z0\().8h + fadd \s4\().8h, \s4\().8h, \z0\().8h + fadd \s5\().8h, \s5\().8h, \z0\().8h + fadd \s6\().8h, \s6\().8h, \z0\().8h + fadd \s7\().8h, \s7\().8h, \z0\().8h +.endm + +.macro ADD_ZEROPOINT_4 s0, s1, s2, s3, z0 + fadd \s0\().8h, \s0\().8h, \z0\().8h + fadd \s1\().8h, \s1\().8h, \z0\().8h + fadd \s2\().8h, \s2\().8h, \z0\().8h + fadd \s3\().8h, \s3\().8h, \z0\().8h +.endm + +.macro FLOAT_TO_INT_8 s0, s1, s2, s3, s4, s5, s6, s7 + fcvtas \s0\().8h, \s0\().8h + fcvtas \s1\().8h, \s1\().8h + fcvtas \s2\().8h, \s2\().8h + fcvtas \s3\().8h, \s3\().8h + fcvtas \s4\().8h, \s4\().8h + fcvtas \s5\().8h, \s5\().8h + fcvtas \s6\().8h, \s6\().8h + fcvtas \s7\().8h, \s7\().8h +.endm + +.macro FLOAT_TO_INT_4 s0, s1, s2, s3 + fcvtas \s0\().8h, \s0\().8h + fcvtas \s1\().8h, \s1\().8h + fcvtas \s2\().8h, \s2\().8h + fcvtas \s3\().8h, \s3\().8h +.endm + +.macro INT16_TO_INT8_8 s0, s1, s2, s3, s4, s5, s6, s7, d0, d1, d2, d3 + sqxtn \d0\().8b, \s0\().8h + sqxtn2 \d0\().16b, \s1\().8h + sqxtn \d1\().8b, \s2\().8h + sqxtn2 \d1\().16b, \s3\().8h + sqxtn \d2\().8b, \s4\().8h + sqxtn2 \d2\().16b, \s5\().8h + sqxtn \d3\().8b, \s6\().8h + sqxtn2 \d3\().16b, \s7\().8h +.endm + +.macro INT16_TO_INT8_4 s0, s1, s2, s3, d0, d1 + sqxtn \d0\().8b, \s0\().8h + sqxtn2 \d0\().16b, \s1\().8h + sqxtn \d1\().8b, \s2\().8h + sqxtn2 \d1\().16b, \s3\().8h +.endm + + +/* +Note: Only used in dynamic quant,so do not need compare min max! + */ +asm_function DynamicQuanInput_ARM82 +//void DynamicQuanInput_ARM82(const float* src, int8_t* dst, size_t sizeQuad, float* scale, size_t aMin, size_t aMax, size_t zeroPoint); +//x0:src, x1:dst, x2:sizeQuad, x3:scale, x4:aMin, x5:aMax, x6:zeroPoint +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +ld1 {v29.s}[0], [x3] // Load scale +// copy zero point +dup v30.4s, w6 +fcvtn v31.4h, v29.4s +scvtf v30.4s, v30.4s + +dup v31.8h, v31.h[0] +fcvtn v30.4h, v30.4s +dup v30.8h, v30.h[0] + +FL28: +cmp x2, #28 +blt FL24 + +FLLoop28: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 +ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 +ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 +ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x0], #64 + +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +SCALE_TO_FLOAT_8 v8, v9, v10, v11, v12, v13, v14, v15, v31 +SCALE_TO_FLOAT_8 v16, v17, v18, v19, v20, v21, v22, v23, v31 +SCALE_TO_FLOAT_4 v24, v25, v26, v27, v31 +sub x2, x2, #28 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +ADD_ZEROPOINT_8 v8, v9, v10, v11, v12, v13, v14, v15, v30 +ADD_ZEROPOINT_8 v16, v17, v18, v19, v20, v21, v22, v23, v30 +ADD_ZEROPOINT_4 v24, v25, v26, v27, v30 + +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +FLOAT_TO_INT_8 v8, v9, v10, v11, v12, v13, v14, v15 +FLOAT_TO_INT_8 v16, v17, v18, v19, v20, v21, v22, v23 +FLOAT_TO_INT_4 v24, v25, v26, v27 +cmp x2, #28 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v28, v29, v0, v1 +INT16_TO_INT8_8 v8, v9, v10, v11, v12, v13, v14, v15, v2, v3, v4, v5 +st1 {v28.16b, v29.16b}, [x1], #32 +INT16_TO_INT8_8 v16, v17, v18, v19, v20, v21, v22, v23, v6, v7, v8, v9 +st1 {v0.16b, v1.16b}, [x1], #32 +INT16_TO_INT8_4 v24, v25, v26, v27, v10, v11 + +st1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x1], #64 +st1 {v6.16b, v7.16b, v8.16b, v9.16b}, [x1], #64 +st1 {v10.16b, v11.16b}, [x1], #32 + +bge FLLoop28 + +FL24: +cmp x2, #24 +blt FL16 + +FLLoop24: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 +ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 +ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +SCALE_TO_FLOAT_8 v8, v9, v10, v11, v12, v13, v14, v15, v31 +SCALE_TO_FLOAT_8 v16, v17, v18, v19, v20, v21, v22, v23, v31 +sub x2, x2, #24 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +ADD_ZEROPOINT_8 v8, v9, v10, v11, v12, v13, v14, v15, v30 +ADD_ZEROPOINT_8 v16, v17, v18, v19, v20, v21, v22, v23, v30 + +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +FLOAT_TO_INT_8 v8, v9, v10, v11, v12, v13, v14, v15 +FLOAT_TO_INT_8 v16, v17, v18, v19, v20, v21, v22, v23 +cmp x2, #24 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v24, v25, v26, v27 +INT16_TO_INT8_8 v8, v9, v10, v11, v12, v13, v14, v15, v0, v1, v2, v3 +INT16_TO_INT8_8 v16, v17, v18, v19, v20, v21, v22, v23, v4, v5, v6, v7 + +st1 {v24.16b, v25.16b, v26.16b, v27.16b}, [x1], #64 +st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 +st1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 + +bge FLLoop24 + +FL16: +cmp x2, #16 +blt FL8 + +FLLoop16: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +SCALE_TO_FLOAT_8 v8, v9, v10, v11, v12, v13, v14, v15, v31 +sub x2, x2, #16 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +ADD_ZEROPOINT_8 v8, v9, v10, v11, v12, v13, v14, v15, v30 + +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +FLOAT_TO_INT_8 v8, v9, v10, v11, v12, v13, v14, v15 +cmp x2, #16 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v24, v25, v26, v27 +INT16_TO_INT8_8 v8, v9, v10, v11, v12, v13, v14, v15, v0, v1, v2, v3 + +st1 {v24.16b, v25.16b, v26.16b, v27.16b}, [x1], #64 +st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 + +bge FLLoop16 + +FL8: +cmp x2, #8 +blt FL4 + +FLLoop8: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +sub x2, x2, #8 +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +cmp x2, #8 +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v24, v25, v26, v27 +st1 {v24.16b, v25.16b, v26.16b, v27.16b}, [x1], #64 + +bge FLLoop8 + +FL4: +cmp x2, #4 +blt FL1 + +FLLoop4: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +sub x2, x2, #4 +SCALE_TO_FLOAT_4 v0, v1, v2, v3, v31 +ADD_ZEROPOINT_4 v0, v1, v2, v3, v30 +cmp x2, #4 +FLOAT_TO_INT_4 v0, v1, v2, v3 +INT16_TO_INT8_4 v0, v1, v2, v3, v24, v25 +st1 {v24.16b, v25.16b}, [x1], #32 + +bge FLLoop4 + +FL1: +cmp x2, #0 +beq FLEnd + +FLLoop1: +ld1 {v0.8h}, [x0], #16 +fmul v0.8h, v0.8h, v31.8h +fadd v0.8h, v0.8h, v30.8h + +fcvtas v0.8h, v0.8h +sqxtn v0.8b, v0.8h + +st1 {v0.d}[0], [x1], #8 + +subs x2, x2, #1 +bne FLLoop1 + +FLEnd: +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 +ret +#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantAndReorder_ARM82.S b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantAndReorder_ARM82.S new file mode 100644 index 000000000..44e3568f1 --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantAndReorder_ARM82.S @@ -0,0 +1,433 @@ +// +// DynamicQuanInput_ARM82.S +// MNN +// +// Created by MNN on 2019/01/22. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro SCALE_TO_FLOAT_8 s0, s1, s2, s3, s4, s5, s6, s7, z0 + fmul \s0\().8h, \s0\().8h, \z0\().8h + fmul \s1\().8h, \s1\().8h, \z0\().8h + fmul \s2\().8h, \s2\().8h, \z0\().8h + fmul \s3\().8h, \s3\().8h, \z0\().8h + fmul \s4\().8h, \s4\().8h, \z0\().8h + fmul \s5\().8h, \s5\().8h, \z0\().8h + fmul \s6\().8h, \s6\().8h, \z0\().8h + fmul \s7\().8h, \s7\().8h, \z0\().8h +.endm + +.macro SCALE_TO_FLOAT_4 s0, s1, s2, s3, z0 + fmul \s0\().8h, \s0\().8h, \z0\().8h + fmul \s1\().8h, \s1\().8h, \z0\().8h + fmul \s2\().8h, \s2\().8h, \z0\().8h + fmul \s3\().8h, \s3\().8h, \z0\().8h +.endm + +.macro ADD_ZEROPOINT_8 s0, s1, s2, s3, s4, s5, s6, s7, z0 + fadd \s0\().8h, \s0\().8h, \z0\().8h + fadd \s1\().8h, \s1\().8h, \z0\().8h + fadd \s2\().8h, \s2\().8h, \z0\().8h + fadd \s3\().8h, \s3\().8h, \z0\().8h + fadd \s4\().8h, \s4\().8h, \z0\().8h + fadd \s5\().8h, \s5\().8h, \z0\().8h + fadd \s6\().8h, \s6\().8h, \z0\().8h + fadd \s7\().8h, \s7\().8h, \z0\().8h +.endm + +.macro ADD_ZEROPOINT_4 s0, s1, s2, s3, z0 + fadd \s0\().8h, \s0\().8h, \z0\().8h + fadd \s1\().8h, \s1\().8h, \z0\().8h + fadd \s2\().8h, \s2\().8h, \z0\().8h + fadd \s3\().8h, \s3\().8h, \z0\().8h +.endm + +.macro FLOAT_TO_INT_8 s0, s1, s2, s3, s4, s5, s6, s7 + fcvtas \s0\().8h, \s0\().8h + fcvtas \s1\().8h, \s1\().8h + fcvtas \s2\().8h, \s2\().8h + fcvtas \s3\().8h, \s3\().8h + fcvtas \s4\().8h, \s4\().8h + fcvtas \s5\().8h, \s5\().8h + fcvtas \s6\().8h, \s6\().8h + fcvtas \s7\().8h, \s7\().8h +.endm + +.macro FLOAT_TO_INT_4 s0, s1, s2, s3 + fcvtas \s0\().8h, \s0\().8h + fcvtas \s1\().8h, \s1\().8h + fcvtas \s2\().8h, \s2\().8h + fcvtas \s3\().8h, \s3\().8h +.endm + +.macro INT16_TO_INT8_8 s0, s1, s2, s3, s4, s5, s6, s7, d0, d1, d2, d3 + sqxtn \d0\().8b, \s0\().8h + sqxtn2 \d0\().16b, \s1\().8h + sqxtn \d1\().8b, \s2\().8h + sqxtn2 \d1\().16b, \s3\().8h + sqxtn \d2\().8b, \s4\().8h + sqxtn2 \d2\().16b, \s5\().8h + sqxtn \d3\().8b, \s6\().8h + sqxtn2 \d3\().16b, \s7\().8h +.endm + +.macro INT16_TO_INT8_4 s0, s1, s2, s3, d0, d1 + sqxtn \d0\().8b, \s0\().8h + sqxtn2 \d0\().16b, \s1\().8h + sqxtn \d1\().8b, \s2\().8h + sqxtn2 \d1\().16b, \s3\().8h +.endm + + +/* +Note: Only used in dynamic quant,so do not need compare min max! +1. Quant Float16 to Int8; +2. Pack data from C8 to C4 for Im2Col fixed unit=4 + */ +asm_function DynamicQuanInputAndReorder_ARM82 +//void DynamicQuanInputAndReorder_ARM82(const float* src, int8_t* dst, size_t planeSize, float* scale, size_t aMin, size_t aMax, size_t zeroPoint, size_t ocQuad, size_t offset); +//x0:src, x1:dst, x2:planeSize, x3:scale, x4:aMin, x5:aMax, x6:zeroPoint, x7:ocQuad, x8:offset +ldr x8, [sp, #0] // plane*4 +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +ld1 {v29.s}[0], [x3] // Load scale +// copy zero point +dup v30.4s, w6 +fcvtn v31.4h, v29.4s +scvtf v30.4s, v30.4s + +add x13, x8, x8 + +dup v31.8h, v31.h[0] +fcvtn v30.4h, v30.4s +dup v30.8h, v30.h[0] + +mov x9, x1 // first N*4 +add x10, x1, x8 // seconde N*4 +mov x14, x2 // Reserve planeSize + +Outter_Channel_Loop: +cmp x7, #1 +blt End + +mov x11, x9 // flag address +mov x12, x10 + +FL28: // N loop +cmp x2, #28 +blt FL20 + +FLLoop28: // N=28 + +ChannleLoop_28: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 +ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 +ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 +ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x0], #64 + +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +SCALE_TO_FLOAT_8 v8, v9, v10, v11, v12, v13, v14, v15, v31 +SCALE_TO_FLOAT_8 v16, v17, v18, v19, v20, v21, v22, v23, v31 +SCALE_TO_FLOAT_4 v24, v25, v26, v27, v31 +sub x2, x2, #28 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +ADD_ZEROPOINT_8 v8, v9, v10, v11, v12, v13, v14, v15, v30 +ADD_ZEROPOINT_8 v16, v17, v18, v19, v20, v21, v22, v23, v30 +ADD_ZEROPOINT_4 v24, v25, v26, v27, v30 + +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +FLOAT_TO_INT_8 v8, v9, v10, v11, v12, v13, v14, v15 +FLOAT_TO_INT_8 v16, v17, v18, v19, v20, v21, v22, v23 +FLOAT_TO_INT_4 v24, v25, v26, v27 +cmp x2, #28 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v28, v29, v0, v1 +INT16_TO_INT8_8 v8, v9, v10, v11, v12, v13, v14, v15, v2, v3, v4, v5 +//st1 {v28.16b, v29.16b}, [x1], #32 +INT16_TO_INT8_8 v16, v17, v18, v19, v20, v21, v22, v23, v6, v7, v8, v9 +//st1 {v0.16b, v1.16b}, [x1], #32 +INT16_TO_INT8_4 v24, v25, v26, v27, v10, v11 + +// Reorder c8->c4, 0,..27 means plane index +uzp1 v12.4s, v28.4s, v29.4s // 0 0 1 1 x 2 2 3 3 -> 0 1 2 3 +uzp1 v13.4s, v0.4s, v1.4s // 4 4 5 5 x 6 6 7 7 -> 4 5 6 7 +uzp1 v14.4s, v2.4s, v3.4s // 8 8 9 9 x 10 10 11 11 -> 8 9 10 11 +uzp1 v15.4s, v4.4s, v5.4s // 12 12 13 13 x 14 14 15 15 -> 12 13 14 15 +uzp1 v16.4s, v6.4s, v7.4s // 16 16 17 17 x 18 18 19 19 -> 16 17 18 19 +uzp1 v17.4s, v8.4s, v9.4s // 20 20 21 21 x 22 22 23 23 -> 20 21 22 23 +uzp1 v18.4s, v10.4s, v11.4s // 24 24 25 25 x 26 26 27 27 -> 24 25 26 27 +uzp2 v19.4s, v28.4s, v29.4s +uzp2 v20.4s, v0.4s, v1.4s +uzp2 v21.4s, v2.4s, v3.4s +uzp2 v22.4s, v4.4s, v5.4s +uzp2 v23.4s, v6.4s, v7.4s +uzp2 v24.4s, v8.4s, v9.4s +uzp2 v25.4s, v10.4s, v11.4s + +st1 {v12.16b, v13.16b, v14.16b, v15.16b}, [x11], #64 +st1 {v16.16b, v17.16b, v18.16b}, [x11], #48 +st1 {v19.16b, v20.16b, v21.16b, v22.16b}, [x12], #64 +st1 {v23.16b, v24.16b, v25.16b}, [x12], #48 + +bge FLLoop28 + +FL24: +cmp x2, #24 +blt FL20 + +FLLoop24: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 +ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 +ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +SCALE_TO_FLOAT_8 v8, v9, v10, v11, v12, v13, v14, v15, v31 +SCALE_TO_FLOAT_8 v16, v17, v18, v19, v20, v21, v22, v23, v31 +sub x2, x2, #24 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +ADD_ZEROPOINT_8 v8, v9, v10, v11, v12, v13, v14, v15, v30 +ADD_ZEROPOINT_8 v16, v17, v18, v19, v20, v21, v22, v23, v30 + +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +FLOAT_TO_INT_8 v8, v9, v10, v11, v12, v13, v14, v15 +FLOAT_TO_INT_8 v16, v17, v18, v19, v20, v21, v22, v23 +cmp x2, #24 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v24, v25, v26, v27 +INT16_TO_INT8_8 v8, v9, v10, v11, v12, v13, v14, v15, v28, v29, v0, v1 +INT16_TO_INT8_8 v16, v17, v18, v19, v20, v21, v22, v23, v2, v3, v4, v5 + +// Reorder c8->c4 +uzp1 v6.4s, v24.4s, v25.4s // 0 0 1 1 x 2 2 3 3 -> 0 1 2 3 +uzp1 v7.4s, v26.4s, v27.4s // 4 4 5 5 x 6 6 7 7 -> 4 5 6 7 +uzp1 v8.4s, v28.4s, v29.4s // 8 8 9 9 x 10 10 11 11 -> 8 9 10 11 +uzp1 v9.4s, v0.4s, v1.4s // 12 12 13 13 x 14 14 15 15 -> 12 13 14 15 +uzp1 v10.4s, v2.4s, v3.4s // 16 16 17 17 x 18 18 19 19 -> 16 17 18 19 +uzp1 v11.4s, v4.4s, v5.4s // 20 20 21 21 x 22 22 23 23 -> 20 21 22 23 +uzp2 v12.4s, v24.4s, v25.4s +uzp2 v13.4s, v26.4s, v27.4s +uzp2 v14.4s, v28.4s, v29.4s +uzp2 v15.4s, v0.4s, v1.4s +uzp2 v16.4s, v2.4s, v3.4s +uzp2 v17.4s, v4.4s, v5.4s + +st1 {v6.16b, v7.16b, v8.16b, v9.16b}, [x11], #64 +st1 {v10.16b, v11.16b}, [x11], #32 +st1 {v12.16b, v13.16b, v14.16b, v15.16b}, [x12], #64 +st1 {v16.16b, v17.16b}, [x12], #32 + +bge FLLoop24 + +FL20: +cmp x2, #20 +blt FL12 + +FLLoop20: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 +ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0], #64 + +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +SCALE_TO_FLOAT_8 v8, v9, v10, v11, v12, v13, v14, v15, v31 +SCALE_TO_FLOAT_4 v16, v17, v18, v19, v31 +sub x2, x2, #20 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +ADD_ZEROPOINT_8 v8, v9, v10, v11, v12, v13, v14, v15, v30 +ADD_ZEROPOINT_4 v16, v17, v18, v19, v30 + +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +FLOAT_TO_INT_8 v8, v9, v10, v11, v12, v13, v14, v15 +FLOAT_TO_INT_4 v16, v17, v18, v19 +cmp x2, #20 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v24, v25, v26, v27 +INT16_TO_INT8_8 v8, v9, v10, v11, v12, v13, v14, v15, v21, v22, v23, v28 +INT16_TO_INT8_4 v16, v17, v18, v19, v29, v20 + +// Reorder c8->c4 +uzp1 v0.4s, v24.4s, v25.4s // 0 0 1 1 x 2 2 3 3 -> 0 1 2 3 +uzp1 v1.4s, v26.4s, v27.4s // 4 4 5 5 x 6 6 7 7 -> 4 5 6 7 +uzp1 v2.4s, v21.4s, v22.4s // 8 8 9 9 x 10 10 11 11 -> 8 9 10 11 +uzp1 v3.4s, v23.4s, v28.4s // 12 12 13 13 x 14 14 15 15 -> 12 13 14 15 +uzp1 v4.4s, v29.4s, v20.4s // 16 16 17 17 x 18 18 19 19 -> 16 17 18 19 +uzp2 v5.4s, v24.4s, v25.4s +uzp2 v6.4s, v26.4s, v27.4s +uzp2 v7.4s, v21.4s, v22.4s +uzp2 v8.4s, v23.4s, v28.4s +uzp2 v9.4s, v29.4s, v20.4s + +st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x11], #64 +st1 {v4.16b}, [x11], #16 +st1 {v5.16b, v6.16b, v7.16b, v8.16b}, [x12], #64 +st1 {v9.16b}, [x12], #16 + +bge FLLoop20 + +FL16: +cmp x2, #16 +blt FL12 + +FLLoop16: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 +ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +SCALE_TO_FLOAT_8 v8, v9, v10, v11, v12, v13, v14, v15, v31 +sub x2, x2, #16 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +ADD_ZEROPOINT_8 v8, v9, v10, v11, v12, v13, v14, v15, v30 + +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +FLOAT_TO_INT_8 v8, v9, v10, v11, v12, v13, v14, v15 +cmp x2, #16 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v24, v25, v26, v27 +INT16_TO_INT8_8 v8, v9, v10, v11, v12, v13, v14, v15, v20, v21, v22, v23 + +// Reorder c8->c4 +uzp1 v16.4s, v24.4s, v25.4s // 0 0 1 1 x 2 2 3 3 -> 0 1 2 3 +uzp1 v17.4s, v26.4s, v27.4s // 4 4 5 5 x 6 6 7 7 -> 4 5 6 7 +uzp1 v18.4s, v20.4s, v21.4s // 8 8 9 9 x 10 10 11 11 -> 8 9 10 11 +uzp1 v19.4s, v22.4s, v23.4s // 12 12 13 13 x 14 14 15 15 -> 12 13 14 15 + +uzp2 v0.4s, v24.4s, v25.4s +uzp2 v1.4s, v26.4s, v27.4s +uzp2 v2.4s, v20.4s, v21.4s +uzp2 v3.4s, v22.4s, v23.4s + +st1 {v16.16b, v17.16b, v18.16b, v19.16b}, [x11], #64 +st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x12], #64 + +bge FLLoop16 + +FL12: +cmp x2, #12 +blt FL8 + +FLLoop12: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +ld1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], #64 + +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +SCALE_TO_FLOAT_4 v8, v9, v10, v11, v31 +sub x2, x2, #12 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +ADD_ZEROPOINT_4 v8, v9, v10, v11, v30 + +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +FLOAT_TO_INT_4 v8, v9, v10, v11 +cmp x2, #12 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v24, v25, v26, v27 +INT16_TO_INT8_4 v8, v9, v10, v11, v20, v21 + +// Reorder c8->c4 +uzp1 v12.4s, v24.4s, v25.4s // 0 0 1 1 x 2 2 3 3 -> 0 1 2 3 +uzp2 v16.4s, v24.4s, v25.4s +uzp1 v13.4s, v26.4s, v27.4s // 4 4 5 5 x 6 6 7 7 -> 4 5 6 7 +uzp2 v17.4s, v26.4s, v27.4s +uzp1 v14.4s, v20.4s, v21.4s // 8 8 9 9 x 10 10 11 11 -> 8 9 10 11 +uzp2 v18.4s, v20.4s, v21.4s + +st1 {v12.16b, v13.16b, v14.16b}, [x11], #48 +st1 {v16.16b, v17.16b, v18.16b}, [x12], #48 + +bge FLLoop12 + +FL8: +cmp x2, #8 +blt FL4 + +FLLoop8: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 +sub x2, x2, #8 +SCALE_TO_FLOAT_8 v0, v1, v2, v3, v4, v5, v6, v7, v31 +ADD_ZEROPOINT_8 v0, v1, v2, v3, v4, v5, v6, v7, v30 +cmp x2, #8 +FLOAT_TO_INT_8 v0, v1, v2, v3, v4, v5, v6, v7 +INT16_TO_INT8_8 v0, v1, v2, v3, v4, v5, v6, v7, v24, v25, v26, v27 + +// Reorder c8->c4 +uzp1 v12.4s, v24.4s, v25.4s // 0 0 1 1 x 2 2 3 3 -> 0 1 2 3 +uzp2 v19.4s, v24.4s, v25.4s +uzp1 v13.4s, v26.4s, v27.4s // 4 4 5 5 x 6 6 7 7 -> 4 5 6 7 +uzp2 v20.4s, v26.4s, v27.4s + +st1 {v12.16b, v13.16b}, [x11], #32 +st1 {v19.16b, v20.16b}, [x12], #32 + +bge FLLoop8 + +FL4: +cmp x2, #4 +blt FL1 + +FLLoop4: +ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 +sub x2, x2, #4 +SCALE_TO_FLOAT_4 v0, v1, v2, v3, v31 +ADD_ZEROPOINT_4 v0, v1, v2, v3, v30 +cmp x2, #4 +FLOAT_TO_INT_4 v0, v1, v2, v3 +INT16_TO_INT8_4 v0, v1, v2, v3, v24, v25 + +// Reorder c8->c4 +uzp1 v12.4s, v24.4s, v25.4s // 0 0 1 1 x 2 2 3 3 -> 0 1 2 3 +uzp2 v19.4s, v24.4s, v25.4s + +st1 {v12.16b}, [x11], #16 +st1 {v19.16b}, [x12], #16 +//st1 {v24.16b, v25.16b}, [x1], #32 + +bge FLLoop4 + +FL1: +cmp x2, #0 +ble FLEnd + +FLLoop1: +ld1 {v0.8h}, [x0], #16 +fmul v0.8h, v0.8h, v31.8h +fadd v0.8h, v0.8h, v30.8h +sub x2, x2, #1 + +fcvtas v0.8h, v0.8h +sqxtn v0.8b, v0.8h + +cmp x2, #1 +st1 {v0.s}[0], [x11], #4 +st1 {v0.s}[1], [x12], #4 + +bge FLLoop1 + +FLEnd: +sub x7, x7, #1 +add x9, x9, x13 +add x10, x10, x13 +mov x2, x14 +b Outter_Channel_Loop + +End: +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 +ret +#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantFP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantFP16.S index 689194c9e..01455850d 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantFP16.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNDynamicQuantFP16.S @@ -19,39 +19,39 @@ fcvtas \z3\().8h, \z3\().8h .endm -//void MNNDynamicQuantFP16(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize) +//void MNNDynamicQuantFP16(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack) asm_function MNNDynamicQuantFP16 -// x0: src, x1:dst, x2:scale, x3:sum, x4:src_depth_quad, x5:realSize +// Feature: quant and reorder C8->C4 + +// x0: src, x1:dst, x2:scale, x3:src_depth_quad, x4:realSize stp d14, d15, [sp, #(-16 * 4)]! stp d12, d13, [sp, #(16 * 1)] stp d10, d11, [sp, #(16 * 2)] stp d8, d9, [sp, #(16 * 3)] Start: -lsl x6, x5, #3 // dst_step = batch * unit * sizeof(int8_t) = batch * 8 = batch << 3 -lsl x7, x6, #1 // src_step = dst_step * 2 (float16_t) = dst_step << 1 - -movi v29.16b, #1 +lsl x6, x4, #3 // dst_step = batch * (2*unit) * sizeof(int8_t) = batch * 8 = batch << 3 +lsl x7, x4, #4 // src_step = batch * pack * sizeof(float16) = batch * 8 * 2 = batch << 4 +lsl x8, x4, #2 // 4 * plane +add x11, x1, x8 // second N*4 TILE_12: -cmp x5, #12 +cmp x4, #12 blt TILE_10 mov x9, x0 // src mov x10, x1 // dst -mov x12, x4 // src_depth_quad +mov x15, x11 // second dst +mov x12, x3 // src_depth_quad sub x13, x7, #128 // src_step - 64 -sub x14, x6, #64 // dst_step - 64 - -// quant_scale: v12, v13 -ld1 {v12.8h}, [x2], #16 -ld1 {v13.d}[0], [x2], #8 -movi v23.4s, #0 -movi v24.4s, #0 -movi v25.4s, #0 -movi v26.4s, #0 -movi v27.4s, #0 -movi v28.4s, #0 + +// quant_scale: v12, v13, v14 +// ld1 {v12.8h}, [x2], #16 +// ld1 {v13.d}[0], [x2], #8 +ld1 {v12.4s, v13.4s, v14.4s}, [x2], #48 +fcvtn v12.4h, v12.4s +fcvtn2 v12.8h, v13.4s +fcvtn v13.4h, v14.4s LoopSz_12: ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x9], #64 @@ -91,47 +91,47 @@ sqxtn2 v4.16b, v9.8h sqxtn v5.8b, v10.8h sqxtn2 v5.16b, v11.8h -.inst 0x4e9d9417 // sdot v23.4s, v0.16b, v29.16b -.inst 0x4e9d9438 // sdot v24.4s, v1.16b, v29.16b -.inst 0x4e9d9459 // sdot v25.4s, v2.16b, v29.16b -.inst 0x4e9d947a // sdot v26.4s, v3.16b, v29.16b -.inst 0x4e9d949b // sdot v27.4s, v4.16b, v29.16b -.inst 0x4e9d94bc // sdot v28.4s, v5.16b, v29.16b +uzp1 v6.4s, v0.4s, v1.4s +uzp1 v7.4s, v2.4s, v3.4s +uzp1 v8.4s, v4.4s, v5.4s +uzp2 v9.4s, v0.4s, v1.4s +uzp2 v10.4s, v2.4s, v3.4s +uzp2 v11.4s, v4.4s, v5.4s + +st1 {v6.16b, v7.16b, v8.16b}, [x10], x6 +st1 {v9.16b, v10.16b, v11.16b}, [x15], x6 -st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x10], #64 -st1 {v4.16b, v5.16b}, [x10], x14 +//st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x10], #64 +//st1 {v4.16b, v5.16b}, [x10], x14 subs x12, x12, #1 bne LoopSz_12 -addp v12.4s, v23.4s, v24.4s -addp v13.4s, v25.4s, v26.4s -addp v14.4s, v27.4s, v28.4s -st1 {v12.4s, v13.4s, v14.4s}, [x3], #48 - Tile12End: -sub x5, x5, #12 // batch -= 12 +sub x4, x4, #12 // batch -= 12 add x0, x0, #192 // src += 12 * 8 * sizeof(float16_t) -add x1, x1, #96 // dst += 12 * 8 * sizeof(int8_t) +add x1, x1, #48 // dst += 12 * 4 * sizeof(int8_t) +add x11, x11, #48 b TILE_12 TILE_10: -cmp x5, #10 +cmp x4, #10 blt TILE_8 mov x9, x0 // src mov x10, x1 // dst -mov x12, x4 // src_depth_quad -sub x13, x7, #128 // src_step - 64 -sub x14, x6, #64 // dst_step - 64 +mov x15, x11 // second dst +mov x12, x3 // src_depth_quad +sub x13, x7, #128 // src_step - 128 +sub x14, x6, #32 // dst_step - 32 // quant_scale: v10, v11 -ld1 {v10.8h}, [x2], #16 -ld1 {v11.s}[0], [x2], #4 -movi v24.4s, #0 -movi v25.4s, #0 -movi v26.4s, #0 -movi v27.4s, #0 -movi v28.4s, #0 +//ld1 {v10.8h}, [x2], #16 +//ld1 {v11.s}[0], [x2], #4 +ld1 {v12.4s, v13.4s}, [x2], #32 +ld1 {v14.d}[0], [x2], #8 +fcvtn v10.4h, v12.4s +fcvtn2 v10.8h, v13.4s +fcvtn v11.4h, v14.4s LoopSz_10: ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x9], #64 @@ -168,45 +168,45 @@ sqxtn2 v3.16b, v7.8h sqxtn v4.8b, v8.8h sqxtn2 v4.16b, v9.8h -.inst 0x4e9d9418 // sdot v24.4s, v0.16b, v29.16b -.inst 0x4e9d9439 // sdot v25.4s, v1.16b, v29.16b -.inst 0x4e9d945a // sdot v26.4s, v2.16b, v29.16b -.inst 0x4e9d947b // sdot v27.4s, v3.16b, v29.16b -.inst 0x4e9d949c // sdot v28.4s, v4.16b, v29.16b +uzp1 v6.4s, v0.4s, v1.4s // 0 1 2 3 +uzp1 v7.4s, v2.4s, v3.4s // 4 5 6 7 +uzp1 v8.4s, v4.4s, v4.4s // 8 9 8 9 +uzp2 v12.4s, v0.4s, v1.4s +uzp2 v13.4s, v2.4s, v3.4s +uzp2 v14.4s, v4.4s, v4.4s +st1 {v6.16b, v7.16b}, [x10], #32 +st1 {v8.d}[0], [x10], x14 +st1 {v12.16b, v13.16b}, [x15], #32 +st1 {v14.d}[0], [x15], x14 -st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x10], #64 -st1 {v4.16b}, [x10], x14 +// st1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x10], #64 +// st1 {v4.16b}, [x10], x14 subs x12, x12, #1 bne LoopSz_10 -addp v13.4s, v24.4s, v25.4s -addp v14.4s, v26.4s, v27.4s -addp v15.4s, v28.4s, v28.4s -st1 {v13.4s, v14.4s}, [x3], #32 -st1 {v15.d}[0], [x3], #8 - Tile10End: -sub x5, x5, #10 // batch -= 10 +sub x4, x4, #10 // batch -= 10 add x0, x0, #160 // src += 10 * 8 * sizeof(float16_t) -add x1, x1, #80 // dst += 10 * 8 * sizeof(int8_t) +add x1, x1, #40 // dst += 10 * 4 * sizeof(int8_t) +add x11, x11, #40 b TILE_10 TILE_8: -cmp x5, #8 +cmp x4, #8 blt TILE_1 sub x8, x7, #64 // src_step - 64 mov x9, x0 // src mov x10, x1 // dst -mov x12, x4 // src_depth_quad +mov x15, x11 // second dst +mov x12, x3 // src_depth_quad // quant_scale: v8 -ld1 {v8.8h}, [x2], #16 -movi v25.4s, #0 -movi v26.4s, #0 -movi v27.4s, #0 -movi v28.4s, #0 +//ld1 {v8.8h}, [x2], #16 +ld1 {v12.4s, v13.4s}, [x2], #32 +fcvtn v8.4h, v12.4s +fcvtn2 v8.8h, v13.4s LoopSz_8: ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x9], #64 @@ -236,37 +236,36 @@ sqxtn2 v11.16b, v5.8h sqxtn v12.8b, v6.8h sqxtn2 v12.16b, v7.8h -.inst 0x4e9d9539 // sdot v25.4s, v9.16b, v29.16b -.inst 0x4e9d955a // sdot v26.4s, v10.16b, v29.16b -.inst 0x4e9d957b // sdot v27.4s, v11.16b, v29.16b -.inst 0x4e9d959c // sdot v28.4s, v12.16b, v29.16b - -st1 {v9.16b, v10.16b, v11.16b, v12.16b}, [x10], x6 +uzp1 v6.4s, v9.4s, v10.4s // 0 1 2 3 first +uzp1 v7.4s, v11.4s, v12.4s // 4 5 6 7 +uzp2 v14.4s, v9.4s, v10.4s // 0 1 2 3 second +uzp2 v15.4s, v11.4s, v12.4s // 4 5 6 7 +st1 {v6.16b, v7.16b}, [x10], x6 +st1 {v14.16b, v15.16b}, [x15], x6 +//st1 {v9.16b, v10.16b, v11.16b, v12.16b}, [x10], x6 subs x12, x12, #1 bne LoopSz_8 -addp v14.4s, v25.4s, v26.4s -addp v15.4s, v27.4s, v28.4s -st1 {v14.4s, v15.4s}, [x3], #32 - Tile8End: -sub x5, x5, #8 // batch -= 8 +sub x4, x4, #8 // batch -= 8 add x0, x0, #128 // src += 8 * 8 * sizeof(float16_t) -add x1, x1, #64 // dst += 8 * 8 * sizeof(int8_t) +add x1, x1, #32 // dst += 8 * 4 * sizeof(int8_t) +add x11, x11, #32 b TILE_8 TILE_4: -cmp x5, #4 +cmp x4, #4 blt TILE_2 mov x9, x0 // src mov x10, x1 // dst -mov x12, x4 // src_depth_quad +mov x15, x11 // second dst +mov x12, x3 // src_depth_quad // quant_scale: v8 -ld1 {v8.d}[0], [x2], #8 -movi v27.4s, #0 -movi v28.4s, #0 +//ld1 {v8.d}[0], [x2], #8 +ld1 {v12.4s}, [x2], #16 +fcvtn v8.4h, v12.4s LoopSz_4: ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x9], x7 @@ -286,34 +285,35 @@ sqxtn2 v4.16b, v1.8h sqxtn v5.8b, v2.8h sqxtn2 v5.16b, v3.8h -.inst 0x4e9d949b // sdot v27.4s, v4.16b, v29.16b -.inst 0x4e9d94bc // sdot v28.4s, v5.16b, v29.16b - -st1 {v4.16b, v5.16b}, [x10], x6 +uzp1 v6.4s, v4.4s, v5.4s // 0 1 2 3 first +uzp2 v14.4s, v4.4s, v5.4s // 0 1 2 3 second +st1 {v6.16b}, [x10], x6 +st1 {v14.16b}, [x15], x6 +//st1 {v4.16b, v5.16b}, [x10], x6 subs x12, x12, #1 bne LoopSz_4 -addp v26.4s, v27.4s, v28.4s -st1 {v26.4s}, [x3], #16 - Tile4End: -sub x5, x5, #4 // batch -= 4 +sub x4, x4, #4 // batch -= 4 add x0, x0, #64 // src += 4 * 8 * sizeof(float16_t) -add x1, x1, #32 // dst += 4 * 8 * sizeof(int8_t) +add x1, x1, #16 // dst += 4 * 4 * sizeof(int8_t) +add x11, x11, #16 b TILE_4 TILE_2: -cmp x5, #2 +cmp x4, #2 blt TILE_1 mov x9, x0 // src mov x10, x1 // dst -mov x12, x4 // src_depth_quad +mov x15, x11 // second dst +mov x12, x3 // src_depth_quad // quant_scale: v8 -ld1 {v8.s}[0], [x2], #4 -movi v28.4s, #0 +//ld1 {v8.s}[0], [x2], #4 +ld1 {v12.d}[0], [x2], #8 +fcvtn v8.4h, v12.4s LoopSz_2: ld1 {v0.8h, v1.8h}, [x9], x7 @@ -329,33 +329,34 @@ fcvtas v1.8h, v1.8h // y = (int8_t)x sqxtn v2.8b, v0.8h sqxtn2 v2.16b, v1.8h -.inst 0x4e9d945c // sdot v28.4s, v2.16b, v29.16b -st1 {v2.16b}, [x10], x6 +st1 {v2.d}[0], [x10], x6 +st1 {v2.d}[1], [x15], x6 +//st1 {v2.16b}, [x10], x6 subs x12, x12, #1 bne LoopSz_2 -addp v27.4s, v28.4s, v28.4s -st1 {v27.d}[0], [x3], #8 - Tile2End: -sub x5, x5, #2 // batch -= 2 +sub x4, x4, #2 // batch -= 2 add x0, x0, #32 // src += 2 * 8 * sizeof(float16_t) -add x1, x1, #16 // dst += 2 * 8 * sizeof(int8_t) +add x1, x1, #8 // dst += 2 * 4 * sizeof(int8_t) +add x11, x11, #8 b TILE_2 TILE_1: -cmp x5, #1 +cmp x4, #1 blt End mov x9, x0 // src mov x10, x1 // dst -mov x12, x4 // src_depth_quad +mov x15, x11 // second dst +mov x12, x3 // src_depth_quad // quant_scale: v8 -ld1 {v8.h}[0], [x2], #2 -movi v28.4s, #0 +//ld1 {v8.h}[0], [x2], #2 +ld1 {v12.s}[0], [x2], #4 +fcvtn v8.4h, v12.4s LoopSz_1: ld1 {v0.8h}, [x9], x7 @@ -366,20 +367,18 @@ fmul v0.8h, v0.8h, v8.h[0] fcvtas v0.8h, v0.8h // y = (int8_t)x sqxtn v0.8b, v0.8h -.inst 0x4e9d941c // sdot v28.4s, v0.16b, v29.16b -st1 {v0.8b}, [x10], x6 +st1 {v0.s}[0], [x10], x6 +st1 {v0.s}[1], [x15], x6 subs x12, x12, #1 bne LoopSz_1 -addp v27.4s, v28.4s, v28.4s -st1 {v27.s}[0], [x3], #4 - Tile1End: -sub x5, x5, #1 // batch -= 1 +sub x4, x4, #1 // batch -= 1 add x0, x0, #16 // src += 1 * 8 * sizeof(float16_t) -add x1, x1, #8 // dst += 1 * 8 * sizeof(int8_t) +add x1, x1, #4 // dst += 1 * 4 * sizeof(int8_t) +add x11, x11, #4 b TILE_1 @@ -390,4 +389,4 @@ ldp d12, d13, [sp, #(16 * 1)] ldp d14, d15, [sp], #(16 * 4) ret -#endif \ No newline at end of file +#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_sdot.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_sdot.S deleted file mode 100644 index 9620c93eb..000000000 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_sdot.S +++ /dev/null @@ -1,314 +0,0 @@ -// -// MNNGemmHybridInt4FP16_sdot.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s, idx0, idx1, alpha0, alpha1 - fmul \d0\().4s, \d0\().4s, \s\().s[\idx0] - fmul \d1\().4s, \d1\().4s, \s\().s[\idx0] - fmul \d2\().4s, \d2\().4s, \s\().s[\idx1] - fmul \d3\().4s, \d3\().4s, \s\().s[\idx1] - fmul \d0\().4s, \d0\().4s, \alpha0\().4s - fmul \d1\().4s, \d1\().4s, \alpha1\().4s - fmul \d2\().4s, \d2\().4s, \alpha0\().4s - fmul \d3\().4s, \d3\().4s, \alpha1\().4s -.endm - -.macro Float32ToHalf s0, s1, s2, s3, d0, d1 - fcvtn \d0\().4h, \s0\().4s - fcvtn2 \d0\().8h, \s1\().4s - fcvtn \d1\().4h, \s2\().4s - fcvtn2 \d1\().8h, \s3\().4s -.endm - -.macro Dequant c0, z0, b0, s0, idx - fmla \c0\().8h, \z0\().8h, \s0\().h[\idx] - fadd \c0\().8h, \c0\().8h, \b0\().8h -.endm - -asm_function MNNGemmHybridInt4FP16_sdot - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt4_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] -ldr x14, [x7, #40] -Start: -lsl x13, x3, #5 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_quad * 32 = src_depth_quad << 5 -ld1 {v6.16b, v7.16b}, [x14] -// mask -movi v14.16b, #15 -TILE_4: - cmp x6, #4 - blt TILE_1 - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr -LoopSz_TILE_4: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-23 - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - ld1 {v4.16b, v5.16b}, [x24], x15 // src - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v14.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v14.16b - - mov v10.d[0], v4.d[1] - mov v10.d[1], v4.d[0] - mov v11.d[1], v5.d[0] - mov v11.d[0], v5.d[1] - .inst 0x4e809490 // sdot v16.4s, v4.16b, v0.16b // (0,0)x2 (1,1)x2 - .inst 0x4e809558 // sdot v24.4s, v10.16b, v0.16b // (1,0)x2 (0,1)x2 - .inst 0x4e819491 // sdot v17.4s, v4.16b, v1.16b // (0,2) (1,3) - .inst 0x4e819559 // sdot v25.4s, v10.16b, v1.16b // (1,2) (0,3) - .inst 0x4e829492 // sdot v18.4s, v4.16b, v2.16b - .inst 0x4e82955a // sdot v26.4s, v10.16b, v2.16b - .inst 0x4e839493 // sdot v19.4s, v4.16b, v3.16b - .inst 0x4e83955b // sdot v27.4s, v10.16b, v3.16b - .inst 0x4e8094b4 // sdot v20.4s, v5.16b, v0.16b - .inst 0x4e80957c // sdot v28.4s, v11.16b, v0.16b - .inst 0x4e8194b5 // sdot v21.4s, v5.16b, v1.16b - .inst 0x4e81957d // sdot v29.4s, v11.16b, v1.16b - .inst 0x4e8294b6 // sdot v22.4s, v5.16b, v2.16b - .inst 0x4e82957e // sdot v30.4s, v11.16b, v2.16b - .inst 0x4e8394b7 // sdot v23.4s, v5.16b, v3.16b - .inst 0x4e83957f // sdot v31.4s, v11.16b, v3.16b - - subs x26, x26, #1 - bne LoopSz_TILE_4 - - addp v16.4s, v16.4s, v24.4s // (batch,oc)(0,0)(1,1)(1,0)(0,1) - addp v17.4s, v17.4s, v25.4s // (0,2)(1,3)(1,2)(0,3) - addp v18.4s, v18.4s, v26.4s // (0,4)(1,5)(1,4)(0,5) - addp v19.4s, v19.4s, v27.4s // (0,6)(1,7)(1,6)(0,7) - addp v20.4s, v20.4s, v28.4s - addp v21.4s, v21.4s, v29.4s - addp v22.4s, v22.4s, v30.4s - addp v23.4s, v23.4s, v31.4s - tbl v24.16b, {v16.16b, v17.16b}, v6.16b // batch=0,oc=0-3 - tbl v25.16b, {v16.16b, v17.16b}, v7.16b // batch=1,oc=0-3 - tbl v26.16b, {v18.16b, v19.16b}, v6.16b // batch=0,oc=4-7 - tbl v27.16b, {v18.16b, v19.16b}, v7.16b // batch=1,oc=4-7 - tbl v28.16b, {v20.16b, v21.16b}, v6.16b - tbl v29.16b, {v20.16b, v21.16b}, v7.16b - tbl v30.16b, {v22.16b, v23.16b}, v6.16b - tbl v31.16b, {v22.16b, v23.16b}, v7.16b - -LoopSzEnd_TILE_4: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v24, v25, v26, v27 - Int32ToFloat v28, v29, v30, v31 - // using float scale dequant for precison - ld1 {v1.d}[0], [x23] // scales 4 batch - ld1 {v2.8h}, [x19], #16 // alpha - - fcvtl v3.4s, v2.4h // oc:0-3 - fcvtl2 v4.4s, v2.8h // oc:4-7 - fcvtl v5.4s, v1.4h // scales: 4 batch - - MulScale v24, v26, v25, v27, v5, 0, 1, v3, v4 - MulScale v28, v30, v29, v31, v5, 2, 3, v3, v4 - Float32ToHalf v24, v26, v25, v27, v10, v11 - Float32ToHalf v28, v30, v29, v31, v12, v13 -Tile4Dequant: - ld1 {v1.8h}, [x20], #16 // zero - ld1 {v2.8h}, [x21], #16 // bias - ld1 {v3.d}[0], [x22] // sums - // sum + (zero * sumx) + bias - Dequant v10, v1, v2, v3, 0 - Dequant v11, v1, v2, v3, 1 - Dequant v12, v1, v2, v3, 2 - Dequant v13, v1, v2, v3, 3 - st1 {v10.8h, v11.8h, v12.8h, v13.8h}, [x28], x4 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #64 // dst += 4 * 8 * sizeof(float16_t) - add x1, x1, #32 // src += 4 * 8 * sizeof(int8_t) - add x11, x11, #8 // sum += 4 * sizeof(float16_t) - add x12, x12, #8 // scale += 4 * sizeof(float16_t) - b TILE_4 - -TILE_1: - cmp x6, #1 - blt End - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - movi v24.4s, #0 - movi v25.4s, #0 - movi v26.4s, #0 - movi v27.4s, #0 - movi v10.4s, #0 - movi v11.4s, #0 - movi v12.4s, #0 - movi v13.4s, #0 -LoopSz_TILE_1: - // src : 1 x [1 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [2] : v16-v19 - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - ld1 {v4.8b}, [x24], x15 // src - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v14.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v14.16b - - mov v29.d[0], v4.d[1] - mov v29.d[1], v4.d[0] - - .inst 0x4e809498 // sdot v24.4s, v4.16b, v0.16b // (0,0)x2 (1,1)x2 - .inst 0x4e8097b9 // sdot v25.4s, v29.16b, v0.16b // (1,0)x2 (0,1)x2 - .inst 0x4e81949a // sdot v26.4s, v4.16b, v1.16b // (0,2)x2 (1,3)x2 - .inst 0x4e8197bb // sdot v27.4s, v29.16b, v1.16b // (1,2)x2 (0,3)x2 - .inst 0x4e82948a // sdot v10.4s, v4.16b, v2.16b // (0,4)x2 (1,5)x2 - .inst 0x4e8297ab // sdot v11.4s, v29.16b, v2.16b // (1,4)x2 (0,5)x2 - .inst 0x4e83948c // sdot v12.4s, v4.16b, v3.16b // (0,6)x2 (1,7)x2 - .inst 0x4e8397ad // sdot v13.4s, v29.16b, v3.16b // (1,6)x2 (0,7)x2 - - subs x26, x26, #1 - bne LoopSz_TILE_1 - addp v16.4s, v24.4s, v25.4s - addp v17.4s, v26.4s, v27.4s - addp v18.4s, v10.4s, v11.4s - addp v19.4s, v12.4s, v13.4s - -LoopSzEnd_TILE_1: - add x7, x7, x13 - sub x27, x27, #1 - tbl v24.16b, {v16.16b, v17.16b}, v6.16b - tbl v20.16b, {v18.16b, v19.16b}, v6.16b - - scvtf v24.4s, v24.4s - scvtf v20.4s, v20.4s - // using float scale dequant for precison - ld1 {v4.h}[0], [x23] // scales - ld1 {v0.8h}, [x19], #16 // alpha - fcvtl v5.4s, v4.4h - fcvtl v22.4s, v0.4h - fcvtl2 v21.4s, v0.8h - fmul v24.4s, v24.4s, v5.s[0] - fmul v20.4s, v20.4s, v5.s[0] - fmul v24.4s, v24.4s, v22.4s - fmul v20.4s, v20.4s, v21.4s - fcvtn v17.4h, v24.4s - fcvtn2 v17.8h, v20.4s -Tile1Dequant: - ld1 {v1.8h}, [x20], #16 // zero - ld1 {v2.8h}, [x21], #16 // bias - ld1 {v3.h}[0], [x22] // sums - // sum + (zero * sumx) + bias - fadd v2.8h, v2.8h, v17.8h - fmla v2.8h, v1.8h, v3.h[0] - st1 {v2.8h}, [x28], x4 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - sub x6, x6, #1 // batch -= 1 - add x0, x0, #16 // dst += 1 * 8 * sizeof(float16_t) - add x1, x1, #8 // dst += 1 * 8 * sizeof(int8_t) - add x11, x11, #2 // sum += 1 * sizeof(float16_t) - add x12, x12, #2 // scale += 1 * sizeof(float16_t) - b TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif \ No newline at end of file diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_smmla.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_smmla.S deleted file mode 100644 index 62a3053ee..000000000 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt4FP16_smmla.S +++ /dev/null @@ -1,506 +0,0 @@ -// -// MNNGemmHybridInt4_smmla.S -// MNN -// -// Created by MNN on 2023/10/30. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro Float32ToHalf s0, s1, s2, s3, d0, d1 - fcvtn \d0\().4h, \s0\().4s - fcvtn2 \d0\().8h, \s1\().4s - fcvtn \d1\().4h, \s2\().4s - fcvtn2 \d1\().8h, \s3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s, idx0, idx1, alpha0, alpha1 - fmul \d0\().4s, \d0\().4s, \s\().s[\idx0] - fmul \d1\().4s, \d1\().4s, \s\().s[\idx0] - fmul \d2\().4s, \d2\().4s, \s\().s[\idx1] - fmul \d3\().4s, \d3\().4s, \s\().s[\idx1] - fmul \d0\().4s, \d0\().4s, \alpha0\().4s - fmul \d1\().4s, \d1\().4s, \alpha1\().4s - fmul \d2\().4s, \d2\().4s, \alpha0\().4s - fmul \d3\().4s, \d3\().4s, \alpha1\().4s -.endm - -.macro MulScale_New d0, d1, d2, d3, s, a1, a2, a3, a4 - fmul \d0\().4s, \d0\().4s, \s\().4s - fmul \d1\().4s, \d1\().4s, \s\().4s - fmul \d2\().4s, \d2\().4s, \s\().4s - fmul \d3\().4s, \d3\().4s, \s\().4s - fmul \d0\().4s, \d0\().4s, \a1\().4s - fmul \d1\().4s, \d1\().4s, \a2\().4s - fmul \d2\().4s, \d2\().4s, \a3\().4s - fmul \d3\().4s, \d3\().4s, \a4\().4s -.endm - -.macro Dequant c0, z0, b0, s0, idx - fmla \c0\().8h, \z0\().8h, \s0\().h[\idx] - fadd \c0\().8h, \c0\().8h, \b0\().8h -.endm - -asm_function MNNGemmHybridInt4FP16_smmla - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt4_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] - -Start: -lsl x13, x3, #5 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_quad * 32 = src_depth_quad << 5 -// mask -movi v10.16b, #15 -// offset -movi v11.16b, #8 -TILE_8: - cmp x6, #8 - blt TILE_4 - //mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - sub x14, x4, #64 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_8: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr - - ld1 {v14.8h}, [x23] // scales - ld1 {v15.8h}, [x19], #16 // alpha - -LoopSz_TILE_8: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-31 - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x24], x15 // src - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v10.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v10.16b - - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b // batch=0,1, oc=0,1 - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b // batch=0,1, oc=2,3 - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b // batch=0,1, oc=4,5 - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b // batch=0,1, oc=6,7 - .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b // batch=2,3, oc=0,1 - .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b // batch=2,3, oc=2,3 - .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b // batch=2,3, oc=4,5 - .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b // batch=2,3, oc=6,7 - - .inst 0x4e80a4d8 // smmla v24.4s, v6.16b, v0.16b // batch=4,5, oc=0,1 - .inst 0x4e81a4d9 // smmla v25.4s, v6.16b, v1.16b // batch=4,5, oc=2,3 - .inst 0x4e82a4da // smmla v26.4s, v6.16b, v2.16b // batch=4,5, oc=4,5 - .inst 0x4e83a4db // smmla v27.4s, v6.16b, v3.16b // batch=4,5, oc=6,7 - .inst 0x4e80a4fc // smmla v28.4s, v7.16b, v0.16b // batch=6,7, oc=0,1 - .inst 0x4e81a4fd // smmla v29.4s, v7.16b, v1.16b // batch=6,7, oc=2,3 - .inst 0x4e82a4fe // smmla v30.4s, v7.16b, v2.16b // batch=6,7, oc=4,5 - .inst 0x4e83a4ff // smmla v31.4s, v7.16b, v3.16b // batch=6,7, oc=6,7 - subs x26, x26, #1 - bne LoopSz_TILE_8 - -LoopSzEnd_TILE_8: - add x7, x7, x13 - fcvtl v8.4s, v15.4h // oc:0-3 - fcvtl2 v9.4s, v15.8h // oc:4-7 - fcvtl v12.4s, v14.4h // scales: batch 0,1,2,3 - fcvtl2 v13.4s, v14.8h // scales: batch 4,5,6,7 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - Int32ToFloat v20, v21, v22, v23 - Int32ToFloat v24, v25, v26, v27 - Int32ToFloat v28, v29, v30, v31 - - zip1 v0.4s, v12.4s, v12.4s // scales: batch 0,0,1,1 - zip2 v1.4s, v12.4s, v12.4s // scales: batch 2,2,3,3 - zip1 v2.4s, v13.4s, v13.4s // scales: batch 4,4,5,5 - zip2 v3.4s, v13.4s, v13.4s // scales: batch 6,6,7,7 - trn1 v4.2d, v8.2d, v8.2d // alpha: oc 0,1,0,1 - trn2 v5.2d, v8.2d, v8.2d // alpha: oc 2,3,2,3 - trn1 v6.2d, v9.2d, v9.2d // alpha: oc 4,5,4,5 - trn2 v7.2d, v9.2d, v9.2d // alpha: oc 6,7,6,7 - - MulScale_New v16, v17, v18, v19, v0, v4, v5, v6, v7 - MulScale_New v20, v21, v22, v23, v1, v4, v5, v6, v7 - MulScale_New v24, v25, v26, v27, v2, v4, v5, v6, v7 - MulScale_New v28, v29, v30, v31, v3, v4, v5, v6, v7 - Float32ToHalf v16, v17, v18, v19, v0, v1 // (batch,oc) v12:(0,0)(0,1)(1,0)(1,1)(0,2)(0,3)(1,3)(1,2) - Float32ToHalf v20, v21, v22, v23, v12, v13 // batch=2,3 v14:(2,0)(2,1)(3,0)(3,1)(2,2)(2,3)(3,3)(3,2) - Float32ToHalf v24, v25, v26, v27, v14, v15 // batch=4,5 - Float32ToHalf v28, v29, v30, v31, v8, v9 // batch=6,7 - - uzp1 v4.4s, v0.4s, v1.4s - uzp2 v5.4s, v0.4s, v1.4s - uzp1 v6.4s, v12.4s, v13.4s - uzp2 v7.4s, v12.4s, v13.4s - uzp1 v0.4s, v14.4s, v15.4s - uzp2 v1.4s, v14.4s, v15.4s - uzp1 v2.4s, v8.4s, v9.4s - uzp2 v3.4s, v8.4s, v9.4s -Tile8Dequant: - ld1 {v16.8h}, [x20], #16 // zero - ld1 {v17.8h}, [x21], #16 // bias - ld1 {v12.8h}, [x22] // sums - // sum + (zero * sumx) + bias - Dequant v4, v16, v17, v12, 0 - Dequant v5, v16, v17, v12, 1 - Dequant v6, v16, v17, v12, 2 - Dequant v7, v16, v17, v12, 3 - - Dequant v0, v16, v17, v12, 4 - Dequant v1, v16, v17, v12, 5 - Dequant v2, v16, v17, v12, 6 - Dequant v3, v16, v17, v12, 7 - st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x28], #64 - st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_8 -Tile8End: - sub x6, x6, #8 // bach -= 8 - add x0, x0, #128 // dst += 8 * 8 * sizeof(float16_t) - add x1, x1, #64 // src += 8 * 8 * sizeof(int8_t) - add x11, x11, #16 // sum += 8 * sizeof(float16_t) - add x12, x12, #16 // scale += 8 * sizeof(float16_t) - b TILE_8 - -TILE_4: - cmp x6, #4 - blt TILE_2 - mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - ld1 {v14.d}[0], [x23] // scales - ld1 {v15.8h}, [x19], #16 // alpha -LoopSz_TILE_4: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-23 - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - ld1 {v4.16b, v5.16b}, [x24], x15 // src - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v10.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v10.16b - - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b - .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b - .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b - .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b - .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_4 - -LoopSzEnd_TILE_4: - add x7, x7, x13 - fcvtl v8.4s, v15.4h // oc:0-3 - fcvtl2 v9.4s, v15.8h // oc:4-7 - fcvtl v12.4s, v14.4h // scales: batch 0,1,2,3 - - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - Int32ToFloat v20, v21, v22, v23 - - zip1 v0.4s, v12.4s, v12.4s // scales: batch 0,0,1,1 - zip2 v1.4s, v12.4s, v12.4s // scales: batch 2,2,3,3 - trn1 v4.2d, v8.2d, v8.2d // alpha: oc 0,1,0,1 - trn2 v5.2d, v8.2d, v8.2d // alpha: oc 2,3,2,3 - trn1 v6.2d, v9.2d, v9.2d // alpha: oc 4,5,4,5 - trn2 v7.2d, v9.2d, v9.2d // alpha: oc 6,7,6,7 - - MulScale_New v16, v17, v18, v19, v0, v4, v5, v6, v7 - MulScale_New v20, v21, v22, v23, v1, v4, v5, v6, v7 - Float32ToHalf v16, v17, v18, v19, v0, v1 // (batch,oc) v12:(0,0)(0,1)(1,0)(1,1)(0,2)(0,3)(1,3)(1,2) - Float32ToHalf v20, v21, v22, v23, v12, v13 // batch=2,3 v14:(2,0)(2,1)(3,0)(3,1)(2,2)(2,3)(3,3)(3,2) - - uzp1 v4.4s, v0.4s, v1.4s - uzp2 v5.4s, v0.4s, v1.4s - uzp1 v6.4s, v12.4s, v13.4s - uzp2 v7.4s, v12.4s, v13.4s -Tile4Dequant: - ld1 {v16.8h}, [x20], #16 // zero - ld1 {v17.8h}, [x21], #16 // bias - ld1 {v12.d}[0], [x22] // sums - // sum + (zero * sumx) + bias - Dequant v4, v16, v17, v12, 0 - Dequant v5, v16, v17, v12, 1 - Dequant v6, v16, v17, v12, 2 - Dequant v7, v16, v17, v12, 3 - st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #64 // dst += 4 * 8 * sizeof(float16_t) - add x1, x1, #32 // src += 4 * 8 * sizeof(int8_t) - add x11, x11, #8 // sum += 4 * sizeof(float16_t) - add x12, x12, #8 // scale += 4 * sizeof(float16_t) - b TILE_4 - -TILE_2: - cmp x6, #2 - blt TILE_1 - mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_2: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - ld1 {v14.s}[0], [x23] // scales - ld1 {v15.8h}, [x19], #16 // alpha -LoopSz_TILE_2: - // src : 1 x [2 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [4] : v16-19 - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - ld1 {v4.16b}, [x24], x15 // src - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v10.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v10.16b - - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_2 - -LoopSzEnd_TILE_2: - add x7, x7, x13 - fcvtl v8.4s, v15.4h // oc:0-3 - fcvtl2 v9.4s, v15.8h // oc:4-7 - fcvtl v12.4s, v14.4h // scales: batch 0,1 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - zip1 v0.4s, v12.4s, v12.4s // scales: batch 0,0,1,1 - trn1 v4.2d, v8.2d, v8.2d // alpha: oc 0,1,0,1 - trn2 v5.2d, v8.2d, v8.2d // alpha: oc 2,3,2,3 - trn1 v6.2d, v9.2d, v9.2d // alpha: oc 4,5,4,5 - trn2 v7.2d, v9.2d, v9.2d // alpha: oc 6,7,6,7 - MulScale_New v16, v17, v18, v19, v0, v4, v5, v6, v7 - Float32ToHalf v16, v17, v18, v19, v0, v1 // (batch,oc) v12:(0,0)(0,1)(1,0)(1,1)(0,2)(0,3)(1,3)(1,2) - - uzp1 v4.4s, v0.4s, v1.4s - uzp2 v5.4s, v0.4s, v1.4s - -Tile2Dequant: - ld1 {v16.8h}, [x20], #16 // zero - ld1 {v17.8h}, [x21], #16 // bias - ld1 {v12.s}[0], [x22] // sums - // sum + (zero * sumx) + bias - Dequant v4, v16, v17, v12, 0 - Dequant v5, v16, v17, v12, 1 - st1 {v4.8h, v5.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_2 -Tile2End: - sub x6, x6, #2 // batch -= 2 - add x0, x0, #32 // dst += 2 * 8 * sizeof(float16_t) - add x1, x1, #16 // dst += 2 * 8 * sizeof(int8_t) - add x11, x11, #4 // sum += 2 * sizeof(float16_t) - add x12, x12, #4 // scale += 2 * sizeof(float16_t) - b TILE_2 - - -TILE_1: - cmp x6, #1 - blt End - mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - ld1 {v14.d}[0], [x23] // scales - ld1 {v15.8h}, [x19], #16 // alpha -LoopSz_TILE_1: - // src : 1 x [1 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [2] : v16-v19 - prfm pldl1keep, [x25, #64] // 预取下一次权重数据 - prfm pldl1keep, [x24, x15] // 预取下一次源数据 - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - ld1 {v4.8b}, [x24], x15 // src - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v10.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v10.16b - - .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b - .inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b - .inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b - .inst 0x4e84a473 // smmla v19.4s, v3.16b, v4.16b - subs x26, x26, #1 - bne LoopSz_TILE_1 - -LoopSzEnd_TILE_1: - add x7, x7, x13 - sub x27, x27, #1 - uzp1 v20.4s, v16.4s, v17.4s - uzp1 v21.4s, v18.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s - // using float scale dequant for precison - fcvtl v28.4s, v15.4h // oc:0-3 - fcvtl2 v29.4s, v15.8h // oc:4-7 - fcvtl v12.4s, v14.4h // scales: batch 0 - - fmul v20.4s, v20.4s, v12.s[0] - fmul v21.4s, v21.4s, v12.s[0] - fmul v20.4s, v20.4s, v28.4s - fmul v21.4s, v21.4s, v29.4s - fcvtn v17.4h, v20.4s - fcvtn2 v17.8h, v21.4s -Tile1Dequant: - - ld1 {v1.8h}, [x20], #16 // zero - ld1 {v2.8h}, [x21], #16 // bias - ld1 {v3.h}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - fadd v2.8h, v2.8h, v17.8h - fmla v2.8h, v1.8h, v3.h[0] - st1 {v2.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - sub x6, x6, #1 // batch -= 1 - add x0, x0, #16 // dst += 1 * 8 * sizeof(float16_t) - add x1, x1, #8 // dst += 1 * 8 * sizeof(int8_t) - add x11, x11, #2 // sum += 1 * sizeof(float16_t) - add x12, x12, #2 // scale += 1 * sizeof(float16_t) - b TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_sdot.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_sdot.S deleted file mode 100644 index d675b79e8..000000000 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_sdot.S +++ /dev/null @@ -1,303 +0,0 @@ -// -// MNNGemmHybridInt8_sdot.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s, idx0, idx1, alpha0, alpha1 - fmul \d0\().4s, \d0\().4s, \s\().s[\idx0] - fmul \d1\().4s, \d1\().4s, \s\().s[\idx0] - fmul \d2\().4s, \d2\().4s, \s\().s[\idx1] - fmul \d3\().4s, \d3\().4s, \s\().s[\idx1] - fmul \d0\().4s, \d0\().4s, \alpha0\().4s - fmul \d1\().4s, \d1\().4s, \alpha1\().4s - fmul \d2\().4s, \d2\().4s, \alpha0\().4s - fmul \d3\().4s, \d3\().4s, \alpha1\().4s -.endm - -.macro Float32ToHalf s0, s1, s2, s3, d0, d1 - fcvtn \d0\().4h, \s0\().4s - fcvtn2 \d0\().8h, \s1\().4s - fcvtn \d1\().4h, \s2\().4s - fcvtn2 \d1\().8h, \s3\().4s -.endm - -.macro Dequant c0, z0, b0, s0, idx - fmla \c0\().8h, \z0\().8h, \s0\().h[\idx] - fadd \c0\().8h, \c0\().8h, \b0\().8h -.endm - -asm_function MNNGemmHybridInt8FP16_sdot - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt8_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] -ldr x14, [x7, #40] -Start: -lsl x13, x3, #6 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 64 = src_depth_quad << 6 -ld1 {v6.16b, v7.16b}, [x14] -TILE_4: - cmp x6, #4 - blt TILE_1 - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr -LoopSz_TILE_4: - // v0: oc=0,1 - // v1: oc=2,3 - // v2: oc=4,5 - // v3: oc=6,7 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight, oc=0-7 - // v4:n=0,1, v5:n=2,3 - // v10:n=1,0, v11:n=3,2 - ld1 {v4.16b, v5.16b}, [x24], x15 // src batch=0,1,2,3 - mov v10.d[0], v4.d[1] // v10:n=1,0 - mov v10.d[1], v4.d[0] - mov v11.d[1], v5.d[0] - mov v11.d[0], v5.d[1] - .inst 0x4e809490 // sdot v16.4s, v4.16b, v0.16b // (0,0)x2 (1,1)x2 - .inst 0x4e809558 // sdot v24.4s, v10.16b, v0.16b // (1,0)x2 (0,1)x2 - .inst 0x4e819491 // sdot v17.4s, v4.16b, v1.16b // (0,2) (1,3) - .inst 0x4e819559 // sdot v25.4s, v10.16b, v1.16b // (1,2) (0,3) - .inst 0x4e829492 // sdot v18.4s, v4.16b, v2.16b - .inst 0x4e82955a // sdot v26.4s, v10.16b, v2.16b - .inst 0x4e839493 // sdot v19.4s, v4.16b, v3.16b - .inst 0x4e83955b // sdot v27.4s, v10.16b, v3.16b - .inst 0x4e8094b4 // sdot v20.4s, v5.16b, v0.16b - .inst 0x4e80957c // sdot v28.4s, v11.16b, v0.16b - .inst 0x4e8194b5 // sdot v21.4s, v5.16b, v1.16b - .inst 0x4e81957d // sdot v29.4s, v11.16b, v1.16b - .inst 0x4e8294b6 // sdot v22.4s, v5.16b, v2.16b - .inst 0x4e82957e // sdot v30.4s, v11.16b, v2.16b - .inst 0x4e8394b7 // sdot v23.4s, v5.16b, v3.16b - .inst 0x4e83957f // sdot v31.4s, v11.16b, v3.16b - - subs x26, x26, #1 - bne LoopSz_TILE_4 - - addp v16.4s, v16.4s, v24.4s // (batch,oc)(0,0)(1,1)(1,0)(0,1) - addp v17.4s, v17.4s, v25.4s // (0,2)(1,3)(1,2)(0,3) - addp v18.4s, v18.4s, v26.4s // (0,4)(1,5)(1,4)(0,5) - addp v19.4s, v19.4s, v27.4s // (0,6)(1,7)(1,6)(0,7) - addp v20.4s, v20.4s, v28.4s - addp v21.4s, v21.4s, v29.4s - addp v22.4s, v22.4s, v30.4s - addp v23.4s, v23.4s, v31.4s - tbl v24.16b, {v16.16b, v17.16b}, v6.16b // batch=0,oc=0-3 - tbl v25.16b, {v16.16b, v17.16b}, v7.16b // batch=1,oc=0-3 - tbl v26.16b, {v18.16b, v19.16b}, v6.16b // batch=0,oc=4-7 - tbl v27.16b, {v18.16b, v19.16b}, v7.16b // batch=1,oc=4-7 - tbl v28.16b, {v20.16b, v21.16b}, v6.16b - tbl v29.16b, {v20.16b, v21.16b}, v7.16b - tbl v30.16b, {v22.16b, v23.16b}, v6.16b - tbl v31.16b, {v22.16b, v23.16b}, v7.16b - -LoopSzEnd_TILE_4: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v24, v25, v26, v27 - Int32ToFloat v28, v29, v30, v31 - // using float scale dequant for precison - ld1 {v1.d}[0], [x23] // scales 4 batch - ld1 {v2.8h}, [x19], #16 // alpha - - fcvtl v3.4s, v2.4h // oc:0-3 - fcvtl2 v4.4s, v2.8h // oc:4-7 - fcvtl v5.4s, v1.4h // scales: 4 batch - - MulScale v24, v26, v25, v27, v5, 0, 1, v3, v4 - MulScale v28, v30, v29, v31, v5, 2, 3, v3, v4 - Float32ToHalf v24, v26, v25, v27, v12, v13 - Float32ToHalf v28, v30, v29, v31, v14, v15 -Tile4Dequant: - ld1 {v1.8h}, [x20], #16 // zero - ld1 {v2.8h}, [x21], #16 // bias - ld1 {v3.d}[0], [x22] // sums - // sum + (zero * sumx) + bias - Dequant v12, v1, v2, v3, 0 - Dequant v13, v1, v2, v3, 1 - Dequant v14, v1, v2, v3, 2 - Dequant v15, v1, v2, v3, 3 - st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x28], x4 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #64 // dst += 4 * 8 * sizeof(float16_t) - add x1, x1, #32 // src += 4 * 8 * sizeof(int8_t) - add x11, x11, #8 // sum += 4 * sizeof(float16_t) - add x12, x12, #8 // scale += 4 * sizeof(float16_t) - b TILE_4 - -TILE_1: - cmp x6, #1 - blt End - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - movi v24.4s, #0 - movi v25.4s, #0 - movi v26.4s, #0 - movi v27.4s, #0 - movi v10.4s, #0 - movi v11.4s, #0 - movi v12.4s, #0 - movi v13.4s, #0 -LoopSz_TILE_1: - // src : 1 x [1 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [2] : v16-v19 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.8b}, [x24], x15 // src - mov v29.d[0], v4.d[1] - mov v29.d[1], v4.d[0] - - .inst 0x4e809498 // sdot v24.4s, v4.16b, v0.16b // (0,0)x2 (1,1)x2 - .inst 0x4e8097b9 // sdot v25.4s, v29.16b, v0.16b // (1,0)x2 (0,1)x2 - .inst 0x4e81949a // sdot v26.4s, v4.16b, v1.16b // (0,2)x2 (1,3)x2 - .inst 0x4e8197bb // sdot v27.4s, v29.16b, v1.16b // (1,2)x2 (0,3)x2 - .inst 0x4e82948a // sdot v10.4s, v4.16b, v2.16b // (0,4)x2 (1,5)x2 - .inst 0x4e8297ab // sdot v11.4s, v29.16b, v2.16b // (1,4)x2 (0,5)x2 - .inst 0x4e83948c // sdot v12.4s, v4.16b, v3.16b // (0,6)x2 (1,7)x2 - .inst 0x4e8397ad // sdot v13.4s, v29.16b, v3.16b // (1,6)x2 (0,7)x2 - - subs x26, x26, #1 - bne LoopSz_TILE_1 - addp v16.4s, v24.4s, v25.4s - addp v17.4s, v26.4s, v27.4s - addp v18.4s, v10.4s, v11.4s - addp v19.4s, v12.4s, v13.4s - -LoopSzEnd_TILE_1: - add x7, x7, x13 - sub x27, x27, #1 - tbl v15.16b, {v16.16b, v17.16b}, v6.16b - tbl v20.16b, {v18.16b, v19.16b}, v6.16b - - scvtf v15.4s, v15.4s - scvtf v20.4s, v20.4s - // using float scale dequant for precison - ld1 {v4.h}[0], [x23] // scales - ld1 {v0.8h}, [x19], #16 // alpha - fcvtl v5.4s, v4.4h - fcvtl v22.4s, v0.4h - fcvtl2 v21.4s, v0.8h - fmul v15.4s, v15.4s, v5.s[0] - fmul v20.4s, v20.4s, v5.s[0] - fmul v15.4s, v15.4s, v22.4s - fmul v20.4s, v20.4s, v21.4s - fcvtn v17.4h, v15.4s - fcvtn2 v17.8h, v20.4s -Tile1Dequant: - ld1 {v1.8h}, [x20], #16 // zero - ld1 {v2.8h}, [x21], #16 // bias - ld1 {v3.h}[0], [x22] // sums - // sum + (zero * sumx) + bias - fadd v2.8h, v2.8h, v17.8h - fmla v2.8h, v1.8h, v3.h[0] - st1 {v2.8h}, [x28], x4 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - sub x6, x6, #1 // batch -= 1 - add x0, x0, #16 // dst += 1 * 8 * sizeof(float16_t) - add x1, x1, #8 // dst += 1 * 8 * sizeof(int8_t) - add x11, x11, #2 // sum += 1 * sizeof(float16_t) - add x12, x12, #2 // scale += 1 * sizeof(float16_t) - b TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_smmla.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_smmla.S deleted file mode 100644 index 5f339725c..000000000 --- a/source/backend/arm82/asm/arm64/low_memory/MNNGemmHybridInt8FP16_smmla.S +++ /dev/null @@ -1,566 +0,0 @@ -// -// MNNGemmHybridInt8_smmla.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s, idx0, idx1, alpha0, alpha1 - fmul \d0\().4s, \d0\().4s, \s\().s[\idx0] - fmul \d1\().4s, \d1\().4s, \s\().s[\idx0] - fmul \d2\().4s, \d2\().4s, \s\().s[\idx1] - fmul \d3\().4s, \d3\().4s, \s\().s[\idx1] - fmul \d0\().4s, \d0\().4s, \alpha0\().4s - fmul \d1\().4s, \d1\().4s, \alpha1\().4s - fmul \d2\().4s, \d2\().4s, \alpha0\().4s - fmul \d3\().4s, \d3\().4s, \alpha1\().4s -.endm - -.macro Float32ToHalf s0, s1, s2, s3, d0, d1 - fcvtn \d0\().4h, \s0\().4s - fcvtn2 \d0\().8h, \s1\().4s - fcvtn \d1\().4h, \s2\().4s - fcvtn2 \d1\().8h, \s3\().4s -.endm - -.macro Dequant c0, z0, b0, s0, idx - fmla \c0\().8h, \z0\().8h, \s0\().h[\idx] - fadd \c0\().8h, \c0\().8h, \b0\().8h -.endm - -asm_function MNNGemmHybridInt8FP16_smmla - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt8_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] - -Start: -lsl x13, x3, #6 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 64 = src_depth_quad << 6 -cmp x6, #1 -beq TILE_EQ_1 - -TILE_8: - cmp x6, #8 - blt TILE_4 - //mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - sub x14, x4, #64 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_8: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr -LoopSz_TILE_8: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-23 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x24], x15 // src - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b // batch=0,1, oc=0,1 - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b // batch=0,1, oc=2,3 - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b // batch=0,1, oc=4,5 - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b // batch=0,1, oc=6,7 - .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b // batch=2,3, oc=0,1 - .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b // batch=2,3, oc=2,3 - .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b // batch=2,3, oc=4,5 - .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b // batch=2,3, oc=6,7 - - .inst 0x4e80a4d8 // smmla v24.4s, v6.16b, v0.16b // batch=4,5, oc=0,1 - .inst 0x4e81a4d9 // smmla v25.4s, v6.16b, v1.16b // batch=4,5, oc=2,3 - .inst 0x4e82a4da // smmla v26.4s, v6.16b, v2.16b // batch=4,5, oc=4,5 - .inst 0x4e83a4db // smmla v27.4s, v6.16b, v3.16b // batch=4,5, oc=6,7 - .inst 0x4e80a4fc // smmla v28.4s, v7.16b, v0.16b // batch=6,7, oc=0,1 - .inst 0x4e81a4fd // smmla v29.4s, v7.16b, v1.16b // batch=6,7, oc=2,3 - .inst 0x4e82a4fe // smmla v30.4s, v7.16b, v2.16b // batch=6,7, oc=4,5 - .inst 0x4e83a4ff // smmla v31.4s, v7.16b, v3.16b // batch=6,7, oc=6,7 - subs x26, x26, #1 - bne LoopSz_TILE_8 - -LoopSzEnd_TILE_8: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - Int32ToFloat v20, v21, v22, v23 - Int32ToFloat v24, v25, v26, v27 - Int32ToFloat v28, v29, v30, v31 - // using float scale dequant for precison - trn1 v8.2d, v16.2d, v17.2d // batch=0,oc:0-3 - trn1 v9.2d, v18.2d, v19.2d // batch=0,oc:4-7 - trn2 v10.2d, v16.2d, v17.2d // batch=1,oc:0-3 - trn2 v11.2d, v18.2d, v19.2d // batch=1,oc:4-7 - trn1 v12.2d, v20.2d, v21.2d // batch=2,oc:0-3 - trn1 v13.2d, v22.2d, v23.2d // batch=2,oc:4-7 - trn2 v14.2d, v20.2d, v21.2d // batch=3,oc:0-3 - trn2 v15.2d, v22.2d, v23.2d // batch=3,oc:4-7 - - trn1 v0.2d, v24.2d, v25.2d // batch=4,oc:0-3 - trn1 v1.2d, v26.2d, v27.2d // batch=4,oc:4-7 - trn2 v2.2d, v24.2d, v25.2d // batch=5,oc:0-3 - trn2 v3.2d, v26.2d, v27.2d // batch=5,oc:4-7 - trn1 v4.2d, v28.2d, v29.2d // batch=6,oc:0-3 - trn1 v5.2d, v30.2d, v31.2d // batch=6,oc:4-7 - trn2 v6.2d, v28.2d, v29.2d // batch=7,oc:0-3 - trn2 v7.2d, v30.2d, v31.2d // batch=7,oc:4-7 - - ld1 {v16.8h}, [x23] // scales - ld1 {v17.8h}, [x19], #16 // alpha - fcvtl v18.4s, v17.4h // oc:0-3 - fcvtl2 v19.4s, v17.8h // oc:4-7 - fcvtl v28.4s, v16.4h // scales: batch 0,1,2,3 - fcvtl2 v29.4s, v16.8h // scales: batch 4,5,6,7 - - MulScale v8, v9, v10, v11, v28, 0, 1, v18, v19 - MulScale v12, v13, v14, v15, v28, 2, 3, v18, v19 - Float32ToHalf v8, v9, v10, v11, v20, v21 // batch=0,1 - Float32ToHalf v12, v13, v14, v15, v22, v23 // batch=2,3 - - MulScale v0, v1, v2, v3, v29, 0, 1, v18, v19 - MulScale v4, v5, v6, v7, v29, 2, 3, v18, v19 - Float32ToHalf v0, v1, v2, v3, v24, v25 // batch=4,5 - Float32ToHalf v4, v5, v6, v7, v26, v27 // batch=6,7 - -Tile8Dequant: - ld1 {v1.8h}, [x20], #16 // zero - ld1 {v2.8h}, [x21], #16 // bias - ld1 {v3.8h}, [x22] // sums - // sum + (zero * sumx) + bias - Dequant v20, v1, v2, v3, 0 - Dequant v21, v1, v2, v3, 1 - Dequant v22, v1, v2, v3, 2 - Dequant v23, v1, v2, v3, 3 - - Dequant v24, v1, v2, v3, 4 - Dequant v25, v1, v2, v3, 5 - Dequant v26, v1, v2, v3, 6 - Dequant v27, v1, v2, v3, 7 - st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x28], #64 - st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_8 -Tile8End: - sub x6, x6, #8 // bach -= 8 - add x0, x0, #128 // dst += 8 * 8 * sizeof(float16_t) - add x1, x1, #64 // src += 8 * 8 * sizeof(int8_t) - add x11, x11, #16 // sum += 8 * sizeof(float16_t) - add x12, x12, #16 // scale += 8 * sizeof(float16_t) - b TILE_8 - -TILE_4: - cmp x6, #4 - blt TILE_2 - mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr -LoopSz_TILE_4: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-23 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.16b, v5.16b}, [x24], x15 // src - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b // batch=0,1, oc=0,1 - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b // batch=0,1, oc=2,3 - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b // batch=0,1, oc=4,5 - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b // batch=0,1, oc=6,7 - .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b // batch=2,3, oc=0,1 - .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b // batch=2,3, oc=2,3 - .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b // batch=2,3, oc=4,5 - .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b // batch=2,3, oc=6,7 - subs x26, x26, #1 - bne LoopSz_TILE_4 - -LoopSzEnd_TILE_4: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v4.d}[0], [x23] // scales - ld1 {v31.8h}, [x19], #16 // alpha - fcvtl v29.4s, v31.4h // oc:0-3 - fcvtl2 v30.4s, v31.8h // oc:4-7 - trn1 v24.2d, v16.2d, v17.2d // batch=0,oc:0-3 - trn1 v25.2d, v18.2d, v19.2d // batch=0,oc:4-7 - trn2 v26.2d, v16.2d, v17.2d // batch=1,oc:0-3 - trn2 v27.2d, v18.2d, v19.2d // batch=1,oc:4-7 - trn1 v28.2d, v20.2d, v21.2d // batch=2,oc:0-3 - trn1 v6.2d, v22.2d, v23.2d // batch=2,oc:4-7 - trn2 v7.2d, v20.2d, v21.2d // batch=3,oc:0-3 - trn2 v8.2d, v22.2d, v23.2d // batch=3,oc:4-7 - - fcvtl v5.4s, v4.4h // scales: 4 batch - - MulScale v24, v25, v26, v27, v5, 0, 1, v29, v30 - MulScale v28, v6, v7, v8, v5, 2, 3, v29, v30 - Float32ToHalf v24, v25, v26, v27, v12, v13 - Float32ToHalf v28, v6, v7, v8, v14, v15 -Tile4Dequant: - ld1 {v1.8h}, [x20], #16 // zero - ld1 {v2.8h}, [x21], #16 // bias - ld1 {v3.d}[0], [x22] // sums - // sum + (zero * sumx) + bias - Dequant v12, v1, v2, v3, 0 - Dequant v13, v1, v2, v3, 1 - Dequant v14, v1, v2, v3, 2 - Dequant v15, v1, v2, v3, 3 - st1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #64 // dst += 4 * 8 * sizeof(float16_t) - add x1, x1, #32 // src += 4 * 8 * sizeof(int8_t) - add x11, x11, #8 // sum += 4 * sizeof(float16_t) - add x12, x12, #8 // scale += 4 * sizeof(float16_t) - b TILE_4 - -TILE_2: - cmp x6, #2 - blt TILE_1 - mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_2: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr -LoopSz_TILE_2: - // src : 1 x [2 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [4] : v16-19 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.16b}, [x24], x15 // src - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_2 - -LoopSzEnd_TILE_2: - add x7, x7, x13 - sub x27, x27, #1 - uzp1 v13.2d, v16.2d, v17.2d - uzp1 v14.2d, v18.2d, v19.2d - uzp2 v15.2d, v16.2d, v17.2d - uzp2 v16.2d, v18.2d, v19.2d - Int32ToFloat v13, v14, v15, v16 - // using float scale dequant for precison - ld1 {v4.s}[0], [x23] // scales - ld1 {v0.8h}, [x19], #16 // alpha - fcvtl v5.4s, v4.4h - fcvtl v20.4s, v0.4h - fcvtl2 v21.4s, v0.8h - MulScale v13, v14, v15, v16, v5, 0, 1, v20, v21 - fcvtn v11.4h, v13.4s - fcvtn2 v11.8h, v14.4s - fcvtn v12.4h, v15.4s - fcvtn2 v12.8h, v16.4s -Tile2Dequant: - //ld1 {v0.8h}, [x19], #16 // alpha - ld1 {v1.8h}, [x20], #16 // zero - ld1 {v2.8h}, [x21], #16 // bias - ld1 {v3.s}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - Dequant v11, v1, v2, v3, 0 - Dequant v12, v1, v2, v3, 1 - st1 {v11.8h, v12.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_2 -Tile2End: - sub x6, x6, #2 // batch -= 2 - add x0, x0, #32 // dst += 2 * 8 * sizeof(float16_t) - add x1, x1, #16 // dst += 2 * 8 * sizeof(int8_t) - add x11, x11, #4 // sum += 2 * sizeof(float16_t) - add x12, x12, #4 // scale += 2 * sizeof(float16_t) - b TILE_2 - - -TILE_1: - cmp x6, #1 - blt End - mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - ld1 {v29.8h}, [x20], #16 // zero - ld1 {v30.8h}, [x21], #16 // bias - ld1 {v8.h}[0], [x22] // sums - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - fmla v30.8h, v29.8h, v8.h[0] // bias + zero * sum - -LoopSz_TILE_1: - // src : 1 x [1 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [2] : v16-v19 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.8b}, [x24], x15 // src - .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b - .inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b - .inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b - .inst 0x4e84a473 // smmla v19.4s, v3.16b, v4.16b - - subs x26, x26, #1 - bne LoopSz_TILE_1 - -LoopSzEnd_TILE_1: - add x7, x7, x13 - sub x27, x27, #1 - uzp1 v22.4s, v16.4s, v17.4s - uzp1 v23.4s, v18.4s, v19.4s - scvtf v22.4s, v22.4s - scvtf v23.4s, v23.4s - // using float scale dequant for precison - ld1 {v4.h}[0], [x23] // scales - ld1 {v0.8h}, [x19], #16 // alpha - fcvtl v5.4s, v4.4h - fcvtl v20.4s, v0.4h - fcvtl2 v21.4s, v0.8h - - fmul v22.4s, v22.4s, v5.s[0] - fmul v23.4s, v23.4s, v5.s[0] - fmul v22.4s, v22.4s, v20.4s - fmul v23.4s, v23.4s, v21.4s - fcvtn v17.4h, v22.4s - fcvtn2 v17.8h, v23.4s -Tile1Dequant: - // sum + (zero * sumx) + bias - fadd v30.8h, v30.8h, v17.8h - st1 {v30.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - sub x6, x6, #1 // batch -= 1 - add x0, x0, #16 // dst += 1 * 8 * sizeof(float16_t) - add x1, x1, #8 // dst += 1 * 8 * sizeof(int8_t) - add x11, x11, #2 // sum += 1 * sizeof(float16_t) - add x12, x12, #2 // scale += 1 * sizeof(float16_t) - b TILE_1 -b End -TILE_EQ_1: - - mov x14, x4 // dst_step - lsr x15, x4, #1 // src_step = dst_step / 2 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - ld1 {v29.8h}, [x20], #16 // zero - ld1 {v30.8h}, [x21], #16 // bias - ld1 {v8.h}[0], [x22] // sums - // init - dup v14.4s, wzr - dup v15.4s, wzr - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - fmla v30.8h, v29.8h, v8.h[0] // bias + zero * sum - - -L2: -cmp x26, #2 -blt L1 -LoopSz_2: - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x25], #64 - ld1 {v8.16b}, [x24], #16 // src - sub x26, x26, #2 - - .inst 0x4e80a50e // smmla v14.4s, v8.16b, v0.16b // (N=0,OC=0) (N=0,OC=1) () () - .inst 0x4e81a50f // smmla v15.4s, v8.16b, v1.16b // (N=0,OC=2) (N=0,OC=3) () () - .inst 0x4e82a510 // smmla v16.4s, v8.16b, v2.16b // (N=0,OC=4) (N=0,OC=5) () () - .inst 0x4e83a511 // smmla v17.4s, v8.16b, v3.16b // (N=0,OC=6) (N=0,OC=7) () () - .inst 0x4e84a512 // smmla v18.4s, v8.16b, v4.16b - .inst 0x4e85a513 // smmla v19.4s, v8.16b, v5.16b - .inst 0x4e86a514 // smmla v20.4s, v8.16b, v6.16b - .inst 0x4e87a515 // smmla v21.4s, v8.16b, v7.16b - cmp x26, #2 - bge LoopSz_2 -L1: -cmp x26, #1 -blt LoopSzEnd -LoopSz_1: - // src : 1 x [1 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.8b}, [x24], x15 // src - .inst 0x4e80a48e // smmla v14.4s, v4.16b, v0.16b - .inst 0x4e81a48f // smmla v15.4s, v4.16b, v1.16b - .inst 0x4e82a490 // smmla v16.4s, v4.16b, v2.16b - .inst 0x4e83a491 // smmla v17.4s, v4.16b, v3.16b - - subs x26, x26, #1 - bne LoopSz_1 - -LoopSzEnd: - add x7, x7, x13 - sub x27, x27, #1 - - trn1 v26.2d, v14.2d, v15.2d - trn1 v27.2d, v16.2d, v17.2d - trn2 v28.2d, v18.2d, v19.2d - trn2 v29.2d, v20.2d, v21.2d - add v26.4s, v26.4s, v28.4s - add v27.4s, v27.4s, v29.4s - scvtf v26.4s, v26.4s - scvtf v27.4s, v27.4s - // using float scale dequant for precison - ld1 {v4.h}[0], [x23] // scales - ld1 {v0.8h}, [x19], #16 // alpha - fcvtl v5.4s, v4.4h - fcvtl v20.4s, v0.4h - fcvtl2 v21.4s, v0.8h - - fmul v26.4s, v26.4s, v5.s[0] - fmul v27.4s, v27.4s, v5.s[0] - fmul v26.4s, v26.4s, v20.4s - fmul v27.4s, v27.4s, v21.4s - fcvtn v17.4h, v26.4s - fcvtn2 v17.8h, v27.4s -Int8ToFP16: - // sum + (zero * sumx) + bias - fadd v30.8h, v30.8h, v17.8h - st1 {v30.8h}, [x28], x14 - cmp x27, #1 - bge LoopDz - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif \ No newline at end of file diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S new file mode 100644 index 000000000..143ec060a --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S @@ -0,0 +1,665 @@ +// +// MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16.S +// MNN +// +// Created by MNN on 2019/12/17. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#if defined(__aarch64__) +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro ADD_BIAS_FLOAT d0, d1, d2, d3, z0 + fadd \d0\().4s, \d0\().4s, \z0\().4s + fadd \d1\().4s, \d1\().4s, \z0\().4s + fadd \d2\().4s, \d2\().4s, \z0\().4s + fadd \d3\().4s, \d3\().4s, \z0\().4s +.endm + +.macro SET_BIAS d0, d1, d2, d3 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 +.endm +.macro Int32ToFloat z0, z1, z2, z3 + scvtf \z0\().4s, \z0\().4s + scvtf \z1\().4s, \z1\().4s + scvtf \z2\().4s, \z2\().4s + scvtf \z3\().4s, \z3\().4s +.endm +.macro MUL_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s + fmul \d3\().4s, \d3\().4s, \s\().4s +.endm +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm +.macro ReLU_FP16 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().8h, \s0\().8h, \z1\().8h + fmin \s1\().8h, \s1\().8h, \z1\().8h + fmin \s2\().8h, \s2\().8h, \z1\().8h + fmin \s3\().8h, \s3\().8h, \z1\().8h + fmax \s0\().8h, \s0\().8h, \z0\().8h + fmax \s1\().8h, \s1\().8h, \z0\().8h + fmax \s2\().8h, \s2\().8h, \z0\().8h + fmax \s3\().8h, \s3\().8h, \z0\().8h +.endm + +.macro Float32ToHalf s0, s1, s2, s3, d0, d1 + fcvtn \d0\().4h, \s0\().4s + fcvtn2 \d0\().8h, \s1\().4s + fcvtn \d1\().4h, \s2\().4s + fcvtn2 \d1\().8h, \s3\().4s +.endm + +asm_function MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16 +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + const float* extraScale = nullptr; +}; +*/ + +//void MNNGemmInt8AddBiasScale_ARMV82_Unit(int8_t* dst, const int8_t* src, +// const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, +// const QuanPostTreatParameters* parameters, size_t realDstCount); + +//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step +//x5:dst_depth_quad, x6: parameters, x7: realDstCount + +//Load from x6: x8: scale, x9: bias, x25: xKernelSum, x26: weightQuantBias, x23: fp32minmax, x27: blockNum +ldr x8, [x6, #0] +ldr x9, [x6, #8] +//ldr w12, [x6, #16] + +stp d14, d15, [sp, #(-16 * 10)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x21, x22, [sp, #(16 * 4)] +stp x19, x20, [sp, #(16 * 5)] +stp x27, x28, [sp, #(16 * 6)] +stp x25, x26, [sp, #(16 * 7)] +stp x23, x24, [sp, #(16 * 8)] + +ldr x25, [x6, #40] // xKernelSum +ldr x26, [x6, #48] // weightQuantBias +ldr x23, [x6, #56] // fp32minmax +ldr x27, [x6, #64] // blockNum + +//add x24, x23, #4 + +mov x21, #16 // sizeof(float16_t) * PACK +mul x27, x27, x3 +Start: +lsl x15, x27, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT +mov x22, #48 // src_steps +add x24, x15, x15 +ldr x27, [x6, #80] // extra scale +TILE_12: + cmp x7, #12 + blt TILE_8 + +L8LoopDz_TILE_12: + // ld1 {v0.4s, v1.4s}, [x9], #32 // bias + mov x11, x1 + mov x13, x3 + // Init 0 + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + SET_BIAS v20, v21, v22, v23 + SET_BIAS v24, v25, v26, v27 + SET_BIAS v28, v29, v30, v31 + + mov x28, x2 + L8LoopSz_TILE_12: + ld1 {v3.16b}, [x2], x15 // weight + ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + ld1 {v4.16b}, [x2], #16 + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] + .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] + .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] + .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] + .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] + .inst 0x4f82e872 // sdot v18.4s, v3.16b, v2.4b[2] + .inst 0x4fa2e873 // sdot v19.4s, v3.16b, v2.4b[3] + .inst 0x4f80e094 // sdot v20.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e095 // sdot v21.4s, v4.16b, v0.4b[1] + .inst 0x4f80e896 // sdot v22.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e897 // sdot v23.4s, v4.16b, v0.4b[3] + sub x2, x2, x15 + .inst 0x4f81e098 // sdot v24.4s, v4.16b, v1.4b[0] + .inst 0x4fa1e099 // sdot v25.4s, v4.16b, v1.4b[1] + .inst 0x4f81e89a // sdot v26.4s, v4.16b, v1.4b[2] + .inst 0x4fa1e89b // sdot v27.4s, v4.16b, v1.4b[3] + subs x13, x13, #1 + .inst 0x4f82e09c // sdot v28.4s, v4.16b, v2.4b[0] + .inst 0x4fa2e09d // sdot v29.4s, v4.16b, v2.4b[1] + .inst 0x4f82e89e // sdot v30.4s, v4.16b, v2.4b[2] + .inst 0x4fa2e89f // sdot v31.4s, v4.16b, v2.4b[3] + bne L8LoopSz_TILE_12 + + L8LoopSzEnd_TILE_12: + //add x2, x2, x15 + //add x24, x15, x15 + add x2, x28, x24 + sub x5, x5, #1 + + L8Tile12Quan: + ld1 {v0.4s, v1.4s}, [x8], #32 // scale + ld1 {v2.4s, v3.4s, v4.4s}, [x25] // x kernel sum + ld1 {v5.4s, v6.4s}, [x26], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + Int32ToFloat v20, v21, v22, v23 + Int32ToFloat v24, v25, v26, v27 + Int32ToFloat v28, v29, v30, v31 + + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v0, v12, v13, v14, v15 + MUL_SCALE v0, v16, v17, v18, v19 + MUL_SCALE v1, v20, v21, v22, v23 + MUL_SCALE v1, v24, v25, v26, v27 + MUL_SCALE v1, v28, v29, v30, v31 + + cbz x27, TILE12_L8_MLA_TERM + ld1 {v0.4s, v1.4s}, [x27], #32 + ld1 {v7.4s}, [x27] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + MUL_EXTRA_SCALE v7, v16, v17, v18, v19 + MUL_EXTRA_SCALE v0, v20, v21, v22, v23 + MUL_EXTRA_SCALE v1, v24, v25, v26, v27 + MUL_EXTRA_SCALE v7, v28, v29, v30, v31 + sub x27, x27, #32 + + TILE12_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 + + //ld1r {v0.4s}, [x23] // f32 min + //ld1r {v1.4s}, [x24] // f32 max + MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7 + MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 + MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 + MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 + sub x4, x4, #128 + + cbz x9, TILE12_ADD_DSTV + TILE12_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x9], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v0 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + ADD_BIAS_FLOAT v24, v25, v26, v27, v1 + ADD_BIAS_FLOAT v28, v29, v30, v31, v1 + + Float32ToHalf v8, v20, v9, v21, v0, v1 + Float32ToHalf v10, v22, v11, v23, v2, v3 + Float32ToHalf v12, v24, v13, v25, v4, v5 + Float32ToHalf v14, v26, v15, v27, v6, v7 + Float32ToHalf v16, v28, v17, v29, v8, v9 + Float32ToHalf v18, v30, v19, v31, v10, v11 + b TILE12_POST + + TILE12_ADD_DSTV: + Float32ToHalf v8, v20, v9, v21, v0, v1 + Float32ToHalf v10, v22, v11, v23, v2, v3 + Float32ToHalf v12, v24, v13, v25, v4, v5 + Float32ToHalf v14, v26, v15, v27, v6, v7 + Float32ToHalf v16, v28, v17, v29, v8, v9 + Float32ToHalf v18, v30, v19, v31, v10, v11 + ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0] + fadd v0.8h, v0.8h, v20.8h + fadd v1.8h, v1.8h, v21.8h + fadd v2.8h, v2.8h, v22.8h + fadd v3.8h, v3.8h, v23.8h + fadd v4.8h, v4.8h, v12.8h + fadd v5.8h, v5.8h, v13.8h + fadd v6.8h, v6.8h, v14.8h + fadd v7.8h, v7.8h, v15.8h + fadd v8.8h, v8.8h, v16.8h + fadd v9.8h, v9.8h, v17.8h + fadd v10.8h, v10.8h, v18.8h + fadd v11.8h, v11.8h, v19.8h + sub x0, x0, #128 + + TILE12_POST: + cbz x23, TILE12_STORE + ld1r {v24.8h}, [x23], #2 // f32 min + ld1r {v25.8h}, [x23] // f32 max + + ReLU_FP16 v0, v1, v2, v3, v24, v25 + ReLU_FP16 v4, v5, v6, v7, v24, v25 + ReLU_FP16 v8, v9, v10, v11, v24, v25 + sub x23, x23, #2 + + TILE12_STORE: + + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4 + add x4, x4, #128 + L8Tile12LoopCheck: + cmp x5, #1 + bge L8LoopDz_TILE_12 + blt End + +TILE_8: + //ld1r {v26.4s}, [x23] // f32 min + //ld1r {v27.4s}, [x24] // f32 max + cmp x7, #8 + blt TILE_4 + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x26 // weightQuantBias + +L8LoopDz_TILE_8: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 + mov x13, x3 + + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + SET_BIAS v20, v21, v22, v23 + mov x28, x12 + L8LoopSz_TILE_8: + ld1 {v3.16b}, [x12], x15 // weight + ld1 {v0.16b, v1.16b}, [x11], x22 // src + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + ld1 {v4.16b}, [x12], #16 + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] + .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] + .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] + .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + sub x12, x12, x15 + .inst 0x4f80e090 // sdot v16.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e091 // sdot v17.4s, v4.16b, v0.4b[1] + .inst 0x4f80e892 // sdot v18.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e893 // sdot v19.4s, v4.16b, v0.4b[3] + subs x13, x13, #1 + .inst 0x4f81e094 // sdot v20.4s, v4.16b, v1.4b[0] + .inst 0x4fa1e095 // sdot v21.4s, v4.16b, v1.4b[1] + .inst 0x4f81e896 // sdot v22.4s, v4.16b, v1.4b[2] + .inst 0x4fa1e897 // sdot v23.4s, v4.16b, v1.4b[3] + bne L8LoopSz_TILE_8 + + L8LoopSzEnd_TILE_8: + //add x12, x12, x15 + //add x24, x15, x15 + add x12, x28, x24 + sub x14, x14, #1 + + L8Tile8Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.4s, v3.4s}, [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + Int32ToFloat v20, v21, v22, v23 + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v0, v12, v13, v14, v15 + MUL_SCALE v1, v16, v17, v18, v19 + MUL_SCALE v1, v20, v21, v22, v23 + + cbz x27, TILE8_L8_MLA_TERM + ld1 {v4.4s, v5.4s}, [x27] + MUL_EXTRA_SCALE v4, v8, v9, v10, v11 + MUL_EXTRA_SCALE v5, v12, v13, v14, v15 + MUL_EXTRA_SCALE v4, v16, v17, v18, v19 + MUL_EXTRA_SCALE v5, v20, v21, v22, v23 + + TILE8_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v24, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v2, v25, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v17, v2, v25, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v18, v2, v25, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v19, v2, v25, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v20, v3, v25, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v21, v3, v25, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7 + + sub x4, x4, #64 + + cbz x9, TILE8_ADD_DSTV + TILE8_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x20], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v1 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + + Float32ToHalf v8, v16, v9, v17, v0, v1 + Float32ToHalf v10, v18, v11, v19, v2, v3 + Float32ToHalf v12, v20, v13, v21, v4, v5 + Float32ToHalf v14, v22, v15, v23, v6, v7 + b TILE8_POST + + TILE8_ADD_DSTV: + Float32ToHalf v8, v16, v9, v17, v0, v1 + Float32ToHalf v10, v18, v11, v19, v2, v3 + Float32ToHalf v12, v20, v13, v21, v4, v5 + Float32ToHalf v14, v22, v15, v23, v6, v7 + ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10], #64 + ld1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x10] + fadd v0.8h, v0.8h, v24.8h + fadd v1.8h, v1.8h, v25.8h + fadd v2.8h, v2.8h, v26.8h + fadd v3.8h, v3.8h, v27.8h + fadd v4.8h, v4.8h, v28.8h + fadd v5.8h, v5.8h, v29.8h + fadd v6.8h, v6.8h, v30.8h + fadd v7.8h, v7.8h, v31.8h + sub x10, x10, #64 + + TILE8_POST: + cbz x23, TILE8_STORE + ld1r {v24.8h}, [x23], #2 // f16 min + ld1r {v25.8h}, [x23] // f16 max + ReLU_FP16 v0, v1, v2, v3, v24, v25 + ReLU_FP16 v4, v5, v6, v7, v24, v25 + sub x23, x23, #2 + + TILE8_STORE: + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], x4 + + //st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 + //st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 + //st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 + //st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], x4 + add x4, x4, #64 + + L8Tile8LoopCheck: + cmp x14, #1 + bge L8LoopDz_TILE_8 + +cbz x27, Tile8End +add x27, x27, #32 +Tile8End: + sub x7, x7, #8 + add x0, x0, x21, LSL #3 + add x1, x1, #32 + add x25, x25, #32 + +TILE_4: + cmp x7, #4 + blt TILE_1 + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 + mov x20, x9 + mov x6, x26 // weightQuantBias +L8LoopDz_TILE_4: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 + mov x13, x3 + + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + + mov x28, x12 + L8LoopSz_TILE_4: + ld1 {v3.16b}, [x12], x15 // weight + ld1 {v0.16b}, [x11], x22 // src + ld1 {v4.16b}, [x12], #16 // weight + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + subs x13, x13, #1 + sub x12, x12, x15 + .inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1] + .inst 0x4f80e88e // sdot v14.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e88f // sdot v15.4s, v4.16b, v0.4b[3] + bne L8LoopSz_TILE_4 + + L8LoopSzEnd_TILE_4: + //add x12, x12, x15 + //add x24, x15, x15 + add x12, x28, x24 + sub x14, x14, #1 + + L8Tile4Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.4s}, [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v1, v12, v13, v14, v15 + + cbz x27, TILE4_L8_MLA_TERM + ld1 {v4.4s}, [x27] + MUL_EXTRA_SCALE v4, v8, v9, v10, v11 + MUL_EXTRA_SCALE v4, v12, v13, v14, v15 + + TILE4_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v2, v25, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v13, v2, v25, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v14, v2, v25, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v15, v2, v25, 3 // tile:3, oc:4-7 + + cbz x9, TILE4_ADD_DSTV + TILE4_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x20], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v1 + Float32ToHalf v8, v12, v9, v13, v0, v1 + Float32ToHalf v10, v14, v11, v15, v2, v3 + b TILE4_POST + + TILE4_ADD_DSTV: + Float32ToHalf v8, v12, v9, v13, v0, v1 + Float32ToHalf v10, v14, v11, v15, v2, v3 + ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x10] + fadd v0.8h, v0.8h, v20.8h + fadd v1.8h, v1.8h, v21.8h + fadd v2.8h, v2.8h, v22.8h + fadd v3.8h, v3.8h, v23.8h + + TILE4_POST: + cbz x23, TILE4_STORE + ld1r {v24.8h}, [x23], #2 // f16 min + ld1r {v25.8h}, [x23] // f16 max + sub x23, x23, #2 + ReLU_FP16 v0, v1, v2, v3, v24, v25 + //st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 + //st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 + + + TILE4_STORE: + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], x4 + + L8Tile4LoopCheck: + cmp x14, #1 + bge L8LoopDz_TILE_4 +cbz x27, Tile4End +add x27, x27, #16 +Tile4End: + sub x7, x7, #4 + add x0, x0, x21, LSL #2 + add x1, x1, #16 + add x25, x25, #16 + +TILE_1: + cbz x7, End + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 + mov x20, x9 + mov x6, x26 // weightQuantBias + +L8LoopDz_TILE_1: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 + mov x13, x3 + + movi v8.16b, #0 + movi v9.16b, #0 + //mov v8.16b, v0.16b + //mov v9.16b, v1.16b + mov x28, x12 + L8LoopSz_TILE_1: + ld1 {v3.16b}, [x12], x15 // weight + ld1 {v0.s}[0], [x11], x22 // src + ld1 {v4.16b}, [x12], #16 // weight + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + subs x13, x13, #1 + sub x12, x12, x15 + .inst 0x4f80e089 // sdot v9.4s, v4.16b, v0.4b[0] + bne L8LoopSz_TILE_1 + + L8LoopSzEnd_TILE_1: + //add x12, x12, x15 + //add x24, x15, x15 + add x12, x28, x24 + sub x14, x14, #1 + + L8Tile1Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.s}[0], [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + scvtf v8.4s, v8.4s + scvtf v9.4s, v9.4s + fmul v8.4s, v8.4s, v0.4s + fmul v9.4s, v9.4s, v1.4s + + cbz x27, TILE1_L8_MLA_TERM + ld1 {v4.s}[0], [x27] + fmul v8.4s, v8.4s, v4.s[0] + fmul v9.4s, v9.4s, v4.s[0] + + TILE1_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v25, 0 // tile:0, oc:4-7 + + cbz x9, TILE1_ADD_DSTV + TILE1_ADD_BIAS: + ld1 {v20.4s, v21.4s}, [x20], #32 + fadd v8.4s, v8.4s, v20.4s + fadd v9.4s, v9.4s, v21.4s + fcvtn v0.4h, v8.4s + fcvtn2 v0.8h, v9.4s + b TILE1_POST + + TILE1_ADD_DSTV: + fcvtn v0.4h, v8.4s + fcvtn2 v0.8h, v9.4s + ld1 {v3.8h}, [x10] + fadd v0.8h, v0.8h, v3.8h + + TILE1_POST: + cbz x23, TILE1_STORE + ld1r {v24.8h}, [x23], #2 // f32 min + ld1r {v25.8h}, [x23] // f32 max + sub x23, x23, #2 + fmax v0.8h, v24.8h, v0.8h + fmin v0.8h, v25.8h, v0.8h + // st1 {v8.4s}, [x10], x4 + // st1 {v9.4s}, [x10], x4 + + //fcvtn v0.4h, v8.4s + //fcvtn2 v0.8h, v9.4s + TILE1_STORE: + st1 {v0.8h}, [x10], x4 + + L8Tile1LoopCheck: + cmp x14, #1 + bge L8LoopDz_TILE_1 +cbz x27, Tile1End +add x27, x27, #4 +Tile1End: + sub x7, x7, #1 + add x0, x0, x21 + add x1, x1, #4 + add x25, x25, #4 + b TILE_1 + +End: +ldp x23, x24, [sp, #(16 * 8)] +ldp x25, x26, [sp, #(16 * 7)] +ldp x27, x28, [sp, #(16 * 6)] +ldp x19, x20, [sp, #(16 * 5)] +ldp x21, x22, [sp, #(16 * 4)] +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 10) +ret + +#endif // __aarch64__ diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S new file mode 100644 index 000000000..5d92ae056 --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S @@ -0,0 +1,690 @@ +// +// MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16.S +// MNN +// +// Created by MNN on 2019/12/17. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#if defined(__aarch64__) +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro ADD_BIAS_FLOAT d0, d1, d2, d3, z0 + fadd \d0\().4s, \d0\().4s, \z0\().4s + fadd \d1\().4s, \d1\().4s, \z0\().4s + fadd \d2\().4s, \d2\().4s, \z0\().4s + fadd \d3\().4s, \d3\().4s, \z0\().4s +.endm + +.macro SET_BIAS d0, d1, d2, d3 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 +.endm +.macro Int32ToFloat z0, z1, z2, z3 + scvtf \z0\().4s, \z0\().4s + scvtf \z1\().4s, \z1\().4s + scvtf \z2\().4s, \z2\().4s + scvtf \z3\().4s, \z3\().4s +.endm +.macro MUL_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s + fmul \d3\().4s, \d3\().4s, \s\().4s +.endm +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm +.macro ReLU_FP16 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().8h, \s0\().8h, \z1\().8h + fmin \s1\().8h, \s1\().8h, \z1\().8h + fmin \s2\().8h, \s2\().8h, \z1\().8h + fmin \s3\().8h, \s3\().8h, \z1\().8h + fmax \s0\().8h, \s0\().8h, \z0\().8h + fmax \s1\().8h, \s1\().8h, \z0\().8h + fmax \s2\().8h, \s2\().8h, \z0\().8h + fmax \s3\().8h, \s3\().8h, \z0\().8h +.endm + +.macro Float32ToHalf s0, s1, s2, s3, d0, d1 + fcvtn \d0\().4h, \s0\().4s + fcvtn2 \d0\().8h, \s1\().4s + fcvtn \d1\().4h, \s2\().4s + fcvtn2 \d1\().8h, \s3\().4s +.endm + +asm_function MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16 +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + const float* extraScale = nullptr; +}; +*/ + +//void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16(int8_t* dst, const int8_t* src, +// const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, +// const QuanPostTreatParameters* parameters, size_t realDstCount); + +//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step +//x5:dst_depth_quad, x6: parameters, x7: realDstCount + +//Load from x6: x8: scale, x9: bias, x25: xKernelSum, x26: weightQuantBias, x23: fp32minmax, x27: blockNum +ldr x8, [x6, #0] +ldr x9, [x6, #8] +//ldr w12, [x6, #16] + +stp d14, d15, [sp, #(-16 * 10)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x21, x22, [sp, #(16 * 4)] +stp x19, x20, [sp, #(16 * 5)] +stp x27, x28, [sp, #(16 * 6)] +stp x25, x26, [sp, #(16 * 7)] +stp x23, x24, [sp, #(16 * 8)] + +//ldr w27, [x6, #20] +ldr x25, [x6, #40] // xKernelSum +ldr x26, [x6, #48] // weightQuantBias +ldr x23, [x6, #56] // fp32minmax +ldr x27, [x6, #64] // blockNum + +mov x21, #16 // sizeof(float16_t) * PACK +mul x27, x27, x3 +Start: +lsl x15, x27, #3 // x15 = src_depth_quad * UNIT * SRC_UNIT * sizeof(int4_t) +mov x22, #48 // src_steps +add x24, x15, x15 +ldr x27, [x6, #80] // extra scale +TILE_12: + cmp x7, #12 + blt TILE_8 + +L8LoopDz_TILE_12: + // ld1 {v0.4s, v1.4s}, [x9], #32 // bias + mov x11, x1 + mov x13, x3 + movi v7.16b, #15 + + // Init 0 + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + SET_BIAS v20, v21, v22, v23 + SET_BIAS v24, v25, v26, v27 + SET_BIAS v28, v29, v30, v31 + + mov x28, x2 + L8LoopSz_TILE_12: + ld1 {v3.d}[0], [x2], x15 // weight + ld1 {v4.d}[0], [x2], #8 + ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] + .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] + .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] + .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + // int4->int8 + ushr v5.16b, v4.16b, #4 + and v6.16b, v4.16b, v7.16b + zip1 v4.16b, v5.16b, v6.16b + + .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] + .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] + .inst 0x4f82e872 // sdot v18.4s, v3.16b, v2.4b[2] + .inst 0x4fa2e873 // sdot v19.4s, v3.16b, v2.4b[3] + .inst 0x4f80e094 // sdot v20.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e095 // sdot v21.4s, v4.16b, v0.4b[1] + .inst 0x4f80e896 // sdot v22.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e897 // sdot v23.4s, v4.16b, v0.4b[3] + sub x2, x2, x15 + .inst 0x4f81e098 // sdot v24.4s, v4.16b, v1.4b[0] + .inst 0x4fa1e099 // sdot v25.4s, v4.16b, v1.4b[1] + .inst 0x4f81e89a // sdot v26.4s, v4.16b, v1.4b[2] + .inst 0x4fa1e89b // sdot v27.4s, v4.16b, v1.4b[3] + subs x13, x13, #1 + .inst 0x4f82e09c // sdot v28.4s, v4.16b, v2.4b[0] + .inst 0x4fa2e09d // sdot v29.4s, v4.16b, v2.4b[1] + .inst 0x4f82e89e // sdot v30.4s, v4.16b, v2.4b[2] + .inst 0x4fa2e89f // sdot v31.4s, v4.16b, v2.4b[3] + bne L8LoopSz_TILE_12 + + L8LoopSzEnd_TILE_12: + add x2, x28, x24 + sub x5, x5, #1 + + L8Tile12Quan: + ld1 {v0.4s, v1.4s}, [x8], #32 // scale + ld1 {v2.4s, v3.4s, v4.4s}, [x25] // x kernel sum + ld1 {v5.4s, v6.4s}, [x26], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + Int32ToFloat v20, v21, v22, v23 + Int32ToFloat v24, v25, v26, v27 + Int32ToFloat v28, v29, v30, v31 + + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v0, v12, v13, v14, v15 + MUL_SCALE v0, v16, v17, v18, v19 + MUL_SCALE v1, v20, v21, v22, v23 + MUL_SCALE v1, v24, v25, v26, v27 + MUL_SCALE v1, v28, v29, v30, v31 + + cbz x27, TILE12_L8_MLA_TERM + ld1 {v0.4s, v1.4s}, [x27], #32 + ld1 {v7.4s}, [x27] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + MUL_EXTRA_SCALE v7, v16, v17, v18, v19 + MUL_EXTRA_SCALE v0, v20, v21, v22, v23 + MUL_EXTRA_SCALE v1, v24, v25, v26, v27 + MUL_EXTRA_SCALE v7, v28, v29, v30, v31 + sub x27, x27, #32 + + TILE12_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 + + //ld1r {v0.4s}, [x23] // f32 min + //ld1r {v1.4s}, [x24] // f32 max + MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7 + MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 + MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 + MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 + sub x4, x4, #128 + + cbz x9, TILE12_ADD_DSTV + TILE12_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x9], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v0 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + ADD_BIAS_FLOAT v24, v25, v26, v27, v1 + ADD_BIAS_FLOAT v28, v29, v30, v31, v1 + + Float32ToHalf v8, v20, v9, v21, v0, v1 + Float32ToHalf v10, v22, v11, v23, v2, v3 + Float32ToHalf v12, v24, v13, v25, v4, v5 + Float32ToHalf v14, v26, v15, v27, v6, v7 + Float32ToHalf v16, v28, v17, v29, v8, v9 + Float32ToHalf v18, v30, v19, v31, v10, v11 + b TILE12_POST + + TILE12_ADD_DSTV: + Float32ToHalf v8, v20, v9, v21, v0, v1 + Float32ToHalf v10, v22, v11, v23, v2, v3 + Float32ToHalf v12, v24, v13, v25, v4, v5 + Float32ToHalf v14, v26, v15, v27, v6, v7 + Float32ToHalf v16, v28, v17, v29, v8, v9 + Float32ToHalf v18, v30, v19, v31, v10, v11 + ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + ld1 {v12.8h, v13.8h, v14.8h, v15.8h}, [x0], #64 + ld1 {v16.8h, v17.8h, v18.8h, v19.8h}, [x0] + fadd v0.8h, v0.8h, v20.8h + fadd v1.8h, v1.8h, v21.8h + fadd v2.8h, v2.8h, v22.8h + fadd v3.8h, v3.8h, v23.8h + fadd v4.8h, v4.8h, v12.8h + fadd v5.8h, v5.8h, v13.8h + fadd v6.8h, v6.8h, v14.8h + fadd v7.8h, v7.8h, v15.8h + fadd v8.8h, v8.8h, v16.8h + fadd v9.8h, v9.8h, v17.8h + fadd v10.8h, v10.8h, v18.8h + fadd v11.8h, v11.8h, v19.8h + sub x0, x0, #128 + + TILE12_POST: + cbz x23, TILE12_STORE + ld1r {v24.8h}, [x23], #2 // f32 min + ld1r {v25.8h}, [x23] // f32 max + + ReLU_FP16 v0, v1, v2, v3, v24, v25 + ReLU_FP16 v4, v5, v6, v7, v24, v25 + ReLU_FP16 v8, v9, v10, v11, v24, v25 + sub x23, x23, #2 + + TILE12_STORE: + + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x0], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x0], #64 + st1 {v8.8h, v9.8h, v10.8h, v11.8h}, [x0], x4 + add x4, x4, #128 + L8Tile12LoopCheck: + cmp x5, #1 + bge L8LoopDz_TILE_12 + blt End + +TILE_8: + cmp x7, #8 + blt TILE_4 + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x26 // weightQuantBias + +L8LoopDz_TILE_8: + mov x11, x1 + mov x13, x3 + movi v7.16b, #15 + + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + SET_BIAS v20, v21, v22, v23 + mov x28, x12 + L8LoopSz_TILE_8: + ld1 {v3.d}[0], [x12], x15 // weight + ld1 {v4.d}[0], [x12], #8 + ld1 {v0.16b, v1.16b}, [x11], x22 // src + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + // int4->int8 + ushr v5.16b, v4.16b, #4 + and v6.16b, v4.16b, v7.16b + zip1 v4.16b, v5.16b, v6.16b + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] + .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] + .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] + .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + sub x12, x12, x15 + .inst 0x4f80e090 // sdot v16.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e091 // sdot v17.4s, v4.16b, v0.4b[1] + .inst 0x4f80e892 // sdot v18.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e893 // sdot v19.4s, v4.16b, v0.4b[3] + subs x13, x13, #1 + .inst 0x4f81e094 // sdot v20.4s, v4.16b, v1.4b[0] + .inst 0x4fa1e095 // sdot v21.4s, v4.16b, v1.4b[1] + .inst 0x4f81e896 // sdot v22.4s, v4.16b, v1.4b[2] + .inst 0x4fa1e897 // sdot v23.4s, v4.16b, v1.4b[3] + bne L8LoopSz_TILE_8 + + L8LoopSzEnd_TILE_8: + add x12, x28, x24 + sub x14, x14, #1 + + L8Tile8Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.4s, v3.4s}, [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + Int32ToFloat v20, v21, v22, v23 + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v0, v12, v13, v14, v15 + MUL_SCALE v1, v16, v17, v18, v19 + MUL_SCALE v1, v20, v21, v22, v23 + + cbz x27, TILE8_L8_MLA_TERM + ld1 {v4.4s, v5.4s}, [x27] + MUL_EXTRA_SCALE v4, v8, v9, v10, v11 + MUL_EXTRA_SCALE v5, v12, v13, v14, v15 + MUL_EXTRA_SCALE v4, v16, v17, v18, v19 + MUL_EXTRA_SCALE v5, v20, v21, v22, v23 + + TILE8_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v24, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v2, v25, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v17, v2, v25, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v18, v2, v25, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v19, v2, v25, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v20, v3, v25, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v21, v3, v25, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7 + + sub x4, x4, #64 + + cbz x9, TILE8_ADD_DSTV + TILE8_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x20], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v1 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + + Float32ToHalf v8, v16, v9, v17, v0, v1 + Float32ToHalf v10, v18, v11, v19, v2, v3 + Float32ToHalf v12, v20, v13, v21, v4, v5 + Float32ToHalf v14, v22, v15, v23, v6, v7 + b TILE8_POST + + TILE8_ADD_DSTV: + Float32ToHalf v8, v16, v9, v17, v0, v1 + Float32ToHalf v10, v18, v11, v19, v2, v3 + Float32ToHalf v12, v20, v13, v21, v4, v5 + Float32ToHalf v14, v22, v15, v23, v6, v7 + ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10], #64 + ld1 {v28.8h, v29.8h, v30.8h, v31.8h}, [x10] + fadd v0.8h, v0.8h, v24.8h + fadd v1.8h, v1.8h, v25.8h + fadd v2.8h, v2.8h, v26.8h + fadd v3.8h, v3.8h, v27.8h + fadd v4.8h, v4.8h, v28.8h + fadd v5.8h, v5.8h, v29.8h + fadd v6.8h, v6.8h, v30.8h + fadd v7.8h, v7.8h, v31.8h + sub x10, x10, #64 + + TILE8_POST: + cbz x23, TILE8_STORE + ld1r {v24.8h}, [x23], #2 // f16 min + ld1r {v25.8h}, [x23] // f16 max + ReLU_FP16 v0, v1, v2, v3, v24, v25 + ReLU_FP16 v4, v5, v6, v7, v24, v25 + sub x23, x23, #2 + + TILE8_STORE: + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + st1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], x4 + + //st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 + //st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 + //st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 + //st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], x4 + add x4, x4, #64 + + L8Tile8LoopCheck: + cmp x14, #1 + bge L8LoopDz_TILE_8 + +cbz x27, Tile8End +add x27, x27, #32 +Tile8End: + sub x7, x7, #8 + add x0, x0, x21, LSL #3 + add x1, x1, #32 + add x25, x25, #32 + +TILE_4: + movi v7.16b, #15 + cmp x7, #4 + blt TILE_1 + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 + mov x20, x9 + mov x6, x26 // weightQuantBias +L8LoopDz_TILE_4: + mov x11, x1 + mov x13, x3 + + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + + mov x28, x12 + L8LoopSz_TILE_4: + ld1 {v3.d}[0], [x12], x15 // weight + ld1 {v0.16b}, [x11], x22 // src + ld1 {v4.d}[0], [x12], #8 // weight + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + // int4->int8 + ushr v5.16b, v4.16b, #4 + and v6.16b, v4.16b, v7.16b + zip1 v4.16b, v5.16b, v6.16b + subs x13, x13, #1 + sub x12, x12, x15 + .inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1] + .inst 0x4f80e88e // sdot v14.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e88f // sdot v15.4s, v4.16b, v0.4b[3] + bne L8LoopSz_TILE_4 + + L8LoopSzEnd_TILE_4: + add x12, x28, x24 + sub x14, x14, #1 + + L8Tile4Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.4s}, [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v1, v12, v13, v14, v15 + + cbz x27, TILE4_L8_MLA_TERM + ld1 {v4.4s}, [x27] + MUL_EXTRA_SCALE v4, v8, v9, v10, v11 + MUL_EXTRA_SCALE v4, v12, v13, v14, v15 + + TILE4_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v2, v25, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v13, v2, v25, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v14, v2, v25, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v15, v2, v25, 3 // tile:3, oc:4-7 + + cbz x9, TILE4_ADD_DSTV + TILE4_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x20], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v1 + Float32ToHalf v8, v12, v9, v13, v0, v1 + Float32ToHalf v10, v14, v11, v15, v2, v3 + b TILE4_POST + + TILE4_ADD_DSTV: + Float32ToHalf v8, v12, v9, v13, v0, v1 + Float32ToHalf v10, v14, v11, v15, v2, v3 + ld1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x10] + fadd v0.8h, v0.8h, v20.8h + fadd v1.8h, v1.8h, v21.8h + fadd v2.8h, v2.8h, v22.8h + fadd v3.8h, v3.8h, v23.8h + + TILE4_POST: + cbz x23, TILE4_STORE + ld1r {v24.8h}, [x23], #2 // f16 min + ld1r {v25.8h}, [x23] // f16 max + sub x23, x23, #2 + ReLU_FP16 v0, v1, v2, v3, v24, v25 + //st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 + //st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 + + + TILE4_STORE: + st1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], x4 + + L8Tile4LoopCheck: + cmp x14, #1 + bge L8LoopDz_TILE_4 +cbz x27, Tile4End +add x27, x27, #16 +Tile4End: + sub x7, x7, #4 + add x0, x0, x21, LSL #2 + add x1, x1, #16 + add x25, x25, #16 + +TILE_1: + // Already execute: [movi v7.16b, #15] in TILE_4 + cbz x7, End + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 + mov x20, x9 + mov x6, x26 // weightQuantBias + +L8LoopDz_TILE_1: + mov x11, x1 + mov x13, x3 + + movi v8.16b, #0 + movi v9.16b, #0 + mov x28, x12 + L8LoopSz_TILE_1: + ld1 {v3.d}[0], [x12], x15 // weight + ld1 {v0.s}[0], [x11], x22 // src + ld1 {v4.d}[0], [x12], #8 // weight + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + subs x13, x13, #1 + // int4->int8 + ushr v5.16b, v4.16b, #4 + and v6.16b, v4.16b, v7.16b + zip1 v4.16b, v5.16b, v6.16b + sub x12, x12, x15 + + .inst 0x4f80e089 // sdot v9.4s, v4.16b, v0.4b[0] + bne L8LoopSz_TILE_1 + + L8LoopSzEnd_TILE_1: + add x12, x28, x24 + sub x14, x14, #1 + + L8Tile1Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.s}[0], [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + scvtf v8.4s, v8.4s + scvtf v9.4s, v9.4s + fmul v8.4s, v8.4s, v0.4s + fmul v9.4s, v9.4s, v1.4s + + cbz x27, TILE1_L8_MLA_TERM + ld1 {v4.s}[0], [x27] + fmul v8.4s, v8.4s, v4.s[0] + fmul v9.4s, v9.4s, v4.s[0] + + TILE1_L8_MLA_TERM: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v25, 0 // tile:0, oc:4-7 + + cbz x9, TILE1_ADD_DSTV + TILE1_ADD_BIAS: + ld1 {v20.4s, v21.4s}, [x20], #32 + fadd v8.4s, v8.4s, v20.4s + fadd v9.4s, v9.4s, v21.4s + fcvtn v0.4h, v8.4s + fcvtn2 v0.8h, v9.4s + b TILE1_POST + + TILE1_ADD_DSTV: + fcvtn v0.4h, v8.4s + fcvtn2 v0.8h, v9.4s + ld1 {v3.8h}, [x10] + fadd v0.8h, v0.8h, v3.8h + + TILE1_POST: + cbz x23, TILE1_STORE + ld1r {v24.8h}, [x23], #2 // f16 min + ld1r {v25.8h}, [x23] // f16 max + sub x23, x23, #2 + fmax v0.8h, v24.8h, v0.8h + fmin v0.8h, v25.8h, v0.8h + // st1 {v8.4s}, [x10], x4 + // st1 {v9.4s}, [x10], x4 + TILE1_STORE: + st1 {v0.8h}, [x10], x4 + + L8Tile1LoopCheck: + cmp x14, #1 + bge L8LoopDz_TILE_1 +cbz x27, Tile1End +add x27, x27, #4 +Tile1End: + sub x7, x7, #1 + add x0, x0, x21 + add x1, x1, #4 + add x25, x25, #4 + b TILE_1 + +End: +ldp x23, x24, [sp, #(16 * 8)] +ldp x25, x26, [sp, #(16 * 7)] +ldp x27, x28, [sp, #(16 * 6)] +ldp x19, x20, [sp, #(16 * 5)] +ldp x21, x22, [sp, #(16 * 4)] +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 10) +ret + +#endif // __aarch64__ diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16.S new file mode 100644 index 000000000..76c79b42e --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16.S @@ -0,0 +1,855 @@ +// +// MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16.S +// MNN +// +// Created by MNN on 2022/09/26. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#if defined(__aarch64__) +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro SET_0_5 d0, d1, d2, d3, d4 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 + movi \d4\().16b, #0 +.endm + +.macro SET_0_4 d0, d1, d2, d3 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 +.endm + +.macro SET_0_2 d0, d1 + movi \d0\().16b, #0 + movi \d1\().16b, #0 +.endm + +.macro ADD_BIAS_FLOAT d0, d1, d2, d3, z0 + fadd \d0\().4s, \d0\().4s, \z0\().4s + fadd \d1\().4s, \d1\().4s, \z0\().4s + fadd \d2\().4s, \d2\().4s, \z0\().4s + fadd \d3\().4s, \d3\().4s, \z0\().4s +.endm + +.macro ReLU_FP16 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().8h, \s0\().8h, \z1\().8h + fmin \s1\().8h, \s1\().8h, \z1\().8h + fmin \s2\().8h, \s2\().8h, \z1\().8h + fmin \s3\().8h, \s3\().8h, \z1\().8h + fmax \s0\().8h, \s0\().8h, \z0\().8h + fmax \s1\().8h, \s1\().8h, \z0\().8h + fmax \s2\().8h, \s2\().8h, \z0\().8h + fmax \s3\().8h, \s3\().8h, \z0\().8h +.endm + +.macro ReLU_FP16_2 s0, s1, z0, z1 // z0:min z1:max + fmin \s0\().8h, \s0\().8h, \z1\().8h + fmin \s1\().8h, \s1\().8h, \z1\().8h + fmax \s0\().8h, \s0\().8h, \z0\().8h + fmax \s1\().8h, \s1\().8h, \z0\().8h +.endm + +.macro SET_BIAS s, d0, d1, d2, d3, d4, idx + dup \d0\().2d, \s\().d[\idx] + dup \d1\().2d, \s\().d[\idx] + dup \d2\().2d, \s\().d[\idx] + dup \d3\().2d, \s\().d[\idx] + dup \d4\().2d, \s\().d[\idx] +.endm +.macro SET_BIAS_4 s, d0, d1, d2, d3, idx + dup \d0\().2d, \s\().d[\idx] + dup \d1\().2d, \s\().d[\idx] + dup \d2\().2d, \s\().d[\idx] + dup \d3\().2d, \s\().d[\idx] +.endm +.macro SET_BIAS_2 s, d0, d1, idx + dup \d0\().2d, \s\().d[\idx] + dup \d1\().2d, \s\().d[\idx] +.endm +.macro Int32ToFloat z0, z1, z2, z3 + scvtf \z0\().4s, \z0\().4s + scvtf \z1\().4s, \z1\().4s + scvtf \z2\().4s, \z2\().4s + scvtf \z3\().4s, \z3\().4s +.endm +.macro MUL_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s + fmul \d3\().4s, \d3\().4s, \s\().4s +.endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro Float32ToHalf s0, s1, s2, s3, d0, d1 + fcvtn \d0\().4h, \s0\().4s + fcvtn2 \d0\().8h, \s1\().4s + fcvtn \d1\().4h, \s2\().4s + fcvtn2 \d1\().8h, \s3\().4s +.endm + +asm_function MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16 + +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + const float* extraScale = nullptr; +}; +*/ + +//void MNNGemmInt8AddBiasScale_ARMV86_Unit(int8_t* dst, const int8_t* src, +// const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, +// const QuanPostTreatParameters* parameters, size_t realDstCount); + +//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step +//x5:dst_depth_quad, x6: parameters, x7: realDstCount + +//Load from x7: x8: scale, x9: biasFloat, x27: srcKernelSum, x28: weightQuanBias, x14: fp32minmax +/* For FP16 +UNIT = 8; +SRC_UNIT = 8; +DST_XUNIT = 10; + */ +ldr x8, [x6, #0] +ldr x9, [x6, #8] + +stp d14, d15, [sp, #(-16 * 10)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x21, x22, [sp, #(16 * 4)] +stp x19, x20, [sp, #(16 * 5)] +stp x23, x24, [sp, #(16 * 6)] +stp x25, x26, [sp, #(16 * 7)] +stp x27, x28, [sp, #(16 * 8)] +// ldr w23, [x6, #24] +ldr x27, [x6, #40] // srcKernelSum +ldr x28, [x6, #48] // weightQuanBias +ldr x23, [x6, #64] // blockNum +ldr x14, [x6, #56] // fp32minmax + +mul x23, x23, x3 // UP_DIV(ic*ky*kx, SRC_UNIT) = blockNum * src_depth_quad_per_block +mov x22, #80 // GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT = 10 * 8 = 80 +mov x21, #16 // sizeof(float16_t) * UNIT + +Start: +lsl x15, x23, #6 // x15 = src_depth_quad * UNIT * UNIT_SRC * sizeof(int8_t) = src_depth_quad * 64 = src_depth_quad << 6 +ldr x23, [x6, #80] // extra scale +TILE_10: + cmp x7, #10 + blt TILE_8 +sub x4, x4, #128 +LoopDz_TILE_10: + //ld1 {v0.4s, v1.4s}, [x9], #32 // bias + mov x11, x1 // src + mov x12, x2 // weight + mov x13, x3 // src_depth_quad + mov x10, x0 // tag dst address + + SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 + SET_0_5 v13, v17, v21, v25, v29 // oc:2,3,2,3 + SET_0_5 v14, v18, v22, v26, v30 // oc:4,5,4,5 + SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7 + +LoopSz_TILE_10: + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 + ld1 {v7.16b}, [x11], #16 + subs x13, x13, #1 + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba46f // smmla v15.4s, v3.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa492 // smmla v18.4s, v4.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba493 // smmla v19.4s, v4.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + .inst 0x4e8aa4b6 // smmla v22.4s, v5.16b, v10.16b // tile4-oc4, tile4-oc5, tile5-oc4, tile5-oc5 + .inst 0x4e8ba4b7 // smmla v23.4s, v5.16b, v11.16b // tile4-oc6, tile4-oc7, tile5-oc6, tile5-oc7 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + .inst 0x4e8aa4da // smmla v26.4s, v6.16b, v10.16b // tile6-oc4, tile6-oc5, tile7-oc4, tile7-oc5 + .inst 0x4e8ba4db // smmla v27.4s, v6.16b, v11.16b // tile6-oc6, tile6-oc7, tile7-oc6, tile7-oc7 + + .inst 0x4e88a4fc // smmla v28.4s, v7.16b, v8.16b // tile8-oc0, tile8-oc1, tile9-oc0, tile9-oc1 + .inst 0x4e89a4fd // smmla v29.4s, v7.16b, v9.16b // tile8-oc2, tile8-oc3, tile9-oc2, tile9-oc3 + .inst 0x4e8aa4fe // smmla v30.4s, v7.16b, v10.16b // tile8-oc4, tile8-oc5, tile9-oc4, tile9-oc5 + .inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7 + bne LoopSz_TILE_10 +LoopSzEnd_TILE_10: + add x2, x2, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sizeof(int8_t); + sub x5, x5, #1 // dz-- + // transpose + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v4.2d, v16.2d, v17.2d + uzp2 v5.2d, v16.2d, v17.2d + uzp1 v6.2d, v18.2d, v19.2d + uzp2 v7.2d, v18.2d, v19.2d + + uzp1 v8.2d, v20.2d, v21.2d + uzp2 v9.2d, v20.2d, v21.2d + uzp1 v10.2d, v22.2d, v23.2d + uzp2 v11.2d, v22.2d, v23.2d + + uzp1 v12.2d, v24.2d, v25.2d + uzp2 v13.2d, v24.2d, v25.2d + uzp1 v14.2d, v26.2d, v27.2d + uzp2 v15.2d, v26.2d, v27.2d + + uzp1 v16.2d, v28.2d, v29.2d + uzp2 v17.2d, v28.2d, v29.2d + uzp1 v18.2d, v30.2d, v31.2d + uzp2 v19.2d, v30.2d, v31.2d + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + +Tile10Quan: + ld1 {v20.4s, v21.4s}, [x8], #32 // scale + ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum + ld1 {v24.d}[0], [x27] + sub x27, x27, #32 + ld1 {v25.4s, v26.4s}, [x28], #32 // weight quan zeropoint + //ld1r {v27.4s}, [x6], #4 // f32 min + //ld1r {v28.4s}, [x6] // f32 max + //sub x6, x6, #4 + MUL_SCALE v20, v0, v1, v4, v5 + MUL_SCALE v21, v2, v3, v6, v7 + MUL_SCALE v20, v8, v9, v12, v13 + MUL_SCALE v21, v10, v11, v14, v15 + fmul v16.4s, v16.4s, v20.4s + fmul v17.4s, v17.4s, v20.4s + fmul v18.4s, v18.4s, v21.4s + fmul v19.4s, v19.4s, v21.4s + + cbz x23, TILE10_MLA + ld1 {v27.4s, v28.4s}, [x23], #32 + ld1 {v29.d}[0], [x23] + MUL_EXTRA_SCALE v27, v0, v1, v4, v5 + MUL_EXTRA_SCALE v28, v8, v9, v12, v13 + MUL_EXTRA_SCALE v27, v2, v3, v6, v7 + MUL_EXTRA_SCALE v28, v10, v11, v14, v15 + fmul v16.4s, v16.4s, v29.s[0] + fmul v17.4s, v17.4s, v29.s[1] + fmul v18.4s, v18.4s, v29.s[0] + fmul v19.4s, v19.4s, v29.s[1] + sub x23, x23, #32 + + TILE10_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + MLA_WEIGHTZERO v4, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v5, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v6, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v7, v22, v26, 3 // tile:3, oc:4-7 + + MLA_WEIGHTZERO v8, v23, v25, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v9, v23, v25, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v10, v23, v26, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v11, v23, v26, 1 // tile:5, oc:4-7 + + MLA_WEIGHTZERO v12, v23, v25, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v13, v23, v25, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v14, v23, v26, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v15, v23, v26, 3 // tile:7, oc:4-7 + + MLA_WEIGHTZERO v16, v24, v25, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v24, v25, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v24, v26, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v19, v24, v26, 1 // tile:9, oc:4-7 + + + cbz x9, TILE10_ADD_DSTV + TILE10_ADD_BIAS: + ld1 {v20.4s, v21.4s}, [x9], #32 // bias + ADD_BIAS_FLOAT v0, v1, v4, v5, v20 + ADD_BIAS_FLOAT v2, v3, v6, v7, v21 + ADD_BIAS_FLOAT v8, v9, v12, v13, v20 + ADD_BIAS_FLOAT v10, v11, v14, v15, v21 + fadd v16.4s, v16.4s, v20.4s + fadd v17.4s, v17.4s, v20.4s + fadd v18.4s, v18.4s, v21.4s + fadd v19.4s, v19.4s, v21.4s + + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + Float32ToHalf v8, v10, v9, v11, v24, v25 + Float32ToHalf v12, v14, v13, v15, v26, v27 + Float32ToHalf v16, v18, v17, v19, v30, v31 + b TILE10_POST // to Relu post + + TILE10_ADD_DSTV: + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + Float32ToHalf v8, v10, v9, v11, v24, v25 + Float32ToHalf v12, v14, v13, v15, v26, v27 + Float32ToHalf v16, v18, v17, v19, v30, v31 + + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + ld1 {v8.8h, v9.8h}, [x10] + + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v2.8h + fadd v23.8h, v23.8h, v3.8h + fadd v24.8h, v24.8h, v4.8h + fadd v25.8h, v25.8h, v5.8h + fadd v26.8h, v26.8h, v6.8h + fadd v27.8h, v27.8h, v7.8h + fadd v30.8h, v30.8h, v8.8h + fadd v31.8h, v31.8h, v9.8h + + TILE10_POST: + cbz x14, TILE10_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + + ReLU_FP16 v20, v21, v22, v23, v29, v28 + ReLU_FP16 v24, v25, v26, v27, v29, v28 + ReLU_FP16_2 v30, v31, v29, v28 + + TILE10_STORE: + + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x0], #64 + st1 {v30.8h, v31.8h}, [x0], x4 + +Tile10LoopCheck: + cmp x5, #1 + bge LoopDz_TILE_10 + b End + +TILE_8: + cmp x7, #8 + blt TILE_4 + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias + sub x4, x4, #64 // For Tile8, revert it when Tile8 end +LoopDz_TILE_8: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + mov x10, x26 // tag dst + + SET_0_4 v12, v16, v20, v24 // oc:0,1,0,1 + SET_0_4 v13, v17, v21, v25 // oc:2,3,2,3 + SET_0_4 v14, v18, v22, v26 // oc:4,5,4,5 + SET_0_4 v15, v19, v23, v27 // oc:6,7,6,7 +LoopSz_TILE_8: + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], x22 // src: E0-E7 + subs x13, x13, #1 + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba46f // smmla v15.4s, v3.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa492 // smmla v18.4s, v4.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba493 // smmla v19.4s, v4.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + .inst 0x4e8aa4b6 // smmla v22.4s, v5.16b, v10.16b // tile4-oc4, tile4-oc5, tile5-oc4, tile5-oc5 + .inst 0x4e8ba4b7 // smmla v23.4s, v5.16b, v11.16b // tile4-oc6, tile4-oc7, tile5-oc6, tile5-oc7 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + .inst 0x4e8aa4da // smmla v26.4s, v6.16b, v10.16b // tile6-oc4, tile6-oc5, tile7-oc4, tile7-oc5 + .inst 0x4e8ba4db // smmla v27.4s, v6.16b, v11.16b // tile6-oc6, tile6-oc7, tile7-oc6, tile7-oc7 + + bne LoopSz_TILE_8 +LoopSzEnd_TILE_8: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v4.2d, v16.2d, v17.2d + uzp2 v5.2d, v16.2d, v17.2d + uzp1 v6.2d, v18.2d, v19.2d + uzp2 v7.2d, v18.2d, v19.2d + + uzp1 v8.2d, v20.2d, v21.2d + uzp2 v9.2d, v20.2d, v21.2d + uzp1 v10.2d, v22.2d, v23.2d + uzp2 v11.2d, v22.2d, v23.2d + + uzp1 v12.2d, v24.2d, v25.2d + uzp2 v13.2d, v24.2d, v25.2d + uzp1 v14.2d, v26.2d, v27.2d + uzp2 v15.2d, v26.2d, v27.2d + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + +Tile8Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.4s, v23.4s}, [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v4, v5 + MUL_SCALE v21, v2, v3, v6, v7 + MUL_SCALE v20, v8, v9, v12, v13 + MUL_SCALE v21, v10, v11, v14, v15 + + cbz x23, TILE8_MLA + ld1 {v27.4s, v28.4s}, [x23] + MUL_EXTRA_SCALE v27, v0, v1, v4, v5 + MUL_EXTRA_SCALE v28, v8, v9, v12, v13 + MUL_EXTRA_SCALE v27, v2, v3, v6, v7 + MUL_EXTRA_SCALE v28, v10, v11, v14, v15 + + TILE8_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + MLA_WEIGHTZERO v4, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v5, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v6, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v7, v22, v26, 3 // tile:3, oc:4-7 + + MLA_WEIGHTZERO v8, v23, v25, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v9, v23, v25, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v10, v23, v26, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v11, v23, v26, 1 // tile:5, oc:4-7 + + MLA_WEIGHTZERO v12, v23, v25, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v13, v23, v25, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v14, v23, v26, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v15, v23, v26, 3 // tile:7, oc:4-7 + + cbz x9, TILE8_ADD_DSTV + TILE8_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + ADD_BIAS_FLOAT v0, v1, v4, v5, v16 + ADD_BIAS_FLOAT v2, v3, v6, v7, v17 + ADD_BIAS_FLOAT v8, v9, v12, v13, v16 + ADD_BIAS_FLOAT v10, v11, v14, v15, v17 + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + Float32ToHalf v8, v10, v9, v11, v24, v25 + Float32ToHalf v12, v14, v13, v15, v26, v27 + b TILE8_POST + + TILE8_ADD_DSTV: + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + Float32ToHalf v8, v10, v9, v11, v24, v25 + Float32ToHalf v12, v14, v13, v15, v26, v27 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10] + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v2.8h + fadd v23.8h, v23.8h, v3.8h + fadd v24.8h, v24.8h, v4.8h + fadd v25.8h, v25.8h, v5.8h + fadd v26.8h, v26.8h, v6.8h + fadd v27.8h, v27.8h, v7.8h + + TILE8_POST: + cbz x14, TILE8_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + ReLU_FP16 v20, v21, v22, v23, v29, v28 + ReLU_FP16 v24, v25, v26, v27, v29, v28 + + TILE8_STORE: + + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x26], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x26], x4 + +Tile8LoopCheck: + cmp x24, #1 + bge LoopDz_TILE_8 +cbz x23, Tile8End +add x23, x23, #32 +Tile8End: + sub x7, x7, #8 + add x0, x0, x21, LSL #3 + add x1, x1, #64 + add x27, x27, #32 + add x4, x4, #64 // Revert it + +TILE_4: + cmp x7, #4 + blt TILE_2 + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias +LoopDz_TILE_4: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + mov x10, x26 // tag dst + + SET_0_2 v12, v16 // oc:0,1,0,1 + SET_0_2 v13, v17 // oc:2,3,2,3 + SET_0_2 v14, v18 // oc:4,5,4,5 + SET_0_2 v15, v19 // oc:6,7,6,7 +LoopSz_TILE_4: + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight + ld1 {v4.16b, v5.16b}, [x11], x22 // src + subs x13, x13, #1 + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa48e // smmla v14.4s, v4.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba48f // smmla v15.4s, v4.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a4b0 // smmla v16.4s, v5.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa4b2 // smmla v18.4s, v5.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba4b3 // smmla v19.4s, v5.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + bne LoopSz_TILE_4 +LoopSzEnd_TILE_4: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v4.2d, v16.2d, v17.2d + uzp2 v5.2d, v16.2d, v17.2d + uzp1 v6.2d, v18.2d, v19.2d + uzp2 v7.2d, v18.2d, v19.2d + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + +Tile4Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.4s}, [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v4, v5 + MUL_SCALE v21, v2, v3, v6, v7 + + cbz x23, TILE4_MLA + ld1 {v27.4s}, [x23] + MUL_EXTRA_SCALE v27, v0, v1, v4, v5 + MUL_EXTRA_SCALE v27, v2, v3, v6, v7 + + TILE4_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + MLA_WEIGHTZERO v4, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v5, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v6, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v7, v22, v26, 3 // tile:3, oc:4-7 + + cbz x9, TILE4_ADD_DSTV + TILE4_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + ADD_BIAS_FLOAT v0, v1, v4, v5, v16 + ADD_BIAS_FLOAT v2, v3, v6, v7, v17 + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + b TILE4_POST + + TILE4_ADD_DSTV: + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10] + fadd v20.8h, v20.8h, v24.8h + fadd v21.8h, v21.8h, v25.8h + fadd v22.8h, v22.8h, v26.8h + fadd v23.8h, v23.8h, v27.8h + + TILE4_POST: + cbz x14, TILE4_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + ReLU_FP16 v20, v21, v22, v23, v29, v28 + + TILE4_STORE: + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x26], x4 + +Tile4LoopCheck: + cmp x24, #1 + bge LoopDz_TILE_4 +cbz x23, Tile4End +add x23, x23, #16 +Tile4End: + sub x7, x7, #4 + add x0, x0, x21, LSL #2 + add x1, x1, #32 + add x27, x27, #16 + //b TILE_4 + +TILE_2: + cmp x7, #2 + blt TILE_1 + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias +LoopDz_TILE_2: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + mov x10, x26 // tag dst + + // v12 oc:0,1,0,1 + // v13 oc:2,3,2,3 + // v14 oc:4,5,4,5 + // v15 oc:6,7,6,7 + SET_0_4 v12, v13, v14, v15 +LoopSz_TILE_2: + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 + ld1 {v4.16b}, [x11], x22 // src + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa48e // smmla v14.4s, v4.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba48f // smmla v15.4s, v4.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + subs x13, x13, #1 + bne LoopSz_TILE_2 +LoopSzEnd_TILE_2: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + Int32ToFloat v0, v1, v2, v3 + +Tile2Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.d}[0], [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + fmul v0.4s, v0.4s, v20.4s + fmul v1.4s, v1.4s, v20.4s + fmul v2.4s, v2.4s, v21.4s + fmul v3.4s, v3.4s, v21.4s + + cbz x23, TILE2_MLA + ld1 {v27.d}[0], [x23] + fmul v0.4s, v0.4s, v27.s[0] + fmul v1.4s, v1.4s, v27.s[1] + fmul v2.4s, v2.4s, v27.s[0] + fmul v3.4s, v3.4s, v27.s[1] + + TILE2_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + cbz x9, TILE2_ADD_DSTV + TILE2_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v17.4s + fadd v3.4s, v3.4s, v17.4s + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + b TILE2_POST + + TILE2_ADD_DSTV: + Float32ToHalf v0, v2, v1, v3, v20, v21 + ld1 {v24.8h, v25.8h}, [x10] + fadd v20.8h, v20.8h, v24.8h + fadd v21.8h, v21.8h, v25.8h + + TILE2_POST: + cbz x14, TILE2_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + fmax v20.8h, v20.8h, v29.8h + fmax v21.8h, v21.8h, v29.8h + fmin v20.8h, v20.8h, v28.8h + fmin v21.8h, v21.8h, v28.8h + + TILE2_STORE: + st1 {v20.8h, v21.8h}, [x26], x4 + +Tile2LoopCheck: + cmp x24, #1 + bge LoopDz_TILE_2 +cbz x23, Tile2End +add x23, x23, #8 +Tile2End: + sub x7, x7, #2 + add x0, x0, x21, LSL #1 + add x1, x1, #16 + add x27, x27, #8 + +TILE_1: + cmp x7, #1 + blt End + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias +LoopDz_TILE_1: + //ld1 {v7.4s, v8.4s}, [x20], #32 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + mov x10, x26 + + //dup v16.2d, v7.d[0] // oc:0,1,0,1 + //dup v17.2d, v7.d[1] // oc:2,3,2,3 + //dup v18.2d, v8.d[0] // oc:4,5,4,5 + //dup v19.2d, v8.d[1] // oc:6,7,6,7 + movi v16.4s, #0 // oc:0,1,0,1 + movi v17.4s, #0 // oc:2,3,2,3 + movi v18.4s, #0 // oc:4,5,4,5 + movi v19.4s, #0 // oc:6,7,6,7 + + //movi v22.4s, #0 // oc:0,1,0,1 + //movi v23.4s, #0 // oc:2,3,2,3 + //movi v24.4s, #0 // oc:4,5,4,5 + //movi v25.4s, #0 // oc:6,7,6,7 + +LoopSz_TILE_1: + // src : 1 x [1 x 8] : v2 + // weight : 2 x [2 x 8] : v0-1 + // dst : 1 x 2 x [2] : v30-v31 + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight + ld1 {v2.8b}, [x11], x22 // src + subs x13, x13, #1 + .inst 0x4e88a450 // smmla v16.4s, v2.16b, v8.16b + .inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b + .inst 0x4e8aa452 // smmla v18.4s, v2.16b, v10.16b + .inst 0x4e8ba453 // smmla v19.4s, v2.16b, v11.16b + + bne LoopSz_TILE_1 +LoopSzEnd_TILE_1: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v27.2d, v16.2d, v17.2d + uzp1 v26.2d, v18.2d, v19.2d + scvtf v27.4s, v27.4s + scvtf v26.4s, v26.4s + +Tile1Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v6.s}[0], [x27] // x kernel sum + ld1 {v8.4s, v9.4s}, [x6], #32 // weight quan zeropoint + fmul v27.4s, v27.4s, v0.4s + fmul v26.4s, v26.4s, v1.4s + + cbz x23, TILE1_MLA + ld1 {v4.s}[0], [x23] + fmul v27.4s, v27.4s, v4.s[0] + fmul v26.4s, v26.4s, v4.s[0] + TILE1_MLA: + MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7 + + cbz x9, TILE1_ADD_DSTV + TILE1_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v27.4s, v27.4s, v16.4s + fadd v26.4s, v26.4s, v17.4s + fcvtn v0.4h, v27.4s + fcvtn2 v0.8h, v26.4s + b TILE1_POST + + TILE1_ADD_DSTV: + fcvtn v0.4h, v27.4s + fcvtn2 v0.8h, v26.4s + ld1 {v24.8h}, [x10] + fadd v0.8h, v0.8h, v24.8h + + TILE1_POST: + cbz x14, TILE1_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + fmax v0.8h, v0.8h, v29.8h + fmin v0.8h, v0.8h, v28.8h + TILE1_STORE: + st1 {v0.8h}, [x26], x4 + +Tile1LoopEnd: + cmp x24, #1 + bge LoopDz_TILE_1 + +End: +ldp x27, x28, [sp, #(16 * 8)] +ldp x25, x26, [sp, #(16 * 7)] +ldp x23, x24, [sp, #(16 * 6)] +ldp x19, x20, [sp, #(16 * 5)] +ldp x21, x22, [sp, #(16 * 4)] +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 10) +ret + +#endif // __aarch64__ diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16.S new file mode 100644 index 000000000..7022af3a1 --- /dev/null +++ b/source/backend/arm82/asm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16.S @@ -0,0 +1,875 @@ +// +// MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16.S +// MNN +// +// Created by MNN on 2022/09/26. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#if defined(__aarch64__) +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro SET_0_5 d0, d1, d2, d3, d4 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 + movi \d4\().16b, #0 +.endm + +.macro SET_0_4 d0, d1, d2, d3 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 +.endm + +.macro SET_0_2 d0, d1 + movi \d0\().16b, #0 + movi \d1\().16b, #0 +.endm + +.macro ADD_BIAS_FLOAT d0, d1, d2, d3, z0 + fadd \d0\().4s, \d0\().4s, \z0\().4s + fadd \d1\().4s, \d1\().4s, \z0\().4s + fadd \d2\().4s, \d2\().4s, \z0\().4s + fadd \d3\().4s, \d3\().4s, \z0\().4s +.endm + +.macro ReLU_FP16 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().8h, \s0\().8h, \z1\().8h + fmin \s1\().8h, \s1\().8h, \z1\().8h + fmin \s2\().8h, \s2\().8h, \z1\().8h + fmin \s3\().8h, \s3\().8h, \z1\().8h + fmax \s0\().8h, \s0\().8h, \z0\().8h + fmax \s1\().8h, \s1\().8h, \z0\().8h + fmax \s2\().8h, \s2\().8h, \z0\().8h + fmax \s3\().8h, \s3\().8h, \z0\().8h +.endm + +.macro ReLU_FP16_2 s0, s1, z0, z1 // z0:min z1:max + fmin \s0\().8h, \s0\().8h, \z1\().8h + fmin \s1\().8h, \s1\().8h, \z1\().8h + fmax \s0\().8h, \s0\().8h, \z0\().8h + fmax \s1\().8h, \s1\().8h, \z0\().8h +.endm +.macro Int32ToFloat z0, z1, z2, z3 + scvtf \z0\().4s, \z0\().4s + scvtf \z1\().4s, \z1\().4s + scvtf \z2\().4s, \z2\().4s + scvtf \z3\().4s, \z3\().4s +.endm +.macro MUL_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s + fmul \d3\().4s, \d3\().4s, \s\().4s +.endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro Float32ToHalf s0, s1, s2, s3, d0, d1 + fcvtn \d0\().4h, \s0\().4s + fcvtn2 \d0\().8h, \s1\().4s + fcvtn \d1\().4h, \s2\().4s + fcvtn2 \d1\().8h, \s3\().4s +.endm + +asm_function MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16 +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + const float* extraScale = nullptr; +}; +*/ +//void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit(int8_t* dst, const int8_t* src, +// const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, +// const QuanPostTreatParameters* parameters, size_t realDstCount); + +//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step +//x5:dst_depth_quad, x6: parameters, x7: realDstCount + +//Load from x7: x8: scale, x9: biasFloat, x27: srcKernelSum, x28: weightQuanBias, x14: fp32minmax +/* For FP16 +UNIT = 8; +SRC_UNIT = 8; +DST_XUNIT = 10; + */ +ldr x8, [x6, #0] +ldr x9, [x6, #8] + +stp d14, d15, [sp, #(-16 * 10)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x21, x22, [sp, #(16 * 4)] +stp x19, x20, [sp, #(16 * 5)] +stp x23, x24, [sp, #(16 * 6)] +stp x25, x26, [sp, #(16 * 7)] +stp x27, x28, [sp, #(16 * 8)] +// ldr w23, [x6, #24] +ldr x27, [x6, #40] // srcKernelSum +ldr x28, [x6, #48] // weightQuanBias +ldr x23, [x6, #64] // blockNum +ldr x14, [x6, #56] // fp32minmax + +mul x23, x23, x3 // UP_DIV(ic*ky*kx, SRC_UNIT) = blockNum * src_depth_quad_per_block +mov x22, #80 // GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT = 10 * 8 = 80 +mov x21, #16 // sizeof(float16_t) * UNIT + +Start: +lsl x15, x23, #5 // x15 = src_depth_quad * UNIT * UNIT_SRC * sizeof(int4_t) = src_depth_quad * 8 * 8 * 0.5 = src_depth_quad << 5 +ldr x23, [x6, #80] // extra scale +TILE_10: + cmp x7, #10 + blt TILE_8 +sub x4, x4, #128 // For Tile10 +LoopDz_TILE_10: + //ld1 {v0.4s, v1.4s}, [x9], #32 // bias + mov x11, x1 // src + mov x12, x2 // weight + mov x13, x3 // src_depth_quad + mov x10, x0 // tag dst address + + SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 + SET_0_5 v13, v17, v21, v25, v29 // oc:2,3,2,3 + SET_0_5 v14, v18, v22, v26, v30 // oc:4,5,4,5 + SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7 + +LoopSz_TILE_10: + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + movi v2.16b, #15 + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 + ld1 {v7.16b}, [x11], #16 + // int4->int8 + + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + and v10.16b, v0.16b, v2.16b // oc:4-5 + and v11.16b, v1.16b, v2.16b // oc:6-7 + + subs x13, x13, #1 + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba46f // smmla v15.4s, v3.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa492 // smmla v18.4s, v4.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba493 // smmla v19.4s, v4.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + .inst 0x4e8aa4b6 // smmla v22.4s, v5.16b, v10.16b // tile4-oc4, tile4-oc5, tile5-oc4, tile5-oc5 + .inst 0x4e8ba4b7 // smmla v23.4s, v5.16b, v11.16b // tile4-oc6, tile4-oc7, tile5-oc6, tile5-oc7 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + .inst 0x4e8aa4da // smmla v26.4s, v6.16b, v10.16b // tile6-oc4, tile6-oc5, tile7-oc4, tile7-oc5 + .inst 0x4e8ba4db // smmla v27.4s, v6.16b, v11.16b // tile6-oc6, tile6-oc7, tile7-oc6, tile7-oc7 + + .inst 0x4e88a4fc // smmla v28.4s, v7.16b, v8.16b // tile8-oc0, tile8-oc1, tile9-oc0, tile9-oc1 + .inst 0x4e89a4fd // smmla v29.4s, v7.16b, v9.16b // tile8-oc2, tile8-oc3, tile9-oc2, tile9-oc3 + .inst 0x4e8aa4fe // smmla v30.4s, v7.16b, v10.16b // tile8-oc4, tile8-oc5, tile9-oc4, tile9-oc5 + .inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7 + bne LoopSz_TILE_10 +LoopSzEnd_TILE_10: + add x2, x2, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT * 0.5); + sub x5, x5, #1 // dz-- + // transpose + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v4.2d, v16.2d, v17.2d + uzp2 v5.2d, v16.2d, v17.2d + uzp1 v6.2d, v18.2d, v19.2d + uzp2 v7.2d, v18.2d, v19.2d + + uzp1 v8.2d, v20.2d, v21.2d + uzp2 v9.2d, v20.2d, v21.2d + uzp1 v10.2d, v22.2d, v23.2d + uzp2 v11.2d, v22.2d, v23.2d + + uzp1 v12.2d, v24.2d, v25.2d + uzp2 v13.2d, v24.2d, v25.2d + uzp1 v14.2d, v26.2d, v27.2d + uzp2 v15.2d, v26.2d, v27.2d + + uzp1 v16.2d, v28.2d, v29.2d + uzp2 v17.2d, v28.2d, v29.2d + uzp1 v18.2d, v30.2d, v31.2d + uzp2 v19.2d, v30.2d, v31.2d + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + +Tile10Quan: + ld1 {v20.4s, v21.4s}, [x8], #32 // scale + ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum + ld1 {v24.d}[0], [x27] + sub x27, x27, #32 + ld1 {v25.4s, v26.4s}, [x28], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v4, v5 + MUL_SCALE v21, v2, v3, v6, v7 + MUL_SCALE v20, v8, v9, v12, v13 + MUL_SCALE v21, v10, v11, v14, v15 + fmul v16.4s, v16.4s, v20.4s + fmul v17.4s, v17.4s, v20.4s + fmul v18.4s, v18.4s, v21.4s + fmul v19.4s, v19.4s, v21.4s + + cbz x23, TILE10_MLA + ld1 {v27.4s, v28.4s}, [x23], #32 + ld1 {v29.d}[0], [x23] + MUL_EXTRA_SCALE v27, v0, v1, v4, v5 + MUL_EXTRA_SCALE v28, v8, v9, v12, v13 + MUL_EXTRA_SCALE v27, v2, v3, v6, v7 + MUL_EXTRA_SCALE v28, v10, v11, v14, v15 + fmul v16.4s, v16.4s, v29.s[0] + fmul v17.4s, v17.4s, v29.s[1] + fmul v18.4s, v18.4s, v29.s[0] + fmul v19.4s, v19.4s, v29.s[1] + sub x23, x23, #32 + + TILE10_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + MLA_WEIGHTZERO v4, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v5, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v6, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v7, v22, v26, 3 // tile:3, oc:4-7 + + MLA_WEIGHTZERO v8, v23, v25, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v9, v23, v25, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v10, v23, v26, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v11, v23, v26, 1 // tile:5, oc:4-7 + + MLA_WEIGHTZERO v12, v23, v25, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v13, v23, v25, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v14, v23, v26, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v15, v23, v26, 3 // tile:7, oc:4-7 + + MLA_WEIGHTZERO v16, v24, v25, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v24, v25, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v24, v26, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v19, v24, v26, 1 // tile:9, oc:4-7 + + + cbz x9, TILE10_ADD_DSTV + TILE10_ADD_BIAS: + ld1 {v20.4s, v21.4s}, [x9], #32 // bias + ADD_BIAS_FLOAT v0, v1, v4, v5, v20 + ADD_BIAS_FLOAT v2, v3, v6, v7, v21 + ADD_BIAS_FLOAT v8, v9, v12, v13, v20 + ADD_BIAS_FLOAT v10, v11, v14, v15, v21 + fadd v16.4s, v16.4s, v20.4s + fadd v17.4s, v17.4s, v20.4s + fadd v18.4s, v18.4s, v21.4s + fadd v19.4s, v19.4s, v21.4s + + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + Float32ToHalf v8, v10, v9, v11, v24, v25 + Float32ToHalf v12, v14, v13, v15, v26, v27 + Float32ToHalf v16, v18, v17, v19, v30, v31 + b TILE10_POST // to Relu post + + TILE10_ADD_DSTV: + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + Float32ToHalf v8, v10, v9, v11, v24, v25 + Float32ToHalf v12, v14, v13, v15, v26, v27 + Float32ToHalf v16, v18, v17, v19, v30, v31 + + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10], #64 + ld1 {v8.8h, v9.8h}, [x10] + + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v2.8h + fadd v23.8h, v23.8h, v3.8h + fadd v24.8h, v24.8h, v4.8h + fadd v25.8h, v25.8h, v5.8h + fadd v26.8h, v26.8h, v6.8h + fadd v27.8h, v27.8h, v7.8h + fadd v30.8h, v30.8h, v8.8h + fadd v31.8h, v31.8h, v9.8h + + TILE10_POST: + cbz x14, TILE10_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + + ReLU_FP16 v20, v21, v22, v23, v29, v28 + ReLU_FP16 v24, v25, v26, v27, v29, v28 + ReLU_FP16_2 v30, v31, v29, v28 + + TILE10_STORE: + + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x0], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x0], #64 + st1 {v30.8h, v31.8h}, [x0], x4 + +Tile10LoopCheck: + cmp x5, #1 + bge LoopDz_TILE_10 + b End + +TILE_8: + //ld1r {v28.4s}, [x6], #4 // f32 min + //ld1r {v29.4s}, [x6] // f32 max + movi v30.16b, #15 + cmp x7, #8 + blt TILE_4 + sub x4, x4, #64 // just for Tile8, revert it when Tile8end + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias +LoopDz_TILE_8: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + mov x10, x26 // tag dst + + SET_0_4 v12, v16, v20, v24 // oc:0,1,0,1 + SET_0_4 v13, v17, v21, v25 // oc:2,3,2,3 + SET_0_4 v14, v18, v22, v26 // oc:4,5,4,5 + SET_0_4 v15, v19, v23, v27 // oc:6,7,6,7 +LoopSz_TILE_8: + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + //movi v2.16b, #15 + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], x22 // src: E0-E7 + + // int4->int8 + subs x13, x13, #1 + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + and v10.16b, v0.16b, v30.16b // oc:4-5 + and v11.16b, v1.16b, v30.16b // oc:6-7 + + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba46f // smmla v15.4s, v3.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa492 // smmla v18.4s, v4.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba493 // smmla v19.4s, v4.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + .inst 0x4e8aa4b6 // smmla v22.4s, v5.16b, v10.16b // tile4-oc4, tile4-oc5, tile5-oc4, tile5-oc5 + .inst 0x4e8ba4b7 // smmla v23.4s, v5.16b, v11.16b // tile4-oc6, tile4-oc7, tile5-oc6, tile5-oc7 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + .inst 0x4e8aa4da // smmla v26.4s, v6.16b, v10.16b // tile6-oc4, tile6-oc5, tile7-oc4, tile7-oc5 + .inst 0x4e8ba4db // smmla v27.4s, v6.16b, v11.16b // tile6-oc6, tile6-oc7, tile7-oc6, tile7-oc7 + + bne LoopSz_TILE_8 +LoopSzEnd_TILE_8: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v4.2d, v16.2d, v17.2d + uzp2 v5.2d, v16.2d, v17.2d + uzp1 v6.2d, v18.2d, v19.2d + uzp2 v7.2d, v18.2d, v19.2d + + uzp1 v8.2d, v20.2d, v21.2d + uzp2 v9.2d, v20.2d, v21.2d + uzp1 v10.2d, v22.2d, v23.2d + uzp2 v11.2d, v22.2d, v23.2d + + uzp1 v12.2d, v24.2d, v25.2d + uzp2 v13.2d, v24.2d, v25.2d + uzp1 v14.2d, v26.2d, v27.2d + uzp2 v15.2d, v26.2d, v27.2d + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + +Tile8Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.4s, v23.4s}, [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v4, v5 + MUL_SCALE v21, v2, v3, v6, v7 + MUL_SCALE v20, v8, v9, v12, v13 + MUL_SCALE v21, v10, v11, v14, v15 + + cbz x23, TILE8_MLA + ld1 {v27.4s, v28.4s}, [x23] + MUL_EXTRA_SCALE v27, v0, v1, v4, v5 + MUL_EXTRA_SCALE v28, v8, v9, v12, v13 + MUL_EXTRA_SCALE v27, v2, v3, v6, v7 + MUL_EXTRA_SCALE v28, v10, v11, v14, v15 + + TILE8_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + MLA_WEIGHTZERO v4, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v5, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v6, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v7, v22, v26, 3 // tile:3, oc:4-7 + + MLA_WEIGHTZERO v8, v23, v25, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v9, v23, v25, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v10, v23, v26, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v11, v23, v26, 1 // tile:5, oc:4-7 + + MLA_WEIGHTZERO v12, v23, v25, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v13, v23, v25, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v14, v23, v26, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v15, v23, v26, 3 // tile:7, oc:4-7 + + cbz x9, TILE8_ADD_DSTV + TILE8_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + ADD_BIAS_FLOAT v0, v1, v4, v5, v16 + ADD_BIAS_FLOAT v2, v3, v6, v7, v17 + ADD_BIAS_FLOAT v8, v9, v12, v13, v16 + ADD_BIAS_FLOAT v10, v11, v14, v15, v17 + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + Float32ToHalf v8, v10, v9, v11, v24, v25 + Float32ToHalf v12, v14, v13, v15, v26, v27 + b TILE8_POST + + TILE8_ADD_DSTV: + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + Float32ToHalf v8, v10, v9, v11, v24, v25 + Float32ToHalf v12, v14, v13, v15, v26, v27 + ld1 {v0.8h, v1.8h, v2.8h, v3.8h}, [x10], #64 + ld1 {v4.8h, v5.8h, v6.8h, v7.8h}, [x10] + fadd v20.8h, v20.8h, v0.8h + fadd v21.8h, v21.8h, v1.8h + fadd v22.8h, v22.8h, v2.8h + fadd v23.8h, v23.8h, v3.8h + fadd v24.8h, v24.8h, v4.8h + fadd v25.8h, v25.8h, v5.8h + fadd v26.8h, v26.8h, v6.8h + fadd v27.8h, v27.8h, v7.8h + + TILE8_POST: + cbz x14, TILE8_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + ReLU_FP16 v20, v21, v22, v23, v29, v28 + ReLU_FP16 v24, v25, v26, v27, v29, v28 + + TILE8_STORE: + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x26], #64 + st1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x26], x4 + +Tile8LoopCheck: + cmp x24, #1 + bge LoopDz_TILE_8 +cbz x23, Tile8End +add x23, x23, #32 +Tile8End: + sub x7, x7, #8 + add x0, x0, x21, LSL #3 + add x1, x1, #64 + add x27, x27, #32 + add x4, x4, #64 // Revert x4 for following tiles + +TILE_4: + cmp x7, #4 + blt TILE_2 + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias +LoopDz_TILE_4: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + mov x10, x26 // tag dst + + SET_0_2 v12, v16 // oc:0,1,0,1 + SET_0_2 v13, v17 // oc:2,3,2,3 + SET_0_2 v14, v18 // oc:4,5,4,5 + SET_0_2 v15, v19 // oc:6,7,6,7 +LoopSz_TILE_4: + ld1 {v2.16b, v3.16b}, [x12], #32 // weight + ld1 {v4.16b, v5.16b}, [x11], x22 // src + // int4->int8 + ushr v8.16b, v2.16b, #4 + ushr v9.16b, v3.16b, #4 + and v10.16b, v2.16b, v30.16b + and v11.16b, v3.16b, v30.16b + subs x13, x13, #1 + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa48e // smmla v14.4s, v4.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba48f // smmla v15.4s, v4.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a4b0 // smmla v16.4s, v5.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa4b2 // smmla v18.4s, v5.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba4b3 // smmla v19.4s, v5.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + bne LoopSz_TILE_4 +LoopSzEnd_TILE_4: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v4.2d, v16.2d, v17.2d + uzp2 v5.2d, v16.2d, v17.2d + uzp1 v6.2d, v18.2d, v19.2d + uzp2 v7.2d, v18.2d, v19.2d + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + +Tile4Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.4s}, [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v4, v5 + MUL_SCALE v21, v2, v3, v6, v7 + + cbz x23, TILE4_MLA + ld1 {v27.4s}, [x23] + MUL_EXTRA_SCALE v27, v0, v1, v4, v5 + MUL_EXTRA_SCALE v27, v2, v3, v6, v7 + + TILE4_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + MLA_WEIGHTZERO v4, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v5, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v6, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v7, v22, v26, 3 // tile:3, oc:4-7 + + cbz x9, TILE4_ADD_DSTV + TILE4_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + ADD_BIAS_FLOAT v0, v1, v4, v5, v16 + ADD_BIAS_FLOAT v2, v3, v6, v7, v17 + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + b TILE4_POST + + TILE4_ADD_DSTV: + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + Float32ToHalf v4, v6, v5, v7, v22, v23 + ld1 {v24.8h, v25.8h, v26.8h, v27.8h}, [x10] + fadd v20.8h, v20.8h, v24.8h + fadd v21.8h, v21.8h, v25.8h + fadd v22.8h, v22.8h, v26.8h + fadd v23.8h, v23.8h, v27.8h + + TILE4_POST: + cbz x14, TILE4_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + ReLU_FP16 v20, v21, v22, v23, v29, v28 + + TILE4_STORE: + st1 {v20.8h, v21.8h, v22.8h, v23.8h}, [x26], x4 + +Tile4LoopCheck: + cmp x24, #1 + bge LoopDz_TILE_4 +cbz x23, Tile4End +add x23, x23, #16 +Tile4End: + sub x7, x7, #4 + add x0, x0, x21, LSL #2 + add x1, x1, #32 + add x27, x27, #16 + //b TILE_4 + +TILE_2: + cmp x7, #2 + blt TILE_1 + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias +LoopDz_TILE_2: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + mov x10, x26 // tag dst + + // v12 oc:0,1,0,1 + // v13 oc:2,3,2,3 + // v14 oc:4,5,4,5 + // v15 oc:6,7,6,7 + SET_0_4 v12, v13, v14, v15 +LoopSz_TILE_2: + ld1 {v2.16b, v3.16b}, [x12], #32 // weight + ld1 {v4.16b}, [x11], x22 // src + // int4->int8 + ushr v8.16b, v2.16b, #4 + ushr v9.16b, v3.16b, #4 + and v10.16b, v2.16b, v30.16b + and v11.16b, v3.16b, v30.16b + + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa48e // smmla v14.4s, v4.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba48f // smmla v15.4s, v4.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + subs x13, x13, #1 + bne LoopSz_TILE_2 +LoopSzEnd_TILE_2: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + Int32ToFloat v0, v1, v2, v3 + +Tile2Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.d}[0], [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + fmul v0.4s, v0.4s, v20.4s + fmul v1.4s, v1.4s, v20.4s + fmul v2.4s, v2.4s, v21.4s + fmul v3.4s, v3.4s, v21.4s + + cbz x23, TILE2_MLA + ld1 {v27.d}[0], [x23] + fmul v0.4s, v0.4s, v27.s[0] + fmul v1.4s, v1.4s, v27.s[1] + fmul v2.4s, v2.4s, v27.s[0] + fmul v3.4s, v3.4s, v27.s[1] + + TILE2_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + cbz x9, TILE2_ADD_DSTV + TILE2_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v17.4s + fadd v3.4s, v3.4s, v17.4s + // float32->float16 + Float32ToHalf v0, v2, v1, v3, v20, v21 + b TILE2_POST + + TILE2_ADD_DSTV: + Float32ToHalf v0, v2, v1, v3, v20, v21 + ld1 {v24.8h, v25.8h}, [x10] + fadd v20.8h, v20.8h, v24.8h + fadd v21.8h, v21.8h, v25.8h + + TILE2_POST: + cbz x14, TILE2_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + fmax v20.8h, v20.8h, v29.8h + fmax v21.8h, v21.8h, v29.8h + fmin v20.8h, v20.8h, v28.8h + fmin v21.8h, v21.8h, v28.8h + + TILE2_STORE: + st1 {v20.8h, v21.8h}, [x26], x4 + +Tile2LoopCheck: + cmp x24, #1 + bge LoopDz_TILE_2 +cbz x23, Tile2End +add x23, x23, #8 +Tile2End: + sub x7, x7, #2 + add x0, x0, x21, LSL #1 + add x1, x1, #16 + add x27, x27, #8 + //b TILE_2 + +TILE_1: + + cmp x7, #1 + blt End + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias +LoopDz_TILE_1: + //ld1 {v7.4s, v8.4s}, [x20], #32 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + mov x10, x26 + + //dup v16.2d, v7.d[0] // oc:0,1,0,1 + //dup v17.2d, v7.d[1] // oc:2,3,2,3 + //dup v18.2d, v8.d[0] // oc:4,5,4,5 + //dup v19.2d, v8.d[1] // oc:6,7,6,7 + movi v16.4s, #0 // oc:0,1,0,1 + movi v17.4s, #0 // oc:2,3,2,3 + movi v18.4s, #0 // oc:4,5,4,5 + movi v19.4s, #0 // oc:6,7,6,7 + + //movi v22.4s, #0 // oc:0,1,0,1 + //movi v23.4s, #0 // oc:2,3,2,3 + //movi v24.4s, #0 // oc:4,5,4,5 + //movi v25.4s, #0 // oc:6,7,6,7 + +LoopSz1_TILE_1: + // src : 1 x [1 x 8] : v2 + // weight : 2 x [2 x 8] : v0-1 + // dst : 1 x 2 x [2] : v30-v31 + ld1 {v13.16b, v14.16b}, [x12], #32 // weight + ld1 {v2.8b}, [x11], x22 // src + // int4->int8 + ushr v0.16b, v13.16b, #4 + and v3.16b, v13.16b, v30.16b + ushr v1.16b, v14.16b, #4 + and v4.16b, v14.16b, v30.16b + + .inst 0x4e80a450 // smmla v16.4s, v2.16b, v0.16b + .inst 0x4e81a451 // smmla v17.4s, v2.16b, v1.16b + .inst 0x4e83a452 // smmla v18.4s, v2.16b, v3.16b + .inst 0x4e84a453 // smmla v19.4s, v2.16b, v4.16b + subs x13, x13, #1 + bne LoopSz1_TILE_1 + + LoopSz_TILE_1_ADD: + //add v16.4s, v16.4s, v22.4s + //add v17.4s, v17.4s, v23.4s + //add v18.4s, v18.4s, v24.4s + //add v19.4s, v19.4s, v25.4s + +LoopSzEnd_TILE_1: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v27.2d, v16.2d, v17.2d + uzp1 v26.2d, v18.2d, v19.2d + scvtf v27.4s, v27.4s + scvtf v26.4s, v26.4s + +Tile1Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v6.s}[0], [x27] // x kernel sum + ld1 {v8.4s, v9.4s}, [x6], #32 // weight quan zeropoint + fmul v27.4s, v27.4s, v0.4s + fmul v26.4s, v26.4s, v1.4s + + cbz x23, TILE1_MLA + ld1 {v4.s}[0], [x23] + fmul v27.4s, v27.4s, v4.s[0] + fmul v26.4s, v26.4s, v4.s[0] + TILE1_MLA: + MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7 + + cbz x9, TILE1_ADD_DSTV + TILE1_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v27.4s, v27.4s, v16.4s + fadd v26.4s, v26.4s, v17.4s + fcvtn v0.4h, v27.4s + fcvtn2 v0.8h, v26.4s + b TILE1_POST + + TILE1_ADD_DSTV: + fcvtn v0.4h, v27.4s + fcvtn2 v0.8h, v26.4s + ld1 {v24.8h}, [x10] + fadd v0.8h, v0.8h, v24.8h + + TILE1_POST: + cbz x14, TILE1_STORE + ld1r {v29.8h}, [x14], #2 // f32 min + ld1r {v28.8h}, [x14] // f32 max + sub x14, x14, #2 + fmax v0.8h, v0.8h, v29.8h + fmin v0.8h, v0.8h, v28.8h + TILE1_STORE: + st1 {v0.8h}, [x26], x4 + +Tile1LoopEnd: + cmp x24, #1 + bge LoopDz_TILE_1 + +End: +ldp x27, x28, [sp, #(16 * 8)] +ldp x25, x26, [sp, #(16 * 7)] +ldp x23, x24, [sp, #(16 * 6)] +ldp x19, x20, [sp, #(16 * 5)] +ldp x21, x22, [sp, #(16 * 4)] +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 10) +ret + +#endif // __aarch64__ diff --git a/source/backend/arm82/asm/arm64/low_memory/MNNQuantScaleFP16.S b/source/backend/arm82/asm/arm64/low_memory/MNNQuantScaleFP16.S index b1e4d8ad0..3c2358402 100644 --- a/source/backend/arm82/asm/arm64/low_memory/MNNQuantScaleFP16.S +++ b/source/backend/arm82/asm/arm64/low_memory/MNNQuantScaleFP16.S @@ -29,10 +29,10 @@ stp d10, d11, [sp, #(16 * 2)] stp d8, d9, [sp, #(16 * 3)] Start: -mov w8, #1123942400 // 127.0 -dup v0.4s, w8 -fcvtn v31.4h, v0.4s -fcvtn2 v31.8h, v0.4s +movi v31.4s, #127 +scvtf v31.4s, v31.4s +//fcvtn v31.4h, v0.4s +//fcvtn2 v31.8h, v0.4s lsl x9, x4, #1 // src_step = batch * sizeof(float16_t) TILE_12: @@ -64,14 +64,32 @@ sub x4, x4, #12 add x0, x0, #24 // quant_scale = 127 / absmax // dequant_scale = absmax / 127 -fdiv v4.8h, v31.8h, v0.8h -fdiv v5.8h, v31.8h, v1.8h -fdiv v6.8h, v0.8h, v31.8h -fdiv v7.8h, v1.8h, v31.8h -st1 {v4.8h}, [x1], #16 -st1 {v5.d}[0], [x1], #8 -st1 {v6.8h}, [x2], #16 -st1 {v7.d}[0], [x2], #8 + +// float16->float32 +fcvtl v4.4s, v0.4h +fcvtl2 v5.4s, v0.8h +fcvtl v6.4s, v1.4h + +fdiv v8.4s, v31.4s, v4.4s +fdiv v9.4s, v31.4s, v5.4s +fdiv v10.4s, v31.4s, v6.4s + +fdiv v12.4s, v4.4s, v31.4s +fdiv v13.4s, v5.4s, v31.4s +fdiv v14.4s, v6.4s, v31.4s + +st1 {v8.4s, v9.4s, v10.4s}, [x1], #48 +st1 {v12.4s, v13.4s, v14.4s}, [x2], #48 + +//fdiv v4.8h, v31.8h, v0.8h +//fdiv v5.8h, v31.8h, v1.8h +//fdiv v6.8h, v0.8h, v31.8h +//fdiv v7.8h, v1.8h, v31.8h + +//st1 {v4.8h}, [x1], #16 +//st1 {v5.d}[0], [x1], #8 +//st1 {v6.8h}, [x2], #16 +//st1 {v7.d}[0], [x2], #8 b TILE_12 TILE_10: @@ -103,14 +121,33 @@ sub x4, x4, #10 add x0, x0, #20 // quant_scale = 127 / absmax // dequant_scale = absmax / 127 -fdiv v4.8h, v31.8h, v0.8h -fdiv v5.8h, v31.8h, v1.8h -fdiv v6.8h, v0.8h, v31.8h -fdiv v7.8h, v1.8h, v31.8h -st1 {v4.8h}, [x1], #16 -st1 {v5.s}[0], [x1], #4 -st1 {v6.8h}, [x2], #16 -st1 {v7.s}[0], [x2], #4 + +// float16->float32 +fcvtl v4.4s, v0.4h +fcvtl2 v5.4s, v0.8h +fcvtl v6.4s, v1.4h + +fdiv v8.4s, v31.4s, v4.4s +fdiv v9.4s, v31.4s, v5.4s +fdiv v10.4s, v31.4s, v6.4s + +fdiv v12.4s, v4.4s, v31.4s +fdiv v13.4s, v5.4s, v31.4s +fdiv v14.4s, v6.4s, v31.4s + +st1 {v8.4s, v9.4s}, [x1], #32 +st1 {v10.d}[0], [x1], #8 +st1 {v12.4s, v13.4s}, [x2], #32 +st1 {v14.d}[0], [x2], #8 + +// fdiv v4.8h, v31.8h, v0.8h +// fdiv v5.8h, v31.8h, v1.8h +// fdiv v6.8h, v0.8h, v31.8h +// fdiv v7.8h, v1.8h, v31.8h +// st1 {v4.8h}, [x1], #16 +// st1 {v5.s}[0], [x1], #4 +// st1 {v6.8h}, [x2], #16 +// st1 {v7.s}[0], [x2], #4 b TILE_10 @@ -139,10 +176,23 @@ sub x4, x4, #8 add x0, x0, #16 // quant_scale = 127 / absmax // dequant_scale = absmax / 127 -fdiv v2.8h, v31.8h, v0.8h -fdiv v3.8h, v0.8h, v31.8h -st1 {v2.8h}, [x1], #16 -st1 {v3.8h}, [x2], #16 +// float16->float32 +fcvtl v4.4s, v0.4h +fcvtl2 v5.4s, v0.8h + +fdiv v8.4s, v31.4s, v4.4s +fdiv v9.4s, v31.4s, v5.4s + +fdiv v12.4s, v4.4s, v31.4s +fdiv v13.4s, v5.4s, v31.4s + +st1 {v8.4s, v9.4s}, [x1], #32 +st1 {v12.4s, v13.4s}, [x2], #32 + +// fdiv v2.8h, v31.8h, v0.8h +// fdiv v3.8h, v0.8h, v31.8h +// st1 {v2.8h}, [x1], #16 +// st1 {v3.8h}, [x2], #16 b TILE_8 @@ -171,10 +221,18 @@ sub x4, x4, #1 add x0, x0, #2 // quant_scale = 127 / absmax // dequant_scale = absmax / 127 -fdiv h2, h31, h0 -fdiv h3, h0, h31 -st1 {v2.h}[0], [x1], #2 -st1 {v3.h}[0], [x2], #2 +fcvtl v4.4s, v0.4h + +fdiv v8.4s, v31.4s, v4.4s +fdiv v12.4s, v4.4s, v31.4s + +st1 {v8.s}[0], [x1], #4 +st1 {v12.s}[0], [x2], #4 + +// fdiv h2, h31, h0 +// fdiv h3, h0, h31 +// st1 {v2.h}[0], [x1], #2 +// st1 {v3.h}[0], [x2], #2 b TILE_1 diff --git a/source/backend/cpu/CMakeLists.txt b/source/backend/cpu/CMakeLists.txt index 22aeb1ef4..41426c66c 100644 --- a/source/backend/cpu/CMakeLists.txt +++ b/source/backend/cpu/CMakeLists.txt @@ -14,7 +14,6 @@ if (MNN_SUPPORT_BF16) endif() list(APPEND MNN_OBJECTS_TO_LINK $) list(APPEND MNN_TARGETS MNNCPU) -option(MNN_SSE_USE_FP16_INSTEAD "Use fp16 instead of bf16 for x86op" OFF) if(MNN_USE_SPARSE_COMPUTE) diff --git a/source/backend/cpu/CPUAttention.cpp b/source/backend/cpu/CPUAttention.cpp index a71472d1f..a420f2d0d 100644 --- a/source/backend/cpu/CPUAttention.cpp +++ b/source/backend/cpu/CPUAttention.cpp @@ -27,198 +27,366 @@ // reduce the value of 'query' to 'query * FP16_QSCALE', avoid fp16 overflow #define FP16_QSCALE 0.5 +#define FP8_E5M2 namespace MNN { +#if defined FP8_E5M2 // E5M2 : [S E E E E E M M] +typedef uint8_t fp8_t; +static inline fp8_t fp16_to_fp8(FLOAT16_T x) { + return *((fp8_t *)(&x) + 1); +} +static FLOAT16_T fp8_to_fp16(fp8_t x) { + uint16_t rawData = 0; + rawData |= (uint16_t)x << 8; + return *((FLOAT16_T *)(&rawData)); +} +static inline fp8_t float_to_fp8(float x) { + uint32_t rawData = *((uint32_t *)(&x)); + int sign = (rawData >> 31) & 1U; + int exp = (int)((rawData >> 23) & 0x0ffU) - 127; + if (exp < -16) + exp = -16; + if (exp > 15) + exp = 15; + exp += 16; // exp [-16, 15] ==> [0, 31] + int mant = (rawData >> 21) & 3U; + return (sign << 7) | (exp << 2) | mant; +} +static inline float fp8_to_float(fp8_t x) { + uint32_t sign = (x >> 7) & 1U; + uint32_t exp = (int)((x >> 2) & 0x1fU) - 16 + 127; + uint32_t mant = (x & 3U) << 21; + uint32_t rawData = (sign << 31) | (exp << 23) | mant; + return *((float *)(&rawData)); +} +#elif defined FP8_E4M3 // E4M3: [S E E E E M M M] +typedef uint8_t fp8_t; +static inline fp8_t fp16_to_fp8(FLOAT16_T x) { + uint16_t rawData = *((uint16_t *)(&x)); + int sign = (rawData >> 15) & 1U; + int exp = (int)((rawData >> 10) & 0x1fU) - 15; + if (exp < -8) + exp = -8; + if (exp > 7) + exp = 7; + exp += 8; // exp [-8, 7] ==> [0, 15] + int mant = (rawData >> 7) & 7U; + return (sign << 7) | (exp << 3) | mant; +} +static FLOAT16_T fp8_to_fp16(fp8_t x) { + uint32_t sign = (x >> 7) & 1U; + uint32_t exp = (int)((x >> 3) & 0x0fU) - 8 + 15; + uint32_t mant = (x & 7U) << 7; + uint16_t rawData = (sign << 15) | (exp << 10) | mant; + return *((FLOAT16_T *)(&rawData)); +} +static inline fp8_t float_to_fp8(float x) { + uint32_t rawData = *((uint32_t *)(&x)); + int sign = (rawData >> 31) & 1U; + int exp = (int)((rawData >> 23) & 0x0ffU) - 127; + if (exp < -8) + exp = -8; + if (exp > 7) + exp = 7; + exp += 8; // exp [-8, 7] ==> [0, 15] + int mant = (rawData >> 20) & 7U; + return (sign << 7) | (exp << 3) | mant; +} +static inline float fp8_to_float(fp8_t x) { + uint32_t sign = (x >> 7) & 1U; + uint32_t exp = (int)((x >> 3) & 0x0fU) - 8 + 127; + uint32_t mant = (x & 7U) << 20; + uint32_t rawData = (sign << 31) | (exp<< 23) | mant; + return *((float *)(&rawData)); +} +#else +// Do not support fp8 +#endif // fp8 format definition + +static int nearestInt(float x) { + return x < 0 ? -nearestInt(-x) : (int)(x + 0.5f); +} + template -static void prefill_pack(Tensor* query, Tensor* key, Tensor* value, char* query_ptr, char* key_ptr, char* value_ptr, - int mMaxLength, int mNumHead, int mKvNumHead, int mHeadDim, int mValueH, - int eP, int hP, int query_e, int key_h, int seq_len, int h, int kv_h, float q_scale) { - auto query_src = query->host(); - auto key_src = key->host(); - auto value_src = value->host(); - auto query_dst = reinterpret_cast(query_ptr); - auto key_dst = reinterpret_cast(key_ptr); - auto value_dst = reinterpret_cast(value_ptr); - // transpose query: [seq_len, num_head, head_dim] -> numhead, [seq_len/eP, head_dim, eP] - for (int i = 0; i < query_e; i++) { +static void pack_query(Tensor* query, char* pack_q, int mNumHead, int mHeadDim, int eP, int seq_len, int h, float q_scale) { + T * query_src = query->host(); + T * query_dst = reinterpret_cast(pack_q); + for (int i = 0; i < seq_len; i++) { + int out_index = i / eP; + int in_index = i % eP; for (int j = 0; j < mHeadDim; j++) { - for (int k = 0; k < eP; k++) { - int s = i * eP + k; - if (s < seq_len) { - query_dst[i * mHeadDim * eP + j * eP + k] = query_src[s * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale; - } + query_dst[out_index * mHeadDim * eP + j * eP + in_index] = query_src[i * mNumHead * mHeadDim + h * mHeadDim + j] * q_scale; + } + } +} + +template +static void pack_key(Tensor* key, char* pack_key, int mPastLength, int seq_len, int mKvNumHead, int mHeadDim, int hP, int kv_h, char* scale, char* zero_point, bool quant) { + if (quant) { // Quantize the keys + auto key_src = key->host(); + auto key_dst = reinterpret_cast(pack_key); + auto scale_dst = reinterpret_cast(scale); + auto zeroPoint_dst = reinterpret_cast(zero_point); + for (int i = 0; i < seq_len; i++) { + float minKey = key_src[i * mKvNumHead * mHeadDim + kv_h * mHeadDim + 0]; + float maxKey = key_src[i * mKvNumHead * mHeadDim + kv_h * mHeadDim + 0]; + for (int j = 1; j < mHeadDim; j++) { + auto key = key_src[i * mKvNumHead * mHeadDim + kv_h * mHeadDim + j]; + minKey = ALIMIN(minKey, key); + maxKey = ALIMAX(maxKey, key); + } + int out_index = (mPastLength + i) / hP; + int in_index = (mPastLength + i) % hP; + scale_dst[out_index * hP + in_index] = (maxKey - minKey) / 255.0f; + zeroPoint_dst[out_index * hP + in_index] = 128.0f * (maxKey - minKey) / 255.0f + minKey; + for (int j = 0; j < mHeadDim; j++) { + key_dst[out_index * mHeadDim * hP + j * hP + in_index] = nearestInt((key_src[i * mKvNumHead * mHeadDim + kv_h * mHeadDim + j] - minKey) / (maxKey - minKey) * 255 - 128); } } } - // transpose key: [seq_len, num_head, head_dim] -> numhead, [seq_len/hP, head_dim, hP] - for (int i = 0; i < key_h; i++) { - for (int j = 0; j < mHeadDim; j++) { + else { // Do not quantize the keys + auto key_src = key->host(); + auto key_dst = reinterpret_cast(pack_key); + for (int i = 0; i < seq_len; i++) { + int out_index = (mPastLength + i) / hP; + int in_index = (mPastLength + i) % hP; + for (int j = 0; j < mHeadDim; j++) { + key_dst[out_index * mHeadDim * hP + j * hP + in_index] = key_src[i * mKvNumHead * mHeadDim + kv_h * mHeadDim + j]; + } + } + } +} + + + +template +static void pack_value(Tensor* value, char* pack_value, int mMaxLength, int mPastLength, int seq_len, int mKvNumHead, int mHeadDim, int hP, int kv_h, bool quant) { + if (quant) { // Quantize the values to fp8 + T * value_src = value->host(); + fp8_t * value_dst = reinterpret_cast(pack_value); + for (int i = 0; i < seq_len; i++) { + for (int j = 0; j < mHeadDim; j++) { + int out_index = j / hP; + int in_index = j % hP; + auto origin = value_src[i * mKvNumHead * mHeadDim + kv_h * mHeadDim + j]; + if (sizeof(T) == 2) + value_dst[out_index * mMaxLength * hP + (mPastLength + i) * hP + in_index] = fp16_to_fp8(origin); + else + value_dst[out_index * mMaxLength * hP + (mPastLength + i) * hP + in_index] = float_to_fp8(origin); + } + } + } + else { // Do not quantize the values + T * value_src = value->host(); + T * value_dst = reinterpret_cast(pack_value); + for (int i = 0; i < seq_len; i++) { + for (int j = 0; j < mHeadDim; j++) { + int out_index = j / hP; + int in_index = j % hP; + value_dst[out_index * mMaxLength * hP + (mPastLength + i) * hP + in_index] = value_src[i * mKvNumHead * mHeadDim + kv_h * mHeadDim + j]; + } + } + } +} + +void dequant_value_float(char * dst, char * src, int mHeadDim, int kv_seq_len, int hP, int mMaxLength) { + fp8_t * qv = (fp8_t *)src; + float * dqv = (float *)dst; + for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { + for (int j = 0; j < kv_seq_len; j++) { for (int k = 0; k < hP; k++) { - int s = i * hP + k; - if (s < seq_len) { - key_dst[i * mHeadDim * hP + j * hP + k] = key_src[s * mKvNumHead * mHeadDim + kv_h * mHeadDim + j]; - } + dqv[i * kv_seq_len * hP + j * hP + k] = fp8_to_float(qv[i * mMaxLength * hP + j * hP + k]); } } } - // transpose value: [seq_len, num_head, head_dim] -> numhead, [head_dim/hP, seq_len, hP] - for (int i = 0; i < mValueH; i++) { - for (int j = 0; j < seq_len; j++) { +} + +void dequant_value_fp16(char * dst, char * src, int mHeadDim, int kv_seq_len, int hP, int mMaxLength) { + fp8_t * qv = (fp8_t *)src; + FLOAT16_T * dqv = (FLOAT16_T *)dst; + for (int i = 0; i < UP_DIV(mHeadDim, hP); i++) { + for (int j = 0; j < kv_seq_len; j++) { for (int k = 0; k < hP; k++) { - int hd = i * hP + k; - if (hd < mHeadDim) { - value_dst[i * mMaxLength * hP + j * hP + k] = value_src[j * mKvNumHead * mHeadDim + kv_h * mHeadDim + hd]; - } + dqv[i * kv_seq_len * hP + j * hP + k] = fp8_to_fp16(qv[i * mMaxLength * hP + j * hP + k]); } } } } template -static void decode_pack(Tensor* query, Tensor* key, Tensor* value, char* query_ptr, char* key_ptr, char* value_ptr, - int mMaxLength, int mPastLength, int mHeadDim, int mValueH, int eP, int hP, int h, int kv_h, float q_scale) { - auto query_src = query->host(); - auto key_src = key->host(); - auto value_src = value->host(); - auto query_dst = reinterpret_cast(query_ptr); - auto key_dst = reinterpret_cast(key_ptr); - auto value_dst = reinterpret_cast(value_ptr); - for (int i = 0; i < mHeadDim; i++) { - query_dst[i * eP] = query_src[h * mHeadDim + i] * q_scale; - } - // transpose key: [1, num_head, head_dim] -> numhead, [kv_seq_len/hP, head_dim, hP] - int outside_offset = UP_DIV(mPastLength, hP); - int inside_offset = mPastLength % hP; - for (int i = 0; i < mHeadDim; i++) { - key_dst[(outside_offset - (inside_offset != 0)) * mHeadDim * hP + i * hP + inside_offset] = key_src[kv_h * mHeadDim + i]; - } - // transpose value: [1, num_head, head_dim] -> numhead, [head_dim/hP, kv_seq_len, hP] - for (int i = 0; i < mValueH; i++) { - for (int j = 0; j < hP; j++) { - value_dst[i * mMaxLength * hP + mPastLength * hP + j] = value_src[kv_h * mHeadDim + i * hP + j]; +static void unpack_QK(float * unpack_qk_dst, char * pack_qk_src, int seq_len, int kv_seq_len, int unit) { + float * dst = unpack_qk_dst; + T * src = (T *)(pack_qk_src); + // [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] + for (int i = 0; i < seq_len; i++) { + for (int j = 0; j < kv_seq_len; j++) { + int out_index = j / unit; + int in_index = j % unit; + dst[i * kv_seq_len + j] = src[out_index * seq_len * unit + i * unit + in_index]; } } } template -static void prefill_unpack(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) { - auto src_ptr = reinterpret_cast(pack_qkv); - auto dst_ptr = reinterpret_cast(unpack_qkv); +static void pack_QK(char * pack_qk_dst, float * qk_src, int seq_len, int kv_seq_len, int eP) { + T * dst = reinterpret_cast(pack_qk_dst); + float * src = reinterpret_cast(qk_src); + // [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP] for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < mHeadDim; j++) { - int a = j / unit; - int b = j % unit; - dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b]; + int out_index = i / eP; + int in_index = i % eP; + for (int j = 0; j < kv_seq_len; j++) { + dst[out_index * kv_seq_len * eP + j * eP + in_index] = src[i * kv_seq_len + j]; } } } template -static void prefill_softmax(int* mask_ptr, float* mask_qk, float* softmax_qk, char* unpack_qk, char* pack_qk, - float mScale, int eP, int query_e, int seq_len, float min_val, bool float_mask) { - T* qk_src = reinterpret_cast(unpack_qk); - T* qk_dst = reinterpret_cast(pack_qk); - if (float_mask) { - T* fpmask_ptr = reinterpret_cast(mask_ptr); +static void mask_QK(float * unpack_qk, int seq_len, int kv_seq_len, float mScale, float min_val, int * mask_ptr, bool float_mask) { + if (seq_len == 1) { + for (int i = 0; i < kv_seq_len; i++) { + unpack_qk[i] = unpack_qk[i] * mScale; + } + } else if (float_mask) { // float mask - for (int i = 0; i < seq_len * seq_len; i++) { - mask_qk[i] = qk_src[i] * mScale + fpmask_ptr[i]; + T* fpmask_ptr = reinterpret_cast(mask_ptr); + for (int i = 0; i < seq_len * kv_seq_len; i++) { + unpack_qk[i] = unpack_qk[i] * mScale + fpmask_ptr[i]; } } else { // int mask - for (int i = 0; i < seq_len * seq_len; i++) { + for (int i = 0; i < seq_len * kv_seq_len; i++) { if (mask_ptr[i]) { - mask_qk[i] = qk_src[i] * mScale; + unpack_qk[i] = unpack_qk[i] * mScale; } else { - mask_qk[i] = min_val; + unpack_qk[i] = min_val; } } } - for (int i = 0; i < seq_len; i++) { - MNNSoftmax(softmax_qk + i * seq_len, mask_qk + i * seq_len, seq_len); - } - for (int i = 0; i < query_e; i++) { - for (int j = 0; j < seq_len; j++) { - for (int k = 0; k < eP; k++) { - int s = i * eP + k; - if (s < seq_len) { - qk_dst[i * seq_len * eP + j * eP + k] = softmax_qk[s * seq_len + j]; - } - } - } +} + +static void softmax_QK(float* softmax_qk_addr, float* unpack_qk_addr, int seq_len, int kv_seq_len) { + for (int i = 0; i < seq_len; i++) { // softmax each row + MNNSoftmax(softmax_qk_addr + i * kv_seq_len, unpack_qk_addr + i * kv_seq_len, kv_seq_len); } } template -static void decode_softmax(float* mask_qk, float* softmax_qk, char* unpack_qk, char* pack_qk, - float mScale, int eP, int kv_seq_len) { - T* qk_src = reinterpret_cast(unpack_qk); - T* qk_dst = reinterpret_cast(pack_qk); - for (int i = 0; i < kv_seq_len; i++) { - mask_qk[i] = qk_src[i] * mScale; - } - // softmax - MNNSoftmax(softmax_qk, mask_qk, kv_seq_len); - // pack qk - for (int i = 0; i < kv_seq_len; i++) { - qk_dst[i * eP] = softmax_qk[i]; +static void unpack_QKV(char* pack_qkv, char* unpack_qkv, int mNumHead, int mHeadDim, int unit, int seq_len) { + auto src_ptr = reinterpret_cast(pack_qkv); + auto dst_ptr = reinterpret_cast(unpack_qkv); + for (int i = 0; i < seq_len; i++) { + for (int j = 0; j < mHeadDim; j++) { + int a = j / unit; + int b = j % unit; + dst_ptr[i * mNumHead * mHeadDim + j] = src_ptr[a * seq_len * unit + i * unit + b]; + } } } -void CPUAttention::allocKVCache() { +void CPUAttention::allocKVCache(int kv_seq_len, bool quantKey, bool quantValue) { if (!mKVCache) { return; } - mResource->mMaxLength = ROUND_UP(mResource->mPastLength, mResource->mExpandChunk); - // past_key: [1, numhead, headdim, maxlen] -> numhead, [headdim, maxlen] -> pack_b -> numhead, [maxlen/hP, head_dim, hP] - mResource->mPastKey.reset(Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), mResource->mHeadDim, hP})); - // past_value: [1, numhead, maxlen, headdim] -> numhead, [maxlen, headdim] -> pack_b -> numhead, [head_dim/hP, max_len, hP] - mResource->mPastValue.reset(Tensor::createDevice({mResource->mKvNumHead, mResource->mValueH, mResource->mMaxLength, hP})); - backend()->onAcquireBuffer(mResource->mPastKey.get(), Backend::STATIC); - backend()->onAcquireBuffer(mResource->mPastValue.get(), Backend::STATIC); + mResource->mMaxLength = kv_seq_len + mResource->mExpandChunk; + if (quantKey) { + mResource->mPastKey.reset(Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), mResource->mHeadDim, hP})); + mResource->mDequantKeyScale.reset(Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), 1, hP})); + mResource->mDequantKeyZeroPoint.reset(Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), 1, hP})); + backend()->onAcquireBuffer(mResource->mPastKey.get(), Backend::STATIC); + backend()->onAcquireBuffer(mResource->mDequantKeyScale.get(), Backend::STATIC); + backend()->onAcquireBuffer(mResource->mDequantKeyZeroPoint.get(), Backend::STATIC); + } else { + mResource->mPastKey.reset(Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), mResource->mHeadDim, hP})); + backend()->onAcquireBuffer(mResource->mPastKey.get(), Backend::STATIC); + } + if (quantValue) { + mResource->mPastValue.reset(Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mHeadDim, hP), mResource->mMaxLength, hP})); + backend()->onAcquireBuffer(mResource->mPastValue.get(), Backend::STATIC); + } else { + mResource->mPastValue.reset(Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mHeadDim, hP), mResource->mMaxLength, hP})); + backend()->onAcquireBuffer(mResource->mPastValue.get(), Backend::STATIC); + } } -void CPUAttention::reallocKVCache() { - if (!mKVCache || mResource->mPastLength < mResource->mMaxLength) { +void CPUAttention::reallocKVCache(int kv_seq_len, bool quantKey, bool quantValue) { + if (!mKVCache || kv_seq_len <= mResource->mMaxLength) { return; } - mResource->mMaxLength = mResource->mPastLength + mResource->mExpandChunk; - // past_key: [1, numhead, headdim, maxlen] -> numhead, [headdim, maxlen] -> pack_b -> numhead, [maxlen/hP, head_dim, hP] - auto new_key = Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), mResource->mHeadDim, hP}); - // past_value: [1, numhead, maxlen, headdim] -> numhead, [maxlen, headdim] -> pack_b -> numhead, [head_dim/hP, max_len, hP] - auto new_value = Tensor::createDevice({mResource->mKvNumHead, mResource->mValueH, mResource->mMaxLength, hP}); - backend()->onAcquireBuffer(new_key, Backend::STATIC); - backend()->onAcquireBuffer(new_value, Backend::STATIC); - // copy - for (int h = 0; h < mResource->mKvNumHead; h++) { - ::memset(new_key->host() + h * UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP * bytes, 0, UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP * bytes); - ::memset(new_value->host() + h * mResource->mValueH * mResource->mMaxLength * hP * bytes, 0, mResource->mValueH * mResource->mMaxLength * hP * bytes); - ::memcpy(new_key->host() + h * UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP * bytes, - mResource->mPastKey->host() + h * UP_DIV(mResource->mPastLength, hP) * mResource->mHeadDim * hP * bytes, - UP_DIV(mResource->mPastLength, hP) * mResource->mHeadDim * hP * bytes); - for (int i = 0; i < mResource->mValueH; i++) { - ::memcpy(new_value->host() + (h * mResource->mValueH + i) * mResource->mMaxLength * hP * bytes, - mResource->mPastValue->host() + (h * mResource->mValueH + i) * mResource->mPastLength * hP * bytes, - mResource->mPastLength * hP * bytes); + int oldMaxLength = mResource->mMaxLength; + mResource->mMaxLength = kv_seq_len + mResource->mExpandChunk; + if (quantKey) { + auto new_key = Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), mResource->mHeadDim, hP}); + auto new_scale = Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), 1, hP}); + auto new_zeroPoint = Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), 1, hP}); + backend()->onAcquireBuffer(new_key, Backend::STATIC); + backend()->onAcquireBuffer(new_scale, Backend::STATIC); + backend()->onAcquireBuffer(new_zeroPoint, Backend::STATIC); + for (int h = 0; h < mResource->mKvNumHead; h++) { + memcpy(new_key->host() + h * UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP, + mResource->mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mResource->mHeadDim * hP, + UP_DIV(oldMaxLength, hP) * mResource->mHeadDim * hP); + memcpy(new_scale->host() + h * UP_DIV(mResource->mMaxLength, hP) * hP * bytes, + mResource->mDequantKeyScale->host() + h * UP_DIV(oldMaxLength, hP) * hP * bytes, + UP_DIV(oldMaxLength, hP) * hP * bytes); + memcpy(new_zeroPoint->host() + h * UP_DIV(mResource->mMaxLength, hP) * hP * bytes, + mResource->mDequantKeyZeroPoint->host() + h * UP_DIV(oldMaxLength, hP) * hP * bytes, + UP_DIV(oldMaxLength, hP) * hP * bytes); } + mResource->mPastKey.reset(new_key); + mResource->mDequantKeyScale.reset(new_scale); + mResource->mDequantKeyZeroPoint.reset(new_zeroPoint); + } + else { + auto new_key = Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mMaxLength, hP), mResource->mHeadDim, hP}); + backend()->onAcquireBuffer(new_key, Backend::STATIC); + for (int h = 0; h < mResource->mKvNumHead; h++) { + memcpy(new_key->host() + h * UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP * bytes, + mResource->mPastKey->host() + h * UP_DIV(oldMaxLength, hP) * mResource->mHeadDim * hP * bytes, + UP_DIV(oldMaxLength, hP) * mResource->mHeadDim * hP * bytes); + } + mResource->mPastKey.reset(new_key); + } + if (quantValue) { + auto new_value = Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mHeadDim, hP), mResource->mMaxLength, hP}); + backend()->onAcquireBuffer(new_value, Backend::STATIC); + for (int h = 0; h < mResource->mKvNumHead; h++) { + for (int i = 0; i < UP_DIV(mResource->mHeadDim, hP); i++) { + memcpy(new_value->host() + (h * UP_DIV(mResource->mHeadDim, hP) + i) * mResource->mMaxLength * hP, + mResource->mPastValue->host() + (h * UP_DIV(mResource->mHeadDim, hP) + i) * oldMaxLength * hP, + oldMaxLength * hP); + } + } + mResource->mPastValue.reset(new_value); + } + else { + auto new_value = Tensor::createDevice({mResource->mKvNumHead, UP_DIV(mResource->mHeadDim, hP), mResource->mMaxLength, hP}); + backend()->onAcquireBuffer(new_value, Backend::STATIC); + for (int h = 0; h < mResource->mKvNumHead; h++) { + for (int i = 0; i < UP_DIV(mResource->mHeadDim, hP); i++) { + memcpy(new_value->host() + (h * UP_DIV(mResource->mHeadDim, hP) + i) * mResource->mMaxLength * hP * bytes, + mResource->mPastValue->host() + (h * UP_DIV(mResource->mHeadDim, hP) + i) * oldMaxLength * hP * bytes, + oldMaxLength * hP * bytes); + } + } + mResource->mPastValue.reset(new_value); } - mResource->mPastKey.reset(new_key); - mResource->mPastValue.reset(new_value); } ErrorCode CPUAttention::onResize(const std::vector& inputs, const std::vector& outputs) { auto core = static_cast(backend())->functions(); core->MNNGetMatMulPackMode(&eP, &lP, &hP); - unit = core->pack; + mThreadNum = ((CPUBackend *)backend())->threadNumber(); + unit = core->pack; bytes = core->bytes; auto query = inputs[0]; - auto shape = query->shape(); - int seq_len = shape[1]; - mThreadNum = ((CPUBackend *)backend())->threadNumber(); - mResource->mHeadDim = shape[3]; - int query_e = UP_DIV(seq_len, eP); - mPackQ.reset(Tensor::createDevice({mThreadNum, query_e, mResource->mHeadDim, eP})); + auto key = inputs[1]; + int seq_len = query->shape()[1]; + mResource->mNumHead = query->shape()[2]; + mResource->mHeadDim = query->shape()[3]; + mResource->mKvNumHead = key->shape()[2]; + mPackQ.reset(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), mResource->mHeadDim, eP})); mPackQKV.reset(Tensor::createDevice({mThreadNum, UP_DIV(mResource->mHeadDim, unit), seq_len, unit})); backend()->onAcquireBuffer(mPackQ.get(), Backend::DYNAMIC); backend()->onAcquireBuffer(mPackQKV.get(), Backend::DYNAMIC); @@ -229,193 +397,240 @@ ErrorCode CPUAttention::onResize(const std::vector& inputs, const std:: ErrorCode CPUAttention::onExecute(const std::vector& inputs, const std::vector& outputs) { auto core = static_cast(backend())->functions(); - auto matmulUnit = core->MNNPackedMatMul; - auto matmulRemain = core->MNNPackedMatMulRemain; auto query = inputs[0]; - auto key = inputs[1]; + auto key = inputs[1]; auto value = inputs[2]; auto mask = inputs[3]; + auto mask_shape = mask->shape(); bool float_mask = (mask->getType() == halide_type_of()); - auto shape = query->shape(); - int seq_len = shape[1]; - mThreadNum = ((CPUBackend *)backend())->threadNumber(); - mIsDecode = seq_len == 1; - mResource->mNumHead = shape[2]; - mResource->mKvNumHead = key->shape()[2]; + int mask_seqlen = mask_shape[2]; + int mask_kvlen = mask_shape[3]; + int seq_len = query->shape()[1]; + MNN_ASSERT(seq_len == mask_seqlen); + mIsPrefill = (seq_len > 1); + // isPrefill and mask is Square Matrix, is FirstPrefill + mIsFirstPrefill = mIsPrefill && (mask_kvlen == mask_seqlen); + int tileCount = UP_DIV(mResource->mNumHead, mThreadNum); int group_size = mResource->mNumHead / mResource->mKvNumHead; - mResource->mHeadDim = shape[3]; - mResource->mScale = 1.0 / sqrt(mResource->mHeadDim); + + // 0: do not quant kv + // 1: only quant k + // 2: only quant v + // 3: quant kv + int quantKV = static_cast(backend())->getRuntime()->hint().kvcacheQuantOption; + bool quantKey = (quantKV & 1) == 1; + bool quantValue = ((quantKV >> 1) & 1) == 1; + // reduce the value of 'query' to avoid fp16 overflow + float mScale = 1.0 / sqrt(mResource->mHeadDim); float q_scale = 1.0; if (bytes == 2) { q_scale = FP16_QSCALE; - mResource->mScale /= q_scale; + mScale /= q_scale; } - mResource->mValueH = UP_DIV(mResource->mHeadDim, hP); - int query_e = UP_DIV(seq_len, eP); - int key_h = UP_DIV(seq_len, hP); - int tileCount = UP_DIV(mResource->mNumHead, mThreadNum); - std::shared_ptr mTempQK; - if (mIsDecode) { - reallocKVCache(); - mTempQK.reset(Tensor::createDevice({mThreadNum, eP + 2, mResource->mPastLength + 1})); - } else { - mResource->mPastLength = seq_len; - allocKVCache(); - mTempQK.reset(Tensor::createDevice({mThreadNum, 4, seq_len, seq_len})); + if (mIsPrefill) { + // Only reset the kvcache in the first prefill, but keep the kvcache in subsequent prefill + if (mIsFirstPrefill) { + mResource->mPastLength = 0; + allocKVCache(seq_len, quantKey, quantValue); + } else { + reallocKVCache(mResource->mPastLength + seq_len, quantKey, quantValue); + } + } else { // Decode + reallocKVCache(mResource->mPastLength + 1, quantKey, quantValue); } - backend()->onAcquireBuffer(mTempQK.get(), Backend::STATIC); + int kv_seq_len = mResource->mPastLength + seq_len; - std::function mPrefill = [=](int tId){ - auto pack_q = mPackQ->host() + tId * query_e * mResource->mHeadDim * eP * bytes; - auto pack_qk = mTempQK->host() + tId * 4 * seq_len * seq_len * bytes; - auto unpack_qk = pack_qk + seq_len * seq_len * 2 * bytes; - auto mask_qk = reinterpret_cast(pack_qk); - auto softmax_qk = reinterpret_cast(unpack_qk); - auto pack_qkv = mPackQKV->host() + tId * UP_DIV(mResource->mHeadDim, unit) * seq_len * unit * bytes; + // Temporary tensors for intermediate results + std::shared_ptr packQK(Tensor::createDevice({mThreadNum, UP_DIV(kv_seq_len, unit), seq_len, unit})); + std::shared_ptr unpackQK(Tensor::createDevice({mThreadNum, seq_len, kv_seq_len})); + std::shared_ptr softmaxQK(Tensor::createDevice({mThreadNum, seq_len, kv_seq_len})); + std::shared_ptr newPackQK(Tensor::createDevice({mThreadNum, UP_DIV(seq_len, eP), kv_seq_len, eP})); + std::shared_ptr dequantV(Tensor::createDevice({mThreadNum, UP_DIV(mResource->mHeadDim, hP), kv_seq_len, hP})); + backend()->onAcquireBuffer(packQK.get(), Backend::STATIC); + backend()->onAcquireBuffer(unpackQK.get(), Backend::STATIC); + backend()->onAcquireBuffer(softmaxQK.get(), Backend::STATIC); + backend()->onAcquireBuffer(newPackQK.get(), Backend::STATIC); + if (quantValue) { + backend()->onAcquireBuffer(dequantV.get(), Backend::STATIC); + } - int head_index = tId * tileCount; + std::function mCompute = [=](int tId) { + auto pack_q = mPackQ->host() + tId * UP_DIV(seq_len, eP) * mResource->mHeadDim * eP * bytes; + auto pack_qk = packQK->host() + tId * UP_DIV(kv_seq_len, unit) * seq_len * unit * bytes; + auto unpack_qk = unpackQK->host() + tId * seq_len * kv_seq_len; + auto softmax_qk = softmaxQK->host() + tId * seq_len * kv_seq_len; + auto new_pack_qk = newPackQK->host() + tId * UP_DIV(seq_len, eP) * kv_seq_len * eP * bytes; + auto pack_qkv = mPackQKV->host() + tId * UP_DIV(mResource->mHeadDim, unit) * seq_len * unit * bytes; + int head_index = tId * tileCount; for (int h = head_index; h < head_index + tileCount && h < mResource->mNumHead; h++) { + int kv_h = h / group_size; + char * key_dst = nullptr; + char * key_scale_dst = nullptr; + char * key_zero_point_dst = nullptr; + char * value_dst = nullptr; + if (quantKey) { + key_dst = mResource->mPastKey->host() + kv_h * UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP; + key_scale_dst = mResource->mDequantKeyScale->host() + kv_h * UP_DIV(mResource->mMaxLength, hP) * 1 * hP * bytes; + key_zero_point_dst = mResource->mDequantKeyZeroPoint->host() + kv_h * UP_DIV(mResource->mMaxLength, hP) * 1 * hP * bytes; + } else { + key_dst = mResource->mPastKey->host() + kv_h * UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP * bytes; + } + if (quantValue) { + value_dst = mResource->mPastValue->host() + kv_h * UP_DIV(mResource->mHeadDim, hP) * mResource->mMaxLength * hP; + } else { + value_dst = mResource->mPastValue->host() + kv_h * UP_DIV(mResource->mHeadDim, hP) * mResource->mMaxLength * hP * bytes; + } // pack for matmul - int kv_h = h / group_size; - auto key_dst = mResource->mPastKey->host() + kv_h * UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP * bytes; - auto value_dst = mResource->mPastValue->host() + kv_h * mResource->mValueH * mResource->mMaxLength * hP * bytes; if (bytes == 2) { - prefill_pack(query, key, value, pack_q, key_dst, value_dst, mResource->mMaxLength, mResource->mNumHead, mResource->mKvNumHead, mResource->mHeadDim, mResource->mValueH, eP, hP, query_e, key_h, seq_len, h, kv_h, q_scale); + pack_query(query, pack_q, mResource->mNumHead, mResource->mHeadDim, eP, seq_len, h, q_scale); + pack_key(key, key_dst, mResource->mPastLength, seq_len, mResource->mKvNumHead, mResource->mHeadDim, hP, kv_h, key_scale_dst, key_zero_point_dst, quantKey); + pack_value(value, value_dst, mResource->mMaxLength, mResource->mPastLength, seq_len, mResource->mKvNumHead, mResource->mHeadDim, hP, kv_h, quantValue); } else { - prefill_pack(query, key, value, pack_q, key_dst, value_dst, mResource->mMaxLength, mResource->mNumHead, mResource->mKvNumHead, mResource->mHeadDim, mResource->mValueH, eP, hP, query_e, key_h, seq_len, h, kv_h, q_scale); + pack_query(query, pack_q, mResource->mNumHead, mResource->mHeadDim, eP, seq_len, h, q_scale); + pack_key(key, key_dst, mResource->mPastLength, seq_len, mResource->mKvNumHead, mResource->mHeadDim, hP, kv_h, key_scale_dst, key_zero_point_dst, quantKey); + pack_value(value, value_dst, mResource->mMaxLength, mResource->mPastLength, seq_len, mResource->mKvNumHead, mResource->mHeadDim, hP, kv_h, quantValue); } // query @ key int loop_e = seq_len / eP; int remain = seq_len % eP; for (int i = 0 ; i < loop_e; i++) { - size_t shapeParameters[6]; + size_t shapeParameters[7]; size_t* parameters = shapeParameters; - parameters[0] = eP * bytes; - parameters[1] = mResource->mHeadDim; - parameters[2] = seq_len; - parameters[3] = seq_len * unit * bytes; - parameters[4] = 0; - parameters[5] = 0; - matmulUnit((float*)(pack_qk + (i * eP * unit) * bytes), (float*)(pack_q + (i * mResource->mHeadDim * eP) * bytes), (float*)key_dst, parameters, nullptr, nullptr, nullptr, nullptr); + parameters[0] = eP * bytes; + parameters[1] = mResource->mHeadDim; + parameters[2] = kv_seq_len; + parameters[3] = seq_len * unit * bytes; + parameters[4] = 0; + parameters[5] = 0; + parameters[6] = 0; + if (quantKey) { + core->MNNPackedMatMul_int8( + (float*)(pack_qk + (i * eP * unit) * bytes), + (float*)(pack_q + (i * mResource->mHeadDim * eP) * bytes), + (float*)key_dst, + parameters, nullptr, nullptr, + (float*)key_scale_dst, (float*)key_zero_point_dst + ); + } else { + core->MNNPackedMatMul( + (float*)(pack_qk + (i * eP * unit) * bytes), + (float*)(pack_q + (i * mResource->mHeadDim * eP) * bytes), + (float*)key_dst, + parameters, nullptr, nullptr, + nullptr, nullptr + ); + } } { - size_t shapeParameters[6]; + size_t shapeParameters[7]; size_t* parameters = shapeParameters; - parameters[0] = eP * bytes; - parameters[1] = mResource->mHeadDim; - parameters[2] = seq_len; - parameters[3] = seq_len * unit * bytes; - parameters[4] = 0; - parameters[5] = 0; - matmulRemain((float*)(pack_qk + (loop_e * eP * unit) * bytes), (float*)(pack_q + (loop_e * mResource->mHeadDim * eP) * bytes), (float*)key_dst, remain, parameters, nullptr, nullptr, nullptr, nullptr); + parameters[0] = eP * bytes; + parameters[1] = mResource->mHeadDim; + parameters[2] = kv_seq_len; + parameters[3] = seq_len * unit * bytes; + parameters[4] = 0; + parameters[5] = 0; + parameters[6] = 0; + if (quantKey) { + core->MNNPackedMatMulRemain_int8( + (float*)(pack_qk + (loop_e * eP * unit) * bytes), + (float*)(pack_q + (loop_e * mResource->mHeadDim * eP) * bytes), + (float*)key_dst, + remain, parameters, nullptr, nullptr, + (float*)key_scale_dst, (float*)key_zero_point_dst + ); + } else { + core->MNNPackedMatMulRemain( + (float*)(pack_qk + (loop_e * eP * unit) * bytes), + (float*)(pack_q + (loop_e * mResource->mHeadDim * eP) * bytes), + (float*)key_dst, + remain, parameters, nullptr, nullptr, + nullptr, nullptr + ); + } } - int area_offset[2] {seq_len, 0}; - core->MNNUnpackCUnitTranspose((float*)unpack_qk, (float*)pack_qk, seq_len, seq_len, area_offset); - // div scale and mask - auto mask_ptr = mask->host(); - if (bytes == 2) { - prefill_softmax(mask_ptr, mask_qk, softmax_qk, unpack_qk, pack_qk, mResource->mScale, eP, query_e, seq_len, -65504.0, float_mask); + if(bytes == 2) { + // unpack qk: [kv_seq_len/unit, seq_len, unit] -> [seq_len, kv_seq_len] + unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len, unit); + mask_QK(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask->host(), float_mask); + softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len); + // pack qk for qk @ v : [seq_len, kv_seq_len] -> [seq_len/eP, kv_seq_len, eP] + pack_QK(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP); } else { - prefill_softmax(mask_ptr, mask_qk, softmax_qk, unpack_qk, pack_qk, mResource->mScale, eP, query_e, seq_len, std::numeric_limits::lowest(), float_mask); + unpack_QK(unpack_qk, pack_qk, seq_len, kv_seq_len, unit); + mask_QK(unpack_qk, seq_len, kv_seq_len, mScale, std::numeric_limits::lowest(), mask->host(), float_mask); + softmax_QK(softmax_qk, unpack_qk, seq_len, kv_seq_len); + pack_QK(new_pack_qk, softmax_qk, seq_len, kv_seq_len, eP); + } + // Dequantize values from fp8 to float + if (quantValue) { + char * qv = value_dst; + char * dqv = dequantV->host() + tId * UP_DIV(mResource->mHeadDim, hP) * kv_seq_len * hP * bytes; + if (bytes == 2) { + dequant_value_fp16(dqv, qv, mResource->mHeadDim, kv_seq_len, hP, mResource->mMaxLength); + } else { + dequant_value_float(dqv, qv, mResource->mHeadDim, kv_seq_len, hP, mResource->mMaxLength); + } + value_dst = dqv; } // qk @ v for (int i = 0 ; i < loop_e; i++) { size_t shapeParameters[6]; size_t* parameters = shapeParameters; parameters[0] = eP * bytes; - parameters[1] = seq_len; + parameters[1] = kv_seq_len; parameters[2] = mResource->mHeadDim; parameters[3] = seq_len * unit * bytes; parameters[4] = 0; - parameters[5] = (mResource->mMaxLength - seq_len) * hP * bytes; - matmulUnit((float*)(pack_qkv + (i * eP * unit) * bytes), (float*)(pack_qk + (i * seq_len * eP) * bytes), (float*)value_dst, parameters, nullptr, nullptr, nullptr, nullptr); + parameters[5] = quantValue ? 0 : (mResource->mMaxLength - kv_seq_len) * hP * bytes; + core->MNNPackedMatMul( + (float*)(pack_qkv + (i * eP * unit) * bytes), + (float*)(new_pack_qk + (i * kv_seq_len * eP) * bytes), + (float*)value_dst, parameters, + nullptr, nullptr, nullptr, nullptr + ); } { size_t shapeParameters[6]; size_t* parameters = shapeParameters; parameters[0] = eP * bytes; - parameters[1] = seq_len; + parameters[1] = kv_seq_len; parameters[2] = mResource->mHeadDim; parameters[3] = seq_len * unit * bytes; parameters[4] = 0; - parameters[5] = (mResource->mMaxLength - seq_len) * hP * bytes; - matmulRemain((float*)(pack_qkv + (loop_e * eP * unit) * bytes), (float*)(pack_qk + (loop_e * seq_len * eP) * bytes), (float*)value_dst, remain, parameters, nullptr, nullptr, nullptr, nullptr); + parameters[5] = quantValue ? 0 : (mResource->mMaxLength - kv_seq_len) * hP * bytes; + core->MNNPackedMatMulRemain( + (float*)(pack_qkv + (loop_e * eP * unit) * bytes), + (float*)(new_pack_qk + (loop_e * kv_seq_len * eP) * bytes), + (float*)value_dst, remain, parameters, + nullptr, nullptr, nullptr, nullptr + ); } - // transpose: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim] + // unpack: [head_dim/unit, seq_len, unit] -> [seq_len, num_head, head_dim] auto dst_ptr = outputs[0]->host() + h * mResource->mHeadDim * bytes; if (bytes == 2) { - prefill_unpack(pack_qkv, dst_ptr, mResource->mNumHead, mResource->mHeadDim, unit, seq_len); + unpack_QKV(pack_qkv, dst_ptr, mResource->mNumHead, mResource->mHeadDim, unit, seq_len); } else { - prefill_unpack(pack_qkv, dst_ptr, mResource->mNumHead, mResource->mHeadDim, unit, seq_len); + unpack_QKV(pack_qkv, dst_ptr, mResource->mNumHead, mResource->mHeadDim, unit, seq_len); } } }; - std::function mDecode = [=](int tId) { - int kv_seq_len = mResource->mPastLength + 1; - auto pack_q = mPackQ->host() + tId * mResource->mHeadDim * eP * bytes; - auto pack_qk = mTempQK->host() + tId * (eP + 2) * kv_seq_len * bytes; - auto unpack_qk = pack_qk + kv_seq_len * eP * bytes; - auto mask_qk = reinterpret_cast(pack_qk); - auto softmax_qk = reinterpret_cast(unpack_qk); - auto pack_qkv = mPackQKV->host() + tId * UP_DIV(mResource->mHeadDim, unit) * unit * bytes; - - int head_index = tId * tileCount; - for (int h = head_index; h < head_index + tileCount && h < mResource->mNumHead; h++) { - int kv_h = h / group_size; - auto key_dst = mResource->mPastKey->host() + kv_h * UP_DIV(mResource->mMaxLength, hP) * mResource->mHeadDim * hP * bytes; - auto value_dst = mResource->mPastValue->host() + kv_h * mResource->mValueH * mResource->mMaxLength * hP * bytes; - // pack for matmul - if (bytes == 2) { - decode_pack(query, key, value, pack_q, key_dst, value_dst, mResource->mMaxLength, mResource->mPastLength, mResource->mHeadDim, mResource->mValueH, eP, hP, h, kv_h, q_scale); - } else { - decode_pack(query, key, value, pack_q, key_dst, value_dst, mResource->mMaxLength, mResource->mPastLength, mResource->mHeadDim, mResource->mValueH, eP, hP, h, kv_h, q_scale); - } - // query @ key: [1, head_dim] @ [head_dim, kv_seq_len] -> [1, kv_seq_len] - size_t shapeParameters[6]; - size_t* parameters = shapeParameters; - parameters[0] = eP * bytes; - parameters[1] = mResource->mHeadDim; - parameters[2] = kv_seq_len; - parameters[3] = seq_len * unit * bytes; - parameters[4] = 0; - parameters[5] = 0; - matmulRemain((float*)pack_qk, (float*)pack_q, (float*)key_dst, seq_len, parameters, nullptr, nullptr, nullptr, nullptr); - int area_offset[2] {seq_len, 0}; - core->MNNUnpackCUnitTranspose((float*)unpack_qk, (float*)pack_qk, seq_len, kv_seq_len, area_offset); - if (bytes == 2) { - decode_softmax(mask_qk, softmax_qk, unpack_qk, pack_qk, mResource->mScale, eP, kv_seq_len); - } else { - decode_softmax(mask_qk, softmax_qk, unpack_qk, pack_qk, mResource->mScale, eP, kv_seq_len); - } - // qk @ v: [1, kv_seq_len] @ [kv_seq_len, head_dim] -> [1, head_dim] - { - size_t shapeParameters[6]; - size_t* parameters = shapeParameters; - parameters[0] = eP * bytes; - parameters[1] = kv_seq_len; - parameters[2] = mResource->mHeadDim; - parameters[3] = 1 * unit * bytes; - parameters[5] = (mResource->mMaxLength - kv_seq_len) * hP * bytes; - matmulRemain((float*)pack_qkv, (float*)pack_qk, (float*)value_dst, 1, parameters, nullptr, nullptr, nullptr, nullptr); - } - // transpose: [head_dim/unit, 1, unit] -> [1, num_head, head_dim] - auto dst_ptr = outputs[0]->host() + h * mResource->mHeadDim * bytes; - core->MNNUnpackCUnitTranspose((float*)dst_ptr, (float*)pack_qkv, 1, mResource->mHeadDim, area_offset); - } - }; - - std::function mFunction = mIsDecode ? mDecode : mPrefill; MNN_CONCURRENCY_BEGIN(tId, mThreadNum) { - mFunction((int)tId); + mCompute((int)tId); } MNN_CONCURRENCY_END(); - if(mIsDecode) { - mResource->mPastLength++; + + mResource->mPastLength += seq_len; + backend()->onReleaseBuffer(packQK.get(), Backend::STATIC); + backend()->onReleaseBuffer(unpackQK.get(), Backend::STATIC); + backend()->onReleaseBuffer(softmaxQK.get(), Backend::STATIC); + backend()->onReleaseBuffer(newPackQK.get(), Backend::STATIC); + if (quantValue){ + backend()->onReleaseBuffer(dequantV.get(), Backend::STATIC); } - backend()->onReleaseBuffer(mTempQK.get(), Backend::STATIC); return NO_ERROR; } @@ -447,4 +662,4 @@ REGISTER_CPU_OP_CREATOR_TRANSFORMER(CPUAttentionCreator, OpType_Attention); } // namespace MNN -#endif +#endif \ No newline at end of file diff --git a/source/backend/cpu/CPUAttention.hpp b/source/backend/cpu/CPUAttention.hpp index bc48de6b4..abf351249 100644 --- a/source/backend/cpu/CPUAttention.hpp +++ b/source/backend/cpu/CPUAttention.hpp @@ -25,17 +25,19 @@ class CPUAttention : public Execution { virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; struct Resource { - std::shared_ptr mPastKey; - std::shared_ptr mPastValue; - float mScale; - const int mExpandChunk = 64; + std::shared_ptr mPastKey; // numhead, [maxlen/eP, headdim, eP] + std::shared_ptr mPastValue; // numhead, [headdim/eP, maxlen, eP] + std::shared_ptr mDequantKeyScale; // numhead, [maxlen/eP, 1, eP] + std::shared_ptr mDequantKeyZeroPoint; // numhead, [maxlen/eP, 1, eP] int mPastLength = 0, mMaxLength = 0; - int mNumHead = 0, mKvNumHead = 0, mHeadDim = 0, mValueH = 0; + const int mExpandChunk = 64; + int mNumHead = 0, mKvNumHead = 0, mHeadDim = 0; }; private: - void allocKVCache(); - void reallocKVCache(); - bool mIsDecode = false; + void allocKVCache(int kv_seq_len, bool quantK, bool quantV); + void reallocKVCache(int kv_seq_len, bool quantK, bool quantV); + bool mIsPrefill = true; + bool mIsFirstPrefill = true; bool mKVCache; int mThreadNum = 1; std::shared_ptr mResource; diff --git a/source/backend/cpu/CPUBackend.cpp b/source/backend/cpu/CPUBackend.cpp index 66c349c37..5f1a75eab 100644 --- a/source/backend/cpu/CPUBackend.cpp +++ b/source/backend/cpu/CPUBackend.cpp @@ -30,7 +30,6 @@ #define MAX_THREAD_NUMBER 32 #define LARGE_MEMORY 1024 * 1024 * 500 #ifdef MNN_SUPPORT_BF16 -#include "bf16/BF16Backend.hpp" #include "bf16/BF16Functions.hpp" #endif @@ -48,56 +47,183 @@ ErrorCode CastWrapExecution::onExecute(const std::vector& inputs, const CPUCastCreator::cast(inputs[0], outputs[0], cpuBackend, convertType); return NO_ERROR; } - -CPURuntime::CPURuntime(const Backend::Info& info) { - mStaticAllocator.reset(new EagerBufferAllocator(BufferAllocator::Allocator::createDefault())); - mThreadNumber = info.numThread; - mThreadNumber = std::max(1, mThreadNumber); - mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER); - mPower = BackendConfig::Power_Normal; - mMemory = BackendConfig::Memory_Normal; - mPrecision = BackendConfig::Precision_Normal; - mFlops = MNNGetCPUFlops(mThreadNumber); - if (info.user != nullptr) { - mPrecision = info.user->precision; - mPower = info.user->power; - mMemory = info.user->memory; - mFlags = info.user->flags; +void CPURuntime::computeDivideSizes(int size, int* dst) const { + if (mGroupWithComputeRate.size() <= 1) { + // Avg divide + int length = UP_DIV(size, mThreadNumber); + int cur = length; + for (int i=0; igroups.size() == 0) { + return; + } + std::vector> lockCPUIndexes(mThreadNumber); switch (mPower) { case BackendConfig::Power_Low: - MNNSetCPUThreadsMode(MNN_CPU_MODE_LITTLE); + for (int v=0; vgroups[0].ids.data(), cpuInfo->groups[0].ids.size()); + } break; case BackendConfig::Power_High: - MNNSetCPUThreadsMode(MNN_CPU_MODE_POWER_FRI); + { + int selectCPUSize = 0; + int groupIndex = cpuInfo->groups.size() - 1; + while (selectCPUSize < mThreadNumber && groupIndex >= 0) { + auto& group = cpuInfo->groups[groupIndex]; + int size = ALIMIN(group.ids.size(), mThreadNumber - selectCPUSize); + for (int v=0; v result(threadsNumber, 0); +#pragma omp parallel for + for (int i = 0; i < threadsNumber; ++i) { + result[i] = MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second); + } #endif #ifdef MNN_USE_THREAD_POOL - mThreadNumber = ThreadPool::init(mThreadNumber); + ThreadPool::active(mThreadNumber); + ThreadPool::enqueue(std::make_pair([&](int i) { + MNNSetSchedAffinity(lockCPUIndexes[i].first, lockCPUIndexes[i].second); + return 0; + }, mThreadNumber), mTaskIndex, mThreadNumber); + ThreadPool::deactive(mThreadNumber); +#endif +} + +void CPURuntime::_resetGroupCompute() const { + if (mPastDecreaseHint == hint().cpuDecreaseRate) { + return; + } + mGroupWithComputeRate.clear(); + if (mThreadNumber <= 1 || mPower == BackendConfig::Power_Low) { + return; + } + mPastDecreaseHint = hint().cpuDecreaseRate; + auto cpuInfo = MNNGetCPUInfo(); + if (cpuInfo->groups.size() < 2) { + return; + } + float decreaseRate = (float)(hint().cpuDecreaseRate) / 100.0f; + int validCpuSize = (int)(cpuInfo->groups[cpuInfo->groups.size()-1].ids.size()); + int groupIndex = (int)cpuInfo->groups.size()-2; + float maxFreq = (float)cpuInfo->groups[cpuInfo->groups.size()-1].maxFreq; + validCpuSize = ALIMIN(validCpuSize, mThreadNumber); + float totalComputeRate = 1.0f * validCpuSize; + mGroupWithComputeRate.emplace_back(std::make_pair(totalComputeRate, validCpuSize)); + float currentRate = 1.0f; + while (validCpuSize < mThreadNumber && groupIndex >= 0) { + auto& group = cpuInfo->groups[groupIndex]; + int selectSize = ALIMIN(mThreadNumber - validCpuSize, (int)group.ids.size()); + validCpuSize += group.ids.size(); + currentRate *= decreaseRate; + totalComputeRate += currentRate * selectSize; + mGroupWithComputeRate.emplace_back(std::make_pair(currentRate * selectSize, selectSize)); + } + for (auto& g : mGroupWithComputeRate) { + g.first = g.first / totalComputeRate; + } +} + +void CPURuntime::_resetThreadPool() { + mThreadNumber = std::max(1, mThreadNumber); + mThreadNumber = std::min(mThreadNumber, MAX_THREAD_NUMBER); +#ifdef MNN_USE_THREAD_POOL + ThreadPool::releaseWorkIndex(mTaskIndex); + auto cpuInfo = MNNGetCPUInfo(); + if (mThreadNumber > 1) { + int systemThreadNumber = (int)cpuInfo->cpuNumber; + if (systemThreadNumber == 0) { + systemThreadNumber = mThreadNumber; + } + mThreadNumber = ALIMIN(ThreadPool::init(systemThreadNumber), mThreadNumber); + } + mGroupWithComputeRate.clear(); if (mThreadNumber > 1) { mTaskIndex = ThreadPool::acquireWorkIndex(); + if (-1 == mTaskIndex) { + MNN_ERROR("The ThreadPool has been used to MNN_THREAD_POOL_MAX_TASKS, can't use thread pool\n"); + mThreadNumber = 1; + } } else { mTaskIndex = -1; } - if (mTaskIndex >= 0 && mPower == BackendConfig::Power_High) { - ThreadPool::active(); - } #endif + // Reset tid to rebind cpu if necessary + mCurrentTID = 0; +} +void CPURuntime::onReset(int numberThread, const BackendConfig* config) { + if (config != nullptr) { + mPrecision = config->precision; + mPower = config->power; + mMemory = config->memory; + mFlags = config->flags; + } + mThreadNumber = numberThread; + _resetThreadPool(); + // Mask Group Compute reset + mPastDecreaseHint = -1; +} + +CPURuntime::CPURuntime(const Backend::Info& info) { + mStaticAllocator.reset(new EagerBufferAllocator(BufferAllocator::Allocator::createDefault())); + mThreadNumber = info.numThread; + mPower = BackendConfig::Power_Normal; + mMemory = BackendConfig::Memory_Normal; + mPrecision = BackendConfig::Precision_Normal; + if (info.user != nullptr) { + mPrecision = info.user->precision; + mPower = info.user->power; + mMemory = info.user->memory; + mFlags = info.user->flags; + } + _resetThreadPool(); #ifdef LOG_VERBOSE MNN_PRINT("create CPURuntime:%p\n", this); #endif } CPURuntime:: ~ CPURuntime() { #ifdef MNN_USE_THREAD_POOL - if (mTaskIndex >= 0 && mPower == BackendConfig::Power_High) { - ThreadPool::deactive(); - } ThreadPool::releaseWorkIndex(mTaskIndex); #endif } @@ -106,13 +232,7 @@ float CPURuntime::onGetMemoryInMB() { return staticMemoryInMB; } bool CPURuntime::onCheckInfo(Backend::Info& info) const { -#ifdef MNN_USE_THREAD_POOL - int threadNumber = mThreadNumber; - if (mTaskIndex < 0) { - threadNumber = 1; - } - info.numThread = threadNumber; -#endif + info.numThread = mThreadNumber; return true; } @@ -120,6 +240,7 @@ Backend* CPURuntime::onCreate(const BackendConfig* config) const { auto precision = mPrecision; auto memory = mMemory; size_t flags = mFlags; + _resetGroupCompute(); if (nullptr != config) { precision = config->precision; flags = config->flags; @@ -137,7 +258,9 @@ Backend* CPURuntime::onCreate(const BackendConfig* config) const { #endif #ifdef MNN_SUPPORT_BF16 if (precision == BackendConfig::Precision_Low_BF16 && BF16Functions::get()) { - return new BF16Backend(this); + auto res = new CPUBackend(this, precision, memory, MNN_FORWARD_CPU_EXTENSION, 0); + res->mCoreFunctions = BF16Functions::get(); + return res; } #endif if (flags == MNN_CPU_USE_DEFAULT_BACKEND) { @@ -178,8 +301,9 @@ void CPURuntime::onGabageCollect(int level) { void CPURuntime::onConcurrencyBegin() const { #ifdef MNN_USE_THREAD_POOL - if (mTaskIndex >= 0 && mPower != BackendConfig::Power_High) { - ThreadPool::active(); + if (mTaskIndex >= 0) { + ThreadPool::active(mThreadNumber); + mThreadOpen = true; } #else #ifdef _OPENMP @@ -187,12 +311,14 @@ void CPURuntime::onConcurrencyBegin() const { omp_set_num_threads(mThreadNumber); #endif #endif + _bindCPUCore(); } void CPURuntime::onConcurrencyEnd() const { #ifdef MNN_USE_THREAD_POOL - if (mTaskIndex >= 0 && mPower != BackendConfig::Power_High) { - ThreadPool::deactive(); + if (mTaskIndex >= 0) { + ThreadPool::deactive(mThreadNumber); + mThreadOpen = false; } #endif } @@ -219,7 +345,7 @@ CPUBackend::CPUBackend(const CPURuntime* runtime, BackendConfig::PrecisionMode p mMemory = memory; mRuntime = const_cast(runtime); std::shared_ptr defaultAlloc(BufferAllocator::Allocator::createRecurse(runtime->mStaticAllocator.get())); - if (mRuntime->getAllocatorType() == Runtime::Allocator_Defer) { + if (mRuntime->hint().memoryAllocatorType == Runtime::Allocator_Defer) { mDynamicAllocator.reset(new DeferBufferAllocator(defaultAlloc)); } else { mDynamicAllocator.reset(new EagerBufferAllocator(defaultAlloc)); @@ -256,7 +382,7 @@ bool CPUBackend::onSelectDynamicAllocator(int index, int maxIndex) { return false; } if (maxIndex == 2 && mDynamicAllocatorBackup.get() == nullptr) { - if (mRuntime->getAllocatorType() == Runtime::Allocator_Defer) { + if (mRuntime->hint().memoryAllocatorType == Runtime::Allocator_Defer) { mDynamicAllocatorBackup.reset(new DeferBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticAllocator.get()))); } else { mDynamicAllocatorBackup.reset(new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(mStaticAllocator.get()))); diff --git a/source/backend/cpu/CPUBackend.hpp b/source/backend/cpu/CPUBackend.hpp index 7793b696c..1ac8721de 100644 --- a/source/backend/cpu/CPUBackend.hpp +++ b/source/backend/cpu/CPUBackend.hpp @@ -11,6 +11,7 @@ #include #include +#include #include "core/Backend.hpp" #include "core/Execution.hpp" #include "core/BufferAllocator.hpp" @@ -24,6 +25,7 @@ class CPURuntime : public Runtime { virtual ~ CPURuntime(); int onGetRuntimeStatus(RuntimeStatus statusEnum) const override; virtual Backend* onCreate(const BackendConfig* config) const override; + virtual void onReset(int numberThread, const BackendConfig* config) override; virtual void onGabageCollect(int level) override; virtual float onGetMemoryInMB() override; virtual CompilerType onGetCompilerType() const override { @@ -33,20 +35,35 @@ class CPURuntime : public Runtime { void onConcurrencyEnd() const; virtual bool onCheckInfo(Backend::Info& info) const override; + // dividedSize's length should be larger than threadNumber + void computeDivideSizes(int size, int* dst) const; + +#ifdef MNN_USE_THREAD_POOL + inline bool multiThreadValid() const { + return mThreadOpen; + } +#endif private: + void _bindCPUCore() const; + void _resetThreadPool(); std::shared_ptr mStaticAllocator; int mThreadNumber; - mutable int mTaskIndex; +#ifdef MNN_USE_THREAD_POOL + mutable int mTaskIndex = -1; + mutable bool mThreadOpen = false; +#endif + void _resetGroupCompute() const; + mutable std::vector> mGroupWithComputeRate; + mutable int mPastDecreaseHint = -1; BackendConfig::MemoryMode mMemory; BackendConfig::PowerMode mPower; BackendConfig::PrecisionMode mPrecision; // Backend features // CPU features - float mFlops = 0.0f; static Backend*(*gExtraCreate)(const Runtime* runtime); size_t mFlags = 0; - int mAllocator = 0; + mutable int mCurrentTID = 0; }; struct CoreFunctions; struct CoreInt8Functions; @@ -114,9 +131,14 @@ class CPUBackend : public Backend { static bool addCreator(OpType t, Creator* c); - int threadNumber() const { + inline int threadNumber() const { return mRuntime->mThreadNumber; } +#ifdef MNN_USE_THREAD_POOL + inline bool threadOpen() const { + return mRuntime->mThreadOpen; + } +#endif BufferAllocator* getBufferAllocator(bool defer_allocator = true) const { return mCurrentDynamicAllocator; @@ -140,12 +162,13 @@ class CPUBackend : public Backend { static void initCreatorMap(); static int getBytes(const Backend* backend, const Tensor* output); static DataType getDataType(const Tensor* tensor); + friend class CPURuntime; protected: MemObj* allocBuffer(size_t size, Tensor* dest, StorageType storageType); - const CoreFunctions* mCoreFunctions; - const CoreInt8Functions* mInt8CoreFunctions; + CoreFunctions* mCoreFunctions; + CoreInt8Functions* mInt8CoreFunctions; private: std::shared_ptr mStaticAllocator; std::shared_ptr mDynamicAllocator; diff --git a/source/backend/cpu/CPUConvolution.cpp b/source/backend/cpu/CPUConvolution.cpp index 511623299..9c42008d9 100644 --- a/source/backend/cpu/CPUConvolution.cpp +++ b/source/backend/cpu/CPUConvolution.cpp @@ -49,6 +49,13 @@ bool CPUConvolution::Resource::copyBiasAlign(const float* bias, int outputCount) return true; } CPUConvolution::MutableResourceInt8::MutableResourceInt8(std::shared_ptr res, Backend* backend) : mResource(res) { + auto outputChannleUp4 = res->mOriginBias->length(0); + mBiasFloat.reset(Tensor::createDevice({outputChannleUp4})); + mValid = backend->onAcquireBuffer(mBiasFloat.get(), Backend::STATIC); + if (!mValid) { + MNN_ERROR("mBiasFloat buffer allocated error!\n"); + return; + } if (res->mUseConvQuan) { mBiasInt32 = res->mOriginBias; mScaleFloat = res->mOriginScale; @@ -59,11 +66,21 @@ CPUConvolution::MutableResourceInt8::MutableResourceInt8(std::shared_ptrmOutputZeroPoint; mClampMax = res->mClampMax; mClampMin = res->mClampMin; + // bias int32 -> bias float + auto int32BiasPtr = res->mOriginBias->host(); + auto floatBiasPtr = mBiasFloat->host(); + auto weightScale = res->mOriginScale->host(); + for (int i = 0; i < outputChannleUp4; ++i) { + if (mInputScale && mOutputScale) { // symmetric quan + floatBiasPtr[i] = int32BiasPtr[i] * weightScale[i] * mInputScale / mOutputScale; + } else { + floatBiasPtr[i] = int32BiasPtr[i] * weightScale[i]; + } + } return; } - auto outputChannleUp4 = res->mOriginBias->length(0); mBiasInt32.reset(Tensor::createDevice({outputChannleUp4})); - mScaleFloat.reset(Tensor::createDevice({outputChannleUp4})); + mScaleFloat.reset(Tensor::createDevice({outputChannleUp4})); mValid = backend->onAcquireBuffer(mBiasInt32.get(), Backend::STATIC); if (mValid) { mValid = backend->onAcquireBuffer(mScaleFloat.get(), Backend::STATIC); @@ -82,73 +99,102 @@ void CPUConvolution::MutableResourceInt8::updateInputOutputScale(std::vectormInputScale; + mOutputScale = mResource->mOutputScale; + mInputZeroPoint = mResource->mInputZeroPoint; + mOutputZeroPoint = mResource->mOutputZeroPoint; +// if (mInputScale == inputScale && mOutputScale == outputScale) { +// return; +// } + if (inputScale != 0 && outputScale != 0) { + mInputScale = inputScale; + mOutputScale = outputScale; + mInputZeroPoint = int8_t(inputZeroPoint); + mOutputZeroPoint = int8_t(outputZeroPoint); } - if (mInputScale == inputScale && mOutputScale == outputScale) { + if (mInputScale == 0 || mOutputScale == 0) { return; } - mInputScale = inputScale; - mOutputScale = outputScale; - mInputZeroPoint = int8_t(inputZeroPoint); - mOutputZeroPoint = int8_t(outputZeroPoint); + int size = mResource->mOutputCount; const int kernelNum = static_cast(mResource->mInt8WeightKernelSum.size()); auto biasData = mResource->mOriginBias->host(); auto alphaData = mResource->mOriginScale->host(); - auto alphaScale = inputScale / outputScale; + auto alphaScale = mInputScale / mOutputScale; auto scale = mScaleFloat->host(); auto bias = mBiasInt32->host(); + auto biasfloat = mBiasFloat->host(); #ifdef MNN_USE_SSE - inputZeroPoint += 128.0f; + float offset = 128.f; +#else + float offset = 0.f; #endif for (int i = 0; i < kernelNum; i++) { auto alphaValue = alphaData[i]; if (fabs(alphaValue) < 1e-6) { alphaValue = 1e-6; } - scale[i] = alphaValue * alphaScale; + scale[i] = alphaValue * alphaScale; // input_scale*weight_scale/output_scale // compute outputZeroPointFused in asymmetric quant - int outputZeroPointFused = static_cast(outputZeroPoint / scale[i]); - bias[i] = static_cast(biasData[i] / (inputScale * alphaValue)) - mResource->mInt8WeightKernelSum[i] * inputZeroPoint + outputZeroPointFused; + int outputZeroPointFused = static_cast(mOutputZeroPoint / scale[i]); + bias[i] = static_cast(biasData[i] / (mInputScale * alphaValue)) - mResource->mInt8WeightKernelSum[i] * (mInputZeroPoint + offset) + outputZeroPointFused; + // biasfloat[i] = biasData[i] / mOutputScale - mResource->mInt8WeightKernelSum[i] * (mInputZeroPoint + offset) * scale[i] + mOutputZeroPoint; + biasfloat[i] = bias[i] * scale[i]; } } std::shared_ptr CPUConvolution::makeResourceInt8(Backend* backend, const MNN::Convolution2D *convParam, int pack) { auto core = static_cast(backend)->functions(); // TODO: use different pack from float int UNIT = pack; - + std::shared_ptr resource(new ResourceInt8); // TODO: ConvInt8Winograd need in/out scale, which isn't exist in quantinfo when model construct by V3 API const auto convCommon = convParam->common(); const auto group = convParam->common()->group(); const auto outputCount = convCommon->outputCount(); const auto outputChannleUp4 = UP_DIV(outputCount, UNIT) * UNIT; - - resource->mOriginBias.reset(Tensor::createDevice({outputChannleUp4})); - resource->mOriginScale.reset(Tensor::createDevice({outputChannleUp4})); + + int quanCount = outputChannleUp4; + if (convParam->quanParameter() && convParam->quanParameter()->alpha()) { + quanCount = convParam->quanParameter()->alpha()->size(); + quanCount = ROUND_UP(quanCount, UNIT); + } + resource->mOriginBias.reset(Tensor::createDevice({quanCount})); + resource->mOriginScale.reset(Tensor::createDevice({quanCount * core->bytes})); + resource->mWeightQuantZero.reset(Tensor::createDevice({quanCount})); auto allocRes = backend->onAcquireBuffer(resource->mOriginBias.get(), Backend::STATIC); allocRes &= backend->onAcquireBuffer(resource->mOriginScale.get(), Backend::STATIC); + allocRes &= backend->onAcquireBuffer(resource->mWeightQuantZero.get(), Backend::STATIC); if (!allocRes) { return nullptr; } auto biasPtr = resource->mOriginBias->host(); - memset(biasPtr, 0, outputChannleUp4 * sizeof(int32_t)); + memset(biasPtr, 0, quanCount * sizeof(int32_t)); auto scalePtr = resource->mOriginScale->host(); - memset(scalePtr, 0, outputChannleUp4 * sizeof(float)); + memset(scalePtr, 0, quanCount * sizeof(uint8_t) * core->bytes); + auto betaPtr = resource->mWeightQuantZero->host(); + memset(betaPtr, 0, quanCount * sizeof(int32_t)); - resource->mActBits = convParam->symmetricQuan()->nbits(); + resource->mActBits = 8; + if (convParam->symmetricQuan()) { + resource->mActBits = convParam->symmetricQuan()->nbits(); + } const int8_t* weightSrc = nullptr; int weightSize = 0; std::shared_ptr quanCommon; resource->mOutputCount = outputCount; - if (!ConvolutionCommon::getConvInt8Parameters(convParam, quanCommon, backend, weightSrc, weightSize, scalePtr, biasPtr)) { + if (!ConvolutionCommon::getConvInt8Parameters(convParam, quanCommon, backend, weightSrc, weightSize, scalePtr, biasPtr, betaPtr)) { return nullptr; } if (convParam->bias() && convParam->quanParameter()->alpha()) { resource->mUseConvQuan = false; } + if (quanCommon.get()) { + resource->mWeightAsymmetricQuant = quanCommon->asymmetric; + } + resource->mWeightInt8.reset(Tensor::createDevice({weightSize})); allocRes = backend->onAcquireBuffer(resource->mWeightInt8.get(), Backend::STATIC); if (!allocRes) { @@ -156,12 +202,16 @@ std::shared_ptr CPUConvolution::makeResourceInt8(B } const int kernelNum = outputCount; const int kernelSize = weightSize / kernelNum; - resource->mInt8WeightKernelSum.resize(kernelNum); + resource->mInt8WeightKernelSum.resize(outputChannleUp4); + bool checkWeightQuantZero = false; for (int i = 0; i < kernelNum; i++) { int temp = 0; int offset = i * kernelSize; + if (static_cast(betaPtr[i]) != 0) { + checkWeightQuantZero = true; + } for (int j = 0; j < kernelSize; j++) { - temp += int(weightSrc[offset + j]); + temp += (static_cast(weightSrc[offset + j]) - betaPtr[i]); } resource->mInt8WeightKernelSum[i] = temp; #ifdef MNN_USE_SSE @@ -170,10 +220,19 @@ std::shared_ptr CPUConvolution::makeResourceInt8(B } #endif } - resource->mInputZeroPoint = convParam->symmetricQuan()->zeroPoint(); - resource->mOutputZeroPoint = convParam->symmetricQuan()->outputZeroPoint(); - resource->mClampMin = convParam->symmetricQuan()->clampMin(); - resource->mClampMax = convParam->symmetricQuan()->clampMax(); + if (false == checkWeightQuantZero) { // All weight quant bias is 0, do not need to compute related term in gemm kernel. + resource->mWeightAsymmetricQuant = false; + } + resource->mInputZeroPoint = 0; + resource->mOutputZeroPoint = 0; + resource->mClampMin = -128; + resource->mClampMax = 127; + if (convParam->symmetricQuan()) { + resource->mInputZeroPoint = convParam->symmetricQuan()->zeroPoint(); + resource->mOutputZeroPoint = convParam->symmetricQuan()->outputZeroPoint(); + resource->mClampMin = convParam->symmetricQuan()->clampMin(); + resource->mClampMax = convParam->symmetricQuan()->clampMax(); + } if (convParam->quanParameter() != nullptr) { resource->mInputScale = convParam->quanParameter()->scaleIn(); resource->mOutputScale = convParam->quanParameter()->scaleOut(); @@ -181,9 +240,113 @@ std::shared_ptr CPUConvolution::makeResourceInt8(B auto weightDst = resource->mWeightInt8->host(); memcpy(weightDst, weightSrc, resource->mWeightInt8->size()); resource->mRelu = convCommon->relu() || convCommon->relu6(); + if (convParam->symmetricQuan() && convParam->symmetricQuan()->outputDataType() == MNN::DataType_DT_FLOAT) { + resource->mOutputZeroPoint = 0; + resource->mOutputScale = 1.0f; + } return resource; } +void CPUConvolution::makeResource(Backend* backend, std::shared_ptr resource, const Convolution2D* conv2d, std::shared_ptr resourceInt8) { + /* Used to compute weight quant scale and bias and weightKernelSum of type float. */ + bool quanBuffer = (conv2d->quanParameter() != nullptr && conv2d->quanParameter()->buffer() != nullptr); + MNN_ASSERT(quanBuffer || resourceInt8); + resource->backend = backend; + auto core = static_cast(backend)->functions(); + // common parameters + int outputCount = conv2d->common()->outputCount(); + int LSize = conv2d->common()->inputCount() * conv2d->common()->kernelX() * conv2d->common()->kernelY(); + int ocUp4 = ROUND_UP(outputCount, core->pack); + int8_t* weightOrigin; + + // Save weight quant scale and bias: wf=scale*wi+bias + resource->mDequantize.mScaleBias.reset(Tensor::createDevice({2 * ocUp4 * core->bytes})); + auto success = resource->backend->onAcquireBuffer(resource->mDequantize.mScaleBias.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Alloc denquant scaleBias memory error\n"); + return; + } + auto alphaPtr = resource->mDequantize.mScaleBias->host(); + auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + ocUp4 * core->bytes); + ::memset(alphaPtr, 0, 2 * ocUp4 * core->bytes); + + std::shared_ptr quantCommon; + // Load quant scale and bias + if (quanBuffer) { + quantCommon = ConvolutionCommon::load(conv2d, backend, false, true); + weightOrigin = quantCommon->weight.get(); // weight before reorder + + int h = quantCommon->alpha.size(); + if (core->bytes == 2) { + if (quantCommon->asymmetric) { + std::unique_ptr tmp(new int16_t[h]); + core->MNNFp32ToLowp(quantCommon->alpha.get(), tmp.get(), h); + for (int i=0; i< h/2; ++i) { + reinterpret_cast(alphaPtr)[i] = tmp[2 * i + 1]; + reinterpret_cast(biasPtr)[i] = tmp[2 * i]; + } + } else { + core->MNNFp32ToLowp(quantCommon->alpha.get(), reinterpret_cast(alphaPtr), h); + } + } else { + if (quantCommon->asymmetric) { + h = h / 2; + for (int i=0; ialpha.get()[2 * i + 1]; + biasPtr[i] = quantCommon->alpha.get()[2 * i]; + } + } else { + for (int i=0; ialpha.get()[i]; + biasPtr[i] = 0.f; + } + } + } + } else { + weightOrigin = resourceInt8->mWeightInt8->host(); + auto wZero = resourceInt8->mWeightQuantZero->host(); // has packed to outputUp4 + auto wScale = resourceInt8->mOriginScale->host(); + int h = ocUp4; + if (core->bytes == 2) { + std::unique_ptr tmp(new int16_t[h]); + core->MNNFp32ToLowp(wScale, tmp.get(), h); + for (int i=0; i< h; ++i) { + reinterpret_cast(alphaPtr)[i] = tmp[i]; + reinterpret_cast(biasPtr)[i] = (-1.f) * wZero[i] * tmp[i]; + } + } else { + for (int i=0; i< h; ++i) { + alphaPtr[i] = wScale[i]; + biasPtr[i] = (-1.f) * wZero[i] * wScale[i]; + } + } + } + + // Compute float weightKernelSum + resource->mWeightKernelSum.reset(Tensor::createDevice({ocUp4 * 4})); + success = resource->backend->onAcquireBuffer(resource->mWeightKernelSum.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Alloc denquant mWeightKernelSum memory error\n"); + return; + } + auto weightKernelSum = resource->mWeightKernelSum->host(); + for (int i = 0; i < outputCount; ++i) { + int sum = 0; + for (int j = 0; j < LSize; ++j) { + sum = sum + static_cast(weightOrigin[j + i * LSize]); + } + if(core->bytes == 2) { + auto scale = reinterpret_cast(alphaPtr)[i]; + auto bias = reinterpret_cast(biasPtr)[i]; + weightKernelSum[i] = static_cast(sum) * scale + LSize * bias; + } else { + auto scale = alphaPtr[i]; + auto bias = biasPtr[i]; + weightKernelSum[i] = static_cast(sum) * scale + LSize * bias; + } + } +} + CPUConvolution::CPUConvolution(const Convolution2DCommon *convOp, Backend *b) : MNN::Execution(b), mCommon(convOp) { // Do nothing } @@ -245,7 +408,7 @@ class CPUConvInt8Creator : public CPUBackend::Creator { if (ConvInt8Winograd::mustUse(convOp)) { return new ConvInt8Winograd(backend, convOp, res); } - return new DenseConvInt8TiledExecutor(backend, convOp, res); + return new DenseConvInt8TiledExecutor(backend, convOp, res, false); } }; diff --git a/source/backend/cpu/CPUConvolution.hpp b/source/backend/cpu/CPUConvolution.hpp index d79c3ee73..d241007d6 100644 --- a/source/backend/cpu/CPUConvolution.hpp +++ b/source/backend/cpu/CPUConvolution.hpp @@ -45,6 +45,7 @@ class CPUConvolution : public Execution { std::shared_ptr mScaleBias; }; struct Resource { + std::shared_ptr mWeightKernelSum; std::shared_ptr mWeight; std::shared_ptr mBias; ResourceDequantizeInfo mDequantize; @@ -54,18 +55,21 @@ class CPUConvolution : public Execution { int lU; int lP; int hP; + std::vector mReluThreshold; }; struct ResourceInt8 { std::vector mInt8WeightKernelSum; std::shared_ptr mWeightInt8; std::shared_ptr mOriginBias; std::shared_ptr mOriginScale; + std::shared_ptr mWeightQuantZero; // relu or relu6 bool mRelu; int mActBits; int mOutputCount; bool mUseConvQuan = true; + bool mWeightAsymmetricQuant = true; #ifdef MNN_USE_SSE std::vector offsets; #endif @@ -89,10 +93,12 @@ class CPUConvolution : public Execution { int8_t mClampMax; std::shared_ptr mBiasInt32; std::shared_ptr mScaleFloat; + std::shared_ptr mBiasFloat; int32_t mShiftBits = 14; bool mValid; }; static std::shared_ptr makeResourceInt8(Backend *backend, const MNN::Convolution2D *convOp, int pack=4); + static void makeResource(Backend* backend, std::shared_ptr resource, const Convolution2D* conv2d, std::shared_ptr resourceInt8 = nullptr); CPUConvolution(const Convolution2DCommon *convOp, Backend *b); virtual ~CPUConvolution() = default; virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; diff --git a/source/backend/cpu/CPUConvolutionDepthwise.cpp b/source/backend/cpu/CPUConvolutionDepthwise.cpp index a3b129fa2..03767edfa 100644 --- a/source/backend/cpu/CPUConvolutionDepthwise.cpp +++ b/source/backend/cpu/CPUConvolutionDepthwise.cpp @@ -187,7 +187,8 @@ ErrorCode CPUConvolutionDepthwise::BasicFloatExecution::onResize(const std::vect auto postData = getPostParameters(); auto batch = inputs[0]->batch(); int total = batch * dst_depth_quad; - int numberThread = std::min(((CPUBackend*)backend())->threadNumber(), total); + int numberThread = ((CPUBackend*)backend())->threadNumber(); + auto rt = static_cast(backend()->getRuntime()); auto runBasic = [=](uint8_t* dst_z, const uint8_t* src_z, const uint8_t* weight_dz, int L, int T, int R, int B) { for (int dy = T; dy < B; ++dy) { auto dst_y = dst_z + dy * dst_y_step * bytes; @@ -207,10 +208,13 @@ ErrorCode CPUConvolutionDepthwise::BasicFloatExecution::onResize(const std::vect } } }; + std::vector divides(numberThread+1); + divides[0] = 0; + rt->computeDivideSizes(total, divides.data()+1); mExecutor = [=](const uint8_t* srcOrigin, uint8_t* dstOrigin, int tId) { auto biasP = inputs[2]->host(); auto weightP = inputs[1]->host(); - for (int index = tId; index < total; index += numberThread) { + for (int index = divides[tId]; index < divides[tId+1]; ++index) { int dz = index / batch; auto dst_z = dstOrigin + dst_z_step * index * bytes; const auto src_z = srcOrigin + src_z_step * index * bytes; diff --git a/source/backend/cpu/CPUDeconvolution.cpp b/source/backend/cpu/CPUDeconvolution.cpp index 9968e6198..0a1e6f813 100644 --- a/source/backend/cpu/CPUDeconvolution.cpp +++ b/source/backend/cpu/CPUDeconvolution.cpp @@ -87,8 +87,10 @@ static void _transformWeight(const uint8_t* tempWeight, uint8_t* dest, int outpu static void _reorderWeightInt8(Backend* bn, const Convolution2DCommon* common, const int8_t* srcPtr, std::shared_ptr& weight) { auto core = static_cast(bn)->int8Functions(); + auto gcore = static_cast(bn)->functions(); int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); + UNIT = gcore->pack; int oc = common->outputCount(), ic = common->inputCount(), kernelCount = common->kernelX() * common->kernelY(); std::vector shape = {UP_DIV(oc, UNIT), UP_DIV(ic, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT}; @@ -167,11 +169,13 @@ CPUDeconvolution::CPUDeconvolution(const Tensor* input, const Op* convOp, Backen std::vector _bias(outputChannleUp4, 0); std::vector _scale(outputChannleUp4, 0); + std::vector _beta(outputChannleUp4, 0); auto biasPtr = _bias.data(); auto scalePtr = _scale.data(); + auto betaPtr = _beta.data(); if (ModeInt8) { - ConvolutionCommon::getConvInt8Parameters(conv2d, quanCommon, backend, quanWeightInt8, tempWeightSize, scalePtr, biasPtr); + ConvolutionCommon::getConvInt8Parameters(conv2d, quanCommon, backend, quanWeightInt8, tempWeightSize, scalePtr, biasPtr, betaPtr); } else { ConvolutionCommon::getConvParameters(&quanCommon, backend, conv2d, &tempWeight, &tempWeightSize); } diff --git a/source/backend/cpu/CPUDeconvolution.hpp b/source/backend/cpu/CPUDeconvolution.hpp index 750fc4816..c9e0427f0 100644 --- a/source/backend/cpu/CPUDeconvolution.hpp +++ b/source/backend/cpu/CPUDeconvolution.hpp @@ -78,7 +78,7 @@ class CPUDeconvolutionOrigin : public CPUDeconvolutionBasic { } } #else - if(conv2d->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){ + if(conv2d->symmetricQuan() && conv2d->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){ gemmKernel = core->Int8GemmKernelFast; } #endif diff --git a/source/backend/cpu/CPUMatMul.cpp b/source/backend/cpu/CPUMatMul.cpp index 1ad8ff4aa..069a77965 100644 --- a/source/backend/cpu/CPUMatMul.cpp +++ b/source/backend/cpu/CPUMatMul.cpp @@ -89,8 +89,12 @@ ErrorCode CPUMatMul::onResize(const std::vector& inputs, const std::vec core->MNNGetMatMulPackMode(&eP, &lP, &hP); int numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1; auto bufferAlloc = static_cast(backend())->getBufferAllocator(); - auto ATPtrAlloc = bufferAlloc->alloc(eP * l * core->bytes * numberThread); - auto BTPtrAlloc = bufferAlloc->alloc(UP_DIV(h, hP) * UP_DIV(l, lP) * lP * hP * core->bytes); + auto ATPtrAlloc = bufferAlloc->alloc(eP * UP_DIV(l, lP) * lP * core->bytes * numberThread); + int matmulBytes = core->bytes; + if (core->matmulBytes != 0) { + matmulBytes = core->matmulBytes; + } + auto BTPtrAlloc = bufferAlloc->alloc(UP_DIV(h, hP) * UP_DIV(l, lP) * lP * hP * matmulBytes); auto CTPtrAlloc = bufferAlloc->alloc(UP_DIV(h, core->pack) * eP * core->pack * core->bytes * numberThread); if (ATPtrAlloc.invalid() || BTPtrAlloc.invalid() || CTPtrAlloc.invalid()) { return OUT_OF_MEMORY; @@ -180,10 +184,11 @@ void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const if (nullptr == biasPtr) { postPtr = nullptr; } + auto lAlign = UP_DIV(mL, lP) * lP; int tileCount = UP_DIV(mE, eP); int numberThread = mSupportMultiThread ? ((CPUBackend*)backend())->threadNumber() : 1; MNN_CONCURRENCY_BEGIN(tId, numberThread) { - auto TA = mTempA.ptr() + tId * eP * mL * core->bytes; + auto TA = mTempA.ptr() + tId * eP * lAlign * core->bytes; auto TB = mTempB.ptr(); auto hC4 = UP_DIV(mH, core->pack); auto TC = mTempC.ptr() + tId * eP * hC4 * core->pack * core->bytes; @@ -199,27 +204,78 @@ void CPUMatMul::execute(const float* APtr, const float* BPtr, float* CPtr, const int xEnd = ALIMIN(xStart + eP, mE); int xC = xEnd - xStart; if (mTransposeA) { - for (int y=0; ybytes, (uint8_t*)APtr + (y * mE + xStart) * core->bytes, core->bytes * xC); + // l, e -> l/lp, xC|eP, lp + if (lP > 1) { + // TODO: Speed up it + if (mL % lP != 0) { + ::memset(TA, 0, eP * lAlign * core->bytes); + } + if (core->bytes == 4) { + auto D = (int32_t*)TA; + auto S = (int32_t*)APtr; + for (int y=0; ybytes == 2); + auto D = (int16_t*)TA; + auto S = (int16_t*)APtr; + for (int y=0; ybytes, (uint8_t*)APtr + (y * mE + xStart) * core->bytes, core->bytes * xC); + } } } else { - // e, l -> l, eP - int dims[] = { - xC, - mL, - mL, - eP - }; - if (core->bytes == 2) { - auto S = (const int16_t*)APtr + xStart * mL; - auto D = (int16_t*)TA; - MNNTranspose16Bit(D, S, dims); - } else if (core->bytes == 4) { - auto S = (const int32_t*)APtr + xStart * mL; - auto D = (int32_t*)TA; - MNNTranspose32Bit(D, S, dims); + if (lP > 1) { + // e, l -> l/lp, 1, xC|eP, lp + int lC = mL / lP; + int lR = mL % lP; + for (int yy=0; yybytes, (uint8_t*)APtr + ((x+xStart)*mL+yy*lP)*core->bytes, lP * core->bytes); + } + } + if (lR > 0) { + int yy = lC; + for (int x=0; xbytes, 0, lP * core->bytes); + ::memcpy(TA + (yy * eP * lP + x * lP) * core->bytes, (uint8_t*)APtr + ((x+xStart)*mL+yy*lP)*core->bytes, xC * core->bytes); + } + } + } else { + // e, l -> l, eP + int dims[] = { + xC, + mL, + mL, + eP + }; + if (core->bytes == 2) { + auto S = (const int16_t*)APtr + xStart * mL; + auto D = (int16_t*)TA; + MNNTranspose16Bit(D, S, dims); + } else if (core->bytes == 4) { + auto S = (const int32_t*)APtr + xStart * mL; + auto D = (int32_t*)TA; + MNNTranspose32Bit(D, S, dims); + } } } + if (core->matmulBytes != 0) { + core->MNNFp32ToLowp((const float*)TA, (int16_t*)TA, eP * lAlign); + } if (xC == eP) { core->MNNPackedMatMul((float*)TC, (float*)TA, (float*)TB, parameters, postPtr, biasPtr, nullptr, nullptr); } else { diff --git a/source/backend/cpu/CPURelu.cpp b/source/backend/cpu/CPURelu.cpp index fb5aaa335..073556464 100644 --- a/source/backend/cpu/CPURelu.cpp +++ b/source/backend/cpu/CPURelu.cpp @@ -179,10 +179,6 @@ ErrorCode CPUPRelu::onResize(const std::vector& inputs, const std::vect ssize_t outputZero = static_cast(TensorUtils::getDescribe(outputs[0])->quantAttr->zero); ssize_t maxValue = static_cast(TensorUtils::getDescribe(inputs[0])->quantAttr->max); ssize_t minValue = static_cast(TensorUtils::getDescribe(inputs[0])->quantAttr->min); - float inputScales[1] = {inputScale}; - float outputScales[1] = {outputScale}; - ssize_t inputZeros[1] = {inputZero}; - ssize_t outputZeros[1] = {outputZero}; mQuanScalesInput.resize(1); mQuanScalesOutput.resize(1); mQuanZerosInput.resize(1); @@ -210,13 +206,14 @@ ErrorCode CPUPRelu::onExecute(const std::vector& inputs, const std::vec auto coreInt8 = static_cast(backend())->int8Functions(); const int channel = ib.dim[1].extent; const int batch = ib.dim[0].extent; - const int depthQuad = UP_DIV(channel, core->pack); + int pack = 4; + int depthQuad = UP_DIV(channel, core->pack); const uint8_t* srcO = (const uint8_t*)ib.host; uint8_t* dstO = (uint8_t*)ob.host; auto totalCount = batch * depthQuad; auto numberThread = ((CPUBackend*)backend())->threadNumber(); if (mUseInt8) { - + depthQuad = UP_DIV(channel, pack); MNN_CONCURRENCY_BEGIN(tId, numberThread) { QuanPrePostParameters params; params.maxValue = static_cast(TensorUtils::getDescribe(inputs[0])->quantAttr->max); @@ -227,7 +224,7 @@ ErrorCode CPUPRelu::onExecute(const std::vector& inputs, const std::vec params.outputZeroPoint = mQuanZerosOutput.data(); for (int b=tId; bMNNReluWithSlopeChannelInt8((int8_t*)(dstO + sizeQuad * core->pack * b), (const int8_t*)(srcO + sizeQuad * core->pack * b), (const float*)(mSlope.host() + core->bytes * core->pack * c), sizeQuad, 1, ¶ms); + coreInt8->MNNReluWithSlopeChannelInt8((int8_t*)(dstO + sizeQuad * pack * b), (const int8_t*)(srcO + sizeQuad * pack * b), (const float*)(mSlope.host() + core->bytes * pack * c), sizeQuad, 1, ¶ms); } } MNN_CONCURRENCY_END(); diff --git a/source/backend/cpu/CPURuntime.cpp b/source/backend/cpu/CPURuntime.cpp index 0cf9a336d..98b7d04f5 100644 --- a/source/backend/cpu/CPURuntime.cpp +++ b/source/backend/cpu/CPURuntime.cpp @@ -11,19 +11,33 @@ https://github.com/Tencent/ncnn/blob/master/src/cpu.cpp https://github.com/pytorch/cpuinfo */ -#ifdef __ANDROID__ +#ifdef __linux__ #include #include #include +#include +#include +#include +#include +#include +#include + +#define CPUINFO_ARM_LINUX_FEATURE_FPHP UINT32_C(0x00000200) +#define CPUINFO_ARM_LINUX_FEATURE_ASIMDHP UINT32_C(0x00000400) +#define CPUINFO_ARM_LINUX_FEATURE_ASIMDDP UINT32_C(0x00100000) +// ref: https://cs.android.com/android/platform/superproject/+/master:bionic/libc/kernel/uapi/asm-arm64/asm/hwcap.h;drc=04da58f5b3bc40dbbafb4f8422aa2991479d9e1e;l=70 +#define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000) +#define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000) +#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002) #endif -#include "core/Macro.h" +#include +#include +#include "core/Macro.h" #ifdef __ANDROID__ -#include -#include #include -#endif // __ANDROID__ +#endif #if __APPLE__ #include "TargetConditionals.h" @@ -37,30 +51,68 @@ #endif // TARGET_OS_IPHONE #endif // __APPLE__ -#ifdef _OPENMP -#include -#endif // _OPENMP - #include #include #include #include #include #include "backend/cpu/CPURuntime.hpp" +#include "core/FileLoader.hpp" -#if defined (__linux__) && defined (__aarch64__) -#include +#define BUFFER_SIZE 1024 -#define CPUINFO_ARM_LINUX_FEATURE_FPHP UINT32_C(0x00000200) -#define CPUINFO_ARM_LINUX_FEATURE_ASIMDHP UINT32_C(0x00000400) -#define CPUINFO_ARM_LINUX_FEATURE_ASIMDDP UINT32_C(0x00100000) -#define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000) -#define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000) -#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002) +int MNNGetCurrentPid() { +#if defined (__linux__) +#ifdef __GLIBC__ + pid_t pid = syscall(SYS_gettid); +#else +#ifdef PI3 + pid_t pid = getpid(); +#else + pid_t pid = gettid(); +#endif +#endif + return pid; +#else + return 0; +#endif +} +int MNNSetSchedAffinity(const int* cpuIDs, int size) { +#if defined (__linux__) +#ifndef CPU_SETSIZE +#define CPU_SETSIZE 1024 +#endif +#define __NCPUBITS (8 * sizeof(unsigned long)) + typedef struct { + unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; + } cpu_set_t; -#endif /* __linux__ && __aarch64__ */ +#ifndef CPU_SET +#define CPU_SET(cpu, cpusetp) ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) +#endif +#ifndef CPU_ZERO +#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t)) +#endif + // set affinity for thread + pid_t pid = MNNGetCurrentPid(); + cpu_set_t mask; + CPU_ZERO(&mask); + for (int i = 0; i < size; i++) { + CPU_SET(cpuIDs[i], &mask); + } -#ifdef __ANDROID__ + int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask); + if (syscallret) { + MNN_PRINT("syscall error %d\n", syscallret); + return -1; + } +#endif + return 0; +} + +// cpuinfo +// Reference from: https://github.com/pytorch/cpuinfo +#if defined(ENABLE_ARMV82) && defined(__arm__) /* As per include/sys/system_properties.h in Android NDK */ #define CPUINFO_HARDWARE_VALUE_MAX 64 @@ -154,231 +206,6 @@ struct cpuinfo_arm_chipset { char suffix[8]; }; -#define BUFFER_SIZE 1024 - -static uint32_t getNumberOfCPU() { - FILE* fp = fopen("/proc/cpuinfo", "rb"); - if (!fp) { - return 1; - } - uint32_t number = 0; - char buffer[BUFFER_SIZE]; - while (!feof(fp)) { - char* str = fgets(buffer, BUFFER_SIZE, fp); - if (!str) { - break; - } - if (memcmp(buffer, "processor", 9) == 0) { - number++; - } - } - fclose(fp); - if (number < 1) { - number = 1; - } - return number; -} - -static int getCPUMaxFreqKHz(int cpuID) { - char path[256]; - sprintf(path, "/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", cpuID); - FILE* fp = fopen(path, "rb"); - if (!fp) { - sprintf(path, "/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state", cpuID); - fp = fopen(path, "rb"); - if (!fp) { - sprintf(path, "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", cpuID); - fp = fopen(path, "rb"); - if (!fp) { - return -1; - } - int maxfrequency = -1; - fscanf(fp, "%d", &maxfrequency); - fclose(fp); - return maxfrequency; - } - } - int maxfrequency = 0; - while (!feof(fp)) { - int frequency = 0; - int history = fscanf(fp, "%d %*d", &frequency); - if (history != 1) { - break; - } - if (frequency > maxfrequency) { - maxfrequency = frequency; - } - } - fclose(fp); - return maxfrequency; -} - -static int sortCPUIDByMaxFrequency(std::vector& cpuIDs, int* littleClusterOffset) { - const int cpuNumbers = cpuIDs.size(); - *littleClusterOffset = 0; - if (cpuNumbers == 0) { - return 0; - } - std::vector cpusFrequency; - cpusFrequency.resize(cpuNumbers); - for (int i = 0; i < cpuNumbers; ++i) { - int frequency = getCPUMaxFreqKHz(i); - cpuIDs[i] = i; - cpusFrequency[i] = frequency; - // MNN_PRINT("cpu fre: %d, %d\n", i, frequency); - } - for (int i = 0; i < cpuNumbers; ++i) { - for (int j = i + 1; j < cpuNumbers; ++j) { - if (cpusFrequency[i] < cpusFrequency[j]) { - // id - int temp = cpuIDs[i]; - cpuIDs[i] = cpuIDs[j]; - cpuIDs[j] = temp; - // frequency - temp = cpusFrequency[i]; - cpusFrequency[i] = cpusFrequency[j]; - cpusFrequency[j] = temp; - } - } - } - int midMaxFrequency = (cpusFrequency.front() + cpusFrequency.back()) / 2; - if (midMaxFrequency == cpusFrequency.back()) { - return 0; - } - for (int i = 0; i < cpuNumbers; ++i) { - if (cpusFrequency[i] < midMaxFrequency) { - *littleClusterOffset = i; - break; - } - } - return 0; -} - -static int setSchedAffinity(const std::vector& cpuIDs) { -#define CPU_SETSIZE 1024 -#define __NCPUBITS (8 * sizeof(unsigned long)) - typedef struct { - unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; - } cpu_set_t; - -#define CPU_SET(cpu, cpusetp) ((cpusetp)->__bits[(cpu) / __NCPUBITS] |= (1UL << ((cpu) % __NCPUBITS))) - -#define CPU_ZERO(cpusetp) memset((cpusetp), 0, sizeof(cpu_set_t)) - - // set affinity for thread -#ifdef __GLIBC__ - pid_t pid = syscall(SYS_gettid); -#else -#ifdef PI3 - pid_t pid = getpid(); -#else - pid_t pid = gettid(); -#endif -#endif - cpu_set_t mask; - CPU_ZERO(&mask); - for (int i = 0; i < (int)cpuIDs.size(); i++) { - CPU_SET(cpuIDs[i], &mask); - } - - int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask); - if (syscallret) { - MNN_PRINT("syscall error %d\n", syscallret); - return -1; - } - - return 0; -} - -#endif // arch - -int MNNSetCPUThreadsMode(MNNCPUThreadsMode mode) { -#ifdef __ANDROID__ - auto numberOfCPUs = getNumberOfCPU(); - if (mode == MNN_CPU_MODE_DEFAULT) { - return 0; - } - static std::vector sortedCPUIDs; - static int littleClusterOffset = 0; - if (sortedCPUIDs.empty()) { - sortedCPUIDs.resize(numberOfCPUs); - for (int i = 0; i < numberOfCPUs; ++i) { - sortedCPUIDs[i] = i; - } - sortCPUIDByMaxFrequency(sortedCPUIDs, &littleClusterOffset); - } - - if (littleClusterOffset == 0 && mode != MNN_CPU_MODE_POWER_FRI) { - MNN_PRINT("This CPU Arch Do NOT support for setting cpu thread mode\n"); - } - std::vector cpuAttachIDs; - switch (mode) { - case MNN_CPU_MODE_POWER_FRI: - cpuAttachIDs = sortedCPUIDs; - break; - case MNN_CPU_MODE_LITTLE: - cpuAttachIDs = std::vector(sortedCPUIDs.begin() + littleClusterOffset, sortedCPUIDs.end()); - break; - case MNN_CPU_MODE_BIG: - cpuAttachIDs = std::vector(sortedCPUIDs.begin(), sortedCPUIDs.begin() + littleClusterOffset); - break; - default: - cpuAttachIDs = sortedCPUIDs; - break; - } - -#ifdef _OPENMP - const int threadsNumber = cpuAttachIDs.size(); - omp_set_num_threads(threadsNumber); - std::vector result(threadsNumber, 0); -#pragma omp parallel for - for (int i = 0; i < threadsNumber; ++i) { - result[i] = setSchedAffinity(cpuAttachIDs); - } - for (int i = 0; i < threadsNumber; ++i) { - if (result[i] != 0) { - return -1; - } - } -#else - int res = setSchedAffinity(cpuAttachIDs); - if (res != 0) { - return -1; - } -#endif // _OPENMP - return 0; -#elif __IOS__ - return -1; -#else - return -1; -#endif // arch -} -float MNNGetCPUFlops(uint32_t number) { - float flops = 2048.0f; -#ifdef __ANDROID__ - auto numberOfCPUs = getNumberOfCPU(); - if (0 == numberOfCPUs) { - return flops; - } - std::vector freqs; - freqs.resize(numberOfCPUs); - for (int i = 0; i < numberOfCPUs; ++i) { - freqs[i] = getCPUMaxFreqKHz(i); - } - std::sort(freqs.rbegin(), freqs.rend()); - number = std::min(number, numberOfCPUs); - flops = 0.0f; - for (uint32_t i = 0; i < number; ++i) { - flops += (float)freqs[i] / 1024.0f; - } -#endif - return flops; -} - -// cpuinfo -// Reference from: https://github.com/pytorch/cpuinfo -#ifdef __ANDROID__ - #define CPUINFO_ARM_MIDR_IMPLEMENTER_MASK UINT32_C(0xFF000000) #define CPUINFO_ARM_MIDR_VARIANT_MASK UINT32_C(0x00F00000) #define CPUINFO_ARM_MIDR_ARCHITECTURE_MASK UINT32_C(0x000F0000) @@ -400,19 +227,6 @@ float MNNGetCPUFlops(uint32_t number) { #define CPUINFO_ARM_MIDR_PART_OFFSET 4 #define CPUINFO_ARM_MIDR_REVISION_OFFSET 0 -#ifdef __aarch64__ -#define CPUINFO_ARM_LINUX_FEATURE_FPHP UINT32_C(0x00000200) -#define CPUINFO_ARM_LINUX_FEATURE_ASIMDHP UINT32_C(0x00000400) -#define CPUINFO_ARM_LINUX_FEATURE_ASIMDDP UINT32_C(0x00100000) -// ref: https://cs.android.com/android/platform/superproject/+/master:bionic/libc/kernel/uapi/asm-arm64/asm/hwcap.h;drc=04da58f5b3bc40dbbafb4f8422aa2991479d9e1e;l=70 -#define CPUINFO_ARM_LINUX_FEATURE_I8MM UINT32_C(0x00002000) -#define CPUINFO_ARM_LINUX_FEATURE_SVE UINT32_C(0x00400000) -#define CPUINFO_ARM_LINUX_FEATURE_SVE2 UINT32_C(0x00000002) -#else -#define CPUINFO_ARM_LINUX_FEATURE_HALF UINT32_C(0x00000002) -#define CPUINFO_ARM_LINUX_FEATURE_NEON UINT32_C(0x00001000) -#endif - struct cpuinfo_arm_linux_processor { uint32_t architecture_version; // Main ID Register value @@ -1308,39 +1122,18 @@ struct cpuinfo_arm_chipset cpuinfo_arm_android_decode_chipset(const struct cpuin // MNN_PRINT("chipset vendor, series, model is: %d, %d, %d\n", chipset.vendor, chipset.series, chipset.model); return chipset; } - -#endif // __ANDROID__ - -#if defined(__APPLE__) && defined(__aarch64__) - -static uint32_t get_sys_info_by_name(const char* type_specifier) { - size_t size = 0; - uint32_t result = 0; - if (sysctlbyname(type_specifier, NULL, &size, NULL, 0) != 0) { - MNN_PRINT("sysctlbyname(\"%s\") failed\n", type_specifier); - } else if (size == sizeof(uint32_t)) { - sysctlbyname(type_specifier, &result, &size, NULL, 0); - MNN_PRINT("%s: %u , size = %lu\n", type_specifier, result, size); - } else { - MNN_PRINT("sysctl does not support non-integer lookup for (\"%s\")\n", type_specifier); - } - return result; -} - -#endif // iOS - -void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) { - memset(cpuinfo_isa, 0, sizeof(struct cpuinfo_arm_isa)); - - // android -#ifdef __ANDROID__ +static void _getInfoARMv7(MNNCPUInfo* cpuinfo_isa) { + // Get White List And Black List struct cpuinfo_arm_linux_processor* arm_linux_processors = NULL; - const uint32_t processors_count = getNumberOfCPU(); + if (0 == cpuinfo_isa->groups.size()) { + return; + } + const uint32_t processors_count = cpuinfo_isa->allCpuIdsSorted.size(); char proc_cpuinfo_hardware[CPUINFO_HARDWARE_VALUE_MAX] = {0}; arm_linux_processors = static_cast( - calloc(processors_count, sizeof(struct cpuinfo_arm_linux_processor))); + malloc(processors_count * sizeof(struct cpuinfo_arm_linux_processor))); if (arm_linux_processors == NULL) { MNN_PRINT("failed to allocate %zu bytes for descriptions of %u ARM logical processors\n", processors_count * sizeof(struct cpuinfo_arm_linux_processor), processors_count); @@ -1349,6 +1142,7 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) { if (!cpuinfo_arm_linux_parse_proc_cpuinfo(proc_cpuinfo_hardware, processors_count, arm_linux_processors)) { MNN_PRINT("failed to parse processor information from /proc/cpuinfo\n"); + free(arm_linux_processors); return; } @@ -1369,54 +1163,17 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) { } } } - - uint32_t isa_features = 0; -#ifdef __aarch64__ - isa_features = (uint32_t)getauxval(AT_HWCAP); -#endif - struct cpuinfo_android_properties android_properties; cpuinfo_arm_android_parse_properties(&android_properties); const struct cpuinfo_arm_chipset chipset = cpuinfo_arm_android_decode_chipset(&android_properties, valid_processors, 0); - - switch (last_midr & (CPUINFO_ARM_MIDR_IMPLEMENTER_MASK | CPUINFO_ARM_MIDR_PART_MASK)) { - case UINT32_C(0x51008040): /* Kryo 485 Gold (Cortex-A76) */ - cpuinfo_isa->dot = true; - break; - default: -#ifdef __aarch64__ - if (isa_features & CPUINFO_ARM_LINUX_FEATURE_ASIMDDP) { - cpuinfo_isa->dot = true; - } -#endif - // TODO, whitelist, ex: hisilicon_kirin 980... - break; - } -#ifdef __aarch64__ - const uint32_t fp16arith_mask = CPUINFO_ARM_LINUX_FEATURE_FPHP | CPUINFO_ARM_LINUX_FEATURE_ASIMDHP; - if ((isa_features & fp16arith_mask) == fp16arith_mask) { - if (chipset.series == cpuinfo_arm_chipset_series_samsung_exynos && chipset.model == 9810) { - cpuinfo_isa->fp16arith = false; - } else { - cpuinfo_isa->fp16arith = true; - } - } - if (isa_features & CPUINFO_ARM_LINUX_FEATURE_I8MM) { - cpuinfo_isa->i8mm = true; - } - /* - if (isa_features & CPUINFO_ARM_LINUX_FEATURE_SVE2) { - // MNN_PRINT("Support SVE2\n"); - } - */ -#else // pytorch/cpuinfo: src/arm/linux/aarch32-isa.c uint32_t architecture_version = 0; if (processors_count > 0) { architecture_version = arm_linux_processors[0].architecture_version; } if (architecture_version >= 8) { + FUNC_PRINT_ALL((last_midr & (CPUINFO_ARM_MIDR_IMPLEMENTER_MASK | CPUINFO_ARM_MIDR_PART_MASK)), 0x); /* * NEON FP16 compute extension and VQRDMLAH/VQRDMLSH instructions are not indicated in /proc/cpuinfo. * Use a MIDR-based heuristic to whitelist processors known to support it: @@ -1437,6 +1194,8 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) { case UINT32_C(0x4100D050): /* Cortex-A55 */ case UINT32_C(0x4100D060): /* Cortex-A65 */ case UINT32_C(0x4100D0B0): /* Cortex-A76 */ + case UINT32_C(0x4100d440): /* 888 */ + case UINT32_C(0x4100d480): /* 8gen1 */ case UINT32_C(0x4100D0C0): /* Neoverse N1 */ case UINT32_C(0x4100D0D0): /* Cortex-A77 */ case UINT32_C(0x4100D0E0): /* Cortex-A76AE */ @@ -1459,6 +1218,8 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) { case UINT32_C(0x4100D0B0): /* Cortex-A76 */ case UINT32_C(0x4100D0D0): /* Cortex-A77 */ case UINT32_C(0x4100D0E0): /* Cortex-A76AE */ + case UINT32_C(0x4100d440): /* 888 */ + case UINT32_C(0x4100d480): /* 8gen1 */ case UINT32_C(0x4800D400): /* Cortex-A76 (HiSilicon) */ case UINT32_C(0x51008040): /* Kryo 485 Gold (Cortex-A76) */ case UINT32_C(0x51008050): /* Kryo 485 Silver (Cortex-A55) */ @@ -1474,106 +1235,210 @@ void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa) { break; } } -#endif + // Whitelist + switch (last_midr & (CPUINFO_ARM_MIDR_IMPLEMENTER_MASK | CPUINFO_ARM_MIDR_PART_MASK)) { + case UINT32_C(0x51008040): /* Kryo 485 Gold (Cortex-A76) */ + cpuinfo_isa->dot = true; + break; + default: + // TODO, whitelist, ex: hisilicon_kirin 980... + break; + } + // Blacklist + if (chipset.series == cpuinfo_arm_chipset_series_samsung_exynos && chipset.model == 9810) { + // Spectial machine, disable fp16 + cpuinfo_isa->fp16arith = false; + } if (arm_linux_processors) { free(arm_linux_processors); } - -#endif // #ifdef __ANDROID__ - - // iOS -#if defined(__IOS__) && defined(__aarch64__) - -// A11 -#ifndef CPUFAMILY_ARM_MONSOON_MISTRAL -#define CPUFAMILY_ARM_MONSOON_MISTRAL 0xe81e7ef6 -#endif -// A12 -#ifndef CPUFAMILY_ARM_VORTEX_TEMPEST -#define CPUFAMILY_ARM_VORTEX_TEMPEST 0x07d34b9f -#endif -// A13 -#ifndef CPUFAMILY_ARM_LIGHTNING_THUNDER -#define CPUFAMILY_ARM_LIGHTNING_THUNDER 0x462504d2 -#endif -// A14 -#ifndef CPUFAMILY_ARM_FIRESTORM_ICESTORM -#define CPUFAMILY_ARM_FIRESTORM_ICESTORM 0x1b588bb3 -#endif -// A15 -#ifndef CPUFAMILY_ARM_AVALANCHE_BLIZZARD -#define CPUFAMILY_ARM_AVALANCHE_BLIZZARD 0xda33d83d -#endif -// A16 -#ifndef CPUFAMILY_ARM_EVEREST_SAWTOOTH -#define CPUFAMILY_ARM_EVEREST_SAWTOOTH 0x8765edea -#endif -// A17 Pro -#ifndef CPUFAMILY_ARM_PCORE_ECORE_COLL -#define CPUFAMILY_ARM_PCORE_ECORE_COLL 0x2876f5b5 +} #endif - const uint32_t cpu_family = get_sys_info_by_name("hw.cpufamily"); - // const uint32_t cpu_type = get_sys_info_by_name("hw.cputype"); - // const uint32_t cpu_subtype = get_sys_info_by_name("hw.cpusubtype"); - - cpuinfo_isa->fp16arith = cpu_family == CPUFAMILY_ARM_MONSOON_MISTRAL || - cpu_family == CPUFAMILY_ARM_VORTEX_TEMPEST || - cpu_family == CPUFAMILY_ARM_LIGHTNING_THUNDER || - cpu_family == CPUFAMILY_ARM_FIRESTORM_ICESTORM || - cpu_family == CPUFAMILY_ARM_AVALANCHE_BLIZZARD || - cpu_family == CPUFAMILY_ARM_EVEREST_SAWTOOTH || - cpu_family == CPUFAMILY_ARM_PCORE_ECORE_COLL; - - cpuinfo_isa->dot = cpu_family == CPUFAMILY_ARM_LIGHTNING_THUNDER || - cpu_family == CPUFAMILY_ARM_FIRESTORM_ICESTORM || - cpu_family == CPUFAMILY_ARM_AVALANCHE_BLIZZARD || - cpu_family == CPUFAMILY_ARM_EVEREST_SAWTOOTH || - cpu_family == CPUFAMILY_ARM_PCORE_ECORE_COLL; - - cpuinfo_isa->i8mm = cpu_family == CPUFAMILY_ARM_EVEREST_SAWTOOTH || - cpu_family == CPUFAMILY_ARM_PCORE_ECORE_COLL; -#endif // iOS - -// arm64-osx -#if defined(__APPLE__) && defined(__aarch64__) && !defined(__IOS__) -// Apple M1 -#ifndef CPUFAMILY_AARCH64_FIRESTORM_ICESTORM -#define CPUFAMILY_AARCH64_FIRESTORM_ICESTORM 0x1b588bb3 -#endif -// Apple M2 -#ifndef CPUFAMILY_AARCH64_AVALANCHE_BLIZZARD -#define CPUFAMILY_AARCH64_AVALANCHE_BLIZZARD 0xda33d83d -#endif - const uint32_t cpu_family = get_sys_info_by_name("hw.cpufamily"); - cpuinfo_isa->fp16arith = cpu_family == CPUFAMILY_AARCH64_FIRESTORM_ICESTORM || - cpu_family == CPUFAMILY_AARCH64_AVALANCHE_BLIZZARD; - cpuinfo_isa->dot = cpu_family == CPUFAMILY_AARCH64_FIRESTORM_ICESTORM || - cpu_family == CPUFAMILY_AARCH64_AVALANCHE_BLIZZARD; +#if defined(__APPLE__) && defined(__aarch64__) +static bool have_feature(const char* feature) { + // For more information on sysctlbyname(), see: + // https://developer.apple.com/documentation/kernel/1387446-sysctlbyname/determining_instruction_set_characteristics + int64_t feature_present = 0; + size_t size = sizeof(feature_present); + if (sysctlbyname(feature, &feature_present, &size, NULL, 0) != 0) { + return false; + } + return feature_present; +} +static void _getInfoApple(MNNCPUInfo* cpuinfo_isa) { + /**Ref from + https://developer.apple.com/documentation/kernel/1387446-sysctlbyname/determining_instruction_set_characteristics + */ + if (have_feature("hw.optional.arm.FEAT_FP16")) { + cpuinfo_isa->fp16arith = true; + } + if (have_feature("hw.optional.arm.FEAT_DotProd")) { + cpuinfo_isa->dot = true; + } + if (have_feature("hw.optional.arm.FEAT_I8MM")) { + cpuinfo_isa->i8mm = true; + } +} #endif -#ifndef __ANDROID__ -#if defined (__linux__) && defined (__aarch64__) - +#if defined(__linux__) && defined(__aarch64__) +static void _getInfoAux(MNNCPUInfo* cpuinfo_isa) { + // Use AUX to get info for linux-aarch64 uint32_t isa_features = 0; isa_features = (uint32_t)getauxval(AT_HWCAP); + if (isa_features & CPUINFO_ARM_LINUX_FEATURE_ASIMDDP) { + cpuinfo_isa->dot = true; + } + const uint32_t fp16arith_mask = CPUINFO_ARM_LINUX_FEATURE_FPHP | CPUINFO_ARM_LINUX_FEATURE_ASIMDHP; + if ((isa_features & fp16arith_mask) == fp16arith_mask) { + cpuinfo_isa->fp16arith = true; + } + if (isa_features & CPUINFO_ARM_LINUX_FEATURE_I8MM) { + cpuinfo_isa->i8mm = true; + } + isa_features = (uint32_t)getauxval(AT_HWCAP2); + if (isa_features & CPUINFO_ARM_LINUX_FEATURE_SVE2) { + cpuinfo_isa->sve2 = true; + } +} +#endif - - if (isa_features & CPUINFO_ARM_LINUX_FEATURE_ASIMDDP) { - cpuinfo_isa->dot = true; +static bool _readAll(const std::string& fileName, MNN::AutoStorage& buffer) { + MNN::FileLoader l(fileName.c_str()); + if (false == l.read()) { + return false; + } + return l.merge(buffer); +} +static std::vector _readNumber(const char* data, int length) { + int current = -1; + std::vector res; + for (int i=0; i '9') { + if (current >=0 ) { + res.emplace_back(current); + current = -1; + } + continue; } - - const uint32_t fp16arith_mask = CPUINFO_ARM_LINUX_FEATURE_FPHP | CPUINFO_ARM_LINUX_FEATURE_ASIMDHP; - if ((isa_features & fp16arith_mask) == fp16arith_mask) { - cpuinfo_isa->fp16arith = true; + if (current >= 0) { + current = current*10 + (c - '0'); + } else { + current = c - '0'; } + } + if (current >=0 ) { + res.emplace_back(current); + current = -1; + } + return res; +} +static MNNCPUInfo* gCPUInfo = nullptr; +static void _fillInfo(MNNCPUInfo* cpuInfo); +const MNNCPUInfo* MNNGetCPUInfo() { + if (nullptr != gCPUInfo) { + return gCPUInfo; + } + gCPUInfo = new MNNCPUInfo; + _fillInfo(gCPUInfo); + return gCPUInfo; +} - if (isa_features & CPUINFO_ARM_LINUX_FEATURE_I8MM) { - cpuinfo_isa->i8mm = true; +static void _fillInfo(MNNCPUInfo* cpuinfo_isa) { + cpuinfo_isa->dot = false; + cpuinfo_isa->fp16arith = false; + cpuinfo_isa->i8mm = false; + cpuinfo_isa->sve2 = false; + // android + /**Get CPU Info*/ +#ifdef __linux__ + do { + DIR* root; + std::string dir = "/sys/devices/system/cpu/cpufreq"; + if ((root = opendir(dir.c_str())) == NULL) { + break; + } + CPUGroup group; + struct dirent* ent; + while ((ent = readdir(root)) != NULL) { + if (ent->d_name[0] != '.') { + std::string policyName = dir + "/" + ent->d_name; + std::string cpus = policyName + "/affected_cpus"; + { + MNN::AutoStorage buffer; + if (false == _readAll(cpus, buffer)) { + continue; + } + group.ids = _readNumber((const char*)buffer.get(), buffer.size()); + } + std::string minfreq = policyName + "/cpuinfo_min_freq"; + { + MNN::AutoStorage buffer; + if (_readAll(minfreq, buffer)) { + auto freq = _readNumber((const char*)buffer.get(), buffer.size()); + if (freq.size() > 0) { + group.minFreq = freq[0]; + } + } + } + std::string maxfreq = policyName + "/cpuinfo_max_freq"; + { + MNN::AutoStorage buffer; + if (_readAll(maxfreq, buffer)) { + auto freq = _readNumber((const char*)buffer.get(), buffer.size()); + if (freq.size() > 0) { + group.maxFreq = freq[0]; + } + } + } + cpuinfo_isa->groups.emplace_back(group); + } } + closedir(root); + std::sort(cpuinfo_isa->groups.begin(), cpuinfo_isa->groups.end(), [](const CPUGroup& left, const CPUGroup& right) { + return left.maxFreq < right.maxFreq; + }); + // Merge group if needed + if (cpuinfo_isa->groups.size() >= 2 && cpuinfo_isa->groups[0].maxFreq == cpuinfo_isa->groups[1].maxFreq) { + auto backupGroups = std::move(cpuinfo_isa->groups); + CPUGroup&& current = std::move(backupGroups[0]); + for (int v=1; vgroups.emplace_back(current); + current = std::move(backupGroups[v]); + } else { + current.ids.insert(current.ids.end(), backupGroups[v].ids.begin(), backupGroups[v].ids.end()); + } + } + cpuinfo_isa->groups.emplace_back(current); + } + cpuinfo_isa->cpuNumber = 0; + for (auto& group : cpuinfo_isa->groups) { + cpuinfo_isa->cpuNumber += group.ids.size(); + std::string message = "CPU Group: ["; + for (int v=0; vdot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm); + MNN_PRINT("The device supports: i8sdot:%d, fp16:%d, i8mm: %d, sve2: %d\n", cpuinfo_isa->dot, cpuinfo_isa->fp16arith, cpuinfo_isa->i8mm, cpuinfo_isa->sve2); + return; } diff --git a/source/backend/cpu/CPURuntime.hpp b/source/backend/cpu/CPURuntime.hpp index 4376553c7..7155e023b 100644 --- a/source/backend/cpu/CPURuntime.hpp +++ b/source/backend/cpu/CPURuntime.hpp @@ -9,30 +9,24 @@ #define CPURuntime_hpp #include +#include #include "core/Macro.h" -struct cpuinfo_arm_isa { +struct CPUGroup { + uint32_t minFreq; + uint32_t maxFreq; + std::vector ids; +}; +struct MNNCPUInfo { bool fp16arith; bool dot; bool i8mm; + bool sve2; + std::vector groups; + int cpuNumber = 0; }; -/* - CPU thread mode, only effective on HMP(Heterogeneous Multi-Processing)arch CPUs - that have ARM big.LITTLE technology and on Android - */ -typedef enum { - /* Compliance with Operating System Scheduling */ - MNN_CPU_MODE_DEFAULT = 0, - /* Bind threads to CPU IDs according to CPU frequency, but this mode is power-friendly */ - MNN_CPU_MODE_POWER_FRI = 1, - /* Bind threads to little CPUs */ - MNN_CPU_MODE_LITTLE = 2, - /* Bind threads to big CPUs */ - MNN_CPU_MODE_BIG = 3 -} MNNCPUThreadsMode; -int MNNSetCPUThreadsMode(MNNCPUThreadsMode mode); - -float MNNGetCPUFlops(uint32_t number); -void cpuinfo_arm_init(struct cpuinfo_arm_isa* cpuinfo_isa); +int MNNSetSchedAffinity(const int* cpuIDs, int size); +int MNNGetCurrentPid(); +const MNNCPUInfo* MNNGetCPUInfo(); #endif /* CPUInfo_hpp */ diff --git a/source/backend/cpu/CPUSoftMaxInt8.cpp b/source/backend/cpu/CPUSoftMaxInt8.cpp deleted file mode 100644 index 1630ae52c..000000000 --- a/source/backend/cpu/CPUSoftMaxInt8.cpp +++ /dev/null @@ -1,317 +0,0 @@ -// -// CPUSoftMaxInt8.cpp -// MNNCPU -// -// Created by jbyang on 2023/4/22. -// - -#include "CPUSoftMaxInt8.hpp" -#include "backend/cpu/CPUBackend.hpp" -#include "backend/cpu/CPUFixedPoint.hpp" -#include "backend/cpu/CPUQuantizationUtils.hpp" -#include "core/Macro.h" -#include "core/TensorUtils.hpp" -#include "core/Concurrency.h" -#include "CPUTensorConvert.hpp" - -namespace MNN { - -CPUSoftmaxInt8::CPUSoftmaxInt8(Backend* backend, int axis) : Execution(backend), mAxis(axis), mStorage(2), mTempOutput(2), mNeedUnpackC4(false) { - // do nothing. -} - -const int kScaledDiffIntegerBits = 5; -const int kAccumulationIntegerBits = 12; - -ErrorCode CPUSoftmaxInt8::onResize(const std::vector& inputs, const std::vector& outputs) { - auto input = inputs[0]; - auto output = outputs[0]; - auto inputQuant = TensorUtils::getQuantInfo(input); - float beta = 1.0; - float scale = inputQuant[0]; - PreprocessSoftmaxScaling(beta, scale, kScaledDiffIntegerBits, &mInputMultiplier, &mInputLeftShift); - mDiffMin = -1.0 * CalculateInputRadius(kScaledDiffIntegerBits, mInputLeftShift); - - const auto layout = TensorUtils::getDescribe(input)->dimensionFormat; - mNeedUnpackC4 = layout == MNN_DATA_FORMAT_NC4HW4; - const int dimensions = input->buffer().dimensions; - - int axis = mAxis; - if (axis < 0) { - axis += input->dimensions(); - } - mInside = 1; mOutside = 1; - for (int i = 0; i < axis; ++i) { - mOutside *= input->length(i); - } - mTargetAxis = input->length(axis); - for (int i = axis + 1; i < dimensions; ++i) { - mInside *= input->length(i); - } - - mStorage.buffer().dim[0].extent = input->length(0); - mStorage.buffer().dim[1].extent = input->stride(0); - TensorUtils::getDescribe(&mStorage)->dimensionFormat = MNN_DATA_FORMAT_NHWC; - mStorage.buffer().dimensions = 2; - mStorage.buffer().type = input->getType(); - backend()->onAcquireBuffer(&mStorage, Backend::DYNAMIC); - backend()->onReleaseBuffer(&mStorage, Backend::DYNAMIC); - - if (mNeedUnpackC4) { - mTempOutput.buffer().dim[0].extent = output->length(0); - mTempOutput.buffer().dim[1].extent = output->stride(0); - TensorUtils::getDescribe(&mTempOutput)->dimensionFormat = MNN_DATA_FORMAT_NHWC; - mTempOutput.buffer().dimensions = 2; - mTempOutput.buffer().type = input->getType(); - backend()->onAcquireBuffer(&mTempOutput, Backend::DYNAMIC); - backend()->onReleaseBuffer(&mTempOutput, Backend::DYNAMIC); - } - - return NO_ERROR; -} - -void CPUSoftmaxInt8::QuantizedSoftmax(const uint8_t* inputData, int outerSize, int targetAxis, - int32_t inputBetaMultiplier, int32_t inputBetaLeftShift, - uint8_t* outputData, int threadNum) { - using FixedPointScaledDiff = FixedPoint; - using FixedPointAccum = FixedPoint; - using FixedPoint0 = FixedPoint; - - const int depth = targetAxis; -#ifdef MNN_USE_SSE - int32_t zeroPoint = 128; - int32_t minValue = 0; - int32_t maxValue = 255; - const uint8_t* src_ = inputData; - uint8_t* dst_ = outputData; -#else - int32_t zeroPoint = 0; - int32_t minValue = -128; - int32_t maxValue = 127; - const int8_t* src_ = (int8_t*)inputData; - int8_t* dst_ = (int8_t*)outputData; -#endif - MNN_CONCURRENCY_BEGIN(tId, threadNum) { - auto inputDataPtr = src_ + tId * depth; - uint8_t* outputDataPtr = (uint8_t*)dst_ + tId * depth; - for (int b = (int)tId; b < outerSize; b += threadNum, inputDataPtr += depth * threadNum, outputDataPtr += depth * threadNum) { - // Determine the largest entry in the current row - int8_t maxInRow = -128; - { - int c = 0; -#ifdef MNN_USE_NEON - int8x16_t max16_0 = vdupq_n_s8(0); - int8x16_t max16_1 = vdupq_n_s8(0); - for (; c <= depth - 32; c += 32) { - max16_0 = vmaxq_s8(max16_0, vld1q_s8(inputDataPtr + c + 0)); - max16_1 = vmaxq_s8(max16_1, vld1q_s8(inputDataPtr + c + 16)); - } - int8x16_t max16 = vmaxq_s8(max16_0, max16_1); - if (c <= depth - 16) { - max16 = vmaxq_s8(max16, vld1q_s8(inputDataPtr + c)); - c += 16; - } - int8x8_t max8 = vmax_s8(vget_low_s8(max16), vget_high_s8(max16)); - if (c <= depth - 8) { - max8 = vmax_s8(max8, vld1_s8(inputDataPtr + c)); - c += 8; - } - int8x8_t max4 = vmax_s8(max8, vext_s8(max8, max8, 4)); - int8x8_t max2 = vmax_s8(max4, vext_s8(max4, max4, 2)); - int8x8_t max1 = vpmax_s8(max2, max2); - maxInRow = vget_lane_s8(max1, 0); -#endif - for (; c < depth; ++c) { - maxInRow = std::max(maxInRow, static_cast(inputDataPtr[c] - zeroPoint)); - } - } - -#ifdef MNN_USE_NEON - using FixedPointAccumInt32x4 = FixedPoint; - using FixedPointScaledDiffInt32x4 = FixedPoint; - using FixedPoint0Int32x4 = FixedPoint; - FixedPoint0Int32x4 input_beta_multiplier_f0 = FixedPoint0Int32x4::FromScalarRaw(inputBetaMultiplier); - int16x8_t max_in_row_s16 = vdupq_n_s16(maxInRow); -#endif - - FixedPointAccum sumOfExps = FixedPointAccum::Zero(); - { - int c = 0; -#ifdef MNN_USE_NEON - int32x4_t diff_min_s32 = vdupq_n_s32(mDiffMin); - FixedPointAccumInt32x4 sum_of_exps_0 = FixedPointAccumInt32x4::Zero(); - FixedPointAccumInt32x4 sum_of_exps_1 = FixedPointAccumInt32x4::Zero(); - FixedPointAccumInt32x4 zeros = FixedPointAccumInt32x4::Zero(); - for (; c <= depth - 8; c += 8) { - int16x8_t input_s16 = vmovl_s8(vld1_s8(inputDataPtr + c)); - int16x8_t input_diff_s16 = - vsubq_s16(input_s16, max_in_row_s16); - int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); - int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); - int32x4_t mask_0 = - MaskIfGreaterThanOrEqual(input_diff_s32_0, diff_min_s32); - int32x4_t mask_1 = - MaskIfGreaterThanOrEqual(input_diff_s32_1, diff_min_s32); - FixedPointScaledDiffInt32x4 scaled_diff_0 = - input_beta_multiplier_f0 * - FixedPointScaledDiffInt32x4::FromRaw( - ShiftLeft(input_diff_s32_0, inputBetaLeftShift)); - FixedPointScaledDiffInt32x4 scaled_diff_1 = - input_beta_multiplier_f0 * - FixedPointScaledDiffInt32x4::FromRaw( - ShiftLeft(input_diff_s32_1, inputBetaLeftShift)); - FixedPointAccumInt32x4 exps_0 = - Rescale( - exp_on_negative_values(scaled_diff_0)); - FixedPointAccumInt32x4 exps_1 = - Rescale( - exp_on_negative_values(scaled_diff_1)); - FixedPointAccumInt32x4 masked_exps_0 = - SelectUsingMask(mask_0, exps_0, zeros); - FixedPointAccumInt32x4 masked_exps_1 = - SelectUsingMask(mask_1, exps_1, zeros); - sum_of_exps_0 = sum_of_exps_0 + masked_exps_0; - sum_of_exps_1 = sum_of_exps_1 + masked_exps_1; - } - int32x4_t sum_of_exps_reduced_4 = (sum_of_exps_0 + sum_of_exps_1).raw(); - int32x2_t sum_of_exps_reduced_2 = - vadd_s32(vget_low_s32(sum_of_exps_reduced_4), - vget_high_s32(sum_of_exps_reduced_4)); - int32x2_t sum_of_exps_reduced_1 = - vpadd_s32(sum_of_exps_reduced_2, sum_of_exps_reduced_2); - sumOfExps = - FixedPointAccum::FromRaw(vget_lane_s32(sum_of_exps_reduced_1, 0)); -#endif - for (; c < depth; ++c) { - int32_t inputDiff = (inputDataPtr[c] - zeroPoint) - maxInRow; - if (inputDiff >= mDiffMin) { - const int32_t inputDiffRescaled = - MultiplyByQuantizedMultiplierGreaterThanOne(inputDiff, inputBetaMultiplier, inputBetaLeftShift); - const FixedPointScaledDiff scaledDiffF8 = FixedPointScaledDiff::FromRaw(inputDiffRescaled); - sumOfExps = sumOfExps + Rescale(exp_on_negative_values(scaledDiffF8)); - } - } - } - - int fixedSumOfExps = sumOfExps.raw(); - #if defined(_MSC_VER) - int headroomPlusOne; - { - unsigned long leading_zero = 0; - if (_BitScanReverse(&leading_zero, static_cast(fixedSumOfExps))) { - headroomPlusOne = 31 - leading_zero; - } else { - headroomPlusOne = 31; - } - } - #else - int headroomPlusOne = __builtin_clz(static_cast(fixedSumOfExps)); - #endif - - int numBitsOverUnit = kAccumulationIntegerBits - headroomPlusOne; - - if (numBitsOverUnit + 31 - 8 > 31) { - numBitsOverUnit = 8; - } - int32_t shiftedSumMinusOne = static_cast((static_cast(fixedSumOfExps) << headroomPlusOne) - - (static_cast(1) << 31)); - FixedPoint0 shiftedScale = one_over_one_plus_x_for_x_in_0_1(FixedPoint0::FromRaw(shiftedSumMinusOne)); - - { - int c = 0; -#ifdef MNN_USE_NEON - int16x8_t diff_min_s16 = vdupq_n_s16(mDiffMin); - for (; c <= depth - 8; c += 8) { - int16x8_t input_s16 = vmovl_s8(vld1_s8(inputDataPtr + c)); - int16x8_t input_diff_s16 = - vsubq_s16(input_s16, max_in_row_s16); - int32x4_t input_diff_s32_0 = vmovl_s16(vget_low_s16(input_diff_s16)); - int32x4_t input_diff_s32_1 = vmovl_s16(vget_high_s16(input_diff_s16)); - uint8x8_t mask = vmovn_u16(vcgeq_s16(input_diff_s16, diff_min_s16)); - FixedPointScaledDiffInt32x4 scaled_diff_0 = - input_beta_multiplier_f0 * - FixedPointScaledDiffInt32x4::FromRaw( - ShiftLeft(input_diff_s32_0, inputBetaLeftShift)); - FixedPointScaledDiffInt32x4 scaled_diff_1 = - input_beta_multiplier_f0 * - FixedPointScaledDiffInt32x4::FromRaw( - ShiftLeft(input_diff_s32_1, inputBetaLeftShift)); - FixedPoint0Int32x4 exp_0 = exp_on_negative_values(scaled_diff_0); - FixedPoint0Int32x4 exp_1 = exp_on_negative_values(scaled_diff_1); - int32x4_t output_s32_0 = RoundingDivideByPOT( - vqrdmulhq_n_s32(exp_0.raw(), shiftedScale.raw()), - numBitsOverUnit + 31 - 8); - int32x4_t output_s32_1 = RoundingDivideByPOT( - vqrdmulhq_n_s32(exp_1.raw(), shiftedScale.raw()), - numBitsOverUnit + 31 - 8); - int16x8_t output_s16 = - vcombine_s16(vqmovn_s32(output_s32_0), vqmovn_s32(output_s32_1)); - uint8x8_t output_s8 = vqmovun_s16(output_s16); - uint8x8_t masked_output = vbsl_u8(mask, output_s8, vdup_n_u8(0)); - vst1_u8(outputDataPtr + c, masked_output); - } -#endif - for (; c < depth; ++c) { - int32_t inputDiff = (inputDataPtr[c] - zeroPoint) - maxInRow; - if (inputDiff >= mDiffMin) { - const int inputDiffRescaled = - MultiplyByQuantizedMultiplierGreaterThanOne(inputDiff, inputBetaMultiplier, inputBetaLeftShift); - const FixedPointScaledDiff scaledDiffF8 = FixedPointScaledDiff::FromRaw(inputDiffRescaled); - FixedPoint0 expIn0 = exp_on_negative_values(scaledDiffF8); - - int unsatOutput = RoundingDivideByPOT((shiftedScale * expIn0).raw(), numBitsOverUnit + 31 - 8) + zeroPoint; - outputDataPtr[c] = std::max(std::min(unsatOutput, maxValue), minValue); - - } - else { - outputDataPtr[c] = zeroPoint; - } - } - } - } - } - MNN_CONCURRENCY_END(); -} - -ErrorCode CPUSoftmaxInt8::onExecute(const std::vector& inputs, - const std::vector& outputs) { - MNN_ASSERT(1 == inputs.size()); - MNN_ASSERT(1 == outputs.size()); - - Tensor* input = inputs[0]; - Tensor* output = outputs[0]; - uint8_t* inputData = input->host(); - uint8_t* outputData = output->host(); - - auto batch = input->batch(); - auto dimentions = input->dimensions(); - int areaInput = 1; - for (int i = 2; i < dimentions; ++i) { - areaInput *= input->length(i); - } - int threadNum = ((CPUBackend *)backend())->threadNumber(); - - uint8_t* tempInputData = mStorage.host(); - auto functions = ((CPUBackend*)backend())->functions(); - if (mNeedUnpackC4) { - uint8_t* tempOutputData = mTempOutput.host(); - CPUTensorConverter::convert(inputData, outputData, MNN_DATA_FORMAT_NC4HW4, MNN_DATA_FORMAT_NCHW, batch, areaInput, input->channel(), 1, functions); - CPUTensorConverter::convert(outputData, tempInputData, MNN_DATA_FORMAT_NCHW, MNN_DATA_FORMAT_NHWC, mOutside, mInside, mTargetAxis, 1, functions); - QuantizedSoftmax(tempInputData, mInside * mOutside, mTargetAxis, mInputMultiplier, mInputLeftShift, tempOutputData, threadNum); - CPUTensorConverter::convert(tempOutputData, tempInputData, MNN_DATA_FORMAT_NHWC, MNN_DATA_FORMAT_NCHW, mOutside, mInside, mTargetAxis, 1, functions); - CPUTensorConverter::convert(tempInputData, outputData, MNN_DATA_FORMAT_NCHW, MNN_DATA_FORMAT_NC4HW4, batch, areaInput, input->channel(), 1, functions); - } else { - CPUTensorConverter::convert(inputData, outputData, MNN_DATA_FORMAT_NCHW, MNN_DATA_FORMAT_NHWC, mOutside, mInside, mTargetAxis, 1, functions); - QuantizedSoftmax(outputData, mInside * mOutside, mTargetAxis, mInputMultiplier, mInputLeftShift, tempInputData, threadNum); - CPUTensorConverter::convert(tempInputData, outputData, MNN_DATA_FORMAT_NHWC, MNN_DATA_FORMAT_NCHW, mOutside, mInside, mTargetAxis, 1, functions); - } - - return NO_ERROR; -} - -Execution* CPUSoftmaxInt8::create(const MNN::Op *op, Backend *backend) { - auto axis = op->main_as_Axis()->axis(); - return new CPUSoftmaxInt8(backend, axis); -} - -} diff --git a/source/backend/cpu/CPUSoftMaxInt8.hpp b/source/backend/cpu/CPUSoftMaxInt8.hpp deleted file mode 100644 index a1f8e4da4..000000000 --- a/source/backend/cpu/CPUSoftMaxInt8.hpp +++ /dev/null @@ -1,39 +0,0 @@ -// -// CPUSoftMaxInt8.hpp -// MNNCPU -// -// Created by MNN on 2023/4/22. -// - -#ifndef CPUSoftMaxInt8_hpp -#define CPUSoftMaxInt8_hpp -#include "core/Execution.hpp" -#include -namespace MNN { - -class CPUSoftmaxInt8 : public Execution { -public: - CPUSoftmaxInt8(Backend *backend, int axis); - virtual ~CPUSoftmaxInt8() = default; - virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - static Execution* create(const MNN::Op *op, Backend *backend); - - void QuantizedSoftmax(const uint8_t *inputData, int outerSize, int targetAxis, int32_t inputBetaMultiplier, - int32_t inputBetaLeftShift, uint8_t *output_data, int threadNum); - -private: - int32_t mInputMultiplier; - int mInputLeftShift; - int mDiffMin; - int mAxis; - int mInside; - int mOutside; - int mTargetAxis; - Tensor mStorage; - Tensor mTempOutput; - bool mNeedUnpackC4; -}; - -} -#endif /* CPUSoftMaxInt8_hpp */ diff --git a/source/backend/cpu/CPUSoftmax.cpp b/source/backend/cpu/CPUSoftmax.cpp index c8cfecede..d4811899a 100644 --- a/source/backend/cpu/CPUSoftmax.cpp +++ b/source/backend/cpu/CPUSoftmax.cpp @@ -8,13 +8,13 @@ #include #include "backend/cpu/CPUSoftmax.hpp" -#include "backend/cpu/CPUSoftMaxInt8.hpp" #include "backend/cpu/CPUBackend.hpp" #include "backend/cpu/compute/CommonOptFunction.h" #include "core/Concurrency.h" #include "core/Macro.h" #include "core/TensorUtils.hpp" #include "CPUTensorConvert.hpp" +#include "CPUCast.hpp" namespace MNN { static void ___MNNSoftmax(float* dest, const float* source, size_t size, MNNBinaryExecute mulfunction) { @@ -71,19 +71,39 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { addFunction = fp32Core->MNNSelectBinaryFunctionForFloat(BinaryOpOperation_ADD); recFunction = fp32Core->MNNSelectUnaryFunctionForFloat(UnaryOpOperation_RECIPROCAL, 1);//Use high precision MNN_CONCURRENCY_BEGIN(tId, threadNumber) { - auto tempInput = (float*)(mTmpInput.ptr() + tId * outsideStride * sizeof(float)); - auto tempOutput = (float*)(mTmpOutput.ptr() + tId * outsideStride * sizeof(float)); + float* tempOutput = nullptr; + float* tempInput = nullptr; + if (mTmpInput.ptr()) { + tempInput = (float*)(mTmpInput.ptr() + tId * outsideStride * sizeof(float)); + } + + if (mTmpOutput.ptr()) { + tempOutput = (float*)(mTmpOutput.ptr() + tId * outsideStride * sizeof(float)); + } + for (int o=tId; oscale, mInQuantAttr->zero, mInQuantAttr->min, mInQuantAttr->max, cpuBn); + ::memcpy(tempOutput, tempInput, mInside * 4); + for (int z = 1; z < mChannel; ++z) { + maxFunction(tempOutput, tempOutput, tempInput + z * mInside, mInside, -1); + } + } else { + ::memcpy(tempInput, srcO, mInside * mLowOrInt8); + for (int z = 1; z < mChannel; ++z) { + maxFunction(tempInput, tempInput, srcO + z * mInside * mLowOrInt8, mInside, -1); + } } // Sub Max for (int z=0; zbytes != 4) { + if (mLowOrInt8 != 4) { workSrc = tempInput; workDst = tempOutput; - core->MNNLowpToFp32((int16_t*)(dstO), workSrc, outsideStride); + if (mLowOrInt8 == 2) { + core->MNNLowpToFp32((int16_t*)(dstO), workSrc, outsideStride); + } } // Use Fp32 to compute Begin MNNExp(workDst, workSrc, exprOffset, outsideStride); @@ -113,8 +135,12 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { mulFunction(workDst + z * mInside, workDst + z * mInside, tempInput, mInside, -1); } // Use Fp32 Compute end - if (core->bytes != 4) { + if (mLowOrInt8 == 2) { core->MNNFp32ToLowp(workDst, (int16_t*)(dstO), outsideStride); + } else if (mLowOrInt8 == 1) { + CPUCastCreator::cast(workDst, dstO, CPUCastCreator::FlOAT_TO_INT8, outsideStride, mOutQuantAttr->scale, mOutQuantAttr->zero, mOutQuantAttr->min, mOutQuantAttr->max, cpuBn); + } else { + // do nothing. } } }; @@ -122,19 +148,29 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { return 0; } MNN_CONCURRENCY_BEGIN(tId, threadNumber) { - auto tempInput = (float*)(mTmpInput.ptr() + tId * outsideStride * sizeof(float)); - auto tempOutput = (float*)(mTmpOutput.ptr() + tId * outsideStride * sizeof(float)); + float* tempInput; + float* tempOutput; + if (mTmpInput.ptr()) { + tempInput = (float*)(mTmpInput.ptr() + tId * outsideStride * sizeof(float)); + } + if (mTmpOutput.ptr()) { + tempOutput = (float*)(mTmpOutput.ptr() + tId * outsideStride * sizeof(float)); + } for (int o=tId; oMNNLowpToFp32((int16_t*)(srcO), tempInput, outsideStride); workDst = tempOutput; workSrc = tempInput; + } else if (mLowOrInt8 == 1) { + CPUCastCreator::cast(srcO, tempInput, CPUCastCreator::INT8_TO_FlOAT, outsideStride, mInQuantAttr->scale, mInQuantAttr->zero, mInQuantAttr->min, mInQuantAttr->max, cpuBn); + workDst = tempOutput; + workSrc = tempInput; } } else { int dims[] = { @@ -143,12 +179,17 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { mInside, mChannel }; - if (bytes != 4) { + if (mLowOrInt8 == 2) { MNN_ASSERT(bytes == 2); MNNTranspose16Bit((int16_t*)tempOutput, (int16_t*)(srcO), dims); core->MNNLowpToFp32((int16_t*)tempOutput, tempInput, outsideStride); workDst = tempOutput; workSrc = tempInput; + } else if (mLowOrInt8 == 1) { + CPUCastCreator::cast(srcO, tempOutput, CPUCastCreator::INT8_TO_FlOAT, outsideStride, mInQuantAttr->scale, mInQuantAttr->zero, mInQuantAttr->min, mInQuantAttr->max, cpuBn); + MNNTranspose32Bit((int32_t*)tempInput, (int32_t*)tempOutput, dims); + workDst = tempOutput; + workSrc = tempInput; } else { // Use output to cache transpoe result MNNTranspose32Bit((int32_t*)dstO, (int32_t*)(srcO), dims); @@ -166,8 +207,10 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { } // PostTreat if (1 == mInside) { - if (bytes != 4) { + if (mLowOrInt8 == 2) { core->MNNFp32ToLowp(tempOutput, (int16_t*)(dstO), outsideStride); + } else if (mLowOrInt8 == 1) { + CPUCastCreator::cast(tempOutput, dstO, CPUCastCreator::FlOAT_TO_INT8, outsideStride, mOutQuantAttr->scale, mOutQuantAttr->zero, mOutQuantAttr->min, mOutQuantAttr->max, cpuBn); } } else { int dims[] = { @@ -176,10 +219,13 @@ int CPUSoftmax::_softmaxCommon(const uint8_t *srcData, uint8_t *dstData) { mChannel, mInside }; - if (bytes != 4) { - MNN_ASSERT(bytes == 2); + if (mLowOrInt8 == 2) { + MNN_ASSERT(bytes == 2); core->MNNFp32ToLowp((float*)tempOutput, (int16_t*)tempInput, outsideStride); MNNTranspose16Bit((int16_t*)dstO, (int16_t*)(tempInput), dims); + } else if (mLowOrInt8 == 1) { + MNNTranspose32Bit((int32_t*)tempInput, (int32_t*)tempOutput, dims); + CPUCastCreator::cast(tempInput, dstO, CPUCastCreator::FlOAT_TO_INT8, outsideStride, mOutQuantAttr->scale, mOutQuantAttr->zero, mOutQuantAttr->min, mOutQuantAttr->max, cpuBn); } else { MNNTranspose32Bit((int32_t*)dstO, (int32_t*)(tempInput), dims); } @@ -227,14 +273,24 @@ ErrorCode CPUSoftmax::onResize(const std::vector &inputs, const std::v mInside = inside; mOutside = outside; mChannel = channel; + + mLowOrInt8 = 4; + if (static_cast(backend())->functions()->bytes != 4) { + mLowOrInt8 = 2; + } + if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) { + mLowOrInt8 = 1; + } + mInQuantAttr = TensorUtils::getDescribe(inputs[0])->quantAttr; + mOutQuantAttr = TensorUtils::getDescribe(outputs[0])->quantAttr; auto cpuBn = static_cast(backend()); - if (inside != 1 || cpuBn->functions()->bytes != 4) { // not run _softmax1, we need maxValue Tensor and sumValue Tensor. + if (inside != 1 || mLowOrInt8 != 4) { // not run _softmax1, we need maxValue Tensor and sumValue Tensor. int threadNum = cpuBn->threadNumber(); auto buf = cpuBn->getBufferAllocator(); threadNum = ALIMIN(threadNum, outside); mTmpInput = buf->alloc(threadNum * inside * channel * sizeof(float)); - if (cpuBn->functions()->bytes != 4) { + if (mLowOrInt8 != 4) { mTmpOutput = buf->alloc(threadNum * inside * channel * sizeof(float)); buf->free(mTmpOutput); } @@ -274,9 +330,9 @@ ErrorCode CPUSoftmax::onExecute(const std::vector &inputs, const std:: return NO_ERROR; } auto functions = static_cast(backend())->functions(); - CPUTensorConverter::convert(inputDataPtr, outputDataPtr, MNN_DATA_FORMAT_NC4HW4, MNN_DATA_FORMAT_NCHW, batch, areaInput, inputTensor->channel(), functions->bytes, functions); + CPUTensorConverter::convert(inputDataPtr, outputDataPtr, MNN_DATA_FORMAT_NC4HW4, MNN_DATA_FORMAT_NCHW, batch, areaInput, inputTensor->channel(), mLowOrInt8, functions); _softmaxCommon((uint8_t*)outputDataPtr, (uint8_t*)tempData); - CPUTensorConverter::convert(tempData, outputDataPtr, MNN_DATA_FORMAT_NCHW, MNN_DATA_FORMAT_NC4HW4, batch, areaInput, inputTensor->channel(), functions->bytes, functions); + CPUTensorConverter::convert(tempData, outputDataPtr, MNN_DATA_FORMAT_NCHW, MNN_DATA_FORMAT_NC4HW4, batch, areaInput, inputTensor->channel(), mLowOrInt8, functions); return NO_ERROR; } @@ -293,11 +349,8 @@ class CPUSoftmaxCreator : public CPUBackend::Creator { public: virtual Execution *onCreate(const std::vector &inputs, const std::vector &outputs, const MNN::Op *op, Backend *backend) const override { - if (CPUBackend::getDataType(inputs[0]) == DataType_DT_INT8 || inputs[0]->getType().bytes() == 1) { - return CPUSoftmaxInt8::create(op, backend); - } else { - return CPUSoftmax::create(op, backend); - } + return CPUSoftmax::create(op, backend); + } }; diff --git a/source/backend/cpu/CPUSoftmax.hpp b/source/backend/cpu/CPUSoftmax.hpp index c76cd7554..ec6f25102 100644 --- a/source/backend/cpu/CPUSoftmax.hpp +++ b/source/backend/cpu/CPUSoftmax.hpp @@ -11,6 +11,7 @@ #include "core/Execution.hpp" #include "core/BufferAllocator.hpp" +#include "core/TensorUtils.hpp" namespace MNN { class CPUSoftmax : public Execution { public: @@ -32,6 +33,11 @@ class CPUSoftmax : public Execution { int mInside; int mOutside; int mChannel; + + std::shared_ptr mInQuantAttr; + std::shared_ptr mOutQuantAttr; + + int mLowOrInt8; }; } // namespace MNN diff --git a/source/backend/cpu/ThreadPool.cpp b/source/backend/cpu/ThreadPool.cpp index 75020fdd7..4b489151b 100644 --- a/source/backend/cpu/ThreadPool.cpp +++ b/source/backend/cpu/ThreadPool.cpp @@ -10,15 +10,6 @@ #include #include -//#define MNN_THREAD_LOCK_CPU - -#ifdef MNN_THREAD_LOCK_CPU -#include -#include -#include -#include -#endif - #define MNN_THREAD_POOL_MAX_TASKS 2 namespace MNN { ThreadPool* ThreadPool::gInstance = nullptr; @@ -45,115 +36,13 @@ void ThreadPool::destroy() { gInstance = nullptr; } } -#ifdef MNN_THREAD_LOCK_CPU -static int getNumberOfCPU() { - FILE* fp = fopen("/proc/cpuinfo", "rb"); - if (!fp) { - return 1; - } - int number = 0; - char buffer[1024]; - while (!feof(fp)) { - char* str = fgets(buffer, 1024, fp); - if (!str) { - break; - } - if (memcmp(buffer, "processor", 9) == 0) { - number++; - } - } - fclose(fp); - if (number < 1) { - number = 1; - } - return number; -} - -static int getCPUMaxFreqKHz(int cpuID) { - char path[256]; - sprintf(path, "/sys/devices/system/cpu/cpufreq/stats/cpu%d/time_in_state", cpuID); - FILE* fp = fopen(path, "rb"); - if (!fp) { - sprintf(path, "/sys/devices/system/cpu/cpu%d/cpufreq/stats/time_in_state", cpuID); - fp = fopen(path, "rb"); - if (!fp) { - sprintf(path, "/sys/devices/system/cpu/cpu%d/cpufreq/cpuinfo_max_freq", cpuID); - fp = fopen(path, "rb"); - if (!fp) { - return -1; - } - int maxfrequency = -1; - fscanf(fp, "%d", &maxfrequency); - fclose(fp); - return maxfrequency; - } - } - int maxfrequency = 0; - while (!feof(fp)) { - int frequency = 0; - int history = fscanf(fp, "%d %*d", &frequency); - if (history != 1) { - break; - } - if (frequency > maxfrequency) { - maxfrequency = frequency; - } - } - fclose(fp); - return maxfrequency; -} - -static std::vector sortCPUIDByMaxFrequency(int maxNumbers) { - const int cpuNumbers = getNumberOfCPU(); - if (cpuNumbers == 0) { - return {}; - } - std::vector cpuIDs; - std::vector> cpusFrequency; - cpusFrequency.resize(cpuNumbers); - for (int i = 0; i < cpuNumbers; ++i) { - int frequency = getCPUMaxFreqKHz(i); - cpusFrequency[i].first = frequency; - cpusFrequency[i].second = i; - } - maxNumbers = std::min(maxNumbers, cpuNumbers); - std::sort(cpusFrequency.rbegin(), cpusFrequency.rend()); - cpuIDs.resize(maxNumbers); - for (int i = 0; i < maxNumbers; ++i) { - cpuIDs[i] = cpusFrequency[i].second; - } - // FUNC_PRINT(cpusFrequency[0].first); - return cpuIDs; -} - -static int setSchedAffinity(const std::vector& cpuIDs) { -#define __NCPUBITS (8 * sizeof(unsigned long)) - typedef struct { - unsigned long __bits[CPU_SETSIZE / __NCPUBITS]; - } cpu_set_t; - - // set affinity for thread - - pid_t pid = gettid(); - cpu_set_t mask; - CPU_ZERO(&mask); - for (int i = 1; i < (int)cpuIDs.size(); i++) { - CPU_SET(cpuIDs[i], &mask); - } - - int syscallret = syscall(__NR_sched_setaffinity, pid, sizeof(mask), &mask); - if (syscallret) { - MNN_PRINT("syscall error %d\n", syscallret); - return -1; - } - - return 0; -} -#endif // arch ThreadPool::ThreadPool(int numberThread) { mNumberThread = numberThread; - mActiveCount = 0; + mActiveCount.resize(numberThread); + for (int i=0; i sortedCPUIDs = sortCPUIDByMaxFrequency(numberThread); -#endif for (int i = 1; i < mNumberThread; ++i) { int threadIndex = i; -#ifdef MNN_THREAD_LOCK_CPU - mWorkers.emplace_back([this, sortedCPUIDs, threadIndex]() { -#else mWorkers.emplace_back([this, threadIndex]() { -#endif -#ifdef MNN_THREAD_LOCK_CPU - int res = setSchedAffinity(sortedCPUIDs); -#endif while (!mStop) { - while (mActiveCount > 0) { + while (*mActiveCount[threadIndex] > 0) { for (int i = 0; i < MNN_THREAD_POOL_MAX_TASKS; ++i) { if (*mTasks[i].second[threadIndex]) { mTasks[i].first.first(threadIndex); @@ -186,7 +65,7 @@ ThreadPool::ThreadPool(int numberThread) { std::this_thread::yield(); } std::unique_lock _l(mQueueMutex); - mCondition.wait(_l, [this] { return mStop || mActiveCount > 0; }); + mCondition.wait(_l, [this, threadIndex] { return mStop || *mActiveCount[threadIndex] > 0; }); } }); } @@ -206,6 +85,9 @@ ThreadPool::~ThreadPool() { delete c; } } + for (int i=0; imTaskAvailable[index] = true; } -void ThreadPool::active() { +void ThreadPool::active(int threadNumber) { if (nullptr == gInstance) { return; } { std::lock_guard _l(gInstance->mQueueMutex); - gInstance->mActiveCount++; + for (int i=0; imActiveCount[i])++; + } } gInstance->mCondition.notify_all(); } -void ThreadPool::deactive() { +void ThreadPool::deactive(int threadNumber) { if (nullptr == gInstance) { return; } - gInstance->mActiveCount--; + for (int i=0; imActiveCount[i])--; + } } -void ThreadPool::enqueue(TASK&& task, int index) { +void ThreadPool::enqueue(TASK&& task, int index, int threadNumber) { if (1 >= task.second || 0 > index) { for (int i = 0; i < task.second; ++i) { task.first(i); @@ -257,25 +143,24 @@ void ThreadPool::enqueue(TASK&& task, int index) { return; } MNN_ASSERT(nullptr != gInstance); - gInstance->enqueueInternal(std::move(task), index); + gInstance->enqueueInternal(std::move(task), index, threadNumber); } -void ThreadPool::enqueueInternal(TASK&& task, int index) { - if (mActiveCount == 0) { +void ThreadPool::enqueueInternal(TASK&& task, int index, int threadNumber) { + if (threadNumber <= 1) { for (int i = 0; i < task.second; ++i) { task.first(i); } return; } int workSize = task.second; - if (workSize > mNumberThread) { + if (workSize > threadNumber) { mTasks[index].first = std::make_pair( - [workSize, &task, this](int tId) { - for (int v = tId; v < workSize; v += mNumberThread) { + [workSize, &task, threadNumber, this](int tId) { + for (int v = tId; v < workSize; v += threadNumber) { task.first(v); } - }, - mNumberThread); - workSize = mNumberThread; + },threadNumber); + workSize = threadNumber; } else { mTasks[index].first = std::move(task); } diff --git a/source/backend/cpu/ThreadPool.hpp b/source/backend/cpu/ThreadPool.hpp index c491338b5..f93ee9cf8 100644 --- a/source/backend/cpu/ThreadPool.hpp +++ b/source/backend/cpu/ThreadPool.hpp @@ -25,10 +25,10 @@ class MNN_PUBLIC ThreadPool { int number() const { return mNumberThread; } - static void enqueue(TASK&& task, int index); + static void enqueue(TASK&& task, int index, int threadNumber); - static void active(); - static void deactive(); + static void active(int threadNumber); + static void deactive(int threadNumber); static int acquireWorkIndex(); static void releaseWorkIndex(int index); @@ -37,7 +37,7 @@ class MNN_PUBLIC ThreadPool { static void destroy(); private: - void enqueueInternal(TASK&& task, int index); + void enqueueInternal(TASK&& task, int index, int threadNumber); static ThreadPool* gInstance; ThreadPool(int number = 0); @@ -52,7 +52,7 @@ class MNN_PUBLIC ThreadPool { std::mutex mQueueMutex; int mNumberThread = 0; - std::atomic_int mActiveCount = {0}; + std::vector mActiveCount; }; } // namespace MNN #endif diff --git a/source/backend/cpu/arm/CommonNeonBF16.cpp b/source/backend/cpu/arm/CommonNeonBF16.cpp deleted file mode 100644 index abb1bb1be..000000000 --- a/source/backend/cpu/arm/CommonNeonBF16.cpp +++ /dev/null @@ -1,187 +0,0 @@ - - -#if defined(MNN_SUPPORT_BF16) // CmakeList.txt does not work for ios, this file has to be self-filted, MNN.podspec doesnot filter this. - -#include "core/Macro.h" -#include "../compute/CommonOptFunction.h" -#include "./FunctionSummary.hpp" - -// todo: search for proper value for bf16 -void NEON_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP) { - *eP = 12; - *lP = 1; -#ifdef __aarch64__ - *hP = 8; -#else - *hP = 4; -#endif -} - -#ifdef __aarch64__ -#define EP 12 -#define HP 8 -#define LP 4 -void ARMV86_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP) { - *eP = EP; - *hP = HP; - *lP = LP; -} -void ARMV86_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t l, bool transpose) { - // [l, h] -> [h/hp, l/lp, hp, lp] - auto dest = (int16_t*)destF; - auto source = (const int16_t*)sourceF; - auto lCP = UP_DIV(l, LP); - auto hCP = UP_DIV(h, HP); - int sYstride = 1; - int sXstride = h; - if (transpose) { - sYstride = l; - sXstride = 1; - } - ::memset(dest, 0, lCP * hCP * sizeof(int16_t) * HP * LP); - for (int y = 0; y < h; ++y) { - int yC = y / HP; - int yR = y % HP; - for (int x = 0; x < l; ++x) { - int xC = x / LP; - int xR = x % LP; - dest[xR + yR * LP + xC * HP * LP + yC * HP * LP * lCP] = source[sXstride * x + sYstride * y]; - } - } -} -void ARMV86_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) { - // [l/4, e, 4] -> [l/4, ep, 4] - int number = info[0]; - int eReal = info[1]; - int eDest = info[2]; - int offset = info[3]; - if (1 == number) { - int l = el[1]; - if (l % 8 != 0) { - auto lAigin = UP_DIV(l, LP) * LP; - ::memset(destOrigin, 0, eDest * lAigin * sizeof(int16_t)); - } - } - - for (int n=0; n [l/4, ep, 4] - for (int x = 0; x < lDiv; ++x) { - auto destX = (int64_t*)(dest + x * eDest * 4); - auto srcX = (int64_t*)(source + x * eReal * 4); - for (int y = 0; y < e; ++y) { - destX[y] = srcX[y * offset]; - } - } - continue; - } - for (int x = 0; x < l; ++x) { - auto dl = lOR + x; - auto dlC = dl / LP; - auto dlR = dl % LP; - auto xC = x / LP; - auto xR = x % LP; - auto destX = dest + dlC * eDest * LP + dlR; - auto srcX = source + xC * eReal * LP + xR; - for (int y = 0; y < e; ++y) { - destX[y * 4] = srcX[y * 4 * offset]; - } - } - } -} -#undef EP -#undef HP -#undef LP -void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t l, bool transpose) { - auto hP = (int)h / 8; - auto hR = (int)hP * 8; - int16_t* dest = (int16_t*)destFloat; - int16_t* source = (int16_t*)sourceFloat; - if (hR != h) { - ::memset(dest, 0, UP_DIV(h, 8) * 8 * l * sizeof(int16_t)); - } - if (!transpose) { - for (int y = 0; y < hP; ++y) { - auto destY = dest + y * 8 * l; - auto sourceY = source + y * 8; - for (int x = 0; x < l; ++x) { - ::memcpy(destY + 8 * x, sourceY + x * h, 8 * sizeof(int16_t)); - } - } - auto hRemain = h - hR; - if (hRemain > 0) { - auto destY = dest + hP * 8 * l; - auto sourceY = source + hP * 8; - for (int x = 0; x < l; ++x) { - ::memcpy(destY + 8 * x, sourceY + x * h, hRemain * sizeof(int16_t)); - } - } - return; - } - int lC8 = (int)l / 8; - auto lR = lC8 * 8; - if (hP > 0 && lC8 > 0) { - MNNPackC8_BF16(destFloat, sourceFloat, l, h); - } - for (int y = hR; y < h; ++y) { - auto yR = y % 8; - auto yC = hP; - for (int x = 0; x < l; ++x) { - dest[x * 8 + yR + yC * 8 * l] = source[x + y * l]; - } - } - for (int y = 0; y < hR; ++y) { - auto yR = y % 8; - auto yC = y / 8; - for (int x = lR; x < l; ++x) { - dest[x * 8 + yR + yC * 8 * l] = source[x + y * l]; - } - } -} - -#else -void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t l, bool transpose) { - int16_t* dest = (int16_t*)destFloat; - int16_t* source = (int16_t*)sourceFloat; - if (!transpose) { - auto hP = h / 4; - auto hR = hP * 4; - if (hR != h) { - ::memset(dest, 0, UP_DIV(h, 4) * 4 * l * sizeof(int16_t)); - } - for (int y = 0; y < hP; ++y) { - auto destY = dest + y * 4 * l; - auto sourceY = source + y * 4; - for (int x = 0; x < l; ++x) { - ::memcpy(destY + 4 * x, sourceY + x * h, 4 * sizeof(int16_t)); - } - } - auto hRemain = h - hR; - if (hRemain > 0) { - auto destY = dest + hP * 4 * l; - auto sourceY = source + hP * 4; - for (int x = 0; x < l; ++x) { - ::memcpy(destY + 4 * x, sourceY + x * h, hRemain * sizeof(int16_t)); - } - } - return; - } - int offset[2] = { - (int)l, - (int)l, - }; - MNNPackC4_BF16(destFloat, sourceFloat, l, h, offset); -} -#endif // __aarch64__ -#endif // MNN_SUPPORT_BF16 - diff --git a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit.S b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit.S index cdcd1226f..72ff71423 100644 --- a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit.S +++ b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit.S @@ -15,14 +15,24 @@ .align 5 asm_function MNNGemmInt8AddBiasScale_16x4_Unit - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + const float* extraScale = nullptr; + const float* extraBias = nullptr; +}; +*/ //void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, // size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t real) { @@ -42,23 +52,25 @@ ldr lr, [r6, #4] vpush {q4-q7} -// Branch1: input is int8_t, output is float32, DO NOT USE "scale". -// Branch2: input is int8_t, output is float32. USE "scale", DO NOT USE "minValue" and "maxValue". -// Branch3: input is int8_t, output is int8_t. USE "scale", "minValue" and "maxValue". - ldr r7, [r6, #16] // r7: useInt8 -cmp r7, #1 -beq InitBranch3 - -InitBranch2: -mov r7, #-0x80 // Branch2 do not use "minValue", so set r7 as a flag to decide the branch. -b Start -InitBranch3: -ldr r7, [r6, #8] -ldr r6, [r6, #12] -b Start +ldr r12, [r6, #28] // srcKernelSum +str r12, [sp, #4] +ldr r12, [r6, #32] // weightBias +str r12, [sp, #8] +ldr r12, [r6, #36] // f32minmax +str r12, [sp, #12] +ldr r12, [r6, #8] // int8 max +str r12, [sp, #16] +ldr r12, [r6, #12] // int8 min +str r12, [sp, #20] +ldr r12, [r6, #40] // blockNum +mul r12, r12, r3 // src_depth_quad=src_depth_quad*blockNum +lsl r12, r12, #6 // weight_stride = src_depth_quad*LP*HP +str r12, [sp, #24] +ldr r12, [r6, #48] // extraScale +str r12, [sp, #28] Start: cmp r10, #2 @@ -66,7 +78,7 @@ blt L1LoopDz L2LoopDz: mov r10, r1 - + str r2, [sp, #32] // store weight ptr subs r12, r3, #1 // first four output vld1.8 {q2}, [r1]! @@ -143,8 +155,7 @@ L2LoopDz: L2LoopSzEnd: L2Quan: - vld1.s32 {q4}, [lr]! - vld1.f32 {q5}, [r8]! + vld1.f32 {q5}, [r8]! // scale vpadd.s32 d16, d16, d17 vpadd.s32 d20, d20, d21 @@ -157,31 +168,85 @@ L2LoopDz: vpadd.s32 d30, d30, d31 // q8,q9 - vdup.32 q2, r6 - vdup.32 q3, r7 + vpadd.s32 d16, d16, d18 vpadd.s32 d17, d20, d22 vpadd.s32 d18, d24, d26 vpadd.s32 d19, d28, d30 - vaddq.s32 q0, q8, q4 - vaddq.s32 q1, q9, q4 - - vcvt.f32.s32 q0, q0 - vcvt.f32.s32 q1, q1 + // vaddq.s32 q0, q8, q4 // add bias + // vaddq.s32 q1, q9, q4 - vmov.f32 q10, #0.5 - vmov.f32 q11, #-0.5 + vcvt.f32.s32 q0, q8 + vcvt.f32.s32 q1, q9 - vmulq.f32 q0, q0, q5 + vmulq.f32 q0, q0, q5 // mul scale vmulq.f32 q1, q1, q5 - cmp r7, #-0x80 + // extra scale if has + ldr r6, [sp, #28] + cmp r6, #0 + beq L2_MLA + vld1.f32 {d10[0]}, [r6]! // tile0 + vld1.f32 {d10[1]}, [r6] // tile1 + vmulq.f32 q0, q0, d10[0] + vmulq.f32 q1, q1, d10[1] + + L2_MLA: + ldr r6, [sp, #4] // srcKernelSum + vld1.f32 {d12[0]}, [r6]! // tile 0 + vld1.f32 {d12[1]}, [r6] // tile 1 + ldr r6, [sp, #8] // weightBias + vld1.f32 {q7}, [r6]! + str r6, [sp, #8] // update next 4 weightBias + + vmla.f32 q0, q7, d12[0] + vmla.f32 q1, q7, d12[1] + + cmp r7, #0 bne L2QuanUseInt8 + + L2_ADD_BIAS: + cmp lr, #0 + beq L2_ADD_DSTV + vld1.f32 {q4}, [lr]! // bias + vadd.f32 q0, q0, q4 // bias + vadd.f32 q1, q1, q4 + b L2_POST + + L2_ADD_DSTV: + vld1.f32 {q4, q5}, [r0] + vadd.f32 q0, q0, q4 + vadd.f32 q1, q1, q5 + + L2_POST: + ldr r6, [sp, #12] // fp32 minmax + cmp r6, #0 + beq L2_STORE + vld1.f32 {d20[0]}, [r6]! + vld1.f32 {d22[0]}, [r6] + vdup.f32 q10, d20[0] + vdup.f32 q11, d22[0] + vmax.f32 q0, q0, q10 + vmax.f32 q1, q1, q10 + vmin.f32 q0, q0, q11 + vmin.f32 q1, q1, q11 + + L2_STORE: vst1.f32 {q0, q1}, [r0], r4 b L2LoopCheck L2QuanUseInt8: + vld1.f32 {q4}, [lr]! // bias + vadd.f32 q0, q0, q4 // bias + vadd.f32 q1, q1, q4 + + vmov.f32 q10, #0.5 + vmov.f32 q11, #-0.5 + ldr r6, [sp, #16] + vdup.32 q3, r6 // max + ldr r6, [sp, #20] + vdup.32 q2, r6 // min vcgt.f32 q12, q0, #0 vcgt.f32 q13, q1, #0 vbsl.f32 q12, q10, q11 @@ -201,17 +266,20 @@ L2LoopDz: vqmovn.s16 d6, q2 - vst1.s8 d6, [r0], r4 + vst1.s8 {d6}, [r0], r4 L2LoopCheck: subs r5, r5, #1 mov r1, r10 + ldr r2, [sp, #32] // origin weight ptr + ldr r6, [sp, #24] // weight stride + add r2, r2, r6 // next oc4 weight ptr bne L2LoopDz b End L1LoopDz: mov r10, r1 - + str r2, [sp, #32] // store weight ptr subs r12, r3, #1 // first four output vld1.8 {q2}, [r1]! @@ -259,35 +327,74 @@ L1LoopDz: L1LoopSzEnd: L1Quan: - vld1.s32 {q4}, [lr]! - vld1.f32 {q5}, [r8]! + //vld1.f32 {q4}, [lr]! // bias + vld1.f32 {q5}, [r8]! // scale vpadd.s32 d16, d16, d17 vpadd.s32 d20, d20, d21 vpadd.s32 d18, d18, d19 vpadd.s32 d22, d22, d23 - // q8,q9 - vdup.32 q2, r6 - vdup.32 q3, r7 + // q8 vpadd.s32 d16, d16, d18 vpadd.s32 d17, d20, d22 - vaddq.s32 q0, q8, q4 - - vcvt.f32.s32 q0, q0 - - vmov.f32 q10, #0.5 - vmov.f32 q11, #-0.5 - + // vaddq.s32 q0, q8, q4 + vcvt.f32.s32 q0, q8 vmulq.f32 q0, q0, q5 - - cmp r7, #-0x80 + // extra scale if has + ldr r6, [sp, #28] + cmp r6, #0 + beq L1_MLA + vld1.f32 {d10[0]}, [r6] // tile0 + vmulq.f32 q0, q0, d10[0] + + L1_MLA: + ldr r6, [sp, #4] // srcKernelSum + vld1.f32 {d12[0]}, [r6] // tile 0 + ldr r6, [sp, #8] // weightBias + vld1.f32 {q7}, [r6]! + str r6, [sp, #8] // update next 4 weightBias + vmla.f32 q0, q7, d12[0] + //vadd.f32 q0, q0, q4 + + cmp r7, #0 bne L1QuanUseInt8 + + cmp lr, #0 + beq L1_ADD_DSTV + vld1.f32 {q4}, [lr]! // bias + vadd.f32 q0, q0, q4 + b L1_POST + + L1_ADD_DSTV: + vld1.f32 {q4}, [r0] + vadd.f32 q0, q0, q4 + + L1_POST: + ldr r6, [sp, #12] // fp32 minmax + cmp r6, #0 + beq L1_STORE + + vld1.f32 {d20[0]}, [r6]! + vld1.f32 {d22[0]}, [r6] + vdup.f32 q10, d20[0] + vdup.f32 q11, d22[0] + vmax.f32 q0, q0, q10 + vmin.f32 q0, q0, q11 + L1_STORE: vst1.f32 {q0}, [r0], r4 b L1LoopCheck L1QuanUseInt8: + vld1.f32 {q4}, [lr]! // bias + vadd.f32 q0, q0, q4 + vmov.f32 q10, #0.5 + vmov.f32 q11, #-0.5 + ldr r6, [sp, #16] + vdup.32 q3, r6 // max + ldr r6, [sp, #20] + vdup.32 q2, r6 // min vcgt.f32 q12, q0, #0 vbsl.f32 q12, q10, q11 vbsl.f32 q13, q10, q11 @@ -301,10 +408,13 @@ L1LoopDz: vqmovn.s16 d6, q2 - vst1.s32 d6[0], [r0], r4 + vst1.s32 {d6[0]}, [r0], r4 L1LoopCheck: subs r5, r5, #1 mov r1, r10 + ldr r2, [sp, #32] // origin weight ptr + ldr r6, [sp, #24] // weight stride + add r2, r2, r6 // next oc4 weight ptr bne L1LoopDz End: diff --git a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S index a77529575..25c9e5359 100644 --- a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S +++ b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S @@ -43,22 +43,18 @@ ldr lr, [r6, #4] vpush {q4-q7} -// Branch1: input is int8_t, output is float32, DO NOT USE "scale". -// Branch2: input is int8_t, output is float32. USE "scale", DO NOT USE "minValue" and "maxValue". -// Branch3: input is int8_t, output is int8_t. USE "scale", "minValue" and "maxValue". - -ldr r7, [r6, #16] // r7: useInt8 -cmp r7, #1 -beq InitBranch3 - -InitBranch2: -mov r7, #-0x80 // Branch2 do not use "minValue", so set r7 as a flag to decide the branch. -b Start - -InitBranch3: -ldr r7, [r6, #8] -ldr r6, [r6, #12] -b Start +// Only int8 output use this kernel. + +ldr r12, [r6, #28] // srcKernelSum +str r12, [sp, #4] +ldr r12, [r6, #32] // weightBias +str r12, [sp, #8] +ldr r12, [r6, #36] // f32minmax +str r12, [sp, #12] +ldr r12, [r6, #8] // int8 max +str r12, [sp, #16] +ldr r12, [r6, #12] // int8 min +str r12, [sp, #20] Start: cmp r10, #2 @@ -132,10 +128,11 @@ L2LoopDz: vpaddl.s16 q7, q15 L2Quan: - vld1.s32 {q14}, [lr]! + vld1.f32 {q14}, [lr]! // bias + vld1.f32 {q15}, [r8]! // scale vpadd.s32 d20, d0, d1 vpadd.s32 d21, d2, d3 - vld1.f32 {q15}, [r8]! + vpadd.s32 d22, d4, d5 vpadd.s32 d23, d6, d7 vpadd.s32 d24, d8, d9 @@ -149,24 +146,35 @@ L2LoopDz: vpadd.s32 d18, d24, d25 vpadd.s32 d19, d26, d27 - vaddq.s32 q0, q8, q14 - vaddq.s32 q1, q9, q14 + //vaddq.s32 q0, q8, q14 // add bias + //vaddq.s32 q1, q9, q14 - vcvt.f32.s32 q0, q0 - vcvt.f32.s32 q1, q1 - vmulq.f32 q0, q0, q15 + vcvt.f32.s32 q0, q8 + vcvt.f32.s32 q1, q9 + vmulq.f32 q0, q0, q15 // mul scale vmulq.f32 q1, q1, q15 - cmp r7, #-0x80 - bne L2QuanUseInt8 - vst1.f32 {q0, q1}, [r0], r4 - b L2LoopCheck + ldr r6, [sp, #4] // srcKernelSum + vld1.f32 {d12[0]}, [r6]! // tile 0 + vld1.f32 {d12[1]}, [r6] // tile 1 + ldr r6, [sp, #8] // weightBias + vld1.f32 {q7}, [r6]! + str r6, [sp, #8] // update next 4 weightBias + + vmla.f32 q0, q7, d12[0] // add srcKernelSum x weightBias + vmla.f32 q1, q7, d12[1] + + vadd.f32 q0, q0, q14 // add bias + vadd.f32 q1, q1, q14 + L2QuanUseInt8: vmov.f32 q10, #0.5 vmov.f32 q11, #-0.5 - vdup.32 q2, r6 - vdup.32 q3, r7 + ldr r6, [sp, #16] + vdup.32 q2, r6 // max + ldr r6, [sp, #20] + vdup.32 q3, r6 // min vcgt.f32 q12, q0, #0 vcgt.f32 q13, q1, #0 @@ -177,10 +185,10 @@ L2LoopDz: vcvt.s32.f32 q0, q0 vcvt.s32.f32 q1, q1 - vmax.s32 q0, q2, q0 - vmax.s32 q1, q2, q1 - vmin.s32 q0, q3, q0 - vmin.s32 q1, q3, q1 + vmin.s32 q0, q2, q0 + vmin.s32 q1, q2, q1 + vmax.s32 q0, q3, q0 + vmax.s32 q1, q3, q1 vqmovn.s32 d4, q0 vqmovn.s32 d5, q1 @@ -242,7 +250,7 @@ L1LoopDz: vpaddl.s16 q3, q11 L1Quan: - vld1.s32 {q14}, [lr]! + vld1.f32 {q14}, [lr]! vpadd.s32 d20, d0, d1 vpadd.s32 d21, d2, d3 vld1.f32 {q15}, [r8]! @@ -253,21 +261,26 @@ L1LoopDz: vpadd.s32 d16, d20, d21 vpadd.s32 d17, d22, d23 - vaddq.s32 q0, q8, q14 + //vaddq.s32 q0, q8, q14 - vcvt.f32.s32 q0, q0 - vdup.32 q2, r6 - vdup.32 q3, r7 + vcvt.f32.s32 q0, q8 vmulq.f32 q0, q0, q15 - cmp r7, #-0x80 - bne L1QuanUseInt8 - vst1.f32 {q0, q1}, [r0], r4 - b L1LoopCheck + + ldr r6, [sp, #4] // srcKernelSum + vld1.f32 {d12[0]}, [r6] // tile 0 + ldr r6, [sp, #8] // weightBias + vld1.f32 {q7}, [r6]! + str r6, [sp, #8] // update next 4 weightBias + vmla.f32 q0, q7, d12[0] + vadd.f32 q0, q0, q14 // add bias L1QuanUseInt8: vmov.f32 q10, #0.5 vmov.f32 q11, #-0.5 - + ldr r6, [sp, #16] + vdup.32 q3, r6 // max + ldr r6, [sp, #20] + vdup.32 q2, r6 // min vcgt.f32 q12, q0, #0 vbsl.f32 q12, q10, q11 vbsl.f32 q13, q10, q11 diff --git a/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S new file mode 100644 index 000000000..6368937de --- /dev/null +++ b/source/backend/cpu/arm/arm32/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S @@ -0,0 +1,392 @@ +// +// MNNGemmInt8AddBiasScale_16x4_w4_Unit.S +// MNN +// +// Created by MNN on 2019/06/11. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __arm__ +#ifndef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +asm_function MNNGemmInt8AddBiasScale_16x4_w4_Unit +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + const float* extraScale = nullptr; + const float* extraBias = nullptr; +}; +*/ + +//void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, +// size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t real) { + +//Auto: r0: dst*, r1: src*, r2:weight*, r3: src_depth_quad +// Load from sp: r4: dst_step, r5: dst_depth_quad, r6: post, r10: real +// Load from post: r8: scale, lr: bias, r7: maxValue, r6: minValue + +push {r4-r8, r10, lr} // avoid to touch platform-register r-9 + +ldr r4, [sp, #28] +ldr r5, [sp, #32] +ldr r6, [sp, #36] +ldr r10, [sp, #40] +ldr r8, [r6, #0] +ldr lr, [r6, #4] + +vpush {q4-q7} + +// Branch1: input is int8_t, output is float32, DO NOT USE "scale". +// Branch2: input is int8_t, output is float32. USE "scale", DO NOT USE "minValue" and "maxValue". +// Branch3: input is int8_t, output is int8_t. USE "scale", "minValue" and "maxValue". + + +ldr r7, [r6, #16] // r7: useInt8 + +ldr r12, [r6, #28] // srcKernelSum +str r12, [sp, #4] +ldr r12, [r6, #32] // weightBias +str r12, [sp, #8] +ldr r12, [r6, #36] // f32minmax +str r12, [sp, #12] +ldr r12, [r6, #40] // blockNum +mul r12, r12, r3 // src_depth_quad=src_depth_quad*blockNum +lsl r12, r12, #6 // weight_stride = src_depth_quad*LP*HP +str r12, [sp, #16] +ldr r12, [r6, #48] // extraScale +str r12, [sp, #20] + +Start: +cmp r10, #2 +blt L1LoopDz + +L2LoopDz: + mov r10, r1 + str r2, [sp, #24] // store weight ptr + subs r12, r3, #1 + // first four output + vld1.8 {q2}, [r1]! + vld1.8 {q4}, [r2]! // weight, d8,d9,d10,d11 + // int4->int8 + vmov.i8 q5, #15 + vand.i8 q5, q5, q4 + vshr.u8 q4, q4, #4 + vzip.8 q4, q5 + + vmull.s8 q0, d4, d8 + vmull.s8 q1, d4, d10 + vmlal.s8 q0, d5, d9 + vmlal.s8 q1, d5, d11 + vpaddl.s16 q8, q0 + vpaddl.s16 q9, q1 + vld1.8 {q6}, [r2]! // weight,d12,d13,d14,d15 + // int4->int8 + vmov.i8 q7, #15 + vand.i8 q7, q7, q6 + vshr.u8 q6, q6, #4 + vzip.8 q6, q7 + + vmull.s8 q0, d4, d12 + vmull.s8 q1, d4, d14 + vmlal.s8 q0, d5, d13 + vmlal.s8 q1, d5, d15 + vpaddl.s16 q10, q0 + vld1.8 {q3}, [r1]! + vpaddl.s16 q11, q1 + // second four output + vmull.s8 q0, d6, d8 + vmull.s8 q1, d6, d10 + vmlal.s8 q0, d7, d9 + vmlal.s8 q1, d7, d11 + vpaddl.s16 q12, q0 + vpaddl.s16 q13, q1 + + vmull.s8 q0, d6, d12 + vmull.s8 q1, d6, d14 + vmlal.s8 q0, d7, d13 + vmlal.s8 q1, d7, d15 + vpaddl.s16 q14, q0 + vpaddl.s16 q15, q1 + + beq L2LoopSzEnd + + L2LoopSz: + // first four output + vld1.8 {q2}, [r1]! + vld1.8 {q4}, [r2]! + // int4->int8 + vmov.i8 q5, #15 + vand.i8 q5, q5, q4 + vshr.u8 q4, q4, #4 + vzip.8 q4, q5 + vmull.s8 q0, d4, d8 + vmull.s8 q1, d4, d10 + vmlal.s8 q0, d5, d9 + vmlal.s8 q1, d5, d11 + vld1.8 {q6}, [r2]! + // int4->int8 + vmov.i8 q7, #15 + vand.i8 q7, q7, q6 + vshr.u8 q6, q6, #4 + vzip.8 q6, q7 + vpadal.s16 q8, q0 + vpadal.s16 q9, q1 + + vmull.s8 q0, d4, d12 + vmull.s8 q1, d4, d14 + vmlal.s8 q0, d5, d13 + vmlal.s8 q1, d5, d15 + vld1.8 {q3}, [r1]! + vpadal.s16 q10, q0 + vpadal.s16 q11, q1 + // second four output + vmull.s8 q0, d6, d8 + vmull.s8 q1, d6, d10 + vmlal.s8 q0, d7, d9 + vmlal.s8 q1, d7, d11 + vpadal.s16 q12, q0 + vpadal.s16 q13, q1 + + vmull.s8 q0, d6, d12 + vmull.s8 q1, d6, d14 + vmlal.s8 q0, d7, d13 + vmlal.s8 q1, d7, d15 + vpadal.s16 q14, q0 + vpadal.s16 q15, q1 + + subs r12, r12, #1 + bne L2LoopSz + + L2LoopSzEnd: + + L2Quan: + vld1.f32 {q5}, [r8]! // scale + + vpadd.s32 d16, d16, d17 + vpadd.s32 d20, d20, d21 + vpadd.s32 d18, d18, d19 + vpadd.s32 d22, d22, d23 + + vpadd.s32 d24, d24, d25 + vpadd.s32 d28, d28, d29 + vpadd.s32 d26, d26, d27 + vpadd.s32 d30, d30, d31 + + // q8,q9 + + vpadd.s32 d16, d16, d18 + vpadd.s32 d17, d20, d22 + vpadd.s32 d18, d24, d26 + vpadd.s32 d19, d28, d30 + + // vaddq.s32 q0, q8, q4 // add bias + // vaddq.s32 q1, q9, q4 + + vcvt.f32.s32 q0, q0 + vcvt.f32.s32 q1, q1 + + vmulq.f32 q0, q0, q5 // mul scale + vmulq.f32 q1, q1, q5 + + // extra scale if has + ldr r6, [sp, #20] + cmp r6, #0 + beq L2_MLA + vld1.f32 {d10[0]}, [r6]! // tile0 + vld1.f32 {d10[1]}, [r6] // tile1 + vmulq.f32 q0, q0, d10[0] + vmulq.f32 q1, q1, d10[1] + + L2_MLA: + ldr r6, [sp, #4] // srcKernelSum + vld1.f32 {d12[0]}, [r6]! // tile 0 + vld1.f32 {d12[1]}, [r6] // tile 1 + ldr r6, [sp, #8] // weightBias + vld1.f32 {q7}, [r6]! + str r6, [sp, #8] // update next 4 weightBias + + vmla.f32 q0, q7, d12[0] + vmla.f32 q1, q7, d12[1] + + L2_POST: + ldr r6, [sp, #12] // fp32 minmax + cmp r6, #0 + beq L2_STORE + vld1.f32 {d20[0]}, [r6]! + vld1.f32 {d22[0]}, [r6] + vdup.f32 q10, d20[0] + vdup.f32 q11, d22[0] + vmax.f32 q0, q0, q10 + vmax.f32 q1, q1, q10 + vmin.f32 q0, q0, q11 + vmin.f32 q1, q1, q11 + + L2_STORE: + vst1.f32 {q0, q1}, [r0], r4 + +L2LoopCheck: + subs r5, r5, #1 + mov r1, r10 + ldr r2, [sp, #24] // origin weight ptr + ldr r6, [sp, #16] // weight stride + add r2, r2, r6 // next oc4 weight ptr + bne L2LoopDz + +b End + +L1LoopDz: + mov r10, r1 + str r2, [sp, #24] // store weight ptr + subs r12, r3, #1 + // first four output + vld1.8 {q2}, [r1]! + vld1.8 {q4}, [r2]! + // int4->int8 + vmov.i8 q5, #15 + vand.i8 q5, q5, q4 + vshr.u8 q4, q4, #4 + vzip.8 q4, q5 + + vmull.s8 q0, d4, d8 + vmull.s8 q1, d4, d10 + vmlal.s8 q0, d5, d9 + vmlal.s8 q1, d5, d11 + vpaddl.s16 q8, q0 + vpaddl.s16 q9, q1 + vld1.8 {q6}, [r2]! + // int4->int8 + vmov.i8 q7, #15 + vand.i8 q7, q7, q6 + vshr.u8 q6, q6, #4 + vzip.8 q6, q7 + + vmull.s8 q0, d4, d12 + vmull.s8 q1, d4, d14 + vmlal.s8 q0, d5, d13 + vmlal.s8 q1, d5, d15 + vpaddl.s16 q10, q0 + add r1, r1, #16 + vpaddl.s16 q11, q1 + + beq L1LoopSzEnd + + L1LoopSz: + // first four output + vld1.8 {q2}, [r1]! + vld1.8 {q4}, [r2]! + // int4->int8 + vmov.i8 q5, #15 + vand.i8 q5, q5, q4 + vshr.u8 q4, q4, #4 + vzip.8 q4, q5 + vmull.s8 q0, d4, d8 + vmull.s8 q1, d4, d10 + vmlal.s8 q0, d5, d9 + vmlal.s8 q1, d5, d11 + vld1.8 {q6}, [r2]! + // int4->int8 + vmov.i8 q7, #15 + vand.i8 q7, q7, q6 + vshr.u8 q6, q6, #4 + vzip.8 q6, q7 + vpadal.s16 q8, q0 + vpadal.s16 q9, q1 + + vmull.s8 q0, d4, d12 + vmull.s8 q1, d4, d14 + vmlal.s8 q0, d5, d13 + vmlal.s8 q1, d5, d15 + add r1, r1, #16 + vpadal.s16 q10, q0 + vpadal.s16 q11, q1 + + subs r12, r12, #1 + bne L1LoopSz + + L1LoopSzEnd: + L1Quan: + //vld1.f32 {q4}, [lr]! // bias + vld1.f32 {q5}, [r8]! // scale + + vpadd.s32 d16, d16, d17 + vpadd.s32 d20, d20, d21 + vpadd.s32 d18, d18, d19 + vpadd.s32 d22, d22, d23 + + // q8 + vpadd.s32 d16, d16, d18 + vpadd.s32 d17, d20, d22 + + // vaddq.s32 q0, q8, q4 + vcvt.f32.s32 q0, q0 + vmulq.f32 q0, q0, q5 + // extra scale if has + ldr r6, [sp, #20] + cmp r6, #0 + beq L1_MLA + vld1.f32 {d10[0]}, [r6] // tile0 + vmulq.f32 q0, q0, d10[0] + + L1_MLA: + ldr r6, [sp, #4] // srcKernelSum + vld1.f32 {d12[0]}, [r6] // tile 0 + ldr r6, [sp, #8] // weightBias + vld1.f32 {q7}, [r6]! + str r6, [sp, #8] // update next 4 weightBias + vmla.f32 q0, q7, d12[0] + //vadd.f32 q0, q0, q4 + + cmp lr, #0 + beq L1_ADD_DSTV + vld1.f32 {q4}, [lr]! // bias + vadd.f32 q0, q0, q4 + b L1_POST + + L1_ADD_DSTV: + vld1.f32 {q4}, [r0] + vadd.f32 q0, q0, q4 + + L1_POST: + ldr r6, [sp, #12] // fp32 minmax + cmp r6, #0 + beq L1_STORE + + vld1.f32 {d20[0]}, [r6]! + vld1.f32 {d22[0]}, [r6] + vdup.f32 q10, d20[0] + vdup.f32 q11, d22[0] + vmax.f32 q0, q0, q10 + vmin.f32 q0, q0, q11 + L1_STORE: + vst1.f32 {q0}, [r0], r4 + +L1LoopCheck: + subs r5, r5, #1 + mov r1, r10 + ldr r2, [sp, #24] // origin weight ptr + ldr r6, [sp, #16] // weight stride + add r2, r2, r6 // next oc4 weight ptr + bne L1LoopDz + +End: +vpop {q4-q7} +pop {r4-r8, r10, pc} + +#endif +#endif diff --git a/source/backend/cpu/arm/arm32/MNNLineDepthWiseInt8AddBiasScaleUnit.S b/source/backend/cpu/arm/arm32/MNNLineDepthWiseInt8AddBiasScaleUnit.S index e905a3703..55460f637 100644 --- a/source/backend/cpu/arm/arm32/MNNLineDepthWiseInt8AddBiasScaleUnit.S +++ b/source/backend/cpu/arm/arm32/MNNLineDepthWiseInt8AddBiasScaleUnit.S @@ -54,7 +54,7 @@ ldr r11, [r3, #8] vdup.i8 d23, r11 ldr r11, [r3, #12] vdup.i8 d22, r11 -ldr r3, [r3, #4] +ldr r3, [r3, #44] // bias mul r10, r6, r8 sub lr, lr, r10 diff --git a/source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx1.S b/source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx1.S index 9297eb959..f9e220cb5 100644 --- a/source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx1.S +++ b/source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx1.S @@ -7,6 +7,24 @@ // // +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + +}; + */ + #ifdef __arm__ #ifndef __aarch64__ @@ -58,7 +76,7 @@ loop_e8: ldr r5, [sp, #(push_registers_bytes + 4)] ldr r6, [sp, #(push_registers_bytes + 8)] ldr r7, [r4] - ldr r8, [r4, #4] + ldr r8, [r4, #44] push {r0-r2, r10} ldr lr, [r6], #4 // dataOffset add r1, r1, lr @@ -135,7 +153,7 @@ loop_e4: ldr r5, [sp, #(push_registers_bytes + 4)] ldr r6, [sp, #(push_registers_bytes + 8)] ldr r7, [r4] - ldr r8, [r4, #4] + ldr r8, [r4, #44] push {r0-r2, r10} ldr lr, [r6], #4 // dataOffset add r1, r1, lr @@ -196,7 +214,7 @@ loop_e2: ldr r5, [sp, #(push_registers_bytes + 4)] ldr r6, [sp, #(push_registers_bytes + 8)] ldr r7, [r4] - ldr r8, [r4, #4] + ldr r8, [r4, #44] push {r0-r2, r10} ldr lr, [r6], #4 // dataOffset add r1, r1, lr @@ -255,7 +273,7 @@ loop_e1: ldr r5, [sp, #(push_registers_bytes + 4)] ldr r6, [sp, #(push_registers_bytes + 8)] ldr r7, [r4] - ldr r8, [r4, #4] + ldr r8, [r4, #44] push {r0-r2, r10} ldr lr, [r6], #4 // dataOffset diff --git a/source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx4.S b/source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx4.S index c6a13b39f..01ce74082 100644 --- a/source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx4.S +++ b/source/backend/cpu/arm/arm32/MNNPackedSparseQuantMatMulEpx4.S @@ -96,7 +96,7 @@ loop_e8: ldr r5, [sp, #(push_registers_bytes + 4)] ldr r6, [sp, #(push_registers_bytes + 8)] ldr r7, [r4] - ldr r8, [r4, #4] + ldr r8, [r4, #44] push {r0-r2, r10} ldr r10, [r3, #20] // cStride ldr lr, [r6], #4 // dataOffset @@ -175,7 +175,7 @@ loop_e4: ldr r5, [sp, #(push_registers_bytes + 4)] ldr r6, [sp, #(push_registers_bytes + 8)] ldr r7, [r4] - ldr r8, [r4, #4] + ldr r8, [r4, #44] push {r0-r2, r10} ldr r10, [r3, #20] // cStride ldr lr, [r6], #4 // dataOffset @@ -233,7 +233,7 @@ loop_e2: ldr r5, [sp, #(push_registers_bytes + 4)] ldr r6, [sp, #(push_registers_bytes + 8)] ldr r7, [r4] - ldr r8, [r4, #4] + ldr r8, [r4, #44] push {r0-r2, r10} ldr r10, [r3, #20] // cStride ldr lr, [r6], #4 // dataOffset @@ -294,7 +294,7 @@ loop_e1: ldr r5, [sp, #(push_registers_bytes + 4)] ldr r6, [sp, #(push_registers_bytes + 8)] ldr r7, [r4] - ldr r8, [r4, #4] + ldr r8, [r4, #44] push {r0-r2, r10} ldr r10, [r3, #20] // cStride ldr lr, [r6], #4 // dataOffset diff --git a/source/backend/cpu/arm/arm32/bf16/MNNPackC4ForMatMul_A_BF16.S b/source/backend/cpu/arm/arm32/bf16/MNNPackC4ForMatMul_A_BF16.S index 663ffae68..54744568e 100644 --- a/source/backend/cpu/arm/arm32/bf16/MNNPackC4ForMatMul_A_BF16.S +++ b/source/backend/cpu/arm/arm32/bf16/MNNPackC4ForMatMul_A_BF16.S @@ -25,7 +25,7 @@ ldr r6, [r2, #12] // xOffset // eReal -> eReal * 4 * sizeof(float) // eDest -> eDest * sizeof(float) mov r12, #2 // sizeof(int16_t) -mov lr, #8 // sizeof(int16_t) * 4 +mov lr, #16 // sizeof(float) * 4 mul r4, lr, r4 mul r11, r12, r11 mul r6, lr, r6 @@ -39,7 +39,7 @@ push {r0, r1} ldr r1, [r1, #0] // Compute dest ptr: r0 = r0 + eOffset * sizeof(float) + lOffset * eDest * sizeof(float) -; mov lr, #2 //sizeof(int16_t) + mul r7, r11, r7 mul r8, r12, r8 add r0, r0, r7 @@ -55,18 +55,36 @@ bne Right LoopL4: mov r2, r1 .macro MAIN_TRANSPOSE - vld1.16 {d16}, [r1], r6 // load size: 4 * sizeof(int16_t) - vld1.16 {d19}, [r1], r6 - vld1.16 {d22}, [r1], r6 - vld1.16 {d25}, [r1], r6 - vld1.16 {d17}, [r1], r6 - vld1.16 {d20}, [r1], r6 - vld1.16 {d23}, [r1], r6 - vld1.16 {d26}, [r1], r6 - vld1.16 {d18}, [r1], r6 - vld1.16 {d21}, [r1], r6 - vld1.16 {d24}, [r1], r6 - vld1.16 {d27}, [r1], r6 + + vld1.32 {q0}, [r1], r6 // load size: 4 * sizeof(float) + vld1.32 {q1}, [r1], r6 + vld1.32 {q2}, [r1], r6 + vld1.32 {q3}, [r1], r6 + + vshrn.i32 d16, q0, #16 + vshrn.i32 d19, q1, #16 + vshrn.i32 d22, q2, #16 + vshrn.i32 d25, q3, #16 + + vld1.32 {q0}, [r1], r6 // load size: 4 * sizeof(float) + vld1.32 {q1}, [r1], r6 + vld1.32 {q2}, [r1], r6 + vld1.32 {q3}, [r1], r6 + + vshrn.i32 d17, q0, #16 + vshrn.i32 d20, q1, #16 + vshrn.i32 d23, q2, #16 + vshrn.i32 d26, q3, #16 + + vld1.32 {q0}, [r1], r6 // load size: 4 * sizeof(float) + vld1.32 {q1}, [r1], r6 + vld1.32 {q2}, [r1], r6 + vld1.32 {q3}, [r1], r6 + + vshrn.i32 d18, q0, #16 + vshrn.i32 d21, q1, #16 + vshrn.i32 d24, q2, #16 + vshrn.i32 d27, q3, #16 // transpose each 4 16-bit elements in 2 d_n vectors, by transpose 16-bit and scale up transpose 32-bit. vtrn.16 d16, d19 @@ -145,7 +163,9 @@ LoopE1: cmp r5, #4 blt LoopE1L3 LoopE1L4: - vld1.16 {d0}, [r1], r4 + vld1.32 {q0}, [r1], r4 + vshrn.i32 d0, q0, #16 + vst1.16 {d0[0]}, [r0], r11 vst1.16 {d0[1]}, [r0], r11 vst1.16 {d0[2]}, [r0], r11 @@ -157,7 +177,9 @@ LoopE1: LoopE1L3: cmp r5, #3 blt LoopE1L2 - vld1.16 {d0}, [r1], r4 + vld1.32 {q0}, [r1], r4 + vshrn.i32 d0, q0, #16 + vst1.16 {d0[0]}, [r0], r11 vst1.16 {d0[1]}, [r0], r11 vst1.16 {d0[2]}, [r0], r11 @@ -167,7 +189,9 @@ LoopE1: LoopE1L2: cmp r5, #2 blt LoopE1L1 - vld1.16 {d0}, [r1], r4 + vld1.32 {q0}, [r1], r4 + vshrn.i32 d0, q0, #16 + vst1.16 {d0[0]}, [r0], r11 vst1.16 {d0[1]}, [r0], r11 sub r5, r5, #2 @@ -175,7 +199,8 @@ LoopE1: LoopE1L1: cmp r5, #1 blt LoopE1End - vld1.16 {d0[0]}, [r1], r4 + vld1.32 {d0}, [r1], r4 + vshrn.i32 d0, q0, #16 vst1.16 {d0[0]}, [r0], r11 LoopE1End: diff --git a/source/backend/cpu/arm/arm32/bf16/MNNPackC4_BF16.S b/source/backend/cpu/arm/arm32/bf16/MNNPackC4_BF16.S index 70b9e61e4..844f32d48 100644 --- a/source/backend/cpu/arm/arm32/bf16/MNNPackC4_BF16.S +++ b/source/backend/cpu/arm/arm32/bf16/MNNPackC4_BF16.S @@ -39,8 +39,8 @@ mul r4, r2, r3 cmp r4, #0 beq UpEnd -//r4: srcDepthOffset:srcArea*sizeof(int16_t) -mov r4, #2 +//r4: srcDepthOffset:srcArea*sizeof(float) +mov r4, #4 mul r4, lr, r4 //r10 -> 4 * (dstArea * sizeof(int16_t) - area * sizeof(int16_t)) @@ -48,8 +48,8 @@ mov r12, #8 sub r10, r10, r2 mul r10, r12, r10 -//lr -> (srcArea * sizeof(int16_t) - area * sizeof(int16_t)) -mov r12, #2 +//lr -> (srcArea * sizeof(float) - area * sizeof(float)) +mov r12, #4 sub lr, lr, r2 mul lr, r12, lr @@ -65,10 +65,15 @@ mov r8, r2 cmp r8, #3 ble UpL4AreaRemain UpL4AreaLoop: -vld1.16 {d0}, [r1]! // load 4 elements of 16-bit into 64bit vector register d0 -vld1.16 {d1}, [r5]! -vld1.16 {d2}, [r6]! -vld1.16 {d3}, [r7]! +vld1.32 {q0}, [r1]! // load 4 elements of 16-bit into 64bit vector register d0 +vld1.32 {q1}, [r5]! +vld1.32 {q2}, [r6]! +vld1.32 {q3}, [r7]! +vshrn.i32 d0, q0, #16 +vshrn.i32 d1, q1, #16 +vshrn.i32 d2, q2, #16 +vshrn.i32 d3, q3, #16 + // transpose // no suitable instruction to transpose int16_t type vst4.16 {d0, d1, d2, d3}, [r0]! sub r8, r8, #4 @@ -79,10 +84,11 @@ UpL4AreaRemain: cmp r8, #0 beq UpL4AreaRemainEnd UpL4AreaRemainLoop: -vld1.16 {d0[0]}, [r1]! -vld1.16 {d0[1]}, [r5]! -vld1.16 {d0[2]}, [r6]! -vld1.16 {d0[3]}, [r7]! +vld1.32 {d0[0]}, [r1]! +vld1.32 {d0[1]}, [r5]! +vld1.32 {d1[0]}, [r6]! +vld1.32 {d1[1]}, [r7]! +vshrn.i32 d0, q0, #16 vst1.16 {d0}, [r0]! @@ -104,10 +110,14 @@ mov r8, r2 cmp r8, #3 ble UpL3AreaRemain UpL3AreaLoop: -vld1.16 {d0}, [r1]! +vld1.32 {q0}, [r1]! +vld1.32 {q1}, [r5]! +vld1.32 {q2}, [r6]! +vshrn.i32 d0, q0, #16 +vshrn.i32 d1, q1, #16 +vshrn.i32 d2, q2, #16 vmov.i16 d3, #0 -vld1.16 {d1}, [r5]! -vld1.16 {d2}, [r6]! + // transpose // no suitable instruction to transpose int16_t type vst4.16 {d0, d1, d2, d3}, [r0]! sub r8, r8, #4 @@ -117,10 +127,11 @@ bge UpL3AreaLoop cmp r8, #0 beq UpL3AreaRemainEnd UpL3AreaRemain: -vmov.i16 d0, #0 -vld1.16 {d0[0]}, [r1]! -vld1.16 {d0[1]}, [r5]! -vld1.16 {d0[2]}, [r6]! +vmov.i32 q0, #0 +vld1.32 {d0[0]}, [r1]! +vld1.32 {d0[1]}, [r5]! +vld1.32 {d1[0]}, [r6]! +vshrn.i32 d0, q0, #16 vst1.16 {d0}, [r0]! @@ -139,11 +150,13 @@ mov r8, r2 cmp r8, #3 ble UpL2AreaRemain UpL2AreaLoop: -vld1.16 {d0}, [r1]! +vld1.32 {q0}, [r1]! +vld1.32 {q1}, [r5]! +vshrn.i32 d0, q0, #16 +vshrn.i32 d1, q1, #16 + vmov.i16 d3, #0 -vld1.16 {d1}, [r5]! vmov.i16 d2, #0 -// transpose // no suitable instruction to transpose int16_t type vst4.16 {d0, d1, d2, d3}, [r0]! sub r8, r8, #4 cmp r8, #4 @@ -152,9 +165,11 @@ bge UpL2AreaLoop cmp r8, #0 beq UpL2AreaRemainEnd UpL2AreaRemain: -vmov.i16 d0, #0 -vld1.16 {d0[0]}, [r1]! -vld1.16 {d0[1]}, [r5]! +vmov.i32 q0, #0 +vld1.32 {d0[0]}, [r1]! +vld1.32 {d0[1]}, [r5]! + +vshrn.i32 d0, q0, #16 vst1.16 {d0}, [r0]! @@ -171,7 +186,8 @@ mov r8, r2 cmp r8, #3 ble UpL1AreaRemain UpL1AreaLoop: -vld1.16 {d0}, [r1]! +vld1.32 {q0}, [r1]! +vshrn.i32 d0, q0, #16 vmov.i16 d3, #0 vmov.i16 d1, #0 vmov.i16 d2, #0 @@ -184,8 +200,9 @@ bge UpL1AreaLoop cmp r8, #0 beq UpL1AreaRemainEnd UpL1AreaRemain: -vmov.i16 d0, #0 -vld1.16 {d0[0]}, [r1]! +vmov.i16 q0, #0 +vld1.32 {d0[0]}, [r1]! +vshrn.i32 d0, q0, #16 vst1.16 {d0}, [r0]! diff --git a/source/backend/cpu/arm/arm32/bf16/MNNPackedMatMulRemain_BF16.S b/source/backend/cpu/arm/arm32/bf16/MNNPackedMatMulRemain_BF16.S index 252f1956a..ea64bd0fd 100644 --- a/source/backend/cpu/arm/arm32/bf16/MNNPackedMatMulRemain_BF16.S +++ b/source/backend/cpu/arm/arm32/bf16/MNNPackedMatMulRemain_BF16.S @@ -25,6 +25,8 @@ ldr r4, [sp, #32] ldr r6, [sp, #36] ldr r7, [sp, #40] ldr r12, [r4, #0] +// aStride is compute as float, divided 2 to as bf16 +lsr r12, r12, #1 cmp r6, #0 beq Start vld1.32 {q3}, [r6] @@ -61,8 +63,8 @@ LoopE4: bne LoopE4L cmp r6, #0 beq StoreE4 - vld1.16 {d28}, [r7]! // load 4 * sizeof(int16_t) - vshll.s16 q14, d28, #16 // shift left long of each int16_t as float32 + vld1.32 {q14}, [r7]! // load 4 * sizeof(float) + vmla.f32 q8, q14, d6[1] vmla.f32 q9, q14, d6[1] vmla.f32 q10, q14, d6[1] @@ -81,20 +83,17 @@ LoopE4: StoreE4: ldr r8, [r4, #20] + lsr r8, r8, #1 // bExtraStride is compute as float, divide to bf16 add r11, r11, r8 ldr r8, [r4, #12] - vshrn.i32 d16, q8, #16 // shift right 16bit of each float32 as int16_t - vshrn.i32 d17, q9, #16 - vshrn.i32 d18, q10, #16 - vshrn.i32 d19, q11, #16 - vst1.16 {d16, d17}, [lr]! - vst1.16 {d18, d19}, [lr], r8 - sub lr, lr, #16 + vst1.32 {q8, q9}, [lr]! + vst1.32 {q10, q11}, [lr], r8 + sub lr, lr, #32 // revert to next C4 begin subs r5, r5, #1 // move 4 colum along lP dim. lP = l / 4 bne LoopE4H sub r3, r3, #4 // move 4 colum along e dim. - add r0, r0, #32 // move address of 4 * 4 * sizeof(int16_t) + add r0, r0, #64 // move address of 4 * 4 * sizeof(float) add r1, r1, #8 // move address of 4 * sizeof(int16_t) in src tile block cmp r3, #4 pop {r7} @@ -125,8 +124,7 @@ LoopE1: bne LoopE1L cmp r6, #0 beq StoreE1 - vld1.16 {d28}, [r7]! // load 4 * sizeof(int16_t) - vshll.s16 q14, d28, #16 // shift left long of each int16_t as float32 + vld1.32 {q14}, [r7]! // load 4 * sizeof(float) vmla.f32 q15, q14, d6[1] PostTreatE1: @@ -135,15 +133,15 @@ LoopE1: StoreE1: ldr r8, [r4, #20] + lsr r8, r8, #1 add r11, r11, r8 ldr r8, [r4, #12] - vshrn.i32 d30, q15, #16 // shift right 16bit of each float32 as int16_t - vst1.16 {d30}, [lr], r8 + vst1.16 {q15}, [lr], r8 subs r5, r5, #1 bne LoopE1H subs r3, r3, #1 - add r0, r0, #8 // move address of 4 * sizeof(int16_t) + add r0, r0, #16 // move address of 4 * sizeof(float) add r1, r1, #2 // move address of 1 * sizeof(int16_t) pop {r7} bne LoopE1 diff --git a/source/backend/cpu/arm/arm32/bf16/MNNPackedMatMul_BF16.S b/source/backend/cpu/arm/arm32/bf16/MNNPackedMatMul_BF16.S index 3b9ab3d48..22719baf4 100644 --- a/source/backend/cpu/arm/arm32/bf16/MNNPackedMatMul_BF16.S +++ b/source/backend/cpu/arm/arm32/bf16/MNNPackedMatMul_BF16.S @@ -30,8 +30,9 @@ add r4, r4, #3 ldr r8, [r3, #12]//cStride ldr r3, [r3, #20]//bExtraStride lsr r4, r4, #2 +lsr r3, r3, #1 //bExtraStride is compute as fp32, turn to bf16 -sub r8, r8, #96 // after segment "Store", total line stride is CStride, all vst. offset is 12 * 4 * size_t(int16_t) = 96byte +sub r8, r8, #192 // after segment "Store", total line stride is CStride, all vst. offset is 12 * 4 * size_t(float) = 192byte vpush {q4-q7} // q0, q1, q2: src @@ -95,10 +96,8 @@ LoopH: cmp r5, #0 beq Store vld1.32 {q0}, [r5] // parameter remains float - cmp r6, #0 - beq LoadOrigin - vld1.16 {d6}, [r6]! // load 4 * sizeof(int16_t) - vshll.s16 q3, d6, #16 // shift left long of each int16_t as int32_t + vld1.32 {q3}, [r6]! // load 4 * sizeof(float) + vmla.f32 q4, q3, d0[1] vmla.f32 q5, q3, d0[1] vmla.f32 q6, q3, d0[1] @@ -114,44 +113,6 @@ LoopH: b PostTreat - LoadOrigin: - mov r11, r0 - vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) - vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t - vshll.s16 q1, d2, #16 - vmla.f32 q4, q1, d0[1] - vmla.f32 q5, q2, d0[1] - - vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) - vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t - vshll.s16 q1, d2, #16 - vmla.f32 q6, q1, d0[1] - vmla.f32 q7, q2, d0[1] - - vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) - vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t - vshll.s16 q1, d2, #16 - vmla.f32 q8, q1, d0[1] - vmla.f32 q9, q2, d0[1] - - vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) - vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t - vshll.s16 q1, d2, #16 - vmla.f32 q10, q1, d0[1] - vmla.f32 q11, q2, d0[1] - - vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) - vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t - vshll.s16 q1, d2, #16 - vmla.f32 q12, q1, d0[1] - vmla.f32 q13, q2, d0[1] - - vld1.16 {d2, d3}, [r11]! // load 2 * 4 * sizeof(int16_t) - vshll.s16 q2, d3, #16 // shift left long of each int16_t as int32_t - vshll.s16 q1, d2, #16 - vmla.f32 q14, q1, d0[1] - vmla.f32 q15, q2, d0[1] - PostTreat: vdup.f32 q2, d1[0] // min vdup.f32 q1, d1[1] // max @@ -183,20 +144,13 @@ LoopH: vmin.f32 q15, q15, q1 Store: - vshrn.i32 d8, q4, #16 // !!caution: these instructions has relying, eg: d10 must be written after reading q5. shift right 16bit of each float32 as int16_t - vshrn.i32 d9, q5, #16 - vshrn.i32 d10, q6, #16 - vshrn.i32 d11, q7, #16 - vshrn.i32 d12, q8, #16 - vshrn.i32 d13, q9, #16 - vshrn.i32 d14, q10, #16 - vshrn.i32 d15, q11, #16 - vshrn.i32 d16, q12, #16 - vshrn.i32 d17, q13, #16 - vshrn.i32 d18, q14, #16 - vshrn.i32 d19, q15, #16 - - vstm r0!, {d8, d9, d10, d11, d12, d13, d14, d15, d16, d17, d18, d19} + + vst1.32 {q4, q5}, [r0]! + vst1.32 {q6, q7}, [r0]! + vst1.32 {q8, q9}, [r0]! + vst1.32 {q10, q11}, [r0]! + vst1.32 {q12, q13}, [r0]! + vst1.32 {q14, q15}, [r0]! add r0, r0, r8 add r2, r2, r3 diff --git a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_16x4_Unit.S b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_16x4_Unit.S index 5621142ff..d31d57ad7 100644 --- a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_16x4_Unit.S +++ b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_16x4_Unit.S @@ -13,15 +13,81 @@ .text .align 5 +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro ReLU_FP32_4 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmin \s3\().4s, \s3\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s + fmax \s3\().4s, \s3\().4s, \z0\().4s +.endm +.macro ReLU_FP32_3 s0, s1, s2, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s +.endm +.macro ReLU_FP32_2 s0, s1, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s +.endm +.macro ReLU_FP32_1 s0, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s +.endm +.macro MUL_SCALE4 s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s + fmul \d3\().4s, \d3\().4s, \s\().4s +.endm +.macro MUL_SCALE3 s, d0, d1, d2 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s +.endm +.macro MUL_SCALE2 s, d0, d1 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s +.endm +.macro MUL_SCALE1 s, d0 + fmul \d0\().4s, \d0\().4s, \s\().4s +.endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm + asm_function MNNGemmInt8AddBiasScale_16x4_Unit -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum; + const int32_t* bias; + float* extraScale; +}; +*/ //void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, // size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realSize) { @@ -30,7 +96,7 @@ asm_function MNNGemmInt8AddBiasScale_16x4_Unit // x5: dst_depth_quad, x6: post, x7: realSize //Load from post: -// x7: scale, x10: bias, w11: maxValue, w6: minValue, w13: UseInt8 +// x7: scale, x10: bias, w11: maxValue, w6: minValue, w13: UseInt8, x14: srcKernelSum, x12: weightQuantBias mov x8, x7 mov x15, x6 ldr x7, [x15, #0] @@ -38,11 +104,23 @@ ldr x10, [x15, #8] ldr w11, [x15, #16] ldr w6, [x15, #20] ldr w13, [x15, #24] - -stp d14, d15, [sp, #-64]! -stp d12, d13, [sp, #16] -stp d10, d11, [sp, #32] -stp d8, d9, [sp, #48] +ldr x14, [x15, #40] // srcKernelSum +ldr x12, [x15, #48] // weightQuantBias + +stp d14, d15, [sp, #(-16 * 8)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x19, x20, [sp, #(16 * 4)] +stp x21, x22, [sp, #(16 * 5)] +stp x23, x24, [sp, #(16 * 6)] + +ldr x19, [x15, #56] // fp32 min max +ldr x21, [x15, #64] // blockNum +ldr x23, [x15, #80] // extraScale +mul x21, x21, x3 // blockNum * src_depth_quad_perblock +lsl x21, x21, #6 // src_depth_quad* SRC_UNIT * UNIT * sizeof(int8_t) +add x20, x19, #4 Start: cmp x8, #3 @@ -56,9 +134,10 @@ beq L1Dz cmp w13, #1 bne L4LoopDz -//sub x4, x4, #8 // post->scale != nullptr && post->useInt8 == 1. + L4LoopDz: mov x8, x1 + mov x22, x2 ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64 ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 @@ -183,7 +262,6 @@ L4LoopDz: ComputeSum: - ld1 {v0.4s}, [x10], #16 addp v4.4s, v16.4s, v17.4s addp v5.4s, v18.4s, v19.4s addp v6.4s, v20.4s, v21.4s @@ -199,35 +277,69 @@ L4LoopDz: addp v15.4s, v10.4s, v11.4s L4Quan: - ld1 {v1.4s}, [x7], #16 - add v16.4s, v12.4s, v0.4s - add v17.4s, v13.4s, v0.4s - add v18.4s, v14.4s, v0.4s - add v19.4s, v15.4s, v0.4s + ld1 {v1.4s}, [x7], #16 // scalefuse + ld1 {v20.4s}, [x14] // srcKernelSum + ld1 {v21.4s}, [x12], #16 // weightQuanZero - dup v31.16b, w6 // Min - dup v30.16b, w11 // Max + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s + scvtf v6.4s, v14.4s + scvtf v7.4s, v15.4s - scvtf v4.4s, v16.4s - scvtf v5.4s, v17.4s - scvtf v6.4s, v18.4s - scvtf v7.4s, v19.4s + cbz x23, TILE4_MUL_OHE_SCALE + ld1 {v2.4s}, [x23] + MUL_EXTRA_SCALE v2, v4, v5, v6, v7 + + TILE4_MUL_OHE_SCALE: + MUL_SCALE4 v1, v4, v5, v6, v7 + + MLA_WEIGHTZERO v4, v20, v21, 0 + MLA_WEIGHTZERO v5, v20, v21, 1 + MLA_WEIGHTZERO v6, v20, v21, 2 + MLA_WEIGHTZERO v7, v20, v21, 3 - fmul v12.4s, v4.4s, v1.4s - fmul v13.4s, v5.4s, v1.4s - fmul v14.4s, v6.4s, v1.4s - fmul v15.4s, v7.4s, v1.4s cmp w13, #1 beq L4QuantUseInt8 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], x4 + + L4_Add_BIAS: + cbz x10, L4_ADD_DSTV + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + fadd v6.4s, v6.4s, v0.4s + fadd v7.4s, v7.4s, v0.4s + b L4_POST + + L4_ADD_DSTV: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0] + fadd v4.4s, v4.4s, v8.4s + fadd v5.4s, v5.4s, v9.4s + fadd v6.4s, v6.4s, v10.4s + fadd v7.4s, v7.4s, v11.4s + + L4_POST: + cbz x19, L4_STORE + ld1r {v26.4s}, [x19] // f32 min + ld1r {v27.4s}, [x20] // f32 max + ReLU_FP32_4 v4, v5, v6, v7, v26, v27 + + L4_STORE: + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], x4 b L4LoopCheck L4QuantUseInt8: + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + fadd v6.4s, v6.4s, v0.4s + fadd v7.4s, v7.4s, v0.4s - fcvtas v8.4s, v12.4s - fcvtas v9.4s, v13.4s - fcvtas v10.4s, v14.4s - fcvtas v11.4s, v15.4s + dup v31.16b, w6 // Min + dup v30.16b, w11 // Max + fcvtas v8.4s, v4.4s + fcvtas v9.4s, v5.4s + fcvtas v10.4s, v6.4s + fcvtas v11.4s, v7.4s sqxtn v0.4h, v8.4s sqxtn2 v0.8h, v9.4s @@ -243,6 +355,7 @@ L4LoopDz: L4LoopCheck: subs x5, x5, #1 mov x1, x8 + add x2, x22, x21 bne L4LoopDz b End @@ -253,10 +366,11 @@ bne L3LoopDz sub x4, x4, #8 L3LoopDz: mov x8, x1 + mov x22, x2 ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64 ld1 {v4.16b, v5.16b, v6.16b}, [x1], #48 add x1, x1, #16 - + smull v8.8h, v0.8b, v4.8b smull v9.8h, v1.8b, v4.8b smull v10.8h, v2.8b, v4.8b @@ -347,10 +461,6 @@ L3LoopDz: smlal2 v9.8h, v1.16b, v6.16b smlal2 v10.8h, v2.16b, v6.16b smlal2 v11.8h, v3.16b, v6.16b - smlal2 v8.8h, v0.16b, v6.16b - smlal2 v9.8h, v1.16b, v6.16b - smlal2 v10.8h, v2.16b, v6.16b - smlal2 v11.8h, v3.16b, v6.16b sadalp v24.4s, v8.8h sadalp v25.4s, v9.8h @@ -360,7 +470,6 @@ L3LoopDz: bne L3LoopSz L3ComputeSum: - ld1 {v0.4s}, [x10], #16 addp v4.4s, v16.4s, v17.4s addp v5.4s, v18.4s, v19.4s addp v6.4s, v20.4s, v21.4s @@ -374,29 +483,65 @@ L3LoopDz: L3Quan: ld1 {v1.4s}, [x7], #16 - add v16.4s, v12.4s, v0.4s - add v17.4s, v13.4s, v0.4s - add v18.4s, v14.4s, v0.4s - - dup v31.16b, w6 // Min - dup v30.16b, w11 // Max - - scvtf v4.4s, v16.4s - scvtf v5.4s, v17.4s - scvtf v6.4s, v18.4s + ld1 {v20.d}[0], [x14], #8 // srcKernelSum + ld1 {v20.s}[2], [x14] + ld1 {v21.4s}, [x12], #16 // weightQuanZero + + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s + scvtf v6.4s, v14.4s + MUL_SCALE3 v1, v4, v5, v6 + + cbz x23, TILE3_MUL_OHE_SCALE + ld1 {v2.d}[0], [x23], #8 + ld1 {v2.s}[2], [x23] + fmul v4.4s, v4.4s, v2.s[0] + fmul v5.4s, v5.4s, v2.s[1] + fmul v6.4s, v6.4s, v2.s[2] + sub x23, x23, #8 + + TILE3_MUL_OHE_SCALE: + sub x14, x14, #8 + MLA_WEIGHTZERO v4, v20, v21, 0 + MLA_WEIGHTZERO v5, v20, v21, 1 + MLA_WEIGHTZERO v6, v20, v21, 2 - fmul v12.4s, v4.4s, v1.4s - fmul v13.4s, v5.4s, v1.4s - fmul v14.4s, v6.4s, v1.4s cmp w13, #1 beq L3QuantUseInt8 - st1 {v12.4s, v13.4s, v14.4s}, [x0], x4 + + L3_ADD_BIAS: + cbz x10, L3_ADD_DSTV + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + fadd v6.4s, v6.4s, v0.4s + b L3_POST + + L3_ADD_DSTV: + ld1 {v0.4s, v1.4s, v2.4s}, [x0] + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v1.4s + fadd v6.4s, v6.4s, v2.4s + + L3_POST: + cbz x19, L3_STORE + ld1r {v26.4s}, [x19] // f32 min + ld1r {v27.4s}, [x20] // f32 max + ReLU_FP32_3 v4, v5, v6, v26, v27 + L3_STORE: + st1 {v4.4s, v5.4s, v6.4s}, [x0], x4 b L3LoopCheck L3QuantUseInt8: - fcvtas v8.4s, v12.4s - fcvtas v9.4s, v13.4s - fcvtas v10.4s, v14.4s + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + fadd v6.4s, v6.4s, v0.4s + dup v31.16b, w6 // Min + dup v30.16b, w11 // Max + fcvtas v8.4s, v4.4s + fcvtas v9.4s, v5.4s + fcvtas v10.4s, v6.4s sqxtn v0.4h, v8.4s sqxtn2 v0.8h, v9.4s @@ -417,6 +562,7 @@ L3LoopDz: L3LoopCheck: subs x5, x5, #1 mov x1, x8 + add x2, x22, x21 bne L3LoopDz b End @@ -424,6 +570,7 @@ b End L2Dz: L2LoopDz: mov x8, x1 + mov x22, x2 ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64 ld1 {v4.16b, v5.16b}, [x1], #32 @@ -496,7 +643,6 @@ L2LoopDz: L2ComputeSum: - ld1 {v0.4s}, [x10], #16 addp v4.4s, v16.4s, v17.4s addp v5.4s, v18.4s, v19.4s addp v6.4s, v20.4s, v21.4s @@ -507,25 +653,55 @@ L2LoopDz: L2Quan: ld1 {v1.4s}, [x7], #16 - add v16.4s, v12.4s, v0.4s - add v17.4s, v13.4s, v0.4s + ld1 {v20.d}[0], [x14] // srcKernelSum + ld1 {v21.4s}, [x12], #16 // weightQuanZero - dup v31.8b, w6 // Min - dup v30.8b, w11 // Max + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s + MUL_SCALE2 v1, v4, v5 - scvtf v4.4s, v16.4s - scvtf v5.4s, v17.4s + cbz x23, TILE2_MUL_OHE_SCALE + ld1 {v2.d}[0], [x23] + fmul v4.4s, v4.4s, v2.s[0] + fmul v5.4s, v5.4s, v2.s[1] + + TILE2_MUL_OHE_SCALE: + MLA_WEIGHTZERO v4, v20, v21, 0 + MLA_WEIGHTZERO v5, v20, v21, 1 - fmul v12.4s, v4.4s, v1.4s - fmul v13.4s, v5.4s, v1.4s cmp w13, #1 beq L2QuantUseInt8 - st1 {v12.4s, v13.4s}, [x0], x4 + + L2_ADD_BIAS: + cbz x10, L2_ADD_DSTV + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + b L2_POST + + L2_ADD_DSTV: + ld1 {v0.4s, v1.4s}, [x0] + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v1.4s + + L2_POST: + cbz x19, L2_STORE + ld1r {v26.4s}, [x19] // f32 min + ld1r {v27.4s}, [x20] // f32 max + ReLU_FP32_2 v4, v5, v26, v27 + + L2_STORE: + st1 {v4.4s, v5.4s}, [x0], x4 b L2LoopCheck L2QuantUseInt8: - fcvtas v8.4s, v12.4s - fcvtas v9.4s, v13.4s + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + dup v31.8b, w6 // Min + dup v30.8b, w11 // Max + fcvtas v8.4s, v4.4s + fcvtas v9.4s, v5.4s sqxtn v0.4h, v8.4s sqxtn2 v0.8h, v9.4s @@ -540,6 +716,7 @@ L2LoopDz: L2LoopCheck: subs x5, x5, #1 mov x1, x8 + add x2, x22, x21 bne L2LoopDz b End @@ -547,6 +724,7 @@ b End L1Dz: L1LoopDz: mov x8, x1 + mov x22, x2 ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x2], #64 dup v16.4s, wzr dup v17.4s, wzr @@ -599,7 +777,7 @@ L1LoopDz: sadalp v18.4s, v10.8h sadalp v19.4s, v11.8h - ld1 {v0.4s}, [x10], #16 + //ld1 {v0.4s}, [x10], #16 addp v4.4s, v16.4s, v17.4s addp v5.4s, v18.4s, v19.4s @@ -607,22 +785,49 @@ L1LoopDz: L1Quan: ld1 {v1.4s}, [x7], #16 - add v16.4s, v12.4s, v0.4s + ld1 {v20.s}[0], [x14] // srcKernelSum + ld1 {v21.4s}, [x12], #16 // weightQuanZero - dup v31.4s, w6 // Min - dup v30.4s, w11 // Max + scvtf v4.4s, v12.4s + MUL_SCALE1 v1, v4 - scvtf v4.4s, v16.4s + cbz x23, TILE1_MUL_OHE_SCALE + ld1 {v2.s}[0], [x23] + fmul v4.4s, v4.4s, v2.s[0] + + TILE1_MUL_OHE_SCALE: + MLA_WEIGHTZERO v4, v20, v21, 0 - fmul v12.4s, v4.4s, v1.4s cmp w13, #1 beq L1QuantUseInt8 - st1 {v12.4s}, [x0], x4 + + L1_ADD_BIAS: + cbz x10, L1_ADD_DSTV + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + b L1_POST + + L1_ADD_DSTV: + ld1 {v0.4s}, [x0] + fadd v4.4s, v4.4s, v0.4s + + L1_POST: + cbz x19, L1_STORE + ld1r {v26.4s}, [x19] // f32 min + ld1r {v27.4s}, [x20] // f32 max + ReLU_FP32_1 v4, v26, v27 + + L1_STORE: + st1 {v4.4s}, [x0], x4 b L1LoopCheck L1QuantUseInt8: + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + dup v31.4s, w6 // Min + dup v30.4s, w11 // Max - fcvtas v8.4s, v12.4s + fcvtas v8.4s, v4.4s smin v8.4s, v30.4s, v8.4s @@ -635,13 +840,17 @@ L1LoopDz: L1LoopCheck: subs x5, x5, #1 mov x1, x8 + add x2, x22, x21 bne L1LoopDz End: -ldp d8, d9, [sp, #48] -ldp d10, d11, [sp, #32] -ldp d12, d13, [sp, #16] -ldp d14, d15, [sp], #64 +ldp x23, x24, [sp, #(16 * 6)] +ldp x21, x22, [sp, #(16 * 5)] +ldp x19, x20, [sp, #(16 * 4)] +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 8) ret #endif diff --git a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S index a6e0142d1..16b2837b7 100644 --- a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S +++ b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_16x4_Unit_FAST.S @@ -13,14 +13,54 @@ .text .align 5 +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro ReLU_FP32 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmin \s3\().4s, \s3\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s + fmax \s3\().4s, \s3\().4s, \z0\().4s +.endm +.macro ReLU_FP32_3 s0, s1, s2, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s +.endm +.macro ReLU_FP32_2 s0, s1, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s +.endm +.macro ReLU_FP32_1 s0, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s +.endm + asm_function MNNGemmInt8AddBiasScale_16x4_Unit_FAST -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -//}; +/* +struct QuanPostTreatParameters { + const float* scale; + const float* bias; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; +}; +*/ //void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, // size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t remain) { @@ -30,6 +70,7 @@ asm_function MNNGemmInt8AddBiasScale_16x4_Unit_FAST //Load from post: // x7: scale, x10: bias, w11: maxValue, w13: minValue, w12: useInt8 +// x19: srcKernelSum, x20: weightQuanBias mov x8, x7 ldr x7, [x6, #0] ldr x10, [x6, #8] @@ -37,10 +78,14 @@ ldr w11, [x6, #16] ldr w13, [x6, #20] ldr w12, [x6, #24] -stp d14, d15, [sp, #-64]! -stp d12, d13, [sp, #16] -stp d10, d11, [sp, #32] -stp d8, d9, [sp, #48] +stp d14, d15, [sp, #(-16 * 6)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x21, x22, [sp, #(16 * 4)] +stp x19, x20, [sp, #(16 * 5)] +ldr x19, [x6, #40] +ldr x20, [x6, #48] cmp x8, #3 beq L3Dz @@ -183,33 +228,47 @@ L4LoopDz: addp v14.4s, v20.4s, v21.4s addp v15.4s, v22.4s, v23.4s - add v16.4s, v12.4s, v0.4s - add v17.4s, v13.4s, v0.4s + //add v16.4s, v12.4s, v0.4s + //add v17.4s, v13.4s, v0.4s + //add v18.4s, v14.4s, v0.4s + //add v19.4s, v15.4s, v0.4s L4Quan: - ld1 {v1.4s}, [x7], #16 - add v18.4s, v14.4s, v0.4s - add v19.4s, v15.4s, v0.4s - - scvtf v4.4s, v16.4s - scvtf v5.4s, v17.4s - scvtf v6.4s, v18.4s - scvtf v7.4s, v19.4s + ld1 {v1.4s}, [x7], #16 // scale + ld1 {v2.4s}, [x19] // x kernel sum + ld1 {v24.4s}, [x20], #16 // weight quan zeropoint - dup v31.4s, w13 // Min - dup v30.4s, w11 // Max + TILE4_INT2FLOAT: + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s + scvtf v6.4s, v14.4s + scvtf v7.4s, v15.4s fmul v12.4s, v4.4s, v1.4s fmul v13.4s, v5.4s, v1.4s fmul v14.4s, v6.4s, v1.4s fmul v15.4s, v7.4s, v1.4s + MLA_WEIGHTZERO v12, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v13, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v14, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v15, v2, v24, 3 // tile:3, oc:0-3 + + + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s + fadd v15.4s, v15.4s, v0.4s + cmp w12, #1 beq L4QuantUseInt8 + ReLU_FP32 v12, v13, v14, v15, v26, v27 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], x4 b L4LoopCheck L4QuantUseInt8: + dup v31.4s, w13 // Min + dup v30.4s, w11 // Max fcvtas v8.4s, v12.4s fcvtas v9.4s, v13.4s fcvtas v10.4s, v14.4s @@ -243,6 +302,7 @@ L4LoopCheck: b End L3Dz: +add x3, x19, #8 cmp w12, #1 bne L3LoopDz sub x4, x4, #8 @@ -346,35 +406,43 @@ L3LoopDz: addp v19.4s, v9.4s, v8.4s addp v20.4s, v7.4s, v6.4s addp v21.4s, v5.4s, v4.4s + addp v12.4s, v16.4s, v17.4s addp v13.4s, v18.4s, v19.4s - ld1 {v0.4s}, [x10], #16 addp v14.4s, v20.4s, v21.4s - - add v16.4s, v12.4s, v0.4s - add v17.4s, v13.4s, v0.4s + ld1 {v0.4s}, [x10], #16 L3Quan: ld1 {v1.4s}, [x7], #16 - add v18.4s, v14.4s, v0.4s + ld1 {v2.d}[0], [x19] // x kernel sum + ld1 {v2.s}[2], [x6] + ld1 {v24.4s}, [x20], #16 // weight quan zeropoint - scvtf v4.4s, v16.4s - scvtf v5.4s, v17.4s - scvtf v6.4s, v18.4s - - dup v31.4s, w13 // Min - dup v30.4s, w11 // Max + TILE3_INT2FLOAT: + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s + scvtf v6.4s, v14.4s fmul v12.4s, v4.4s, v1.4s fmul v13.4s, v5.4s, v1.4s fmul v14.4s, v6.4s, v1.4s + MLA_WEIGHTZERO v12, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v13, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v14, v2, v24, 2 // tile:2, oc:0-3 + + + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + fadd v14.4s, v14.4s, v0.4s cmp w12, #1 beq L3QuantUseInt8 + ReLU_FP32_3 v12, v13, v14, v26, v27 st1 {v12.4s, v13.4s, v14.4s}, [x0], x4 b L3LoopCheck L3QuantUseInt8: - + dup v31.4s, w13 // Min + dup v30.4s, w11 // Max fcvtas v8.4s, v12.4s fcvtas v9.4s, v13.4s fcvtas v10.4s, v14.4s @@ -480,29 +548,33 @@ L2LoopDz: addp v19.4s, v9.4s, v8.4s addp v12.4s, v16.4s, v17.4s addp v13.4s, v18.4s, v19.4s - ld1 {v0.4s}, [x10], #16 - - add v16.4s, v12.4s, v0.4s - add v17.4s, v13.4s, v0.4s L2Quan: ld1 {v1.4s}, [x7], #16 + ld1 {v2.d}[0], [x19] // x kernel sum + ld1 {v24.4s}, [x20], #16 // weight quan zeropoint + ld1 {v0.4s}, [x10], #16 - scvtf v4.4s, v16.4s - scvtf v5.4s, v17.4s - - dup v31.4s, w13 // Min - dup v30.4s, w11 // Max + TILE2_INT2FLOAT: + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s fmul v12.4s, v4.4s, v1.4s fmul v13.4s, v5.4s, v1.4s + MLA_WEIGHTZERO v12, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v13, v2, v24, 1 // tile:1, oc:0-3 + fadd v12.4s, v12.4s, v0.4s + fadd v13.4s, v13.4s, v0.4s + cmp w12, #1 beq L2QuantUseInt8 + ReLU_FP32_2 v12, v13, v26, v27 st1 {v12.4s, v13.4s}, [x0], x4 b L2LoopCheck L2QuantUseInt8: - + dup v31.4s, w13 // Min + dup v30.4s, w11 // Max fcvtas v8.4s, v12.4s fcvtas v9.4s, v13.4s @@ -580,25 +652,27 @@ L1LoopDz: addp v12.4s, v16.4s, v17.4s ld1 {v0.4s}, [x10], #16 - add v16.4s, v12.4s, v0.4s - L1Quan: ld1 {v1.4s}, [x7], #16 + ld1 {v2.s}[0], [x19] // x kernel sum + ld1 {v24.4s}, [x20], #16 // weight quan zeropoint - scvtf v4.4s, v16.4s - - dup v31.4s, w13 // Min - dup v30.4s, w11 // Max - + TILE1_INT2FLOAT: + scvtf v4.4s, v12.4s fmul v12.4s, v4.4s, v1.4s + MLA_WEIGHTZERO v12, v2, v24, 0 // tile:0, oc:0-3 + fadd v12.4s, v12.4s, v0.4s + cmp w12, #1 beq L1QuantUseInt8 + ReLU_FP32_1 v12, v26, v27 st1 {v12.4s}, [x0], x4 b L1LoopCheck L1QuantUseInt8: - + dup v31.4s, w13 // Min + dup v30.4s, w11 // Max fcvtas v8.4s, v12.4s smin v8.4s, v30.4s, v8.4s @@ -615,10 +689,12 @@ L1LoopCheck: bne L1LoopDz End: +ldp x19, x20, [sp, #80] +ldp x21, x22, [sp, #64] ldp d8, d9, [sp, #48] ldp d10, d11, [sp, #32] ldp d12, d13, [sp, #16] -ldp d14, d15, [sp], #64 +ldp d14, d15, [sp], #96 ret #endif diff --git a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S index 943e5655f..d1fdd68bd 100644 --- a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S +++ b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV82_Unit.S @@ -12,11 +12,25 @@ .text .align 5 -.macro SET_BIAS s, d0, d1, d2, d3 - mov \d0\().16b, \s\().16b - mov \d1\().16b, \s\().16b - mov \d2\().16b, \s\().16b - mov \d3\().16b, \s\().16b +.macro ADD_BIAS_FLOAT d0, d1, d2, d3, z0 + fadd \d0\().4s, \d0\().4s, \z0\().4s + fadd \d1\().4s, \d1\().4s, \z0\().4s + fadd \d2\().4s, \d2\().4s, \z0\().4s + fadd \d3\().4s, \d3\().4s, \z0\().4s +.endm + +.macro ADD_FLOAT d0, d1, d2, d3, s0, s1, s2, s3 + fadd \d0\().4s, \d0\().4s, \s0\().4s + fadd \d1\().4s, \d1\().4s, \s1\().4s + fadd \d2\().4s, \d2\().4s, \s2\().4s + fadd \d3\().4s, \d3\().4s, \s3\().4s +.endm + +.macro SET_BIAS d0, d1, d2, d3 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 .endm .macro Int32ToFloat z0, z1, z2, z3 scvtf \z0\().4s, \z0\().4s @@ -30,6 +44,12 @@ fmul \d2\().4s, \d2\().4s, \s\().4s fmul \d3\().4s, \d3\().4s, \s\().4s .endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm .macro FloatToInt32 z0, z1, z2, z3 fcvtas \z0\().4s, \z0\().4s fcvtas \z1\().4s, \z1\().4s @@ -50,15 +70,38 @@ Int16ToInt8_ONE \s0, \s1, \d0 Int16ToInt8_ONE \s2, \s3, \d1 .endm +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro ReLU_FP32 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmin \s3\().4s, \s3\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s + fmax \s3\().4s, \s3\().4s, \z0\().4s +.endm asm_function MNNGemmInt8AddBiasScale_ARMV82_Unit - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -//}; +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum; + const int32_t* bias; + float* extraScale; +}; +*/ //void MNNGemmInt8AddBiasScale_ARMV82_Unit(int8_t* dst, const int8_t* src, // const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, @@ -67,29 +110,37 @@ asm_function MNNGemmInt8AddBiasScale_ARMV82_Unit //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //x5:dst_depth_quad, x6: parameters, x7: realDstCount -//Load from x7: x8: scale, x9: bias, w12: maxValue, w13: minValue, w28: useInt8 +//Load from x6: x8: scale, x9: bias, w28: useInt8, x25: xKernelSum, x26: weightQuantBias, x23: fp32minmax +// x24: extraScale ldr x8, [x6, #0] ldr x9, [x6, #8] -ldr w12, [x6, #16] -ldr w13, [x6, #20] -stp d14, d15, [sp, #(-16 * 7)]! +stp d14, d15, [sp, #(-16 * 9)]! stp d12, d13, [sp, #(16 * 1)] stp d10, d11, [sp, #(16 * 2)] stp d8, d9, [sp, #(16 * 3)] stp x21, x22, [sp, #(16 * 4)] stp x19, x20, [sp, #(16 * 5)] stp x27, x28, [sp, #(16 * 6)] +stp x25, x26, [sp, #(16 * 7)] +stp x23, x24, [sp, #(16 * 8)] + +ldr x27, [x6, #64] // blockNum +mul x27, x27, x3 // blockNum * src_depth_quad_perblock +lsl x15, x27, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT + ldr w28, [x6, #24] // useInt8 +ldr x25, [x6, #40] // xKernelSum +ldr x26, [x6, #48] // weightQuantBias +ldr x24, [x6, #80] // extraScale +add x23, x6, #16 // int8 max ptr mov x21, #4 // sizeof(int8_t) * UNIT cbnz w28, Start mov x21, #16 // sizeof(float) * UNIT +ldr x23, [x6, #56] // fp32minmax Start: -lsl x15, x3, #4 // x15 = src_depth_quad * UNIT * SRC_UNIT mov x22, #48 // src_steps -dup v7.16b, w12 // max -dup v6.16b, w13 // min TILE_12: cmp x7, #12 @@ -97,16 +148,18 @@ TILE_12: cmp x5, #2 blt L4LoopDz_TILE_12 L8LoopDz_TILE_12: - ld1 {v0.4s, v1.4s}, [x9], #32 // bias + //ld1 {v0.4s, v1.4s}, [x9], #32 // bias mov x11, x1 mov x13, x3 + mov x20, x0 // tag dst address + mov x27, x2 - SET_BIAS v0, v8, v9, v10, v11 - SET_BIAS v0, v12, v13, v14, v15 - SET_BIAS v0, v16, v17, v18, v19 - SET_BIAS v1, v20, v21, v22, v23 - SET_BIAS v1, v24, v25, v26, v27 - SET_BIAS v1, v28, v29, v30, v31 + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + SET_BIAS v20, v21, v22, v23 + SET_BIAS v24, v25, v26, v27 + SET_BIAS v28, v29, v30, v31 L8LoopSz_TILE_12: ld1 {v3.16b}, [x2], x15 // weight @@ -141,26 +194,108 @@ L8LoopDz_TILE_12: bne L8LoopSz_TILE_12 L8LoopSzEnd_TILE_12: - add x2, x2, x15 + // add x2, x2, x15 + add x2, x27, x15, LSL #1 sub x5, x5, #2 L8Tile12Quan: ld1 {v0.4s, v1.4s}, [x8], #32 // scale + ld1 {v2.4s, v3.4s, v4.4s}, [x25] // x kernel sum + ld1 {v5.4s, v6.4s}, [x26], #32 // weight quan zeropoint Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 Int32ToFloat v20, v21, v22, v23 Int32ToFloat v24, v25, v26, v27 Int32ToFloat v28, v29, v30, v31 + MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v12, v13, v14, v15 MUL_SCALE v0, v16, v17, v18, v19 MUL_SCALE v1, v20, v21, v22, v23 MUL_SCALE v1, v24, v25, v26, v27 MUL_SCALE v1, v28, v29, v30, v31 + + cbz x24, TILE12_L8_MLA + ld1 {v0.4s, v1.4s}, [x24], #32 + ld1 {v7.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + MUL_EXTRA_SCALE v7, v16, v17, v18, v19 + MUL_EXTRA_SCALE v0, v20, v21, v22, v23 + MUL_EXTRA_SCALE v1, v24, v25, v26, v27 + MUL_EXTRA_SCALE v7, v28, v29, v30, v31 + sub x24, x24, #32 + + TILE12_L8_MLA: + MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 + + MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7 + MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 + MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 + MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 + cmp w28, #1 beq L8Tile12QuanUseInt8 sub x4, x4, #128 + + cbz x9, TILE12_ADD_DSTV + TILE12_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x9], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v0 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + ADD_BIAS_FLOAT v24, v25, v26, v27, v1 + ADD_BIAS_FLOAT v28, v29, v30, v31, v1 + b TILE12_POST + + TILE12_ADD_DSTV: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20], #64 + ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3 + ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], x4 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20], #64 + ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3 + ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3 + ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7 + + TILE12_POST: + cbz x23, TILE12_STORE + ld1r {v0.4s}, [x23], #4 // f32 min + ld1r {v1.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v0, v1 + ReLU_FP32 v12, v13, v14, v15, v0, v1 + ReLU_FP32 v16, v17, v18, v19, v0, v1 + ReLU_FP32 v20, v21, v22, v23, v0, v1 + ReLU_FP32 v24, v25, v26, v27, v0, v1 + ReLU_FP32 v28, v29, v30, v31, v0, v1 + sub x23, x23, #4 + + TILE12_STORE: st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 @@ -171,6 +306,19 @@ L8LoopDz_TILE_12: b L8Tile12LoopCheck L8Tile12QuanUseInt8: + ld1r {v7.4s}, [x23], #4 // int8 max + ld1r {v6.4s}, [x23] // int8 min + ld1 {v0.4s, v1.4s}, [x9], #32 + dup v7.16b, v7.b[0] + dup v6.16b, v6.b[0] + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v0 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + ADD_BIAS_FLOAT v24, v25, v26, v27, v1 + ADD_BIAS_FLOAT v28, v29, v30, v31, v1 + + sub x23, x23, #4 FloatToInt32 v8, v9, v10, v11 FloatToInt32 v12, v13, v14, v15 FloatToInt32 v16, v17, v18, v19 @@ -207,11 +355,9 @@ L8LoopDz_TILE_12: blt End L4LoopDz_TILE_12: - ld1 {v0.4s}, [x9] // bias - - SET_BIAS v0, v8, v9, v10, v11 - SET_BIAS v0, v12, v13, v14, v15 - SET_BIAS v0, v16, v17, v18, v19 + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 L4LoopSz_TILE_12: ld1 {v3.16b}, [x2], #16 // weight @@ -235,15 +381,66 @@ L4LoopDz_TILE_12: L4Tile12Quan: ld1 {v0.4s}, [x8] // scale + ld1 {v2.4s, v3.4s, v4.4s}, [x25]// x kernel sum + ld1 {v5.4s}, [x26], #16 // weight quan zeropoint Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v12, v13, v14, v15 MUL_SCALE v0, v16, v17, v18, v19 + + cbz x24, TILE12_L4_MLA + ld1 {v0.4s, v1.4s}, [x24], #32 + ld1 {v7.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + MUL_EXTRA_SCALE v7, v16, v17, v18, v19 + sub x24, x24, #32 + + TILE12_L4_MLA: + MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 cmp w28, #1 beq L4Tile12QuanUseInt8 sub x4, x4, #128 + + TILE12_L4_ADD_BIAS: + cbz x9, TILE12_L4_ADD_DSTV + ld1 {v0.4s}, [x9] // bias + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v0 + b TILE12_L4_POST + + TILE12_L4_ADD_DSTV: + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0] + sub x0, x0, #128 + ADD_FLOAT v8, v9, v10, v11, v20, v21, v22, v23 + ADD_FLOAT v12, v13, v14, v15, v24, v25, v26, v27 + ADD_FLOAT v16, v17, v18, v19, v28, v29, v30, v31 + + TILE12_L4_POST: + cbz x23, TILE12_L4_STORE + ld1r {v6.4s}, [x23], #4 // f32 min + ld1r {v7.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v6, v7 + ReLU_FP32 v12, v13, v14, v15, v6, v7 + ReLU_FP32 v16, v17, v18, v19, v6, v7 + sub x23, x23, #4 + TILE12_L4_STORE: st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 @@ -251,6 +448,15 @@ L4LoopDz_TILE_12: b End L4Tile12QuanUseInt8: + ld1r {v7.4s}, [x23], #4 // int8 max + ld1r {v6.4s}, [x23] // int8 min + ld1 {v0.4s}, [x9] // bias + dup v7.16b, v7.b[0] + dup v6.16b, v6.b[0] + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v0 + sub x23, x23, #4 FloatToInt32 v8, v9, v10, v11 FloatToInt32 v12, v13, v14, v15 FloatToInt32 v16, v17, v18, v19 @@ -276,17 +482,19 @@ TILE_8: mov x14, x5 mov x19, x8 // scale mov x20, x9 // bias + mov x6, x26 // weightQuantBias cmp x5, #2 blt L4LoopDz_TILE_8 L8LoopDz_TILE_8: - ld1 {v0.4s, v1.4s}, [x20], #32 // bias + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias mov x11, x1 mov x13, x3 + mov x27, x12 - SET_BIAS v0, v8, v9, v10, v11 - SET_BIAS v0, v12, v13, v14, v15 - SET_BIAS v1, v16, v17, v18, v19 - SET_BIAS v1, v20, v21, v22, v23 + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + SET_BIAS v20, v21, v22, v23 L8LoopSz_TILE_8: ld1 {v3.16b}, [x12], x15 // weight @@ -313,11 +521,14 @@ L8LoopDz_TILE_8: bne L8LoopSz_TILE_8 L8LoopSzEnd_TILE_8: - add x12, x12, x15 + //add x12, x12, x15 + add x12, x27, x15, LSL #1 sub x14, x14, #2 L8Tile8Quan: ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.4s, v3.4s}, [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 Int32ToFloat v16, v17, v18, v19 @@ -326,9 +537,68 @@ L8LoopDz_TILE_8: MUL_SCALE v0, v12, v13, v14, v15 MUL_SCALE v1, v16, v17, v18, v19 MUL_SCALE v1, v20, v21, v22, v23 + + cbz x24, TILE8_L8_MLA + ld1 {v0.4s, v1.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + MUL_EXTRA_SCALE v0, v16, v17, v18, v19 + MUL_EXTRA_SCALE v1, v20, v21, v22, v23 + + TILE8_L8_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v24, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v2, v25, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v17, v2, v25, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v18, v2, v25, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v19, v2, v25, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v20, v3, v25, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v21, v3, v25, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7 + cmp w28, #1 beq L8Tile8QuanUseInt8 sub x4, x4, #64 + + cbz x9, TILE8_ADD_DSTV + TILE8_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x20], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v1 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + b TILE8_POST + + TILE8_ADD_DSTV: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x4 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10] + ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3 + ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7 + ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27 + ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31 + sub x10, x10, #128 + sub x10, x10, x4 + + TILE8_POST: + cbz x23, TILE8_STORE + ld1r {v0.4s}, [x23], #4 // f32 min + ld1r {v1.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v0, v1 + ReLU_FP32 v12, v13, v14, v15, v0, v1 + ReLU_FP32 v16, v17, v18, v19, v0, v1 + ReLU_FP32 v20, v21, v22, v23, v0, v1 + sub x23, x23, #4 + + TILE8_STORE: st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 @@ -337,6 +607,16 @@ L8LoopDz_TILE_8: b L8Tile8LoopCheck L8Tile8QuanUseInt8: + ld1r {v7.4s}, [x23], #4 // int8 max + ld1r {v6.4s}, [x23] // int8 min + ld1 {v0.4s, v1.4s}, [x20], #32 + dup v7.16b, v7.b[0] + dup v6.16b, v6.b[0] + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v1 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + sub x23, x23, #4 FloatToInt32 v8, v9, v10, v11 FloatToInt32 v12, v13, v14, v15 FloatToInt32 v16, v17, v18, v19 @@ -364,12 +644,12 @@ L8LoopDz_TILE_8: cbz x14, Tile8End L4LoopDz_TILE_8: - ld1 {v0.4s}, [x20], #16 // bias + //ld1 {v0.4s}, [x20], #16 // bias mov x11, x1 mov x13, x3 - SET_BIAS v0, v8, v9, v10, v11 - SET_BIAS v0, v12, v13, v14, v15 + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 L4LoopSz_TILE_8: ld1 {v3.16b}, [x12], #16 // weight @@ -388,20 +668,69 @@ L4LoopDz_TILE_8: L4LoopSzEnd_TILE_8: L4Tile8Quan: - ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v0.4s}, [x19], #16 // scale + ld1 {v2.4s, v3.4s}, [x25] // x kernel sum + ld1 {v24.4s}, [x6], #16 // weight quan zeropoint Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v0, v12, v13, v14, v15 + + cbz x24, TILE8_L4_MLA + ld1 {v0.4s, v1.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + + TILE8_L4_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v24, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3 cmp w28, #1 beq L4Tile8QuanUseInt8 sub x4, x4, #64 + + cbz x9, TILE8_L4_ADD_DSTV + TILE8_L4_ADD_BIAS: + ld1 {v4.4s}, [x20], #16 + ADD_BIAS_FLOAT v8, v9, v10, v11, v4 + ADD_BIAS_FLOAT v12, v13, v14, v15, v4 + b TILE8_L4_POST + + TILE8_L4_ADD_DSTV: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10] + sub x10, x10, #64 + ADD_FLOAT v8, v9, v10, v11, v4, v5, v6, v7 + ADD_FLOAT v12, v13, v14, v15, v16, v17, v18, v19 + + TILE8_L4_POST: + cbz x23, TILE8_L4_STORE + ld1r {v0.4s}, [x23], #4 // f32 min + ld1r {v1.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v0, v1 + ReLU_FP32 v12, v13, v14, v15, v0, v1 + sub x23, x23, #4 + + TILE8_L4_STORE: st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 add x4, x4, #64 b Tile8End L4Tile8QuanUseInt8: + ld1r {v7.4s}, [x23], #4 // int8 max + ld1r {v6.4s}, [x23] // int8 min + ld1 {v4.4s}, [x20], #16 + dup v7.16b, v7.b[0] + dup v6.16b, v6.b[0] + ADD_BIAS_FLOAT v8, v9, v10, v11, v4 + ADD_BIAS_FLOAT v12, v13, v14, v15, v4 + sub x23, x23, #4 FloatToInt32 v8, v9, v10, v11 FloatToInt32 v12, v13, v14, v15 Int32ToInt16 v8, v9, v10, v11, v0, v1 @@ -412,11 +741,15 @@ L4LoopDz_TILE_8: smin v16.16b, v7.16b, v16.16b smin v17.16b, v7.16b, v17.16b st1 {v16.16b, v17.16b}, [x10], x4 - Tile8End: +cbz x24, Tile8_End_Offset +add x24, x24, #32 + +Tile8_End_Offset: sub x7, x7, #8 add x0, x0, x21, LSL #3 add x1, x1, #32 + add x25, x25, #32 TILE_4: cmp x7, #4 @@ -426,15 +759,17 @@ TILE_4: mov x14, x5 mov x19, x8 mov x20, x9 + mov x6, x26 // weightQuantBias cmp x5, #2 blt L4LoopDz_TILE_4 L8LoopDz_TILE_4: - ld1 {v0.4s, v1.4s}, [x20], #32 // bias + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias mov x11, x1 mov x13, x3 + mov x27, x12 - SET_BIAS v0, v8, v9, v10, v11 - SET_BIAS v1, v12, v13, v14, v15 + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 L8LoopSz_TILE_4: ld1 {v3.16b}, [x12], x15 // weight @@ -453,22 +788,73 @@ L8LoopDz_TILE_4: bne L8LoopSz_TILE_4 L8LoopSzEnd_TILE_4: - add x12, x12, x15 + //add x12, x12, x15 + add x12, x27, x15, LSL #1 sub x14, x14, #2 L8Tile4Quan: ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.4s}, [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint Int32ToFloat v8, v9, v10, v11 Int32ToFloat v12, v13, v14, v15 MUL_SCALE v0, v8, v9, v10, v11 MUL_SCALE v1, v12, v13, v14, v15 + + cbz x24, TILE4_L8_MLA + ld1 {v0.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v0, v12, v13, v14, v15 + + TILE4_L8_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v2, v25, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v13, v2, v25, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v14, v2, v25, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v15, v2, v25, 3 // tile:3, oc:4-7 + cmp w28, #1 beq L8Tile4QuanUseInt8 + + cbz x9, TILE4_ADD_DSTV + TILE4_ADD_BIAS: + ld1 {v4.4s, v5.4s}, [x20], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v4 + ADD_BIAS_FLOAT v12, v13, v14, v15, v5 + b TILE4_POST + + TILE4_ADD_DSTV: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x4 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10] + sub x10, x10, x4 + ADD_FLOAT v8, v9, v10, v11, v4, v5, v6, v7 + ADD_FLOAT v12, v13, v14, v15, v16, v17, v18, v19 + + TILE4_POST: + cbz x23, TILE4_STORE + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v26, v27 + ReLU_FP32 v12, v13, v14, v15, v26, v27 + sub x23, x23, #4 + + TILE4_STORE: st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 b L8Tile4LoopCheck L8Tile4QuanUseInt8: + ld1r {v7.4s}, [x23], #4 // int8 max + ld1r {v6.4s}, [x23] // int8 min + ld1 {v4.4s, v5.4s}, [x20], #32 + dup v7.16b, v7.b[0] + dup v6.16b, v6.b[0] + ADD_BIAS_FLOAT v8, v9, v10, v11, v4 + ADD_BIAS_FLOAT v12, v13, v14, v15, v5 + sub x23, x23, #4 FloatToInt32 v8, v9, v10, v11 FloatToInt32 v12, v13, v14, v15 Int32ToInt16 v8, v9, v10, v11, v0, v1 @@ -487,10 +873,10 @@ L8LoopDz_TILE_4: cbz x14, Tile4End L4LoopDz_TILE_4: - ld1 {v0.4s}, [x20], #16 // bias + //ld1 {v0.4s}, [x20], #16 // bias mov x11, x1 mov x13, x3 - SET_BIAS v0, v8, v9, v10, v11 + SET_BIAS v8, v9, v10, v11 L4LoopSz_TILE_4: ld1 {v3.16b}, [x12], #16 // weight @@ -506,25 +892,68 @@ L4LoopDz_TILE_4: L4Tile4Quan: ld1 {v0.4s}, [x19], #16 // scale + ld1 {v2.4s}, [x25] // x kernel sum + ld1 {v24.4s}, [x6], #16 // weight quan zeropoint Int32ToFloat v8, v9, v10, v11 MUL_SCALE v0, v8, v9, v10, v11 + + cbz x24, TILE4_L4_MLA + ld1 {v0.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + + TILE4_L4_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + cmp w28, #1 beq L4Tile4QuanUseInt8 + + cbz x9, TILE4_L4_ADD_DSTV + TILE4_L4_ADD_BIAS: + ld1 {v3.4s}, [x20], #16 + ADD_BIAS_FLOAT v8, v9, v10, v11, v3 + b TILE4_L4_POST + + TILE4_L4_ADD_DSTV: + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10] + ADD_FLOAT v8, v9, v10, v11, v12, v13, v14, v15 + + TILE4_L4_POST: + cbz x23, TILE4_L4_STORE + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v26, v27 + sub x23, x23, #4 + + TILE4_L4_STORE: st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 b Tile4End L4Tile4QuanUseInt8: + ld1r {v7.4s}, [x23], #4 // int8 max + ld1r {v6.4s}, [x23] // int8 min + ld1 {v3.4s}, [x20], #16 + dup v7.16b, v7.b[0] + dup v6.16b, v6.b[0] + ADD_BIAS_FLOAT v8, v9, v10, v11, v3 + sub x23, x23, #4 FloatToInt32 v8, v9, v10, v11 Int32ToInt16 v8, v9, v10, v11, v0, v1 Int16ToInt8_ONE v0, v1, v16 smax v16.16b, v6.16b, v16.16b smin v16.16b, v7.16b, v16.16b st1 {v16.16b}, [x10], x4 - Tile4End: +cbz x24, Tile4_End_Offset +add x24, x24, #16 + +Tile4_End_Offset: sub x7, x7, #4 add x0, x0, x21, LSL #2 add x1, x1, #16 + add x25, x25, #16 TILE_1: cbz x7, End @@ -533,14 +962,17 @@ TILE_1: mov x14, x5 mov x19, x8 mov x20, x9 + mov x6, x26 // weightQuantBias cmp x5, #2 blt L4LoopDz_TILE_1 L8LoopDz_TILE_1: - ld1 {v0.4s, v1.4s}, [x20], #32 // bias + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias mov x11, x1 mov x13, x3 - mov v8.16b, v0.16b - mov v9.16b, v1.16b + mov x27, x12 + + movi v8.16b, #0 + movi v9.16b, #0 L8LoopSz_TILE_1: ld1 {v3.16b}, [x12], x15 // weight ld1 {v0.s}[0], [x11], x22 // src @@ -552,22 +984,68 @@ L8LoopDz_TILE_1: bne L8LoopSz_TILE_1 L8LoopSzEnd_TILE_1: - add x12, x12, x15 + add x12, x27, x15, LSL #1 sub x14, x14, #2 L8Tile1Quan: ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.s}[0], [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint scvtf v8.4s, v8.4s scvtf v9.4s, v9.4s fmul v8.4s, v8.4s, v0.4s fmul v9.4s, v9.4s, v1.4s + + cbz x24, TILE1_L8_MLA + ld1 {v0.s}[0], [x24] + fmul v8.4s, v8.4s, v0.s[0] + fmul v9.4s, v9.4s, v0.s[0] + + TILE1_L8_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v25, 0 // tile:0, oc:4-7 + cmp w28, #1 beq L8Tile1QuanUseInt8 + + cbz x9, TILE1_ADD_DSTV + TILE1_ADD_BIAS: + ld1 {v10.4s, v11.4s}, [x20], #32 + fadd v8.4s, v8.4s, v10.4s + fadd v9.4s, v9.4s, v11.4s + b TILE1_POST + + TILE1_ADD_DSTV: + ld1 {v10.4s}, [x10], x4 + ld1 {v11.4s}, [x10] + sub x10, x10, x4 + fadd v8.4s, v8.4s, v10.4s + fadd v9.4s, v9.4s, v11.4s + + TILE1_POST: + cbz x23, TILE1_STORE + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + sub x23, x23, #4 + fmin v8.4s, v8.4s, v27.4s + fmin v9.4s, v9.4s, v27.4s + fmax v8.4s, v8.4s, v26.4s + fmax v9.4s, v9.4s, v26.4s + + TILE1_STORE: st1 {v8.4s}, [x10], x4 st1 {v9.4s}, [x10], x4 b L8Tile1LoopCheck L8Tile1QuanUseInt8: + ld1r {v7.4s}, [x23], #4 // int8 max + ld1r {v6.4s}, [x23] // int8 min + ld1 {v10.4s, v11.4s}, [x20], #32 + dup v7.16b, v7.b[0] + dup v6.16b, v6.b[0] + fadd v8.4s, v8.4s, v10.4s + fadd v9.4s, v9.4s, v11.4s + sub x23, x23, #4 fcvtas v8.4s, v8.4s fcvtas v9.4s, v9.4s sqxtn v0.4h, v8.4s @@ -584,10 +1062,10 @@ L8LoopDz_TILE_1: cbz x14, Tile1End L4LoopDz_TILE_1: - ld1 {v0.4s}, [x20], #16 // bias + //ld1 {v0.4s}, [x20], #16 // bias mov x11, x1 mov x13, x3 - mov v8.16b, v0.16b + movi v8.16b, #0 L4LoopSz_TILE_1: ld1 {v3.16b}, [x12], #16 // weight ld1 {v0.s}[0], [x11], x22 // src @@ -599,14 +1077,49 @@ L4LoopDz_TILE_1: L4Tile1Quan: ld1 {v0.4s}, [x19], #16 // scale + ld1 {v2.s}[0], [x25] // x kernel sum + ld1 {v24.4s}, [x6], #16 // weight quan zeropoint scvtf v8.4s, v8.4s fmul v8.4s, v8.4s, v0.4s + + cbz x24, TILE1_L4_MLA + ld1 {v0.s}[0], [x24] + fmul v8.4s, v8.4s, v0.s[0] + + TILE1_L4_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 cmp w28, #1 beq L4Tile1QuanUseInt8 + + cbz x9, TILE1_L4_ADD_DSTV + TILE1_L4_ADD_BIAS: + ld1 {v4.4s}, [x20], #16 + fadd v8.4s, v8.4s, v4.4s + b TILE1_L4_POST + + TILE1_L4_ADD_DSTV: + ld1 {v4.4s}, [x10] + fadd v8.4s, v8.4s, v4.4s + + TILE1_L4_POST: + cbz x23, TILE1_L4_STORE + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + sub x23, x23, #4 + fmax v8.4s, v8.4s, v26.4s + fmin v8.4s, v8.4s, v27.4s + TILE1_L4_STORE: st1 {v8.4s}, [x10], x4 b Tile1End L4Tile1QuanUseInt8: + ld1r {v7.4s}, [x23], #4 // int8 max + ld1r {v6.4s}, [x23] // int8 min + ld1 {v4.4s}, [x20], #16 + fadd v8.4s, v8.4s, v4.4s + sub x23, x23, #4 + dup v7.16b, v7.b[0] + dup v6.16b, v6.b[0] fcvtas v8.4s, v8.4s sqxtn v0.4h, v8.4s sqxtn v16.8b, v0.8h @@ -615,19 +1128,26 @@ L4LoopDz_TILE_1: st1 {v16.s}[0], [x10], x4 Tile1End: +cbz x24, Tile1_End_Offset +add x24, x24, #4 + +Tile1_End_Offset: sub x7, x7, #1 add x0, x0, x21 add x1, x1, #4 + add x25, x25, #4 b TILE_1 End: +ldp x23, x24, [sp, #(16 * 8)] +ldp x25, x26, [sp, #(16 * 7)] ldp x27, x28, [sp, #(16 * 6)] ldp x19, x20, [sp, #(16 * 5)] ldp x21, x22, [sp, #(16 * 4)] ldp d8, d9, [sp, #(16 * 3)] ldp d10, d11, [sp, #(16 * 2)] ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 7) +ldp d14, d15, [sp], #(16 * 9) ret #endif // __aarch64__ diff --git a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV86_Unit.S b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV86_Unit.S index a6d4af6a3..eda852364 100644 --- a/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV86_Unit.S +++ b/source/backend/cpu/arm/arm64/MNNGemmInt8AddBiasScale_ARMV86_Unit.S @@ -12,11 +12,30 @@ .text .align 5 -.macro SET_BIAS s, d0, d1, d2, d3 - mov \d0\().16b, \s\().16b - mov \d1\().16b, \s\().16b - mov \d2\().16b, \s\().16b - mov \d3\().16b, \s\().16b +.macro SET_0_5 d0, d1, d2, d3, d4 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 + movi \d4\().16b, #0 +.endm +.macro SET_0_4 d0, d1, d2, d3 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 +.endm +.macro ADD_BIAS_FLOAT d0, d1, d2, d3, z0 + fadd \d0\().4s, \d0\().4s, \z0\().4s + fadd \d1\().4s, \d1\().4s, \z0\().4s + fadd \d2\().4s, \d2\().4s, \z0\().4s + fadd \d3\().4s, \d3\().4s, \z0\().4s +.endm +.macro ADD_FLOAT d0, d1, d2, d3, s0, s1, s2, s3 + fadd \d0\().4s, \d0\().4s, \s0\().4s + fadd \d1\().4s, \d1\().4s, \s1\().4s + fadd \d2\().4s, \d2\().4s, \s2\().4s + fadd \d3\().4s, \d3\().4s, \s3\().4s .endm .macro Int32ToFloat z0, z1, z2, z3 scvtf \z0\().4s, \z0\().4s @@ -30,6 +49,12 @@ fmul \d2\().4s, \d2\().4s, \s\().4s fmul \d3\().4s, \d3\().4s, \s\().4s .endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm .macro FloatToInt32 z0, z1, z2, z3 fcvtas \z0\().4s, \z0\().4s fcvtas \z1\().4s, \z1\().4s @@ -50,16 +75,43 @@ Int16ToInt8_ONE \s0, \s1, \d0 Int16ToInt8_ONE \s2, \s3, \d1 .endm +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro ReLU_FP32 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmin \s3\().4s, \s3\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s + fmax \s3\().4s, \s3\().4s, \z0\().4s +.endm +.macro ReLU_FP32_2 s0, s1, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s +.endm asm_function MNNGemmInt8AddBiasScale_ARMV86_Unit - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -//}; - +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum; + float* extraScale; +}; +*/ //void MNNGemmInt8AddBiasScale_ARMV86_Unit(int8_t* dst, const int8_t* src, // const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, // const QuanPostTreatParameters* parameters, size_t realDstCount); @@ -67,13 +119,13 @@ asm_function MNNGemmInt8AddBiasScale_ARMV86_Unit //Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step //x5:dst_depth_quad, x6: parameters, x7: realDstCount -//Load from x7: x8: scale, x9: bias, w12: maxValue, w13: minValue, w23: useInt8 +//Load from x6: x8: scale, x9: bias, w23: useInt8, x27: srcKernelSum, x28: weightQuanBias, +// EP=10,LP=8,HP=8 + ldr x8, [x6, #0] ldr x9, [x6, #8] -ldr w10, [x6, #16] -ldr w14, [x6, #20] -stp d14, d15, [sp, #(-16 * 8)]! +stp d14, d15, [sp, #(-16 * 10)]! stp d12, d13, [sp, #(16 * 1)] stp d10, d11, [sp, #(16 * 2)] stp d8, d9, [sp, #(16 * 3)] @@ -81,341 +133,744 @@ stp x21, x22, [sp, #(16 * 4)] stp x19, x20, [sp, #(16 * 5)] stp x23, x24, [sp, #(16 * 6)] stp x25, x26, [sp, #(16 * 7)] +stp x27, x28, [sp, #(16 * 8)] ldr w23, [x6, #24] +ldr x27, [x6, #40] // srcKernelSum +ldr x28, [x6, #48] // weightQuanBias -mov x21, #4 // sizeof(int8_t) * UNIT -mov x22, #160 // GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT = 20 * 8 = 160 +ldr x22, [x6, #64] // blockNum +mul x22, x22, x3 // UP_DIV(ic*ky*kx, SRC_UNIT) = blockNum * src_depth_quad_per_block +lsl x15, x22, #6 // x15 = src_depth_quad * UNIT * UNIT_SRC = src_depth_quad * 64 = src_depth_quad << 6 + +ldr x10, [x6, #80] // extra scale +mov x21, #4 // sizeof(int8_t) * pack +add x14, x6, #16 // int8 max ptr cbnz w23, Start -mov x21, #16 // sizeof(float) * UNIT +mov x21, #16 // sizeof(float) * pack +ldr x14, [x6, #56] // float32 maxmin ptr Start: -lsl x15, x3, #5 // x15 = src_depth_quad * UNIT * UNIT_SRC = src_depth_quad * 32 = src_depth_quad << 5 +mov x22, #80 // GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT = 10 * 8 = 80 + +TILE_10: + cmp x7, #10 + blt TILE_8 + sub x4, x4, #32 // For int8 output, x4-64 + cbnz w23, TILE10_DZ + sub x4, x4, #96 // For float32 output, x4-32-96=x4-128 -TILE_20: - cmp x7, #20 - blt TILE_16 -LoopDz_TILE_20: - ld1 {v0.4s}, [x9], #16 // bias +TILE10_DZ: +cmp x5, #2 +blt LoopDz4_TILE_10 + +LoopDz8_TILE_10: mov x11, x1 // src mov x12, x2 // weight mov x13, x3 // src_depth_quad - mov v1.16b, v0.16b - uzp1 v12.2d, v0.2d, v1.2d // bias_0, bias_1, bias_0, bias_1 - uzp2 v13.2d, v0.2d, v1.2d // bias_2, bias_3, bias_2, bias_3 - mov v14.16b, v12.16b - mov v15.16b, v13.16b - SET_BIAS v14, v16, v18, v20, v22 - SET_BIAS v14, v24, v26, v28, v30 - SET_BIAS v15, v17, v19, v21, v23 - SET_BIAS v15, v25, v27, v29, v31 -LoopSz_TILE_20: - // src : 10 x [2 x 8] : v2-11 - // weight : 2 x [2 x 8] : v0-1 - // dst : 10 x 2 x [4] : v12-v31 - ld1 {v0.16b, v1.16b}, [x12], #32 // weight - ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x11], #64 // src - .inst 0x4e80a44c // smmla v12.4s, v2.16b, v0.16b - .inst 0x4e81a44d // smmla v13.4s, v2.16b, v1.16b - .inst 0x4e80a46e // smmla v14.4s, v3.16b, v0.16b - .inst 0x4e81a46f // smmla v15.4s, v3.16b, v1.16b - ld1 {v6.16b, v7.16b, v8.16b, v9.16b}, [x11], #64 - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b - .inst 0x4e80a4b2 // smmla v18.4s, v5.16b, v0.16b - .inst 0x4e81a4b3 // smmla v19.4s, v5.16b, v1.16b - ld1 {v10.16b, v11.16b}, [x11], #32 - .inst 0x4e80a4d4 // smmla v20.4s, v6.16b, v0.16b - .inst 0x4e81a4d5 // smmla v21.4s, v6.16b, v1.16b - .inst 0x4e80a4f6 // smmla v22.4s, v7.16b, v0.16b - .inst 0x4e81a4f7 // smmla v23.4s, v7.16b, v1.16b - .inst 0x4e80a518 // smmla v24.4s, v8.16b, v0.16b - .inst 0x4e81a519 // smmla v25.4s, v8.16b, v1.16b - .inst 0x4e80a53a // smmla v26.4s, v9.16b, v0.16b - .inst 0x4e81a53b // smmla v27.4s, v9.16b, v1.16b - .inst 0x4e80a55c // smmla v28.4s, v10.16b, v0.16b - .inst 0x4e81a55d // smmla v29.4s, v10.16b, v1.16b + + SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 + SET_0_5 v13, v17, v21, v25, v29 // oc:2,3,2,3 + SET_0_5 v14, v18, v22, v26, v30 // oc:4,5,4,5 + SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7 + +LoopSz_TILE_10: + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 + ld1 {v7.16b}, [x11], #16 subs x13, x13, #1 - .inst 0x4e80a57e // smmla v30.4s, v11.16b, v0.16b - .inst 0x4e81a57f // smmla v31.4s, v11.16b, v1.16b - bne LoopSz_TILE_20 -LoopSzEnd_TILE_20: + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba46f // smmla v15.4s, v3.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa492 // smmla v18.4s, v4.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba493 // smmla v19.4s, v4.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + .inst 0x4e8aa4b6 // smmla v22.4s, v5.16b, v10.16b // tile4-oc4, tile4-oc5, tile5-oc4, tile5-oc5 + .inst 0x4e8ba4b7 // smmla v23.4s, v5.16b, v11.16b // tile4-oc6, tile4-oc7, tile5-oc6, tile5-oc7 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + .inst 0x4e8aa4da // smmla v26.4s, v6.16b, v10.16b // tile6-oc4, tile6-oc5, tile7-oc4, tile7-oc5 + .inst 0x4e8ba4db // smmla v27.4s, v6.16b, v11.16b // tile6-oc6, tile6-oc7, tile7-oc6, tile7-oc7 + + .inst 0x4e88a4fc // smmla v28.4s, v7.16b, v8.16b // tile8-oc0, tile8-oc1, tile9-oc0, tile9-oc1 + .inst 0x4e89a4fd // smmla v29.4s, v7.16b, v9.16b // tile8-oc2, tile8-oc3, tile9-oc2, tile9-oc3 + .inst 0x4e8aa4fe // smmla v30.4s, v7.16b, v10.16b // tile8-oc4, tile8-oc5, tile9-oc4, tile9-oc5 + .inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7 + bne LoopSz_TILE_10 +LoopSzEnd_TILE_10: add x2, x2, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); - sub x5, x5, #1 // dz-- + sub x5, x5, #2 // dz-2 // transpose - uzp1 v11.2d, v12.2d, v13.2d - uzp2 v12.2d, v12.2d, v13.2d - uzp1 v13.2d, v14.2d, v15.2d - uzp2 v14.2d, v14.2d, v15.2d - uzp1 v15.2d, v16.2d, v17.2d - uzp2 v16.2d, v16.2d, v17.2d - uzp1 v17.2d, v18.2d, v19.2d - uzp2 v18.2d, v18.2d, v19.2d - uzp1 v19.2d, v20.2d, v21.2d - uzp2 v20.2d, v20.2d, v21.2d - uzp1 v21.2d, v22.2d, v23.2d - uzp2 v22.2d, v22.2d, v23.2d - uzp1 v23.2d, v24.2d, v25.2d - uzp2 v24.2d, v24.2d, v25.2d - uzp1 v25.2d, v26.2d, v27.2d - uzp2 v26.2d, v26.2d, v27.2d - uzp1 v27.2d, v28.2d, v29.2d - uzp2 v28.2d, v28.2d, v29.2d - uzp1 v29.2d, v30.2d, v31.2d - uzp2 v30.2d, v30.2d, v31.2d - Int32ToFloat v11, v12, v13, v14 - Int32ToFloat v15, v16, v17, v18 - Int32ToFloat v19, v20, v21, v22 - Int32ToFloat v23, v24, v25, v26 - Int32ToFloat v27, v28, v29, v30 - -Tile20Quan: - ld1 {v0.4s}, [x8], #16 // scale - MUL_SCALE v0, v11, v12, v13, v14 - MUL_SCALE v0, v15, v16, v17, v18 - MUL_SCALE v0, v19, v20, v21, v22 - MUL_SCALE v0, v23, v24, v25, v26 - MUL_SCALE v0, v27, v28, v29, v30 - cmp w23, #1 - beq Tile20QuanUseInt8 - sub x4, x4, #256 - st1 {v11.4s, v12.4s, v13.4s, v14.4s}, [x0], #64 - st1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x0], #64 - st1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x0], #64 - st1 {v23.4s, v24.4s, v25.4s, v26.4s}, [x0], #64 - st1 {v27.4s, v28.4s, v29.4s, v30.4s}, [x0], x4 - add x4, x4, #256 - b Tile20LoopCheck - - Tile20QuanUseInt8: - FloatToInt32 v11, v12, v13, v14 - FloatToInt32 v15, v16, v17, v18 - FloatToInt32 v19, v20, v21, v22 - FloatToInt32 v23, v24, v25, v26 - FloatToInt32 v27, v28, v29, v30 - Int32ToInt16 v11, v12, v13, v14, v0, v1 - Int32ToInt16 v15, v16, v17, v18, v2, v3 - Int32ToInt16 v19, v20, v21, v22, v4, v5 - Int32ToInt16 v23, v24, v25, v26, v6, v7 - Int32ToInt16 v27, v28, v29, v30, v8, v9 - Int16ToInt8 v0, v1, v2, v3, v16, v17 - Int16ToInt8 v4, v5, v6, v7, v18, v19 - Int16ToInt8_ONE v8, v9, v20 - dup v11.16b, w10 // max - dup v10.16b, w14 // min - smax v16.16b, v10.16b, v16.16b - smax v17.16b, v10.16b, v17.16b - smax v18.16b, v10.16b, v18.16b - smax v19.16b, v10.16b, v19.16b - smax v20.16b, v10.16b, v20.16b - smin v16.16b, v11.16b, v16.16b - smin v17.16b, v11.16b, v17.16b - smin v18.16b, v11.16b, v18.16b - smin v19.16b, v11.16b, v19.16b - smin v20.16b, v11.16b, v20.16b - sub x4, x4, #64 - st1 {v16.16b, v17.16b, v18.16b, v19.16b}, [x0], #64 - st1 {v20.16b}, [x0], x4 // dst += dz * dst_step; - add x4, x4, #64 -Tile20LoopCheck: - cmp x5, #1 - bge LoopDz_TILE_20 - b End + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v16.2d, v17.2d + uzp1 v4.2d, v20.2d, v21.2d + uzp2 v5.2d, v20.2d, v21.2d + uzp1 v6.2d, v24.2d, v25.2d + uzp2 v7.2d, v24.2d, v25.2d + uzp1 v8.2d, v28.2d, v29.2d + uzp2 v9.2d, v28.2d, v29.2d -TILE_16: - dup v11.16b, w10 // max - dup v10.16b, w14 // min - sub x10, x22, #64 - cmp x7, #16 - blt TILE_8 - mov x24, x5 // dst_depth_quad - mov x26, x0 // dst - mov x25, x2 // weight - mov x19, x8 // scale - mov x20, x9 // bias -LoopDz_TILE_16: // while (dz = dst_depth_quad) - ld1 {v0.4s}, [x20], #16 // bias + uzp1 v10.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v11.2d, v14.2d, v15.2d // E1: oc:4-7 + uzp1 v12.2d, v18.2d, v19.2d + uzp2 v13.2d, v18.2d, v19.2d + uzp1 v14.2d, v22.2d, v23.2d + uzp2 v15.2d, v22.2d, v23.2d + uzp1 v16.2d, v26.2d, v27.2d + uzp2 v17.2d, v26.2d, v27.2d + uzp1 v18.2d, v30.2d, v31.2d + uzp2 v19.2d, v30.2d, v31.2d + + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + +Tile10Quan: + ld1 {v20.4s, v21.4s}, [x8], #32 // scale + ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum + ld1 {v24.d}[0], [x27] + ld1 {v25.4s, v26.4s}, [x28], #32 // weight quan zeropoint + sub x27, x27, #32 + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v20, v4, v5, v6, v7 + MUL_SCALE v21, v10, v11, v12, v13 + MUL_SCALE v21, v14, v15, v16, v17 + fmul v8.4s, v8.4s, v20.4s + fmul v9.4s, v9.4s, v20.4s + fmul v18.4s, v18.4s, v21.4s + fmul v19.4s, v19.4s, v21.4s + + cbz x10, TILE10_MLA + ld1 {v27.4s, v28.4s}, [x10], #32 + ld1 {v29.d}[0], [x10] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + MUL_EXTRA_SCALE v28, v4, v5, v6, v7 + MUL_EXTRA_SCALE v27, v10, v11, v12, v13 + MUL_EXTRA_SCALE v28, v14, v15, v16, v17 + fmul v8.4s, v8.4s, v29.s[0] + fmul v9.4s, v9.4s, v29.s[1] + fmul v18.4s, v18.4s, v29.s[0] + fmul v19.4s, v19.4s, v29.s[1] + sub x10, x10, #32 + + TILE10_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v11, v22, v26, 1 // tile:1, oc:4-7 + + MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v13, v22, v26, 3 // tile:3, oc:4-7 + + MLA_WEIGHTZERO v4, v23, v25, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v5, v23, v25, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v23, v26, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v15, v23, v26, 1 // tile:5, oc:4-7 + + MLA_WEIGHTZERO v6, v23, v25, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v23, v26, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v17, v23, v26, 3 // tile:7, oc:4-7 + + MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v24, v26, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v19, v24, v26, 1 // tile:9, oc:4-7 + + cbnz w23, Tile10QuanUseInt8 + + TILE10_ADD_BIAS: + cbz x9, TILE10_ADD_DSTV + ld1 {v20.4s, v21.4s}, [x9], #32 // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v20 + ADD_BIAS_FLOAT v4, v5, v6, v7, v20 + ADD_BIAS_FLOAT v10, v11, v12, v13, v21 + ADD_BIAS_FLOAT v14, v15, v16, v17, v21 + fadd v8.4s, v8.4s, v20.4s + fadd v9.4s, v9.4s, v20.4s + fadd v18.4s, v18.4s, v21.4s + fadd v19.4s, v19.4s, v21.4s + b TILE10_POST + + TILE10_ADD_DSTV: + // first batch10 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s}, [x0], x4 + ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 + ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 + fadd v8.4s, v8.4s, v28.4s + fadd v9.4s, v9.4s, v29.4s + + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s}, [x0] + ADD_FLOAT v10, v11, v12, v13, v20, v21, v22, v23 + ADD_FLOAT v14, v15, v16, v17, v24, v25, v26, v27 + fadd v18.4s, v18.4s, v28.4s + fadd v19.4s, v19.4s, v29.4s + + sub x0, x0, #256 + sub x0, x0, x4 + + TILE10_POST: + cbz x14, TILE10_STORE + ld1r {v30.4s}, [x14], #4 // f32 min + ld1r {v31.4s}, [x14] // f32 max + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + ReLU_FP32 v8, v9, v10, v11, v30, v31 + ReLU_FP32 v12, v13, v14, v15, v30, v31 + ReLU_FP32 v16, v17, v18, v19, v30, v31 + sub x14, x14, #4 + + TILE10_STORE: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 + st1 {v8.4s, v9.4s}, [x0], x4 + st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [x0], #64 + st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [x0], #64 + st1 {v18.4s, v19.4s}, [x0], x4 + b Tile10LoopCheck + + Tile10QuanUseInt8: + ld1 {v20.4s, v21.4s}, [x9], #32 // bias + ld1r {v31.4s}, [x14], #4 // int8 max + ld1r {v30.4s}, [x14] // int8 min + ADD_BIAS_FLOAT v0, v1, v2, v3, v20 + ADD_BIAS_FLOAT v4, v5, v6, v7, v20 + ADD_BIAS_FLOAT v10, v11, v12, v13, v21 + ADD_BIAS_FLOAT v14, v15, v16, v17, v21 + fadd v8.4s, v8.4s, v20.4s + fadd v9.4s, v9.4s, v20.4s + fadd v18.4s, v18.4s, v21.4s + fadd v19.4s, v19.4s, v21.4s + + sub x14, x14, #4 + dup v31.16b, v31.b[0] + dup v30.16b, v30.b[0] + + FloatToInt32 v0, v1, v2, v3 + FloatToInt32 v4, v5, v6, v7 + FloatToInt32 v10, v11, v12, v13 + FloatToInt32 v14, v15, v16, v17 + FloatToInt32 v8, v9, v18, v19 + + Int32ToInt16 v0, v1, v2, v3, v20, v21 + Int32ToInt16 v4, v5, v6, v7, v22, v23 + sqxtn v24.4h, v8.4s + sqxtn2 v24.8h, v9.4s + Int32ToInt16 v10, v11, v12, v13, v25, v26 + Int32ToInt16 v14, v15, v16, v17, v27, v28 + sqxtn v29.4h, v18.4s + sqxtn2 v29.8h, v19.4s + + Int16ToInt8 v20, v21, v22, v23, v0, v1 + sqxtn v2.8b, v24.8h + Int16ToInt8 v25, v26, v27, v28, v3, v4 + sqxtn v5.8b, v29.8h + + smax v0.16b, v30.16b, v0.16b + smax v1.16b, v30.16b, v1.16b + smax v2.8b, v30.8b, v2.8b + smax v3.16b, v30.16b, v3.16b + smax v4.16b, v30.16b, v4.16b + smax v5.8b, v30.8b, v5.8b + + smin v0.16b, v31.16b, v0.16b + smin v1.16b, v31.16b, v1.16b + smin v2.8b, v31.8b, v2.8b + smin v3.16b, v31.16b, v3.16b + smin v4.16b, v31.16b, v4.16b + smin v5.8b, v31.8b, v5.8b + + st1 {v0.16b, v1.16b}, [x0], #32 + st1 {v2.8b}, [x0], x4 + st1 {v3.16b, v4.16b}, [x0], #32 + st1 {v5.8b}, [x0], x4 + +Tile10LoopCheck: + cmp x5, #2 + bge LoopDz8_TILE_10 + cbz x5, End + +LoopDz4_TILE_10: mov x11, x1 // src - mov x12, x25 // weight + mov x12, x2 // weight mov x13, x3 // src_depth_quad - mov v1.16b, v0.16b - uzp1 v2.2d, v0.2d, v1.2d // bias_0, bias_1, bias_0, bias_1 - uzp2 v3.2d, v0.2d, v1.2d // bias_2, bias_3, bias_2, bias_3 - SET_BIAS v2, v16, v18, v20, v22 - SET_BIAS v2, v24, v26, v28, v30 - SET_BIAS v3, v17, v19, v21, v23 - SET_BIAS v3, v25, v27, v29, v31 -LoopSz_TILE_16: - // src : 8 x [2 x 8] : v2-9 - // weight : 2 x [2 x 8] : v0-1 - // dst : 8 x 2 x [4] : v16-v31 - ld1 {v0.16b, v1.16b}, [x12], #32 // weight - ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x11], #64 // src - .inst 0x4e80a450 // smmla v16.4s, v2.16b, v0.16b - .inst 0x4e81a451 // smmla v17.4s, v2.16b, v1.16b - .inst 0x4e80a472 // smmla v18.4s, v3.16b, v0.16b - .inst 0x4e81a473 // smmla v19.4s, v3.16b, v1.16b - ld1 {v6.16b, v7.16b, v8.16b, v9.16b}, [x11], x10 - .inst 0x4e80a494 // smmla v20.4s, v4.16b, v0.16b - .inst 0x4e81a495 // smmla v21.4s, v4.16b, v1.16b - .inst 0x4e80a4b6 // smmla v22.4s, v5.16b, v0.16b - .inst 0x4e81a4b7 // smmla v23.4s, v5.16b, v1.16b - .inst 0x4e80a4d8 // smmla v24.4s, v6.16b, v0.16b - .inst 0x4e81a4d9 // smmla v25.4s, v6.16b, v1.16b - .inst 0x4e80a4fa // smmla v26.4s, v7.16b, v0.16b - .inst 0x4e81a4fb // smmla v27.4s, v7.16b, v1.16b + + SET_0_5 v12, v13, v16, v17, v20 + SET_0_5 v21, v24, v25, v28, v29 + +LoopSz4_TILE_10: + ld1 {v8.16b, v9.16b}, [x12] // weight + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 + ld1 {v7.16b}, [x11], #16 subs x13, x13, #1 - .inst 0x4e80a51c // smmla v28.4s, v8.16b, v0.16b - .inst 0x4e81a51d // smmla v29.4s, v8.16b, v1.16b - .inst 0x4e80a53e // smmla v30.4s, v9.16b, v0.16b - .inst 0x4e81a53f // smmla v31.4s, v9.16b, v1.16b - bne LoopSz_TILE_16 -LoopSzEnd_TILE_16: - add x25, x25, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); - sub x24, x24, #1 // dz-- + add x12, x12, #64 // x12+lp*hp + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + + .inst 0x4e88a4fc // smmla v28.4s, v7.16b, v8.16b // tile8-oc0, tile8-oc1, tile9-oc0, tile9-oc1 + .inst 0x4e89a4fd // smmla v29.4s, v7.16b, v9.16b // tile8-oc2, tile8-oc3, tile9-oc2, tile9-oc3 + bne LoopSz4_TILE_10 +LoopSz4End_TILE_10: // transpose - uzp1 v15.2d, v16.2d, v17.2d - uzp2 v16.2d, v16.2d, v17.2d - uzp1 v17.2d, v18.2d, v19.2d - uzp2 v18.2d, v18.2d, v19.2d - uzp1 v19.2d, v20.2d, v21.2d - uzp2 v20.2d, v20.2d, v21.2d - uzp1 v21.2d, v22.2d, v23.2d - uzp2 v22.2d, v22.2d, v23.2d - uzp1 v23.2d, v24.2d, v25.2d - uzp2 v24.2d, v24.2d, v25.2d - uzp1 v25.2d, v26.2d, v27.2d - uzp2 v26.2d, v26.2d, v27.2d - uzp1 v27.2d, v28.2d, v29.2d - uzp2 v28.2d, v28.2d, v29.2d - uzp1 v29.2d, v30.2d, v31.2d - uzp2 v30.2d, v30.2d, v31.2d - Int32ToFloat v15, v16, v17, v18 - Int32ToFloat v19, v20, v21, v22 - Int32ToFloat v23, v24, v25, v26 - Int32ToFloat v27, v28, v29, v30 - -Tile16Quan: - ld1 {v0.4s}, [x19], #16 // scale - MUL_SCALE v0, v15, v16, v17, v18 - MUL_SCALE v0, v19, v20, v21, v22 - MUL_SCALE v0, v23, v24, v25, v26 - MUL_SCALE v0, v27, v28, v29, v30 - cmp w23, #1 - beq Tile16QuanUseInt8 - sub x4, x4, #192 - st1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26], #64 - st1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x26], #64 - st1 {v23.4s, v24.4s, v25.4s, v26.4s}, [x26], #64 - st1 {v27.4s, v28.4s, v29.4s, v30.4s}, [x26], x4 - add x4, x4, #192 - b Tile16LoopCheck - - Tile16QuanUseInt8: - FloatToInt32 v15, v16, v17, v18 - FloatToInt32 v19, v20, v21, v22 - FloatToInt32 v23, v24, v25, v26 - FloatToInt32 v27, v28, v29, v30 - Int32ToInt16 v15, v16, v17, v18, v0, v1 - Int32ToInt16 v19, v20, v21, v22, v2, v3 - Int32ToInt16 v23, v24, v25, v26, v4, v5 - Int32ToInt16 v27, v28, v29, v30, v6, v7 - Int16ToInt8 v0, v1, v2, v3, v16, v17 - Int16ToInt8 v4, v5, v6, v7, v18, v19 - smax v16.16b, v10.16b, v16.16b - smax v17.16b, v10.16b, v17.16b - smax v18.16b, v10.16b, v18.16b - smax v19.16b, v10.16b, v19.16b - smin v16.16b, v11.16b, v16.16b - smin v17.16b, v11.16b, v17.16b - smin v18.16b, v11.16b, v18.16b - smin v19.16b, v11.16b, v19.16b - st1 {v16.16b, v17.16b, v18.16b, v19.16b}, [x26], x4 // dst += dz * dst_step; -Tile16LoopCheck: - cmp x24, #1 - bge LoopDz_TILE_16 -Tile16End: - sub x7, x7, #16 - add x0, x0, x21, LSL #4 - add x1, x1, #128 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v16.2d, v17.2d + uzp1 v4.2d, v20.2d, v21.2d + uzp2 v5.2d, v20.2d, v21.2d + uzp1 v6.2d, v24.2d, v25.2d + uzp2 v7.2d, v24.2d, v25.2d + uzp1 v8.2d, v28.2d, v29.2d + uzp2 v9.2d, v28.2d, v29.2d + + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + scvtf v8.4s, v8.4s + scvtf v9.4s, v9.4s + +Tile10Quan_L4: + ld1 {v20.4s}, [x8] // scale + ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum + ld1 {v24.d}[0], [x27] + ld1 {v25.4s}, [x28] // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v20, v4, v5, v6, v7 + fmul v8.4s, v8.4s, v20.4s + fmul v9.4s, v9.4s, v20.4s + + cbz x10, TILE10_MLA_L4 + ld1 {v27.4s, v28.4s}, [x10], #32 + ld1 {v29.d}[0], [x10] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + MUL_EXTRA_SCALE v28, v4, v5, v6, v7 + fmul v8.4s, v8.4s, v29.s[0] + fmul v9.4s, v9.4s, v29.s[1] + + TILE10_MLA_L4: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v4, v23, v25, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v5, v23, v25, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v6, v23, v25, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3 + //sub x4, x4, #128 + + cbnz w23, Tile10QuanUseInt8_L4 + + TILE10_ADD_BIAS_L4: + cbz x9, TILE10_ADD_DSTV_L4 + ld1 {v20.4s}, [x9] // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v20 + ADD_BIAS_FLOAT v4, v5, v6, v7, v20 + fadd v8.4s, v8.4s, v20.4s + fadd v9.4s, v9.4s, v20.4s + b TILE10_POST_L4 + + TILE10_ADD_DSTV_L4: + // first batch10 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s}, [x0] + ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 + ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 + fadd v8.4s, v8.4s, v28.4s + fadd v9.4s, v9.4s, v29.4s + + sub x0, x0, #128 + + TILE10_POST_L4: + cbz x14, TILE10_STORE_L4 + ld1r {v30.4s}, [x14], #4 // f32 min + ld1r {v31.4s}, [x14] // f32 max + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + fmax v8.4s, v8.4s, v30.4s + fmax v9.4s, v9.4s, v30.4s + fmin v8.4s, v8.4s, v31.4s + fmin v9.4s, v9.4s, v31.4s + sub x14, x14, #4 + + TILE10_STORE_L4: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 + st1 {v8.4s, v9.4s}, [x0], x4 + b End + + Tile10QuanUseInt8_L4: + ld1 {v20.4s}, [x9] // bias + ld1r {v31.4s}, [x14], #4 // int8 max + ld1r {v30.4s}, [x14] // int8 min + ADD_BIAS_FLOAT v0, v1, v2, v3, v20 + ADD_BIAS_FLOAT v4, v5, v6, v7, v20 + fadd v8.4s, v8.4s, v20.4s + fadd v9.4s, v9.4s, v20.4s + + sub x14, x14, #4 + dup v31.16b, v31.b[0] + dup v30.16b, v30.b[0] + + FloatToInt32 v0, v1, v2, v3 + FloatToInt32 v4, v5, v6, v7 + fcvtas v8.4s, v8.4s + fcvtas v9.4s, v9.4s + + Int32ToInt16 v0, v1, v2, v3, v16, v17 + Int32ToInt16 v4, v5, v6, v7, v18, v19 + sqxtn v24.4h, v8.4s + sqxtn2 v24.8h, v9.4s + + Int16ToInt8 v16, v17, v18, v19, v21, v22 + sqxtn v23.8b, v24.8h + + smax v21.16b, v30.16b, v21.16b + smax v22.16b, v30.16b, v22.16b + smax v23.8b, v30.8b, v23.8b + + smin v21.16b, v31.16b, v21.16b + smin v22.16b, v31.16b, v22.16b + smin v23.8b, v31.8b, v23.8b + + st1 {v21.16b, v22.16b}, [x0], #32 + st1 {v23.8b}, [x0], x4 + b End TILE_8: + // post parameters initilize + cbnz w23, INT8_POST_INIT + cbz x14, TILE_Remain + ld1r {v30.4s}, [x14], #4 // f32 min + ld1r {v31.4s}, [x14] // f32 max + b TILE_Remain + + INT8_POST_INIT: + ld1r {v31.4s}, [x14], #4 // int8 max + ld1r {v30.4s}, [x14] // int8 min + dup v31.16b, v31.b[0] + dup v30.16b, v30.b[0] + + TILE_Remain: cmp x7, #8 blt TILE_4 + cbnz w23, TILE8_START + sub x4, x4, #64 // For float32 output, add #64 when tile8 end. + + TILE8_START: mov x24, x5 // dst_depth_quad mov x26, x0 // dst mov x25, x2 // weight mov x19, x8 // scale mov x20, x9 // bias + mov x6, x28 // weightQuanBias +cmp x5, #2 +blt LoopDz4_TILE_8 LoopDz_TILE_8: - ld1 {v0.4s}, [x20], #16 // bias mov x11, x1 // src mov x12, x25 // weight mov x13, x3 // src_depth_quad - mov v1.16b, v0.16b - uzp1 v2.2d, v0.2d, v1.2d // bias_0, bias_1, bias_0, bias_1 - uzp2 v3.2d, v0.2d, v1.2d // bias_2, bias_3, bias_2, bias_3 - SET_BIAS v2, v24, v26, v28, v30 - SET_BIAS v3, v25, v27, v29, v31 + SET_0_4 v12, v16, v20, v24 + SET_0_4 v13, v17, v21, v25 + SET_0_4 v14, v18, v22, v26 + SET_0_4 v15, v19, v23, v27 LoopSz_TILE_8: - // src : 4 x [2 x 8] : v2-5 - // weight : 2 x [2 x 8] : v0-1 - // dst : 4 x 2 x [4] : v24-v31 - ld1 {v0.16b, v1.16b}, [x12], #32 // weight - ld1 {v2.16b, v3.16b, v4.16b, v5.16b}, [x11], x22 // src - .inst 0x4e80a458 // smmla v24.4s, v2.16b, v0.16b - .inst 0x4e81a459 // smmla v25.4s, v2.16b, v1.16b - .inst 0x4e80a47a // smmla v26.4s, v3.16b, v0.16b - .inst 0x4e81a47b // smmla v27.4s, v3.16b, v1.16b - .inst 0x4e80a49c // smmla v28.4s, v4.16b, v0.16b - .inst 0x4e81a49d // smmla v29.4s, v4.16b, v1.16b - .inst 0x4e80a4be // smmla v30.4s, v5.16b, v0.16b - .inst 0x4e81a4bf // smmla v31.4s, v5.16b, v1.16b + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], x22 // src: E0-E7 subs x13, x13, #1 + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba46f // smmla v15.4s, v3.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa492 // smmla v18.4s, v4.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba493 // smmla v19.4s, v4.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + .inst 0x4e8aa4b6 // smmla v22.4s, v5.16b, v10.16b // tile4-oc4, tile4-oc5, tile5-oc4, tile5-oc5 + .inst 0x4e8ba4b7 // smmla v23.4s, v5.16b, v11.16b // tile4-oc6, tile4-oc7, tile5-oc6, tile5-oc7 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + .inst 0x4e8aa4da // smmla v26.4s, v6.16b, v10.16b // tile6-oc4, tile6-oc5, tile7-oc4, tile7-oc5 + .inst 0x4e8ba4db // smmla v27.4s, v6.16b, v11.16b // tile6-oc6, tile6-oc7, tile7-oc6, tile7-oc7 bne LoopSz_TILE_8 + LoopSzEnd_TILE_8: add x25, x25, x15 - sub x24, x24, #1 - uzp1 v23.2d, v24.2d, v25.2d - uzp2 v24.2d, v24.2d, v25.2d - uzp1 v25.2d, v26.2d, v27.2d - uzp2 v26.2d, v26.2d, v27.2d - uzp1 v27.2d, v28.2d, v29.2d - uzp2 v28.2d, v28.2d, v29.2d - uzp1 v29.2d, v30.2d, v31.2d - uzp2 v30.2d, v30.2d, v31.2d - Int32ToFloat v23, v24, v25, v26 - Int32ToFloat v27, v28, v29, v30 + sub x24, x24, #2 // dz-2 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v8.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v9.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v2.2d, v16.2d, v17.2d // E2: oc:0-3 + uzp2 v3.2d, v16.2d, v17.2d // E3: oc:0-3 + uzp1 v10.2d, v18.2d, v19.2d // E2: oc:4-7 + uzp2 v11.2d, v18.2d, v19.2d // E3: oc:4-7 + + uzp1 v4.2d, v20.2d, v21.2d // E4: oc:0-3 + uzp2 v5.2d, v20.2d, v21.2d // E5: oc:0-3 + uzp1 v12.2d, v22.2d, v23.2d // E4: oc:4-7 + uzp2 v13.2d, v22.2d, v23.2d // E5: oc:4-7 + + uzp1 v6.2d, v24.2d, v25.2d // E6: oc:0-3 + uzp2 v7.2d, v24.2d, v25.2d // E7: oc:0-3 + uzp1 v14.2d, v26.2d, v27.2d // E6: oc:4-7 + uzp2 v15.2d, v26.2d, v27.2d // E7: oc:4-7 + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 Tile8Quan: - ld1 {v0.4s}, [x19], #16 // scale - MUL_SCALE v0, v23, v24, v25, v26 - MUL_SCALE v0, v27, v28, v29, v30 - cmp w23, #1 - beq Tile8QuanUseInt8 - sub x4, x4, #64 - st1 {v23.4s, v24.4s, v25.4s, v26.4s}, [x26], #64 - st1 {v27.4s, v28.4s, v29.4s, v30.4s}, [x26], x4 - add x4, x4, #64 + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.4s, v23.4s}, [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v20, v4, v5, v6, v7 + MUL_SCALE v21, v8, v9, v10, v11 + MUL_SCALE v21, v12, v13, v14, v15 + + cbz x10, TILE8_MLA + ld1 {v27.4s, v28.4s}, [x10] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + MUL_EXTRA_SCALE v28, v4, v5, v6, v7 + MUL_EXTRA_SCALE v27, v8, v9, v10, v11 + MUL_EXTRA_SCALE v28, v12, v13, v14, v15 + + TILE8_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 + MLA_WEIGHTZERO v1, v22, v25, 1 + MLA_WEIGHTZERO v2, v22, v25, 2 + MLA_WEIGHTZERO v3, v22, v25, 3 + MLA_WEIGHTZERO v4, v23, v25, 0 + MLA_WEIGHTZERO v5, v23, v25, 1 + MLA_WEIGHTZERO v6, v23, v25, 2 + MLA_WEIGHTZERO v7, v23, v25, 3 + + MLA_WEIGHTZERO v8, v22, v26, 0 + MLA_WEIGHTZERO v9, v22, v26, 1 + MLA_WEIGHTZERO v10, v22, v26, 2 + MLA_WEIGHTZERO v11, v22, v26, 3 + MLA_WEIGHTZERO v12, v23, v26, 0 + MLA_WEIGHTZERO v13, v23, v26, 1 + MLA_WEIGHTZERO v14, v23, v26, 2 + MLA_WEIGHTZERO v15, v23, v26, 3 + + cbnz w23, Tile8QuanUseInt8 + + cbz x9, TILE8_ADD_DSTV + TILE8_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v16 + ADD_BIAS_FLOAT v8, v9, v10, v11, v17 + ADD_BIAS_FLOAT v12, v13, v14, v15, v17 + b TILE8_POST + + TILE8_ADD_DSTV: + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x26], x4 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x26], #64 + ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 + ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26] + ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19 + ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23 + sub x26, x26, x4 + sub x26, x26, #128 + + TILE8_POST: + cbz x14, TILE8_STORE + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + ReLU_FP32 v8, v9, v10, v11, v30, v31 + ReLU_FP32 v12, v13, v14, v15, v30, v31 + + TILE8_STORE: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], #64 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x26], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x26], x4 b Tile8LoopCheck Tile8QuanUseInt8: - FloatToInt32 v23, v24, v25, v26 - FloatToInt32 v27, v28, v29, v30 - Int32ToInt16 v23, v24, v25, v26, v4, v5 - Int32ToInt16 v27, v28, v29, v30, v6, v7 - Int16ToInt8 v4, v5, v6, v7, v18, v19 - smax v18.16b, v10.16b, v18.16b - smax v19.16b, v10.16b, v19.16b - smin v18.16b, v11.16b, v18.16b - smin v19.16b, v11.16b, v19.16b + ld1 {v16.4s, v17.4s}, [x20], #32 + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v16 + ADD_BIAS_FLOAT v8, v9, v10, v11, v17 + ADD_BIAS_FLOAT v12, v13, v14, v15, v17 + + FloatToInt32 v0, v1, v2, v3 + FloatToInt32 v4, v5, v6, v7 + FloatToInt32 v8, v9, v10, v11 + FloatToInt32 v12, v13, v14, v15 + + Int32ToInt16 v0, v1, v2, v3, v20, v21 + Int32ToInt16 v4, v5, v6, v7, v22, v23 + Int32ToInt16 v8, v9, v10, v11, v24, v25 + Int32ToInt16 v12, v13, v14, v15, v26, v27 + + Int16ToInt8 v20, v21, v22, v23, v28, v29 + Int16ToInt8 v24, v25, v26, v27, v18, v19 + smax v28.16b, v30.16b, v28.16b + smax v29.16b, v30.16b, v29.16b + smax v18.16b, v30.16b, v18.16b + smax v19.16b, v30.16b, v19.16b + smin v28.16b, v31.16b, v28.16b + smin v29.16b, v31.16b, v29.16b + smin v18.16b, v31.16b, v18.16b + smin v19.16b, v31.16b, v19.16b + st1 {v28.16b, v29.16b}, [x26], x4 st1 {v18.16b, v19.16b}, [x26], x4 // dst += dz * dst_step Tile8LoopCheck: - cmp x24, #1 + cmp x24, #2 bge LoopDz_TILE_8 + cbz x24, Tile8Check + +LoopDz4_TILE_8: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + SET_0_4 v12, v13, v16, v17 + SET_0_4 v20, v21, v24, v25 +LoopSz4_TILE_8: + ld1 {v8.16b, v9.16b}, [x12] // weight + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], x22 // src: E0-E7 + subs x13, x13, #1 + add x12, x12, #64 + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + bne LoopSz4_TILE_8 + +LoopSz4End_TILE_8: + add x25, x25, x15 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v16.2d, v17.2d // E2: oc:0-3 + uzp2 v3.2d, v16.2d, v17.2d // E3: oc:0-3 + uzp1 v4.2d, v20.2d, v21.2d // E4: oc:0-3 + uzp2 v5.2d, v20.2d, v21.2d // E5: oc:0-3 + uzp1 v6.2d, v24.2d, v25.2d // E6: oc:0-3 + uzp2 v7.2d, v24.2d, v25.2d // E7: oc:0-3 + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + +Tile8Quan_L4: + ld1 {v20.4s}, [x19] // scale + ld1 {v22.4s, v23.4s}, [x27] // x kernel sum + ld1 {v25.4s}, [x6] // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v20, v4, v5, v6, v7 + + cbz x10, TILE8_MLA_L4 + ld1 {v27.4s, v28.4s}, [x10] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + MUL_EXTRA_SCALE v28, v4, v5, v6, v7 + + TILE8_MLA_L4: + MLA_WEIGHTZERO v0, v22, v25, 0 + MLA_WEIGHTZERO v1, v22, v25, 1 + MLA_WEIGHTZERO v2, v22, v25, 2 + MLA_WEIGHTZERO v3, v22, v25, 3 + MLA_WEIGHTZERO v4, v23, v25, 0 + MLA_WEIGHTZERO v5, v23, v25, 1 + MLA_WEIGHTZERO v6, v23, v25, 2 + MLA_WEIGHTZERO v7, v23, v25, 3 + + cbnz w23, Tile8QuanUseInt8_L4 + + cbz x9, TILE8_ADD_DSTV_L4 + TILE8_ADD_BIAS_L4: + ld1 {v16.4s}, [x20] + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v16 + b TILE8_POST_L4 + + TILE8_ADD_DSTV_L4: + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x26] + ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 + ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 + sub x26, x26, #64 + + TILE8_POST_L4: + cbz x14, TILE8_STORE_L4 + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + + TILE8_STORE_L4: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], #64 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 + b Tile8Check + + Tile8QuanUseInt8_L4: + ld1 {v16.4s}, [x20] + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v16 + + FloatToInt32 v0, v1, v2, v3 + FloatToInt32 v4, v5, v6, v7 + + Int32ToInt16 v0, v1, v2, v3, v20, v21 + Int32ToInt16 v4, v5, v6, v7, v22, v23 + + Int16ToInt8 v20, v21, v22, v23, v16, v17 + smax v16.16b, v30.16b, v16.16b + smax v17.16b, v30.16b, v17.16b + smin v16.16b, v31.16b, v16.16b + smin v17.16b, v31.16b, v17.16b + st1 {v16.16b, v17.16b}, [x26], x4 + +Tile8Check: +cbz x10, Tile8End +add x10, x10, #32 + Tile8End: sub x7, x7, #8 add x0, x0, x21, LSL #3 add x1, x1, #64 + add x27, x27, #32 + add x4, x4, #64 // Revert x4 for following tile. TILE_4: cmp x7, #4 @@ -425,59 +880,193 @@ TILE_4: mov x25, x2 // weight mov x19, x8 // scale mov x20, x9 // bias + mov x6, x28 // weightQuanBias +cmp x5, #2 +blt LoopDz4_TILE_4 LoopDz_TILE_4: - ld1 {v0.4s}, [x20], #16 // bias mov x11, x1 // src mov x12, x25 // weight mov x13, x3 // src_depth_quad - mov v1.16b, v0.16b - uzp1 v28.2d, v0.2d, v1.2d // bias_0, bias_1, bias_0, bias_1 - uzp2 v29.2d, v0.2d, v1.2d // bias_2, bias_3, bias_2, bias_3 - mov v30.16b, v28.16b - mov v31.16b, v29.16b + SET_0_4 v12, v13, v14, v15 + SET_0_4 v16, v17, v18, v19 + LoopSz_TILE_4: - // src : 2 x [2 x 8] : v2-3 - // weight : 2 x [2 x 8] : v0-1 - // dst : 2 x 2 x [4] : v28-v31 - ld1 {v0.16b, v1.16b}, [x12], #32 // weight - ld1 {v2.16b, v3.16b}, [x11], x22 // src - .inst 0x4e80a45c // smmla v28.4s, v2.16b, v0.16b - .inst 0x4e81a45d // smmla v29.4s, v2.16b, v1.16b - .inst 0x4e80a47e // smmla v30.4s, v3.16b, v0.16b - .inst 0x4e81a47f // smmla v31.4s, v3.16b, v1.16b + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight + ld1 {v4.16b, v5.16b}, [x11], x22 // src subs x13, x13, #1 + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa48e // smmla v14.4s, v4.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba48f // smmla v15.4s, v4.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a4b0 // smmla v16.4s, v5.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa4b2 // smmla v18.4s, v5.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba4b3 // smmla v19.4s, v5.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 bne LoopSz_TILE_4 LoopSzEnd_TILE_4: add x25, x25, x15 - sub x24, x24, #1 - uzp1 v27.2d, v28.2d, v29.2d - uzp2 v28.2d, v28.2d, v29.2d - uzp1 v29.2d, v30.2d, v31.2d - uzp2 v30.2d, v30.2d, v31.2d - Int32ToFloat v27, v28, v29, v30 + sub x24, x24, #2 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v4.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v5.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v16.2d, v17.2d + uzp1 v6.2d, v18.2d, v19.2d + uzp2 v7.2d, v18.2d, v19.2d + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 Tile4Quan: - ld1 {v0.4s}, [x19], #16 // scale - MUL_SCALE v0, v27, v28, v29, v30 - cmp w23, #1 - beq Tile4QuanUseInt8 - st1 {v27.4s, v28.4s, v29.4s, v30.4s}, [x26], x4 + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.4s}, [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v21, v4, v5, v6, v7 + + cbz x10, TILE4_MLA + ld1 {v27.4s}, [x10] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + MUL_EXTRA_SCALE v27, v4, v5, v6, v7 + + TILE4_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v4, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v5, v22, v26, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v6, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v7, v22, v26, 3 // tile:3, oc:4-7 + + cbnz w23, Tile4QuanUseInt8 + + TILE4_ADD_BIAS: + cbz x9, TILE4_ADD_DSTV + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v17 + b TILE4_POST + + TILE4_ADD_DSTV: + ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26], x4 + ld1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x26] + ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18 + ADD_FLOAT v4, v5, v6, v7, v19, v20, v21, v22 + sub x26, x26, x4 + + TILE4_POST: + cbz x14, TILE4_STORE + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + + TILE4_STORE: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], x4 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 b Tile4LoopCheck Tile4QuanUseInt8: - FloatToInt32 v27, v28, v29, v30 - Int32ToInt16 v27, v28, v29, v30, v6, v7 - Int16ToInt8_ONE v6, v7, v19 - smax v19.16b, v10.16b, v19.16b - smin v19.16b, v11.16b, v19.16b + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v17 + FloatToInt32 v0, v1, v2, v3 + FloatToInt32 v4, v5, v6, v7 + Int32ToInt16 v0, v1, v2, v3, v8, v9 + Int32ToInt16 v4, v5, v6, v7, v10, v11 + Int16ToInt8_ONE v8, v9, v19 + Int16ToInt8_ONE v10, v11, v20 + smax v19.16b, v30.16b, v19.16b + smin v19.16b, v31.16b, v19.16b + smax v20.16b, v30.16b, v20.16b + smin v20.16b, v31.16b, v20.16b st1 {v19.16b}, [x26], x4 // dst += dz * dst_step + st1 {v20.16b}, [x26], x4 Tile4LoopCheck: - cmp x24, #1 + cmp x24, #2 bge LoopDz_TILE_4 + cbz x24, Tile4Check + +LoopDz4_TILE_4: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + SET_0_4 v12, v13, v16, v17 +LoopSz4_TILE_4: + ld1 {v8.16b, v9.16b}, [x12] // weight + ld1 {v4.16b, v5.16b}, [x11], x22 // src + subs x13, x13, #1 + add x12, x12, #64 + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + + .inst 0x4e88a4b0 // smmla v16.4s, v5.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + bne LoopSz4_TILE_4 +LoopSz4End_TILE_4: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v16.2d, v17.2d + Int32ToFloat v0, v1, v2, v3 + +Tile4Quan_L4: + ld1 {v20.4s}, [x19] // scale + ld1 {v22.4s}, [x27] // x kernel sum + ld1 {v25.4s}, [x6] // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + + cbz x10, TILE4_MLA_L4 + ld1 {v27.4s}, [x10] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + + TILE4_MLA_L4: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3 + + cbnz w23, Tile4QuanUseInt8_L4 + + TILE4_ADD_BIAS_L4: + cbz x9, TILE4_ADD_DSTV_L4 + ld1 {v16.4s}, [x20] // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + b TILE4_POST_L4 + + TILE4_ADD_DSTV_L4: + ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26] + ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18 + + TILE4_POST_L4: + cbz x14, TILE4_STORE_L4 + ReLU_FP32 v0, v1, v2, v3, v30, v31 + + TILE4_STORE_L4: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], x4 + b Tile4Check + + Tile4QuanUseInt8_L4: + ld1 {v16.4s}, [x20] // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + FloatToInt32 v0, v1, v2, v3 + Int32ToInt16 v0, v1, v2, v3, v8, v9 + Int16ToInt8_ONE v8, v9, v19 + smax v19.16b, v30.16b, v19.16b + smin v19.16b, v31.16b, v19.16b + st1 {v19.16b}, [x26], x4 // dst += dz * dst_step + +Tile4Check: +cbz x10, Tile4End +add x10, x10, #16 Tile4End: sub x7, x7, #4 add x0, x0, x21, LSL #2 add x1, x1, #32 + add x27, x27, #16 TILE_2: cmp x7, #2 @@ -487,57 +1076,189 @@ TILE_2: mov x25, x2 // weight mov x19, x8 // scale mov x20, x9 // bias + mov x6, x28 // weightQuanBias +cmp x5, #2 +blt LoopDz4_TILE_2 LoopDz_TILE_2: - ld1 {v0.4s}, [x20], #16 // bias mov x11, x1 // src mov x12, x25 // weight mov x13, x3 // src_depth_quad - mov v1.16b, v0.16b - uzp1 v30.2d, v0.2d, v1.2d // bias_0, bias_1, bias_0, bias_1 - uzp2 v31.2d, v0.2d, v1.2d // bias_2, bias_3, bias_2, bias_3 + SET_0_4 v12, v13, v14, v15 LoopSz_TILE_2: - // src : 1 x [2 x 8] : v2 - // weight : 2 x [2 x 8] : v0-1 - // dst : 1 x 2 x [4] : v30-v31 - ld1 {v0.16b, v1.16b}, [x12], #32 // weight - ld1 {v2.16b}, [x11], x22 // src - .inst 0x4e80a45e // smmla v30.4s, v2.16b, v0.16b - .inst 0x4e81a45f // smmla v31.4s, v2.16b, v1.16b + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 + ld1 {v4.16b}, [x11], x22 // src + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa48e // smmla v14.4s, v4.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba48f // smmla v15.4s, v4.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 subs x13, x13, #1 bne LoopSz_TILE_2 LoopSzEnd_TILE_2: add x25, x25, x15 - sub x24, x24, #1 - uzp1 v29.2d, v30.2d, v31.2d - uzp2 v30.2d, v30.2d, v31.2d - scvtf v29.4s, v29.4s - scvtf v30.4s, v30.4s + sub x24, x24, #2 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + Int32ToFloat v0, v1, v2, v3 Tile2Quan: - ld1 {v0.4s}, [x19], #16 // scale - fmul v29.4s, v29.4s, v0.4s - fmul v30.4s, v30.4s, v0.4s - cmp w23, #1 - beq Tile2QuanUseInt8 - st1 {v29.4s, v30.4s}, [x26], x4 + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.d}[0], [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + fmul v0.4s, v0.4s, v20.4s + fmul v1.4s, v1.4s, v20.4s + fmul v2.4s, v2.4s, v21.4s + fmul v3.4s, v3.4s, v21.4s + + cbz x10, TILE2_MLA + ld1 {v27.d}[0], [x10] + fmul v0.4s, v0.4s, v27.s[0] + fmul v1.4s, v1.4s, v27.s[1] + fmul v2.4s, v2.4s, v27.s[0] + fmul v3.4s, v3.4s, v27.s[1] + + TILE2_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + cbnz w23, Tile2QuanUseInt8 + + TILE2_ADD_BIAS: + cbz x9, TILE2_ADD_DSTV + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v17.4s + fadd v3.4s, v3.4s, v17.4s + b TILE2_POST + + TILE2_ADD_DSTV: + ld1 {v18.4s, v19.4s}, [x26], x4 + ld1 {v20.4s, v21.4s}, [x26] + fadd v0.4s, v0.4s, v18.4s + fadd v1.4s, v1.4s, v19.4s + fadd v2.4s, v2.4s, v20.4s + fadd v3.4s, v3.4s, v21.4s + sub x26, x26, x4 + + TILE2_POST: + cbz x14, TILE2_STORE + ReLU_FP32 v0, v1, v2, v3, v30, v31 + TILE2_STORE: + st1 {v0.4s, v1.4s}, [x26], x4 + st1 {v2.4s, v3.4s}, [x26], x4 b Tile2LoopCheck + Tile2QuanUseInt8: - fcvtas v29.4s, v29.4s - fcvtas v30.4s, v30.4s - sqxtn v6.4h, v29.4s - sqxtn2 v6.8h, v30.4s + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v17.4s + fadd v3.4s, v3.4s, v17.4s + fcvtas v0.4s, v0.4s + fcvtas v1.4s, v1.4s + fcvtas v2.4s, v2.4s + fcvtas v3.4s, v3.4s + sqxtn v6.4h, v0.4s + sqxtn2 v6.8h, v1.4s + sqxtn v7.4h, v2.4s + sqxtn2 v7.8h, v3.4s sqxtn v19.8b, v6.8h - smax v19.16b, v10.16b, v19.16b - smin v19.16b, v11.16b, v19.16b + sqxtn v20.8b, v7.8h + smax v19.8b, v30.8b, v19.8b + smin v19.8b, v31.8b, v19.8b + smax v20.8b, v30.8b, v20.8b + smin v20.8b, v31.8b, v20.8b st1 {v19.8b}, [x26], x4 // dst += dz * dst_step + st1 {v20.8b}, [x26], x4 Tile2LoopCheck: - cmp x24, #1 + cmp x24, #2 bge LoopDz_TILE_2 + cbz x24, Tile2Check +LoopDz4_TILE_2: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + movi v12.4s, #0 + movi v13.4s, #0 +LoopSz4_TILE_2: + ld1 {v8.16b, v9.16b}, [x12] + ld1 {v4.16b}, [x11], x22 // src + + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + subs x13, x13, #1 + add x12, x12, #64 + bne LoopSz4_TILE_2 +LoopSz4End_TILE_2: + add x25, x25, x15 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + scvtf v0.4s, v0.4s + scvtf v1.4s, v1.4s + +Tile2Quan_L4: + ld1 {v20.4s}, [x19] + ld1 {v22.d}[0], [x27] // x kernel sum + ld1 {v25.4s}, [x6] // weight quan zeropoint + fmul v0.4s, v0.4s, v20.4s + fmul v1.4s, v1.4s, v20.4s + + cbz x10, TILE2_MLA_L4 + ld1 {v27.d}[0], [x10] + fmul v0.4s, v0.4s, v27.s[0] + fmul v1.4s, v1.4s, v27.s[1] + + TILE2_MLA_L4: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + + cbnz w23, Tile2QuanUseInt8_L4 + + TILE2_ADD_BIAS_L4: + cbz x9, TILE2_ADD_DSTV_L4 + ld1 {v16.4s}, [x20] // bias + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + b TILE2_POST_L4 + + TILE2_ADD_DSTV_L4: + ld1 {v18.4s, v19.4s}, [x26] + fadd v0.4s, v0.4s, v18.4s + fadd v1.4s, v1.4s, v19.4s + + TILE2_POST_L4: + cbz x14, TILE2_STORE_L4 + ReLU_FP32_2 v0, v1, v30, v31 + TILE2_STORE_L4: + st1 {v0.4s, v1.4s}, [x26], x4 + b Tile2Check + + Tile2QuanUseInt8_L4: + ld1 {v16.4s}, [x20] // bias + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fcvtas v0.4s, v0.4s + fcvtas v1.4s, v1.4s + sqxtn v6.4h, v0.4s + sqxtn2 v6.8h, v1.4s + sqxtn v19.8b, v6.8h + smax v19.8b, v30.8b, v19.8b + smin v19.8b, v31.8b, v19.8b + st1 {v19.8b}, [x26], x4 // dst += dz * dst_step + +Tile2Check: +cbz x10, Tile2End +add x10, x10, #8 Tile2End: sub x7, x7, #2 add x0, x0, x21, LSL #1 add x1, x1, #16 + add x27, x27, #8 TILE_1: cmp x7, #1 @@ -547,55 +1268,168 @@ TILE_1: mov x25, x2 // weight mov x19, x8 // scale mov x20, x9 // bias + mov x6, x28 // weightQuanBias +cmp x5, #2 +blt LoopDz4_TILE_1 LoopDz_TILE_1: - ld1 {v0.4s}, [x20], #16 // bias - mov x11, x1 // src - mov x12, x25 // weight - mov x13, x3 // src_depth_quad - mov v1.16b, v0.16b - uzp1 v30.2d, v0.2d, v1.2d // bias_0, bias_1, bias_0, bias_1 - uzp2 v31.2d, v0.2d, v1.2d // bias_2, bias_3, bias_2, bias_3 + //ld1 {v0.4s}, [x20], #16 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + + movi v16.4s, #0 + movi v17.4s, #0 + movi v18.4s, #0 + movi v19.4s, #0 LoopSz_TILE_1: - // src : 1 x [1 x 8] : v2 - // weight : 2 x [2 x 8] : v0-1 - // dst : 1 x 2 x [2] : v30-v31 - ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ld1 {v8.16b, v9.16b, v10.16b, v11.16b}, [x12], #64 // weight ld1 {v2.8b}, [x11], x22 // src - .inst 0x4e80a45e // smmla v30.4s, v2.16b, v0.16b - .inst 0x4e81a45f // smmla v31.4s, v2.16b, v1.16b subs x13, x13, #1 + + .inst 0x4e88a450 // smmla v16.4s, v2.16b, v8.16b + .inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b + .inst 0x4e8aa452 // smmla v18.4s, v2.16b, v10.16b + .inst 0x4e8ba453 // smmla v19.4s, v2.16b, v11.16b bne LoopSz_TILE_1 LoopSzEnd_TILE_1: add x25, x25, x15 - sub x24, x24, #1 - uzp1 v29.2d, v30.2d, v31.2d - uzp2 v30.2d, v30.2d, v31.2d - scvtf v29.4s, v29.4s - scvtf v30.4s, v30.4s + sub x24, x24, #2 + uzp1 v27.2d, v16.2d, v17.2d + uzp1 v26.2d, v18.2d, v19.2d + scvtf v27.4s, v27.4s + scvtf v26.4s, v26.4s Tile1Quan: - ld1 {v0.4s}, [x19], #16 // scale - fmul v29.4s, v29.4s, v0.4s - fmul v30.4s, v30.4s, v0.4s - cmp w23, #1 - beq Tile1QuanUseInt8 - st1 {v29.4s, v30.4s}, [x26], x4 + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v6.s}[0], [x27] // x kernel sum + ld1 {v8.4s, v9.4s}, [x6], #32 // weight quan zeropoint + fmul v27.4s, v27.4s, v0.4s + fmul v26.4s, v26.4s, v1.4s + + cbz x10, TILE1_MLA + ld1 {v10.s}[0], [x10] + fmul v27.4s, v27.4s, v10.s[0] + fmul v26.4s, v26.4s, v10.s[0] + + TILE1_MLA: + MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7 + + cbnz w23, Tile1QuanUseInt8 + + TILE1_ADD_BIAS: + cbz x9, TILE1_ADD_DSTV + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v27.4s, v27.4s, v16.4s + fadd v26.4s, v26.4s, v17.4s + b TILE1_POST + + TILE1_ADD_DSTV: + ld1 {v16.4s}, [x26], x4 + ld1 {v17.4s}, [x26] + fadd v27.4s, v27.4s, v16.4s + fadd v26.4s, v26.4s, v17.4s + sub x26, x26, x4 + + TILE1_POST: + cbz x14, TILE1_STORE + fmin v27.4s, v27.4s, v31.4s + fmax v27.4s, v27.4s, v30.4s + fmin v26.4s, v26.4s, v31.4s + fmax v26.4s, v26.4s, v30.4s + + TILE1_STORE: + st1 {v27.4s}, [x26], x4 + st1 {v26.4s}, [x26], x4 b Tile1LoopEnd + Tile1QuanUseInt8: - fcvtas v29.4s, v29.4s - fcvtas v30.4s, v30.4s - sqxtn v6.4h, v29.4s - sqxtn2 v6.8h, v30.4s - sqxtn v19.8b, v6.8h - smax v19.16b, v10.16b, v19.16b - smin v19.16b, v11.16b, v19.16b - st1 {v19.s}[0], [x26], x4 // dst += dz * dst_step + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v27.4s, v27.4s, v16.4s + fadd v26.4s, v26.4s, v17.4s + fcvtas v27.4s, v27.4s + fcvtas v26.4s, v26.4s + sqxtn v6.4h, v27.4s + sqxtn v7.4h, v26.4s + sqxtn v6.8b, v6.8h + sqxtn v7.8b, v7.8h + smax v6.16b, v30.16b, v6.16b + smin v6.16b, v31.16b, v6.16b + smax v7.16b, v30.16b, v7.16b + smin v7.16b, v31.16b, v7.16b + st1 {v6.s}[0], [x26], x4 // dst += dz * dst_step + st1 {v7.s}[0], [x26], x4 Tile1LoopEnd: - cmp x24, #1 + cmp x24, #2 bge LoopDz_TILE_1 + cbz x24, End + +LoopDz4_TILE_1: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + + movi v16.4s, #0 + movi v17.4s, #0 +LoopSz4_TILE_1: + ld1 {v8.16b, v9.16b}, [x12] // weight + ld1 {v2.8b}, [x11], x22 // src + subs x13, x13, #1 + add x12, x12, #64 + .inst 0x4e88a450 // smmla v16.4s, v2.16b, v8.16b + .inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b + bne LoopSz4_TILE_1 +LoopSz4End_TILE_1: + add x25, x25, x15 + uzp1 v27.2d, v16.2d, v17.2d + scvtf v27.4s, v27.4s + +Tile1Quan_L4: + ld1 {v0.4s}, [x19] // scale + ld1 {v6.s}[0], [x27] // x kernel sum + ld1 {v8.4s}, [x6] // weight quan zeropoint + fmul v27.4s, v27.4s, v0.4s + cbz x10, TILE1_MLA_L4 + ld1 {v10.s}[0], [x10] + fmul v27.4s, v27.4s, v10.s[0] + + TILE1_MLA_L4: + MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 + + cbnz w23, Tile1QuanUseInt8_L4 + + TILE1_ADD_BIAS_L4: + cbz x9, TILE1_ADD_DSTV_L4 + ld1 {v16.4s}, [x20] // bias + fadd v27.4s, v27.4s, v16.4s + b TILE1_POST_L4 + + TILE1_ADD_DSTV_L4: + ld1 {v16.4s}, [x26] + fadd v27.4s, v27.4s, v16.4s + + TILE1_POST_L4: + cbz x14, TILE1_STORE_L4 + fmin v27.4s, v27.4s, v31.4s + fmax v27.4s, v27.4s, v30.4s + + TILE1_STORE_L4: + st1 {v27.4s}, [x26], x4 + b End + + Tile1QuanUseInt8_L4: + ld1 {v16.4s}, [x20] // bias + fadd v27.4s, v27.4s, v16.4s + fcvtas v27.4s, v27.4s + sqxtn v6.4h, v27.4s + sqxtn v6.8b, v6.8h + smax v6.8b, v30.8b, v6.8b + smin v6.8b, v31.8b, v6.8b + st1 {v6.s}[0], [x26], x4 // dst += dz * dst_step End: +ldp x27, x28, [sp, #(16 * 8)] ldp x25, x26, [sp, #(16 * 7)] ldp x23, x24, [sp, #(16 * 6)] ldp x19, x20, [sp, #(16 * 5)] @@ -603,7 +1437,7 @@ ldp x21, x22, [sp, #(16 * 4)] ldp d8, d9, [sp, #(16 * 3)] ldp d10, d11, [sp, #(16 * 2)] ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 8) +ldp d14, d15, [sp], #(16 * 10) ret #endif // __aarch64__ diff --git a/source/backend/cpu/arm/arm64/MNNLineDepthWiseInt8AddBiasScaleUnit.S b/source/backend/cpu/arm/arm64/MNNLineDepthWiseInt8AddBiasScaleUnit.S index 7b7f141bd..44c7766c5 100644 --- a/source/backend/cpu/arm/arm64/MNNLineDepthWiseInt8AddBiasScaleUnit.S +++ b/source/backend/cpu/arm/arm64/MNNLineDepthWiseInt8AddBiasScaleUnit.S @@ -38,7 +38,7 @@ asm_function MNNLineDepthWiseInt8AddBiasScaleUnit ldr x8, [sp, #0] ldr x9, [sp, #8] -str d14, [sp, #(-16 * 9)]! +str d14, [sp, #(-16 * 10)]! stp d12, d13, [sp, #(16 * 1)] stp d10, d11, [sp, #(16 * 2)] stp d8, d9, [sp, #(16 * 3)] @@ -53,7 +53,7 @@ ldr w11, [x3, #16] dup v30.16b, w11 // max ldr w11, [x3, #20] dup v31.16b, w11 // min -ldr x3, [x3, #8] +ldr x3, [x3, #72] mul x10, x6, x8 sub x9, x9, x10 @@ -711,7 +711,7 @@ ldp x27, x28, [sp, #(16 * 4)] ldp d8, d9, [sp, #(16 * 3)] ldp d10, d11, [sp, #(16 * 2)] ldp d12, d13, [sp, #(16 * 1)] -ldr d14, [sp], #(16 * 9) +ldr d14, [sp], #(16 * 10) ret #endif diff --git a/source/backend/cpu/arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S b/source/backend/cpu/arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S index 362b863ca..27a59d20d 100644 --- a/source/backend/cpu/arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S +++ b/source/backend/cpu/arm/arm64/MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3.S @@ -19,15 +19,22 @@ asm_function MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3 // size_t dilateY_step, int8_t* idx) { // kernelx=3, kernely=3,dilatex=1,dilatey=1 - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// float roundValuePos = 0.5f; -// float roundValueNeg = -0.5f; -//}; +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + +};*/ // Auto Load: // x0: dst*, x1: src*, x2: weight*, x3: parameters* @@ -50,7 +57,7 @@ stp x19, x20, [sp, #(16 * 5)] ldr x19, [x3, #0] // scale ldr w20, [x3, #16] // max ldr w15, [x3, #20] // min -ldr x3, [x3, #8] // bias +ldr x3, [x3, #72] // bias ld1 {v24.16b, v25.16b, v26.16b, v27.16b}, [x14] ld1 {v0.16b, v1.16b}, [x2], #32 // v0,v1:weight ld1 {v12.s}[0], [x2] // weight: k:8 @@ -684,4 +691,4 @@ ldp d12, d13, [sp, #(16 * 1)] ldp d14, d15, [sp], #(16 * 6) ret -#endif \ No newline at end of file +#endif diff --git a/source/backend/cpu/arm/arm64/MNNPackC4Int8ForMatMulA_ARM82.S b/source/backend/cpu/arm/arm64/MNNPackC4Int8ForMatMulA_ARM82.S new file mode 100644 index 000000000..db81f8a03 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNPackC4Int8ForMatMulA_ARM82.S @@ -0,0 +1,202 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro SET_0 s0, s1, s2, s3 + movi \s0\().4s, #0 + movi \s1\().4s, #0 + movi \s2\().4s, #0 + movi \s3\().4s, #0 +.endm + +/* +struct SumByAxisParams { + ssize_t kernelCountUnitDouble; + ssize_t col_buffer_unit_size; + ssize_t DST_XUNIT; + ssize_t SRC_UNIT; + ssize_t blockNum; + ssize_t oneScale; +}; + */ + +asm_function MNNSumByAxisLForMatmul_A_ARM82 +// MNNSumByAxisLForMatmul_A_ARM82(float_t* dest, int8_t* source, float* dequantScale, ssize_t realDstCount, +// ssize_t kernelCountUnitDouble, ssize_t col_buffer_unit_size, ssize_t EP, ssize_t LP, ssize_t blockNum, ssize_t oneScale); +// x0: dest, x1: source, x2: dequantScale, x3: realDstCount, x4: sumParams +// x4: kernelCountUnitDouble, x5: col_buffer_unit_size +// Load from sp: x8: blockNum + +ldr x8, [x4, #32] // blockNum +ldr x5, [x4, #40] // oneScale +ldr x4, [x4, #0] // kernelCountUnitDouble + +//ldr x8, [sp, #0] // blockNum + +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v31.16b, #1 +ld1r {v30.4s}, [x2] // Dequant scale +mov x6, #48 // EP*LP +sdiv x4, x4, x8 // src_depth_quad per block + +TILE_12: +cmp x3, #12 +blt Remain + +mov x9, x8 // blockNum +cbnz x5, TILE12_BLOCK_NUM +ld1 {v13.4s, v14.4s, v15.4s}, [x2], #48 // batch quant scale + +TILE12_BLOCK_NUM: +mov x15, x4 // kernelCountUnitDouble + +movi v10.4s, #0 +movi v11.4s, #0 +movi v12.4s, #0 + +TILE12_BLOCK_INNER: + +ld1 {v0.16b, v1.16b, v2.16b}, [x1], #48 // E: 0,1,2,3,...,11 +subs x15, x15, #1 + +.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0, E1, E2, E3 +.inst 0x4e8197eb // sdot v11.4s, v31.16b, v1.16b +.inst 0x4e8297ec // sdot v12.4s, v31.16b, v2.16b + +bne TILE12_BLOCK_INNER + +TILE12_BLOCK_INNER_END: +subs x9, x9, #1 // blockNum-- + +scvtf v10.4s, v10.4s +scvtf v11.4s, v11.4s +scvtf v12.4s, v12.4s + +cbnz x5, TILE12_MUL_ONE_SCALE +fmul v10.4s, v10.4s, v13.4s +fmul v11.4s, v11.4s, v14.4s +fmul v12.4s, v12.4s, v15.4s +b TILE12_STORE + +TILE12_MUL_ONE_SCALE: +fmul v10.4s, v10.4s, v30.4s +fmul v11.4s, v11.4s, v30.4s +fmul v12.4s, v12.4s, v30.4s + +TILE12_STORE: +st1 {v10.4s, v11.4s, v12.4s}, [x0], #48 +bne TILE12_BLOCK_NUM + +TILE12_END: +subs x3, x3, #12 // realDstCount-=12 +bne TILE_12 + + +Remain: // remain realDstCount < EP +cbz x3, End +/* x11: Remain dstCount step for each block */ +lsl x11, x3, #2 + +TILE_2: // realDstCount >= 1 +cmp x3, #2 +blt TILE_1 + +mov x7, x1 +mov x9, x8 // blockNum +mov x10, x0 // tag dst address + +cbnz x5, TILE2_BLOCK_NUM +ld1 {v13.d}[0], [x2], #8 // batch quant scale + +TILE2_BLOCK_NUM: +mov x15, x4 // kernelCountUnitDouble +movi v10.4s, #0 + +TILE2_BLOCK_INNER: +ld1 {v0.d}[0], [x7] // E: 0,1 +add x7, x7, x6 +subs x15, x15, #1 +.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0 +bne TILE2_BLOCK_INNER + +TILE2_BLOCK_INNER_ENd: +scvtf v10.4s, v10.4s + +cbnz x5, TILE2_MUL_ONE_SCALE +fmul v10.4s, v10.4s, v13.4s +b TILE2_STORE + +TILE2_MUL_ONE_SCALE: +fmul v10.4s, v10.4s, v30.4s + +TILE2_STORE: +subs x9, x9, #1 // blockNum-- +st1 {v10.d}[0], [x10], x11 +bne TILE2_BLOCK_NUM + +TILE2_END: +sub x3, x3, #2 // realDstCount-=2 +add x1, x1, #8 // LP * 2 +add x0, x0, #8 // finish remain 2 +b TILE_2 + + +TILE_1: // realDstCount >= 1 +cmp x3, #1 +blt End + +mov x7, x1 +mov x9, x8 // blockNum +mov x10, x0 + +cbnz x5, TILE1_BLOCK_NUM +ld1 {v13.s}[0], [x2], #4 // batch quant scale + +TILE1_BLOCK_NUM: +mov x15, x4 // kernelCountUnitDouble +movi v10.4s, #0 + +TILE1_BLOCK_INNER: +ld1 {v0.s}[0], [x7] // E: 0 +subs x15, x15, #1 +add x7, x7, x6 +.inst 0x4e8097ea // sdot v10.4s, v31.16b, v0.16b // sum LP axis for E0 + +bne TILE1_BLOCK_INNER + +TILE1_BLOCK_INNER_END: +scvtf v10.4s, v10.4s + +cbnz x5, TILE1_MUL_ONE_SCALE +fmul v10.4s, v10.4s, v13.4s +b TILE1_STORE + +TILE1_MUL_ONE_SCALE: +fmul v10.4s, v10.4s, v30.4s + +TILE1_STORE: +subs x9, x9, #1 // blockNum-- +st1 {v10.s}[0], [x10], x11 +bne TILE1_BLOCK_NUM + +TILE1_END: +sub x3, x3, #1 // realDstCount-=1 +add x1, x1, #4 // LP * 1 +add x0, x0, #4 // finish remain 1 + +b TILE_1 + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif \ No newline at end of file diff --git a/source/backend/cpu/arm/arm64/MNNPackC4Int8ForMatMulA_ARM86.S b/source/backend/cpu/arm/arm64/MNNPackC4Int8ForMatMulA_ARM86.S new file mode 100644 index 000000000..803166f17 --- /dev/null +++ b/source/backend/cpu/arm/arm64/MNNPackC4Int8ForMatMulA_ARM86.S @@ -0,0 +1,318 @@ +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro SET_0 s0, s1, s2, s3 + movi \s0\().4s, #0 + movi \s1\().4s, #0 + movi \s2\().4s, #0 + movi \s3\().4s, #0 +.endm + +/* +struct SumByAxisParams { + ssize_t kernelCountUnitDouble; + ssize_t col_buffer_unit_size; + ssize_t DST_XUNIT; + ssize_t SRC_UNIT; + ssize_t blockNum; + ssize_t oneScale; +}; + */ + +asm_function MNNSumByAxisLForMatmul_A_ARM86 +// MNNSumByAxisLForMatmul_A_ARM86(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); +// x0: dest, x1: source, x2: dequantScale, x3: realDstCount, x4: sumParams +// Load from sp: x6: blockNum + +ldr x6, [x4, #32] // blockNum +ldr x12, [x4, #40] // oneScale +ldr x5, [x4, #0] // kernelCountUnitDouble + +stp d14, d15, [sp, #(-16 * 4)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] + +movi v31.16b, #1 +ld1r {v30.4s}, [x2] // dequant scale +mov x8, #80 // EP*LP +sdiv x5, x5, x6 // src_depth_quad_per_block + +START: +lsl x11, x3, #2 + +cmp x3, #1 +beq TILE_1 + +TILE_10: // realDstCount >= EP(10) +cmp x3, #10 +blt Remain +mov x9, x6 // blockNum + +cbnz x12, TILE10_BLOCK_NUM +ld1 {v5.4s, v6.4s}, [x2], #32 +ld1 {v7.d}[0], [x2] +sub x2, x2, #32 + +TILE10_BLOCK_NUM: +cbz x9, TILE10_END + +mov x15, x5 // kernelCountUnitDouble of a block +SET_0 v10, v11, v12, v13 +movi v14.4s, #0 + +TILE10_BLOCK_SRC_QUAD: + +//Loop_EPxLP: // EP*LP=10*8 +ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x1], #64 // E: 0,1,...,7 +ld1 {v4.16b}, [x1], #16 // E: 8,9 +subs x15, x15, #1 + +.inst 0x4e80a7ea // smmla v10.4s, v31.16b, v0.16b // sum LP axis for E0 and E1 +.inst 0x4e81a7eb // smmla v11.4s, v31.16b, v1.16b +.inst 0x4e82a7ec // smmla v12.4s, v31.16b, v2.16b +.inst 0x4e83a7ed // smmla v13.4s, v31.16b, v3.16b +.inst 0x4e84a7ee // smmla v14.4s, v31.16b, v4.16b + +bne TILE10_BLOCK_SRC_QUAD + +TILE10_PER_BLOCK_END: +sub x9, x9, #1 // blockNum-- + +trn1 v20.2d, v10.2d, v11.2d +trn1 v21.2d, v12.2d, v13.2d + +scvtf v20.4s, v20.4s +scvtf v21.4s, v21.4s +scvtf v14.4s, v14.4s + +cbnz x12, TILE10_ONE_SCALE +fmul v20.4s, v20.4s, v5.4s +fmul v21.4s, v21.4s, v6.4s +fmul v14.4s, v14.4s, v7.4s +b TILE10_STORE + +TILE10_ONE_SCALE: +fmul v20.4s, v20.4s, v30.4s +fmul v21.4s, v21.4s, v30.4s +fmul v14.4s, v14.4s, v30.4s + +TILE10_STORE: +st1 {v20.4s, v21.4s}, [x0], #32 +st1 {v14.d}[0], [x0], #8 +b TILE10_BLOCK_NUM // Finish one block + +TILE10_END: +sub x3, x3, #10 // realDstCount-=10 +b TILE_10 + + +Remain: // remain realDstCount < EP +cbz x3, End + +lsl x11, x3, #2 +/* For remain dstCount, each E's block step is x11. */ +TILE_8: // realDstCount >= 8 +cmp x3, #8 +blt TILE_4 + +mov x7, x1 // tag begin src address for Remain8 +mov x10, x0 // tag begin dst address for Remain8 +mov x9, x6 // blockNum + +cbnz x12, TILE8_BLOCK_NUM +ld1 {v5.4s, v6.4s}, [x2], #32 + +TILE8_BLOCK_NUM: +cbz x9, TILE8_END +mov x15, x5 // kernelCountUnitDouble + +SET_0 v10, v11, v12, v13 + +TILE8_BLOCK_SRC_QUAD: + +ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x7] // E: 0,1,...,7 +subs x15, x15, #1 +add x7, x7, x8 // x7=x7+EP*LP +.inst 0x4e80a7ea // smmla v10.4s, v31.16b, v0.16b // sum LP axis for E0 and E1 +.inst 0x4e81a7eb // smmla v11.4s, v31.16b, v1.16b +.inst 0x4e82a7ec // smmla v12.4s, v31.16b, v2.16b +.inst 0x4e83a7ed // smmla v13.4s, v31.16b, v3.16b + +bne TILE8_BLOCK_SRC_QUAD + +TILE8_PER_BLOCK_END: +sub x9, x9, #1 // blockNum-- + +trn1 v20.2d, v10.2d, v11.2d +trn1 v21.2d, v12.2d, v13.2d + +scvtf v20.4s, v20.4s +scvtf v21.4s, v21.4s + +cbnz x12, TILE8_ONE_SCALE +fmul v20.4s, v20.4s, v5.4s +fmul v21.4s, v21.4s, v6.4s +b TILE8_STORE + +TILE8_ONE_SCALE: +fmul v20.4s, v20.4s, v30.4s +fmul v21.4s, v21.4s, v30.4s + +TILE8_STORE: +st1 {v20.4s, v21.4s}, [x10], x11 // Go to next block for this 8 remain. +b TILE8_BLOCK_NUM + +TILE8_END: +add x0, x0, #32 // finish 8 dstCount * sizeof(float) +sub x3, x3, #8 // realDstCount-=8 +add x1, x1, #64 // LP*8 + + +TILE_4: // realDstCount >= 4 +cmp x3, #4 +blt TILE_2 + +mov x7, x1 // tag begin src address for Remain4 +mov x10, x0 // tag begin dst address for Remain4 +mov x9, x6 // blockNum + +cbnz x12, TILE4_BLOCK_NUM +ld1 {v5.4s}, [x2], #16 + +TILE4_BLOCK_NUM: +cbz x9, TILE4_END +mov x15, x5 // kernelCountUnitDouble +movi v10.4s, #0 +movi v11.4s, #0 + +TILE4_BLOCK_SRC_QUAD: + +ld1 {v0.16b, v1.16b}, [x7] // E: 0,1,2,3 +subs x15, x15, #1 +add x7, x7, x8 +.inst 0x4e80a7ea // smmla v10.4s, v31.16b, v0.16b // sum LP axis for E0 and E1 +.inst 0x4e81a7eb // smmla v11.4s, v31.16b, v1.16b + +bne TILE4_BLOCK_SRC_QUAD + +TILE4_PER_BLOCK_END: +sub x9, x9, #1 // blockNum-- + +trn1 v20.2d, v10.2d, v11.2d +scvtf v20.4s, v20.4s + +cbnz x12, TILE4_ONE_SCALE +fmul v20.4s, v20.4s, v5.4s +b TILE4_STORE +TILE4_ONE_SCALE: +fmul v20.4s, v20.4s, v30.4s +TILE4_STORE: +st1 {v20.4s}, [x10], x11 +b TILE4_BLOCK_NUM + +TILE4_END: +add x0, x0, #16 // finish 4 dstCount * sizeof(float) +sub x3, x3, #4 // realDstCount-=4 +add x1, x1, #32 // LP*4 + +TILE_2: // realDstCount >= 2 +cmp x3, #2 +blt TILE_1 + +mov x7, x1 // tag begin src address for Remain8 +mov x10, x0 // tag begin dst address for Remain8 +mov x9, x6 // blockNum + +cbnz x12, TILE2_BLOCK_NUM +ld1 {v5.d}[0], [x2], #8 +TILE2_BLOCK_NUM: +cbz x9, TILE2_END +mov x15, x5 // kernelCountUnitDouble + +movi v10.4s, #0 + +TILE2_BLOCK_SRC_QUAD: + +ld1 {v0.16b}, [x7] // E: 0,1 +subs x15, x15, #1 +add x7, x7, x8 + +.inst 0x4e80a7ea // smmla v10.4s, v31.16b, v0.16b // sum LP axis for E0 and E1 + +bne TILE2_BLOCK_SRC_QUAD + +TILE2_PER_BLOCK_END: +sub x9, x9, #1 // blockNum-- + +scvtf v10.4s, v10.4s +cbnz x12, TILE2_ONE_SCALE +fmul v10.4s, v10.4s, v5.4s +b TILE2_STORE +TILE2_ONE_SCALE: +fmul v10.4s, v10.4s, v30.4s +TILE2_STORE: +st1 {v10.d}[0], [x10], x11 +b TILE2_BLOCK_NUM + +TILE2_END: +add x0, x0, #8 // finish 2 dstCount: 2 * sizeof(float32) +sub x3, x3, #2 // realDstCount-=2 +add x1, x1, #16 // LP * 2 * sizeof(int8_t) + +TILE_1: // realDstCount >= 1 +cmp x3, #1 +blt End + +mov x7, x1 // tag begin src address for Remain4 +mov x10, x0 // tag begin dst address for Remain4 +mov x9, x6 // blockNum + +cbnz x12, TILE1_BLOCK_NUM +ld1 {v5.s}[0], [x2], #4 + +TILE1_BLOCK_NUM: +cbz x9, TILE1_END +mov x15, x5 // kernelCountUnitDouble +movi v10.4s, #0 + +TILE1_BLOCK_SRC_QUAD: + +ld1 {v0.d}[0], [x7] // E: 0 +subs x15, x15, #1 +add x7, x7, x8 +.inst 0x4e80a7ea // smmla v10.4s, v31.16b, v0.16b // sum LP axis for E0 + +bne TILE1_BLOCK_SRC_QUAD + +TILE1_PER_BLOCK_END: +sub x9, x9, #1 // blockNum-- + +scvtf v10.4s, v10.4s + +cbnz x12, TILE1_ONE_SCALE +fmul v10.4s, v10.4s, v5.4s +b TILE1_STORE + +TILE1_ONE_SCALE: +fmul v10.4s, v10.4s, v30.4s +TILE1_STORE: +st1 {v10.s}[0], [x10], x11 +b TILE1_BLOCK_NUM + +TILE1_END: +sub x3, x3, #1 // realDstCount-=1 +add x1, x1, #8 // LP * 1 * sizeof(int8_t) +add x0, x0, #4 // 1 * sizeof(float) + +End: +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 4) +ret +#endif \ No newline at end of file diff --git a/source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx1.S b/source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx1.S index 12c11436a..119cb6c90 100644 --- a/source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx1.S +++ b/source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx1.S @@ -38,7 +38,10 @@ ldp x3, x9, [x3] // x3: eSize, x9: eP mov x8, x6 // x8: dataOffsetMap mov x7, x5 // x7: NNZMap -ldp x24, x6, [x4], #16 // x5: scale , x6: bias +ldr x24, [x4, #0] +ldr x6, [x4, #72] +add x4, x4, #16 +//ldp x24, x6, [x4], #16 // x5: scale , x6: bias lsr x14, x11, #2 lsl x14, x14, #2 // x14: (h / 4) * 4 ld2r {v13.4s, v14.4s}, [x4] // first two elements of x4 are pointers, 'max, min ' locate at [2], [3] diff --git a/source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx4.S b/source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx4.S index 5b506cd55..d99a3cfb2 100644 --- a/source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx4.S +++ b/source/backend/cpu/arm/arm64/MNNPackedSparseQuantMatMulEpx4.S @@ -38,7 +38,10 @@ ldp x3, x9, [x3] // x3: eSize, x9: eP mov x8, x6 // x8: dataOffsetMap mov x7, x5 // x7: NNZMap -ldp x24, x6, [x4], #16 // x5: scale , x6: bias +ldr x24, [x4] +ldr x6, [x4, #72] +add x4, x4, #16 +//ldp x24, x6, [x4], #16 // x5: scale , x6: bias lsr x14, x11, #2 lsl x14, x14, #2 // x14: (h / 4) * 4 ld2r {v13.4s, v14.4s}, [x4] // first two elements of x4 are pointers, 'max, min ' locate at [2], [3] diff --git a/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMulRemain_BF16.S b/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMulRemain_BF16.S index 2acfe6930..a5de45f88 100644 --- a/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMulRemain_BF16.S +++ b/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMulRemain_BF16.S @@ -19,13 +19,6 @@ movi \d3\().4s, #0 .endm -.macro Float32ToBf16 d0, d1, d2, d3 - shrn \d0\().4h, \d0\().4s, #16 - shrn \d1\().4h, \d1\().4s, #16 - shrn \d2\().4h, \d2\().4s, #16 - shrn \d3\().4h, \d3\().4s, #16 -.endm - .macro FOURFMAX s, d0, d1, d2, d3 fmax \d0\().4s, \d0\().4s, \s\().4s fmax \d1\().4s, \d1\().4s, \s\().4s @@ -50,12 +43,15 @@ asm_function ARMV86_MNNPackedMatMulRemain_BF16 //void ARMV86_MNNPackedMatMulRemain_BF16(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias); //Auto x0: C, x1:A, x2:B, x3:eSize, x4:parameter, x5:postParameters, x6:bias -sub sp, sp, #64 +sub sp, sp, #96 str x19, [sp, #0] str x20, [sp, #8] str x21, [sp, #16] str x22, [sp, #24] +stp d9, d10, [sp, #32] +str d15, [sp, #64] ldr x11, [x4, #0] // aStride +lsr x11, x11, #1 // aStride->bf16 stride ldr x9, [x4, #8] // l ldr x10, [x4, #16] // h lsl x11, x11, #2 // aStride * 4 @@ -63,6 +59,7 @@ mov x22, #64 // B_stride = LP * HP = 4 * 8 * sizeof(int16_t) ldr x7, [x4, #24] // cStride ldr x19, [x4, #40] // bExtraStride +lsr x19, x19, #1 // bExtraStride->bf16 stride add x10, x10, #3 lsr x10, x10, #2 @@ -89,14 +86,12 @@ LoopE8: // e, TILE_BLOCK size is 8 LH8: cmp x8, #2 // h/4 > 2 blt LH4 - sub x14, x7, #64 // cStride - 64 + sub x14, x7, #128 // cStride - 8 * 4 * sizeof(float) LoopH8x8: mov x15, x1 // src, A mov x12, x9 // l cbz x5, NoBiasLH8 - ld1 {v0.4h, v1.4h}, [x20], #16 // 8 * sizeof(int16_t) - shll v0.4s, v0.4h, #16 - shll v1.4s, v1.4h, #16 + ld1 {v0.4s, v1.4s}, [x20], #32 // 8 * sizeof(float) mov v2.16b, v0.16b mov v3.16b, v1.16b uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 @@ -165,15 +160,11 @@ LoopE8: // e, TILE_BLOCK size is 8 FOURFMIN v10, v23, v24, v25, v26 FOURFMIN v10, v27, v28, v29, v30 StoreLH8: - Float32ToBf16 v15, v16, v17, v18 - Float32ToBf16 v19, v20, v21, v22 - Float32ToBf16 v23, v24, v25, v26 - Float32ToBf16 v27, v28, v29, v30 - st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x0], #64 // 16 * sizeof(float) + st1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x0], #64 // 16 * sizeof(float) add x0, x0, x14 - st1 {v23.4h, v24.4h, v25.4h, v26.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v27.4h, v28.4h, v29.4h, v30.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v23.4s, v24.4s, v25.4s, v26.4s}, [x0], #64 // 16 * sizeof(float) + st1 {v27.4s, v28.4s, v29.4s, v30.4s}, [x0], #64 // 16 * sizeof(float) add x0, x0, x14 add x13, x13, x19 // weight stride sub x8, x8, #2 @@ -185,8 +176,7 @@ LoopE8: // e, TILE_BLOCK size is 8 mov x15, x1 mov x12, x9 cbz x5, NoBiasHRemain - ld1 {v0.4h}, [x20] - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x20] mov v2.16b, v0.16b uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 @@ -228,14 +218,12 @@ LoopE8: // e, TILE_BLOCK size is 8 FOURFMIN v10, v15, v16, v17, v18 FOURFMIN v10, v19, v20, v21, v22 StoreLH8x4: - Float32ToBf16 v15, v16, v17, v18 - Float32ToBf16 v19, v20, v21, v22 - st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x0], #64 // 16 * sizeof(int16_t) + st1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x0], #64 // 16 * sizeof(int16_t) E8End: sub x3, x3, #8 cmp x3, #8 - add x0, x21, #64 // move dest address of 8 * 4 * sizeof(int16_t) + add x0, x21, #128 // move dest address of 8 * 4 * sizeof(float) add x1, x1, #64 // move A matrix address of 8 * 4 * sizeof(int16_t) bge LoopE8 @@ -255,9 +243,7 @@ E4LH8: mov x15, x1 mov x12, x9 cbz x5, NoBiasE4 - ld1 {v0.4h, v1.4h}, [x20], #16 // 8 * sizeof(int16_t) - shll v0.4s, v0.4h, #16 - shll v1.4s, v1.4h, #16 + ld1 {v0.4s, v1.4s}, [x20], #32 // 8 * sizeof(float) mov v2.16b, v0.16b mov v3.16b, v1.16b uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 @@ -304,10 +290,8 @@ E4LH8: FOURFMIN v10, v15, v16, v17, v18 FOURFMIN v10, v19, v20, v21, v22 StoreLH4x8: - Float32ToBf16 v15, v16, v17, v18 - Float32ToBf16 v19, v20, v21, v22 - st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], x7 // 16 * sizeof(int16_t) - st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], x7 // 16 * sizeof(int16_t) + st1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x0], x7 + st1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x0], x7 add x13, x13, x19 // weight stride sub x8, x8, #2 cmp x8, #2 @@ -317,8 +301,7 @@ E4LH8: mov x15, x1 mov x12, x9 cbz x5, NoBiasE4R - ld1 {v0.4h}, [x20] - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x20] mov v2.16b, v0.16b uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 @@ -347,13 +330,12 @@ E4LH8: cbz x5, StoreLH4x4 PostTreatLH4x4: FOURFMAX v9, v15, v16, v17, v18 - FOURFMIN v10, v19, v20, v21, v22 + FOURFMIN v10, v15, v16, v17, v18 StoreLH4x4: - Float32ToBf16 v15, v16, v17, v18 - st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0] // 16 * sizeof(int16_t) + st1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x0] E4End: sub x3, x3, #4 - add x0, x21, #32 // move dest address of 4 * 4 * sizeof(int16_t) + add x0, x21, #64 // move dest address of 4 * 4 * sizeof(float) add x1, x1, #32 // move dest address of 4 * 4 * sizeof(int16_t) E2: @@ -372,9 +354,7 @@ E2LH8: mov x15, x1 mov x12, x9 cbz x5, NoBiasE2 - ld1 {v0.4h, v1.4h}, [x20], #16 - shll v0.4s, v0.4h, #16 - shll v1.4s, v1.4h, #16 + ld1 {v0.4s, v1.4s}, [x20], #32 mov v2.16b, v0.16b mov v3.16b, v1.16b uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 @@ -406,9 +386,8 @@ E2LH8: FOURFMAX v9, v15, v16, v17, v18 FOURFMIN v10, v15, v16, v17, v18 StoreLH2x8: - Float32ToBf16 v15, v16, v17, v18 - st1 {v15.4h, v16.4h}, [x0], x7 // 8 * sizeof(int16_t) - st1 {v17.4h, v18.4h}, [x0], x7 // 8 * sizeof(int16_t) + st1 {v15.4s, v16.4s}, [x0], x7 // 8 * sizeof(int16_t) + st1 {v17.4s, v18.4s}, [x0], x7 // 8 * sizeof(int16_t) add x13, x13, x19 // weight stride sub x8, x8, #2 cmp x8, #2 @@ -418,8 +397,7 @@ E2LH8: mov x15, x1 mov x12, x9 cbz x5, NoBiasE2R - ld1 {v0.4h}, [x20] - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x20] mov v2.16b, v0.16b uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 @@ -447,12 +425,10 @@ E2LH8: fmin v15.4s, v15.4s, v10.4s fmin v16.4s, v16.4s, v10.4s StoreLH2x4: - shrn v15.4h, v15.4s, #16 - shrn v16.4h, v16.4s, #16 - st1 {v15.4h, v16.4h}, [x0] // 8 * sizeof(int16_t) + st1 {v15.4s, v16.4s}, [x0] E2End: sub x3, x3, #2 - add x0, x21, #16 // move dest address of 2 * 4 * sizeof(int16_t) + add x0, x21, #32 // move dest address of 2 * 4 * sizeof(float) add x1, x1, #16 // move dest address of 2 * 4 * sizeof(int16_t) E1: @@ -473,9 +449,7 @@ LoopE1: mov x15, x1 mov x12, x9 cbz x5, NoBiasE1 - ld1 {v0.4h, v1.4h}, [x20], #16 - shll v0.4s, v0.4h, #16 - shll v1.4s, v1.4h, #16 + ld1 {v0.4s, v1.4s}, [x20], #32 mov v2.16b, v0.16b mov v3.16b, v1.16b uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 @@ -508,10 +482,8 @@ LoopE1: fmin v15.4s, v15.4s, v10.4s fmin v16.4s, v16.4s, v10.4s StoreLH1x8: - shrn v15.4h, v15.4s, #16 - shrn v16.4h, v16.4s, #16 - st1 {v15.4h}, [x0], x7 - st1 {v16.4h}, [x0], x7 + st1 {v15.4s}, [x0], x7 + st1 {v16.4s}, [x0], x7 add x13, x13, x19 sub x8, x8, #2 cmp x8, #2 @@ -522,8 +494,7 @@ LoopE1: mov x15, x1 mov x12, x9 cbz x5, NoBiasE1R - ld1 {v0.4h}, [x20] - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x20] mov v2.16b, v0.16b uzp1 v16.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 uzp2 v17.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 @@ -548,19 +519,20 @@ LoopE1: fmax v15.4s, v15.4s, v9.4s fmin v15.4s, v15.4s, v10.4s StoreLH1x4: - shrn v15.4h, v15.4s, #16 - st1 {v15.4h}, [x0] + st1 {v15.4s}, [x0] E1End: subs x3, x3, #1 - add x0, x21, #8 + add x0, x21, #16 // 4 * sizeof(float) add x1, x1, #8 bne LoopE1 End: +ldr d15, [sp, #64] +ldp d9, d10, [sp, #32] ldr x19, [sp, #0] ldr x20, [sp, #8] ldr x21, [sp, #16] ldr x22, [sp, #24] -add sp, sp, #64 +add sp, sp, #96 ret #endif diff --git a/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMul_BF16.S b/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMul_BF16.S index 7d3282969..567e34b56 100644 --- a/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMul_BF16.S +++ b/source/backend/cpu/arm/arm64/bf16/ARMV86_MNNPackedMatMul_BF16.S @@ -19,13 +19,6 @@ movi \d3\().4s, #0 .endm -.macro Float32ToBf16 d0, d1, d2, d3 - shrn \d0\().4h, \d0\().4s, #16 - shrn \d1\().4h, \d1\().4s, #16 - shrn \d2\().4h, \d2\().4s, #16 - shrn \d3\().4h, \d3\().4s, #16 -.endm - .macro FOURFMAX s, d0, d1, d2, d3 fmax \d0\().4s, \d0\().4s, \s\().4s fmax \d1\().4s, \d1\().4s, \s\().4s @@ -51,11 +44,11 @@ asm_function ARMV86_MNNPackedMatMul_BF16 //void ARMV86_MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); // x0: C, x1:A, x2:B, x3:parameter, x4: postParameters, x5:bias -stp d14, d15, [sp, #-80]! +stp d14, d15, [sp, #-128]! stp d12, d13, [sp, #16] stp d10, d11, [sp, #32] stp d8, d9, [sp, #48] -stp x19, x21, [sp, #64] +stp x19, x20, [sp, #64] //ldr x8, [x3, #0] // deprecated ldr x9, [x3, #8] // l @@ -64,6 +57,7 @@ mov x11, #64 // B_stride = LP * HP = 4 * 8 * sizeof(int16_t) ldr x13, [x3, #24] // cStride ldr x7, [x3, #40] // bExtraStride +lsr x7, x7, #1 // bExtraStride -> bf16 stride add x10, x10, #3 lsr x10, x10, #2 @@ -79,14 +73,13 @@ Start: cmp x10, #2 blt LH4 LH8: - sub x14, x13, #96 // cStride - 96 + sub x14, x13, #192 // cStride - 12 * 4 * sizeof(float) LoopH: mov x15, x1 mov x12, x9 cbz x5, NoBiasH8 - ld1 {v0.4h, v1.4h}, [x5], #16 // 8 * sizeof(int16_t) - shll v0.4s, v0.4h, #16 - shll v1.4s, v1.4h, #16 + ld1 {v0.4s, v1.4s}, [x5], #32 // 8 * sizeof(float) + mov v2.16b, v0.16b mov v3.16b, v1.16b uzp1 v18.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 @@ -185,19 +178,14 @@ LoopH: FOURFMIN v6, v23, v24, v25, v26 FOURFMIN v6, v27, v28, v29, v30 StoreLH8: - Float32ToBf16 v7, v8, v9, v10 - Float32ToBf16 v11, v12, v13, v14 - Float32ToBf16 v15, v16, v17, v18 - Float32ToBf16 v19, v20, v21, v22 - Float32ToBf16 v23, v24, v25, v26 - Float32ToBf16 v27, v28, v29, v30 - st1 {v7.4h, v8.4h, v9.4h, v10.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v11.4h, v12.4h, v13.4h, v14.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t) + + st1 {v7.4s, v8.4s, v9.4s, v10.4s}, [x0], #64 // 16 * sizeof(int16_t) + st1 {v11.4s, v12.4s, v13.4s, v14.4s}, [x0], #64 // 16 * sizeof(int16_t) + st1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x0], #64 // 16 * sizeof(int16_t) add x0, x0, x14 - st1 {v19.4h, v20.4h, v21.4h, v22.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v23.4h, v24.4h, v25.4h, v26.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v27.4h, v28.4h, v29.4h, v30.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x0], #64 // 16 * sizeof(int16_t) + st1 {v23.4s, v24.4s, v25.4s, v26.4s}, [x0], #64 // 16 * sizeof(int16_t) + st1 {v27.4s, v28.4s, v29.4s, v30.4s}, [x0], #64 // 16 * sizeof(int16_t) add x0, x0, x14 add x2, x2, x7 // weight stride sub x10, x10, #2 @@ -209,8 +197,7 @@ LoopHR: mov x15, x1 mov x12, x9 cbz x5, NoBiasH4 - ld1 {v0.4h}, [x5], #8 // 8 * sizeof(int16_t) - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x5], #16 // 4 * sizeof(float) mov v2.16b, v0.16b uzp1 v18.2d, v0.2d, v2.2d // bias_0, bias_1, bias_0, bias_1 uzp2 v19.2d, v0.2d, v2.2d // bias_2, bias_3, bias_2, bias_3 @@ -269,18 +256,16 @@ LoopHR: FOURFMIN v6, v11, v12, v13, v14 FOURFMIN v6, v15, v16, v17, v18 StoreLH4: - Float32ToBf16 v7, v8, v9, v10 - Float32ToBf16 v11, v12, v13, v14 - Float32ToBf16 v15, v16, v17, v18 - st1 {v7.4h, v8.4h, v9.4h, v10.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v11.4h, v12.4h, v13.4h, v14.4h}, [x0], #32 // 16 * sizeof(int16_t) - st1 {v15.4h, v16.4h, v17.4h, v18.4h}, [x0], #32 // 16 * sizeof(int16_t) + st1 {v7.4s, v8.4s, v9.4s, v10.4s}, [x0], #64 // 16 * sizeof(int16_t) + st1 {v11.4s, v12.4s, v13.4s, v14.4s}, [x0], #64 // 16 * sizeof(int16_t) + st1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x0], #64 // 16 * sizeof(int16_t) + End: -ldp x19, x21, [sp, #64] +ldp x19, x20, [sp, #64] ldp d8, d9, [sp, #48] ldp d10, d11, [sp, #32] ldp d12, d13, [sp, #16] -ldp d14, d15, [sp], #80 +ldp d14, d15, [sp], #128 ret #endif diff --git a/source/backend/cpu/arm/arm64/bf16/MNNPackC4ForMatMul_A_BF16.S b/source/backend/cpu/arm/arm64/bf16/MNNPackC4ForMatMul_A_BF16.S index faa7d31a1..4fc8d85eb 100644 --- a/source/backend/cpu/arm/arm64/bf16/MNNPackC4ForMatMul_A_BF16.S +++ b/source/backend/cpu/arm/arm64/bf16/MNNPackC4ForMatMul_A_BF16.S @@ -34,14 +34,16 @@ mov x6, #0 ldr w4, [x2, #4] // eReal ldr w11, [x2, #8] // eDest ldr w6, [x2, #12] // xOffset -// xOffset -> xOffset * 4 * sizeof(int16_t) +// xOffset -> xOffset * 4 * sizeof(float) // eReal -> eReal * 4 * sizeof(int16_t) // eDest -> eDest * sizeof(int16_t) mov x12, #2 // sizeof(int16_t). kept as a const mov x9, #8 -mul x4, x9, x4 +mov x15, #16 // sizeof(float) +mul x4, x15, x4 mul x11, x12, x11 -mul x6, x9, x6 + +mul x6, x15, x6 LoopNumber: mov x2, #0 @@ -72,18 +74,35 @@ bne Right LoopL4: mov x2, x1 .macro MAIN_TRANSPOSE - ld1 {v0.4h}, [x1], x6 // load size: 4 * sizeof(int16_t), jump one stride line as x6 - ld1 {v3.4h}, [x1], x6 - ld1 {v6.4h}, [x1], x6 - ld1 {v17.4h}, [x1], x6 - ld1 {v1.4h}, [x1], x6 - ld1 {v4.4h}, [x1], x6 - ld1 {v7.4h}, [x1], x6 - ld1 {v18.4h}, [x1], x6 - ld1 {v2.4h}, [x1], x6 - ld1 {v5.4h}, [x1], x6 - ld1 {v16.4h}, [x1], x6 - ld1 {v19.4h}, [x1], x6 + ld1 {v0.4s}, [x1], x6 // load size: 4 * sizeof(int16_t), jump one stride line as x6 + ld1 {v3.4s}, [x1], x6 + ld1 {v6.4s}, [x1], x6 + ld1 {v17.4s}, [x1], x6 + + ld1 {v1.4s}, [x1], x6 + ld1 {v4.4s}, [x1], x6 + ld1 {v7.4s}, [x1], x6 + ld1 {v18.4s}, [x1], x6 + + ld1 {v2.4s}, [x1], x6 + ld1 {v5.4s}, [x1], x6 + ld1 {v16.4s}, [x1], x6 + ld1 {v19.4s}, [x1], x6 + + shrn v0.4h, v0.4s, #16 + shrn v3.4h, v3.4s, #16 + shrn v6.4h, v6.4s, #16 + shrn v17.4h, v17.4s, #16 + + shrn v1.4h, v1.4s, #16 + shrn v4.4h, v4.4s, #16 + shrn v7.4h, v7.4s, #16 + shrn v18.4h, v18.4s, #16 + + shrn v2.4h, v2.4s, #16 + shrn v5.4h, v5.4s, #16 + shrn v16.4h, v16.4s, #16 + shrn v19.4h, v19.4s, #16 transpose_4x4 v0, v3, v6, v17, v23, v24 transpose_4x4 v1, v4, v7, v18, v25, v26 @@ -99,23 +118,6 @@ bne Right stp d18, d19, [x0, #(16 * 5)] add x0, x0, #(16 * 6) - // st1 {v0.4h}, [x0], #8 // store size: 4 * sizeof(int16_t) - // st1 {v1.4h}, [x0], #8 - // st1 {v2.4h}, [x0], #8 - // st1 {v3.4h}, [x0], #8 - // st1 {v4.4h}, [x0], #8 - // st1 {v5.4h}, [x0], #8 - // st1 {v6.4h}, [x0], #8 - // st1 {v7.4h}, [x0], #8 - // st1 {v16.4h}, [x0], #8 - // st1 {v17.4h}, [x0], #8 - // st1 {v18.4h}, [x0], #8 - // st1 {v19.4h}, [x0], #8 - - // st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x0], #32 - // st1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x0], #32 - // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 - add x1, x2, x4 sub x5, x5, #4 cmp w5, #4 @@ -133,20 +135,6 @@ bne Right str d16, [x0, #(16 * 4)] add x0, x0, #(16 * 4 + 8) - // st1 {v0.4h}, [x0], #8 // store size: 4 * sizeof(int16_t) - // st1 {v1.4h}, [x0], #8 - // st1 {v2.4h}, [x0], #8 - // st1 {v3.4h}, [x0], #8 - // st1 {v4.4h}, [x0], #8 - // st1 {v5.4h}, [x0], #8 - // st1 {v6.4h}, [x0], #8 - // st1 {v7.4h}, [x0], #8 - // st1 {v16.4h}, [x0], #8 - - // st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x0], #32 - // st1 {v4.4h, v5.4h, v6.4h, v7.4h}, [x0], #32 - // st1 {v16.4h}, [x0], #8 - b LoopEEnd LoopEL2: @@ -158,16 +146,6 @@ bne Right stp d4, d5, [x0, #(16 * 2)] add x0, x0, #(16 * 3) - // st1 {v0.4h}, [x0], #8 // store size: 4 * sizeof(int16_t) - // st1 {v1.4h}, [x0], #8 - // st1 {v2.4h}, [x0], #8 - // st1 {v3.4h}, [x0], #8 - // st1 {v4.4h}, [x0], #8 - // st1 {v5.4h}, [x0], #8 - - // st1 {v0.4h, v1.4h, v2.4h, v3.4h}, [x0], #32 - // st1 {v4.4h, v5.4h}, [x0], #16 - b LoopEEnd LoopEL1: @@ -178,12 +156,6 @@ bne Right str d2, [x0, #16] add x0, x0, #(16 + 8) - // st1 {v0.4h}, [x0], #8 // store size: 4 * sizeof(int16_t) - // st1 {v1.4h}, [x0], #8 - // st1 {v2.4h}, [x0], #8 - - // st1 {v0.4h, v1.4h, v2.4h}, [x0], #24 - LoopEEnd: b End @@ -198,7 +170,8 @@ LoopE1: cmp w5, #4 blt LoopE1L3 LoopE1L4: - ld1 {v0.4h}, [x1], x4 + ld1 {v0.4s}, [x1], x4 + shrn v0.4h, v0.4s, #16 st1 {v0.h}[0], [x0], x11 st1 {v0.h}[1], [x0], x11 st1 {v0.h}[2], [x0], x11 @@ -210,7 +183,8 @@ LoopE1: LoopE1L3: cmp w5, #3 blt LoopE1L2 - ld1 {v0.4h}, [x1], x4 + ld1 {v0.4s}, [x1], x4 + shrn v0.4h, v0.4s, #16 st1 {v0.h}[0], [x0], x11 st1 {v0.h}[1], [x0], x11 st1 {v0.h}[2], [x0], x11 @@ -220,7 +194,8 @@ LoopE1: LoopE1L2: cmp w5, #2 blt LoopE1L1 - ld1 {v0.4h}, [x1], x4 + ld1 {v0.4s}, [x1], x4 + shrn v0.4h, v0.4s, #16 st1 {v0.h}[0], [x0], x11 st1 {v0.h}[1], [x0], x11 sub w5, w5, #2 @@ -228,7 +203,8 @@ LoopE1: LoopE1L1: cmp w5, #1 blt LoopE1End - ld1 {v0.h}[0], [x1], x4 + ld1 {v0.s}[0], [x1], x4 + shrn v0.4h, v0.4s, #16 st1 {v0.h}[0], [x0], x11 LoopE1End: diff --git a/source/backend/cpu/arm/arm64/bf16/MNNPackC8_BF16.S b/source/backend/cpu/arm/arm64/bf16/MNNPackC8_BF16.S index 87503e839..7157ce44b 100644 --- a/source/backend/cpu/arm/arm64/bf16/MNNPackC8_BF16.S +++ b/source/backend/cpu/arm/arm64/bf16/MNNPackC8_BF16.S @@ -23,9 +23,10 @@ lsr x4, x2, #3 lsr x5, x3, #3 mov x12, #2 // sizeof(int16_t) mov x13, #16 // 8 * sizeof(int16_t) -mul x6, x12, x2 +mov x15, #4 +mul x6, x15, x2 mul x7, x13, x2 -mov x12, #16 // 8 * sizeof(int16_t) +mov x12, #32 // 8 * sizeof(float) mul x15, x12, x2 .macro transpose_4x4 x0, x1, x2, x3, x5, x6 @@ -47,32 +48,15 @@ mov x12, x4 LoopL: mov x10, x9 -ld1 {v16.4h, v17.4h}, [x9], x6 -ld1 {v18.4h, v19.4h}, [x9], x6 -ld1 {v20.4h, v21.4h}, [x9], x6 -ld1 {v22.4h, v23.4h}, [x9], x6 - -ld1 {v24.4h, v25.4h}, [x9], x6 -ld1 {v26.4h, v27.4h}, [x9], x6 -ld1 {v28.4h, v29.4h}, [x9], x6 -ld1 {v30.4h, v31.4h}, [x9], x6 - -shll v16.4s, v16.4h, #16 -shll v17.4s, v17.4h, #16 -shll v18.4s, v18.4h, #16 -shll v19.4s, v19.4h, #16 -shll v20.4s, v20.4h, #16 -shll v21.4s, v21.4h, #16 -shll v22.4s, v22.4h, #16 -shll v23.4s, v23.4h, #16 -shll v24.4s, v24.4h, #16 -shll v25.4s, v25.4h, #16 -shll v26.4s, v26.4h, #16 -shll v27.4s, v27.4h, #16 -shll v28.4s, v28.4h, #16 -shll v29.4s, v29.4h, #16 -shll v30.4s, v30.4h, #16 -shll v31.4s, v31.4h, #16 +ld1 {v16.4s, v17.4s}, [x9], x6 +ld1 {v18.4s, v19.4s}, [x9], x6 +ld1 {v20.4s, v21.4s}, [x9], x6 +ld1 {v22.4s, v23.4s}, [x9], x6 + +ld1 {v24.4s, v25.4s}, [x9], x6 +ld1 {v26.4s, v27.4s}, [x9], x6 +ld1 {v28.4s, v29.4s}, [x9], x6 +ld1 {v30.4s, v31.4s}, [x9], x6 transpose_4x4 v16, v18, v20, v22, v0, v1 @@ -109,7 +93,7 @@ stp d19, d27, [x8], #16 stp d21, d29, [x8], #16 stp d23, d31, [x8], #16 -add x9, x10, #16 // 8 * sizeof(int16_t) +add x9, x10, #32 // 8 * sizeof(float) subs x12, x12, #1 bne LoopL diff --git a/source/backend/cpu/arm/arm64/bf16/MNNPackedMatMulRemain_BF16.S b/source/backend/cpu/arm/arm64/bf16/MNNPackedMatMulRemain_BF16.S index a65140adc..64232bd9f 100644 --- a/source/backend/cpu/arm/arm64/bf16/MNNPackedMatMulRemain_BF16.S +++ b/source/backend/cpu/arm/arm64/bf16/MNNPackedMatMulRemain_BF16.S @@ -24,11 +24,14 @@ ldr x11, [x4, #0] // aStride ldr x9, [x4, #8] // l ldr x10, [x4, #16] // h +lsr x11, x11, #1 // aStride = aStride / 2 (fp32 -> bf16) + ldr x7, [x4, #24] // cStride ldr x19, [x4, #40] // bExtraStride add x10, x10, #3 lsr x10, x10, #2 +lsr x19, x19, #1 // bExtraStride = bExtraStride / 2 cbz x5, Start ld1 {v5.4s}, [x5] @@ -121,9 +124,7 @@ LoopE8: // e, TILE_BLOCK size is 8 cbz x5, StoreLH8 AddBiasLH8: - ld1 {v0.4h, v1.4h}, [x20], #16 - shll v0.4s, v0.4h, #16 - shll v1.4s, v1.4h, #16 + ld1 {v0.4s, v1.4s}, [x20], #32 fmla v16.4s, v0.4s, v5.s[1] fmla v17.4s, v0.4s, v5.s[1] @@ -181,33 +182,17 @@ LoopE8: // e, TILE_BLOCK size is 8 fmin v31.4s, v31.4s, v7.4s StoreLH8: - shrn v16.4h, v16.4s, #16 - shrn v17.4h, v17.4s, #16 - shrn v18.4h, v18.4s, #16 - shrn v19.4h, v19.4s, #16 - shrn v20.4h, v20.4s, #16 - shrn v21.4h, v21.4s, #16 - shrn v22.4h, v22.4s, #16 - shrn v23.4h, v23.4s, #16 - shrn v24.4h, v24.4s, #16 - shrn v25.4h, v25.4s, #16 - shrn v26.4h, v26.4s, #16 - shrn v27.4h, v27.4s, #16 - shrn v28.4h, v28.4s, #16 - shrn v29.4h, v29.4s, #16 - shrn v30.4h, v30.4s, #16 - shrn v31.4h, v31.4s, #16 - - stp d16, d17, [x0] - stp d18, d19, [x0, #(16 * 1)] - stp d24, d25, [x0, #(16 * 2)] - stp d26, d27, [x0, #(16 * 3)] + + stp q16, q17, [x0] + stp q18, q19, [x0, #(32 * 1)] + stp q24, q25, [x0, #(32 * 2)] + stp q26, q27, [x0, #(32 * 3)] add x0, x0, x7 // stp donot support post-index offset in register - stp d20, d21, [x0] - stp d22, d23, [x0, #(16 * 1)] - stp d28, d29, [x0, #(16 * 2)] - stp d30, d31, [x0, #(16 * 3)] + stp q20, q21, [x0] + stp q22, q23, [x0, #(32 * 1)] + stp q28, q29, [x0, #(32 * 2)] + stp q30, q31, [x0, #(32 * 3)] add x0, x0, x7 // stp donot support post-index offset in register // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 // 4 * 4 * sizeof(int16_t) @@ -271,8 +256,7 @@ LoopE8: // e, TILE_BLOCK size is 8 cbz x5, StoreLH8x4 AddBiasLH8x4: - ld1 {v0.4h}, [x20] - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x20] fmla v16.4s, v0.4s, v5.s[1] fmla v17.4s, v0.4s, v5.s[1] @@ -304,20 +288,12 @@ LoopE8: // e, TILE_BLOCK size is 8 fmin v23.4s, v23.4s, v7.4s StoreLH8x4: - shrn v16.4h, v16.4s, #16 - shrn v17.4h, v17.4s, #16 - shrn v18.4h, v18.4s, #16 - shrn v19.4h, v19.4s, #16 - shrn v20.4h, v20.4s, #16 - shrn v21.4h, v21.4s, #16 - shrn v22.4h, v22.4s, #16 - shrn v23.4h, v23.4s, #16 - - stp d16, d17, [x0] - stp d18, d19, [x0, #(16 * 1)] - stp d20, d21, [x0, #(16 * 2)] - stp d22, d23, [x0, #(16 * 3)] - add x0, x0, #(16 * 4) + + stp q16, q17, [x0] + stp q18, q19, [x0, #(32 * 1)] + stp q20, q21, [x0, #(32 * 2)] + stp q22, q23, [x0, #(32 * 3)] + add x0, x0, #(32 * 4) // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], #32 // st1 {v20.4h, v21.4h, v22.4h, v23.4h}, [x0], #32 @@ -326,7 +302,7 @@ LoopE8: // e, TILE_BLOCK size is 8 sub x3, x3, #8 cmp x3, #8 - add x0, x21, #64 // move dest address of 8 * 4 * sizeof(int16_t) + add x0, x21, #128 // move dest address of 8 * 4 * sizeof(float) add x1, x1, #16 // move A matrix address of 8 * sizeof(int16_t) bge LoopE8 @@ -412,9 +388,7 @@ blt E1 cbz x5, StoreLH4x8 AddBiasLH4x8: - ld1 {v0.4h, v1.4h}, [x20], #16 - shll v0.4s, v0.4h, #16 - shll v1.4s, v1.4h, #16 + ld1 {v0.4s, v1.4s}, [x20], #32 fmla v16.4s, v0.4s, v5.s[1] fmla v17.4s, v0.4s, v5.s[1] @@ -446,21 +420,12 @@ blt E1 fmin v23.4s, v23.4s, v7.4s StoreLH4x8: - shrn v16.4h, v16.4s, #16 - shrn v17.4h, v17.4s, #16 - shrn v18.4h, v18.4s, #16 - shrn v19.4h, v19.4s, #16 - shrn v20.4h, v20.4s, #16 - shrn v21.4h, v21.4s, #16 - shrn v22.4h, v22.4s, #16 - shrn v23.4h, v23.4s, #16 - - - stp d16, d17, [x0] - stp d18, d19, [x0, #16] + + stp q16, q17, [x0] + stp q18, q19, [x0, #32] add x0, x0, x7 - stp d20, d21, [x0] - stp d22, d23, [x0, #16] + stp q20, q21, [x0] + stp q22, q23, [x0, #32] add x0, x0, x7 // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0], x7 @@ -503,8 +468,7 @@ blt E1 cbz x5, StoreLH4x4 AddBiasLH4x4: - ld1 {v0.4h}, [x20] - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x20] fmla v16.4s, v0.4s, v5.s[1] fmla v17.4s, v0.4s, v5.s[1] @@ -525,20 +489,15 @@ blt E1 StoreLH4x4: - shrn v16.4h, v16.4s, #16 - shrn v17.4h, v17.4s, #16 - shrn v18.4h, v18.4s, #16 - shrn v19.4h, v19.4s, #16 - - stp d16, d17, [x0] - stp d18, d19, [x0, #16] + stp q16, q17, [x0] + stp q18, q19, [x0, #32] // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0] E4End: sub x3, x3, #4 - add x0, x21, #32 // move dest address of 4 * 4 * sizeof(int16_t) + add x0, x21, #64 // move dest address of 4 * 4 * sizeof(float) add x1, x1, #8 // move dest address of 4 * sizeof(int16_t) E1: @@ -590,9 +549,7 @@ LoopE1: cbz x5, StoreLH1x8 AddBiasLH1x8: - ld1 {v0.4h, v1.4h}, [x20], #16 - shll v1.4s, v1.4h, #16 - shll v0.4s, v0.4h, #16 + ld1 {v0.4s, v1.4s}, [x20], #32 fmla v16.4s, v0.4s, v5.s[1] fmla v20.4s, v1.4s, v5.s[1] @@ -604,10 +561,8 @@ LoopE1: fmin v20.4s, v20.4s, v7.4s StoreLH1x8: - shrn v16.4h, v16.4s, #16 - shrn v20.4h, v20.4s, #16 - st1 {v16.4h}, [x0], x7 - st1 {v20.4h}, [x0], x7 + st1 {v16.4s}, [x0], x7 + st1 {v20.4s}, [x0], x7 bge E1LoopH8 @@ -640,8 +595,7 @@ LoopE1: cbz x5, StoreLH1x4 AddBiasLH1x4: - ld1 {v0.4h}, [x20] - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x20] fmla v16.4s, v0.4s, v5.s[1] @@ -650,13 +604,12 @@ LoopE1: fmin v16.4s, v16.4s, v7.4s StoreLH1x4: - shrn v16.4h, v16.4s, #16 - st1 {v16.4h}, [x0] + st1 {v16.4s}, [x0] E1End: subs x3, x3, #1 - add x0, x21, #8 + add x0, x21, #16 add x1, x1, #2 bne LoopE1 diff --git a/source/backend/cpu/arm/arm64/bf16/MNNPackedMatMul_BF16.S b/source/backend/cpu/arm/arm64/bf16/MNNPackedMatMul_BF16.S index 22c2c24ca..28991753c 100644 --- a/source/backend/cpu/arm/arm64/bf16/MNNPackedMatMul_BF16.S +++ b/source/backend/cpu/arm/arm64/bf16/MNNPackedMatMul_BF16.S @@ -27,6 +27,7 @@ ldr x10, [x3, #16] // h ldr x13, [x3, #24] // cStride ldr x7, [x3, #40] // bExtraStride +lsr x7, x7, #1 // v0, v1, v2: A // v3, v4: B @@ -218,9 +219,7 @@ LoopH: cbz x4, StoreLH8 AddBiasLH8: - ld1 {v0.4h, v1.4h}, [x5], #16 // 8 * sizeof(int16_t) - shll v0.4s, v0.4h, #16 - shll v1.4s, v1.4h, #16 + ld1 {v0.4s, v1.4s}, [x5], #32 // 8 * sizeof(int16_t) fmla v8.4s, v0.4s, v5.s[1] fmla v9.4s, v0.4s, v5.s[1] @@ -305,44 +304,19 @@ LoopH: StoreLH8: - shrn v8.4h, v8.4s, #16 - shrn v9.4h, v9.4s, #16 - shrn v10.4h, v10.4s, #16 - shrn v11.4h, v11.4s, #16 - shrn v12.4h, v12.4s, #16 - shrn v13.4h, v13.4s, #16 - shrn v14.4h, v14.4s, #16 - shrn v15.4h, v15.4s, #16 - shrn v16.4h, v16.4s, #16 - shrn v17.4h, v17.4s, #16 - shrn v18.4h, v18.4s, #16 - shrn v19.4h, v19.4s, #16 - shrn v20.4h, v20.4s, #16 - shrn v21.4h, v21.4s, #16 - shrn v22.4h, v22.4s, #16 - shrn v23.4h, v23.4s, #16 - shrn v24.4h, v24.4s, #16 - shrn v25.4h, v25.4s, #16 - shrn v26.4h, v26.4s, #16 - shrn v27.4h, v27.4s, #16 - shrn v28.4h, v28.4s, #16 - shrn v29.4h, v29.4s, #16 - shrn v30.4h, v30.4s, #16 - shrn v31.4h, v31.4s, #16 - - stp d8, d9, [x0] - stp d10, d11, [x0, #(16 * 1)] // 2 * 4 * sizeof(int16_t) - stp d12, d13, [x0, #(16 * 2)] - stp d14, d15, [x0, #(16 * 3)] - stp d16, d17, [x0, #(16 * 4)] - stp d18, d19, [x0, #(16 * 5)] + stp q8, q9, [x0] + stp q10, q11, [x0, #(32 * 1)] // 2 * 4 * sizeof(int16_t) + stp q12, q13, [x0, #(32 * 2)] + stp q14, q15, [x0, #(32 * 3)] + stp q16, q17, [x0, #(32 * 4)] + stp q18, q19, [x0, #(32 * 5)] add x0, x0, x13 // stp donot support post-index offset in register - stp d20, d21, [x0] - stp d22, d23, [x0, #(16 * 1)] - stp d24, d25, [x0, #(16 * 2)] - stp d26, d27, [x0, #(16 * 3)] - stp d28, d29, [x0, #(16 * 4)] - stp d30, d31, [x0, #(16 * 5)] + stp q20, q21, [x0] + stp q22, q23, [x0, #(32 * 1)] + stp q24, q25, [x0, #(32 * 2)] + stp q26, q27, [x0, #(32 * 3)] + stp q28, q29, [x0, #(32 * 4)] + stp q30, q31, [x0, #(32 * 5)] add x0, x0, x13 // st1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x0], #32 // 16 * sizeof(int16_t) @@ -415,8 +389,7 @@ LoopHRemain: cbz x4, StoreLH4 AddBiasLH4: - ld1 {v0.4h}, [x5], #8 - shll v0.4s, v0.4h, #16 + ld1 {v0.4s}, [x5], #16 fmla v8.4s, v0.4s, v5.s[1] fmla v9.4s, v0.4s, v5.s[1] @@ -462,29 +435,9 @@ LoopHRemain: StoreLH4: - shrn v8.4h, v8.4s, #16 - shrn v9.4h, v9.4s, #16 - shrn v10.4h, v10.4s, #16 - shrn v11.4h, v11.4s, #16 - shrn v12.4h, v12.4s, #16 - shrn v13.4h, v13.4s, #16 - shrn v14.4h, v14.4s, #16 - shrn v15.4h, v15.4s, #16 - shrn v16.4h, v16.4s, #16 - shrn v17.4h, v17.4s, #16 - shrn v18.4h, v18.4s, #16 - shrn v19.4h, v19.4s, #16 - - stp d8, d9, [x0] - stp d10, d11, [x0, #(16 * 1)] - stp d12, d13, [x0, #(16 * 2)] - stp d14, d15, [x0, #(16 * 3)] - stp d16, d17, [x0, #(16 * 4)] - stp d18, d19, [x0, #(16 * 5)] - - // st1 {v8.4h, v9.4h, v10.4h, v11.4h}, [x0], #32 - // st1 {v12.4h, v13.4h, v14.4h, v15.4h}, [x0], #32 - // st1 {v16.4h, v17.4h, v18.4h, v19.4h}, [x0] + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0] sub x10, x10, #1 diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNDynamicQuantFP32.S b/source/backend/cpu/arm/arm64/low_memory/MNNDynamicQuantFP32.S index d1b02673b..f8971acaa 100644 --- a/source/backend/cpu/arm/arm64/low_memory/MNNDynamicQuantFP32.S +++ b/source/backend/cpu/arm/arm64/low_memory/MNNDynamicQuantFP32.S @@ -37,34 +37,32 @@ add \d0\().4s, \d0\().4s, \d2\().4s .endm -//void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize) +//void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack) asm_function MNNDynamicQuantFP32 -// x0: src, x1:dst, x2:scale, x3: sum, x4:src_depth_quad, x5:realSize +// x0: src, x1:dst, x2:scale, x3:src_depth_quad, x4:realSize stp d14, d15, [sp, #(-16 * 4)]! stp d12, d13, [sp, #(16 * 1)] stp d10, d11, [sp, #(16 * 2)] stp d8, d9, [sp, #(16 * 3)] Start: -lsl x6, x5, #2 // dst_step = batch * unit * sizeof(int8_t) = batch * 4 = batch << 2 +lsl x6, x4, #2 // dst_step = batch * unit * sizeof(int8_t) = batch * 4 = batch << 2 lsl x7, x6, #2 // src_step = dst_step * 4 (sizeof(float32_t)) = dst_step << 2 TILE_8: -cmp x5, #8 +cmp x4, #8 blt TILE_4 sub x8, x7, #64 // src_step - 64 mov x9, x0 // src mov x10, x1 // dst //mov x11, x2 // scale -mov x12, x4 // src_depth_quad +mov x12, x3 // src_depth_quad // quant_scale: v8, 8(batch)*sizeof(float32_t) ld1 {v8.4s, v9.4s}, [x2], #32 // int8 sum -movi v10.4s, #0 -movi v11.4s, #0 LoopSz_8: ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x9], #64 ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x9], x8 @@ -99,36 +97,23 @@ sqxtn v17.8b, v14.8h sqxtn2 v17.16b, v15.8h st1 {v16.16b, v17.16b}, [x10], x6 -// sum -//Transpose v0, v1, v2, v3, v14, v15, v16, v17 -//Add_4x4 v0, v1, v2, v3 -addp v18.4s, v0.4s, v1.4s -addp v19.4s, v2.4s, v3.4s -addp v20.4s, v4.4s, v5.4s -addp v21.4s, v6.4s, v7.4s -addp v22.4s, v18.4s, v19.4s -addp v23.4s, v20.4s, v21.4s - -add v10.4s, v22.4s, v10.4s -add v11.4s, v23.4s, v11.4s subs x12, x12, #1 bne LoopSz_8 Tile8End: -sub x5, x5, #8 // batch -= 8 +sub x4, x4, #8 // batch -= 8 add x0, x0, #128 // src += 8 * 4 * sizeof(float32_t) add x1, x1, #32 // dst += 8 * 4 * sizeof(int8_t) -st1 {v10.4s, v11.4s}, [x3], #32 b TILE_8 TILE_4: -cmp x5, #4 +cmp x4, #4 blt TILE_1 mov x9, x0 // src mov x10, x1 // dst //mov x11, x2 // scale -mov x12, x4 // src_depth_quad +mov x12, x3 // src_depth_quad // quant_scale: v8, 4(batch)*sizeof(float32_t) ld1 {v8.4s}, [x2], #16 @@ -158,28 +143,23 @@ sqxtn v6.8b, v4.8h sqxtn2 v6.16b, v5.8h st1 {v6.16b}, [x10], x6 -// sum -Transpose v0, v1, v2, v3, v14, v15, v16, v17 -Add_4x4 v0, v1, v2, v3 -add v10.4s, v0.4s, v10.4s subs x12, x12, #1 bne LoopSz_4 Tile4End: -sub x5, x5, #4 // batch -= 4 +sub x4, x4, #4 // batch -= 4 add x0, x0, #64 // src += 4 * 4 * sizeof(float32_t) add x1, x1, #16 // dst += 4 * 4 * sizeof(int8_t) //add x2, x2, #16 // scale += 4 * sizeof(float32_t) -st1 {v10.4s}, [x3], #16 b TILE_4 TILE_1: -cmp x5, #1 +cmp x4, #1 blt End mov x9, x0 // src mov x10, x1 // dst -mov x12, x4 // src_depth_quad +mov x12, x3 // src_depth_quad // quant_scale: v8 ld1 {v8.s}[0], [x2], #4 @@ -192,26 +172,17 @@ fmul v0.4s, v0.4s, v8.s[0] // int16_t x = round(x) fcvtas v0.4s, v0.4s -dup v1.4s, v0.s[1] -dup v2.4s, v0.s[2] -dup v3.4s, v0.s[3] - // y = (int8_t)x sqxtn v7.4h, v0.4s sqxtn v7.8b, v7.8h -// sum - -Add_4x4 v0, v1, v2, v3 -add v4.4s, v0.4s, v4.4s st1 {v7.s}[0], [x10], x6 subs x12, x12, #1 bne LoopSz_1 -st1 {v4.s}[0], [x3], #4 Tile1End: -subs x5, x5, #1 // batch -= 1 +subs x4, x4, #1 // batch -= 1 add x0, x0, #16 // src += 1 * 4 * sizeof(float32_t) add x1, x1, #4 // dst += 1 * 4 * sizeof(int8_t) //add x2, x2, #4 // scale += 1 * sizeof(float32_t) @@ -224,4 +195,4 @@ ldp d12, d13, [sp, #(16 * 1)] ldp d14, d15, [sp], #(16 * 4) ret -#endif \ No newline at end of file +#endif diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNDynamicUpdateConvBiasScale.S b/source/backend/cpu/arm/arm64/low_memory/MNNDynamicUpdateConvBiasScale.S new file mode 100644 index 000000000..b0fa8194d --- /dev/null +++ b/source/backend/cpu/arm/arm64/low_memory/MNNDynamicUpdateConvBiasScale.S @@ -0,0 +1,229 @@ +// +// MNNDynamicUpdateConvBiasScale.S +// MNN +// +// Created by MNN on 2019/01/22. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro Round z0, z1, z2, z3 + fcvtzs \z0\().4s, \z0\().4s + fcvtzs \z1\().4s, \z1\().4s + fcvtzs \z2\().4s, \z2\().4s + fcvtzs \z3\().4s, \z3\().4s +.endm + +.macro MUL_CONSTANT s0, s1, s2, s3, z0 + fmul \s0\().4s, \s0\().4s, \z0\().4s + fmul \s1\().4s, \s1\().4s, \z0\().4s + fmul \s2\().4s, \s2\().4s, \z0\().4s + fmul \s3\().4s, \s3\().4s, \z0\().4s +.endm + +.macro DIV4 s0, s1, s2, s3, z0, z1, z2, z3 + fdiv \s0\().4s, \s0\().4s, \z0\().4s + fdiv \s1\().4s, \s1\().4s, \z1\().4s + fdiv \s2\().4s, \s2\().4s, \z2\().4s + fdiv \s3\().4s, \s3\().4s, \z3\().4s +.endm + +.macro SUB4 s0, s1, s2, s3, z0, z1, z2, z3 + fsub \s0\().4s, \s0\().4s, \z0\().4s + fsub \s1\().4s, \s1\().4s, \z1\().4s + fsub \s2\().4s, \s2\().4s, \z2\().4s + fsub \s3\().4s, \s3\().4s, \z3\().4s +.endm + +.macro Float32ToHalf s0, s1, s2, s3, d0, d1 + fcvtn \d0\().4h, \s0\().4s + fcvtn2 \d0\().8h, \s1\().4s + fcvtn \d1\().4h, \s2\().4s + fcvtn2 \d1\().8h, \s3\().4s +.endm + +/* +Note: Only used in dynamic quant,so do not need compare min max! + */ +asm_function MNNDynamicUpdateConvBiasScale +//MNNDynamicUpdateConvBiasScale(biasFloat.data(), scaleFloat.data(), biasfp32, weightDequantScale, +//inputScale, weightKernelSum, inputZero, UP_DIV(output->channel(), 4), alphaSize) +//x0:biasFloat, x1:scaleFloat, x2:biasfp32, x3:weightDequantScale, x4:inputScale, x5:weightKernelSum, x6:inputZero, x7:ocQuad +//Load from sp: x9: scaleSize + +ldr x9, [sp, #0] +stp d14, d15, [sp, #-64]! +stp d12, d13, [sp, #16] +stp d10, d11, [sp, #32] +stp d8, d9, [sp, #48] + +ld1r {v31.4s}, [x4] // input dequant scale +ld1r {v30.4s}, [x6] // input dequant zero:fp32 zero + +lsr x9, x9, #2 +// fuse scale + +SCALE_L24: +cmp x9, #24 +blt SCALE_L16 + +ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3], #64 +ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3], #64 +ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3], #64 +ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3], #64 +ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x3], #64 +ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x3], #64 +MUL_CONSTANT v0, v1, v2, v3, v31 // w_scale * x_scale +MUL_CONSTANT v4, v5, v6, v7, v31 +MUL_CONSTANT v8, v9, v10, v11, v31 +MUL_CONSTANT v12, v13, v14, v15, v31 +MUL_CONSTANT v16, v17, v18, v19, v31 +MUL_CONSTANT v20, v21, v22, v23, v31 +sub x9, x9, #24 +st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 +st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 +st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x1], #64 +st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x1], #64 +b SCALE_L24 + +SCALE_L16: +cmp x9, #16 +blt SCALE_L8 + +ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3], #64 +ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3], #64 +ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x3], #64 +ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x3], #64 +MUL_CONSTANT v0, v1, v2, v3, v31 // w_scale * x_scale +MUL_CONSTANT v4, v5, v6, v7, v31 +MUL_CONSTANT v8, v9, v10, v11, v31 +MUL_CONSTANT v12, v13, v14, v15, v31 +sub x9, x9, #16 +st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 +st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x1], #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x1], #64 +b SCALE_L16 + +SCALE_L8: +cmp x9, #8 +blt SCALE_L4 + +ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3], #64 +ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x3], #64 +MUL_CONSTANT v0, v1, v2, v3, v31 // w_scale * x_scale +MUL_CONSTANT v4, v5, v6, v7, v31 +sub x9, x9, #8 +st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 +st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x1], #64 +b SCALE_L8 + +SCALE_L4: +cmp x9, #4 +blt SCALE_L1 + +ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x3], #64 +MUL_CONSTANT v0, v1, v2, v3, v31 // w_scale * x_scale +sub x9, x9, #4 +st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x1], #64 +b SCALE_L4 + +SCALE_L1: +cmp x9, #1 +blt BIAS_L8 + +ld1 {v0.4s}, [x3], #16 +fmul v0.4s, v0.4s, v31.4s +sub x9, x9, #1 +st1 {v0.4s}, [x1], #16 +b SCALE_L1 + +// Bias: +BIAS_L16: +cmp x7, #16 +blt BIAS_L8 + +ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 // oldbias +ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64 +ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x2], #64 +ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x2], #64 +ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64 // weightKernelSum +ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x5], #64 +ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x5], #64 + +sub x7, x7, #16 + +MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero +MUL_CONSTANT v20, v21, v22, v23, v30 // w_sum * x_zero +MUL_CONSTANT v24, v25, v26, v27, v30 // w_sum * x_zero + +SUB4 v0, v1, v2, v3, v16, v17, v18, v19 +SUB4 v4, v5, v6, v7, v20, v21, v22, v23 +ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64 +SUB4 v8, v9, v10, v11, v24, v25, v26, v27 +MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero +SUB4 v12, v13, v14, v15, v16, v17, v18, v19 + +st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 // bias float +st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 +st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 +st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 +b BIAS_L16 + +BIAS_L8: +cmp x7, #8 +blt BIAS_L4 + +ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 // oldbias +ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x2], #64 +ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x5], #64 // weightKernelSum +ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x5], #64 +sub x7, x7, #8 + +MUL_CONSTANT v16, v17, v18, v19, v30 // w_sum * x_zero +MUL_CONSTANT v20, v21, v22, v23, v30 // w_sum * x_zero +SUB4 v0, v1, v2, v3, v16, v17, v18, v19 +SUB4 v4, v5, v6, v7, v20, v21, v22, v23 +st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 // bias float +st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 +b BIAS_L8 + +BIAS_L4: +cmp x7, #4 +blt BIAS_L1 + +ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x2], #64 // oldbias +ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x5], #64 // weightKernelSum +sub x7, x7, #4 + +MUL_CONSTANT v8, v9, v10, v11, v30 // w_sum * x_zero +SUB4 v0, v1, v2, v3, v8, v9, v10, v11 +st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 +b BIAS_L4 + +BIAS_L1: +cmp x7, #1 +blt End +ld1 {v0.4s}, [x2], #16 // oldbias +ld1 {v4.4s}, [x5], #16 // weightKernelSum +sub x7, x7, #1 +fmul v4.4s, v4.4s, v30.4s // w_sum * x_zero +fsub v0.4s, v0.4s, v4.4s // oldbias - w_sum * x_zero +st1 {v0.4s}, [x0], #16 +b BIAS_L1 + +End: +ldp d8, d9, [sp, #48] +ldp d10, d11, [sp, #32] +ldp d12, d13, [sp, #16] +ldp d14, d15, [sp], #64 +ret +#endif diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32.S deleted file mode 100644 index 83548cfd9..000000000 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32.S +++ /dev/null @@ -1,308 +0,0 @@ -// -// MNNGemmHybridInt4_sdot.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s - fmul \d0\().4s, \d0\().4s, \s\().s[0] - fmul \d1\().4s, \d1\().4s, \s\().s[1] - fmul \d2\().4s, \d2\().4s, \s\().s[2] - fmul \d3\().4s, \d3\().4s, \s\().s[3] -.endm - -.macro Dequant c0, a0, z0, b0, s0, idx - fmul \c0\().4s, \c0\().4s, \a0\().4s - fmla \c0\().4s, \z0\().4s, \s0\().s[\idx] - fadd \c0\().4s, \c0\().4s, \b0\().4s -.endm - -asm_function MNNGemmHybridInt4FP32 - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] - -Start: -lsl x13, x3, #3 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_quad * 8 = src_depth_quad << 3 - -TILE_4: - cmp x6, #4 - blt TILE_1 - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - // batch=0,oc=0-3 - movi v10.4s, #0 //ic=0-3 - movi v11.4s, #0 - movi v12.4s, #0 - movi v13.4s, #0 - // batch=1,oc=0-3 - movi v16.4s, #0 - movi v17.4s, #0 - movi v18.4s, #0 - movi v19.4s, #0 - // batch=2,oc=0-3 - movi v20.4s, #0 - movi v21.4s, #0 - movi v22.4s, #0 - movi v23.4s, #0 - // batch=3,oc=0-3 - movi v24.4s, #0 - movi v25.4s, #0 - movi v26.4s, #0 - movi v27.4s, #0 - // mask - movi v14.16b, #15 - // offset - movi v15.16b, #8 -LoopSz_TILE_4: - // src : 4(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 4 x 4 x [1] : v16-v19 - ld1 {v0.8b}, [x25], #8 // weight - ld1 {v4.16b}, [x24], x15 // src - // int4->int8 - ushr v8.16b, v0.16b, #4 - and v9.16b, v0.16b, v14.16b - zip1 v0.16b, v8.16b, v9.16b - - Unit_TILE_4: - sxtl v5.8h, v4.8b // src batch=0,1 - sxtl2 v6.8h, v4.16b // batch=2,3 - sxtl v1.8h, v0.8b // weight oc=0,1 - sxtl2 v2.8h, v0.16b // oc=2,3 - dup v28.2d, v1.d[0] // oc=0,0 - dup v29.2d, v1.d[1] // oc=1,1 - dup v30.2d, v2.d[0] // oc=2,2 - dup v31.2d, v2.d[1] // oc=3,3 - // batch=0 - smlal v10.4s, v5.4h, v28.4h - smlal v11.4s, v5.4h, v29.4h - smlal v12.4s, v5.4h, v30.4h - smlal v13.4s, v5.4h, v31.4h - // batch=1 - smlal2 v16.4s, v5.8h, v28.8h - smlal2 v17.4s, v5.8h, v29.8h - smlal2 v18.4s, v5.8h, v30.8h - smlal2 v19.4s, v5.8h, v31.8h - // batch=2 - smlal v20.4s, v6.4h, v28.4h - smlal v21.4s, v6.4h, v29.4h - smlal v22.4s, v6.4h, v30.4h - smlal v23.4s, v6.4h, v31.4h - // batch=3 - smlal2 v24.4s, v6.8h, v28.8h - smlal2 v25.4s, v6.8h, v29.8h - smlal2 v26.4s, v6.8h, v30.8h - smlal2 v27.4s, v6.8h, v31.8h - // .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0 - // .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1 - // .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2 - // .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3 - - subs x26, x26, #1 - bne LoopSz_TILE_4 - -LoopSzEnd_TILE_4: - // add 4 ic - addp v10.4s, v10.4s, v11.4s - addp v12.4s, v12.4s, v13.4s - addp v16.4s, v16.4s, v17.4s - addp v18.4s, v18.4s, v19.4s - addp v20.4s, v20.4s, v21.4s - addp v22.4s, v22.4s, v23.4s - addp v24.4s, v24.4s, v25.4s - addp v26.4s, v26.4s, v27.4s - - addp v10.4s, v10.4s, v12.4s // batch=0,oc=0-3 - addp v11.4s, v16.4s, v18.4s - addp v12.4s, v20.4s, v22.4s - addp v13.4s, v24.4s, v26.4s - - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v10, v11, v12, v13 - // Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v5.4s}, [x23] // scales, 4 batch,so 4 scale - - MulScale v10, v11, v12, v13, v5 - -Tile4Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.4s}, [x22] // sums - // alpha * sum + (zero * sums) + bias - Dequant v10, v0, v1, v2, v3, 0 - Dequant v11, v0, v1, v2, v3, 1 - Dequant v12, v0, v1, v2, v3, 2 - Dequant v13, v0, v1, v2, v3, 3 - st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #64 // dst += 4 * 4 * sizeof(float32_t) - add x1, x1, #16 // src += 4 * 4 * sizeof(int8_t) - add x11, x11, #16 // sum += 4 * sizeof(float32_t) - add x12, x12, #16 // scale += 4 * sizeof(float32_t) - b TILE_4 - -TILE_1: - cmp x6, #1 - blt End - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t) - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - // batch=0,oc=0-3 - movi v10.4s, #0 //ic=0-3 - movi v11.4s, #0 - movi v12.4s, #0 - movi v13.4s, #0 - // mask - movi v14.16b, #15 - // offset - movi v15.16b, #8 -LoopSz_TILE_1: - // src : 1(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 1 x 4 x [1] : v16 - ld1 {v0.8b}, [x25], #8 // weight pack*pack*0.5 - ld1 {v4.s}[0], [x24], x15 // src - // int4->int8 - ushr v8.16b, v0.16b, #4 - and v9.16b, v0.16b, v14.16b - zip1 v0.16b, v8.16b, v9.16b - - Unit_TILE_1: - sxtl v5.8h, v4.8b // src batch=0 - sxtl v1.8h, v0.8b // weight oc=0,1 - sxtl2 v2.8h, v0.16b // oc=2,3 - dup v28.2d, v1.d[0] // oc=0,0 - dup v29.2d, v1.d[1] // oc=1,1 - dup v30.2d, v2.d[0] // oc=2,2 - dup v31.2d, v2.d[1] // oc=3,3 - // batch=0 - smlal v10.4s, v5.4h, v28.4h - smlal v11.4s, v5.4h, v29.4h - smlal v12.4s, v5.4h, v30.4h - smlal v13.4s, v5.4h, v31.4h - - //.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] - - subs x26, x26, #1 - bne LoopSz_TILE_1 - -LoopSzEnd_TILE_1: - // add 4 ic - addp v10.4s, v10.4s, v11.4s - addp v12.4s, v12.4s, v13.4s - addp v16.4s, v10.4s, v12.4s - add x7, x7, x13 - sub x27, x27, #1 - scvtf v16.4s, v16.4s - // using float scale dequant for precison - ld1 {v4.s}[0], [x23] // scales - fmul v16.4s, v16.4s, v4.s[0] -Tile1Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.s}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - fmla v2.4s, v0.4s, v16.4s - fmla v2.4s, v1.4s, v3.s[0] - st1 {v2.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - subs x6, x6, #1 // batch -= 1 - add x0, x0, #16 // dst += 1 * 4 * sizeof(float32_t) - add x1, x1, #4 // src += 1 * 4 * sizeof(int8_t) - add x11, x11, #4 // sum += 1 * sizeof(float32_t) - add x12, x12, #4 // scale += 1 * sizeof(float32_t) - bne TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif \ No newline at end of file diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_sdot.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_sdot.S deleted file mode 100644 index 11bf247a5..000000000 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_sdot.S +++ /dev/null @@ -1,413 +0,0 @@ -// -// MNNGemmHybridInt4_sdot.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s - fmul \d0\().4s, \d0\().4s, \s\().s[0] - fmul \d1\().4s, \d1\().4s, \s\().s[1] - fmul \d2\().4s, \d2\().4s, \s\().s[2] - fmul \d3\().4s, \d3\().4s, \s\().s[3] -.endm - -.macro Dequant c0, a0, z0, b0, s0, idx - fmul \c0\().4s, \c0\().4s, \a0\().4s - fmla \c0\().4s, \z0\().4s, \s0\().s[\idx] - fadd \c0\().4s, \c0\().4s, \b0\().4s -.endm - -asm_function MNNGemmHybridInt4FP32_sdot - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] - -Start: -lsl x13, x3, #3 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int8) = src_depth_quad * 8 = src_depth_quad << 3 - -TILE_12: - cmp x6, #12 - blt TILE_8 - sub x14, x4, #128 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_12: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - movi v16.4s, #0 - movi v17.4s, #0 - movi v18.4s, #0 - movi v19.4s, #0 - movi v20.4s, #0 - movi v21.4s, #0 - movi v22.4s, #0 - movi v23.4s, #0 - movi v24.4s, #0 - movi v25.4s, #0 - movi v26.4s, #0 - movi v27.4s, #0 - - // mask - movi v14.16b, #15 -LoopSz_TILE_12: - // src : 4(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 4 x 4 x [1] : v16-v19 - ld1 {v0.8b}, [x25], #8 // weight - ld1 {v4.16b, v5.16b, v6.16b}, [x24], x15 // src - // int4->int8 - ushr v8.16b, v0.16b, #4 - and v9.16b, v0.16b, v14.16b - zip1 v0.16b, v8.16b, v9.16b - .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0 - .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1 - .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2 - .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3 - .inst 0x4f85e014 // sdot v20.4s, v0.16b, v5.4b[0] // batch4 - .inst 0x4fa5e015 // sdot v21.4s, v0.16b, v5.4b[1] // batch5 - .inst 0x4f85e816 // sdot v22.4s, v0.16b, v5.4b[2] // batch6 - .inst 0x4fa5e817 // sdot v23.4s, v0.16b, v5.4b[3] // batch7 - .inst 0x4f86e018 // sdot v24.4s, v0.16b, v6.4b[0] // batch8 - .inst 0x4fa6e019 // sdot v25.4s, v0.16b, v6.4b[1] // batch9 - .inst 0x4f86e81a // sdot v26.4s, v0.16b, v6.4b[2] // batch10 - .inst 0x4fa6e81b // sdot v27.4s, v0.16b, v6.4b[3] // batch11 - subs x26, x26, #1 - bne LoopSz_TILE_12 - -LoopSzEnd_TILE_12: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - Int32ToFloat v20, v21, v22, v23 - Int32ToFloat v24, v25, v26, v27 - // using float scale dequant for precison - ld1 {v5.4s, v6.4s, v7.4s}, [x23] // scales, 12 batch,so 12 scale - - MulScale v16, v17, v18, v19, v5 - MulScale v20, v21, v22, v23, v6 - MulScale v24, v25, v26, v27, v7 - -Tile12Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.4s, v4.4s, v5.4s}, [x22] // sums - // alpha * sum + (zero * sums) + bias - Dequant v16, v0, v1, v2, v3, 0 - Dequant v17, v0, v1, v2, v3, 1 - Dequant v18, v0, v1, v2, v3, 2 - Dequant v19, v0, v1, v2, v3, 3 - Dequant v20, v0, v1, v2, v4, 0 - Dequant v21, v0, v1, v2, v4, 1 - Dequant v22, v0, v1, v2, v4, 2 - Dequant v23, v0, v1, v2, v4, 3 - Dequant v24, v0, v1, v2, v5, 0 - Dequant v25, v0, v1, v2, v5, 1 - Dequant v26, v0, v1, v2, v5, 2 - Dequant v27, v0, v1, v2, v5, 3 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x28], #64 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x28], #64 - st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_12 -Tile12End: - sub x6, x6, #12 // bach -= 12 - add x0, x0, #192 // dst += 12 * 4 * sizeof(float32_t) - add x1, x1, #48 // src += 12 * 4 * sizeof(int8_t) - add x11, x11, #48 // sum += 12 * sizeof(float32_t) - add x12, x12, #48 // scale += 12 * sizeof(float32_t) - b TILE_12 - -TILE_8: - cmp x6, #8 - blt TILE_4 - sub x14, x4, #64 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_8: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - movi v16.4s, #0 - movi v17.4s, #0 - movi v18.4s, #0 - movi v19.4s, #0 - movi v20.4s, #0 - movi v21.4s, #0 - movi v22.4s, #0 - movi v23.4s, #0 - - // mask - movi v14.16b, #15 -LoopSz_TILE_8: - // src : 4(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 4 x 4 x [1] : v16-v19 - ld1 {v0.8b}, [x25], #8 // weight - ld1 {v4.16b, v5.16b}, [x24], x15 // src - // int4->int8 - ushr v8.16b, v0.16b, #4 - and v9.16b, v0.16b, v14.16b - zip1 v0.16b, v8.16b, v9.16b - .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0 - .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1 - .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2 - .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3 - .inst 0x4f85e014 // sdot v20.4s, v0.16b, v5.4b[0] // batch4 - .inst 0x4fa5e015 // sdot v21.4s, v0.16b, v5.4b[1] // batch5 - .inst 0x4f85e816 // sdot v22.4s, v0.16b, v5.4b[2] // batch6 - .inst 0x4fa5e817 // sdot v23.4s, v0.16b, v5.4b[3] // batch7 - subs x26, x26, #1 - bne LoopSz_TILE_8 - -LoopSzEnd_TILE_8: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v5.4s, v6.4s}, [x23] // scales, 8 batch,so 8 scale - - MulScale v16, v17, v18, v19, v5 - MulScale v20, v21, v22, v23, v6 - -Tile8Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.4s, v4.4s}, [x22] // sums - // alpha * sum + (zero * sums) + bias - Dequant v16, v0, v1, v2, v3, 0 - Dequant v17, v0, v1, v2, v3, 1 - Dequant v18, v0, v1, v2, v3, 2 - Dequant v19, v0, v1, v2, v3, 3 - Dequant v20, v0, v1, v2, v4, 0 - Dequant v21, v0, v1, v2, v4, 1 - Dequant v22, v0, v1, v2, v4, 2 - Dequant v23, v0, v1, v2, v4, 3 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x28], #64 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_8 -Tile8End: - sub x6, x6, #8 // bach -= 4 - add x0, x0, #128 // dst += 8 * 4 * sizeof(float32_t) - add x1, x1, #32 // src += 8 * 4 * sizeof(int8_t) - add x11, x11, #32 // sum += 8 * sizeof(float32_t) - add x12, x12, #32 // scale += 8 * sizeof(float32_t) - b TILE_8 - -TILE_4: - cmp x6, #4 - blt TILE_1 - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - // mask - movi v14.16b, #15 -LoopSz_TILE_4: - // src : 4(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 4 x 4 x [1] : v16-v19 - ld1 {v0.8b}, [x25], #8 // weight - ld1 {v4.16b}, [x24], x15 // src - // int4->int8 - ushr v8.16b, v0.16b, #4 - and v9.16b, v0.16b, v14.16b - zip1 v0.16b, v8.16b, v9.16b - .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0 - .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1 - .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2 - .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3 - subs x26, x26, #1 - bne LoopSz_TILE_4 - -LoopSzEnd_TILE_4: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - // Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v5.4s}, [x23] // scales, 4 batch,so 4 scale - - MulScale v16, v17, v18, v19, v5 - -Tile4Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.4s}, [x22] // sums - // alpha * sum + (zero * sums) + bias - Dequant v16, v0, v1, v2, v3, 0 - Dequant v17, v0, v1, v2, v3, 1 - Dequant v18, v0, v1, v2, v3, 2 - Dequant v19, v0, v1, v2, v3, 3 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #64 // dst += 4 * 4 * sizeof(float32_t) - add x1, x1, #16 // src += 4 * 4 * sizeof(int8_t) - add x11, x11, #16 // sum += 4 * sizeof(float32_t) - add x12, x12, #16 // scale += 4 * sizeof(float32_t) - b TILE_4 - -TILE_1: - cmp x6, #1 - blt End - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t) - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - // mask - movi v14.16b, #15 -LoopSz_TILE_1: - // src : 1(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 1 x 4 x [1] : v16 - ld1 {v0.8b}, [x25], #8 // weight pack*pack*0.5 - ld1 {v4.s}[0], [x24], x15 // src - // int4->int8 - ushr v8.16b, v0.16b, #4 - and v9.16b, v0.16b, v14.16b - zip1 v0.16b, v8.16b, v9.16b - - .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] - - subs x26, x26, #1 - bne LoopSz_TILE_1 - -LoopSzEnd_TILE_1: - add x7, x7, x13 - sub x27, x27, #1 - scvtf v16.4s, v16.4s - // using float scale dequant for precison - ld1 {v4.s}[0], [x23] // scales - fmul v16.4s, v16.4s, v4.s[0] -Tile1Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.s}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - fmla v2.4s, v0.4s, v16.4s - fmla v2.4s, v1.4s, v3.s[0] - st1 {v2.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - subs x6, x6, #1 // batch -= 1 - add x0, x0, #16 // dst += 1 * 4 * sizeof(float32_t) - add x1, x1, #4 // src += 1 * 4 * sizeof(int8_t) - add x11, x11, #4 // sum += 1 * sizeof(float32_t) - add x12, x12, #4 // scale += 1 * sizeof(float32_t) - bne TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif \ No newline at end of file diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_smmla.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_smmla.S deleted file mode 100644 index aa0b5a383..000000000 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt4FP32_smmla.S +++ /dev/null @@ -1,476 +0,0 @@ -// -// MNNGemmHybridInt4FP32_smmla.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s, idx0, idx1 - fmul \d0\().4s, \d0\().4s, \s\().s[\idx0] - fmul \d1\().4s, \d1\().4s, \s\().s[\idx0] - fmul \d2\().4s, \d2\().4s, \s\().s[\idx1] - fmul \d3\().4s, \d3\().4s, \s\().s[\idx1] -.endm - -.macro Dequant c0, a0, z0, b0, s0, idx - fmul \c0\().4s, \c0\().4s, \a0\().4s - fmla \c0\().4s, \z0\().4s, \s0\().s[\idx] - fadd \c0\().4s, \c0\().4s, \b0\().4s -.endm - -asm_function MNNGemmHybridInt4FP32_smmla - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] - -Start: -lsl x13, x3, #5 // x13 = src_depth_quad * UNIT * UNIT_SRC / 2(int4) = src_depth_quad * 32 = src_depth_quad << 5 - -TILE_8: - cmp x6, #8 - blt TILE_4 - sub x14, x4, #192 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_8: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr - - // mask - movi v10.16b, #15 -LoopSz_TILE_8: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-23 - //ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - ld1 {v12.16b, v13.16b, v14.16b, v15.16b}, [x24], x15 // src - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v10.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v10.16b - - .inst 0x4e80a590 // smmla v16.4s, v12.16b, v0.16b - .inst 0x4e81a591 // smmla v17.4s, v12.16b, v1.16b - .inst 0x4e82a592 // smmla v18.4s, v12.16b, v2.16b - .inst 0x4e83a593 // smmla v19.4s, v12.16b, v3.16b - .inst 0x4e80a5b4 // smmla v20.4s, v13.16b, v0.16b - .inst 0x4e81a5b5 // smmla v21.4s, v13.16b, v1.16b - .inst 0x4e82a5b6 // smmla v22.4s, v13.16b, v2.16b - .inst 0x4e83a5b7 // smmla v23.4s, v13.16b, v3.16b - .inst 0x4e80a5d8 // smmla v24.4s, v14.16b, v0.16b - .inst 0x4e81a5d9 // smmla v25.4s, v14.16b, v1.16b - .inst 0x4e82a5da // smmla v26.4s, v14.16b, v2.16b - .inst 0x4e83a5db // smmla v27.4s, v14.16b, v3.16b - .inst 0x4e80a5fc // smmla v28.4s, v15.16b, v0.16b - .inst 0x4e81a5fd // smmla v29.4s, v15.16b, v1.16b - .inst 0x4e82a5fe // smmla v30.4s, v15.16b, v2.16b - .inst 0x4e83a5ff // smmla v31.4s, v15.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_8 - -LoopSzEnd_TILE_8: - add x7, x7, x13 - sub x27, x27, #1 - - trn1 v0.2d, v16.2d, v17.2d // batch:0 oc:0-3 - trn1 v1.2d, v18.2d, v19.2d // batch:0 oc:4-7 - trn2 v2.2d, v16.2d, v17.2d // batch:1 oc:0-3 - trn2 v3.2d, v18.2d, v19.2d // batch:1 oc:4-7 - trn1 v4.2d, v20.2d, v21.2d // batch:2 oc:0-3 - trn1 v5.2d, v22.2d, v23.2d // batch:2 oc:4-7 - trn2 v6.2d, v20.2d, v21.2d // batch:3 oc:0-3 - trn2 v7.2d, v22.2d, v23.2d // batch:3 oc:4-7 - - trn1 v8.2d, v24.2d, v25.2d // batch:0 oc:0-3 - trn1 v9.2d, v26.2d, v27.2d // batch:0 oc:4-7 - trn2 v10.2d, v24.2d, v25.2d // batch:1 oc:0-3 - trn2 v11.2d, v26.2d, v27.2d // batch:1 oc:4-7 - trn1 v12.2d, v28.2d, v29.2d // batch:2 oc:0-3 - trn1 v13.2d, v30.2d, v31.2d // batch:2 oc:4-7 - trn2 v14.2d, v28.2d, v29.2d // batch:3 oc:0-3 - trn2 v15.2d, v30.2d, v31.2d // batch:3 oc:4-7 - - Int32ToFloat v0, v1, v2, v3 - Int32ToFloat v4, v5, v6, v7 - Int32ToFloat v8, v9, v10, v11 - Int32ToFloat v12, v13, v14, v15 - // using float scale dequant for precison - ld1 {v16.4s, v17.4s}, [x23] // scales - MulScale v0, v1, v2, v3, v16, 0, 1 - MulScale v4, v5, v6, v7, v16, 2, 3 - MulScale v8, v9, v10, v11, v17, 0, 1 - MulScale v12, v13, v14, v15, v17, 2, 3 -Tile8Dequant: - ld1 {v18.4s, v19.4s}, [x19], #32 // alpha - ld1 {v20.4s, v21.4s}, [x20], #32 // zero - ld1 {v22.4s, v23.4s}, [x21], #32 // bias - ld1 {v24.4s, v25.4s}, [x22] // sums - // alpha * cusum + (zero * sums) + bias - Dequant v0, v18, v20, v22, v24, 0 // Batch0 - Dequant v1, v19, v21, v23, v24, 0 - Dequant v2, v18, v20, v22, v24, 1 // Batch1 - Dequant v3, v19, v21, v23, v24, 1 - Dequant v4, v18, v20, v22, v24, 2 // Batch2 - Dequant v5, v19, v21, v23, v24, 2 - Dequant v6, v18, v20, v22, v24, 3 // Batch3 - Dequant v7, v19, v21, v23, v24, 3 - Dequant v8, v18, v20, v22, v25, 0 // Batch4 - Dequant v9, v19, v21, v23, v25, 0 - Dequant v10, v18, v20, v22, v25, 1 // Batch5 - Dequant v11, v19, v21, v23, v25, 1 - Dequant v12, v18, v20, v22, v25, 2 // Batch6 - Dequant v13, v19, v21, v23, v25, 2 - Dequant v14, v18, v20, v22, v25, 3 // Batch7 - Dequant v15, v19, v21, v23, v25, 3 - st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x28], #64 - st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x28], #64 - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x28], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_8 -Tile8End: - sub x6, x6, #8 // bach -= 8 - add x0, x0, #256 // dst += 8 * 8 * sizeof(float32_t) - add x1, x1, #64 // src += 8 * 8 * sizeof(int8_t) - add x11, x11, #32 // sum += 8 * sizeof(float32_t) - add x12, x12, #32 // scale += 8 * sizeof(float32_t) - b TILE_8 - -TILE_4: - cmp x6, #4 - blt TILE_2 - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - sub x14, x14, #64 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - // mask - movi v10.16b, #15 -LoopSz_TILE_4: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-23 - //ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v10.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v10.16b - ld1 {v4.16b, v5.16b}, [x24], x15 // src - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b - .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b - .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b - .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b - .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_4 - -LoopSzEnd_TILE_4: - add x7, x7, x13 - sub x27, x27, #1 - - trn1 v24.2d, v16.2d, v17.2d // batch:0 oc:0-3 - trn1 v25.2d, v18.2d, v19.2d // batch:0 oc:4-7 - trn2 v26.2d, v16.2d, v17.2d // batch:1 oc:0-3 - trn2 v27.2d, v18.2d, v19.2d // batch:1 oc:4-7 - trn1 v28.2d, v20.2d, v21.2d // batch:2 oc:0-3 - trn1 v29.2d, v22.2d, v23.2d // batch:2 oc:4-7 - trn2 v30.2d, v20.2d, v21.2d // batch:3 oc:0-3 - trn2 v31.2d, v22.2d, v23.2d // batch:3 oc:4-7 - Int32ToFloat v24, v25, v26, v27 - Int32ToFloat v28, v29, v30, v31 - // using float scale dequant for precison - ld1 {v5.4s}, [x23] // scales - MulScale v24, v25, v26, v27, v5, 0, 1 - MulScale v28, v29, v30, v31, v5, 2, 3 -Tile4Dequant: - ld1 {v0.4s, v1.4s}, [x19], #32 // alpha - ld1 {v2.4s, v3.4s}, [x20], #32 // zero - ld1 {v8.4s, v9.4s}, [x21], #32 // bias - ld1 {v6.4s}, [x22] // sums - // alpha * cusum + (zero * sums) + bias - Dequant v24, v0, v2, v8, v6, 0 // Batch0 - Dequant v25, v1, v3, v9, v6, 0 - Dequant v26, v0, v2, v8, v6, 1 // Batch1 - Dequant v27, v1, v3, v9, v6, 1 - Dequant v28, v0, v2, v8, v6, 2 // Batch2 - Dequant v29, v1, v3, v9, v6, 2 - Dequant v30, v0, v2, v8, v6, 3 // Batch3 - Dequant v31, v1, v3, v9, v6, 3 - st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x28], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #128 // dst += 4 * 8 * sizeof(float32_t) - add x1, x1, #32 // src += 4 * 8 * sizeof(int8_t) - add x11, x11, #16 // sum += 4 * sizeof(float32_t) - add x12, x12, #16 // scale += 4 * sizeof(float32_t) - b TILE_4 - -TILE_2: - cmp x6, #2 - blt TILE_1 - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_2: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - // mask - movi v14.16b, #15 -LoopSz_TILE_2: - // src : 1 x [2 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [4] : v16-19 - //ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v14.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v14.16b - ld1 {v4.16b}, [x24], x15 // src - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_2 - -LoopSzEnd_TILE_2: - add x7, x7, x13 - sub x27, x27, #1 - trn1 v20.2d, v16.2d, v17.2d - trn1 v21.2d, v18.2d, v19.2d - trn2 v22.2d, v16.2d, v17.2d - trn2 v23.2d, v18.2d, v19.2d - Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v5.d}[0], [x23] // scales - fmul v20.4s, v20.4s, v5.s[0] - fmul v21.4s, v21.4s, v5.s[0] - fmul v22.4s, v22.4s, v5.s[1] - fmul v23.4s, v23.4s, v5.s[1] -Tile2Dequant: - ld1 {v0.4s, v1.4s}, [x19], #32 // alpha - ld1 {v2.4s, v3.4s}, [x20], #32 // zero - ld1 {v8.4s, v9.4s}, [x21], #32 // bias - ld1 {v10.d}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - Dequant v20, v0, v2, v8, v10, 0 - Dequant v21, v1, v3, v9, v10, 0 - Dequant v22, v0, v2, v8, v10, 1 - Dequant v23, v1, v3, v9, v10, 1 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_2 -Tile2End: - sub x6, x6, #2 // batch -= 2 - add x0, x0, #64 // dst += 2 * 8 * sizeof(float32_t) - add x1, x1, #16 // dst += 2 * 8 * sizeof(int8_t) - add x11, x11, #8 // sum += 2 * sizeof(float32_t) - add x12, x12, #8 // scale += 2 * sizeof(float32_t) - b TILE_2 - -TILE_1: - cmp x6, #1 - blt End - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t) - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - // mask - movi v14.16b, #15 - -LoopSz_TILE_1: - // src : 1 x [1 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [2] : v16-v19 - //ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v8.16b, v9.16b}, [x25], #32 // weight - // int4 to int8: v0, v1, v2, v3 - ushr v0.16b, v8.16b, #4 - and v1.16b, v8.16b, v14.16b - ushr v2.16b, v9.16b, #4 - and v3.16b, v9.16b, v14.16b - ld1 {v4.8b}, [x24], x15 // src - .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b - .inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b - .inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b - .inst 0x4e84a473 // smmla v19.4s, v3.16b, v4.16b - - subs x26, x26, #1 - bne LoopSz_TILE_1 - -LoopSzEnd_TILE_1: - add x7, x7, x13 - sub x27, x27, #1 - uzp1 v20.4s, v16.4s, v17.4s - uzp1 v21.4s, v18.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s - // using float scale dequant for precison - ld1 {v4.s}[0], [x23] // scales - fmul v20.4s, v20.4s, v4.s[0] - fmul v21.4s, v21.4s, v4.s[0] -Tile1Dequant: - ld1 {v0.4s, v1.4s}, [x19], #32 // alpha - ld1 {v2.4s, v3.4s}, [x20], #32 // zero - ld1 {v12.4s, v13.4s}, [x21], #32 // bias - ld1 {v6.s}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - fmla v12.4s, v20.4s, v0.4s - fmla v13.4s, v21.4s, v1.4s - fmla v12.4s, v2.4s, v6.s[0] - fmla v13.4s, v3.4s, v6.s[0] - st1 {v12.4s, v13.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - sub x6, x6, #1 // batch -= 1 - add x0, x0, #32 // dst += 1 * 8 * sizeof(float32_t) - add x1, x1, #8 // dst += 1 * 8 * sizeof(int8_t) - add x11, x11, #4 // sum += 1 * sizeof(float32_t) - add x12, x12, #4 // scale += 1 * sizeof(float32_t) - b TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif \ No newline at end of file diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32.S deleted file mode 100644 index 418638fce..000000000 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32.S +++ /dev/null @@ -1,293 +0,0 @@ -// -// MNNGemmHybridInt4_sdot.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s - fmul \d0\().4s, \d0\().4s, \s\().s[0] - fmul \d1\().4s, \d1\().4s, \s\().s[1] - fmul \d2\().4s, \d2\().4s, \s\().s[2] - fmul \d3\().4s, \d3\().4s, \s\().s[3] -.endm - -.macro Dequant c0, a0, z0, b0, s0, idx - fmul \c0\().4s, \c0\().4s, \a0\().4s - fmla \c0\().4s, \z0\().4s, \s0\().s[\idx] - fadd \c0\().4s, \c0\().4s, \b0\().4s -.endm - -asm_function MNNGemmHybridInt8FP32 - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt8FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] - -Start: -lsl x13, x3, #4 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 16 = src_depth_quad << 4 - -TILE_4: - cmp x6, #4 - blt TILE_1 - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - // batch=0,oc=0-3 - movi v10.4s, #0 //ic=0-3 - movi v11.4s, #0 - movi v12.4s, #0 - movi v13.4s, #0 - // batch=1,oc=0-3 - movi v16.4s, #0 - movi v17.4s, #0 - movi v18.4s, #0 - movi v19.4s, #0 - // batch=2,oc=0-3 - movi v20.4s, #0 - movi v21.4s, #0 - movi v22.4s, #0 - movi v23.4s, #0 - // batch=3,oc=0-3 - movi v24.4s, #0 - movi v25.4s, #0 - movi v26.4s, #0 - movi v27.4s, #0 - -LoopSz_TILE_4: - // src : 4(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - ld1 {v0.16b}, [x25], #16 // weight - ld1 {v4.16b}, [x24], x15 // src - - Unit_TILE_4: - sxtl v5.8h, v4.8b // src batch=0,1 - sxtl2 v6.8h, v4.16b // batch=2,3 - sxtl v1.8h, v0.8b // weight oc=0,1 - sxtl2 v2.8h, v0.16b // oc=2,3 - dup v28.2d, v1.d[0] // oc=0,0 - dup v29.2d, v1.d[1] // oc=1,1 - dup v30.2d, v2.d[0] // oc=2,2 - dup v31.2d, v2.d[1] // oc=3,3 - // batch=0 - smlal v10.4s, v5.4h, v28.4h - smlal v11.4s, v5.4h, v29.4h - smlal v12.4s, v5.4h, v30.4h - smlal v13.4s, v5.4h, v31.4h - // batch=1 - smlal2 v16.4s, v5.8h, v28.8h - smlal2 v17.4s, v5.8h, v29.8h - smlal2 v18.4s, v5.8h, v30.8h - smlal2 v19.4s, v5.8h, v31.8h - // batch=2 - smlal v20.4s, v6.4h, v28.4h - smlal v21.4s, v6.4h, v29.4h - smlal v22.4s, v6.4h, v30.4h - smlal v23.4s, v6.4h, v31.4h - // batch=3 - smlal2 v24.4s, v6.8h, v28.8h - smlal2 v25.4s, v6.8h, v29.8h - smlal2 v26.4s, v6.8h, v30.8h - smlal2 v27.4s, v6.8h, v31.8h - // .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0 - // .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1 - // .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2 - // .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3 - - subs x26, x26, #1 - bne LoopSz_TILE_4 - -LoopSzEnd_TILE_4: - // add 4 ic - addp v10.4s, v10.4s, v11.4s - addp v12.4s, v12.4s, v13.4s - addp v16.4s, v16.4s, v17.4s - addp v18.4s, v18.4s, v19.4s - addp v20.4s, v20.4s, v21.4s - addp v22.4s, v22.4s, v23.4s - addp v24.4s, v24.4s, v25.4s - addp v26.4s, v26.4s, v27.4s - - addp v10.4s, v10.4s, v12.4s // batch=0,oc=0-3 - addp v11.4s, v16.4s, v18.4s - addp v12.4s, v20.4s, v22.4s - addp v13.4s, v24.4s, v26.4s - - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v10, v11, v12, v13 - // Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v5.4s}, [x23] // scales, 4 batch,so 4 scale - - MulScale v10, v11, v12, v13, v5 - -Tile4Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.4s}, [x22] // sums - // alpha * sum + (zero * sums) + bias - Dequant v10, v0, v1, v2, v3, 0 - Dequant v11, v0, v1, v2, v3, 1 - Dequant v12, v0, v1, v2, v3, 2 - Dequant v13, v0, v1, v2, v3, 3 - st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #64 // dst += 4 * 4 * sizeof(float32_t) - add x1, x1, #16 // src += 4 * 4 * sizeof(int8_t) - add x11, x11, #16 // sum += 4 * sizeof(float32_t) - add x12, x12, #16 // scale += 4 * sizeof(float32_t) - b TILE_4 - -TILE_1: - cmp x6, #1 - blt End - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t) - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - // batch=0,oc=0-3 - movi v10.4s, #0 //ic=0-3 - movi v11.4s, #0 - movi v12.4s, #0 - movi v13.4s, #0 - -LoopSz_TILE_1: - // src : 1(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 1 x 4 x [1] : v16 - ld1 {v0.16b}, [x25], #16 // weight pack*pack - ld1 {v4.s}[0], [x24], x15 // src - - Unit_TILE_1: - sxtl v5.8h, v4.8b // src batch=0 - sxtl v1.8h, v0.8b // weight oc=0,1 - sxtl2 v2.8h, v0.16b // oc=2,3 - dup v28.2d, v1.d[0] // oc=0,0 - dup v29.2d, v1.d[1] // oc=1,1 - dup v30.2d, v2.d[0] // oc=2,2 - dup v31.2d, v2.d[1] // oc=3,3 - // batch=0 - smlal v10.4s, v5.4h, v28.4h - smlal v11.4s, v5.4h, v29.4h - smlal v12.4s, v5.4h, v30.4h - smlal v13.4s, v5.4h, v31.4h - - //.inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] - - subs x26, x26, #1 - bne LoopSz_TILE_1 - -LoopSzEnd_TILE_1: - // add 4 ic - addp v10.4s, v10.4s, v11.4s - addp v12.4s, v12.4s, v13.4s - addp v16.4s, v10.4s, v12.4s - add x7, x7, x13 - sub x27, x27, #1 - scvtf v16.4s, v16.4s - // using float scale dequant for precison - ld1 {v4.s}[0], [x23] // scales - fmul v16.4s, v16.4s, v4.s[0] -Tile1Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.s}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - fmla v2.4s, v0.4s, v16.4s - fmla v2.4s, v1.4s, v3.s[0] - st1 {v2.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - subs x6, x6, #1 // batch -= 1 - add x0, x0, #16 // dst += 1 * 4 * sizeof(float32_t) - add x1, x1, #4 // src += 1 * 4 * sizeof(int8_t) - add x11, x11, #4 // sum += 1 * sizeof(float32_t) - add x12, x12, #4 // scale += 1 * sizeof(float32_t) - bne TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif \ No newline at end of file diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_sdot.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_sdot.S deleted file mode 100644 index dd14f71d5..000000000 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_sdot.S +++ /dev/null @@ -1,396 +0,0 @@ -// -// MNNGemmHybridInt8_smmla.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s - fmul \d0\().4s, \d0\().4s, \s\().s[0] - fmul \d1\().4s, \d1\().4s, \s\().s[1] - fmul \d2\().4s, \d2\().4s, \s\().s[2] - fmul \d3\().4s, \d3\().4s, \s\().s[3] -.endm - -.macro Dequant c0, a0, z0, b0, s0, idx - fmul \c0\().4s, \c0\().4s, \a0\().4s - fmla \c0\().4s, \z0\().4s, \s0\().s[\idx] - fadd \c0\().4s, \c0\().4s, \b0\().4s -.endm - -asm_function MNNGemmHybridInt8FP32_sdot - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt8FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] - -Start: -lsl x13, x3, #4 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 16 = src_depth_quad << 4 - -TILE_12: - cmp x6, #12 - blt TILE_8 - sub x14, x4, #128 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_12: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - movi v16.4s, #0 - movi v17.4s, #0 - movi v18.4s, #0 - movi v19.4s, #0 - movi v20.4s, #0 - movi v21.4s, #0 - movi v22.4s, #0 - movi v23.4s, #0 - movi v24.4s, #0 - movi v25.4s, #0 - movi v26.4s, #0 - movi v27.4s, #0 - -LoopSz_TILE_12: - // src : 4(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 4 x 4 x [1] : v16-v19 - ld1 {v0.16b}, [x25], #16 // weight - ld1 {v4.16b, v5.16b, v6.16b}, [x24], x15 // src - - .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0 - .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1 - .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2 - .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3 - .inst 0x4f85e014 // sdot v20.4s, v0.16b, v5.4b[0] // batch4 - .inst 0x4fa5e015 // sdot v21.4s, v0.16b, v5.4b[1] // batch5 - .inst 0x4f85e816 // sdot v22.4s, v0.16b, v5.4b[2] // batch6 - .inst 0x4fa5e817 // sdot v23.4s, v0.16b, v5.4b[3] // batch7 - .inst 0x4f86e018 // sdot v24.4s, v0.16b, v6.4b[0] // batch8 - .inst 0x4fa6e019 // sdot v25.4s, v0.16b, v6.4b[1] // batch9 - .inst 0x4f86e81a // sdot v26.4s, v0.16b, v6.4b[2] // batch10 - .inst 0x4fa6e81b // sdot v27.4s, v0.16b, v6.4b[3] // batch11 - subs x26, x26, #1 - bne LoopSz_TILE_12 - -LoopSzEnd_TILE_12: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - Int32ToFloat v20, v21, v22, v23 - Int32ToFloat v24, v25, v26, v27 - // using float scale dequant for precison - ld1 {v5.4s, v6.4s, v7.4s}, [x23] // scales, 12 batch,so 12 scale - - MulScale v16, v17, v18, v19, v5 - MulScale v20, v21, v22, v23, v6 - MulScale v24, v25, v26, v27, v7 - -Tile12Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.4s, v4.4s, v5.4s}, [x22] // sums - // alpha * sum + (zero * sums) + bias - Dequant v16, v0, v1, v2, v3, 0 - Dequant v17, v0, v1, v2, v3, 1 - Dequant v18, v0, v1, v2, v3, 2 - Dequant v19, v0, v1, v2, v3, 3 - Dequant v20, v0, v1, v2, v4, 0 - Dequant v21, v0, v1, v2, v4, 1 - Dequant v22, v0, v1, v2, v4, 2 - Dequant v23, v0, v1, v2, v4, 3 - Dequant v24, v0, v1, v2, v5, 0 - Dequant v25, v0, v1, v2, v5, 1 - Dequant v26, v0, v1, v2, v5, 2 - Dequant v27, v0, v1, v2, v5, 3 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x28], #64 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x28], #64 - st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_12 -Tile12End: - sub x6, x6, #12 // bach -= 12 - add x0, x0, #192 // dst += 12 * 4 * sizeof(float32_t) - add x1, x1, #48 // src += 12 * 4 * sizeof(int8_t) - add x11, x11, #48 // sum += 12 * sizeof(float32_t) - add x12, x12, #48 // scale += 12 * sizeof(float32_t) - b TILE_12 - -TILE_8: - cmp x6, #8 - blt TILE_4 - sub x14, x4, #64 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_8: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - movi v16.4s, #0 - movi v17.4s, #0 - movi v18.4s, #0 - movi v19.4s, #0 - movi v20.4s, #0 - movi v21.4s, #0 - movi v22.4s, #0 - movi v23.4s, #0 - - // mask - movi v14.16b, #15 - // offset - movi v15.16b, #8 -LoopSz_TILE_8: - // src : 4(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 4 x 4 x [1] : v16-v19 - ld1 {v0.16b}, [x25], #16 // weight - ld1 {v4.16b, v5.16b}, [x24], x15 // src - - .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0 - .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1 - .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2 - .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3 - .inst 0x4f85e014 // sdot v20.4s, v0.16b, v5.4b[0] // batch4 - .inst 0x4fa5e015 // sdot v21.4s, v0.16b, v5.4b[1] // batch5 - .inst 0x4f85e816 // sdot v22.4s, v0.16b, v5.4b[2] // batch6 - .inst 0x4fa5e817 // sdot v23.4s, v0.16b, v5.4b[3] // batch7 - subs x26, x26, #1 - bne LoopSz_TILE_8 - -LoopSzEnd_TILE_8: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v5.4s, v6.4s}, [x23] // scales, 8 batch,so 8 scale - - MulScale v16, v17, v18, v19, v5 - MulScale v20, v21, v22, v23, v6 - -Tile8Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.4s, v4.4s}, [x22] // sums - // alpha * sum + (zero * sums) + bias - Dequant v16, v0, v1, v2, v3, 0 - Dequant v17, v0, v1, v2, v3, 1 - Dequant v18, v0, v1, v2, v3, 2 - Dequant v19, v0, v1, v2, v3, 3 - Dequant v20, v0, v1, v2, v4, 0 - Dequant v21, v0, v1, v2, v4, 1 - Dequant v22, v0, v1, v2, v4, 2 - Dequant v23, v0, v1, v2, v4, 3 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x28], #64 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_8 -Tile8End: - sub x6, x6, #8 // bach -= 4 - add x0, x0, #128 // dst += 8 * 4 * sizeof(float32_t) - add x1, x1, #32 // src += 8 * 4 * sizeof(int8_t) - add x11, x11, #32 // sum += 8 * sizeof(float32_t) - add x12, x12, #32 // scale += 8 * sizeof(float32_t) - b TILE_8 - -TILE_4: - cmp x6, #4 - blt TILE_1 - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - -LoopSz_TILE_4: - // src : 4(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 4 x 4 x [1] : v16-v19 - ld1 {v0.16b}, [x25], #16 // weight - ld1 {v4.16b}, [x24], x15 // src - .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] // batch0 - .inst 0x4fa4e011 // sdot v17.4s, v0.16b, v4.4b[1] // batch1 - .inst 0x4f84e812 // sdot v18.4s, v0.16b, v4.4b[2] // batch2 - .inst 0x4fa4e813 // sdot v19.4s, v0.16b, v4.4b[3] // batch3 - subs x26, x26, #1 - bne LoopSz_TILE_4 - -LoopSzEnd_TILE_4: - add x7, x7, x13 - sub x27, x27, #1 - Int32ToFloat v16, v17, v18, v19 - // Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v5.4s}, [x23] // scales, 4 batch,so 4 scale - - MulScale v16, v17, v18, v19, v5 - -Tile4Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.4s}, [x22] // sums - // alpha * sum + (zero * sums) + bias - Dequant v16, v0, v1, v2, v3, 0 - Dequant v17, v0, v1, v2, v3, 1 - Dequant v18, v0, v1, v2, v3, 2 - Dequant v19, v0, v1, v2, v3, 3 - st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #64 // dst += 4 * 4 * sizeof(float32_t) - add x1, x1, #16 // src += 4 * 4 * sizeof(int8_t) - add x11, x11, #16 // sum += 4 * sizeof(float32_t) - add x12, x12, #16 // scale += 4 * sizeof(float32_t) - b TILE_4 - -TILE_1: - cmp x6, #1 - blt End - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t) - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - -LoopSz_TILE_1: - // src : 1(batch) x [1 x 4] : v4 - // weight : 4(oc) x [1 x 4] : v0 - // dst : 1 x 4 x [1] : v16 - ld1 {v0.16b}, [x25], #16 // weight - ld1 {v4.s}[0], [x24], x15 // src - .inst 0x4f84e010 // sdot v16.4s, v0.16b, v4.4b[0] - - subs x26, x26, #1 - bne LoopSz_TILE_1 - -LoopSzEnd_TILE_1: - add x7, x7, x13 - sub x27, x27, #1 - scvtf v16.4s, v16.4s - // using float scale dequant for precison - ld1 {v4.s}[0], [x23] // scales - fmul v16.4s, v16.4s, v4.s[0] -Tile1Dequant: - ld1 {v0.4s}, [x19], #16 // alpha - ld1 {v1.4s}, [x20], #16 // zero - ld1 {v2.4s}, [x21], #16 // bias - ld1 {v3.s}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - fmla v2.4s, v0.4s, v16.4s - fmla v2.4s, v1.4s, v3.s[0] - st1 {v2.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - sub x6, x6, #1 // batch -= 1 - add x0, x0, #16 // dst += 1 * 4 * sizeof(float32_t) - add x1, x1, #4 // src += 1 * 4 * sizeof(int8_t) - add x11, x11, #4 // sum += 1 * sizeof(float32_t) - add x12, x12, #4 // scale += 1 * sizeof(float32_t) - b TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif \ No newline at end of file diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_smmla.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_smmla.S deleted file mode 100644 index a4c915853..000000000 --- a/source/backend/cpu/arm/arm64/low_memory/MNNGemmHybridInt8FP32_smmla.S +++ /dev/null @@ -1,445 +0,0 @@ -// -// MNNGemmHybridInt8FP32_smmla.S -// MNN -// -// Created by MNN on 2023/11/09. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifdef __aarch64__ - -#include "MNNAsmGlobal.h" - -.text -.align 5 - -.macro Int32ToFloat z0, z1, z2, z3 - scvtf \z0\().4s, \z0\().4s - scvtf \z1\().4s, \z1\().4s - scvtf \z2\().4s, \z2\().4s - scvtf \z3\().4s, \z3\().4s -.endm - -.macro MulScale d0, d1, d2, d3, s, idx0, idx1 - fmul \d0\().4s, \d0\().4s, \s\().s[\idx0] - fmul \d1\().4s, \d1\().4s, \s\().s[\idx0] - fmul \d2\().4s, \d2\().4s, \s\().s[\idx1] - fmul \d3\().4s, \d3\().4s, \s\().s[\idx1] -.endm - -.macro Dequant c0, a0, z0, b0, s0, idx - fmul \c0\().4s, \c0\().4s, \a0\().4s - fmla \c0\().4s, \z0\().4s, \s0\().s[\idx] - fadd \c0\().4s, \c0\().4s, \b0\().4s -.endm - -asm_function MNNGemmHybridInt8FP32_smmla - -//struct QuanPostTreatParameters { -// const float* scale; -// const int32_t* bias; -// int32_t maxValue; -// int32_t minValue; -// int32_t useInt8; -//}; - -//void MNNGemmHybridInt8FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, float** param); - - -// Auto: x0: C*, x1: A*, x2:B*, x3: src_depth_quad, x4: dst_step, x5: dst_depth_quad, x6: realSize, x7: param -// load from param: x7: alpha*, x8: zero*, x9: bias*, x10: sums*, x11: scales* -stp d14, d15, [sp, #(-16 * 9)]! -stp d12, d13, [sp, #(16 * 1)] -stp d10, d11, [sp, #(16 * 2)] -stp d8, d9, [sp, #(16 * 3)] -stp x21, x22, [sp, #(16 * 4)] -stp x19, x20, [sp, #(16 * 5)] -stp x23, x24, [sp, #(16 * 6)] -stp x25, x26, [sp, #(16 * 7)] -stp x27, x28, [sp, #(16 * 8)] - -ldr x8, [x7, #0] -ldr x9, [x7, #8] -ldr x10, [x7, #16] -ldr x11, [x7, #24] -ldr x12, [x7, #32] - -Start: -lsl x13, x3, #6 // x13 = src_depth_quad * UNIT * UNIT_SRC / 1(int8) = src_depth_quad * 64 = src_depth_quad << 6 - -TILE_8: - cmp x6, #8 - blt TILE_4 - sub x14, x4, #192 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_8: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr - dup v24.4s, wzr - dup v25.4s, wzr - dup v26.4s, wzr - dup v27.4s, wzr - dup v28.4s, wzr - dup v29.4s, wzr - dup v30.4s, wzr - dup v31.4s, wzr - -LoopSz_TILE_8: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-23 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v12.16b, v13.16b, v14.16b, v15.16b}, [x24], x15 // src - - .inst 0x4e80a590 // smmla v16.4s, v12.16b, v0.16b - .inst 0x4e81a591 // smmla v17.4s, v12.16b, v1.16b - .inst 0x4e82a592 // smmla v18.4s, v12.16b, v2.16b - .inst 0x4e83a593 // smmla v19.4s, v12.16b, v3.16b - .inst 0x4e80a5b4 // smmla v20.4s, v13.16b, v0.16b - .inst 0x4e81a5b5 // smmla v21.4s, v13.16b, v1.16b - .inst 0x4e82a5b6 // smmla v22.4s, v13.16b, v2.16b - .inst 0x4e83a5b7 // smmla v23.4s, v13.16b, v3.16b - .inst 0x4e80a5d8 // smmla v24.4s, v14.16b, v0.16b - .inst 0x4e81a5d9 // smmla v25.4s, v14.16b, v1.16b - .inst 0x4e82a5da // smmla v26.4s, v14.16b, v2.16b - .inst 0x4e83a5db // smmla v27.4s, v14.16b, v3.16b - .inst 0x4e80a5fc // smmla v28.4s, v15.16b, v0.16b - .inst 0x4e81a5fd // smmla v29.4s, v15.16b, v1.16b - .inst 0x4e82a5fe // smmla v30.4s, v15.16b, v2.16b - .inst 0x4e83a5ff // smmla v31.4s, v15.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_8 - -LoopSzEnd_TILE_8: - add x7, x7, x13 - sub x27, x27, #1 - - trn1 v0.2d, v16.2d, v17.2d // batch:0 oc:0-3 - trn1 v1.2d, v18.2d, v19.2d // batch:0 oc:4-7 - trn2 v2.2d, v16.2d, v17.2d // batch:1 oc:0-3 - trn2 v3.2d, v18.2d, v19.2d // batch:1 oc:4-7 - trn1 v4.2d, v20.2d, v21.2d // batch:2 oc:0-3 - trn1 v5.2d, v22.2d, v23.2d // batch:2 oc:4-7 - trn2 v6.2d, v20.2d, v21.2d // batch:3 oc:0-3 - trn2 v7.2d, v22.2d, v23.2d // batch:3 oc:4-7 - - trn1 v8.2d, v24.2d, v25.2d // batch:0 oc:0-3 - trn1 v9.2d, v26.2d, v27.2d // batch:0 oc:4-7 - trn2 v10.2d, v24.2d, v25.2d // batch:1 oc:0-3 - trn2 v11.2d, v26.2d, v27.2d // batch:1 oc:4-7 - trn1 v12.2d, v28.2d, v29.2d // batch:2 oc:0-3 - trn1 v13.2d, v30.2d, v31.2d // batch:2 oc:4-7 - trn2 v14.2d, v28.2d, v29.2d // batch:3 oc:0-3 - trn2 v15.2d, v30.2d, v31.2d // batch:3 oc:4-7 - - Int32ToFloat v0, v1, v2, v3 - Int32ToFloat v4, v5, v6, v7 - Int32ToFloat v8, v9, v10, v11 - Int32ToFloat v12, v13, v14, v15 - // using float scale dequant for precison - ld1 {v16.4s, v17.4s}, [x23] // scales - MulScale v0, v1, v2, v3, v16, 0, 1 - MulScale v4, v5, v6, v7, v16, 2, 3 - MulScale v8, v9, v10, v11, v17, 0, 1 - MulScale v12, v13, v14, v15, v17, 2, 3 -Tile8Dequant: - ld1 {v18.4s, v19.4s}, [x19], #32 // alpha - ld1 {v20.4s, v21.4s}, [x20], #32 // zero - ld1 {v22.4s, v23.4s}, [x21], #32 // bias - ld1 {v24.4s, v25.4s}, [x22] // sums - // alpha * cusum + (zero * sums) + bias - Dequant v0, v18, v20, v22, v24, 0 // Batch0 - Dequant v1, v19, v21, v23, v24, 0 - Dequant v2, v18, v20, v22, v24, 1 // Batch1 - Dequant v3, v19, v21, v23, v24, 1 - Dequant v4, v18, v20, v22, v24, 2 // Batch2 - Dequant v5, v19, v21, v23, v24, 2 - Dequant v6, v18, v20, v22, v24, 3 // Batch3 - Dequant v7, v19, v21, v23, v24, 3 - Dequant v8, v18, v20, v22, v25, 0 // Batch4 - Dequant v9, v19, v21, v23, v25, 0 - Dequant v10, v18, v20, v22, v25, 1 // Batch5 - Dequant v11, v19, v21, v23, v25, 1 - Dequant v12, v18, v20, v22, v25, 2 // Batch6 - Dequant v13, v19, v21, v23, v25, 2 - Dequant v14, v18, v20, v22, v25, 3 // Batch7 - Dequant v15, v19, v21, v23, v25, 3 - st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x28], #64 - st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x28], #64 - st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x28], #64 - st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_8 -Tile8End: - sub x6, x6, #8 // bach -= 8 - add x0, x0, #256 // dst += 8 * 8 * sizeof(float32_t) - add x1, x1, #64 // src += 8 * 8 * sizeof(int8_t) - add x11, x11, #32 // sum += 8 * sizeof(float32_t) - add x12, x12, #32 // scale += 8 * sizeof(float32_t) - b TILE_8 - -TILE_4: - cmp x6, #4 - blt TILE_2 - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - sub x14, x14, #64 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_4: - // dequant info for batch - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - dup v20.4s, wzr - dup v21.4s, wzr - dup v22.4s, wzr - dup v23.4s, wzr -LoopSz_TILE_4: - // src : 2 x [2 x 8] : v4-5 - // weight : 4 x [2 x 8] : v0-3 - // dst : 2 x 4 x [4] : v16-23 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.16b, v5.16b}, [x24], x15 // src - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b - .inst 0x4e80a4b4 // smmla v20.4s, v5.16b, v0.16b - .inst 0x4e81a4b5 // smmla v21.4s, v5.16b, v1.16b - .inst 0x4e82a4b6 // smmla v22.4s, v5.16b, v2.16b - .inst 0x4e83a4b7 // smmla v23.4s, v5.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_4 - -LoopSzEnd_TILE_4: - add x7, x7, x13 - sub x27, x27, #1 - - trn1 v24.2d, v16.2d, v17.2d // batch:0 oc:0-3 - trn1 v25.2d, v18.2d, v19.2d // batch:0 oc:4-7 - trn2 v26.2d, v16.2d, v17.2d // batch:1 oc:0-3 - trn2 v27.2d, v18.2d, v19.2d // batch:1 oc:4-7 - trn1 v28.2d, v20.2d, v21.2d // batch:2 oc:0-3 - trn1 v29.2d, v22.2d, v23.2d // batch:2 oc:4-7 - trn2 v30.2d, v20.2d, v21.2d // batch:3 oc:0-3 - trn2 v31.2d, v22.2d, v23.2d // batch:3 oc:4-7 - Int32ToFloat v24, v25, v26, v27 - Int32ToFloat v28, v29, v30, v31 - // using float scale dequant for precison - ld1 {v5.4s}, [x23] // scales - MulScale v24, v25, v26, v27, v5, 0, 1 - MulScale v28, v29, v30, v31, v5, 2, 3 -Tile4Dequant: - ld1 {v0.4s, v1.4s}, [x19], #32 // alpha - ld1 {v2.4s, v3.4s}, [x20], #32 // zero - ld1 {v8.4s, v9.4s}, [x21], #32 // bias - ld1 {v6.4s}, [x22] // sums - // alpha * cusum + (zero * sums) + bias - Dequant v24, v0, v2, v8, v6, 0 // Batch0 - Dequant v25, v1, v3, v9, v6, 0 - Dequant v26, v0, v2, v8, v6, 1 // Batch1 - Dequant v27, v1, v3, v9, v6, 1 - Dequant v28, v0, v2, v8, v6, 2 // Batch2 - Dequant v29, v1, v3, v9, v6, 2 - Dequant v30, v0, v2, v8, v6, 3 // Batch3 - Dequant v31, v1, v3, v9, v6, 3 - st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x28], #64 - st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_4 -Tile4End: - sub x6, x6, #4 // bach -= 4 - add x0, x0, #128 // dst += 4 * 8 * sizeof(float32_t) - add x1, x1, #32 // src += 4 * 8 * sizeof(int8_t) - add x11, x11, #16 // sum += 4 * sizeof(float32_t) - add x12, x12, #16 // scale += 4 * sizeof(float32_t) - b TILE_4 - -TILE_2: - cmp x6, #2 - blt TILE_1 - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4 - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_2: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr -LoopSz_TILE_2: - // src : 1 x [2 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [4] : v16-19 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.16b}, [x24], x15 // src - .inst 0x4e80a490 // smmla v16.4s, v4.16b, v0.16b - .inst 0x4e81a491 // smmla v17.4s, v4.16b, v1.16b - .inst 0x4e82a492 // smmla v18.4s, v4.16b, v2.16b - .inst 0x4e83a493 // smmla v19.4s, v4.16b, v3.16b - subs x26, x26, #1 - bne LoopSz_TILE_2 - -LoopSzEnd_TILE_2: - add x7, x7, x13 - sub x27, x27, #1 - trn1 v20.2d, v16.2d, v17.2d - trn1 v21.2d, v18.2d, v19.2d - trn2 v22.2d, v16.2d, v17.2d - trn2 v23.2d, v18.2d, v19.2d - Int32ToFloat v20, v21, v22, v23 - // using float scale dequant for precison - ld1 {v5.d}[0], [x23] // scales - fmul v20.4s, v20.4s, v5.s[0] - fmul v21.4s, v21.4s, v5.s[0] - fmul v22.4s, v22.4s, v5.s[1] - fmul v23.4s, v23.4s, v5.s[1] -Tile2Dequant: - ld1 {v0.4s, v1.4s}, [x19], #32 // alpha - ld1 {v2.4s, v3.4s}, [x20], #32 // zero - ld1 {v8.4s, v9.4s}, [x21], #32 // bias - ld1 {v10.d}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - Dequant v20, v0, v2, v8, v10, 0 - Dequant v21, v1, v3, v9, v10, 0 - Dequant v22, v0, v2, v8, v10, 1 - Dequant v23, v1, v3, v9, v10, 1 - st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_2 -Tile2End: - sub x6, x6, #2 // batch -= 2 - add x0, x0, #64 // dst += 2 * 8 * sizeof(float32_t) - add x1, x1, #16 // dst += 2 * 8 * sizeof(int8_t) - add x11, x11, #8 // sum += 2 * sizeof(float32_t) - add x12, x12, #8 // scale += 2 * sizeof(float32_t) - b TILE_2 - -TILE_1: - - cmp x6, #1 - blt End - mov x14, x4 // dst_step - lsr x15, x4, #2 // src_step = dst_step / 4, sizeof(float32_t)/4=sizeof(int8_t) - mov x27, x5 // dst_depth_quad - mov x28, x0 // dst - mov x7, x2 // weight - // dequant info - mov x19, x8 // alpha - mov x20, x9 // zero - mov x21, x10 // bias -LoopDz_TILE_1: - mov x22, x11 // sums - mov x23, x12 // scales - mov x24, x1 // src - mov x25, x7 // weight - mov x26, x3 // src_depth_quad - // init - dup v16.4s, wzr - dup v17.4s, wzr - dup v18.4s, wzr - dup v19.4s, wzr - -LoopSz_TILE_1: - // src : 1 x [1 x 8] : v4 - // weight : 4 x [2 x 8] : v0-3 - // dst : 1 x 4 x [2] : v16-v19 - ld1 {v0.16b, v1.16b, v2.16b, v3.16b}, [x25], #64 // weight - ld1 {v4.8b}, [x24], x15 // src - .inst 0x4e84a410 // smmla v16.4s, v0.16b, v4.16b - .inst 0x4e84a431 // smmla v17.4s, v1.16b, v4.16b - .inst 0x4e84a452 // smmla v18.4s, v2.16b, v4.16b - .inst 0x4e84a473 // smmla v19.4s, v3.16b, v4.16b - - subs x26, x26, #1 - bne LoopSz_TILE_1 - -LoopSzEnd_TILE_1: - add x7, x7, x13 - sub x27, x27, #1 - uzp1 v20.4s, v16.4s, v17.4s - uzp1 v21.4s, v18.4s, v19.4s - scvtf v20.4s, v20.4s - scvtf v21.4s, v21.4s - // using float scale dequant for precison - ld1 {v4.s}[0], [x23] // scales - fmul v20.4s, v20.4s, v4.s[0] - fmul v21.4s, v21.4s, v4.s[0] -Tile1Dequant: - ld1 {v0.4s, v1.4s}, [x19], #32 // alpha - ld1 {v2.4s, v3.4s}, [x20], #32 // zero - ld1 {v10.4s, v11.4s}, [x21], #32 // bias - ld1 {v8.s}[0], [x22] // sums - // alpha * sum + (zero * sumx) + bias - fmla v10.4s, v20.4s, v0.4s - fmla v11.4s, v21.4s, v1.4s - fmla v10.4s, v2.4s, v8.s[0] - fmla v11.4s, v3.4s, v8.s[0] - st1 {v10.4s, v11.4s}, [x28], x14 - cmp x27, #1 - bge LoopDz_TILE_1 -Tile1End: - sub x6, x6, #1 // batch -= 1 - add x0, x0, #32 // dst += 1 * 8 * sizeof(float32_t) - add x1, x1, #8 // src += 1 * 8 * sizeof(int8_t) - add x11, x11, #4 // sum += 1 * sizeof(float32_t) - add x12, x12, #4 // scale += 1 * sizeof(float32_t) - b TILE_1 - -End: -ldp x27, x28, [sp, #(16 * 8)] -ldp x25, x26, [sp, #(16 * 7)] -ldp x23, x24, [sp, #(16 * 6)] -ldp x19, x20, [sp, #(16 * 5)] -ldp x21, x22, [sp, #(16 * 4)] -ldp d8, d9, [sp, #(16 * 3)] -ldp d10, d11, [sp, #(16 * 2)] -ldp d12, d13, [sp, #(16 * 1)] -ldp d14, d15, [sp], #(16 * 9) -ret - -#endif \ No newline at end of file diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S new file mode 100644 index 000000000..fa8258b66 --- /dev/null +++ b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_16x4_w4_Unit.S @@ -0,0 +1,830 @@ +// +// MNNGemmInt8AddBiasScale_16x4_w4_Unit.S +// MNN +// +// Created by MNN on 2019/06/11. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#ifdef __aarch64__ + +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro ReLU_FP32_4 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmin \s3\().4s, \s3\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s + fmax \s3\().4s, \s3\().4s, \z0\().4s +.endm +.macro ReLU_FP32_3 s0, s1, s2, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s +.endm +.macro ReLU_FP32_2 s0, s1, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s +.endm +.macro ReLU_FP32_1 s0, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s +.endm +.macro MUL_SCALE4 s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s + fmul \d3\().4s, \d3\().4s, \s\().4s +.endm +.macro MUL_SCALE3 s, d0, d1, d2 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s +.endm +.macro MUL_SCALE2 s, d0, d1 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s +.endm +.macro MUL_SCALE1 s, d0 + fmul \d0\().4s, \d0\().4s, \s\().4s +.endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm + +asm_function MNNGemmInt8AddBiasScale_16x4_w4_Unit + +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum; +}; +*/ +//void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, +// size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realSize) { + +//Auto: x0: dst*, x1: src*, x2:weight*, x3: src_depth_quad, x4: dst_step, +// x5: dst_depth_quad, x6: post, x7: realSize + +//Load from post: +// x7: scale, x10: bias, w11: maxValue, w6: minValue, w13: UseInt8, x14: srcKernelSum, x12: weightQuantBias +mov x8, x7 +mov x15, x6 +ldr x7, [x15, #0] +ldr x10, [x15, #8] +ldr w11, [x15, #16] +ldr w6, [x15, #20] +ldr w13, [x15, #24] +ldr x14, [x15, #40] // srcKernelSum +ldr x12, [x15, #48] // weightQuantBias + +stp d14, d15, [sp, #(-16 * 8)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x19, x20, [sp, #(16 * 4)] +stp x21, x22, [sp, #(16 * 5)] +stp x23, x24, [sp, #(16 * 6)] + +ldr x19, [x15, #56] // fp32 min max +ldr x21, [x15, #64] // blockNum +ldr x23, [x15, #80] // extraScale +mul x21, x21, x3 // blockNum * src_depth_quad_perblock +lsl x21, x21, #5 // src_depth_quad* SRC_UNIT * UNIT * sizeof(int4_t) +add x20, x19, #4 + +Start: +cmp x8, #3 +beq L3Dz + +cmp x8, #2 +beq L2Dz + +cmp x8, #1 +beq L1Dz + +//cmp w13, #1 +//bne L4LoopDz +//sub x4, x4, #8 // post->scale != nullptr && post->useInt8 == 1. +L4LoopDz: + mov x8, x1 + mov x22, x2 + ld1 {v0.16b, v1.16b}, [x2], #32 // weight + ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 // src + // int4->int8 + movi v8.16b, #15 + ushr v10.16b, v0.16b, #4 + and v11.16b, v0.16b, v8.16b + ushr v12.16b, v1.16b, #4 + and v13.16b, v1.16b, v8.16b + zip1 v0.16b, v10.16b, v11.16b + zip2 v1.16b, v10.16b, v11.16b + zip1 v2.16b, v12.16b, v13.16b + zip2 v3.16b, v12.16b, v13.16b + + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v1.8b, v4.8b + smull v10.8h, v2.8b, v4.8b + smull v11.8h, v3.8b, v4.8b + smull v12.8h, v0.8b, v5.8b + smull v13.8h, v1.8b, v5.8b + smull v14.8h, v2.8b, v5.8b + smull v15.8h, v3.8b, v5.8b + + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v1.16b, v4.16b + smlal2 v10.8h, v2.16b, v4.16b + smlal2 v11.8h, v3.16b, v4.16b + smlal2 v12.8h, v0.16b, v5.16b + smlal2 v13.8h, v1.16b, v5.16b + smlal2 v14.8h, v2.16b, v5.16b + smlal2 v15.8h, v3.16b, v5.16b + + L4Initialize: + saddlp v16.4s, v8.8h + saddlp v17.4s, v9.8h + saddlp v18.4s, v10.8h + saddlp v19.4s, v11.8h + saddlp v20.4s, v12.8h + saddlp v21.4s, v13.8h + saddlp v22.4s, v14.8h + saddlp v23.4s, v15.8h + + smull v8.8h, v0.8b, v6.8b + smull v9.8h, v1.8b, v6.8b + smull v10.8h, v2.8b, v6.8b + smull v11.8h, v3.8b, v6.8b + smull v12.8h, v0.8b, v7.8b + smull v13.8h, v1.8b, v7.8b + smull v14.8h, v2.8b, v7.8b + smull v15.8h, v3.8b, v7.8b + subs x9, x3, #1 + smlal2 v8.8h, v0.16b, v6.16b + smlal2 v9.8h, v1.16b, v6.16b + smlal2 v10.8h, v2.16b, v6.16b + smlal2 v11.8h, v3.16b, v6.16b + smlal2 v12.8h, v0.16b, v7.16b + smlal2 v13.8h, v1.16b, v7.16b + smlal2 v14.8h, v2.16b, v7.16b + smlal2 v15.8h, v3.16b, v7.16b + + saddlp v24.4s, v8.8h + saddlp v25.4s, v9.8h + saddlp v26.4s, v10.8h + saddlp v27.4s, v11.8h + saddlp v28.4s, v12.8h + saddlp v29.4s, v13.8h + saddlp v30.4s, v14.8h + saddlp v31.4s, v15.8h + L4InitializeEnd: + beq ComputeSum + + L4LoopSz: + ld1 {v4.16b, v5.16b, v6.16b, v7.16b}, [x1], #64 + ld1 {v0.16b, v1.16b}, [x2], #32 + // int4->int8 + movi v8.16b, #15 + ushr v10.16b, v0.16b, #4 + and v11.16b, v0.16b, v8.16b + ushr v12.16b, v1.16b, #4 + and v13.16b, v1.16b, v8.16b + zip1 v0.16b, v10.16b, v11.16b + zip2 v1.16b, v10.16b, v11.16b + zip1 v2.16b, v12.16b, v13.16b + zip2 v3.16b, v12.16b, v13.16b + + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v1.8b, v4.8b + smull v10.8h, v2.8b, v4.8b + smull v11.8h, v3.8b, v4.8b + smull v12.8h, v0.8b, v5.8b + smull v13.8h, v1.8b, v5.8b + smull v14.8h, v2.8b, v5.8b + smull v15.8h, v3.8b, v5.8b + + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v1.16b, v4.16b + smlal2 v10.8h, v2.16b, v4.16b + smlal2 v11.8h, v3.16b, v4.16b + smlal2 v12.8h, v0.16b, v5.16b + smlal2 v13.8h, v1.16b, v5.16b + smlal2 v14.8h, v2.16b, v5.16b + smlal2 v15.8h, v3.16b, v5.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v0.8b, v6.8b + smull v9.8h, v1.8b, v6.8b + smull v10.8h, v2.8b, v6.8b + smull v11.8h, v3.8b, v6.8b + smull v12.8h, v0.8b, v7.8b + smull v13.8h, v1.8b, v7.8b + smull v14.8h, v2.8b, v7.8b + smull v15.8h, v3.8b, v7.8b + + subs x9, x9, #1 + + smlal2 v8.8h, v0.16b, v6.16b + smlal2 v9.8h, v1.16b, v6.16b + smlal2 v10.8h, v2.16b, v6.16b + smlal2 v11.8h, v3.16b, v6.16b + smlal2 v12.8h, v0.16b, v7.16b + smlal2 v13.8h, v1.16b, v7.16b + smlal2 v14.8h, v2.16b, v7.16b + smlal2 v15.8h, v3.16b, v7.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + sadalp v28.4s, v12.8h + sadalp v29.4s, v13.8h + sadalp v30.4s, v14.8h + sadalp v31.4s, v15.8h + + bne L4LoopSz + + ComputeSum: + + addp v4.4s, v16.4s, v17.4s + addp v5.4s, v18.4s, v19.4s + addp v6.4s, v20.4s, v21.4s + addp v7.4s, v22.4s, v23.4s + addp v8.4s, v24.4s, v25.4s + addp v9.4s, v26.4s, v27.4s + addp v10.4s, v28.4s, v29.4s + addp v11.4s, v30.4s, v31.4s + + addp v12.4s, v4.4s, v5.4s + addp v13.4s, v6.4s, v7.4s + addp v14.4s, v8.4s, v9.4s + addp v15.4s, v10.4s, v11.4s + + L4Quan: + ld1 {v1.4s}, [x7], #16 // scalefuse + ld1 {v20.4s}, [x14] // srcKernelSum + ld1 {v21.4s}, [x12], #16 // weightQuanZero + + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s + scvtf v6.4s, v14.4s + scvtf v7.4s, v15.4s + + cbz x23, TILE4_MUL_OHE_SCALE + ld1 {v2.4s}, [x23] + MUL_EXTRA_SCALE v2, v4, v5, v6, v7 + + TILE4_MUL_OHE_SCALE: + MUL_SCALE4 v1, v4, v5, v6, v7 + + MLA_WEIGHTZERO v4, v20, v21, 0 + MLA_WEIGHTZERO v5, v20, v21, 1 + MLA_WEIGHTZERO v6, v20, v21, 2 + MLA_WEIGHTZERO v7, v20, v21, 3 + + L4_Add_BIAS: + cbz x10, L4_ADD_DSTV + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + fadd v6.4s, v6.4s, v0.4s + fadd v7.4s, v7.4s, v0.4s + b L4_POST + + L4_ADD_DSTV: + ld1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0] + fadd v4.4s, v4.4s, v8.4s + fadd v5.4s, v5.4s, v9.4s + fadd v6.4s, v6.4s, v10.4s + fadd v7.4s, v7.4s, v11.4s + + L4_POST: + cbz x19, L4_STORE + ld1r {v26.4s}, [x19] // f32 min + ld1r {v27.4s}, [x20] // f32 max + ReLU_FP32_4 v4, v5, v6, v7, v26, v27 + + L4_STORE: + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], x4 + +L4LoopCheck: + subs x5, x5, #1 + mov x1, x8 + add x2, x22, x21 + bne L4LoopDz + +b End + +L3Dz: +cmp w13, #1 +bne L3LoopDz +sub x4, x4, #8 +L3LoopDz: + mov x8, x1 + mov x22, x2 + ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v4.16b, v5.16b, v6.16b}, [x1], #48 + add x1, x1, #16 + // int4->int8 + movi v8.16b, #15 + ushr v10.16b, v0.16b, #4 + and v11.16b, v0.16b, v8.16b + ushr v12.16b, v1.16b, #4 + and v13.16b, v1.16b, v8.16b + zip1 v0.16b, v10.16b, v11.16b + zip2 v1.16b, v10.16b, v11.16b + zip1 v2.16b, v12.16b, v13.16b + zip2 v3.16b, v12.16b, v13.16b + + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v1.8b, v4.8b + smull v10.8h, v2.8b, v4.8b + smull v11.8h, v3.8b, v4.8b + smull v12.8h, v0.8b, v5.8b + smull v13.8h, v1.8b, v5.8b + smull v14.8h, v2.8b, v5.8b + smull v15.8h, v3.8b, v5.8b + + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v1.16b, v4.16b + smlal2 v10.8h, v2.16b, v4.16b + smlal2 v11.8h, v3.16b, v4.16b + smlal2 v12.8h, v0.16b, v5.16b + smlal2 v13.8h, v1.16b, v5.16b + smlal2 v14.8h, v2.16b, v5.16b + smlal2 v15.8h, v3.16b, v5.16b + + L3Initialize: + saddlp v16.4s, v8.8h + saddlp v17.4s, v9.8h + saddlp v18.4s, v10.8h + saddlp v19.4s, v11.8h + saddlp v20.4s, v12.8h + saddlp v21.4s, v13.8h + saddlp v22.4s, v14.8h + saddlp v23.4s, v15.8h + + smull v8.8h, v0.8b, v6.8b + smull v9.8h, v1.8b, v6.8b + smull v10.8h, v2.8b, v6.8b + smull v11.8h, v3.8b, v6.8b + + subs x9, x3, #1 + + smlal2 v8.8h, v0.16b, v6.16b + smlal2 v9.8h, v1.16b, v6.16b + smlal2 v10.8h, v2.16b, v6.16b + smlal2 v11.8h, v3.16b, v6.16b + + saddlp v24.4s, v8.8h + saddlp v25.4s, v9.8h + saddlp v26.4s, v10.8h + saddlp v27.4s, v11.8h + L3InitializeEnd: + beq L3ComputeSum + + L3LoopSz: + ld1 {v4.16b, v5.16b, v6.16b}, [x1], #48 + ld1 {v0.16b, v1.16b}, [x2], #32 + // int4->int8 + movi v8.16b, #15 + ushr v10.16b, v0.16b, #4 + and v11.16b, v0.16b, v8.16b + ushr v12.16b, v1.16b, #4 + and v13.16b, v1.16b, v8.16b + zip1 v0.16b, v10.16b, v11.16b + zip2 v1.16b, v10.16b, v11.16b + zip1 v2.16b, v12.16b, v13.16b + zip2 v3.16b, v12.16b, v13.16b + + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v1.8b, v4.8b + smull v10.8h, v2.8b, v4.8b + smull v11.8h, v3.8b, v4.8b + smull v12.8h, v0.8b, v5.8b + smull v13.8h, v1.8b, v5.8b + smull v14.8h, v2.8b, v5.8b + smull v15.8h, v3.8b, v5.8b + + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v1.16b, v4.16b + smlal2 v10.8h, v2.16b, v4.16b + smlal2 v11.8h, v3.16b, v4.16b + smlal2 v12.8h, v0.16b, v5.16b + smlal2 v13.8h, v1.16b, v5.16b + smlal2 v14.8h, v2.16b, v5.16b + smlal2 v15.8h, v3.16b, v5.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + smull v8.8h, v0.8b, v6.8b + smull v9.8h, v1.8b, v6.8b + smull v10.8h, v2.8b, v6.8b + smull v11.8h, v3.8b, v6.8b + + subs x9, x9, #1 + add x1, x1, #16 + + smlal2 v8.8h, v0.16b, v6.16b + smlal2 v9.8h, v1.16b, v6.16b + smlal2 v10.8h, v2.16b, v6.16b + smlal2 v11.8h, v3.16b, v6.16b + + sadalp v24.4s, v8.8h + sadalp v25.4s, v9.8h + sadalp v26.4s, v10.8h + sadalp v27.4s, v11.8h + + bne L3LoopSz + + L3ComputeSum: + addp v4.4s, v16.4s, v17.4s + addp v5.4s, v18.4s, v19.4s + addp v6.4s, v20.4s, v21.4s + addp v7.4s, v22.4s, v23.4s + addp v8.4s, v24.4s, v25.4s + addp v9.4s, v26.4s, v27.4s + + addp v12.4s, v4.4s, v5.4s + addp v13.4s, v6.4s, v7.4s + addp v14.4s, v8.4s, v9.4s + + L3Quan: + ld1 {v1.4s}, [x7], #16 + ld1 {v20.d}[0], [x14], #8 // srcKernelSum + ld1 {v20.s}[2], [x14] + ld1 {v21.4s}, [x12], #16 // weightQuanZero + + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s + scvtf v6.4s, v14.4s + MUL_SCALE3 v1, v4, v5, v6 + + cbz x23, TILE3_MUL_OHE_SCALE + ld1 {v2.d}[0], [x23], #8 + ld1 {v2.s}[2], [x23] + fmul v4.4s, v4.4s, v2.s[0] + fmul v5.4s, v5.4s, v2.s[1] + fmul v6.4s, v6.4s, v2.s[2] + sub x23, x23, #8 + + TILE3_MUL_OHE_SCALE: + sub x14, x14, #8 + MLA_WEIGHTZERO v4, v20, v21, 0 + MLA_WEIGHTZERO v5, v20, v21, 1 + MLA_WEIGHTZERO v6, v20, v21, 2 + + L3_ADD_BIAS: + cbz x10, L3_ADD_DSTV + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + fadd v6.4s, v6.4s, v0.4s + b L3_POST + + L3_ADD_DSTV: + ld1 {v0.4s, v1.4s, v2.4s}, [x0] + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v1.4s + fadd v6.4s, v6.4s, v2.4s + + L3_POST: + cbz x19, L3_STORE + ld1r {v26.4s}, [x19] // f32 min + ld1r {v27.4s}, [x20] // f32 max + ReLU_FP32_3 v4, v5, v6, v26, v27 + L3_STORE: + st1 {v4.4s, v5.4s, v6.4s}, [x0], x4 + +L3LoopCheck: + subs x5, x5, #1 + mov x1, x8 + add x2, x22, x21 + bne L3LoopDz + +b End + +L2Dz: +L2LoopDz: + mov x8, x1 + mov x22, x2 + ld1 {v0.16b, v1.16b}, [x2], #32 + ld1 {v4.16b, v5.16b}, [x1], #32 + // int4->int8 + movi v8.16b, #15 + ushr v10.16b, v0.16b, #4 + and v11.16b, v0.16b, v8.16b + ushr v12.16b, v1.16b, #4 + and v13.16b, v1.16b, v8.16b + zip1 v0.16b, v10.16b, v11.16b + zip2 v1.16b, v10.16b, v11.16b + zip1 v2.16b, v12.16b, v13.16b + zip2 v3.16b, v12.16b, v13.16b + + + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v1.8b, v4.8b + smull v10.8h, v2.8b, v4.8b + smull v11.8h, v3.8b, v4.8b + smull v12.8h, v0.8b, v5.8b + smull v13.8h, v1.8b, v5.8b + smull v14.8h, v2.8b, v5.8b + smull v15.8h, v3.8b, v5.8b + add x1, x1, #32 + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v1.16b, v4.16b + smlal2 v10.8h, v2.16b, v4.16b + smlal2 v11.8h, v3.16b, v4.16b + smlal2 v12.8h, v0.16b, v5.16b + smlal2 v13.8h, v1.16b, v5.16b + smlal2 v14.8h, v2.16b, v5.16b + smlal2 v15.8h, v3.16b, v5.16b + + L2Initialize: + saddlp v16.4s, v8.8h + saddlp v17.4s, v9.8h + saddlp v18.4s, v10.8h + saddlp v19.4s, v11.8h + saddlp v20.4s, v12.8h + saddlp v21.4s, v13.8h + saddlp v22.4s, v14.8h + saddlp v23.4s, v15.8h + subs x9, x3, #1 + L2InitializeEnd: + beq L2ComputeSum + + L2LoopSz: + ld1 {v4.16b, v5.16b}, [x1], #32 + ld1 {v0.16b, v1.16b}, [x2], #32 + // int4->int8 + movi v8.16b, #15 + ushr v10.16b, v0.16b, #4 + and v11.16b, v0.16b, v8.16b + ushr v12.16b, v1.16b, #4 + and v13.16b, v1.16b, v8.16b + zip1 v0.16b, v10.16b, v11.16b + zip2 v1.16b, v10.16b, v11.16b + zip1 v2.16b, v12.16b, v13.16b + zip2 v3.16b, v12.16b, v13.16b + + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v1.8b, v4.8b + smull v10.8h, v2.8b, v4.8b + smull v11.8h, v3.8b, v4.8b + smull v12.8h, v0.8b, v5.8b + smull v13.8h, v1.8b, v5.8b + smull v14.8h, v2.8b, v5.8b + smull v15.8h, v3.8b, v5.8b + + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v1.16b, v4.16b + smlal2 v10.8h, v2.16b, v4.16b + smlal2 v11.8h, v3.16b, v4.16b + add x1, x1, #32 + subs x9, x9, #1 + smlal2 v12.8h, v0.16b, v5.16b + smlal2 v13.8h, v1.16b, v5.16b + smlal2 v14.8h, v2.16b, v5.16b + smlal2 v15.8h, v3.16b, v5.16b + + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + bne L2LoopSz + + L2ComputeSum: + + addp v4.4s, v16.4s, v17.4s + addp v5.4s, v18.4s, v19.4s + addp v6.4s, v20.4s, v21.4s + addp v7.4s, v22.4s, v23.4s + + addp v12.4s, v4.4s, v5.4s + addp v13.4s, v6.4s, v7.4s + + L2Quan: + ld1 {v1.4s}, [x7], #16 + ld1 {v20.d}[0], [x14] // srcKernelSum + ld1 {v21.4s}, [x12], #16 // weightQuanZero + + scvtf v4.4s, v12.4s + scvtf v5.4s, v13.4s + MUL_SCALE2 v1, v4, v5 + + cbz x23, TILE2_MUL_OHE_SCALE + ld1 {v2.d}[0], [x23] + fmul v4.4s, v4.4s, v2.s[0] + fmul v5.4s, v5.4s, v2.s[1] + + TILE2_MUL_OHE_SCALE: + MLA_WEIGHTZERO v4, v20, v21, 0 + MLA_WEIGHTZERO v5, v20, v21, 1 + + L2_ADD_BIAS: + cbz x10, L2_ADD_DSTV + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v0.4s + b L2_POST + + L2_ADD_DSTV: + ld1 {v0.4s, v1.4s}, [x0] + fadd v4.4s, v4.4s, v0.4s + fadd v5.4s, v5.4s, v1.4s + + L2_POST: + cbz x19, L2_STORE + ld1r {v26.4s}, [x19] // f32 min + ld1r {v27.4s}, [x20] // f32 max + ReLU_FP32_2 v4, v5, v26, v27 + + L2_STORE: + st1 {v4.4s, v5.4s}, [x0], x4 + +L2LoopCheck: + subs x5, x5, #1 + mov x1, x8 + add x2, x22, x21 + bne L2LoopDz + +b End + +L1Dz: +L1LoopDz: + mov x8, x1 + mov x22, x2 + ld1 {v0.16b, v1.16b}, [x2], #32 + // int4->int8 + movi v8.16b, #15 + ushr v10.16b, v0.16b, #4 + and v11.16b, v0.16b, v8.16b + ushr v12.16b, v1.16b, #4 + and v13.16b, v1.16b, v8.16b + zip1 v0.16b, v10.16b, v11.16b + zip2 v1.16b, v10.16b, v11.16b + zip1 v2.16b, v12.16b, v13.16b + zip2 v3.16b, v12.16b, v13.16b + dup v16.4s, wzr + dup v17.4s, wzr + ld1 {v4.16b}, [x1], #16 + add x1, x1, #48 + + smull v8.8h, v0.8b, v4.8b + dup v18.4s, wzr + smull v9.8h, v1.8b, v4.8b + dup v19.4s, wzr + smull v10.8h, v2.8b, v4.8b + smull v11.8h, v3.8b, v4.8b + subs x9, x3, #1 + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v1.16b, v4.16b + smlal2 v10.8h, v2.16b, v4.16b + smlal2 v11.8h, v3.16b, v4.16b + beq L1LoopSzEnd + + L1LoopSz: + sadalp v16.4s, v8.8h + ld1 {v4.16b}, [x1], #16 + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + sadalp v20.4s, v12.8h + sadalp v21.4s, v13.8h + sadalp v22.4s, v14.8h + sadalp v23.4s, v15.8h + + ld1 {v0.16b, v1.16b}, [x2], #32 + add x1, x1, #48 + // int4->int8 + movi v8.16b, #15 + ushr v10.16b, v0.16b, #4 + and v11.16b, v0.16b, v8.16b + ushr v12.16b, v1.16b, #4 + and v13.16b, v1.16b, v8.16b + zip1 v0.16b, v10.16b, v11.16b + zip2 v1.16b, v10.16b, v11.16b + zip1 v2.16b, v12.16b, v13.16b + zip2 v3.16b, v12.16b, v13.16b + + smull v8.8h, v0.8b, v4.8b + smull v9.8h, v1.8b, v4.8b + smull v10.8h, v2.8b, v4.8b + smull v11.8h, v3.8b, v4.8b + + smlal2 v8.8h, v0.16b, v4.16b + smlal2 v9.8h, v1.16b, v4.16b + smlal2 v10.8h, v2.16b, v4.16b + smlal2 v11.8h, v3.16b, v4.16b + + subs x9, x9, #1 + bne L1LoopSz + + L1LoopSzEnd: + sadalp v16.4s, v8.8h + sadalp v17.4s, v9.8h + sadalp v18.4s, v10.8h + sadalp v19.4s, v11.8h + + //ld1 {v0.4s}, [x10], #16 + addp v4.4s, v16.4s, v17.4s + addp v5.4s, v18.4s, v19.4s + + addp v12.4s, v4.4s, v5.4s + + L1Quan: + ld1 {v1.4s}, [x7], #16 + ld1 {v20.s}[0], [x14] // srcKernelSum + ld1 {v21.4s}, [x12], #16 // weightQuanZero + + scvtf v4.4s, v12.4s + MUL_SCALE1 v1, v4 + + cbz x23, TILE1_MUL_OHE_SCALE + ld1 {v2.s}[0], [x23] + fmul v4.4s, v4.4s, v2.s[0] + + TILE1_MUL_OHE_SCALE: + MLA_WEIGHTZERO v4, v20, v21, 0 + + L1_ADD_BIAS: + cbz x10, L1_ADD_DSTV + ld1 {v0.4s}, [x10], #16 + fadd v4.4s, v4.4s, v0.4s + b L1_POST + + L1_ADD_DSTV: + ld1 {v0.4s}, [x0] + fadd v4.4s, v4.4s, v0.4s + + L1_POST: + cbz x19, L1_STORE + ld1r {v26.4s}, [x19] // f32 min + ld1r {v27.4s}, [x20] // f32 max + ReLU_FP32_1 v4, v26, v27 + + L1_STORE: + st1 {v4.4s}, [x0], x4 + +L1LoopCheck: + subs x5, x5, #1 + mov x1, x8 + add x2, x22, x21 + bne L1LoopDz + +End: +ldp x23, x24, [sp, #(16 * 6)] +ldp x21, x22, [sp, #(16 * 5)] +ldp x19, x20, [sp, #(16 * 4)] +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 8) +ret + +#endif diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S new file mode 100644 index 000000000..fa9bc1f43 --- /dev/null +++ b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S @@ -0,0 +1,999 @@ +// +// MNNGemmInt8AddBiasScale_ARMV82_w4_Unit.S +// MNN +// +// Created by MNN on 2019/12/17. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#if defined(__aarch64__) +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro ADD_BIAS_FLOAT d0, d1, d2, d3, z0 + fadd \d0\().4s, \d0\().4s, \z0\().4s + fadd \d1\().4s, \d1\().4s, \z0\().4s + fadd \d2\().4s, \d2\().4s, \z0\().4s + fadd \d3\().4s, \d3\().4s, \z0\().4s +.endm + +.macro ADD_FLOAT d0, d1, d2, d3, s0, s1, s2, s3 + fadd \d0\().4s, \d0\().4s, \s0\().4s + fadd \d1\().4s, \d1\().4s, \s1\().4s + fadd \d2\().4s, \d2\().4s, \s2\().4s + fadd \d3\().4s, \d3\().4s, \s3\().4s +.endm + +.macro SET_BIAS d0, d1, d2, d3 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 +.endm +.macro Int32ToFloat z0, z1, z2, z3 + scvtf \z0\().4s, \z0\().4s + scvtf \z1\().4s, \z1\().4s + scvtf \z2\().4s, \z2\().4s + scvtf \z3\().4s, \z3\().4s +.endm +.macro MUL_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s + fmul \d3\().4s, \d3\().4s, \s\().4s +.endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm +.macro FloatToInt32 z0, z1, z2, z3 + fcvtas \z0\().4s, \z0\().4s + fcvtas \z1\().4s, \z1\().4s + fcvtas \z2\().4s, \z2\().4s + fcvtas \z3\().4s, \z3\().4s +.endm +.macro Int32ToInt16 s0, s1, s2, s3, d0, d1 + sqxtn \d0\().4h, \s0\().4s + sqxtn2 \d0\().8h, \s1\().4s + sqxtn \d1\().4h, \s2\().4s + sqxtn2 \d1\().8h, \s3\().4s +.endm +.macro Int16ToInt8_ONE s0, s1, d0 + sqxtn \d0\().8b, \s0\().8h + sqxtn2 \d0\().16b, \s1\().8h +.endm +.macro Int16ToInt8 s0, s1, s2, s3, d0, d1 + Int16ToInt8_ONE \s0, \s1, \d0 + Int16ToInt8_ONE \s2, \s3, \d1 +.endm +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro ReLU_FP32 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmin \s3\().4s, \s3\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s + fmax \s3\().4s, \s3\().4s, \z0\().4s +.endm + +asm_function MNNGemmInt8AddBiasScale_ARMV82_w4_Unit +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum; + const int32_t* bias; + +}; +*/ + +//void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit(int8_t* dst, const int8_t* src, +// const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, +// const QuanPostTreatParameters* parameters, size_t realDstCount); + +//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step +//x5:dst_depth_quad, x6: parameters, x7: realDstCount + +//Load from x6: x8: scale, x9: bias, x25: xKernelSum, x26: weightQuantBias, x23: fp32minmax +ldr x8, [x6, #0] +ldr x9, [x6, #8] + +stp d14, d15, [sp, #(-16 * 10)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x21, x22, [sp, #(16 * 4)] +stp x19, x20, [sp, #(16 * 5)] +stp x27, x28, [sp, #(16 * 6)] +stp x25, x26, [sp, #(16 * 7)] +stp x23, x24, [sp, #(16 * 8)] + +ldr x27, [x6, #64] // blockNum +mul x27, x27, x3 // blockNum * src_depth_quad_perblock +lsl x15, x27, #3 // x15 = src_depth_quad * UNIT * SRC_UNIT * sizeof(int4_t) + +ldr x25, [x6, #40] // xKernelSum +ldr x26, [x6, #48] // weightQuantBias +ldr x24, [x6, #80] // extraScale + +mov x21, #16 // sizeof(float) * UNIT +ldr x23, [x6, #56] // fp32minmax +Start: +mov x22, #48 // src_steps + +TILE_12: + cmp x7, #12 + blt TILE_8 + cmp x5, #2 + blt L4LoopDz_TILE_12 +L8LoopDz_TILE_12: + //ld1 {v0.4s, v1.4s}, [x9], #32 // bias + mov x11, x1 + mov x13, x3 + mov x20, x0 // tag dst address + mov x27, x2 + movi v7.16b, #15 + + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + SET_BIAS v20, v21, v22, v23 + SET_BIAS v24, v25, v26, v27 + SET_BIAS v28, v29, v30, v31 + + L8LoopSz_TILE_12: + ld1 {v3.d}[0], [x2], x15 // weight + ld1 {v4.d}[0], [x2], #8 + ld1 {v0.16b, v1.16b, v2.16b}, [x11], #48 // src + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] + .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] + .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] + .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + // int4->int8 + ushr v5.16b, v4.16b, #4 + and v6.16b, v4.16b, v7.16b + zip1 v4.16b, v5.16b, v6.16b + + .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] + .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] + .inst 0x4f82e872 // sdot v18.4s, v3.16b, v2.4b[2] + .inst 0x4fa2e873 // sdot v19.4s, v3.16b, v2.4b[3] + + .inst 0x4f80e094 // sdot v20.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e095 // sdot v21.4s, v4.16b, v0.4b[1] + .inst 0x4f80e896 // sdot v22.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e897 // sdot v23.4s, v4.16b, v0.4b[3] + sub x2, x2, x15 + .inst 0x4f81e098 // sdot v24.4s, v4.16b, v1.4b[0] + .inst 0x4fa1e099 // sdot v25.4s, v4.16b, v1.4b[1] + .inst 0x4f81e89a // sdot v26.4s, v4.16b, v1.4b[2] + .inst 0x4fa1e89b // sdot v27.4s, v4.16b, v1.4b[3] + subs x13, x13, #1 + .inst 0x4f82e09c // sdot v28.4s, v4.16b, v2.4b[0] + .inst 0x4fa2e09d // sdot v29.4s, v4.16b, v2.4b[1] + .inst 0x4f82e89e // sdot v30.4s, v4.16b, v2.4b[2] + .inst 0x4fa2e89f // sdot v31.4s, v4.16b, v2.4b[3] + bne L8LoopSz_TILE_12 + + L8LoopSzEnd_TILE_12: + // add x2, x2, x15 + add x2, x27, x15, LSL #1 + sub x5, x5, #2 + + L8Tile12Quan: + ld1 {v0.4s, v1.4s}, [x8], #32 // scale + ld1 {v2.4s, v3.4s, v4.4s}, [x25] // x kernel sum + ld1 {v5.4s, v6.4s}, [x26], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + Int32ToFloat v20, v21, v22, v23 + Int32ToFloat v24, v25, v26, v27 + Int32ToFloat v28, v29, v30, v31 + + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v0, v12, v13, v14, v15 + MUL_SCALE v0, v16, v17, v18, v19 + MUL_SCALE v1, v20, v21, v22, v23 + MUL_SCALE v1, v24, v25, v26, v27 + MUL_SCALE v1, v28, v29, v30, v31 + + cbz x24, TILE12_L8_MLA + ld1 {v0.4s, v1.4s}, [x24], #32 + ld1 {v7.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + MUL_EXTRA_SCALE v7, v16, v17, v18, v19 + MUL_EXTRA_SCALE v0, v20, v21, v22, v23 + MUL_EXTRA_SCALE v1, v24, v25, v26, v27 + MUL_EXTRA_SCALE v7, v28, v29, v30, v31 + sub x24, x24, #32 + + TILE12_L8_MLA: + MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 + + MLA_WEIGHTZERO v20, v2, v6, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v21, v2, v6, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v22, v2, v6, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v23, v2, v6, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v24, v3, v6, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v25, v3, v6, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v26, v3, v6, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v27, v3, v6, 3 // tile:7, oc:4-7 + MLA_WEIGHTZERO v28, v4, v6, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v29, v4, v6, 1 // tile:9, oc:4-7 + MLA_WEIGHTZERO v30, v4, v6, 2 // tile:10, oc:4-7 + MLA_WEIGHTZERO v31, v4, v6, 3 // tile:11, oc:4-7 + + sub x4, x4, #128 + + cbz x9, TILE12_ADD_DSTV + TILE12_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x9], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v0 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + ADD_BIAS_FLOAT v24, v25, v26, v27, v1 + ADD_BIAS_FLOAT v28, v29, v30, v31, v1 + b TILE12_POST + + TILE12_ADD_DSTV: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20], #64 + ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3 + ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], x4 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20], #64 + ADD_FLOAT v16, v17, v18, v19, v0, v1, v2, v3 + ADD_FLOAT v20, v21, v22, v23, v4, v5, v6, v7 + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x20], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x20] + ADD_FLOAT v24, v25, v26, v27, v0, v1, v2, v3 + ADD_FLOAT v28, v29, v30, v31, v4, v5, v6, v7 + + TILE12_POST: + cbz x23, TILE12_STORE + ld1r {v0.4s}, [x23], #4 // f32 min + ld1r {v1.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v0, v1 + ReLU_FP32 v12, v13, v14, v15, v0, v1 + ReLU_FP32 v16, v17, v18, v19, v0, v1 + ReLU_FP32 v20, v21, v22, v23, v0, v1 + ReLU_FP32 v24, v25, v26, v27, v0, v1 + ReLU_FP32 v28, v29, v30, v31, v0, v1 + sub x23, x23, #4 + + TILE12_STORE: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + st1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + st1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0], x4 + add x4, x4, #128 + + L8Tile12LoopCheck: + cmp x5, #1 + bgt L8LoopDz_TILE_12 + blt End + +L4LoopDz_TILE_12: + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + movi v7.16b, #15 + + L4LoopSz_TILE_12: + ld1 {v3.d}[0], [x2], #8 // weight + ld1 {v0.16b, v1.16b, v2.16b}, [x1], #48 // src + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] + .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] + .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] + .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + subs x3, x3, #1 + .inst 0x4f82e070 // sdot v16.4s, v3.16b, v2.4b[0] + .inst 0x4fa2e071 // sdot v17.4s, v3.16b, v2.4b[1] + .inst 0x4f82e872 // sdot v18.4s, v3.16b, v2.4b[2] + .inst 0x4fa2e873 // sdot v19.4s, v3.16b, v2.4b[3] + bne L4LoopSz_TILE_12 + + L4LoopSzEnd_TILE_12: + + L4Tile12Quan: + ld1 {v0.4s}, [x8] // scale + ld1 {v2.4s, v3.4s, v4.4s}, [x25]// x kernel sum + ld1 {v5.4s}, [x26], #16 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v0, v12, v13, v14, v15 + MUL_SCALE v0, v16, v17, v18, v19 + + cbz x24, TILE12_L4_MLA + ld1 {v0.4s, v1.4s}, [x24], #32 + ld1 {v7.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + MUL_EXTRA_SCALE v7, v16, v17, v18, v19 + sub x24, x24, #32 + + TILE12_L4_MLA: + MLA_WEIGHTZERO v8, v2, v5, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v5, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v5, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v5, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v5, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v5, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v5, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v5, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v4, v5, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v17, v4, v5, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v4, v5, 2 // tile:10, oc:0-3 + MLA_WEIGHTZERO v19, v4, v5, 3 // tile:11, oc:0-3 + + sub x4, x4, #128 + + TILE12_L4_ADD_BIAS: + cbz x9, TILE12_L4_ADD_DSTV + ld1 {v0.4s}, [x9] // bias + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v0 + b TILE12_L4_POST + + TILE12_L4_ADD_DSTV: + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x0] + sub x0, x0, #128 + ADD_FLOAT v8, v9, v10, v11, v20, v21, v22, v23 + ADD_FLOAT v12, v13, v14, v15, v24, v25, v26, v27 + ADD_FLOAT v16, v17, v18, v19, v28, v29, v30, v31 + + TILE12_L4_POST: + cbz x23, TILE12_L4_STORE + ld1r {v6.4s}, [x23], #4 // f32 min + ld1r {v7.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v6, v7 + ReLU_FP32 v12, v13, v14, v15, v6, v7 + ReLU_FP32 v16, v17, v18, v19, v6, v7 + sub x23, x23, #4 + TILE12_L4_STORE: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x0], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x0], #64 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x0], x4 + add x4, x4, #128 + b End + +TILE_8: + cmp x7, #8 + blt TILE_4 + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x26 // weightQuantBias + cmp x5, #2 + blt L4LoopDz_TILE_8 +L8LoopDz_TILE_8: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 + mov x13, x3 + mov x27, x12 + movi v7.16b, #15 + + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + SET_BIAS v16, v17, v18, v19 + SET_BIAS v20, v21, v22, v23 + + L8LoopSz_TILE_8: + ld1 {v3.d}[0], [x12], x15 // weight + ld1 {v4.d}[0], [x12], #8 + ld1 {v0.16b, v1.16b}, [x11], x22 // src + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + // int4->int8 + ushr v5.16b, v4.16b, #4 + and v6.16b, v4.16b, v7.16b + zip1 v4.16b, v5.16b, v6.16b + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] + .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] + .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] + .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + sub x12, x12, x15 + .inst 0x4f80e090 // sdot v16.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e091 // sdot v17.4s, v4.16b, v0.4b[1] + .inst 0x4f80e892 // sdot v18.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e893 // sdot v19.4s, v4.16b, v0.4b[3] + subs x13, x13, #1 + .inst 0x4f81e094 // sdot v20.4s, v4.16b, v1.4b[0] + .inst 0x4fa1e095 // sdot v21.4s, v4.16b, v1.4b[1] + .inst 0x4f81e896 // sdot v22.4s, v4.16b, v1.4b[2] + .inst 0x4fa1e897 // sdot v23.4s, v4.16b, v1.4b[3] + bne L8LoopSz_TILE_8 + + L8LoopSzEnd_TILE_8: + //add x12, x12, x15 + add x12, x27, x15, LSL #1 + sub x14, x14, #2 + + L8Tile8Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.4s, v3.4s}, [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + Int32ToFloat v20, v21, v22, v23 + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v0, v12, v13, v14, v15 + MUL_SCALE v1, v16, v17, v18, v19 + MUL_SCALE v1, v20, v21, v22, v23 + + cbz x24, TILE8_L8_MLA + ld1 {v0.4s, v1.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + MUL_EXTRA_SCALE v0, v16, v17, v18, v19 + MUL_EXTRA_SCALE v1, v20, v21, v22, v23 + + TILE8_L8_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v24, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v2, v25, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v17, v2, v25, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v18, v2, v25, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v19, v2, v25, 3 // tile:3, oc:4-7 + MLA_WEIGHTZERO v20, v3, v25, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v21, v3, v25, 1 // tile:5, oc:4-7 + MLA_WEIGHTZERO v22, v3, v25, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v23, v3, v25, 3 // tile:7, oc:4-7 + + sub x4, x4, #64 + + cbz x9, TILE8_ADD_DSTV + TILE8_ADD_BIAS: + ld1 {v0.4s, v1.4s}, [x20], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v0 + ADD_BIAS_FLOAT v12, v13, v14, v15, v0 + ADD_BIAS_FLOAT v16, v17, v18, v19, v1 + ADD_BIAS_FLOAT v20, v21, v22, v23, v1 + b TILE8_POST + + TILE8_ADD_DSTV: + ld1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x10], #64 + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x4 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x10], #64 + ld1 {v28.4s, v29.4s, v30.4s, v31.4s}, [x10] + ADD_FLOAT v8, v9, v10, v11, v0, v1, v2, v3 + ADD_FLOAT v12, v13, v14, v15, v4, v5, v6, v7 + ADD_FLOAT v16, v17, v18, v19, v24, v25, v26, v27 + ADD_FLOAT v20, v21, v22, v23, v28, v29, v30, v31 + sub x10, x10, #128 + sub x10, x10, x4 + + TILE8_POST: + cbz x23, TILE8_STORE + ld1r {v0.4s}, [x23], #4 // f32 min + ld1r {v1.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v0, v1 + ReLU_FP32 v12, v13, v14, v15, v0, v1 + ReLU_FP32 v16, v17, v18, v19, v0, v1 + ReLU_FP32 v20, v21, v22, v23, v0, v1 + sub x23, x23, #4 + + TILE8_STORE: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 + st1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10], #64 + st1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x10], x4 + add x4, x4, #64 + + L8Tile8LoopCheck: + cmp x14, #1 + bgt L8LoopDz_TILE_8 + cbz x14, Tile8End + +L4LoopDz_TILE_8: + //ld1 {v0.4s}, [x20], #16 // bias + mov x11, x1 + mov x13, x3 + movi v7.16b, #15 + + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + + L4LoopSz_TILE_8: + ld1 {v3.d}[0], [x12], #8 // weight + ld1 {v0.16b, v1.16b}, [x11], x22 // src + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + subs x13, x13, #1 + .inst 0x4f81e06c // sdot v12.4s, v3.16b, v1.4b[0] + .inst 0x4fa1e06d // sdot v13.4s, v3.16b, v1.4b[1] + .inst 0x4f81e86e // sdot v14.4s, v3.16b, v1.4b[2] + .inst 0x4fa1e86f // sdot v15.4s, v3.16b, v1.4b[3] + bne L4LoopSz_TILE_8 + + L4LoopSzEnd_TILE_8: + + L4Tile8Quan: + ld1 {v0.4s}, [x19], #16 // scale + ld1 {v2.4s, v3.4s}, [x25] // x kernel sum + ld1 {v24.4s}, [x6], #16 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v0, v12, v13, v14, v15 + + cbz x24, TILE8_L4_MLA + ld1 {v0.4s, v1.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v1, v12, v13, v14, v15 + + TILE8_L4_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v3, v24, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v13, v3, v24, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v3, v24, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v15, v3, v24, 3 // tile:7, oc:0-3 + + sub x4, x4, #64 + + cbz x9, TILE8_L4_ADD_DSTV + TILE8_L4_ADD_BIAS: + ld1 {v4.4s}, [x20], #16 + ADD_BIAS_FLOAT v8, v9, v10, v11, v4 + ADD_BIAS_FLOAT v12, v13, v14, v15, v4 + b TILE8_L4_POST + + TILE8_L4_ADD_DSTV: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], #64 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10] + sub x10, x10, #64 + ADD_FLOAT v8, v9, v10, v11, v4, v5, v6, v7 + ADD_FLOAT v12, v13, v14, v15, v16, v17, v18, v19 + + TILE8_L4_POST: + cbz x23, TILE8_L4_STORE + ld1r {v0.4s}, [x23], #4 // f32 min + ld1r {v1.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v0, v1 + ReLU_FP32 v12, v13, v14, v15, v0, v1 + sub x23, x23, #4 + + TILE8_L4_STORE: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 + add x4, x4, #64 + +Tile8End: +cbz x24, Tile8_End_Offset +add x24, x24, #32 + +Tile8_End_Offset: + sub x7, x7, #8 + add x0, x0, x21, LSL #3 + add x1, x1, #32 + add x25, x25, #32 + +TILE_4: + cmp x7, #4 + blt TILE_1 + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 + mov x20, x9 + mov x6, x26 // weightQuantBias + cmp x5, #2 + blt L4LoopDz_TILE_4 +L8LoopDz_TILE_4: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 + mov x13, x3 + mov x27, x12 + movi v7.16b, #15 + + SET_BIAS v8, v9, v10, v11 + SET_BIAS v12, v13, v14, v15 + + L8LoopSz_TILE_4: + ld1 {v3.d}[0], [x12], x15 // weight + ld1 {v0.16b}, [x11], x22 // src + ld1 {v4.d}[0], [x12], #8 // weight + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + // int4->int8 + ushr v5.16b, v4.16b, #4 + and v6.16b, v4.16b, v7.16b + zip1 v4.16b, v5.16b, v6.16b + subs x13, x13, #1 + sub x12, x12, x15 + .inst 0x4f80e08c // sdot v12.4s, v4.16b, v0.4b[0] + .inst 0x4fa0e08d // sdot v13.4s, v4.16b, v0.4b[1] + .inst 0x4f80e88e // sdot v14.4s, v4.16b, v0.4b[2] + .inst 0x4fa0e88f // sdot v15.4s, v4.16b, v0.4b[3] + bne L8LoopSz_TILE_4 + + L8LoopSzEnd_TILE_4: + //add x12, x12, x15 + add x12, x27, x15, LSL #1 + sub x14, x14, #2 + + L8Tile4Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.4s}, [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + MUL_SCALE v0, v8, v9, v10, v11 + MUL_SCALE v1, v12, v13, v14, v15 + + cbz x24, TILE4_L8_MLA + ld1 {v0.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + MUL_EXTRA_SCALE v0, v12, v13, v14, v15 + + TILE4_L8_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v2, v25, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v13, v2, v25, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v14, v2, v25, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v15, v2, v25, 3 // tile:3, oc:4-7 + + cbz x9, TILE4_ADD_DSTV + TILE4_ADD_BIAS: + ld1 {v4.4s, v5.4s}, [x20], #32 + ADD_BIAS_FLOAT v8, v9, v10, v11, v4 + ADD_BIAS_FLOAT v12, v13, v14, v15, v5 + b TILE4_POST + + TILE4_ADD_DSTV: + ld1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x10], x4 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x10] + sub x10, x10, x4 + ADD_FLOAT v8, v9, v10, v11, v4, v5, v6, v7 + ADD_FLOAT v12, v13, v14, v15, v16, v17, v18, v19 + + TILE4_POST: + cbz x23, TILE4_STORE + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v26, v27 + ReLU_FP32 v12, v13, v14, v15, v26, v27 + sub x23, x23, #4 + + TILE4_STORE: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10], x4 + + L8Tile4LoopCheck: + cmp x14, #1 + bgt L8LoopDz_TILE_4 + cbz x14, Tile4End + +L4LoopDz_TILE_4: + //ld1 {v0.4s}, [x20], #16 // bias + mov x11, x1 + mov x13, x3 + movi v7.16b, #15 + SET_BIAS v8, v9, v10, v11 + + L4LoopSz_TILE_4: + ld1 {v3.d}[0], [x12], #8 // weight + ld1 {v0.16b}, [x11], x22 // src + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + subs x13, x13, #1 + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + .inst 0x4fa0e069 // sdot v9.4s, v3.16b, v0.4b[1] + .inst 0x4f80e86a // sdot v10.4s, v3.16b, v0.4b[2] + .inst 0x4fa0e86b // sdot v11.4s, v3.16b, v0.4b[3] + bne L4LoopSz_TILE_4 + + L4LoopSzEnd_TILE_4: + + L4Tile4Quan: + ld1 {v0.4s}, [x19], #16 // scale + ld1 {v2.4s}, [x25] // x kernel sum + ld1 {v24.4s}, [x6], #16 // weight quan zeropoint + Int32ToFloat v8, v9, v10, v11 + MUL_SCALE v0, v8, v9, v10, v11 + + cbz x24, TILE4_L4_MLA + ld1 {v0.4s}, [x24] + MUL_EXTRA_SCALE v0, v8, v9, v10, v11 + + TILE4_L4_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v24, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v2, v24, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v11, v2, v24, 3 // tile:3, oc:0-3 + + cbz x9, TILE4_L4_ADD_DSTV + TILE4_L4_ADD_BIAS: + ld1 {v3.4s}, [x20], #16 + ADD_BIAS_FLOAT v8, v9, v10, v11, v3 + b TILE4_L4_POST + + TILE4_L4_ADD_DSTV: + ld1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x10] + ADD_FLOAT v8, v9, v10, v11, v12, v13, v14, v15 + + TILE4_L4_POST: + cbz x23, TILE4_L4_STORE + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + ReLU_FP32 v8, v9, v10, v11, v26, v27 + sub x23, x23, #4 + + TILE4_L4_STORE: + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x10], x4 + +Tile4End: +cbz x24, Tile4_End_Offset +add x24, x24, #16 + +Tile4_End_Offset: + sub x7, x7, #4 + add x0, x0, x21, LSL #2 + add x1, x1, #16 + add x25, x25, #16 + +TILE_1: + cbz x7, End + movi v7.16b, #15 + mov x10, x0 + mov x12, x2 + mov x14, x5 + mov x19, x8 + mov x20, x9 + mov x6, x26 // weightQuantBias + cmp x5, #2 + blt L4LoopDz_TILE_1 +L8LoopDz_TILE_1: + //ld1 {v0.4s, v1.4s}, [x20], #32 // bias + mov x11, x1 + mov x13, x3 + mov x27, x12 + + movi v8.16b, #0 + movi v9.16b, #0 + L8LoopSz_TILE_1: + ld1 {v3.d}[0], [x12], x15 // weight + ld1 {v0.s}[0], [x11], x22 // src + ld1 {v4.d}[0], [x12], #8 // weight + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + subs x13, x13, #1 + // int4->int8 + ushr v5.16b, v4.16b, #4 + and v6.16b, v4.16b, v7.16b + zip1 v4.16b, v5.16b, v6.16b + sub x12, x12, x15 + + .inst 0x4f80e089 // sdot v9.4s, v4.16b, v0.4b[0] + bne L8LoopSz_TILE_1 + + L8LoopSzEnd_TILE_1: + add x12, x27, x15, LSL #1 + sub x14, x14, #2 + + L8Tile1Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v2.s}[0], [x25] // x kernel sum + ld1 {v24.4s, v25.4s}, [x6], #32 // weight quan zeropoint + scvtf v8.4s, v8.4s + scvtf v9.4s, v9.4s + fmul v8.4s, v8.4s, v0.4s + fmul v9.4s, v9.4s, v1.4s + + cbz x24, TILE1_L8_MLA + ld1 {v0.s}[0], [x24] + fmul v8.4s, v8.4s, v0.s[0] + fmul v9.4s, v9.4s, v0.s[0] + + TILE1_L8_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v9, v2, v25, 0 // tile:0, oc:4-7 + + cbz x9, TILE1_ADD_DSTV + TILE1_ADD_BIAS: + ld1 {v10.4s, v11.4s}, [x20], #32 + fadd v8.4s, v8.4s, v10.4s + fadd v9.4s, v9.4s, v11.4s + b TILE1_POST + + TILE1_ADD_DSTV: + ld1 {v10.4s}, [x10], x4 + ld1 {v11.4s}, [x10] + sub x10, x10, x4 + fadd v8.4s, v8.4s, v10.4s + fadd v9.4s, v9.4s, v11.4s + + TILE1_POST: + cbz x23, TILE1_STORE + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + sub x23, x23, #4 + fmin v8.4s, v8.4s, v27.4s + fmin v9.4s, v9.4s, v27.4s + fmax v8.4s, v8.4s, v26.4s + fmax v9.4s, v9.4s, v26.4s + + TILE1_STORE: + st1 {v8.4s}, [x10], x4 + st1 {v9.4s}, [x10], x4 + + L8Tile1LoopCheck: + cmp x14, #1 + bgt L8LoopDz_TILE_1 + cbz x14, Tile1End + +L4LoopDz_TILE_1: + //ld1 {v0.4s}, [x20], #16 // bias + mov x11, x1 + mov x13, x3 + movi v8.16b, #0 + L4LoopSz_TILE_1: + ld1 {v3.d}[0], [x12], #8 // weight + ld1 {v0.s}[0], [x11], x22 // src + // int4->int8 + ushr v5.16b, v3.16b, #4 + and v6.16b, v3.16b, v7.16b + zip1 v3.16b, v5.16b, v6.16b + subs x13, x13, #1 + + .inst 0x4f80e068 // sdot v8.4s, v3.16b, v0.4b[0] + bne L4LoopSz_TILE_1 + + L4LoopSzEnd_TILE_1: + + L4Tile1Quan: + ld1 {v0.4s}, [x19], #16 // scale + ld1 {v2.s}[0], [x25] // x kernel sum + ld1 {v24.4s}, [x6], #16 // weight quan zeropoint + scvtf v8.4s, v8.4s + fmul v8.4s, v8.4s, v0.4s + + cbz x24, TILE1_L4_MLA + ld1 {v0.s}[0], [x24] + fmul v8.4s, v8.4s, v0.s[0] + + TILE1_L4_MLA: + MLA_WEIGHTZERO v8, v2, v24, 0 // tile:0, oc:0-3 + + cbz x9, TILE1_L4_ADD_DSTV + TILE1_L4_ADD_BIAS: + ld1 {v4.4s}, [x20], #16 + fadd v8.4s, v8.4s, v4.4s + b TILE1_L4_POST + + TILE1_L4_ADD_DSTV: + ld1 {v4.4s}, [x10] + fadd v8.4s, v8.4s, v4.4s + + TILE1_L4_POST: + cbz x23, TILE1_L4_STORE + ld1r {v26.4s}, [x23], #4 // f32 min + ld1r {v27.4s}, [x23] // f32 max + sub x23, x23, #4 + fmax v8.4s, v8.4s, v26.4s + fmin v8.4s, v8.4s, v27.4s + TILE1_L4_STORE: + st1 {v8.4s}, [x10], x4 + +Tile1End: +cbz x24, Tile1_End_Offset +add x24, x24, #4 + +Tile1_End_Offset: + sub x7, x7, #1 + add x0, x0, x21 + add x1, x1, #4 + add x25, x25, #4 + b TILE_1 + +End: +ldp x23, x24, [sp, #(16 * 8)] +ldp x25, x26, [sp, #(16 * 7)] +ldp x27, x28, [sp, #(16 * 6)] +ldp x19, x20, [sp, #(16 * 5)] +ldp x21, x22, [sp, #(16 * 4)] +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 10) +ret + +#endif // __aarch64__ diff --git a/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit.S b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit.S new file mode 100644 index 000000000..b4cc330c2 --- /dev/null +++ b/source/backend/cpu/arm/arm64/low_memory/MNNGemmInt8AddBiasScale_ARMV86_w4_Unit.S @@ -0,0 +1,1205 @@ +// +// MNNGemmInt8AddBiasScale_ARMV86_w4_Unit.S +// MNN +// +// Created by MNN on 2022/09/26. +// Copyright © 2018, Alibaba Group Holding Limited +// + +#if defined(__aarch64__) +#include "MNNAsmGlobal.h" + +.text +.align 5 + +.macro SET_0_5 d0, d1, d2, d3, d4 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 + movi \d4\().16b, #0 +.endm +.macro SET_0_4 d0, d1, d2, d3 + movi \d0\().16b, #0 + movi \d1\().16b, #0 + movi \d2\().16b, #0 + movi \d3\().16b, #0 +.endm +.macro ADD_BIAS_FLOAT d0, d1, d2, d3, z0 + fadd \d0\().4s, \d0\().4s, \z0\().4s + fadd \d1\().4s, \d1\().4s, \z0\().4s + fadd \d2\().4s, \d2\().4s, \z0\().4s + fadd \d3\().4s, \d3\().4s, \z0\().4s +.endm +.macro ADD_FLOAT d0, d1, d2, d3, s0, s1, s2, s3 + fadd \d0\().4s, \d0\().4s, \s0\().4s + fadd \d1\().4s, \d1\().4s, \s1\().4s + fadd \d2\().4s, \d2\().4s, \s2\().4s + fadd \d3\().4s, \d3\().4s, \s3\().4s +.endm +.macro Int32ToFloat z0, z1, z2, z3 + scvtf \z0\().4s, \z0\().4s + scvtf \z1\().4s, \z1\().4s + scvtf \z2\().4s, \z2\().4s + scvtf \z3\().4s, \z3\().4s +.endm +.macro MUL_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().4s + fmul \d1\().4s, \d1\().4s, \s\().4s + fmul \d2\().4s, \d2\().4s, \s\().4s + fmul \d3\().4s, \d3\().4s, \s\().4s +.endm +.macro MUL_EXTRA_SCALE s, d0, d1, d2, d3 + fmul \d0\().4s, \d0\().4s, \s\().s[0] + fmul \d1\().4s, \d1\().4s, \s\().s[1] + fmul \d2\().4s, \d2\().4s, \s\().s[2] + fmul \d3\().4s, \d3\().4s, \s\().s[3] +.endm +.macro MLA_WEIGHTZERO d0, s0, s1, idx // idx for xKernelSum + fmla \d0\().4s, \s1\().4s, \s0\().s[\idx] +.endm +.macro ReLU_FP32 s0, s1, s2, s3, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmin \s2\().4s, \s2\().4s, \z1\().4s + fmin \s3\().4s, \s3\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s + fmax \s2\().4s, \s2\().4s, \z0\().4s + fmax \s3\().4s, \s3\().4s, \z0\().4s +.endm +.macro ReLU_FP32_2 s0, s1, z0, z1 // z0:min z1:max + fmin \s0\().4s, \s0\().4s, \z1\().4s + fmin \s1\().4s, \s1\().4s, \z1\().4s + fmax \s0\().4s, \s0\().4s, \z0\().4s + fmax \s1\().4s, \s1\().4s, \z0\().4s +.endm + +asm_function MNNGemmInt8AddBiasScale_ARMV86_w4_Unit +/* +struct QuanPostTreatParameters { + const float* scale; + const float* biasFloat; + int32_t maxValue; + int32_t minValue; + int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. + float roundValuePos = 0.5f; + float roundValueNeg = -0.5f; + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum; + float* extraScale; + +}; +*/ +//void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit(int8_t* dst, const int8_t* src, +// const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, +// const QuanPostTreatParameters* parameters, size_t realDstCount); + +//Auto: x0:dst, x1:src, x2:weight, x3:src_depth_quad, x4:dst_step +//x5:dst_depth_quad, x6: parameters, x7: realDstCount + +//Load from x6: x8: scale, x9: bias, x27: srcKernelSum, x28: weightQuanBias, +ldr x8, [x6, #0] +ldr x9, [x6, #8] + +stp d14, d15, [sp, #(-16 * 10)]! +stp d12, d13, [sp, #(16 * 1)] +stp d10, d11, [sp, #(16 * 2)] +stp d8, d9, [sp, #(16 * 3)] +stp x21, x22, [sp, #(16 * 4)] +stp x19, x20, [sp, #(16 * 5)] +stp x23, x24, [sp, #(16 * 6)] +stp x25, x26, [sp, #(16 * 7)] +stp x27, x28, [sp, #(16 * 8)] +ldr x27, [x6, #40] // srcKernelSum +ldr x28, [x6, #48] // weightQuanBias + +ldr x22, [x6, #64] // blockNum +mul x22, x22, x3 // UP_DIV(ic*ky*kx, SRC_UNIT) = blockNum * src_depth_quad_per_block +lsl x15, x22, #5 // x15 = src_depth_quad * UNIT * UNIT_SRC = src_depth_quad * 64 * (sizeof(int4)) = src_depth_quad << 4 + +mov x21, #16 // sizeof(float) * pack +ldr x14, [x6, #56] // float32 maxmin ptr +ldr x23, [x6, #80] // extra scale + +Start: +mov x22, #80 // GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT = 10 * 8 = 80 + +TILE_10: + cmp x7, #10 + blt TILE_8 + sub x4, x4, #128 // For float32 output, x4-128 +cmp x5, #2 +blt LoopDz4_TILE_10 + +LoopDz8_TILE_10: + mov x11, x1 // src + mov x12, x2 // weight + mov x13, x3 // src_depth_quad + + SET_0_5 v12, v16, v20, v24, v28 // oc:0,1,0,1 + SET_0_5 v13, v17, v21, v25, v29 // oc:2,3,2,3 + SET_0_5 v14, v18, v22, v26, v30 // oc:4,5,4,5 + SET_0_5 v15, v19, v23, v27, v31 // oc:6,7,6,7 + +LoopSz_TILE_10: + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + movi v2.16b, #15 + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 + ld1 {v7.16b}, [x11], #16 + // int4->int8 + + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + and v10.16b, v0.16b, v2.16b // oc:4-5 + and v11.16b, v1.16b, v2.16b // oc:6-7 + + subs x13, x13, #1 + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba46f // smmla v15.4s, v3.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa492 // smmla v18.4s, v4.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba493 // smmla v19.4s, v4.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + .inst 0x4e8aa4b6 // smmla v22.4s, v5.16b, v10.16b // tile4-oc4, tile4-oc5, tile5-oc4, tile5-oc5 + .inst 0x4e8ba4b7 // smmla v23.4s, v5.16b, v11.16b // tile4-oc6, tile4-oc7, tile5-oc6, tile5-oc7 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + .inst 0x4e8aa4da // smmla v26.4s, v6.16b, v10.16b // tile6-oc4, tile6-oc5, tile7-oc4, tile7-oc5 + .inst 0x4e8ba4db // smmla v27.4s, v6.16b, v11.16b // tile6-oc6, tile6-oc7, tile7-oc6, tile7-oc7 + + .inst 0x4e88a4fc // smmla v28.4s, v7.16b, v8.16b // tile8-oc0, tile8-oc1, tile9-oc0, tile9-oc1 + .inst 0x4e89a4fd // smmla v29.4s, v7.16b, v9.16b // tile8-oc2, tile8-oc3, tile9-oc2, tile9-oc3 + .inst 0x4e8aa4fe // smmla v30.4s, v7.16b, v10.16b // tile8-oc4, tile8-oc5, tile9-oc4, tile9-oc5 + .inst 0x4e8ba4ff // smmla v31.4s, v7.16b, v11.16b // tile8-oc6, tile8-oc7, tile9-oc6, tile9-oc7 + bne LoopSz_TILE_10 +LoopSzEnd_TILE_10: + add x2, x2, x15 // weight += dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT * 0.5); + sub x5, x5, #2 // dz-2 + // transpose + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v16.2d, v17.2d + uzp1 v4.2d, v20.2d, v21.2d + uzp2 v5.2d, v20.2d, v21.2d + uzp1 v6.2d, v24.2d, v25.2d + uzp2 v7.2d, v24.2d, v25.2d + uzp1 v8.2d, v28.2d, v29.2d + uzp2 v9.2d, v28.2d, v29.2d + + uzp1 v10.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v11.2d, v14.2d, v15.2d // E1: oc:4-7 + uzp1 v12.2d, v18.2d, v19.2d + uzp2 v13.2d, v18.2d, v19.2d + uzp1 v14.2d, v22.2d, v23.2d + uzp2 v15.2d, v22.2d, v23.2d + uzp1 v16.2d, v26.2d, v27.2d + uzp2 v17.2d, v26.2d, v27.2d + uzp1 v18.2d, v30.2d, v31.2d + uzp2 v19.2d, v30.2d, v31.2d + + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + Int32ToFloat v16, v17, v18, v19 + +Tile10Quan: + ld1 {v20.4s, v21.4s}, [x8], #32 // scale + ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum + ld1 {v24.d}[0], [x27] + ld1 {v25.4s, v26.4s}, [x28], #32 // weight quan zeropoint + sub x27, x27, #32 + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v20, v4, v5, v6, v7 + MUL_SCALE v21, v10, v11, v12, v13 + MUL_SCALE v21, v14, v15, v16, v17 + fmul v8.4s, v8.4s, v20.4s + fmul v9.4s, v9.4s, v20.4s + fmul v18.4s, v18.4s, v21.4s + fmul v19.4s, v19.4s, v21.4s + + cbz x23, TILE10_MLA + ld1 {v27.4s, v28.4s}, [x23], #32 + ld1 {v29.d}[0], [x23] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + MUL_EXTRA_SCALE v28, v4, v5, v6, v7 + MUL_EXTRA_SCALE v27, v10, v11, v12, v13 + MUL_EXTRA_SCALE v28, v14, v15, v16, v17 + fmul v8.4s, v8.4s, v29.s[0] + fmul v9.4s, v9.4s, v29.s[1] + fmul v18.4s, v18.4s, v29.s[0] + fmul v19.4s, v19.4s, v29.s[1] + sub x23, x23, #32 + + TILE10_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v10, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v11, v22, v26, 1 // tile:1, oc:4-7 + + MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v12, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v13, v22, v26, 3 // tile:3, oc:4-7 + + MLA_WEIGHTZERO v4, v23, v25, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v5, v23, v25, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v14, v23, v26, 0 // tile:4, oc:4-7 + MLA_WEIGHTZERO v15, v23, v26, 1 // tile:5, oc:4-7 + + MLA_WEIGHTZERO v6, v23, v25, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v16, v23, v26, 2 // tile:6, oc:4-7 + MLA_WEIGHTZERO v17, v23, v26, 3 // tile:7, oc:4-7 + + MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3 + MLA_WEIGHTZERO v18, v24, v26, 0 // tile:8, oc:4-7 + MLA_WEIGHTZERO v19, v24, v26, 1 // tile:9, oc:4-7 + + TILE10_ADD_BIAS: + cbz x9, TILE10_ADD_DSTV + ld1 {v20.4s, v21.4s}, [x9], #32 // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v20 + ADD_BIAS_FLOAT v4, v5, v6, v7, v20 + ADD_BIAS_FLOAT v10, v11, v12, v13, v21 + ADD_BIAS_FLOAT v14, v15, v16, v17, v21 + fadd v8.4s, v8.4s, v20.4s + fadd v9.4s, v9.4s, v20.4s + fadd v18.4s, v18.4s, v21.4s + fadd v19.4s, v19.4s, v21.4s + b TILE10_POST + + TILE10_ADD_DSTV: + // first batch10 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s}, [x0], x4 + ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 + ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 + fadd v8.4s, v8.4s, v28.4s + fadd v9.4s, v9.4s, v29.4s + + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s}, [x0] + ADD_FLOAT v10, v11, v12, v13, v20, v21, v22, v23 + ADD_FLOAT v14, v15, v16, v17, v24, v25, v26, v27 + fadd v18.4s, v18.4s, v28.4s + fadd v19.4s, v19.4s, v29.4s + + sub x0, x0, #256 + sub x0, x0, x4 + + TILE10_POST: + cbz x14, TILE10_STORE + ld1r {v30.4s}, [x14], #4 // f32 min + ld1r {v31.4s}, [x14] // f32 max + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + ReLU_FP32 v8, v9, v10, v11, v30, v31 + ReLU_FP32 v12, v13, v14, v15, v30, v31 + ReLU_FP32 v16, v17, v18, v19, v30, v31 + sub x14, x14, #4 + + TILE10_STORE: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 + st1 {v8.4s, v9.4s}, [x0], x4 + st1 {v10.4s, v11.4s, v12.4s, v13.4s}, [x0], #64 + st1 {v14.4s, v15.4s, v16.4s, v17.4s}, [x0], #64 + st1 {v18.4s, v19.4s}, [x0], x4 + +Tile10LoopCheck: + cmp x5, #2 + bge LoopDz8_TILE_10 + cbz x5, End + +LoopDz4_TILE_10: + mov x11, x1 // src + mov x12, x2 // weight + mov x13, x3 // src_depth_quad + + SET_0_5 v12, v13, v16, v17, v20 + SET_0_5 v21, v24, v25, v28, v29 + +LoopSz4_TILE_10: + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], #64 // src: E0-E9 + ld1 {v7.16b}, [x11], #16 + subs x13, x13, #1 + + // int4->int8 + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + + .inst 0x4e88a4fc // smmla v28.4s, v7.16b, v8.16b // tile8-oc0, tile8-oc1, tile9-oc0, tile9-oc1 + .inst 0x4e89a4fd // smmla v29.4s, v7.16b, v9.16b // tile8-oc2, tile8-oc3, tile9-oc2, tile9-oc3 + bne LoopSz4_TILE_10 +LoopSz4End_TILE_10: + // transpose + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v16.2d, v17.2d + uzp1 v4.2d, v20.2d, v21.2d + uzp2 v5.2d, v20.2d, v21.2d + uzp1 v6.2d, v24.2d, v25.2d + uzp2 v7.2d, v24.2d, v25.2d + uzp1 v8.2d, v28.2d, v29.2d + uzp2 v9.2d, v28.2d, v29.2d + + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + scvtf v8.4s, v8.4s + scvtf v9.4s, v9.4s + +Tile10Quan_L4: + ld1 {v20.4s}, [x8] // scale + ld1 {v22.4s, v23.4s}, [x27], #32 // x kernel sum + ld1 {v24.d}[0], [x27] + ld1 {v25.4s}, [x28] // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v20, v4, v5, v6, v7 + fmul v8.4s, v8.4s, v20.4s + fmul v9.4s, v9.4s, v20.4s + + cbz x23, TILE10_MLA_L4 + ld1 {v27.4s, v28.4s}, [x23], #32 + ld1 {v29.d}[0], [x23] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + MUL_EXTRA_SCALE v28, v4, v5, v6, v7 + fmul v8.4s, v8.4s, v29.s[0] + fmul v9.4s, v9.4s, v29.s[1] + + TILE10_MLA_L4: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v4, v23, v25, 0 // tile:4, oc:0-3 + MLA_WEIGHTZERO v5, v23, v25, 1 // tile:5, oc:0-3 + MLA_WEIGHTZERO v6, v23, v25, 2 // tile:6, oc:0-3 + MLA_WEIGHTZERO v7, v23, v25, 3 // tile:7, oc:0-3 + MLA_WEIGHTZERO v8, v24, v25, 0 // tile:8, oc:0-3 + MLA_WEIGHTZERO v9, v24, v25, 1 // tile:9, oc:0-3 + //sub x4, x4, #128 + + TILE10_ADD_BIAS_L4: + cbz x9, TILE10_ADD_DSTV_L4 + ld1 {v20.4s}, [x9] // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v20 + ADD_BIAS_FLOAT v4, v5, v6, v7, v20 + fadd v8.4s, v8.4s, v20.4s + fadd v9.4s, v9.4s, v20.4s + b TILE10_POST_L4 + + TILE10_ADD_DSTV_L4: + // first batch10 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x0], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x0], #64 + ld1 {v28.4s, v29.4s}, [x0] + ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 + ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 + fadd v8.4s, v8.4s, v28.4s + fadd v9.4s, v9.4s, v29.4s + + sub x0, x0, #128 + + TILE10_POST_L4: + cbz x14, TILE10_STORE_L4 + ld1r {v30.4s}, [x14], #4 // f32 min + ld1r {v31.4s}, [x14] // f32 max + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + fmax v8.4s, v8.4s, v30.4s + fmax v9.4s, v9.4s, v30.4s + fmin v8.4s, v8.4s, v31.4s + fmin v9.4s, v9.4s, v31.4s + sub x14, x14, #4 + + TILE10_STORE_L4: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x0], #64 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x0], #64 + st1 {v8.4s, v9.4s}, [x0], x4 + b End + +TILE_8: + // post parameters initilize + cbz x14, TILE_Remain + ld1r {v30.4s}, [x14], #4 // f32 min + ld1r {v31.4s}, [x14] // f32 max + + TILE_Remain: + movi v28.16b, #15 + cmp x7, #8 + blt TILE_4 + sub x4, x4, #64 // For float32 output + + TILE8_START: + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias + +cmp x5, #2 +blt LoopDz4_TILE_8 +LoopDz_TILE_8: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + SET_0_4 v12, v16, v20, v24 + SET_0_4 v13, v17, v21, v25 + SET_0_4 v14, v18, v22, v26 + SET_0_4 v15, v19, v23, v27 +LoopSz_TILE_8: + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], x22 // src: E0-E7 + + // int4->int8 + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + and v10.16b, v0.16b, v28.16b // oc:4-5 + and v11.16b, v1.16b, v28.16b // oc:6-7 + + subs x13, x13, #1 + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa46e // smmla v14.4s, v3.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba46f // smmla v15.4s, v3.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa492 // smmla v18.4s, v4.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba493 // smmla v19.4s, v4.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + .inst 0x4e8aa4b6 // smmla v22.4s, v5.16b, v10.16b // tile4-oc4, tile4-oc5, tile5-oc4, tile5-oc5 + .inst 0x4e8ba4b7 // smmla v23.4s, v5.16b, v11.16b // tile4-oc6, tile4-oc7, tile5-oc6, tile5-oc7 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + .inst 0x4e8aa4da // smmla v26.4s, v6.16b, v10.16b // tile6-oc4, tile6-oc5, tile7-oc4, tile7-oc5 + .inst 0x4e8ba4db // smmla v27.4s, v6.16b, v11.16b // tile6-oc6, tile6-oc7, tile7-oc6, tile7-oc7 + bne LoopSz_TILE_8 + +LoopSzEnd_TILE_8: + add x25, x25, x15 + sub x24, x24, #2 // dz-2 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v8.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v9.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v2.2d, v16.2d, v17.2d // E2: oc:0-3 + uzp2 v3.2d, v16.2d, v17.2d // E3: oc:0-3 + uzp1 v10.2d, v18.2d, v19.2d // E2: oc:4-7 + uzp2 v11.2d, v18.2d, v19.2d // E3: oc:4-7 + + uzp1 v4.2d, v20.2d, v21.2d // E4: oc:0-3 + uzp2 v5.2d, v20.2d, v21.2d // E5: oc:0-3 + uzp1 v12.2d, v22.2d, v23.2d // E4: oc:4-7 + uzp2 v13.2d, v22.2d, v23.2d // E5: oc:4-7 + + uzp1 v6.2d, v24.2d, v25.2d // E6: oc:0-3 + uzp2 v7.2d, v24.2d, v25.2d // E7: oc:0-3 + uzp1 v14.2d, v26.2d, v27.2d // E6: oc:4-7 + uzp2 v15.2d, v26.2d, v27.2d // E7: oc:4-7 + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + Int32ToFloat v8, v9, v10, v11 + Int32ToFloat v12, v13, v14, v15 + +Tile8Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.4s, v23.4s}, [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v20, v4, v5, v6, v7 + MUL_SCALE v21, v8, v9, v10, v11 + MUL_SCALE v21, v12, v13, v14, v15 + + cbz x23, TILE8_MLA + ld1 {v18.4s, v19.4s}, [x23] + MUL_EXTRA_SCALE v18, v0, v1, v2, v3 + MUL_EXTRA_SCALE v19, v4, v5, v6, v7 + MUL_EXTRA_SCALE v18, v8, v9, v10, v11 + MUL_EXTRA_SCALE v19, v12, v13, v14, v15 + + TILE8_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 + MLA_WEIGHTZERO v1, v22, v25, 1 + MLA_WEIGHTZERO v2, v22, v25, 2 + MLA_WEIGHTZERO v3, v22, v25, 3 + MLA_WEIGHTZERO v4, v23, v25, 0 + MLA_WEIGHTZERO v5, v23, v25, 1 + MLA_WEIGHTZERO v6, v23, v25, 2 + MLA_WEIGHTZERO v7, v23, v25, 3 + + MLA_WEIGHTZERO v8, v22, v26, 0 + MLA_WEIGHTZERO v9, v22, v26, 1 + MLA_WEIGHTZERO v10, v22, v26, 2 + MLA_WEIGHTZERO v11, v22, v26, 3 + MLA_WEIGHTZERO v12, v23, v26, 0 + MLA_WEIGHTZERO v13, v23, v26, 1 + MLA_WEIGHTZERO v14, v23, v26, 2 + MLA_WEIGHTZERO v15, v23, v26, 3 + + cbz x9, TILE8_ADD_DSTV + TILE8_ADD_BIAS: + ld1 {v16.4s, v17.4s}, [x20], #32 + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v16 + ADD_BIAS_FLOAT v8, v9, v10, v11, v17 + ADD_BIAS_FLOAT v12, v13, v14, v15, v17 + b TILE8_POST + + TILE8_ADD_DSTV: + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x26], x4 + ld1 {v16.4s, v17.4s, v18.4s, v19.4s}, [x26], #64 + ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 + ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26] + ADD_FLOAT v8, v9, v10, v11, v16, v17, v18, v19 + ADD_FLOAT v12, v13, v14, v15, v20, v21, v22, v23 + sub x26, x26, x4 + sub x26, x26, #128 + + TILE8_POST: + cbz x14, TILE8_STORE + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + ReLU_FP32 v8, v9, v10, v11, v30, v31 + ReLU_FP32 v12, v13, v14, v15, v30, v31 + + TILE8_STORE: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], #64 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 + st1 {v8.4s, v9.4s, v10.4s, v11.4s}, [x26], #64 + st1 {v12.4s, v13.4s, v14.4s, v15.4s}, [x26], x4 + b Tile8LoopCheck + +Tile8LoopCheck: + cmp x24, #2 + bge LoopDz_TILE_8 + cbz x24, Tile8Check + +LoopDz4_TILE_8: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + SET_0_4 v12, v13, v16, v17 + SET_0_4 v20, v21, v24, v25 +LoopSz4_TILE_8: + ld1 {v3.16b, v4.16b, v5.16b, v6.16b}, [x11], x22 // src: E0-E7 + subs x13, x13, #1 + // int4->int8 + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + + .inst 0x4e88a46c // smmla v12.4s, v3.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a46d // smmla v13.4s, v3.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + + .inst 0x4e88a490 // smmla v16.4s, v4.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a491 // smmla v17.4s, v4.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + + .inst 0x4e88a4b4 // smmla v20.4s, v5.16b, v8.16b // tile4-oc0, tile4-oc1, tile5-oc0, tile5-oc1 + .inst 0x4e89a4b5 // smmla v21.4s, v5.16b, v9.16b // tile4-oc2, tile4-oc3, tile5-oc2, tile5-oc3 + + .inst 0x4e88a4d8 // smmla v24.4s, v6.16b, v8.16b // tile6-oc0, tile6-oc1, tile7-oc0, tile7-oc1 + .inst 0x4e89a4d9 // smmla v25.4s, v6.16b, v9.16b // tile6-oc2, tile6-oc3, tile7-oc2, tile7-oc3 + bne LoopSz4_TILE_8 + +LoopSz4End_TILE_8: + add x25, x25, x15 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v16.2d, v17.2d // E2: oc:0-3 + uzp2 v3.2d, v16.2d, v17.2d // E3: oc:0-3 + uzp1 v4.2d, v20.2d, v21.2d // E4: oc:0-3 + uzp2 v5.2d, v20.2d, v21.2d // E5: oc:0-3 + uzp1 v6.2d, v24.2d, v25.2d // E6: oc:0-3 + uzp2 v7.2d, v24.2d, v25.2d // E7: oc:0-3 + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + +Tile8Quan_L4: + ld1 {v20.4s}, [x19] // scale + ld1 {v22.4s, v23.4s}, [x27] // x kernel sum + ld1 {v25.4s}, [x6] // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v20, v4, v5, v6, v7 + + cbz x23, TILE8_MLA_L4 + ld1 {v18.4s, v19.4s}, [x23] + MUL_EXTRA_SCALE v18, v0, v1, v2, v3 + MUL_EXTRA_SCALE v19, v4, v5, v6, v7 + + TILE8_MLA_L4: + MLA_WEIGHTZERO v0, v22, v25, 0 + MLA_WEIGHTZERO v1, v22, v25, 1 + MLA_WEIGHTZERO v2, v22, v25, 2 + MLA_WEIGHTZERO v3, v22, v25, 3 + MLA_WEIGHTZERO v4, v23, v25, 0 + MLA_WEIGHTZERO v5, v23, v25, 1 + MLA_WEIGHTZERO v6, v23, v25, 2 + MLA_WEIGHTZERO v7, v23, v25, 3 + + cbz x9, TILE8_ADD_DSTV_L4 + TILE8_ADD_BIAS_L4: + ld1 {v16.4s}, [x20] + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v16 + b TILE8_POST_L4 + + TILE8_ADD_DSTV_L4: + ld1 {v20.4s, v21.4s, v22.4s, v23.4s}, [x26], #64 + ld1 {v24.4s, v25.4s, v26.4s, v27.4s}, [x26] + ADD_FLOAT v0, v1, v2, v3, v20, v21, v22, v23 + ADD_FLOAT v4, v5, v6, v7, v24, v25, v26, v27 + sub x26, x26, #64 + + TILE8_POST_L4: + cbz x14, TILE8_STORE_L4 + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + + TILE8_STORE_L4: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], #64 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 + b Tile8Check + +Tile8Check: +cbz x23, Tile8End +add x23, x23, #32 + +Tile8End: + sub x7, x7, #8 + add x0, x0, x21, LSL #3 + add x1, x1, #64 + add x27, x27, #32 + add x4, x4, #64 // Revert x4 for following tile. + +TILE_4: + cmp x7, #4 + blt TILE_2 + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias + +cmp x5, #2 +blt LoopDz4_TILE_4 +LoopDz_TILE_4: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + SET_0_4 v12, v13, v14, v15 + SET_0_4 v16, v17, v18, v19 + +LoopSz_TILE_4: + // int4->int8 + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + and v10.16b, v0.16b, v28.16b // oc:4-5 + and v11.16b, v1.16b, v28.16b // oc:6-7 + + ld1 {v4.16b, v5.16b}, [x11], x22 // src + subs x13, x13, #1 + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa48e // smmla v14.4s, v4.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba48f // smmla v15.4s, v4.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + + .inst 0x4e88a4b0 // smmla v16.4s, v5.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + .inst 0x4e8aa4b2 // smmla v18.4s, v5.16b, v10.16b // tile2-oc4, tile2-oc5, tile3-oc4, tile3-oc5 + .inst 0x4e8ba4b3 // smmla v19.4s, v5.16b, v11.16b // tile2-oc6, tile2-oc7, tile3-oc6, tile3-oc7 + bne LoopSz_TILE_4 +LoopSzEnd_TILE_4: + add x25, x25, x15 + sub x24, x24, #2 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v4.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v5.2d, v14.2d, v15.2d // E1: oc:4-7 + + uzp1 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v16.2d, v17.2d + uzp1 v6.2d, v18.2d, v19.2d + uzp2 v7.2d, v18.2d, v19.2d + Int32ToFloat v0, v1, v2, v3 + Int32ToFloat v4, v5, v6, v7 + +Tile4Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.4s}, [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + MUL_SCALE v21, v4, v5, v6, v7 + + cbz x23, TILE4_MLA + ld1 {v27.4s}, [x23] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + MUL_EXTRA_SCALE v27, v4, v5, v6, v7 + + TILE4_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3 + MLA_WEIGHTZERO v4, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v5, v22, v26, 1 // tile:1, oc:4-7 + MLA_WEIGHTZERO v6, v22, v26, 2 // tile:2, oc:4-7 + MLA_WEIGHTZERO v7, v22, v26, 3 // tile:3, oc:4-7 + + TILE4_ADD_BIAS: + cbz x9, TILE4_ADD_DSTV + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + ADD_BIAS_FLOAT v4, v5, v6, v7, v17 + b TILE4_POST + + TILE4_ADD_DSTV: + ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26], x4 + ld1 {v19.4s, v20.4s, v21.4s, v22.4s}, [x26] + ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18 + ADD_FLOAT v4, v5, v6, v7, v19, v20, v21, v22 + sub x26, x26, x4 + + TILE4_POST: + cbz x14, TILE4_STORE + ReLU_FP32 v0, v1, v2, v3, v30, v31 + ReLU_FP32 v4, v5, v6, v7, v30, v31 + + TILE4_STORE: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], x4 + st1 {v4.4s, v5.4s, v6.4s, v7.4s}, [x26], x4 + b Tile4LoopCheck + +Tile4LoopCheck: + cmp x24, #2 + bge LoopDz_TILE_4 + cbz x24, Tile4Check + +LoopDz4_TILE_4: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + SET_0_4 v12, v13, v16, v17 +LoopSz4_TILE_4: + ld1 {v4.16b, v5.16b}, [x11], x22 // src + // int4->int8 + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + + subs x13, x13, #1 + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + + .inst 0x4e88a4b0 // smmla v16.4s, v5.16b, v8.16b // tile2-oc0, tile2-oc1, tile3-oc0, tile3-oc1 + .inst 0x4e89a4b1 // smmla v17.4s, v5.16b, v9.16b // tile2-oc2, tile2-oc3, tile3-oc2, tile3-oc3 + bne LoopSz4_TILE_4 +LoopSz4End_TILE_4: + add x25, x25, x15 + sub x24, x24, #1 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v16.2d, v17.2d + uzp2 v3.2d, v16.2d, v17.2d + Int32ToFloat v0, v1, v2, v3 + +Tile4Quan_L4: + ld1 {v20.4s}, [x19] // scale + ld1 {v22.4s}, [x27] // x kernel sum + ld1 {v25.4s}, [x6] // weight quan zeropoint + MUL_SCALE v20, v0, v1, v2, v3 + + cbz x23, TILE4_MLA_L4 + ld1 {v27.4s}, [x23] + MUL_EXTRA_SCALE v27, v0, v1, v2, v3 + + TILE4_MLA_L4: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v25, 2 // tile:2, oc:0-3 + MLA_WEIGHTZERO v3, v22, v25, 3 // tile:3, oc:0-3 + + TILE4_ADD_BIAS_L4: + cbz x9, TILE4_ADD_DSTV_L4 + ld1 {v16.4s}, [x20] // bias + ADD_BIAS_FLOAT v0, v1, v2, v3, v16 + b TILE4_POST_L4 + + TILE4_ADD_DSTV_L4: + ld1 {v15.4s, v16.4s, v17.4s, v18.4s}, [x26] + ADD_FLOAT v0, v1, v2, v3, v15, v16, v17, v18 + + TILE4_POST_L4: + cbz x14, TILE4_STORE_L4 + ReLU_FP32 v0, v1, v2, v3, v30, v31 + + TILE4_STORE_L4: + st1 {v0.4s, v1.4s, v2.4s, v3.4s}, [x26], x4 + b Tile4Check + +Tile4Check: +cbz x23, Tile4End +add x23, x23, #16 +Tile4End: + sub x7, x7, #4 + add x0, x0, x21, LSL #2 + add x1, x1, #32 + add x27, x27, #16 + +TILE_2: + cmp x7, #2 + blt TILE_1 + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias + +cmp x5, #2 +blt LoopDz4_TILE_2 +LoopDz_TILE_2: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + SET_0_4 v12, v13, v14, v15 +LoopSz_TILE_2: + // int4->int8 + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + and v10.16b, v0.16b, v28.16b // oc:4-5 + and v11.16b, v1.16b, v28.16b // oc:6-7 + + ld1 {v4.16b}, [x11], x22 // src + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + .inst 0x4e8aa48e // smmla v14.4s, v4.16b, v10.16b // tile0-oc4, tile0-oc5, tile1-oc4, tile1-oc5 + .inst 0x4e8ba48f // smmla v15.4s, v4.16b, v11.16b // tile0-oc6, tile0-oc7, tile1-oc6, tile1-oc7 + subs x13, x13, #1 + bne LoopSz_TILE_2 +LoopSzEnd_TILE_2: + add x25, x25, x15 + sub x24, x24, #2 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + uzp1 v2.2d, v14.2d, v15.2d // E0: oc:4-7 + uzp2 v3.2d, v14.2d, v15.2d // E1: oc:4-7 + Int32ToFloat v0, v1, v2, v3 + +Tile2Quan: + ld1 {v20.4s, v21.4s}, [x19], #32 // scale + ld1 {v22.d}[0], [x27] // x kernel sum + ld1 {v25.4s, v26.4s}, [x6], #32 // weight quan zeropoint + fmul v0.4s, v0.4s, v20.4s + fmul v1.4s, v1.4s, v20.4s + fmul v2.4s, v2.4s, v21.4s + fmul v3.4s, v3.4s, v21.4s + + cbz x23, TILE2_MLA + ld1 {v27.d}[0], [x23] + fmul v0.4s, v0.4s, v27.s[0] + fmul v1.4s, v1.4s, v27.s[1] + fmul v2.4s, v2.4s, v27.s[0] + fmul v3.4s, v3.4s, v27.s[1] + + TILE2_MLA: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + MLA_WEIGHTZERO v2, v22, v26, 0 // tile:0, oc:4-7 + MLA_WEIGHTZERO v3, v22, v26, 1 // tile:1, oc:4-7 + + TILE2_ADD_BIAS: + cbz x9, TILE2_ADD_DSTV + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + fadd v2.4s, v2.4s, v17.4s + fadd v3.4s, v3.4s, v17.4s + b TILE2_POST + + TILE2_ADD_DSTV: + ld1 {v18.4s, v19.4s}, [x26], x4 + ld1 {v20.4s, v21.4s}, [x26] + fadd v0.4s, v0.4s, v18.4s + fadd v1.4s, v1.4s, v19.4s + fadd v2.4s, v2.4s, v20.4s + fadd v3.4s, v3.4s, v21.4s + sub x26, x26, x4 + + TILE2_POST: + cbz x14, TILE2_STORE + ReLU_FP32 v0, v1, v2, v3, v30, v31 + TILE2_STORE: + st1 {v0.4s, v1.4s}, [x26], x4 + st1 {v2.4s, v3.4s}, [x26], x4 + b Tile2LoopCheck + +Tile2LoopCheck: + cmp x24, #2 + bge LoopDz_TILE_2 + cbz x24, Tile2Check +LoopDz4_TILE_2: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + movi v12.4s, #0 + movi v13.4s, #0 +LoopSz4_TILE_2: + // int4->int8 + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + ld1 {v4.16b}, [x11], x22 // src + + .inst 0x4e88a48c // smmla v12.4s, v4.16b, v8.16b // tile0-oc0, tile0-oc1, tile1-oc0, tile1-oc1 + .inst 0x4e89a48d // smmla v13.4s, v4.16b, v9.16b // tile0-oc2, tile0-oc3, tile1-oc2, tile1-oc3 + subs x13, x13, #1 + bne LoopSz4_TILE_2 +LoopSz4End_TILE_2: + add x25, x25, x15 + uzp1 v0.2d, v12.2d, v13.2d // E0: oc:0-3 + uzp2 v1.2d, v12.2d, v13.2d // E1: oc:0-3 + scvtf v0.4s, v0.4s + scvtf v1.4s, v1.4s + +Tile2Quan_L4: + ld1 {v20.4s}, [x19] + ld1 {v22.d}[0], [x27] // x kernel sum + ld1 {v25.4s}, [x6] // weight quan zeropoint + fmul v0.4s, v0.4s, v20.4s + fmul v1.4s, v1.4s, v20.4s + + cbz x23, TILE2_MLA_L4 + ld1 {v27.d}[0], [x23] + fmul v0.4s, v0.4s, v27.s[0] + fmul v1.4s, v1.4s, v27.s[1] + + TILE2_MLA_L4: + MLA_WEIGHTZERO v0, v22, v25, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v1, v22, v25, 1 // tile:1, oc:0-3 + + TILE2_ADD_BIAS_L4: + cbz x9, TILE2_ADD_DSTV_L4 + ld1 {v16.4s}, [x20] // bias + fadd v0.4s, v0.4s, v16.4s + fadd v1.4s, v1.4s, v16.4s + b TILE2_POST_L4 + + TILE2_ADD_DSTV_L4: + ld1 {v18.4s, v19.4s}, [x26] + fadd v0.4s, v0.4s, v18.4s + fadd v1.4s, v1.4s, v19.4s + + TILE2_POST_L4: + cbz x14, TILE2_STORE_L4 + ReLU_FP32_2 v0, v1, v30, v31 + TILE2_STORE_L4: + st1 {v0.4s, v1.4s}, [x26], x4 + b Tile2Check + +Tile2Check: +cbz x23, Tile2End +add x23, x23, #8 +Tile2End: + sub x7, x7, #2 + add x0, x0, x21, LSL #1 + add x1, x1, #16 + add x27, x27, #8 + +TILE_1: + cmp x7, #1 + blt End + mov x24, x5 // dst_depth_quad + mov x26, x0 // dst + mov x25, x2 // weight + mov x19, x8 // scale + mov x20, x9 // bias + mov x6, x28 // weightQuanBias +cmp x5, #2 +blt LoopDz4_TILE_1 +LoopDz_TILE_1: + //ld1 {v0.4s}, [x20], #16 // bias + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + + movi v16.4s, #0 + movi v17.4s, #0 + movi v18.4s, #0 + movi v19.4s, #0 +LoopSz_TILE_1: + ld1 {v2.8b}, [x11], x22 // src + // int4->int8 + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + and v10.16b, v0.16b, v28.16b // oc:4-5 + and v11.16b, v1.16b, v28.16b // oc:6-7 + subs x13, x13, #1 + + .inst 0x4e88a450 // smmla v16.4s, v2.16b, v8.16b + .inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b + .inst 0x4e8aa452 // smmla v18.4s, v2.16b, v10.16b + .inst 0x4e8ba453 // smmla v19.4s, v2.16b, v11.16b + bne LoopSz_TILE_1 +LoopSzEnd_TILE_1: + add x25, x25, x15 + sub x24, x24, #2 + uzp1 v27.2d, v16.2d, v17.2d + uzp1 v26.2d, v18.2d, v19.2d + scvtf v27.4s, v27.4s + scvtf v26.4s, v26.4s + +Tile1Quan: + ld1 {v0.4s, v1.4s}, [x19], #32 // scale + ld1 {v6.s}[0], [x27] // x kernel sum + ld1 {v8.4s, v9.4s}, [x6], #32 // weight quan zeropoint + fmul v27.4s, v27.4s, v0.4s + fmul v26.4s, v26.4s, v1.4s + + cbz x23, TILE1_MLA + ld1 {v10.s}[0], [x23] + fmul v27.4s, v27.4s, v10.s[0] + fmul v26.4s, v26.4s, v10.s[0] + + TILE1_MLA: + MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 + MLA_WEIGHTZERO v26, v6, v9, 0 // tile:0, oc:4-7 + + TILE1_ADD_BIAS: + cbz x9, TILE1_ADD_DSTV + ld1 {v16.4s, v17.4s}, [x20], #32 // bias + fadd v27.4s, v27.4s, v16.4s + fadd v26.4s, v26.4s, v17.4s + b TILE1_POST + + TILE1_ADD_DSTV: + ld1 {v16.4s}, [x26], x4 + ld1 {v17.4s}, [x26] + fadd v27.4s, v27.4s, v16.4s + fadd v26.4s, v26.4s, v17.4s + sub x26, x26, x4 + + TILE1_POST: + cbz x14, TILE1_STORE + fmin v27.4s, v27.4s, v31.4s + fmax v27.4s, v27.4s, v30.4s + fmin v26.4s, v26.4s, v31.4s + fmax v26.4s, v26.4s, v30.4s + + TILE1_STORE: + st1 {v27.4s}, [x26], x4 + st1 {v26.4s}, [x26], x4 + b Tile1LoopEnd + +Tile1LoopEnd: + cmp x24, #2 + bge LoopDz_TILE_1 + cbz x24, End + +LoopDz4_TILE_1: + mov x11, x1 // src + mov x12, x25 // weight + mov x13, x3 // src_depth_quad + + movi v16.4s, #0 + movi v17.4s, #0 +LoopSz4_TILE_1: + ld1 {v2.8b}, [x11], x22 // src + // int4->int8 + ld1 {v0.16b, v1.16b}, [x12], #32 // weight + ushr v8.16b, v0.16b, #4 // oc:0-1 + ushr v9.16b, v1.16b, #4 // oc:2-3 + + subs x13, x13, #1 + .inst 0x4e88a450 // smmla v16.4s, v2.16b, v8.16b + .inst 0x4e89a451 // smmla v17.4s, v2.16b, v9.16b + bne LoopSz4_TILE_1 +LoopSz4End_TILE_1: + add x25, x25, x15 + uzp1 v27.2d, v16.2d, v17.2d + scvtf v27.4s, v27.4s + +Tile1Quan_L4: + ld1 {v0.4s}, [x19] // scale + ld1 {v6.s}[0], [x27] // x kernel sum + ld1 {v8.4s}, [x6] // weight quan zeropoint + fmul v27.4s, v27.4s, v0.4s + cbz x23, TILE1_MLA_L4 + ld1 {v10.s}[0], [x23] + fmul v27.4s, v27.4s, v10.s[0] + + TILE1_MLA_L4: + MLA_WEIGHTZERO v27, v6, v8, 0 // tile:0, oc:0-3 + + TILE1_ADD_BIAS_L4: + cbz x9, TILE1_ADD_DSTV_L4 + ld1 {v16.4s}, [x20] // bias + fadd v27.4s, v27.4s, v16.4s + b TILE1_POST_L4 + + TILE1_ADD_DSTV_L4: + ld1 {v16.4s}, [x26] + fadd v27.4s, v27.4s, v16.4s + + TILE1_POST_L4: + cbz x14, TILE1_STORE_L4 + fmin v27.4s, v27.4s, v31.4s + fmax v27.4s, v27.4s, v30.4s + + TILE1_STORE_L4: + st1 {v27.4s}, [x26], x4 + b End + +End: +ldp x27, x28, [sp, #(16 * 8)] +ldp x25, x26, [sp, #(16 * 7)] +ldp x23, x24, [sp, #(16 * 6)] +ldp x19, x20, [sp, #(16 * 5)] +ldp x21, x22, [sp, #(16 * 4)] +ldp d8, d9, [sp, #(16 * 3)] +ldp d10, d11, [sp, #(16 * 2)] +ldp d12, d13, [sp, #(16 * 1)] +ldp d14, d15, [sp], #(16 * 10) +ret + +#endif // __aarch64__ diff --git a/source/backend/cpu/bf16/BF16Functions.cpp b/source/backend/cpu/bf16/BF16Functions.cpp index 3f792a3ce..852cd791b 100644 --- a/source/backend/cpu/bf16/BF16Functions.cpp +++ b/source/backend/cpu/bf16/BF16Functions.cpp @@ -3,7 +3,6 @@ #include "../x86_x64/avx/FunctionSummary.hpp" #include "../x86_x64/avxfma/FunctionSummary.hpp" #include "../x86_x64/avx512/FunctionSummary.hpp" -#include "../x86_x64/cpu_id.h" #endif #include "core/Macro.h" #if defined(MNN_USE_NEON) @@ -11,20 +10,17 @@ #endif #include "BF16Functions.hpp" -#include "WinogradOptFunctionHalf.hpp" #include "../compute/CommonOptFunction.h" -#include "../CPUPool.hpp" #include "../CPURuntime.hpp" #include "VecHalf.hpp" #include "math/Vec.hpp" -#include "BF16Binary.hpp" -#include "BF16Unary.hpp" using BFVec4 = MNN::Math::VecHalf<4>; using Vec4 = MNN::Math::Vec; -extern "C" { -void MNNReluWithSlopeChannelBF16(float* dstO, const float* srcO, const float* slopeO, size_t sizeQuad, size_t depthQuad); -} namespace MNN { +// The Function Will be Called in init +void registerBF16Backend() { + BF16Functions::init(); +} // just for reference BF16 converting of c++ code, not for arm or sse. inline int16_t MNNFP32ToBF16(float fp32Value) { int32_t* s32Value = (int32_t*)(&fp32Value); @@ -76,825 +72,277 @@ static void _MNNLowpToFp32(const int16_t* src, float* dst, size_t size) { ::memcpy(dst, dstTemp, sizeRemain * sizeof(float)); } } -static void MNNConvRunForUnitDepthWiseBF16(float* dst, const float* src, const float* weight, size_t fw, size_t fh, - size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) { - int fx, fy; - BFVec4 dstValue(0.0f); - const int16_t* src_z = (const int16_t*)src; - const int16_t* weight_z = (const int16_t*)weight; - for (fy = 0; fy < fh; ++fy) { - const auto src_y = src_z + fy * dilateY_step; - const auto weight_y = weight_z + fy * weight_y_step; - for (fx = 0; fx < fw; ++fx) { - const auto weight_x = weight_y + 4 * fx; - const auto src_x = src_y + fx * dilateX_step; - dstValue = dstValue + BFVec4::load(src_x) * BFVec4::load(weight_x); - } - } - BFVec4::save((int16_t*)dst, dstValue); + +#if defined(MNN_USE_NEON) +// todo: search for proper value for bf16 +void NEON_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP) { + *eP = 12; + *lP = 1; +#ifdef __aarch64__ + *hP = 8; +#else + *hP = 4; +#endif } -static void MNNConvRunForLineDepthwiseBF16(float* dstO, const float* srcO, const float* weightO, size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, size_t height, - size_t srcHStep, size_t dstHStep) { - int dx, fx, fy; - auto dst = (int16_t*)dstO; - auto src = (const int16_t*)srcO; - auto weight = (const int16_t*)weightO; - for (int y = 0; y < height; ++y) { - auto srcY = src + y * srcHStep; - auto dstY = dst + y * dstHStep; - for (dx = 0; dx < width; ++dx) { - auto dst_x = dstY + dx * 4; - BFVec4 dstValue(0.0f); - const auto src_z = srcY + src_w_setup * dx; - const auto weight_z = weight; - for (fy = 0; fy < fh; ++fy) { - const auto src_y = src_z + fy * dilateY_step; - const auto weight_y = weight_z + fy * fw * 4; - for (fx = 0; fx < fw; ++fx) { - const auto weight_x = weight_y + 4 * fx; - const auto src_x = src_y + fx * dilateX_step; - dstValue = dstValue + BFVec4::load(src_x) * BFVec4::load(weight_x); - } - } - BFVec4::save(dst_x, dstValue); +#ifdef __aarch64__ +#define EP 12 +#define HP 8 +#define LP 4 +void ARMV86_MNNGetMatMulPackMode_BF16(int* eP, int* lP, int* hP) { + *eP = EP; + *hP = HP; + *lP = LP; +} +void ARMV86_MNNPackForMatMul_B_BF16(float* destF, const float* sourceF, size_t h, size_t l, bool transpose) { + // [l, h] -> [h/hp, l/lp, hp, lp] + auto dest = (int16_t*)destF; + auto source = (const int32_t*)sourceF; + auto lCP = UP_DIV(l, LP); + auto hCP = UP_DIV(h, HP); + int sYstride = 1; + int sXstride = h; + if (transpose) { + sYstride = l; + sXstride = 1; + } + ::memset(dest, 0, lCP * hCP * sizeof(int16_t) * HP * LP); + for (int y = 0; y < h; ++y) { + int yC = y / HP; + int yR = y % HP; + for (int x = 0; x < l; ++x) { + int xC = x / LP; + int xR = x % LP; + dest[xR + yR * LP + xC * HP * LP + yC * HP * LP * lCP] = source[sXstride * x + sYstride * y] >> 16; } } } -void MNNAxByClampBroadcastUnitBF16(float* CF, const float* AF, const float* BF, size_t width, size_t cStride, size_t aStride, size_t height, const float* parameters) { - auto C = (int16_t*)CF; - auto A = (const int16_t*)AF; - auto B = (const int16_t*)BF; - auto minF = BFVec4(parameters[2]); - auto maxF = BFVec4(parameters[3]); - auto beta = BFVec4(parameters[1]); - for (int y = 0; y < height; ++y) { - auto a = A + aStride * y; - auto b = B + 4 * y; - auto bv = BFVec4::load(b); - auto c = C + cStride * y; - for (int x = 0; x < width; ++x) { - auto av = BFVec4::load(a + 4 * x); - auto cv = av + bv * beta; - cv = BFVec4::min(cv, maxF); - cv = BFVec4::max(cv, minF); - BFVec4::save(c + 4 * x, cv); +void ARMV86_MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) { + // [l/4, e, 4] -> [l/4, ep, 4] + int number = info[0]; + int eReal = info[1]; + int eDest = info[2]; + int offset = info[3]; + if (1 == number) { + int l = el[1]; + if (l % 8 != 0) { + auto lAigin = UP_DIV(l, LP) * LP; + ::memset(destOrigin, 0, eDest * lAigin * sizeof(int16_t)); } } -} -#ifndef MNN_USE_NEON -void MNNReluWithSlopeChannelBF16(float* dstO, const float* srcO, const float* slopeO, size_t sizeQuad, size_t depthQuad) { - auto slope = (const int16_t*)slopeO; - auto dst = (int16_t*)dstO; - auto src = (const int16_t*)srcO; - auto zero = BFVec4(0.0f); - for (int j = 0; j < depthQuad; j++) { - auto slopeZ = BFVec4::load(slope + 4 * j); - auto srcZ = src + 4 * j * sizeQuad; - auto dstZ = dst + 4 * j * sizeQuad; - for (int i = 0; i < sizeQuad; i++) { - auto srcValue = BFVec4::load(srcZ + 4 * i); - std::array dstV; - for (int c = 0; c < 4; c++) { - if (srcValue[c] < 0) { - dstV[c] = srcValue[c] * slopeZ[c]; - } else { - dstV[c] = srcValue[c]; + + for (int n=0; n [l/4, ep, 4] + for (int x = 0; x < lDiv; ++x) { + auto destX = dest + x * eDest * 4; + auto srcX = source + x * eReal * 4; + for (int y = 0; y < e; ++y) { + auto srcV = Vec4::load(srcX + y * offset * 4); + auto dstV = BFVec4(std::move(srcV.value)); + BFVec4::save((int16_t*)(destX + 4*y), dstV); } } - auto dstValue = BFVec4(std::move(Vec4::load(dstV.data()).value)); - BFVec4::save(dstZ + 4 * i, dstValue); + continue; + } + for (int x = 0; x < l; ++x) { + auto dl = lOR + x; + auto dlC = dl / LP; + auto dlR = dl % LP; + auto xC = x / LP; + auto xR = x % LP; + auto destX = dest + dlC * eDest * LP + dlR; + auto srcX = sourceInt + xC * eReal * LP + xR; + for (int y = 0; y < e; ++y) { + destX[y * 4] = srcX[y * 4 * offset] >> 16; + } } } } -#endif - -#if !defined(MNN_USE_SSE) && !defined(MNN_USE_NEON) -void MNNPackC4ForMatMul_A_BF16(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) { - MNNPackC4ForMatMul_A(destOrigin, sourceGroup, info, el); - return; -} - -void MNNPackForMatMul_B_BF16(float* dest, const float* source, size_t h, size_t l, bool transpose) { - auto hP = h / 4; - auto hR = hP * 4; - if (hR != h) { - ::memset(dest, 0, UP_DIV(h, 4)*4*l*sizeof(int16_t)); - } +#undef EP +#undef HP +#undef LP +void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t l, bool transpose) { + auto hP = (int)h / 8; + auto hR = (int)hP * 8; + int16_t* dest = (int16_t*)destFloat; + const float* source = sourceFloat; if (!transpose) { - for (int y=0; y 0) { - auto destY = dest + hP * 4 * l; - auto sourceY = source + hP * 4; - for (int x=0; x().max(); - float maxValue = std::numeric_limits().max(); - if (nullptr != postParameters) { - minValue = postParameters[2]; - maxValue = postParameters[3]; - alpha = postParameters[0]; - beta = postParameters[1]; - } - - for (int x = 0; x < eSize; ++x) { - auto dst = C + 4 * x; - auto src = - A + x; // input data is packed as tileCount x l x 16, is only one tiled block here, indexed as A[z * 16 + x] - for (int ry = 0; ry < h; ++ry) { - auto y = ry / 4; - auto yRemain = ry % 4; - auto bY = B + y * bStride; - auto dstY = dst + y * cStride; // convert NCHW to NC4HW4 ie 1·(y/4)·X·4 - int wdy = ry / 6; - int wdyRemain = ry % 6; - auto weight = - B + wdy * bStride + - wdyRemain; // weight is packed as (h/6) x l x 6, indexed as B[(ry / 6) * Bstride +z*6 + (ry % 6)] - float summer = 0.0f; - for (int z = 0; z < l; ++z) { - auto aZ = src + z * 16; - auto wZ = weight + z * 6; - summer += MNNLowpToFp32(wZ[0]) * MNNLowpToFp32(aZ[0]); + auto destY = dest + hP * 8 * l; + auto sourceY = source + hP * 8; + float sTmp[8]; + ::memset(sTmp, 0, sizeof(sTmp)); + for (int x = 0; x < l; ++x) { + ::memcpy(sTmp, sourceY + x * h, hRemain * sizeof(float)); + auto s0 = Vec4::load(sTmp + 0); + auto s1 = Vec4::load(sTmp + 4); + auto d0 = BFVec4(std::move(s0.value)); + auto d1 = BFVec4(std::move(s1.value)); + BFVec4::save(destY + 8 * x + 0, d0); + BFVec4::save(destY + 8 * x + 4, d1); } - float originValue = MNNLowpToFp32(dstY[yRemain]); - if (nullptr != bias) { - originValue = MNNLowpToFp32(bias[ry]); - } - auto dstValue = originValue * beta + alpha * summer; - dstValue = std::min(dstValue, maxValue); - dstValue = std::max(dstValue, minValue); - dstY[yRemain] = MNNFP32ToBF16(dstValue); - } - } -} - -void MNNPackedMatMul_BF16(float* C, const float* A, const float* B, const size_t* parameter, float* cache, - const float* postParameters, const float* bias, const float* k, const float* b) { - return MNNPackedMatMulRemain_BF16(C, A, B, 16, parameter, cache, postParameters, bias, nullptr, nullptr); - // return _AVX_MNNPackedMatMulFMA(C, A, B, parameter, cache); -} - - -static void _MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow); - -static void _MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigthF, float *destF, int cacheLineSize, int ow, const float* bias, const float* parameters) { - auto weigth = (const int16_t*)weigthF; - auto dest = (int16_t*)destF; - int unit = ow / 2; - auto biasF = BFVec4::load((const int16_t*)bias); - auto minV = BFVec4(parameters[2]); - auto maxV = BFVec4(parameters[3]); - MNN_ASSERT(cacheLineSize >= 1); - for (int x = 0; x < unit; ++x) { - auto offset = 4 * 4 * x; - int i = 0; - BFVec4 m0 = BFVec4::load(weigth + i * 16 + 4 * 0) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 0); - BFVec4 m1 = BFVec4::load(weigth + i * 16 + 4 * 1) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 1); - BFVec4 m2 = BFVec4::load(weigth + i * 16 + 4 * 2) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 2); - BFVec4 m3 = BFVec4::load(weigth + i * 16 + 4 * 3) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 3); - - for (i = 1; i < cacheLineSize; ++i) { - m0 = m0 + BFVec4::load(weigth + i * 16 + 4 * 0) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 0); - m1 = m1 + BFVec4::load(weigth + i * 16 + 4 * 1) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 1); - m2 = m2 + BFVec4::load(weigth + i * 16 + 4 * 2) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 2); - m3 = m3 + BFVec4::load(weigth + i * 16 + 4 * 3) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 3); - } - - auto o0 = m0 + m1 + m2 + biasF; - auto o1 = m1 - m2 + m3 + biasF; - o0 = BFVec4::min(o0, maxV); - o1 = BFVec4::min(o1, maxV); - o0 = BFVec4::max(o0, minV); - o1 = BFVec4::max(o1, minV); - BFVec4::save(dest + 8 * x + 0 * 4, o0); - BFVec4::save(dest + 8 * x + 1 * 4, o1); - } - if (unit * 2 < ow) { - auto offset = 4 * 4 * unit; - int i = 0; - BFVec4 m0 = BFVec4::load(weigth + i * 16 + 4 * 0) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 0); - BFVec4 m1 = BFVec4::load(weigth + i * 16 + 4 * 1) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 1); - BFVec4 m2 = BFVec4::load(weigth + i * 16 + 4 * 2) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 2); - - for (i = 1; i < cacheLineSize; ++i) { - m0 = m0 + BFVec4::load(weigth + i * 16 + 4 * 0) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 0); - m1 = m1 + BFVec4::load(weigth + i * 16 + 4 * 1) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 1); - m2 = m2 + BFVec4::load(weigth + i * 16 + 4 * 2) * BFVec4::load((int16_t*)cacheLine[i] + offset + 4 * 2); - } - - auto o0 = m0 + m1 + m2 + biasF; - o0 = BFVec4::min(o0, maxV); - o0 = BFVec4::max(o0, minV); - BFVec4::save(dest + 8 * unit + 0 * 4, o0); - } -} -static void _MNNConvDwF23SourceTransUnit(const int16_t *source, int16_t *dest, size_t unit); -static void _MNNSourceTransformCommonF23(const float *sourceF, float *destF, int unit, int iw, int pad, int su, int eu) { - auto source = (const int16_t*)sourceF; - auto dest = (int16_t*)destF; - for (int x = 0; x < su; ++x) { - auto dstX = dest + 4 * 4 * x; - auto sx = x * 2 - (int)pad; - auto ex = sx + 4; - - auto clampSx = std::max(sx, 0); - auto clampEx = std::min(ex, (int)iw); - - BFVec4 v[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - for (int i = clampSx; i < clampEx; ++i) { - v[i - sx] = BFVec4::load(source + 4 * i); } - auto m0 = v[0] - v[2]; - auto m1 = v[1] + v[2]; - auto m2 = v[2] - v[1]; - auto m3 = v[3] - v[1]; - - BFVec4::save(dstX + 4 * 0, m0); - BFVec4::save(dstX + 4 * 1, m1); - BFVec4::save(dstX + 4 * 2, m2); - BFVec4::save(dstX + 4 * 3, m3); - } - _MNNConvDwF23SourceTransUnit(source + 4 * (su * 2 - pad), dest + 4 * 4 * su, eu - su); - - for (int x = eu; x < unit; ++x) { - auto dstX = dest + 4 * 4 * x; - auto sx = x * 2 - (int)pad; - auto ex = sx + 4; - - auto clampSx = std::max(sx, 0); - auto clampEx = std::min(ex, (int)iw); - - BFVec4 v[4] = {0.0f, 0.0f, 0.0f, 0.0f}; - for (int i = clampSx; i < clampEx; ++i) { - v[i - sx] = BFVec4::load(source + 4 * i); - } - auto m0 = v[0] - v[2]; - auto m1 = v[1] + v[2]; - auto m2 = v[2] - v[1]; - auto m3 = v[3] - v[1]; - - BFVec4::save(dstX + 4 * 0, m0); - BFVec4::save(dstX + 4 * 1, m1); - BFVec4::save(dstX + 4 * 2, m2); - BFVec4::save(dstX + 4 * 3, m3); - } -} - -static void _MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigthF, float *destF, size_t ow, const float* bias, const float* parameters) { - int unit = ow / 2; - auto weigth = (const int16_t*)weigthF; - auto dest = (int16_t*)destF; - - auto w00 = BFVec4::load(weigth + 0 * 16 + 4 * 0); - auto w01 = BFVec4::load(weigth + 0 * 16 + 4 * 1); - auto w02 = BFVec4::load(weigth + 0 * 16 + 4 * 2); - auto w03 = BFVec4::load(weigth + 0 * 16 + 4 * 3); - auto w10 = BFVec4::load(weigth + 1 * 16 + 4 * 0); - auto w11 = BFVec4::load(weigth + 1 * 16 + 4 * 1); - auto w12 = BFVec4::load(weigth + 1 * 16 + 4 * 2); - auto w13 = BFVec4::load(weigth + 1 * 16 + 4 * 3); - auto w20 = BFVec4::load(weigth + 2 * 16 + 4 * 0); - auto w21 = BFVec4::load(weigth + 2 * 16 + 4 * 1); - auto w22 = BFVec4::load(weigth + 2 * 16 + 4 * 2); - auto w23 = BFVec4::load(weigth + 2 * 16 + 4 * 3); - - auto biasF = BFVec4::load((const int16_t*)bias); - auto minV = BFVec4(parameters[2]); - auto maxV = BFVec4(parameters[3]); - for (int x = 0; x < unit; ++x) { - auto offset = 4 * 4 * x; - int i = 0; - BFVec4 m0 = w00 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 0); - BFVec4 m1 = w01 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 1); - BFVec4 m2 = w02 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 2); - BFVec4 m3 = w03 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 3); - - m0 = m0 + w10 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 0); - m1 = m1 + w11 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 1); - m2 = m2 + w12 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 2); - m3 = m3 + w13 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 3); - - m0 = m0 + w20 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 0); - m1 = m1 + w21 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 1); - m2 = m2 + w22 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 2); - m3 = m3 + w23 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 3); - - auto o0 = m0 + m1 + m2 + biasF; - auto o1 = m1 - m2 + m3 + biasF; - o0 = BFVec4::min(o0, maxV); - o1 = BFVec4::min(o1, maxV); - o0 = BFVec4::max(o0, minV); - o1 = BFVec4::max(o1, minV); - BFVec4::save(dest + 8 * x + 0 * 4, o0); - BFVec4::save(dest + 8 * x + 1 * 4, o1); - } - if (unit * 2 < ow) { - auto offset = 4 * 4 * unit; - BFVec4 m0 = w00 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 0); - BFVec4 m1 = w01 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 1); - BFVec4 m2 = w02 * BFVec4::load((int16_t*)cacheLine[0] + offset + 4 * 2); - - m0 = m0 + w10 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 0); - m1 = m1 + w11 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 1); - m2 = m2 + w12 * BFVec4::load((int16_t*)cacheLine[1] + offset + 4 * 2); - - m0 = m0 + w20 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 0); - m1 = m1 + w21 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 1); - m2 = m2 + w22 * BFVec4::load((int16_t*)cacheLine[2] + offset + 4 * 2); - auto o0 = m0 + m1 + m2 + biasF; - o0 = BFVec4::min(o0, maxV); - o0 = BFVec4::max(o0, minV); - BFVec4::save(dest + 8 * unit + 0 * 4, o0); - } -} -static void _MNNConvDwF23SourceTransUnit(const int16_t *source, int16_t *dest, size_t unit) { - if (unit <= 0) { return; } - BFVec4 v0 = BFVec4::load(source + 4 * 0); - BFVec4 v1 = BFVec4::load(source + 4 * 1); - BFVec4 v2; - BFVec4 v3; - source += 8; - - for (int x = 0; x < unit; ++x) { - v2 = BFVec4::load(source + 0 * 4); - v3 = BFVec4::load(source + 1 * 4); - auto m0 = v0 - v2; - auto m1 = v1 + v2; - auto m2 = v2 - v1; - auto m3 = v3 - v1; - - BFVec4::save(dest + 4 * 0, m0); - BFVec4::save(dest + 4 * 1, m1); - BFVec4::save(dest + 4 * 2, m2); - BFVec4::save(dest + 4 * 3, m3); - - source += 8; - dest += 16; - - v0 = v2; - v1 = v3; - } -} - -static void _MNNMatrixSub(float* CF, const float* AF, const float* BF, size_t widthC4, size_t cStride, size_t aStride, - size_t bStride, size_t height) { - auto A = (int16_t*)AF; - auto B = (int16_t*)BF; - auto C = (int16_t*)CF; - for (int y = 0; y < height; ++y) { - auto a = A + aStride * y; - auto b = B + bStride * y; - auto c = C + cStride * y; - for (int x = 0; x < widthC4; ++x) { - BFVec4::save(c + 4 * x, BFVec4::load(a + 4 * x) - BFVec4::load(b + 4 * x)); - } - } -} -static void _MNNMatrixAdd(float* CF, const float* AF, const float* BF, size_t widthC4, size_t cStride, size_t aStride, - size_t bStride, size_t height) { - auto A = (int16_t*)AF; - auto B = (int16_t*)BF; - auto C = (int16_t*)CF; - for (int y = 0; y < height; ++y) { - auto a = A + aStride * y; - auto b = B + bStride * y; - auto c = C + cStride * y; - for (int x = 0; x < widthC4; ++x) { - BFVec4::save(c + 4 * x, BFVec4::load(a + 4 * x) + BFVec4::load(b + 4 * x)); - } - } -} - -static void _MNNStrassenMergeCFunction(float* c11F, float* c12F, float* c21F, float* c22F, float* xAddrF, size_t cStride, - size_t eSub, size_t hSub) { - auto c11 = (int16_t*)c11F; - auto c12 = (int16_t*)c12F; - auto c21 = (int16_t*)c21F; - auto c22 = (int16_t*)c22F; - auto xAddr = (int16_t*)xAddrF; - for (int y=0; y> 16; } } -} - -size_t _MNNGridSampleComputeOffset(int h, int w, int height, int width, bool padMode) { - if (padMode == true) { //padMode == BorderMode_ZEROS - if (h < 0 || h >= height || w < 0 || w >= width) { - return -1; + return; +#endif + int lC8 = (int)l / 8; + auto lR = lC8 * 8; + if (hP > 0 && lC8 > 0) { + MNNPackC8_BF16(destFloat, sourceFloat, l, h); + } + for (int y = hR; y < h; ++y) { + auto yR = y % 8; + auto yC = hP; + for (int x = 0; x < l; ++x) { + dest[x * 8 + yR + yC * 8 * l] = sourceInt32[x + y * l] >> 16; } - } else { - // Clearly, CLAMP is the right way to go for GridSamplePaddingMode_BORDER - // For GridSamplePaddingMode_REFLECTION, since we have reflected the values into (-1, 1), - // the leftover reflections degrade to GridSamplePaddingMode_BORDER - h = h < 0 ? 0 : ( h > (height - 1) ? (height - 1) : h); - w = w < 0 ? 0 : ( w > (width - 1) ? (width - 1) : w); } - return h * width * 4 + w * 4; -} - -void _MNNGridSampleInterp(float* output, const float* input, const float* cord, size_t inH, size_t inW, size_t outW, size_t channelCUnit, size_t inOffset, size_t outOffset, bool sampleMode, bool padMode) { - int16_t* outputPtr = (int16_t*)output; - const int16_t* inputPtr = (const int16_t*)input; - const int16_t* cordPtr = (const int16_t*)cord; - - for (auto ow = 0; ow < outW; ++ow) { - auto w = MNNLowpToFp32(cordPtr[2 * ow + 0]); - auto h = MNNLowpToFp32(cordPtr[2 * ow + 1]); - BFVec4 interp; - - if (sampleMode == true) { //sampleMode == SampleMode_NEAREST - int nh = ::floor(h + 0.5f); - int nw = ::floor(w + 0.5f); - size_t ns = _MNNGridSampleComputeOffset(nh, nw, inH, inW, padMode); - for (int k = 0; k < channelCUnit; ++k) { - interp = ns == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + ns); - BFVec4::save(outputPtr + k * outOffset + 4 * ow, interp); - } - } else { //sampleMode == GridSampleMode_BILINEAR - int w0_h = ::floor(h); - int w0_w = ::floor(w); - int w1_h = ::ceil(h); - int w1_w = ::ceil(w); - auto oneV = BFVec4(1.0f); - - auto f0 = BFVec4((float)w1_w - w); - auto f1 = oneV - f0; - auto h0 = BFVec4((float)w1_h - h); - auto h1 = oneV - h0; - - size_t s00 = _MNNGridSampleComputeOffset(w0_h, w0_w, inH, inW, padMode); - size_t s01 = _MNNGridSampleComputeOffset(w0_h, w1_w, inH, inW, padMode); - size_t s10 = _MNNGridSampleComputeOffset(w1_h, w0_w, inH, inW, padMode); - size_t s11 = _MNNGridSampleComputeOffset(w1_h, w1_w, inH, inW, padMode); - - for (int k = 0; k < channelCUnit; ++k) { - BFVec4 i00 = s00 == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + s00); - BFVec4 i01 = s01 == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + s01); - BFVec4 i10 = s10 == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + s10); - BFVec4 i11 = s11 == -1 ? BFVec4(0.f) : BFVec4::load(inputPtr + k * inOffset + s11); - - BFVec4 i0 = i00 * f0 + i01 * f1; - BFVec4 i1 = i10 * f0 + i11 * f1; - - interp = i0 * h0 + i1 * h1; - BFVec4::save(outputPtr + k * outOffset + 4 * ow, interp); - } + for (int y = 0; y < hR; ++y) { + auto yR = y % 8; + auto yC = y / 8; + for (int x = lR; x < l; ++x) { + dest[x * 8 + yR + yC * 8 * l] = sourceInt32[x + y * l] >> 16; } } } - -static void _MNNAddC4WithStride(const float* sourceF, float* destF, size_t srcStride, size_t dstStride, size_t count) { - auto source = (const int16_t*)sourceF; - auto dest = (int16_t*)destF; - for (int i = 0; i < count; ++i) { - auto s = source + i * srcStride; - auto d = dest + i * dstStride; - BFVec4::save(d, BFVec4::load(d) + BFVec4::load(s)); - } -} -static void _MNNDeconvRunForUnitDepthWise(const int16_t* dst, int16_t* src, const int16_t* weight, size_t fw, size_t fh, - size_t weight_y_step, size_t dilateX_step, size_t dilateY_step) { - int fx, fy; - auto src_z = src; - auto weight_z = weight; - BFVec4 dstV = BFVec4::load(dst); - for (fy = 0; fy < fh; ++fy) { - auto src_y = src_z + fy * dilateY_step; - auto weight_y = weight_z + fy * weight_y_step; - for (fx = 0; fx < fw; ++fx) { - BFVec4 weight_x = BFVec4::load(weight_y + 4 * fx); - BFVec4 src_x = BFVec4::load(src_y + fx * dilateX_step); - BFVec4::save(src_y + fx * dilateX_step, src_x + weight_x * dstV); +#else +void NEON_MNNPackForMatMul_B_BF16(float* destFloat, const float* sourceFloat, size_t h, size_t l, bool transpose) { + int16_t* dest = (int16_t*)destFloat; + const float* source = sourceFloat; + if (!transpose) { + auto hP = h / 4; + auto hR = hP * 4; + if (hR != h) { + ::memset(dest, 0, UP_DIV(h, 4) * 4 * l * sizeof(int16_t)); } - } -} -static void _MNNDeconvRunForLineDepthwise(const int16_t* dst, int16_t* src, const int16_t* weight, size_t width, size_t src_w_setup, - size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step) { - int dx; - for (dx = 0; dx < width; ++dx) { - auto dst_x = dst + dx * 4; - auto src_dx = src + src_w_setup * dx; - _MNNDeconvRunForUnitDepthWise(dst_x, src_dx, weight, fw, fh, fw * 4, dilateX_step, dilateY_step); - } -} - -static void _MNNComputeMatMulForH_1_BF16(const float* AF, const float* BF, float* CF, const float* biasPtrF, const MatMulParam* param, size_t tId) { - auto A = (const int16_t*)AF; - auto B = (const int16_t*)BF; - auto C = (int16_t*)CF; - auto biasPtr = (const int16_t*)biasPtrF; - int e = param->e; - int l = param->l; - int numberThread = param->numberThread; - float biasValue = 0.0f; - auto bf = BF16Functions::get(); - if (nullptr != biasPtr) { - bf->MNNLowpToFp32(biasPtr, &biasValue, 1); - } - if (param->ATranspose) { - auto eC4 = e / 4; - auto eR = e % 4; - for (int y=tId; y 0) { - BFVec4 sumValue = BFVec4(biasValue); - auto srcY = A + eC4 * 4; - int16_t AR[4]; - for (int x=0; x 0) { + auto destY = dest + hP * 4 * l; + auto sourceY = source + hP * 4; + for (int x = 0; x < l; ++x) { + auto s0 = Vec4::load(sourceY + x * h + 0); + auto d0 = BFVec4(std::move(s0.value)); + BFVec4::save(destY + 4 * x + 0, d0); } - int16_t CR[4]; - BFVec4::save(CR, sumValue); - ::memcpy(C + 4 * eC4, CR, eR * sizeof(int16_t)); } return; } - auto lC4 = l / 4; - auto lR = l % 4; - for (int y=tId; y 0) { - int16_t AR[4] = {0, 0, 0, 0}; - int16_t BR[4] = {0, 0, 0, 0}; - ::memcpy(AR, srcY + lC4 * 4, lR * sizeof(int16_t)); - ::memcpy(BR, B + 4 * lC4, lR * sizeof(int16_t)); - sumValue = sumValue + BFVec4::load(AR) * BFVec4::load(BR); +#if 0 + auto sourceInt32 = (const int32_t*)source; + // Origin C++ code + ::memset(dest, 0, UP_DIV(h, 4) * 4 * l * sizeof(int16_t)); + + for (int y = 0; y < h; ++y) { + auto yR = y % 4; + auto yC = y / 4; + for (int x = 0; x < l; ++x) { + dest[x * 4 + yR + yC * 4 * l] = sourceInt32[x + y * l] >> 16; } - float sumSingle = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; - bf->MNNFp32ToLowp(&sumSingle, C + y, 1); } + return; +#endif + int offset[2] = { + (int)l, + (int)l, + }; + MNNPackC4_BF16(destFloat, sourceFloat, l, h, offset); } +#endif // __aarch64__ +#endif -static void _MNNComputeMatMulForE_1_BF16(const float* AF, const float* BF, float* CF, const float* biasPtrF, const MatMulParam* param, size_t tId) { - auto l = param->l; - auto h = param->h; - auto numberThread = param->numberThread; - auto lC4 = l / 4; - auto lR = l % 4; - auto A = (const int16_t*)AF; - auto B = (const int16_t*)BF; - auto C = (int16_t*)CF; - auto biasPtr = (const int16_t*)biasPtrF; - auto bf16 = BF16Functions::get(); - if (param->BTranspose) { - for (int y=tId; y 0) { - int16_t AR[4] = {0, 0, 0, 0}; - int16_t BR[4] = {0, 0, 0, 0}; - ::memcpy(AR, A + lC4 * 4, lR * sizeof(int16_t)); - ::memcpy(BR, by + 4 * lC4, lR * sizeof(int16_t)); - sumValue = sumValue + BFVec4::load(AR) * BFVec4::load(BR); - } - float sumRemain = sumValue[0] + sumValue[1] + sumValue[2] + sumValue[3]; - if (nullptr != biasPtr) { - sumRemain += BFVec4::broadcast(biasPtr[y])[0]; - } - bf16->MNNFp32ToLowp(&sumRemain, C + y, 1); - } - } else { - auto hC4 = h / 4; - auto hR = h % 4; - for (int y=tId; y> 16; } - BFVec4::save(C + 4 * y, sumValue); - } - if (tId == 0 && hR > 0) { - auto bs = B + 4 * hC4; - BFVec4 sumValue = BFVec4(0.0f); - if (biasPtr != nullptr) { - int16_t biasTemp[4]; - ::memcpy(biasTemp, biasPtr + 4 * hC4, hR * sizeof(int16_t)); - sumValue = BFVec4::load(biasTemp); - } - auto srcY = A + 4 * hC4 * l; - int16_t bTemp[4]; - for (int x=0; xMNNConvRunForLineDepthwise = MNNConvRunForLineDepthwiseBF16; - gInstance->MNNConvRunForUnitDepthWise = MNNConvRunForUnitDepthWiseBF16; - gInstance->MNNAxByClampBroadcastUnit = MNNAxByClampBroadcastUnitBF16; + *gInstance = *MNNGetCoreFunctions(); gInstance->MNNFp32ToLowp = _MNNFp32ToLowp; gInstance->MNNLowpToFp32 = _MNNLowpToFp32; - gInstance->bytes = 2; - gInstance->pack = 4; - gInstance->MNNPackCUnit = (decltype(gInstance->MNNPackCUnit))MNNPackC4Int16; - gInstance->MNNUnpackCUnit = (decltype(gInstance->MNNUnpackCUnit))MNNUnpackC4Int16; - gInstance->MNNUnpackCUnitTranspose = (decltype(gInstance->MNNUnpackCUnitTranspose))MNNPackTransposeInt16; - gInstance->MNNPackCUnitTranspose = (decltype(gInstance->MNNPackCUnitTranspose))MNNUnpackTransposeInt16; - gInstance->MNNConvDwF23MulTransUnit = _MNNConvDwF23MulTransUnit; - gInstance->MNNSourceTransformCommonF23 = _MNNSourceTransformCommonF23; - gInstance->MNNMultiAndDestTransformCommon23 = _MNNMultiAndDestTransformCommon23; - gInstance->MNNMatrixAdd = _MNNMatrixAdd; - gInstance->MNNMatrixSub = _MNNMatrixSub; - gInstance->MNNStrassenMergeCFunction = _MNNStrassenMergeCFunction; - gInstance->penalty = 10.0f; - gInstance->MNNScaleAndAddBias = _MNNScaleAndAddBias; - gInstance->MNNGridSampleComputeCord = _MNNGridSampleComputeCord; - gInstance->MNNGridSampleInterp = _MNNGridSampleInterp; - gInstance->MNNCopyC4WithStride = MNNCopyC4Int16WithStride; - gInstance->MNNAddC4WithStride = _MNNAddC4WithStride; - gInstance->chooseWinoSourceTransformPack = (decltype(gInstance->chooseWinoSourceTransformPack))(WinogradFunctionHalf::chooseWinoSourceTransformPack); - gInstance->chooseWinoSourceUnrollTransform = (decltype(gInstance->chooseWinoSourceUnrollTransform))(WinogradFunctionHalf::chooseSourceUnrollTransform); - gInstance->chooseWinoDestUnrollTransform = (decltype(gInstance->chooseWinoDestUnrollTransform))(WinogradFunctionHalf::chooseWinoDestUnrollTransform); - gInstance->MNNDeconvRunForLineDepthwise = (decltype(gInstance->MNNDeconvRunForLineDepthwise))_MNNDeconvRunForLineDepthwise; - gInstance->MNNDeconvRunForUnitDepthWise = (decltype(gInstance->MNNDeconvRunForUnitDepthWise))_MNNDeconvRunForUnitDepthWise; - gInstance->MNNSelectBinaryFunctionForFloat = BF16BinaryFloatSelect; - gInstance->MNNSelectUnaryFunctionForFloat = BF16UnaryFloatSelect; - gInstance->MNNReluWithSlopeChannel = MNNReluWithSlopeChannelBF16;// TODO: Optimize it - -#if !defined(MNN_USE_SSE) && !defined(MNN_USE_NEON) - gInstance->penalty = 1.5f; - gInstance->MNNPackForMatMul_B = MNNPackForMatMul_B_BF16; // common function MNNPackForMatMul_B_BF16 is needed even with out sse or arm neon. - gInstance->MNNPackC4ForMatMul_A = MNNPackC4ForMatMul_A_BF16;// - gInstance->MNNPackedMatMul = (decltype(gInstance->MNNPackedMatMul))MNNPackedMatMul_BF16; - gInstance->MNNPackedMatMulRemain = (decltype(gInstance->MNNPackedMatMulRemain))MNNPackedMatMulRemain_BF16; -#endif - gInstance->MNNComputeMatMulForH_1 = _MNNComputeMatMulForH_1_BF16; - gInstance->MNNComputeMatMulForE_1 = _MNNComputeMatMulForE_1_BF16; - gInstance->MNNPoolingAvg = (decltype(gInstance->MNNPoolingAvg))(poolingAvg); - gInstance->MNNPoolingMax = (decltype(gInstance->MNNPoolingMax))(poolingMax); - gInstance->MNNPoolingMaxWithRedice = (decltype(gInstance->MNNPoolingMaxWithRedice))(poolingMaxWithRedice); + gInstance->matmulBytes = 2; -#if defined(MNN_USE_SSE) - gInstance->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B_BF16; - auto cpuFlags = libyuv::InitCpuFlags(); - if (!(cpuFlags & libyuv::kCpuHasF16C)) { - delete gInstance; - gInstance = nullptr; - return false; - } - if (cpuFlags & libyuv::kCpuHasAVX2) { - gInstance->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B_BF16; - gInstance->MNNGetMatMulPackMode = _AVX_MNNGetMatMulPackMode_BF16; - gInstance->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A_BF16; - gInstance->MNNPackedMatMul = _AVX_MNNPackedMatMulFMA_BF16; - gInstance->MNNPackedMatMulRemain = _AVX_MNNPackedMatMulRemainFMA_BF16; - return true; - } -#elif defined(MNN_USE_NEON) gInstance->MNNPackForMatMul_B = NEON_MNNPackForMatMul_B_BF16; gInstance->MNNGetMatMulPackMode = NEON_MNNGetMatMulPackMode_BF16; gInstance->MNNPackC4ForMatMul_A = NEON_MNNPackC4ForMatMul_A_BF16; gInstance->MNNPackedMatMul = NEON_MNNPackedMatMul_BF16; gInstance->MNNPackedMatMulRemain = NEON_MNNPackedMatMulRemain_BF16; - gInstance->MNNConvRunForLineDepthwise = NEON_MNNConvRunForLineDepthwise_BF16; - gInstance->MNNConvRunForUnitDepthWise = NEON_MNNConvRunForUnitDepthWise_BF16; - gInstance->MNNAxByClampBroadcastUnit = NEON_MNNAxByClampBroadcastC4_BF16; #ifdef __aarch64__ - cpuinfo_arm_isa gCPUInfo; - cpuinfo_arm_init(&gCPUInfo); + const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo(); gInstance->supportFp16arith = gCPUInfo.fp16arith; gInstance->supportSDot = gCPUInfo.dot; gInstance->supportI8mm = gCPUInfo.i8mm; @@ -906,10 +354,11 @@ bool BF16Functions::init() { gInstance->MNNPackedMatMulRemain = ARMV86_MNNPackedMatMulRemain_BF16; } #endif - return true; -#endif + gInstance->MNNPackedMatMul_int4 = nullptr; + gInstance->MNNPackedMatMul_int8 = nullptr; // TODO: raw cpu version of bf16 return true; +#endif } CoreFunctions* BF16Functions::get() { diff --git a/source/backend/cpu/bf16/CMakeLists.txt b/source/backend/cpu/bf16/CMakeLists.txt index b533bec6f..7dc34a113 100644 --- a/source/backend/cpu/bf16/CMakeLists.txt +++ b/source/backend/cpu/bf16/CMakeLists.txt @@ -9,11 +9,3 @@ add_library( ${MNN_BF16_SRCS} ) target_compile_options(MNN_BF16 PRIVATE -DMNN_SUPPORT_BF16) -if(CMAKE_SYSTEM_PROCESSOR MATCHES "(x86_64)|(X86_64)|(x64)|(X64)|(amd64)|(AMD64)|(i686)") - if (MNN_USE_SSE) - target_compile_options(MNN_BF16 PRIVATE -DMNN_USE_SSE) - if (MNN_SSE_USE_FP16_INSTEAD) - target_compile_options(MNN_BF16 PRIVATE -DMNN_SSE_USE_FP16_INSTEAD -mf16c) - endif() - endif() -endif() diff --git a/source/backend/cpu/compute/CommonOptFunction.cpp b/source/backend/cpu/compute/CommonOptFunction.cpp index f9ce9567c..897f10b40 100644 --- a/source/backend/cpu/compute/CommonOptFunction.cpp +++ b/source/backend/cpu/compute/CommonOptFunction.cpp @@ -35,16 +35,279 @@ void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count) { } #endif -#if defined(__aarch64__) #ifdef MNN_LOW_MEMORY -extern "C" { -void MNNGemmHybridInt4FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); -void MNNGemmHybridInt8FP32_smmla(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); -void MNNGemmHybridInt4FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); -void MNNGemmHybridInt8FP32_sdot(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); +#ifndef __aarch64__ +static void _MNNPackedMatMulRemain_int4(float* C, const float* A, const float* fB, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, int aStride, const float* k, const float* b) { + auto B = reinterpret_cast(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto hRemain = parameter[4]; + float weightBytes = 0.5; // sizeof(int4_t) + auto bExtraStride = static_cast(parameter[5] / weightBytes); + auto bStride = bExtraStride + 4 * l; + auto hC4 = UP_DIV(h, 4); + float minValue = -std::numeric_limits().max(); + float maxValue = std::numeric_limits().max(); + if (nullptr != postParameters) { + minValue = postParameters[2]; + maxValue = postParameters[3]; + } + int blockId = parameter[6]; + + for (int x=0; x 0) { + summer[0] = dstY[0]; + summer[1] = dstY[1]; + summer[2] = dstY[2]; + summer[3] = dstY[3]; + } + if (nullptr != bias && nullptr != postParameters) { + for (int v=0; v<4; ++v) { + summer[v] += bias[4 * y + v]; + } + } + for (int z=0; z(fB); + auto h = parameter[2]; + auto l = parameter[1]; + auto cStride = parameter[3] / sizeof(float); + auto hRemain = parameter[4]; + float weightBytes = 1; // sizeof(int8_t) + auto bExtraStride = static_cast(parameter[5] / weightBytes); + auto bStride = bExtraStride + 4 * l; + auto hC4 = UP_DIV(h, 4); + float minValue = -std::numeric_limits().max(); + float maxValue = std::numeric_limits().max(); + if (nullptr != postParameters) { + minValue = postParameters[2]; + maxValue = postParameters[3]; + } + int blockId = parameter[6]; + + for (int x=0; x 0) { + summer[0] = dstY[0]; + summer[1] = dstY[1]; + summer[2] = dstY[2]; + summer[3] = dstY[3]; + } + if (nullptr != bias && nullptr != postParameters) { + for (int v=0; v<4; ++v) { + summer[v] += bias[4 * y + v]; + } + } + for (int z=0; z=0 + for (int c = 0; c < src_depth_quad; ++c) { + auto src = source + c * srcStep + i * pack; + for (int k = 0; k < pack; ++k) { + absmaxVal = std::max(absmaxVal, std::abs(src[k])); + } + } + absmax[i] = absmaxVal; + } +} +void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch) { + for (int i = 0; i < batch; ++i) { + auto absmaxPtr = absmax + i; + float absVal = 0.f; + for (int t = 0; t < thread; ++t) { + absVal = std::max(absVal, absmaxPtr[t * batch]); + } + quant_scale[i] = 127.0f / absVal; + dequant_scale[i] = absVal / 127.0f; + } } +void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size_t batch) { + for (int i = 0; i < batch; ++i) { + auto sumPtr = reinterpret_cast(sum) + i; + int sumVal = 0.f; + for (int t = 0; t < thread; ++t) { + sumVal += sumPtr[t * batch]; + } + sum[i] = sumVal * dequant_scale[i]; + } +} +void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack) { +#ifdef MNN_USE_SSE + uint8_t* dstPtr = reinterpret_cast(dst); + int offset = 128; +#else + int8_t* dstPtr = dst; + int offset = 0; #endif + for (int i = 0; i < realSize; ++i) { + auto scaleVal = scale[i]; + for (int c = 0; c < src_depth_quad; ++c) { + auto srcZ = src + c * pack * realSize + i * pack; + auto dstZ = dstPtr + c * pack * realSize + i * pack; + for (int k = 0; k < pack; ++k) { + int val = (int)roundf(srcZ[k] * scaleVal); + dstZ[k] = val + offset; + } + } + } +} + +void MNNDynamicUpdateConvBiasScale(float* newbias, float* newscale, float* oldbias, float* weightScale, float* inputScale, float* weightKernelSum, float* inputZero, size_t ocQuad, size_t scaleSize) { + int ocUp4 = 4 * ocQuad; + int pack = 4; + int blockNum = scaleSize / ocUp4; + for (int i = 0; i < ocUp4; ++i) { + newbias[i] = oldbias[i] - weightKernelSum[i] * inputZero[0]; + } + for (int k = 0; k < blockNum; ++k) { + for (int i = 0; i < ocUp4; ++i) { + newscale[i + k * ocUp4] = weightScale[i + k * ocUp4] * inputScale[0]; + } + } +} + +#endif // not __aarch64__ +#endif // LOW_MEMORY + + +static void MNNSumByAxisLForMatmul_A(float* dest, int8_t* source, const float* scale, ssize_t realDstCount, SumByAxisParams sumParams) { +#ifdef MNN_USE_SSE + uint8_t* srcInt8 = reinterpret_cast(source); +#else + int8_t* srcInt8 = source; #endif + auto scalePtr = scale; + auto kernelCountUnitDouble = sumParams.kernelCountUnitDouble; + auto blockNum = sumParams.blockNum; + auto EP = sumParams.DST_XUNIT; + auto LP = sumParams.SRC_UNIT; + auto blockSizeQuad = kernelCountUnitDouble / blockNum; + auto col_buffer_unit_size = sumParams.col_buffer_unit_size; + auto oneScale = sumParams.oneScale; + do { + int step = ALIMIN(EP, realDstCount); + + for (int k = 0; k < blockNum; ++k) { + // const auto src_x = srcInt8 + w * LP; + const auto src_x = srcInt8 + k * (EP * LP * blockSizeQuad); + for (int w = 0; w < step; ++w) { + float dequantScale = scale[0]; + if (oneScale == 0) { + dequantScale = scalePtr[w]; + } + int sumint32 = 0; + const auto src_y = src_x + w * LP; + for (int j = 0; j < blockSizeQuad; ++j) { + const auto src_z = src_y + j * (EP * LP); + for (int i = 0; i < LP; ++i) { + sumint32 += src_z[i]; + } + } + dest[w + k * step] = dequantScale * static_cast(sumint32); + } + } + scalePtr += step; + + dest += (step * blockNum); + realDstCount -= step; + srcInt8 += col_buffer_unit_size; + } while(realDstCount > 0); +} template void MNNPackC4Common(T* dst, const T* src, size_t area, size_t depth, int* areaOffset) { @@ -461,11 +724,6 @@ void MNNCountMaxMinValue(float* source, float* minVal, float* maxVal, size_t siz } *minVal = min_; *maxVal = max_; - // float range = max_ - min_; - // MNN_ASSERT(range != 0); - // *quantScale = 255.0f / range; - // *dequantScale = range / 255.0f; - // *zeroPoint = std::min(255.f, std::max(roundf(-(min_ * 255.f) / range), 0.f)) - 128.0f; } #ifndef MNN_USE_NEON @@ -579,312 +837,6 @@ void MNNPackedMatMulRemain(float* C, const float* A, const float* B, size_t eSiz _MNNPackedMatMulRemain(C, A, B, eSize, parameter, postParameters, bias, aStride); } -#ifdef MNN_LOW_MEMORY -static void _MNNPackedMatMulRemain_int4(float* C, const float* A, const float* fB, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, int aStride, const float* k, const float* b) { - auto B = reinterpret_cast(fB); - auto h = parameter[2]; - auto l = parameter[1]; - auto cStride = parameter[3] / sizeof(float); - auto hRemain = parameter[4]; - float weightBytes = 0.5; // sizeof(int4_t) - auto bExtraStride = static_cast(parameter[5] / weightBytes); - auto bStride = bExtraStride + 4 * l; - auto hC4 = UP_DIV(h, 4); - float minValue = -std::numeric_limits().max(); - float maxValue = std::numeric_limits().max(); - if (nullptr != postParameters) { - minValue = postParameters[2]; - maxValue = postParameters[3]; - } - int blockId = parameter[6]; - - for (int x=0; x 0) { - summer[0] = dstY[0]; - summer[1] = dstY[1]; - summer[2] = dstY[2]; - summer[3] = dstY[3]; - } - if (nullptr != bias && nullptr != postParameters) { - for (int v=0; v<4; ++v) { - summer[v] += bias[4 * y + v]; - } - } - for (int z=0; z(fB); - auto h = parameter[2]; - auto l = parameter[1]; - auto cStride = parameter[3] / sizeof(float); - auto hRemain = parameter[4]; - float weightBytes = 1; // sizeof(int8_t) - auto bExtraStride = static_cast(parameter[5] / weightBytes); - auto bStride = bExtraStride + 4 * l; - auto hC4 = UP_DIV(h, 4); - float minValue = -std::numeric_limits().max(); - float maxValue = std::numeric_limits().max(); - if (nullptr != postParameters) { - minValue = postParameters[2]; - maxValue = postParameters[3]; - } - int blockId = parameter[6]; - - for (int x=0; x 0) { - summer[0] = dstY[0]; - summer[1] = dstY[1]; - summer[2] = dstY[2]; - summer[3] = dstY[3]; - } - if (nullptr != bias && nullptr != postParameters) { - for (int v=0; v<4; ++v) { - summer[v] += bias[4 * y + v]; - } - } - for (int z=0; z=0 - for (int c = 0; c < src_depth_quad; ++c) { - auto src = source + c * srcStep + i * pack; - for (int k = 0; k < pack; ++k) { - absmaxVal = std::max(absmaxVal, std::abs(src[k])); - } - } - absmax[i] = absmaxVal; - } -} -void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch) { - for (int i = 0; i < batch; ++i) { - auto absmaxPtr = absmax + i; - float absVal = 0.f; - for (int t = 0; t < thread; ++t) { - absVal = std::max(absVal, absmaxPtr[t * batch]); - } - quant_scale[i] = 127.0f / absVal; - dequant_scale[i] = absVal / 127.0f; - } -} -void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size_t batch) { - for (int i = 0; i < batch; ++i) { - auto sumPtr = reinterpret_cast(sum) + i; - int sumVal = 0.f; - for (int t = 0; t < thread; ++t) { - sumVal += sumPtr[t * batch]; - } - sum[i] = sumVal * dequant_scale[i]; - } -} -void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack) { -#ifdef MNN_USE_SSE - uint8_t* dstPtr = reinterpret_cast(dst); -#else - int8_t* dstPtr = dst; -#endif - for (int i = 0; i < realSize; ++i) { - auto scaleVal = scale[i]; - int acc = 0; - for (int c = 0; c < src_depth_quad; ++c) { - auto srcZ = src + c * pack * realSize + i * pack; - auto dstZ = dstPtr + c * pack * realSize + i * pack; - for (int k = 0; k < pack; ++k) { - int val = (int)roundf(srcZ[k] * scaleVal); - acc += val; - dstZ[k] = val; - } - } - ((int32_t*)sum)[i] = acc; - } -} -void MNNGemmHybridInt8FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) { - // C:(oc/4,N,4) A:(ic/4,N,4) B:(oc/4,ic/4,4,4) - int pack = 4; - size_t weight_step = src_depth_quad * pack * pack; - const float* alpha_ptr = param[0]; - const float* zero_ptr = param[1]; - const float* bias_ptr = param[2]; - const float* sums_ptr = param[3]; - const float* scale_ptr = param[4]; - for (int ci = 0; ci < dst_depth_quad; ++ci) { - float* dstZ = C + ci * pack * realSize; - const int8_t* weight = B + ci * weight_step; - auto alpha = alpha_ptr + ci * pack; - auto zero = zero_ptr + ci * pack; - auto bias = bias_ptr + ci * pack; - //const float* sums = param[2]; - for (int j = 0; j < realSize; ++j) { - const float* sums = sums_ptr + j; - const float* scale = scale_ptr + j; - float* dstX = dstZ + j * pack; - std::vector tmp(pack); - // int8_t* weightPtr = B + weight_step; - const int8_t* srcBatch = A + j * pack; - for (int k = 0; k < src_depth_quad; ++k) { - const int8_t* srcZ = srcBatch + k * pack * realSize; - const int8_t* weightZ = weight + k * pack * pack; - for (int cn = 0; cn < pack; ++cn) { // pack for oc - const auto weightj = weightZ + cn * pack; - for (int ck = 0; ck < pack; ++ck) { // pack for ic - tmp[cn] += (int32_t)srcZ[ck] * (int32_t)weightj[ck]; - } - } - } - - // int32->float - for (int cn = 0; cn < pack; ++cn) { - float val = (float)tmp[cn] * scale[0]; - val = bias[cn] + val * alpha[cn] + zero[cn] * sums[0]; - dstX[cn] = val; - } - } - } -} -void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) { - // C:(oc/4,N,4) A:(ic/4,N,4) B:(oc/4,ic/4,4,4) - int pack = 4; - size_t weight_step = src_depth_quad * pack * pack * 0.5; - size_t weight_stride = pack * pack / 2; - const float* alpha_ptr = param[0]; - const float* zero_ptr = param[1]; - const float* bias_ptr = param[2]; - const float* sums_ptr = param[3]; - const float* scale_ptr = param[4]; - for (int ci = 0; ci < dst_depth_quad; ++ci) { - float* dstZ = C + ci * pack * realSize; - const int8_t* weight = B + ci * weight_step; - auto alpha = alpha_ptr + ci * pack; - auto zero = zero_ptr + ci * pack; - auto bias = bias_ptr + ci * pack; - //const float* sums = param[2]; - for (int j = 0; j < realSize; ++j) { - const float* sums = sums_ptr + j; - const float* scale = scale_ptr + j; - float* dstX = dstZ + j * pack; - int tmp[4] = {0, 0, 0, 0}; - // int8_t* weightPtr = B + weight_step; - const int8_t* srcBatch = A + j * pack; - for (int k = 0; k < src_depth_quad; ++k) { - const int8_t* srcZ = srcBatch + k * pack * realSize; - const uint8_t* weightZ = (uint8_t*)weight + k * weight_stride; - int32_t tmpw[16]; - uint32_t c = 0xf; - for (int kk = 0; kk < 8; ++kk) { - tmpw[2 * kk] = (weightZ[kk]>>4) - 8; - tmpw[2 * kk + 1] = (weightZ[kk] & c) - 8; - } - for (int cn = 0; cn < pack; ++cn) { // pack for oc - const auto weightj = tmpw + cn * pack; - for (int ck = 0; ck < pack; ++ck) { // pack for ic - tmp[cn] += (int32_t)srcZ[ck] * (int32_t)weightj[ck]; - } - } - } - - // int32->float - for (int cn = 0; cn < pack; ++cn) { - float val = (float)tmp[cn] * scale[0]; - val = bias[cn] + val * alpha[cn] + zero[cn] * sums[0]; - dstX[cn] = val; - } - } - } -} -#endif - void MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el) { int number = info[0]; int eReal = info[1]; @@ -3298,16 +3250,6 @@ void MNNCoreFunctionInit() { gCoreFunction->MNNPackForMatMul_B = MNNPackForMatMul_B; gCoreFunction->MNNPackedMatMul = MNNPackedMatMul; gCoreFunction->MNNPackedMatMulRemain = MNNPackedMatMulRemain; -#ifdef MNN_LOW_MEMORY - gCoreFunction->MNNPackedMatMul_int4 = MNNPackedMatMul_int4; - gCoreFunction->MNNPackedMatMulRemain_int4 = MNNPackedMatMulRemain_int4; - gCoreFunction->MNNPackedMatMul_int8 = MNNPackedMatMul_int8; - gCoreFunction->MNNPackedMatMulRemain_int8 = MNNPackedMatMulRemain_int8; - gCoreFunction->MNNAbsMax = MNNAbsMaxFP32; - gCoreFunction->MNNDynamicQuant = MNNDynamicQuantFP32; - gCoreFunction->MNNQuantScale = MNNQuantScaleFP32; - gCoreFunction->MNNQuantSum = MNNQuantSumFP32; -#endif gCoreFunction->MNNCountMaxMinValue = MNNCountMaxMinValue; gCoreFunction->MNNGetSparseMatMulPackMode = MNNGetSparseMatMulPackMode; gCoreFunction->MNNAdjustOptimalSparseKernel = _MNNAdjustOptimalSparseKernel; @@ -3315,7 +3257,6 @@ void MNNCoreFunctionInit() { gCoreFunction->MNNComputeMatMulForE_1 = MNNComputeMatMulForE_1; gCoreFunction->MNNComputeMatMulForH_1 = MNNComputeMatMulForH_1; - // Lowp gCoreFunction->MNNFp32ToLowp = nullptr; gCoreFunction->MNNLowpToFp32 = nullptr; @@ -3394,24 +3335,24 @@ void MNNCoreFunctionInit() { gCoreFunction->MNNAccumulateSequenceNumber = MNNAccumulateSequenceNumber; - cpuinfo_arm_isa gCPUInfo; - cpuinfo_arm_init(&gCPUInfo); + const MNNCPUInfo& gCPUInfo = *MNNGetCPUInfo(); gCoreFunction->supportFp16arith = gCPUInfo.fp16arith; gCoreFunction->supportSDot = gCPUInfo.dot; gCoreFunction->supportI8mm = gCPUInfo.i8mm; + gCoreFunction->MNNSumByAxisLForMatmul_A = MNNSumByAxisLForMatmul_A; #ifdef MNN_LOW_MEMORY - gCoreFunction->MNNGemmHybridInt8 = MNNGemmHybridInt8FP32; - gCoreFunction->MNNGemmHybridInt4 = MNNGemmHybridInt4FP32; -#if defined(__aarch64__) - if (gCoreFunction->supportSDot) { - gCoreFunction->MNNGemmHybridInt8 = MNNGemmHybridInt8FP32_sdot; - gCoreFunction->MNNGemmHybridInt4 = MNNGemmHybridInt4FP32_sdot; - } - if (gCoreFunction->supportI8mm) { - gCoreFunction->MNNGemmHybridInt8 = MNNGemmHybridInt8FP32_smmla; - gCoreFunction->MNNGemmHybridInt4 = MNNGemmHybridInt4FP32_smmla; - } -#endif + // Weight Dequant Gemm Kernels + gCoreFunction->MNNPackedMatMul_int4 = MNNPackedMatMul_int4; + gCoreFunction->MNNPackedMatMulRemain_int4 = MNNPackedMatMulRemain_int4; + gCoreFunction->MNNPackedMatMul_int8 = MNNPackedMatMul_int8; + gCoreFunction->MNNPackedMatMulRemain_int8 = MNNPackedMatMulRemain_int8; + // Dynamic Quant Helper Functions + gCoreFunction->MNNAbsMax = MNNAbsMaxFP32; + gCoreFunction->MNNDynamicQuant = MNNDynamicQuantFP32; + gCoreFunction->MNNQuantScale = MNNQuantScaleFP32; + gCoreFunction->MNNQuantSum = MNNQuantSumFP32; + // Dynamic Quan Bias + gCoreFunction->MNNDynamicUpdateConvBiasScale = MNNDynamicUpdateConvBiasScale; #endif MNNCoreInt8FunctionInit(); MNNFunctionInit(); diff --git a/source/backend/cpu/compute/CommonOptFunction.h b/source/backend/cpu/compute/CommonOptFunction.h index 9058c1353..bbfdce0fa 100644 --- a/source/backend/cpu/compute/CommonOptFunction.h +++ b/source/backend/cpu/compute/CommonOptFunction.h @@ -126,9 +126,9 @@ void MNNPackedMatMul_int8(float* C, const float* A, const float* B, const size_t void MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); void MNNQuantScaleFP32(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch); -void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack); +void MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack); void MNNQuantSumFP32(float* sum, const float* dequant_scale, size_t thread, size_t batch); - +void MNNDynamicUpdateConvBiasScale(float* newbias, float* newscale, float* oldbias, float* weightScale, float* inputScale, float* weightKernelSum, float* inputZero, size_t ocQuad, size_t scaleSize); void MNNPackForSparseMatMul_B(float* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const float* source, size_t h, size_t l, const int eP, bool transpose); struct SparseMatMulParas @@ -169,8 +169,15 @@ void MNNSourceTransformCommonF23(const float *source, float *dest, int unit, int void MNNConvDwF23MulTransUnit(float **cacheLine, const float *weigth, float *dest, size_t ow, const float* bias, const float* postParameter); void MNNMultiAndDestTransformCommon23(float **cacheLine, const float *weigth, float *dest, int cacheLineSize, int ow); void MNNInt8ToInt16(int16_t* dest, const int8_t* source, size_t count); -void MNNGemmHybridInt4FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); -void MNNGemmHybridInt8FP32(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); + +struct SumByAxisParams { + ssize_t kernelCountUnitDouble; + ssize_t col_buffer_unit_size; + ssize_t DST_XUNIT; + ssize_t SRC_UNIT; + ssize_t blockNum; + ssize_t oneScale; +}; } typedef void(*MNNBinaryExecute)(void* outputRaw, const void* inputRaw0, const void* inputRaw1, int elementSize, int broadcastIndex); @@ -195,20 +202,18 @@ struct CoreFunctions { // parameters: e, l, h, CStride, AStride, BStride void(*MNNPackedMatMul)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); void(*MNNPackedMatMulRemain)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); - void(*MNNPackedMatMul_int4)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); - void(*MNNPackedMatMulRemain_int4)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); - void(*MNNAbsMax)(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); - void(*MNNQuantScale)(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch); - void(*MNNDynamicQuant)(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack); - void(*MNNQuantSum)(float* sum, const float* dequant_scale, size_t thread, size_t batch); - void(*MNNGemmHybridInt4)(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); - void(*MNNGemmHybridInt8)(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); - void(*MNNPackedMatMul_int8)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); - void(*MNNPackedMatMulRemain_int8)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); + void(*MNNPackedMatMul_int4)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) = nullptr; + void(*MNNPackedMatMulRemain_int4)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) = nullptr; + void(*MNNAbsMax)(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) = nullptr; + void(*MNNQuantScale)(float* absmax, float* quant_scale, float* dequant_scale, size_t thread, size_t batch) = nullptr; + void(*MNNDynamicQuant)(const float* src, int8_t* dst, const float* scale, size_t src_depth_quad, size_t realSize, int pack) = nullptr; + void(*MNNQuantSum)(float* sum, const float* dequant_scale, size_t thread, size_t batch) = nullptr; + void(*MNNPackedMatMul_int8)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) = nullptr; + void(*MNNPackedMatMulRemain_int8)(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b) = nullptr; void(*MNNComputeMatMulForH_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); void(*MNNComputeMatMulForE_1)(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId); void(*MNNCountMaxMinValue)(float* source, float* minVal, float* maxVal, size_t size); - + void(*MNNDynamicUpdateConvBiasScale)(float* newbias, float* newscale, float* oldbias, float* weightScale, float* inputScale, float* weightKernelSum, float* inputZero, size_t ocQuad, size_t scaleSize); typedef void(*MNNPackedMatMulKernel)(float* C, const float* A, const float* B, const size_t* parameter, const float* postParameters, const float* bias); @@ -228,6 +233,7 @@ struct CoreFunctions { void(*MNNFp32ToLowp)(const float* src, int16_t* dst, size_t size); void(*MNNLowpToFp32)(const int16_t* src, float* dst, size_t size); int bytes; // Byte for float + int matmulBytes = 0; // Special bytes for dense matmul, C = A*B, A, B is matmulBytes, C is bytes. If 0, means the same as bytes /**NC4HW4's Functions*/ int pack; @@ -330,6 +336,7 @@ struct CoreFunctions { void(*MNN2BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); void(*MNN1BitcopyFast)(uint8_t* dstO, const uint8_t* srcO, int size, int stride, int ds); void(*MNNAccumulateSequenceNumber)(float* dst, const float* src, int size); + void(*MNNSumByAxisLForMatmul_A)(float* dest, int8_t* source, const float* dequantScale, ssize_t realDstCount, SumByAxisParams sumParams); }; void MNNCoreFunctionInit(); CoreFunctions* MNNGetCoreFunctions(); diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp index 7abf443d2..756f24aee 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.cpp @@ -13,12 +13,12 @@ #include #include "backend/cpu/CPUBackend.hpp" -#include "backend/cpu/compute/CommonOptFunction.h" #include "core/Concurrency.h" #include "core/TensorUtils.hpp" + namespace MNN { -ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* convOp, std::shared_ptr res): CPUConvolution(convOp, backend), mResource(res), mMutableResource(res, backend) { +ConvInt8TiledExecutor::ConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr res): CPUConvolution(convOp->common(), backend), mResourceInt8(res), mMutableResource(res, backend) { mValid = mMutableResource.mValid; } @@ -85,9 +85,10 @@ void ConvInt8TiledExecutor::reorderWeight(Tensor* weight, const uint8_t* weightS } static bool _reorderWeightInside(Backend* bn, const Convolution2DCommon* common, - const std::shared_ptr& weightOrigin, - std::shared_ptr& weight) { + const std::shared_ptr& weightOrigin, + std::shared_ptr& weight) { auto core = static_cast(bn)->int8Functions(); + auto gcore = static_cast(bn)->functions(); int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); // reorder weight, [oc, ic, k^2] => [oc/unit, ((ic/unit)*k^2)/(src_unit/unit), unit(oc), (src_unit/unit), unit(ic)] @@ -111,30 +112,195 @@ static bool _reorderWeightInside(Backend* bn, const Convolution2DCommon* common, return true; } -DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr res) : ConvInt8TiledExecutor(backend, convOp->common(), res) { - std::shared_ptr weightOrigin = mResource->mWeightInt8; - mValid = _reorderWeightInside(backend, convOp->common(), weightOrigin, mResource->mWeightInt8); +static void Getfp32Info (std::shared_ptr resource, std::shared_ptr weightOrigin, const Convolution2D* conv2d, std::shared_ptr quantCommon) { + // common parameters + int outputCount = conv2d->common()->outputCount(); + auto core = static_cast(resource->backend)->functions(); + int LSize = conv2d->common()->inputCount() * conv2d->common()->kernelX() * conv2d->common()->kernelY(); + int ocUp4 = ROUND_UP(outputCount, core->pack); + + int dequantCnt = quantCommon->alpha.size(); + if (quantCommon->asymmetric) { + dequantCnt /= 2; + } + int blockNum = dequantCnt / outputCount; + int scaleSize = blockNum * ocUp4; // pack size. + int blockSize = LSize / blockNum; + int originOffset = 0; + if (quantCommon->canUseInt4) { + originOffset = -8; + } + + // Save weight quant scale and bias: wf=scale*wi+bias + int bytes = 4; + resource->mDequantize.mScaleBias.reset(Tensor::createDevice({2 * scaleSize * bytes})); + auto success = resource->backend->onAcquireBuffer(resource->mDequantize.mScaleBias.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Alloc denquant scaleBias memory error\n"); + return; + } + auto alphaPtr = resource->mDequantize.mScaleBias->host(); + auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + scaleSize * bytes); + ::memset(alphaPtr, 1, scaleSize * bytes); + ::memset(biasPtr, 0, scaleSize * bytes); + auto quanInfoPtr = quantCommon->alpha.get(); + int h = quantCommon->alpha.size(); + if (quantCommon->asymmetric) { + for (int i = 0; i < blockNum; ++i) { + auto dstAlpha = alphaPtr + i * ocUp4; + auto dstBias = biasPtr + i * ocUp4; + for (int j = 0; j < outputCount; ++j) { + int scaleIndex = j * blockNum + i; + dstAlpha[j] = quanInfoPtr[2 * scaleIndex + 1]; + dstBias[j] = quanInfoPtr[2 * scaleIndex] + (float)originOffset * dstAlpha[j]; + } + } + + } else { + for (int i = 0; i < blockNum; ++i) { + auto dstAlpha = alphaPtr + i * ocUp4; + auto dstBias = biasPtr + i * ocUp4; + for (int j = 0; j < outputCount; ++j) { + int scaleIndex = j * blockNum + i; + dstAlpha[j] = quanInfoPtr[scaleIndex]; + dstBias[j] = (float)originOffset * dstAlpha[j]; + } + } + } + // Save float weight kernel sum + resource->mWeightKernelSum.reset(Tensor::createDevice({bytes * ocUp4})); + success = resource->backend->onAcquireBuffer(resource->mWeightKernelSum.get(), Backend::STATIC); + if (!success) { + MNN_ERROR("Alloc denquant mWeightKernelSum memory error\n"); + return; + } + auto weightKernelSum = resource->mWeightKernelSum->host(); + auto realWeightData = weightOrigin->host(); + ::memset(weightKernelSum, 0, resource->mWeightKernelSum->size()); + for (int j = 0; j < outputCount; ++j) { + float sum = 0.f; + for (int k = 0; k < blockNum; ++k) { + int scaleIndex = k + j * blockNum; + float scale = 0; + float bias = 0; + if (quantCommon->asymmetric) { + scale = quanInfoPtr[2 * scaleIndex + 1]; + bias = quanInfoPtr[2 * scaleIndex]; + } else { + scale = quanInfoPtr[scaleIndex]; + bias = 0; + } + int tmp = 0; + for (int i = 0; i < blockSize; ++i) { + int l_index = k * blockSize + i; + tmp += (int)realWeightData[j * blockNum * blockSize + l_index]; + } + sum += (tmp * scale + blockSize * bias); + } + weightKernelSum[j] = sum; + } +} + +DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr res, bool dynamicQuantExe) : ConvInt8TiledExecutor(backend, convOp, res) { + std::shared_ptr weightOrigin = mResourceInt8->mWeightInt8; + std::shared_ptr quanCommon ; + mDynamicQuantExe = dynamicQuantExe; + if (dynamicQuantExe) { + MNN_ASSERT(convOp->quanParameter() != nullptr && convOp->quanParameter()->buffer() != nullptr); + quanCommon = ConvolutionCommon::load(convOp, backend, false, true); + // fp32 weightKernelSum + mResource.reset(new CPUConvolution::Resource); + mResource->backend = backend; + Getfp32Info(mResource, weightOrigin, convOp, quanCommon); // Call this before reorder weight. + } + + mValid = _reorderWeightInside(backend, convOp->common(), weightOrigin, mResourceInt8->mWeightInt8); if(!mValid) { return; } - // choose int8 gemm kernel auto core = static_cast(backend)->int8Functions(); - mGemmKernel = core->Int8GemmKernel; + auto gcore = static_cast(backend)->functions(); + // offline quant + if (false == dynamicQuantExe) { + mGemmKernel = core->Int8GemmKernel; #ifdef MNN_USE_SSE - int actBits = convOp->symmetricQuan()->nbits(); - if (actBits <= 7) { - mGemmKernel = core->Int8GemmKernelFast; - } + int actBits = convOp->symmetricQuan()->nbits(); + if (actBits <= 7) { + mGemmKernel = core->Int8GemmKernelFast; + } #else - if(convOp->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){ - mGemmKernel = core->Int8GemmKernelFast; - } + if(convOp->symmetricQuan()->method() == QuantizeAlgo_OVERFLOW_AWARE){ + mGemmKernel = core->Int8GemmKernelFast; + } #endif -} + mResource.reset(new CPUConvolution::Resource); + CPUConvolution::makeResource(backend, mResource, convOp, mResourceInt8); + return; + } -DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const DenseConvInt8TiledExecutor& exe) - : ConvInt8TiledExecutor(backend, common, exe.mResource), mGemmKernel(exe.mGemmKernel) { + // dynamic quant + int UNIT, SRC_UNIT, DST_XUNIT; + core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); + bool needPermuteInt4weight = ((UNIT == 8 && SRC_UNIT == 8 && DST_XUNIT ==10) || (UNIT == 4 && SRC_UNIT == 8 && DST_XUNIT ==20) || (UNIT == 64 && SRC_UNIT == 4 && DST_XUNIT ==4)); + mResource->mDequantize.bits = 8; + if (quanCommon->canUseInt4) { + mResourceInt8->mWeightAsymmetricQuant = true; + auto weightLength = mResourceInt8->mWeightInt8->size(); + MNN_ASSERT(weightLength % 2 == 0); + mResource->mDequantize.bits = 4; + std::shared_ptr weightLow(Tensor::createDevice( mResourceInt8->mWeightInt8->shape())); + auto res = mResource->backend->onAcquireBuffer(weightLow.get(), Backend::STATIC); + if (!res) { + MNN_ERROR("int4 weight acquire buffer error\n"); + return ; + } + auto srcPtr = mResourceInt8->mWeightInt8->host(); + auto dstPtr = weightLow->host(); + // Pack two int4-weight to one int8-weight. + if (false == needPermuteInt4weight) { + weightLength = UP_DIV(weightLength, 2); + for (int i=0; i < weightLength; ++i) { + int s0 = srcPtr[2 * i + 0]; + int s1 = srcPtr[2 * i + 1]; + int d = (s0 + 8) * 16 + (s1 + 8); + dstPtr[i] = d; + } + } else { + int permuteUnit = UNIT * SRC_UNIT; + int halfPermuteStride = static_cast(permuteUnit / 2); + for (int i = 0; i < weightLength / permuteUnit; ++i) { + auto src0 = srcPtr + i * permuteUnit; + auto dst0 = dstPtr + i * halfPermuteStride; + for (int j = 0; j < halfPermuteStride; ++j) { + int s0 = src0[j]; + int s1 = src0[j + halfPermuteStride]; + int d = (s0 + 8) * 16 + (s1 + 8); + dst0[j] = d; + } + } + } + // Update int4 weight to mWeightInt8. + mResourceInt8->mWeightInt8 = weightLow; + } + // Relu/Relu6 post parameters + auto postPtr = getPostParameters(); + mResource->mReluThreshold.resize(2); + mResource->mReluThreshold[0] = postPtr[2]; + mResource->mReluThreshold[1] = postPtr[3]; + if (gcore->bytes == 2) { + gcore->MNNFp32ToLowp(mResource->mReluThreshold.data(), reinterpret_cast(mResource->mReluThreshold.data()), 2); + } + if (mCommon->relu()) { + mResource->mReluThreshold[0] = 0.f; + } + if (mCommon->relu6()) { + mResource->mReluThreshold[0] = 0.f; + mResource->mReluThreshold[1] = 6.f; + } +} +DenseConvInt8TiledExecutor::DenseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, bool dynamicQuantExe, const DenseConvInt8TiledExecutor& exe) + : ConvInt8TiledExecutor(backend, convOp, exe.mResourceInt8), mGemmKernel(exe.mGemmKernel), mResource(exe.mResource), mDynamicQuantExe(dynamicQuantExe) { } DenseConvInt8TiledExecutor::~DenseConvInt8TiledExecutor() { @@ -145,7 +311,7 @@ bool DenseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** if (nullptr == dst) { return true; } - auto exe = new DenseConvInt8TiledExecutor(bn, op->main_as_Convolution2D()->common(), *this); + auto exe = new DenseConvInt8TiledExecutor(bn, op->main_as_Convolution2D(), mDynamicQuantExe, *this); if (!exe->valid()) { return false; } @@ -159,42 +325,153 @@ void DenseConvInt8TiledExecutor::getPackParameter(int* Unit, int* srcUnit, int* ErrorCode DenseConvInt8TiledExecutor::onResize(const std::vector& inputs, const std::vector& outputs) { - // Timer kernelTimer; - ConvInt8TiledExecutor::onResize(inputs, outputs); - auto output = outputs[0]; + mUseBatchQuan = (static_cast(backend())->getRuntime()->hint().dynamicQuantOption == 1); + mUseBatchQuan &= mCommon->kernelY() == 1 && mCommon->kernelX() == 1 + && outputs[0]->width() == inputs[0]->width() && outputs[0]->height() == inputs[0]->height() + && mCommon->strideX() == 1 && mCommon->strideY() == 1 && mCommon->padX() == 0 && mCommon->padY() == 0 + && outputs[0]->height() == 1 && outputs[0]->width() == 1; + mUseBatchQuan &= mDynamicQuantExe; + mUseBatchQuan &= (inputs[0]->batch() > 1); auto core = static_cast(backend())->int8Functions(); - + auto gcore =static_cast(backend())->functions(); int UNIT, SRC_UNIT, DST_XUNIT; - getPackParameter(&UNIT, &SRC_UNIT, &DST_XUNIT, core); - const int threads = std::max(static_cast(backend())->threadNumber(), 1); + core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); + + if (mDynamicQuantExe == false) { + mMutableResource.updateInputOutputScale(TensorUtils::getQuantInfo(inputs[0]), TensorUtils::getQuantInfo(outputs[0])); + CPUConvolution::onResize(inputs, outputs); + ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core); + mBlockNum = 1; + } else { // Dynamic Quant kernels + CPUConvolution::onResize(inputs, outputs); + // Gemm Kernel + mGemmKernel = core->Int8GemmKernel; + if (mResource->mDequantize.bits == 4) { + mGemmKernel = core->Int8GemmKernel_W4; + } + mQuantFunc = core->MNNFloat2Int8; + if (gcore->bytes == 2 && gcore->pack == 8) { + mGemmKernel = core->MNNGemmInt8AddBiasScale_Unit_FP16; + if (mResource->mDequantize.bits == 4) { + mGemmKernel = core->MNNGemmInt8AddBiasScale_w4_Unit_FP16; + } + mQuantFunc = core->DynamicQuanInput_ARM82; + mQuantAndReorderFunc = core->DynamicQuanInputAndReorder_ARM82; + + } + // A axisSum kernel + mSumByAxisLFunc = gcore->MNNSumByAxisLForMatmul_A; + if (gcore->bytes == 2 && gcore->pack == 8) { // use fp16 + ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core, 4); + } else { + ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, inputs[0], outputs[0], mPadX, mPadY, gcore, core); + } + int ocUp4 = ROUND_UP(outputs[0]->channel(), gcore->pack); + int alphaSize = mResource->mDequantize.mScaleBias->size() / (4 * 2); + mBlockNum = alphaSize / ocUp4; + } + + // input scale buffer + int batch = inputs[0]->batch(); +// mTempIm2ColBuffer.reset(Tensor::createDevice({mThreadNums, DST_XUNIT * mIm2ColCount * mResourceInt8->mWeightInt8->length(1) * SRC_UNIT})); + mInputDeqScales.reset(Tensor::createDevice({batch * 4})); + bool success = backend()->onAcquireBuffer(mInputDeqScales.get(), Backend::DYNAMIC); + + // Im2col info + auto output = outputs[0]; + const int threads = static_cast(backend())->threadNumber(); auto planeSize = output->width() * output->height() * output->batch(); - auto planeSizeInThread = UP_DIV(planeSize, threads); const int L2Size = 2048; const int tileLimitByC = UP_DIV(L2Size, mIm2ColParamter.kernelCountUnit * SRC_UNIT); - int tileLimit = ALIMIN(tileLimitByC, planeSizeInThread); + int tileLimit = 0; + int outC = output->channel(); + int outC4 = UP_DIV(outC, gcore->pack); + + if (threads < planeSize) { // Thread split by output nhw. + tileLimit = ALIMIN(tileLimitByC, UP_DIV(planeSize, threads)); + mSplitByOc = false; + } else { + tileLimit = ALIMIN(tileLimitByC, planeSize); + auto ocPerThread = UP_DIV(outC4, threads); + auto threadNeed = UP_DIV(outC4, ocPerThread); + if (UNIT > gcore->pack) { // AVX512:UNIT=64,pack=16 + MNN_ASSERT(UNIT % gcore->pack == 0); + int ocDivUnit = UP_DIV(outC4 * gcore->pack, UNIT); + ocPerThread = UP_DIV(ocDivUnit, threads); + threadNeed = UP_DIV(ocDivUnit, ocPerThread); + } + mThreadNums = ALIMIN(threads, threadNeed); + mSplitByOc = true; + + mDivides.resize(threads+1); + mDivides[0] = 0; + static_cast(backend()->getRuntime())->computeDivideSizes(outC4, mDivides.data() + 1); + } mIm2ColCount = UP_DIV(tileLimit, DST_XUNIT); auto DynamicDestUnit = DST_XUNIT * mIm2ColCount; mTileCount = UP_DIV(planeSize, DynamicDestUnit); - mThreadNums = std::min(threads, mTileCount); - - auto input = inputs[0]; - // set im2col tensor info - mTempIm2ColBuffer.reset(Tensor::createDevice({mThreadNums, DST_XUNIT * mIm2ColCount * mResource->mWeightInt8->length(1) * SRC_UNIT})); - bool success = backend()->onAcquireBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); - if (!success) { - return OUT_OF_MEMORY; + + if (threads < planeSize) { + mThreadNums = ALIMIN(threads, mTileCount); + mDivides.resize(threads+1); + mDivides[0] = 0; + static_cast(backend()->getRuntime())->computeDivideSizes(mTileCount, mDivides.data() + 1); } + int ocUp4 = ROUND_UP(outC, gcore->pack); + int alphaSize = mResource->mDequantize.mScaleBias->size() / (4 * 2); + auto bufferAlloc = static_cast(backend())->getBufferAllocator(); auto blitInfoSize = ConvolutionTiledExecutor::computeBlitInfoSize(DST_XUNIT * mIm2ColCount, mIm2ColParamter.ow, mIm2ColParamter.kernelX * mIm2ColParamter.kernelY, mThreadNums); + mBlitInfoStride = blitInfoSize.second; mBlitInfo = bufferAlloc->alloc(blitInfoSize.first); - if (mBlitInfo.invalid()) { + mTempIm2ColBuffer.reset(Tensor::createDevice({mThreadNums, DST_XUNIT * mIm2ColCount * mResourceInt8->mWeightInt8->length(1) * SRC_UNIT})); + mTempSrcSum.resize(mThreadNums * mBlockNum * DST_XUNIT * mIm2ColCount * 4); // Use 4 bytes to save kernel sum. + + success &= backend()->onAcquireBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); + if (!success || mBlitInfo.invalid()) { return OUT_OF_MEMORY; } - bufferAlloc->free(mBlitInfo); - mBlitInfoStride = blitInfoSize.second; + if (false == mDynamicQuantExe) { + bufferAlloc->free(mBlitInfo); + backend()->onReleaseBuffer(mInputDeqScales.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); + return NO_ERROR; + } + + int inC = inputs[0]->channel(); + // set im2col tensor info + mQuantInput.reset((Tensor::createDevice({batch, mIm2ColParamter.ih, mIm2ColParamter.iw, ROUND_UP(inC, gcore->pack)}))); + // set dynamic quant buffer + mTempMaxMinValueBuffer.reset(Tensor::createDevice({mThreadNums, 2 * gcore->bytes})); + // set compute buffer + mDynamicBias.reset(Tensor::createDevice({ocUp4 * 4})); + mScaleFuse.reset(Tensor::createDevice({alphaSize * 4})); + + success &= backend()->onAcquireBuffer(mQuantInput.get(), Backend::DYNAMIC); + success &= backend()->onAcquireBuffer(mDynamicBias.get(), Backend::DYNAMIC); + success &= backend()->onAcquireBuffer(mTempMaxMinValueBuffer.get(), Backend::DYNAMIC); + success &= backend()->onAcquireBuffer(mScaleFuse.get(), Backend::DYNAMIC); + + if (mUseBatchQuan) { + int infobytes = 4; // use float32 to save dequant scale and quant scale. + int size = mThreadNums * batch * gcore->bytes + 2 * batch * infobytes; + mBatchQuantInfo.reset(Tensor::createDevice({size})); + success &= backend()->onAcquireBuffer(mBatchQuantInfo.get(), Backend::DYNAMIC); + } + if (!success) { + return OUT_OF_MEMORY; + } + bufferAlloc->free(mBlitInfo); + backend()->onReleaseBuffer(mInputDeqScales.get(), Backend::DYNAMIC); backend()->onReleaseBuffer(mTempIm2ColBuffer.get(), Backend::DYNAMIC); - // MNN_PRINT("dense conv2d int8 resize: cost time: %llu us\n", kernelTimer.durationInUs()); + backend()->onReleaseBuffer(mQuantInput.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mDynamicBias.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mTempMaxMinValueBuffer.get(), Backend::DYNAMIC); + backend()->onReleaseBuffer(mScaleFuse.get(), Backend::DYNAMIC); + if (mUseBatchQuan) { + backend()->onReleaseBuffer(mBatchQuantInfo.get(), Backend::DYNAMIC); + } return NO_ERROR; } @@ -203,85 +480,323 @@ ErrorCode DenseConvInt8TiledExecutor::onExecute(const std::vector& inpu const auto input = inputs[0]; auto output = outputs[0]; auto core = static_cast(backend())->int8Functions(); + auto gcore = static_cast(backend())->functions(); int UNIT__, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT__, &SRC_UNIT, &DST_XUNIT); auto blitProc = core->MNNPackC4Int8ForMatMul_A; - const int plane = output->batch() * mIm2ColParamter.oh * mIm2ColParamter.ow; - int PackUnit = static_cast(backend())->functions()->pack; - const int dstZStep = plane * PackUnit; - - const int batch = input->batch(); - const int ocDiv4 = UP_DIV(output->channel(), PackUnit); + if ( mDynamicQuantExe && gcore->bytes == 2 && core->MNNPackC4Int8ForMatMul_A_ARM86FP16) { + blitProc = core->MNNPackC4Int8ForMatMul_A_ARM86FP16; + } + const int plane = output->batch() * mIm2ColParamter.oh * mIm2ColParamter.ow; + const int batch = input->batch(); + const int PackUnit = gcore->pack; + const int dstZStep = plane * PackUnit; + const int ocDiv4 = UP_DIV(output->channel(), PackUnit); + const int ocUp4 = ROUND_UP(output->channel(), PackUnit); const auto kernelCountUnitDouble = mIm2ColParamter.kernelCountUnit; - //auto remain = outputPlaneLen % GEMM_INT8_DST_XUNIT; - //FUNC_PRINT(remain); + const auto col_buffer_unit_size = kernelCountUnitDouble * DST_XUNIT * SRC_UNIT * sizeof(int8_t); + const auto col_buffer_size = col_buffer_unit_size * mIm2ColCount; + const int dstBytes = static_cast(backend())->getBytes(backend(), output); + const int alphaSize = mResource->mDequantize.mScaleBias->size() / (4 * 2); + const int blockL = kernelCountUnitDouble / mBlockNum; // source depthQuad for each block. + float weightBytes = 1.f; + int weight_step_Y = weightBytes * (UNIT__ * SRC_UNIT); + int src_step_Y = DST_XUNIT * SRC_UNIT; + + auto inputDataPtr = input->host(); + auto im2colPtr = mTempIm2ColBuffer->host(); + const auto weightDataPtr = mResourceInt8->mWeightInt8->host(); + auto srcKernelSumPtr = mTempSrcSum.data(); + auto weightDequantBias = mResource->mDequantize.mScaleBias->host() + alphaSize * 4; + + auto outputDataPtr = output->host(); + auto biasPtr = mMutableResource.mBiasFloat->host(); + auto scalePtr = mMutableResource.mScaleFloat->host(); + + auto inputZeroPoint = mMutableResource.mInputZeroPoint; + auto inputScalePtr = mInputDeqScales->host(); + (reinterpret_cast(inputScalePtr))[0] = mMutableResource.mInputScale; + + auto SingleDynamicQuant = [&] () { + const auto floatptr = input->host(); + auto int8ptr = mQuantInput->host(); + auto inputsize = static_cast(backend())->getTensorSize(inputs[0]); + float quantscale = 0.f; + float dequantscale = 0.f; + int zeropoint = 0; + + /* Count max and min value to compute input scale and zeropoint */ + auto maxMinValPtr = mTempMaxMinValueBuffer->host(); + int threadNeed = mThreadNums; + auto inputSizeCount = UP_DIV(inputsize, mThreadNums); + if (inputSizeCount < 9) { + threadNeed = 1; + inputSizeCount = inputsize; + } else { + threadNeed = ALIMIN(UP_DIV(inputsize, inputSizeCount), mThreadNums); + inputSizeCount = UP_DIV(inputsize, threadNeed); + } + auto findMaxMinValueFunction = [&](int tId) { + auto perThreadWorkCount = ALIMIN(inputSizeCount, inputsize - tId * inputSizeCount); + auto minValPtrTid = reinterpret_cast(maxMinValPtr + tId * mTempMaxMinValueBuffer->stride(0)); + auto maxValPtrTid = reinterpret_cast(maxMinValPtr + tId * mTempMaxMinValueBuffer->stride(0) + gcore->bytes); + auto inputDataPtrTid = reinterpret_cast(reinterpret_cast(floatptr) + tId * inputSizeCount * gcore->bytes); + gcore->MNNCountMaxMinValue(inputDataPtrTid, minValPtrTid, maxValPtrTid, perThreadWorkCount); + }; + MNN_CONCURRENCY_BEGIN(tId, threadNeed) { + findMaxMinValueFunction((int)tId); + } + MNN_CONCURRENCY_END(); + if (threadNeed > 1) { + gcore->MNNCountMaxMinValue(reinterpret_cast(maxMinValPtr),reinterpret_cast(maxMinValPtr), reinterpret_cast(maxMinValPtr + gcore->bytes), 2 * mThreadNums); + } + float maxVal = 0; + float minVal = 0; + if (gcore->bytes == 4) { + maxVal = (reinterpret_cast(maxMinValPtr))[1]; + minVal = (reinterpret_cast(maxMinValPtr))[0]; + } + if (gcore->bytes == 2) { + std::vector _mVal(2); + gcore->MNNLowpToFp32(reinterpret_cast(maxMinValPtr), _mVal.data(), 2); + maxVal = _mVal[1]; + minVal = _mVal[0]; + } - const auto inputDataPtr = input->host(); - const auto weightDataPtr = mResource->mWeightInt8->host(); + /* Dynamic quant */ + float range = maxVal - minVal; + quantscale = 255.0f / range; + dequantscale = range / 255.0f; + zeropoint = static_cast(roundf(-minVal * 255.f / range) - 128.0f); + std::vectorqsVec(PackUnit, quantscale); + auto sizeDiv = UP_DIV(inputsize, PackUnit); + int inputPlane = input->batch() * mIm2ColParamter.iw * mIm2ColParamter.ih; + if (gcore->bytes == 2 && gcore->pack == 8 && inputPlane > 1) { // C8->C4 + mQuantAndReorderFunc(floatptr, int8ptr, inputPlane, qsVec.data(), -128, 127, (ssize_t)zeropoint, UP_DIV(input->channel(), PackUnit), 4 * inputPlane); + } else { + mQuantFunc(floatptr, int8ptr, sizeDiv, qsVec.data(), -128, 127, (ssize_t)zeropoint); + } - auto im2colPtr = mTempIm2ColBuffer->host(); - auto outputDataPtr = output->host(); - QuanPostTreatParameters quanParam; - quanParam.bias = mMutableResource.mBiasInt32->host(); - quanParam.scale = mMutableResource.mScaleFloat->host(); - quanParam.maxValue = mMutableResource.mClampMax; - if (mResource->mRelu) { - quanParam.minValue = mMutableResource.mOutputZeroPoint; + /* bias float */ + #ifdef MNN_USE_SSE + int offset = 128; + #else + int offset = 0; + #endif + auto biasfp32 = mMutableResource.mResource->mOriginBias->host(); + auto weightDequantScale = mResource->mDequantize.mScaleBias->host(); + float zerofp32 = (zeropoint + offset) * dequantscale; + + gcore->MNNDynamicUpdateConvBiasScale(mDynamicBias->host(), mScaleFuse->host(), biasfp32, weightDequantScale, &dequantscale, mResource->mWeightKernelSum->host(), &zerofp32, UP_DIV(output->channel(), 4), alphaSize); + // Move step for A and B for each block computing + + inputZeroPoint = zeropoint; + (reinterpret_cast(inputScalePtr))[0] = dequantscale; + biasPtr = mDynamicBias->host(); + scalePtr = mScaleFuse->host(); + inputDataPtr = int8ptr; + }; + + auto BatchDynamicQuant = [&]() { + // Allocate input max/sum/dequant/quant buffer + auto infobytes = 4; + auto dequantPtr = mBatchQuantInfo->host(); + auto quantPtr = dequantPtr + batch * infobytes; + auto maxPtr = mBatchQuantInfo->host() + 2 * batch * infobytes; + + // compute sum and absmax + int icDiv4 = UP_DIV(input->channel(), PackUnit); + int threadwork = UP_DIV(icDiv4, mThreadNums); + int threadNeed = UP_DIV(icDiv4, threadwork); + int threadTmp = ALIMIN(mThreadNums, threadNeed); + threadwork = UP_DIV(icDiv4, threadTmp); + MNN_CONCURRENCY_BEGIN(tId, threadTmp) { + int workCount = threadwork; + if (tId == threadTmp - 1) { + workCount = icDiv4 - tId * threadwork; + } + int icIndex = tId * threadwork; + auto inputData = reinterpret_cast(input->host() + icIndex * batch * PackUnit * gcore->bytes); + auto batchMax = reinterpret_cast(maxPtr + tId * batch * gcore->bytes); + gcore->MNNAbsMax(inputData, batchMax, workCount, batch, PackUnit); + } + MNN_CONCURRENCY_END(); + + // Compute quant scale + gcore->MNNQuantScale((float*)maxPtr, (float*)quantPtr, (float*)dequantPtr, threadTmp, batch); + + // quant + MNN_CONCURRENCY_BEGIN(tId, threadTmp) { + int workCount = threadwork; + if (tId == threadTmp - 1) { + workCount = icDiv4 - tId * threadwork; + } + auto icIndex = tId * threadwork; + auto inputData = reinterpret_cast(input->host() + icIndex * batch * PackUnit * gcore->bytes); + auto int8ptr = mQuantInput->host() + icIndex * batch * PackUnit; + auto scale_ptr = reinterpret_cast(quantPtr); + gcore->MNNDynamicQuant(inputData, int8ptr, scale_ptr, workCount, batch, PackUnit); + } + MNN_CONCURRENCY_END(); + + inputZeroPoint = 0; + inputScalePtr = (uint8_t*)dequantPtr; + inputDataPtr = mQuantInput->host(); + biasPtr = mMutableResource.mResource->mOriginBias->host(); + scalePtr = mResource->mDequantize.mScaleBias->host(); + }; + ssize_t oneScale = 1; + if (mUseBatchQuan) { + BatchDynamicQuant(); + oneScale = 0; + } else if (mDynamicQuantExe) { + SingleDynamicQuant(); } else { - quanParam.minValue = mMutableResource.mClampMin; + // offline quant. } - int dstBytes = static_cast(backend())->getBytes(backend(), output); - if (dstBytes != 1) { - quanParam.useInt8 = 0; + + if (mResource->mDequantize.bits == 4) { + weightBytes = 0.5; + weight_step_Y *= 0.5; } - //MNN_PRINT("max: %d, min: %d\n", quanParam.maxValue, quanParam.minValue); - const int col_buffer_unit_size = mIm2ColParamter.kernelCountUnit * DST_XUNIT * SRC_UNIT * sizeof(int8_t); - auto col_buffer_size = col_buffer_unit_size * mIm2ColCount; - auto threadFunction = [&](int tId) { + + SumByAxisParams sumParams; + sumParams.oneScale = oneScale; + sumParams.SRC_UNIT = SRC_UNIT; + sumParams.blockNum = mBlockNum; + sumParams.DST_XUNIT = DST_XUNIT; + sumParams.col_buffer_unit_size = col_buffer_unit_size; + sumParams.kernelCountUnitDouble = kernelCountUnitDouble; + + auto ThreadFunction = [&](int tId, int eStartIndex, int eEndIndex, int estep, int ocIndex) { + auto ocDivThread = ocDiv4; + if (mSplitByOc) { // Thread split by OC + ocDivThread = ALIMIN(mDivides[tId + 1] - mDivides[tId], ocDiv4 - mDivides[tId]); + } + float* reluPtr = mResource->mReluThreshold.data(); + uint8_t* extraScale = nullptr; // input scale for batch dynamic quant. + QuanPostTreatParameters quanParam; + quanParam.blockNum = mBlockNum; + if (mUseBatchQuan) { + extraScale = inputScalePtr; + } +#ifdef MNN_USE_SSE + quanParam.extraBias = mResource->mWeightKernelSum->host() + ocIndex; +#endif + if (dstBytes != 1) { + quanParam.useInt8 = 0; + quanParam.fp32minmax = reluPtr; + } else { + quanParam.maxValue = mMutableResource.mClampMax; + if (mResourceInt8->mRelu) { + quanParam.minValue = mMutableResource.mOutputZeroPoint; + } else { + quanParam.minValue = mMutableResource.mClampMin; + } + } + auto outputTid = outputDataPtr + ocIndex * plane * dstBytes; + const auto biasFloatTid = reinterpret_cast(biasPtr + ocIndex * 4); + const auto scaleFloatTid = reinterpret_cast(scalePtr + ocIndex * 4); + const auto weightDequanBiasTid = reinterpret_cast(weightDequantBias + ocIndex * 4); + const auto weightPtrTid = weightDataPtr + static_cast(ocIndex * kernelCountUnitDouble * SRC_UNIT * weightBytes); + if (mBlockNum == 1) { + quanParam.biasFloat = biasFloatTid; + quanParam.scale = scaleFloatTid; + quanParam.weightQuanBias = weightDequanBiasTid; + } + auto colAddr = im2colPtr + tId * mTempIm2ColBuffer->stride(0); auto srcPtr = (int8_t const **)(mBlitInfo.ptr() + tId * mBlitInfoStride.first); auto el = (int32_t *)(srcPtr + mBlitInfoStride.second); + auto xKernelSumPtrTid = reinterpret_cast(srcKernelSumPtr + tId * mBlockNum * DST_XUNIT * mIm2ColCount * 4); - int32_t info[4]; + int32_t info[6]; info[1] = mIm2ColParamter.iw * mIm2ColParamter.ih * batch; - info[2] = col_buffer_unit_size; + info[2] = static_cast(col_buffer_unit_size); info[3] = mIm2ColParamter.strideX; - for (int tIndex = tId; tIndex < mTileCount; tIndex += mThreadNums) { + info[5] = kernelCountUnitDouble; + for (int tIndex = eStartIndex; tIndex < eEndIndex; tIndex += estep) { const int xIndexStart = tIndex * DST_XUNIT * mIm2ColCount; int realDstCount = ALIMIN(plane - xIndexStart, DST_XUNIT * mIm2ColCount); - + auto ptrExtraScale = extraScale != nullptr ? (extraScale + xIndexStart * 4) : nullptr; + auto ptrInputscale = mUseBatchQuan == true ? (inputScalePtr + xIndexStart * 4) : inputScalePtr; // im2col auto res = ConvolutionTiledExecutor::turnIm2ColToBlitInfo((const float**)srcPtr, el, xIndexStart, realDstCount, mIm2ColParamter, (const uint8_t*)inputDataPtr, 1); int number = res.first; bool needZero = res.second; if (needZero) { #ifdef MNN_USE_SSE - ::memset(colAddr, mMutableResource.mInputZeroPoint + 128, col_buffer_size); + ::memset(colAddr, inputZeroPoint + 128, col_buffer_size); #else - ::memset(colAddr, mMutableResource.mInputZeroPoint, col_buffer_size); + ::memset(colAddr, inputZeroPoint, col_buffer_size); #endif } info[0] = number; + info[4] = realDstCount; if (number > 0) { blitProc(colAddr, srcPtr, info, el); } - auto outputInTilePtr = outputDataPtr + xIndexStart * PackUnit * dstBytes; + if (mResourceInt8->mWeightAsymmetricQuant) { + mSumByAxisLFunc(xKernelSumPtrTid, colAddr, (float*)ptrInputscale, realDstCount, sumParams); + } + auto outputInTilePtr = outputTid + xIndexStart * PackUnit * dstBytes; auto colAddrTemp = colAddr; - do { - int step = ALIMIN(DST_XUNIT, realDstCount); - mGemmKernel(outputInTilePtr, colAddrTemp, weightDataPtr, kernelCountUnitDouble, dstZStep * dstBytes, ocDiv4, &quanParam, step); - realDstCount-=step; - outputInTilePtr += DST_XUNIT * PackUnit * dstBytes; - colAddrTemp += col_buffer_unit_size; - } while(realDstCount > 0); + auto ptrX = xKernelSumPtrTid; + if (mBlockNum == 1) { + do { + int step = ALIMIN(DST_XUNIT, realDstCount); + quanParam.srcKernelSum = ptrX; + quanParam.extraScale = extraScale != nullptr ? (float*)ptrExtraScale : nullptr; + mGemmKernel(outputInTilePtr, colAddrTemp, weightPtrTid, kernelCountUnitDouble, dstZStep * dstBytes, ocDivThread, &quanParam, step); + ptrX += step; + realDstCount-=step; + outputInTilePtr += DST_XUNIT * PackUnit * dstBytes; + colAddrTemp += col_buffer_unit_size; + ptrExtraScale = extraScale != nullptr ? (ptrExtraScale + step * 4) : nullptr; + } while(realDstCount > 0); + } else { // Now offline quant do not run into. + do { + int step = ALIMIN(DST_XUNIT, realDstCount); + quanParam.extraScale = extraScale != nullptr ? (float*)ptrExtraScale : nullptr; + for (int k = 0; k < mBlockNum; ++k) { + quanParam.biasFloat = nullptr; + quanParam.fp32minmax = nullptr; + if (k == 0) { + quanParam.biasFloat = (float*)biasFloatTid; + } + if (k == mBlockNum - 1) { + quanParam.fp32minmax = reluPtr; + } + quanParam.srcKernelSum = ptrX + k * step; + quanParam.weightQuanBias = weightDequanBiasTid + k * ocUp4; + quanParam.scale = (float*)(scaleFloatTid + k * ocUp4); + + mGemmKernel(outputInTilePtr, colAddrTemp + k * blockL * src_step_Y, weightPtrTid + k * blockL * weight_step_Y, blockL, dstZStep * dstBytes, ocDivThread, &quanParam, step); + } + ptrX += (step * mBlockNum); + realDstCount-=step; + outputInTilePtr += DST_XUNIT * PackUnit * dstBytes; + colAddrTemp += col_buffer_unit_size; + ptrExtraScale = extraScale != nullptr ? (ptrExtraScale + step * 4) : nullptr; + } while(realDstCount > 0); + } } }; - MNN_CONCURRENCY_BEGIN(tId, mThreadNums) { - threadFunction((int)tId); + + if (!mSplitByOc) { + MNN_CONCURRENCY_BEGIN(tId, mThreadNums) { + ThreadFunction((int)tId, mDivides[tId], mDivides[tId + 1], 1, 0); + } + MNN_CONCURRENCY_END(); + } else { + MNN_CONCURRENCY_BEGIN(tId, mThreadNums) { + int ocIndex = PackUnit * mDivides[tId]; + ThreadFunction((int)tId, 0, mTileCount,1, ocIndex); + } + MNN_CONCURRENCY_END(); } - MNN_CONCURRENCY_END(); - // MNN_PRINT("dense conv2d int8 execute: cost time: %llu us\n", kernelTimer.durationInUs()); + return NO_ERROR; } diff --git a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp index 685e0088b..ec2d78393 100644 --- a/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp +++ b/source/backend/cpu/compute/ConvInt8TiledExecutor.hpp @@ -11,13 +11,14 @@ #include "backend/cpu/CPUConvolution.hpp" #include "Int8FunctionsOpt.h" +#include "CommonOptFunction.h" namespace MNN { class ConvInt8TiledExecutor : public CPUConvolution { public: // given weight+bias+scale, do post process - ConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* convOp, std::shared_ptr res); + ConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr res); virtual ~ConvInt8TiledExecutor(); virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; @@ -29,7 +30,8 @@ class ConvInt8TiledExecutor : public CPUConvolution { int mTileCount; int mThreadNums; std::shared_ptr mTempIm2ColBuffer; - std::shared_ptr mResource; + std::shared_ptr mResourceInt8; + // std::shared_ptr mResource; CPUConvolution::MutableResourceInt8 mMutableResource; MemChunk mBlitInfo; std::pair mBlitInfoStride; @@ -48,16 +50,35 @@ class ConvInt8TiledExecutor : public CPUConvolution { class DenseConvInt8TiledExecutor : public ConvInt8TiledExecutor { public: // given weight+bias+scale, do post process - DenseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr res); + DenseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr res, bool dynamicQuantExe); virtual ~DenseConvInt8TiledExecutor(); virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; void getPackParameter(int* Unit, int* SrcUnit, int* DestUnit, const CoreInt8Functions* core) override; private: - DenseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const DenseConvInt8TiledExecutor& exe); + DenseConvInt8TiledExecutor(Backend* backend, const Convolution2D* common, bool dynamicQuantExe, const DenseConvInt8TiledExecutor& exe); decltype(CoreInt8Functions::Int8GemmKernel) mGemmKernel; + std::function mQuantFunc; + std::function mQuantAndReorderFunc = nullptr; + std::function mSumByAxisLFunc; + std::shared_ptr mQuantInput; + std::shared_ptr mDynamicBias; + std::shared_ptr mScaleFuse; + std::shared_ptr mBatchQuantInfo; + std::shared_ptr mInputDeqScales; + std::shared_ptr mTempMaxMinValueBuffer; + std::shared_ptr mResource; + std::vector mTempSrcSum; + std::vector mDivides; + + int mThreadNums; + int mBlockNum; + int mOcPerThread; + bool mDynamicQuantExe; + bool mSplitByOc; + bool mUseBatchQuan; }; } // namespace MNN diff --git a/source/backend/cpu/compute/ConvInt8Winograd.cpp b/source/backend/cpu/compute/ConvInt8Winograd.cpp index 180320e58..2d0a4b5f2 100644 --- a/source/backend/cpu/compute/ConvInt8Winograd.cpp +++ b/source/backend/cpu/compute/ConvInt8Winograd.cpp @@ -31,7 +31,7 @@ std::shared_ptr ConvInt8Winograd::makeWinoResour std::shared_ptr weight, offsets, scales, inputScales; weight.reset(Tensor::createDevice({alpha2, oc4, ic4, UNIT, SRC_UNIT})); - offsets.reset(Tensor::createDevice({alpha2, oc4, UNIT})); + offsets.reset(Tensor::createDevice({alpha2, oc4, UNIT})); scales.reset(Tensor::createDevice({alpha2, oc4 * UNIT})); inputScales.reset(Tensor::createDevice({alpha2, UNIT})); @@ -47,7 +47,7 @@ std::shared_ptr ConvInt8Winograd::makeWinoResour return nullptr; } ::memset(weight->host(), 0, weight->size()); - ::memset(offsets->host(), 0, offsets->size()); + ::memset(offsets->host(), 0, offsets->size()); ::memset(scales->host(), 0, scales->size()); auto inputScaleData = (const float*)attr; attr += alpha2; auto inputPointData = (const int32_t*)attr; attr += alpha2; @@ -80,7 +80,9 @@ std::shared_ptr ConvInt8Winograd::makeWinoResour for (int a = 0; a < alpha2; ++a) { for (int oz = 0; oz < oc; ++oz) { - int oz4 = oz / UNIT, ozRemain = oz % UNIT, offset = 0; + int oz4 = oz / UNIT, ozRemain = oz % UNIT; + int offset_int32 = 0; + float offset = 0.f; float scale = weightScaleData[a * oc + oz]; for (int sz = 0; sz < ic; ++sz) { int sz4 = sz / SRC_UNIT, szRemain = sz % SRC_UNIT; @@ -95,7 +97,7 @@ std::shared_ptr ConvInt8Winograd::makeWinoResour offset += quanData * (-128); #endif } - offsets->host()[a * oc4 * UNIT + oz] = offset; + offsets->host()[a * oc4 * UNIT + oz] = offset * scale * inputScaleData[a]; scales->host()[a * oc4 * UNIT + oz] = scale * inputScaleData[a]; } } @@ -178,8 +180,10 @@ ErrorCode ConvInt8Winograd::onResize(const std::vector &inputs, const } auto core = static_cast(backend())->int8Functions(); + auto gcore = static_cast(backend())->functions(); int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); + UNIT = gcore->pack; auto input = mInputFloat.get(), output = outputs[0]; int batch = input->batch(), ic = input->channel(), oc = output->channel(); @@ -219,6 +223,7 @@ static void mergeAddBiasScaleQuantize(const std::vector& inputs, Tensor auto coreInt8 = cpuBn->int8Functions(); int UNIT, SRC_UNIT, DST_XUNIT; coreInt8->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); + UNIT = core->pack; int countC4 = UP_DIV(output->channel(), UNIT), plane = output->height() * output->width() * output->batch(); auto mergeFloat = inputs[0]->host(); @@ -226,7 +231,7 @@ static void mergeAddBiasScaleQuantize(const std::vector& inputs, Tensor core->MNNMatrixAdd(mergeFloat, mergeFloat, inputs[i]->host(), plane * countC4, 0, 0, 0, 1); } std::vector fakeScale(countC4 * UNIT, 1); - core->MNNScaleAndAddBias(mergeFloat, mergeFloat, (const float*)quanParam->bias, fakeScale.data(), plane, countC4); + core->MNNScaleAndAddBias(mergeFloat, mergeFloat, quanParam->biasFloat, fakeScale.data(), plane, countC4); coreInt8->MNNFloat2Int8(mergeFloat, output->host(), plane * countC4, quanParam->scale, quanParam->minValue, quanParam->maxValue, zeroPoint); } @@ -274,8 +279,10 @@ static void _reorderCommon(float* dst, const float* src, size_t area, size_t dep ErrorCode ConvInt8Winograd::onExecute(const std::vector &inputs, const std::vector &outputs) { auto bn = static_cast(backend()); auto core = bn->int8Functions(); + auto gcore = bn->functions(); int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); + UNIT = gcore->pack; // scale, zero, min, max auto inputQuant = TensorUtils::getQuantInfo(inputs[0]); auto outputQuant = TensorUtils::getQuantInfo(outputs[0]); @@ -308,7 +315,7 @@ ErrorCode ConvInt8Winograd::onExecute(const std::vector &inputs, const scale.assign(UNIT, 1.0 / outputQuant[0]); quanParam.scale = scale.data(); // For winograd Int8, will not treat origin bias to int32, use float directly - quanParam.bias = mResource->mOriginBias->host(); + quanParam.biasFloat = mResource->mOriginBias->host(); quanParam.maxValue = outputQuant[3]; if (mResource->mRelu) { quanParam.minValue = outputQuant[1]; @@ -322,9 +329,11 @@ ErrorCode ConvInt8Winograd::onExecute(const std::vector &inputs, const ConvInt8Winograd::WinoExecution::WinoExecution(std::shared_ptr res, int kernelY, int kernelX, int unitY, int unitX, int outputCount, int inputCount) : Execution(res->backend), mWinoResource(res), mUnitY(unitY), mUnitX(unitX), mKernelY(kernelY), mKernelX(kernelX) { auto core = static_cast(res->backend)->int8Functions(); + auto gcore = static_cast(res->backend)->functions(); int UNIT, SRC_UNIT, DST_XUNIT; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); + UNIT = gcore->pack; int threadNumber = ((CPUBackend *)backend())->threadNumber(); int alphaY = mUnitY + mKernelY - 1, alphaX = mUnitX + mKernelX - 1, alpha2 = alphaY * alphaX; @@ -364,6 +373,7 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector bool conv1d = (alphaY == 1 || alphaX == 1); int UNIT, SRC_UNIT, DST_XUNIT; coreInt8->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); + UNIT = core->pack; auto gemmFunc = coreInt8->Int8GemmKernel; CoreFunctions::WinoUnrollTransFunc srcTransXFunc = nullptr, srcTransYFunc = nullptr; @@ -477,6 +487,10 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector auto dstOrigin = output->host(); auto weight = mWinoResource->weight->host(); + std::vector xkernelSum(DST_XUNIT, 0); + std::vector wKernelSum(dc_4 * UNIT, 0); + std::vector reluThred = {-std::numeric_limits().max(), std::numeric_limits().max()}; + auto tFunction = [&](int tId) { auto _srcOrigin = mTempInputBuffer->host() + tId * mTempInputBuffer->stride(0); auto _dstOrigin = mTempOutputBuffer->host() + tId * mTempOutputBuffer->stride(0); @@ -507,9 +521,13 @@ ErrorCode ConvInt8Winograd::WinoExecution::onExecute(const std::vector auto _dstFloatPtr = _dstOrigin + i * dc_4 * xC * UNIT; auto _weightInt8Ptr = weight + i * mWinoResource->weight->stride(0); QuanPostTreatParameters quanParam; - quanParam.bias = mWinoResource->offsets->host() + i * mWinoResource->offsets->stride(0); + quanParam.biasFloat = (mWinoResource->offsets->host() + i * mWinoResource->offsets->stride(0)); quanParam.useInt8 = 0; + quanParam.srcKernelSum = xkernelSum.data(); + quanParam.weightQuanBias = wKernelSum.data(); + quanParam.fp32minmax = reluThred.data(); quanParam.scale = mWinoResource->scales->host() + i * dc_4 * UNIT; + quanParam.extraScale = nullptr; gemmFunc((int8_t*)_dstFloatPtr, _srcInt8Ptr, _weightInt8Ptr, mTempInputBuffer->length(2), xC * UNIT * sizeof(float), dc_4, &quanParam, xC); } #ifndef MNN_WINO_TRANFORM_TEST_CLOSE diff --git a/source/backend/cpu/compute/Convolution1x1Strassen.cpp b/source/backend/cpu/compute/Convolution1x1Strassen.cpp index deeec58e4..3ed5c0c6e 100644 --- a/source/backend/cpu/compute/Convolution1x1Strassen.cpp +++ b/source/backend/cpu/compute/Convolution1x1Strassen.cpp @@ -101,7 +101,7 @@ ErrorCode Convolution1x1Strassen::onResize(const std::vector &inputs, auto CONVOLUTION_TILED_NUMBER = ePack; auto input = inputs[0]; auto output = outputs[0]; - int numberThread = ((CPUBackend *)backend())->threadNumber(); + const int numberThread = ((CPUBackend *)backend())->threadNumber(); auto ic = input->channel(); auto oc = output->channel(); auto icC4 = UP_DIV(ic, core->pack); @@ -133,13 +133,15 @@ ErrorCode Convolution1x1Strassen::onResize(const std::vector &inputs, } #endif mWeightBytes = static_cast(dequantBits) / 8.0f; + auto rt = static_cast(backend()->getRuntime()); if (matrixSizeE > CONVOLUTION_TILED_NUMBER * 8 * numberThread && matrixSizeE > ocC4) { - // Divide in plane, in this case the divide equal numberThread - int divideStep = UP_DIV(matrixSizeE, numberThread); + std::vector divides(numberThread+1); + divides[0] = 0; + rt->computeDivideSizes(matrixSizeE, divides.data()+1); mUnits.resize(numberThread); for (int i = 0; i < numberThread; ++i) { - int planeStart = i * divideStep; - int planeEnd = std::min(planeStart + divideStep, matrixSizeE); + int planeStart = divides[i]; + int planeEnd = divides[i+1]; int planeSize = planeEnd - planeStart; Unit &unit = mUnits[i]; if (planeSize <= 0) { @@ -173,15 +175,17 @@ ErrorCode Convolution1x1Strassen::onResize(const std::vector &inputs, hDiv = hPack / core->pack; } auto ocDiv = UP_DIV(ocC4, hDiv); - numberThread = std::min(numberThread, ocDiv); - int divideStep = (ocDiv / numberThread) * hDiv; + std::vector divides(numberThread+1); + divides[0] = 0; + rt->computeDivideSizes(ocDiv, divides.data()+1); mUnits.resize(numberThread); for (int i = 0; i < numberThread; ++i) { - int ocStart = i * divideStep; - int ocSize = divideStep; - if (i == numberThread - 1) { - ocSize = ocC4 - i * divideStep; + int ocStart = divides[i] * hDiv; + int ocEnd = divides[i+1] * hDiv; + if (ocEnd >= ocC4) { + ocEnd = ocC4; } + int ocSize = ocEnd - ocStart; Unit &unit = mUnits[i]; if (ocSize <= 0) { unit.mValid = false; diff --git a/source/backend/cpu/compute/ConvolutionDepthwise3x3.cpp b/source/backend/cpu/compute/ConvolutionDepthwise3x3.cpp index d70a8f1d1..46fc68048 100644 --- a/source/backend/cpu/compute/ConvolutionDepthwise3x3.cpp +++ b/source/backend/cpu/compute/ConvolutionDepthwise3x3.cpp @@ -96,7 +96,7 @@ bool ConvolutionDepthwise3x3::onClone(Backend* bn, const Op* op, Execution** dst ErrorCode ConvolutionDepthwise3x3::onResize(const std::vector &inputs, const std::vector &outputs) { CPUConvolution::onResize(inputs, outputs); - int numberThread = ((CPUBackend *)backend())->threadNumber(); + const int numberThread = ((CPUBackend *)backend())->threadNumber(); auto output = outputs[0]; auto owUnit = UP_DIV(output->width(), 2); auto core = static_cast(backend())->functions(); @@ -113,6 +113,15 @@ ErrorCode ConvolutionDepthwise3x3::onResize(const std::vector &inputs, mPostParameters = getPostParameters(); // auto rate = (float)(mSourceEndX-mSourceStartX) / (float)owUnit; // FUNC_PRINT_ALL(rate, f); + + int channelC4 = UP_DIV(inputs[0]->channel(), core->pack); + int batch = inputs[0]->batch(); + auto total = channelC4 * batch; + + mDivides.resize(numberThread+1); + mDivides[0] = 0; + static_cast(backend()->getRuntime())->computeDivideSizes(total, mDivides.data() + 1); + return NO_ERROR; } @@ -141,12 +150,11 @@ ErrorCode ConvolutionDepthwise3x3::onExecute(const std::vector &inputs int threadNumber = ((CPUBackend *)backend())->threadNumber(); auto maxKernelH = std::min(mPadY + ih, 3); - auto total = channelC4 * batch; auto inputOrigin = input->host(); auto outputOrigin = output->host(); MNN_CONCURRENCY_BEGIN(tId, threadNumber) { auto cacheLineStart = mCacheLine->host() + tId * mCacheLine->stride(0); - for (int index = (int)tId; index < total; index += threadNumber) { + for (int index = mDivides[tId]; index < mDivides[tId+1]; ++index) { int z = index / batch; auto biasPtr = (const float*)(mResource->mBias->host() + core->bytes * core->pack * z); auto inputZ = inputOrigin + core->pack * index * iw * ih * core->bytes; diff --git a/source/backend/cpu/compute/ConvolutionDepthwise3x3.hpp b/source/backend/cpu/compute/ConvolutionDepthwise3x3.hpp index 319021bb3..4ff4d4ef0 100644 --- a/source/backend/cpu/compute/ConvolutionDepthwise3x3.hpp +++ b/source/backend/cpu/compute/ConvolutionDepthwise3x3.hpp @@ -30,6 +30,7 @@ class ConvolutionDepthwise3x3 : public CPUConvolution { int mSourceStartX = 0; int mSourceEndX = 0; std::vector mPostParameters; + std::vector mDivides; }; } // namespace MNN diff --git a/source/backend/cpu/compute/ConvolutionFloatFactory.cpp b/source/backend/cpu/compute/ConvolutionFloatFactory.cpp index 5e85a184c..40a444696 100644 --- a/source/backend/cpu/compute/ConvolutionFloatFactory.cpp +++ b/source/backend/cpu/compute/ConvolutionFloatFactory.cpp @@ -15,13 +15,13 @@ #include "backend/cpu/compute/ConvolutionWinogradBridge.hpp" #include "backend/cpu/compute/DenseConvolutionTiledExecutor.hpp" -#include "backend/cpu/compute/ConvolutionHybrid.hpp" #ifdef MNN_USE_SPARSE_COMPUTE #include "backend/cpu/compute/SparseConvolutionTiledExecutor.hpp" #endif #include "core/Macro.h" #include "core/OpCommonUtils.hpp" #include "backend/cpu/OneDNNConvolution.hpp" +#include "backend/cpu/compute/ConvInt8TiledExecutor.hpp" namespace MNN { @@ -35,7 +35,7 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend #ifdef MNN_USE_SPARSE_COMPUTE if (conv2d->sparseParameter() && nullptr != weightQuantInfo.get()) { - if (supportSparse) { + if (supportSparse && weightQuantInfo->quan->index() != nullptr) { return new SparseConvolutionTiledExecutor(common, backend, weightQuantInfo->quan, conv2d->sparseParameter(), bias, biasSize); } @@ -46,13 +46,15 @@ static Execution* _createUnit(const Tensor* input, const Tensor* output, Backend && common->strideX() == 1 && common->strideY() == 1; if (lowMemory && nullptr != weightQuantInfo.get() && originWeightSize == 0) { - if (cpuBackend->memoryMode() == BackendConfig::Memory_Low && fastWay) { - return new ConvolutionHybrid(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo); + if (cpuBackend->memoryMode() == BackendConfig::Memory_Low) { + auto core = static_cast(backend)->functions(); + auto resourceInt8 = CPUConvolution::makeResourceInt8(backend, conv2d, core->pack); + return new DenseConvInt8TiledExecutor(backend, conv2d, resourceInt8, true); } else { return new DenseConvolutionTiledExecutor(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo); } } - if (fastWay) { + if (fastWay && cpuBackend->functions()->matmulBytes == 0) { return new Convolution1x1Strassen(common, backend, originWeight, originWeightSize, bias, biasSize, weightQuantInfo); } if (originWeightSize == 0) { @@ -78,7 +80,7 @@ Execution* ConvolutionFloatFactory::create(const std::vector& inputs, c return new ConvolutionTiledExecutorMultiInput(conv2d->common(), backend); } #ifdef MNN_LOW_MEMORY - bool lowMemory = static_cast(backend)->memoryMode() != BackendConfig::Memory_High; + bool lowMemory = static_cast(backend)->memoryMode() != BackendConfig::Memory_High && static_cast(backend)->functions()->MNNPackedMatMul_int8 != nullptr; #else bool lowMemory = false; #endif diff --git a/source/backend/cpu/compute/ConvolutionHybrid.cpp b/source/backend/cpu/compute/ConvolutionHybrid.cpp deleted file mode 100644 index bf4f24c31..000000000 --- a/source/backend/cpu/compute/ConvolutionHybrid.cpp +++ /dev/null @@ -1,401 +0,0 @@ -// -// ConvolutionHybrid.cpp -// MNN -// -// Created by MNN on 2023/10/26. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#include "ConvolutionHybrid.hpp" -#include -#include "core/BufferAllocator.hpp" -#include "backend/cpu/CPUBackend.hpp" -#include "core/Concurrency.h" -#include "ConvOpt.h" -#include "core/Macro.h" -#include "CommonOptFunction.h" -#include "core/TensorUtils.hpp" -#include -#include "backend/cpu/compute/DenseConvolutionTiledExecutor.hpp" - -namespace MNN { - -bool ConvolutionHybrid::initQuantizeResource(std::shared_ptr int8Info, std::shared_ptr resource, int hU, int hP, int lU, int lP, int outputCount, int srcChannel, int kernelSize, int bytes) { - int weightLength = hU * lU * hP * lP; - resource->mWeight.reset(Tensor::createDevice( - {weightLength})); - auto res = resource->backend->onAcquireBuffer(resource->mWeight.get(), Backend::STATIC); - if (!res) { - return false; - } - resource->mDequantize.bits = 8; - resource->hU = hU; - resource->lU = lU; - resource->hP = hP; - resource->lP = lP; - - // Save scale bias - resource->mDequantize.mScaleBias.reset(MNN::Tensor::createDevice({hU * hP * 2})); - res = resource->backend->onAcquireBuffer(resource->mDequantize.mScaleBias.get(), Backend::STATIC); - if (!res) { - return false; - } - auto alphaPtr = resource->mDequantize.mScaleBias->host(); - auto biasPtr = reinterpret_cast(reinterpret_cast(alphaPtr) + hU * hP * bytes); - ::memset(alphaPtr, 0, 2 * hU * hP * bytes); - int h = int8Info->alpha.size(); - if (int8Info->canUseInt4 && int8Info->asymmetric) { - // int4 to uint4, -8 offset merge to bias - for (int i = 0; i < h/2; ++i) { - int8Info->alpha.get()[2 * i] -= 8 * int8Info->alpha.get()[2 * i + 1]; - } - } - if (bytes == 2) { - auto core = static_cast(resource->backend)->functions(); - if (int8Info->asymmetric) { - std::unique_ptr tmp(new int16_t[h]); - core->MNNFp32ToLowp(int8Info->alpha.get(), tmp.get(), h); - for (int i=0; i< h/2; ++i) { - reinterpret_cast(alphaPtr)[i] = tmp[2 * i + 1]; - reinterpret_cast(biasPtr)[i] = tmp[2 * i]; - } - } else { - core->MNNFp32ToLowp(int8Info->alpha.get(), reinterpret_cast(alphaPtr), h); - if (int8Info->canUseInt4) { - for (int i = 0; i < h; ++i) { - int8Info->alpha.get()[i] *= -8.0; - } - core->MNNFp32ToLowp(int8Info->alpha.get(), reinterpret_cast(biasPtr), h); - } - } - } else { - if (int8Info->asymmetric) { - h = h / 2; - for (int i=0; ialpha.get()[2 * i + 1]; - biasPtr[i] = int8Info->alpha.get()[2 * i]; - } - } else { - for (int i=0; ialpha.get()[i]; - if (int8Info->canUseInt4) { - biasPtr[i] = -8.0 * int8Info->alpha.get()[i]; - } else { - biasPtr[i] = 0.f; - } - } - } - } - std::vector data(weightLength, 0); - auto srcWInt8 = int8Info->weight.get(); - if (hP * hU != outputCount || lP * lU != srcChannel) { - int packedic = lU * lP; - for (int i = 0; i < outputCount; ++i) { - for (int j = 0; j < srcChannel; ++j) { - int destIdx = i * packedic + j; - int srcIdx = i * srcChannel + j; - data[destIdx] = srcWInt8[srcIdx]; - } - } - srcWInt8 = data.data(); - } - if (int8Info->canUseInt4) { - MNN_ASSERT(weightLength % 2 == 0); - weightLength = UP_DIV(weightLength, 2); - resource->mDequantize.bits = 4; - - auto srcPtr = int8Info->weight.get(); - auto dstPtr = resource->mWeight->host(); - // oc, ic -> oc/hP, ic/lP, hP, lP - if (hP == 8 && lP == 8) { - for (int i = 0; i < hU; i++) { - for (int j = 0; j < lU; j++) { - for (int k = 0; k < 2; k++) { - for (int n = 0; n < 16; n++) { - int hp_idx = n / 8; - int lp_idx = n % 8; - int s0 = srcWInt8[(i * hP + k * 4 + hp_idx) * lP *lU + (j * lP + lp_idx)]; - int s1 = srcWInt8[(i * hP + k * 4 + hp_idx + 2) * lP * lU + (j * lP + lp_idx)]; - int d = (s0 + 8) * 16 + (s1 + 8); - dstPtr[(i * lU * lP * hP + j * hP * lP + k * 32) / 2 + n] = (uint8_t)d; - } - } - } - } - } else { - for (int i = 0; i < hU; i++) { - for (int j = 0; j < lU; j++) { - for (int k = 0; k < hP; k++) { - for (int l = 0; l < lP; l+=2) { - int s0 = srcWInt8[(i * hP + k) * lP * lU + (j * lP + l)]; - int s1 = srcWInt8[(i * hP + k) * lP * lU + (j * lP + l + 1)]; - int d = (s0 + 8) * 16 + (s1 + 8); - dstPtr[(i * lU * lP * hP + j * hP * lP + k * lP + l) / 2] = d; - } - } - } - } - } - } else { - // Reorder weight for int8 - auto dstWInt8 = resource->mWeight->host(); - // oc, ic -> oc/hP, ic/lP, hP, lP - for (int i = 0; i < hU; i++) { - for (int j = 0; j < lU; j++) { - for (int k = 0; k < hP; k++) { - for (int l = 0; l < lP; l++) { - dstWInt8[i * lU * lP * hP + j * hP * lP + k * lP + l] = srcWInt8[(i * hP + k) * lP * lU + (j * lP + l)]; - } - } - } - } - } - return true; -} - -ConvolutionHybrid::ConvolutionHybrid(const Convolution2DCommon *common, Backend *b, const float *originWeight, - size_t originWeightSize, const float *bias, size_t biasSize, std::shared_ptr quantInfo) - : CPUConvolution(common, b) { - mResource.reset(new CPUConvolution::Resource); - mResource->backend = b; - if (!mResource->copyBiasAlign(bias, (int)biasSize)) { - MNN_ERROR("Not Enough Memory\n"); - mValid = false; - return; - } - MNN_ASSERT(nullptr != quantInfo.get()); - originWeightSize = quantInfo->weight.size(); - auto outputCount = (int)biasSize; - int inputCount = (int)originWeightSize / (int)biasSize * common->kernelX() * common->kernelY(); - auto core = static_cast(b)->functions(); - auto int8_core = static_cast(backend())->int8Functions(); - int unit = core->pack; - int ePack, lPack, hPack; - core->MNNGetMatMulPackMode(&ePack, &lPack, &hPack); - // printf("ePack, lPack, hPack = %d, %d, %d\n", ePack, lPack, hPack); - // printf("UNIT, SRC_UNIT, DST_XUNIT = %d, %d, %d\n", UNIT, SRC_UNIT, DST_XUNIT); - hPack = unit; - lPack = unit; - // [oc, ic] => [oc/unit, ic/src_unit, unit, src_unit] - if (unit == 4 && core->supportI8mm) { // Low Memory: use fp32 and smmla. - hPack = 8; - lPack = 8; - } - auto hU = UP_DIV(outputCount, hPack); - auto lU = UP_DIV(inputCount, lPack); - ConvolutionHybrid::initQuantizeResource(quantInfo, mResource, hU, hPack, lU, lPack, outputCount, (int)originWeightSize / (int)biasSize, common->kernelX() * common->kernelY(), core->bytes); -} - -ConvolutionHybrid::ConvolutionHybrid(std::shared_ptr resource, const Convolution2DCommon *common, Backend* b) : CPUConvolution(common, b) { - mResource = resource; -} - -ConvolutionHybrid::~ConvolutionHybrid() { - // Do nothing -} - -bool ConvolutionHybrid::onClone(Backend* bn, const Op* op, Execution** dst) { - if (!mValid) { - return false; - } - if (nullptr == dst) { - return true; - } - *dst = new ConvolutionHybrid(mResource, op->main_as_Convolution2D()->common(), bn); - return true; -} - -ErrorCode ConvolutionHybrid::allocTensor(Tensor* tensor, size_t size) { - tensor->buffer().type = halide_type_of(); - tensor->buffer().dimensions = 1; - tensor->buffer().dim[0].extent = size; - bool success = backend()->onAcquireBuffer(tensor, Backend::DYNAMIC); - if (!success) { - return OUT_OF_MEMORY; - } - return NO_ERROR; -} - -ErrorCode ConvolutionHybrid::allocDynamicQuantInfo(int thread, int batch, int ic, int oc, int bytes) { - // absmax: thread * batch * bytes - // sum: thread * batch * sizeof(int) - // dequant_scale: batch * bytes - // quant_scale: batch * bytes - allocTensor(&mQuantInfo.quant_info, (thread + 2) * batch * bytes + thread * batch * sizeof(int)); - if (ANeedToPack8) { - int ic8 = UP_DIV(ic, 8) * 8; - int oc8 = UP_DIV(oc, 8) * 8; - mInputTemp.reset(Tensor::createDevice({batch, 1, 1, ic8})); - mOutputTemp.reset(Tensor::createDevice({batch, 1, 1, oc8})); - bool allocSucc = backend()->onAcquireBuffer(mInputTemp.get(), Backend::DYNAMIC); - allocSucc = allocSucc && backend()->onAcquireBuffer(mOutputTemp.get(), Backend::DYNAMIC); - if (!allocSucc) { - return OUT_OF_MEMORY; - } - allocTensor(&mQuantInfo.quant_buffer, batch * ic8); - backend()->onReleaseBuffer(mInputTemp.get(), Backend::DYNAMIC); - backend()->onReleaseBuffer(mOutputTemp.get(), Backend::DYNAMIC); - } else { - allocTensor(&mQuantInfo.quant_buffer, batch * ic); - } - backend()->onReleaseBuffer(&mQuantInfo.quant_info, Backend::DYNAMIC); - backend()->onReleaseBuffer(&mQuantInfo.quant_buffer, Backend::DYNAMIC); - return NO_ERROR; -} - -ErrorCode ConvolutionHybrid::onResize(const std::vector &inputs, const std::vector &outputs) { - CPUConvolution::onResize(inputs, outputs); - auto input = inputs[0]; - auto output = outputs[0]; - auto core = static_cast(backend())->functions(); - auto int8_core = static_cast(backend())->int8Functions(); - auto inputPtr = input->host(); - auto outputPtr = output->host(); - auto weightPtr = mResource->mWeight->host(); - auto biasPtr = mResource->mBias->host(); - auto batch = output->batch() * output->height() * output->width(); - int ic = input->channel(); - int oc = output->channel(); - int bytes = core->bytes; - int unit = core->pack; - int eP, lP, hP; - core->MNNGetMatMulPackMode(&eP, &lP, &hP); - int UNIT, SRC_UNIT, DST_XUNIT; - int8_core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DST_XUNIT); - hP = unit; - lP = unit; - int tileC = std::max(unit, hP); - LowMemoryGemmFuncWithInt8Weight gemmKernel; - gemmKernel = core->MNNGemmHybridInt8; - float weightBytes = 1; - if (mResource->mDequantize.bits == 4) { - weightBytes = 0.5; - gemmKernel = core->MNNGemmHybridInt4; - } - - const uint8_t* dequantAlpha = mResource->mDequantize.mScaleBias->host();; - const uint8_t* dequantBias = dequantAlpha + mResource->hU * mResource->hP * bytes;; - int threadNumber = ((CPUBackend *)backend())->threadNumber(); - auto oC4 = UP_DIV(oc, tileC); - int iC4 = UP_DIV(ic, unit); - if (iC4 < threadNumber || oC4 < threadNumber) { - threadNumber = std::min(oC4, iC4); - } - int tileCount = UP_DIV(oC4, threadNumber); - int iTileCount = UP_DIV(iC4, threadNumber); - if (unit == 4 && core->supportI8mm) { // Low Memory: use fp32 and smmla. - ANeedToPack8 = true; - } - int8_t order[32] = {0, 1, 2, 3, 12, 13, 14, 15, 16, 17, 18, 19, 28, 29, 30, 31, 8, 9, 10, 11, 4, 5, 6, 7, 24, 25, 26, 27, 20, 21, 22, 23}; - allocDynamicQuantInfo(threadNumber, batch, ic, oc, bytes); - mDynamicQuant = [=]() { - auto maxPtr = mQuantInfo.quant_info.host(); - auto sumPtr = maxPtr + threadNumber * batch * bytes; - auto dequantPtr = sumPtr + threadNumber * batch * sizeof(int); - auto quantPtr = dequantPtr + batch * bytes; - // compute sum and absmax - MNN_CONCURRENCY_BEGIN(tId, threadNumber) { - int workCount = iTileCount; - if (tId == threadNumber - 1) { - workCount = iC4 - tId * iTileCount; - } - int icIndex = tId * iTileCount; - auto input_ptr = reinterpret_cast(input->host() + icIndex * batch * unit * bytes); - auto max_ptr = reinterpret_cast(maxPtr + tId * batch * bytes); - core->MNNAbsMax(input_ptr, max_ptr, workCount, batch, unit); - } - MNN_CONCURRENCY_END(); - // compute scale - core->MNNQuantScale((float*)maxPtr, (float*)quantPtr, (float*)dequantPtr, threadNumber, batch); - // quant - MNN_CONCURRENCY_BEGIN(tId, threadNumber) { - int workCount = iTileCount; - if (tId == threadNumber - 1) { - workCount = iC4 - tId * iTileCount; - } - int icIndex = tId * iTileCount; - auto input_ptr = reinterpret_cast(input->host() + icIndex * batch * unit * bytes); - auto quant_ptr = mQuantInfo.quant_buffer.host() + icIndex * batch * unit; - auto scale_ptr = reinterpret_cast(quantPtr); - auto sum_ptr = reinterpret_cast(sumPtr + tId * batch * sizeof(int)); - core->MNNDynamicQuant(input_ptr, quant_ptr, scale_ptr, sum_ptr, workCount, batch, unit); - } - MNN_CONCURRENCY_END(); - // compute quant sum - core->MNNQuantSum((float*)sumPtr, (float*)dequantPtr, threadNumber, batch); - }; - mFunction.first = threadNumber; - mFunction.second = [=](int tId){ - int workCount = tileCount; - if (tId == threadNumber - 1) { - workCount = oC4 - tId * tileCount; - } - int unit_ = unit; - int tileCount_ = tileCount; - if (ANeedToPack8) { - int oC8 = UP_DIV(oc, 8); - tileCount_ = UP_DIV(oC8, threadNumber); - workCount = tileCount_; - if (tId == threadNumber - 1) { - workCount = oC8 - tId * tileCount_; - } - unit_ = 8; - } - - int ocIndex = tId * tileCount_ * unit_; - const float* finput_ptr = input->host(); - const int8_t* input_ptr = mQuantInfo.quant_buffer.host(); - const int8_t* input_ptr_tmp = mQuantInfo.quant_buffer.host(); - auto weight_ptr = mResource->mWeight->host() + static_cast(ocIndex * ic * weightBytes); - auto output_ptr = reinterpret_cast(outputs[0]->host() + ocIndex * batch * bytes); - if (ANeedToPack8 && batch > 1) { - input_ptr = mInputTemp->host(); - output_ptr = reinterpret_cast(mOutputTemp->host() + ocIndex * batch * bytes); - } - auto bias_ptr = reinterpret_cast(mResource->mBias->host() + ocIndex * bytes); - auto alpha_ptr = reinterpret_cast(dequantAlpha + ocIndex * bytes); - auto zero_ptr = reinterpret_cast(dequantBias + ocIndex * bytes); - const uint8_t* max_ptr = mQuantInfo.quant_info.host(); - const float* sums_ptr = reinterpret_cast(max_ptr + threadNumber * batch * bytes); - const float* scale_ptr = reinterpret_cast(max_ptr + threadNumber * batch * (bytes + sizeof(int))); - size_t dst_depth_quad = workCount; - size_t src_depth_quad = UP_DIV(ic, unit_); - size_t dst_step = batch * unit_ * bytes; - size_t realSize = batch; - const float* param[6]; - param[0] = alpha_ptr; - param[1] = zero_ptr; - param[2] = bias_ptr; - param[3] = sums_ptr; - param[4] = scale_ptr; - param[5] = (float*)order; - gemmKernel(output_ptr, input_ptr, weight_ptr, src_depth_quad, dst_step, dst_depth_quad, realSize, param); - }; - return NO_ERROR; -} - -ErrorCode ConvolutionHybrid::onExecute(const std::vector &inputs, const std::vector &outputs) { - mDynamicQuant(); - if (ANeedToPack8 && inputs[0]->batch() > 1) { - auto core = static_cast(backend())->functions(); - auto plane_in = inputs[0]->width() * inputs[0]->height() * inputs[0]->batch(); - auto plane_out = outputs[0]->width() * outputs[0]->height() * outputs[0]->batch(); - auto depth = UP_DIV(inputs[0]->channel(), core->pack); - auto output_depth = UP_DIV(outputs[0]->channel(), core->pack); - int areaOffset[2] = {plane_out, plane_out}; - MNNPackInt8C2Origin(mInputTemp.get()->host(), mQuantInfo.quant_buffer.host(), plane_in, depth, plane_in); - MNN_CONCURRENCY_BEGIN(tId, mFunction.first) { - mFunction.second((int)tId); - } - MNN_CONCURRENCY_END(); - MNNUnpackC2Float(outputs[0]->host(), mOutputTemp.get()->host(), plane_out, output_depth, areaOffset, core->pack); - return NO_ERROR; - } - - MNN_CONCURRENCY_BEGIN(tId, mFunction.first) { - mFunction.second((int)tId); - } - MNN_CONCURRENCY_END(); - return NO_ERROR; -} -} // namespace MNN diff --git a/source/backend/cpu/compute/ConvolutionHybrid.hpp b/source/backend/cpu/compute/ConvolutionHybrid.hpp deleted file mode 100644 index df260b21a..000000000 --- a/source/backend/cpu/compute/ConvolutionHybrid.hpp +++ /dev/null @@ -1,48 +0,0 @@ -// -// ConvolutionHybrid.hpp -// MNN -// -// Created by MNN on 2023/10/26. -// Copyright © 2018, Alibaba Group Holding Limited -// - -#ifndef ConvolutionHybrid_hpp -#define ConvolutionHybrid_hpp - -#include -#include "backend/cpu/CPUConvolution.hpp" - -typedef void(*LowMemoryGemmFuncWithInt8Weight)(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param); -namespace MNN { -class ConvolutionHybrid : public CPUConvolution { -public: - ConvolutionHybrid(const Convolution2DCommon *common, Backend *b, const float *originWeight, - size_t originWeightSize, const float *bias, size_t biasSize, std::shared_ptr); - ConvolutionHybrid(std::shared_ptr resource, const Convolution2DCommon *common, Backend* b); - static bool initQuantizeResource(std::shared_ptr int8Info, std::shared_ptr resource, int hU, int hP, int lU, int lP, int outputCount, int srcChannel, int kernelSize, int bytes); - - virtual ~ConvolutionHybrid(); - - virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - - virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; - virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; -private: - ErrorCode allocTensor(Tensor* tensor, size_t size); - ErrorCode allocDynamicQuantInfo(int thread, int batch, int ic, int oc, int bytes); -private: - struct DynamicQuantInfo { - Tensor quant_info; - Tensor quant_buffer; - }; - std::shared_ptr mResource; - std::function mDynamicQuant; - std::pair> mFunction; - DynamicQuantInfo mQuantInfo; - bool ANeedToPack8 = false; - std::shared_ptr mInputTemp; - std::shared_ptr mOutputTemp; -}; -} // namespace MNN - -#endif /* ConvolutionHybrid_hpp */ diff --git a/source/backend/cpu/compute/ConvolutionPackWinograd.cpp b/source/backend/cpu/compute/ConvolutionPackWinograd.cpp index 83b17c77d..7d3d3a553 100644 --- a/source/backend/cpu/compute/ConvolutionPackWinograd.cpp +++ b/source/backend/cpu/compute/ConvolutionPackWinograd.cpp @@ -32,6 +32,10 @@ ConvolutionPackWinograd::ConvolutionPackWinograd(const Convolution2DCommon *conv int unit = config.unit; auto core = static_cast(backend())->functions(); int pack = core->pack, bytes = core->bytes; + int weightBytes = bytes; + if (0!=core->matmulBytes) { + weightBytes = core->matmulBytes; + } mResource.reset(new Resource); mResource->backend = b; @@ -83,14 +87,14 @@ ConvolutionPackWinograd::ConvolutionPackWinograd(const Convolution2DCommon *conv auto tempWeight = generator.allocTransformWeight(sourceWeight.get(), lPack, hPack, true); auto shape = tempWeight->shape(); - shape.push_back(bytes); + shape.push_back(weightBytes); mResource->mWeight.reset(Tensor::createDevice(shape)); mValid = backend()->onAcquireBuffer(mResource->mWeight.get(), Backend::STATIC); if (!mValid) { return; } generator.transformWeight(tempWeight.get(), sourceWeight.get(), true); - if (bytes != 4) { + if (weightBytes != 4) { core->MNNFp32ToLowp(tempWeight->host(), mResource->mWeight->host(), tempWeight->elementSize()); } else { ::memcpy(mResource->mWeight->host(), tempWeight->host(), tempWeight->size()); @@ -143,7 +147,11 @@ WinogradConfig ConvolutionPackWinograd::bestWinogradUnit(const Convolution2DComm auto core = static_cast(b)->functions(); - auto winogradMemoryLevel = static_cast(b)->getRuntime()->getWinogradMemoryLevel(); + auto winogradMemoryLevel = static_cast(b)->getRuntime()->hint().winogradMemoryUsed; + int multiBytes = static_cast(b)->functions()->bytes; + if (static_cast(b)->functions()->matmulBytes != 0) { + multiBytes = static_cast(b)->functions()->matmulBytes; + } int ow = outputTensor->width(); int oh = outputTensor->height(); int oc = outputTensor->channel(); @@ -164,6 +172,9 @@ WinogradConfig ConvolutionPackWinograd::bestWinogradUnit(const Convolution2DComm float maxRate = 0.0f; float originCost = (float)ow * oh * (2.0 * ic) * oc * kernelSize * kernelSize; // macs, with bias std::set supportSu{4, 6, 8}; + if (multiBytes < 4) { + supportSu = {4, 6}; + } CoreFunctions::WinoUnrollDestTransFunc destTransform[CONVOLUTION_WINOGRAD_MAX_UNIT + 1]; for (int u = CONVOLUTION_WINOGRAD_MIN_UNIT; u <= maxUnit; ++u) { auto sui = u + kernelSize - 1; @@ -245,20 +256,11 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector &inputs, auto totalCount = wUnit * hUnit * batch; // MNN_PRINT("ow=%d, oh=%d\n", ow, oh); - int threadNumber = std::max(((CPUBackend *)backend())->threadNumber(), 1); - int tileCount = UP_DIV(totalCount, ePack); - int eRemain = totalCount % ePack; - threadNumber = std::min(threadNumber, tileCount); - std::vector parameters(6); - parameters[0] = eRemain * bytes; - parameters[1] = input->channel(); - parameters[2] = output->channel(); - parameters[3] = ePack * pack * bytes; - parameters[4] = 0; - parameters[5] = 0; - - std::vector parametersRemain = parameters; - parametersRemain[3] = eRemain * pack * bytes; + int threadNumber = ((CPUBackend*)(backend()))->threadNumber(); + + std::vector divides(threadNumber+1); + static_cast( static_cast(backend())->getRuntime())->computeDivideSizes(totalCount, divides.data()+1); + divides[0] = 0; auto midBuffer0Bytes = srcUnit2 * pack * bytes; bool allow_x86_bf16_winograd = true; #ifdef MNN_USE_SSE @@ -269,6 +271,24 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector &inputs, auto bias = mResource->mBias->host(); mMainFunction.first = threadNumber; mMainFunction.second = [=](int tId, const uint8_t* inputOrigin, uint8_t* dstOrigin) { + int tSta = divides[tId]; + int tFin = divides[tId+1]; + if (tSta >= tFin) { + return; + } + int eRemain = (tFin-tSta) % ePack; + std::vector parameters(6); + parameters[1] = input->channel(); + parameters[2] = output->channel(); + parameters[4] = 0; + parameters[5] = 0; + parameters[0] = eRemain * bytes; + parameters[3] = ePack * pack * bytes; + + std::vector parametersRemain = parameters; + parametersRemain[0] = eRemain * bytes; + parametersRemain[3] = eRemain * pack * bytes; + auto srcOrigin = inputOrigin; auto _srcOrigin = mTempBuffer->host() + tId * mTempBuffer->stride(0); auto gemmBuffer = (mGemmMidBuffer->host() + tId * mGemmMidBuffer->stride(0)); @@ -276,12 +296,11 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector &inputs, auto midBufferStride1 = mTransformMidBuffer->stride(1); auto weightStride = mResource->mWeight->stride(0); auto midBuffer1 = midBuffer0 + midBuffer0Bytes; - for (int tIndex = (int)tId; tIndex < tileCount; tIndex += threadNumber) { - int xIndex = (int)tIndex * ePack; - int xReamin = totalCount - xIndex; + for (int xIndex = tSta; xIndex < tFin; xIndex+=ePack) { + int xReamin = tFin - xIndex; int xC = xReamin > ePack ? ePack : xReamin; - const bool fuseTransformPack = (xC * FULSE_THRESHHOLD_DENOMINATOR >= FULSE_THRESHHOLD_NUMERATOR * ePack) && allow_x86_bf16_winograd && nullptr != mSourceTransformPack; + const bool fuseTransformPack = (xC * FULSE_THRESHHOLD_DENOMINATOR >= FULSE_THRESHHOLD_NUMERATOR * ePack) && allow_x86_bf16_winograd && nullptr != mSourceTransformPack && core->matmulBytes == 0; /*Source Transform Begin*/ #ifndef MNN_WINO_TRANFORM_TEST_CLOSE { @@ -519,11 +538,16 @@ ErrorCode ConvolutionPackWinograd::onResize(const std::vector &inputs, /*Dest Transform And Post Treat End*/ } }; + std::vector postDivides(threadNumber+1); + static_cast( static_cast(backend())->getRuntime())->computeDivideSizes(dc_4, postDivides.data()+1); + postDivides[0] = 0; mPostFunction.first = threadNumber; mPostFunction.second = [=](int tId, uint8_t* outputOrigin) { auto dstOrigin = outputOrigin; - for (int dy=(int)tId; dy < dc_4; dy += threadNumber) { + int tSta = postDivides[tId]; + int tFin = postDivides[tId+1]; + for (int dy=tSta; dy < tFin; ++dy) { auto dataFloatPtr = (float*)(dstOrigin + ow * oh * batch * dy * pack * bytes); auto biasFloatPtr = (const float*)(bias + pack * dy * bytes); core->MNNAxByClampBroadcastUnit(dataFloatPtr, dataFloatPtr, biasFloatPtr, ow * oh * batch, 0, 0, 1, mPostParameters.data()); diff --git a/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp b/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp index ff3a0afa1..e2e0f16bc 100644 --- a/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp +++ b/source/backend/cpu/compute/ConvolutionTiledExecutor.cpp @@ -91,9 +91,12 @@ std::pair> ConvolutionTiledExecutor::computeBl return std::make_pair(total, std::make_pair(stride, kernelSize * maxLine)); } -void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColParameter& dstIm2ColParamter, const Convolution2DCommon* convCommon, Tensor* input, Tensor* output, int padX, int padY, const CoreFunctions* floatCore, const CoreInt8Functions* int8Core) { +void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColParameter& dstIm2ColParamter, const Convolution2DCommon* convCommon, Tensor* input, Tensor* output, int padX, int padY, const CoreFunctions* floatCore, const CoreInt8Functions* int8Core, int pack) { // FIXME: Set int8 and float's pack as diff - int pack = floatCore->pack; + if (pack == 0) { + pack = floatCore->pack; + } + const auto kernelCount = convCommon->kernelX() * convCommon->kernelY(); dstIm2ColParamter.dilateX = convCommon->dilateX(); @@ -119,7 +122,12 @@ void ConvolutionTiledExecutor:: setIm2ColParameter(ConvolutionCommon::Im2ColPara int UNIT, SRC_UNIT, DynamicDestUnit; auto core = int8Core; core->MNNGetGemmUnit(&UNIT, &SRC_UNIT, &DynamicDestUnit); - if (SRC_UNIT > pack) { + if (floatCore->bytes == 2 && DynamicDestUnit == 20) { + UNIT = 8; + SRC_UNIT= 8; + DynamicDestUnit = 10; + } + if (SRC_UNIT > UNIT) { const auto srcCountUnit = UP_DIV(input->channel(), pack); dstIm2ColParamter.kernelCountUnit = UP_DIV(srcCountUnit * kernelCount, SRC_UNIT / pack); dstIm2ColParamter.ic = dstIm2ColParamter.icDiv4 * pack; diff --git a/source/backend/cpu/compute/ConvolutionTiledExecutor.hpp b/source/backend/cpu/compute/ConvolutionTiledExecutor.hpp index 1d83fa3b3..3fc0076bd 100644 --- a/source/backend/cpu/compute/ConvolutionTiledExecutor.hpp +++ b/source/backend/cpu/compute/ConvolutionTiledExecutor.hpp @@ -46,7 +46,7 @@ class ConvolutionTiledExecutor : public Execution { virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; void initWeight(const float *source, float* cache, int depth, int outputCount, int kernelSize, const CoreFunctions* function); static std::pair turnIm2ColToBlitInfo(float const ** srcPtr, int32_t* el, int start, int xC, const ConvolutionCommon::Im2ColParameter& im2Col, const uint8_t* srcOrigin, int bytes); - static void setIm2ColParameter(ConvolutionCommon::Im2ColParameter& dstIm2ColParamter, const Convolution2DCommon* convCommon, Tensor* input, Tensor* output, int padX, int padY, const CoreFunctions* floatCore, const CoreInt8Functions* int8Core); + static void setIm2ColParameter(ConvolutionCommon::Im2ColParameter& dstIm2ColParamter, const Convolution2DCommon* convCommon, Tensor* input, Tensor* output, int padX, int padY, const CoreFunctions* floatCore, const CoreInt8Functions* int8Core, int pack = 0); // Total / Stride static std::pair> computeBlitInfoSize(int eP, int ow, int kernelSize, int threadNumber); diff --git a/source/backend/cpu/compute/ConvolutionWinogradImpl.cpp b/source/backend/cpu/compute/ConvolutionWinogradImpl.cpp index 0d597b07b..5ecb180f4 100644 --- a/source/backend/cpu/compute/ConvolutionWinogradImpl.cpp +++ b/source/backend/cpu/compute/ConvolutionWinogradImpl.cpp @@ -49,17 +49,4 @@ bool ConvolutionWinogradImpl::canUseWinograd(const Convolution2DCommon *common) return true; } -ErrorCode ConvolutionWinogradImpl::onExecute(const std::vector &inputs, const std::vector &outputs) { - return NO_ERROR; -} - -ErrorCode ConvolutionWinogradImpl::onResize(const std::vector &inputs, const std::vector &outputs) { - return NO_ERROR; -} - -bool ConvolutionWinogradImpl::onClone(Backend* bn, const Op* op, Execution** dst) { - return false; -} - - } // namespace MNN diff --git a/source/backend/cpu/compute/ConvolutionWinogradImpl.hpp b/source/backend/cpu/compute/ConvolutionWinogradImpl.hpp index 377cea037..c7ba42f1a 100644 --- a/source/backend/cpu/compute/ConvolutionWinogradImpl.hpp +++ b/source/backend/cpu/compute/ConvolutionWinogradImpl.hpp @@ -44,12 +44,9 @@ class ConvolutionWinogradImpl : public CPUConvolution { public: ConvolutionWinogradImpl(const Convolution2DCommon *convOp, Backend *b); virtual ~ConvolutionWinogradImpl(); - virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; - virtual ErrorCode onResize(const std::vector &inputs, const std::vector &outputs) override; static bool canUseWinograd(const Convolution2DCommon *convOp); static WinogradConfig bestWinogradUnit(const Convolution2DCommon *convOp, const Tensor *input, const Tensor *output, int threadnumber, Backend* b, const PerfConfig& denseConfig); - virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; protected: ConvolutionWinogradImpl(std::shared_ptr resource, const Convolution2DCommon *convOp, Backend* b) : CPUConvolution(convOp, b) { mResource = resource; diff --git a/source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp b/source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp index f844d7d5b..61dfb445a 100644 --- a/source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp +++ b/source/backend/cpu/compute/DenseConvolutionTiledExecutor.cpp @@ -193,6 +193,9 @@ DenseConvolutionTiledExecutor::DenseConvolutionTiledExecutor(const Convolution2D return; } } else { + if (core->matmulBytes != 0) { + bytes = core->matmulBytes; + } mResource->mWeight.reset(Tensor::createDevice( {hU * lU * hP * lP * bytes})); mValid = mValid && backend()->onAcquireBuffer(mResource->mWeight.get(), Backend::STATIC); @@ -330,7 +333,6 @@ void DenseConvolutionTiledImpl::getPackParameter(int* eP, int* lP, int* hP, cons return; } -// #define PROFILE_DETAIL PerfConfig DenseConvolutionTiledImpl::bestTileConvolutionConfig(const Convolution2DCommon *common, const Tensor *inputTensor, const Tensor *outputTensor, int threadNumber, Backend* b) { @@ -413,29 +415,11 @@ PerfConfig DenseConvolutionTiledImpl::bestTileConvolutionConfig(const Convolutio innerAcc += inner[i]; } PerfConfig thisConfig(false, eP, eP, 0, -1); - thisConfig.isParallelInner = outerAcc > innerAcc; + thisConfig.isParallelInner = outerAcc > innerAcc && 0 == core->matmulBytes; thisConfig.instructionCosts = outerAcc > innerAcc ? innerAcc : outerAcc; if (thisConfig.instructionCosts < denseConfig.instructionCosts) { denseConfig = thisConfig; -#ifdef PROFILE_DETAIL - MNN_PRINT("\nouterFlops:"); - formatMatrix(outerFlops, {sizeof(outerFlops) / sizeof(float)}); - MNN_PRINT("\ninnerFlops:"); - formatMatrix(innerFlops, {sizeof(innerFlops) / sizeof(float)}); - MNN_PRINT("\nouterBandwidth:"); - formatMatrix(outerBandwidth, {sizeof(outerBandwidth) / sizeof(float)}); - MNN_PRINT("\ninnerBandwidth:"); - formatMatrix(innerBandwidth, {sizeof(innerBandwidth) / sizeof(float)}); - - MNN_PRINT("\nouter:"); - formatMatrix(outer, {sizeof(outer) / sizeof(float)}); - MNN_PRINT("\ninner:"); - formatMatrix(inner, {sizeof(inner) / sizeof(float)}); - - MNN_PRINT("\ndense im2col mParallelInner:%d, ePack:%d, outerAcc:%.1f, innerAcc:%.1f, totalCount:%d, tileCount:%d, outerCoefficient:%.2f, innerCoefficient:%.2f, tailCost:%.2f, lastTail:%.2f, allowed thread:%d, omp thread:\n\n", - denseConfig.isParallelInner, eP, outerAcc, innerAcc, plane, tileCount, outerCoefficient, innerCoefficient, tailCost, lastTail, threadNumber); -#endif } } @@ -455,12 +439,15 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs int bytes = core->bytes; float weightBytes = bytes; int unit = core->pack; + int matmulBytes = bytes; + if (core->matmulBytes != 0) { + matmulBytes = core->matmulBytes; + } auto packA = core->MNNPackC4ForMatMul_A; int eP, lP, hP; getPackParameter(&eP, &lP, &hP, core); auto matmulUnit = core->MNNPackedMatMul; auto matmulRemain = core->MNNPackedMatMulRemain; - auto weightType = weight->getType(); const uint8_t* dequantAlpha = nullptr; const uint8_t* dequantBias = nullptr; auto ic = input->channel(); @@ -503,13 +490,11 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs mTempBufferTranspose.buffer().type = halide_type_of(); mTempBufferTranspose.buffer().dimensions = 2; mTempBufferTranspose.buffer().dim[0].extent = threadNumber; - mTempBufferTranspose.buffer().dim[1].extent = UP_DIV(L, lP) * lP * eP * bytes; + mTempBufferTranspose.buffer().dim[1].extent = UP_DIV(L, lP) * lP * eP * matmulBytes; TensorUtils::setLinearLayout(&mTempBufferTranspose); auto plane = mIm2ColParameters.ow * mIm2ColParameters.oh * batch; int tileCount = UP_DIV(plane, eP); mConvPerfconfig = bestTileConvolutionConfig(mCommon, input, output, threadNumber, backend()); - - auto threadNumberFirst = mConvPerfconfig.isParallelInner ? threadNumber : std::min(threadNumber, tileCount); bool success = backend()->onAcquireBuffer(&mTempBufferTranspose, Backend::DYNAMIC); if (!success) { return OUT_OF_MEMORY; @@ -525,15 +510,14 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs bufferAlloc->free(tempPtr); auto postParameters = getPostParameters(); - mFunction.first = threadNumberFirst; + mFunction.first = threadNumber; if (mConvPerfconfig.isParallelInner) { - + auto rt = static_cast(backend()->getRuntime()); + std::vector ocC4ParralSize(threadNumber + 1); + ocC4ParralSize[0] = 0; + rt->computeDivideSizes(oC4, ocC4ParralSize.data()+1); mFunction.second = [=](int placeholder) { -#ifdef PROFILE_DETAIL - MNN_PRINT("dense conv: n:%d, ic:%d, oc:%d, kh:%d, kw:%d, plane:%d, threadNumberFirst:%d, tileCount:%d, ePack:%d, pack::%d, bytes:%d\n", - batch, ic, outputChannel, kernel_width, kernel_height, plane, threadNumberFirst, tileCount, eP, unit, bytes); -#endif const float* biasPtr = bias ? bias->host() : nullptr; auto gemmBuffer = mTempBufferTranspose.host() + mTempBufferTranspose.stride(0) * 0; auto srcPtr = (float const **)(tempPtr.ptr() + 0 * kernelSize * maxLine * (4 * sizeof(int32_t) + sizeof(float *))); @@ -556,16 +540,10 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs parameters[5] = weightStride; // Only used when block quant parameters[6] = 0; -#ifdef PROFILE_DETAIL - std::vector durationMul(threadNumberFirst, 0); - std::vector packATime(threadNumberFirst, 0); - std::vector indexTime(threadNumberFirst, 0); - Timer timer[threadNumberFirst]; - std::vector macs(threadNumberFirst, 0); -#endif - auto dstOrigin = output->host(); auto srcOrigin = input->host(); + std::vector im2colParallelSize(threadNumber + 1); + im2colParallelSize[0] = 0; for (int x = 0; x < tileCount; x += 1) { int start = (int)x * eP; @@ -578,17 +556,15 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs if (needZero || lP != 1) { ::memset(gemmBuffer, 0, mTempBufferTranspose.stride(0)); } - -#ifdef PROFILE_DETAIL - indexTime[0] += timer[0].durationInUs(); - timer[0].reset(); -#endif - info[0] = 1; int hw4Stride = info[1] * unit * bytes; - MNN_CONCURRENCY_BEGIN(tId, threadNumberFirst) { + rt->computeDivideSizes(number * icC4, im2colParallelSize.data() + 1); + im2colParallelSize[0] = 0; + MNN_CONCURRENCY_BEGIN(tId, threadNumber) { int threadEL[4]; - for(int tic_inumber = tId; tic_inumber < number * icC4; tic_inumber+=threadNumberFirst) { + int ticSta = im2colParallelSize[tId]; + int ticEnd = im2colParallelSize[tId+1]; + for(int tic_inumber = ticSta; tic_inumber < ticEnd; tic_inumber++) { int inumber = tic_inumber / icC4; int t_ic = tic_inumber % icC4; memcpy(threadEL, el + 4 * inumber, 4 * sizeof(int)); @@ -600,16 +576,11 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs } MNN_CONCURRENCY_END(); -#ifdef PROFILE_DETAIL - packATime[0] += timer[0].durationInUs(); - timer[0].reset(); -#endif - if (xC == eP) { - MNN_CONCURRENCY_BEGIN(tId, threadNumberFirst) { + MNN_CONCURRENCY_BEGIN(tId, threadNumber) { size_t paraParameters[PARAMETERSIZE]; memcpy(paraParameters, parameters, PARAMETERSIZE * sizeof(size_t)); - for (int t_oc = tId; t_oc < oC4; t_oc += threadNumberFirst) { + for (int t_oc = ocC4ParralSize[tId]; t_oc < ocC4ParralSize[tId+1]; ++t_oc) { int ocIndex = t_oc * tileC; auto _dstFloatPtr = reinterpret_cast(dstOrigin + (ocIndex / unit * plane + start) * unit * bytes); auto _weightFloatPtr = reinterpret_cast(weightPtr + int((ocIndex / hP * LRoundup * hP) * weightBytes)); @@ -637,10 +608,10 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs } MNN_CONCURRENCY_END(); } else { - MNN_CONCURRENCY_BEGIN(tId, threadNumberFirst) { + MNN_CONCURRENCY_BEGIN(tId, threadNumber) { size_t paraParameters[PARAMETERSIZE]; memcpy(paraParameters, parameters, PARAMETERSIZE * sizeof(size_t)); - for (int t_oc = tId; t_oc < oC4; t_oc += threadNumberFirst) { + for (int t_oc = ocC4ParralSize[tId]; t_oc < ocC4ParralSize[tId+1]; ++t_oc) { int ocIndex = t_oc * tileC; auto _dstFloatPtr = reinterpret_cast(dstOrigin + (ocIndex / unit * plane + start) * unit * bytes); auto _weightFloatPtr = reinterpret_cast(weightPtr + int((ocIndex / hP * LRoundup * hP) * weightBytes)); @@ -669,32 +640,16 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs MNN_CONCURRENCY_END(); } -#ifdef PROFILE_DETAIL - macs[0] += 2.0 * xC * L * oC4 * unit / threadNumberFirst; - durationMul[0] += timer[0].durationInUs(); - timer[0].reset(); -#endif - } - -#ifdef PROFILE_DETAIL - double gflops = macs[0] / 1000.0 / durationMul[0]; - MNN_PRINT("dense conv mParallelInner:%d, inside measure: indexTime:%lu us, packATime:%lu us, durationMul:%lu us, total:%lu us, %.3f GFLOPS\n", - mConvPerfconfig.isParallelInner, indexTime[0], packATime[0], durationMul[0], indexTime[0] + packATime[0] + durationMul[0], gflops); - -#endif - }; } else { - mFunction.second = [=](int tId) { + std::vector divides(threadNumber + 1); + divides[0] = 0; -#ifdef PROFILE_DETAIL - if (tId == 0) { - MNN_PRINT("dense conv: n:%d, ic:%d, oc:%d, kh:%d, kw:%d, plane:%d, tileCount:%d, ePack:%d, pack::%d, bytes:%d\n", - batch, ic, outputChannel, kernel_width, kernel_height, plane, tileCount, eP, unit, bytes); - } -#endif + static_cast(static_cast(backend())->getRuntime())->computeDivideSizes(tileCount, divides.data() + 1); + + mFunction.second = [=](int tId) { const float* biasPtr = bias ? bias->host() : nullptr; auto gemmBuffer = mTempBufferTranspose.host() + mTempBufferTranspose.stride(0) * tId; auto srcPtr = (float const **)(tempPtr.ptr() + tId * kernelSize * maxLine * (4 * sizeof(int32_t) + sizeof(float *))); @@ -713,17 +668,11 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs parameters[5] = weightStride; // Only used when block quant parameters[6] = 0; -#ifdef PROFILE_DETAIL - std::vector durationMul(threadNumberFirst, 0); - std::vector packATime(threadNumberFirst, 0); - std::vector indexTime(threadNumberFirst, 0); - Timer timer[threadNumberFirst]; - std::vector macs(threadNumberFirst, 0); -#endif - auto dstOrigin = output->host(); auto srcOrigin = input->host(); - for (int x = (int)tId; x < tileCount; x += threadNumberFirst) { + int tEnd = divides[tId+1]; + int tStart = divides[tId]; + for (int x = (int)tStart; x < tEnd; ++x) { int start = (int)x * eP; int remain = plane - start; int xC = remain > eP ? eP : remain; @@ -735,18 +684,10 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs ::memset(gemmBuffer, 0, mTempBufferTranspose.stride(0)); } -#ifdef PROFILE_DETAIL - indexTime[tId] += timer[tId].durationInUs(); - timer[tId].reset(); -#endif if (number > 0) { packA((float *)gemmBuffer, srcPtr, info, el); } -#ifdef PROFILE_DETAIL - packATime[tId] += timer[tId].durationInUs(); - timer[tId].reset(); -#endif int finishedL = 0; int wquantStride = 0; int8_t* _weightPtr = reinterpret_cast(weightPtr); @@ -780,20 +721,7 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs } // matmulRemain(_dstFloatPtr, (float*)gemmBuffer, (float*)weightPtr, xC, parameters, postParameters.data(), biasPtr, k, b); } - -#ifdef PROFILE_DETAIL - macs[tId] += 2.0 * xC * L * oC4 * unit; // bias - durationMul[tId] += timer[tId].durationInUs(); - timer[tId].reset(); -#endif } - -#ifdef PROFILE_DETAIL - double gflops = macs[tId] / 1000.0 / durationMul[tId]; - MNN_PRINT("dense conv mParallelInner:%d, inside measure: indexTime:%lu us, packATime:%lu us, durationMul:%lu us, total:%lu us, %.3f GFLOPS\n", - mConvPerfconfig.isParallelInner, indexTime[tId], packATime[tId], durationMul[tId], indexTime[tId] + packATime[tId] + durationMul[tId], gflops); - -#endif }; } return NO_ERROR; @@ -801,10 +729,6 @@ ErrorCode DenseConvolutionTiledImpl::onResize(const std::vector& inputs ErrorCode DenseConvolutionTiledImpl::onExecute(const std::vector& inputs, const std::vector& outputs) { -#ifdef PROFILE_DETAIL - Timer outsideTimer; - outsideTimer.reset(); -#endif if (mConvPerfconfig.isParallelInner) { mFunction.second(0); } else { @@ -814,12 +738,8 @@ ErrorCode DenseConvolutionTiledImpl::onExecute(const std::vector& input MNN_CONCURRENCY_END(); } -#ifdef PROFILE_DETAIL - MNN_PRINT("dense conv. mParallelInner:%d, outside measure: total cost %lu us\n", mConvPerfconfig.isParallelInner, outsideTimer.durationInUs()); -#endif return NO_ERROR; } -#undef PROFILE_DETAIL } // namespace MNN diff --git a/source/backend/cpu/compute/DenseConvolutionTiledExecutor.hpp b/source/backend/cpu/compute/DenseConvolutionTiledExecutor.hpp index 2ce01634f..f618b127f 100644 --- a/source/backend/cpu/compute/DenseConvolutionTiledExecutor.hpp +++ b/source/backend/cpu/compute/DenseConvolutionTiledExecutor.hpp @@ -31,7 +31,6 @@ class DenseConvolutionTiledImpl : public ConvolutionTiledImpl { static PerfConfig bestTileConvolutionConfig(const Convolution2DCommon *common, const Tensor *inputTensor, const Tensor *outputTensor, int threadNumber, Backend* b); protected: - }; class DenseConvolutionTiledExecutor : public ConvolutionTiledExecutor { public: diff --git a/source/backend/cpu/compute/GemmInt8Executor.cpp b/source/backend/cpu/compute/GemmInt8Executor.cpp index 0c7e9a7ff..00e501e5d 100644 --- a/source/backend/cpu/compute/GemmInt8Executor.cpp +++ b/source/backend/cpu/compute/GemmInt8Executor.cpp @@ -15,7 +15,9 @@ namespace MNN { GemmInt8Executor::GemmInt8Executor(Backend* bn, std::shared_ptr resource, const Convolution2D *conv2D, decltype(CoreInt8Functions::Int8GemmKernel) gemmKernel, std::vector bias): - CPUConvolution(conv2D->common(), bn), mResource(resource), mMutableResource(resource, bn), mGemmKernel(gemmKernel), mQuantBias(bias){ + CPUConvolution(conv2D->common(), bn), mResourceInt8(resource), mMutableResource(resource, bn), mGemmKernel(gemmKernel), mQuantBias(bias){ + mResource.reset(new Resource); + CPUConvolution::makeResource(bn, mResource, conv2D, mResourceInt8); } GemmInt8Executor::~GemmInt8Executor() { @@ -43,23 +45,32 @@ ErrorCode GemmInt8Executor::onResize(const std::vector &inputs, const auto pack = gcore->pack; auto scaleSrc = mMutableResource.mScaleFloat->host(); + int realWeightQuantScaleSize = mResource->mDequantize.mScaleBias->size() / 2; + auto weightBiasSrc = reinterpret_cast(mResource->mDequantize.mScaleBias->host() + realWeightQuantScaleSize); auto ocDivUp = UP_DIV(output->channel(), pack) * pack; mKernelY = mCommon->kernelY(); mKernelX = mCommon->kernelX(); int kernelCount = mKernelX * mKernelY; std::vector scaleData(ocDivUp); + mKernelSum.resize(ocDivUp, 0); ::memset(scaleData.data(), 0.f, ocDivUp * sizeof(float)); auto l = mMutableResource.mScaleFloat->length(0); auto lU = UP_DIV(l, pack); for (int divC = 0; divC < lU; ++divC) { auto srcX = scaleSrc + divC * pack; + auto wbias = weightBiasSrc + divC * pack; for (int k = 0; k < kernelCount; ++k) { int indexK = divC * kernelCount * pack + k * pack; for (int j = 0; j < pack; ++j) { scaleData[indexK + j] = srcX[j]; + mKernelSum[indexK + j] = wbias[j]; } } } + float* biasFloat = reinterpret_cast(mQuantBias.data()); + for (int i = 0; i < mQuantBias.size(); ++i) { + biasFloat[i] = mQuantBias[i] * scaleData[i]; + } mScaleData = scaleData; const auto IC4 = UP_DIV(input->channel(), pack); ConvolutionTiledExecutor::setIm2ColParameter(mIm2ColParamter, mCommon, input, output, 0, 0, static_cast(backend())->functions(), core); @@ -71,7 +82,7 @@ ErrorCode GemmInt8Executor::onResize(const std::vector &inputs, const mIm2ColParamter.padX = 0; mIm2ColParamter.padY = 0; mIm2ColParamter.kernelCountUnit = UP_DIV(input->channel(), SRC_UNIT); - if (SRC_UNIT > pack) { + if (SRC_UNIT > UNIT___) { const auto srcCountUnit = UP_DIV(input->channel(), pack); mIm2ColParamter.ic = mIm2ColParamter.icDiv4 * pack; } else { @@ -131,22 +142,39 @@ ErrorCode GemmInt8Executor::onExecute(const std::vector &inputs, const QuanPostTreatParameters quanParam; quanParam.scale = mScaleData.data(); quanParam.maxValue = mMutableResource.mClampMax; - if (mResource->mRelu) { + if (mResourceInt8->mRelu) { quanParam.minValue = mMutableResource.mOutputZeroPoint; } else { quanParam.minValue = mMutableResource.mClampMin; } + auto postParameters = getPostParameters(); + std::vector fp32minmax = {postParameters[2], postParameters[3]}; + quanParam.fp32minmax = fp32minmax.data(); quanParam.useInt8 = 0; // Save result as float data type. - quanParam.bias = mQuantBias.data(); + quanParam.biasFloat = reinterpret_cast(mQuantBias.data()); + quanParam.weightQuanBias = mKernelSum.data(); + quanParam.extraScale = nullptr; + float dequantScale = mMutableResource.mResource->mInputScale; + + SumByAxisParams sumParams; + sumParams.DST_XUNIT = DST_XUNIT; + sumParams.SRC_UNIT = SRC_UNIT; + sumParams.blockNum = 1; + sumParams.kernelCountUnitDouble = mIm2ColParamter.kernelCountUnit; + sumParams.oneScale = 1; + sumParams.col_buffer_unit_size = mInputCol->stride(0); auto threadFunction = [&](int tId) { auto colAddr = im2colPtr + tId * mInputCol->stride(0); auto col_buffer_size = mInputCol->stride(0); - int32_t info[4]; + int32_t info[6]; info[1] = mIm2ColParamter.iw * mIm2ColParamter.ih * batch; info[2] = DST_XUNIT; info[3] = mIm2ColParamter.strideX; + info[5] = mIm2ColParamter.kernelCountUnit; + float paramsf[1]; + paramsf[0] = dequantScale; auto srcPtr = (int8_t const **)(mBlitInfo.ptr() + tId * mBlitInfoStride.first); auto el = (int32_t *)(srcPtr + mBlitInfoStride.second); @@ -165,9 +193,15 @@ ErrorCode GemmInt8Executor::onExecute(const std::vector &inputs, const #endif } info[0] = number; + info[4] = realDstCount; + std::vector xKernelSum(realDstCount); if (number > 0) { blitProc(colAddr, srcPtr, info, el); } + if (mResourceInt8->mWeightAsymmetricQuant) { + gcore->MNNSumByAxisLForMatmul_A(xKernelSum.data(), colAddr, &dequantScale, realDstCount, sumParams); + } + quanParam.srcKernelSum = xKernelSum.data(); auto outputInTilePtr = outputDataPtr + xIndexStart * PackUnit; mGemmKernel((int8_t*)outputInTilePtr, colAddr, weightDataPtr, src_depth_quad, dstZStep * sizeof(float), ocDiv4, &quanParam, realDstCount); } diff --git a/source/backend/cpu/compute/GemmInt8Executor.hpp b/source/backend/cpu/compute/GemmInt8Executor.hpp index 668d56308..0c1345f03 100644 --- a/source/backend/cpu/compute/GemmInt8Executor.hpp +++ b/source/backend/cpu/compute/GemmInt8Executor.hpp @@ -26,13 +26,15 @@ class GemmInt8Executor : public CPUConvolution { int mKernelY; std::shared_ptr mInputCol; std::vector mScaleData; + std::vector mKernelSum; std::vector mQuantBias; - std::shared_ptr mResource; + std::shared_ptr mResourceInt8; ConvolutionCommon::Im2ColParameter mIm2ColParamter; CPUConvolution::MutableResourceInt8 mMutableResource; decltype(CoreInt8Functions::Int8GemmKernel) mGemmKernel; MemChunk mBlitInfo; std::pair mBlitInfoStride; + std::shared_ptr mResource; }; } // namespace MNN #endif /* DeconvInt8Executor_hpp */ diff --git a/source/backend/cpu/compute/IdstConvolutionInt8.cpp b/source/backend/cpu/compute/IdstConvolutionInt8.cpp index 140d8dd21..bec8d7109 100644 --- a/source/backend/cpu/compute/IdstConvolutionInt8.cpp +++ b/source/backend/cpu/compute/IdstConvolutionInt8.cpp @@ -72,15 +72,18 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back shape = {UP_DIV(outputCount, UNIT), UP_DIV(srcCount, SRC_UNIT) * kernelCount, UNIT, SRC_UNIT}; } mWeight.reset(Tensor::createDevice(shape)); - mFakeBias.reset(Tensor::createDevice({(int)ROUND_UP(biasSize, PackUnit)})); + mFakeBias.reset(Tensor::createDevice({(int)ROUND_UP(biasSize, PackUnit)})); + mFakeWeightBias.reset(Tensor::createDevice({(int)ROUND_UP(biasSize, PackUnit)})); mValid = b->onAcquireBuffer(mWeight.get(), Backend::STATIC); mValid &= b->onAcquireBuffer(mFakeBias.get(), Backend::STATIC); + mValid &= b->onAcquireBuffer(mFakeWeightBias.get(), Backend::STATIC); if (!mValid) { MNN_ERROR("Memory not enough\n"); return; } ConvInt8TiledExecutor::reorderWeight(mWeight.get(), (uint8_t*)common->weight.get(), SRC_UNIT, UNIT, srcCount, outputCount, kernelCount); - ::memset(mFakeBias->host(), 0, mFakeBias->size()); + ::memset(mFakeBias->host(), 0, mFakeBias->size()); + ::memset(mFakeWeightBias->host(), 0, mFakeWeightBias->size()); #ifdef MNN_USE_SSE for (int oz = 0; oz < outputCount; ++oz) { auto srcZ = common->weight.get() + oz * kernelCount * srcCount; @@ -88,7 +91,7 @@ IdstConvolutionInt8::IdstConvolutionInt8(const Convolution2DCommon* convOp, Back for (int i = 0; i < kernelCount * srcCount; ++i) { offset += srcZ[i] * (-128); } - mFakeBias->host()[oz] = offset; + mFakeBias->host()[oz] = static_cast(offset) * 1.f; } #endif } @@ -149,7 +152,7 @@ ErrorCode IdstConvolutionInt8::onExecute(const std::vector& inputs, con int UNIT__, SRC_UNIT, DST_XUNIT; coreInt->MNNGetGemmUnit(&UNIT__, &SRC_UNIT, &DST_XUNIT); int PackUnit = static_cast(backend())->functions()->pack; - + auto gemmKernel = coreInt->Int8GemmKernel; // AUTOTIME; @@ -176,9 +179,14 @@ ErrorCode IdstConvolutionInt8::onExecute(const std::vector& inputs, con std::vector fakeScale(ocC4 * PackUnit, 1.0f); QuanPostTreatParameters quanParam; - quanParam.bias = mFakeBias->host(); + quanParam.biasFloat = mFakeBias->host(); quanParam.scale = fakeScale.data(); quanParam.useInt8 = 0; + float fp32minmax[2] = {-std::numeric_limits().max(), std::numeric_limits().max()}; + quanParam.fp32minmax = fp32minmax; + quanParam.weightQuanBias = mFakeWeightBias->host(); + std::vector fakeSrcKernleSum(DST_XUNIT, 0.f); + quanParam.srcKernelSum = fakeSrcKernleSum.data(); // MNN_PRINT("%s, %d, %d, %d,%d->%d,%d\n", layer->layer.layerId, layer->kernelSize[0], layer->kernelSize[1], // input->d1, input->d2, output->d1, output->d2); diff --git a/source/backend/cpu/compute/IdstConvolutionInt8.hpp b/source/backend/cpu/compute/IdstConvolutionInt8.hpp index 1a188c077..c66332512 100644 --- a/source/backend/cpu/compute/IdstConvolutionInt8.hpp +++ b/source/backend/cpu/compute/IdstConvolutionInt8.hpp @@ -40,6 +40,7 @@ class IdstConvolutionInt8 : public CPUConvolution { std::vector mPostParameters; // mFakeBias used by GemmKernel std::shared_ptr mFakeBias; + std::shared_ptr mFakeWeightBias; MemChunk mBlitInfo; std::pair mBlitInfoStride; }; diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.cpp b/source/backend/cpu/compute/Int8FunctionsOpt.cpp index 2d046de25..50fad7e6a 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.cpp +++ b/source/backend/cpu/compute/Int8FunctionsOpt.cpp @@ -22,6 +22,8 @@ void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int const QuanPostTreatParameters* post, size_t realCount); void MNNGemmInt8AddBiasScale_16x4_Unit_FAST(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount); +void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realCount); void MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr); void MNNMaxPoolInt8(int8_t* dst, int8_t* src, size_t outputWidth, size_t inputWidth, size_t kernelx, size_t kernely, size_t stridesx); @@ -35,6 +37,31 @@ void MNNGemmInt8AddBiasScale_ARMV86_Unit(int8_t* dst, const int8_t* src, const i const QuanPostTreatParameters* post, size_t realDstCount); void MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr); +#if defined(MNN_LOW_MEMORY) +// int4 weight gemmInt8 kernel +void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); +void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); +void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); +// Tools to dynamic-quant fp16-input data. +#ifdef MNN_USE_ARMV82 +void DynamicQuanInput_ARM82(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, + ssize_t maxValue, ssize_t zeroPoint); +// int8 weight gemmInt8 kernel to return fp16-output data. +void MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); +void MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); +void MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); +void MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); +void DynamicQuanInputAndReorder_ARM82(const float* src, int8_t* dst, size_t planeSize, const float* scale, ssize_t aMin, + ssize_t aMax, ssize_t zeroPoint, size_t ocQuad, size_t offset); +#endif +#endif #endif // __aarch64__ } #endif // MNN_USE_NEON @@ -1386,11 +1413,28 @@ static int8_t MNNInt32ToInt8(int data, int bias, float scale, float maxValue, fl static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) { const int bytes = ((post->useInt8 == 1) ? 1 : 4); + float fp32min = 0, fp32max = 0; +// if (0 == post->useInt8) { +// fp32min = (post->fp32minmax)[0]; +// fp32max = (post->fp32minmax)[1]; +// } + auto blockNum = post->blockNum; + int weight_step_Z = (src_depth_quad * blockNum) * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); + int weight_step_Y = (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); + const auto srcSumPtr = post->srcKernelSum; + if (0 == post->useInt8 && post->fp32minmax) { + fp32min = (post->fp32minmax)[0]; + fp32max = (post->fp32minmax)[1]; + } + + float* biasPtr = (float*)post->biasFloat; + for (int dz = 0; dz < dst_depth_quad; ++dz) { - const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); - const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT; + const auto weight_dz = weight + weight_step_Z * dz; + const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; + const auto weight_zero = post->weightQuanBias + (dz * GEMM_INT8_UNIT); const float* scale_dz = nullptr; - scale_dz = post->scale + dz * GEMM_INT8_UNIT; + scale_dz = post->scale + (dz * GEMM_INT8_UNIT); auto dst_z = dst + dz * dst_step; for (int w = 0; w < realCount; ++w) { const auto src_x = src + w * GEMM_INT8_SRC_UNIT; @@ -1398,7 +1442,7 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co int32_t dstTemp[4] = {0, 0, 0, 0}; for (int sz = 0; sz < src_depth_quad; ++sz) { - const auto weight_sz = weight_dz + (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT) * sz; + const auto weight_sz = weight_dz + weight_step_Y * sz; const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT; for (int j = 0; j < GEMM_INT8_UNIT; ++j) { @@ -1410,34 +1454,125 @@ static void MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, co } for (int j = 0; j < GEMM_INT8_UNIT; ++j) { - if (!post->scale) { - ((float*)dst_x)[j] = (float)(dstTemp[j] + bias_dz[j]); - } else if (post->useInt8 == 1) { - dst_x[j] = MNNInt32ToInt8(dstTemp[j], bias_dz[j], scale_dz[j], post->maxValue, post->minValue); - } else { - float value = (float)(dstTemp[j] + bias_dz[j]) * scale_dz[j]; + float value = dstTemp[j] * scale_dz[j] + srcSumPtr[w] * weight_zero[j]; + if (post->extraScale) { + value = dstTemp[j] * scale_dz[j] * post->extraScale[w] + srcSumPtr[w] * weight_zero[j]; + } + if (post->useInt8 == 0) { + if (biasPtr) { + value += bias_dz[j]; + } else { + float dstv = ((float*)dst_x)[j]; + value += dstv; + } + if (post->fp32minmax) { + value = std::min(std::max(fp32min, value), fp32max); + } ((float*)dst_x)[j] = value; + } else { + value += bias_dz[j]; + value = ALIMAX(value, post->minValue); + value = ALIMIN(value, post->maxValue); + dst_x[j] = static_cast(roundf(value)); + } + } + } + } +} + +static void MNNGemmInt8AddBiasScale_16x4_w4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount) { + uint32_t c = 0xf; + const int bytes = 4; + float fp32min = 0, fp32max = 0; + int weight_step_Z = 0.5 * (post->blockNum * src_depth_quad) * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); + int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); + MNN_ASSERT(post->useInt8==0); + if (post->fp32minmax) { + fp32min = (post->fp32minmax)[0]; + fp32max = (post->fp32minmax)[1]; + } + + float* biasPtr = (float*)post->biasFloat; + int blockNum = post->blockNum; + + const auto srcSumPtr = post->srcKernelSum; + for (int dz = 0; dz < dst_depth_quad; ++dz) { + const auto weight_dz = weight + weight_step_Z * dz; + const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; + const auto weight_zero = post->weightQuanBias + (dz * GEMM_INT8_UNIT); + const float* scale_dz = nullptr; + scale_dz = post->scale + (dz * GEMM_INT8_UNIT); + auto dst_z = dst + dz * dst_step; + for (int w = 0; w < realCount; ++w) { + const auto src_x = src + w * GEMM_INT8_SRC_UNIT; + auto dst_x = dst_z + w * GEMM_INT8_UNIT * bytes; + int32_t dstTemp[4] = {0, 0, 0, 0}; + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = (uint8_t*)weight_dz + weight_step_Y * sz; + const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT; + + int w8[64]; // 64=GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT + for (int k = 0; k < 32; ++k) { + w8[2 * k] = (weight_sz[k]>>4); + w8[2 * k + 1] = (weight_sz[k] & c); + } + + for (int j = 0; j < GEMM_INT8_UNIT; ++j) { + const auto weight_j = w8 + j * GEMM_INT8_SRC_UNIT; + for (int i = 0; i < GEMM_INT8_SRC_UNIT; ++i) { + dstTemp[j] += (int32_t)src_z[i] * (int32_t)weight_j[i]; + } } } + + for (int j = 0; j < GEMM_INT8_UNIT; ++j) { + float value = dstTemp[j] * scale_dz[j] + srcSumPtr[w] * weight_zero[j]; + if (post->extraScale) { + value = dstTemp[j] * scale_dz[j] * post->extraScale[w] + srcSumPtr[w] * weight_zero[j]; + } + + if (biasPtr) { + value += bias_dz[j]; + } else { + float dstv = ((float*)dst_x)[j]; + value += dstv; + } + if (post->fp32minmax) { + value = std::min(std::max(fp32min, value), fp32max); + } + ((float*)dst_x)[j] = value; + } } } } static void MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const float* slope, size_t planeNumber, size_t depthQuad, QuanPrePostParameters *params) { +#ifdef MNN_USE_SSE +float offset = 128.f; +uint8_t* srcPtr = (uint8_t*)src; +uint8_t* dstPtr = (uint8_t*)dst; +#else +float offset = 0.f; +const int8_t* srcPtr = src; +int8_t* dstPtr = dst; +#endif float mulVal = 0.f; float inputScale = params->inputScale[0]; float outputScale = params->outputScale[0]; - int32_t inputZero = static_cast(params->inputZeroPoint[0]); - int32_t outputZero = static_cast(params->outputZeroPoint[0]); + float inputZero = static_cast(params->inputZeroPoint[0]) + offset; + float outputZero = static_cast(params->outputZeroPoint[0]) + offset; + int32_t minval = params->minValue + offset; + int32_t maxval = params->maxValue + offset; for (int j = 0;j < depthQuad; ++j) { const float* slopeZ = slope + 4 * j; - const int8_t* srcZ = src + 4 * j * planeNumber; - int8_t* dstZ = dst + 4 * j * planeNumber; + const auto srcZ = srcPtr + 4 * j * planeNumber; + auto dstZ = dstPtr + 4 * j * planeNumber; for (int i = 0; i < planeNumber; ++i) { for (int c = 0; c < 4; ++c) { - if (srcZ[4 * i + c] < 0) { + if ((float)srcZ[4 * i + c] < inputZero) { mulVal = (srcZ[4 * i + c] - inputZero) * slopeZ[c]; - dstZ[4 * i + c] = ALIMIN(ALIMAX(static_cast(roundf(mulVal)) + outputZero, params->minValue), params->maxValue); + dstZ[4 * i + c] = ALIMIN(ALIMAX(static_cast(roundf(mulVal)) + outputZero, minval), maxval); } else { dstZ[4 * i + c] = srcZ[4 * i + c]; } @@ -1974,9 +2109,9 @@ static void MNNGetGemmUnitSdot(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) { } static void MNNGetGemmUnitI8mm(int* UNIT, int* SRC_UNIT, int* DST_XUNIT) { - *UNIT = 4; + *UNIT = 8; *SRC_UNIT = 8; - *DST_XUNIT = 20; + *DST_XUNIT = 10; } template @@ -2055,6 +2190,9 @@ void MNNCoreInt8FunctionInit() { gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_16x4_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_16x4_Unit_FAST; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnit; +#ifdef MNN_LOW_MEMORY + gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_16x4_w4_Unit; +#endif // Im2Col gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A; @@ -2088,15 +2226,31 @@ void MNNCoreInt8FunctionInit() { gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A_L4<12, 4>; // ConvDepthwise gCoreFunc->ConvDepthwise3x3LineInt8_ARM82 = MNNLineDepthWiseInt8AddBiasScale_ARMV82_Unit3X3; - +#if defined(MNN_LOW_MEMORY) + #ifdef MNN_USE_ARMV82 + gCoreFunc->DynamicQuanInput_ARM82 = DynamicQuanInput_ARM82; + gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV82_Unit_FP16; + gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV82_w4_Unit_FP16; + gCoreFunc->DynamicQuanInputAndReorder_ARM82 = DynamicQuanInputAndReorder_ARM82; + #endif + gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV82_w4_Unit; +#endif } if (core->supportI8mm) { // MatMul gCoreFunc->Int8GemmKernel = MNNGemmInt8AddBiasScale_ARMV86_Unit; gCoreFunc->Int8GemmKernelFast = MNNGemmInt8AddBiasScale_ARMV86_Unit; gCoreFunc->MNNGetGemmUnit = MNNGetGemmUnitI8mm; +#if defined(MNN_LOW_MEMORY) + gCoreFunc->Int8GemmKernel_W4 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit; + #ifdef MNN_USE_ARMV82 + gCoreFunc->MNNGemmInt8AddBiasScale_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_Unit_FP16; + gCoreFunc->MNNGemmInt8AddBiasScale_w4_Unit_FP16 = MNNGemmInt8AddBiasScale_ARMV86_w4_Unit_FP16; + #endif +#endif // Im2Col - gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<20, 8, 4>; + gCoreFunc->MNNPackC4Int8ForMatMul_A = _ArmBasicMNNPackC4ForMatMul_A<10, 8, 8>; + gCoreFunc->MNNPackC4Int8ForMatMul_A_ARM86FP16 = _ArmBasicMNNPackC4ForMatMul_A<10, 8, 8>; } #endif MNNInt8FunctionInit(); diff --git a/source/backend/cpu/compute/Int8FunctionsOpt.h b/source/backend/cpu/compute/Int8FunctionsOpt.h index eea714090..da974619c 100644 --- a/source/backend/cpu/compute/Int8FunctionsOpt.h +++ b/source/backend/cpu/compute/Int8FunctionsOpt.h @@ -38,13 +38,19 @@ extern "C" { struct QuanPostTreatParameters { const float* scale; - const int32_t* bias; + const float* biasFloat; int32_t maxValue; int32_t minValue; int32_t useInt8 = 1; // Save result as int8_t dataType; otherwise float32. float roundValuePos = 0.5f; float roundValueNeg = -0.5f; - + float* srcKernelSum; + float* weightQuanBias; + float* fp32minmax; + ssize_t blockNum = 1; + const int32_t* bias; + const float* extraScale = nullptr; + const float* extraBias = nullptr; }; struct QuanPrePostParameters{ float* inputScale; @@ -78,7 +84,13 @@ struct CoreInt8Functions { void(*Int8GemmKernelFast)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realCount); void(*MNNGetGemmUnit)(int* UNIT, int* SRC_UNIT, int* DST_XUNIT); void(*MNNPackC4Int8ForMatMul_A)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el); - + void(*MNNPackC4Int8ForMatMul_A_ARM86FP16)(int8_t* destOrigin, int8_t const** sourceGroup, const int32_t* info, const int32_t* el) = nullptr; + void(*MNNGemmInt8AddBiasScale_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); + void(*MNNGemmInt8AddBiasScale_w4_Unit_FP16)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); + void(*Int8GemmKernel_W4)(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, + const QuanPostTreatParameters* post, size_t realDstCount); // sparse void(*MNNGetSparseQuantMatMulPackMode)(int* eP, int *lP, int* hP); void(*MNNPackForSparseQuantMatMul_B)(int8_t* dest, unsigned int* NNZMap, int* dataOffsetMap, int sparseBlockOC, const int8_t* source, size_t h, size_t kernelCount, size_t icCount, const int eP); @@ -90,9 +102,9 @@ struct CoreInt8Functions { size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder); void(*ConvDepthwise3x3LineInt8_ARM82)(int8_t* dst, const int8_t* src, const int8_t* weight, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder) = nullptr; - - void(*MNNFloat2Int8)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, - ssize_t maxValue, ssize_t zeroPoint); + void(*DynamicQuanInput_ARM82)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint) = nullptr; + void (*DynamicQuanInputAndReorder_ARM82)(const float* src, int8_t* dst, size_t planeSize, const float* scale, ssize_t aMin, ssize_t aMax, ssize_t zeroPoint, size_t ocQuad, size_t offset) = nullptr; + void(*MNNFloat2Int8)(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minValue, ssize_t maxValue, ssize_t zeroPoint); void(*MNNInt8ScaleToFloat)(float* dst, const int8_t* src, const float* scale, size_t size, ssize_t zeroPoint); void(*MNNScaleAndAddBias)(float* dst, const float* src, const float* bias, const float* alpha, size_t planeNumber, size_t biasNumber); diff --git a/source/backend/cpu/compute/SparseConvInt8TiledExecutor.cpp b/source/backend/cpu/compute/SparseConvInt8TiledExecutor.cpp index 395fa3745..5c8fc0dca 100644 --- a/source/backend/cpu/compute/SparseConvInt8TiledExecutor.cpp +++ b/source/backend/cpu/compute/SparseConvInt8TiledExecutor.cpp @@ -64,12 +64,12 @@ bool SparseConvInt8TiledExecutor::reorderWeight(Backend* b, const Convolution2DC return true; } -SparseConvInt8TiledExecutor::SparseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr res) : ConvInt8TiledExecutor(backend, convOp->common(), res) { +SparseConvInt8TiledExecutor::SparseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, std::shared_ptr res) : ConvInt8TiledExecutor(backend, convOp, res) { std::shared_ptr weightOrigin; - weightOrigin.swap(mResource->mWeightInt8); + weightOrigin.swap(mResourceInt8->mWeightInt8); const SparseCommon* sparseCommon = convOp->sparseParameter(); - mValid = reorderWeight(backend, convOp->common(), weightOrigin, mResource->mWeightInt8, sparseCommon); + mValid = reorderWeight(backend, convOp->common(), weightOrigin, mResourceInt8->mWeightInt8, sparseCommon); if(!mValid) { return; } @@ -81,9 +81,9 @@ SparseConvInt8TiledExecutor::SparseConvInt8TiledExecutor(Backend* backend, const } -SparseConvInt8TiledExecutor::SparseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, +SparseConvInt8TiledExecutor::SparseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, const SparseConvInt8TiledExecutor& exe) - : ConvInt8TiledExecutor(backend, common, exe.mResource), + : ConvInt8TiledExecutor(backend, convOp, exe.mResourceInt8), mNNZMap(exe.mNNZMap), mDataOffsetMap(exe.mDataOffsetMap), mSparseBlockOC(exe.mSparseBlockOC), @@ -98,7 +98,7 @@ bool SparseConvInt8TiledExecutor::onClone(Backend* bn, const Op* op, Execution** if (nullptr == dst) { return true; } - auto exe = new SparseConvInt8TiledExecutor(bn, op->main_as_Convolution2D()->common(), *this); + auto exe = new SparseConvInt8TiledExecutor(bn, op->main_as_Convolution2D(), *this); if (!exe->valid()) { return false; } @@ -170,7 +170,7 @@ ErrorCode SparseConvInt8TiledExecutor::onExecute(const std::vector& inp const int ocDivPack = UP_DIV(output->channel(), PackUnit); const auto inputDataPtr = input->host(); - const auto weightDataPtr = mResource->mWeightInt8->host(); + const auto weightDataPtr = mResourceInt8->mWeightInt8->host(); const auto NNZMapPtr = mNNZMap->host(); const auto dataOffsetPtr = mDataOffsetMap->host(); auto im2colPtr = mTempIm2ColBuffer->host(); @@ -179,7 +179,7 @@ ErrorCode SparseConvInt8TiledExecutor::onExecute(const std::vector& inp quanParam.bias = mMutableResource.mBiasInt32->host(); quanParam.scale = mMutableResource.mScaleFloat->host(); quanParam.maxValue = mMutableResource.mClampMax; - if (mResource->mRelu) { + if (mResourceInt8->mRelu) { quanParam.minValue = mMutableResource.mOutputZeroPoint; } else { quanParam.minValue = mMutableResource.mClampMin; diff --git a/source/backend/cpu/compute/SparseConvInt8TiledExecutor.hpp b/source/backend/cpu/compute/SparseConvInt8TiledExecutor.hpp index b3982fab7..9bcb7ee61 100644 --- a/source/backend/cpu/compute/SparseConvInt8TiledExecutor.hpp +++ b/source/backend/cpu/compute/SparseConvInt8TiledExecutor.hpp @@ -50,7 +50,7 @@ class SparseConvInt8TiledExecutor : public ConvInt8TiledExecutor { } private: - SparseConvInt8TiledExecutor(Backend* backend, const Convolution2DCommon* common, const SparseConvInt8TiledExecutor& exe); + SparseConvInt8TiledExecutor(Backend* backend, const Convolution2D* convOp, const SparseConvInt8TiledExecutor& exe); SparseQuantMatMulParam mSparseQuantParam; decltype(CoreInt8Functions::MNNPackedSparseQuantMatMulEpx1) mSparseQuantMatMulKernel; diff --git a/source/backend/cpu/compute/SparseConvolutionTiledExecutor.cpp b/source/backend/cpu/compute/SparseConvolutionTiledExecutor.cpp index ab39bfd9b..06b4c1b11 100644 --- a/source/backend/cpu/compute/SparseConvolutionTiledExecutor.cpp +++ b/source/backend/cpu/compute/SparseConvolutionTiledExecutor.cpp @@ -273,6 +273,10 @@ ErrorCode SparseConvolutionTiledImpl::onResize(const std::vector& input int bytes = core->bytes; int unit = core->pack; auto packA = core->MNNPackC4ForMatMul_A; + if (core->matmulBytes != 0) { + // Use origin packC4 + packA = MNNGetCoreFunctions()->MNNPackC4ForMatMul_A; + } int eP, lP, hP; getPackParameter(&eP, &lP, &hP, core); auto weightPtr = weight->host(); diff --git a/source/backend/cpu/x86_x64/AVX2Functions.cpp b/source/backend/cpu/x86_x64/AVX2Functions.cpp index 0f1db20d4..e48d00981 100644 --- a/source/backend/cpu/x86_x64/AVX2Functions.cpp +++ b/source/backend/cpu/x86_x64/AVX2Functions.cpp @@ -44,10 +44,7 @@ bool AVX2Functions::init(int cpuFlags) { coreFunction->MNNPackedMatMulRemain_int4 = _AVX_MNNPackedMatMulRemain_int4; coreFunction->MNNPackedMatMul_int8 = _AVX_MNNPackedMatMul_int8; coreFunction->MNNPackedMatMulRemain_int8 = _AVX_MNNPackedMatMulRemain_int8; - coreFunction->MNNGemmHybridInt4 = _AVX_MNNGemmHybridInt4; - coreFunction->MNNGemmHybridInt8 = _AVX_MNNGemmHybridInt8; coreFunction->MNNAbsMax = _AVX_MNNAbsMaxFP32; - coreFunction->MNNDynamicQuant = _AVX_MNNDynamicQuantFP32; #endif coreFunction->MNNPackC4ForMatMul_A = _AVX_MNNPackC4ForMatMul_A; coreFunction->MNNPackForMatMul_B = _AVX_MNNPackForMatMul_B; diff --git a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp index f3dc97bdc..ca87c0464 100644 --- a/source/backend/cpu/x86_x64/FunctionDispatcher.cpp +++ b/source/backend/cpu/x86_x64/FunctionDispatcher.cpp @@ -55,10 +55,7 @@ void MNNFunctionInit() { coreFunction->MNNPackedMatMulRemain_int4 = _SSE_MNNPackedMatMulRemain_int4; coreFunction->MNNPackedMatMul_int8 = _SSE_MNNPackedMatMul_int8; coreFunction->MNNPackedMatMulRemain_int8 = _SSE_MNNPackedMatMulRemain_int8; - coreFunction->MNNGemmHybridInt4 = _SSE_MNNGemmHybridInt4; - coreFunction->MNNGemmHybridInt8 = _SSE_MNNGemmHybridInt8; coreFunction->MNNAbsMax = _SSE_MNNAbsMaxFP32; - coreFunction->MNNDynamicQuant = _SSE_MNNDynamicQuantFP32; #endif coreFunction->MNNPackC4ForMatMul_A = _SSE_MNNPackC4ForMatMul_A; coreFunction->MNNPackForMatMul_B = _SSE_MNNPackForMatMul_B; @@ -137,6 +134,9 @@ void MNNInt8FunctionInit() { core->Int8GemmKernel = _SSE_MNNGemmInt8AddBiasScale_16x4_Unit; core->Int8GemmKernelFast = _SSE_MNNGemmInt8AddBiasScale_16x4_Unit; core->ConvDepthwiseLineInt8 = _SSE_MNNLineDepthWiseInt8AddBiasScaleUnit; +#ifdef MNN_LOW_MEMORY + core->Int8GemmKernel_W4 = _SSE_MNNGemmInt8AddBiasScale_16x4_w4; +#endif } } diff --git a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp index e6056c907..214010c6f 100644 --- a/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/avx/FunctionSummary.hpp @@ -46,12 +46,7 @@ void _AVX_MNNPackedMatMul_int8(float* C, const float* A, const float* B, const s const float* postParameters, const float* bias, const float* k, const float* b); void _AVX_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); -void _AVX_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, - size_t dst_depth_quad, size_t realSize, const float** param); -void _AVX_MNNGemmHybridInt8(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, - size_t dst_depth_quad, size_t realSize, const float** param); void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); -void _AVX_MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack); #endif void _AVX_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); diff --git a/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp b/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp index 2214e1688..d19863b14 100644 --- a/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmAVX2.cpp @@ -71,146 +71,6 @@ static __m128i _load_int4_to_int8(const uint8_t* src) { return int8_tx16; } -void _AVX_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, - size_t dst_depth_quad, size_t realSize, const float** param) { - int pack = 8; - size_t weight_step = src_depth_quad * pack * pack * 0.5; - size_t weight_stride = pack * pack * 0.5; - const float* alpha_ptr = param[0]; - const float* zero_ptr = param[1]; - const float* bias_ptr = param[2]; - const float* sums_ptr = param[3]; - const float* scale_ptr = param[4]; - auto one_int16 = _mm256_set1_epi16(1); - auto offset_int8 = _mm256_set1_epi8(128); - for (int ci = 0; ci < dst_depth_quad; ++ci) { - float* dstZ = C + ci * pack * realSize; - const int8_t* weight = B + ci * weight_step; - auto alpha = alpha_ptr + ci * pack; - auto zero = zero_ptr + ci * pack; - auto bias = bias_ptr + ci * pack; - __m256 alphaValue = _mm256_loadu_ps(alpha); - for (int j = 0; j < realSize; ++j) { - const float* sums = sums_ptr + j; - const float* scale = scale_ptr + j; - float* dstX = dstZ + j * pack; - __m256 scaleValue = _mm256_set1_ps(scale[0]); - auto sum_val = _mm256_set1_ps(sums[0]); - __m256 biasValue = _mm256_add_ps(_mm256_loadu_ps(bias), _mm256_mul_ps(_mm256_loadu_ps(zero), sum_val)); - const int8_t* srcBatch = A + j * pack; - auto oc0123_int16 = _mm256_set1_epi16(0); - auto oc4567_int16 = _mm256_set1_epi16(0); - auto oc0123_int32 = _mm256_set1_epi32(0); - auto oc4567_int32 = _mm256_set1_epi32(0); - const __m256i mask = _mm256_set1_epi8(0xf); - // auto extra = _mm256_set1_epi32(0); - for (int k = 0; k < src_depth_quad; ++k) { - auto srcZ = srcBatch + k * pack * realSize; - const uint8_t* weightZ = (uint8_t*)weight + k * weight_stride; - auto s0 = _mm256_castpd_si256(_mm256_broadcast_sd((double*)srcZ)); - auto wi4 = _mm256_castps_si256(_mm256_loadu_ps((const float*)weightZ)); - auto w0_ = _mm256_and_si256(mask, _mm256_srli_epi16(wi4, 4)); - auto w1_ = _mm256_and_si256(mask, wi4); - auto w0 = _mm256_permute2x128_si256(w0_, w1_, 0x20); - auto w1 = _mm256_permute2x128_si256(w0_, w1_, 0x31); - oc0123_int16 = _mm256_maddubs_epi16(w0, s0); // int16_t sum - oc4567_int16 = _mm256_maddubs_epi16(w1, s0); // int16_t sum - oc0123_int32 = _mm256_add_epi32(_mm256_madd_epi16(oc0123_int16, one_int16), oc0123_int32); - oc4567_int32 = _mm256_add_epi32(_mm256_madd_epi16(oc4567_int16, one_int16), oc4567_int32); - } - - auto oc0426_int32 = _mm256_unpacklo_epi32(oc0123_int32, oc4567_int32); - auto oc1537_int32 = _mm256_unpackhi_epi32(oc0123_int32, oc4567_int32); - auto tmp0 = _mm256_unpacklo_epi32(oc0426_int32, oc1537_int32); // 01452367 - auto tmp1 = _mm256_unpackhi_epi32(oc0426_int32, oc1537_int32); // 01452367 - auto tmp2 = _mm256_add_epi32(tmp0, tmp1); // 01452367 - auto oc0145 = _mm256_extractf128_si256(tmp2, 0); - auto oc2367 = _mm256_extractf128_si256(tmp2, 1); - auto oc0123 = _mm_unpacklo_epi64(oc0145, oc2367); - auto oc4567 = _mm_unpackhi_epi64(oc0145, oc2367); - - auto sum8 = _mm256_set_m128i(oc4567, oc0123); - - __m256 f0 = _mm256_cvtepi32_ps(sum8); - __m256 fs = _mm256_mul_ps(_mm256_mul_ps(f0, scaleValue), alphaValue); - fs = _mm256_add_ps(biasValue, fs); - _mm256_storeu_ps(dstX, fs); - - } - } -} -void _AVX_MNNGemmHybridInt8(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, - size_t dst_depth_quad, size_t realSize, const float** param) { - int pack = 8; - size_t weight_step = src_depth_quad * pack * pack; - const float* alpha_ptr = param[0]; - const float* zero_ptr = param[1]; - const float* bias_ptr = param[2]; - const float* sums_ptr = param[3]; - const float* scale_ptr = param[4]; - for (int ci = 0; ci < dst_depth_quad; ++ci) { - float* dstZ = C + ci * pack * realSize; - const int8_t* weight = B + ci * weight_step; - auto alpha = alpha_ptr + ci * pack; - auto zero = zero_ptr + ci * pack; - auto bias = bias_ptr + ci * pack; - __m256 alphaValue = _mm256_load_ps(alpha); - for (int j = 0; j < realSize; ++j) { - const float* sums = sums_ptr + j; - const float* scale = scale_ptr + j; - float* dstX = dstZ + j * pack; - __m256 scaleValue = _mm256_set1_ps(scale[0]); - __m256 biasValue = _mm256_add_ps(_mm256_load_ps(bias), _mm256_mul_ps(_mm256_load_ps(zero), _mm256_set1_ps(sums[0]))); - const int8_t* srcBatch = A + j * pack; - auto oc0_and_1 = _mm256_set1_epi32(0); - auto oc2_and_3 = _mm256_set1_epi32(0); - auto oc4_and_5 = _mm256_set1_epi32(0); - auto oc6_and_7 = _mm256_set1_epi32(0); - for (int k = 0; k < src_depth_quad; ++k) { - const int8_t* srcZ = srcBatch + k * pack * realSize; - const int8_t* weightZ = weight + k * pack * pack; - auto w0 = _mm_loadu_si128((__m128i const*)weightZ); // w0-1 - auto w1 = _mm_loadu_si128((__m128i const*)(weightZ + 16)); - auto w2 = _mm_loadu_si128((__m128i const*)(weightZ + 16 * 2)); - auto w3 = _mm_loadu_si128((__m128i const*)(weightZ + 16 * 3)); - auto w0_16= _mm256_cvtepi8_epi16(w0); //16xint16_t - auto w1_16= _mm256_cvtepi8_epi16(w1); - auto w2_16= _mm256_cvtepi8_epi16(w2); - auto w3_16= _mm256_cvtepi8_epi16(w3); - auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)srcZ + 0)); - auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)srcZ + 1)); - auto s0_16 = _mm256_cvtepi8_epi16(s0); - auto s1_16 = _mm256_cvtepi8_epi16(s1); - auto S_int16 = _mm256_unpacklo_epi64(s0_16, s1_16); - oc0_and_1 = _mm256_add_epi32(oc0_and_1, _mm256_madd_epi16(S_int16, w0_16)); - oc2_and_3 = _mm256_add_epi32(oc2_and_3, _mm256_madd_epi16(S_int16, w1_16)); - oc4_and_5 = _mm256_add_epi32(oc4_and_5, _mm256_madd_epi16(S_int16, w2_16)); - oc6_and_7 = _mm256_add_epi32(oc6_and_7, _mm256_madd_epi16(S_int16, w3_16)); - } - auto oc_02021313_lo = _mm256_unpacklo_epi32(oc0_and_1, oc2_and_3); - auto oc_02021313_hi = _mm256_unpackhi_epi32(oc0_and_1, oc2_and_3); - auto oc_46465757_lo = _mm256_unpacklo_epi32(oc4_and_5, oc6_and_7); - auto oc_46465757_hi = _mm256_unpackhi_epi32(oc4_and_5, oc6_and_7); - auto oc_02021313 = _mm256_add_epi32(oc_02021313_lo, oc_02021313_hi); - auto oc_46465757 = _mm256_add_epi32(oc_46465757_lo, oc_46465757_hi); - auto oc_04261537_lo = _mm256_unpacklo_epi32(oc_02021313, oc_46465757); - auto oc_04261537_hi = _mm256_unpackhi_epi32(oc_02021313, oc_46465757); - auto oc_04261537 = _mm256_add_epi32(oc_04261537_lo, oc_04261537_hi); - auto oc_0426 = _mm256_extractf128_si256(oc_04261537, 0); - auto oc_1537 = _mm256_extractf128_si256(oc_04261537, 1); - auto oc_0145 = _mm_unpacklo_epi32(oc_0426, oc_1537); - auto oc_2367 = _mm_unpackhi_epi32(oc_0426, oc_1537); - auto oc_0123 = _mm_unpacklo_epi64(oc_0145, oc_2367); - auto oc_4567 = _mm_unpackhi_epi64(oc_0145, oc_2367); - auto sum8 = _mm256_set_m128i(oc_4567, oc_0123); - __m256 f0 = _mm256_cvtepi32_ps(sum8); - __m256 fs = _mm256_mul_ps(_mm256_mul_ps(f0, scaleValue), alphaValue); - fs = _mm256_add_ps(biasValue, fs); - _mm256_storeu_ps(dstX, fs); - } - } -} - void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { // source: (ic/8, N, 8) auto srcStep = pack * realSize; @@ -236,40 +96,6 @@ void _AVX_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_qua } } -void _AVX_MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack) { - // AVX: pack=8 - __m256 zero = _mm256_setzero_ps(); - __m256 plus = _mm256_set1_ps(0.5f); - __m256 minus = _mm256_set1_ps(-0.5f); - auto offset = _mm256_set1_epi32(128); - uint8_t* dstPtr = reinterpret_cast(dst); - float temp[8]; - for (int i = 0; i < realSize; ++i) { - __m256 scaleVal = _mm256_set1_ps(scale[i]); - __m256 acc = _mm256_setzero_ps(); - for (int c = 0; c < src_depth_quad; ++c) { - auto srcZ = src + c * pack * realSize + i * pack; - auto dstZ = dstPtr + c * pack * realSize + i * pack; - __m256 f0 = _mm256_loadu_ps(srcZ); - __m256 m0 = _mm256_mul_ps(f0, scaleVal); - __m256 mask = _mm256_cmp_ps(m0, zero, 1); - __m256 d0 = _mm256_blendv_ps(plus, minus, mask); - d0 = _mm256_add_ps(d0, m0); - __m256 round0 = _mm256_round_ps(d0, 3); - auto d0_epi32 = _mm256_cvtps_epi32(round0); // int32x8 - auto d0_epi16 = _mm256_packs_epi32(d0_epi32, _mm256_castps_si256(_mm256_permute2f128_ps(_mm256_castsi256_ps(d0_epi32), _mm256_castsi256_ps(d0_epi32), 1))); - // d0_epi32 = _mm256_packs_epi32(d0_epi32, d0_epi32); // int16x8 - d0_epi32 = _mm256_packs_epi16(d0_epi16, d0_epi16); // int8x8 - auto D0 = _mm_castsi128_ps(_mm256_extracti128_si256(d0_epi32, 0)); - _mm_storeu_ps(temp, D0); - ::memcpy(dstZ, temp, pack * sizeof(int8_t)); - acc = _mm256_add_ps(acc, round0); - } - _mm256_storeu_ps(temp, acc); - int sumVal = static_cast(temp[0] + temp[1] + temp[2] + temp[3] + temp[4] + temp[5] + temp[6] + temp[7]); - ((int32_t*)sum)[i] = sumVal; - } -} #endif void _AVX_MNNComputeMatMulForE_1(const float* A, const float* B, float* C, const float* biasPtr, const MatMulParam* param, size_t tId) { diff --git a/source/backend/cpu/x86_x64/avx/GemmCommon.cpp b/source/backend/cpu/x86_x64/avx/GemmCommon.cpp index 0753e7f8d..ed944f4c6 100644 --- a/source/backend/cpu/x86_x64/avx/GemmCommon.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmCommon.cpp @@ -420,11 +420,11 @@ void _AVX_MNNPackedSparseMatMul(float* C, const float* A, const float* B, unsign void _AVX_MNNComputeScaleZeroScalar(float* source, float* min, float* max, size_t size) { int pack = 8; - int sizeDiv8 = UP_DIV(size, pack); - __m256 minVal = _mm256_loadu_ps(source); + int sizeDiv8 = size / pack; + __m256 minVal = _mm256_set1_ps(source[0]); __m256 maxVal = minVal; float maxArr[8], minArr[8]; - for (int i = 1; i < sizeDiv8; ++i) { + for (int i = 0; i < sizeDiv8; ++i) { auto src0 = source + pack * i; __m256 vecA = _mm256_loadu_ps(src0); __m256 maskMax = _mm256_cmp_ps(vecA, maxVal, 14); @@ -432,7 +432,6 @@ void _AVX_MNNComputeScaleZeroScalar(float* source, float* min, float* max, size_ maxVal = _mm256_blendv_ps(maxVal, vecA, maskMax); minVal = _mm256_blendv_ps(minVal, vecA, maskMin); } - _mm256_storeu_ps(maxArr, maxVal); _mm256_storeu_ps(minArr, minVal); float max_ = maxArr[0], min_ = minArr[0]; @@ -444,12 +443,11 @@ void _AVX_MNNComputeScaleZeroScalar(float* source, float* min, float* max, size_ min_ = minArr[k]; } } + for (int i = pack * sizeDiv8; i < size; ++i) { + max_ = std::max(max_, source[i]); + min_ = std::min(min_, source[i]); + } min[0] = min_; max[0] = max_; - // float range = max_ - min_; - // MNN_ASSERT(range != 0); - // *quantScale = 255.0f / range; - // *dequantScale = range / 255.0f; - // *zeroPoint = std::min(255.f, std::max(roundf(-(min_ * 255.f) / range), 0.f)) - 128.f; } diff --git a/source/backend/cpu/x86_x64/avx/GemmInt8.cpp b/source/backend/cpu/x86_x64/avx/GemmInt8.cpp index 18c4422d1..1a6b60746 100644 --- a/source/backend/cpu/x86_x64/avx/GemmInt8.cpp +++ b/source/backend/cpu/x86_x64/avx/GemmInt8.cpp @@ -51,6 +51,450 @@ auto d##i = _mm_add_epi32(d##i##0, d##i##1); #define COMPUTE(u, v)\ D##u##v = _mm256_add_epi32(D##u##v, _mm256_madd_epi16(W##u, S##v)); +#define LOAD_INT4_TO_INT8 \ +auto w_int4 = _mm_loadu_si128((__m128i const*)weight_sz);\ +auto w_int4_high = _mm_and_si128(mask, _mm_srli_epi16(w_int4, 4));\ +auto w_int4_low = _mm_and_si128(mask, w_int4);\ +auto w_0 = _mm_unpacklo_epi8(w_int4_high, w_int4_low);\ +auto w_1 = _mm_unpackhi_epi8(w_int4_high, w_int4_low); + +void _AVX_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { + MNN_ASSERT(post->useInt8==0); + const auto dst_step_tmp = dst_step / sizeof(int8_t); + auto zero128 = _mm256_set1_ps(0.0f); + auto minValue = _mm256_set1_ps(post->minValue); + auto maxValue = _mm256_set1_ps(post->maxValue); + auto offset = _mm256_set1_epi32(128); + __m256 fp32min, fp32max; + if (post->fp32minmax) { + fp32min = _mm256_set1_ps((post->fp32minmax)[0]); + fp32max = _mm256_set1_ps((post->fp32minmax)[1]); + } + int blockNum = post->blockNum; + const float* biasPtr = nullptr; + if (post->biasFloat) { + biasPtr = post->biasFloat; + } + + int weight_step_Z = 0.5 * blockNum * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); + int weight_step_Y = 0.5 * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); + const __m128i mask = _mm_set1_epi8(0xf); + + auto srcKernelSumPtr = post->srcKernelSum; + __m256 kernelSum0 = _mm256_setzero_ps(); + __m256 kernelSum1 = _mm256_setzero_ps(); + __m256 kernelSum2 = _mm256_setzero_ps(); + __m256 kernelSum3 = _mm256_setzero_ps(); + if (GEMMINT8_AVX2_E == realDst) { + kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]); + kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]); + kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); + kernelSum3 = _mm256_set1_ps(post->srcKernelSum[3]); + } else { + kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]); + if (realDst > 1) { + kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]); + } + if (realDst > 2) { + kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); + } + } + auto f128 = _mm256_set1_ps(128.f); + __m256 extrascale0 = _mm256_setzero_ps(); + __m256 extrascale1 = _mm256_setzero_ps(); + __m256 extrascale2 = _mm256_setzero_ps(); + __m256 extrascale3 = _mm256_setzero_ps(); + if (post->extraScale) { + if (GEMMINT8_AVX2_E == realDst) { + extrascale0 = _mm256_set1_ps(post->extraScale[0]); + extrascale1 = _mm256_set1_ps(post->extraScale[1]); + extrascale2 = _mm256_set1_ps(post->extraScale[2]); + extrascale3 = _mm256_set1_ps(post->extraScale[3]); + } else { + extrascale0 = _mm256_set1_ps(post->extraScale[0]); + if (realDst > 1) { + extrascale1 = _mm256_set1_ps(post->extraScale[1]); + } + if (realDst > 2) { + extrascale2 = _mm256_set1_ps(post->extraScale[2]); + } + } + } + //printf("e=%d, sz=%d, dz=%d\n", realDst, src_depth_quad, dst_depth_quad); + if (GEMMINT8_AVX2_E == realDst) { + for (int dz = 0; dz < dst_depth_quad; ++dz) { + const auto weight_dz = weight + dz * weight_step_Z; + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; + const float* scale_dz = post->scale + dz * AVX2_PACKINT8; + auto dst_z = dst + dz * dst_step_tmp; + const auto src_x = src; + auto dst_x = dst_z; + __m256i D00 = _mm256_set1_epi32(0); + __m256i D01 = _mm256_set1_epi32(0); + __m256i D02 = _mm256_set1_epi32(0); + __m256i D03 = _mm256_set1_epi32(0); + __m256i D10 = _mm256_set1_epi32(0); + __m256i D11 = _mm256_set1_epi32(0); + __m256i D12 = _mm256_set1_epi32(0); + __m256i D13 = _mm256_set1_epi32(0); + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + sz * weight_step_Y; + const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E; + LOAD_INT4_TO_INT8; + auto W0 = _mm256_cvtepi8_epi16(w_0); + auto W1 = _mm256_cvtepi8_epi16(w_1); + + auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0)); + auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1)); + auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2)); + auto s3 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 3)); + auto S0 = _mm256_cvtepu8_epi16(s0); + auto S1 = _mm256_cvtepu8_epi16(s1); + auto S2 = _mm256_cvtepu8_epi16(s2); + auto S3 = _mm256_cvtepu8_epi16(s3); + + COMPUTE(0, 0); + COMPUTE(1, 0); + COMPUTE(0, 1); + COMPUTE(1, 1); + COMPUTE(0, 2); + COMPUTE(1, 2); + COMPUTE(0, 3); + COMPUTE(1, 3); + } + auto D0 = NORMAL_HADD(D00, D10); + auto D1 = NORMAL_HADD(D01, D11); + auto D2 = NORMAL_HADD(D02, D12); + auto D3 = NORMAL_HADD(D03, D13); + auto scaleValue = _mm256_loadu_ps(scale_dz); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + + auto f0 = _mm256_cvtepi32_ps(D0); + auto f1 = _mm256_cvtepi32_ps(D1); + auto f2 = _mm256_cvtepi32_ps(D2); + auto f3 = _mm256_cvtepi32_ps(D3); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second + auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third + auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth + f0 = _mm256_mul_ps(f0, scaleValue); + f1 = _mm256_mul_ps(f1, scaleValue); + f2 = _mm256_mul_ps(f2, scaleValue); + f3 = _mm256_mul_ps(f3, scaleValue); + if (post->extraScale) { + f0 = _mm256_mul_ps(f0, extrascale0); + f1 = _mm256_mul_ps(f1, extrascale1); + f2 = _mm256_mul_ps(f2, extrascale2); + f3 = _mm256_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * AVX2_PACKINT8; + auto extrabias = _mm256_loadu_ps(extraB); + extrabias = _mm256_mul_ps(f128, extrabias); + auto extrabias0 = _mm256_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm256_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm256_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm256_mul_ps(extrabias, extrascale3); + f0 = _mm256_sub_ps(f0, extrabias0); + f1 = _mm256_sub_ps(f1, extrabias1); + f2 = _mm256_sub_ps(f2, extrabias2); + f3 = _mm256_sub_ps(f3, extrabias3); + } + } + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + f2 = _mm256_add_ps(f2, xy0_2); + f3 = _mm256_add_ps(f3, xy0_3); + + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); + f2 = _mm256_add_ps(f2, biasValue); + f3 = _mm256_add_ps(f3, biasValue); + } else { + auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); + auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); + auto dstv2 = _mm256_loadu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8); + auto dstv3 = _mm256_loadu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8); + f0 = _mm256_add_ps(f0, dstv0); + f1 = _mm256_add_ps(f1, dstv1); + f2 = _mm256_add_ps(f2, dstv2); + f3 = _mm256_add_ps(f3, dstv3); + } + if (post->fp32minmax) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f2 = _mm256_min_ps(f2, fp32max); + f3 = _mm256_min_ps(f3, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); + f2 = _mm256_max_ps(f2, fp32min); + f3 = _mm256_max_ps(f3, fp32min); + } + _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); + _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); + _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); + _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3); + + } + return; + } + if (3 == realDst) { + for (int dz = 0; dz < dst_depth_quad; ++dz) { + const auto weight_dz = weight + dz * weight_step_Z; + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; + const float* scale_dz = post->scale + dz * AVX2_PACKINT8; + auto dst_z = dst + dz * dst_step_tmp; + const auto src_x = src; + auto dst_x = dst_z; + __m256i D00 = _mm256_set1_epi32(0); + __m256i D01 = _mm256_set1_epi32(0); + __m256i D02 = _mm256_set1_epi32(0); + + __m256i D10 = _mm256_set1_epi32(0); + __m256i D11 = _mm256_set1_epi32(0); + __m256i D12 = _mm256_set1_epi32(0); + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + sz * weight_step_Y; + const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E; + LOAD_INT4_TO_INT8; + + auto W0 = _mm256_cvtepi8_epi16(w_0); + auto W1 = _mm256_cvtepi8_epi16(w_1); + + auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0)); + auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1)); + auto s2 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 2)); + auto S0 = _mm256_cvtepu8_epi16(s0); + auto S1 = _mm256_cvtepu8_epi16(s1); + auto S2 = _mm256_cvtepu8_epi16(s2); + + COMPUTE(0, 0); + COMPUTE(1, 0); + COMPUTE(0, 1); + COMPUTE(1, 1); + COMPUTE(0, 2); + COMPUTE(1, 2); + } + auto D0 = NORMAL_HADD(D00, D10); + auto D1 = NORMAL_HADD(D01, D11); + auto D2 = NORMAL_HADD(D02, D12); + auto scaleValue = _mm256_loadu_ps(scale_dz); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + + auto f0 = _mm256_cvtepi32_ps(D0); + auto f1 = _mm256_cvtepi32_ps(D1); + auto f2 = _mm256_cvtepi32_ps(D2); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second + auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third + f0 = _mm256_mul_ps(f0, scaleValue); + f1 = _mm256_mul_ps(f1, scaleValue); + f2 = _mm256_mul_ps(f2, scaleValue); + if (post->extraScale) { + f0 = _mm256_mul_ps(f0, extrascale0); + f1 = _mm256_mul_ps(f1, extrascale1); + f2 = _mm256_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * AVX2_PACKINT8; + auto extrabias = _mm256_loadu_ps(extraB); + extrabias = _mm256_mul_ps(f128, extrabias); + auto extrabias0 = _mm256_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm256_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm256_mul_ps(extrabias, extrascale2); + f0 = _mm256_sub_ps(f0, extrabias0); + f1 = _mm256_sub_ps(f1, extrabias1); + f2 = _mm256_sub_ps(f2, extrabias2); + } + } + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + f2 = _mm256_add_ps(f2, xy0_2); + + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); + f2 = _mm256_add_ps(f2, biasValue); + } else { + auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); + auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); + auto dstv2 = _mm256_loadu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8); + f0 = _mm256_add_ps(f0, dstv0); + f1 = _mm256_add_ps(f1, dstv1); + f2 = _mm256_add_ps(f2, dstv2); + } + if (post->fp32minmax) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f2 = _mm256_min_ps(f2, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); + f2 = _mm256_max_ps(f2, fp32min); + } + _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); + _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); + _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); + + } + return; + } + if (2 == realDst) { + for (int dz = 0; dz < dst_depth_quad; ++dz) { + const auto weight_dz = weight + dz * weight_step_Z; + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; + const float* scale_dz = post->scale + dz * AVX2_PACKINT8; + auto dst_z = dst + dz * dst_step_tmp; + const auto src_x = src; + auto dst_x = dst_z; + __m256i D00 = _mm256_set1_epi32(0); + __m256i D01 = _mm256_set1_epi32(0); + + __m256i D10 = _mm256_set1_epi32(0); + __m256i D11 = _mm256_set1_epi32(0); + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + sz * weight_step_Y; + const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E; + LOAD_INT4_TO_INT8; + auto W0 = _mm256_cvtepi8_epi16(w_0); + auto W1 = _mm256_cvtepi8_epi16(w_1); + + auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0)); + auto s1 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 1)); + auto S0 = _mm256_cvtepu8_epi16(s0); + auto S1 = _mm256_cvtepu8_epi16(s1); + + COMPUTE(0, 0); + COMPUTE(1, 0); + COMPUTE(0, 1); + COMPUTE(1, 1); + } + auto D0 = NORMAL_HADD(D00, D10); + auto D1 = NORMAL_HADD(D01, D11); + auto scaleValue = _mm256_loadu_ps(scale_dz); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + + auto f0 = _mm256_cvtepi32_ps(D0); + auto f1 = _mm256_cvtepi32_ps(D1); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second + f0 = _mm256_mul_ps(f0, scaleValue); + f1 = _mm256_mul_ps(f1, scaleValue); + if (post->extraScale) { + f0 = _mm256_mul_ps(f0, extrascale0); + f1 = _mm256_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * AVX2_PACKINT8; + auto extrabias = _mm256_loadu_ps(extraB); + extrabias = _mm256_mul_ps(f128, extrabias); + auto extrabias0 = _mm256_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm256_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm256_mul_ps(extrabias, extrascale2); + f0 = _mm256_sub_ps(f0, extrabias0); + f1 = _mm256_sub_ps(f1, extrabias1); + } + } + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); + } else { + auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); + auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); + f0 = _mm256_add_ps(f0, dstv0); + f1 = _mm256_add_ps(f1, dstv1); + } + if (post->fp32minmax) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); + } + _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); + _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); + + } + return; + } + if (1 == realDst) { + for (int dz = 0; dz < dst_depth_quad; ++dz) { + const auto weight_dz = weight + dz * weight_step_Z; + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; + const float* scale_dz = post->scale + dz * AVX2_PACKINT8; + auto dst_z = dst + dz * dst_step_tmp; + const auto src_x = src; + auto dst_x = dst_z; + __m256i D00 = _mm256_set1_epi32(0); + __m256i D10 = _mm256_set1_epi32(0); + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + sz * weight_step_Y; + const auto src_z = src_x + sz * GEMMINT8_AVX2_L * GEMMINT8_AVX2_E; + LOAD_INT4_TO_INT8; + auto W0 = _mm256_cvtepi8_epi16(w_0); + auto W1 = _mm256_cvtepi8_epi16(w_1); + + auto s0 = _mm_castps_si128(_mm_broadcast_ss((float*)src_z + 0)); + auto S0 = _mm256_cvtepu8_epi16(s0); + + COMPUTE(0, 0); + COMPUTE(1, 0); + } + auto D0 = NORMAL_HADD(D00, D10); + auto scaleValue = _mm256_loadu_ps(scale_dz); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + + auto f0 = _mm256_cvtepi32_ps(D0); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + f0 = _mm256_mul_ps(f0, scaleValue); + if (post->extraScale) { + f0 = _mm256_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * AVX2_PACKINT8; + auto extrabias = _mm256_loadu_ps(extraB); + extrabias = _mm256_mul_ps(f128, extrabias); + auto extrabias0 = _mm256_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm256_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm256_mul_ps(extrabias, extrascale2); + f0 = _mm256_sub_ps(f0, extrabias0); + } + } + f0 = _mm256_add_ps(f0, xy0_0); + + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + } else { + auto dstv = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); + f0 = _mm256_add_ps(f0, dstv); + } + if (post->fp32minmax) { + f0 = _mm256_min_ps(f0, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + } + + _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); + + } + return; + } + +} + void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { const auto dst_step_tmp = dst_step / sizeof(int8_t); auto zero128 = _mm256_set1_ps(0.0f); @@ -59,11 +503,61 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons auto plus = _mm256_set1_ps(0.5f); auto minus = _mm256_set1_ps(-0.5f); auto offset = _mm256_set1_epi32(128); + __m256 fp32min, fp32max; + if (0 == post->useInt8 && post->fp32minmax) { + fp32min = _mm256_set1_ps((post->fp32minmax)[0]); + fp32max = _mm256_set1_ps((post->fp32minmax)[1]); + } + int blockNum = post->blockNum; + const float* biasPtr = nullptr; + if (post->biasFloat) { + biasPtr = post->biasFloat; + } + auto srcKernelSumPtr = post->srcKernelSum; + __m256 kernelSum0 = _mm256_setzero_ps(); + __m256 kernelSum1 = _mm256_setzero_ps(); + __m256 kernelSum2 = _mm256_setzero_ps(); + __m256 kernelSum3 = _mm256_setzero_ps(); + if (GEMMINT8_AVX2_E == realDst) { + kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]); + kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]); + kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); + kernelSum3 = _mm256_set1_ps(post->srcKernelSum[3]); + } else { + kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]); + if (realDst > 1) { + kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]); + } + if (realDst > 2) { + kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); + } + } + auto f128 = _mm256_set1_ps(128.f); + __m256 extrascale0 = _mm256_setzero_ps(); + __m256 extrascale1 = _mm256_setzero_ps(); + __m256 extrascale2 = _mm256_setzero_ps(); + __m256 extrascale3 = _mm256_setzero_ps(); + if (post->extraScale) { + if (GEMMINT8_AVX2_E == realDst) { + extrascale0 = _mm256_set1_ps(post->extraScale[0]); + extrascale1 = _mm256_set1_ps(post->extraScale[1]); + extrascale2 = _mm256_set1_ps(post->extraScale[2]); + extrascale3 = _mm256_set1_ps(post->extraScale[3]); + } else { + extrascale0 = _mm256_set1_ps(post->extraScale[0]); + if (realDst > 1) { + extrascale1 = _mm256_set1_ps(post->extraScale[1]); + } + if (realDst > 2) { + extrascale2 = _mm256_set1_ps(post->extraScale[2]); + } + } + } //printf("e=%d, sz=%d, dz=%d\n", realDst, src_depth_quad, dst_depth_quad); if (GEMMINT8_AVX2_E == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { - const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); - const auto bias_dz = post->bias + dz * AVX2_PACKINT8; + const auto weight_dz = weight + dz * blockNum * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = post->scale + dz * AVX2_PACKINT8; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; @@ -107,40 +601,92 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons auto D1 = NORMAL_HADD(D01, D11); auto D2 = NORMAL_HADD(D02, D12); auto D3 = NORMAL_HADD(D03, D13); - - auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); - D0 = _mm256_add_epi32(D0, biasValue0); - D1 = _mm256_add_epi32(D1, biasValue0); - D2 = _mm256_add_epi32(D2, biasValue0); - D3 = _mm256_add_epi32(D3, biasValue0); - auto scaleValue = _mm256_loadu_ps(scale_dz); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); auto f3 = _mm256_cvtepi32_ps(D3); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second + auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third + auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); f3 = _mm256_mul_ps(f3, scaleValue); - if (post->useInt8 == 0) { - _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); - _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); - _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); - _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3); - } else { + if (post->extraScale) { + f0 = _mm256_mul_ps(f0, extrascale0); + f1 = _mm256_mul_ps(f1, extrascale1); + f2 = _mm256_mul_ps(f2, extrascale2); + f3 = _mm256_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * AVX2_PACKINT8; + auto extrabias = _mm256_loadu_ps(extraB); + extrabias = _mm256_mul_ps(f128, extrabias); + auto extrabias0 = _mm256_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm256_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm256_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm256_mul_ps(extrabias, extrascale3); + f0 = _mm256_sub_ps(f0, extrabias0); + f1 = _mm256_sub_ps(f1, extrabias1); + f2 = _mm256_sub_ps(f2, extrabias2); + f3 = _mm256_sub_ps(f3, extrabias3); + } + } + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + f2 = _mm256_add_ps(f2, xy0_2); + f3 = _mm256_add_ps(f3, xy0_3); + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); + f2 = _mm256_add_ps(f2, biasValue); + f3 = _mm256_add_ps(f3, biasValue); + } + if (post->useInt8 == 1) { POSTTREAT(0); POSTTREAT(1); POSTTREAT(2); POSTTREAT(3); + } else { + if (nullptr == biasPtr) { + auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); + auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); + auto dstv2 = _mm256_loadu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8); + auto dstv3 = _mm256_loadu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8); + + f0 = _mm256_add_ps(f0, dstv0); + f1 = _mm256_add_ps(f1, dstv1); + f2 = _mm256_add_ps(f2, dstv2); + f3 = _mm256_add_ps(f3, dstv3); + } + if (post->fp32minmax) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f2 = _mm256_min_ps(f2, fp32max); + f3 = _mm256_min_ps(f3, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); + f2 = _mm256_max_ps(f2, fp32min); + f3 = _mm256_max_ps(f3, fp32min); + } + _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); + _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); + _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); + _mm256_storeu_ps(((float*)dst_x) + 3 * AVX2_PACKINT8, f3); } } return; } if (3 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { - const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); - const auto bias_dz = post->bias + dz * AVX2_PACKINT8; + const auto weight_dz = weight + dz * blockNum * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = post->scale + dz * AVX2_PACKINT8; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; @@ -178,35 +724,77 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons auto D0 = NORMAL_HADD(D00, D10); auto D1 = NORMAL_HADD(D01, D11); auto D2 = NORMAL_HADD(D02, D12); - - auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); - D0 = _mm256_add_epi32(D0, biasValue0); - D1 = _mm256_add_epi32(D1, biasValue0); - D2 = _mm256_add_epi32(D2, biasValue0); - auto scaleValue = _mm256_loadu_ps(scale_dz); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second + auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); - if (post->useInt8 == 0) { - _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); - _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); - _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); - } else { + if (post->extraScale) { + f0 = _mm256_mul_ps(f0, extrascale0); + f1 = _mm256_mul_ps(f1, extrascale1); + f2 = _mm256_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * AVX2_PACKINT8; + auto extrabias = _mm256_loadu_ps(extraB); + extrabias = _mm256_mul_ps(f128, extrabias); + auto extrabias0 = _mm256_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm256_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm256_mul_ps(extrabias, extrascale2); + f0 = _mm256_sub_ps(f0, extrabias0); + f1 = _mm256_sub_ps(f1, extrabias1); + f2 = _mm256_sub_ps(f2, extrabias2); + } + } + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + f2 = _mm256_add_ps(f2, xy0_2); + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); + f2 = _mm256_add_ps(f2, biasValue); + } + if (post->useInt8 == 1) { POSTTREAT(0); POSTTREAT(1); POSTTREAT(2); + } else { + if (nullptr == biasPtr) { + auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); + auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); + auto dstv2 = _mm256_loadu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8); + f0 = _mm256_add_ps(f0, dstv0); + f1 = _mm256_add_ps(f1, dstv1); + f2 = _mm256_add_ps(f2, dstv2); + } + if (post->fp32minmax) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f2 = _mm256_min_ps(f2, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); + f2 = _mm256_max_ps(f2, fp32min); + } + _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); + _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); + _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); } } return; } if (2 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { - const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); - const auto bias_dz = post->bias + dz * AVX2_PACKINT8; + const auto weight_dz = weight + dz * blockNum * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = post->scale + dz * AVX2_PACKINT8; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; @@ -237,30 +825,64 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons } auto D0 = NORMAL_HADD(D00, D10); auto D1 = NORMAL_HADD(D01, D11); - - auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); - D0 = _mm256_add_epi32(D0, biasValue0); - D1 = _mm256_add_epi32(D1, biasValue0); - auto scaleValue = _mm256_loadu_ps(scale_dz); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); - if (post->useInt8 == 0) { - _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); - _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); - } else { + if (post->extraScale) { + f0 = _mm256_mul_ps(f0, extrascale0); + f1 = _mm256_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * AVX2_PACKINT8; + auto extrabias = _mm256_loadu_ps(extraB); + extrabias = _mm256_mul_ps(f128, extrabias); + auto extrabias0 = _mm256_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm256_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm256_mul_ps(extrabias, extrascale2); + f0 = _mm256_sub_ps(f0, extrabias0); + f1 = _mm256_sub_ps(f1, extrabias1); + } + } + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); + } + if (post->useInt8 == 1) { POSTTREAT(0); POSTTREAT(1); + } else { + if (nullptr == biasPtr) { + auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); + auto dstv1 = _mm256_loadu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8); + f0 = _mm256_add_ps(f0, dstv0); + f1 = _mm256_add_ps(f1, dstv1); + } + if (post->fp32minmax) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); + } + _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); + _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); } } return; } if (1 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { - const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); - const auto bias_dz = post->bias + dz * AVX2_PACKINT8; + const auto weight_dz = weight + dz * blockNum * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = post->scale + dz * AVX2_PACKINT8; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; @@ -283,17 +905,43 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons COMPUTE(1, 0); } auto D0 = NORMAL_HADD(D00, D10); - - auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); - D0 = _mm256_add_epi32(D0, biasValue0); - auto scaleValue = _mm256_loadu_ps(scale_dz); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + auto f0 = _mm256_cvtepi32_ps(D0); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first f0 = _mm256_mul_ps(f0, scaleValue); - if (post->useInt8 == 0) { - _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); - } else { + if (post->extraScale) { + f0 = _mm256_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * AVX2_PACKINT8; + auto extrabias = _mm256_loadu_ps(extraB); + extrabias = _mm256_mul_ps(f128, extrabias); + auto extrabias0 = _mm256_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm256_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm256_mul_ps(extrabias, extrascale2); + f0 = _mm256_sub_ps(f0, extrabias0); + } + } + f0 = _mm256_add_ps(f0, xy0_0); + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * AVX2_PACKINT8; + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + } + if (post->useInt8 == 1) { POSTTREAT(0); + } else { + if (nullptr == biasPtr) { + auto dstv0 = _mm256_loadu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8); + f0 = _mm256_add_ps(f0, dstv0); + } + if (post->fp32minmax) { + f0 = _mm256_min_ps(f0, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + } + _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); } } return; @@ -309,11 +957,36 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, auto minus = _mm256_set1_ps(-0.5f); auto oneValue = _mm256_set1_epi16(1); auto offset = _mm256_set1_epi32(128); + __m256 fp32min, fp32max; + if (0 == post->useInt8) { + fp32min = _mm256_set1_ps((post->fp32minmax)[0]); + fp32max = _mm256_set1_ps((post->fp32minmax)[1]); + } + auto srcKernelSumPtr = post->srcKernelSum; + __m256 kernelSum0 = _mm256_setzero_ps(); + __m256 kernelSum1 = _mm256_setzero_ps(); + __m256 kernelSum2 = _mm256_setzero_ps(); + __m256 kernelSum3 = _mm256_setzero_ps(); + if (GEMMINT8_AVX2_E == realDst) { + kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]); + kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]); + kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); + kernelSum3 = _mm256_set1_ps(post->srcKernelSum[3]); + } else { + kernelSum0 = _mm256_set1_ps(post->srcKernelSum[0]); + if (realDst > 1) { + kernelSum1 = _mm256_set1_ps(post->srcKernelSum[1]); + } + if (realDst > 2) { + kernelSum2 = _mm256_set1_ps(post->srcKernelSum[2]); + } + } //printf("e=%d, sz=%d, dz=%d\n", realDst, src_depth_quad, dst_depth_quad); if (GEMMINT8_AVX2_E == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); - const auto bias_dz = post->bias + dz * AVX2_PACKINT8; + const auto bias_dz = post->biasFloat + dz * AVX2_PACKINT8; + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = post->scale + dz * AVX2_PACKINT8; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; @@ -344,22 +1017,45 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, auto D2 = D02; auto D3 = D03; - auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); - D0 = _mm256_add_epi32(D0, biasValue0); - D1 = _mm256_add_epi32(D1, biasValue0); - D2 = _mm256_add_epi32(D2, biasValue0); - D3 = _mm256_add_epi32(D3, biasValue0); + // auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + // D0 = _mm256_add_epi32(D0, biasValue0); + // D1 = _mm256_add_epi32(D1, biasValue0); + // D2 = _mm256_add_epi32(D2, biasValue0); + // D3 = _mm256_add_epi32(D3, biasValue0); auto scaleValue = _mm256_loadu_ps(scale_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); auto f3 = _mm256_cvtepi32_ps(D3); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second + auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third + auto xy0_3 = _mm256_mul_ps(kernelSum3, weightBiasValue); // ..fourth f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); f3 = _mm256_mul_ps(f3, scaleValue); + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + f2 = _mm256_add_ps(f2, xy0_2); + f3 = _mm256_add_ps(f3, xy0_3); + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); + f2 = _mm256_add_ps(f2, biasValue); + f3 = _mm256_add_ps(f3, biasValue); if (post->useInt8 == 0) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f2 = _mm256_min_ps(f2, fp32max); + f3 = _mm256_min_ps(f3, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); + f2 = _mm256_max_ps(f2, fp32min); + f3 = _mm256_max_ps(f3, fp32min); _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); @@ -376,7 +1072,8 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, if (3 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); - const auto bias_dz = post->bias + dz * AVX2_PACKINT8; + const auto bias_dz = post->biasFloat + dz * AVX2_PACKINT8; + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = post->scale + dz * AVX2_PACKINT8; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; @@ -402,19 +1099,38 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, auto D1 = D01; auto D2 = D02; - auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); - D0 = _mm256_add_epi32(D0, biasValue0); - D1 = _mm256_add_epi32(D1, biasValue0); - D2 = _mm256_add_epi32(D2, biasValue0); + // auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); + // D0 = _mm256_add_epi32(D0, biasValue0); + // D1 = _mm256_add_epi32(D1, biasValue0); + // D2 = _mm256_add_epi32(D2, biasValue0); auto scaleValue = _mm256_loadu_ps(scale_dz); + auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); auto f2 = _mm256_cvtepi32_ps(D2); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second + auto xy0_2 = _mm256_mul_ps(kernelSum2, weightBiasValue); // .. third f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); f2 = _mm256_mul_ps(f2, scaleValue); + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + f2 = _mm256_add_ps(f2, xy0_2); + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); + f2 = _mm256_add_ps(f2, biasValue); if (post->useInt8 == 0) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f2 = _mm256_min_ps(f2, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); + f2 = _mm256_max_ps(f2, fp32min); _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); _mm256_storeu_ps(((float*)dst_x) + 2 * AVX2_PACKINT8, f2); @@ -429,7 +1145,8 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, if (2 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); - const auto bias_dz = post->bias + dz * AVX2_PACKINT8; + const auto bias_dz = post->biasFloat + dz * AVX2_PACKINT8; + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = post->scale + dz * AVX2_PACKINT8; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; @@ -451,16 +1168,26 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, auto D0 = D00; auto D1 = D01; - auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); - D0 = _mm256_add_epi32(D0, biasValue0); - D1 = _mm256_add_epi32(D1, biasValue0); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto scaleValue = _mm256_loadu_ps(scale_dz); auto f0 = _mm256_cvtepi32_ps(D0); auto f1 = _mm256_cvtepi32_ps(D1); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm256_mul_ps(kernelSum1, weightBiasValue); // ..second f0 = _mm256_mul_ps(f0, scaleValue); f1 = _mm256_mul_ps(f1, scaleValue); + f0 = _mm256_add_ps(f0, xy0_0); + f1 = _mm256_add_ps(f1, xy0_1); + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); + f1 = _mm256_add_ps(f1, biasValue); if (post->useInt8 == 0) { + f0 = _mm256_min_ps(f0, fp32max); + f1 = _mm256_min_ps(f1, fp32max); + f0 = _mm256_max_ps(f0, fp32min); + f1 = _mm256_max_ps(f1, fp32min); _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); _mm256_storeu_ps(((float*)dst_x) + 1 * AVX2_PACKINT8, f1); } else { @@ -473,7 +1200,8 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, if (1 == realDst) { for (int dz = 0; dz < dst_depth_quad; ++dz) { const auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX2_L * GEMMINT8_AVX2_H); - const auto bias_dz = post->bias + dz * AVX2_PACKINT8; + const auto bias_dz = post->biasFloat + dz * AVX2_PACKINT8; + const auto weightBias_dz = post->weightQuanBias + dz * AVX2_PACKINT8; const float* scale_dz = post->scale + dz * AVX2_PACKINT8; auto dst_z = dst + dz * dst_step_tmp; const auto src_x = src; @@ -490,14 +1218,19 @@ void _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast(int8_t* dst, const int8_t* src, D00 = _mm256_add_epi32(D00, _mm256_madd_epi16(_mm256_maddubs_epi16(s0, w0), oneValue)); } auto D0 = D00; - - auto biasValue0 = _mm256_loadu_si256((__m256i*)(bias_dz)); - D0 = _mm256_add_epi32(D0, biasValue0); + auto weightBiasValue = _mm256_loadu_ps((float*)weightBias_dz); auto scaleValue = _mm256_loadu_ps(scale_dz); auto f0 = _mm256_cvtepi32_ps(D0); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm256_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first f0 = _mm256_mul_ps(f0, scaleValue); + f0 = _mm256_add_ps(f0, xy0_0); + auto biasValue = _mm256_loadu_ps(bias_dz); + f0 = _mm256_add_ps(f0, biasValue); if (post->useInt8 == 0) { + f0 = _mm256_min_ps(f0, fp32max); + f0 = _mm256_max_ps(f0, fp32min); _mm256_storeu_ps(((float*)dst_x) + 0 * AVX2_PACKINT8, f0); } else { POSTTREAT(0); @@ -747,6 +1480,9 @@ void _AVX_MNNInt8FunctionInit(void* functions) { gAVX2CoreInt8Functions->Int8GemmKernelFast = _AVX_MNNGemmInt8AddBiasScale_16x4_Unit_Fast; gAVX2CoreInt8Functions->MNNGetGemmUnit = _AVX2_MNNGetGemmUnit; gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _AVXMNNPackC4ForMatMul_A; +#ifdef MNN_LOW_MEMORY + gAVX2CoreInt8Functions->Int8GemmKernel_W4 = _AVX_MNNGemmInt8AddBiasScale_16x4_w4; +#endif // Int8 <-> Float gAVX2CoreInt8Functions->MNNFloat2Int8 = _AVX_MNNFloat2Int8; diff --git a/source/backend/cpu/x86_x64/avx512/GemmInt8.cpp b/source/backend/cpu/x86_x64/avx512/GemmInt8.cpp index 1dc73cbab..6eb8a5379 100644 --- a/source/backend/cpu/x86_x64/avx512/GemmInt8.cpp +++ b/source/backend/cpu/x86_x64/avx512/GemmInt8.cpp @@ -14,10 +14,12 @@ #ifdef MNN_AVX512_VNNI extern void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); extern void _AVX512_MNNLineDepthWiseInt8AddBiasScaleUnit_VNNI(int8_t* dstO, const int8_t* srcO, const int8_t* weightO, const QuanPostTreatParameters* parameters, size_t width, size_t src_w_step, size_t fw, size_t fh, size_t dilateX_step, size_t dilateY_step, int8_t* idxOrder=nullptr); +extern void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); #endif // Define in GemmInt8_4_4_64.cpp extern void _AVX512_NO_VNNI_4_4_64(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); +extern void _AVX512_NO_VNNI_4_4_64_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); // Define in GemmInt8_4_4_64_7bit.cpp extern void _AVX512_NO_VNNI_4_4_64_7bit(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); @@ -123,7 +125,6 @@ static void _AVX512BasicMNNPackC4ForMatMul_A(int8_t* destOrigin, int8_t const** } } } - } @@ -201,32 +202,53 @@ void _AVX512_MNNLineDepthWiseInt8AddBiasScaleUnit(int8_t* dstO, const int8_t* sr } } void _AVX512_MNNFloat2Int8(const float* src, int8_t* dst, size_t sizeQuad, const float* scalep, ssize_t minV, ssize_t maxV, ssize_t zeroPoint) { - auto zero = _mm512_setzero_ps(); - auto minValue = _mm512_set1_ps(minV); - auto maxValue = _mm512_set1_ps(maxV); - auto zeroPointValue = _mm512_set1_ps(zeroPoint); - auto offset = _mm512_set1_ps(128.f); - auto plus = _mm512_set1_ps(0.5f); - auto minus = _mm512_set1_ps(-0.5f); - auto scaleValue0 = _mm512_loadu_ps(scalep); + auto zero = _mm256_set1_epi32(0); + auto minValue = _mm256_set1_ps(minV); + auto maxValue = _mm256_set1_ps(maxV); + auto zeroPointValue = _mm256_set1_ps(zeroPoint); + auto offset = _mm256_set1_epi32(128); + auto plus = _mm256_set1_ps(0.5f); + auto minus = _mm256_set1_ps(-0.5f); + auto scaleValue0 = _mm256_loadu_ps(scalep); + auto scaleValue1 = _mm256_loadu_ps(scalep + 8); for (int i = 0; i < sizeQuad; ++i) { - auto f0 = _mm512_loadu_ps(src + PACK_UNIT * i); - f0 = _mm512_mul_ps(f0, scaleValue0); - f0 = _mm512_add_ps(f0, zeroPointValue); - f0 = _mm512_min_ps(f0, maxValue); - f0 = _mm512_max_ps(f0, minValue); - auto m0 = _mm512_cmp_ps_mask(f0, zero, 1); - auto r0 = _mm512_mask_blend_ps(m0, plus, minus); - f0 = _mm512_add_ps(f0, r0); - __m512 round0 = _mm512_roundscale_ps(f0, 3); - round0 = _mm512_add_ps(round0, offset); - auto i0_int32 = _mm512_cvtps_epi32(round0); - auto i0_int16 = _mm512_cvtsepi32_epi16(i0_int32); - auto h0_int16 = _mm256_extracti128_si256(i0_int16, 0); - auto h1_int16 = _mm256_extracti128_si256(i0_int16, 1); - h0_int16 = _mm_packus_epi16(h0_int16, h1_int16); - _mm_storeu_si128((__m128i*)(dst + i * PACK_UNIT), h0_int16); + auto f0 = _mm256_loadu_ps(src + PACK_UNIT * i); + auto f1 = _mm256_loadu_ps(src + PACK_UNIT * i + 8); + f0 = _mm256_mul_ps(f0, scaleValue0); + f1 = _mm256_mul_ps(f1, scaleValue1); + f0 = _mm256_add_ps(f0, zeroPointValue); + f1 = _mm256_add_ps(f1, zeroPointValue); + f0 = _mm256_min_ps(f0, maxValue); + f1 = _mm256_min_ps(f1, maxValue); + f0 = _mm256_max_ps(f0, minValue); + f1 = _mm256_max_ps(f1, minValue); + auto m0 = _mm256_cmp_ps(f0, _mm256_castsi256_ps(zero), 1); + auto m1 = _mm256_cmp_ps(f1, _mm256_castsi256_ps(zero), 1); + m0 = _mm256_blendv_ps(plus, minus, m0); + m1 = _mm256_blendv_ps(plus, minus, m1); + f0 = _mm256_add_ps(f0, m0); + f1 = _mm256_add_ps(f1, m1); + // 3: _MM_FROUND_TO_ZERO + auto d0 = _mm256_cvtps_epi32(_mm256_round_ps(f0, 3)); + auto d1 = _mm256_cvtps_epi32(_mm256_round_ps(f1, 3)); + d0 = _mm256_add_epi32(d0, offset); + d1 = _mm256_add_epi32(d1, offset); + d0 = _mm256_packs_epi32(d0, _mm256_setzero_si256()); + d1 = _mm256_packs_epi32(d1, _mm256_setzero_si256()); + d0 = _mm256_permute4x64_epi64(d0, 0xD8); + d1 = _mm256_permute4x64_epi64(d1, 0xD8); +#if defined(_MSC_VER) + __m256i x = static_cast<__m256i>(_mm256_packus_epi16(d0, _mm256_setzero_si256())); + __m256i y = static_cast<__m256i>(_mm256_packus_epi16(d1, _mm256_setzero_si256())); + *((int64_t*)dst + 2 * i + 0) = x.m256i_i64[0]; + *((int64_t*)dst + 2 * i + 1) = y.m256i_i64[0]; +#else + __v4di x = static_cast<__v4di>(_mm256_packus_epi16(d0, _mm256_setzero_si256())); + __v4di y = static_cast<__v4di>(_mm256_packus_epi16(d1, _mm256_setzero_si256())); + *((int64_t*)dst + 2 * i + 0) = x[0]; + *((int64_t*)dst + 2 * i + 1) = y[0]; +#endif } } @@ -296,17 +318,22 @@ void _AVX512_MNNInt8FunctionInit(void* functions, bool supportVNNI) { if (supportVNNI) { gAVX2CoreInt8Functions->Int8GemmKernel = _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI; gAVX2CoreInt8Functions->Int8GemmKernelFast = _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI; + gAVX2CoreInt8Functions->Int8GemmKernel_W4 = _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI; // conv depthwise gAVX2CoreInt8Functions->ConvDepthwiseLineInt8 = _AVX512_MNNLineDepthWiseInt8AddBiasScaleUnit_VNNI; // MatMul gAVX2CoreInt8Functions->MNNGetGemmUnit = _AVX512_MNNGetGemmUnit_VNNI; // Im2Col gAVX2CoreInt8Functions->MNNPackC4Int8ForMatMul_A = _AVX512BasicMNNPackC4ForMatMul_A; + + + } else #endif { gAVX2CoreInt8Functions->Int8GemmKernel = _AVX512_NO_VNNI_4_4_64; gAVX2CoreInt8Functions->Int8GemmKernelFast = _AVX512_NO_VNNI_4_4_64_7bit; + gAVX2CoreInt8Functions->Int8GemmKernel_W4 = _AVX512_NO_VNNI_4_4_64_w4; // conv depthwise gAVX2CoreInt8Functions->ConvDepthwiseLineInt8 = _AVX512_MNNLineDepthWiseInt8AddBiasScaleUnit; // MatMul diff --git a/source/backend/cpu/x86_x64/avx512/GemmInt8_4_4_64_NOVNNI.cpp b/source/backend/cpu/x86_x64/avx512/GemmInt8_4_4_64_NOVNNI.cpp index 0df2809d6..7273eab05 100644 --- a/source/backend/cpu/x86_x64/avx512/GemmInt8_4_4_64_NOVNNI.cpp +++ b/source/backend/cpu/x86_x64/avx512/GemmInt8_4_4_64_NOVNNI.cpp @@ -16,4 +16,5 @@ static inline __m512i mnn_mm512_dpbusds_epi32_replace(__m512i dst, __m512i src, } #define MATMULCOREFUNC_NAME _AVX512_NO_VNNI_4_4_64 +#define MATMULCOREFUNC_NAME_W4 _AVX512_NO_VNNI_4_4_64_w4 #include "Matmul_4_4_64.inl" \ No newline at end of file diff --git a/source/backend/cpu/x86_x64/avx512/GemmInt8_VNNI.cpp b/source/backend/cpu/x86_x64/avx512/GemmInt8_VNNI.cpp index f97480b68..31335e2cf 100644 --- a/source/backend/cpu/x86_x64/avx512/GemmInt8_VNNI.cpp +++ b/source/backend/cpu/x86_x64/avx512/GemmInt8_VNNI.cpp @@ -13,10 +13,13 @@ #define GEMMINT8_AVX512_H GEMMINT8_AVX512_H_VNNI #define _MM256_SET_M128I(__H, __L) _mm256_insertf128_si256(_mm256_castsi128_si256(__L), __H, 1) // for compile compatiable #define AVX512_BROADCAST_INT32(src) _mm512_castps_si512(_mm512_broadcastss_ps(_mm_load_ss(src))) + +#define DEQUANT_VALUE(N) \ + auto f##N = _mm512_cvtepi32_ps(D##N);\ + f##N = _mm512_mul_ps(f##N, scaleValue); + #define SCALE_BIAS_VEC(N) \ - auto d##N = _mm512_add_epi32(D##N, biasValue);\ - auto f##N = _mm512_cvtepi32_ps(d##N);\ - f##N = _mm512_mul_ps(f##N, scaleValue); + f##N = _mm512_add_ps(f##N, biasValue); #define POSTTREAT(N, O) \ f##N = _mm512_min_ps(f##N, maxValue);\ @@ -24,13 +27,76 @@ auto m##N = _mm512_cmp_ps_mask(f##N, zero512, 1);\ auto b##N = _mm512_mask_blend_ps(m##N, plus, minus);\ f##N = _mm512_add_ps(f##N, b##N);\ - d##N = _mm512_cvtps_epi32(_mm512_roundscale_ps(f##N, 3));\ + auto d##N = _mm512_cvtps_epi32(_mm512_roundscale_ps(f##N, 3));\ auto hd##N = _mm512_cvtsepi32_epi16(d##N); hd##N = _mm256_add_epi16(hd##N, offset);\ auto h0##N = _mm256_extracti128_si256(hd##N, 0);\ auto h1##N = _mm256_extracti128_si256(hd##N, 1);\ h0##N = _mm_packus_epi16(h0##N, h1##N);\ _mm_storeu_si128((__m128i*)dst_x + O, h0##N); +#define POST_TREAT_FLOAT(N,M,K,V) \ + f##N = _mm512_min_ps(f##N, fp32max);\ + f##N = _mm512_max_ps(f##N, fp32min);\ + f##M = _mm512_min_ps(f##M, fp32max);\ + f##M = _mm512_max_ps(f##M, fp32min);\ + f##K = _mm512_min_ps(f##K, fp32max);\ + f##K = _mm512_max_ps(f##K, fp32min);\ + f##V = _mm512_min_ps(f##V, fp32max);\ + f##V = _mm512_max_ps(f##V, fp32min); + +#define SRCKERNELSUM_MUL_WEIGHTQUANBIAS \ + xy0_0 = _mm512_mul_ps(kernelSum0, weightBiasValue);\ + xy0_1 = _mm512_mul_ps(kernelSum1, weightBiasValue);\ + xy0_2 = _mm512_mul_ps(kernelSum2, weightBiasValue);\ + xy0_3 = _mm512_mul_ps(kernelSum3, weightBiasValue); + +#define PLUS_TERM(N,M,K,V) \ + f##N = _mm512_add_ps(f##N, xy0_0);\ + f##M = _mm512_add_ps(f##M, xy0_1);\ + f##K = _mm512_add_ps(f##K, xy0_2);\ + f##V = _mm512_add_ps(f##V, xy0_3); + +#define POST_TREAT_FLOAT_3(N,M,K) \ + f##N = _mm512_min_ps(f##N, fp32max);\ + f##N = _mm512_max_ps(f##N, fp32min);\ + f##M = _mm512_min_ps(f##M, fp32max);\ + f##M = _mm512_max_ps(f##M, fp32min);\ + f##K = _mm512_min_ps(f##K, fp32max);\ + f##K = _mm512_max_ps(f##K, fp32min); + +#define SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3 \ + xy0_0 = _mm512_mul_ps(kernelSum0, weightBiasValue);\ + xy0_1 = _mm512_mul_ps(kernelSum1, weightBiasValue);\ + xy0_2 = _mm512_mul_ps(kernelSum2, weightBiasValue); + +#define PLUS_TERM_3(N,M,K) \ + f##N = _mm512_add_ps(f##N, xy0_0);\ + f##M = _mm512_add_ps(f##M, xy0_1);\ + f##K = _mm512_add_ps(f##K, xy0_2); + +#define POST_TREAT_FLOAT_2(N,M) \ + f##N = _mm512_min_ps(f##N, fp32max);\ + f##N = _mm512_max_ps(f##N, fp32min);\ + f##M = _mm512_min_ps(f##M, fp32max);\ + f##M = _mm512_max_ps(f##M, fp32min); + +#define SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2 \ + xy0_0 = _mm512_mul_ps(kernelSum0, weightBiasValue);\ + xy0_1 = _mm512_mul_ps(kernelSum1, weightBiasValue); + +#define PLUS_TERM_2(N,M) \ + f##N = _mm512_add_ps(f##N, xy0_0);\ + f##M = _mm512_add_ps(f##M, xy0_1); + +#define POST_TREAT_FLOAT_1(N) \ + f##N = _mm512_min_ps(f##N, fp32max);\ + f##N = _mm512_max_ps(f##N, fp32min); + +#define SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1 \ + xy0_0 = _mm512_mul_ps(kernelSum0, weightBiasValue); + +#define PLUS_TERM_1(N) \ + f##N = _mm512_add_ps(f##N, xy0_0); // GemmInt8 with VNNI void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { @@ -44,10 +110,69 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s int dzUnit = GEMMINT8_AVX512_H / PACK_UNIT; int dzU = dst_depth_quad / dzUnit; int dzR = dst_depth_quad % dzUnit; + __m512 fp32min, fp32max; + if (0 == post->useInt8 && post->fp32minmax) { + fp32min = _mm512_set1_ps((post->fp32minmax)[0]); + fp32max = _mm512_set1_ps((post->fp32minmax)[1]); + } + auto blockNum = post->blockNum; + const float* biasPtr = nullptr; + const float* bias_dz = nullptr; + const float* extraB_dz = nullptr; + if (post->biasFloat) { + biasPtr = post->biasFloat; + } + auto srcKernelSumPtr = post->srcKernelSum; + __m512 kernelSum0 = _mm512_setzero_ps(); + __m512 kernelSum1 = _mm512_setzero_ps(); + __m512 kernelSum2 = _mm512_setzero_ps(); + __m512 kernelSum3 = _mm512_setzero_ps(); + if (GEMMINT8_AVX512_E == realDst) { + kernelSum0 = _mm512_set1_ps(post->srcKernelSum[0]); + kernelSum1 = _mm512_set1_ps(post->srcKernelSum[1]); + kernelSum2 = _mm512_set1_ps(post->srcKernelSum[2]); + kernelSum3 = _mm512_set1_ps(post->srcKernelSum[3]); + } else { + kernelSum0 = _mm512_set1_ps(post->srcKernelSum[0]); + if (realDst > 1) { + kernelSum1 = _mm512_set1_ps(post->srcKernelSum[1]); + } + if (realDst > 2) { + kernelSum2 = _mm512_set1_ps(post->srcKernelSum[2]); + } + } + auto f128 = _mm512_set1_ps(128.f); + __m512 extrascale0 = _mm512_setzero_ps(); + __m512 extrascale1 = _mm512_setzero_ps(); + __m512 extrascale2 = _mm512_setzero_ps(); + __m512 extrascale3 = _mm512_setzero_ps(); + if (post->extraScale) { + if (GEMMINT8_AVX512_E == realDst) { + extrascale0 = _mm512_set1_ps(post->extraScale[0]); + extrascale1 = _mm512_set1_ps(post->extraScale[1]); + extrascale2 = _mm512_set1_ps(post->extraScale[2]); + extrascale3 = _mm512_set1_ps(post->extraScale[3]); + } else { + extrascale0 = _mm512_set1_ps(post->extraScale[0]); + if (realDst > 1) { + extrascale1 = _mm512_set1_ps(post->extraScale[1]); + } + if (realDst > 2) { + extrascale2 = _mm512_set1_ps(post->extraScale[2]); + } + } + } + int weightZStride = blockNum * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); if (realDst == GEMMINT8_AVX512_E) { for (int dz = 0; dz < dzU; ++dz) { - auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dz * PACK_UNIT * dzUnit; + auto weight_dz = weight + dz * weightZStride; + if (biasPtr) { + bias_dz = biasPtr + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; auto dst_z = dst + dz * dst_step_tmp * dzUnit; const auto src_x = src; @@ -77,9 +202,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s const auto weight_sz = weight_dz + (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) * sz; const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); auto w0 = _mm512_loadu_si512(weight_sz); - auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_E); + auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L); auto s0 = AVX512_BROADCAST_INT32(src_z + 0); auto s1 = AVX512_BROADCAST_INT32(src_z + 1); @@ -106,37 +231,185 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s D14 = _mm512_dpbusds_epi32(D14, s2, w3); D15 = _mm512_dpbusds_epi32(D15, s3, w3); } - - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2, xy0_3; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + DEQUANT_VALUE(3); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + f3 = _mm512_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + f3 = _mm512_sub_ps(f3, extrabias3); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); - SCALE_BIAS_VEC(2); - SCALE_BIAS_VEC(3); + PLUS_TERM(0,1,2,3); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + SCALE_BIAS_VEC(3); + } - biasValue = _mm512_loadu_si512(bias_dz + 1 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); - SCALE_BIAS_VEC(4); - SCALE_BIAS_VEC(5); - SCALE_BIAS_VEC(6); - SCALE_BIAS_VEC(7); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + DEQUANT_VALUE(6); + DEQUANT_VALUE(7); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + f6 = _mm512_mul_ps(f6, extrascale2); + f7 = _mm512_mul_ps(f7, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + f6 = _mm512_sub_ps(f6, extrabias2); + f7 = _mm512_sub_ps(f7, extrabias3); + } + } + + PLUS_TERM(4,5,6,7); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + SCALE_BIAS_VEC(6); + SCALE_BIAS_VEC(7); + } - biasValue = _mm512_loadu_si512(bias_dz + 2 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); - SCALE_BIAS_VEC(8); - SCALE_BIAS_VEC(9); - SCALE_BIAS_VEC(10); - SCALE_BIAS_VEC(11); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + DEQUANT_VALUE(10); + DEQUANT_VALUE(11); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + f10 = _mm512_mul_ps(f10, extrascale2); + f11 = _mm512_mul_ps(f11, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + f10 = _mm512_sub_ps(f10, extrabias2); + f11 = _mm512_sub_ps(f11, extrabias3); + } + } + + PLUS_TERM(8,9,10,11); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + SCALE_BIAS_VEC(10); + SCALE_BIAS_VEC(11); + } - biasValue = _mm512_loadu_si512(bias_dz + 3 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); - SCALE_BIAS_VEC(12); - SCALE_BIAS_VEC(13); - SCALE_BIAS_VEC(14); - SCALE_BIAS_VEC(15); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + DEQUANT_VALUE(14); + DEQUANT_VALUE(15); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + f14 = _mm512_mul_ps(f14, extrascale2); + f15 = _mm512_mul_ps(f15, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + f14 = _mm512_sub_ps(f14, extrabias2); + f15 = _mm512_sub_ps(f15, extrabias3); + } + } + + PLUS_TERM(12,13,14,15); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + SCALE_BIAS_VEC(14); + SCALE_BIAS_VEC(15); + } if (post->useInt8 == 0) { + if (biasPtr == nullptr) { + auto destTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps((float*)destTmp), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f2); + f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f3); + destTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f5); + f6 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f6); + f7 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f7); + destTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f9); + f10 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f10); + f11 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f11); + destTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f13); + f14 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f14); + f15 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f15); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT(0,1,2,3); + POST_TREAT_FLOAT(4,5,6,7); + POST_TREAT_FLOAT(8,9,10,11); + POST_TREAT_FLOAT(12,13,14,15); + } + _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); @@ -181,9 +454,15 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s POSTTREAT(15, 3); } } - auto weight_dz = weight + dzU * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dzU * PACK_UNIT * dzUnit; + auto weight_dz = weight + dzU * weightZStride; + if (biasPtr) { + bias_dz = biasPtr + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; auto dst_z = dst + dzU * dst_step_tmp * dzUnit; const auto src_x = src; @@ -210,15 +489,54 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s D3 = _mm512_dpbusds_epi32(D3, s3, w0); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2, xy0_3; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + DEQUANT_VALUE(3); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + f3 = _mm512_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + f3 = _mm512_sub_ps(f3, extrabias3); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); - SCALE_BIAS_VEC(2); - SCALE_BIAS_VEC(3); + PLUS_TERM(0,1,2,3); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + SCALE_BIAS_VEC(3); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps((float*)dst_x), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 2), f2); + f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 3), f3); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT(0,1,2,3); + } _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); @@ -231,17 +549,28 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s } dst_x += dst_step_tmp; scale_dz += PACK_UNIT; - bias_dz += PACK_UNIT; - weight_dz += PACK_UNIT * GEMMINT8_AVX512_E; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; } return; } // e = 3 if (realDst == 3) { for (int dz = 0; dz < dzU; ++dz) { - auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dz * PACK_UNIT * dzUnit; + auto weight_dz = weight + dz * weightZStride; + if (biasPtr) { + bias_dz = biasPtr + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; auto dst_z = dst + dz * dst_step_tmp * dzUnit; const auto src_x = src; auto dst_x = dst_z; @@ -266,9 +595,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s const auto weight_sz = weight_dz + (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) * sz; const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); auto w0 = _mm512_loadu_si512(weight_sz); - auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_E); + auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L); auto s0 = AVX512_BROADCAST_INT32(src_z + 0); auto s1 = AVX512_BROADCAST_INT32(src_z + 1); @@ -291,32 +620,160 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s D14 = _mm512_dpbusds_epi32(D14, s2, w3); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); - SCALE_BIAS_VEC(2); + PLUS_TERM_3(0,1,2); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + } - biasValue = _mm512_loadu_si512(bias_dz + 1 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); - SCALE_BIAS_VEC(4); - SCALE_BIAS_VEC(5); - SCALE_BIAS_VEC(6); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + DEQUANT_VALUE(6); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + f6 = _mm512_mul_ps(f6, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + f6 = _mm512_sub_ps(f6, extrabias2); + } + } + + PLUS_TERM_3(4,5,6); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + SCALE_BIAS_VEC(6); + } - biasValue = _mm512_loadu_si512(bias_dz + 2 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); - SCALE_BIAS_VEC(8); - SCALE_BIAS_VEC(9); - SCALE_BIAS_VEC(10); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + DEQUANT_VALUE(10); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + f10 = _mm512_mul_ps(f10, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + f10 = _mm512_sub_ps(f10, extrabias2); + } + } + + PLUS_TERM_3(8,9,10); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + SCALE_BIAS_VEC(10); + } - biasValue = _mm512_loadu_si512(bias_dz + 3 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); - SCALE_BIAS_VEC(12); - SCALE_BIAS_VEC(13); - SCALE_BIAS_VEC(14); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + DEQUANT_VALUE(14); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + f14 = _mm512_mul_ps(f14, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + f14 = _mm512_sub_ps(f14, extrabias2); + } + } + + PLUS_TERM_3(12,13,14); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + SCALE_BIAS_VEC(14); + } if (post->useInt8 == 0) { + if (biasPtr == nullptr) { + auto dstTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f2); + dstTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f5); + f6 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f6); + dstTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f9); + f10 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f10); + dstTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f13); + f14 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f14); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_3(0,1,2); + POST_TREAT_FLOAT_3(4,5,6); + POST_TREAT_FLOAT_3(8,9,10); + POST_TREAT_FLOAT_3(12,13,14); + } _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); @@ -353,9 +810,15 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s POSTTREAT(14, 2); } } - auto weight_dz = weight + dzU * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dzU * PACK_UNIT * dzUnit; + auto weight_dz = weight + dzU * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; auto dst_z = dst + dzU * dst_step_tmp * dzUnit; const auto src_x = src; @@ -379,14 +842,49 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s D2 = _mm512_dpbusds_epi32(D2, s2, w0); } - auto biasValue = _mm512_loadu_si512(bias_dz); + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); - SCALE_BIAS_VEC(2); - + PLUS_TERM_3(0,1,2); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + } + if (post->useInt8 == 0) { + if (biasPtr == nullptr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 2), f2); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_3(0,1,2); + } _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); @@ -397,17 +895,28 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s } dst_x += dst_step_tmp; scale_dz += PACK_UNIT; - bias_dz += PACK_UNIT; - weight_dz += PACK_UNIT * GEMMINT8_AVX512_E; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; } return; } // e = 2 if (realDst == 2) { for (int dz = 0; dz < dzU; ++dz) { - auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dz * PACK_UNIT * dzUnit; + auto weight_dz = weight + dz * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; auto dst_z = dst + dz * dst_step_tmp * dzUnit; const auto src_x = src; auto dst_x = dst_z; @@ -428,9 +937,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s const auto weight_sz = weight_dz + (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) * sz; const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); auto w0 = _mm512_loadu_si512(weight_sz); - auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_E); + auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L); auto s0 = AVX512_BROADCAST_INT32(src_z + 0); auto s1 = AVX512_BROADCAST_INT32(src_z + 1); @@ -448,28 +957,135 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s D13 = _mm512_dpbusds_epi32(D13, s1, w3); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1; + + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); + PLUS_TERM_2(0,1); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + } - biasValue = _mm512_loadu_si512(bias_dz + 1 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); - SCALE_BIAS_VEC(4); - SCALE_BIAS_VEC(5); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + } + } + + PLUS_TERM_2(4,5); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + } - biasValue = _mm512_loadu_si512(bias_dz + 2 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); - SCALE_BIAS_VEC(8); - SCALE_BIAS_VEC(9); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + } + } + + PLUS_TERM_2(8,9); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + } - biasValue = _mm512_loadu_si512(bias_dz + 3 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); - SCALE_BIAS_VEC(12); - SCALE_BIAS_VEC(13); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + } + } + + PLUS_TERM_2(12,13); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + auto dstTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16), f1); + dstTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f5); + dstTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f9); + dstTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f13); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_2(0,1); + POST_TREAT_FLOAT_2(4,5); + POST_TREAT_FLOAT_2(8,9); + POST_TREAT_FLOAT_2(12,13); + } _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); dst_x += dst_step_tmp; @@ -498,9 +1114,15 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s POSTTREAT(13, 1); } } - auto weight_dz = weight + dzU * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dzU * PACK_UNIT * dzUnit; + auto weight_dz = weight + dzU * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; auto dst_z = dst + dzU * dst_step_tmp * dzUnit; const auto src_x = src; @@ -521,13 +1143,40 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s D1 = _mm512_dpbusds_epi32(D1, s1, w0); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); + PLUS_TERM_2(0,1); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + } + POST_TREAT_FLOAT_2(0,1); _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); } else { @@ -536,16 +1185,27 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s } dst_x += dst_step_tmp; scale_dz += PACK_UNIT; - bias_dz += PACK_UNIT; - weight_dz += PACK_UNIT * GEMMINT8_AVX512_E; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; } return; } if (realDst == 1) { for (int dz = 0; dz < dzU; ++dz) { - auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dz * PACK_UNIT * dzUnit; + auto weight_dz = weight + dz * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; auto dst_z = dst + dz * dst_step_tmp * dzUnit; const auto src_x = src; auto dst_x = dst_z; @@ -561,9 +1221,9 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s const auto weight_sz = weight_dz + (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) * sz; const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); auto w0 = _mm512_loadu_si512(weight_sz); - auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_E); + auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L); auto s0 = AVX512_BROADCAST_INT32(src_z + 0); @@ -576,24 +1236,113 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s D12 = _mm512_dpbusds_epi32(D12, s0, w3); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0; + + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(0); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f0 = _mm512_sub_ps(f0, extrabias0); + } + } - SCALE_BIAS_VEC(0); + PLUS_TERM_1(0); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + } - biasValue = _mm512_loadu_si512(bias_dz + 1 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); - SCALE_BIAS_VEC(4); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(4); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f4 = _mm512_sub_ps(f4, extrabias0); + } + } + + PLUS_TERM_1(4); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + } - biasValue = _mm512_loadu_si512(bias_dz + 2 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); - SCALE_BIAS_VEC(8); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(8); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f8 = _mm512_sub_ps(f8, extrabias0); + } + } + + PLUS_TERM_1(8); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + } - biasValue = _mm512_loadu_si512(bias_dz + 3 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); - SCALE_BIAS_VEC(12); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(12); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f12 = _mm512_sub_ps(f12, extrabias0); + } + } + + PLUS_TERM_1(12); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + auto dstTemp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp)), f0); + dstTemp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f4); + dstTemp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f8); + dstTemp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f12); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_1(0); + POST_TREAT_FLOAT_1(4); + POST_TREAT_FLOAT_1(8); + POST_TREAT_FLOAT_1(12); + } _mm512_storeu_ps(((float*)dst_x), f0); dst_x += dst_step_tmp; _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); @@ -614,9 +1363,15 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s POSTTREAT(12, 0); } } - auto weight_dz = weight + dzU * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dzU * PACK_UNIT * dzUnit; + auto weight_dz = weight + dzU * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; auto dst_z = dst + dzU * dst_step_tmp * dzUnit; const auto src_x = src; @@ -634,20 +1389,1324 @@ void _AVX512_MNNGemmInt8AddBiasScale_16x4_Unit_VNNI(int8_t* dst, const int8_t* s D0 = _mm512_dpbusds_epi32(D0, s0, w0); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); - SCALE_BIAS_VEC(0); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(0); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f0 = _mm512_sub_ps(f0, extrabias0); + } + } + + PLUS_TERM_1(0); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_1(0); + } _mm512_storeu_ps(((float*)dst_x), f0); } else { POSTTREAT(0, 0); } dst_x += dst_step_tmp; scale_dz += PACK_UNIT; - bias_dz += PACK_UNIT; - weight_dz += PACK_UNIT * GEMMINT8_AVX512_E; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; + } + return; + } +} + +// GemmInt8 with VNNI int4-weight fp32-output +void _AVX512_MNNGemmInt8AddBiasScale_16x4_w4_Unit_VNNI(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { + MNN_ASSERT(post->useInt8 == 0); + const auto dst_step_tmp = dst_step / sizeof(int8_t); + auto zero512 = _mm512_set1_ps(0.0f); + int dzUnit = GEMMINT8_AVX512_H / PACK_UNIT; + int dzU = dst_depth_quad / dzUnit; + int dzR = dst_depth_quad % dzUnit; + const __m512i mask = _mm512_set1_epi8(0xf); + __m512 fp32min, fp32max; + if (post->fp32minmax) { + fp32min = _mm512_set1_ps((post->fp32minmax)[0]); + fp32max = _mm512_set1_ps((post->fp32minmax)[1]); + } + auto blockNum = post->blockNum; + const float* biasPtr = nullptr; + const float* bias_dz = nullptr; + const float* extraB_dz = nullptr; + if (post->biasFloat) { + biasPtr = post->biasFloat; + } + + auto srcKernelSumPtr = post->srcKernelSum; + __m512 kernelSum0 = _mm512_setzero_ps(); + __m512 kernelSum1 = _mm512_setzero_ps(); + __m512 kernelSum2 = _mm512_setzero_ps(); + __m512 kernelSum3 = _mm512_setzero_ps(); + if (GEMMINT8_AVX512_E == realDst) { + kernelSum0 = _mm512_set1_ps(post->srcKernelSum[0]); + kernelSum1 = _mm512_set1_ps(post->srcKernelSum[1]); + kernelSum2 = _mm512_set1_ps(post->srcKernelSum[2]); + kernelSum3 = _mm512_set1_ps(post->srcKernelSum[3]); + } else { + kernelSum0 = _mm512_set1_ps(post->srcKernelSum[0]); + if (realDst > 1) { + kernelSum1 = _mm512_set1_ps(post->srcKernelSum[1]); + } + if (realDst > 2) { + kernelSum2 = _mm512_set1_ps(post->srcKernelSum[2]); + } + } + auto f128 = _mm512_set1_ps(128.f); + __m512 extrascale0 = _mm512_setzero_ps(); + __m512 extrascale1 = _mm512_setzero_ps(); + __m512 extrascale2 = _mm512_setzero_ps(); + __m512 extrascale3 = _mm512_setzero_ps(); + if (post->extraScale) { + if (GEMMINT8_AVX512_E == realDst) { + extrascale0 = _mm512_set1_ps(post->extraScale[0]); + extrascale1 = _mm512_set1_ps(post->extraScale[1]); + extrascale2 = _mm512_set1_ps(post->extraScale[2]); + extrascale3 = _mm512_set1_ps(post->extraScale[3]); + } else { + extrascale0 = _mm512_set1_ps(post->extraScale[0]); + if (realDst > 1) { + extrascale1 = _mm512_set1_ps(post->extraScale[1]); + } + if (realDst > 2) { + extrascale2 = _mm512_set1_ps(post->extraScale[2]); + } + } + } + int weight_step_Z = static_cast(blockNum * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) / 2); // sizeof(int4_t) + int weight_step_Y = static_cast(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H / 2); // sizeof(int4_t) + + if (realDst == GEMMINT8_AVX512_E) { + for (int dz = 0; dz < dzU; ++dz) { + auto weight_dz = weight + dz * weight_step_Z; + if (post->biasFloat) { + bias_dz = biasPtr + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; + float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + auto dst_z = dst + dz * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + __m512i D0 = _mm512_set1_epi32(0); + __m512i D1 = _mm512_set1_epi32(0); + __m512i D2 = _mm512_set1_epi32(0); + __m512i D3 = _mm512_set1_epi32(0); + + __m512i D4 = _mm512_set1_epi32(0); + __m512i D5 = _mm512_set1_epi32(0); + __m512i D6 = _mm512_set1_epi32(0); + __m512i D7 = _mm512_set1_epi32(0); + + __m512i D8 = _mm512_set1_epi32(0); + __m512i D9 = _mm512_set1_epi32(0); + __m512i D10 = _mm512_set1_epi32(0); + __m512i D11 = _mm512_set1_epi32(0); + + __m512i D12 = _mm512_set1_epi32(0); + __m512i D13 = _mm512_set1_epi32(0); + __m512i D14 = _mm512_set1_epi32(0); + __m512i D15 = _mm512_set1_epi32(0); + + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); + + // int4->int8: total count=4*64(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) + // Load 4*64 int4 weight + auto w0_int4_64 = _mm512_loadu_si512(weight_sz); // 128xint4_t=64 byte + auto w1_int4_64 = _mm512_loadu_si512(weight_sz + 64); // 128xint4_t + // 256xint4_t->256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + auto w2 = _mm512_and_si512(mask, w0_int4_64); // 64xint8_t + auto w1 = _mm512_and_si512(mask, _mm512_srli_epi16(w1_int4_64, 4)); + auto w3 = _mm512_and_si512(mask, w1_int4_64); + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + auto s1 = AVX512_BROADCAST_INT32(src_z + 1); + auto s2 = AVX512_BROADCAST_INT32(src_z + 2); + auto s3 = AVX512_BROADCAST_INT32(src_z + 3); + + D0 = _mm512_dpbusds_epi32(D0, s0, w0); + D1 = _mm512_dpbusds_epi32(D1, s1, w0); + D2 = _mm512_dpbusds_epi32(D2, s2, w0); + D3 = _mm512_dpbusds_epi32(D3, s3, w0); + + D4 = _mm512_dpbusds_epi32(D4, s0, w1); + D5 = _mm512_dpbusds_epi32(D5, s1, w1); + D6 = _mm512_dpbusds_epi32(D6, s2, w1); + D7 = _mm512_dpbusds_epi32(D7, s3, w1); + + D8 = _mm512_dpbusds_epi32(D8, s0, w2); + D9 = _mm512_dpbusds_epi32(D9, s1, w2); + D10 = _mm512_dpbusds_epi32(D10, s2, w2); + D11 = _mm512_dpbusds_epi32(D11, s3, w2); + + D12 = _mm512_dpbusds_epi32(D12, s0, w3); + D13 = _mm512_dpbusds_epi32(D13, s1, w3); + D14 = _mm512_dpbusds_epi32(D14, s2, w3); + D15 = _mm512_dpbusds_epi32(D15, s3, w3); + } + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2, xy0_3; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + DEQUANT_VALUE(3); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + f3 = _mm512_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + f3 = _mm512_sub_ps(f3, extrabias3); + } + } + + PLUS_TERM(0,1,2,3); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + SCALE_BIAS_VEC(3); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + DEQUANT_VALUE(6); + DEQUANT_VALUE(7); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + f6 = _mm512_mul_ps(f6, extrascale2); + f7 = _mm512_mul_ps(f7, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + f6 = _mm512_sub_ps(f6, extrabias2); + f7 = _mm512_sub_ps(f7, extrabias3); + } + } + + PLUS_TERM(4,5,6,7); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + SCALE_BIAS_VEC(6); + SCALE_BIAS_VEC(7); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + DEQUANT_VALUE(10); + DEQUANT_VALUE(11); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + f10 = _mm512_mul_ps(f10, extrascale2); + f11 = _mm512_mul_ps(f11, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + f10 = _mm512_sub_ps(f10, extrabias2); + f11 = _mm512_sub_ps(f11, extrabias3); + } + } + + PLUS_TERM(8,9,10,11); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + SCALE_BIAS_VEC(10); + SCALE_BIAS_VEC(11); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + DEQUANT_VALUE(14); + DEQUANT_VALUE(15); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + f14 = _mm512_mul_ps(f14, extrascale2); + f15 = _mm512_mul_ps(f15, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + f14 = _mm512_sub_ps(f14, extrabias2); + f15 = _mm512_sub_ps(f15, extrabias3); + } + } + + PLUS_TERM(12,13,14,15); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + SCALE_BIAS_VEC(14); + SCALE_BIAS_VEC(15); + } + if (biasPtr == nullptr) { + auto destTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps((float*)destTmp), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f2); + f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f3); + destTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f5); + f6 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f6); + f7 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f7); + destTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f9); + f10 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f10); + f11 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f11); + destTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f13); + f14 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f14); + f15 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f15); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT(0,1,2,3); + POST_TREAT_FLOAT(4,5,6,7); + POST_TREAT_FLOAT(8,9,10,11); + POST_TREAT_FLOAT(12,13,14,15); + } + + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f3); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f6); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f7); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f10); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f11); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f14); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f15); + + } + auto weight_dz = weight + dzU * weight_step_Z; + if (biasPtr) { + bias_dz = biasPtr + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; + + auto dst_z = dst + dzU * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + for (int i=0; i256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + auto s1 = AVX512_BROADCAST_INT32(src_z + 1); + auto s2 = AVX512_BROADCAST_INT32(src_z + 2); + auto s3 = AVX512_BROADCAST_INT32(src_z + 3); + + D0 = _mm512_dpbusds_epi32(D0, s0, w0); + D1 = _mm512_dpbusds_epi32(D1, s1, w0); + D2 = _mm512_dpbusds_epi32(D2, s2, w0); + D3 = _mm512_dpbusds_epi32(D3, s3, w0); + } + + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2, xy0_3; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + DEQUANT_VALUE(3); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + f3 = _mm512_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + f3 = _mm512_sub_ps(f3, extrabias3); + } + } + + PLUS_TERM(0,1,2,3); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + SCALE_BIAS_VEC(3); + } + + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps((float*)dst_x), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 2), f2); + f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 3), f3); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT(0,1,2,3); + } + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f3); + + dst_x += dst_step_tmp; + scale_dz += PACK_UNIT; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; + } + return; + } + // e = 3 + if (realDst == 3) { + for (int dz = 0; dz < dzU; ++dz) { + auto weight_dz = weight + dz * weight_step_Z; + if (biasPtr) { + bias_dz = biasPtr + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; + auto dst_z = dst + dz * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + __m512i D0 = _mm512_set1_epi32(0); + __m512i D1 = _mm512_set1_epi32(0); + __m512i D2 = _mm512_set1_epi32(0); + + __m512i D4 = _mm512_set1_epi32(0); + __m512i D5 = _mm512_set1_epi32(0); + __m512i D6 = _mm512_set1_epi32(0); + + __m512i D8 = _mm512_set1_epi32(0); + __m512i D9 = _mm512_set1_epi32(0); + __m512i D10 = _mm512_set1_epi32(0); + + __m512i D12 = _mm512_set1_epi32(0); + __m512i D13 = _mm512_set1_epi32(0); + __m512i D14 = _mm512_set1_epi32(0); + + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); + // int4->int8: total count=4*64(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) + // Load 4*64 int4 weight + auto w0_int4_64 = _mm512_loadu_si512(weight_sz); // 128xint4_t=64 byte + auto w1_int4_64 = _mm512_loadu_si512(weight_sz + 64); // 128xint4_t + // 256xint4_t->256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + auto w2 = _mm512_and_si512(mask, w0_int4_64); // 64xint8_t + auto w1 = _mm512_and_si512(mask, _mm512_srli_epi16(w1_int4_64, 4)); + auto w3 = _mm512_and_si512(mask, w1_int4_64); + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + auto s1 = AVX512_BROADCAST_INT32(src_z + 1); + auto s2 = AVX512_BROADCAST_INT32(src_z + 2); + + D0 = _mm512_dpbusds_epi32(D0, s0, w0); + D1 = _mm512_dpbusds_epi32(D1, s1, w0); + D2 = _mm512_dpbusds_epi32(D2, s2, w0); + + D4 = _mm512_dpbusds_epi32(D4, s0, w1); + D5 = _mm512_dpbusds_epi32(D5, s1, w1); + D6 = _mm512_dpbusds_epi32(D6, s2, w1); + + D8 = _mm512_dpbusds_epi32(D8, s0, w2); + D9 = _mm512_dpbusds_epi32(D9, s1, w2); + D10 = _mm512_dpbusds_epi32(D10, s2, w2); + + D12 = _mm512_dpbusds_epi32(D12, s0, w3); + D13 = _mm512_dpbusds_epi32(D13, s1, w3); + D14 = _mm512_dpbusds_epi32(D14, s2, w3); + } + + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + } + } + + PLUS_TERM_3(0,1,2); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + DEQUANT_VALUE(6); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + f6 = _mm512_mul_ps(f6, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + f6 = _mm512_sub_ps(f6, extrabias2); + } + } + + PLUS_TERM_3(4,5,6); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + SCALE_BIAS_VEC(6); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + DEQUANT_VALUE(10); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + f10 = _mm512_mul_ps(f10, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + f10 = _mm512_sub_ps(f10, extrabias2); + } + } + + PLUS_TERM_3(8,9,10); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + SCALE_BIAS_VEC(10); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + DEQUANT_VALUE(14); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + f14 = _mm512_mul_ps(f14, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + f14 = _mm512_sub_ps(f14, extrabias2); + } + } + + PLUS_TERM_3(12,13,14); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + SCALE_BIAS_VEC(14); + } + + if (biasPtr == nullptr) { + auto dstTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f2); + dstTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f5); + f6 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f6); + dstTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f9); + f10 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f10); + dstTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f13); + f14 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f14); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_3(0,1,2); + POST_TREAT_FLOAT_3(4,5,6); + POST_TREAT_FLOAT_3(8,9,10); + POST_TREAT_FLOAT_3(12,13,14); + } + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f6); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f10); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f14); + + } + auto weight_dz = weight + dzU * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; + + auto dst_z = dst + dzU * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + for (int i=0; iextraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + } + } + + PLUS_TERM_3(0,1,2); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + } + + if (biasPtr == nullptr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 2), f2); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_3(0,1,2); + } + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); + + dst_x += dst_step_tmp; + scale_dz += PACK_UNIT; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; + } + return; + } + // e = 2 + if (realDst == 2) { + for (int dz = 0; dz < dzU; ++dz) { + auto weight_dz = weight + dz * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; + auto dst_z = dst + dz * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + __m512i D0 = _mm512_set1_epi32(0); + __m512i D1 = _mm512_set1_epi32(0); + + __m512i D4 = _mm512_set1_epi32(0); + __m512i D5 = _mm512_set1_epi32(0); + + __m512i D8 = _mm512_set1_epi32(0); + __m512i D9 = _mm512_set1_epi32(0); + + __m512i D12 = _mm512_set1_epi32(0); + __m512i D13 = _mm512_set1_epi32(0); + + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); + // int4->int8: total count=4*64(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) + // Load 4*64 int4 weight + auto w0_int4_64 = _mm512_loadu_si512(weight_sz); // 128xint4_t=64 byte + auto w1_int4_64 = _mm512_loadu_si512(weight_sz + 64); // 128xint4_t + // 256xint4_t->256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + auto w2 = _mm512_and_si512(mask, w0_int4_64); // 64xint8_t + auto w1 = _mm512_and_si512(mask, _mm512_srli_epi16(w1_int4_64, 4)); + auto w3 = _mm512_and_si512(mask, w1_int4_64); + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + auto s1 = AVX512_BROADCAST_INT32(src_z + 1); + + D0 = _mm512_dpbusds_epi32(D0, s0, w0); + D1 = _mm512_dpbusds_epi32(D1, s1, w0); + + D4 = _mm512_dpbusds_epi32(D4, s0, w1); + D5 = _mm512_dpbusds_epi32(D5, s1, w1); + + D8 = _mm512_dpbusds_epi32(D8, s0, w2); + D9 = _mm512_dpbusds_epi32(D9, s1, w2); + + D12 = _mm512_dpbusds_epi32(D12, s0, w3); + D13 = _mm512_dpbusds_epi32(D13, s1, w3); + } + + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1; + + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + } + } + + PLUS_TERM_2(0,1); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + } + } + + PLUS_TERM_2(4,5); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + } + } + + PLUS_TERM_2(8,9); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + } + } + + PLUS_TERM_2(12,13); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + } + + if (nullptr == biasPtr) { + auto dstTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16), f1); + dstTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f5); + dstTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f9); + dstTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f13); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_2(0,1); + POST_TREAT_FLOAT_2(4,5); + POST_TREAT_FLOAT_2(8,9); + POST_TREAT_FLOAT_2(12,13); + } + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13); + + } + auto weight_dz = weight + dzU * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; + + auto dst_z = dst + dzU * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + for (int i=0; iextraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + } + } + + PLUS_TERM_2(0,1); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + } + + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + } + POST_TREAT_FLOAT_2(0,1); + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + + dst_x += dst_step_tmp; + scale_dz += PACK_UNIT; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; + } + return; + } + if (realDst == 1) { + for (int dz = 0; dz < dzU; ++dz) { + auto weight_dz = weight + dz * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; + auto dst_z = dst + dz * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + __m512i D0 = _mm512_set1_epi32(0); + + __m512i D4 = _mm512_set1_epi32(0); + + __m512i D8 = _mm512_set1_epi32(0); + + __m512i D12 = _mm512_set1_epi32(0); + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); + // int4->int8: total count=4*64(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) + // Load 4*64 int4 weight + auto w0_int4_64 = _mm512_loadu_si512(weight_sz); // 128xint4_t=64 byte + auto w1_int4_64 = _mm512_loadu_si512(weight_sz + 64); // 128xint4_t + // 256xint4_t->256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + auto w2 = _mm512_and_si512(mask, w0_int4_64); // 64xint8_t + auto w1 = _mm512_and_si512(mask, _mm512_srli_epi16(w1_int4_64, 4)); + auto w3 = _mm512_and_si512(mask, w1_int4_64); + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + + D0 = _mm512_dpbusds_epi32(D0, s0, w0); + + D4 = _mm512_dpbusds_epi32(D4, s0, w1); + + D8 = _mm512_dpbusds_epi32(D8, s0, w2); + + D12 = _mm512_dpbusds_epi32(D12, s0, w3); + } + + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0; + + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(0); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f0 = _mm512_sub_ps(f0, extrabias0); + } + } + + PLUS_TERM_1(0); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(4); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f4 = _mm512_sub_ps(f4, extrabias0); + } + } + + PLUS_TERM_1(4); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(8); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f8 = _mm512_sub_ps(f8, extrabias0); + } + } + + PLUS_TERM_1(8); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(12); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f12 = _mm512_sub_ps(f12, extrabias0); + } + } + + PLUS_TERM_1(12); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + } + + if (nullptr == biasPtr) { + auto dstTemp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp)), f0); + dstTemp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f4); + dstTemp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f8); + dstTemp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f12); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_1(0); + POST_TREAT_FLOAT_1(4); + POST_TREAT_FLOAT_1(8); + POST_TREAT_FLOAT_1(12); + } + _mm512_storeu_ps(((float*)dst_x), f0); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12); + + } + auto weight_dz = weight + dzU * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; + + auto dst_z = dst + dzU * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + for (int i=0; iextraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f0 = _mm512_sub_ps(f0, extrabias0); + } + } + + PLUS_TERM_1(0); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + } + + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_1(0); + } + _mm512_storeu_ps(((float*)dst_x), f0); + dst_x += dst_step_tmp; + scale_dz += PACK_UNIT; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; } return; } diff --git a/source/backend/cpu/x86_x64/avx512/Matmul_4_4_64.inl b/source/backend/cpu/x86_x64/avx512/Matmul_4_4_64.inl index aae677b09..5addec946 100644 --- a/source/backend/cpu/x86_x64/avx512/Matmul_4_4_64.inl +++ b/source/backend/cpu/x86_x64/avx512/Matmul_4_4_64.inl @@ -1,10 +1,13 @@ #define GEMMINT8_AVX512_H GEMMINT8_AVX512_H_NOVNNI #define AVX512_BROADCAST_INT32(src) _mm512_castps_si512(_mm512_broadcastss_ps(_mm_load_ss(src))) + +#define DEQUANT_VALUE(N) \ + auto f##N = _mm512_cvtepi32_ps(D##N);\ + f##N = _mm512_mul_ps(f##N, scaleValue); + #define SCALE_BIAS_VEC(N) \ - auto d##N = _mm512_add_epi32(D##N, biasValue);\ - auto f##N = _mm512_cvtepi32_ps(d##N);\ - f##N = _mm512_mul_ps(f##N, scaleValue); + f##N = _mm512_add_ps(f##N, biasValue); #define POSTTREAT(N, O) \ f##N = _mm512_min_ps(f##N, maxValue);\ @@ -12,13 +15,77 @@ auto m##N = _mm512_cmp_ps_mask(f##N, zero512, 1);\ auto b##N = _mm512_mask_blend_ps(m##N, plus, minus);\ f##N = _mm512_add_ps(f##N, b##N);\ - d##N = _mm512_cvtps_epi32(_mm512_roundscale_ps(f##N, 3));\ + auto d##N = _mm512_cvtps_epi32(_mm512_roundscale_ps(f##N, 3));\ auto hd##N = _mm512_cvtsepi32_epi16(d##N); hd##N = _mm256_add_epi16(hd##N, offset);\ auto h0##N = _mm256_extracti128_si256(hd##N, 0);\ auto h1##N = _mm256_extracti128_si256(hd##N, 1);\ h0##N = _mm_packus_epi16(h0##N, h1##N);\ _mm_storeu_si128((__m128i*)dst_x + O, h0##N); +#define POST_TREAT_FLOAT(N,M,K,V) \ + f##N = _mm512_min_ps(f##N, fp32max);\ + f##N = _mm512_max_ps(f##N, fp32min);\ + f##M = _mm512_min_ps(f##M, fp32max);\ + f##M = _mm512_max_ps(f##M, fp32min);\ + f##K = _mm512_min_ps(f##K, fp32max);\ + f##K = _mm512_max_ps(f##K, fp32min);\ + f##V = _mm512_min_ps(f##V, fp32max);\ + f##V = _mm512_max_ps(f##V, fp32min); + +#define SRCKERNELSUM_MUL_WEIGHTQUANBIAS \ + xy0_0 = _mm512_mul_ps(kernelSum0, weightBiasValue);\ + xy0_1 = _mm512_mul_ps(kernelSum1, weightBiasValue);\ + xy0_2 = _mm512_mul_ps(kernelSum2, weightBiasValue);\ + xy0_3 = _mm512_mul_ps(kernelSum3, weightBiasValue); + +#define PLUS_TERM(N,M,K,V) \ + f##N = _mm512_add_ps(f##N, xy0_0);\ + f##M = _mm512_add_ps(f##M, xy0_1);\ + f##K = _mm512_add_ps(f##K, xy0_2);\ + f##V = _mm512_add_ps(f##V, xy0_3); + +#define POST_TREAT_FLOAT_3(N,M,K) \ + f##N = _mm512_min_ps(f##N, fp32max);\ + f##N = _mm512_max_ps(f##N, fp32min);\ + f##M = _mm512_min_ps(f##M, fp32max);\ + f##M = _mm512_max_ps(f##M, fp32min);\ + f##K = _mm512_min_ps(f##K, fp32max);\ + f##K = _mm512_max_ps(f##K, fp32min); + +#define SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3 \ + xy0_0 = _mm512_mul_ps(kernelSum0, weightBiasValue);\ + xy0_1 = _mm512_mul_ps(kernelSum1, weightBiasValue);\ + xy0_2 = _mm512_mul_ps(kernelSum2, weightBiasValue); + +#define PLUS_TERM_3(N,M,K) \ + f##N = _mm512_add_ps(f##N, xy0_0);\ + f##M = _mm512_add_ps(f##M, xy0_1);\ + f##K = _mm512_add_ps(f##K, xy0_2); + +#define POST_TREAT_FLOAT_2(N,M) \ + f##N = _mm512_min_ps(f##N, fp32max);\ + f##N = _mm512_max_ps(f##N, fp32min);\ + f##M = _mm512_min_ps(f##M, fp32max);\ + f##M = _mm512_max_ps(f##M, fp32min); + +#define SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2 \ + xy0_0 = _mm512_mul_ps(kernelSum0, weightBiasValue);\ + xy0_1 = _mm512_mul_ps(kernelSum1, weightBiasValue); + +#define PLUS_TERM_2(N,M) \ + f##N = _mm512_add_ps(f##N, xy0_0);\ + f##M = _mm512_add_ps(f##M, xy0_1); + +#define POST_TREAT_FLOAT_1(N) \ + f##N = _mm512_min_ps(f##N, fp32max);\ + f##N = _mm512_max_ps(f##N, fp32min); + +#define SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1 \ + xy0_0 = _mm512_mul_ps(kernelSum0, weightBiasValue); + +#define PLUS_TERM_1(N) \ + f##N = _mm512_add_ps(f##N, xy0_0); + // GemmInt8 with NO VNNI void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { @@ -33,10 +100,71 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s int dzU = dst_depth_quad / dzUnit; int dzR = dst_depth_quad % dzUnit; auto one = _mm512_set1_epi16(1); + __m512 fp32min, fp32max; + if (0 == post->useInt8 && post->fp32minmax) { + fp32min = _mm512_set1_ps((post->fp32minmax)[0]); + fp32max = _mm512_set1_ps((post->fp32minmax)[1]); + } + auto blockNum = post->blockNum; + const float* biasPtr = nullptr; + const float* bias_dz = nullptr; + const float* extraB_dz = nullptr; + if (post->biasFloat) { + biasPtr = post->biasFloat; + } + + int weightZStride = blockNum * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); + + auto srcKernelSumPtr = post->srcKernelSum; + __m512 kernelSum0 = _mm512_setzero_ps(); + __m512 kernelSum1 = _mm512_setzero_ps(); + __m512 kernelSum2 = _mm512_setzero_ps(); + __m512 kernelSum3 = _mm512_setzero_ps(); + if (GEMMINT8_AVX512_E == realDst) { + kernelSum0 = _mm512_set1_ps(post->srcKernelSum[0]); + kernelSum1 = _mm512_set1_ps(post->srcKernelSum[1]); + kernelSum2 = _mm512_set1_ps(post->srcKernelSum[2]); + kernelSum3 = _mm512_set1_ps(post->srcKernelSum[3]); + } else { + kernelSum0 = _mm512_set1_ps(post->srcKernelSum[0]); + if (realDst > 1) { + kernelSum1 = _mm512_set1_ps(post->srcKernelSum[1]); + } + if (realDst > 2) { + kernelSum2 = _mm512_set1_ps(post->srcKernelSum[2]); + } + } + auto f128 = _mm512_set1_ps(128.f); + __m512 extrascale0 = _mm512_setzero_ps(); + __m512 extrascale1 = _mm512_setzero_ps(); + __m512 extrascale2 = _mm512_setzero_ps(); + __m512 extrascale3 = _mm512_setzero_ps(); + if (post->extraScale) { + if (GEMMINT8_AVX512_E == realDst) { + extrascale0 = _mm512_set1_ps(post->extraScale[0]); + extrascale1 = _mm512_set1_ps(post->extraScale[1]); + extrascale2 = _mm512_set1_ps(post->extraScale[2]); + extrascale3 = _mm512_set1_ps(post->extraScale[3]); + } else { + extrascale0 = _mm512_set1_ps(post->extraScale[0]); + if (realDst > 1) { + extrascale1 = _mm512_set1_ps(post->extraScale[1]); + } + if (realDst > 2) { + extrascale2 = _mm512_set1_ps(post->extraScale[2]); + } + } + } if (realDst == GEMMINT8_AVX512_E) { for (int dz = 0; dz < dzU; ++dz) { - auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dz * PACK_UNIT * dzUnit; + auto weight_dz = weight + dz * weightZStride; + if (post->biasFloat) { + bias_dz = biasPtr + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; auto dst_z = dst + dz * dst_step_tmp * dzUnit; const auto src_x = src; @@ -66,9 +194,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s const auto weight_sz = weight_dz + (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) * sz; const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); auto w0 = _mm512_loadu_si512(weight_sz); - auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_E); + auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L); auto s0 = AVX512_BROADCAST_INT32(src_z + 0); auto s1 = AVX512_BROADCAST_INT32(src_z + 1); @@ -95,37 +223,185 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s D14 = mnn_mm512_dpbusds_epi32(D14, s2, w3); D15 = mnn_mm512_dpbusds_epi32(D15, s3, w3); } - - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2, xy0_3; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + DEQUANT_VALUE(3); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + f3 = _mm512_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + f3 = _mm512_sub_ps(f3, extrabias3); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); - SCALE_BIAS_VEC(2); - SCALE_BIAS_VEC(3); + PLUS_TERM(0,1,2,3); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + SCALE_BIAS_VEC(3); + } - biasValue = _mm512_loadu_si512(bias_dz + 1 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); - SCALE_BIAS_VEC(4); - SCALE_BIAS_VEC(5); - SCALE_BIAS_VEC(6); - SCALE_BIAS_VEC(7); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + DEQUANT_VALUE(6); + DEQUANT_VALUE(7); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + f6 = _mm512_mul_ps(f6, extrascale2); + f7 = _mm512_mul_ps(f7, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + f6 = _mm512_sub_ps(f6, extrabias2); + f7 = _mm512_sub_ps(f7, extrabias3); + } + } + + PLUS_TERM(4,5,6,7); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + SCALE_BIAS_VEC(6); + SCALE_BIAS_VEC(7); + } - biasValue = _mm512_loadu_si512(bias_dz + 2 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); - SCALE_BIAS_VEC(8); - SCALE_BIAS_VEC(9); - SCALE_BIAS_VEC(10); - SCALE_BIAS_VEC(11); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + DEQUANT_VALUE(10); + DEQUANT_VALUE(11); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + f10 = _mm512_mul_ps(f10, extrascale2); + f11 = _mm512_mul_ps(f11, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + f10 = _mm512_sub_ps(f10, extrabias2); + f11 = _mm512_sub_ps(f11, extrabias3); + } + } + + PLUS_TERM(8,9,10,11); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + SCALE_BIAS_VEC(10); + SCALE_BIAS_VEC(11); + } - biasValue = _mm512_loadu_si512(bias_dz + 3 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); - SCALE_BIAS_VEC(12); - SCALE_BIAS_VEC(13); - SCALE_BIAS_VEC(14); - SCALE_BIAS_VEC(15); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + DEQUANT_VALUE(14); + DEQUANT_VALUE(15); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + f14 = _mm512_mul_ps(f14, extrascale2); + f15 = _mm512_mul_ps(f15, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + f14 = _mm512_sub_ps(f14, extrabias2); + f15 = _mm512_sub_ps(f15, extrabias3); + } + } + + PLUS_TERM(12,13,14,15); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + SCALE_BIAS_VEC(14); + SCALE_BIAS_VEC(15); + } if (post->useInt8 == 0) { + if (biasPtr == nullptr) { + auto destTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps((float*)destTmp), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f2); + f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f3); + destTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f5); + f6 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f6); + f7 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f7); + destTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f9); + f10 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f10); + f11 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f11); + destTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f13); + f14 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f14); + f15 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f15); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT(0,1,2,3); + POST_TREAT_FLOAT(4,5,6,7); + POST_TREAT_FLOAT(8,9,10,11); + POST_TREAT_FLOAT(12,13,14,15); + } + _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); @@ -170,9 +446,15 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s POSTTREAT(15, 3); } } - auto weight_dz = weight + dzU * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dzU * PACK_UNIT * dzUnit; + auto weight_dz = weight + dzU * weightZStride; + if (biasPtr) { + bias_dz = biasPtr + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; auto dst_z = dst + dzU * dst_step_tmp * dzUnit; const auto src_x = src; @@ -199,15 +481,54 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s D3 = mnn_mm512_dpbusds_epi32(D3, s3, w0); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2, xy0_3; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + DEQUANT_VALUE(3); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + f3 = _mm512_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + f3 = _mm512_sub_ps(f3, extrabias3); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); - SCALE_BIAS_VEC(2); - SCALE_BIAS_VEC(3); + PLUS_TERM(0,1,2,3); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + SCALE_BIAS_VEC(3); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps((float*)dst_x), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 2), f2); + f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 3), f3); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT(0,1,2,3); + } _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); @@ -220,17 +541,28 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s } dst_x += dst_step_tmp; scale_dz += PACK_UNIT; - bias_dz += PACK_UNIT; - weight_dz += PACK_UNIT * GEMMINT8_AVX512_E; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; } return; } // e = 3 if (realDst == 3) { for (int dz = 0; dz < dzU; ++dz) { - auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dz * PACK_UNIT * dzUnit; + auto weight_dz = weight + dz * weightZStride; + if (biasPtr) { + bias_dz = biasPtr + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; auto dst_z = dst + dz * dst_step_tmp * dzUnit; const auto src_x = src; auto dst_x = dst_z; @@ -255,9 +587,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s const auto weight_sz = weight_dz + (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) * sz; const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); auto w0 = _mm512_loadu_si512(weight_sz); - auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_E); + auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L); auto s0 = AVX512_BROADCAST_INT32(src_z + 0); auto s1 = AVX512_BROADCAST_INT32(src_z + 1); @@ -280,32 +612,160 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s D14 = mnn_mm512_dpbusds_epi32(D14, s2, w3); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); - SCALE_BIAS_VEC(2); + PLUS_TERM_3(0,1,2); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + } - biasValue = _mm512_loadu_si512(bias_dz + 1 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); - SCALE_BIAS_VEC(4); - SCALE_BIAS_VEC(5); - SCALE_BIAS_VEC(6); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + DEQUANT_VALUE(6); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + f6 = _mm512_mul_ps(f6, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + f6 = _mm512_sub_ps(f6, extrabias2); + } + } + + PLUS_TERM_3(4,5,6); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + SCALE_BIAS_VEC(6); + } - biasValue = _mm512_loadu_si512(bias_dz + 2 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); - SCALE_BIAS_VEC(8); - SCALE_BIAS_VEC(9); - SCALE_BIAS_VEC(10); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + DEQUANT_VALUE(10); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + f10 = _mm512_mul_ps(f10, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + f10 = _mm512_sub_ps(f10, extrabias2); + } + } + + PLUS_TERM_3(8,9,10); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + SCALE_BIAS_VEC(10); + } - biasValue = _mm512_loadu_si512(bias_dz + 3 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); - SCALE_BIAS_VEC(12); - SCALE_BIAS_VEC(13); - SCALE_BIAS_VEC(14); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + DEQUANT_VALUE(14); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + f14 = _mm512_mul_ps(f14, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + f14 = _mm512_sub_ps(f14, extrabias2); + } + } + + PLUS_TERM_3(12,13,14); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + SCALE_BIAS_VEC(14); + } if (post->useInt8 == 0) { + if (biasPtr == nullptr) { + auto dstTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f2); + dstTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f5); + f6 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f6); + dstTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f9); + f10 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f10); + dstTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f13); + f14 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f14); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_3(0,1,2); + POST_TREAT_FLOAT_3(4,5,6); + POST_TREAT_FLOAT_3(8,9,10); + POST_TREAT_FLOAT_3(12,13,14); + } _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); @@ -342,9 +802,15 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s POSTTREAT(14, 2); } } - auto weight_dz = weight + dzU * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dzU * PACK_UNIT * dzUnit; + auto weight_dz = weight + dzU * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; auto dst_z = dst + dzU * dst_step_tmp * dzUnit; const auto src_x = src; @@ -368,14 +834,49 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s D2 = mnn_mm512_dpbusds_epi32(D2, s2, w0); } - auto biasValue = _mm512_loadu_si512(bias_dz); + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); - SCALE_BIAS_VEC(2); - + PLUS_TERM_3(0,1,2); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + } + if (post->useInt8 == 0) { + if (biasPtr == nullptr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 2), f2); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_3(0,1,2); + } _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); @@ -386,17 +887,28 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s } dst_x += dst_step_tmp; scale_dz += PACK_UNIT; - bias_dz += PACK_UNIT; - weight_dz += PACK_UNIT * GEMMINT8_AVX512_E; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; } return; } // e = 2 if (realDst == 2) { for (int dz = 0; dz < dzU; ++dz) { - auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dz * PACK_UNIT * dzUnit; + auto weight_dz = weight + dz * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; auto dst_z = dst + dz * dst_step_tmp * dzUnit; const auto src_x = src; auto dst_x = dst_z; @@ -417,9 +929,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s const auto weight_sz = weight_dz + (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) * sz; const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); auto w0 = _mm512_loadu_si512(weight_sz); - auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_E); + auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L); auto s0 = AVX512_BROADCAST_INT32(src_z + 0); auto s1 = AVX512_BROADCAST_INT32(src_z + 1); @@ -437,28 +949,135 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s D13 = mnn_mm512_dpbusds_epi32(D13, s1, w3); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1; + + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); + PLUS_TERM_2(0,1); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + } - biasValue = _mm512_loadu_si512(bias_dz + 1 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); - SCALE_BIAS_VEC(4); - SCALE_BIAS_VEC(5); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + } + } + + PLUS_TERM_2(4,5); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + } - biasValue = _mm512_loadu_si512(bias_dz + 2 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); - SCALE_BIAS_VEC(8); - SCALE_BIAS_VEC(9); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + } + } + + PLUS_TERM_2(8,9); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + } - biasValue = _mm512_loadu_si512(bias_dz + 3 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); - SCALE_BIAS_VEC(12); - SCALE_BIAS_VEC(13); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + } + } + + PLUS_TERM_2(12,13); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + auto dstTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16), f1); + dstTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f5); + dstTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f9); + dstTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f13); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_2(0,1); + POST_TREAT_FLOAT_2(4,5); + POST_TREAT_FLOAT_2(8,9); + POST_TREAT_FLOAT_2(12,13); + } _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); dst_x += dst_step_tmp; @@ -487,9 +1106,15 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s POSTTREAT(13, 1); } } - auto weight_dz = weight + dzU * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dzU * PACK_UNIT * dzUnit; + auto weight_dz = weight + dzU * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; auto dst_z = dst + dzU * dst_step_tmp * dzUnit; const auto src_x = src; @@ -510,13 +1135,40 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s D1 = mnn_mm512_dpbusds_epi32(D1, s1, w0); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + } + } - SCALE_BIAS_VEC(0); - SCALE_BIAS_VEC(1); + PLUS_TERM_2(0,1); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + } + POST_TREAT_FLOAT_2(0,1); _mm512_storeu_ps(((float*)dst_x), f0); _mm512_storeu_ps(((float*)dst_x) + 16, f1); } else { @@ -525,16 +1177,27 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s } dst_x += dst_step_tmp; scale_dz += PACK_UNIT; - bias_dz += PACK_UNIT; - weight_dz += PACK_UNIT * GEMMINT8_AVX512_E; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; } return; } if (realDst == 1) { for (int dz = 0; dz < dzU; ++dz) { - auto weight_dz = weight + dz * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dz * PACK_UNIT * dzUnit; + auto weight_dz = weight + dz * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; auto dst_z = dst + dz * dst_step_tmp * dzUnit; const auto src_x = src; auto dst_x = dst_z; @@ -550,9 +1213,9 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s const auto weight_sz = weight_dz + (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) * sz; const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); auto w0 = _mm512_loadu_si512(weight_sz); - auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_E); - auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_E); + auto w1 = _mm512_loadu_si512(weight_sz + 1 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w2 = _mm512_loadu_si512(weight_sz + 2 * PACK_UNIT * GEMMINT8_AVX512_L); + auto w3 = _mm512_loadu_si512(weight_sz + 3 * PACK_UNIT * GEMMINT8_AVX512_L); auto s0 = AVX512_BROADCAST_INT32(src_z + 0); @@ -565,24 +1228,113 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s D12 = mnn_mm512_dpbusds_epi32(D12, s0, w3); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0; + + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(0); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f0 = _mm512_sub_ps(f0, extrabias0); + } + } - SCALE_BIAS_VEC(0); + PLUS_TERM_1(0); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + } - biasValue = _mm512_loadu_si512(bias_dz + 1 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); - SCALE_BIAS_VEC(4); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(4); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f4 = _mm512_sub_ps(f4, extrabias0); + } + } + + PLUS_TERM_1(4); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + } - biasValue = _mm512_loadu_si512(bias_dz + 2 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); - SCALE_BIAS_VEC(8); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(8); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f8 = _mm512_sub_ps(f8, extrabias0); + } + } + + PLUS_TERM_1(8); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + } - biasValue = _mm512_loadu_si512(bias_dz + 3 * PACK_UNIT); scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); - SCALE_BIAS_VEC(12); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(12); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f12 = _mm512_sub_ps(f12, extrabias0); + } + } + + PLUS_TERM_1(12); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + auto dstTemp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp)), f0); + dstTemp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f4); + dstTemp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f8); + dstTemp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f12); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_1(0); + POST_TREAT_FLOAT_1(4); + POST_TREAT_FLOAT_1(8); + POST_TREAT_FLOAT_1(12); + } _mm512_storeu_ps(((float*)dst_x), f0); dst_x += dst_step_tmp; _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); @@ -603,9 +1355,15 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s POSTTREAT(12, 0); } } - auto weight_dz = weight + dzU * src_depth_quad * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H); - auto bias_dz = (int32_t*)post->bias + dzU * PACK_UNIT * dzUnit; + auto weight_dz = weight + dzU * weightZStride; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; auto dst_z = dst + dzU * dst_step_tmp * dzUnit; const auto src_x = src; @@ -623,20 +1381,1325 @@ void MATMULCOREFUNC_NAME(int8_t* dst, const int8_t* src, const int8_t* weight, s D0 = mnn_mm512_dpbusds_epi32(D0, s0, w0); } - auto biasValue = _mm512_loadu_si512(bias_dz); auto scaleValue = _mm512_loadu_ps(scale_dz); - SCALE_BIAS_VEC(0); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(0); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f0 = _mm512_sub_ps(f0, extrabias0); + } + } + + PLUS_TERM_1(0); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + } if (post->useInt8 == 0) { + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_1(0); + } _mm512_storeu_ps(((float*)dst_x), f0); } else { POSTTREAT(0, 0); } dst_x += dst_step_tmp; scale_dz += PACK_UNIT; - bias_dz += PACK_UNIT; - weight_dz += PACK_UNIT * GEMMINT8_AVX512_E; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; + } + return; + } +} + +void MATMULCOREFUNC_NAME_W4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { + MNN_ASSERT(post->useInt8==0); + const auto dst_step_tmp = dst_step / sizeof(int8_t); + auto zero512 = _mm512_set1_ps(0.0f); + auto offset = _mm256_set1_epi16(128); + int dzUnit = GEMMINT8_AVX512_H / PACK_UNIT; + int dzU = dst_depth_quad / dzUnit; + int dzR = dst_depth_quad % dzUnit; + auto one = _mm512_set1_epi16(1); + __m512 fp32min, fp32max; + if (0 == post->useInt8 && post->fp32minmax) { + fp32min = _mm512_set1_ps((post->fp32minmax)[0]); + fp32max = _mm512_set1_ps((post->fp32minmax)[1]); + } + auto blockNum = post->blockNum; + const float* biasPtr = nullptr; + const float* bias_dz = nullptr; + const float* extraB_dz = nullptr; + if (post->biasFloat) { + biasPtr = post->biasFloat; + } + + auto srcKernelSumPtr = post->srcKernelSum; + __m512 kernelSum0 = _mm512_setzero_ps(); + __m512 kernelSum1 = _mm512_setzero_ps(); + __m512 kernelSum2 = _mm512_setzero_ps(); + __m512 kernelSum3 = _mm512_setzero_ps(); + + int weight_step_Z = static_cast(src_depth_quad * blockNum * (GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) / 2); + int weight_step_Y = static_cast(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H / 2); + const __m512i mask = _mm512_set1_epi8(0xf); + if (GEMMINT8_AVX512_E == realDst) { + kernelSum0 = _mm512_set1_ps(post->srcKernelSum[0]); + kernelSum1 = _mm512_set1_ps(post->srcKernelSum[1]); + kernelSum2 = _mm512_set1_ps(post->srcKernelSum[2]); + kernelSum3 = _mm512_set1_ps(post->srcKernelSum[3]); + } else { + kernelSum0 = _mm512_set1_ps(post->srcKernelSum[0]); + if (realDst > 1) { + kernelSum1 = _mm512_set1_ps(post->srcKernelSum[1]); + } + if (realDst > 2) { + kernelSum2 = _mm512_set1_ps(post->srcKernelSum[2]); + } + } + auto f128 = _mm512_set1_ps(128.f); + __m512 extrascale0 = _mm512_setzero_ps(); + __m512 extrascale1 = _mm512_setzero_ps(); + __m512 extrascale2 = _mm512_setzero_ps(); + __m512 extrascale3 = _mm512_setzero_ps(); + if (post->extraScale) { + if (GEMMINT8_AVX512_E == realDst) { + extrascale0 = _mm512_set1_ps(post->extraScale[0]); + extrascale1 = _mm512_set1_ps(post->extraScale[1]); + extrascale2 = _mm512_set1_ps(post->extraScale[2]); + extrascale3 = _mm512_set1_ps(post->extraScale[3]); + } else { + extrascale0 = _mm512_set1_ps(post->extraScale[0]); + if (realDst > 1) { + extrascale1 = _mm512_set1_ps(post->extraScale[1]); + } + if (realDst > 2) { + extrascale2 = _mm512_set1_ps(post->extraScale[2]); + } + } + } + + if (realDst == GEMMINT8_AVX512_E) { + for (int dz = 0; dz < dzU; ++dz) { + auto weight_dz = weight + dz * weight_step_Z; + if (post->biasFloat) { + bias_dz = biasPtr + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; + float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + auto dst_z = dst + dz * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + __m512i D0 = _mm512_set1_epi32(0); + __m512i D1 = _mm512_set1_epi32(0); + __m512i D2 = _mm512_set1_epi32(0); + __m512i D3 = _mm512_set1_epi32(0); + + __m512i D4 = _mm512_set1_epi32(0); + __m512i D5 = _mm512_set1_epi32(0); + __m512i D6 = _mm512_set1_epi32(0); + __m512i D7 = _mm512_set1_epi32(0); + + __m512i D8 = _mm512_set1_epi32(0); + __m512i D9 = _mm512_set1_epi32(0); + __m512i D10 = _mm512_set1_epi32(0); + __m512i D11 = _mm512_set1_epi32(0); + + __m512i D12 = _mm512_set1_epi32(0); + __m512i D13 = _mm512_set1_epi32(0); + __m512i D14 = _mm512_set1_epi32(0); + __m512i D15 = _mm512_set1_epi32(0); + + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); + // int4->int8: total count=4*64(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) + // Load 4*64 int4 weight + auto w0_int4_64 = _mm512_loadu_si512(weight_sz); // 128xint4_t=64 byte + auto w1_int4_64 = _mm512_loadu_si512(weight_sz + 64); // 128xint4_t + // 256xint4_t->256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + auto w2 = _mm512_and_si512(mask, w0_int4_64); // 64xint8_t + auto w1 = _mm512_and_si512(mask, _mm512_srli_epi16(w1_int4_64, 4)); + auto w3 = _mm512_and_si512(mask, w1_int4_64); + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + auto s1 = AVX512_BROADCAST_INT32(src_z + 1); + auto s2 = AVX512_BROADCAST_INT32(src_z + 2); + auto s3 = AVX512_BROADCAST_INT32(src_z + 3); + + D0 = mnn_mm512_dpbusds_epi32(D0, s0, w0); + D1 = mnn_mm512_dpbusds_epi32(D1, s1, w0); + D2 = mnn_mm512_dpbusds_epi32(D2, s2, w0); + D3 = mnn_mm512_dpbusds_epi32(D3, s3, w0); + + D4 = mnn_mm512_dpbusds_epi32(D4, s0, w1); + D5 = mnn_mm512_dpbusds_epi32(D5, s1, w1); + D6 = mnn_mm512_dpbusds_epi32(D6, s2, w1); + D7 = mnn_mm512_dpbusds_epi32(D7, s3, w1); + + D8 = mnn_mm512_dpbusds_epi32(D8, s0, w2); + D9 = mnn_mm512_dpbusds_epi32(D9, s1, w2); + D10 = mnn_mm512_dpbusds_epi32(D10, s2, w2); + D11 = mnn_mm512_dpbusds_epi32(D11, s3, w2); + + D12 = mnn_mm512_dpbusds_epi32(D12, s0, w3); + D13 = mnn_mm512_dpbusds_epi32(D13, s1, w3); + D14 = mnn_mm512_dpbusds_epi32(D14, s2, w3); + D15 = mnn_mm512_dpbusds_epi32(D15, s3, w3); + } + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2, xy0_3; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + DEQUANT_VALUE(3); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + f3 = _mm512_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + f3 = _mm512_sub_ps(f3, extrabias3); + } + } + + PLUS_TERM(0,1,2,3); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + SCALE_BIAS_VEC(3); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + DEQUANT_VALUE(6); + DEQUANT_VALUE(7); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + f6 = _mm512_mul_ps(f6, extrascale2); + f7 = _mm512_mul_ps(f7, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + f6 = _mm512_sub_ps(f6, extrabias2); + f7 = _mm512_sub_ps(f7, extrabias3); + } + } + + PLUS_TERM(4,5,6,7); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + SCALE_BIAS_VEC(6); + SCALE_BIAS_VEC(7); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + DEQUANT_VALUE(10); + DEQUANT_VALUE(11); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + f10 = _mm512_mul_ps(f10, extrascale2); + f11 = _mm512_mul_ps(f11, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + f10 = _mm512_sub_ps(f10, extrabias2); + f11 = _mm512_sub_ps(f11, extrabias3); + } + } + + PLUS_TERM(8,9,10,11); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + SCALE_BIAS_VEC(10); + SCALE_BIAS_VEC(11); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + DEQUANT_VALUE(14); + DEQUANT_VALUE(15); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + f14 = _mm512_mul_ps(f14, extrascale2); + f15 = _mm512_mul_ps(f15, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + f14 = _mm512_sub_ps(f14, extrabias2); + f15 = _mm512_sub_ps(f15, extrabias3); + } + } + + PLUS_TERM(12,13,14,15); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + SCALE_BIAS_VEC(14); + SCALE_BIAS_VEC(15); + } + if (biasPtr == nullptr) { + auto destTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps((float*)destTmp), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f2); + f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f3); + destTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f5); + f6 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f6); + f7 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f7); + destTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f9); + f10 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f10); + f11 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f11); + destTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 1), f13); + f14 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 2), f14); + f15 = _mm512_add_ps(_mm512_loadu_ps(((float*)destTmp) + 16 * 3), f15); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT(0,1,2,3); + POST_TREAT_FLOAT(4,5,6,7); + POST_TREAT_FLOAT(8,9,10,11); + POST_TREAT_FLOAT(12,13,14,15); + } + + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f3); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f6); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f7); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f10); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f11); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f14); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f15); + + } + auto weight_dz = weight + dzU * weight_step_Z; + if (biasPtr) { + bias_dz = biasPtr + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; + + auto dst_z = dst + dzU * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + for (int i=0; iextraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + f3 = _mm512_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm512_mul_ps(extrabias, extrascale3); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + f3 = _mm512_sub_ps(f3, extrabias3); + } + } + + PLUS_TERM(0,1,2,3); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + SCALE_BIAS_VEC(3); + } + + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps((float*)dst_x), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 2), f2); + f3 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 3), f3); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT(0,1,2,3); + } + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); + _mm512_storeu_ps(((float*)dst_x) + 16 * 3, f3); + + dst_x += dst_step_tmp; + scale_dz += PACK_UNIT; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; + } + return; + } + // e = 3 + if (realDst == 3) { + for (int dz = 0; dz < dzU; ++dz) { + auto weight_dz = weight + dz * weight_step_Z; + if (biasPtr) { + bias_dz = biasPtr + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; + auto dst_z = dst + dz * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + __m512i D0 = _mm512_set1_epi32(0); + __m512i D1 = _mm512_set1_epi32(0); + __m512i D2 = _mm512_set1_epi32(0); + + __m512i D4 = _mm512_set1_epi32(0); + __m512i D5 = _mm512_set1_epi32(0); + __m512i D6 = _mm512_set1_epi32(0); + + __m512i D8 = _mm512_set1_epi32(0); + __m512i D9 = _mm512_set1_epi32(0); + __m512i D10 = _mm512_set1_epi32(0); + + __m512i D12 = _mm512_set1_epi32(0); + __m512i D13 = _mm512_set1_epi32(0); + __m512i D14 = _mm512_set1_epi32(0); + + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); + // int4->int8: total count=4*64(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) + // Load 4*64 int4 weight + auto w0_int4_64 = _mm512_loadu_si512(weight_sz); // 128xint4_t=64 byte + auto w1_int4_64 = _mm512_loadu_si512(weight_sz + 64); // 128xint4_t + // 256xint4_t->256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + auto w2 = _mm512_and_si512(mask, w0_int4_64); // 64xint8_t + auto w1 = _mm512_and_si512(mask, _mm512_srli_epi16(w1_int4_64, 4)); + auto w3 = _mm512_and_si512(mask, w1_int4_64); + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + auto s1 = AVX512_BROADCAST_INT32(src_z + 1); + auto s2 = AVX512_BROADCAST_INT32(src_z + 2); + + D0 = mnn_mm512_dpbusds_epi32(D0, s0, w0); + D1 = mnn_mm512_dpbusds_epi32(D1, s1, w0); + D2 = mnn_mm512_dpbusds_epi32(D2, s2, w0); + + D4 = mnn_mm512_dpbusds_epi32(D4, s0, w1); + D5 = mnn_mm512_dpbusds_epi32(D5, s1, w1); + D6 = mnn_mm512_dpbusds_epi32(D6, s2, w1); + + D8 = mnn_mm512_dpbusds_epi32(D8, s0, w2); + D9 = mnn_mm512_dpbusds_epi32(D9, s1, w2); + D10 = mnn_mm512_dpbusds_epi32(D10, s2, w2); + + D12 = mnn_mm512_dpbusds_epi32(D12, s0, w3); + D13 = mnn_mm512_dpbusds_epi32(D13, s1, w3); + D14 = mnn_mm512_dpbusds_epi32(D14, s2, w3); + } + + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1, xy0_2; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + DEQUANT_VALUE(2); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + } + } + + PLUS_TERM_3(0,1,2); + if (nullptr != biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + DEQUANT_VALUE(6); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + f6 = _mm512_mul_ps(f6, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + f6 = _mm512_sub_ps(f6, extrabias2); + } + } + + PLUS_TERM_3(4,5,6); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + SCALE_BIAS_VEC(6); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + DEQUANT_VALUE(10); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + f10 = _mm512_mul_ps(f10, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + f10 = _mm512_sub_ps(f10, extrabias2); + } + } + + PLUS_TERM_3(8,9,10); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + SCALE_BIAS_VEC(10); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_3; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + DEQUANT_VALUE(14); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + f14 = _mm512_mul_ps(f14, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + f14 = _mm512_sub_ps(f14, extrabias2); + } + } + + PLUS_TERM_3(12,13,14); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + SCALE_BIAS_VEC(14); + } + + if (biasPtr == nullptr) { + auto dstTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f2); + dstTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f5); + f6 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f6); + dstTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f9); + f10 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f10); + dstTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f13); + f14 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 2), f14); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_3(0,1,2); + POST_TREAT_FLOAT_3(4,5,6); + POST_TREAT_FLOAT_3(8,9,10); + POST_TREAT_FLOAT_3(12,13,14); + } + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f6); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f10); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f14); + + } + auto weight_dz = weight + dzU * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; + + auto dst_z = dst + dzU * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + for (int i=0; iextraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + f2 = _mm512_mul_ps(f2, extrascale2); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm512_mul_ps(extrabias, extrascale2); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + f2 = _mm512_sub_ps(f2, extrabias2); + } + } + + PLUS_TERM_3(0,1,2); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + SCALE_BIAS_VEC(2); + } + + if (biasPtr == nullptr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + f2 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16 * 2), f2); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_3(0,1,2); + } + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + _mm512_storeu_ps(((float*)dst_x) + 16 * 2, f2); + + dst_x += dst_step_tmp; + scale_dz += PACK_UNIT; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; + } + return; + } + // e = 2 + if (realDst == 2) { + for (int dz = 0; dz < dzU; ++dz) { + auto weight_dz = weight + dz * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; + auto dst_z = dst + dz * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + __m512i D0 = _mm512_set1_epi32(0); + __m512i D1 = _mm512_set1_epi32(0); + + __m512i D4 = _mm512_set1_epi32(0); + __m512i D5 = _mm512_set1_epi32(0); + + __m512i D8 = _mm512_set1_epi32(0); + __m512i D9 = _mm512_set1_epi32(0); + + __m512i D12 = _mm512_set1_epi32(0); + __m512i D13 = _mm512_set1_epi32(0); + + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); + // int4->int8: total count=4*64(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) + // Load 4*64 int4 weight + auto w0_int4_64 = _mm512_loadu_si512(weight_sz); // 128xint4_t=64 byte + auto w1_int4_64 = _mm512_loadu_si512(weight_sz + 64); // 128xint4_t + // 256xint4_t->256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + auto w2 = _mm512_and_si512(mask, w0_int4_64); // 64xint8_t + auto w1 = _mm512_and_si512(mask, _mm512_srli_epi16(w1_int4_64, 4)); + auto w3 = _mm512_and_si512(mask, w1_int4_64); + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + auto s1 = AVX512_BROADCAST_INT32(src_z + 1); + + D0 = mnn_mm512_dpbusds_epi32(D0, s0, w0); + D1 = mnn_mm512_dpbusds_epi32(D1, s1, w0); + + D4 = mnn_mm512_dpbusds_epi32(D4, s0, w1); + D5 = mnn_mm512_dpbusds_epi32(D5, s1, w1); + + D8 = mnn_mm512_dpbusds_epi32(D8, s0, w2); + D9 = mnn_mm512_dpbusds_epi32(D9, s1, w2); + + D12 = mnn_mm512_dpbusds_epi32(D12, s0, w3); + D13 = mnn_mm512_dpbusds_epi32(D13, s1, w3); + } + + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1; + + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + } + } + + PLUS_TERM_2(0,1); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(4); + DEQUANT_VALUE(5); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + f5 = _mm512_mul_ps(f5, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f4 = _mm512_sub_ps(f4, extrabias0); + f5 = _mm512_sub_ps(f5, extrabias1); + } + } + + PLUS_TERM_2(4,5); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + SCALE_BIAS_VEC(5); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(8); + DEQUANT_VALUE(9); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + f9 = _mm512_mul_ps(f9, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f8 = _mm512_sub_ps(f8, extrabias0); + f9 = _mm512_sub_ps(f9, extrabias1); + } + } + + PLUS_TERM_2(8,9); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + SCALE_BIAS_VEC(9); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(12); + DEQUANT_VALUE(13); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + f13 = _mm512_mul_ps(f13, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f12 = _mm512_sub_ps(f12, extrabias0); + f13 = _mm512_sub_ps(f13, extrabias1); + } + } + + PLUS_TERM_2(12,13); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + SCALE_BIAS_VEC(13); + } + + if (nullptr == biasPtr) { + auto dstTmp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16), f1); + dstTmp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f4); + f5 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f5); + dstTmp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f8); + f9 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f9); + dstTmp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 0), f12); + f13 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTmp) + 16 * 1), f13); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_2(0,1); + POST_TREAT_FLOAT_2(4,5); + POST_TREAT_FLOAT_2(8,9); + POST_TREAT_FLOAT_2(12,13); + } + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f5); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f9); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12); + _mm512_storeu_ps(((float*)dst_x) + 16 * 1, f13); + + } + auto weight_dz = weight + dzU * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; + + auto dst_z = dst + dzU * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + for (int i=0; i256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + auto s1 = AVX512_BROADCAST_INT32(src_z + 1); + + D0 = mnn_mm512_dpbusds_epi32(D0, s0, w0); + D1 = mnn_mm512_dpbusds_epi32(D1, s1, w0); + } + + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0, xy0_1; + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_2; + DEQUANT_VALUE(0); + DEQUANT_VALUE(1); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + f1 = _mm512_mul_ps(f1, extrascale1); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm512_mul_ps(extrabias, extrascale1); + f0 = _mm512_sub_ps(f0, extrabias0); + f1 = _mm512_sub_ps(f1, extrabias1); + } + } + + PLUS_TERM_2(0,1); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + SCALE_BIAS_VEC(1); + } + + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + f1 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x) + 16), f1); + } + POST_TREAT_FLOAT_2(0,1); + _mm512_storeu_ps(((float*)dst_x), f0); + _mm512_storeu_ps(((float*)dst_x) + 16, f1); + + dst_x += dst_step_tmp; + scale_dz += PACK_UNIT; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; + } + return; + } + if (realDst == 1) { + for (int dz = 0; dz < dzU; ++dz) { + auto weight_dz = weight + dz * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dz * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dz * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dz * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dz * PACK_UNIT * dzUnit; + auto dst_z = dst + dz * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + __m512i D0 = _mm512_set1_epi32(0); + + __m512i D4 = _mm512_set1_epi32(0); + + __m512i D8 = _mm512_set1_epi32(0); + + __m512i D12 = _mm512_set1_epi32(0); + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = (const float*)(src_x + sz * GEMMINT8_AVX512_E * GEMMINT8_AVX512_L); + // int4->int8: total count=4*64(GEMMINT8_AVX512_L * GEMMINT8_AVX512_H) + // Load 4*64 int4 weight + auto w0_int4_64 = _mm512_loadu_si512(weight_sz); // 128xint4_t=64 byte + auto w1_int4_64 = _mm512_loadu_si512(weight_sz + 64); // 128xint4_t + // 256xint4_t->256xint8_t + auto w0 = _mm512_and_si512(mask, _mm512_srli_epi16(w0_int4_64, 4)); // 64xint8_t + auto w2 = _mm512_and_si512(mask, w0_int4_64); // 64xint8_t + auto w1 = _mm512_and_si512(mask, _mm512_srli_epi16(w1_int4_64, 4)); + auto w3 = _mm512_and_si512(mask, w1_int4_64); + + auto s0 = AVX512_BROADCAST_INT32(src_z + 0); + + D0 = mnn_mm512_dpbusds_epi32(D0, s0, w0); + + D4 = mnn_mm512_dpbusds_epi32(D4, s0, w1); + + D8 = mnn_mm512_dpbusds_epi32(D8, s0, w2); + + D12 = mnn_mm512_dpbusds_epi32(D12, s0, w3); + } + + auto scaleValue = _mm512_loadu_ps(scale_dz); + auto weightBiasValue = _mm512_loadu_ps(weightBias_dz); + __m512 xy0_0; + + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(0); + + if (post->extraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f0 = _mm512_sub_ps(f0, extrabias0); + } + } + + PLUS_TERM_1(0); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 1 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 1 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(4); + + if (post->extraScale) { // Batch quant + f4 = _mm512_mul_ps(f4, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 1 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f4 = _mm512_sub_ps(f4, extrabias0); + } + } + + PLUS_TERM_1(4); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 1 * PACK_UNIT); + SCALE_BIAS_VEC(4); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 2 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 2 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(8); + + if (post->extraScale) { // Batch quant + f8 = _mm512_mul_ps(f8, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 2 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f8 = _mm512_sub_ps(f8, extrabias0); + } + } + + PLUS_TERM_1(8); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 2 * PACK_UNIT); + SCALE_BIAS_VEC(8); + } + + scaleValue = _mm512_loadu_ps(scale_dz + 3 * PACK_UNIT); + weightBiasValue = _mm512_loadu_ps(weightBias_dz + 3 * PACK_UNIT); + // x_kernelSum x w_quantZero + SRCKERNELSUM_MUL_WEIGHTQUANBIAS_1; + DEQUANT_VALUE(12); + + if (post->extraScale) { // Batch quant + f12 = _mm512_mul_ps(f12, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz + 3 * PACK_UNIT); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f12 = _mm512_sub_ps(f12, extrabias0); + } + } + + PLUS_TERM_1(12); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz + 3 * PACK_UNIT); + SCALE_BIAS_VEC(12); + } + + if (nullptr == biasPtr) { + auto dstTemp = dst_x; + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp)), f0); + dstTemp += dst_step_tmp; + f4 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f4); + dstTemp += dst_step_tmp; + f8 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f8); + dstTemp += dst_step_tmp; + f12 = _mm512_add_ps(_mm512_loadu_ps(((float*)dstTemp) + 16 * 0), f12); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_1(0); + POST_TREAT_FLOAT_1(4); + POST_TREAT_FLOAT_1(8); + POST_TREAT_FLOAT_1(12); + } + _mm512_storeu_ps(((float*)dst_x), f0); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f4); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f8); + dst_x += dst_step_tmp; + _mm512_storeu_ps(((float*)dst_x) + 16 * 0, f12); + + } + auto weight_dz = weight + dzU * weight_step_Z; + if (biasPtr) { + bias_dz = post->biasFloat + dzU * PACK_UNIT * dzUnit; + } + if (post->extraBias) { + extraB_dz = post->extraBias + dzU * PACK_UNIT * dzUnit; + } + float* scale_dz = (float*)post->scale + dzU * PACK_UNIT * dzUnit; + const auto weightBias_dz = post->weightQuanBias + dzU * PACK_UNIT * dzUnit; + + auto dst_z = dst + dzU * dst_step_tmp * dzUnit; + const auto src_x = src; + auto dst_x = dst_z; + for (int i=0; iextraScale) { // Batch quant + f0 = _mm512_mul_ps(f0, extrascale0); + if (post->extraBias && nullptr != biasPtr) { + auto extrabias = _mm512_loadu_ps(extraB_dz); + extrabias = _mm512_mul_ps(f128, extrabias); + auto extrabias0 = _mm512_mul_ps(extrabias, extrascale0); + f0 = _mm512_sub_ps(f0, extrabias0); + } + } + + PLUS_TERM_1(0); + if (biasPtr) { + auto biasValue = _mm512_loadu_ps(bias_dz); + SCALE_BIAS_VEC(0); + } + + if (nullptr == biasPtr) { + f0 = _mm512_add_ps(_mm512_loadu_ps(((float*)dst_x)), f0); + } + if (post->fp32minmax) { + POST_TREAT_FLOAT_1(0); + } + _mm512_storeu_ps(((float*)dst_x), f0); + dst_x += dst_step_tmp; + scale_dz += PACK_UNIT; + if (biasPtr) { + bias_dz += PACK_UNIT; + } + if (post->extraBias) { + extraB_dz += PACK_UNIT; + } + weight_dz += PACK_UNIT * GEMMINT8_AVX512_L; } return; } diff --git a/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp b/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp index 6dbc438d7..047c3dc7a 100644 --- a/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp +++ b/source/backend/cpu/x86_x64/avx512/PackedFunction.cpp @@ -39,11 +39,11 @@ void _AVX512_MNNAddC4WithStride(const float* source, float* dest, size_t srcStri void _AVX512_MNNComputeScaleZeroScalar(float* source, float* min, float* max, size_t size) { int pack = 16; - int sizeDiv16 = UP_DIV(size, pack); - __m512 minVal = _mm512_loadu_ps(source); + int sizeDiv16 = size / pack; + __m512 minVal = _mm512_set1_ps(source[0]); __m512 maxVal = minVal; float maxArr[16], minArr[16]; - for (int i = 1; i < sizeDiv16; ++i) { + for (int i = 0; i < sizeDiv16; ++i) { auto src0 = source + pack * i; __m512 vecA = _mm512_loadu_ps(src0); auto maskMax = _mm512_cmp_ps_mask(vecA, maxVal, 14); @@ -62,14 +62,27 @@ void _AVX512_MNNComputeScaleZeroScalar(float* source, float* min, float* max, si min_ = minArr[k]; } } + for (int i = pack * sizeDiv16; i < size; ++i) { + min_ = ALIMIN(min_, source[i]); + max_ = ALIMAX(max_, source[i]); + } min[0] = min_; max[0] = max_; - // float range = max_ - min_; - // MNN_ASSERT(range != 0); - // *quantScale = 255.0f / range; - // *dequantScale = range / 255.0f; - // *zeroPoint = std::min(255.f, std::max(roundf(-(min_ * 255.f) / range), 0.f)) - 128.f; +} +void _AVX512_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { + // source: (ic/4, N, 4) + auto srcStep = pack * realSize; + for (int i = 0; i < realSize; ++i) { + float absmaxVal = 0.f; // absmaxVal>=0 + for (int c = 0; c < src_depth_quad; ++c) { + auto src = source + c * srcStep + i * pack; + for (int k = 0; k < pack; ++k) { + absmaxVal = std::max(absmaxVal, std::abs(src[k])); + } + } + absmax[i] = absmaxVal; + } } void _AVX512_MNNReluWithSlopeChannel(float* dst, const float* src, const float* slope, size_t sizeQuad, size_t depthQuad) { @@ -737,6 +750,7 @@ void _AVX512_ExtraInit(void* functions) { coreFunction->MNNMatrixAdd = _AVX512_MNNMatrixAdd; coreFunction->MNNMatrixSub = _AVX512_MNNMatrixSub; coreFunction->MNNCountMaxMinValue = _AVX512_MNNComputeScaleZeroScalar; + coreFunction->MNNAbsMax = _AVX512_MNNAbsMaxFP32; coreFunction->MNNConvRunForUnitDepthWise = _AVX512_MNNConvRunForUnitDepthWise; coreFunction->MNNConvRunForLineDepthwise = _AVX512_MNNConvRunForLineDepthwise; diff --git a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp index d0a2fed31..4f1525087 100644 --- a/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp +++ b/source/backend/cpu/x86_x64/sse/FunctionSummary.hpp @@ -59,12 +59,9 @@ void _SSE_MNNPackedMatMul_int8(float* C, const float* A, const float* B, const s const float* postParameters, const float* bias, const float* k, const float* b); void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, size_t eSize, const size_t* parameter, const float* postParameters, const float* bias, const float* k, const float* b); -void _SSE_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, - size_t dst_depth_quad, size_t realSize, const float** param); -void _SSE_MNNGemmHybridInt8(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, - size_t dst_depth_quad, size_t realSize, const float** param); void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack); -void _SSE_MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack); +void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, + size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst); #endif void _SSE_MNNPackC4ForMatMul_A(float* destOrigin, float const** sourceGroup, const int32_t* info, const int32_t* el); void _SSE_MNNConvRunForLineDepthwise(float* dst, const float* src, const float* weight, size_t width, size_t src_w_setup, diff --git a/source/backend/cpu/x86_x64/sse/GemmCommon.cpp b/source/backend/cpu/x86_x64/sse/GemmCommon.cpp index 53e21f0c7..40a372601 100644 --- a/source/backend/cpu/x86_x64/sse/GemmCommon.cpp +++ b/source/backend/cpu/x86_x64/sse/GemmCommon.cpp @@ -185,11 +185,11 @@ void _SSE_MNNPackedSparseMatMul(float* C, const float* A, const float* B, unsign void _SSE_MNNComputeScaleZeroScalar(float* source, float* min, float* max, size_t size) { int pack = 4; - int sizeDiv4 = UP_DIV(size, pack); - __m128 minVal = _mm_loadu_ps(source); + int sizeDiv4 = size / pack; + __m128 minVal = _mm_set1_ps(source[0]); __m128 maxVal = minVal; float maxArr[4], minArr[4]; - for (int i = 1; i < sizeDiv4; ++i) { + for (int i = 0; i < sizeDiv4; ++i) { auto src0 = source + pack * i; __m128 vecA = _mm_loadu_ps(src0); __m128 maskMax = _mm_cmpgt_ps(maxVal, vecA); @@ -200,7 +200,7 @@ void _SSE_MNNComputeScaleZeroScalar(float* source, float* min, float* max, size_ _mm_storeu_ps(maxArr, maxVal); _mm_storeu_ps(minArr, minVal); float max_ = maxArr[0], min_ = minArr[0]; - for (int k = 1; k < 4; ++k) { + for (int k = 1; k < pack; ++k) { if (max_ < maxArr[k]) { max_ = maxArr[k]; } @@ -208,13 +208,11 @@ void _SSE_MNNComputeScaleZeroScalar(float* source, float* min, float* max, size_ min_ = minArr[k]; } } + for (int i = pack * sizeDiv4; i < size; ++i) { + max_ = std::max(max_, source[i]); + min_ = std::min(min_, source[i]); + } min[0] = min_; max[0] = max_; - // float range = max_ - min_; - // MNN_ASSERT(range != 0); - // *quantScale = 255.0f / range; - // *dequantScale = range / 255.0f; - // *zeroPoint = std::min(255.f, std::max(roundf(-(min_ * 255.f) / range), 0.f)) - 128.0f; - } diff --git a/source/backend/cpu/x86_x64/sse/GemmFunction.hpp b/source/backend/cpu/x86_x64/sse/GemmFunction.hpp index e0272c184..89841f389 100644 --- a/source/backend/cpu/x86_x64/sse/GemmFunction.hpp +++ b/source/backend/cpu/x86_x64/sse/GemmFunction.hpp @@ -729,162 +729,4 @@ static void _SSE_MNNPackednMatMulRemainCommon_int8(float* C, const float* A, con } } } -// int4 -> int8 -static inline __m128i _load_int4_to_int8(const uint8_t* src) { - uint8_t c = 0xf; - int32_t data[4]; - int8_t temp[16]; - for (int i = 0; i < 8; ++i) { - temp[2 * i] = (src[i] >> 4); - temp[2 * i +1] = (src[i] & c); - } - auto int8_tx16 = _mm_loadu_si128((const __m128i*)temp); - return int8_tx16; -} -static void _SSE_MNNGemmHybrid_int4(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) { - // C:(oc/4,N,4) A:(ic/4,N,4) B:(oc/4,ic/4,4,4) - int pack = 4; - __m128i zero_128i = _mm_set1_epi32(0); - size_t weight_step = src_depth_quad * pack * pack * 0.5; - size_t weight_stride = pack * pack * 0.5; - const float* alpha_ptr = param[0]; - const float* zero_ptr = param[1]; - const float* bias_ptr = param[2]; - const float* sums_ptr = param[3]; - const float* scale_ptr = param[4]; - std::vector tmpsrc(16, 0); - - for (int ci = 0; ci < dst_depth_quad; ++ci) { - float* dstZ = C + ci * pack * realSize; - const int8_t* weight = B + ci * weight_step; - auto alpha = alpha_ptr + ci * pack; - auto zero = zero_ptr + ci * pack; - auto bias = bias_ptr + ci * pack; - __m128 alphaValue = _mm_load_ps(alpha); - //const float* sums = param[2]; - for (int j = 0; j < realSize; ++j) { - const float* sums = sums_ptr + j; - const float* scale = scale_ptr + j; - float* dstX = dstZ + j * pack; - __m128i sum4 = _mm_set1_epi32(0); - __m128 scaleValue = _mm_set1_ps(scale[0]); - __m128 biasValue = _mm_add_ps(_mm_load_ps(bias), _mm_mul_ps(_mm_load_ps(zero), _mm_set1_ps(sums[0]))); - const int8_t* srcBatch = A + j * pack; - for (int k = 0; k < src_depth_quad; ++k) { - const int8_t* srcZ = srcBatch + k * pack * realSize; - const uint8_t* weightZ = (uint8_t*)weight + k * weight_stride; - auto w0 = _load_int4_to_int8(weightZ); - - ::memcpy(tmpsrc.data(), srcZ, 4 * sizeof(int8_t)); - auto s0 = _mm_loadu_si128((const __m128i*)tmpsrc.data()); - // src,weight: int8->int16 - auto s0_16 = _mm_srai_epi16(_mm_unpacklo_epi8(zero_128i, s0), 8); - auto w0_16 = _mm_srai_epi16(_mm_unpacklo_epi8(zero_128i, w0), 8); - auto w1_16 = _mm_srai_epi16(_mm_unpackhi_epi8(zero_128i, w0), 8); - auto w2_16 = _mm_unpackhi_epi64(w0_16, zero_128i); - auto w3_16 = _mm_unpackhi_epi64(w1_16, zero_128i); - - auto oc0 = _mm_madd_epi16(s0_16, w0_16); - auto oc1 = _mm_madd_epi16(s0_16, w2_16); - auto oc2 = _mm_madd_epi16(s0_16, w1_16); - auto oc3 = _mm_madd_epi16(s0_16, w3_16); - - auto d0 = _mm_unpacklo_epi32(oc0, oc1); - auto d1 = _mm_unpackhi_epi32(oc0, oc1); - auto d2 = _mm_unpacklo_epi32(oc2, oc3); - auto d3 = _mm_unpackhi_epi32(oc2, oc3); - - auto e0 = _mm_unpacklo_epi64(d0, d2); - auto e1 = _mm_unpackhi_epi64(d0, d2); - auto e2 = _mm_unpacklo_epi64(d1, d3); - auto e3 = _mm_unpackhi_epi64(d1, d3); - - e0 = _mm_add_epi32(e0, e1); - e2 = _mm_add_epi32(e2, e3); - e0 = _mm_add_epi32(e0, e2); - - sum4 = _mm_add_epi32(e0, sum4); - - } - __m128 f0 = _mm_cvtepi32_ps(sum4); - __m128 fs = _mm_mul_ps(_mm_mul_ps(f0, scaleValue), alphaValue); - fs = _mm_add_ps(biasValue, fs); - _mm_storeu_ps(dstX, fs); - } - } -} -static void _SSE_MNNGemmHybrid_int8(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) { - // C:(oc/4,N,4) A:(ic/4,N,4) B:(oc/4,ic/4,4,4) - int pack = 4; - __m128i zero_128i = _mm_set1_epi32(0); - size_t weight_step = src_depth_quad * pack * pack; - size_t weight_stride = pack * pack; - const float* alpha_ptr = param[0]; - const float* zero_ptr = param[1]; - const float* bias_ptr = param[2]; - const float* sums_ptr = param[3]; - const float* scale_ptr = param[4]; - std::vector tmpsrc(16, 0); - - for (int ci = 0; ci < dst_depth_quad; ++ci) { - float* dstZ = C + ci * pack * realSize; - const int8_t* weight = B + ci * weight_step; - auto alpha = alpha_ptr + ci * pack; - auto zero = zero_ptr + ci * pack; - auto bias = bias_ptr + ci * pack; - __m128 alphaValue = _mm_load_ps(alpha); - //const float* sums = param[2]; - for (int j = 0; j < realSize; ++j) { - const float* sums = sums_ptr + j; - const float* scale = scale_ptr + j; - float* dstX = dstZ + j * pack; - - __m128i sum4 = _mm_set1_epi32(0); - __m128 scaleValue = _mm_set1_ps(scale[0]); - __m128 biasValue = _mm_add_ps(_mm_load_ps(bias), _mm_mul_ps(_mm_load_ps(zero), _mm_set1_ps(sums[0]))); - const int8_t* srcBatch = A + j * pack; - for (int k = 0; k < src_depth_quad; ++k) { - const int8_t* srcZ = srcBatch + k * pack * realSize; - const int8_t* weightZ = weight + k * weight_stride; - auto w0 = _mm_loadu_si128((__m128i*)(weightZ)); // 16xint8_t weight - - ::memcpy(tmpsrc.data(), srcZ, 4 * sizeof(int8_t)); - auto s0 = _mm_loadu_si128((const __m128i*)tmpsrc.data()); - // src,weight: int8->int16 -// auto s0_16 = _mm_unpacklo_epi8(s0, zero_128i); - auto s0_16 = _mm_srai_epi16(_mm_unpacklo_epi8(zero_128i, s0), 8); - auto w0_16 = _mm_srai_epi16(_mm_unpacklo_epi8(zero_128i, w0), 8); - auto w1_16 = _mm_srai_epi16(_mm_unpackhi_epi8(zero_128i, w0), 8); - auto w2_16 = _mm_unpackhi_epi64(w0_16, zero_128i); - auto w3_16 = _mm_unpackhi_epi64(w1_16, zero_128i); - - auto oc0 = _mm_madd_epi16(s0_16, w0_16); - auto oc1 = _mm_madd_epi16(s0_16, w2_16); - auto oc2 = _mm_madd_epi16(s0_16, w1_16); - auto oc3 = _mm_madd_epi16(s0_16, w3_16); - - auto d0 = _mm_unpacklo_epi32(oc0, oc1); - auto d1 = _mm_unpackhi_epi32(oc0, oc1); - auto d2 = _mm_unpacklo_epi32(oc2, oc3); - auto d3 = _mm_unpackhi_epi32(oc2, oc3); - - auto e0 = _mm_unpacklo_epi64(d0, d2); - auto e1 = _mm_unpackhi_epi64(d0, d2); - auto e2 = _mm_unpacklo_epi64(d1, d3); - auto e3 = _mm_unpackhi_epi64(d1, d3); - - e0 = _mm_add_epi32(e0, e1); - e2 = _mm_add_epi32(e2, e3); - e0 = _mm_add_epi32(e0, e2); - - sum4 = _mm_add_epi32(e0, sum4); - - } - __m128 f0 = _mm_cvtepi32_ps(sum4); - __m128 fs = _mm_mul_ps(_mm_mul_ps(f0, scaleValue), alphaValue); - fs = _mm_add_ps(biasValue, fs); - _mm_storeu_ps(dstX, fs); - } - } -} #endif diff --git a/source/backend/cpu/x86_x64/sse/GemmInt8.cpp b/source/backend/cpu/x86_x64/sse/GemmInt8.cpp index 2afb7144e..77702c2d4 100644 --- a/source/backend/cpu/x86_x64/sse/GemmInt8.cpp +++ b/source/backend/cpu/x86_x64/sse/GemmInt8.cpp @@ -22,11 +22,61 @@ void _SSE_MNNGemmInt8AddBiasScale_16x4_Unit(int8_t* dst, const int8_t* src, cons __m128 maxValue = _mm_set1_ps(post->maxValue); __m128 plus = _mm_set1_ps(0.5f); __m128 minus = _mm_set1_ps(-0.5f); + __m128 fp32min, fp32max; + if (0 == post->useInt8 && post->fp32minmax) { + fp32min = _mm_set1_ps((post->fp32minmax)[0]); + fp32max = _mm_set1_ps((post->fp32minmax)[1]); + } auto oneValue = _mm_set1_epi16(1); auto offset = _mm_set1_epi32(128); + auto f128 = _mm_set1_ps(128.f); + auto srcKernelSumPtr = post->srcKernelSum; + __m128 kernelSum0 = _mm_setzero_ps(); + __m128 kernelSum1 = _mm_setzero_ps(); + __m128 kernelSum2 = _mm_setzero_ps(); + __m128 kernelSum3 = _mm_setzero_ps(); + if (GEMM_INT8_DST_XUNIT == realDst) { + kernelSum0 = _mm_load_ps1(post->srcKernelSum); + kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1); + kernelSum2 = _mm_load_ps1(post->srcKernelSum + 2); + kernelSum3 = _mm_load_ps1(post->srcKernelSum + 3); + } else { + kernelSum0 = _mm_load_ps1(post->srcKernelSum); + if (realDst > 1) { + kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1); + } + if (realDst > 2) { + kernelSum2 = _mm_load_ps1(post->srcKernelSum + 2); + } + } + __m128 extrascale0 = _mm_setzero_ps(); + __m128 extrascale1 = _mm_setzero_ps(); + __m128 extrascale2 = _mm_setzero_ps(); + __m128 extrascale3 = _mm_setzero_ps(); + if (post->extraScale) { + if (GEMM_INT8_DST_XUNIT == realDst) { + extrascale0 = _mm_load_ps1(post->extraScale); + extrascale1 = _mm_load_ps1(post->extraScale + 1); + extrascale2 = _mm_load_ps1(post->extraScale + 2); + extrascale3 = _mm_load_ps1(post->extraScale + 3); + } else { + extrascale0 = _mm_load_ps1(post->extraScale); + if (realDst > 1) { + extrascale1 = _mm_load_ps1(post->extraScale + 1); + } + if (realDst > 2) { + extrascale2 = _mm_load_ps1(post->extraScale + 2); + } + } + } + const float* biasPtr = nullptr; + if (post->biasFloat) { + biasPtr = post->biasFloat; + } + auto blockNum = post->blockNum; for (int dz = 0; dz < dst_depth_quad; ++dz) { - const auto weight_dz = weight + dz * src_depth_quad * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); - const auto bias_dz = post->bias + dz * GEMM_INT8_UNIT; + const auto weight_dz = weight + dz * (src_depth_quad * blockNum) * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); + const auto weightBias_dz = post->weightQuanBias + dz * GEMM_INT8_UNIT; const float* scale_dz = nullptr; scale_dz = post->scale + dz * GEMM_INT8_UNIT; auto dst_z = dst + dz * dst_step_tmp; @@ -128,22 +178,60 @@ auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_ep E0 = _mm_hadd_epi32(E0, E1); E1 = _mm_hadd_epi32(E2, E3); d3 = _mm_hadd_epi32(E0, E1); - - auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz)); auto scaleValue = _mm_loadu_ps(scale_dz); - d0 = _mm_add_epi32(d0, biasValue); - d1 = _mm_add_epi32(d1, biasValue); - d2 = _mm_add_epi32(d2, biasValue); - d3 = _mm_add_epi32(d3, biasValue); + // auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz)); + // d0 = _mm_add_epi32(d0, biasValue); + // d1 = _mm_add_epi32(d1, biasValue); + // d2 = _mm_add_epi32(d2, biasValue); + // d3 = _mm_add_epi32(d3, biasValue); + //auto biasValue = _mm_loadu_ps((float*)(bias_dz)); + auto weightBiasValue = _mm_loadu_ps((float*)weightBias_dz); __m128 f0 = _mm_cvtepi32_ps(d0); __m128 f1 = _mm_cvtepi32_ps(d1); __m128 f2 = _mm_cvtepi32_ps(d2); __m128 f3 = _mm_cvtepi32_ps(d3); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm_mul_ps(kernelSum1, weightBiasValue); // ..second + auto xy0_2 = _mm_mul_ps(kernelSum2, weightBiasValue); // .. third + auto xy0_3 = _mm_mul_ps(kernelSum3, weightBiasValue); // ..fourth f0 = _mm_mul_ps(f0, scaleValue); f1 = _mm_mul_ps(f1, scaleValue); f2 = _mm_mul_ps(f2, scaleValue); f3 = _mm_mul_ps(f3, scaleValue); + if (post->extraScale) { + f0 = _mm_mul_ps(f0, extrascale0); + f1 = _mm_mul_ps(f1, extrascale1); + f2 = _mm_mul_ps(f2, extrascale2); + f3 = _mm_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * GEMM_INT8_UNIT; + auto extrabias = _mm_loadu_ps(extraB); + extrabias = _mm_mul_ps(f128, extrabias); + auto extrabias0 = _mm_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm_mul_ps(extrabias, extrascale3); + f0 = _mm_sub_ps(f0, extrabias0); + f1 = _mm_sub_ps(f1, extrabias1); + f2 = _mm_sub_ps(f2, extrabias2); + f3 = _mm_sub_ps(f3, extrabias3); + } + } + f0 = _mm_add_ps(f0, xy0_0); + f1 = _mm_add_ps(f1, xy0_1); + f2 = _mm_add_ps(f2, xy0_2); + f3 = _mm_add_ps(f3, xy0_3); + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; + auto biasValue = _mm_loadu_ps(bias_dz); + f0 = _mm_add_ps(f0, biasValue); + f1 = _mm_add_ps(f1, biasValue); + f2 = _mm_add_ps(f2, biasValue); + f3 = _mm_add_ps(f3, biasValue); + } if (post->useInt8 == 1) { + // for Relu Int8 activation f0 = _mm_min_ps(f0, maxValue); f1 = _mm_min_ps(f1, maxValue); f2 = _mm_min_ps(f2, maxValue); @@ -188,7 +276,24 @@ auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_ep } } } else { // Store float values directly. + // for Relu float activation. __m128 f[4] = {f0, f1, f2, f3}; + if (nullptr == biasPtr) { + for (int j = 0; j < realDst; ++j) { + auto dstv = _mm_loadu_ps(((float*)dst_x) + j * 4); + f[j] = _mm_add_ps(dstv, f[j]); + } + } + if (post->fp32minmax) { + f[0] = _mm_min_ps(f[0], fp32max); + f[1] = _mm_min_ps(f[1], fp32max); + f[2] = _mm_min_ps(f[2], fp32max); + f[3] = _mm_min_ps(f[3], fp32max); + f[0] = _mm_max_ps(f[0], fp32min); + f[1] = _mm_max_ps(f[1], fp32min); + f[2] = _mm_max_ps(f[2], fp32min); + f[3] = _mm_max_ps(f[3], fp32min); + } for (int j = 0; j < realDst; ++j) { _mm_storeu_ps(((float*)dst_x) + j * 4, f[j]); } @@ -196,6 +301,260 @@ auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_ep } } +static inline void _load_int4_to_int8(const uint8_t* src, int8_t* dst) { + uint8_t c = 0xf; + for (int i = 0; i < 32; ++i) { + dst[2 * i] = (src[i] >> 4); + dst[2 * i +1] = (src[i] & c); + } +} +void _SSE_MNNGemmInt8AddBiasScale_16x4_w4(int8_t* dst, const int8_t* src, const int8_t* weight, size_t src_depth_quad, size_t dst_step, + size_t dst_depth_quad, const QuanPostTreatParameters* post, size_t realDst) { + MNN_ASSERT(post->useInt8 == 0); + const auto dst_step_tmp = dst_step / sizeof(int8_t); + __m128i zero = _mm_set1_epi32(0); + __m128 minValue = _mm_set1_ps(post->minValue); + __m128 maxValue = _mm_set1_ps(post->maxValue); + __m128 fp32min, fp32max; + if (post->fp32minmax) { + fp32min = _mm_set1_ps((post->fp32minmax)[0]); + fp32max = _mm_set1_ps((post->fp32minmax)[1]); + } + const float* biasPtr = nullptr; + if (post->biasFloat) { + biasPtr = post->biasFloat; + } + int blockNum = post->blockNum; + int weight_step_Z = 0.5 * (src_depth_quad * blockNum) * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); + int weight_step_Y = 0.5 * (GEMM_INT8_UNIT * GEMM_INT8_SRC_UNIT); + + auto oneValue = _mm_set1_epi16(1); + auto offset = _mm_set1_epi32(128); + auto srcKernelSumPtr = post->srcKernelSum; + __m128 kernelSum0 = _mm_setzero_ps(); + __m128 kernelSum1 = _mm_setzero_ps(); + __m128 kernelSum2 = _mm_setzero_ps(); + __m128 kernelSum3 = _mm_setzero_ps(); + if (GEMM_INT8_DST_XUNIT == realDst) { + kernelSum0 = _mm_load_ps1(post->srcKernelSum); + kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1); + kernelSum2 = _mm_load_ps1(post->srcKernelSum + 2); + kernelSum3 = _mm_load_ps1(post->srcKernelSum + 3); + } else { + kernelSum0 = _mm_load_ps1(post->srcKernelSum); + if (realDst > 1) { + kernelSum1 = _mm_load_ps1(post->srcKernelSum + 1); + } + if (realDst > 2) { + kernelSum2 = _mm_load_ps1(post->srcKernelSum + 2); + } + } + auto f128 = _mm_set1_ps(128.f); + __m128 extrascale0 = _mm_setzero_ps(); + __m128 extrascale1 = _mm_setzero_ps(); + __m128 extrascale2 = _mm_setzero_ps(); + __m128 extrascale3 = _mm_setzero_ps(); + if (post->extraScale) { + if (GEMM_INT8_DST_XUNIT == realDst) { + extrascale0 = _mm_load_ps1(post->extraScale); + extrascale1 = _mm_load_ps1(post->extraScale + 1); + extrascale2 = _mm_load_ps1(post->extraScale + 2); + extrascale3 = _mm_load_ps1(post->extraScale + 3); + } else { + extrascale0 = _mm_load_ps1(post->extraScale); + if (realDst > 1) { + extrascale1 = _mm_load_ps1(post->extraScale + 1); + } + if (realDst > 2) { + extrascale2 = _mm_load_ps1(post->extraScale + 2); + } + } + } + for (int dz = 0; dz < dst_depth_quad; ++dz) { + const auto weight_dz = weight + dz * weight_step_Z; + const auto weightBias_dz = post->weightQuanBias + dz * GEMM_INT8_UNIT; + const float* scale_dz = nullptr; + scale_dz = post->scale + dz * GEMM_INT8_UNIT; + auto dst_z = dst + dz * dst_step_tmp; + const auto src_x = src; + auto dst_x = dst_z; + __m128i d0 = _mm_set1_epi32(0); + __m128i d1 = _mm_set1_epi32(0); + __m128i d2 = _mm_set1_epi32(0); + __m128i d3 = _mm_set1_epi32(0); + + __m128i e0 = _mm_set1_epi32(0); + __m128i e1 = _mm_set1_epi32(0); + __m128i e2 = _mm_set1_epi32(0); + __m128i e3 = _mm_set1_epi32(0); + + __m128i D0 = _mm_set1_epi32(0); + __m128i D1 = _mm_set1_epi32(0); + __m128i D2 = _mm_set1_epi32(0); + __m128i D3 = _mm_set1_epi32(0); + + __m128i E0 = _mm_set1_epi32(0); + __m128i E1 = _mm_set1_epi32(0); + __m128i E2 = _mm_set1_epi32(0); + __m128i E3 = _mm_set1_epi32(0); + + for (int sz = 0; sz < src_depth_quad; ++sz) { + const auto weight_sz = weight_dz + weight_step_Y * sz; + const auto src_z = src_x + sz * GEMM_INT8_DST_XUNIT * GEMM_INT8_SRC_UNIT; + + int8_t tmp_w[64]; + _load_int4_to_int8((uint8_t*)weight_sz, tmp_w); + + auto w0 = _mm_loadu_si128((__m128i*)(tmp_w + GEMM_INT8_SRC_UNIT * 0)); + auto w1 = _mm_loadu_si128((__m128i*)(tmp_w + GEMM_INT8_SRC_UNIT * 1)); + auto w2 = _mm_loadu_si128((__m128i*)(tmp_w + GEMM_INT8_SRC_UNIT * 2)); + auto w3 = _mm_loadu_si128((__m128i*)(tmp_w + GEMM_INT8_SRC_UNIT * 3)); + + auto s0 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 0)); + auto s1 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 1)); + auto s2 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 2)); + auto s3 = _mm_loadu_si128((__m128i*)(src_z + GEMM_INT8_SRC_UNIT * 3)); + + +//#define COMPUTE(i, j)\ +//auto d##i##j = _mm_maddubs_epi16(s##i, w##j);\ +//d##i##j = _mm_madd_epi16(d##i##j, oneValue);\ + +#define COMPUTE(i, j)\ +auto W##i##j##0 = _mm_srai_epi16(_mm_unpacklo_epi8(zero, w##j), 8);\ +auto W##i##j##1 = _mm_srai_epi16(_mm_unpackhi_epi8(zero, w##j), 8);\ +auto S##i##j##0 = _mm_unpacklo_epi8(s##i, zero);\ +auto S##i##j##1 = _mm_unpackhi_epi8(s##i, zero);\ +auto d##i##j = _mm_add_epi32(_mm_madd_epi16(S##i##j##0, W##i##j##0), _mm_madd_epi16(S##i##j##1, W##i##j##1));\ + + COMPUTE(0, 0); + COMPUTE(0, 1); + COMPUTE(0, 2); + COMPUTE(0, 3); + COMPUTE(1, 0); + COMPUTE(1, 1); + COMPUTE(1, 2); + COMPUTE(1, 3); + COMPUTE(2, 0); + COMPUTE(2, 1); + COMPUTE(2, 2); + COMPUTE(2, 3); + COMPUTE(3, 0); + COMPUTE(3, 1); + COMPUTE(3, 2); + COMPUTE(3, 3); + + d0 = _mm_add_epi32(d0, d00); + d1 = _mm_add_epi32(d1, d01); + d2 = _mm_add_epi32(d2, d02); + d3 = _mm_add_epi32(d3, d03); + + e0 = _mm_add_epi32(e0, d10); + e1 = _mm_add_epi32(e1, d11); + e2 = _mm_add_epi32(e2, d12); + e3 = _mm_add_epi32(e3, d13); + + D0 = _mm_add_epi32(D0, d20); + D1 = _mm_add_epi32(D1, d21); + D2 = _mm_add_epi32(D2, d22); + D3 = _mm_add_epi32(D3, d23); + + E0 = _mm_add_epi32(E0, d30); + E1 = _mm_add_epi32(E1, d31); + E2 = _mm_add_epi32(E2, d32); + E3 = _mm_add_epi32(E3, d33); + } + d0 = _mm_hadd_epi32(d0, d1); + d1 = _mm_hadd_epi32(d2, d3); + d0 = _mm_hadd_epi32(d0, d1); + + e0 = _mm_hadd_epi32(e0, e1); + e1 = _mm_hadd_epi32(e2, e3); + d1 = _mm_hadd_epi32(e0, e1); + + D0 = _mm_hadd_epi32(D0, D1); + D1 = _mm_hadd_epi32(D2, D3); + d2 = _mm_hadd_epi32(D0, D1); + + E0 = _mm_hadd_epi32(E0, E1); + E1 = _mm_hadd_epi32(E2, E3); + d3 = _mm_hadd_epi32(E0, E1); + auto scaleValue = _mm_loadu_ps(scale_dz); + // auto biasValue = _mm_loadu_si128((__m128i*)(bias_dz)); + // d0 = _mm_add_epi32(d0, biasValue); + // d1 = _mm_add_epi32(d1, biasValue); + // d2 = _mm_add_epi32(d2, biasValue); + // d3 = _mm_add_epi32(d3, biasValue); + //auto biasValue = _mm_loadu_ps((float*)(bias_dz)); + auto weightBiasValue = _mm_loadu_ps((float*)weightBias_dz); + __m128 f0 = _mm_cvtepi32_ps(d0); + __m128 f1 = _mm_cvtepi32_ps(d1); + __m128 f2 = _mm_cvtepi32_ps(d2); + __m128 f3 = _mm_cvtepi32_ps(d3); + // x_kernelSum x w_quantZero + auto xy0_0 = _mm_mul_ps(kernelSum0, weightBiasValue); // x dimemsion first + auto xy0_1 = _mm_mul_ps(kernelSum1, weightBiasValue); // ..second + auto xy0_2 = _mm_mul_ps(kernelSum2, weightBiasValue); // .. third + auto xy0_3 = _mm_mul_ps(kernelSum3, weightBiasValue); // ..fourth + f0 = _mm_mul_ps(f0, scaleValue); + f1 = _mm_mul_ps(f1, scaleValue); + f2 = _mm_mul_ps(f2, scaleValue); + f3 = _mm_mul_ps(f3, scaleValue); + if (post->extraScale) { + f0 = _mm_mul_ps(f0, extrascale0); + f1 = _mm_mul_ps(f1, extrascale1); + f2 = _mm_mul_ps(f2, extrascale2); + f3 = _mm_mul_ps(f3, extrascale3); + if (post->extraBias && nullptr != biasPtr) { + auto extraB = post->extraBias + dz * GEMM_INT8_UNIT; + auto extrabias = _mm_loadu_ps(extraB); + extrabias = _mm_mul_ps(f128, extrabias); + auto extrabias0 = _mm_mul_ps(extrabias, extrascale0); + auto extrabias1 = _mm_mul_ps(extrabias, extrascale1); + auto extrabias2 = _mm_mul_ps(extrabias, extrascale2); + auto extrabias3 = _mm_mul_ps(extrabias, extrascale3); + f0 = _mm_sub_ps(f0, extrabias0); + f1 = _mm_sub_ps(f1, extrabias1); + f2 = _mm_sub_ps(f2, extrabias2); + f3 = _mm_sub_ps(f3, extrabias3); + } + } + f0 = _mm_add_ps(f0, xy0_0); + f1 = _mm_add_ps(f1, xy0_1); + f2 = _mm_add_ps(f2, xy0_2); + f3 = _mm_add_ps(f3, xy0_3); + + if (nullptr != biasPtr) { + const auto bias_dz = biasPtr + dz * GEMM_INT8_UNIT; + auto biasValue = _mm_loadu_ps(bias_dz); + f0 = _mm_add_ps(f0, biasValue); + f1 = _mm_add_ps(f1, biasValue); + f2 = _mm_add_ps(f2, biasValue); + f3 = _mm_add_ps(f3, biasValue); + } + __m128 f[4] = {f0, f1, f2, f3}; + if (nullptr == biasPtr) { + for (int j = 0; j < realDst; ++j) { + auto dstv = _mm_loadu_ps(((float*)dst_x) + j * 4); + f[j] = _mm_add_ps(dstv, f[j]); + } + } + if (post->fp32minmax) { + f[0] = _mm_min_ps(f[0], fp32max); + f[1] = _mm_min_ps(f[1], fp32max); + f[2] = _mm_min_ps(f[2], fp32max); + f[3] = _mm_min_ps(f[3], fp32max); + f[0] = _mm_max_ps(f[0], fp32min); + f[1] = _mm_max_ps(f[1], fp32min); + f[2] = _mm_max_ps(f[2], fp32min); + f[3] = _mm_max_ps(f[3], fp32min); + } + for (int j = 0; j < realDst; ++j) { + _mm_storeu_ps(((float*)dst_x) + j * 4, f[j]); + } + } +} + void _SSE_MNNInt8ToInt16(int16_t* dest, const int8_t* sourceO, size_t count) { int countC16 = count / 16; int countR = count % 16; diff --git a/source/backend/cpu/x86_x64/sse/GemmSSE.cpp b/source/backend/cpu/x86_x64/sse/GemmSSE.cpp index 7d5699e96..8e5a32896 100644 --- a/source/backend/cpu/x86_x64/sse/GemmSSE.cpp +++ b/source/backend/cpu/x86_x64/sse/GemmSSE.cpp @@ -67,12 +67,6 @@ void _SSE_MNNPackedMatMulRemain_int8(float* C, const float* A, const float* B, s } } -void _SSE_MNNGemmHybridInt4(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) { - _SSE_MNNGemmHybrid_int4(C, A, B, src_depth_quad, dst_step, dst_depth_quad, realSize, param); -} -void _SSE_MNNGemmHybridInt8(float* C, const int8_t* A, const int8_t* B, size_t src_depth_quad, size_t dst_step, size_t dst_depth_quad, size_t realSize, const float** param) { - _SSE_MNNGemmHybrid_int8(C, A, B, src_depth_quad, dst_step, dst_depth_quad, realSize, param); -} // Dynamic quant void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_quad, size_t realSize, int pack) { // source: (ic/4, N, 4) @@ -99,36 +93,4 @@ void _SSE_MNNAbsMaxFP32(const float* source, float* absmax, size_t src_depth_qua absmax[i] = absmaxVal; } } - -void _SSE_MNNDynamicQuantFP32(const float* src, int8_t* dst, const float* scale, float* sum, size_t src_depth_quad, size_t realSize, int pack) { - // SSE: pack=4 - __m128 zero = _mm_setzero_ps(); - __m128 plus = _mm_set1_ps(0.5f); - __m128 minus = _mm_set1_ps(-0.5f); - auto offset = _mm_set1_epi32(128); - uint8_t* dstPtr = reinterpret_cast(dst); - float temp[4]; - for (int i = 0; i < realSize; ++i) { - __m128 scaleVal = _mm_load_ps1(scale + i); - __m128 acc = _mm_setzero_ps(); - for (int c = 0; c < src_depth_quad; ++c) { - auto srcZ = src + c * pack * realSize + i * pack; - auto dstZ = dstPtr + c * pack * realSize + i * pack; - __m128 f0 = _mm_loadu_ps(srcZ); - __m128 m0 = _mm_mul_ps(f0, scaleVal); - __m128 mask = _mm_cmplt_ps(m0, zero); - __m128 d0 = _mm_blendv_ps(plus, minus, mask); - d0 = _mm_add_ps(d0, m0); - __m128 round0 = _mm_round_ps(d0, 3); - auto d0_epi32 = _mm_cvtps_epi32(round0); - d0_epi32 = _mm_packs_epi32(d0_epi32, d0_epi32); - d0_epi32 = _mm_packs_epi16(d0_epi32, d0_epi32); - *((int*)dstZ) = _mm_cvtsi128_si32(d0_epi32); - acc = _mm_add_ps(acc, round0); - } - _mm_storeu_ps(temp, acc); - int sumVal = static_cast(temp[0] + temp[1] + temp[2] + temp[3]); - ((int32_t*)sum)[i] = sumVal; - } -} #endif diff --git a/source/backend/cpu/x86_x64/sse/MathFunctions.cpp b/source/backend/cpu/x86_x64/sse/MathFunctions.cpp index f5c66ce5f..b9e857006 100644 --- a/source/backend/cpu/x86_x64/sse/MathFunctions.cpp +++ b/source/backend/cpu/x86_x64/sse/MathFunctions.cpp @@ -328,7 +328,7 @@ void _SSE_MNNReluWithSlopeChannelInt8(int8_t* dst, const int8_t* src, const floa d0 = _mm_add_epi32(d0, offset); d0 = _mm_packs_epi32(d0, d0); d0 = _mm_packus_epi16(d0, d0); - *((int*)dst + i) = _mm_cvtsi128_si32(d0); + *((int*)dstZ + i) = _mm_cvtsi128_si32(d0); } } } diff --git a/source/backend/metal/MetalBackend.hpp b/source/backend/metal/MetalBackend.hpp index 589dd5fff..e01913a38 100644 --- a/source/backend/metal/MetalBackend.hpp +++ b/source/backend/metal/MetalBackend.hpp @@ -138,7 +138,7 @@ class MetalBackend : public Backend { */ static void addCreator(OpType type, Creator *creator); static void setTensor(const MNN::Tensor* tensor, id encoder, int index); - static std::pair, int> getBuffer(MNN::Tensor* tensor); + static std::pair, int> getBuffer(const MNN::Tensor* tensor); size_t getTensorSizeInBytes(const Tensor* tensor) const; virtual bool onSelectDynamicAllocator(int index, int maxIndex) override; id getHostBuffer(size_t size) const; @@ -207,6 +207,12 @@ class MetalBackend : public Backend { bool useFp16InsteadFp32() const { return mUseFloatAsFp16; } + struct CopyPipeline { + id pipeline; + id shape; + MTLSize localSize; + MTLSize groupSize; + }; private: MetalRuntimeAllocator::MetalBufferAlloc mEmptyMem; id getCommandBufferForBufferCopy() const; @@ -234,6 +240,8 @@ class MetalBackend : public Backend { std::shared_ptr mStaticBufferPool; private: + CopyPipeline _makeCopyInfo(const Tensor *src, const Tensor *dst, id shape, int castType) const; + mutable id mHostBuffer = nullptr; // hostmask: 0: no host, 1: src is host, 2: dst is host void onCopyDeviceToDevice(const Tensor *src, const Tensor *dst, id encoder, id shape, int hostmask = 0) const; diff --git a/source/backend/metal/MetalBackend.mm b/source/backend/metal/MetalBackend.mm index 57d800910..6f73629bb 100644 --- a/source/backend/metal/MetalBackend.mm +++ b/source/backend/metal/MetalBackend.mm @@ -36,7 +36,7 @@ static void _MetalApplyTensor(uint8_t* host, size_t offset, Tensor* t) { des->extra.offset = offset; } static BufferAllocator* _createBufferAllocator(const Runtime* runtime, BufferAllocator* origin, bool secondResize) { - if (runtime->getAllocatorType() == Runtime::Allocator_Defer && secondResize) { + if (runtime->hint().memoryAllocatorType == Runtime::Allocator_Defer && secondResize) { return new DeferBufferAllocator(BufferAllocator::Allocator::createRecurse(origin), 1024, _MetalApplyTensor); } return new EagerBufferAllocator(BufferAllocator::Allocator::createRecurse(origin), 1024); @@ -315,6 +315,9 @@ MemChunk chunk() override { } id MetalBackend::getHostBuffer(size_t size) const { + if (size < METAL_CONST_BUFFER_LIMIT) { + size = METAL_CONST_BUFFER_LIMIT; + } // reuse if (nullptr != mHostBuffer && mHostBuffer.length >= size) { return mHostBuffer; @@ -568,17 +571,15 @@ kernel void main0(const device IType *in [[buffer(0)]], device OType *out [[buff } return res; } -void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst, - id encoder, id shape, int castType) const { - auto ctx = (__bridge MNNMetalContext *)context(); - auto standalone = encoder == nil; - encoder = encoder ?: [getCommandBufferForBufferCopy() computeCommandEncoder]; +MetalBackend::CopyPipeline MetalBackend::_makeCopyInfo(const Tensor *src, const Tensor *dst, id shape, int castType) const { + auto ctx = (__bridge MNNMetalContext *)context(); + MetalBackend::CopyPipeline res; auto sfmt = TensorUtils::getDescribe(src)->dimensionFormat; auto dfmt = TensorUtils::getDescribe(dst)->dimensionFormat; if (shape == nil) { shape = getConstBuffer(8 * sizeof(int)); } - // copy + res.shape = shape; if (sfmt == dfmt || src->dimensions() <= 1) { auto srcType = _getType(src->getType(), MNN_DATA_FORMAT_NC4HW4, mUseFloatAsFp16 && castType != 1); auto dstType = _getType(dst->getType(), MNN_DATA_FORMAT_NC4HW4, mUseFloatAsFp16 && castType != 2); @@ -589,6 +590,7 @@ kernel void main0(const device IType *in [[buffer(0)]], device OType *out [[buff srcType, dstType }; + ((uint32_t*)[shape contents])[0] = size; id pipeline = mRuntime->findPipeline(keys); if (nil == pipeline) { MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; @@ -599,16 +601,14 @@ kernel void main0(const device IType *in [[buffer(0)]], device OType *out [[buff pipeline = makeComputePipelineWithSourceOption(gCopy, "main0", option); mRuntime->insertPipeline(keys, pipeline); } - [encoder setComputePipelineState:pipeline]; - ((uint32_t*)[shape contents])[0] = size; - setTensor(src, encoder, 0); - setTensor(dst, encoder, 1); - [encoder setBuffer:shape offset:0 atIndex:2]; - [encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(size, 256), 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; - } - else if (sfmt == MNN_DATA_FORMAT_NC4HW4 || dfmt == MNN_DATA_FORMAT_NC4HW4) { - auto srcType = _getType(src->getType(), sfmt, mUseFloatAsFp16 && castType != 1); - auto dstType = _getType(dst->getType(), dfmt, mUseFloatAsFp16 && castType != 2); + res.groupSize = MTLSizeMake(UP_DIV(size, 256), 1, 1); + res.localSize = MTLSizeMake(256, 1, 1); + res.pipeline = pipeline; + return res; + } + auto srcType = _getType(src->getType(), sfmt, mUseFloatAsFp16 && castType != 1); + auto dstType = _getType(dst->getType(), dfmt, mUseFloatAsFp16 && castType != 2); + if (sfmt == MNN_DATA_FORMAT_NC4HW4 || dfmt == MNN_DATA_FORMAT_NC4HW4) { auto normalTensor = dst; if (dfmt == MNN_DATA_FORMAT_NC4HW4) { normalTensor = src; @@ -635,52 +635,62 @@ kernel void main0(const device IType *in [[buffer(0)]], device OType *out [[buff pipeline = makeComputePipelineWithSourceOption(gNC4HW4Convert, "main0", option); mRuntime->insertPipeline(keys, pipeline); } - [encoder setComputePipelineState:pipeline]; + res.pipeline = pipeline; auto size = getTensorShape(shape, normalTensor); - MetalBackend::setTensor(src, encoder, 0); - MetalBackend::setTensor(dst, encoder, 1); - [encoder setBuffer:shape offset:0 atIndex:2]; auto gl = [ctx computeBestGroupAndLocal:pipeline threads:size]; - [encoder dispatchThreadgroups:gl.first threadsPerThreadgroup:gl.second]; + res.groupSize = gl.first; + res.localSize = gl.second; + return res; + } + // NCHW <-> NHWC + std::vector keys = { + "transpose", + srcType, + dstType + }; + id pipeline = mRuntime->findPipeline(keys); + if (nil == pipeline) { + MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; + auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; + [dic setValue:@(keys[1].c_str()) forKey:@"IType"]; + [dic setValue:@(keys[2].c_str()) forKey:@"OType"]; + option.preprocessorMacros = dic; + pipeline = makeComputePipelineWithSourceOption(gTranspose, "main0", option); + mRuntime->insertPipeline(keys, pipeline); + } + res.pipeline = pipeline; + int n, c, plane; + _getNCPlane(dst, plane, c, n); + auto shapePtr = (uint32_t*)shape.contents; + shapePtr[0] = n; + shapePtr[3] = 1; + if (MNN_DATA_FORMAT_NHWC == dfmt) { + shapePtr[1] = plane; + shapePtr[2] = c; } else { - // NCHW <-> NHWC - auto srcType = _getType(src->getType(), sfmt, mUseFloatAsFp16 && castType != 1); - auto dstType = _getType(dst->getType(), dfmt, mUseFloatAsFp16 && castType != 2); - std::vector keys = { - "transpose", - srcType, - dstType - }; - id pipeline = mRuntime->findPipeline(keys); - if (nil == pipeline) { - MTLCompileOptions *option = [[MTLCompileOptions alloc] init]; - auto dic = [NSMutableDictionary dictionaryWithCapacity:0]; - [dic setValue:@(keys[1].c_str()) forKey:@"IType"]; - [dic setValue:@(keys[2].c_str()) forKey:@"OType"]; - option.preprocessorMacros = dic; - pipeline = makeComputePipelineWithSourceOption(gTranspose, "main0", option); - mRuntime->insertPipeline(keys, pipeline); - } - [encoder setComputePipelineState:pipeline]; - int n, c, plane; - _getNCPlane(dst, plane, c, n); - auto shapePtr = (uint32_t*)shape.contents; - shapePtr[0] = n; - shapePtr[3] = 1; - if (MNN_DATA_FORMAT_NHWC == dfmt) { - shapePtr[1] = plane; - shapePtr[2] = c; - } else { - shapePtr[1] = c; - shapePtr[2] = plane; - } - auto size = plane * n * c; - setTensor(src, encoder, 0); - setTensor(dst, encoder, 1); - [encoder setBuffer:shape offset:0 atIndex:2]; - [encoder dispatchThreadgroups:MTLSizeMake(UP_DIV(size, 256), 1, 1) threadsPerThreadgroup:MTLSizeMake(256, 1, 1)]; + shapePtr[1] = c; + shapePtr[2] = plane; } + auto size = plane * n * c; + res.localSize = MTLSizeMake(256, 1, 1); + res.groupSize = MTLSizeMake(UP_DIV(size, 256), 1, 1); + return res; +} +static void _execute(id encoder, const MetalBackend::CopyPipeline& info, std::pair, int> src, std::pair, int> dst) { + [encoder setComputePipelineState:info.pipeline]; + [encoder setBuffer:src.first offset:src.second atIndex:0]; + [encoder setBuffer:dst.first offset:dst.second atIndex:1]; + [encoder setBuffer:info.shape offset:0 atIndex:2]; + [encoder dispatchThreadgroups:info.groupSize threadsPerThreadgroup:info.localSize]; +} +void MetalBackend::onCopyDeviceToDevice(const Tensor *src, const Tensor *dst, + id encoder, id shape, int castType) const { + auto ctx = (__bridge MNNMetalContext *)context(); + auto info = _makeCopyInfo(src, dst, shape, castType); + auto standalone = encoder == nil; + encoder = encoder ?: [getCommandBufferForBufferCopy() computeCommandEncoder]; + _execute(encoder, info, MetalBackend::getBuffer(src), MetalBackend::getBuffer(dst)); if (standalone) { [encoder endEncoding]; MNN_PRINT_ENCODER(ctx, encoder); @@ -724,22 +734,21 @@ kernel void main0(const device IType *in [[buffer(0)]], device OType *out [[buff if (needConvert) { auto tDst = const_cast(dst); auto tmpBuffer = getHostBuffer(dst->usize()); - MetalRuntimeAllocator::MetalBufferAlloc tmp(tmpBuffer); - TensorUtils::getDescribe(tDst)->extra.offset = 0; - tDst->buffer().device = (uint64_t)(&tmp); - onCopyDeviceToDevice(src, dst, nullptr, nullptr, 2); - tDst->buffer().device = 0; - devicePtr = (uint8_t*)tmpBuffer.contents; + auto info = _makeCopyInfo(src, dst, shape, 2); + auto standalone = encoder == nil; + encoder = encoder ?: [getCommandBufferForBufferCopy() computeCommandEncoder]; + _execute(encoder, info, MetalBackend::getBuffer(src), std::make_pair(tmpBuffer, 0)); + if (standalone) { + [encoder endEncoding]; + } commit(); + devicePtr = (uint8_t*)tmpBuffer.contents; } wait(); ::memcpy(dst->host(), devicePtr, dst->usize()); return; } if (src->buffer().host && !dst->buffer().host) { - auto device = (id)((MetalRuntimeAllocator::MetalBufferAlloc *)dst->deviceId())->getBuffer(); - auto devicePtr = (uint8_t*)device.contents + TensorUtils::getDescribe(dst)->extra.offset; - // For command queue from user, need user to make sure last frame's gpu work is ready bool needWait = !mRuntime->userSync(); if (needWait) { @@ -749,13 +758,17 @@ kernel void main0(const device IType *in [[buffer(0)]], device OType *out [[buff if (needConvert) { auto tmpBuffer = getHostBuffer(srcSize); ::memcpy(tmpBuffer.contents, src->host(), srcSize); - MetalRuntimeAllocator::MetalBufferAlloc tmp(tmpBuffer); - auto tSrc = const_cast(src); - TensorUtils::getDescribe(tSrc)->extra.offset = 0; - tSrc->buffer().device = (uint64_t)(&tmp); - onCopyDeviceToDevice(tSrc, dst, nullptr, nullptr, 1); - tSrc->buffer().device = 0; + auto info = _makeCopyInfo(src, dst, shape, 1); + auto standalone = encoder == nil; + encoder = encoder ?: [getCommandBufferForBufferCopy() computeCommandEncoder]; + _execute(encoder, info, std::make_pair(tmpBuffer, 0), MetalBackend::getBuffer(dst)); + if (standalone) { + [encoder endEncoding]; + } + commit(); } else { + auto device = (id)((MetalRuntimeAllocator::MetalBufferAlloc *)dst->deviceId())->getBuffer(); + auto devicePtr = (uint8_t*)device.contents + TensorUtils::getDescribe(dst)->extra.offset; ::memcpy(devicePtr, src->host(), srcSize); } return; @@ -797,7 +810,7 @@ kernel void main0(const device IType *in [[buffer(0)]], device OType *out [[buff void MetalBackend::setTensor(const MNN::Tensor* tensor, id encoder, int index) { [encoder setBuffer:((MetalRuntimeAllocator::MetalBufferAlloc *)tensor->deviceId())->getBuffer() offset:TensorUtils::getDescribe(tensor)->extra.offset atIndex:index]; } -std::pair, int> MetalBackend::getBuffer(MNN::Tensor* tensor) { +std::pair, int> MetalBackend::getBuffer(const MNN::Tensor* tensor) { return std::make_pair(((MetalRuntimeAllocator::MetalBufferAlloc *)tensor->deviceId())->getBuffer(), TensorUtils::getDescribe(tensor)->extra.offset); } diff --git a/source/backend/opencl/core/OpenCLBackend.cpp b/source/backend/opencl/core/OpenCLBackend.cpp index 74f765912..d9bc65fab 100644 --- a/source/backend/opencl/core/OpenCLBackend.cpp +++ b/source/backend/opencl/core/OpenCLBackend.cpp @@ -338,7 +338,7 @@ Backend::MemObj* OpenCLBackend::onAcquire(const Tensor* nativeTensor, StorageTyp size = N * alignC * W * H; size = size + hR * W * 4 + wR * 4; } else { - size = nativeTensor->elementSize(); + size = N * H * W * C; size = ROUND_UP(size, 4); } if (mOpenCLRuntime->isSupportedIntelSubgroup()) { diff --git a/source/backend/opencl/core/OpenCLGemmTune.cpp b/source/backend/opencl/core/OpenCLGemmTune.cpp index 870bf7530..00cd3ed98 100644 --- a/source/backend/opencl/core/OpenCLGemmTune.cpp +++ b/source/backend/opencl/core/OpenCLGemmTune.cpp @@ -34,17 +34,10 @@ static void generateCombinations(const std::vector> &candi static bool isCandidateValid(uint32_t kwg, uint32_t kwi, uint32_t mwg, uint32_t mdimc, uint32_t vwm, uint32_t nwg, uint32_t ndimc, uint32_t vwn, uint32_t mdima, uint32_t ndimb, uint32_t sa, uint32_t sb, OpenCLRuntime *runtime, const std::vector& gemmSize) { // problem size align - if(gemmSize[0] % mwg != 0 || gemmSize[1] % nwg != 0 || gemmSize[2] % kwg != 0) { - return false; - } - // mwg nwg only for M N equal to 16 - if((gemmSize[0] > 16 && mwg == 16) || (gemmSize[1] > 16 && nwg == 16)) { - return false; - } - // params align - if(kwg % kwi != 0) { + if(gemmSize[0] % mwg != 0 || gemmSize[1] % nwg != 0) { return false; } + if(mwg % (mdimc * vwm) != 0 || mwg % (mdima * vwm) != 0) { return false; } @@ -53,9 +46,19 @@ static bool isCandidateValid(uint32_t kwg, uint32_t kwi, uint32_t mwg, uint32_t } uint32_t kdima = (mdimc * ndimc) / mdima; uint32_t kdimb = (mdimc * ndimc) / ndimb; - if(kwg % kdima != 0 || kwg % kdimb != 0) { - return false; + if(sa == 1 || sb == 1) { + // params align + if(kwg % kwi != 0) { + return false; + } + if(kwg % kdima != 0 || kwg % kdimb != 0) { + return false; + } + if(gemmSize[2] % kwg != 0) { + return false; + } } + if(mdimc != mdima || ndimc != ndimb) { return false; } @@ -63,6 +66,11 @@ static bool isCandidateValid(uint32_t kwg, uint32_t kwi, uint32_t mwg, uint32_t return false; } + // no local memory no need tune kwg + if(sa == 0 && sb == 0 && kwg == 32) { + return false; + } + // local memory limit uint32_t local_mem_size = 0; if(sa) { @@ -89,27 +97,31 @@ static bool isCandidateValid(uint32_t kwg, uint32_t kwi, uint32_t mwg, uint32_t if(mdimc != mdima || ndimc != ndimb) { return false; } + + bool totalLarge = 1.0 * gemmSize[0] / 1024 * gemmSize[1] / 1024 * gemmSize[2] / 1024 >= 0.5; bool dimLarge = gemmSize[0] > 128 && gemmSize[1] > 128 && gemmSize[2] > 128; - if(totalLarge && dimLarge) { - if(mwg * nwg < 128 * 64) { - return false; - } - if(mdimc * ndimc < 16 * 8) { - return false; - } - if(vwm * vwn < 4 * 4) { - return false; - } - } else { - if(mwg * nwg > 128 * 64) { - return false; - } - if(mdimc * ndimc > 16 * 8) { - return false; - } - if(vwm * vwn > 4 * 4) { - return false; + if(gemmSize[4] == 1) { + if(totalLarge && dimLarge) { + if(mwg * nwg < 128 * 64) { + return false; + } + if(mdimc * ndimc < 16 * 8) { + return false; + } + if(vwm * vwn < 4 * 4) { + return false; + } + } else { + if(mwg * nwg > 128 * 64) { + return false; + } + if(mdimc * ndimc > 16 * 8) { + return false; + } + if(vwm * vwn > 4 * 4) { + return false; + } } } @@ -118,14 +130,13 @@ static bool isCandidateValid(uint32_t kwg, uint32_t kwi, uint32_t mwg, uint32_t std::vector getGemmParams(const std::vector &gemmSize, const std::vector tensorMemory, OpenCLRuntime *runtime) { - MNN_ASSERT(gemmSize.size() == 5); // M, N, K, Layout, B + MNN_ASSERT(gemmSize.size() == 6); // M, N, K, Layout, Batch, Bias MNN_ASSERT(gemmSize[0] % 16 == 0); MNN_ASSERT(gemmSize[1] % 16 == 0); - MNN_ASSERT(gemmSize[2] % 16 == 0); + MNN_ASSERT(gemmSize[2] % 4 == 0); - MNN_ASSERT(tensorMemory.size() == 3); + MNN_ASSERT((gemmSize[5] == 0 && tensorMemory.size() == 3) || (gemmSize[5] == 1 && tensorMemory.size() == 4)); auto& tunedGemmParams = runtime->tunedGemmParamsMap(); - std::vector info(gemmSize); uint32_t isFp16 = runtime->isSupportedFP16(); @@ -153,123 +164,111 @@ std::vector getGemmParams(const std::vector &gemmSize, const if(gemmSize[0] >= 256 && gemmSize[1] >= 256 && gemmSize[2] >= 256) { if(multiNum > 8.0) { if(maxDivsorM >= 128 && maxDivsorN >= 64) { - return {0, 1, 16, 2, 16, 16, 128, 8, 8, 64, 0, 0, 0, 1, 8, 8}; + return {16, 2, 16, 16, 128, 8, 8, 64, 0, 0, 0, 1, 8, 8}; } } if(maxDivsorM >= 64 && maxDivsorN >= 64) { - return {0, 1, 16, 2, 8, 8, 64, 8, 8, 64, 0, 0, 0, 1, 8, 8}; + return {16, 2, 8, 8, 64, 8, 8, 64, 0, 0, 0, 1, 8, 8}; } } } else {// BatchGemm if(maxDivsorM >= 64 && maxDivsorN >= 128) { - return {0, 1, 16, 2, 16, 16, 64, 8, 8, 128, 0, 0, 1, 0, 2, 8}; + return {16, 2, 16, 16, 64, 8, 8, 128, 0, 0, 1, 0, 2, 8}; } else if(maxDivsorM >= 64 && maxDivsorN >= 64) { - return {0, 1, 16, 2, 8, 8, 64, 8, 8, 64, 0, 0, 1, 0, 4, 4}; + return {16, 2, 8, 8, 64, 8, 8, 64, 0, 0, 1, 0, 4, 4}; } } - return {0, 1, 16, 2, 4, 4, 16, 4, 4, 16, 0, 0, 1, 0, 2, 2}; + return {16, 2, 4, 4, 16, 4, 4, 16, 0, 0, 1, 0, 2, 2}; } std::vector> totalCombinations; // save total candidate combinations - std::vector params_prefer = {0, 1, 16, 2, 4, 4, 16, 4, 4, 16, 0, 0, 1, 0, 2, 2}; + std::vector params_prefer = {16, 2, 4, 4, 16, 4, 4, 16, 0, 0, 1, 0, 2, 2}; totalCombinations.emplace_back(params_prefer); uint32_t min_cost = UINT_MAX; - if(runtime->getCLTuneLevel() >= Normal) { - // set candidates - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 1, 1, 2, 8});//12 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 0, 0, 8, 8});//11 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 8, 8});//4 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 2, 8});//2 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 1, 8, 8}); - totalCombinations.push_back({0, 1, 16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 0, 0, 2, 8}); - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 0, 0, 4, 4}); - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 32 , 0, 0, 1, 0, 4, 4});//1 - totalCombinations.push_back({0, 1, 16, 2, 8, 8 , 32 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//2 - - if(runtime->getCLTuneLevel() == Normal) { - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 1, 1, 2, 8});//10 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 0, 0, 8, 8});//6 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 0, 1, 8, 8});//6 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 2, 8});//4 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 8, 8});//4 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 0, 1, 8, 8});//4 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 2, 8});//3 - totalCombinations.push_back({0, 1, 16, 2, 8, 8 , 64 , 8 , 8 , 64 , 0, 0, 1, 0, 2, 8});//1 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 4, 4});//1 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 8, 8});//2 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//3 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 32 , 0, 0, 1, 1, 4, 4});//1 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 32 , 0, 0, 1, 1, 4, 4});//1 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 0, 0, 8, 8});//1 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 16, 16, 128, 0, 0, 0, 0, 8, 8});//2 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 16, 16, 128, 0, 0, 0, 1, 8, 8});//2 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 8, 8});//2 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 16, 16, 128, 0, 0, 1, 0, 8, 8});//1 - totalCombinations.push_back({0, 1, 32, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//1 - totalCombinations.push_back({0, 1, 32, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 1, 1, 2, 8});//1 - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 0, 1, 4, 4});//1 - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 0, 1, 4, 4}); - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 1, 0, 4, 4}); - totalCombinations.push_back({0, 1, 16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8}); - totalCombinations.push_back({0, 1, 32, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 0, 0, 2, 8}); - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 2, 8}); - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 4, 8}); - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 0, 8, 8}); - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 1, 8, 8}); - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 0, 8, 8}); - totalCombinations.push_back({0, 1, 32, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 1, 1, 8, 8}); - totalCombinations.push_back({0, 1, 16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 1, 0, 8, 8}); + if(runtime->getCLTuneLevel() >= Wide) { + // set candidates= + totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 0, 0, 4, 8});//12 + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 0, 0, 8, 8});//11 .. + totalCombinations.push_back({16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 0, 0, 8, 8});//1 + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 32 , 0, 0, 0, 1, 8, 4});//1 + totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 64, 0, 0, 0, 0, 2, 8}); + totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 0, 1, 4, 8});//10 + + totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 0, 0, 4, 4}); + totalCombinations.push_back({16, 2, 8, 8 , 32 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//2 + totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 1, 1, 2, 8});//12 + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 2, 8});//2 + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 0, 8, 8}); + totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 0, 0, 2, 8}); + + if(runtime->getCLTuneLevel() < Fast) { + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 8, 8});//4 + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 0, 1, 8, 8});//6 + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 8, 8});//4 + + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 2, 8});//3 + totalCombinations.push_back({16, 2, 8, 8 , 64 , 8 , 8 , 64 , 0, 0, 1, 0, 2, 8});//1 + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 1, 4, 4});//1 + totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//3 + + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 32 , 0, 0, 0, 0, 4, 4});//1 + totalCombinations.push_back({16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 0, 1, 8, 8});//2 + totalCombinations.push_back({16, 2, 16, 16, 128, 16, 16, 128, 0, 0, 1, 0, 8, 8});//1 + totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 1, 0, 2, 8});//1 + totalCombinations.push_back({16, 2, 8 , 8 , 16 , 8 , 8 , 128, 0, 0, 1, 1, 2, 8});//1 + + totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 0, 1, 4, 4});//1 + totalCombinations.push_back({16, 2, 16, 16, 64 , 8 , 8 , 32 , 0, 0, 1, 0, 4, 4}); + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 64 , 0, 0, 1, 0, 4, 8}); + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 0, 1, 8, 8}); + totalCombinations.push_back({16, 2, 16, 16, 128, 8 , 8 , 128, 0, 0, 1, 1, 8, 8}); + + totalCombinations.push_back({16, 2, 8, 8, 32, 8, 8, 32, 0, 0, 1, 0, 2, 4}); + totalCombinations.push_back({16, 2, 8, 8, 16, 8, 8, 32, 0, 0, 1, 1, 2, 4}); } } else { // get all combinations std::vector> candidates = { - {0}, // GEMMK - {16, 32, 64, 128}, // MWG - {16, 32, 64, 128}, // NWG {16, 32}, // KWG - {8, 16}, // MDIMC - {8, 16}, // NDIMC + {2}, // KWI {8, 16}, // MDIMA + {8, 16}, // MDIMC + {16, 32, 64, 128}, // MWG {8, 16}, // NDIMB - {2}, // KWI - {2, 4, 8}, // VWM - {2, 4, 8}, // VWN - {0, 1}, // STRM - {0, 1}, // STRN + {8, 16}, // NDIMC + {16, 32, 64, 128}, // NWG {0}, // SA {0}, // SB - {1} // KREG + {0, 1}, // STRM + {0, 1}, // STRN + {2, 4, 8}, // VWM + {2, 4, 8} // VWN }; std::vector currentCombination(candidates.size()); generateCombinations(candidates, currentCombination, totalCombinations, 0); } - for(int i = 0; i < totalCombinations.size(); i++) { - uint32_t gemmk = totalCombinations[i][0]; - uint32_t kreg = totalCombinations[i][1]; - uint32_t kwg = totalCombinations[i][2]; - uint32_t kwi = totalCombinations[i][3]; - uint32_t mdima = totalCombinations[i][4]; - uint32_t mdimc = totalCombinations[i][5]; - uint32_t mwg = totalCombinations[i][6]; - uint32_t ndimb = totalCombinations[i][7]; - uint32_t ndimc = totalCombinations[i][8]; - uint32_t nwg = totalCombinations[i][9]; - uint32_t sa = totalCombinations[i][10]; - uint32_t sb = totalCombinations[i][11]; - uint32_t strm = totalCombinations[i][12]; - uint32_t strn = totalCombinations[i][13]; - uint32_t vwm = totalCombinations[i][14]; - uint32_t vwn = totalCombinations[i][15]; + uint32_t kwg = totalCombinations[i][0]; + uint32_t kwi = totalCombinations[i][1]; + uint32_t mdima = totalCombinations[i][2]; + uint32_t mdimc = totalCombinations[i][3]; + uint32_t mwg = totalCombinations[i][4]; + uint32_t ndimb = totalCombinations[i][5]; + uint32_t ndimc = totalCombinations[i][6]; + uint32_t nwg = totalCombinations[i][7]; + uint32_t sa = totalCombinations[i][8]; + uint32_t sb = totalCombinations[i][9]; + uint32_t strm = totalCombinations[i][10]; + uint32_t strn = totalCombinations[i][11]; + uint32_t vwm = totalCombinations[i][12]; + uint32_t vwn = totalCombinations[i][13]; if(isCandidateValid(kwg, kwi, mwg, mdimc, vwm, nwg, ndimc, vwn, mdima, ndimb, sa, sb, runtime, gemmSize)) { std::set buildOptions; buildOptions.clear(); - buildOptions.emplace("-DGEMMK=" + std::to_string(gemmk)); - buildOptions.emplace("-DKREG=" + std::to_string(kreg)); buildOptions.emplace("-DKWG=" + std::to_string(kwg)); buildOptions.emplace("-DKWI=" + std::to_string(kwi)); buildOptions.emplace("-DMDIMA=" + std::to_string(mdima)); @@ -293,12 +292,10 @@ std::vector getGemmParams(const std::vector &gemmSize, const buildOptions.emplace(" -DRELAX_WORKGROUP_SIZE=1"); } - if(runtime->isSupportedFP16()){ - buildOptions.emplace(" -DPRECISION=16"); - } else { - buildOptions.emplace(" -DPRECISION=32"); + if(gemmSize[5] == 1) { + buildOptions.emplace(" -DBIAS"); } - + int localM = mdimc; int localN = ndimc; @@ -328,9 +325,6 @@ std::vector getGemmParams(const std::vector &gemmSize, const float beta = 0.0f; // A: [n, l, e] // B: [n, l, h] - int offset_a = 0; - int offset_b = 0; - int offset_c = 0; cl::Event event; int idx = 0; @@ -341,15 +335,20 @@ std::vector getGemmParams(const std::vector &gemmSize, const ret |= kernel->get().setArg(idx++, alpha); ret |= kernel->get().setArg(idx++, beta); if(gemmSize[4] > 1) { + int batch_offset_a = gemmSize[0] * gemmSize[2]; + int batch_offset_b = gemmSize[1] * gemmSize[2]; + int batch_offset_c = gemmSize[0] * gemmSize[1]; + ret |= kernel->get().setArg(idx++, tensorMemory[0]); - ret |= kernel->get().setArg(idx++, gemmSize[0]); - ret |= kernel->get().setArg(idx++, gemmSize[2]); + ret |= kernel->get().setArg(idx++, batch_offset_a); ret |= kernel->get().setArg(idx++, tensorMemory[1]); - ret |= kernel->get().setArg(idx++, gemmSize[1]); - ret |= kernel->get().setArg(idx++, gemmSize[2]); + ret |= kernel->get().setArg(idx++, batch_offset_b); + if(gemmSize[5] == 1) { + ret |= kernel->get().setArg(idx++, tensorMemory[3]); + ret |= kernel->get().setArg(idx++, gemmSize[1]); + } ret |= kernel->get().setArg(idx++, tensorMemory[2]); - ret |= kernel->get().setArg(idx++, gemmSize[0]); - ret |= kernel->get().setArg(idx++, gemmSize[1]); + ret |= kernel->get().setArg(idx++, batch_offset_c); MNN_CHECK_CL_SUCCESS(ret, "setArg getGemmParams XgemmBatchhed Kernel"); @@ -360,8 +359,15 @@ std::vector getGemmParams(const std::vector &gemmSize, const continue; } } else { + int offset_a = 0; + int offset_b = 0; + int offset_c = 0; + ret |= kernel->get().setArg(idx++, tensorMemory[0]); ret |= kernel->get().setArg(idx++, tensorMemory[1]); + if(gemmSize[5] == 1) { + ret |= kernel->get().setArg(idx++, tensorMemory[3]); + } ret |= kernel->get().setArg(idx++, tensorMemory[2]); ret |= kernel->get().setArg(idx++, offset_a); ret |= kernel->get().setArg(idx++, offset_b); @@ -381,27 +387,26 @@ std::vector getGemmParams(const std::vector &gemmSize, const int cost_time = (int)runtime->getCostTime(&event); if(cost_time < min_cost) { min_cost = cost_time; - params_prefer[0] = gemmk; - params_prefer[1] = kreg; - params_prefer[2] = kwg; - params_prefer[3] = kwi; - params_prefer[4] = mdima; - params_prefer[5] = mdimc; - params_prefer[6] = mwg; - params_prefer[7] = ndimb; - params_prefer[8] = ndimc; - params_prefer[9] = nwg; - params_prefer[10] = sa; - params_prefer[11] = sb; - params_prefer[12] = strm; - params_prefer[13] = strn; - params_prefer[14] = vwm; - params_prefer[15] = vwn; - -// for(auto &iter : params_prefer) { -// MNN_PRINT("%d ", iter); -// } -// MNN_PRINT(": %d us, shape:%d %d %d batch:%d, flops:%f GFLOPS\n", min_cost, gemmSize[0], gemmSize[1], gemmSize[2], gemmSize[4], 2.0 / 1000.0 * gemmSize[0] * gemmSize[1] * gemmSize[2] * gemmSize[4] / min_cost); + params_prefer[0] = kwg; + params_prefer[1] = kwi; + params_prefer[2] = mdima; + params_prefer[3] = mdimc; + params_prefer[4] = mwg; + params_prefer[5] = ndimb; + params_prefer[6] = ndimc; + params_prefer[7] = nwg; + params_prefer[8] = sa; + params_prefer[9] = sb; + params_prefer[10] = strm; + params_prefer[11] = strn; + params_prefer[12] = vwm; + params_prefer[13] = vwn; + #ifdef TIME_TUNE_LOG + for(auto &iter : params_prefer) { + MNN_PRINT("%d ", iter); + } + MNN_PRINT(": %d us, shape:%d %d %d batch:%d, layout:%d bias:%d, flops:%f GFLOPS\n", min_cost, gemmSize[0], gemmSize[1], gemmSize[2], gemmSize[4], gemmSize[3], gemmSize[5], 2.0 / 1000.0 * gemmSize[0] * gemmSize[1] * gemmSize[2] * gemmSize[4] / min_cost); + #endif } } } diff --git a/source/backend/opencl/core/OpenCLOPRegister.cpp b/source/backend/opencl/core/OpenCLOPRegister.cpp index 9f244d651..3b0eeb5f6 100644 --- a/source/backend/opencl/core/OpenCLOPRegister.cpp +++ b/source/backend/opencl/core/OpenCLOPRegister.cpp @@ -2,6 +2,7 @@ #ifndef MNN_OPENCL_SEP_BUILD namespace MNN { namespace OpenCL { +#ifndef MNN_OPENCL_BUFFER_CLOSED extern void ___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__(); extern void ___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__(); extern void ___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__(); @@ -29,6 +30,7 @@ extern void ___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__(); extern void ___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__(); extern void ___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__(); extern void ___OpenCLScaleBufCreator__OpType_Scale__BUFFER__(); +#endif extern void ___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__(); extern void ___OpenCLMatMulCreator__OpType_MatMul__IMAGE__(); extern void ___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__(); @@ -60,12 +62,13 @@ extern void ___OpenCLInterpCreator__OpType_Interp__IMAGE__(); extern void ___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__(); #ifdef MNN_SUPPORT_TRANSFORMER_FUSE -extern void ___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__(); extern void ___OpenCLSelfAttentionBufCreator__OpType_FmhaV2__BUFFER__(); -extern void ___OpenCLGroupNormBufCreator__OpType_GroupNorm__BUFFER__(); extern void ___OpenCLSplitGeluBufCreator__OpType_SplitGeLU__BUFFER__(); +extern void ___OpenCLGroupNormBufCreator__OpType_GroupNorm__BUFFER__(); +extern void ___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__(); #endif void registerOpenCLOps() { +#ifndef MNN_OPENCL_BUFFER_CLOSED ___OpenCLInterp3DBufCreator__OpType_Interp3D__BUFFER__(); ___OpenCLReductionBufCreator__OpType_Reduction__BUFFER__(); ___OpenCLArgMaxBufCreator__OpType_ArgMax__BUFFER__(); @@ -93,6 +96,7 @@ ___OpenCLUnaryBufCreator__OpType_Sigmoid__BUFFER__(); ___OpenCLUnaryBufCreator__OpType_TanH__BUFFER__(); ___OpenCLGridSampleBufCreator__OpType_GridSample__BUFFER__(); ___OpenCLScaleBufCreator__OpType_Scale__BUFFER__(); +#endif ___OpenCLDepthwiseConvolutionCreator__OpType_ConvolutionDepthwise__IMAGE__(); ___OpenCLMatMulCreator__OpType_MatMul__IMAGE__(); ___OpenCLUnaryCreator__OpType_UnaryOp__IMAGE__(); @@ -122,11 +126,12 @@ ___OpenCLInterp3DCreator__OpType_Interp3D__IMAGE__(); ___OpenCLCastCreator__OpType_Cast__IMAGE__(); ___OpenCLInterpCreator__OpType_Interp__IMAGE__(); ___OpenCLGridSampleCreator__OpType_GridSample__IMAGE__(); + #ifdef MNN_SUPPORT_TRANSFORMER_FUSE -___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__(); ___OpenCLSelfAttentionBufCreator__OpType_FmhaV2__BUFFER__(); -___OpenCLGroupNormBufCreator__OpType_GroupNorm__BUFFER__(); ___OpenCLSplitGeluBufCreator__OpType_SplitGeLU__BUFFER__(); +___OpenCLGroupNormBufCreator__OpType_GroupNorm__BUFFER__(); +___OpenCLAttentionBufCreator__OpType_Attention__BUFFER__(); #endif } } diff --git a/source/backend/opencl/core/OpenCLRunningUtils.cpp b/source/backend/opencl/core/OpenCLRunningUtils.cpp index 2ad64aa69..3898224d3 100644 --- a/source/backend/opencl/core/OpenCLRunningUtils.cpp +++ b/source/backend/opencl/core/OpenCLRunningUtils.cpp @@ -194,9 +194,9 @@ std::pair, uint32_t> localWS3DDefault(const std::vectorgetCLTuneLevel() == Fast) { while(lws[2] <= gws[2] && lws[2] <= 8) { lws[1] = 1; - while(lws[1] <= gws[1] && lws[1] <= 8) { + while(lws[1] <= gws[1] && lws[1] <= 16) { lws[0] = 1; - while(lws[0] <= gws[0] && lws[0] <= 8) { + while(lws[0] <= gws[0] && lws[0] <= 16) { if(lws[0] <= maxWorkItemSizes[0] && lws[1] <= maxWorkItemSizes[1] && lws[2] <= maxWorkItemSizes[2] && lws[0]*lws[1]*lws[2] <= std::min(maxWorkGroupSize, static_cast(64)) && lws[0]*lws[1]*lws[2] >= 16) { cl::Event event; std::vector internalGlobalWS(3, 1); diff --git a/source/backend/opencl/core/runtime/OpenCLRuntime.cpp b/source/backend/opencl/core/runtime/OpenCLRuntime.cpp index 1dde1dcbc..7d6c4e5de 100644 --- a/source/backend/opencl/core/runtime/OpenCLRuntime.cpp +++ b/source/backend/opencl/core/runtime/OpenCLRuntime.cpp @@ -951,12 +951,12 @@ bool OpenCLRuntime::setCache(std::pair cache) { MNN_ERROR("Error tunning gemm info\n"); return false; } - MNN_ASSERT(tun->gemmSize()->size() == 6); + MNN_ASSERT(tun->gemmSize()->size() == 7); std::vector info(tun->gemmSize()->size()); for (int v=0; vgemmSize()->data()[v]; } - MNN_ASSERT(tun->paramInfo()->size() == 16); + MNN_ASSERT(tun->paramInfo()->size() == 14); std::vector params(tun->paramInfo()->size()); for (int v=0; vparamInfo()->data()[v]; diff --git a/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp b/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp index 2382a7081..2ca359ecf 100644 --- a/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/AttentionBufExecution.cpp @@ -6,7 +6,6 @@ // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE #include "backend/opencl/execution/buffer/AttentionBufExecution.hpp" @@ -421,10 +420,8 @@ class AttentionBufCreator : public OpenCLBackend::Creator { return new AttentionBufExecution(op, backend, param->kv_cache()); } }; -REGISTER_OPENCL_OP_CREATOR(AttentionBufCreator, OpType_Attention, BUFFER); +REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(AttentionBufCreator, OpType_Attention, BUFFER); } // namespace OpenCL } // namespace MNN #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ -#endif/* MNN_OPENCL_BUFFER_CLOSED */ - diff --git a/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp b/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp index 9de3796d9..cb33dc05d 100644 --- a/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/AttentionBufExecution.hpp @@ -6,7 +6,6 @@ // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE #ifndef AttentionBufExecution_hpp @@ -83,4 +82,3 @@ class AttentionBufExecution : public CommonExecution { } // namespace MNN #endif /* AttentionBufExecution_hpp */ #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ -#endif /* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/execution/buffer/ConvBufExecution.cpp b/source/backend/opencl/execution/buffer/ConvBufExecution.cpp index d0fdc0cda..8eb739f28 100644 --- a/source/backend/opencl/execution/buffer/ConvBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufExecution.cpp @@ -25,7 +25,7 @@ ConvBufCommonExecution::ConvBufCommonExecution(Backend *backend) { ConvBufCommonExecution::ConvBufCommonExecution(const Convolution2D *conv2dParams, Backend *backend) { auto openclBackend = (OpenCLBackend *)backend; int biasSize = conv2dParams->common()->outputCount(); - int buffer_size = ROUND_UP(biasSize, 16);//pack to 16 + int buffer_size = ROUND_UP(biasSize, 32);//pack to packN if(openclBackend->getOpenCLRuntime()->isSupportedFP16()) { buffer_size *= sizeof(half_float::half); } else { @@ -33,7 +33,7 @@ ConvBufCommonExecution::ConvBufCommonExecution(const Convolution2D *conv2dParams } mResource.reset(new ConvBufResource); - mResource->mBias.reset(Tensor::createDevice({1, 1, 1, ROUND_UP(biasSize, 16)})); + mResource->mBias.reset(Tensor::createDevice({1, 1, 1, ROUND_UP(biasSize, 32)})); backend->onAcquireBuffer(mResource->mBias.get(), Backend::STATIC); cl::Buffer &biasBuffer = openCLBuffer(mResource->mBias.get()); @@ -62,38 +62,6 @@ ConvBufCommonExecution::~ConvBufCommonExecution() { // Do nothing } -void ConvBufExecution::setConv1x1WeightBuffer(int packCout, int packCin, const float* filterDataPtr) { - cl_int res; - std::shared_ptr filterBuffer(Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, 8)/*Cout pack set to max 8*/, ROUND_UP(mResource->mInputChannel, packCin), mResource->mKernelWidth, mResource->mKernelHeight})); - - int buffer_size = filterBuffer->elementSize(); - if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()) { - buffer_size *= sizeof(half_float::half); - } else { - buffer_size *= sizeof(float); - } - mResource->mKernelBuffer.reset(new cl::Buffer(mOpenCLBackend->getOpenCLRuntime()->context(), CL_MEM_READ_WRITE | CL_MEM_ALLOC_HOST_PTR, buffer_size)); - auto kernelBufferPtr = mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueMapBuffer(*(mResource->mKernelBuffer.get()), true, CL_MAP_WRITE, 0, buffer_size, nullptr, nullptr, &res); - if(kernelBufferPtr != nullptr && res == CL_SUCCESS){ - ::memset(kernelBufferPtr, 0, buffer_size); - for(int o = 0; o < mResource->mOutputChannel; o++){ - for(int i = 0 ; i < mResource->mInputChannel; i++){ - int bufferIdx = (o/packCout) * ROUND_UP(mResource->mInputChannel, packCin)*packCout + (i/packCin)*packCin*packCout + (o%packCout)*packCin + (i%packCin);//(Co/packCout, Ci/packCin, packCout, packCin) - int filterIdx = o*mResource->mInputChannel + i; - if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){ - ((half_float::half*)kernelBufferPtr)[bufferIdx] = (half_float::half)(filterDataPtr[filterIdx]); - }else{ - ((float*)kernelBufferPtr)[bufferIdx] = (float)(filterDataPtr[filterIdx]); - } - } - } - }else{ - MNN_ERROR("Map error ptrCL == nullptr \n"); - MNN_ASSERT(false); - } - mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(*(mResource->mKernelBuffer.get()), kernelBufferPtr); -} - void ConvBufExecution::_generateFilterConvertRegion(Tensor* virtualFilter, Tensor* originBuffer) const { auto filterDes = TensorUtils::getDescribe(virtualFilter); filterDes->regions.clear(); @@ -152,29 +120,19 @@ ConvBufExecution::ConvBufExecution(const std::vector &inputs, const st //select opt conv method bool isConv1x1 = (mResource->mKernelHeight == mResource->mKernelWidth && mResource->mKernelHeight == 1 && mPaddings[0] == 0 && mPaddings[1] == 0 && mResource->mStrides[0] == 1 && mResource->mStrides[1] == 1); - bool useConvGemm = isConv1x1 && inputs[0]->width() == 1 && inputs[0]->height() == 1; + + mResource->mConv1x1Opt = isConv1x1; + mResource->mConv1x1C8Opt = mResource->mConv1x1Opt && mResource->mOutputChannel >= 16; + bool useConvGemm = isConv1x1 && mResource->mInputChannel > 32 && mResource->mOutputChannel > 64; if (useConvGemm) { - // Enough computation - bool isTotalLarge = (inputs[0]->batch() * 1.0 / 512 * mResource->mInputChannel / 512 * mResource->mOutputChannel / 512 > 1.0); - bool isEachDimLarge = (inputs[0]->batch() > 256 && mResource->mInputChannel > 128 && mResource->mOutputChannel > 256); - if(isTotalLarge && isEachDimLarge) { - mResource->mConvGemmOptLevel = 2; - } else if(isTotalLarge && inputs[0]->batch() % 64 == 0 && mResource->mInputChannel % 8 == 0 && mResource->mOutputChannel % 64 == 0) { - mResource->mConvGemmOptLevel = 1; - } + mResource->mConvGemmOptLevel = 2; } - mResource->mConv1x1Opt = isConv1x1 && inputs[0]->width() >= 4; } - if (mResource->mConvGemmOptLevel > 0) { + if (mResource->mConv1x1Opt) { // Tile Match with mConvGemmOptLevel == 2 - int tileK = 32; + int tileK = 4; int tileN = 32; - if(mResource->mConvGemmOptLevel == 1) { - tileK = 8; - tileN = 64; - } - int buffer_size = ROUND_UP(mResource->mOutputChannel, tileN) * ROUND_UP(mResource->mInputChannel, tileK); mResource->mFilter.reset( Tensor::createDevice({buffer_size})); @@ -211,14 +169,6 @@ ConvBufExecution::ConvBufExecution(const std::vector &inputs, const st } mOpenCLBackend->getOpenCLRuntime()->commandQueue().enqueueUnmapMemObject(filterBuffer, ptrCL); - } else if (mResource->mConv1x1Opt) { - //At first, set packCout equal to 4 - if(mResource->mOutputChannel >= 16){ - setConv1x1WeightBuffer(8, 4, mFilterDataPtr); - mResource->mConv1x1C8Opt = true; - }else{ - setConv1x1WeightBuffer(4, 4, mFilterDataPtr); - } } else { mResource->mFilter.reset( Tensor::createDevice({ROUND_UP(mResource->mOutputChannel, 4) * ROUND_UP(mResource->mInputChannel, 4) * mResource->mKernelWidth * mResource->mKernelHeight})); @@ -325,31 +275,63 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const std::string info = std::to_string(inputChannels) + "_" + std::to_string(outChannel) + "_" + std::to_string(mResource->mKernelHeight) + "_" + std::to_string(mResource->mKernelWidth) + "_" + std::to_string(mResource->mStrides[0]) + "_" + std::to_string(mResource->mStrides[1]) + "_" + std::to_string(mResource->mDilations[0]) + "_" + std::to_string(mResource->mDilations[1]); + if (mResource->mConvGemmOptLevel > 0) { + int area = height * width; + int M = outputShape.at(0) * area; + int N = outputShape.at(3); + int K = inputShape.at(3); + + bool isAlign = (K % 8 == 0 && area == 1 && N % 64 == 0 && M % 64 == 0); + bool isLimitSize = (M * 1.0 / 512 * N / 512 * K / 512 <= 1.0) && (1.0 * M * K / N / N >= 16.0); + if(isAlign && isLimitSize) { + mResource->mConvGemmOptLevel = 1; + } else if(M < 128 || 1.0 * M / 512 * N / 512 * K / 256 < 1.0) { + mResource->mConvGemmOptLevel = 0; + } + } + if (mResource->mConvGemmOptLevel == 2) { // set large tile - int tileM = 32; + int tileM = 16; int tileN = 32; - int tileK = 32; + int tileK = 4; - int M = outputShape.at(0); + int area = height * width; + int M = outputShape.at(0) * area; int N = outputShape.at(3); int K = inputShape.at(3); + int alignM = ROUND_UP(M, tileM); int alignN = ROUND_UP(N, tileN); int alignK = ROUND_UP(K, tileK); // ReArrange input mConvGemmInpTensor.reset(Tensor::createDevice({alignK * alignM})); - mConvGemmOutTensor.reset(Tensor::createDevice({alignN * alignM})); mOpenCLBackend->onAcquireBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); - mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); - mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); + if(N != alignN || M != alignM || area != 1) { + mNeedOutTempTensor = true; + mConvGemmOutTensor.reset(Tensor::createDevice({alignN * alignM})); + mOpenCLBackend->onAcquireBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); + } mOpenCLBackend->onReleaseBuffer(mConvGemmInpTensor.get(), Backend::DYNAMIC); + if(mNeedOutTempTensor) { + mOpenCLBackend->onReleaseBuffer(mConvGemmOutTensor.get(), Backend::DYNAMIC); + } + { std::set buildOptions; + + int m_pack = 1; + if(area == 1) { + m_pack = 4; + buildOptions.emplace("-DAREA_EQUAL_1"); + } else if(outputShape.at(0) == 1) { + m_pack = 4; + buildOptions.emplace("-DBATCH_EQUAL_1"); + } mPreKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", "transpose_pad", buildOptions); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mPreKernel)); - mPreGlobalWorkSize = {static_cast(alignM/4), static_cast(alignK/4)}; + mPreGlobalWorkSize = {static_cast(alignM/m_pack), static_cast(alignK/4)}; int offset = 0; int idx = 0; @@ -360,6 +342,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= mPreKernel->get().setArg(idx++, static_cast(alignK)); ret |= mPreKernel->get().setArg(idx++, static_cast(M)); ret |= mPreKernel->get().setArg(idx++, static_cast(K)); + ret |= mPreKernel->get().setArg(idx++, static_cast(area)); ret |= mPreKernel->get().setArg(idx++, openCLBuffer(input)); ret |= mPreKernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); MNN_CHECK_CL_SUCCESS(ret, "setArg mConvgemmOptLevel==2 PreKernel"); @@ -371,13 +354,24 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const } std::set buildOptions; + uint32_t hasBias = 0; + if(!mNeedOutTempTensor) { + hasBias = 1; + buildOptions = mResource->mBuildOptions; + buildOptions.emplace("-DBIAS"); + } uint32_t layout = 4; uint32_t batch = 1; - auto param = getGemmParams({(uint32_t)alignM, (uint32_t)alignN, (uint32_t)alignK, layout, batch}, {openCLBuffer(mConvGemmInpTensor.get()), openCLBuffer(mResource->mFilter.get()), openCLBuffer(mConvGemmOutTensor.get())}, mOpenCLBackend->getOpenCLRuntime()); + + cl::Buffer outBuffer = mNeedOutTempTensor ? openCLBuffer(mConvGemmOutTensor.get()) : openCLBuffer(output); + std::vector param; + if(mNeedOutTempTensor) { + param = getGemmParams({(uint32_t)alignM, (uint32_t)alignN, (uint32_t)alignK, layout, batch, hasBias}, {openCLBuffer(mConvGemmInpTensor.get()), openCLBuffer(mResource->mFilter.get()), openCLBuffer(mConvGemmOutTensor.get())}, mOpenCLBackend->getOpenCLRuntime()); + } else { + param = getGemmParams({(uint32_t)alignM, (uint32_t)alignN, (uint32_t)alignK, layout, batch, hasBias}, {openCLBuffer(mConvGemmInpTensor.get()), openCLBuffer(mResource->mFilter.get()), openCLBuffer(output), openCLBuffer(mResource->mBias.get())}, mOpenCLBackend->getOpenCLRuntime()); + } - int GEMMK=param[0], KREG=param[1], KWG=param[2], KWI=param[3], MDIMA=param[4], MDIMC=param[5], MWG=param[6], NDIMB=param[7], NDIMC=param[8], NWG=param[9], SA=param[10], SB=param[11], STRM=param[12], STRN=param[13], VWM=param[14], VWN=param[15]; - buildOptions.emplace("-DGEMMK=" + std::to_string(GEMMK)); - buildOptions.emplace("-DKREG=" + std::to_string(KREG)); + int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; buildOptions.emplace("-DKWG=" + std::to_string(KWG)); buildOptions.emplace("-DKWI=" + std::to_string(KWI)); buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); @@ -395,7 +389,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const if(layout >= 4) { buildOptions.emplace("-DOUTPUTMN"); } - + tileM = MWG; tileN = NWG; int localM = MDIMC; @@ -405,12 +399,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const buildOptions.emplace("-DUSE_CL_MAD=1"); buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1"); } - if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){ - buildOptions.emplace(" -DPRECISION=16"); - } else { - buildOptions.emplace(" -DPRECISION=32"); - } - + mKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "Xgemm", buildOptions); int out_per_thread_m = tileM / localM; @@ -431,8 +420,12 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= mKernel->get().setArg(idx++, beta); ret |= mKernel->get().setArg(idx++, openCLBuffer(mConvGemmInpTensor.get())); ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); -// ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); - ret |= mKernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); + if(mNeedOutTempTensor) { + ret |= mKernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); + } else { + ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); + ret |= mKernel->get().setArg(idx++, openCLBuffer(output)); + } ret |= mKernel->get().setArg(idx++, offset); ret |= mKernel->get().setArg(idx++, offset); ret |= mKernel->get().setArg(idx++, offset); @@ -442,9 +435,12 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const mGlobalWorkSize[0] = ROUND_UP(mGlobalWorkSize[0], std::max((uint32_t)1, mLocalWorkSize[0])); mGlobalWorkSize[1] = ROUND_UP(mGlobalWorkSize[1], std::max((uint32_t)1, mLocalWorkSize[1])); - { - std::set buildOptions; - mPostKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", "add_bias", buildOptions); + if(mNeedOutTempTensor) { + std::set buildOptions = mResource->mBuildOptions; + if(area == 1) { + buildOptions.emplace("-DAREA_EQUAL_1"); + } + mPostKernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("gemm_buf", "transpose_bias", buildOptions); uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mPostKernel)); mPostGlobalWorkSize = {static_cast(M), static_cast(UP_DIV(N, 16))}; @@ -458,12 +454,13 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= mPostKernel->get().setArg(idx++, static_cast(alignN)); ret |= mPostKernel->get().setArg(idx++, static_cast(M)); ret |= mPostKernel->get().setArg(idx++, static_cast(N)); + ret |= mPostKernel->get().setArg(idx++, static_cast(area)); ret |= mPostKernel->get().setArg(idx++, openCLBuffer(mConvGemmOutTensor.get())); ret |= mPostKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= mPostKernel->get().setArg(idx++, openCLBuffer(output)); MNN_CHECK_CL_SUCCESS(ret, "setArg mConvgemmOptLevel==2 PostKernel"); - mPostLocalWorkSize = localWS2DDefault(mPostGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "add_bias", mPostKernel).first; + mPostLocalWorkSize = localWS2DDefault(mPostGlobalWorkSize, maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "transpose_bias", mPostKernel).first; mOpenCLBackend->recordKernel2d(mPostKernel, mPostGlobalWorkSize, mPostLocalWorkSize); mPostGlobalWorkSize[0] = ROUND_UP(mPostGlobalWorkSize[0], std::max((uint32_t)1, mPostLocalWorkSize[0])); mPostGlobalWorkSize[1] = ROUND_UP(mPostGlobalWorkSize[1], std::max((uint32_t)1, mPostLocalWorkSize[1])); @@ -513,6 +510,7 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf mConvgemmOptLevel==1 Kernel Select"); } else if (mResource->mConv1x1Opt) { + int tileN = 32; // {"conv_2d_1x1_c4h1w4", "conv_2d_1x1_c4h1w2", "conv_2d_1x1_c4h1w1", "conv_2d_1x1_c8h1w4"}; const int total_kernel = 3; std::string kernelName[total_kernel] = {"conv_2d_1x1_c4h1w4", "conv_2d_1x1_c4h1w2", "conv_2d_1x1_c4h1w1"}; @@ -554,13 +552,15 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= kernel[knl_idx]->get().setArg(idx++, globalWorkSize[knl_idx][1]); ret |= kernel[knl_idx]->get().setArg(idx++, UP_DIV(width, itemW[knl_idx])); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(input)); - ret |= kernel[knl_idx]->get().setArg(idx++, *mResource->mKernelBuffer.get()); + ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= kernel[knl_idx]->get().setArg(idx++, openCLBuffer(output)); ret |= kernel[knl_idx]->get().setArg(idx++, static_cast(inputChannelBlocks)); ret |= kernel[knl_idx]->get().setArg(idx++, height); ret |= kernel[knl_idx]->get().setArg(idx++, width); ret |= kernel[knl_idx]->get().setArg(idx++, UP_DIV(outChannel, 4)); + ret |= kernel[knl_idx]->get().setArg(idx++, ROUND_UP(outChannel, tileN)); + MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf Kernel Select"); std::pair, int> retTune; @@ -591,13 +591,14 @@ ErrorCode ConvBufExecution::onResize(const std::vector &inputs, const ret |= mKernel->get().setArg(idx++, mGlobalWorkSize[1]); ret |= mKernel->get().setArg(idx++, UP_DIV(width, itemW[min_index])); ret |= mKernel->get().setArg(idx++, openCLBuffer(input)); - ret |= mKernel->get().setArg(idx++, *mResource->mKernelBuffer.get()); + ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mFilter.get())); ret |= mKernel->get().setArg(idx++, openCLBuffer(mResource->mBias.get())); ret |= mKernel->get().setArg(idx++, openCLBuffer(output)); ret |= mKernel->get().setArg(idx++, static_cast(inputChannelBlocks)); ret |= mKernel->get().setArg(idx++, height); ret |= mKernel->get().setArg(idx++, width); ret |= mKernel->get().setArg(idx++, UP_DIV(outChannel, 4)); + ret |= mKernel->get().setArg(idx++, ROUND_UP(outChannel, tileN)); MNN_CHECK_CL_SUCCESS(ret, "setArg Conv1x1Buf"); //printf("conv1x1 %d, %d %d, %d %d, %d %d\n", min_index, mGlobalWorkSize[0], mGlobalWorkSize[1], mLocalWorkSize[0], mLocalWorkSize[1], outChannel, width); @@ -749,8 +750,9 @@ ErrorCode ConvBufExecution::onExecute(const std::vector &inputs, const std::string kw = std::to_string(mResource->mKernelWidth); std::string total = std::to_string(1.0 / 1000000 * inputs[0]->batch() * inputs[0]->channel() * outputs[0]->channel() * outputs[0]->height() * outputs[0]->width() * mResource->mKernelHeight * mResource->mKernelWidth); if (mResource->mConvGemmOptLevel > 0) { + std::string m = std::to_string(outputs[0]->width() * outputs[0]->height() * inputs[0]->batch()); name += "-gemm"; - name += std::to_string(mResource->mConvGemmOptLevel) + "-m" + b + "n" + co + "k" + ci; + name += std::to_string(mResource->mConvGemmOptLevel) + "-m" + m + "n" + co + "k" + ci; } else if (mResource->mConv1x1Opt) { name += "-conv1x1"; name += "-b" + b + "ci" + ci + "hi" + hi + "wi" + wi + "co" + co; diff --git a/source/backend/opencl/execution/buffer/ConvBufExecution.hpp b/source/backend/opencl/execution/buffer/ConvBufExecution.hpp index a1a523ca5..e5abe2a53 100644 --- a/source/backend/opencl/execution/buffer/ConvBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/ConvBufExecution.hpp @@ -65,7 +65,6 @@ class ConvBufExecution : public ConvBufCommonExecution, public CommonExecution { virtual ErrorCode onExecute(const std::vector &inputs, const std::vector &outputs) override; virtual bool onClone(Backend* bn, const Op* op, Execution** dst) override; - void setConv1x1WeightBuffer(int packCout, int packCin, const float* filterDataPtr); private: void _generateFilterConvertRegion(Tensor *virtualFilter, Tensor *originBuffer) const; @@ -75,6 +74,7 @@ class ConvBufExecution : public ConvBufCommonExecution, public CommonExecution { std::shared_ptr mKernel; std::shared_ptr mConvGemmInpTensor; std::shared_ptr mConvGemmOutTensor; + bool mNeedOutTempTensor = false; std::shared_ptr mPreKernel = nullptr; std::vector mPreGlobalWorkSize{1, 1, 1}; std::vector mPreLocalWorkSize{1, 1, 1, 1}; diff --git a/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp b/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp index 52c0d39c2..17ca12a3e 100644 --- a/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp +++ b/source/backend/opencl/execution/buffer/ConvBufWinograd.cpp @@ -205,7 +205,7 @@ ConvBufWinograd::ConvBufWinograd(const MNN::Op* op, Backend* backend) : CommonEx int kernelSize = kx; int alpha = unit + kernelSize - 1; - int tileK = 16; + int tileK = 4; int tileN = 32; std::shared_ptr tmpFilterTensor; @@ -460,7 +460,7 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st { int tileM = 16; int tileN = 32; - int tileK = 16; + int tileK = 4; mSource.reset(Tensor::createDevice( std::vector{alpha * alpha * ROUND_UP(input->channel(), tileK) * ROUND_UP(wUnit * hUnit, tileM)})); mDest.reset(Tensor::createDevice( @@ -541,11 +541,9 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st std::set buildOptions; uint32_t layout = 4; - auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop}, {openCLBuffer(mSource.get()), openCLBuffer(mResource->mWeight.get()), openCLBuffer(mDest.get())}, mOpenCLBackend->getOpenCLRuntime()); + auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)0}, {openCLBuffer(mSource.get()), openCLBuffer(mResource->mWeight.get()), openCLBuffer(mDest.get())}, mOpenCLBackend->getOpenCLRuntime()); - int GEMMK=param[0], KREG=param[1], KWG=param[2], KWI=param[3], MDIMA=param[4], MDIMC=param[5], MWG=param[6], NDIMB=param[7], NDIMC=param[8], NWG=param[9], SA=param[10], SB=param[11], STRM=param[12], STRN=param[13], VWM=param[14], VWN=param[15]; - buildOptions.emplace("-DGEMMK=" + std::to_string(GEMMK)); - buildOptions.emplace("-DKREG=" + std::to_string(KREG)); + int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; buildOptions.emplace("-DKWG=" + std::to_string(KWG)); buildOptions.emplace("-DKWI=" + std::to_string(KWI)); buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); @@ -573,12 +571,6 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st buildOptions.emplace("-DUSE_CL_MAD=1"); buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1"); } - - if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){ - buildOptions.emplace(" -DPRECISION=16"); - } else { - buildOptions.emplace(" -DPRECISION=32"); - } mUnits[b * 3 + 1].kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions); @@ -590,6 +582,9 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st float alpha = 1.0f; float beta = 0.0f; + int batch_offset_a = e_pack * l_pack; + int batch_offset_b = h_pack * l_pack; + int batch_offset_c = e_pack * h_pack; int idx = 0; cl_int ret = CL_SUCCESS; @@ -599,14 +594,11 @@ ErrorCode ConvBufWinograd::onEncode(const std::vector& inputs, const st ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, alpha); ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, beta); ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, openCLBuffer(mSource.get())); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, e_pack); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, l_pack); + ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, batch_offset_a); ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, openCLBuffer(mResource->mWeight.get())); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, h_pack); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, l_pack); + ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, batch_offset_b); ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, openCLBuffer(mDest.get())); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, e_pack); - ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, h_pack); + ret |= mUnits[b * 3 + 1].kernel->get().setArg(idx++, batch_offset_c); MNN_CHECK_CL_SUCCESS(ret, "setArg Winograd batchmatmul Kernel"); mOpenCLBackend->recordKernel3d(mUnits[b * 3 + 1].kernel, mGWS_M[b], mLWS_M[b]); diff --git a/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp b/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp index e2aaf0194..03589bc6b 100644 --- a/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/GroupNormBufExecution.cpp @@ -6,7 +6,6 @@ // Created by MNN on 2024/06/24. // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE #include "backend/opencl/execution/buffer/GroupNormBufExecution.hpp" @@ -260,8 +259,7 @@ class GroupNormBufCreator : public OpenCLBackend::Creator { } }; -REGISTER_OPENCL_OP_CREATOR(GroupNormBufCreator, OpType_GroupNorm, BUFFER); +REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(GroupNormBufCreator, OpType_GroupNorm, BUFFER); } // namespace OpenCL } // namespace MNN #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ -#endif/* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/execution/buffer/GroupNormBufExecution.hpp b/source/backend/opencl/execution/buffer/GroupNormBufExecution.hpp index 2f1ce8313..bf569f983 100644 --- a/source/backend/opencl/execution/buffer/GroupNormBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/GroupNormBufExecution.hpp @@ -5,7 +5,6 @@ // Created by MNN on 2024/06/24. // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE #ifndef GroupNormBufExecution_hpp @@ -42,4 +41,3 @@ class GroupNormBufExecution : public CommonExecution { } // namespace MNN #endif /* GroupNormBufExecution_hpp */ #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ -#endif diff --git a/source/backend/opencl/execution/buffer/LoopBufExecution.cpp b/source/backend/opencl/execution/buffer/LoopBufExecution.cpp index 9935db7df..bf8dfc463 100644 --- a/source/backend/opencl/execution/buffer/LoopBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/LoopBufExecution.cpp @@ -16,7 +16,7 @@ namespace OpenCL { static void _TileOrPackTensor(Tensor *input, Tensor *output, std::shared_ptr& kernelW, cl::NDRange &globalWorkSize, cl::NDRange &localWorkSize, const int Width, const int Height, const int Channel, const int Batch, OpenCLBackend *bn, const std::string& KernelName, std::set buildOptions, - const int WidthPad, const int HeightPad, const int ChannelPad) { + const int WidthPad, const int HeightPad, const int ChannelPad, OpenCLRuntime* runtime) { bool fastTileTranspose = false; if (TensorUtils::getDescribe(output)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC || TensorUtils::getDescribe(input)->dimensionFormat == MNN::MNN_DATA_FORMAT_NHWC){ buildOptions.emplace("-DMNN_NHWC"); @@ -30,21 +30,28 @@ static void _TileOrPackTensor(Tensor *input, Tensor *output, std::shared_ptrisSupportedFP16()) { + local_mem_size = 2; + } + if(buildOptions.find("-DDIMENSION_4") != buildOptions.end()) { + local_mem_size *= (64 * 64 * 4); + if(local_mem_size <= runtime->getMaxLocalMem()) { + if((WidthPad & 63) == 0) { + tileW = 64; + } + if((HeightPad & 63) == 0) { + tileH = 64; + } + } + runKernelName = "tile_trans_4d_buf"; // match with tileW tileH tileW/localW tileH/localH buildOptions.emplace("-DWGSW=" + std::to_string(tileW)); @@ -52,6 +59,15 @@ static void _TileOrPackTensor(Tensor *input, Tensor *output, std::shared_ptronAcquireBuffer(mTmpTensors[i].get(), Backend::DYNAMIC); - _TileOrPackTensor(input, mTmpTensors[i].get(), unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "tile_buf", buildOptions, WidthPad, HeightPad, ChannelPad); + _TileOrPackTensor(input, mTmpTensors[i].get(), unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "tile_buf", buildOptions, WidthPad, HeightPad, ChannelPad, runTime); mUnits.emplace_back(unit); } @@ -499,7 +515,7 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp // MNN_PRINT("input%d offset, %d %d %d %d\n", i, Batch, Channel, Height, Width); Unit unit; - _TileOrPackTensor(input, mOffsetTensors.back().get(), unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "tile_buf", mBuildOptions, Width, Height, Channel); + _TileOrPackTensor(input, mOffsetTensors.back().get(), unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "tile_buf", mBuildOptions, Width, Height, Channel, runTime); mUnits.emplace_back(unit); } } @@ -520,11 +536,9 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp std::set buildOptions; uint32_t layout = 0; - auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)n}, {openCLBuffer(mTmpTensors[1].get()), openCLBuffer(mTmpTensors[2].get()), openCLBuffer(mTmpTensors[0].get())}, mOpenCLBackend->getOpenCLRuntime()); + auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)n, (uint32_t)0}, {openCLBuffer(mTmpTensors[1].get()), openCLBuffer(mTmpTensors[2].get()), openCLBuffer(mTmpTensors[0].get())}, mOpenCLBackend->getOpenCLRuntime()); - int GEMMK=param[0], KREG=param[1], KWG=param[2], KWI=param[3], MDIMA=param[4], MDIMC=param[5], MWG=param[6], NDIMB=param[7], NDIMC=param[8], NWG=param[9], SA=param[10], SB=param[11], STRM=param[12], STRN=param[13], VWM=param[14], VWN=param[15]; - buildOptions.emplace("-DGEMMK=" + std::to_string(GEMMK)); - buildOptions.emplace("-DKREG=" + std::to_string(KREG)); + int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; buildOptions.emplace("-DKWG=" + std::to_string(KWG)); buildOptions.emplace("-DKWI=" + std::to_string(KWI)); buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); @@ -552,12 +566,6 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp buildOptions.emplace("-DUSE_CL_MAD=1"); buildOptions.emplace("-DRELAX_WORKGROUP_SIZE=1"); } - - if(mOpenCLBackend->getOpenCLRuntime()->isSupportedFP16()){ - buildOptions.emplace(" -DPRECISION=16"); - } else { - buildOptions.emplace(" -DPRECISION=32"); - } Unit unit; unit.kernel = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions); @@ -570,7 +578,9 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp float alpha = 1.0; float beta = 0.0f; - + int batch_offset_a = e_pack * l_pack; + int batch_offset_b = h_pack * l_pack; + int batch_offset_c = e_pack * h_pack; int idx = 0; cl_int ret = CL_SUCCESS; ret |= unit.kernel->get().setArg(idx++, static_cast(e_pack)); @@ -579,14 +589,11 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp ret |= unit.kernel->get().setArg(idx++, alpha); ret |= unit.kernel->get().setArg(idx++, beta); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTmpTensors[1].get())); - ret |= unit.kernel->get().setArg(idx++, e_pack); - ret |= unit.kernel->get().setArg(idx++, l_pack); + ret |= unit.kernel->get().setArg(idx++, batch_offset_a); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTmpTensors[2].get())); - ret |= unit.kernel->get().setArg(idx++, h_pack); - ret |= unit.kernel->get().setArg(idx++, l_pack); + ret |= unit.kernel->get().setArg(idx++, batch_offset_b); ret |= unit.kernel->get().setArg(idx++, openCLBuffer(mTmpTensors[0].get())); - ret |= unit.kernel->get().setArg(idx++, e_pack); - ret |= unit.kernel->get().setArg(idx++, h_pack); + ret |= unit.kernel->get().setArg(idx++, batch_offset_c); MNN_CHECK_CL_SUCCESS(ret, "setArg LoopBuf GemmTile Kernel"); unit.globalWorkSize = {globalWorkSize[0], globalWorkSize[1], globalWorkSize[2]}; @@ -686,7 +693,7 @@ ErrorCode LoopBatchMatMulBufExecution::onEncode(const std::vector &inp HeightPad = std::get<1>(shape); ChannelPad = std::get<2>(shape); } - _TileOrPackTensor(mTmpTensors[0].get(), output, unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "pack_buf", buildOptions, WidthPad, HeightPad, ChannelPad); + _TileOrPackTensor(mTmpTensors[0].get(), output, unit.kernel, unit.globalWorkSize, unit.localWorkSize, Width, Height, Channel, Batch, mOpenCLBackend, "pack_buf", buildOptions, WidthPad, HeightPad, ChannelPad, runTime); mUnits.emplace_back(unit); } diff --git a/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.cpp b/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.cpp index dc21fcb68..bc7aba4ef 100644 --- a/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.cpp @@ -5,7 +5,6 @@ // Created by MNN on 2024/06/03. // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE #include @@ -41,7 +40,7 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorshape(); int tile_mn = 32; - int tile_k = 16; // for gemm alignment + int tile_k = 4; // for gemm alignment int batch = shape[0]; int seq_len = shape[1]; mHeadDim = shape[2] / mNumHead / 3; @@ -53,7 +52,7 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vector 1024) { - mQseqSplitNum = (seq_len >= 4096) ? 8 : ((seq_len < 2048) ? 2 : 4); + mQseqSplitNum = (seq_len >= 4096 && seq_len % 64 == 0) ? 8 : ((seq_len < 2048) ? 2 : 4); } int buffer_size = batch * mNumHead * ROUND_UP(mHeadDim, tile_k) * ROUND_UP(seq_len, tile_mn); int buffer_qk_size = batch * mNumHead * ROUND_UP(seq_len, tile_mn) * ROUND_UP(seq_len, tile_mn) / mQseqSplitNum; @@ -154,7 +153,9 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(index++, seq_idx); MNN_CHECK_CL_SUCCESS(ret, "setArg split_transpose_qkv"); mLocalWorkSizeSplit[seq_idx] = localWS3DDefault(mGlobalWorkSizeSplit[seq_idx], maxWorkGroupSize, runtime, "split_transpose_qkv", mKernel_split[seq_idx]).first; - + mGlobalWorkSizeSplit[seq_idx][0] = ROUND_UP(mGlobalWorkSizeSplit[seq_idx][0], std::max((uint32_t)1, mLocalWorkSizeSplit[seq_idx][0])); + mGlobalWorkSizeSplit[seq_idx][1] = ROUND_UP(mGlobalWorkSizeSplit[seq_idx][1], std::max((uint32_t)1, mLocalWorkSizeSplit[seq_idx][1])); + mGlobalWorkSizeSplit[seq_idx][2] = ROUND_UP(mGlobalWorkSizeSplit[seq_idx][2], std::max((uint32_t)1, mLocalWorkSizeSplit[seq_idx][2])); mOpenCLBackend->recordKernel3d(mKernel_split[seq_idx], mGlobalWorkSizeSplit[seq_idx], mLocalWorkSizeSplit[seq_idx]); } @@ -171,11 +172,9 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vector buildOptions; uint32_t layout = 4; - auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop}, {openCLBuffer(mTempQ.get()), openCLBuffer(mTempK.get()), openCLBuffer(mTempQK.get())}, mOpenCLBackend->getOpenCLRuntime()); + auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)0}, {openCLBuffer(mTempQ.get()), openCLBuffer(mTempK.get()), openCLBuffer(mTempQK.get())}, mOpenCLBackend->getOpenCLRuntime()); - int GEMMK=param[0], KREG=param[1], KWG=param[2], KWI=param[3], MDIMA=param[4], MDIMC=param[5], MWG=param[6], NDIMB=param[7], NDIMC=param[8], NWG=param[9], SA=param[10], SB=param[11], STRM=param[12], STRN=param[13], VWM=param[14], VWN=param[15]; - buildOptions.emplace("-DGEMMK=" + std::to_string(GEMMK)); - buildOptions.emplace("-DKREG=" + std::to_string(KREG)); + int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; buildOptions.emplace("-DKWG=" + std::to_string(KWG)); buildOptions.emplace("-DKWI=" + std::to_string(KWI)); buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); @@ -203,13 +202,7 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorgetOpenCLRuntime()->isSupportedFP16()){ - buildOptions.emplace(" -DPRECISION=16"); - } else { - buildOptions.emplace(" -DPRECISION=32"); - } - + buildOptions.emplace("-DONLY_HAVE_ALPHA"); mKernel_qk[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions); int out_per_thread_m = tileM / localM; @@ -220,7 +213,9 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(idx++, static_cast(e_pack)); @@ -229,14 +224,11 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(idx++, alpha); ret |= mKernel_qk[seq_idx]->get().setArg(idx++, beta); ret |= mKernel_qk[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQ.get())); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, e_pack); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, l_pack); + ret |= mKernel_qk[seq_idx]->get().setArg(idx++, batch_offset_a); ret |= mKernel_qk[seq_idx]->get().setArg(idx++, openCLBuffer(mTempK.get())); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, h_pack); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, l_pack); + ret |= mKernel_qk[seq_idx]->get().setArg(idx++, batch_offset_b); ret |= mKernel_qk[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQK.get())); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, e_pack); - ret |= mKernel_qk[seq_idx]->get().setArg(idx++, h_pack); + ret |= mKernel_qk[seq_idx]->get().setArg(idx++, batch_offset_c); MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qk Kernel"); mOpenCLBackend->recordKernel3d(mKernel_qk[seq_idx], mGlobalWorkSizeQk[seq_idx], mLocalWorkSizeQk[seq_idx]); @@ -259,7 +251,7 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vector buildOption; buildOption.emplace("-DSOFTMAX_LOCAL_SIZE=" + std::to_string(localSize)); - // buildOption.emplace("-DOUTPUT_TRANSPOSE"); +// buildOption.emplace("-DOUTPUT_TRANSPOSE"); mKernel_softmax[seq_idx] = runtime->buildKernel("self_attention_buf", "softmax_inside", buildOption, inputs[0], outputs[0]); mGlobalWorkSizeSoftMax[seq_idx] = {static_cast(localSize), static_cast(mSoftmaxShape[1]), static_cast(mSoftmaxShape[0])}; @@ -278,31 +270,16 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vector(localSize), 1, 1}; mOpenCLBackend->recordKernel3d(mKernel_softmax[seq_idx], mGlobalWorkSizeSoftMax[seq_idx], mLocalWorkSizeSoftMax[seq_idx]); } - { - unsigned int tileW = 32; - unsigned int tileH = 32; int loop = batch * mNumHead; int transDimW = ROUND_UP(seq_len, tile_mn) / mQseqSplitNum; int transDimH = ROUND_UP(seq_len, tile_mn); - if((transDimW & 63) == 0 && (transDimH & 63) == 0) { - tileW = 64; - tileH = 64; - } - unsigned int localW = 8; - unsigned int localH = 8; - std::set buildOptions; - buildOptions.emplace("-DWGSW=" + std::to_string(tileW)); - buildOptions.emplace("-DWGSH=" + std::to_string(tileH)); - buildOptions.emplace("-DTSW=" + std::to_string(tileW/localW)); - buildOptions.emplace("-DTSH=" + std::to_string(tileH/localH)); + std::set buildOptions; mKernel_trans[seq_idx] = runtime->buildKernel("self_attention_buf", "trans_3d_buf", buildOptions, inputs[0], outputs[0]); - - int w_per_thread = tileW / localW; - int h_per_thread = tileH / localH; - mGlobalWorkSizeTrans[seq_idx] = {(uint32_t)transDimW/w_per_thread, (uint32_t)transDimH/h_per_thread, (uint32_t)(loop)}; - mLocalWorkSizeTrans[seq_idx] = {localW, localH, 1}; + uint32_t maxWorkGroupSize = static_cast(mOpenCLBackend->getOpenCLRuntime()->getMaxWorkGroupSize(mKernel_trans[seq_idx])); + + mGlobalWorkSizeTrans[seq_idx] = {(uint32_t)transDimW/8, (uint32_t)transDimH/8, (uint32_t)(loop)}; uint32_t index = 0; cl_int ret = CL_SUCCESS; @@ -312,11 +289,11 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(index++, transDimW); ret |= mKernel_trans[seq_idx]->get().setArg(index++, transDimH); MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention transpose"); + mLocalWorkSizeTrans[seq_idx] = localWS3DDefault(mGlobalWorkSizeTrans[seq_idx], maxWorkGroupSize, mOpenCLBackend->getOpenCLRuntime(), "trans_3d_buf", mKernel_trans[seq_idx]).first; mOpenCLBackend->recordKernel3d(mKernel_trans[seq_idx], mGlobalWorkSizeTrans[seq_idx], mLocalWorkSizeTrans[seq_idx]); } - // qk * value { // Sotmax: [Batch * mNumHead, ROUND_UP(seqLen, tile), ROUND_UP(seqLen, tile)] -> [B, K, M] @@ -340,11 +317,9 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vector A:[M, K] B:[N, K] C:[M, N] */ uint32_t layout = 0; - auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop}, {openCLBuffer(mTempTrans.get()), openCLBuffer(mTempV.get()), openCLBuffer(mTempQKV.get())}, mOpenCLBackend->getOpenCLRuntime()); - - int GEMMK=param[0], KREG=param[1], KWG=param[2], KWI=param[3], MDIMA=param[4], MDIMC=param[5], MWG=param[6], NDIMB=param[7], NDIMC=param[8], NWG=param[9], SA=param[10], SB=param[11], STRM=param[12], STRN=param[13], VWM=param[14], VWN=param[15]; - buildOptions.emplace("-DGEMMK=" + std::to_string(GEMMK)); - buildOptions.emplace("-DKREG=" + std::to_string(KREG)); + auto param = getGemmParams({(uint32_t)e_pack, (uint32_t)h_pack, (uint32_t)l_pack, layout, (uint32_t)loop, (uint32_t)0}, {openCLBuffer(mTempTrans.get()), openCLBuffer(mTempV.get()), openCLBuffer(mTempQKV.get())}, mOpenCLBackend->getOpenCLRuntime()); + + int KWG=param[0], KWI=param[1], MDIMA=param[2], MDIMC=param[3], MWG=param[4], NDIMB=param[5], NDIMC=param[6], NWG=param[7], SA=param[8], SB=param[9], STRM=param[10], STRN=param[11], VWM=param[12], VWN=param[13]; buildOptions.emplace("-DKWG=" + std::to_string(KWG)); buildOptions.emplace("-DKWI=" + std::to_string(KWI)); buildOptions.emplace("-DMDIMA=" + std::to_string(MDIMA)); @@ -372,13 +347,7 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorgetOpenCLRuntime()->isSupportedFP16()){ - buildOptions.emplace(" -DPRECISION=16"); - } else { - buildOptions.emplace(" -DPRECISION=32"); - } - + mKernel_qkv[seq_idx] = mOpenCLBackend->getOpenCLRuntime()->buildKernel("matmul_params_buf", "XgemmBatched", buildOptions); int out_per_thread_m = tileM / localM; @@ -389,7 +358,9 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(idx++, static_cast(e_pack)); @@ -398,14 +369,11 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(idx++, alpha); ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, beta); ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, openCLBuffer(mTempTrans.get())); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, e_pack); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, l_pack); + ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, batch_offset_a); ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, openCLBuffer(mTempV.get())); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, h_pack); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, l_pack); + ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, batch_offset_b); ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, openCLBuffer(mTempQKV.get())); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, e_pack); - ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, h_pack); + ret |= mKernel_qkv[seq_idx]->get().setArg(idx++, batch_offset_c); MNN_CHECK_CL_SUCCESS(ret, "setArg Self-Attention batchmatmul qkv Kernel"); mOpenCLBackend->recordKernel3d(mKernel_qkv[seq_idx], mGlobalWorkSizeQkv[seq_idx], mLocalWorkSizeQkv[seq_idx]); } @@ -438,6 +406,10 @@ ErrorCode SelfAttentionBufImpl::onResize(Backend *backend, const std::vectorget().setArg(index++, seq_idx); mLocalWorkSizeClip[seq_idx] = localWS3DDefault(mGlobalWorkSizeClip[seq_idx], maxWorkGroupSize, runtime, "clip_transpose_qkv", mKernel_clip[seq_idx]).first; + mGlobalWorkSizeClip[seq_idx][0] = ROUND_UP(mGlobalWorkSizeClip[seq_idx][0], std::max((uint32_t)1, mLocalWorkSizeClip[seq_idx][0])); + mGlobalWorkSizeClip[seq_idx][1] = ROUND_UP(mGlobalWorkSizeClip[seq_idx][1], std::max((uint32_t)1, mLocalWorkSizeClip[seq_idx][1])); + mGlobalWorkSizeClip[seq_idx][2] = ROUND_UP(mGlobalWorkSizeClip[seq_idx][2], std::max((uint32_t)1, mLocalWorkSizeClip[seq_idx][2])); + MNN_CHECK_CL_SUCCESS(ret, "setArg clip_transpose_qkv"); mOpenCLBackend->recordKernel3d(mKernel_clip[seq_idx], mGlobalWorkSizeClip[seq_idx], mLocalWorkSizeClip[seq_idx]); } @@ -578,9 +550,8 @@ class SelfAttentionBufCreator : public OpenCLBackend::Creator { return new SelfAttentionBufExecution(op, backend); } }; -REGISTER_OPENCL_OP_CREATOR(SelfAttentionBufCreator, OpType_FmhaV2, BUFFER); +REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(SelfAttentionBufCreator, OpType_FmhaV2, BUFFER); } // namespace OpenCL } // namespace MNN #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ -#endif/* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.hpp b/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.hpp index 31786f666..447be8f10 100644 --- a/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/SelfAttentionBufExecution.hpp @@ -5,7 +5,6 @@ // Created by MNN on 2024/06/03. // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE #ifndef SelfAttentionBufExecution_hpp @@ -75,5 +74,3 @@ class SelfAttentionBufExecution : public CommonExecution { } // namespace MNN #endif /* SelfAttentionBufExecution_hpp */ #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ - -#endif /* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/execution/buffer/SplitGeluBufExecution.cpp b/source/backend/opencl/execution/buffer/SplitGeluBufExecution.cpp index 171ee58a1..0baee6428 100644 --- a/source/backend/opencl/execution/buffer/SplitGeluBufExecution.cpp +++ b/source/backend/opencl/execution/buffer/SplitGeluBufExecution.cpp @@ -6,7 +6,6 @@ // Created by MNN on 2024/06/26. // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE #include "backend/opencl/execution/buffer/SplitGeluBufExecution.hpp" @@ -96,8 +95,7 @@ class SplitGeluBufCreator : public OpenCLBackend::Creator { } }; -REGISTER_OPENCL_OP_CREATOR(SplitGeluBufCreator, OpType_SplitGeLU, BUFFER); +REGISTER_OPENCL_OP_CREATOR_TRANSFORMER(SplitGeluBufCreator, OpType_SplitGeLU, BUFFER); } // namespace OpenCL } // namespace MNN #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ -#endif/* MNN_OPENCL_BUFFER_CLOSED */ diff --git a/source/backend/opencl/execution/buffer/SplitGeluBufExecution.hpp b/source/backend/opencl/execution/buffer/SplitGeluBufExecution.hpp index 9ecd65b8b..0f0e6bd60 100644 --- a/source/backend/opencl/execution/buffer/SplitGeluBufExecution.hpp +++ b/source/backend/opencl/execution/buffer/SplitGeluBufExecution.hpp @@ -5,7 +5,6 @@ // Created by MNN on 2024/06/26. // Copyright © 2018, Alibaba Group Holding Limited // -#ifndef MNN_OPENCL_BUFFER_CLOSED #ifdef MNN_SUPPORT_TRANSFORMER_FUSE #ifndef SplitGeluBufExecution_hpp @@ -36,4 +35,3 @@ class SplitGeluBufExecution : public CommonExecution { } // namespace MNN #endif /* SplitGeluBufExecution_hpp */ #endif/* MNN_SUPPORT_TRANSFORMER_FUSE */ -#endif diff --git a/source/backend/opencl/execution/cl/conv_2d_buf.cl b/source/backend/opencl/execution/cl/conv_2d_buf.cl index b6d6dd475..9aed2e670 100644 --- a/source/backend/opencl/execution/cl/conv_2d_buf.cl +++ b/source/backend/opencl/execution/cl/conv_2d_buf.cl @@ -9,8 +9,6 @@ return; \ } - - __kernel void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __global const FLOAT *input, @@ -20,7 +18,8 @@ void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, - __private const int out_c_block) { + __private const int out_c_block, + __private const int out_c_pack) { const int out_c_w_idx = get_global_id(0); //c/4 w const int out_b_h_idx = get_global_id(1); //b h @@ -40,43 +39,45 @@ void conv_2d_1x1_c4h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, const int intput_width_idx0 = out_w4_idx; - int offset = mul24(out_c_idx, in_c_block) << 2; + + int offset = out_c_idx*4; int inp_offset = (((out_b_idx*in_c_block)*out_h + out_h_idx)* out_w + intput_width_idx0) << 2; const int inp_add = out_h*out_w*4; for (ushort in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { + + int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*4); COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4(vload4(2, input+inp_offset)); COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4(vload4(3, input+inp_offset)); - - COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(offset, kernel_ptr)); - COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 1, kernel_ptr)); - COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 2, kernel_ptr)); - COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 3, kernel_ptr)); - - out0.x += dot(weights0, in0); - out0.y += dot(weights1, in0); - out0.z += dot(weights2, in0); - out0.w += dot(weights3, in0); - - out1.x += dot(weights0, in1); - out1.y += dot(weights1, in1); - out1.z += dot(weights2, in1); - out1.w += dot(weights3, in1); - - out2.x += dot(weights0, in2); - out2.y += dot(weights1, in2); - out2.z += dot(weights2, in2); - out2.w += dot(weights3, in2); - - out3.x += dot(weights0, in3); - out3.y += dot(weights1, in3); - out3.z += dot(weights2, in3); - out3.w += dot(weights3, in3); + COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset)); + COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack)); + COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack + out_c_pack)); + + out0 = mad(in0.x, weights0, out0); + out0 = mad(in0.y, weights1, out0); + out0 = mad(in0.z, weights2, out0); + out0 = mad(in0.w, weights3, out0); + + out1 = mad(in1.x, weights0, out1); + out1 = mad(in1.y, weights1, out1); + out1 = mad(in1.z, weights2, out1); + out1 = mad(in1.w, weights3, out1); + + out2 = mad(in2.x, weights0, out2); + out2 = mad(in2.y, weights1, out2); + out2 = mad(in2.z, weights2, out2); + out2 = mad(in2.w, weights3, out2); - offset += 4; + out3 = mad(in3.x, weights0, out3); + out3 = mad(in3.y, weights1, out3); + out3 = mad(in3.z, weights2, out3); + out3 = mad(in3.w, weights3, out3); + + offset += 4 * out_c_pack; inp_offset += inp_add; } @@ -122,7 +123,8 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, - __private const int out_c_block) { + __private const int out_c_block, + __private const int out_c_pack) { const int out_c_w_idx = get_global_id(0); //c/8 w/4 const int out_b_h_idx = get_global_id(1); //b h @@ -146,10 +148,10 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 out7 = out4; const int intput_width_idx0 = out_w4_idx; - + for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*8; + int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*8); const int inp_offset = (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; @@ -157,55 +159,55 @@ void conv_2d_1x1_c8h1w4(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); COMPUTE_FLOAT4 in2 = CONVERT_COMPUTE_FLOAT4(vload4(2, input+inp_offset)); COMPUTE_FLOAT4 in3 = CONVERT_COMPUTE_FLOAT4(vload4(3, input+inp_offset)); - - COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(offset, kernel_ptr)); - COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 1, kernel_ptr)); - COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 2, kernel_ptr)); - COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 3, kernel_ptr)); - COMPUTE_FLOAT4 weights4 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 4, kernel_ptr)); - COMPUTE_FLOAT4 weights5 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 5, kernel_ptr)); - COMPUTE_FLOAT4 weights6 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 6, kernel_ptr)); - COMPUTE_FLOAT4 weights7 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 7, kernel_ptr)); - out0.x += dot(weights0, in0); - out0.y += dot(weights1, in0); - out0.z += dot(weights2, in0); - out0.w += dot(weights3, in0); - - out1.x += dot(weights0, in1); - out1.y += dot(weights1, in1); - out1.z += dot(weights2, in1); - out1.w += dot(weights3, in1); - - out2.x += dot(weights0, in2); - out2.y += dot(weights1, in2); - out2.z += dot(weights2, in2); - out2.w += dot(weights3, in2); - - out3.x += dot(weights0, in3); - out3.y += dot(weights1, in3); - out3.z += dot(weights2, in3); - out3.w += dot(weights3, in3); + COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset)); + COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset)); + COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack)); + COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset + out_c_pack)); + COMPUTE_FLOAT4 weights4 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights5 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights6 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights7 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset + out_c_pack + out_c_pack + out_c_pack)); + + out0 = mad(in0.x, weights0, out0); + out0 = mad(in0.y, weights2, out0); + out0 = mad(in0.z, weights4, out0); + out0 = mad(in0.w, weights6, out0); + + out1 = mad(in1.x, weights0, out1); + out1 = mad(in1.y, weights2, out1); + out1 = mad(in1.z, weights4, out1); + out1 = mad(in1.w, weights6, out1); + + out2 = mad(in2.x, weights0, out2); + out2 = mad(in2.y, weights2, out2); + out2 = mad(in2.z, weights4, out2); + out2 = mad(in2.w, weights6, out2); - out4.x += dot(weights4, in0); - out4.y += dot(weights5, in0); - out4.z += dot(weights6, in0); - out4.w += dot(weights7, in0); - - out5.x += dot(weights4, in1); - out5.y += dot(weights5, in1); - out5.z += dot(weights6, in1); - out5.w += dot(weights7, in1); - - out6.x += dot(weights4, in2); - out6.y += dot(weights5, in2); - out6.z += dot(weights6, in2); - out6.w += dot(weights7, in2); - - out7.x += dot(weights4, in3); - out7.y += dot(weights5, in3); - out7.z += dot(weights6, in3); - out7.w += dot(weights7, in3); + out3 = mad(in3.x, weights0, out3); + out3 = mad(in3.y, weights2, out3); + out3 = mad(in3.z, weights4, out3); + out3 = mad(in3.w, weights6, out3); + + out4 = mad(in0.x, weights1, out4); + out4 = mad(in0.y, weights3, out4); + out4 = mad(in0.z, weights5, out4); + out4 = mad(in0.w, weights7, out4); + + out5 = mad(in1.x, weights1, out5); + out5 = mad(in1.y, weights3, out5); + out5 = mad(in1.z, weights5, out5); + out5 = mad(in1.w, weights7, out5); + + out6 = mad(in2.x, weights1, out6); + out6 = mad(in2.y, weights3, out6); + out6 = mad(in2.z, weights5, out6); + out6 = mad(in2.w, weights7, out6); + + out7 = mad(in3.x, weights1, out7); + out7 = mad(in3.y, weights3, out7); + out7 = mad(in3.z, weights5, out7); + out7 = mad(in3.w, weights7, out7); } #ifdef RELU @@ -285,7 +287,8 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, - __private const int out_c_block) { // oc / 4 + __private const int out_c_block, + __private const int out_c_pack) { const int out_c_w_idx = get_global_id(0); //c/8 w/4 const int out_b_h_idx = get_global_id(1); //b h @@ -305,44 +308,42 @@ void conv_2d_1x1_c8h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, COMPUTE_FLOAT4 out5 = out4; const int intput_width_idx0 = out_w2_idx; - for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*8; + int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*8); const int inp_offset = (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); - - COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(offset, kernel_ptr)); - COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 1, kernel_ptr)); - COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 2, kernel_ptr)); - COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 3, kernel_ptr)); - COMPUTE_FLOAT4 weights4 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 4, kernel_ptr)); - COMPUTE_FLOAT4 weights5 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 5, kernel_ptr)); - COMPUTE_FLOAT4 weights6 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 6, kernel_ptr)); - COMPUTE_FLOAT4 weights7 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 7, kernel_ptr)); + COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset)); + COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset)); + COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack)); + COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset + out_c_pack)); + COMPUTE_FLOAT4 weights4 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights5 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights6 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights7 = CONVERT_COMPUTE_FLOAT4(vload4(1, kernel_ptr + offset + out_c_pack + out_c_pack + out_c_pack)); + + out0 = mad(in0.x, weights0, out0); + out0 = mad(in0.y, weights2, out0); + out0 = mad(in0.z, weights4, out0); + out0 = mad(in0.w, weights6, out0); - out0.x += dot(weights0, in0); - out0.y += dot(weights1, in0); - out0.z += dot(weights2, in0); - out0.w += dot(weights3, in0); - - out1.x += dot(weights0, in1); - out1.y += dot(weights1, in1); - out1.z += dot(weights2, in1); - out1.w += dot(weights3, in1); + out1 = mad(in1.x, weights0, out1); + out1 = mad(in1.y, weights2, out1); + out1 = mad(in1.z, weights4, out1); + out1 = mad(in1.w, weights6, out1); - out4.x += dot(weights4, in0); - out4.y += dot(weights5, in0); - out4.z += dot(weights6, in0); - out4.w += dot(weights7, in0); - - out5.x += dot(weights4, in1); - out5.y += dot(weights5, in1); - out5.z += dot(weights6, in1); - out5.w += dot(weights7, in1); + out4 = mad(in0.x, weights1, out4); + out4 = mad(in0.y, weights3, out4); + out4 = mad(in0.z, weights5, out4); + out4 = mad(in0.w, weights7, out4); + + out5 = mad(in1.x, weights1, out5); + out5 = mad(in1.y, weights3, out5); + out5 = mad(in1.z, weights5, out5); + out5 = mad(in1.w, weights7, out5); } #ifdef RELU @@ -404,7 +405,8 @@ void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, - __private const int out_c_block) { + __private const int out_c_block, + __private const int out_c_pack) { const int out_c_w_idx = get_global_id(0); //c/4 w const int out_b_h_idx = get_global_id(1); //b h @@ -421,21 +423,20 @@ void conv_2d_1x1_c4h1w1(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*4; + int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*4); const int inp_offset = (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); - - COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(offset, kernel_ptr)); - COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 1, kernel_ptr)); - COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 2, kernel_ptr)); - COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 3, kernel_ptr)); - - out0.x += dot(weights0, in0); - out0.y += dot(weights1, in0); - out0.z += dot(weights2, in0); - out0.w += dot(weights3, in0); + COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset)); + COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack)); + COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack + out_c_pack)); + + out0 = mad(in0.x, weights0, out0); + out0 = mad(in0.y, weights1, out0); + out0 = mad(in0.z, weights2, out0); + out0 = mad(in0.w, weights3, out0); } #ifdef RELU @@ -461,7 +462,8 @@ void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, __private const int in_c_block, __private const int out_h, __private const int out_w, - __private const int out_c_block) { + __private const int out_c_block, + __private const int out_c_pack) { const int out_c_w_idx = get_global_id(0); //c/4 w const int out_b_h_idx = get_global_id(1); //b h @@ -482,27 +484,27 @@ void conv_2d_1x1_c4h1w2(GLOBAL_SIZE_2_DIMS __private const int out_w_blocks, for (int in_channel_block_idx = 0; in_channel_block_idx < in_c_block; ++in_channel_block_idx) { - int offset = mad24(out_c_idx, in_c_block, in_channel_block_idx)*4; + int offset = mad24(in_channel_block_idx*4, out_c_pack, out_c_idx*4); const int inp_offset = (((out_b_idx*in_c_block + in_channel_block_idx)*out_h + out_h_idx)* out_w + intput_width_idx0)*4; COMPUTE_FLOAT4 in0 = CONVERT_COMPUTE_FLOAT4(vload4(0, input+inp_offset)); COMPUTE_FLOAT4 in1 = CONVERT_COMPUTE_FLOAT4(vload4(1, input+inp_offset)); + + COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset)); + COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack)); + COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack)); + COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(0, kernel_ptr + offset + out_c_pack + out_c_pack + out_c_pack)); + + out0 = mad(in0.x, weights0, out0); + out0 = mad(in0.y, weights1, out0); + out0 = mad(in0.z, weights2, out0); + out0 = mad(in0.w, weights3, out0); - COMPUTE_FLOAT4 weights0 = CONVERT_COMPUTE_FLOAT4(vload4(offset, kernel_ptr)); - COMPUTE_FLOAT4 weights1 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 1, kernel_ptr)); - COMPUTE_FLOAT4 weights2 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 2, kernel_ptr)); - COMPUTE_FLOAT4 weights3 = CONVERT_COMPUTE_FLOAT4(vload4(offset + 3, kernel_ptr)); - - out0.x += dot(weights0, in0); - out0.y += dot(weights1, in0); - out0.z += dot(weights2, in0); - out0.w += dot(weights3, in0); - - out1.x += dot(weights0, in1); - out1.y += dot(weights1, in1); - out1.z += dot(weights2, in1); - out1.w += dot(weights3, in1); + out1 = mad(in1.x, weights0, out1); + out1 = mad(in1.y, weights1, out1); + out1 = mad(in1.z, weights2, out1); + out1 = mad(in1.w, weights3, out1); } #ifdef RELU diff --git a/source/backend/opencl/execution/cl/gemm_buf.cl b/source/backend/opencl/execution/cl/gemm_buf.cl index cd4c7fc94..903b62252 100644 --- a/source/backend/opencl/execution/cl/gemm_buf.cl +++ b/source/backend/opencl/execution/cl/gemm_buf.cl @@ -121,15 +121,17 @@ __kernel void gemm_buf2(GLOBAL_SIZE_DIM2 vstore4(CONVERT_FLOAT4(o1.scdef), 1, output+out_offset+12*width); } -// [M, K/4, 4] -> [alignK, alignM] +// [B, K/4, area, 4] -> [alignK, alignM] (M = B * area) __kernel void transpose_pad(GLOBAL_SIZE_DIM2 const int alignM, const int alignK, const int M, const int K, + const int area, __global const FLOAT* input, __global FLOAT* output ) { +#ifdef AREA_EQUAL_1 const int idx_m4 = get_global_id(0); // idx M const int idx_k4 = get_global_id(1); // idx K UNIFORM_BOUNDRY_CHECK(idx_m4, idx_k4); @@ -149,22 +151,75 @@ __kernel void transpose_pad(GLOBAL_SIZE_DIM2 vstore4((FLOAT4)(m0k4.y, m1k4.y, m2k4.y, m3k4.y), 0, output + out_offset_base + alignM); vstore4((FLOAT4)(m0k4.z, m1k4.z, m2k4.z, m3k4.z), 0, output + out_offset_base + alignM + alignM); vstore4((FLOAT4)(m0k4.w, m1k4.w, m2k4.w, m3k4.w), 0, output + out_offset_base + alignM + alignM + alignM); +#elif defined BATCH_EQUAL_1 + + const int idx_m4 = get_global_id(0); // idx M + const int idx_k4 = get_global_id(1); // idx K + UNIFORM_BOUNDRY_CHECK(idx_m4, idx_k4); + + const int idx_m = idx_m4 << 2; + const int idx_k = idx_k4 << 2; + const int K_4 = (K + 3) >> 2; + const int in_offset_base = (idx_k4 * area + idx_m) * 4; + const int out_offset_base = idx_k * alignM + idx_m; + + FLOAT4 m0k4 = (idx_k4 >= K_4 || idx_m + 0 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base); + FLOAT4 m1k4 = (idx_k4 >= K_4 || idx_m + 1 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + 4); + FLOAT4 m2k4 = (idx_k4 >= K_4 || idx_m + 2 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + 8); + FLOAT4 m3k4 = (idx_k4 >= K_4 || idx_m + 3 >= M) ? (FLOAT4)0 : vload4(0, input + in_offset_base + 12); + + vstore4((FLOAT4)(m0k4.x, m1k4.x, m2k4.x, m3k4.x), 0, output + out_offset_base); + vstore4((FLOAT4)(m0k4.y, m1k4.y, m2k4.y, m3k4.y), 0, output + out_offset_base + alignM); + vstore4((FLOAT4)(m0k4.z, m1k4.z, m2k4.z, m3k4.z), 0, output + out_offset_base + alignM + alignM); + vstore4((FLOAT4)(m0k4.w, m1k4.w, m2k4.w, m3k4.w), 0, output + out_offset_base + alignM + alignM + alignM); + +#else + + const int idx_m = get_global_id(0); // idx M + const int idx_k4 = get_global_id(1); // idx K + UNIFORM_BOUNDRY_CHECK(idx_m, idx_k4); + + const int K_4 = (K + 3) >> 2; + const int idx_k = idx_k4 << 2; + const int out_offset_base = idx_k * alignM + idx_m; + + if(idx_k4 >= K_4 || idx_m >= M) { + output[out_offset_base] = (FLOAT)0; + output[out_offset_base + alignM] = (FLOAT)0; + output[out_offset_base + alignM + alignM] = (FLOAT)0; + output[out_offset_base + alignM + alignM + alignM] = (FLOAT)0; + return; + } + const int idx_b = idx_m / area; + const int idx_area = idx_m % area; + + const int in_offset_base = ((idx_b * K_4 + idx_k4) * area + idx_area) * 4; + FLOAT4 data = vload4(0, input + in_offset_base); + + output[out_offset_base] = data.x; + output[out_offset_base + alignM] = data.y; + output[out_offset_base + alignM + alignM] = data.z; + output[out_offset_base + alignM + alignM + alignM] = data.w; +#endif } -// [alignM, alignN] -> [M, N/4, 4] -__kernel void add_bias(GLOBAL_SIZE_DIM2 +// [alignM, alignN] -> [B, N/4, area, 4] (M = B * area) +__kernel void transpose_bias(GLOBAL_SIZE_DIM2 const int alignM, const int alignN, const int M, const int N, + const int area, __global const FLOAT* input0, __global const FLOAT* input1, __global FLOAT* output ) { +#ifdef AREA_EQUAL_1 const int idx_m = get_global_id(0); // idx M const int idx_n_16 = get_global_id(1); // idx N UNIFORM_BOUNDRY_CHECK(idx_m, idx_n_16); + const int N_4 = (N + 3) >> 2; const int N_16 = (N + 15) >> 4; const int N_left = N & 15; bool canVec16 = (N_left == 0 || (N_left != 0 && idx_n_16 < N_16 - 1)); @@ -172,31 +227,120 @@ __kernel void add_bias(GLOBAL_SIZE_DIM2 FLOAT16 res0 = vload16(0, input0 + idx_m * alignN + (idx_n_16 << 4)); FLOAT16 res1 = vload16(0, input1 + (idx_n_16 << 4)); FLOAT16 res = res0 + res1; - vstore16(res, 0, output + ((idx_m * N_16 + idx_n_16) << 4)); + #ifdef RELU + res = fmax(res, (FLOAT16)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT16)0, (FLOAT16)6); + #endif + vstore16(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2)); } else { - const int N_4 = (N + 3) >> 2; FLOAT4 res0 = vload4(0, input0 + idx_m * alignN + (idx_n_16 << 4)); FLOAT4 res1 = vload4(0, input1 + (idx_n_16 << 4)); FLOAT4 res = res0 + res1; - vstore4(res, 0, output + ((idx_m * N_16 + idx_n_16) << 4)); + #ifdef RELU + res = fmax(res, (FLOAT4)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT4)0, (FLOAT4)6); + #endif + vstore4(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2)); if(idx_n_16 * 4 + 1 >= N_4) return; res0 = vload4(0, input0 + idx_m * alignN + (idx_n_16 << 4) + 4); res1 = vload4(0, input1 + (idx_n_16 << 4) + 4); res = res0 + res1; - vstore4(res, 0, output + ((idx_m * N_16 + idx_n_16) << 4) + 4); + #ifdef RELU + res = fmax(res, (FLOAT4)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT4)0, (FLOAT4)6); + #endif + vstore4(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2) + 4); if(idx_n_16 * 4 + 2 >= N_4) return; res0 = vload4(0, input0 + idx_m * alignN + (idx_n_16 << 4) + 8); res1 = vload4(0, input1 + (idx_n_16 << 4) + 8); res = res0 + res1; - vstore4(res, 0, output + ((idx_m * N_16 + idx_n_16) << 4) + 8); + #ifdef RELU + res = fmax(res, (FLOAT4)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT4)0, (FLOAT4)6); + #endif + vstore4(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2) + 8); if(idx_n_16 * 4 + 3 >= N_4) return; res0 = vload4(0, input0 + idx_m * alignN + (idx_n_16 << 4) + 12); res1 = vload4(0, input1 + (idx_n_16 << 4) + 12); res = res0 + res1; - vstore4(res, 0, output + ((idx_m * N_16 + idx_n_16) << 4) + 12); + #ifdef RELU + res = fmax(res, (FLOAT4)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT4)0, (FLOAT4)6); + #endif + vstore4(res, 0, output + ((idx_m * N_4 + (idx_n_16 << 2)) << 2) + 12); } +#else + const int idx_m = get_global_id(0); // idx M + const int idx_n_16 = get_global_id(1); // idx N + UNIFORM_BOUNDRY_CHECK(idx_m, idx_n_16); + + const int N_4 = (N + 3) >> 2; + + const int idx_b = idx_m / area; + const int idx_area = idx_m % area; + + const int inp_base_offset = idx_m * alignN + (idx_n_16 << 4); + const int out_base_offset = ((idx_b * N_4 + idx_n_16 * 4) * area + idx_area) * 4; + + FLOAT4 res0 = vload4(0, input0 + inp_base_offset); + FLOAT4 res1 = vload4(0, input1 + (idx_n_16 << 4)); + FLOAT4 res = res0 + res1; + #ifdef RELU + res = fmax(res, (FLOAT4)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT4)0, (FLOAT4)6); + #endif + vstore4(res, 0, output + out_base_offset); + + if(idx_n_16 * 4 + 1 >= N_4) return; + res0 = vload4(0, input0 + inp_base_offset + 4); + res1 = vload4(0, input1 + (idx_n_16 << 4) + 4); + res = res0 + res1; + #ifdef RELU + res = fmax(res, (FLOAT4)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT4)0, (FLOAT4)6); + #endif + vstore4(res, 0, output + out_base_offset + area * 4); + + if(idx_n_16 * 4 + 2 >= N_4) return; + res0 = vload4(0, input0 + inp_base_offset + 8); + res1 = vload4(0, input1 + (idx_n_16 << 4) + 8); + res = res0 + res1; + #ifdef RELU + res = fmax(res, (FLOAT4)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT4)0, (FLOAT4)6); + #endif + vstore4(res, 0, output + out_base_offset + area * 8); + + if(idx_n_16 * 4 + 3 >= N_4) return; + res0 = vload4(0, input0 + inp_base_offset + 12); + res1 = vload4(0, input1 + (idx_n_16 << 4) + 12); + res = res0 + res1; + #ifdef RELU + res = fmax(res, (FLOAT4)0); + #endif + #ifdef RELU6 + res = clamp(res, (FLOAT4)0, (FLOAT4)6); + #endif + vstore4(res, 0, output + out_base_offset + area * 12); +#endif } diff --git a/source/backend/opencl/execution/cl/matmul_params_buf.cl b/source/backend/opencl/execution/cl/matmul_params_buf.cl index 601dc50e6..ce1895d44 100644 --- a/source/backend/opencl/execution/cl/matmul_params_buf.cl +++ b/source/backend/opencl/execution/cl/matmul_params_buf.cl @@ -5,11 +5,6 @@ // ================================================================================================= #define USE_INLINE_KEYWORD 1 -// Parameters set by the tuner or by the database. Here they are given a basic default value in case -// this kernel file is used outside of the CLBlast library. -#ifndef GEMMK - #define GEMMK 0 // Kernel to choose: 0 regular, 1 with 2D register tiling -#endif #ifndef MWG #define MWG 8 // Tile-size in dimension M (e.g. 64, 128) #endif @@ -17,7 +12,7 @@ #define NWG 8 // Tile-size in dimension N (e.g. 64, 128) #endif #ifndef KWG - #define KWG 8 // Tile-size in dimension K (e.g. 8, 16) + #define KWG 16 // Tile-size in dimension K (e.g. 8, 16) #endif #ifndef MDIMC #define MDIMC 8 // Threads per workgroup in M-dimension (e.g. 8, 16, 32) @@ -32,7 +27,7 @@ #define NDIMB 8 // Re-shaped tile dimension of matrix B: KDIMB * NDIMB (kernel 0 only) #endif #ifndef KWI - #define KWI 1 // Unroll factor of the KWG loop (smaller or equal than KWG) + #define KWI 2 // Unroll factor of the KWG loop (smaller or equal than KWG) #endif #ifndef VWM #define VWM 1 // Vector width of matrices A and C @@ -52,9 +47,6 @@ #ifndef SB #define SB 0 // Use local/shared memory to cache matrix B (1) or not (0) (kernel 0 only) #endif -#ifndef KREG - #define KREG 1 // Amount of register tiling in second dimension, multiple of VWN (kernel 1 only) -#endif // Helper parameters based on the above tuning parameters #define MWI (MWG/MDIMC) // Work per work-item (M-dimension) @@ -74,39 +66,6 @@ #define GLOBAL_MEM_FENCE 0 // Global synchronisation barrier for potential better performance #endif -// Half-precision -#if PRECISION == 16 - typedef half real; - typedef half2 real2; - typedef half4 real4; - typedef half8 real8; - typedef half16 real16; - #define ZERO 0 - #define ONE 1 - #define SMALLEST -1.0e14 - -// Single-precision -#elif PRECISION == 32 - typedef float real; - typedef float2 real2; - typedef float4 real4; - typedef float8 real8; - typedef float16 real16; - #define ZERO 0.0f - #define ONE 1.0f - #define SMALLEST -1.0e37f -#endif - -// Converts a 'real argument' value to a 'real' value as passed to the kernel. Normally there is no -// conversion, but half-precision is not supported as kernel argument so it is converted from float. -#if PRECISION == 16 - typedef float real_arg; - #define GetRealArg(x) (half)x -#else - typedef real real_arg; - #define GetRealArg(x) x -#endif - // Pointers to local memory objects (using a define because CUDA doesn't need them) #ifndef LOCAL_PTR #define LOCAL_PTR __local @@ -124,115 +83,18 @@ #define RELAX_WORKGROUP_SIZE 0 #endif +#define ZERO (FLOAT)0.0f // Sets a variable to zero -#if PRECISION == 3232 || PRECISION == 6464 - #define SetToZero(a) a.x = ZERO; a.y = ZERO -#else - #define SetToZero(a) a = ZERO -#endif - -// Sets a variable to zero (only the imaginary part) -#if PRECISION == 3232 || PRECISION == 6464 - #define ImagToZero(a) a.y = ZERO -#else - #define ImagToZero(a) -#endif - -// Sets a variable to one -#if PRECISION == 3232 || PRECISION == 6464 - #define SetToOne(a) a.x = ONE; a.y = ZERO -#else - #define SetToOne(a) a = ONE -#endif - -// Determines whether a variable is zero -#if PRECISION == 3232 || PRECISION == 6464 - #define IsZero(a) ((a.x == ZERO) && (a.y == ZERO)) -#else - #define IsZero(a) (a == ZERO) -#endif - -// The absolute value (component-wise) -#if PRECISION == 3232 || PRECISION == 6464 - #define AbsoluteValue(value) value.x = fabs(value.x); value.y = fabs(value.y) -#else - #define AbsoluteValue(value) value = fabs(value) -#endif - -// Negation (component-wise) -#if PRECISION == 3232 || PRECISION == 6464 - #define Negate(value) value.x = -(value.x); value.y = -(value.y) -#else - #define Negate(value) value = -(value) -#endif - -// Adds two complex variables -#if PRECISION == 3232 || PRECISION == 6464 - #define Add(c,a,b) c.x = a.x + b.x; c.y = a.y + b.y -#else - #define Add(c,a,b) c = a + b -#endif - -// Subtracts two complex variables -#if PRECISION == 3232 || PRECISION == 6464 - #define Subtract(c,a,b) c.x = a.x - b.x; c.y = a.y - b.y -#else - #define Subtract(c,a,b) c = a - b -#endif - -// Multiply two complex variables (used in the defines below) -#if PRECISION == 3232 || PRECISION == 6464 - #define MulReal(a,b) a.x*b.x - a.y*b.y - #define MulImag(a,b) a.x*b.y + a.y*b.x -#endif - -// The scalar multiply function -#if PRECISION == 3232 || PRECISION == 6464 - #define Multiply(c,a,b) c.x = MulReal(a,b); c.y = MulImag(a,b) -#else - #define Multiply(c,a,b) c = a * b -#endif - -// The scalar multiply-add function -#if PRECISION == 3232 || PRECISION == 6464 - #define MultiplyAdd(c,a,b) c.x += MulReal(a,b); c.y += MulImag(a,b) -#else - #if USE_CL_MAD == 1 - #define MultiplyAdd(c,a,b) c = mad(a, b, c) - #else - #define MultiplyAdd(c,a,b) c += a * b - #endif -#endif - -// The scalar multiply-subtract function -#if PRECISION == 3232 || PRECISION == 6464 - #define MultiplySubtract(c,a,b) c.x -= MulReal(a,b); c.y -= MulImag(a,b) -#else - #define MultiplySubtract(c,a,b) c -= a * b -#endif - -// The scalar division function: full division -#if PRECISION == 3232 || PRECISION == 6464 - #define DivideFull(c,a,b) singlereal num_x = (a.x * b.x) + (a.y * b.y); singlereal num_y = (a.y * b.x) - (a.x * b.y); singlereal denom = (b.x * b.x) + (b.y * b.y); c.x = num_x / denom; c.y = num_y / denom -#else - #define DivideFull(c,a,b) c = a / b -#endif - -// The scalar AXPBY function -#if PRECISION == 3232 || PRECISION == 6464 - #define AXPBY(e,a,b,c,d) e.x = MulReal(a,b) + MulReal(c,d); e.y = MulImag(a,b) + MulImag(c,d) -#else - #define AXPBY(e,a,b,c,d) e = a*b + c*d -#endif - -// The complex conjugate operation for complex transforms -#if PRECISION == 3232 || PRECISION == 6464 - #define COMPLEX_CONJUGATE(value) value.x = value.x; value.y = -value.y +#define SetToZero(a) a = ZERO +#define IsZero(a) (a == ZERO) +#define Multiply(c,a,b) c = a * b +#if USE_CL_MAD == 1 +#define MultiplyAdd(c,a,b) c = mad(a, b, c) #else - #define COMPLEX_CONJUGATE(value) +#define MultiplyAdd(c,a,b) c += a * b #endif -// ================================================================================================= +#define AXPBY(e,a,b,c,d) e = a*b + c*d // Force inlining functions or not: some compilers don't support the inline keyword #ifdef USE_INLINE_KEYWORD @@ -241,39 +103,42 @@ #define INLINE_FUNC #endif -// ================================================================================================= - INLINE_FUNC int GetGroupID1() { return get_group_id(1); } - INLINE_FUNC int GetGroupID0() { return get_group_id(0); } +INLINE_FUNC int GetGroupID1() { return get_group_id(1); } +INLINE_FUNC int GetGroupID0() { return get_group_id(0); } // ================================================================================================= // End of the C++11 raw string literal +typedef float real_arg; +#define GetRealArg(x) (FLOAT)x +typedef FLOAT real; + // Data-widths in dimension M #if VWM == 1 - typedef real realM; + typedef FLOAT realM; #elif VWM == 2 - typedef real2 realM; + typedef FLOAT2 realM; #elif VWM == 4 - typedef real4 realM; + typedef FLOAT4 realM; #elif VWM == 8 - typedef real8 realM; + typedef FLOAT8 realM; #elif VWM == 16 - typedef real16 realM; + typedef FLOAT16 realM; #endif // Data-widths in dimension N #if VWN == 1 - typedef real realN; + typedef FLOAT realN; #elif VWN == 2 - typedef real2 realN; + typedef FLOAT2 realN; #elif VWN == 4 - typedef real4 realN; + typedef FLOAT4 realN; #elif VWN == 8 - typedef real8 realN; + typedef FLOAT8 realN; #elif VWN == 16 - typedef real16 realN; + typedef FLOAT16 realN; #endif // ================================================================================================= @@ -430,26 +295,89 @@ INLINE_FUNC void GlobalToLocalB(const __global realN* restrict bgm, LOCAL_PTR re // Caches global off-chip memory directly into per-thread private memory (registers). This function // is specific for caching the A input matrix. -#if SA == 0 && GEMMK == 0 +#if SA == 0 +INLINE_FUNC int GlobalIndexA() { + // Computes the indices based on strided/non-strided access + #if STRM == 0 + // [MWG/MWI, MWI/VWM, VWM] + int mg = get_local_id(0)*(MWI/VWM); + #elif STRM == 1 + // [MWI/VWM, MWG/MWI, VWM] + int mg = get_local_id(0); + #endif + + // Computes the indices for the global memory + // [kSizeM/MWG, (MWG/VWM), VWM] + int idm = mg + GetGroupID0() * (MWG/VWM); + return idm; +} + +INLINE_FUNC realM GlobalToPrivateOptA(const __global realM* restrict agm, const int base, const int _mi, + const int kSizeM, const int idk) { + // Computes the indices based on strided/non-strided access + #if STRM == 0 + // [MWG/MWI, MWI/VWM, VWM] + int idm = base + _mi; + #elif STRM == 1 + // [MWI/VWM, MWG/MWI, VWM] + int idm = base + _mi*MDIMC; + #endif + + // Loads the data from global memory (not transposed) and stores into registers + // [kSizeK, kSizeM/VWM, VWM] + return agm[idk*(kSizeM/VWM) + idm]; +} + INLINE_FUNC realM GlobalToPrivateA(const __global realM* restrict agm, const int _mi, - const int kSizeM, const int idk, const int kwg) { + const int kSizeM, const int idk) { // Computes the indices based on strided/non-strided access #if STRM == 0 + // [MWG/MWI, MWI/VWM, VWM] int mg = _mi + get_local_id(0)*(MWI/VWM); #elif STRM == 1 + // [MWI/VWM, MWG/MWI, VWM] int mg = get_local_id(0) + _mi*MDIMC; #endif // Computes the indices for the global memory + // [kSizeM/MWG, (MWG/VWM), VWM] int idm = mg + GetGroupID0() * (MWG/VWM); // Loads the data from global memory (not transposed) and stores into registers + // [kSizeK, kSizeM/VWM, VWM] return agm[idk*(kSizeM/VWM) + idm]; } + #endif // Same as above, but now for the B input matrix -#if SB == 0 && GEMMK == 0 +#if SB == 0 +INLINE_FUNC int GlobalIndexB() { + // Computes the indices based on strided/non-strided access + #if STRN == 0 + int ng = get_local_id(1)*(NWI/VWN); + #elif STRN == 1 + int ng = get_local_id(1); + #endif + + // Computes the indices for the global memory + int idn = ng + GetGroupID1() * (NWG/VWN); + return idn; +} + +INLINE_FUNC realN GlobalToPrivateOptB(const __global realN* restrict bgm, const int base, const int _ni, + const int kSizeN, const int idk) { + // Computes the indices based on strided/non-strided access + #if STRN == 0 + int idn = base + _ni; + #elif STRN == 1 + int idn = base + _ni*NDIMC; + #endif + + // Loads the data from global memory (transposed) and stores into registers + return bgm[idk*(kSizeN/VWN) + idn]; +} + INLINE_FUNC realN GlobalToPrivateB(const __global realN* restrict bgm, const int _ni, const int kSizeN, const int idk) { // Computes the indices based on strided/non-strided access @@ -494,12 +422,14 @@ INLINE_FUNC realN LocalToPrivateB(LOCAL_PTR realN* blm, const int _ni, const int } #endif - - // The vectorised multiply-add function INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bval) { #if USE_VECTOR_MAD == 1 + #if USE_CL_MAD == 1 + cvec = mad(avec, (realM)bval, cvec); + #else cvec += avec * bval; + #endif #else #if VWM == 1 MultiplyAdd(cvec, avec, bval); @@ -545,7 +475,11 @@ INLINE_FUNC realM MultiplyAddVector(realM cvec, const realM avec, const real bva // The vectorised multiply-add function INLINE_FUNC realN MultiplyAddVectorN(realN cvec, const real avec, const realN bval) { #if USE_VECTOR_MAD == 1 + #if USE_CL_MAD == 1 + cvec = mad((realN)avec, bval, cvec); + #else cvec += avec * bval; + #endif #else #if VWN == 1 MultiplyAdd(cvec, avec, bval); @@ -592,28 +526,51 @@ INLINE_FUNC realN MultiplyAddVectorN(realN cvec, const real avec, const realN bv // Merges the results in Cpm with the global array in Cgm. This also performs the multiplication // with the constants: Cgm = alpha*A*B + beta*Cgm = alpha*Cpm + beta*Cgm -// layout : [N, M] -INLINE_FUNC void StoreResultsM(__global realM* cgm, realM c_value, const int _mi, const int _ni, - const int kSizeM, const real alpha, const real beta) { + +typedef struct { + int index[2]; +} INT2; + +INLINE_FUNC INT2 StoreIndexM() { + INT2 res; #if STRM == 0 - int mg = _mi + get_local_id(0)*(MWI/VWM); + int mg = get_local_id(0)*(MWI/VWM); #elif STRM == 1 - int mg = get_local_id(0) + _mi*MDIMC; + int mg = get_local_id(0); #endif #if STRN == 0 - int ng = _ni + get_local_id(1)*NWI; + int ng = get_local_id(1)*NWI; #elif STRN == 1 - int ng = _ni%VWN + get_local_id(1)*VWN + (_ni/VWN)*VWN*NDIMC; + int ng = get_local_id(1)*VWN; #endif int idm = mg + GetGroupID0() * (MWG/VWM); int idn = ng + GetGroupID1() * NWG; + res.index[0] = idm; + res.index[1] = idn; + return res; +} + +// layout : [N, M] +INLINE_FUNC void StoreResultsM(__global realM* cgm, realM c_value, const INT2 baseOffset, const int _mi, const int _ni, + const int kSizeM, const real alpha, const real beta) { + #if STRM == 0 + int idm = _mi + baseOffset.index[0]; + #elif STRM == 1 + int idm = baseOffset.index[0] + _mi*MDIMC; + #endif + #if STRN == 0 + int idn = _ni + baseOffset.index[1]; + #elif STRN == 1 + int idn = _ni%VWN + baseOffset.index[1] + (_ni/VWN)*VWN*NDIMC; + #endif + int index = idn*(kSizeM/VWM) + idm; - realM result; - realM xval = c_value; + realM result = c_value; // The final multiplication with alpha (in case beta == 0) - if (IsZero(beta)) { + #ifdef ONLY_HAVE_ALPHA + realM xval = c_value; #if VWM == 1 Multiply(result, alpha, xval); #elif VWM == 2 @@ -651,10 +608,11 @@ INLINE_FUNC void StoreResultsM(__global realM* cgm, realM c_value, const int _mi Multiply(result.sE, alpha, xval.sE); Multiply(result.sF, alpha, xval.sF); #endif - } + #endif // The final multiplication with alpha and the addition with beta*C - else { + #ifdef HAVE_ALPHA_BETA + realM xval = c_value; realM yval = cgm[index]; #if VWM == 1 AXPBY(result, alpha, xval, beta, yval); @@ -693,39 +651,56 @@ INLINE_FUNC void StoreResultsM(__global realM* cgm, realM c_value, const int _mi AXPBY(result.sE, alpha, xval.sE, beta, yval.sE); AXPBY(result.sF, alpha, xval.sF, beta, yval.sF); #endif - } + #endif cgm[index] = result; } - +INLINE_FUNC INT2 StoreIndexN() { + INT2 res; + #if STRM == 0 + int mg = get_local_id(0)*MWI; + #elif STRM == 1 + int mg = get_local_id(0)*VWM; + #endif + #if STRN == 0 + int ng = get_local_id(1)*(NWI/VWN); + #elif STRN == 1 + int ng = get_local_id(1); + #endif + int idm = mg + GetGroupID0() * MWG; + int idn = ng + GetGroupID1() * (NWG/VWN); + + res.index[0] = idm; + res.index[1] = idn; + return res; +} // layout : [M, N] INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, + const INT2 baseOffset, #ifdef BIAS - __global realN* egn, + realN* epm, #endif const int _mi, const int _ni, const int kSizeN, const real alpha, const real beta) { - - + #if STRM == 0 - int mg = _mi + get_local_id(0)*MWI; + int idm = _mi + baseOffset.index[0]; #elif STRM == 1 - int mg = _mi%VWM + get_local_id(0)*VWM + (_mi/VWM)*VWM*MDIMC; + int idm = _mi%VWM + baseOffset.index[0] + (_mi/VWM)*VWM*MDIMC; #endif #if STRN == 0 - int ng = _ni + get_local_id(1)*(NWI/VWN); + int idn = _ni + baseOffset.index[1]; #elif STRN == 1 - int ng = get_local_id(1) + _ni*NDIMC; + int idn = baseOffset.index[1] + _ni*NDIMC; #endif - int idm = mg + GetGroupID0() * MWG; - int idn = ng + GetGroupID1() * (NWG/VWN); + int index = idm * (kSizeN/VWN) + idn; - realN result = 0; - realN xval = c_value; + realN result = c_value; // The final multiplication with alpha (in case beta == 0) - if (IsZero(beta)) { + #ifdef ONLY_HAVE_ALPHA + realN xval = c_value; #if VWN == 1 Multiply(result, alpha, xval); #elif VWN == 2 @@ -763,10 +738,11 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, Multiply(result.sE, alpha, xval.sE); Multiply(result.sF, alpha, xval.sF); #endif - } + #endif // The final multiplication with alpha and the addition with beta*C - else { + #ifdef HAVE_ALPHA_BETA + realN xval = c_value; realN yval = cgn[index]; #if VWN == 1 AXPBY(result, alpha, xval, beta, yval); @@ -805,48 +781,78 @@ INLINE_FUNC void StoreResultsN(__global realN* cgn, realN c_value, AXPBY(result.sE, alpha, xval.sE, beta, yval.sE); AXPBY(result.sF, alpha, xval.sF, beta, yval.sF); #endif - } + #endif #ifdef BIAS - realN xval = egn[idn]; + realN eval = epm[_ni]; #if VWN == 1 - result += xval; + result += eval; + #ifdef RELU + result = fmax(result, (FLOAT)0); + #endif + #ifdef RELU6 + result = clamp(result, (FLOAT)0, (FLOAT)6); + #endif #elif VWN == 2 - result.x += xval.x; - result.y += xval.y; + result.x += eval.x; + result.y += eval.y; + #ifdef RELU + result = fmax(result, (FLOAT2)0); + #endif + #ifdef RELU6 + result = clamp(result, (FLOAT2)0, (FLOAT2)6); + #endif #elif VWN == 4 - result.x += xval.x; - result.y += xval.y; - result.z += xval.z; - result.w += xval.w; + result.x += eval.x; + result.y += eval.y; + result.z += eval.z; + result.w += eval.w; + #ifdef RELU + result = fmax(result, (FLOAT4)0); + #endif + #ifdef RELU6 + result = clamp(result, (FLOAT4)0, (FLOAT4)6); + #endif #elif VWN == 8 - result.s0 += xval.s0; - result.s1 += xval.s1; - result.s2 += xval.s2; - result.s3 += xval.s3; - result.s4 += xval.s4; - result.s5 += xval.s5; - result.s6 += xval.s6; - result.s7 += xval.s7; + result.s0 += eval.s0; + result.s1 += eval.s1; + result.s2 += eval.s2; + result.s3 += eval.s3; + result.s4 += eval.s4; + result.s5 += eval.s5; + result.s6 += eval.s6; + result.s7 += eval.s7; + #ifdef RELU + result = fmax(result, (FLOAT8)0); + #endif + #ifdef RELU6 + result = clamp(result, (FLOAT8)0, (FLOAT8)6); + #endif #elif VWN == 16 - result.s0 += xval.s0; - result.s1 += xval.s1; - result.s2 += xval.s2; - result.s3 += xval.s3; - result.s4 += xval.s4; - result.s5 += xval.s5; - result.s6 += xval.s6; - result.s7 += xval.s7; - result.s8 += xval.s8; - result.s9 += xval.s9; - result.sA += xval.sA; - result.sB += xval.sB; - result.sC += xval.sC; - result.sD += xval.sD; - result.sE += xval.sE; - result.sF += xval.sF; + result.s0 += eval.s0; + result.s1 += eval.s1; + result.s2 += eval.s2; + result.s3 += eval.s3; + result.s4 += eval.s4; + result.s5 += eval.s5; + result.s6 += eval.s6; + result.s7 += eval.s7; + result.s8 += eval.s8; + result.s9 += eval.s9; + result.sA += eval.sA; + result.sB += eval.sB; + result.sC += eval.sC; + result.sD += eval.sD; + result.sE += eval.sE; + result.sF += eval.sF; + #ifdef RELU + result = fmax(result, (FLOAT16)0); + #endif + #ifdef RELU6 + result = clamp(result, (FLOAT16)0, (FLOAT16)6); + #endif #endif #endif @@ -869,13 +875,6 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, , LOCAL_PTR realN* blm #endif ) { - - // Allocates workitem-private memory (registers) - #pragma promote_to_registers - realM apm[MWI/VWM]; // MWI * 1 - #pragma promote_to_registers - realN bpm[NWI/VWN]; // 1 * NWI - #ifdef OUTPUTMN #pragma promote_to_registers realN cpn[MWI*(NWI/VWN)]; // MWI * NWI @@ -909,64 +908,182 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #endif // Loops over all workgroup tiles - for (int kwg = 0; kwg < kSizeK; kwg += KWG * KREG) { - - // Loads data: off-chip --> local (matrix A) - #if SA == 1 - GlobalToLocalA(agm, alm, kSizeM, tid, kwg); - #endif - // Loads data: off-chip --> local (matrix B) - #if SB == 1 - GlobalToLocalB(bgm, blm, kSizeN, tid, kwg); - #endif - #if SA == 1 || SB == 1 - barrier(CLK_LOCAL_MEM_FENCE); - #endif - - // Loops over all workitem tiles, unrolled by a factor KWI - for (int pwi = 0; pwi < KWG * KREG; pwi += KWI * KREG) { - #pragma unroll - for (int _pit = 0; _pit < KWI*KREG; _pit += KREG) { - #if SA == 0 || SB == 0 - int idk = kwg + pwi + _pit; + #if SA == 1 || SB == 1 + // Allocates workitem-private memory (registers) + #pragma promote_to_registers + realM apm[MWI/VWM]; // MWI * 1 + #pragma promote_to_registers + realN bpm[NWI/VWN]; // 1 * NWI + + for (int kwg = 0; kwg < kSizeK; kwg += KWG) { + // Loads data: off-chip --> local (matrix A) + #if SA == 1 + GlobalToLocalA(agm, alm, kSizeM, tid, kwg); #endif - #if SA == 1 || SB == 1 - int kg = pwi + _pit; + // Loads data: off-chip --> local (matrix B) + #if SB == 1 + GlobalToLocalB(bgm, blm, kSizeN, tid, kwg); #endif + barrier(CLK_LOCAL_MEM_FENCE); + + // Loops over all workitem tiles, unrolled by a factor KWI + for (int pwi = 0; pwi < KWG; pwi += KWI) { + #pragma unroll + for (int _pit = 0; _pit < KWI; _pit += 1) { + #if SA == 0 || SB == 0 + int idk = kwg + pwi + _pit; + #endif + int kg = pwi + _pit; + + // Loads matrix A (kernel 0) or matrix B (kernel 1) + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + // Loads data: local --> private (matrix A) + #if SA == 1 + apm[_mi] = LocalToPrivateA(alm, _mi, kg); + // Loads data: off-chip --> private (matrix A) + #elif SA == 0 + apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk); + #endif + } - // Loads matrix A (kernel 0) or matrix B (kernel 1) - #pragma unroll - for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - // Loads data: local --> private (matrix A) - #if GEMMK == 0 && SA == 1 - apm[_mi] = LocalToPrivateA(alm, _mi, kg); - // Loads data: off-chip --> private (matrix A) - #elif GEMMK == 0 && SA == 0 - apm[_mi] = GlobalToPrivateA(agm, _mi, kSizeM, idk, kwg); - #endif - } + // Loads matrix B (kernel 0) or matrix A (kernel 1) - // Loads matrix B (kernel 0) or matrix A (kernel 1) + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + // Loads data: local --> private (matrix B) + #if SB == 1 + bpm[_ni] = LocalToPrivateB(blm, _ni, kg); + // Loads data: off-chip --> private (matrix B) + #else + bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); + #endif + } - #pragma unroll - for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - // Loads data: local --> private (matrix B) - #if SB == 1 - bpm[_ni] = LocalToPrivateB(blm, _ni, kg); - // Loads data: off-chip --> private (matrix B) - #else - bpm[_ni] = GlobalToPrivateB(bgm, _ni, kSizeN, idk); - #endif + // Performs the accumulation (Cpm += Apm * Bpm) + + #ifdef OUTPUTMN + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + const realM aval = apm[_mi]; + #if VWM == 1 + // [MWI/VWM, VWM, NWI/VWN, VWN] + cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval, bpm[_ni]); + #elif VWM == 2 + cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.x, bpm[_ni]); + cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.y, bpm[_ni]); + #elif VWM == 4 + cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.x, bpm[_ni]); + cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.y, bpm[_ni]); + cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni], aval.z, bpm[_ni]); + cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni], aval.w, bpm[_ni]); + #elif VWM == 8 + cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval.s0, bpm[_ni]); + cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1)*(NWI/VWN) + _ni], aval.s1, bpm[_ni]); + cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2)*(NWI/VWN) + _ni], aval.s2, bpm[_ni]); + cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3)*(NWI/VWN) + _ni], aval.s3, bpm[_ni]); + cpn[(_mi*VWM + 4)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 4)*(NWI/VWN) + _ni], aval.s4, bpm[_ni]); + cpn[(_mi*VWM + 5)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 5)*(NWI/VWN) + _ni], aval.s5, bpm[_ni]); + cpn[(_mi*VWM + 6)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 6)*(NWI/VWN) + _ni], aval.s6, bpm[_ni]); + cpn[(_mi*VWM + 7)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 7)*(NWI/VWN) + _ni], aval.s7, bpm[_ni]); + #elif VWM == 16 + cpn[(_mi*VWM + 0 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0 )*(NWI/VWN) + _ni], aval.s0, bpm[_ni]); + cpn[(_mi*VWM + 1 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 1 )*(NWI/VWN) + _ni], aval.s1, bpm[_ni]); + cpn[(_mi*VWM + 2 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 2 )*(NWI/VWN) + _ni], aval.s2, bpm[_ni]); + cpn[(_mi*VWM + 3 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 3 )*(NWI/VWN) + _ni], aval.s3, bpm[_ni]); + cpn[(_mi*VWM + 4 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 4 )*(NWI/VWN) + _ni], aval.s4, bpm[_ni]); + cpn[(_mi*VWM + 5 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 5 )*(NWI/VWN) + _ni], aval.s5, bpm[_ni]); + cpn[(_mi*VWM + 6 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 6 )*(NWI/VWN) + _ni], aval.s6, bpm[_ni]); + cpn[(_mi*VWM + 7 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 7 )*(NWI/VWN) + _ni], aval.s7, bpm[_ni]); + cpn[(_mi*VWM + 8 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 8 )*(NWI/VWN) + _ni], aval.s8, bpm[_ni]); + cpn[(_mi*VWM + 9 )*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 9 )*(NWI/VWN) + _ni], aval.s9, bpm[_ni]); + cpn[(_mi*VWM + 10)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 10)*(NWI/VWN) + _ni], aval.sA, bpm[_ni]); + cpn[(_mi*VWM + 11)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 11)*(NWI/VWN) + _ni], aval.sB, bpm[_ni]); + cpn[(_mi*VWM + 12)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 12)*(NWI/VWN) + _ni], aval.sC, bpm[_ni]); + cpn[(_mi*VWM + 13)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 13)*(NWI/VWN) + _ni], aval.sD, bpm[_ni]); + cpn[(_mi*VWM + 14)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 14)*(NWI/VWN) + _ni], aval.sE, bpm[_ni]); + cpn[(_mi*VWM + 15)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 15)*(NWI/VWN) + _ni], aval.sF, bpm[_ni]); + #endif + } + } + #else + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + const realM aval = apm[_mi]; + #if VWN == 1 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); + #elif VWN == 2 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + #elif VWN == 4 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); + #elif VWN == 8 + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + #elif VWN == 16 + cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); + cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); + cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); + cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); + cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); + cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); + cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); + cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); + cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); + cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); + cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); + cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); + cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); + cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); + cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); + #endif + } + } + #endif + } } + barrier(CLK_LOCAL_MEM_FENCE); + } + #else + // Allocates workitem-private memory (registers) - // Performs the accumulation (Cpm += Apm * Bpm) + int baseIndexA = GlobalIndexA(); + int baseIndexB = GlobalIndexB(); + #pragma unroll + for (int _kj = 0; _kj < kSizeK; _kj += 4) { #ifdef OUTPUTMN + #pragma promote_to_registers + realN bpm[NWI/VWN]; // 1 * NWI + + #pragma unroll + for(int _ki = 0; _ki < 4; _ki += 1) { + int idk = _kj + _ki; + #pragma unroll + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + // Loads data: off-chip --> private (matrix B) + bpm[_ni] = GlobalToPrivateOptB(bgm, baseIndexB, _ni, kSizeN, idk); + } + #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + const realM aval = GlobalToPrivateOptA(agm, baseIndexA, _mi, kSizeM, idk); #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { - const realM aval = apm[_mi]; #if VWM == 1 // [MWI/VWM, VWM, NWI/VWN, VWN] cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni] = MultiplyAddVectorN(cpn[(_mi*VWM + 0)*(NWI/VWN) + _ni], aval, bpm[_ni]); @@ -1007,77 +1124,104 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, #endif } } + } #else + + #pragma promote_to_registers + realM apm[MWI/VWM]; // MWI * 1 + #pragma unroll + for(int _ki = 0; _ki < 4; _ki += 1) { + int idk = _kj + _ki; + #pragma unroll + for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { + // Loads data: off-chip --> private (matrix B) + apm[_mi] = GlobalToPrivateOptA(agm, baseIndexA, _mi, kSizeM, idk); + } #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + const realN bval = GlobalToPrivateOptB(bgm, baseIndexB, _ni, kSizeN, idk); + #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { const realM aval = apm[_mi]; #if VWN == 1 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni]); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval); #elif VWN == 2 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval.x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bval.y); #elif VWN == 4 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].x); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].y); - cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].z); - cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].w); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval.x); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bval.y); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bval.z); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bval.w); #elif VWN == 8 - cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bpm[_ni].s0); - cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bpm[_ni].s1); - cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bpm[_ni].s2); - cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bpm[_ni].s3); - cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bpm[_ni].s4); - cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bpm[_ni].s5); - cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bpm[_ni].s6); - cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bpm[_ni].s7); + cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0)*(MWI/VWM) + _mi], aval, bval.s0); + cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1)*(MWI/VWM) + _mi], aval, bval.s1); + cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2)*(MWI/VWM) + _mi], aval, bval.s2); + cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3)*(MWI/VWM) + _mi], aval, bval.s3); + cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4)*(MWI/VWM) + _mi], aval, bval.s4); + cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5)*(MWI/VWM) + _mi], aval, bval.s5); + cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6)*(MWI/VWM) + _mi], aval, bval.s6); + cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7)*(MWI/VWM) + _mi], aval, bval.s7); #elif VWN == 16 - cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bpm[_ni].s0); - cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bpm[_ni].s1); - cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bpm[_ni].s2); - cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bpm[_ni].s3); - cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bpm[_ni].s4); - cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bpm[_ni].s5); - cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bpm[_ni].s6); - cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bpm[_ni].s7); - cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bpm[_ni].s8); - cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bpm[_ni].s9); - cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bpm[_ni].sA); - cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bpm[_ni].sB); - cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bpm[_ni].sC); - cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bpm[_ni].sD); - cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bpm[_ni].sE); - cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bpm[_ni].sF); + cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 0 )*(MWI/VWM) + _mi], aval, bval.s0); + cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 1 )*(MWI/VWM) + _mi], aval, bval.s1); + cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 2 )*(MWI/VWM) + _mi], aval, bval.s2); + cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 3 )*(MWI/VWM) + _mi], aval, bval.s3); + cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 4 )*(MWI/VWM) + _mi], aval, bval.s4); + cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 5 )*(MWI/VWM) + _mi], aval, bval.s5); + cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 6 )*(MWI/VWM) + _mi], aval, bval.s6); + cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 7 )*(MWI/VWM) + _mi], aval, bval.s7); + cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 8 )*(MWI/VWM) + _mi], aval, bval.s8); + cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 9 )*(MWI/VWM) + _mi], aval, bval.s9); + cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 10)*(MWI/VWM) + _mi], aval, bval.sA); + cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 11)*(MWI/VWM) + _mi], aval, bval.sB); + cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 12)*(MWI/VWM) + _mi], aval, bval.sC); + cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 13)*(MWI/VWM) + _mi], aval, bval.sD); + cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 14)*(MWI/VWM) + _mi], aval, bval.sE); + cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi] = MultiplyAddVector(cpm[(_ni*VWN + 15)*(MWI/VWM) + _mi], aval, bval.sF); #endif } } + } #endif } - } - #if SA == 1 || SB == 1 - barrier(CLK_LOCAL_MEM_FENCE); - #endif - } + #endif + #if GLOBAL_MEM_FENCE == 1 barrier(CLK_GLOBAL_MEM_FENCE); #endif #ifdef OUTPUTMN + INT2 baseOffset = StoreIndexN(); + #ifdef BIAS + #pragma promote_to_registers + realN epm[NWI/VWN]; // MWI * 1 + for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { + #if STRN == 0 + int idn = _ni + baseOffset.index[1]; + #elif STRN == 1 + int idn = baseOffset.index[1] + _ni*NDIMC; + #endif + epm[_ni] = egm[idn]; + } + #endif #pragma unroll for (int _mi = 0; _mi < MWI; _mi += 1) { #pragma unroll for (int _ni = 0; _ni < NWI/VWN; _ni += 1) { StoreResultsN((__global realN* )cgm, cpn[_mi * (NWI/VWN) + _ni], + baseOffset, #ifdef BIAS - egm, + (realN*)epm, #endif _mi, _ni, kSizeN, alpha, beta); } } #else - + INT2 baseOffset = StoreIndexM(); + // Stores an MWG * NWG tile of results and performs the multiplication with alpha and beta const int cld = kSizeM; @@ -1085,7 +1229,7 @@ INLINE_FUNC void XgemmBody(const int kSizeM, const int kSizeN, const int kSizeK, for (int _ni = 0; _ni < NWI; _ni += 1) { #pragma unroll for (int _mi = 0; _mi < MWI/VWM; _mi += 1) { - StoreResultsM(cgm, cpm[_ni * (MWI/VWM) + _mi], _mi, _ni, cld, alpha, beta); + StoreResultsM(cgm, cpm[_ni * (MWI/VWM) + _mi], baseOffset, _mi, _ni, cld, alpha, beta); } } #endif @@ -1160,20 +1304,28 @@ void Xgemm(const int kSizeM, const int kSizeN, const int kSizeK, void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK, const real_arg arg_alpha, const real_arg arg_beta, - const __global realM* restrict agm, const int a_one, const int a_two, - const __global realN* restrict bgm, const int b_one, const int b_two, - __global realM* cgm, const int c_one, const int c_two) { + const __global realM* restrict agm, const int batch_offset_a, + const __global realN* restrict bgm, const int batch_offset_b, + #ifdef BIAS + const __global realN* restrict egm, const int batch_offset_e, + #endif + __global realM* cgm, const int batch_offset_c) { const int batch = get_group_id(2); const real alpha = GetRealArg(arg_alpha); const real beta = GetRealArg(arg_beta); // Sets the offsets - const int a_offset = batch * a_one * a_two; - const int b_offset = batch * b_one * b_two; - const int c_offset = batch * c_one * c_two; + const int a_offset = batch * batch_offset_a; + const int b_offset = batch * batch_offset_b; + const int c_offset = batch * batch_offset_c; const __global realM* restrict agm_ = &agm[a_offset / VWM]; const __global realN* restrict bgm_ = &bgm[b_offset / VWN]; __global realM* restrict cgm_ = &cgm[c_offset / VWM]; + + #ifdef BIAS + const int e_offset = batch * batch_offset_e; + const __global realN* restrict egm_ = &egm[e_offset / VWN]; + #endif // Allocates workgroup-private memory (local memory) #if SA == 1 @@ -1185,12 +1337,28 @@ void XgemmBatched(const int kSizeM, const int kSizeN, const int kSizeK, // Computes the matrix-multiplication and stores the result in global memory #if SA == 1 && SB == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm, blm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, + #ifdef BIAS + egm_, + #endif + cgm_, alpha, beta, alm, blm); #elif SA == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, alm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, + #ifdef BIAS + egm_, + #endif + cgm_, alpha, beta, alm); #elif SB == 1 - XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta, blm); + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, + #ifdef BIAS + egm_, + #endif + cgm_, alpha, beta, blm); #else - XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, cgm_, alpha, beta); + XgemmBody(kSizeM, kSizeN, kSizeK, agm_, bgm_, + #ifdef BIAS + egm_, + #endif + cgm_, alpha, beta); #endif } diff --git a/source/backend/opencl/execution/cl/opencl_program.cc b/source/backend/opencl/execution/cl/opencl_program.cc index b72288012..f98f47417 100644 --- a/source/backend/opencl/execution/cl/opencl_program.cc +++ b/source/backend/opencl/execution/cl/opencl_program.cc @@ -1931,7 +1931,8 @@ const char* self_attention_buf = " }\n" " sum[lid]=maxValue;\n" " barrier(CLK_LOCAL_MEM_FENCE);\n" -" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i /= 2){\n" +" #pragma unroll\n" +" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i >>= 1){\n" " if (lid0; i /= 2){\n" +" #pragma unroll\n" +" for(int i=SOFTMAX_LOCAL_SIZE/2; i>0; i >>= 1){\n" " if (lid [N Y X]\n" "__kernel void trans_3d_buf(__global const FLOAT* input,\n" " __global FLOAT* output,\n" @@ -1984,50 +1974,28 @@ const char* self_attention_buf = ") {\n" " int b=get_global_id(2);\n" " \n" -" const int lidw=get_local_id(0);\n" -" const int lidh=get_local_id(1);\n" -" // group id\n" -" const int w=get_group_id(0)*WGSW;\n" -" const int h=get_group_id(1)*WGSH;\n" -" int iw=lidw;\n" -" int jh=lidh;\n" -" \n" -" __local FLOAT4 localData[WGSW][WGSH/4];//w64h64\n" -" \n" -" #pragma unroll\n" -" for(int i=0; i [alignK,alignM]\n" +"// [B,K/4,area,4] -> [alignK,alignM] (M=B*area)\n" "__kernel void transpose_pad(GLOBAL_SIZE_DIM2\n" " const int alignM,\n" " const int alignK,\n" " const int M,\n" " const int K,\n" +" const int area,\n" " __global const FLOAT* input,\n" " __global FLOAT* output\n" " ) {\n" +"#ifdef AREA_EQUAL_1\n" " const int idx_m4=get_global_id(0); // idx M\n" " const int idx_k4=get_global_id(1); // idx K\n" " UNIFORM_BOUNDRY_CHECK(idx_m4,idx_k4);\n" @@ -12066,20 +12030,67 @@ const char* gemm_buf = " vstore4((FLOAT4)(m0k4.y,m1k4.y,m2k4.y,m3k4.y),0,output+out_offset_base+alignM);\n" " vstore4((FLOAT4)(m0k4.z,m1k4.z,m2k4.z,m3k4.z),0,output+out_offset_base+alignM+alignM);\n" " vstore4((FLOAT4)(m0k4.w,m1k4.w,m2k4.w,m3k4.w),0,output+out_offset_base+alignM+alignM+alignM);\n" +"#elif defined BATCH_EQUAL_1\n" +" const int idx_m4=get_global_id(0); // idx M\n" +" const int idx_k4=get_global_id(1); // idx K\n" +" UNIFORM_BOUNDRY_CHECK(idx_m4,idx_k4);\n" +" const int idx_m=idx_m4 << 2;\n" +" const int idx_k=idx_k4 << 2;\n" +" const int K_4=(K+3) >> 2;\n" +" const int in_offset_base=(idx_k4*area+idx_m)*4;\n" +" const int out_offset_base=idx_k*alignM+idx_m;\n" +" FLOAT4 m0k4=(idx_k4 >= K_4 || idx_m+0 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base);\n" +" FLOAT4 m1k4=(idx_k4 >= K_4 || idx_m+1 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+4);\n" +" FLOAT4 m2k4=(idx_k4 >= K_4 || idx_m+2 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+8);\n" +" FLOAT4 m3k4=(idx_k4 >= K_4 || idx_m+3 >= M) ? (FLOAT4)0 : vload4(0,input+in_offset_base+12);\n" +" vstore4((FLOAT4)(m0k4.x,m1k4.x,m2k4.x,m3k4.x),0,output+out_offset_base);\n" +" vstore4((FLOAT4)(m0k4.y,m1k4.y,m2k4.y,m3k4.y),0,output+out_offset_base+alignM);\n" +" vstore4((FLOAT4)(m0k4.z,m1k4.z,m2k4.z,m3k4.z),0,output+out_offset_base+alignM+alignM);\n" +" vstore4((FLOAT4)(m0k4.w,m1k4.w,m2k4.w,m3k4.w),0,output+out_offset_base+alignM+alignM+alignM);\n" +"#else\n" +" const int idx_m=get_global_id(0); // idx M\n" +" const int idx_k4=get_global_id(1); // idx K\n" +" UNIFORM_BOUNDRY_CHECK(idx_m,idx_k4);\n" +" \n" +" const int K_4=(K+3) >> 2;\n" +" const int idx_k=idx_k4 << 2;\n" +" const int out_offset_base=idx_k*alignM+idx_m;\n" +" \n" +" if(idx_k4 >= K_4 || idx_m >= M) {\n" +" output[out_offset_base]=(FLOAT)0;\n" +" output[out_offset_base+alignM]=(FLOAT)0;\n" +" output[out_offset_base+alignM+alignM]=(FLOAT)0;\n" +" output[out_offset_base+alignM+alignM+alignM]=(FLOAT)0;\n" +" return;\n" +" }\n" +" const int idx_b=idx_m/area;\n" +" const int idx_area=idx_m % area;\n" +" \n" +" const int in_offset_base=((idx_b*K_4+idx_k4)*area+idx_area)*4;\n" +" FLOAT4 data=vload4(0,input+in_offset_base);\n" +" \n" +" output[out_offset_base]=data.x;\n" +" output[out_offset_base+alignM]=data.y;\n" +" output[out_offset_base+alignM+alignM]=data.z;\n" +" output[out_offset_base+alignM+alignM+alignM]=data.w;\n" +"#endif\n" "}\n" -"// [alignM,alignN] -> [M,N/4,4]\n" -"__kernel void add_bias(GLOBAL_SIZE_DIM2\n" +"// [alignM,alignN] -> [B,N/4,area,4] (M=B*area)\n" +"__kernel void transpose_bias(GLOBAL_SIZE_DIM2\n" " const int alignM,\n" " const int alignN,\n" " const int M,\n" " const int N,\n" +" const int area,\n" " __global const FLOAT* input0,\n" " __global const FLOAT* input1,\n" " __global FLOAT* output\n" " ) {\n" +"#ifdef AREA_EQUAL_1\n" " const int idx_m=get_global_id(0); // idx M\n" " const int idx_n_16=get_global_id(1); // idx N\n" " UNIFORM_BOUNDRY_CHECK(idx_m,idx_n_16);\n" +" const int N_4=(N+3) >> 2;\n" " const int N_16=(N+15) >> 4;\n" " const int N_left=N & 15;\n" " bool canVec16=(N_left == 0 || (N_left != 0 && idx_n_16> 2;\n" " FLOAT4 res0=vload4(0,input0+idx_m*alignN+(idx_n_16 << 4));\n" " FLOAT4 res1=vload4(0,input1+(idx_n_16 << 4));\n" " FLOAT4 res=res0+res1;\n" -" vstore4(res,0,output+((idx_m*N_16+idx_n_16) << 4));\n" +" #ifdef RELU\n" +" res=fmax(res,(FLOAT4)0);\n" +" #endif\n" +" #ifdef RELU6\n" +" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" +" #endif\n" +" vstore4(res,0,output+((idx_m*N_4+(idx_n_16 << 2)) << 2));\n" " \n" " if(idx_n_16*4+1 >= N_4) return;\n" " res0=vload4(0,input0+idx_m*alignN+(idx_n_16 << 4)+4);\n" " res1=vload4(0,input1+(idx_n_16 << 4)+4);\n" " res=res0+res1;\n" -" vstore4(res,0,output+((idx_m*N_16+idx_n_16) << 4)+4);\n" +" #ifdef RELU\n" +" res=fmax(res,(FLOAT4)0);\n" +" #endif\n" +" #ifdef RELU6\n" +" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" +" #endif\n" +" vstore4(res,0,output+((idx_m*N_4+(idx_n_16 << 2)) << 2)+4);\n" " \n" " if(idx_n_16*4+2 >= N_4) return;\n" " res0=vload4(0,input0+idx_m*alignN+(idx_n_16 << 4)+8);\n" " res1=vload4(0,input1+(idx_n_16 << 4)+8);\n" " res=res0+res1;\n" -" vstore4(res,0,output+((idx_m*N_16+idx_n_16) << 4)+8);\n" +" #ifdef RELU\n" +" res=fmax(res,(FLOAT4)0);\n" +" #endif\n" +" #ifdef RELU6\n" +" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" +" #endif\n" +" vstore4(res,0,output+((idx_m*N_4+(idx_n_16 << 2)) << 2)+8);\n" " \n" " if(idx_n_16*4+3 >= N_4) return;\n" " res0=vload4(0,input0+idx_m*alignN+(idx_n_16 << 4)+12);\n" " res1=vload4(0,input1+(idx_n_16 << 4)+12);\n" " res=res0+res1;\n" -" vstore4(res,0,output+((idx_m*N_16+idx_n_16) << 4)+12);\n" +" #ifdef RELU\n" +" res=fmax(res,(FLOAT4)0);\n" +" #endif\n" +" #ifdef RELU6\n" +" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" +" #endif\n" +" vstore4(res,0,output+((idx_m*N_4+(idx_n_16 << 2)) << 2)+12);\n" " }\n" +"#else\n" +" const int idx_m=get_global_id(0); // idx M\n" +" const int idx_n_16=get_global_id(1); // idx N\n" +" UNIFORM_BOUNDRY_CHECK(idx_m,idx_n_16);\n" +" \n" +" const int N_4=(N+3) >> 2;\n" +" const int idx_b=idx_m/area;\n" +" const int idx_area=idx_m % area;\n" +" \n" +" const int inp_base_offset=idx_m*alignN+(idx_n_16 << 4);\n" +" const int out_base_offset=((idx_b*N_4+idx_n_16*4)*area+idx_area)*4;\n" +" \n" +" FLOAT4 res0=vload4(0,input0+inp_base_offset);\n" +" FLOAT4 res1=vload4(0,input1+(idx_n_16 << 4));\n" +" FLOAT4 res=res0+res1;\n" +" #ifdef RELU\n" +" res=fmax(res,(FLOAT4)0);\n" +" #endif\n" +" #ifdef RELU6\n" +" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" +" #endif\n" +" vstore4(res,0,output+out_base_offset);\n" +" \n" +" if(idx_n_16*4+1 >= N_4) return;\n" +" res0=vload4(0,input0+inp_base_offset+4);\n" +" res1=vload4(0,input1+(idx_n_16 << 4)+4);\n" +" res=res0+res1;\n" +" #ifdef RELU\n" +" res=fmax(res,(FLOAT4)0);\n" +" #endif\n" +" #ifdef RELU6\n" +" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" +" #endif\n" +" vstore4(res,0,output+out_base_offset+area*4);\n" +" \n" +" if(idx_n_16*4+2 >= N_4) return;\n" +" res0=vload4(0,input0+inp_base_offset+8);\n" +" res1=vload4(0,input1+(idx_n_16 << 4)+8);\n" +" res=res0+res1;\n" +" #ifdef RELU\n" +" res=fmax(res,(FLOAT4)0);\n" +" #endif\n" +" #ifdef RELU6\n" +" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" +" #endif\n" +" vstore4(res,0,output+out_base_offset+area*8);\n" +" \n" +" if(idx_n_16*4+3 >= N_4) return;\n" +" res0=vload4(0,input0+inp_base_offset+12);\n" +" res1=vload4(0,input1+(idx_n_16 << 4)+12);\n" +" res=res0+res1;\n" +" #ifdef RELU\n" +" res=fmax(res,(FLOAT4)0);\n" +" #endif\n" +" #ifdef RELU6\n" +" res=clamp(res,(FLOAT4)0,(FLOAT4)6);\n" +" #endif\n" +" vstore4(res,0,output+out_base_offset+area*12);\n" +"#endif\n" "}\n" ; #endif @@ -14956,7 +15055,8 @@ const char* conv_2d_buf = " __private const int in_c_block,\n" " __private const int out_h,\n" " __private const int out_w,\n" -" __private const int out_c_block) {\n" +" __private const int out_c_block,\n" +" __private const int out_c_pack) {\n" " const int out_c_w_idx=get_global_id(0); //c/4 w\n" " const int out_b_h_idx=get_global_id(1); //b h\n" " DEAL_NON_UNIFORM_DIM2(out_c_w_idx,out_b_h_idx);\n" @@ -14971,37 +15071,42 @@ const char* conv_2d_buf = " COMPUTE_FLOAT4 out3=out0;\n" " const int intput_width_idx0=out_w4_idx;\n" " \n" -" int offset=mul24(out_c_idx,in_c_block) << 2;\n" +" int offset=out_c_idx*4;\n" " int inp_offset=(((out_b_idx*in_c_block)*out_h+out_h_idx)* out_w+intput_width_idx0) << 2;\n" " \n" " const int inp_add=out_h*out_w*4;\n" " for (ushort in_channel_block_idx=0; in_channel_block_idx local (matrix A)\n" " #if SA == 1\n" " GlobalToLocalA(agm,alm,kSizeM,tid,kwg);\n" @@ -20101,28 +20226,24 @@ const char* matmul_params_buf = " #if SB == 1\n" " GlobalToLocalB(bgm,blm,kSizeN,tid,kwg);\n" " #endif\n" -" #if SA == 1 || SB == 1\n" " barrier(CLK_LOCAL_MEM_FENCE);\n" -" #endif\n" " // Loops over all workitem tiles,unrolled by a factor KWI\n" -" for (int pwi=0; pwi private (matrix A)\n" -" #if GEMMK == 0 && SA == 1\n" +" #if SA == 1\n" " apm[_mi]=LocalToPrivateA(alm,_mi,kg);\n" " // Loads data: off-chip --> private (matrix A)\n" -" #elif GEMMK == 0 && SA == 0\n" -" apm[_mi]=GlobalToPrivateA(agm,_mi,kSizeM,idk,kwg);\n" +" #elif SA == 0\n" +" apm[_mi]=GlobalToPrivateA(agm,_mi,kSizeM,idk);\n" " #endif\n" " }\n" " // Loads matrix B (kernel 0) or matrix A (kernel 1)\n" @@ -20231,28 +20352,166 @@ const char* matmul_params_buf = " #endif\n" " }\n" " }\n" -" #if SA == 1 || SB == 1\n" " barrier(CLK_LOCAL_MEM_FENCE);\n" +" }\n" +" #else\n" +" // Allocates workitem-private memory (registers)\n" +" int baseIndexA=GlobalIndexA();\n" +" int baseIndexB=GlobalIndexB();\n" +" #pragma unroll\n" +" for (int _kj=0; _kj private (matrix B)\n" +" bpm[_ni]=GlobalToPrivateOptB(bgm,baseIndexB,_ni,kSizeN,idk);\n" +" }\n" +" #pragma unroll\n" +" for (int _mi=0; _mi private (matrix B)\n" +" apm[_mi]=GlobalToPrivateOptA(agm,baseIndexA,_mi,kSizeM,idk);\n" +" }\n" +" #pragma unroll\n" +" for (int _ni=0; _ni 0; i /= 2){ + #pragma unroll + for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i >>= 1){ if (lid < i) sum[lid] = fmax(sum[lid], sum[lid + i]); barrier(CLK_LOCAL_MEM_FENCE); @@ -225,7 +226,8 @@ __kernel void softmax_inside(GLOBAL_SIZE_3_DIMS } sum[lid] = sumValue; barrier(CLK_LOCAL_MEM_FENCE); - for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i /= 2){ + #pragma unroll + for(int i = SOFTMAX_LOCAL_SIZE/2; i > 0; i >>= 1){ if (lid < i) sum[lid] = sum[lid] + sum[lid + i]; barrier(CLK_LOCAL_MEM_FENCE); @@ -245,18 +247,7 @@ __kernel void softmax_inside(GLOBAL_SIZE_3_DIMS #endif } } -#ifndef WGSW - #define WGSW 64 // work-group handle size W dimension -#endif -#ifndef WGSH - #define WGSH 64 // work-group handle size H dimension -#endif -#ifndef TSW - #define TSW 8 // thread handle size W dimension -#endif -#ifndef TSH - #define TSH 8 // thread handle size H dimension -#endif + // [N X Y4 4] -> [N Y X] __kernel void trans_3d_buf(__global const FLOAT* input, __global FLOAT* output, @@ -266,56 +257,31 @@ __kernel void trans_3d_buf(__global const FLOAT* input, ) { int b = get_global_id(2); - const int lidw = get_local_id(0); - const int lidh = get_local_id(1); - // group id - const int w = get_group_id(0) * WGSW; - const int h = get_group_id(1) * WGSH; - - int iw = lidw; - int jh = lidh; - - __local FLOAT4 localData[WGSW][WGSH/4];//w64h64 - - #pragma unroll - for(int i = 0; i < TSW; i++) { - int offset_w = i * WGSW / TSW + iw; - #pragma unroll - for(int j = 0; j < TSH / 4; j++) { - int offset_h = j * WGSH / TSH + jh; - // [TSW, WGSW / TSW] [TSH / 4, WGSH / TSH, 4] - localData[offset_w][offset_h] = vload4(0, input + ((b * width + (w+offset_w)) * height/4 + (h/4+offset_h)) * 4); - } - } - - barrier(CLK_LOCAL_MEM_FENCE); + const int w = get_global_id(0) << 3; + const int h = get_global_id(1) << 3; - // H offset: [WGSH / TSH, TSH / 4, 4] - // W offset: [WGSW / TSW, TSW / 4, 4] - int oh_base = jh * TSH / 4; - int ow_base = iw * TSW / 4; + const int inp_offset = (b * width + w) * height + h; + const int out_offset = (b * height + h) * width + w; - //#pragma unroll - for(int j = 0; j < TSH / 4; j++) { - int oh = oh_base + j; - - //#pragma unroll - for(int i = 0; i < TSW / 4; i++) { - int ow = ow_base + i; - - FLOAT4 value_0 = (localData[4*ow][oh]); - FLOAT4 value_1 = (localData[4*ow+1][oh]); - FLOAT4 value_2 = (localData[4*ow+2][oh]); - FLOAT4 value_3 = (localData[4*ow+3][oh]); - vstore4((FLOAT4){value_0.x, value_1.x, value_2.x, value_3.x}, 0, output + ((b * height + h + 4*oh+0) * width + w + 4 * ow)); - vstore4((FLOAT4){value_0.y, value_1.y, value_2.y, value_3.y}, 0, output + ((b * height + h + 4*oh+1) * width + w + 4 * ow)); - vstore4((FLOAT4){value_0.z, value_1.z, value_2.z, value_3.z}, 0, output + ((b * height + h + 4*oh+2) * width + w + 4 * ow)); - vstore4((FLOAT4){value_0.w, value_1.w, value_2.w, value_3.w}, 0, output + ((b * height + h + 4*oh+3) * width + w + 4 * ow)); - } - } + FLOAT8 value_0 = vload8(0, input+inp_offset); + FLOAT8 value_1 = vload8(0, input+inp_offset + height); + FLOAT8 value_2 = vload8(0, input+inp_offset + height + height); + FLOAT8 value_3 = vload8(0, input+inp_offset + height + height + height); + FLOAT8 value_4 = vload8(0, input+inp_offset + (height << 2)); + FLOAT8 value_5 = vload8(0, input+inp_offset + height * 5); + FLOAT8 value_6 = vload8(0, input+inp_offset + height * 6); + FLOAT8 value_7 = vload8(0, input+inp_offset + height * 7); + + vstore8((FLOAT8){value_0.s0, value_1.s0, value_2.s0, value_3.s0, value_4.s0, value_5.s0, value_6.s0, value_7.s0}, 0, output + out_offset); + vstore8((FLOAT8){value_0.s1, value_1.s1, value_2.s1, value_3.s1, value_4.s1, value_5.s1, value_6.s1, value_7.s1}, 0, output + out_offset + width); + vstore8((FLOAT8){value_0.s2, value_1.s2, value_2.s2, value_3.s2, value_4.s2, value_5.s2, value_6.s2, value_7.s2}, 0, output + out_offset + width + width); + vstore8((FLOAT8){value_0.s3, value_1.s3, value_2.s3, value_3.s3, value_4.s3, value_5.s3, value_6.s3, value_7.s3}, 0, output + out_offset + width + width + width); + vstore8((FLOAT8){value_0.s4, value_1.s4, value_2.s4, value_3.s4, value_4.s4, value_5.s4, value_6.s4, value_7.s4}, 0, output + out_offset + (width << 2)); + vstore8((FLOAT8){value_0.s5, value_1.s5, value_2.s5, value_3.s5, value_4.s5, value_5.s5, value_6.s5, value_7.s5}, 0, output + out_offset + width * 5); + vstore8((FLOAT8){value_0.s6, value_1.s6, value_2.s6, value_3.s6, value_4.s6, value_5.s6, value_6.s6, value_7.s6}, 0, output + out_offset + width * 6); + vstore8((FLOAT8){value_0.s7, value_1.s7, value_2.s7, value_3.s7, value_4.s7, value_5.s7, value_6.s7, value_7.s7}, 0, output + out_offset + width * 7); } - __kernel void clip_transpose_qkv(GLOBAL_SIZE_3_DIMS __global const FLOAT *input, // [Batch * mNumHead, ROUND_UP(mHeadDim, tile), ROUND_UP(seqLen, tile)] __global FLOAT *output, // [Batch, seqLen/4, mNumHead * mHeadDim, 4] diff --git a/source/backend/opencl/execution/cl/winogradTransform_buf.cl b/source/backend/opencl/execution/cl/winogradTransform_buf.cl index efb799643..0caf484b0 100644 --- a/source/backend/opencl/execution/cl/winogradTransform_buf.cl +++ b/source/backend/opencl/execution/cl/winogradTransform_buf.cl @@ -111,9 +111,6 @@ __kernel void winoTransSrcBuf2_3_1(GLOBAL_SIZE_DIM2 int batchIndex = pos.y / srcChannelC4; int srcZ = pos.y % srcChannelC4; int dstYOrigin = unitWidth * unitHeight_idx + unitWidth_idx; - int dstHeight = (unitWidth * unitHeight + 3) / 4; - int dstY = dstYOrigin / 4; - int dstX = dstYOrigin % 4 + 4 * dstXOrigin; batchIndex = batchOffset; { @@ -405,10 +402,7 @@ __kernel void winoTransDstBuf2_3_1(GLOBAL_SIZE_DIM2 int unitWidth_idx = pos.x % unitWidth; int unitHeight_idx = pos.x / unitWidth; int2 realPos = (int2)(unitWidth_idx, unitHeight_idx); - int srcWidth = (unitWidth * unitHeight + 3) / 4; int dstXOrigin = unitWidth * unitHeight_idx + unitWidth_idx; - int dstX = dstXOrigin / 4; - int dstY = 4 * pos.y + dstXOrigin % 4; int oz = pos.y % dstChannelC4; FLOAT4 bias = vload4(0, uBias+oz*4); diff --git a/source/core/Backend.hpp b/source/core/Backend.hpp index 111f09877..2605047ec 100644 --- a/source/core/Backend.hpp +++ b/source/core/Backend.hpp @@ -25,6 +25,21 @@ class Execution; class Runtime; class Backend; +struct RuntimeHint { + // 0: Defer, 1: Eager + int memoryAllocatorType = 0; + int winogradMemoryUsed = 3; + + // 0-100, 50 means litter core has 50% capacity of large core + int cpuDecreaseRate = 50; + int dynamicQuantOption = 0; + + // 0: Do not quantize kvcache, just store float + // 1: Only quantize key cache, use int8 asymmetric quantization + // 2: Only quantize value cache, use fp8 quantization + // 3: quantize both key and value cache as described above + int kvcacheQuantOption = 0; +}; /** abstract backend */ class Backend : public NonCopyable { @@ -48,11 +63,6 @@ class Backend : public NonCopyable { INDIRECT = 1 }; Mode mode = DIRECT; - enum Allocator { - DEFER = 0, - EAGER = 1 - }; - Allocator allocator = DEFER; }; /** backend buffer storage type */ @@ -232,21 +242,11 @@ class Runtime : public NonCopyable { Allocator_Defer = 0, Allocator_Eager = 1, }; - - void setWinogradMemoryLevel(int level) { - mWinogradMemoryLevel = level; - } - - int getWinogradMemoryLevel() const { - return mWinogradMemoryLevel; - } - - void setAllocatorType(int type) { - mAllocatorType = static_cast(type); + void setRuntimeHint(const RuntimeHint& hint) { + mHint = hint; } - - AllocatorType getAllocatorType() const { - return mAllocatorType; + const RuntimeHint& hint() const { + return mHint; } virtual CompilerType onGetCompilerType() const { @@ -260,6 +260,13 @@ class Runtime : public NonCopyable { */ virtual Backend* onCreate(const BackendConfig* config = nullptr) const = 0; + /** + @brief reset runtime + */ + virtual void onReset(int numberThread, const BackendConfig* config) { + // Do nothing + } + /** @brief clear unuseful resource @param level clear level: 0 - 100, bigger mean clear more, smaller mean cache more @@ -319,8 +326,7 @@ class Runtime : public NonCopyable { MNN_PUBLIC void waitAsyncWork(); private: std::future mFuture; - AllocatorType mAllocatorType = Allocator_Eager; - int mWinogradMemoryLevel = 3; + RuntimeHint mHint; }; /** abstract Runtime register */ diff --git a/source/core/Concurrency.h b/source/core/Concurrency.h index a3c06622f..08887eb9c 100644 --- a/source/core/Concurrency.h +++ b/source/core/Concurrency.h @@ -26,7 +26,7 @@ } \ ; \ auto cpuBn = (CPUBackend*)backend(); \ - MNN::ThreadPool::enqueue(std::move(task), cpuBn->taskIndex()); \ + MNN::ThreadPool::enqueue(std::move(task), cpuBn->taskIndex(), cpuBn->threadOpen() ? cpuBn->threadNumber() : 1); \ } #else diff --git a/source/core/ConvolutionCommon.cpp b/source/core/ConvolutionCommon.cpp index 6a333f0fa..2418bd211 100644 --- a/source/core/ConvolutionCommon.cpp +++ b/source/core/ConvolutionCommon.cpp @@ -9,6 +9,7 @@ #include "ConvolutionCommon.hpp" #include #include "backend/cpu/compute/CommonOptFunction.h" +#include "backend/cpu/CPUBackend.hpp" #include "half.hpp" #include "core/OpCommonUtils.hpp" #include "core/IDSTDecoder.hpp" @@ -187,16 +188,18 @@ void ConvolutionCommon::getConvParameters(std::shared_ptr *quanCommo } bool ConvolutionCommon::getConvInt8Parameters(const MNN::Convolution2D* conv2d, std::shared_ptr& quanCommon, Backend* backend, - const int8_t*& weight, int& weightSize, float*& scale, int32_t*& bias) { + const int8_t*& weight, int& weightSize, float*& scale, int32_t*& bias, int32_t*& weightQuantZeroPoint) { int outputCount = conv2d->common()->outputCount(); weightSize = 0; + auto core = static_cast(backend)->functions(); // fix xcode UndefinedBehaviorSanitizer - if (conv2d->symmetricQuan()->weight() != nullptr) { + if (conv2d->symmetricQuan() && conv2d->symmetricQuan()->weight() != nullptr) { weight = conv2d->symmetricQuan()->weight()->data(); weightSize = conv2d->symmetricQuan()->weight()->size(); } - if (conv2d->quanParameter() && conv2d->quanParameter()->buffer()) { + if (conv2d->quanParameter() && conv2d->quanParameter()->buffer()) { // int8 weight quanCommon = ConvolutionCommon::load(conv2d, backend, false, true); + MNN_ASSERT(quanCommon != nullptr); weight = quanCommon->weight.get(); weightSize = quanCommon->weight.size(); } @@ -204,16 +207,47 @@ bool ConvolutionCommon::getConvInt8Parameters(const MNN::Convolution2D* conv2d, MNN_ERROR("ConvolutionCommon::getConvInt8Parameters: No weight data!"); return false; } - if (conv2d->symmetricQuan()->bias() && conv2d->symmetricQuan()->scale()) { + bool weightAsy = false; + if (quanCommon && quanCommon->asymmetric) { + weightAsy = true; + } + if (conv2d->symmetricQuan() && conv2d->symmetricQuan()->bias() && conv2d->symmetricQuan()->scale()) { // Compability for old model MNN_ASSERT(conv2d->symmetricQuan()->bias()->size() == outputCount && conv2d->symmetricQuan()->scale()->size() == outputCount); ::memcpy(bias, conv2d->symmetricQuan()->bias()->data(), outputCount * sizeof(int32_t)); ::memcpy(scale, conv2d->symmetricQuan()->scale()->data(), outputCount * sizeof(float)); return true; } - if (conv2d->bias() && conv2d->quanParameter()->alpha()) { + if (conv2d->bias()) { ::memcpy(bias, conv2d->bias()->data(), outputCount * sizeof(float)); - ::memcpy(scale, conv2d->quanParameter()->alpha()->data(), outputCount * sizeof(float)); + } + if (conv2d->quanParameter() && conv2d->quanParameter()->alpha()) { + auto alphaAndBeta = conv2d->quanParameter()->alpha()->data(); + int quantCount = conv2d->quanParameter()->alpha()->size(); + if (false == weightAsy) { // symmetric quant + if (core->bytes == 2) { + core->MNNFp32ToLowp(quanCommon->alpha.get(), reinterpret_cast(scale), quantCount); + } else { + ::memcpy(scale, conv2d->quanParameter()->alpha()->data(), quantCount * core->bytes); + } + } else if (true == weightAsy) { // asymmetric + // int ocx2 = 2 * outputCount; + int scaleSize = quantCount / 2; + float clampMin = conv2d->quanParameter()->aMin() == 0 ? -128 : conv2d->quanParameter()->aMin(); + if (core->bytes == 2) { + std::unique_ptr tmp(new int16_t[quantCount]); + core->MNNFp32ToLowp(alphaAndBeta, tmp.get(), quantCount); + for (int i = 0; i < scaleSize; ++i) { + weightQuantZeroPoint[i] = static_cast(roundf((-1) * tmp[2 * i] / tmp[2 * i + 1]) + clampMin); + reinterpret_cast(scale)[i] = tmp[2 * i + 1]; + } + } else { + for (int i = 0; i < scaleSize; ++i) { + weightQuantZeroPoint[i] = static_cast(roundf((-1) * alphaAndBeta[2 * i] / alphaAndBeta[2 * i + 1]) + clampMin); + scale[i] = alphaAndBeta[2 * i + 1]; + } + } + } return true; } MNN_ERROR("ConvolutionCommon::getConvInt8Parameters: No bias & scale data!"); diff --git a/source/core/ConvolutionCommon.hpp b/source/core/ConvolutionCommon.hpp index a61daa38f..28e3acf83 100644 --- a/source/core/ConvolutionCommon.hpp +++ b/source/core/ConvolutionCommon.hpp @@ -27,7 +27,7 @@ class MNN_PUBLIC ConvolutionCommon : public Execution { static std::shared_ptr load(const Convolution2D* conv, Backend* backend = nullptr, bool forceFloat = false, bool forceInt8 = false); static void getConvParameters(std::shared_ptr *quanCommon, Backend* backend, const MNN::Convolution2D *conv2d, const float** originWeight, int* originWeightSize); static bool getConvInt8Parameters(const MNN::Convolution2D* conv2d, std::shared_ptr& quanCommon, Backend* backend, - const int8_t*& weight, int& weightSize, float*& scale, int32_t*& bias); + const int8_t*& weight, int& weightSize, float*& scale, int32_t*& bias, int32_t*& weightQuantZero); // Return padX, padY static std::pair convolutionPad(const Tensor* input, const Tensor* output, diff --git a/source/core/FileLoader.cpp b/source/core/FileLoader.cpp index 021dcc5de..1b183ea5e 100644 --- a/source/core/FileLoader.cpp +++ b/source/core/FileLoader.cpp @@ -11,7 +11,7 @@ #include "Windows.h" #endif namespace MNN { -static FILE* _OpenFile(const char* file) { +static FILE* _OpenFile(const char* file, bool read) { #if defined(_MSC_VER) wchar_t wFilename[1024]; if (0 == MultiByteToWideChar(CP_ACP, 0, file, -1, wFilename, sizeof(wFilename))) { @@ -19,16 +19,31 @@ static FILE* _OpenFile(const char* file) { } #if _MSC_VER >= 1400 FILE* mFile = nullptr; - if (0 != _wfopen_s(&mFile, wFilename, L"rb")) { - return nullptr; + if (read) { + if (0 != _wfopen_s(&mFile, wFilename, L"rb")) { + return nullptr; + } + } else { + if (0 != _wfopen_s(&mFile, wFilename, L"wb")) { + return nullptr; + } } return mFile; #else - return _wfopen(wFilename, L"rb"); + if (read) { + return _wfopen(wFilename, L"rb"); + } else { + return _wfopen(wFilename, L"wb"); + } #endif #else - return fopen(file, "rb"); + if (read) { + return fopen(file, "rb"); + } else { + return fopen(file, "wb"); + } #endif + return nullptr; } FileLoader::FileLoader(const char* file, bool init) { if (nullptr == file) { @@ -86,7 +101,7 @@ bool FileLoader::read() { } bool FileLoader::write(const char* filePath, std::pair cacheInfo) { - FILE* f = fopen(filePath, "wb"); + FILE* f = _OpenFile(filePath, false); if (nullptr == f) { MNN_ERROR("Open %s error\n", filePath); return false; @@ -132,7 +147,7 @@ void FileLoader::_init() { } mInited = true; if (!mFilePath.empty()) { - mFile = _OpenFile(mFilePath.c_str()); + mFile = _OpenFile(mFilePath.c_str(), true); } if (nullptr == mFile) { MNN_ERROR("Can't open file:%s\n", mFilePath.c_str()); diff --git a/source/core/IDSTDecoder.hpp b/source/core/IDSTDecoder.hpp index 05f61ca77..679e92fcc 100644 --- a/source/core/IDSTDecoder.hpp +++ b/source/core/IDSTDecoder.hpp @@ -11,7 +11,6 @@ #include #include -#include #include "MNN_generated.h" #include "core/ConvolutionCommon.hpp" diff --git a/source/core/Interpreter.cpp b/source/core/Interpreter.cpp index 5078d1493..127bd6e52 100644 --- a/source/core/Interpreter.cpp +++ b/source/core/Interpreter.cpp @@ -221,7 +221,7 @@ Interpreter::Interpreter(Content* net) { mNet->bizCode = std::string(mNet->net->bizCode() ? mNet->net->bizCode()->c_str() : ""); mNet->uuid = std::string(mNet->net->mnn_uuid() ? mNet->net->mnn_uuid()->c_str() : ""); #ifdef MNN_INTERNAL_ENABLED - mNet->basicLogginData = getBasicLoggingData(); + mNet->basicLogginData = logBasicInfo(); mNet->basicLogginData.emplace("ModelVersion", getModelVersion()); #endif } @@ -238,8 +238,6 @@ Interpreter::~Interpreter() { Session* Interpreter::createMultiPathSession(const std::vector& configs) { RuntimeInfo runtime = createRuntime(configs); - runtime.second->setAllocatorType(mNet->modes.memoryAllocatorType); - runtime.second->setWinogradMemoryLevel(mNet->modes.winogradMemoryUsed); if (runtime.first.empty()) { MNN_ERROR("Runtime not valid for create session\n"); return nullptr; @@ -248,6 +246,11 @@ Session* Interpreter::createMultiPathSession(const std::vector& } Session* Interpreter::createMultiPathSession(const std::vector& configs, const RuntimeInfo& runtime) { + for (auto& iter : runtime.first) { + iter.second->setRuntimeHint(mNet->modes.runtimeHint); + } + runtime.second->setRuntimeHint(mNet->modes.runtimeHint); + if (nullptr == mNet->buffer.get()) { MNN_ERROR("The model buffer has been released. Can't create session\n"); return nullptr; @@ -267,6 +270,10 @@ Session* Interpreter::createMultiPathSession(const std::vector& if (!success) { return nullptr; } + if (info.needInputContentForShape) { + MNN_ERROR("Interpreter don't support case for shape compute need input content, please use module api instead\n"); + return nullptr; + } RuntimeInfo rt = runtime; bool valid = false; if (mNet->cacheBuffer.get() != nullptr) { diff --git a/source/core/OpCommonUtils.cpp b/source/core/OpCommonUtils.cpp index fcf0dd32e..f5e385605 100644 --- a/source/core/OpCommonUtils.cpp +++ b/source/core/OpCommonUtils.cpp @@ -11,7 +11,6 @@ #include "MNN_generated.h" #include "Macro.h" #include -#include namespace MNN { Tensor::DimensionType OpCommonUtils::convertDimType(MNN_DATA_FORMAT dimensionFormat) { diff --git a/source/core/Session.cpp b/source/core/Session.cpp index 9ab6b460c..9b27d5e1f 100644 --- a/source/core/Session.cpp +++ b/source/core/Session.cpp @@ -67,10 +67,13 @@ void Session::ModeGroup::setHint(Interpreter::HintMode mode, int hint) { maxTuningNumber = hint; break; case Interpreter::MEM_ALLOCATOR_TYPE: - memoryAllocatorType = hint; + runtimeHint.memoryAllocatorType = hint; break; case Interpreter::WINOGRAD_MEMORY_LEVEL: - winogradMemoryUsed = hint; + runtimeHint.winogradMemoryUsed = hint; + break; + case Interpreter::CPU_LITTLECORE_DECREASE_RATE: + runtimeHint.cpuDecreaseRate = hint; break; case Interpreter::GEOMETRY_COMPUTE_MASK: geometryMask = hint; @@ -78,6 +81,12 @@ void Session::ModeGroup::setHint(Interpreter::HintMode mode, int hint) { case Interpreter::STRICT_CHECK_MODEL: checkNetBuffer = hint > 0; break; + case Interpreter::DYNAMIC_QUANT_OPTIONS: + runtimeHint.dynamicQuantOption = hint; + break; + case Interpreter::KVCACHE_QUANT_OPTIONS: + runtimeHint.kvcacheQuantOption = hint; + break; default: break; } diff --git a/source/core/Session.hpp b/source/core/Session.hpp index 7b3ac7caf..c753a6c51 100644 --- a/source/core/Session.hpp +++ b/source/core/Session.hpp @@ -33,11 +33,10 @@ class MNN_PUBLIC Session { Interpreter::SessionMode resizeMode = Interpreter::Session_Resize_Direct; Interpreter::SessionMode memoryUsageMode = Interpreter::Session_Memory_Collect; Interpreter::SessionMode codegenMode = Interpreter::Session_Codegen_Disable; - int memoryAllocatorType = 0; int maxTuningNumber = MNN_DEFAULT_TUNING_NUMBER; - int winogradMemoryUsed = 3; int geometryMask = 0xFFFF; bool checkNetBuffer = true; + RuntimeHint runtimeHint; void setHint(Interpreter::HintMode hint, int magic); void setMode(Interpreter::SessionMode mode); }; diff --git a/source/geometry/GeometryOPRegister.cpp b/source/geometry/GeometryOPRegister.cpp index 91a743a1a..11982a07b 100644 --- a/source/geometry/GeometryOPRegister.cpp +++ b/source/geometry/GeometryOPRegister.cpp @@ -9,6 +9,7 @@ extern void ___GeometryReshape___create__(); extern void ___GeometryReduce___create__(); extern void ___GeometryInnerProduct___create__(); extern void ___GeometryTopK___create__(); +extern void ___GeometryLayerNorm___create__(); extern void ___GeometryDepthToSpace___create__(); extern void ___GeometryBroadcastTo___create__(); extern void ___GeometryConvert___create__(); @@ -40,7 +41,6 @@ extern void ___GeometrySlice___create__(); extern void ___GeometryConcat___create__(); extern void ___GeometryUnary___create__(); extern void ___GeometryBinary___create__(); -extern void ___GeometryLayerNorm___create__(); void registerGeometryOps() { ___GeometryShape___create__(); @@ -51,6 +51,7 @@ ___GeometryReshape___create__(); ___GeometryReduce___create__(); ___GeometryInnerProduct___create__(); ___GeometryTopK___create__(); +___GeometryLayerNorm___create__(); ___GeometryDepthToSpace___create__(); ___GeometryBroadcastTo___create__(); ___GeometryConvert___create__(); @@ -82,6 +83,5 @@ ___GeometrySlice___create__(); ___GeometryConcat___create__(); ___GeometryUnary___create__(); ___GeometryBinary___create__(); -___GeometryLayerNorm___create__(); } } diff --git a/test.sh b/test.sh index 9d7c1b2d7..ef4edd95b 100755 --- a/test.sh +++ b/test.sh @@ -547,14 +547,22 @@ android_model_test() { models=`ls ~/AliNNModel/TestResource/` for model in $models do - adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./testModel.out ../AliNNModel/TestResource/$model/temp.bin ../AliNNModel/TestResource/$model/input_0.txt ../AliNNModel/TestResource/$model/output.txt 0 0.002" + if [ $model == 'mobilenetv1quan' ]; then + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./testModel.out ../AliNNModel/TestResource/$model/temp.bin ../AliNNModel/TestResource/$model/input_0.txt ../AliNNModel/TestResource/$model/output.txt 0 0.1" + else + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./testModel.out ../AliNNModel/TestResource/$model/temp.bin ../AliNNModel/TestResource/$model/input_0.txt ../AliNNModel/TestResource/$model/output.txt 0 0.002" + fi if [ $? -ne 0 ]; then fail_num=$[$fail_num+1] else pass_num=$[$pass_num+1] fi if [ "$OPENCL_CHANGE" ]; then - adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./testModel.out ../AliNNModel/TestResource/$model/temp.bin ../AliNNModel/TestResource/$model/input_0.txt ../AliNNModel/TestResource/$model/output.txt 3 0.002 1" + if [ $model == 'mobilenetv1quan' ]; then + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./testModel.out ../AliNNModel/TestResource/$model/temp.bin ../AliNNModel/TestResource/$model/input_0.txt ../AliNNModel/TestResource/$model/output.txt 3 0.1 1" + else + adb shell "cd /data/local/tmp/MNN&&export LD_LIBRARY_PATH=.&&./testModel.out ../AliNNModel/TestResource/$model/temp.bin ../AliNNModel/TestResource/$model/input_0.txt ../AliNNModel/TestResource/$model/output.txt 3 0.002 1" + fi if [ $? -ne 0 ]; then fail_cl_num=$[$fail_cl_num+1] else diff --git a/test/TestUtils.h b/test/TestUtils.h index 94fa9f4d5..6a5dd2c20 100644 --- a/test/TestUtils.h +++ b/test/TestUtils.h @@ -47,7 +47,7 @@ bool checkVector(const T* result, const T* rightData, int size, T threshold){ MNN_ASSERT(size >= 0); for(int i = 0; i < size; ++i){ if(fabs(result[i] - rightData[i]) > threshold){ - std::cout << i << " error, right: " << rightData[i] << ", compute: " << result[i] << std::endl; + std::cout << "No." << i << " error, right: " << rightData[i] << ", compute: " << result[i] << std::endl; return false; } } diff --git a/test/core/ThreadPoolTest.cpp b/test/core/ThreadPoolTest.cpp index 51dfaf8f1..a0103cfc5 100644 --- a/test/core/ThreadPoolTest.cpp +++ b/test/core/ThreadPoolTest.cpp @@ -20,17 +20,17 @@ class ThreadPoolTest : public MNNTestCase { std::vector threads; for (int i = 0; i < 10; ++i) { threads.emplace_back([i]() { - MNN::ThreadPool::init(10 - i); + int number = MNN::ThreadPool::init(10 - i); // initializer auto workIndex = ThreadPool::acquireWorkIndex(); FUNC_PRINT(workIndex); - ThreadPool::active(); + ThreadPool::active(number); auto func = [](int index) { FUNC_PRINT(index); std::this_thread::yield(); }; - ThreadPool::enqueue(std::make_pair(std::move(func), 10), workIndex); - ThreadPool::deactive(); + ThreadPool::enqueue(std::make_pair(std::move(func), 10), workIndex, number); + ThreadPool::deactive(number); ThreadPool::releaseWorkIndex(workIndex); }); } diff --git a/test/expr/ModuleTest.cpp b/test/expr/ModuleTest.cpp index d2f4fc19e..84fb16d11 100644 --- a/test/expr/ModuleTest.cpp +++ b/test/expr/ModuleTest.cpp @@ -851,7 +851,7 @@ class MemeoryUsageTest : public MNNTestCase { BackendConfig bnConfig; bnConfig.precision = (MNN::BackendConfig::PrecisionMode)precision; config.numThread = 1; - config.type = ExecutorScope::Current()->getAttr()->firstType.first; + config.type = ExecutorScope::Current()->getAttr()->firstType; config.backendConfig = &bnConfig; auto s1 = net->createSession(config); float memory = 0.0f; @@ -947,7 +947,7 @@ class ConstMemoryReplaceTest : public MNNTestCase { std::shared_ptr net(Interpreter::createFromBuffer((void*)bufferOutput, sizeOutput), Interpreter::destroy); ScheduleConfig config; config.numThread = 4; - config.type = ExecutorScope::Current()->getAttr()->firstType.first; + config.type = ExecutorScope::Current()->getAttr()->firstType; auto s1 = net->createSession(config); int resizeCode; net->getSessionInfo(s1, Interpreter::RESIZE_STATUS, &resizeCode); @@ -984,7 +984,7 @@ class MutlThreadConstReplaceTest : public MNNTestCase { BackendConfig bnConfig; bnConfig.precision = (MNN::BackendConfig::PrecisionMode)precision; config.numThread = 1; - config.type = ExecutorScope::Current()->getAttr()->firstType.first; + config.type = ExecutorScope::Current()->getAttr()->firstType; config.backendConfig = &bnConfig; std::vector threads; diff --git a/test/op/ConvInt8Test.cpp b/test/op/ConvInt8Test.cpp index 31f716046..428a37d72 100644 --- a/test/op/ConvInt8Test.cpp +++ b/test/op/ConvInt8Test.cpp @@ -257,7 +257,7 @@ class ConvInt8TestCommon : public MNNTestCase { // Because of round implement in ARM / X86 / PC may cause 1 / 0 / -1 diff, don't care about this error auto error = (int32_t)targetValue - (int32_t)computeResult; if (error * error > 1) { - MNN_PRINT("%d x %d, ConvInt8 result %d Error: %d -> %d\n", ow, oh, i, targetValue, computeResult); + MNN_PRINT("ic=%d, oc=%d, ow=%d, oh=%d, ConvInt8 result No.%d Error: right=%d, error=%d\n", channel[0], channel[1], ow, oh, i, targetValue, computeResult); #ifdef DEBUG x->writeMap(); auto ptr = y->readMap(); @@ -293,7 +293,7 @@ class ConvInt8Im2colGemmTest : public ConvInt8TestCommon { std::vector> kernels = { {4, 2}, {1, 5}, {7, 1} }; - int iw = 24; int ih = 17; + int iw = 14; int ih = 11; std::vector titles = {"4x2", "1x5", "7x1"}; for (int sx=1; sx<2; ++sx) { for (int sy=1; sy<2; ++sy) { @@ -309,6 +309,7 @@ class ConvInt8Im2colGemmTest : public ConvInt8TestCommon { auto res = testKernel(inputShape, kernels[i], channel, pad, strides, dilate, 8, false, 1, 2, MNN::SparseAlgo_RANDOM, 1, false); if (!res) { MNN_ERROR("Error for test kernel %s for convint8 215, 204 (im2col + gemm)\n", titles[i].c_str()); + MNN_ERROR("overflow=false, bit=8, batch=2, Conv info: sx=%d, sy=%d, dx=%d, dy=%d, px=%d, py=%d, ic=%d, oc=%d\n", sx, sy, dx, dy, px, py, ic, oc); return false; } } @@ -316,6 +317,7 @@ class ConvInt8Im2colGemmTest : public ConvInt8TestCommon { auto res = testKernel(inputShape, kernels[i], channel, pad, strides, dilate, 3, true, 1, 3, MNN::SparseAlgo_RANDOM, 1, false); if (!res) { MNN_ERROR("Error for test kernel %s for convint8 215, 204 (im2col + gemm + overflow aware)\n", titles[i].c_str()); + MNN_ERROR("overflow=true,bit=3, batch=3, Conv info: sx=%d, sy=%d, dx=%d, dy=%d, px=%d, py=%d, ic=%d, oc=%d\n", sx, sy, dx, dy, px, py, ic, oc); return false; } } @@ -323,6 +325,7 @@ class ConvInt8Im2colGemmTest : public ConvInt8TestCommon { auto res = testKernel(inputShape, kernels[i], channel, pad, strides, dilate, 8, false, 1, 5, MNN::SparseAlgo_RANDOM, 1, false); if (!res) { MNN_ERROR("Error for test kernel %s for convint8 215, 201 (im2col + gemm)\n", titles[i].c_str()); + MNN_ERROR("overflow=false,bit=8, batch=5, Conv info: sx=%d, sy=%d, dx=%d, dy=%d, px=%d, py=%d, ic=%d, oc=%d\n", sx, sy, dx, dy, px, py, ic, oc); return false; } } @@ -330,6 +333,7 @@ class ConvInt8Im2colGemmTest : public ConvInt8TestCommon { auto res = testKernel(inputShape, kernels[i], channel, pad, strides, dilate, 3, true, 1, 2, MNN::SparseAlgo_RANDOM, 1, false); if (!res) { MNN_ERROR("Error for test kernel %s for convint8 215, 201 (im2col + gemm + overflow aware)\n", titles[i].c_str()); + MNN_ERROR("overflow=true,bit=3, batch=2, Conv info: sx=%d, sy=%d, dx=%d, dy=%d, px=%d, py=%d, ic=%d, oc=%d\n", sx, sy, dx, dy, px, py, ic, oc); return false; } } @@ -414,22 +418,22 @@ class SparseConvInt8Im2colGemmTest : public ConvInt8TestCommon { for (int i = 0; i < kernels.size(); ++i) { auto res = testKernel(inputShape, kernels[i], channel, pad, strides, dilate, 3, true, 1, 3, SparseList[is].first, SparseList[is].second, false); if (!res) { - MNN_ERROR("Error for test kernel %s for convint8 215, 204 (im2col + gemm + overflow aware)\n", titles[i].c_str()); + MNN_ERROR("Error for test kernel %s for convint8 (im2col + gemm + overflow aware)\n", titles[i].c_str()); return false; } } - inputShape = {215, 201}; + inputShape = {123, 65}; for (int i = 0; i < kernels.size(); ++i) { auto res = testKernel(inputShape, kernels[i], channel, pad, strides, dilate, 8, false, 1, 5, SparseList[is].first, SparseList[is].second, false); if (!res) { - MNN_ERROR("Error for test kernel %s for convint8 215, 201 (im2col + gemm)\n", titles[i].c_str()); + MNN_ERROR("Error for test kernel %s for convint8 (im2col + gemm)\n", titles[i].c_str()); return false; } } for (int i = 0; i < kernels.size(); ++i) { auto res = testKernel(inputShape, kernels[i], channel, pad, strides, dilate, 3, true, 1, 2, SparseList[is].first, SparseList[is].second, false); if (!res) { - MNN_ERROR("Error for test kernel %s for convint8 215, 201 (im2col + gemm + overflow aware)\n", titles[i].c_str()); + MNN_ERROR("Error for test kernel %s for convint8 (im2col + gemm + overflow aware)\n", titles[i].c_str()); return false; } } @@ -567,7 +571,7 @@ class ConvInt8WinogradTestCommon : public MNNTestCase { return false; } if (!checkVector(yPtr, yTargetPtr, yInfo->size, 1)) { - MNN_ERROR("[ConvInt8WinogradTestCommon] result error for batchSize = %d\n", batchSize); + MNN_ERROR("[ConvInt8WinogradTestCommon] result error for batchSize = %d, oc=%d, oh=%d, ow=%d\n", batchSize, yInfo->dim[1], yInfo->dim[2], yInfo->dim[3]); return false; } if (speed) { @@ -593,7 +597,7 @@ class ConvInt8WinogradTestCommon : public MNNTestCase { class ConvInt8WinogradTest : public ConvInt8WinogradTestCommon { virtual bool run(int precision) { - INTS pad = {1, 1}, inputShape = {128, 128}; // {w, h} + INTS pad = {1, 1}, inputShape = {47, 39}; // {w, h} INTS channel = {32, 32}; // {ci, co} std::vector> kernels = { diff --git a/test/op/ConvolutionTest.cpp b/test/op/ConvolutionTest.cpp index 836ace993..6c127d5a8 100644 --- a/test/op/ConvolutionTest.cpp +++ b/test/op/ConvolutionTest.cpp @@ -665,7 +665,7 @@ class ConvolutionInt8CommonTest : public ConvolutionCommonTest { MNN_PRINT("precision:%d, expect:\t expect2:\t real:\t\n", precision); for (int i = 0; i < toutputData.size(); ++i) { - MNN_PRINT("%f\t, %f\t, %f\n", toutputData[i],outputDataSeparateBias[i], outputPtr[i]); + MNN_PRINT("%f\t, %f\n", toutputData[i], outputPtr[i]); } MNN_ERROR("%s(%s) test failed for %d bits, async=%d , relu: %d, relu6: %d!\n", test_op_name.c_str(), device_name.c_str(), nbit, async, activation.first, activation.second); return false; diff --git a/test/op/DeconvolutionTest.cpp b/test/op/DeconvolutionTest.cpp index 87912379d..c4e46af32 100644 --- a/test/op/DeconvolutionTest.cpp +++ b/test/op/DeconvolutionTest.cpp @@ -119,7 +119,7 @@ class DeconvolutionCommonTestInt8 : public MNNTestCase { auto outputPtr = y->readMap(); float errorScale = precision <= MNN::BackendConfig::Precision_High ? 1 : 20; if (!checkVectorByRelativeError(outputPtr, rightOutData.data(), rightOutData.size(), 0.005 * errorScale)) { - MNN_ERROR("%s(%s) test failed!\n", test_op_name.c_str(), device_name.c_str()); + MNN_ERROR("%s(%s) test failed: batch=%d, oc=%d, oh=%d, ow=%d!\n", test_op_name.c_str(), device_name.c_str(), y->getInfo()->dim[0], y->getInfo()->dim[1], y->getInfo()->dim[2], y->getInfo()->dim[3]); return false; } return true; @@ -441,6 +441,75 @@ class DeconvolutionInt8Test : public DeconvolutionCommonTestInt8 { return false; } } + MNN_PRINT("begin testcase 3\n"); + { + std::vector data_a = {// channel 0 + 1.0, 2.0, 4.0, 5.0, + // channel 1 + 1.1, 2.1, 4.1, 5.1, + // channel 2 + 1.2, 2.2, 4.2, 5.2}; + + std::vector weight = {//IOHW + // input channel0 + + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + + // input channel1 + + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + + // input channel2 + + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + 2, 2, 2, 2, 2, 2, 2, 2, 2, + 1, 1, 1, 1, 1, 1, 1, 1, 1, + }; + std::vector bias(9, 0); + std::vector data_c = {3.3, 3.3, 9.6, 6.3, 3.3, 3.3, 9.6, 6.3, 15.6, 15.6, 37.2, + 21.6, 12.3, 12.3, 27.6, 15.3, + + 6.6, 6.6, 19.2, 12.6, 6.6, 6.6, 19.2, 12.6, 31.2, 31.2, 74.4, + 43.2, 24.6, 24.6, 55.2, 30.6}; + int ic = 3, oc = 9; + int kw = 3, kh = 3, ih = 2, iw = 2; + int stride = 2, dilation = 1; + int group = 1, batch = 1; + int pad_w = 0, pad_h = 0; + + std::vector scale = {1., 1.}; + std::vector zeroPoints = {0, 0}; + std::vector quantScales = {0.0416, 0.6112}; + + bool succ = DeconvolutionCommonTestInt8::test("CPU", "Deconv", data_a, weight, bias, data_c, + batch, ic, oc, ih, iw, PadMode_SAME, pad_h, pad_w, kh, kw, + stride, dilation, group, precision, scale, zeroPoints, quantScales); + if (!succ) { + return false; + } + } return true; } }; diff --git a/test/op/PReLUTest.cpp b/test/op/PReLUTest.cpp index d73050317..f6a3d1365 100644 --- a/test/op/PReLUTest.cpp +++ b/test/op/PReLUTest.cpp @@ -40,11 +40,19 @@ class PreluTestInt8 : public MNNTestCase { public: virtual ~PreluTestInt8() = default; virtual bool run(int precision) { - auto input = _Input({1, 4, 4, 2}, NCHW); + auto input = _Input({1, 12, 4, 2}, NCHW); input->setName("input_tensor"); // set input data input->writeScaleMap(0.03567, 1.0); const float inpudata[] = {-1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, + 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, + -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, + 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, + -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, + 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, + -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, + 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, + -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, -1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0}; @@ -52,12 +60,21 @@ class PreluTestInt8 : public MNNTestCase { memcpy(inputPtr, inpudata, 4 * sizeof(float)); input->unMap(); input = _Convert(input, NC4HW4); - auto output = _PRelu(input, {3.0, 1.5, 1.5, 1.5}); + auto output = _PRelu(input, {3.0, 1.5, 1.5, 1.5, 3.0, 1.5, 1.5, 1.5, 3.0, 1.5, 1.5, 1.5}); output = _Convert(output, NCHW); const std::vector expectedOutput = {-3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, - 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0}; + 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, + -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, + 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, + -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, + 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, + -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, -3.0, + 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, + -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, -4.5, + 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0, 4.0 + }; output->writeScaleMap(0.03567, 1.0); auto gotOutput = output->readMap(); if (!checkVector(gotOutput, expectedOutput.data(), 4, 0.05)) { diff --git a/test/op/SoftmaxTest.cpp b/test/op/SoftmaxTest.cpp index 8a3578591..ebbed4224 100644 --- a/test/op/SoftmaxTest.cpp +++ b/test/op/SoftmaxTest.cpp @@ -282,53 +282,6 @@ class SoftmaxInt8Test: public MNNTestCase { } } } - - // testcase 2 - { - auto input = _Input({2, 5}, NCHW); - input->setName("input_tensor"); - // set input data - const float inpudata[] = {1.0, 2.0, 3.0, 4.0, 5.0, -1.0, -2.0, -3.0, -4.0, -5.0}; - const float quantScales[] = {1.0, 0.00784}; - const float zeroPoints[] = {1., 2.}; - input->writeScaleMap(quantScales[0], zeroPoints[0]); - auto inputPtr = input->writeMap(); - memcpy(inputPtr, inpudata, 10 * sizeof(float)); - input->unMap(); - auto output = _Softmax(input); - const std::vector expectedOrder = {0, 1, 2, 3, 4, 9, 8, 7, 6, 5}; - const std::vector expectedOutput = {0.0117, 0.0317, 0.0861, 0.2341, 0.6364, 0.6364, 0.2341, 0.0861, 0.0317, 0.0117}; - output->writeScaleMap(quantScales[1], zeroPoints[1]); - auto gotOutput = output->readMap(); - bool result = checkProbAndOrder((float*)gotOutput, expectedOutput.data(), expectedOrder.data(), 10, {2, 5}, 1); - if (!result) { - MNN_PRINT("SoftmaxInt8 case2 failed!\n"); - return false; - } - } - // testcase 3 - { - auto input = _Input({2, 2}, NCHW); - input->setName("input_tensor"); - // set input data - const float inpudata[] = {-1.0, -2.0, 3.0, 4.0}; - const float quantScales[] = {1.0, 0.00784}; - const float zeroPoints[] = {1., 2.}; - input->writeScaleMap(quantScales[0], zeroPoints[0]); - auto inputPtr = input->writeMap(); - memcpy(inputPtr, inpudata, 4 * sizeof(float)); - input->unMap(); - auto output = _Softmax(input); - const std::vector expectedOrder = {1, 2, 0, 3}; - const std::vector expectedOutput = {0.7310586, 0.26894143, 0.26894143, 0.7310586}; - output->writeScaleMap(quantScales[1], zeroPoints[1]); - auto gotOutput = output->readMap(); - bool result = checkProbAndOrder((float*)gotOutput, expectedOutput.data(), expectedOrder.data(), 4, {2, 2}, 1); - if (!result) { - MNN_PRINT("SoftmaxInt8 case3 failed!\n"); - return false; - } - } return true; } }; diff --git a/test/speed/HybridConvSpeedTest.cpp b/test/speed/HybridConvSpeedTest.cpp index 548596dd3..2354c4c58 100644 --- a/test/speed/HybridConvSpeedTest.cpp +++ b/test/speed/HybridConvSpeedTest.cpp @@ -63,22 +63,21 @@ class HybridConvSpeedTestCommon : public MNNTestCase { #else #define FLOAT_T float #endif + y = _Convert(y, NCHW); + yfp32 = _Convert(yfp32, NCHW); auto yPtr = y->readMap(); auto tgPtr = yfp32->readMap(); auto elesize = batch * oc * oh * ow; - float limit = 0.02f; - if (nbit < 8) { - limit = 0.1f; - } + float limit = 0.1f; for (int i = 0; i < elesize; ++i) { float targetValue = tgPtr[i], computeResult = yPtr[i]; float diff = targetValue - computeResult; float ratio = fabsf(diff) / fmax(targetValue, computeResult); if (targetValue != 0 && computeResult != 0 && ratio > limit) { - MNN_PRINT("HybridConv result Error: %f -> %f\n", targetValue, computeResult); + MNN_PRINT("%d result Error ratio=%f: right=%f, error=%f\n", i, ratio, targetValue, computeResult); return false; } else if ((targetValue == 0 || computeResult == 0) && fabsf(diff) > limit) { - MNN_PRINT("HybridConv result Error: %f -> %f\n", targetValue, computeResult); + MNN_PRINT("%d result Error ratio=%f: right=%f, error=%f\n", i, ratio, targetValue, computeResult); return false; } } @@ -103,9 +102,9 @@ class HybridConvSpeedInt8Test : public HybridConvSpeedTestCommon { public: virtual bool run(int precision) { INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}, inputShape = {1, 1}; // {w, h} - INTS channel0 = {2048, 512}; // {ci, co} + INTS channel0 = {2048, 512}; // {ic, co} INTS channel1 = {1496, 256}; - int batch[2] = {1, 13}; + int batch[2] = {23, 13}; std::vector kernels = {1, 1}; std::vector weightBits = {8, 4}; bool lowmemory = true; @@ -114,14 +113,14 @@ class HybridConvSpeedInt8Test : public HybridConvSpeedTestCommon { for (int n = 0; n < 2; ++n) { auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel0, pad, strides, dilate, batch[n], bits, precision, true); if (!res) { - MNN_ERROR("Error: low memory hybridConv when n=%d, ci=%d, c0=%d\n", batch[n], channel0[0], channel0[1]); + MNN_ERROR("Error: low memory hybridConv when n=%d, ic=%d, oc=%d\n", batch[n], channel0[0], channel0[1]); return false; } } for (int n = 0; n < 2; ++n) { auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel1, pad, strides, dilate, batch[n], bits, precision, true); if (!res) { - MNN_ERROR("Error: low memory hybridConv when n=%d, ci=%d, c0=%d\n", batch[n], channel1[0], channel1[1]); + MNN_ERROR("Error: low memory hybridConv when n=%d, ic=%d, oc=%d\n", batch[n], channel1[0], channel1[1]); return false; } } @@ -133,26 +132,22 @@ class HybridConvSpeedInt8Test : public HybridConvSpeedTestCommon { class HybridConvInt8Test : public HybridConvSpeedTestCommon { public: virtual bool run(int precision) { - INTS channel0 = {2048, 512}; // {ci, co} - INTS channel1 = {1496, 256}; + std::vector< std::vector> channels = {{7, 9}, {2048, 6144}, {1, 10}, {20, 153}, {9, 18}}; INTS strides = {1, 1}, dilate = {1, 1}, pad = {0, 0}, inputShape = {1, 1}; // {w, h} - int batch[2] = {1, 13}; + int testBatchCount = 5; + // std::vector batch(testBatchCount); + std::vector batch = {1, 23, 1479, 38, 29}; std::vector kernels = {1, 1}; std::vector weightBits = {8}; bool lowmemory = true; for (auto& bits : weightBits) { - for (int n = 0; n < 2; ++n) { - auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel0, pad, strides, dilate, batch[n], bits, precision); - if (!res) { - MNN_ERROR("Error: low memory hybridConv when n=%d, ci=%d, c0=%d\n", batch[n], channel0[0], channel0[1]); - return false; - } - } - for (int n = 0; n < 2; ++n) { - auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel1, pad, strides, dilate, batch[n], bits, precision); - if (!res) { - MNN_ERROR("Error: low memory hybridConv when n=%d, ci=%d, c0=%d\n", batch[n], channel1[0], channel1[1]); - return false; + for (int i = 0; i < channels.size(); ++i) { + for (int n = 0; n < batch.size(); ++n) { + auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channels[i], pad, strides, dilate, batch[n], bits, precision); + if (!res) { + MNN_ERROR("Error: low memory hybridConv when n=%d, ic=%d, oc=%d\n", batch[n], channels[i][0], channels[i][1]); + return false; + } } } } @@ -163,8 +158,7 @@ class HybridConvInt8Test : public HybridConvSpeedTestCommon { class DenseConvInt8Test : public HybridConvSpeedTestCommon { public: virtual bool run(int precision) { - INTS channel0 = {256, 256}; // {ci, co} - INTS channel1 = {1496, 256}; + std::vector< std::vector> channels = {{4, 256}, {2048, 256}, {1, 8}, {7, 9}}; INTS strides = {1, 1}, dilate = {1, 3}, pad = {0, 3}, inputShape = {1, 2640}; // {w, h} int batch[2] = {1, 13}; std::vector kernels = {1, 3}; @@ -173,17 +167,12 @@ class DenseConvInt8Test : public HybridConvSpeedTestCommon { int n = 0; for (auto& bits : weightBits) { for (int n = 0; n < 2; ++n) { - auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel0, pad, strides, dilate, batch[n], bits, precision); - if (!res) { - MNN_ERROR("Error: low memory hybridConv when n=%d, ci=%d, c0=%d\n", batch[n], channel0[0], channel0[1]); - return false; - } - } - for (int n = 0; n < 2; ++n) { - auto res = testKernel("Low memory HybridConv test:", inputShape, kernels, channel1, pad, strides, dilate, batch[n], bits, precision); - if (!res) { - MNN_ERROR("Error: low memory hybridConv when n=%d, ci=%d, c0=%d\n", batch[n], channel1[0], channel1[1]); - return false; + for (int i = 0; i < channels.size(); ++i) { + auto res = testKernel("Low memory ConvInt8 with 1x3 kernel test:", inputShape, kernels, channels[i], pad, strides, dilate, batch[n], bits, precision); + if (!res) { + MNN_ERROR("Error: low memory ConvInt8 with 1x3 kernel when n=%d, ic=%d, oc=%d\n", batch[n], channels[i][0], channels[i][1]); + return false; + } } } } diff --git a/tools/converter/source/common/cli.cpp b/tools/converter/source/common/cli.cpp index ffe8c8ae9..89e83ab26 100644 --- a/tools/converter/source/common/cli.cpp +++ b/tools/converter/source/common/cli.cpp @@ -40,7 +40,7 @@ #include "core/MemoryFormater.h" namespace MNN { - +using namespace MNN::Express; static std::string _getDataType(const halide_type_t& type) { switch (type.code) { case halide_type_float: @@ -153,7 +153,7 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv ) ( "keepInputFormat", - "keep input dimension format or not, default: false", + "keep input dimension format or not, default: true", cxxopts::value() ) ( @@ -492,6 +492,151 @@ bool Cli::initializeMNNConvertArgs(modelConfig &modelPath, int argc, char **argv return true; } +typedef VARP (*unaryProc)(VARP input); +static unaryProc selectUnaryProc(int type) { + switch (type) { + case UnaryOpOperation_ABS: + return MNN::Express::_Abs; + case UnaryOpOperation_SQUARE: + return MNN::Express::_Square; + case UnaryOpOperation_NEG: + return MNN::Express::_Negative; + case UnaryOpOperation_RSQRT: + return MNN::Express::_Rsqrt; + case UnaryOpOperation_EXP: + return MNN::Express::_Exp; + case UnaryOpOperation_COS: + return MNN::Express::_Cos; + case UnaryOpOperation_SIN: + return MNN::Express::_Sin; + case UnaryOpOperation_SIGMOID: + return MNN::Express::_Sigmoid; + case UnaryOpOperation_TANH: + return MNN::Express::_Tanh; + case UnaryOpOperation_TAN: + return MNN::Express::_Tan; + case UnaryOpOperation_ATAN: + return MNN::Express::_Atan; + case UnaryOpOperation_SQRT: + return MNN::Express::_Sqrt; + case UnaryOpOperation_RECIPROCAL: + return MNN::Express::_Reciprocal; + case UnaryOpOperation_LOG1P: + return MNN::Express::_Log1p; + case UnaryOpOperation_LOG: + return MNN::Express::_Log; + case UnaryOpOperation_ACOSH: + return MNN::Express::_Acosh; + case UnaryOpOperation_SINH: + return MNN::Express::_Sinh; + case UnaryOpOperation_ASINH: + return MNN::Express::_Asinh; + case UnaryOpOperation_ATANH: + return MNN::Express::_Atanh; + case UnaryOpOperation_SIGN: + return MNN::Express::_Sign; + case UnaryOpOperation_COSH: + return MNN::Express::_Cosh; + case UnaryOpOperation_ERF: + return MNN::Express::_Erf; + case UnaryOpOperation_ERFC: + return MNN::Express::_Erfc; + case UnaryOpOperation_ERFINV: + return MNN::Express::_Erfinv; + case UnaryOpOperation_EXPM1: + return MNN::Express::_Expm1; + case UnaryOpOperation_ASIN: + return MNN::Express::_Asin; + case UnaryOpOperation_ACOS: + return MNN::Express::_Acos; + case UnaryOpOperation_HARDSWISH: + return MNN::Express::_Hardswish; + case UnaryOpOperation_GELU: + return MNN::Express::_Gelu; + default: + MNN_ASSERT(false); + break; + } + return nullptr; +} +static void computeUnaryBuffer(MNN::NetT* net) { + for (auto iter = net->oplists.begin(); iter != net->oplists.end(); ++iter) { + auto op = iter->get(); + auto opType = op->type; + std::map describes; + for (auto& des : net->extraTensorDescribe) { + describes.insert(std::make_pair(des->index, des.get())); + } + if (opType == MNN::OpType_Sigmoid || opType == MNN::OpType_TanH) { + op->type = OpType_UnaryOp; + op->main.value = new UnaryOpT; + op->main.type = OpParameter_UnaryOp; + op->main.AsUnaryOp()->opType = UnaryOpOperation_SIGMOID; + if (opType == MNN::OpType_TanH) { + op->main.AsUnaryOp()->opType = UnaryOpOperation_TANH; + } + opType = op->type; + } + if (opType == MNN::OpType_UnaryOp) { + auto type = op->main.AsUnaryOp()->opType; + if (type == UnaryOpOperation_ABS || type == UnaryOpOperation_NEG || type == UnaryOpOperation_SIGN) { + continue; + } + op->main.AsUnaryOp()->tableInt8.resize(255); + auto unaryParam = op->main.AsUnaryOp()->tableInt8.data(); + + auto outputId = op->outputIndexes[0]; + if (describes.find(outputId) == describes.end()) { + continue; + } + auto unaryDes = describes.find(outputId)->second; + float outScale = unaryDes->quantInfo->scale; + float outZero = unaryDes->quantInfo->zero; + auto inputId = op->inputIndexes[0]; + if (describes.find(inputId) == describes.end()) { + auto iter = describes.find(outputId); + + } + unaryDes = describes.find(inputId)->second; + float inpScale = unaryDes->quantInfo->scale; + float inpZero = unaryDes->quantInfo->zero; + + // Read input data. + std::vector dataInput; + float fx = 0.f; + auto input = _Input({255}, NCHW, halide_type_of()); + input->setName("input_tensor"); + auto ptr_in = input->template writeMap(); + for (int i = -127; i <= 127; ++i) { + fx = (i - inpZero) * inpScale; + dataInput.push_back(fx); + ptr_in[i + 127] = fx; + } + input->unMap(); + // Compute output data. + VARP output; + auto func = selectUnaryProc(type); + if (nullptr == func) { + MNN_ERROR("Don't support quantizing UnaryOP: %s to Int8\n", op->name.c_str()); + } + output = func(input); + auto gotOutput = output->template readMap(); + // Write output data. + int val; + for (int i = 0; i < 255; ++i) { + val = (int)roundf(gotOutput[i] / outScale) + outZero; + if (val > 127) { + val = 127; + } + if (val < -127) { + val = -127; + } + unaryParam[i] = val; + } + } + } +} + bool Cli::convertModel(modelConfig& modelPath) { if (modelPath.dumpInfo) { dumpModelInfo(modelPath.modelFile.c_str()); @@ -555,6 +700,11 @@ bool Cli::convertModel(modelConfig& modelPath) { if (modelPath.model != modelConfig::MNN || modelPath.optimizeLevel >= 2) { std::cout << "Start to Optimize the MNN Net..." << std::endl; std::unique_ptr newNet = optimizeNet(netT, modelPath.forTraining, modelPath); + if (newNet->extraTensorDescribe.size()>0) { + MNN_PRINT("MNN net has tensor quant info\n"); + computeUnaryBuffer(newNet.get()); + } + error = writeFb(newNet, modelPath.MNNModel, modelPath); } else { error = writeFb(netT, modelPath.MNNModel, modelPath); diff --git a/tools/converter/source/optimizer/PostConverter.cpp b/tools/converter/source/optimizer/PostConverter.cpp index 97bd8ec39..26535ebe4 100644 --- a/tools/converter/source/optimizer/PostConverter.cpp +++ b/tools/converter/source/optimizer/PostConverter.cpp @@ -274,7 +274,11 @@ std::unique_ptr optimizeNetImpl(std::unique_ptr& originNet // Remove Invalid Cast "RemoveInvalidCast" }; - auto tensorDescribe = std::move(originNet->extraTensorDescribe); + std::vector> tensorDescribe; + if (originNet->extraTensorDescribe.size() > 0) { + tensorDescribe = std::move(originNet->extraTensorDescribe); + } + std::unique_ptr newNet; newNet = std::move(RunExtraPass(originNet, inputs)); RunNetPass(midOptPass, newNet); @@ -344,7 +348,9 @@ std::unique_ptr optimizeNetImpl(std::unique_ptr& originNet newNet = std::move(RunMergePass(newNet, inputs, PASS_PRIORITY_LOW)); newNet = std::move(RunMergePass(newNet, inputs, PASS_PRIORITY_FINAL)); - newNet->extraTensorDescribe = std::move(tensorDescribe); + if (tensorDescribe.size() > 0) { + newNet->extraTensorDescribe = std::move(tensorDescribe); + } RunNetPass({"ReIndexTensor"}, newNet); RunNetPass({"ReIndexOnnxIfAlias"}, newNet); diff --git a/tools/converter/source/optimizer/Program.cpp b/tools/converter/source/optimizer/Program.cpp index 28be3f37e..461a403ff 100644 --- a/tools/converter/source/optimizer/Program.cpp +++ b/tools/converter/source/optimizer/Program.cpp @@ -20,11 +20,11 @@ namespace MNN { namespace Express { void Program::createUnit(std::map& varMap, std::vector& inputIndexes, const std::vector>& oplists, MNN::OpT* op, const MNN::NetT* net, std::set& invalidSet, std::set& extraInputIndexes) { - createUnit(varMap, inputIndexes, oplists, op, net->tensorName, invalidSet, extraInputIndexes); + createUnit(varMap, inputIndexes, oplists, op, net->tensorName, invalidSet, extraInputIndexes, net); } void Program::createUnit(std::map& varMap, std::vector& inputIndexes, const std::vector>& oplists, - MNN::OpT* op, const std::vector& tensorName, std::set& invalidSet, std::set& extraInputIndexes, const MNN::NetT* net) { + MNN::OpT* op, const std::vector& tensorName, std::set& invalidSet, std::set& extraInputIndexes, const MNN::NetT* net, std::map TensorDescribeName) { if (invalidSet.find(op) != invalidSet.end()) { return; } @@ -46,7 +46,7 @@ void Program::createUnit(std::map& varMap, std::vector& inputInd for (int j = 0; j < oplists.size(); ++j) { for (auto outputIndex : oplists[j]->outputIndexes) { if (outputIndex == input) { - createUnit(varMap, inputIndexes, oplists, oplists[j].get(), tensorName, invalidSet, extraInputIndexes, net); + createUnit(varMap, inputIndexes, oplists, oplists[j].get(), tensorName, invalidSet, extraInputIndexes, net, TensorDescribeName); } } } @@ -69,10 +69,11 @@ void Program::createUnit(std::map& varMap, std::vector& inputInd } auto newVar = Variable::create(expr, j); newVar->setName(tensorName[outputIndexes[j]]); - if (op->type != OpType_ConvertTensor && nullptr != net && !net->extraTensorDescribe.empty()) { + if (nullptr != net && !net->extraTensorDescribe.empty()) { auto& extraDescribes = net->extraTensorDescribe; - int idx = outputIndexes[j]; - if (idx < extraDescribes.size() && nullptr != extraDescribes[idx] && nullptr != extraDescribes[idx]->quantInfo) { +// int idx = outputIndexes[j]; + if (TensorDescribeName.find(op->name) != TensorDescribeName.end()) { + int idx = TensorDescribeName[op->name]; float scale = extraDescribes[idx]->quantInfo->scale; float zero = extraDescribes[idx]->quantInfo->zero; newVar->writeScaleMap(scale, zero); @@ -112,9 +113,15 @@ std::shared_ptr Program::create(const std::vector> std::map varMap; std::vector inputIndexes; std::set extraInputIndexes; + std::map TensorDescribeName; + if (net && net->extraTensorDescribe.size() > 0) { + for (int i = 0; i < net->extraTensorDescribe.size(); ++i) { + TensorDescribeName.insert(std::make_pair(net->extraTensorDescribe[i]->name, i)); + } + } for (int index = 0; index < oplists.size(); ++index) { std::set invalidSet; - createUnit(varMap, inputIndexes, oplists, oplists[index].get(), tensorName, invalidSet, extraInputIndexes, net); + createUnit(varMap, inputIndexes, oplists, oplists[index].get(), tensorName, invalidSet, extraInputIndexes, net, TensorDescribeName); } std::map outputs; for (auto& iter : varMap) { diff --git a/tools/converter/source/optimizer/Program.hpp b/tools/converter/source/optimizer/Program.hpp index 6f27c6ab0..4a7cefe9d 100644 --- a/tools/converter/source/optimizer/Program.hpp +++ b/tools/converter/source/optimizer/Program.hpp @@ -36,7 +36,7 @@ class Program { void save(MNN::NetT* net); private: static std::shared_ptr create(const std::vector>& oplists, const std::vector& tensorName, const std::vector& outputName, bool supportExtra, bool saveAllVars, const MNN::NetT* net=nullptr); - static void createUnit(std::map& varMap, std::vector& inputIndexes, const std::vector>& oplists, MNN::OpT* op, const std::vector& tensorName, std::set& invalidSet, std::set& extraInputIndexes, const MNN::NetT* net=nullptr); + static void createUnit(std::map& varMap, std::vector& inputIndexes, const std::vector>& oplists, MNN::OpT* op, const std::vector& tensorName, std::set& invalidSet, std::set& extraInputIndexes, const MNN::NetT* net=nullptr, std::map TensorDescribeName = {}); Program() { } std::map mVars; diff --git a/tools/converter/source/optimizer/TemplateMerge.cpp b/tools/converter/source/optimizer/TemplateMerge.cpp index 4cdfeb537..d187dda18 100644 --- a/tools/converter/source/optimizer/TemplateMerge.cpp +++ b/tools/converter/source/optimizer/TemplateMerge.cpp @@ -156,10 +156,6 @@ bool TemplateMerge::onExecute(const std::vector& outputs, PassPriority pri } else { invalidVARP.insert(var); } - if (var->get() && var->get()->type() == 19) { - auto updateInputs = updateInputVarOfExpr(var); - updateVars.insert(updateInputs.begin(), updateInputs.end()); - } } } MNN::Express::ExecutorScope::Current()->gc(); diff --git a/tools/converter/source/optimizer/merge/ConvDeQuantizeLinearFuseToConvInt8.cpp b/tools/converter/source/optimizer/merge/ConvDeQuantizeLinearFuseToConvInt8.cpp index 849063dff..83d45af22 100644 --- a/tools/converter/source/optimizer/merge/ConvDeQuantizeLinearFuseToConvInt8.cpp +++ b/tools/converter/source/optimizer/merge/ConvDeQuantizeLinearFuseToConvInt8.cpp @@ -24,6 +24,17 @@ static VARP _ReshapeF(VARP x, VARP shape, MNN::MNN_DATA_FORMAT format) { reshape->main.AsReshape()->dimType = format; return (Variable::create(Expr::create(reshape.get(), {x, shape}))); } + +static VARP _ConvertF(VARP input, MNN::MNN_DATA_FORMAT format) { + std::unique_ptr convert(new OpT); + convert->type = OpType_ConvertTensor; + convert->main.type = OpParameter_TensorConvertInfo; + convert->main.value = new TensorConvertInfoT; + convert->main.AsTensorConvertInfo()->source = MNN_DATA_FORMAT_NC4HW4; + convert->main.AsTensorConvertInfo()->dest = format; + return (Variable::create(Expr::create(convert.get(), {input}))); +} + static bool matchConvInt8ToOther(EXPRP expr, int i) { // convint8->quant->cast->dequant->other // check op type not convint8. if (nullptr == expr->get()) { @@ -60,9 +71,16 @@ static bool matchConvInt8ToOther(EXPRP expr, int i) { // convint8->quant->cast-> VARP conv_var = quan_expr->inputs().at(0); EXPRP conv_expr = conv_var->expr().first; - if (!conv_expr->get() || (conv_expr->get()->type() != OpType_ConvInt8 && conv_expr->get()->type() != OpType_DepthwiseConvInt8)) { + if (!conv_expr->get() || (conv_expr->get()->type() != OpType_ConvInt8 && conv_expr->get()->type() != OpType_DepthwiseConvInt8 && conv_expr->get()->type() != OpType_ReLU && conv_expr->get()->type() != OpType_ReLU6)) { return false; } + if (conv_expr->get()->type() == OpType_ReLU || conv_expr->get()->type() == OpType_ReLU6) { + conv_var = conv_expr->inputs().at(0); + conv_expr = conv_var->expr().first; + if (!conv_expr->get() || (conv_expr->get()->type() != OpType_ConvInt8 && conv_expr->get()->type() != OpType_DepthwiseConvInt8)) { + return false; + } + } return true; } static VARP transformConvInt8ToOther(EXPRP expr, int i) { // convint8->quant->cast->dequant->other => convInt8(float output)->other @@ -75,37 +93,74 @@ static VARP transformConvInt8ToOther(EXPRP expr, int i) { // convint8->quant->ca auto conv_var = quan_expr->inputs().at(0); auto conv_expr = conv_var->expr().first; auto convInt8Input = conv_expr->inputs().at(0); + bool hasRelu = false, hasRelu6 = false; + if (conv_expr->get()->type() == OpType_ReLU || conv_expr->get()->type() == OpType_ReLU6) { + hasRelu = conv_expr->get()->type() == OpType_ReLU ? true : false; + hasRelu6 = conv_expr->get()->type() == OpType_ReLU6 ? true : false; + conv_expr = convInt8Input->expr().first; + convInt8Input = conv_expr->inputs().at(0); + } // change old convInt8 to return a float value, which is input to expr; std::unique_ptr newConvInt8(new MNN::Convolution2DT); std::unique_ptr oldConvOp(conv_expr->get()->UnPack()); auto oldConvParams = oldConvOp->main.AsConvolution2D(); + + float output_zero = oldConvParams->symmetricQuan->outputZeroPoint; + float output_scale = oldConvParams->quanParameter->scaleOut; + float input_scale = oldConvParams->quanParameter->scaleIn; + float input_zero = oldConvParams->symmetricQuan->zeroPoint; + newConvInt8->common.reset(new MNN::Convolution2DCommonT); newConvInt8->common = std::move(oldConvParams->common); + newConvInt8->common->relu = hasRelu; + newConvInt8->common->relu6 = hasRelu6; newConvInt8->symmetricQuan.reset(new QuantizedFloatParamT); newConvInt8->symmetricQuan = std::move(oldConvParams->symmetricQuan); - newConvInt8->symmetricQuan->outputDataType = MNN::DataType_DT_FLOAT; - // newConvInt8->bias = std::move(oldConvParams->bias); - // newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); - - //Update newConvInt8 scale - float outputScale = quan_expr->inputs().at(2)->readMap()[0]; - int oc = static_cast(newConvInt8->symmetricQuan->scale.size()); - float* ptr = newConvInt8->symmetricQuan->scale.data(); - for (int i = 0; i < oc; ++i) { - ptr[i] = ptr[i] * outputScale; - } + //newConvInt8->symmetricQuan->outputDataType = MNN::DataType_DT_FLOAT; + newConvInt8->quanParameter.reset(new IDSTQuanT); + newConvInt8->bias = std::move(oldConvParams->bias); + newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); std::unique_ptr conv_op(new OpT); conv_op->name = conv_expr->name(); - conv_op->type = oldConvOp->type; + conv_op->type = OpType_ConvInt8; conv_op->main.type = OpParameter_Convolution2D; conv_op->main.value = newConvInt8.release(); + convInt8Input->writeScaleMap(input_scale, input_zero); auto newconv_expr = Expr::create(conv_op.get(), {convInt8Input}); newconv_expr->setName(conv_expr->name()); auto newconv_var = Variable::create(newconv_expr); newconv_var->setName(conv_expr->outputName(0)); + newconv_var->writeScaleMap(output_scale, output_zero); + if (conv_expr->inputs().size() == 5) { // Process matmul output + auto config = Global::Get(); + auto format = MNN::MNN_DATA_FORMAT_NCHW; + if (config->model == modelConfig::TFLITE || config->model == modelConfig::TENSORFLOW) { + format = MNN_DATA_FORMAT_NHWC; + } + // expr->inputs = {input, concat, needSqueezeA, needSqueezeB, transposeA} + auto concat_var = conv_expr->inputs().at(1); + bool needSqueezeA = conv_expr->inputs().at(2)->readMap()[0] > 0.f; + bool needSqueezeB = conv_expr->inputs().at(3)->readMap()[0] > 0.f; + + auto output = _ConvertF(newconv_var, format); + output->writeScaleMap(output_scale, output_zero); + VARP reshapeVar = _ReshapeF(output, concat_var, format); + reshapeVar->writeScaleMap(output_scale, output_zero); + if (needSqueezeA) { + reshapeVar = _Squeeze(reshapeVar, {0}); + reshapeVar->writeScaleMap(output_scale, output_zero); + } + if (needSqueezeB) { + reshapeVar = _Squeeze(reshapeVar, {1}); + reshapeVar->writeScaleMap(output_scale, output_zero); + } + reshapeVar->setName(expr->outputName(0) + "__matmul_cvt_convInt8_reshape"); + Expr::replace(conv_expr, reshapeVar->expr().first); + return reshapeVar; + } Expr::replace(conv_expr, newconv_expr); return newconv_var; @@ -162,9 +217,12 @@ static VARP transformOtherToOther (EXPRP expr, int i) { // ohter->quant->cast->d auto cast_expr = cast_var->expr().first; auto quan_var = cast_expr->inputs().at(0); auto quan_expr = quan_var->expr().first; - auto other_var = quan_expr->inputs().at(0); + auto input_var = quan_expr->inputs().at(0); - return other_var; + float scale = quan_expr->inputs().at(2)->readMap()[0]; + float zero = quan_expr->inputs().at(3)->readMap()[0]; + input_var->writeScaleMap(scale, zero); + return input_var; } static VARP buildInputForMatmulInt8 (VARP input, VARP transposeA, VARP SqueezeA, int num_input) { auto transposeAType = transposeA->expr().first; @@ -197,6 +255,41 @@ static VARP buildInputForMatmulInt8 (VARP input, VARP transposeA, VARP SqueezeA, return newInput; } +static EXPRP buildNewConvExpr(EXPRP oldConvExpr, VARP convInput, std::vector updateInfo = {}) { + std::unique_ptr newConvInt8(new MNN::Convolution2DT); + std::unique_ptr oldConvOp(oldConvExpr->get()->UnPack()); + auto oldConvParams = oldConvOp->main.AsConvolution2D(); + newConvInt8->common.reset(new MNN::Convolution2DCommonT); + newConvInt8->common = std::move(oldConvParams->common); + newConvInt8->symmetricQuan.reset(new QuantizedFloatParamT); + newConvInt8->symmetricQuan = std::move(oldConvParams->symmetricQuan); + newConvInt8->quanParameter.reset(new IDSTQuanT); + newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); + newConvInt8->bias = std::move(oldConvParams->bias); + + if (updateInfo.size() > 0) { + newConvInt8->common->relu = updateInfo[0] ? true : false; + } + if (updateInfo.size() > 1) { + newConvInt8->common->relu6 = updateInfo[1] ? true : false; + } + if (updateInfo.size() > 2) { + newConvInt8->symmetricQuan->outputDataType = updateInfo[2] ? DataType_DT_FLOAT : DataType_DT_INT8; + } + float input_scale = newConvInt8->quanParameter->scaleIn; + float input_zero = newConvInt8->symmetricQuan->zeroPoint; + convInput->writeScaleMap(input_scale, input_zero); + + std::unique_ptr conv_op(new OpT); + conv_op->name = oldConvExpr->name(); + conv_op->type = oldConvOp->type; + conv_op->main.type = OpParameter_Convolution2D; + conv_op->main.value = newConvInt8.release(); + + auto new_conv_expr = Expr::create(conv_op.get(), {convInput}); + return new_conv_expr; +} + static auto gRegister = []() { // convInt8->(relu)->quant->cast->dequant->convInt8 auto matchConvInt8ToConvInt8 = [](EXPRP expr) { // check convInt8 @@ -259,33 +352,98 @@ static auto gRegister = []() { // convInt8->(relu)->quant->cast->dequant->convIn auto quan_var = cast_expr->inputs().at(0); auto quan_expr = quan_var->expr().first; auto convInt8Input = quan_expr->inputs().at(0); - if (expr->inputs().size() == 3) { + /* conv params*/ + std::unique_ptr newConvInt8(new MNN::Convolution2DT); + std::unique_ptr oldConvOp(expr->get()->UnPack()); + auto oldConvParams = oldConvOp->main.AsConvolution2D(); + float input_scale = oldConvParams->quanParameter->scaleIn; + float input_zero = oldConvParams->symmetricQuan->zeroPoint; + /* check */ + auto conv_var = quan_expr->inputs().at(0); + conv_var->writeScaleMap(input_scale, input_zero); + EXPRP conv_expr = conv_var->expr().first; + VARP first_conv_input_var = conv_expr->inputs().at(0); + if (conv_expr->get()->type() == OpType_PReLU || conv_expr->get()->type() == OpType_ReLU || conv_expr->get()->type() == OpType_ReLU6) { + auto relu_expr = conv_expr; + bool relu_ = relu_expr->get()->type() == OpType_ReLU ? true: false; + bool relu6_ = relu_expr->get()->type() == OpType_ReLU6 ? true: false; + VARP conv_var_0 = relu_expr->inputs().at(0); + conv_expr = conv_var_0->expr().first; + first_conv_input_var = conv_expr->inputs().at(0); + auto newFirstConvExpr = buildNewConvExpr(conv_expr, first_conv_input_var, {relu_, relu6_}); // write scale for first_conv_input_var + Expr::replace(conv_expr, newFirstConvExpr); + convInt8Input = Variable::create(conv_expr); + conv_var = convInt8Input; + conv_var->writeScaleMap(input_scale, input_zero); + } else { + auto newFirstConvExpr = buildNewConvExpr(conv_expr, first_conv_input_var); // Just write scale for first_conv_input_var, do not update conv info. + Expr::replace(conv_expr, newFirstConvExpr); + convInt8Input = Variable::create(conv_expr); + conv_var = convInt8Input; + conv_var->writeScaleMap(input_scale, input_zero); + } + if (conv_expr->inputs().size() == 5) { + // Process matmul output + auto config = Global::Get(); + auto format = MNN::MNN_DATA_FORMAT_NCHW; + if (config->model == modelConfig::TFLITE || config->model == modelConfig::TENSORFLOW) { + format = MNN_DATA_FORMAT_NHWC; + } + // expr->inputs = {input, concat, needSqueezeA, needSqueezeB, transposeA} + auto concat_var = conv_expr->inputs().at(1); + bool needSqueezeA = conv_expr->inputs().at(2)->readMap()[0] > 0.f; + bool needSqueezeB = conv_expr->inputs().at(3)->readMap()[0] > 0.f; + + auto output = _ConvertF(conv_var, format); + output->writeScaleMap(input_scale, input_zero); + + VARP reshapeVar = _ReshapeF(output, concat_var, format); + reshapeVar->writeScaleMap(input_scale, input_zero); + if (needSqueezeA) { + reshapeVar = _Squeeze(reshapeVar, {0}); + } + if (needSqueezeB) { + reshapeVar = _Squeeze(reshapeVar, {1}); + } + reshapeVar->setName(conv_expr->outputName(0) + "__matmul_cvt_convInt8_reshape"); + Expr::replace(conv_expr, reshapeVar->expr().first); + convInt8Input = reshapeVar; + convInt8Input->writeScaleMap(input_scale, input_zero); + } + + if (expr->inputs().size() == 5) { auto matmulop = expr->get(); auto count_input = matmulop->main_as_Convolution2D()->common()->inputCount(); - convInt8Input = buildInputForMatmulInt8(convInt8Input, expr->inputs().at(1), expr->inputs().at(2), count_input); + convInt8Input = buildInputForMatmulInt8(convInt8Input, expr->inputs().at(4), expr->inputs().at(2), count_input); + convInt8Input->writeScaleMap(input_scale, input_zero); } - std::unique_ptr newConvInt8(new MNN::Convolution2DT); - std::unique_ptr oldConvOp(expr->get()->UnPack()); - auto oldConvParams = oldConvOp->main.AsConvolution2D(); + newConvInt8->common.reset(new MNN::Convolution2DCommonT); newConvInt8->common = std::move(oldConvParams->common); newConvInt8->symmetricQuan.reset(new QuantizedFloatParamT); newConvInt8->symmetricQuan = std::move(oldConvParams->symmetricQuan); - // newConvInt8->bias = std::move(oldConvParams->bias); - // newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); + newConvInt8->quanParameter.reset(new IDSTQuanT); + newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); + newConvInt8->bias = std::move(oldConvParams->bias); + float scaleout = newConvInt8->quanParameter->scaleOut; + float zeroout = newConvInt8->symmetricQuan->outputZeroPoint; std::unique_ptr conv_op(new OpT); conv_op->name = expr->name(); conv_op->type = oldConvOp->type; conv_op->main.type = OpParameter_Convolution2D; conv_op->main.value = newConvInt8.release(); + - auto conv_expr = Expr::create(conv_op.get(), {convInt8Input}); - conv_expr->setName(expr->name()); -// auto conv_var = Variable::create(conv_expr); -// conv_var->setName(expr->outputName(0)); - Expr::replace(expr, conv_expr); + auto new_conv_expr = Expr::create(conv_op.get(), {convInt8Input}); + if (expr->inputs().size() == 5) { + new_conv_expr = Expr::create(conv_op.get(), {convInt8Input, expr->inputs()[1], expr->inputs()[2], expr->inputs()[3], expr->inputs()[4]}); + } + new_conv_expr->setName(expr->name()); + auto new_conv_var = Variable::create(new_conv_expr); + new_conv_var->writeScaleMap(scaleout, zeroout); + Expr::replace(expr, new_conv_expr); return true; }; @@ -341,31 +499,46 @@ static auto gRegister = []() { // convInt8->(relu)->quant->cast->dequant->convIn auto cast_expr = cast_var->expr().first; auto quan_var = cast_expr->inputs().at(0); auto quan_expr = quan_var->expr().first; - auto convInt8Input = quan_expr->inputs().at(1); - if (expr->inputs().size() == 3) { // The convInt8 comes from matmul. + auto convInt8Input = quan_expr->inputs().at(0); + auto other_var = convInt8Input; + if (expr->inputs().size() == 5) { + // [input,concat,squeezeA,squeezeB,transposeA] auto matmulop = expr->get(); auto count_input = matmulop->main_as_Convolution2D()->common()->inputCount(); - auto matmulInput = expr->inputs().at(0); - convInt8Input = buildInputForMatmulInt8(convInt8Input, expr->inputs().at(1), expr->inputs().at(2), count_input); + convInt8Input = buildInputForMatmulInt8(convInt8Input, expr->inputs().at(4), expr->inputs().at(2), count_input); + convInt8Input->setName(expr->name() + "__matmul_converted_input"); } std::unique_ptr newConvInt8(new MNN::Convolution2DT); std::unique_ptr oldConvOp(expr->get()->UnPack()); auto oldConvParams = oldConvOp->main.AsConvolution2D(); + float input_scale = oldConvParams->quanParameter->scaleIn; + float output_scale = oldConvParams->quanParameter->scaleOut; + float input_zero = static_cast(oldConvParams->symmetricQuan->zeroPoint); + float output_zero = static_cast(oldConvParams->symmetricQuan->outputZeroPoint); + newConvInt8->common.reset(new MNN::Convolution2DCommonT); newConvInt8->common = std::move(oldConvParams->common); newConvInt8->symmetricQuan.reset(new QuantizedFloatParamT); newConvInt8->symmetricQuan = std::move(oldConvParams->symmetricQuan); - // newConvInt8->bias = std::move(oldConvParams->bias); - // newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); + newConvInt8->bias = std::move(oldConvParams->bias); + newConvInt8->quanParameter.reset(new IDSTQuanT); + newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); std::unique_ptr conv_op(new OpT); conv_op->name = expr->name(); conv_op->type = oldConvOp->type; conv_op->main.type = OpParameter_Convolution2D; conv_op->main.value = newConvInt8.release(); - + + other_var->writeScaleMap(input_scale, input_zero); + convInt8Input->writeScaleMap(input_scale, input_zero); auto conv_expr = Expr::create(conv_op.get(), {convInt8Input}); + if (expr->inputs().size() == 5) { + conv_expr = Expr::create(conv_op.get(), {convInt8Input, expr->inputs()[1], expr->inputs()[2], expr->inputs()[3], expr->inputs()[4]}); + } + auto conv_var = Variable::create(conv_expr); + conv_var->writeScaleMap(output_scale, output_zero); conv_expr->setName(expr->name()); Expr::replace(expr, conv_expr); return true; @@ -389,7 +562,7 @@ static auto gRegister = []() { // convInt8->(relu)->quant->cast->dequant->convIn } return true; }; - auto transformXToOther = [](EXPRP expr) { // ohter->quant->cast->dequant->other => other->other + auto transformXToOther = [](EXPRP expr) { // X->quant->cast->dequant->output_other => X->output_other int input_size = static_cast(expr->inputs().size()); std::vector new_inputs(input_size); for (int i = 0; i < input_size; ++i) { @@ -473,23 +646,6 @@ static auto gRegister = []() { // convInt8->(relu)->quant->cast->dequant->convIn auto X_expr = X_var->expr().first; bool convInt8End = X_expr->get()->type() == OpType_ConvInt8; - bool hasReshape = X_expr->get()->type() == OpType_Reshape; - if (X_expr->get()->type() == OpType_Reshape) { - auto convert_var = X_expr->inputs().at(0); - auto convert_expr = convert_var->expr().first; - if (convert_expr->get() && convert_expr->get()->type() == OpType_ConvertTensor) { - auto convint8_var = convert_expr->inputs().at(0); - auto convint8_expr = convint8_var->expr().first; - if (convint8_expr->get() && convint8_expr->get()->type() == OpType_ConvInt8) { - convInt8End = true; - X_expr = std::move(convint8_expr); - } - } - if (convert_expr->get() && convert_expr->get()->type() == OpType_ConvInt8) { - convInt8End = true; - X_expr = std::move(convert_expr); - } - } if (convInt8End) { auto convInt8Input = X_expr->inputs().at(0); @@ -500,17 +656,13 @@ static auto gRegister = []() { // convInt8->(relu)->quant->cast->dequant->convIn newConvInt8->common = std::move(oldConvParams->common); newConvInt8->symmetricQuan.reset(new QuantizedFloatParamT); newConvInt8->symmetricQuan = std::move(oldConvParams->symmetricQuan); - newConvInt8->symmetricQuan->outputDataType = DataType_DT_FLOAT; // If convInt8 is the last op, float value is the torch-fx model's output. - // newConvInt8->bias = std::move(oldConvParams->bias); - // newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); + newConvInt8->quanParameter.reset(new IDSTQuanT); + //newConvInt8->symmetricQuan->outputDataType = DataType_DT_FLOAT; // If convInt8 is the last op, float value is the torch-fx model's output. + newConvInt8->bias = std::move(oldConvParams->bias); + newConvInt8->quanParameter = std::move(oldConvParams->quanParameter); - //Update convInt8 scale. - float outputScale = quan_expr->inputs().at(2)->readMap()[0]; - int oc = static_cast(newConvInt8->symmetricQuan->scale.size()); - float* ptr = newConvInt8->symmetricQuan->scale.data(); - for (int i = 0; i < oc; ++i) { - ptr[i] = ptr[i] * outputScale; - } + float output_scale = newConvInt8->quanParameter->scaleOut; + float output_zero = newConvInt8->symmetricQuan->outputZeroPoint; std::unique_ptr conv_op(new OpT); conv_op->name = X_expr->name(); @@ -519,23 +671,51 @@ static auto gRegister = []() { // convInt8->(relu)->quant->cast->dequant->convIn conv_op->main.value = newConvInt8.release(); auto conv_expr = Expr::create(conv_op.get(), {convInt8Input}); + auto conv_var = Variable::create(conv_expr); + conv_var->writeScaleMap(output_scale, output_zero); + if (X_expr->inputs().size() == 5) { + // Process matmul output + auto config = Global::Get(); + auto format = MNN::MNN_DATA_FORMAT_NCHW; + if (config->model == modelConfig::TFLITE || config->model == modelConfig::TENSORFLOW) { + format = MNN_DATA_FORMAT_NHWC; + } + + conv_var->setName(X_expr->outputName(0)); +// newconv_var->setName(conv_expr->outputName(0)); + // expr->inputs = {input, concat, needSqueezeA, needSqueezeB, transposeA} + auto concat_var = X_expr->inputs().at(1); + bool needSqueezeA = X_expr->inputs().at(2)->readMap()[0] > 0.f; + bool needSqueezeB = X_expr->inputs().at(3)->readMap()[0] > 0.f; + + auto output = _ConvertF(conv_var, format); + output->writeScaleMap(output_scale, output_zero); + VARP reshapeVar = _ReshapeF(output, concat_var, format); + reshapeVar->writeScaleMap(output_scale, output_zero); + if (needSqueezeA) { + reshapeVar = _Squeeze(reshapeVar, {0}); + reshapeVar->writeScaleMap(output_scale, output_zero); + } + if (needSqueezeB) { + reshapeVar = _Squeeze(reshapeVar, {1}); + reshapeVar->writeScaleMap(output_scale, output_zero); + } + reshapeVar->setName(expr->name()); + Expr::replace(expr, reshapeVar->expr().first); + return true; + } conv_expr->setName(expr->name()); - - if (hasReshape) { - conv_expr->setName(X_expr->name()); - std::unique_ptr reshapeOp(X_var->expr().first->get()->UnPack()); - auto new_reshape_expr = Expr::create(reshapeOp.get(), X_var->expr().first->inputs()); - new_reshape_expr->setName(expr->name()); - Expr::replace(expr, new_reshape_expr); - } - Expr::replace(X_expr, conv_expr); + Expr::replace(expr, conv_expr); return true; } - + float output_scale = quan_expr->inputs().at(2)->readMap()[0]; + float output_zero = quan_expr->inputs().at(3)->readMap()[0]; // directly return the op output. std::unique_ptr oldOtherOp(X_expr->get()->UnPack()); auto newop_expr = Expr::create(oldOtherOp.get(), X_expr->inputs()); newop_expr->setName(expr->name()); + auto newop_var = Variable::create(newop_expr); + newop_var->writeScaleMap(output_scale, output_zero); Expr::replace(expr, newop_expr); return true; }; diff --git a/tools/converter/source/optimizer/merge/ConvertMatMulToConv2D.cpp b/tools/converter/source/optimizer/merge/ConvertMatMulToConv2D.cpp index 0330a7342..6138f14a0 100644 --- a/tools/converter/source/optimizer/merge/ConvertMatMulToConv2D.cpp +++ b/tools/converter/source/optimizer/merge/ConvertMatMulToConv2D.cpp @@ -479,17 +479,19 @@ ConvertMatMulToConv2D::ConvertMatMulToConv2D() { } } auto matmulInput = matmul_expr->inputs().at(0); - auto inputScale = matmul_expr->inputs().at(2); - auto inputZero = matmul_expr->inputs().at(3); + auto inputScale = matmul_expr->inputs().at(2); + auto inputZero = matmul_expr->inputs().at(3); auto weightScale = matmul_expr->inputs().at(4); + auto weightZero = matmul_expr->inputs().at(5); auto outputScale = matmul_expr->inputs().at(6); - auto outputZero = matmul_expr->inputs().at(7); + auto outputZero = matmul_expr->inputs().at(7); - float input_zero = inputZero->readMap()[0]; - float input_scale = inputScale->readMap()[0]; + float input_zero = inputZero->readMap()[0]; + float input_scale = inputScale->readMap()[0]; const float* weight_scale = weightScale->readMap(); - float output_scale = outputScale->readMap()[0]; - uint8_t output_zero = outputZero->readMap()[0]; + const float* weight_zero = weightZero->readMap(); + float output_scale = outputScale->readMap()[0]; + int output_zero = static_cast(outputZero->readMap()[0]); // Convint8 std::unique_ptr dense(new MNN::Convolution2DT); dense->common.reset(new MNN::Convolution2DCommonT); @@ -502,42 +504,29 @@ ConvertMatMulToConv2D::ConvertMatMulToConv2D() { dense->symmetricQuan->clampMax = 127; dense->symmetricQuan->zeroPoint = static_cast(input_zero); dense->symmetricQuan->outputZeroPoint = static_cast(output_zero); - // weight and bias - auto weight_ptr = weight->readMap(); - dense->symmetricQuan->weight.resize(weightInfo->size); - memcpy(dense->symmetricQuan->weight.data(), weight_ptr, weightInfo->size * sizeof(int8_t)); - std::vector weightKenelSum(numberOutput); - int kernelSize = weightInfo->size / numberOutput; - for (int i = 0; i < numberOutput; i++) { - int temp = 0; - int offset = i * kernelSize; - for (int j = 0; j < kernelSize; j++) { - temp += int(weight_ptr[offset + j]); - } - weightKenelSum[i] = temp; - } - - - dense->symmetricQuan->bias.resize(numberOutput, 0); - // compute conv scale=input_scale * weight_scale / output_scale - std::vector conv_scale(numberOutput); - for (int k = 0; k < numberOutput; ++k) { - if (output_scale != 0) { - conv_scale[k] = input_scale * weight_scale[k] / output_scale; - } else { - conv_scale[k] = 0.f; - } + // quantParameter + dense->quanParameter.reset(new IDSTQuanT); + dense->quanParameter->scaleIn = input_scale; + dense->quanParameter->scaleOut = output_scale; + dense->quanParameter->type = 4; + dense->quanParameter->aMin = -128; + dense->quanParameter->readType = numberOutput; + dense->quanParameter->quantScale = 1.0f; + dense->quanParameter->buffer.resize(weightInfo->size); + ::memcpy(dense->quanParameter->buffer.data(), weight->readMap(), weightInfo->size * sizeof(int8_t)); + dense->bias.resize(numberOutput, 0); + // quan alpha + dense->quanParameter->alpha.resize(2 * numberOutput); + for (int i = 0; i < numberOutput; ++i) { + dense->quanParameter->alpha[2 * i] = (-1)*(weight_zero[i] + 128) * weight_scale[i]; + dense->quanParameter->alpha[2 * i + 1] = weight_scale[i]; } + if (matmul_expr->inputs().size() == 9) { bias_var = matmul_expr->inputs().at(8); - auto bias_ptr = bias_var->readMap(); - auto biasInt32 = dense->symmetricQuan->bias.data(); - for (int cnt = 0; cnt < numberOutput; ++cnt) { - biasInt32[cnt] = bias_ptr[cnt] - weightKenelSum[cnt] * static_cast(input_zero) + static_cast(static_cast(output_zero) / conv_scale[cnt]); - } -// memcpy(dense->symmetricQuan->bias.data(), bias_ptr, sizeof(int32_t) * numberOutput); + auto bias_ptr = bias_var->readMap(); + memcpy(dense->bias.data(), bias_ptr, sizeof(int32_t) * numberOutput); } - dense->symmetricQuan->scale = std::move(conv_scale); // Third, build convint8 op std::unique_ptr dense_op(new OpT); @@ -554,42 +543,27 @@ ConvertMatMulToConv2D::ConvertMatMulToConv2D() { VARP inputRemain = _StridedSlice(inputShape, _Unsqueeze(_Scalar(0), {0}), _Unsqueeze(rank - _Scalar(2), {0}), _Unsqueeze(_Scalar(1), {0}), 0, 0, 0, 0, 0); if (transposeA) { inputE = _Slice(inputShape, _Unsqueeze(rank - _Scalar(1), {0}), _Unsqueeze(_Scalar(1), {0})); - if (format == MNN_DATA_FORMAT_NHWC) { - input = _ReshapeF(input, _Concat({_Unsqueeze(_Scalar(-1), {0}), inputE, _Unsqueeze(_Scalar(1), {0}), inputL}, 0), format); - } else { - input = _ReshapeF(input, _Concat({_Unsqueeze(_Scalar(-1), {0}), inputL, inputE, _Unsqueeze(_Scalar(1), {0})}, 0), format); - } } else { inputE = _Slice(inputShape, _Unsqueeze(rank - _Scalar(2), {0}), _Unsqueeze(_Scalar(1), {0})); - if (format == MNN_DATA_FORMAT_NHWC) { - input = _ReshapeF(input, _Concat({_Unsqueeze(_Scalar(-1), {0}), _Unsqueeze(_Scalar(1), {0}), _Unsqueeze(_Scalar(1), {0}), inputL}, 0), format); - } else { - input = _ReshapeF(input, _Concat({_Unsqueeze(_Scalar(-1), {0}), inputL, _Unsqueeze(_Scalar(1), {0}), _Unsqueeze(_Scalar(1), {0})}, 0), format); - } } if (config->externalFile && weightInfo->size >= config->externalTreshold) { RemoveAndStoreParam(dense_op, config->externalFile, config->externalOffset); } - float ta = 0, sa = 0; + float ta = 0, sa = 0, sqzb = 0; if (transposeA) { ta = 1.0f; } if (needSqueezeA) { sa = 1.0f; } - EXPRP dense_expr = Expr::create(dense_op.get(), {matmul_input, _Const(ta), _Const(sa)}, 1); - VARP output = Variable::create(dense_expr); - output->setName(matmul_expr->outputName(0) + "__matmul_converted"); - output = _ConvertF(output, format); - VARP reshapeVar = _ReshapeF(output, _Concat({inputRemain, inputE, outputH}, 0), format); - if (needSqueezeA) { - reshapeVar = _Squeeze(reshapeVar, {0}); - } if (needSqueezeB) { - reshapeVar = _Squeeze(reshapeVar, {1}); + sqzb = 1.0f; } - reshapeVar->setName(matmul_expr->outputName(0) + "__matmul_cvt_convInt8"); - Expr::replace(matmul_expr, reshapeVar->expr().first); + EXPRP dense_expr = Expr::create(dense_op.get(), {matmul_input, _Concat({inputRemain, inputE, outputH}, 0), _Const(sa), _Const(sqzb), _Const(ta)}, 1); + VARP output = Variable::create(dense_expr); + // output->setName(matmul_expr->outputName(0)); + dense_expr->setName(matmul_expr->outputName(0) + "__matmul_converted"); + Expr::replace(matmul_expr, dense_expr); return true; }; TemplateMerge::getInstance("Merge").insertTemplateV2("MatMulInt8ToConvInt8", fold, PASS_PRIORITY_HIGH); diff --git a/tools/converter/source/optimizer/onnxextra/OnnxConvolutionMerge.cpp b/tools/converter/source/optimizer/onnxextra/OnnxConvolutionMerge.cpp index 18ba327ac..b122d2fb7 100644 --- a/tools/converter/source/optimizer/onnxextra/OnnxConvolutionMerge.cpp +++ b/tools/converter/source/optimizer/onnxextra/OnnxConvolutionMerge.cpp @@ -308,12 +308,14 @@ class OnnxConvolutionTransform : public OnnxExtraManager::Transform { } auto outputScaleVar = outputExpr->inputs()[1]; float outputScale = outputScaleVar->readMap()[0]; - if (hasRelu) { - outputScale = 1.0f; - } int8_t outputZero = 0; if (outputExpr->inputs().size() > 2) { - outputZero = static_cast(outputExpr->inputs()[2]->readMap()[0]); + if (outputExpr->inputs()[2]->getInfo()->type.code == halide_type_uint) { + outputZero = static_cast(outputExpr->inputs()[2]->readMap()[0] - 128); + } else { + outputZero = static_cast(outputExpr->inputs()[2]->readMap()[0]); + } + } // Get weight quant info. float inputClampMin = -128; @@ -337,13 +339,17 @@ class OnnxConvolutionTransform : public OnnxExtraManager::Transform { weightKenelSum[i] = temp; } std::vector biasInt32(common->outputCount, 0); + convParam->quanParameter.reset(new IDSTQuanT); + convParam->quanParameter->aMin = -128; + convParam->quanParameter->aMax = co; + convParam->quanParameter->readType = co; + convParam->quanParameter->type = 4; + convParam->quanParameter->buffer.resize(weightSize); + ::memcpy(convParam->quanParameter->buffer.data(), pw, weightSize * sizeof(int8_t)); + convParam->quanParameter->quantScale = 1.0f; + convParam->quanParameter->scaleOut = outputScale; convParam->symmetricQuan.reset(new QuantizedFloatParamT); - convParam->symmetricQuan->weight.resize(weightSize); - ::memcpy(convParam->symmetricQuan->weight.data(), pw, weightSize * sizeof(int8_t)); convParam->symmetricQuan->nbits = 8; - if (hasRelu) { - convParam->symmetricQuan->outputDataType = DataType_DT_FLOAT; - } // Get input quant info. auto inputExpr = inputs[0]->expr().first; @@ -352,32 +358,30 @@ class OnnxConvolutionTransform : public OnnxExtraManager::Transform { auto inputZeroVar = inputExpr->inputs()[3]; float inputScale = inputScaleVar->readMap()[0]; int8_t inputZero = static_cast(inputZeroVar->readMap()[0]); + + convParam->quanParameter->scaleIn = inputScale; + convParam->quanParameter->alpha.resize(2 * co); // Compute convInt8 scale=(inputScale * weightScale)/outputScale std::vector scale(co); auto weightScale = weightexpr->inputs().at(2); auto ptrscale = weightScale->readMap(); + auto weightZero = weightexpr->inputs().at(3); + auto ptrzero = weightZero->readMap(); for (int cnt = 0; cnt < co; ++cnt) { - if (outputScale != 0){ - scale[cnt] = ptrscale[cnt] * inputScale / outputScale; - } else { - scale[cnt] = 0.f; - } + convParam->quanParameter->alpha[2 * cnt + 1] = ptrscale[cnt]; + convParam->quanParameter->alpha[2 * cnt] = (-1)*(ptrzero[cnt] + 128) * ptrscale[cnt]; } + convParam->bias.resize(co); if (inputSize > 2) { auto biasExpr = inputs[2]->expr().first; - auto biasInt32Var = biasExpr->inputs()[0]; - auto ptr = biasInt32Var->readMap(); - if (!ptr) { + auto biasfp32Var = biasExpr->inputs()[1]; + if (biasfp32Var->readMap() == nullptr) { MNN_ERROR("Convolution bias should be constant\n"); return nullptr; } - for (int cnt = 0; cnt < co; ++cnt) { - biasInt32[cnt] = ptr[cnt] - weightKenelSum[cnt] * static_cast(inputZero) + static_cast(static_cast(outputZero) / scale[cnt]); - } + ::memcpy(convParam->bias.data(), biasfp32Var->readMap(), co * sizeof(float)); } - convParam->symmetricQuan->bias = std::move(biasInt32); - convParam->symmetricQuan->scale = std::move(scale); convParam->symmetricQuan->clampMax = 127; convParam->symmetricQuan->clampMin = -128; convParam->symmetricQuan->zeroPoint = std::move(inputZero); diff --git a/tools/converter/source/optimizer/onnxextra/OnnxDeQuantizeLinear.cpp b/tools/converter/source/optimizer/onnxextra/OnnxDeQuantizeLinear.cpp index b3c241104..6927e246d 100644 --- a/tools/converter/source/optimizer/onnxextra/OnnxDeQuantizeLinear.cpp +++ b/tools/converter/source/optimizer/onnxextra/OnnxDeQuantizeLinear.cpp @@ -30,28 +30,55 @@ class OnnxDequantizeLinearTransform : public OnnxExtraManager::Transform { MNN_ERROR("QuantizeLinear should provide scale and input\n"); return nullptr; } - VARP zeropoint = nullptr; + + uint8_t dataType = halide_type_int; + VARP zeropoint = _Const(0.f); if (inputs.size() > 2) { - zeropoint = inputs[2]; + if (inputs[2]->getInfo() == nullptr) { + MNN_ERROR("DequantizeLinear layer inputs.size>2, but zeroPoint is not const\n"); + } + MNN_ASSERT(inputs[2]->getInfo() != nullptr); + auto zeroDim = inputs[2]->getInfo()->dim; + dataType = inputs[2]->getInfo()->type.code; + std::vector fp32Zero(inputs[2]->getInfo()->size); + if (dataType == halide_type_int) { + const int8_t* zeroPtr = inputs[2]->readMap(); + for (int j = 0; j < fp32Zero.size(); ++j) { + fp32Zero[j] = static_cast(zeroPtr[j]); + } + zeropoint = _Const(fp32Zero.data(), zeroDim, inputs[2]->getInfo()->order, halide_type_of()); + } else { + const uint8_t* zeroPtr = inputs[2]->readMap(); + for (int j = 0; j < fp32Zero.size(); ++j) { + fp32Zero[j] = static_cast(zeroPtr[j]) - 128.f; + } + zeropoint = _Const(fp32Zero.data(), zeroDim, inputs[2]->getInfo()->order, halide_type_of()); + } + zeropoint = _Cast(inputs[2]); } std::vector inputDim = {}; if (input->getInfo()) { inputDim = input->getInfo()->dim; + dataType = input->getInfo()->type.code; } - if (!scale->getInfo()->dim.empty()) { - zeropoint = _Unsqueeze(zeropoint, {1,2,3}); - scale = _Unsqueeze(scale, {1, 2, 3}); - } else { - scale = _Reshape(scale, {1}); - zeropoint = _Reshape(zeropoint, {1}); + auto offset = _Const(0.f); + if (dataType == halide_type_uint) { + offset = _Const(128.f); } + // if (!scale->getInfo()->dim.empty()) { + // zeropoint = _Unsqueeze(zeropoint, {1,2,3}); + // scale = _Unsqueeze(scale, {1, 2, 3}); + // } else { + // scale = _Reshape(scale, {1}); + // zeropoint = _Reshape(zeropoint, {1}); + // } auto _shape = _Const(inputDim.data(), {static_cast(inputDim.size())}, NHWC, halide_type_of()); - auto output = (_Cast(input) - _Cast(zeropoint)) * scale; + auto output = (_Cast(input) - zeropoint) * scale; std::unique_ptr iden(new MNN::OpT); iden->type = OpType_Int8ToFloat; - auto newExpr = MNN::Express::Expr::create(iden.get(), {input, output, scale, _Cast(zeropoint), _shape}, 5); + auto newExpr = MNN::Express::Expr::create(iden.get(), {input, output, scale, zeropoint - offset, _shape}, 5); newExpr->setName(expr->name()); return newExpr; } diff --git a/tools/converter/source/optimizer/onnxextra/OnnxGemm.cpp b/tools/converter/source/optimizer/onnxextra/OnnxGemm.cpp index 70490e32f..f5a96cb27 100644 --- a/tools/converter/source/optimizer/onnxextra/OnnxGemm.cpp +++ b/tools/converter/source/optimizer/onnxextra/OnnxGemm.cpp @@ -70,12 +70,19 @@ class OnnxGemmTransform : public OnnxExtraManager::Transform { // output quant info auto outputExpr = expr->outputs().front().lock(); auto outputScaleVar = outputExpr->inputs()[1]; - auto outputZero = outputExpr->inputs()[2]; + auto outputZero = _Const(0.f); + if (outputExpr->inputs().size() > 2 && outputExpr->inputs()[2]->getInfo()) { + if (outputExpr->inputs()[2]->getInfo()->type.code == halide_type_int) { + outputZero = _Cast(outputExpr->inputs()[2]); + } else { + outputZero = _Cast(outputExpr->inputs()[2]) - _Const(128.f); + } + } Z = _MatMul_Int8(X, y_int8, transA, transB, x_scale, x_zero, y_scale, y_zero, outputScaleVar, outputZero); if (inputs.size() > 2) { auto bias_expr = inputs[2]->expr().first; - auto bias_int32 = bias_expr->inputs().at(0); + auto bias_int32 = bias_expr->inputs().at(1); Z = _MatMul_Int8(X, y_int8, transA, transB, x_scale, x_zero, y_scale, y_zero, outputScaleVar, outputZero, bias_int32); } Z->setName(expr->name()); diff --git a/tools/converter/source/optimizer/onnxextra/OnnxQuantizeLinear.cpp b/tools/converter/source/optimizer/onnxextra/OnnxQuantizeLinear.cpp index 5f9fbe7ce..c94cfee75 100644 --- a/tools/converter/source/optimizer/onnxextra/OnnxQuantizeLinear.cpp +++ b/tools/converter/source/optimizer/onnxextra/OnnxQuantizeLinear.cpp @@ -31,13 +31,19 @@ class OnnxQuantizeLinearTransform : public OnnxExtraManager::Transform { MNN_ERROR("QuantizeLinear should provide scale and input\n"); return nullptr; } + uint8_t dataType = halide_type_int; VARP zeropoint = _Const(0.f); + auto offset = _Const(0.f); if (inputs.size() > 2) { zeropoint = _Cast(inputs[2]); + dataType = inputs[2]->getInfo()->type.code; + } + if (dataType == halide_type_uint) { + offset = _Const(128.f); } auto scaleReq = _Reciprocal(scale); // auto output = _Cast(_Round(_Relu6(_Round(input * scaleReq) + zeropoint, -128.0f, 127.0f))); - auto output = _FloatToInt8(input, scaleReq, -128, 127, static_cast(zeropoint->readMap()[0])); + auto output = _FloatToInt8(input, scaleReq, -128, 127, static_cast(zeropoint->readMap()[0] - offset->readMap()[0])); std::unique_ptr iden(new MNN::OpT); iden->type = OpType_FloatToInt8; std::vector inputDim = {}; @@ -46,7 +52,7 @@ class OnnxQuantizeLinearTransform : public OnnxExtraManager::Transform { inputDim = input->getInfo()->dim; } auto _shape = _Const(inputDim.data(), {static_cast(inputDim.size())}, NHWC, halide_type_of()); - auto newExpr = MNN::Express::Expr::create(iden.get(), {input, output, scale, zeropoint, _shape}, 5); + auto newExpr = MNN::Express::Expr::create(iden.get(), {input, output, scale, zeropoint - offset, _shape}, 5); newExpr->setName(expr->name()); return newExpr; } diff --git a/tools/converter/source/optimizer/postconvert/TransformGroupConvolution.cpp b/tools/converter/source/optimizer/postconvert/TransformGroupConvolution.cpp index 1c5498e26..95e8d6389 100644 --- a/tools/converter/source/optimizer/postconvert/TransformGroupConvolution.cpp +++ b/tools/converter/source/optimizer/postconvert/TransformGroupConvolution.cpp @@ -182,7 +182,7 @@ class TransformGroupConvolution : public PostConverter { auto& common = conv2D->common; const int srcCount = common->inputCount; const bool depthwiseLike = srcCount % common->group != 0 || common->outputCount % common->group != 0; - if (common->group == 1 || op->inputIndexes.size() > 1 || depthwiseLike) { + if (common->group == 1 || depthwiseLike) { iter++; continue; } @@ -212,7 +212,7 @@ class TransformGroupConvolution : public PostConverter { MNN::OpT* sliceOp = new MNN::OpT; sliceOp->type = MNN::OpType_Slice; sliceOp->name = op->name + "_____slice"; - sliceOp->inputIndexes = op->inputIndexes; + sliceOp->inputIndexes = {op->inputIndexes[0]}; sliceOp->outputIndexes = newConvolutionInputIndex; auto sliceT = new MNN::SliceT; sliceOp->main.type = MNN::OpParameter_Slice; @@ -224,38 +224,121 @@ class TransformGroupConvolution : public PostConverter { newOp.push_back(sliceOp); } - int partWeightSize = conv2D->weight.size() / common->group; - int partBiasSize = conv2D->bias.size() / common->group; + if(op->inputIndexes.size() > 1){ + std::vector newConvolutionWeightInputIndex; + std::vector newConvolutionBiasInputIndex; + // splice weight + { + for (int i = 0; i < common->group; ++i) { + std::ostringstream newTensorNameOs; + newTensorNameOs << op->name << "___input___weight___" << i; + newConvolutionWeightInputIndex.push_back(mNet->tensorName.size()); + mNet->tensorName.push_back(newTensorNameOs.str()); + } - // Create Sub Convolution - flatbuffers::FlatBufferBuilder tmpBuilder; - tmpBuilder.Finish(Convolution2DCommon::Pack(tmpBuilder, common.get())); - auto originCommon = flatbuffers::GetRoot(tmpBuilder.GetBufferPointer()); - for (int i = 0; i < common->group; ++i) { - std::ostringstream opNameOs; - auto newConvOp = new MNN::OpT; - opNameOs << op->name << "__group__" << i; - newConvOp->type = op->type; - newConvOp->name = opNameOs.str(); - newConvOp->main.type = MNN::OpParameter_Convolution2D; - newConvOp->inputIndexes.push_back(newConvolutionInputIndex[i]); - newConvOp->outputIndexes.push_back(newConvolutionOutputIndex[i]); + // Create slice op for weight + { + MNN::OpT* sliceOp = new MNN::OpT; + sliceOp->type = MNN::OpType_Slice; + sliceOp->name = op->name + "_____weight_____slice"; + sliceOp->inputIndexes = {op->inputIndexes[1]}; + sliceOp->outputIndexes = newConvolutionWeightInputIndex; + auto sliceT = new MNN::SliceT; + sliceOp->main.type = MNN::OpParameter_Slice; + sliceOp->main.value = sliceT; + sliceT->axis = 0; + for (int i = 0; i < common->group - 1; ++i) { + sliceT->slicePoints.push_back(common->outputCount / (common->group) * (i + 1)); + } + newOp.push_back(sliceOp); + } + } + // slice bias + if(op->inputIndexes.size() == 3){ + for (int i = 0; i < common->group; ++i) { + std::ostringstream newTensorNameOs; + newTensorNameOs << op->name << "___input___bias___" << i; + newConvolutionBiasInputIndex.push_back(mNet->tensorName.size()); + mNet->tensorName.push_back(newTensorNameOs.str()); + } - auto newConvolutionT = new MNN::Convolution2DT; - newConvOp->main.value = newConvolutionT; - newConvolutionT->common = std::unique_ptr(originCommon->UnPack()); - newConvolutionT->common->group = 1; - newConvolutionT->common->outputCount = common->outputCount / common->group; - newConvolutionT->common->inputCount = common->inputCount / common->group; - int startWeight = partWeightSize * i; - int startBias = partBiasSize * i; - for (int v = 0; v < partWeightSize; ++v) { - newConvolutionT->weight.push_back(conv2D->weight[startWeight + v]); + // Create slice op for bias + { + MNN::OpT* sliceOp = new MNN::OpT; + sliceOp->type = MNN::OpType_Slice; + sliceOp->name = op->name + "_____bias_____slice"; + sliceOp->inputIndexes = {op->inputIndexes[2]}; + sliceOp->outputIndexes = newConvolutionBiasInputIndex; + auto sliceT = new MNN::SliceT; + sliceOp->main.type = MNN::OpParameter_Slice; + sliceOp->main.value = sliceT; + sliceT->axis = 0; + for (int i = 0; i < common->group - 1; ++i) { + sliceT->slicePoints.push_back(common->outputCount / (common->group) * (i + 1)); + } + newOp.push_back(sliceOp); + } } - for (int v = 0; v < partBiasSize; ++v) { - newConvolutionT->bias.push_back(conv2D->bias[startBias + v]); + // Create Sub Convolution + flatbuffers::FlatBufferBuilder tmpBuilder; + tmpBuilder.Finish(Convolution2DCommon::Pack(tmpBuilder, common.get())); + auto originCommon = flatbuffers::GetRoot(tmpBuilder.GetBufferPointer()); + for (int i = 0; i < common->group; ++i) { + std::ostringstream opNameOs; + auto newConvOp = new MNN::OpT; + opNameOs << op->name << "__group__" << i; + newConvOp->type = op->type; + newConvOp->name = opNameOs.str(); + newConvOp->main.type = MNN::OpParameter_Convolution2D; + newConvOp->inputIndexes.push_back(newConvolutionInputIndex[i]); + newConvOp->inputIndexes.push_back(newConvolutionWeightInputIndex[i]); + if(op->inputIndexes.size() == 3){ + newConvOp->inputIndexes.push_back(newConvolutionBiasInputIndex[i]); + } + newConvOp->outputIndexes.push_back(newConvolutionOutputIndex[i]); + + auto newConvolutionT = new MNN::Convolution2DT; + newConvOp->main.value = newConvolutionT; + newConvolutionT->common = std::unique_ptr(originCommon->UnPack()); + newConvolutionT->common->group = 1; + newConvolutionT->common->outputCount = common->outputCount / common->group; + newConvolutionT->common->inputCount = common->inputCount / common->group; + newOp.push_back(newConvOp); + } + }else{ + int partWeightSize = conv2D->weight.size() / common->group; + int partBiasSize = conv2D->bias.size() / common->group; + + // Create Sub Convolution + flatbuffers::FlatBufferBuilder tmpBuilder; + tmpBuilder.Finish(Convolution2DCommon::Pack(tmpBuilder, common.get())); + auto originCommon = flatbuffers::GetRoot(tmpBuilder.GetBufferPointer()); + for (int i = 0; i < common->group; ++i) { + std::ostringstream opNameOs; + auto newConvOp = new MNN::OpT; + opNameOs << op->name << "__group__" << i; + newConvOp->type = op->type; + newConvOp->name = opNameOs.str(); + newConvOp->main.type = MNN::OpParameter_Convolution2D; + newConvOp->inputIndexes.push_back(newConvolutionInputIndex[i]); + newConvOp->outputIndexes.push_back(newConvolutionOutputIndex[i]); + + auto newConvolutionT = new MNN::Convolution2DT; + newConvOp->main.value = newConvolutionT; + newConvolutionT->common = std::unique_ptr(originCommon->UnPack()); + newConvolutionT->common->group = 1; + newConvolutionT->common->outputCount = common->outputCount / common->group; + newConvolutionT->common->inputCount = common->inputCount / common->group; + int startWeight = partWeightSize * i; + int startBias = partBiasSize * i; + for (int v = 0; v < partWeightSize; ++v) { + newConvolutionT->weight.push_back(conv2D->weight[startWeight + v]); + } + for (int v = 0; v < partBiasSize; ++v) { + newConvolutionT->bias.push_back(conv2D->bias[startBias + v]); + } + newOp.push_back(newConvOp); } - newOp.push_back(newConvOp); } // Set this op be Concat Op diff --git a/tools/cpp/ModuleBasic.cpp b/tools/cpp/ModuleBasic.cpp index c9b4e93ad..04954e7a4 100644 --- a/tools/cpp/ModuleBasic.cpp +++ b/tools/cpp/ModuleBasic.cpp @@ -15,6 +15,9 @@ #include "rapidjson/document.h" #include "core/MemoryFormater.h" #include +#include +#include +#include #include "ExprDebug.hpp" using namespace MNN::Express; @@ -127,6 +130,9 @@ int main(int argc, char *argv[]) { } checkOutput = outputs.size() > 0; } + // Call Time / Per Second + float freq = 0.0f; + int cpuDecreaseRate = -1; if (inputNames.empty()) { rapidjson::Document document; std::ostringstream jsonNameOs; @@ -176,6 +182,12 @@ int main(int argc, char *argv[]) { if (document.HasMember("repeat")) { repeatNumber = document["repeat"].GetInt(); } + if (document.HasMember("freq")) { + freq = document["freq"].GetFloat(); + } + if (document.HasMember("cpu_decrease_rate")) { + cpuDecreaseRate = document["cpu_decrease_rate"].GetInt(); + } } auto type = MNN_FORWARD_CPU; if (argc > 4) { @@ -189,12 +201,14 @@ int main(int argc, char *argv[]) { modeNum = ::atoi(argv[6]); } + int power = BackendConfig::Power_Normal; int precision = BackendConfig::Precision_Normal; int memory = BackendConfig::Memory_Normal; if (argc > 7) { int mask = atoi(argv[7]); precision = mask % 4; memory = (mask / 4) % 4; + power = (mask / 16) % 4; } const char* cacheFileName = ".tempcache"; if (argc > 8) { @@ -202,6 +216,7 @@ int main(int argc, char *argv[]) { } FUNC_PRINT(precision); FUNC_PRINT(memory); + FUNC_PRINT(power); FUNC_PRINT_ALL(cacheFileName, s); // create session MNN::ScheduleConfig config; @@ -212,7 +227,7 @@ int main(int argc, char *argv[]) { config.backupType = type; BackendConfig backendConfig; // config.path.outputs.push_back("ResizeBilinear_2"); - // backendConfig.power = BackendConfig::Power_High; + backendConfig.power = (BackendConfig::PowerMode)power; backendConfig.precision = static_cast(precision); backendConfig.memory = static_cast(memory); config.backendConfig = &backendConfig; @@ -224,6 +239,9 @@ int main(int argc, char *argv[]) { mConfig.shapeMutable = shapeMutable; std::shared_ptr rtmgr(Executor::RuntimeManager::createRuntimeManager(config)); rtmgr->setCache(cacheFileName); + if (cpuDecreaseRate > 0 && cpuDecreaseRate <= 100) { + rtmgr->setHint(Interpreter::CPU_LITTLECORE_DECREASE_RATE, cpuDecreaseRate); + } if (runMask & 1) { // Need dump tensor, open debug rtmgr->setMode(Interpreter::Session_Debug); @@ -256,6 +274,9 @@ int main(int argc, char *argv[]) { if (runMask & 512) { rtmgr->setHint(Interpreter::WINOGRAD_MEMORY_LEVEL, 0); } + if (runMask & 1024) { + rtmgr->setHint(Interpreter::DYNAMIC_QUANT_OPTIONS, 1); + } std::shared_ptr net; { AUTOTIME; @@ -402,6 +423,12 @@ int main(int argc, char *argv[]) { ((MNN::Tensor*)o->getTensor())->wait(MNN::Tensor::MAP_TENSOR_READ, true); } times[i] = _l.durationInUs() / 1000.0f; + if (freq > 0.0f) { + float remainMs = (1000.0f / freq) - times[i]; + if (remainMs > 0.0f) { + std::this_thread::sleep_for(std::chrono::milliseconds((int)remainMs)); + } + } } if (nullptr != gTimeTraceInfo) { float opSummer = 0.0f; diff --git a/tools/cpp/checkFile.cpp b/tools/cpp/checkFile.cpp index ec7de403f..afa6c2a10 100644 --- a/tools/cpp/checkFile.cpp +++ b/tools/cpp/checkFile.cpp @@ -11,6 +11,7 @@ #include #include #include +#include using namespace std; @@ -25,7 +26,8 @@ int main(int argc, char* argv[]) { const char* file2 = argv[2]; float tolerance = 0.001; if (argc > 3) { - tolerance = atof(argv[3]); + std::istringstream ss(argv[3]); + ss >> tolerance; } // open file diff --git a/tools/cpp/getPerformance.cpp b/tools/cpp/getPerformance.cpp index d56c3c0bb..ff3b5dfc4 100644 --- a/tools/cpp/getPerformance.cpp +++ b/tools/cpp/getPerformance.cpp @@ -10,6 +10,8 @@ #include #include #include +#include +#include #include #include #include "core/Macro.h" @@ -201,10 +203,38 @@ void cpuFLOPSPerformance() { MNN_PRINT("CPU float gflops : %f\n", gflops); } +static void _testMemcpy() { + int size = 1024 * 1024; + int loop = 10000; + std::vector threads; + MNN::Timer _t; + for (int i=0; i<2; ++i) { + threads.emplace_back(std::thread([size, loop]() { + std::vector tmp0(size); + std::vector tmp1(size); + auto t0 = tmp0.data(); + auto t1 = tmp1.data(); + for (int i=0; i=0: transformerNamse.append(funcName) + elif end == '__IMAGE__': + opNamesImage.append(funcName) else: - opNames.append(funcName) + opNamesBuffer.append(funcName) + bufferFileNames = os.listdir(openclBufferDir) print(bufferFileNames) collectFile(bufferFileNames, openclBufferDir) @@ -194,7 +198,11 @@ def collectFile(fileNames, dirname): f.write('#ifndef MNN_OPENCL_SEP_BUILD\n') f.write('namespace MNN {\n') f.write('namespace OpenCL {\n') - for l in opNames: + f.write('#ifndef ' + 'MNN_OPENCL_BUFFER_CLOSED' + '\n') + for l in opNamesBuffer: + f.write("extern void " + l + '();\n') + f.write('#endif\n') + for l in opNamesImage: f.write("extern void " + l + '();\n') f.write('\n') f.write('#ifdef ' + 'MNN_SUPPORT_TRANSFORMER_FUSE' + '\n') @@ -202,8 +210,13 @@ def collectFile(fileNames, dirname): f.write("extern void " + l + '();\n') f.write('#endif\n') f.write('void registerOpenCLOps() {\n') - for l in opNames: - f.write(l+'();\n') + f.write('#ifndef ' + 'MNN_OPENCL_BUFFER_CLOSED' + '\n') + for l in opNamesBuffer: + f.write(l + '();\n') + f.write('#endif\n') + for l in opNamesImage: + f.write(l + '();\n') + f.write('\n') f.write('#ifdef ' + 'MNN_SUPPORT_TRANSFORMER_FUSE' + '\n') for l in transformerNamse: f.write(l+'();\n') diff --git a/tools/script/testMNNFromOnnx.py b/tools/script/testMNNFromOnnx.py index 01cf38847..8b032c86a 100644 --- a/tools/script/testMNNFromOnnx.py +++ b/tools/script/testMNNFromOnnx.py @@ -122,7 +122,7 @@ def __run_mnn(self): if not os.path.exists(mnnconvert_name): print("./MNNConvert not exist in this path. Use pymnn instead of C++ to test") mnnconvert_name = 'mnnconvert' - convert = mnnconvert_name + ' -f ONNX --bizCode MNN --modelFile onnx/test.onnx --MNNModel convert_cache.mnn --keepInputFormat --testdir onnx' + convert = mnnconvert_name + ' -f ONNX --bizCode MNN --modelFile onnx/test.onnx --MNNModel convert_cache.mnn --keepInputFormat=1 --testdir onnx' result = os.popen(convert).read() print(result) return result diff --git a/tools/script/testMNNFromTf.py b/tools/script/testMNNFromTf.py index e267681d7..c30a22116 100644 --- a/tools/script/testMNNFromTf.py +++ b/tools/script/testMNNFromTf.py @@ -32,7 +32,7 @@ def __run_mnn(self): if not os.path.exists(mnnconvert_name): print("./MNNConvert not exist in this path. Use pymnn instead of C++ to test") mnnconvert_name = 'mnnconvert' - convert = mnnconvert_name + ' -f TF --bizCode MNN --modelFile tf/test.pb --MNNModel convert_cache.mnn --keepInputFormat --testdir tf' + convert = mnnconvert_name + ' -f TF --bizCode MNN --modelFile tf/test.pb --MNNModel convert_cache.mnn --keepInputFormat=1 --testdir tf' result = os.popen(convert).read() print(result) return result diff --git a/tools/script/testMNNFromTflite.py b/tools/script/testMNNFromTflite.py index bfa24c65d..dd9ffef71 100644 --- a/tools/script/testMNNFromTflite.py +++ b/tools/script/testMNNFromTflite.py @@ -54,7 +54,7 @@ def __run_mnn(self): if not os.path.exists(mnnconvert_name): print("./MNNConvert not exist in this path. Use pymnn instead of C++ to test") mnnconvert_name = 'mnnconvert' - convert = mnnconvert_name + ' -f TFLITE --bizCode MNN --modelFile tflite/test.tflite --MNNModel convert_cache.mnn --keepInputFormat --testdir tflite' + convert = mnnconvert_name + ' -f TFLITE --bizCode MNN --modelFile tflite/test.tflite --MNNModel convert_cache.mnn --keepInputFormat=1 --testdir tflite' result = os.popen(convert).read() print(result) return result diff --git a/tools/script/testMNNFromTorch.py b/tools/script/testMNNFromTorch.py index 385598dce..173db53c2 100644 --- a/tools/script/testMNNFromTorch.py +++ b/tools/script/testMNNFromTorch.py @@ -31,7 +31,7 @@ def __run_mnn(self): if not os.path.exists(mnnconvert_name): print("./MNNConvert not exist in this path. Use pymnn instead of C++ to test") mnnconvert_name = 'mnnconvert' - convert = mnnconvert_name + ' -f TORCH --bizCode MNN --modelFile torch/test.pt --MNNModel convert_cache.mnn --keepInputFormat --testdir torch' + convert = mnnconvert_name + ' -f TORCH --bizCode MNN --modelFile torch/test.pt --MNNModel convert_cache.mnn --keepInputFormat=1 --testdir torch' result = os.popen(convert).read() print(result) return result diff --git a/tools/train/source/demo/demoMain.cpp b/tools/train/source/demo/demoMain.cpp index 1417e5b6b..30c844e75 100644 --- a/tools/train/source/demo/demoMain.cpp +++ b/tools/train/source/demo/demoMain.cpp @@ -10,7 +10,7 @@ #include "DemoUnit.hpp" #include int main(int argc, const char* argv[]) { - ExecutorScope::Current()->setLazyComputeMode(MNN::Express::Executor::LAZY_CONTENT); +// ExecutorScope::Current()->setLazyComputeMode(MNN::Express::Executor::LAZY_CONTENT); if (argc < 2) { MNN_ERROR("Usage: ./runTrainDemo.out CASENAME [ARGS]\n"); auto& list = DemoUnitSet::get()->list(); diff --git a/tools/train/source/nn/NN.cpp b/tools/train/source/nn/NN.cpp index 7594e7197..a49c6afaf 100644 --- a/tools/train/source/nn/NN.cpp +++ b/tools/train/source/nn/NN.cpp @@ -6,6 +6,7 @@ // Copyright © 2018, Alibaba Group Holding Limited // +#include #include "NN.hpp" #include "Distributions.hpp" #include "module/PipelineModule.hpp" @@ -397,6 +398,11 @@ Module* NN::Linear(int l, int t, bool hasBias, std::shared_ptr weig } auto weight = weightInit->createConstVar({t, l}, NCHW); weight.fix(VARP::TRAINABLE); + // Save lazy mode + auto lazyEval = ExecutorScope::Current()->lazyEval; + auto lazyMode = ExecutorScope::Current()->getLazyMode(); + ExecutorScope::Current()->lazyEval = true; + ExecutorScope::Current()->setLazyComputeMode(Executor::LAZY_FULL); auto input = _Input({l}, NCHW); auto output = _MatMul(input, weight, false, true); if (!hasBias) { @@ -407,6 +413,10 @@ Module* NN::Linear(int l, int t, bool hasBias, std::shared_ptr weig output = _Add(output, bias); auto module = NN::extract({input}, {output}, true); module->setType("Linear"); + // Revert lazy mode + ExecutorScope::Current()->lazyEval = lazyEval; + ExecutorScope::Current()->setLazyComputeMode(lazyMode); + return module; } diff --git a/transformers/diffusion/README.md b/transformers/diffusion/README.md deleted file mode 100644 index 42d247ca8..000000000 --- a/transformers/diffusion/README.md +++ /dev/null @@ -1,45 +0,0 @@ -# Diffusion使用方法 - -## 模型支持与下载 - -[Download-runwayml/stable-diffusion-v1-5]: -https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main -[Download-IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1]: -https://huggingface.co/IDEA-CCNL/Taiyi-Stable-Diffusion-1B-Chinese-v0.1/tree/main - -## 模型转换 -### 将Huggingface的Stable Diffusion模型 转为onnx模型 -python export/onnx_export.py \ - --model_path hf_sd_load_path \ - --output_path onnx_save_path - -### 将onnx模型转为mnn模型 -新建diffusion mnn模型文件夹,将转好的mnn文件放在该文件夹下。 -./MNNConvert -f ONNX --modelFile onnx_save_path/text_encoder/model.onnx --MNNModel mnn_save_path/text_encoder.mnn --weightQuantBits 8 --bizCode biz -./MNNConvert -f ONNX --modelFile onnx_save_path/unet/model.onnx --MNNModel mnn_save_path/unet.mnn --transformerFuse --weightQuantBits 8 --bizCode biz -./MNNConvert -f ONNX --modelFile onnx_save_path/vae_decoder/model.onnx --keepInputFormat --MNNModel mnn_save_path/vae_decoder.mnn --weightQuantBits 8 --bizCode biz - -## 编译Diffusion Demo -### Linux/MAC/Windows上 -cmake .. -DMNN_BUILD_DIFFUSION=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_SEP_BUILD=OFF -DMNN_SUPPORT_TRANSFORMER_FUSE=ON - -### Android上 -cd project/android/build -../build_64.sh -DMNN_BUILD_DIFFUSION=ON -DMNN_BUILD_OPENCV=ON -DMNN_IMGCODECS=ON -DMNN_OPENCL=ON -DMNN_SEP_BUILD=OFF -DMNN_SUPPORT_TRANSFORMER_FUSE=ON - -## 运行Diffusion Demo -./diffusion_demo -其中,resource_path 就是mnn模型文件的路径,除了mnn文件,还需要 -(1)将MNN目录transformers/diffusion/scheduler/alphas.txt文件拷贝到该文件夹下。 -(2)针对stable-diffusion-v1-5模型需要将huggingfacetokenizer目录下merges.txt和vocab.json拷贝到该文件夹中。针对Taiyi-Stable-Diffusion模型需要将huggingfacetokenizer目录下vocab.txt拷贝到该文件夹中。 - -model_type是目前支持的两种diffusion模型的类别。如果是stable-diffusion-v1-5模型设为0,如果是Taiyi-Stable-Diffusion模型设为1。 - -output_image_name是生成图片的名字,默认图片位置在当前运行目录下。 - -input_text是文生图的prompt,如果是stable-diffusion-v1-5模型建议英文prompt,如果是Taiyi-Stable-Diffusion建议中文prompt。 - -运行指令例如: -./diffusion_demo mnn_save_path 0 demo.jpg "a cute cat" -./diffusion_demo mnn_save_path 1 demo.jpg "一只可爱的猫" - diff --git a/transformers/diffusion/env.yaml b/transformers/diffusion/env.yaml new file mode 100644 index 000000000..de4ae3699 --- /dev/null +++ b/transformers/diffusion/env.yaml @@ -0,0 +1,10 @@ +name: ldm +channels: + - pytorch + - defaults +dependencies: + - pytorch + - numpy + - diffusers + - onnx + - transformers diff --git a/transformers/diffusion/main.cpp b/transformers/diffusion/main.cpp index 4cbb03d1e..946175e34 100644 --- a/transformers/diffusion/main.cpp +++ b/transformers/diffusion/main.cpp @@ -3,7 +3,7 @@ int main(int argc, const char* argv[]) { if (argc < 3) { - printf("Usage: ./diffusion_demo \n"); + MNN_PRINT("Usage: ./diffusion_demo \n"); return 0; } @@ -19,16 +19,16 @@ int main(int argc, const char* argv[]) { } } - printf("model resource path: %s\n", resource_path); + MNN_PRINT("model resource path: %s\n", resource_path); if(model_type == diffusion::STABLE_DIFFUSION_1_5) { - printf("model resourc is stable diffusion 1.5\n"); + MNN_PRINT("model type is stable diffusion 1.5\n"); } else if (model_type == diffusion::STABLE_DIFFUSION_TAIYI_CHINESE) { - printf("model resourc is stable diffusion taiyi chinese version\n"); + MNN_PRINT("model type is stable diffusion taiyi chinese version\n"); } else { - printf("model type: %d not supported, please check\n", (int)model_type); + MNN_PRINT("model type: %d not supported, please check\n", (int)model_type); } - printf("output img_name: %s\n", img_name); - printf("input texts: %s\n", input_text.c_str()); + MNN_PRINT("output img_name: %s\n", img_name); + MNN_PRINT("input texts: %s\n", input_text.c_str()); diffusion::Pipeline pipeline(resource_path, model_type); diff --git a/transformers/diffusion/pipeline.cpp b/transformers/diffusion/pipeline.cpp index 32d794684..194ebebd2 100644 --- a/transformers/diffusion/pipeline.cpp +++ b/transformers/diffusion/pipeline.cpp @@ -43,10 +43,10 @@ static inline int64_t getTime() { void display_progress(int cur, int total){ putchar('\r'); - printf("["); + MNN_PRINT("["); for (int i = 0; i < cur; i++) putchar('#'); for (int i = 0; i < total - cur; i++) putchar('-'); - printf("]"); + MNN_PRINT("]"); fprintf(stdout, " [%3d%%]", cur * 100 / total); if (cur == total) putchar('\n'); fflush(stdout); @@ -113,8 +113,8 @@ bool Pipeline::load_modules(std::string modelPath) { mTimestepVar = _Input({1}, NCHW, halide_type_of()); mSampleVar = _Concat({mLatentVar, mLatentVar}, 0); - printf("Model loading and initilizing...\n"); - printf("First time initilizing may cost a few seconds to create cachefile, please wait ...\n"); + MNN_PRINT("Model loading and initilizing...\n"); + MNN_PRINT("First time initilizing may cost a few seconds to create cachefile, please wait ...\n"); VARP text_embeddings; mModules.resize(3); @@ -170,9 +170,9 @@ VARP Pipeline::text_encoder(const std::vector& ids) { #ifdef MNN_DUMP_DATA auto xx = output->readMap(); for(int i=0; i<10; i+=2) { - printf("%f %f ", xx[i], xx[i+mMaxTextLen*768]); + MNN_PRINT("%f %f ", xx[i], xx[i+mMaxTextLen*768]); } - printf("\n\n"); + MNN_PRINT("\n\n"); #endif return output; } @@ -276,12 +276,12 @@ VARP Pipeline::unet(VARP text_embeddings) { auto zz = text_embeddings->readMap(); for(int i=0; i<6; i+=2) { - printf("%f %f %f ", xx[i], yy[i], zz[i]); + MNN_PRINT("%f %f %f ", xx[i], yy[i], zz[i]); } for(int i=0; i<6; i+=2) { - printf("%f %f %f ", xx[16384+i], yy[16384+i], zz[mMaxTextLen*768+i]); + MNN_PRINT("%f %f %f ", xx[16384+i], yy[16384+i], zz[mMaxTextLen*768+i]); } - printf("\n\n"); + MNN_PRINT("\n\n"); #endif } mLatentVar.fix(VARP::CONSTANT); @@ -289,9 +289,9 @@ VARP Pipeline::unet(VARP text_embeddings) { #ifdef MNN_DUMP_DATA auto xx = mLatentVar->readMap(); for(int i=0; i<10; i+=2) { - printf("%f ", xx[i]); + MNN_PRINT("%f ", xx[i]); } - printf("\n\n"); + MNN_PRINT("\n\n"); #endif return mLatentVar; } @@ -307,9 +307,9 @@ VARP Pipeline::vae_decoder(VARP latent) { #ifdef MNN_DUMP_DATA auto xx = output->readMap(); for(int i=0; i<320; i+=32) { - printf("%f ", xx[i]); + MNN_PRINT("%f ", xx[i]); } - printf("\n\n"); + MNN_PRINT("\n\n"); #endif auto image = output; @@ -340,7 +340,7 @@ bool Pipeline::run(const std::string& sentence, const std::string& img_name) { auto image = vae_decoder(latent); bool res = imwrite(img_name, image); if (res) { - printf("SUCCESS! write to %s\n", img_name.c_str()); + MNN_PRINT("SUCCESS! write to %s\n", img_name.c_str()); } return true; } diff --git a/transformers/llm/engine/include/llm/llm.hpp b/transformers/llm/engine/include/llm/llm.hpp new file mode 100644 index 000000000..4e1f445df --- /dev/null +++ b/transformers/llm/engine/include/llm/llm.hpp @@ -0,0 +1,121 @@ +// +// llm.hpp +// +// Created by MNN on 2023/08/25. +// ZhaodeWang +// + +#ifndef LLM_hpp +#define LLM_hpp + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +namespace MNN { +namespace Transformer { +class Tokenizer; +class Pipeline; +class LlmConfig; + +// Llm start +// llm stream buffer with callback +class MNN_PUBLIC LlmStreamBuffer : public std::streambuf { +public: + using CallBack = std::function;; + LlmStreamBuffer(CallBack callback) : callback_(callback) {} + +protected: + virtual std::streamsize xsputn(const char* s, std::streamsize n) override { + if (callback_) { + callback_(s, n); + } + return n; + } + +private: + CallBack callback_ = nullptr; +}; +class MNN_PUBLIC Llm { + using PromptItem = std::pair; // +public: + Llm(std::shared_ptr config) : config_(config) {} + virtual ~Llm(); + static Llm* createLLM(const std::string& config_path); + void chat(); + void reset(); + void trace(bool start); + virtual void load(); + MNN::Express::VARP forward(const std::vector& input_ids); + int sample(MNN::Express::VARP logits, const std::vector& pre_ids); + std::string apply_prompt_template(const std::string& user_content) const; + std::string apply_chat_template(const std::vector& chat_prompts) const; + std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr); + std::string response(const std::vector& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr); + void generate_init(); + std::string generate(const std::vector& input_ids, std::ostream* os, const char* end_with); + std::vector generate(const std::vector& input_ids, int max_new_tokens = -1); + void print_speed(); + // config function + std::string dump_config(); + bool set_config(const std::string& content); + friend class Pipeline; +public: + // forward info + int prompt_len_ = 0; + int gen_seq_len_ = 0; + int all_seq_len_ = 0; + std::vector history_ids_; + // time + int64_t prefill_us_ = 0; + int64_t decode_us_ = 0; + bool is_single_ = true; +protected: + std::shared_ptr config_; + std::shared_ptr tokenizer_; + std::vector key_value_shape_ = {}; + std::vector past_key_values_; + MNN::Express::VARP inputs_embeds_, attention_mask_, position_ids_; + std::shared_ptr runtime_manager_; + std::vector> modules_; + std::vector> decode_modules_; + std::vector> prefill_modules_; + void init_runtime(); + std::string decode(int id); + bool is_stop(int token_id); + virtual std::vector tokenizer(const std::string& query); + virtual MNN::Express::VARP embedding(const std::vector& input_ids); + virtual MNN::Express::VARP gen_attention_mask(int seq_len); + virtual MNN::Express::VARP gen_position_ids(int seq_len); +}; + +// Embedding start +class Embedding : public Llm { +public: + Embedding(std::shared_ptr config); + static Embedding* createEmbedding(const std::string& config_path); + static float dist(MNN::Express::VARP var0, MNN::Express::VARP var1); + virtual void load() override; + MNN::Express::VARP embedding(const std::string& txt); + int dim() const; +private: + virtual std::vector tokenizer(const std::string& query) override; + virtual MNN::Express::VARP gen_attention_mask(int seq_len) override; + virtual MNN::Express::VARP gen_position_ids(int seq_len) override; +}; +// Embedding end +} +} + +#endif // LLM_hpp diff --git a/transformers/llm/engine/llm_demo.cpp b/transformers/llm/engine/llm_demo.cpp index 4edc7731d..416154f84 100644 --- a/transformers/llm/engine/llm_demo.cpp +++ b/transformers/llm/engine/llm_demo.cpp @@ -5,13 +5,13 @@ // ZhaodeWang // -#include "llm.hpp" +#include "llm/llm.hpp" #define MNN_OPEN_TIME_TRACE #include #include #include #include - +using namespace MNN::Transformer; static void trace_prepare(Llm* llm) { MNN_PRINT("Prepare for resize opt Begin\n"); std::vector prompts = { diff --git a/transformers/llm/engine/src/llm.cpp b/transformers/llm/engine/src/llm.cpp index 3838d8538..4ed60d9c2 100644 --- a/transformers/llm/engine/src/llm.cpp +++ b/transformers/llm/engine/src/llm.cpp @@ -15,8 +15,9 @@ #include #include #include "cpp/ExprDebug.hpp" -#include "llm.hpp" +#include "llm/llm.hpp" #include "tokenizer.hpp" +#include "llmconfig.hpp" // 0: no debug, 1: test op time, 2: print tensor info #define DEBUG_MODE 0 @@ -24,6 +25,29 @@ #include "httplib.h" #include #endif +using namespace MNN::Express; +namespace MNN { +namespace Transformer { + +class Lvlm : public Llm { +public: + Lvlm(std::shared_ptr config) : Llm(config) { + img_size_ = config->llm_config_.value("img_size", img_size_); + imgpad_len_ = config->llm_config_.value("imgpad_len", imgpad_len_); + img_start_ = config->llm_config_.value("img_start", img_start_); + img_end_ = config->llm_config_.value("img_end", img_end_); + img_pad_ = config->llm_config_.value("img_pad", img_pad_); + } + ~Lvlm() { visual_module_.reset(); } + virtual void load() override; +private: + int img_size_ = 448, imgpad_len_ = 256, img_start_ = 151857, img_end_ = 151858, img_pad_ = 151859; + std::shared_ptr visual_module_; + MNN::Express::VARP visual_embedding(const std::vector& input_ids); + std::vector url_encode(const std::string& url); + virtual std::vector tokenizer(const std::string& query) override; + virtual MNN::Express::VARP embedding(const std::vector& input_ids) override; +}; // Llm start Llm* Llm::createLLM(const std::string& config_path) { @@ -48,6 +72,14 @@ static MNNForwardType backend_type_convert(const std::string& type_str) { return MNN_FORWARD_AUTO; } +std::string Llm::dump_config() { + return config_->config_.dump(); +} + +bool Llm::set_config(const std::string& content) { + return config_->config_.merge(content.c_str()); +} + void Llm::init_runtime() { ScheduleConfig config; BackendConfig cpuBackendConfig; @@ -64,6 +96,9 @@ void Llm::init_runtime() { runtime_manager_.reset(Executor::RuntimeManager::createRuntimeManager(config)); runtime_manager_->setHint(MNN::Interpreter::MEM_ALLOCATOR_TYPE, 0); + runtime_manager_->setHint(MNN::Interpreter::DYNAMIC_QUANT_OPTIONS, 1); // 1: per batch quant, 2: per tensor quant + runtime_manager_->setHint(MNN::Interpreter::KVCACHE_QUANT_OPTIONS, config_->quant_kv()); // 0: no quant, 1: quant key, 2: quant value, 3: quant kv + #if DEBUG_MODE==1 runtime_manager_->setMode(MNN::Interpreter::Session_Debug); _initTimeTrace(); @@ -104,8 +139,8 @@ void Llm::load() { MNN_PRINT("load %s ... ", model_path.c_str()); runtime_manager_->setExternalFile(config_->llm_weight()); modules_[0].reset(Module::load( - {"input_ids", "attention_mask", "position_ids", "past_key_values"}, - {"logits", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); + {"input_ids", "attention_mask", "position_ids", "past_key_values"}, + {"logits", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); MNN_PRINT("Done!\n"); } else { // load split models @@ -117,8 +152,8 @@ void Llm::load() { std::string model_path = config_->block_model(i); MNN_PRINT("load %s ... ", model_path.c_str()); modules_[i].reset(Module::load( - {"inputs_embeds", "attention_mask", "position_ids", "past_key_values"}, - {"hidden_states", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); + {"inputs_embeds", "attention_mask", "position_ids", "past_key_values"}, + {"hidden_states", "presents"}, model_path.c_str(), runtime_manager_, &module_config)); MNN_PRINT("Done!\n"); } } @@ -261,10 +296,14 @@ void Llm::chat() { } } +void Llm::reset() { + history_ids_.clear(); + all_seq_len_ = 0; +} + void Llm::generate_init() { // init status gen_seq_len_ = 0; - all_seq_len_ = 0; prefill_us_ = 0; decode_us_ = 0; past_key_values_.clear(); @@ -275,6 +314,10 @@ void Llm::generate_init() { past_key_values_.push_back(_Input(key_value_shape_, NCHW)); } } + if (!config_->reuse_kv()) { + all_seq_len_ = 0; + history_ids_.clear(); + } } std::vector Llm::generate(const std::vector& input_ids, int max_new_tokens) { @@ -306,15 +349,14 @@ std::vector Llm::generate(const std::vector& input_ids, int max_new_to std::string Llm::generate(const std::vector& input_ids, std::ostream* os, const char* end_with) { prompt_len_ = static_cast(input_ids.size()); - std::vector all_ids = input_ids; + history_ids_.insert(history_ids_.end(), input_ids.begin(), input_ids.end()); // push to history_ids_ auto st = std::chrono::system_clock::now(); modules_ = prefill_modules_; auto logits = forward(input_ids); if (nullptr == logits.get()) { return ""; } - int token = sample(logits, all_ids); - all_ids.push_back(token); + int token = sample(logits, history_ids_); auto et = std::chrono::system_clock::now(); modules_ = decode_modules_; std::string output_str = decode(token); @@ -322,6 +364,7 @@ std::string Llm::generate(const std::vector& input_ids, std::ostream* os, c *os << output_str << std::flush; while (gen_seq_len_ < config_->max_new_tokens()) { st = std::chrono::system_clock::now(); + history_ids_.push_back(token); logits = forward({token}); if (nullptr == logits.get()) { return ""; @@ -329,14 +372,13 @@ std::string Llm::generate(const std::vector& input_ids, std::ostream* os, c if (logits->getInfo()->size == 0) { return ""; } - token = sample(logits, all_ids); + token = sample(logits, history_ids_); et = std::chrono::system_clock::now(); decode_us_ += std::chrono::duration_cast(et - st).count(); if (is_stop(token)) { *os << end_with << std::flush; break; } - all_ids.push_back(token); auto word = decode(token); *os << word << std::flush; output_str += word; @@ -356,7 +398,11 @@ std::vector Llm::tokenizer(const std::string& user_content) { std::string Llm::response(const std::string& user_content, std::ostream* os, const char* end_with) { generate_init(); if (!end_with) { end_with = "\n"; } - auto input_ids = tokenizer(user_content); + auto prompt = apply_prompt_template(user_content); + if (config_->reuse_kv() && all_seq_len_ > 0) { + prompt = "<|im_end|>\n" + prompt; + } + auto input_ids = tokenizer_->encode(prompt); return generate(input_ids, os, end_with); } @@ -365,7 +411,12 @@ std::string Llm::response(const std::vector& chat_prompts, std::ostr generate_init(); if (!end_with) { end_with = "\n"; } auto prompt = apply_chat_template(chat_prompts); + if (config_->reuse_kv() && all_seq_len_ > 0) { + prompt = "<|im_end|>\n" + prompt; + } + std::cout << "# prompt : " << prompt << std::endl; auto input_ids = tokenizer_->encode(prompt); + printf("input_ids (%lu): ", input_ids.size()); for (auto id : input_ids) printf("%d, ", id); printf("\n"); return generate(input_ids, os, end_with); } @@ -462,29 +513,34 @@ std::string Llm::decode(int id) { } VARP Llm::gen_attention_mask(int seq_len) { + int kv_seq_len = all_seq_len_ + seq_len; + if (seq_len == 1) { + kv_seq_len = seq_len; + } if (config_->attention_mask() == "float") { if (needNewVar(attention_mask_, 2, seq_len)) { - attention_mask_ = _Input({1, 1, seq_len, seq_len}, NCHW, halide_type_of()); + attention_mask_ = _Input({1, 1, seq_len, kv_seq_len}, NCHW, halide_type_of()); } else { return attention_mask_; } auto ptr = attention_mask_->writeMap(); for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < seq_len; j++) { - ptr[seq_len * i + j] = (j > i) * std::numeric_limits::lowest(); + for (int j = 0; j < kv_seq_len; j++) { + int row = i + all_seq_len_; + ptr[kv_seq_len * i + j] = (j > row) * std::numeric_limits::lowest(); } } return attention_mask_; } else { if (needNewVar(attention_mask_, 2, seq_len)) { - attention_mask_ = _Input({1, 1, seq_len, seq_len}, NCHW, halide_type_of()); + attention_mask_ = _Input({1, 1, seq_len, kv_seq_len}, NCHW, halide_type_of()); } else { return attention_mask_; } auto ptr = attention_mask_->writeMap(); if (config_->attention_mask() == "glm") { // chatglm - for (int i = 0; i < seq_len * seq_len; i++) { + for (int i = 0; i < seq_len * kv_seq_len; i++) { ptr[i] = 0; } if (seq_len > 1) { @@ -495,8 +551,9 @@ VARP Llm::gen_attention_mask(int seq_len) { } else { bool is_glm2 = config_->attention_mask() == "glm2"; for (int i = 0; i < seq_len; i++) { - for (int j = 0; j < seq_len; j++) { - ptr[seq_len * i + j] = is_glm2 ? j > i : j <= i; + for (int j = 0; j < kv_seq_len; j++) { + int row = i + all_seq_len_; + ptr[seq_len * i + j] = is_glm2 ? j > row : j <= row; } } } @@ -533,7 +590,7 @@ VARP Llm::gen_position_ids(int seq_len) { ptr[0] = is_glm2 ? gen_seq_len_ : all_seq_len_; } else { for (int i = 0; i < seq_len; i++) { - ptr[i] = i; + ptr[i] = i + all_seq_len_; } } return position_ids_; @@ -671,6 +728,10 @@ Embedding* Embedding::createEmbedding(const std::string& config_path) { return embedding; } +Embedding::Embedding(std::shared_ptr config) : Llm(config) {} + +int Embedding::dim() const { return config_->hidden_size(); } + void Embedding::load() { init_runtime(); printf("load tokenizer\n"); @@ -686,8 +747,8 @@ void Embedding::load() { MNN_PRINT("load %s ... ", model_path.c_str()); modules_.resize(1); modules_[0].reset(Module::load( - {"input_ids", "attention_mask", "position_ids"}, - {"sentence_embeddings"}, model_path.c_str(), runtime_manager_, &module_config)); + {"input_ids", "attention_mask", "position_ids"}, + {"sentence_embeddings"}, model_path.c_str(), runtime_manager_, &module_config)); MNN_PRINT("Done!\n"); } @@ -730,3 +791,5 @@ VARP Embedding::gen_position_ids(int seq_len) { return position_ids; } // Embedding end +} +} diff --git a/transformers/llm/engine/include/llm.hpp b/transformers/llm/engine/src/llmconfig.hpp similarity index 58% rename from transformers/llm/engine/include/llm.hpp rename to transformers/llm/engine/src/llmconfig.hpp index 71c3f00fd..71ef7291f 100644 --- a/transformers/llm/engine/include/llm.hpp +++ b/transformers/llm/engine/src/llmconfig.hpp @@ -1,58 +1,20 @@ // -// llm.hpp +// llmconfig.hpp // -// Created by MNN on 2023/08/25. +// Created by MNN on 2024/07/19. // ZhaodeWang // -#ifndef LLM_hpp -#define LLM_hpp - -#include -#include -#include -#include -#include -#include -#include -#include -#include - -#include -#include -#include -#include -#include "tokenizer.hpp" #include "rapidjson/document.h" +#include +#include -using namespace MNN; -using namespace Express; -using namespace rapidjson; -class Tokenizer; -class Pipeline; - -// Llm start -// llm stream buffer with callback -class LlmStreamBuffer : public std::streambuf { -public: - using CallBack = std::function;; - LlmStreamBuffer(CallBack callback) : callback_(callback) {} - -protected: - virtual std::streamsize xsputn(const char* s, std::streamsize n) override { - if (callback_) { - callback_(s, n); - } - return n; - } - -private: - CallBack callback_ = nullptr; -}; +namespace MNN { +namespace Transformer { static inline bool has_suffix(const std::string& str, const std::string& suffix) { return str.size() >= suffix.size() && - str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; + str.compare(str.size() - suffix.size(), suffix.size(), suffix) == 0; } static inline std::string base_dir(const std::string& path) { @@ -73,26 +35,69 @@ static inline std::string file_name(const std::string& path) { } } +bool merge_json(rapidjson::Value& destination, const rapidjson::Value& source, + rapidjson::Document::AllocatorType& allocator) { + if (!source.IsObject() || !destination.IsObject()) { + return false; + } + + for (auto it = source.MemberBegin(); it != source.MemberEnd(); ++it) { + const char* key = it->name.GetString(); + if (destination.HasMember(key)) { + if (destination[key].IsObject() && it->value.IsObject()) { + // Recursively merge the two JSON objects + merge_json(destination[key], it->value, allocator); + } else { + // Overwrite the value in the destination + destination[key].CopyFrom(it->value, allocator); + } + } else { + // Add the value to the destination + rapidjson::Value newKey(key, allocator); + rapidjson::Value newValue; + newValue.CopyFrom(it->value, allocator); + destination.AddMember(newKey, newValue, allocator); + } + } + return true; +} + class rapid_json_wrapper { public: - Document document; + rapidjson::Document document; rapid_json_wrapper() {} - rapid_json_wrapper(Document doc) : document(std::move(doc)) {} + rapid_json_wrapper(rapidjson::Document doc) : document(std::move(doc)) {} static rapid_json_wrapper parse(const std::ifstream& ifile) { std::ostringstream ostr; ostr << ifile.rdbuf(); - Document document; + rapidjson::Document document; document.Parse(ostr.str().c_str()); rapid_json_wrapper json_wrapper(std::move(document)); return json_wrapper; } static rapid_json_wrapper parse(const char* str) { - Document document; + rapidjson::Document document; document.Parse(str); rapid_json_wrapper json_wrapper(std::move(document)); return json_wrapper; } - + bool merge(const char* str) { + rapidjson::Document input_doc; + input_doc.Parse(str); + if (input_doc.HasParseError()) { + return false; + } + // merge + rapidjson::Document::AllocatorType& allocator = document.GetAllocator(); + return merge_json(document, input_doc, allocator); + } + std::string dump() { + rapidjson::StringBuffer buffer; + rapidjson::Writer writer(buffer); + document.Accept(writer); + return buffer.GetString(); + } + // read value int value(const char* key, const int& default_value) const { if (document.HasMember(key)) { const auto& value = document[key]; @@ -218,6 +223,14 @@ class LlmConfig { int max_new_tokens() const { return config_.value("max_new_tokens", 512); } + + bool reuse_kv() const { + return config_.value("reuse_kv", false); + } + + int quant_kv() const { + return config_.value("quant_kv", 0); + } // generate config end > // < backend config start @@ -272,90 +285,5 @@ class LlmConfig { } // llm model config end > }; - -class MNN_PUBLIC Llm { - using PromptItem = std::pair; // -public: - Llm(std::shared_ptr config) : config_(config) {} - virtual ~Llm(); - static Llm* createLLM(const std::string& config_path); - void chat(); - void trace(bool start); - virtual void load(); - VARP forward(const std::vector& input_ids); - int sample(VARP logits, const std::vector& pre_ids); - std::string apply_prompt_template(const std::string& user_content) const; - std::string apply_chat_template(const std::vector& chat_prompts) const; - std::string response(const std::string& user_content, std::ostream* os = &std::cout, const char* end_with = nullptr); - std::string response(const std::vector& chat_prompts, std::ostream* os = &std::cout, const char* end_with = nullptr); - void generate_init(); - std::string generate(const std::vector& input_ids, std::ostream* os, const char* end_with); - std::vector generate(const std::vector& input_ids, int max_new_tokens = -1); - void print_speed(); - friend class Pipeline; -public: - // forward info - int prompt_len_ = 0; - int gen_seq_len_ = 0; - int all_seq_len_ = 0; - // time - int64_t prefill_us_ = 0; - int64_t decode_us_ = 0; - bool is_single_ = true; - std::shared_ptr config_; - std::unique_ptr tokenizer_; -protected: - std::vector key_value_shape_ = {}; - std::vector past_key_values_; - VARP inputs_embeds_, attention_mask_, position_ids_; - std::shared_ptr runtime_manager_; - std::vector> modules_; - std::vector> decode_modules_; - std::vector> prefill_modules_; - void init_runtime(); - std::string decode(int id); - bool is_stop(int token_id); - virtual std::vector tokenizer(const std::string& query); - virtual VARP embedding(const std::vector& input_ids); - virtual VARP gen_attention_mask(int seq_len); - virtual VARP gen_position_ids(int seq_len); -}; - -class Lvlm : public Llm { -public: - Lvlm(std::shared_ptr config) : Llm(config) { - img_size_ = config->llm_config_.value("img_size", img_size_); - imgpad_len_ = config->llm_config_.value("imgpad_len", imgpad_len_); - img_start_ = config->llm_config_.value("img_start", img_start_); - img_end_ = config->llm_config_.value("img_end", img_end_); - img_pad_ = config->llm_config_.value("img_pad", img_pad_); - } - ~Lvlm() { visual_module_.reset(); } - virtual void load() override; -private: - int img_size_ = 448, imgpad_len_ = 256, img_start_ = 151857, img_end_ = 151858, img_pad_ = 151859; - std::shared_ptr visual_module_; - VARP visual_embedding(const std::vector& input_ids); - std::vector url_encode(const std::string& url); - virtual std::vector tokenizer(const std::string& query) override; - virtual VARP embedding(const std::vector& input_ids) override; -}; -// Llm end - -// Embedding start -class Embedding : public Llm { -public: - Embedding(std::shared_ptr config) : Llm(config) {} - static Embedding* createEmbedding(const std::string& config_path); - static float dist(VARP var0, VARP var1); - virtual void load() override; - VARP embedding(const std::string& txt); - int dim() { return config_->hidden_size(); } -private: - virtual std::vector tokenizer(const std::string& query) override; - virtual VARP gen_attention_mask(int seq_len) override; - virtual VARP gen_position_ids(int seq_len) override; -}; -// Embedding end - -#endif // LLM_hpp +} // Transformer +} // MNN \ No newline at end of file diff --git a/transformers/llm/engine/src/tokenizer.cpp b/transformers/llm/engine/src/tokenizer.cpp index f0350adc9..6330d8885 100644 --- a/transformers/llm/engine/src/tokenizer.cpp +++ b/transformers/llm/engine/src/tokenizer.cpp @@ -15,19 +15,21 @@ #include #include #include +namespace MNN { +namespace Transformer { // base64 static const std::string base64_chars = - "ABCDEFGHIJKLMNOPQRSTUVWXYZ" - "abcdefghijklmnopqrstuvwxyz" - "0123456789+/"; +"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +"abcdefghijklmnopqrstuvwxyz" +"0123456789+/"; static inline bool is_base64(unsigned char c) { return (isalnum(c) || (c == '+') || (c == '/')); } static inline size_t one_char_len(const char *src) { - return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; + return "\1\1\1\1\1\1\1\1\1\1\1\1\2\2\3\4"[(*src & 0xFF) >> 4]; } static std::string base64_decode(const std::string& str) { @@ -343,13 +345,13 @@ Sentencepiece::EncodeResult Sentencepiece::bpe_encode(string_view_ normalized, f if (skip_merge()) continue; // Replaces symbols with `top` rule. symbols[top->left].piece = string_view_( - symbols[top->left].piece.data(), - symbols[top->left].piece.size() + symbols[top->right].piece.size()); + symbols[top->left].piece.data(), + symbols[top->left].piece.size() + symbols[top->right].piece.size()); // Updates prev/next pointers. symbols[top->left].next = symbols[top->right].next; if (symbols[top->right].next >= 0) { - symbols[symbols[top->right].next].prev = top->left; + symbols[symbols[top->right].next].prev = top->left; } symbols[top->right].piece = string_view_(""); @@ -585,11 +587,11 @@ std::string wstring_to_utf8(const std::wstring& str) { void byte_encode_token(const std::string& token, const std::unordered_map& b2u, std::wstring* result) { - result->resize(0); - for (char c : token) { - wchar_t wc = b2u.at(uint8_t(c)); - result->push_back(wc); - } + result->resize(0); + for (char c : token) { + wchar_t wc = b2u.at(uint8_t(c)); + result->push_back(wc); + } } bool HuggingfaceTokenizer::load_vocab(std::ifstream& tok_file) { @@ -611,10 +613,10 @@ bool HuggingfaceTokenizer::load_vocab(std::ifstream& tok_file) { std::getline(tok_file, line); int d = line.find(" "); bpe_ranks_.insert({{utf8_to_wstring(line.substr(0, d)), - utf8_to_wstring(line.substr(d + 1))}, i}); + utf8_to_wstring(line.substr(d + 1))}, i}); } // bytes_to_unicode - auto _insert_range = [=](int start, int end) { + auto _insert_range = [=](int start, int end) { for (int c = start; c <= end; c++) { b2u_.insert({uint8_t(c), wchar_t(c)}); } @@ -654,13 +656,13 @@ void HuggingfaceTokenizer::bpe(const std::wstring& token, const BPERanks& bpe_ra std::set merged; // records indices in pairs that were merged. auto _left = [](int i, std::set& merged) { for (int j = i - 1; j >= -1; j--) { - if (merged.find(j) == merged.end()) return j; + if (merged.find(j) == merged.end()) return j; } return -1; }; auto _right = [](int i, int cap, std::set& merged) { for (int j = i + 1; j < cap; j++) { - if (merged.find(j) == merged.end()) return j; + if (merged.find(j) == merged.end()) return j; } return cap; }; @@ -673,15 +675,15 @@ void HuggingfaceTokenizer::bpe(const std::wstring& token, const BPERanks& bpe_ra int to_merge = -1; // indices into pairs. for (int i = 0; i < pairs.size(); ++i) { - if (merged.find(i) == merged.end()) { // pair i is not merged. - auto iter = bpe_ranks.find(pairs[i]); - int score = iter != bpe_ranks.end() ? iter->second : INT_MAX; - if (score < min_score) { - min_score = score; - to_merge = i; + if (merged.find(i) == merged.end()) { // pair i is not merged. + auto iter = bpe_ranks.find(pairs[i]); + int score = iter != bpe_ranks.end() ? iter->second : INT_MAX; + if (score < min_score) { + min_score = score; + to_merge = i; + } } } - } if (to_merge == -1) break; @@ -747,3 +749,5 @@ std::string HuggingfaceTokenizer::decode(int id) { } return r; } +} +} diff --git a/transformers/llm/engine/include/tokenizer.hpp b/transformers/llm/engine/src/tokenizer.hpp similarity index 98% rename from transformers/llm/engine/include/tokenizer.hpp rename to transformers/llm/engine/src/tokenizer.hpp index e30cd980e..77ceeda5d 100644 --- a/transformers/llm/engine/include/tokenizer.hpp +++ b/transformers/llm/engine/src/tokenizer.hpp @@ -15,8 +15,6 @@ #include // #include #include - -// std::string_view impl in c++11 start class string_view_ { public: string_view_() : data_(nullptr), size_(0) {} @@ -46,6 +44,7 @@ class string_view_ { const char* data_; std::size_t size_ = 0; }; +// std::string_view impl in c++11 end namespace std { template<> @@ -60,7 +59,9 @@ namespace std { } }; } -// std::string_view impl in c++11 end +namespace MNN { +namespace Transformer { +// std::string_view impl in c++11 start class Tokenizer { public: @@ -183,5 +184,7 @@ using BPERanks = std::unordered_map, int, std::unordered_map encoder_; std::vector decoder_; }; +}; +}; -#endif // TOKENIZER_hpp \ No newline at end of file +#endif // TOKENIZER_hpp