From 49223665e3baa8509c0d51c545b96aaf01f0025d Mon Sep 17 00:00:00 2001 From: Pedro Gonnet Date: Tue, 7 Jan 2025 22:48:11 -0800 Subject: [PATCH] Use the new `pthreadpool_parallelize_[23]d_tile_2d_dynamic` strategies in the GEMM-based ops. PiperOrigin-RevId: 713162910 --- CMakeLists.txt | 108 +- bench/average-pooling.cc | 2 +- bench/batch-matrix-multiply.cc | 24 +- bench/bf16-gemm.cc | 2 +- bench/binary.cc | 55 +- bench/convolution.cc | 2 +- bench/deconvolution.cc | 2 +- bench/f16-conv-hwc2chw.cc | 2 +- bench/f16-dwconv.cc | 2 +- bench/f16-dwconv2d-chw.cc | 2 +- bench/f16-f32acc-gemm.cc | 2 +- bench/f16-f32acc-igemm.cc | 2 +- bench/f16-f32acc-rdsum.cc | 2 +- bench/f16-f32acc-rsum.cc | 2 +- bench/f16-gemm-minmax.cc | 2 +- bench/f16-gemm.cc | 2 +- bench/f16-igemm.cc | 2 +- bench/f16-raddstoreexpminusmax.cc | 2 +- bench/f16-rmax.cc | 2 +- bench/f16-rmin.cc | 2 +- bench/f16-rminmax.cc | 2 +- bench/f16-rsum.cc | 2 +- bench/f16-spmm.cc | 2 +- bench/f16-vcmul.cc | 2 +- bench/f32-bgemm.cc | 2 +- bench/f32-conv-hwc.cc | 2 +- bench/f32-conv-hwc2chw.cc | 2 +- bench/f32-dwconv.cc | 2 +- bench/f32-dwconv2d-chw.cc | 2 +- bench/f32-gemm-goi-minmax.cc | 2 +- bench/f32-gemm-minmax.cc | 2 +- bench/f32-gemm.cc | 2 +- bench/f32-igemm.cc | 2 +- bench/f32-im2col-gemm.cc | 2 +- bench/f32-qc4w-gemm.cc | 2 +- bench/f32-qc8w-gemm.cc | 2 +- bench/f32-raddexpminusmax.cc | 2 +- bench/f32-raddextexp.cc | 2 +- bench/f32-raddstoreexpminusmax.cc | 2 +- bench/f32-rdsum.cc | 2 +- bench/f32-rmax.cc | 2 +- bench/f32-rmin.cc | 2 +- bench/f32-rminmax.cc | 2 +- bench/f32-rsum.cc | 2 +- bench/f32-softmax.cc | 331 +++--- bench/f32-spmm.cc | 2 +- bench/f32-vcmul.cc | 2 +- bench/f32-vscaleexpminusmax.cc | 2 +- bench/f32-vscaleextexp.cc | 2 +- bench/fully-connected.cc | 2 +- bench/max-pooling.cc | 2 +- bench/models/benchmark.cc | 107 +- bench/prelu.cc | 2 +- bench/qd8-f16-qb4w-gemm.cc | 2 +- bench/qd8-f16-qc4w-gemm.cc | 2 +- bench/qd8-f16-qc8w-gemm.cc | 2 +- bench/qd8-f32-qb4w-gemm.cc | 2 +- bench/qd8-f32-qc4w-gemm.cc | 2 +- bench/qd8-f32-qc8w-gemm.cc | 2 +- bench/qp8-f32-qb4w-gemm.cc | 2 +- bench/qp8-f32-qc4w-gemm.cc | 2 +- bench/qp8-f32-qc8w-gemm.cc | 2 +- bench/qs8-dwconv.cc | 2 +- bench/qs8-gemm.cc | 2 +- bench/qs8-packw.cc | 2 +- bench/qs8-qc4w-packw.cc | 2 +- bench/qs8-qc8w-gemm-fp32.cc | 2 +- bench/qs8-rdsum.cc | 2 +- bench/qs8-rsum.cc | 2 +- bench/qu8-gemm-fp32.cc | 2 +- bench/qu8-gemm-rndnu.cc | 2 +- bench/qu8-gemm.cc | 2 +- bench/qu8-rdsum.cc | 2 +- bench/qu8-rsum.cc | 2 +- bench/scaled-dot-product-attention.cc | 2 +- bench/softmax.cc | 2 +- bench/unary.cc | 61 +- bench/utils.cc | 70 +- bench/utils.h | 28 + bench/vbinary.cc | 2 +- bench/vunary.cc | 2 +- bench/x16-packw.cc | 2 +- bench/x32-packw.cc | 2 +- bench/x8-lut.cc | 2 +- bench/x8-packq.cc | 2 +- bench/x8-packw.cc | 2 +- bench/xN-transposec.cc | 2 +- bench/xx-transposev.cc | 2 +- gemm_compiler/aarch64_template.py | 12 +- gemm_compiler/base_architecture.py | 4 +- gemm_compiler/generate.py | 4 +- gemm_compiler/neonfma_template.py | 6 +- gemm_compiler/x64_template.py | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u32-acc2.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u32-acc4.c | 2 +- ...toreexpminusmax-neonfp16arith-rr2-p2-u32.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u40-acc2.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u40-acc5.c | 2 +- ...toreexpminusmax-neonfp16arith-rr2-p2-u40.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u48-acc2.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u48-acc3.c | 2 +- ...toreexpminusmax-neonfp16arith-rr2-p2-u48.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u64-acc2.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u64-acc4.c | 2 +- ...toreexpminusmax-neonfp16arith-rr2-p2-u64.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u72-acc3.c | 2 +- ...toreexpminusmax-neonfp16arith-rr2-p2-u72.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u80-acc2.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u80-acc5.c | 2 +- ...toreexpminusmax-neonfp16arith-rr2-p2-u80.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u96-acc2.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u96-acc3.c | 2 +- ...xpminusmax-neonfp16arith-rr2-p2-u96-acc6.c | 2 +- ...toreexpminusmax-neonfp16arith-rr2-p2-u96.c | 2 +- .../neonfp16arith-rr2-p2.c.in | 2 +- ...10x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...10x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...11x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...11x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...emm-1x16-minmax-asm-aarch64-neonfma-ld32.S | 10 +- ...-1x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-1x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-1x64-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S | 10 +- ...emm-2x16-minmax-asm-aarch64-neonfma-ld32.S | 13 +- ...-2x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-2x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-2x64-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S | 13 +- ...emm-3x16-minmax-asm-aarch64-neonfma-ld32.S | 16 +- ...-3x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-3x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-3x64-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S | 16 +- ...emm-4x16-minmax-asm-aarch64-neonfma-ld32.S | 19 +- ...-4x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-4x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-4x64-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S | 19 +- ...emm-5x16-minmax-asm-aarch64-neonfma-ld32.S | 22 +- ...-5x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-5x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-5x64-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S | 22 +- ...-6x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-6x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-7x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-7x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-8x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-8x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-9x16-minmax-asm-amd64-avx512f-broadcast.S | 2 +- ...-9x32-minmax-asm-amd64-avx512f-broadcast.S | 2 +- src/operator-run.c | 216 ++-- src/operators/batch-matrix-multiply-nc.c | 34 +- src/operators/convolution-nhwc.c | 37 +- src/operators/dynamic-fully-connected-nc.c | 30 +- src/operators/fully-connected-nc.c | 28 +- ...w-gemm-10x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...w-gemm-10x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...w-gemm-11x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...w-gemm-11x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...emm-1x16-minmax-asm-aarch64-neondot-ld32.S | 10 +- ...8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S | 2 +- ...gemm-1x8-minmax-asm-aarch64-neondot-ld32.S | 10 +- ...emm-2x16-minmax-asm-aarch64-neondot-ld32.S | 13 +- ...8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S | 2 +- ...gemm-2x8-minmax-asm-aarch64-neondot-ld32.S | 13 +- ...emm-3x16-minmax-asm-aarch64-neondot-ld32.S | 16 +- ...8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S | 2 +- ...gemm-3x8-minmax-asm-aarch64-neondot-ld32.S | 16 +- ...emm-4x16-minmax-asm-aarch64-neondot-ld32.S | 19 +- ...8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S | 2 +- ...gemm-4x8-minmax-asm-aarch64-neondot-ld32.S | 19 +- ...8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S | 2 +- ...8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S | 2 +- src/subgraph.c | 18 +- src/xnnpack/compute.h | 968 ++++++++---------- test/BUILD.bazel | 2 +- test/fully-connected.cc | 2 +- test/reduce-nd.cc | 2 +- test/workspace.cc | 1 + tools/generate-gemm-test.py | 2 +- tools/generate-rdsum-benchmark.py | 2 +- tools/generate-spmm-test.py | 2 +- 201 files changed, 1456 insertions(+), 1288 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 16bf851db1e..f0b3410aecc 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1266,18 +1266,31 @@ IF(XNNPACK_BUILD_TESTS) ENDIF() # ---[ Launch heavy tests first. + # Tests added to this list will be automatically removed from other lists. SET(SHARDED_TESTS fully-connected-nc avgpool-minmax - maxpool-minmax - f32-vclamp - f32-vlrelu - f32-rdsum - f32-velu - f32-argmaxpool - s8-vclamp - u8-vclamp - ) + maxpool-minmax) + IF(XNNPACK_TARGET_PROCESSOR MATCHES "^riscv") + LIST(APPEND SHARDED_TESTS + f16-qs8-vcvt + f16-qu8-vcvt + f32-argmaxpool + f32-qs8-vcvt + f32-qu8-vcvt + f32-rdsum + f32-vclamp + f32-velu + f32-vlrelu + qs8-f32-vcvt + qs8-vcvt + qs8-vlrelu + qu8-f32-vcvt + qu8-vcvt + qu8-vlrelu + s8-vclamp + u8-vclamp) + ENDIF() FOREACH(TEST ${SHARDED_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE include src test) @@ -1297,6 +1310,7 @@ IF(XNNPACK_BUILD_TESTS) IF(XNNPACK_BUILD_LIBRARY) # ---[ Launch heavy tests first. + # Tests added to this list will be automatically removed from other lists. SET(LIBRARY_SHARDED_TESTS batch-matrix-multiply-nc batch-matrix-multiply @@ -1342,6 +1356,7 @@ IF(XNNPACK_BUILD_TESTS) runtime subgraph-nchw workspace) + LIST(REMOVE_ITEM LIBRARY_SUBGRAPH_OPTIMIZATION_TESTS ${LIBRARY_SHARDED_TESTS}) FOREACH(TEST ${LIBRARY_SUBGRAPH_OPTIMIZATION_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) @@ -1386,6 +1401,7 @@ IF(XNNPACK_BUILD_TESTS) transpose-reshape unary unpooling-2d) + LIST(REMOVE_ITEM LIBRARY_SUBGRAPH_UNIT_TESTS ${LIBRARY_SHARDED_TESTS}) FOREACH(TEST ${LIBRARY_SUBGRAPH_UNIT_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) @@ -1407,6 +1423,7 @@ IF(XNNPACK_BUILD_TESTS) convolution-2d deconvolution-2d depthwise-convolution-2d) + LIST(REMOVE_ITEM LIBRARY_SUBGRAPH_CONVOLUTION_UNIT_TESTS ${LIBRARY_SHARDED_TESTS}) FOREACH(TEST ${LIBRARY_SUBGRAPH_CONVOLUTION_UNIT_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE src test) @@ -1440,21 +1457,23 @@ IF(XNNPACK_BUILD_TESTS) f16-conv-hwc2chw f16-f32acc-rdsum f16-f32acc-rsum - f16-ibilinear-chw f16-ibilinear + f16-ibilinear-chw f16-raddstoreexpminusmax f16-rmax f16-rsum f16-spmm-minmax f16-vcmul f16-vmulcaddc-minmax + f32-argmaxpool f32-conv-hwc f32-conv-hwc2chw - f32-ibilinear-chw f32-ibilinear + f32-ibilinear-chw f32-raddexpminusmax f32-raddextexp f32-raddstoreexpminusmax + f32-rdsum f32-rmax f32-rmin f32-rminmax @@ -1466,11 +1485,12 @@ IF(XNNPACK_BUILD_TESTS) f32-vscaleextexp indirection packing + qs8-packw + qs8-qc4w-packw qs8-rdsum-minmax-fp32 - qu8-rdsum qs8-rsum + qu8-rdsum qu8-rsum - qs8-vlrelu qu8-vlrelu s8-ibilinear u8-ibilinear @@ -1483,11 +1503,10 @@ IF(XNNPACK_BUILD_TESTS) x32-unpool x8-lut x8-packw - qs8-packw - qs8-qc4w-packw xN-transpose xx-fill xx-pad) + LIST(REMOVE_ITEM MICROKERNEL_UNIT_TESTS ${SHARDED_TESTS}) FOREACH(TEST ${MICROKERNEL_UNIT_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE include src test) @@ -1514,21 +1533,22 @@ IF(XNNPACK_BUILD_TESTS) f16-dwconv-minmax-multipass f16-dwconv-minmax-unipass f16-dwconv2d-chw - f32-dwconv-multipass f32-dwconv-minmax-multipass - f32-dwconv-unipass f32-dwconv-minmax-unipass + f32-dwconv-multipass + f32-dwconv-unipass f32-dwconv2d-chw - qs8-qc8w-dwconv-minmax-multipass-fp32 - qs8-qc8w-dwconv-minmax-unipass-fp32 qs8-dwconv-minmax-multipass-fp32 qs8-dwconv-minmax-multipass-rndnu qs8-dwconv-minmax-unipass-fp32 qs8-dwconv-minmax-unipass-rndnu + qs8-qc8w-dwconv-minmax-multipass-fp32 + qs8-qc8w-dwconv-minmax-unipass-fp32 qu8-dwconv-minmax-multipass-fp32 qu8-dwconv-minmax-multipass-rndnu qu8-dwconv-minmax-unipass-fp32 qu8-dwconv-minmax-unipass-rndnu) + LIST(REMOVE_ITEM MICROKERNEL_DWCONV_UNIT_TESTS ${SHARDED_TESTS}) FOREACH(TEST ${MICROKERNEL_DWCONV_UNIT_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE include src test) @@ -1550,38 +1570,39 @@ IF(XNNPACK_BUILD_TESTS) SET(MICROKERNEL_GEMM_UNIT_TESTS bf16-gemm-minmax f16-f32acc-gemm-minmax - f16-gemm-minmax f16-f32acc-igemm-minmax + f16-gemm-minmax f16-igemm-minmax - qd8-f16-qc8w-gemm-minmax f32-gemm - f32-gemm-relu - f32-gemm-minmax f32-gemm-goi-minmax - f32-qc8w-gemm - f32-qc8w-gemm-relu - f32-qc8w-gemm-minmax - f32-qc4w-gemm-minmax + f32-gemm-minmax + f32-gemm-relu f32-gemminc-minmax f32-igemm - f32-igemm-relu f32-igemm-minmax + f32-igemm-relu f32-ppmm-minmax - qd8-f32-qc8w-gemm-minmax + f32-qc4w-gemm-minmax + f32-qc8w-gemm + f32-qc8w-gemm-minmax + f32-qc8w-gemm-relu qd8-f16-qb4w-gemm-minmax qd8-f16-qc4w-gemm-minmax + qd8-f16-qc8w-gemm-minmax qd8-f32-qb4w-gemm-minmax qd8-f32-qc4w-gemm-minmax + qd8-f32-qc8w-gemm-minmax qd8-f32-qc8w-igemm-minmax + qp8-f32-qb4w-gemm-minmax qp8-f32-qc4w-gemm-minmax qp8-f32-qc8w-gemm-minmax - qp8-f32-qb4w-gemm-minmax qs8-qc8w-gemm-minmax-fp32 qs8-qc8w-igemm-minmax-fp32 qu8-gemm-minmax-fp32 qu8-gemm-minmax-rndnu qu8-igemm-minmax-fp32 qu8-igemm-minmax-rndnu) + LIST(REMOVE_ITEM MICROKERNEL_GEMM_UNIT_TESTS ${SHARDED_TESTS}) FOREACH(TEST ${MICROKERNEL_GEMM_UNIT_TESTS}) FILE(GLOB TEST_SOURCES "test/${TEST}*.cc") IF(TEST_SOURCES) @@ -1605,6 +1626,7 @@ IF(XNNPACK_BUILD_TESTS) SET(MICROKERNEL_PACKQ_UNIT_TESTS x8-packq) + LIST(REMOVE_ITEM MICROKERNEL_PACKQ_UNIT_TESTS ${SHARDED_TESTS}) FOREACH(TEST ${MICROKERNEL_PACKQ_UNIT_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE include src test) @@ -1675,6 +1697,7 @@ IF(XNNPACK_BUILD_TESTS) qu8-vmul-minmax-rndnu qu8-vmulc-minmax-fp32 qu8-vmulc-minmax-rndnu) + LIST(REMOVE_ITEM MICROKERNEL_VBINARY_UNIT_TESTS ${SHARDED_TESTS}) FOREACH(TEST ${MICROKERNEL_VBINARY_UNIT_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE include src test) @@ -1696,6 +1719,7 @@ IF(XNNPACK_BUILD_TESTS) SET(MICROKERNEL_VCVT_TESTS f16-f32-vcvt f16-qs8-vcvt + f16-qu8-vcvt f32-f16-vcvt f32-qs8-vcvt f32-qu8-vcvt @@ -1704,6 +1728,7 @@ IF(XNNPACK_BUILD_TESTS) qs8-vcvt qu8-f32-vcvt qu8-vcvt) + LIST(REMOVE_ITEM MICROKERNEL_VCVT_TESTS ${SHARDED_TESTS}) FOREACH(TEST ${MICROKERNEL_VCVT_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE include src test) @@ -1728,31 +1753,38 @@ IF(XNNPACK_BUILD_TESTS) f16-vhswish f16-vlrelu f16-vneg + f16-vrndd f16-vrndne - f16-vrndz f16-vrndu - f16-vrndd + f16-vrndz f16-vrsqrt f16-vsigmoid f16-vsqr f16-vsqrt f16-vtanh f32-vabs - f32-vhswish - f32-vgelu + f32-vclamp + f32-velu f32-vexp + f32-vgelu + f32-vhswish f32-vlog + f32-vlrelu f32-vneg f32-vrelu + f32-vrndd f32-vrndne - f32-vrndz f32-vrndu - f32-vrndd + f32-vrndz + f32-vrsqrt f32-vsigmoid f32-vsqr f32-vsqrt - f32-vrsqrt - f32-vtanh) + f32-vtanh + qs8-vlrelu + s8-vclamp + u8-vclamp) + LIST(REMOVE_ITEM MICROKERNEL_VUNARY_TESTS ${SHARDED_TESTS}) FOREACH(TEST ${MICROKERNEL_VUNARY_TESTS}) ADD_EXECUTABLE(${TEST}-test test/${TEST}.cc) TARGET_INCLUDE_DIRECTORIES(${TEST}-test PRIVATE include src test) diff --git a/bench/average-pooling.cc b/bench/average-pooling.cc index 6c3deca8616..457d03f9721 100644 --- a/bench/average-pooling.cc +++ b/bench/average-pooling.cc @@ -434,5 +434,5 @@ BENCHMARK_CAPTURE(xnnpack_average_pooling_qu8, shufflenet_v1_g4, "ShuffleNet v1 BENCHMARK_CAPTURE(xnnpack_average_pooling_qu8, shufflenet_v1_g8, "ShuffleNet v1 (8 groups)")->Apply(ShuffleNetV1G8)->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/batch-matrix-multiply.cc b/bench/batch-matrix-multiply.cc index 877225f9bb6..3e69e020c39 100644 --- a/bench/batch-matrix-multiply.cc +++ b/bench/batch-matrix-multiply.cc @@ -31,10 +31,6 @@ #include "tensorflow/lite/version.h" #endif // BENCHMARK_TENSORFLOW_LITE -namespace { -static const size_t kMinIterations = 10; -} // namespace - void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, const char* net) { const size_t batch_size = state.range(0); @@ -99,8 +95,9 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, return; } - while (state.KeepRunningBatch(kMinIterations)) { - for (int iter = 0; iter < kMinIterations; iter++) { + int num_iters = FLAGS_benchmark_min_iters; + while (state.KeepRunningBatch(num_iters)) { + for (int iter = 0; iter < num_iters; iter++) { benchmark::utils::WipePthreadpoolL2Caches(state, threadpool); status = xnn_run_operator(op, threadpool); @@ -109,6 +106,7 @@ void xnnpack_batch_matrix_multiply_f32(benchmark::State& state, return; } } + num_iters = 1; } status = xnn_delete_operator(op); @@ -207,8 +205,9 @@ void xnnpack_batch_matrix_multiply_qd8_f32_qc8w(benchmark::State& state, return; } - while (state.KeepRunningBatch(kMinIterations)) { - for (int iter = 0; iter < kMinIterations; iter++) { + int num_iters = FLAGS_benchmark_min_iters; + while (state.KeepRunningBatch(num_iters)) { + for (int iter = 0; iter < num_iters; iter++) { benchmark::utils::WipePthreadpoolL2Caches(state, threadpool); status = xnn_run_operator(op, threadpool); @@ -218,6 +217,7 @@ void xnnpack_batch_matrix_multiply_qd8_f32_qc8w(benchmark::State& state, return; } } + num_iters = 1; } status = xnn_delete_operator(op); @@ -353,13 +353,15 @@ void tflite_batch_matrix_multiply_f32(benchmark::State& state, interpreter->typed_tensor(1) + batch_size * k * n, std::ref(f32rng)); - while (state.KeepRunningBatch(kMinIterations)) { - for (int iter = 0; iter < kMinIterations; iter++) { + int num_iters = FLAGS_benchmark_min_iters; + while (state.KeepRunningBatch(num_iters)) { + for (int iter = 0; iter < num_iters; iter++) { if (interpreter->Invoke() != kTfLiteOk) { state.SkipWithError("failed to invoke TFLite interpreter"); return; } } + num_iters = 1; } const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); @@ -376,5 +378,5 @@ void tflite_batch_matrix_multiply_f32(benchmark::State& state, #endif // BENCHMARK_TENSORFLOW_LITE #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/bf16-gemm.cc b/bench/bf16-gemm.cc index 01ef4a8f475..df5c7b2fc5f 100644 --- a/bench/bf16-gemm.cc +++ b/bench/bf16-gemm.cc @@ -241,5 +241,5 @@ static void bf16_gemm(benchmark::State& state, #endif // XNN_ARCH_ARM || XNN_ARCH_ARM64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/binary.cc b/bench/binary.cc index 6b77b1bf116..1f71c577ce1 100644 --- a/bench/binary.cc +++ b/bench/binary.cc @@ -11,24 +11,58 @@ #include #include #include +#include #include #include "utils.h" #include "xnnpack.h" -#include "xnnpack/datatype.h" #include "xnnpack/buffer.h" +#include "xnnpack/common.h" +#include "xnnpack/datatype.h" #include "xnnpack/math.h" #include #ifdef BENCHMARK_TENSORFLOW_LITE -#include "flatbuffers/include/flatbuffers/flatbuffers.h" +#include "flatbuffers/include/flatbuffers/buffer.h" +#include "flatbuffers/include/flatbuffers/flatbuffer_builder.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" #endif // BENCHMARK_TENSORFLOW_LITE +#ifdef BENCHMARK_TENSORFLOW_LITE +namespace tflite { +// Maps the native C++ types to the corresponding TFLite tensor type enum +// values. +template +struct TensorTypeFor; + +#define TFLITE_TENSOR_TYPE_ASSOC(CPP_TYPE, TENSORTYPE_VALUE) \ + template <> \ + struct TensorTypeFor { \ + static constexpr TensorType value = TENSORTYPE_VALUE; \ + }; + +TFLITE_TENSOR_TYPE_ASSOC(bool, TensorType_BOOL); +TFLITE_TENSOR_TYPE_ASSOC(int8_t, TensorType_INT8); +TFLITE_TENSOR_TYPE_ASSOC(int16_t, TensorType_INT16); +TFLITE_TENSOR_TYPE_ASSOC(int32_t, TensorType_INT32); +TFLITE_TENSOR_TYPE_ASSOC(int64_t, TensorType_INT64); +TFLITE_TENSOR_TYPE_ASSOC(uint8_t, TensorType_UINT8); +TFLITE_TENSOR_TYPE_ASSOC(uint16_t, TensorType_UINT16); +TFLITE_TENSOR_TYPE_ASSOC(uint32_t, TensorType_UINT32); +TFLITE_TENSOR_TYPE_ASSOC(uint64_t, TensorType_UINT64); +TFLITE_TENSOR_TYPE_ASSOC(TfLiteFloat16, TensorType_FLOAT16); +TFLITE_TENSOR_TYPE_ASSOC(TfLiteBFloat16, TensorType_BFLOAT16); +TFLITE_TENSOR_TYPE_ASSOC(float, TensorType_FLOAT32); +TFLITE_TENSOR_TYPE_ASSOC(double, TensorType_FLOAT64); +TFLITE_TENSOR_TYPE_ASSOC(std::string, TensorType_STRING); + +#undef TFLITE_TENSOR_TYPE_ASSOC +}; // namespace tflite +#endif // BENCHMARK_TENSORFLOW_LITE + void init_params(xnn_binary_operator op_type, xnn_datatype datatype, xnn_binary_params& params, xnn_quantization_params& input_quantization, @@ -280,13 +314,14 @@ static void benchmark_tflite_binary_operator( state.counters["cpufreq"] = cpu_frequency; } - state.counters["elements"] = benchmark::Counter( - uint64_t(state.iterations()) * batch_size, benchmark::Counter::kIsRate); + state.counters["elements"] = + benchmark::Counter(static_cast(state.iterations()) * batch_size, + benchmark::Counter::kIsRate); const size_t bytes_per_iteration = 2 * batch_size * sizeof(float); - state.counters["bytes"] = - benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, - benchmark::Counter::kIsRate); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); interpreter.reset(); } @@ -393,5 +428,5 @@ BENCHMARK_OP_INTEGRAL(shift_right_logical); BENCHMARK_OP_INTEGRAL(shift_right_arithmetic); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/convolution.cc b/bench/convolution.cc index f6670b91351..ae39d15a83d 100644 --- a/bench/convolution.cc +++ b/bench/convolution.cc @@ -1832,5 +1832,5 @@ BENCHMARK_CAPTURE(xnnpack_convolution_qu8, srcnn955, "SRCNN (9-5-5)")->Apply(SRC #endif // BENCHMARK_TENSORFLOW_LITE #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/deconvolution.cc b/bench/deconvolution.cc index 3fb970c3149..3eed1320fc2 100644 --- a/bench/deconvolution.cc +++ b/bench/deconvolution.cc @@ -569,5 +569,5 @@ BENCHMARK_CAPTURE(xnnpack_deconvolution_qu8, espnet, "ESPNet") #endif // BENCHMARK_TENSORFLOW_LITE #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-conv-hwc2chw.cc b/bench/f16-conv-hwc2chw.cc index 693871262bf..d5aaa6156d3 100644 --- a/bench/f16-conv-hwc2chw.cc +++ b/bench/f16-conv-hwc2chw.cc @@ -124,5 +124,5 @@ static void f16_conv_hwc2chw(benchmark::State& state, #endif // XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-dwconv.cc b/bench/f16-dwconv.cc index 9cdc270a97c..02a3cdaaf84 100644 --- a/bench/f16-dwconv.cc +++ b/bench/f16-dwconv.cc @@ -777,5 +777,5 @@ static void f16_dwconv(benchmark::State& state, #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-dwconv2d-chw.cc b/bench/f16-dwconv2d-chw.cc index b674c559dc6..2c8c7030681 100644 --- a/bench/f16-dwconv2d-chw.cc +++ b/bench/f16-dwconv2d-chw.cc @@ -486,5 +486,5 @@ static void f16_dwconv2d_chw(benchmark::State& state, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-f32acc-gemm.cc b/bench/f16-f32acc-gemm.cc index a8c34f2062a..a78c0cf44f3 100644 --- a/bench/f16-f32acc-gemm.cc +++ b/bench/f16-f32acc-gemm.cc @@ -157,5 +157,5 @@ static void f16_gemm(benchmark::State& state, #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-f32acc-igemm.cc b/bench/f16-f32acc-igemm.cc index e9e14335ba2..58c0da07b74 100644 --- a/bench/f16-f32acc-igemm.cc +++ b/bench/f16-f32acc-igemm.cc @@ -201,5 +201,5 @@ static void f16_igemm(benchmark::State& state, #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-f32acc-rdsum.cc b/bench/f16-f32acc-rdsum.cc index 367e9a7e766..9bdca806643 100644 --- a/bench/f16-f32acc-rdsum.cc +++ b/bench/f16-f32acc-rdsum.cc @@ -130,5 +130,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-f32acc-rsum.cc b/bench/f16-f32acc-rsum.cc index f087668d615..8651b720aa5 100644 --- a/bench/f16-f32acc-rsum.cc +++ b/bench/f16-f32acc-rsum.cc @@ -190,5 +190,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-gemm-minmax.cc b/bench/f16-gemm-minmax.cc index c1834ba6117..7037c4d357b 100644 --- a/bench/f16-gemm-minmax.cc +++ b/bench/f16-gemm-minmax.cc @@ -494,5 +494,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-gemm.cc b/bench/f16-gemm.cc index d4191f2c6d0..8def4f6d277 100644 --- a/bench/f16-gemm.cc +++ b/bench/f16-gemm.cc @@ -360,5 +360,5 @@ static void f16_gemm(benchmark::State& state, #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-igemm.cc b/bench/f16-igemm.cc index aca68bd2937..6371ec6c024 100644 --- a/bench/f16-igemm.cc +++ b/bench/f16-igemm.cc @@ -364,5 +364,5 @@ static void f16_igemm(benchmark::State& state, #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-raddstoreexpminusmax.cc b/bench/f16-raddstoreexpminusmax.cc index 6d899bfb305..bd22f94602c 100644 --- a/bench/f16-raddstoreexpminusmax.cc +++ b/bench/f16-raddstoreexpminusmax.cc @@ -392,5 +392,5 @@ static void f16_raddstoreexpminusmax( #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-rmax.cc b/bench/f16-rmax.cc index 8bf625d5785..14b1148db52 100644 --- a/bench/f16-rmax.cc +++ b/bench/f16-rmax.cc @@ -192,5 +192,5 @@ BENCHMARK_CAPTURE(f16_rmax, scalar_u4_acc4, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-rmin.cc b/bench/f16-rmin.cc index c2890f7fb97..e0f0fdeaf9b 100644 --- a/bench/f16-rmin.cc +++ b/bench/f16-rmin.cc @@ -183,5 +183,5 @@ BENCHMARK_CAPTURE(f16_rmin, scalar_u4_acc4, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-rminmax.cc b/bench/f16-rminmax.cc index d12a8225e6f..da3799cad03 100644 --- a/bench/f16-rminmax.cc +++ b/bench/f16-rminmax.cc @@ -183,5 +183,5 @@ BENCHMARK_CAPTURE(f16_rminmax, scalar_u4_acc4, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-rsum.cc b/bench/f16-rsum.cc index 89fcafcf996..d3f68dc0ce4 100644 --- a/bench/f16-rsum.cc +++ b/bench/f16-rsum.cc @@ -120,5 +120,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-spmm.cc b/bench/f16-spmm.cc index 3e7719b5c91..c74d5c12fc8 100644 --- a/bench/f16-spmm.cc +++ b/bench/f16-spmm.cc @@ -232,5 +232,5 @@ static void f16_spmm(benchmark::State& state, #endif // XNN_ENABLE_ARM_FP16_VECTOR && (XNN_ARCH_ARM || XNN_ARCH_ARM64) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f16-vcmul.cc b/bench/f16-vcmul.cc index 26ef075c606..24478c77427 100644 --- a/bench/f16-vcmul.cc +++ b/bench/f16-vcmul.cc @@ -77,5 +77,5 @@ static void f16_vcmul(benchmark::State& state, uint64_t arch_flags, #undef XNN_UKERNEL_WITH_PARAMS #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-bgemm.cc b/bench/f32-bgemm.cc index f078bca6aae..9988ddc9f88 100644 --- a/bench/f32-bgemm.cc +++ b/bench/f32-bgemm.cc @@ -1649,5 +1649,5 @@ BENCHMARK_BGEMM(ruy_st) #endif // BENCHMARK_RUY #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-conv-hwc.cc b/bench/f32-conv-hwc.cc index 86aad770c86..bbe3e3d7df6 100644 --- a/bench/f32-conv-hwc.cc +++ b/bench/f32-conv-hwc.cc @@ -191,5 +191,5 @@ static void f32_conv_hwc(benchmark::State& state, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-conv-hwc2chw.cc b/bench/f32-conv-hwc2chw.cc index ddeabbf47df..3741e6a0a14 100644 --- a/bench/f32-conv-hwc2chw.cc +++ b/bench/f32-conv-hwc2chw.cc @@ -171,5 +171,5 @@ BENCHMARK_DCONV(f32_conv_hwc2chw_3x3s2p1c3x4__scalar_1x1); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-dwconv.cc b/bench/f32-dwconv.cc index fa21809f074..3835da4b21c 100644 --- a/bench/f32-dwconv.cc +++ b/bench/f32-dwconv.cc @@ -1754,5 +1754,5 @@ BENCHMARK_DWCONV(f32_dwconv_8f8m9l1c1s1r__scalar) BENCHMARK_DWCONV(f32_dwconv_8f8m9l1c1s1r__scalar_acc2) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-dwconv2d-chw.cc b/bench/f32-dwconv2d-chw.cc index 598efe15c3d..03abdde3be8 100644 --- a/bench/f32-dwconv2d-chw.cc +++ b/bench/f32-dwconv2d-chw.cc @@ -2606,5 +2606,5 @@ BENCHMARK_DWCONV(dwconv2d_chw_5x5s2p2__scalar_2x1_acc3) BENCHMARK_DWCONV(dwconv2d_chw_5x5s2p2__scalar_3x1_acc2) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-gemm-goi-minmax.cc b/bench/f32-gemm-goi-minmax.cc index acfa1d80780..df9d9e78c8a 100644 --- a/bench/f32-gemm-goi-minmax.cc +++ b/bench/f32-gemm-goi-minmax.cc @@ -53,5 +53,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-gemm-minmax.cc b/bench/f32-gemm-minmax.cc index f10a4f901f5..3e3a9fddf5b 100644 --- a/bench/f32-gemm-minmax.cc +++ b/bench/f32-gemm-minmax.cc @@ -3435,5 +3435,5 @@ static void f32_gemm_minmax_ukernel_4x4__scalar(benchmark::State& state, const c BENCHMARK_GEMM(f32_gemm_minmax_ukernel_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-gemm.cc b/bench/f32-gemm.cc index ba5d3be4803..a8b3190a815 100644 --- a/bench/f32-gemm.cc +++ b/bench/f32-gemm.cc @@ -2269,5 +2269,5 @@ BENCHMARK_GEMM(ruy_st) #endif // BENCHMARK_RUY #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-igemm.cc b/bench/f32-igemm.cc index 8a67ac74811..eed379cb311 100644 --- a/bench/f32-igemm.cc +++ b/bench/f32-igemm.cc @@ -1263,5 +1263,5 @@ BENCHMARK_CONV(f32_igemm_4x4__scalar) #endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-im2col-gemm.cc b/bench/f32-im2col-gemm.cc index 554f7ffbde6..08455b77c1d 100644 --- a/bench/f32-im2col-gemm.cc +++ b/bench/f32-im2col-gemm.cc @@ -163,5 +163,5 @@ BENCHMARK_CONV(f32_gemm_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-qc4w-gemm.cc b/bench/f32-qc4w-gemm.cc index cb302351efd..90331778260 100644 --- a/bench/f32-qc4w-gemm.cc +++ b/bench/f32-qc4w-gemm.cc @@ -726,5 +726,5 @@ BENCHMARK_GEMM(f32_qc4w_gemm_2x4__scalar) BENCHMARK_GEMM(f32_qc4w_gemm_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-qc8w-gemm.cc b/bench/f32-qc8w-gemm.cc index 76486933043..4b3b9f3bd37 100644 --- a/bench/f32-qc8w-gemm.cc +++ b/bench/f32-qc8w-gemm.cc @@ -1485,5 +1485,5 @@ BENCHMARK_GEMM(f32_qc8w_gemm_2x4__scalar) BENCHMARK_GEMM(f32_qc8w_gemm_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-raddexpminusmax.cc b/bench/f32-raddexpminusmax.cc index 331e0bc9e40..fb3f83280f9 100644 --- a/bench/f32-raddexpminusmax.cc +++ b/bench/f32-raddexpminusmax.cc @@ -214,5 +214,5 @@ static void CharacteristicArguments(benchmark::internal::Benchmark* b) { #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-raddextexp.cc b/bench/f32-raddextexp.cc index 6e06a0c8e5f..3f87e1a75dc 100644 --- a/bench/f32-raddextexp.cc +++ b/bench/f32-raddextexp.cc @@ -158,5 +158,5 @@ static void CharacteristicArguments(benchmark::internal::Benchmark* b) { #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-raddstoreexpminusmax.cc b/bench/f32-raddstoreexpminusmax.cc index 1c78f9b6732..12c791fcb71 100644 --- a/bench/f32-raddstoreexpminusmax.cc +++ b/bench/f32-raddstoreexpminusmax.cc @@ -494,5 +494,5 @@ BENCHMARK_CAPTURE(f32_raddstoreexpminusmax, scalar_rr2_p5_u4_acc4, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-rdsum.cc b/bench/f32-rdsum.cc index 208eed62293..fbc4fcc0a65 100644 --- a/bench/f32-rdsum.cc +++ b/bench/f32-rdsum.cc @@ -200,5 +200,5 @@ BENCHMARK_CAPTURE(f32_rdsum, scalar_c4, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-rmax.cc b/bench/f32-rmax.cc index f54c2ac8743..1ecd90006ce 100644 --- a/bench/f32-rmax.cc +++ b/bench/f32-rmax.cc @@ -297,5 +297,5 @@ BENCHMARK_CAPTURE(f32_rmax, scalar_u4_acc4, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-rmin.cc b/bench/f32-rmin.cc index 9a6f5b2048e..a3a34379029 100644 --- a/bench/f32-rmin.cc +++ b/bench/f32-rmin.cc @@ -297,5 +297,5 @@ BENCHMARK_CAPTURE(f32_rmin, scalar_u4_acc4, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-rminmax.cc b/bench/f32-rminmax.cc index 4d178165c87..b238db63fc4 100644 --- a/bench/f32-rminmax.cc +++ b/bench/f32-rminmax.cc @@ -298,5 +298,5 @@ BENCHMARK_CAPTURE(f32_rminmax, scalar_u4_acc4, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-rsum.cc b/bench/f32-rsum.cc index 02c904dcac3..22a473c39d0 100644 --- a/bench/f32-rsum.cc +++ b/bench/f32-rsum.cc @@ -350,5 +350,5 @@ BENCHMARK_CAPTURE(f32_rsum, scalar_u4_acc4, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-softmax.cc b/bench/f32-softmax.cc index b5f822b7193..4a8b50fdbcc 100644 --- a/bench/f32-softmax.cc +++ b/bench/f32-softmax.cc @@ -1,30 +1,27 @@ #include -#include #include -#include +#include +#include #include #include -#include -#include -#ifdef BENCHMARK_INTEL_DNNL -#include -#endif // BENCHMARK_INTEL_DNNL #include "utils.h" - -#include "xnnpack.h" +#include "xnnpack/buffer.h" #include "xnnpack/common.h" #include "xnnpack/microfnptr.h" -#include "xnnpack/microparams-init.h" +#include "xnnpack/microparams.h" #include "xnnpack/raddexpminusmax.h" #include "xnnpack/raddextexp.h" #include "xnnpack/raddstoreexpminusmax.h" -#include "xnnpack/vbinary.h" #include "xnnpack/reduce.h" +#include "xnnpack/vbinary.h" #include "xnnpack/vscaleexpminusmax.h" #include "xnnpack/vscaleextexp.h" -#include "xnnpack/buffer.h" +#include +#ifdef BENCHMARK_INTEL_DNNL +#include +#endif // BENCHMARK_INTEL_DNNL #ifdef BENCHMARK_INTEL_DNNL @@ -33,25 +30,26 @@ #define DNNL_MEMORY_DESC_INIT dnnl_memory_desc_init_by_tag #define DNNL_MEMORY_CREATE(mem, mem_desc, engine, handle) \ dnnl_memory_create(&mem, &mem_desc, engine, handle) -#else // DNNL_VERSION_MAJOR == 3 +#else // DNNL_VERSION_MAJOR == 3 #define DNNL_MEMORY_DESC_INIT dnnl_memory_desc_create_with_tag #define DNNL_MEMORY_CREATE(mem, mem_desc, engine, handle) \ dnnl_memory_create(&mem, mem_desc, engine, handle) #endif // DNNL_VERSION_MAJOR == 2 -static void DNNLSoftArgMax( - benchmark::State& state) -{ +static void DNNLSoftArgMax(benchmark::State& state) { const size_t elements = state.range(0); const size_t cache_line_size_max = 128; - const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float)); + const size_t packed_elements = + benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float)); std::random_device random_device; auto rng = std::mt19937(random_device()); - auto f32rng = std::bind(std::uniform_real_distribution(-1000.0f, 1000.0f), std::ref(rng)); + auto f32rng = std::bind( + std::uniform_real_distribution(-1000.0f, 1000.0f), std::ref(rng)); - const size_t num_buffers = 1 + - benchmark::utils::DivideRoundUp(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float)); + const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp( + benchmark::utils::GetMaxCacheSize(), + packed_elements * sizeof(float)); xnnpack::Buffer x(elements); xnnpack::Buffer y(packed_elements * num_buffers); @@ -63,38 +61,34 @@ static void DNNLSoftArgMax( return; } - dnnl_dim_t input_output_shape[1] = { static_cast(elements) }; + dnnl_dim_t input_output_shape[1] = {static_cast(elements)}; - dnnl_memory_desc_t memory_descriptor = { 0 }; - if (DNNL_MEMORY_DESC_INIT( - &memory_descriptor, 1, input_output_shape, dnnl_f32, dnnl_x) != dnnl_success) - { + dnnl_memory_desc_t memory_descriptor = {0}; + if (DNNL_MEMORY_DESC_INIT(&memory_descriptor, 1, input_output_shape, dnnl_f32, + dnnl_x) != dnnl_success) { state.SkipWithError("failed to create input memory descriptor"); return; } dnnl_memory_t input_memory = nullptr; - if (DNNL_MEMORY_CREATE( - input_memory, memory_descriptor, engine, x.data()) != dnnl_success) - { + if (DNNL_MEMORY_CREATE(input_memory, memory_descriptor, engine, x.data()) != + dnnl_success) { state.SkipWithError("failed to create input memory"); return; } dnnl_memory_t output_memory = nullptr; - if (DNNL_MEMORY_CREATE( - output_memory, memory_descriptor, engine, y.data()) != dnnl_success) - { + if (DNNL_MEMORY_CREATE(output_memory, memory_descriptor, engine, y.data()) != + dnnl_success) { state.SkipWithError("failed to create output memory"); return; } #if DNNL_VERSION_MAJOR == 2 dnnl_softmax_desc_t softmax_forward_descriptor = {}; - if (dnnl_softmax_forward_desc_init( - &softmax_forward_descriptor, dnnl_forward_inference, - &memory_descriptor, 0) != dnnl_success) - { + if (dnnl_softmax_forward_desc_init(&softmax_forward_descriptor, + dnnl_forward_inference, &memory_descriptor, + 0) != dnnl_success) { state.SkipWithError("failed to create SoftMax forward descriptor"); return; } @@ -102,36 +96,36 @@ static void DNNLSoftArgMax( dnnl_primitive_desc_t softmax_primitive_descriptor = nullptr; #if DNNL_VERSION_MAJOR == 2 - if (dnnl_primitive_desc_create( - &softmax_primitive_descriptor, &softmax_forward_descriptor, - nullptr /* primitive attributes */, engine, nullptr /* hint */) != dnnl_success) + if (dnnl_primitive_desc_create(&softmax_primitive_descriptor, + &softmax_forward_descriptor, + nullptr /* primitive attributes */, engine, + nullptr /* hint */) != dnnl_success) { #else // DNNL_VERSION_MAJOR == 3 if (dnnl_softmax_forward_primitive_desc_create( - &softmax_primitive_descriptor, engine, dnnl_forward_inference, - dnnl_softmax_accurate, /*src_desc=*/ memory_descriptor, - /*dst_dsc=*/ memory_descriptor, /*softmax_axis=*/ 0, - /*attr=*/ nullptr) != dnnl_success) + &softmax_primitive_descriptor, engine, dnnl_forward_inference, + dnnl_softmax_accurate, /*src_desc=*/memory_descriptor, + /*dst_dsc=*/memory_descriptor, /*softmax_axis=*/0, + /*attr=*/nullptr) != dnnl_success) { #endif // DNNL_VERSION_MAJOR == 2 - { state.SkipWithError("failed to create SoftMax primitive descriptor"); return; } dnnl_primitive_t softmax_primitive = nullptr; - if (dnnl_primitive_create( - &softmax_primitive, softmax_primitive_descriptor) != dnnl_success) - { + if (dnnl_primitive_create(&softmax_primitive, softmax_primitive_descriptor) != + dnnl_success) { state.SkipWithError("failed to create SoftMax primitive"); return; } dnnl_exec_arg_t softmax_args[2] = { - {DNNL_ARG_SRC, input_memory}, - {DNNL_ARG_DST, output_memory}, + {DNNL_ARG_SRC, input_memory}, + {DNNL_ARG_DST, output_memory}, }; dnnl_stream_t stream = nullptr; - if (dnnl_stream_create(&stream, engine, dnnl_stream_default_flags) != dnnl_success) { + if (dnnl_stream_create(&stream, engine, dnnl_stream_default_flags) != + dnnl_success) { state.SkipWithError("failed to create stream"); return; } @@ -144,16 +138,15 @@ static void DNNLSoftArgMax( } const auto start = std::chrono::high_resolution_clock::now(); - if (dnnl_primitive_execute( - softmax_primitive, stream, 2, softmax_args) != dnnl_success) - { + if (dnnl_primitive_execute(softmax_primitive, stream, 2, softmax_args) != + dnnl_success) { state.SkipWithError("failed to execute SoftMax"); return; } const auto end = std::chrono::high_resolution_clock::now(); const auto elapsed_seconds = - std::chrono::duration_cast>(end - start); + std::chrono::duration_cast>(end - start); state.SetIterationTime(elapsed_seconds.count()); } @@ -162,7 +155,8 @@ static void DNNLSoftArgMax( return; } - if (dnnl_primitive_desc_destroy(softmax_primitive_descriptor) != dnnl_success) { + if (dnnl_primitive_desc_destroy(softmax_primitive_descriptor) != + dnnl_success) { state.SkipWithError("failed to destroy SoftMax primitive descriptor"); return; } @@ -182,6 +176,11 @@ static void DNNLSoftArgMax( return; } + if (dnnl_memory_desc_destroy(memory_descriptor) != dnnl_success) { + state.SkipWithError("failed to destroy memory descriptor"); + return; + } + if (dnnl_engine_destroy(engine) != dnnl_success) { state.SkipWithError("failed to destroy engine"); return; @@ -193,37 +192,40 @@ static void DNNLSoftArgMax( } const size_t elements_per_iteration = elements; - state.counters["elements"] = - benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); + state.counters["elements"] = benchmark::Counter( + static_cast(state.iterations()) * elements_per_iteration, + benchmark::Counter::kIsRate); const size_t bytes_per_iteration = 2 * elements * sizeof(float); - state.counters["bytes"] = - benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); } #endif // BENCHMARK_INTEL_DNNL static void ThreePassSoftMaxWithRecomputing( - benchmark::State& state, - xnn_f32_rmax_ukernel_fn rmax, - xnn_init_f32_default_params_fn init_rmax_params, - xnn_f32_raddexpminusmax_ukernel_fn raddexpminusmax, - xnn_f32_vscaleexpminusmax_ukernel_fn vscaleexpminusmax, - benchmark::utils::IsaCheckFunction isa_check = nullptr) -{ + benchmark::State& state, xnn_f32_rmax_ukernel_fn rmax, + xnn_init_f32_default_params_fn init_rmax_params, + xnn_f32_raddexpminusmax_ukernel_fn raddexpminusmax, + xnn_f32_vscaleexpminusmax_ukernel_fn vscaleexpminusmax, + benchmark::utils::IsaCheckFunction isa_check = nullptr) { if (isa_check != nullptr && !isa_check(state)) { return; } const size_t elements = state.range(0); const size_t cache_line_size_max = 128; - const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float)); + const size_t packed_elements = + benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float)); std::random_device random_device; auto rng = std::mt19937(random_device()); - auto f32rng = std::bind(std::uniform_real_distribution(-1000.0f, 1000.0f), std::ref(rng)); + auto f32rng = std::bind( + std::uniform_real_distribution(-1000.0f, 1000.0f), std::ref(rng)); - const size_t num_buffers = 1 + - benchmark::utils::DivideRoundUp(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float)); + const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp( + benchmark::utils::GetMaxCacheSize(), + packed_elements * sizeof(float)); xnnpack::Buffer x(elements); xnnpack::Buffer y(packed_elements * num_buffers); @@ -248,11 +250,13 @@ static void ThreePassSoftMaxWithRecomputing( rmax(elements * sizeof(float), x.data(), &x_max, &rmax_params); float y_sum; raddexpminusmax(elements * sizeof(float), x.data(), &y_sum, x_max); - vscaleexpminusmax(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, x_max, 1.0f / y_sum); + vscaleexpminusmax(elements * sizeof(float), x.data(), + y.data() + packed_elements * buffer_index, x_max, + 1.0f / y_sum); const auto end = std::chrono::high_resolution_clock::now(); const auto elapsed_seconds = - std::chrono::duration_cast>(end - start); + std::chrono::duration_cast>(end - start); state.SetIterationTime(elapsed_seconds.count()); } @@ -262,37 +266,40 @@ static void ThreePassSoftMaxWithRecomputing( } const size_t elements_per_iteration = elements; - state.counters["elements"] = - benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); + state.counters["elements"] = benchmark::Counter( + static_cast(state.iterations()) * elements_per_iteration, + benchmark::Counter::kIsRate); const size_t bytes_per_iteration = 2 * elements * sizeof(float); - state.counters["bytes"] = - benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); } static void ThreePassSoftMaxWithReloading( - benchmark::State& state, - xnn_f32_rmax_ukernel_fn rmax, - xnn_init_f32_default_params_fn init_rmax_params, - xnn_f32_raddstoreexpminusmax_ukernel_fn raddstoreexpminusmax, - xnn_init_f32_expminus_params_fn init_expminus_params, - xnn_f32_vbinary_ukernel_fn vmulc, - benchmark::utils::IsaCheckFunction isa_check = nullptr) -{ + benchmark::State& state, xnn_f32_rmax_ukernel_fn rmax, + xnn_init_f32_default_params_fn init_rmax_params, + xnn_f32_raddstoreexpminusmax_ukernel_fn raddstoreexpminusmax, + xnn_init_f32_expminus_params_fn init_expminus_params, + xnn_f32_vbinary_ukernel_fn vmulc, + benchmark::utils::IsaCheckFunction isa_check = nullptr) { if (isa_check != nullptr && !isa_check(state)) { return; } const size_t elements = state.range(0); const size_t cache_line_size_max = 128; - const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float)); + const size_t packed_elements = + benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float)); std::random_device random_device; auto rng = std::mt19937(random_device()); - auto f32rng = std::bind(std::uniform_real_distribution(-1000.0f, 1000.0f), std::ref(rng)); + auto f32rng = std::bind( + std::uniform_real_distribution(-1000.0f, 1000.0f), std::ref(rng)); - const size_t num_buffers = 1 + - benchmark::utils::DivideRoundUp(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float)); + const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp( + benchmark::utils::GetMaxCacheSize(), + packed_elements * sizeof(float)); xnnpack::Buffer x(elements); xnnpack::Buffer y(packed_elements * num_buffers); @@ -320,13 +327,16 @@ static void ThreePassSoftMaxWithReloading( float x_max; rmax(elements * sizeof(float), x.data(), &x_max, &rmax_params); float y_sum; - raddstoreexpminusmax(elements * sizeof(float), x.data(), &x_max, y.data() + packed_elements * buffer_index, &y_sum, &expminus_params); + raddstoreexpminusmax(elements * sizeof(float), x.data(), &x_max, + y.data() + packed_elements * buffer_index, &y_sum, + &expminus_params); const float inv_y_sum = 1.0f / y_sum; - vmulc(elements * sizeof(float), y.data() + packed_elements * buffer_index, &inv_y_sum, y.data() + packed_elements * buffer_index, nullptr); + vmulc(elements * sizeof(float), y.data() + packed_elements * buffer_index, + &inv_y_sum, y.data() + packed_elements * buffer_index, nullptr); const auto end = std::chrono::high_resolution_clock::now(); const auto elapsed_seconds = - std::chrono::duration_cast>(end - start); + std::chrono::duration_cast>(end - start); state.SetIterationTime(elapsed_seconds.count()); } @@ -336,34 +346,37 @@ static void ThreePassSoftMaxWithReloading( } const size_t elements_per_iteration = elements; - state.counters["elements"] = - benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); + state.counters["elements"] = benchmark::Counter( + static_cast(state.iterations()) * elements_per_iteration, + benchmark::Counter::kIsRate); const size_t bytes_per_iteration = 2 * elements * sizeof(float); - state.counters["bytes"] = - benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); } static void TwoPassSoftMax( - benchmark::State& state, - xnn_f32_raddextexp_ukernel_fn raddextexp, - xnn_f32_vscaleextexp_ukernel_fn vscaleextexp, - benchmark::utils::IsaCheckFunction isa_check = nullptr) -{ + benchmark::State& state, xnn_f32_raddextexp_ukernel_fn raddextexp, + xnn_f32_vscaleextexp_ukernel_fn vscaleextexp, + benchmark::utils::IsaCheckFunction isa_check = nullptr) { if (isa_check != nullptr && !isa_check(state)) { return; } const size_t elements = state.range(0); const size_t cache_line_size_max = 128; - const size_t packed_elements = benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float)); + const size_t packed_elements = + benchmark::utils::RoundUp(elements, cache_line_size_max / sizeof(float)); std::random_device random_device; auto rng = std::mt19937(random_device()); - auto f32rng = std::bind(std::uniform_real_distribution(-1000.0f, 1000.0f), std::ref(rng)); + auto f32rng = std::bind( + std::uniform_real_distribution(-1000.0f, 1000.0f), std::ref(rng)); - const size_t num_buffers = 1 + - benchmark::utils::DivideRoundUp(benchmark::utils::GetMaxCacheSize(), packed_elements * sizeof(float)); + const size_t num_buffers = 1 + benchmark::utils::DivideRoundUp( + benchmark::utils::GetMaxCacheSize(), + packed_elements * sizeof(float)); xnnpack::Buffer x(elements); xnnpack::Buffer y(packed_elements * num_buffers); @@ -381,11 +394,13 @@ static void TwoPassSoftMax( const auto start = std::chrono::high_resolution_clock::now(); float scale[2]; raddextexp(elements * sizeof(float), x.data(), scale); - vscaleextexp(elements * sizeof(float), x.data(), y.data() + packed_elements * buffer_index, 1.0f / scale[0], -scale[1]); + vscaleextexp(elements * sizeof(float), x.data(), + y.data() + packed_elements * buffer_index, 1.0f / scale[0], + -scale[1]); const auto end = std::chrono::high_resolution_clock::now(); const auto elapsed_seconds = - std::chrono::duration_cast>(end - start); + std::chrono::duration_cast>(end - start); state.SetIterationTime(elapsed_seconds.count()); } @@ -395,19 +410,21 @@ static void TwoPassSoftMax( } const size_t elements_per_iteration = elements; - state.counters["elements"] = - benchmark::Counter(uint64_t(state.iterations()) * elements_per_iteration, benchmark::Counter::kIsRate); + state.counters["elements"] = benchmark::Counter( + static_cast(state.iterations()) * elements_per_iteration, + benchmark::Counter::kIsRate); const size_t bytes_per_iteration = 2 * elements * sizeof(float); - state.counters["bytes"] = - benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, benchmark::Counter::kIsRate); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); } static void CharacteristicArguments(benchmark::internal::Benchmark* b) { // Size Iterations Parameters used by Stable Diffusion - b->Arg( 128); // 1 - b->Arg( 154); // 421 - b->Arg( 512); // 20 + b->Arg(128); // 1 + b->Arg(154); // 421 + b->Arg(512); // 20 b->Arg(2048); // 80 b->Arg(8192); // 320 for (int32_t n = 10000; n <= 1000000; n *= 10) { @@ -416,59 +433,69 @@ static void CharacteristicArguments(benchmark::internal::Benchmark* b) { } #ifdef BENCHMARK_INTEL_DNNL - BENCHMARK(DNNLSoftArgMax)->Apply(CharacteristicArguments)->UseManualTime(); +BENCHMARK(DNNLSoftArgMax)->Apply(CharacteristicArguments)->UseManualTime(); #endif #if XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) - BENCHMARK_CAPTURE(TwoPassSoftMax, avx512f_p5_scalef, - xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3, - xnn_f32_vscaleextexp_ukernel__avx512f_p5_scalef_u16, - benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime(); - BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx512f_p5_scalef, - xnn_f32_rmax_ukernel__avx512f_u64_acc4, - (xnn_init_f32_default_params_fn) nullptr, - xnn_f32_raddexpminusmax_ukernel__avx512f_p5_scalef_u128_acc4, - xnn_f32_vscaleexpminusmax_ukernel__avx512f_p5_scalef_u16, - benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime(); - BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx512f_p5_scalef, +BENCHMARK_CAPTURE(TwoPassSoftMax, avx512f_p5_scalef, + xnn_f32_raddextexp_ukernel__avx512f_p5_scalef_u144_acc3, + xnn_f32_vscaleextexp_ukernel__avx512f_p5_scalef_u16, + benchmark::utils::CheckAVX512F) + ->Apply(CharacteristicArguments) + ->UseManualTime(); +BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx512f_p5_scalef, + xnn_f32_rmax_ukernel__avx512f_u64_acc4, + (xnn_init_f32_default_params_fn) nullptr, + xnn_f32_raddexpminusmax_ukernel__avx512f_p5_scalef_u128_acc4, + xnn_f32_vscaleexpminusmax_ukernel__avx512f_p5_scalef_u16, + benchmark::utils::CheckAVX512F) + ->Apply(CharacteristicArguments) + ->UseManualTime(); +BENCHMARK_CAPTURE( + ThreePassSoftMaxWithReloading, avx512f_p5_scalef, xnn_f32_rmax_ukernel__avx512f_u64_acc4, (xnn_init_f32_default_params_fn) nullptr, xnn_f32_raddstoreexpminusmax_ukernel__avx512f_rr1_p5_scalef_u64_acc2, - nullptr, - xnn_f32_vmulc_ukernel__avx512f_u32, - benchmark::utils::CheckAVX512F)->Apply(CharacteristicArguments)->UseManualTime(); + nullptr, xnn_f32_vmulc_ukernel__avx512f_u32, benchmark::utils::CheckAVX512F) + ->Apply(CharacteristicArguments) + ->UseManualTime(); #endif // XNN_ENABLE_AVX512F && (XNN_ARCH_X86 || XNN_ARCH_X86_64) #if XNN_ARCH_X86 || XNN_ARCH_X86_64 - BENCHMARK_CAPTURE(TwoPassSoftMax, avx2_p5, - xnn_f32_raddextexp_ukernel__avx2_p5_u96, - xnn_f32_vscaleextexp_ukernel__avx2_p5_u32, - benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime(); - BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx2_p5, - xnn_f32_rmax_ukernel__avx_u32_acc4, - (xnn_init_f32_default_params_fn) nullptr, - xnn_f32_raddexpminusmax_ukernel__avx2_p5_u96, - xnn_f32_vscaleexpminusmax_ukernel__avx2_p5_u24, - benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime(); - BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx2_p5, - xnn_f32_rmax_ukernel__avx_u32_acc4, - (xnn_init_f32_default_params_fn) nullptr, - xnn_f32_raddstoreexpminusmax_ukernel__avx2_rr1_p5_u32_acc2, - nullptr, - xnn_f32_vmulc_ukernel__avx_u16, - benchmark::utils::CheckAVX2)->Apply(CharacteristicArguments)->UseManualTime(); +BENCHMARK_CAPTURE(TwoPassSoftMax, avx2_p5, + xnn_f32_raddextexp_ukernel__avx2_p5_u96, + xnn_f32_vscaleextexp_ukernel__avx2_p5_u32, + benchmark::utils::CheckAVX2) + ->Apply(CharacteristicArguments) + ->UseManualTime(); +BENCHMARK_CAPTURE(ThreePassSoftMaxWithRecomputing, avx2_p5, + xnn_f32_rmax_ukernel__avx_u32_acc4, + (xnn_init_f32_default_params_fn) nullptr, + xnn_f32_raddexpminusmax_ukernel__avx2_p5_u96, + xnn_f32_vscaleexpminusmax_ukernel__avx2_p5_u24, + benchmark::utils::CheckAVX2) + ->Apply(CharacteristicArguments) + ->UseManualTime(); +BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, avx2_p5, + xnn_f32_rmax_ukernel__avx_u32_acc4, + (xnn_init_f32_default_params_fn) nullptr, + xnn_f32_raddstoreexpminusmax_ukernel__avx2_rr1_p5_u32_acc2, + nullptr, xnn_f32_vmulc_ukernel__avx_u16, + benchmark::utils::CheckAVX2) + ->Apply(CharacteristicArguments) + ->UseManualTime(); #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #if XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV - BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, rvv_p6_rmax_m8_exp_m4_vmulc_m8, - xnn_f32_rmax_ukernel__rvv_u8v, - (xnn_init_f32_default_params_fn) nullptr, - xnn_f32_raddstoreexpminusmax_ukernel__rvv_rr2_p6_u4v, - nullptr, - xnn_f32_vmulc_ukernel__rvv_u8v, - benchmark::utils::CheckRVV)->Apply(CharacteristicArguments)->UseManualTime(); +BENCHMARK_CAPTURE(ThreePassSoftMaxWithReloading, rvv_p6_rmax_m8_exp_m4_vmulc_m8, + xnn_f32_rmax_ukernel__rvv_u8v, + (xnn_init_f32_default_params_fn) nullptr, + xnn_f32_raddstoreexpminusmax_ukernel__rvv_rr2_p6_u4v, nullptr, + xnn_f32_vmulc_ukernel__rvv_u8v, benchmark::utils::CheckRVV) + ->Apply(CharacteristicArguments) + ->UseManualTime(); #endif // XNN_ENABLE_RISCV_VECTOR && XNN_ARCH_RISCV #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-spmm.cc b/bench/f32-spmm.cc index f802e9d3ad2..7598d326be0 100644 --- a/bench/f32-spmm.cc +++ b/bench/f32-spmm.cc @@ -1762,5 +1762,5 @@ static void f32_spmm_minmax_ukernel_8x4__scalar(benchmark::State& state, const c BENCHMARK_SPMM(f32_spmm_minmax_ukernel_8x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-vcmul.cc b/bench/f32-vcmul.cc index 37983b27e40..959b223faa8 100644 --- a/bench/f32-vcmul.cc +++ b/bench/f32-vcmul.cc @@ -75,5 +75,5 @@ static void f32_vcmul(benchmark::State& state, uint64_t arch_flags, #undef XNN_UKERNEL_WITH_PARAMS #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-vscaleexpminusmax.cc b/bench/f32-vscaleexpminusmax.cc index 8fcb0a61acb..b670a5aa6c1 100644 --- a/bench/f32-vscaleexpminusmax.cc +++ b/bench/f32-vscaleexpminusmax.cc @@ -127,5 +127,5 @@ static void CharacteristicArguments(benchmark::internal::Benchmark* b) { #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/f32-vscaleextexp.cc b/bench/f32-vscaleextexp.cc index 5633f66a3d5..06982b9e062 100644 --- a/bench/f32-vscaleextexp.cc +++ b/bench/f32-vscaleextexp.cc @@ -117,5 +117,5 @@ static void CharacteristicArguments(benchmark::internal::Benchmark* b) { #endif // XNN_ARCH_X86 || XNN_ARCH_X86_64 #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/fully-connected.cc b/bench/fully-connected.cc index 9e6c67dd1ac..0e69fbdf5e0 100644 --- a/bench/fully-connected.cc +++ b/bench/fully-connected.cc @@ -227,5 +227,5 @@ void xnnpack_dynamic_fully_connected_f32(benchmark::State& state, const char* ne } #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/max-pooling.cc b/bench/max-pooling.cc index 9e3ee9117c4..2e7deedc9e8 100644 --- a/bench/max-pooling.cc +++ b/bench/max-pooling.cc @@ -249,5 +249,5 @@ BENCHMARK_CAPTURE(max_pooling_u8, squeezenet_v11, "SqueezeNet v1.1")->Apply(Sque BENCHMARK_CAPTURE(max_pooling_u8, vgg, "VGG")->Apply(VGG); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/models/benchmark.cc b/bench/models/benchmark.cc index c423a21ee69..858d621fa5d 100644 --- a/bench/models/benchmark.cc +++ b/bench/models/benchmark.cc @@ -5,14 +5,12 @@ #include -#include #include #include #include #include #include #include -#include #include #include @@ -23,9 +21,6 @@ #include "xnnpack/subgraph.h" #include "pthreadpool.h" -int FLAGS_num_threads = 1; -uint32_t FLAGS_xnn_runtime_flags = 0; - struct ModelRuntime { std::unique_ptr model; pthreadpool_t threadpool = nullptr; @@ -116,12 +111,17 @@ static void BenchmarkInvoke(benchmark::State& state, return; } - for (auto _ : state) { - benchmark::utils::WipePthreadpoolL2Caches(state, model_runtime.threadpool); - if (!model_runtime.Invoke()) { - state.SkipWithError("failed to invoke runtime"); - return; + int num_iters = FLAGS_benchmark_min_iters; + while (state.KeepRunningBatch(num_iters)) { + for (int iter = 0; iter < num_iters; iter++) { + benchmark::utils::WipePthreadpoolL2Caches(state, + model_runtime.threadpool); + if (!model_runtime.Invoke()) { + state.SkipWithError("failed to invoke runtime"); + return; + } } + num_iters = 1; } const uint64_t cpu_frequency = benchmark::utils::GetCurrentCpuFrequency(); @@ -132,8 +132,9 @@ static void BenchmarkInvoke(benchmark::State& state, static void FP32Attention(benchmark::State& state) { BenchmarkInvoke(state, [&state]() { - return models::FP32Attention(state.range(0), state.range(1), state.range(2), - state.range(3), state.range(4)); + return models::FP32Attention(FLAGS_batch_size, state.range(0), + state.range(1), state.range(2), + state.range(3)); }); } @@ -141,9 +142,9 @@ static void FP16Attention(benchmark::State& state) { BenchmarkInvoke( state, [&state]() { - return models::FP32Attention(state.range(0), state.range(1), - state.range(2), state.range(3), - state.range(4)); + return models::FP32Attention(FLAGS_batch_size, state.range(0), + state.range(1), state.range(2), + state.range(3)); }, XNN_FLAG_FORCE_FP16_INFERENCE); } @@ -186,13 +187,11 @@ static void FP16MobileNetV3Small(benchmark::State& state) { static void QD8Attention(benchmark::State& state) { models::QD8AttentionWeights weights; - BenchmarkInvoke( - state, - [&state, &weights]() { - return models::QD8Attention(state.range(0), state.range(1), - state.range(2), state.range(3), - state.range(4), weights); - }); + BenchmarkInvoke(state, [&state, &weights]() { + return models::QD8Attention(FLAGS_batch_size, state.range(0), + state.range(1), state.range(2), state.range(3), + weights); + }); } static void QS8MobileNetV2(benchmark::State& state) { @@ -200,16 +199,16 @@ static void QS8MobileNetV2(benchmark::State& state) { } static void AttentionArguments(benchmark::internal::Benchmark* b) { - b->ArgNames({"B", "T", "H", "N", "S"}); - b->Args({1, 16, 25, 24, 4}); - b->Args({1, 1536, 128, 12, 18}); - b->Args({1, 1024, 256, 4, 46}); - b->Args({1, 1792, 256, 8, 36}); - b->Args({1, 1536, 256, 6, 22}); - b->Args({1, 2048, 256, 8, 18}); - b->Args({1, 3072, 256, 16, 28}); - b->Args({1, 2304, 256, 8, 26}); - b->Args({1, 2048, 64, 32, 24}); + b->ArgNames({"T", "H", "N", "S"}); + b->Args({16, 25, 24, 4}); + b->Args({1536, 128, 12, 18}); + b->Args({1024, 256, 4, 46}); + b->Args({1792, 256, 8, 36}); + b->Args({1536, 256, 6, 22}); + b->Args({2048, 256, 8, 18}); + b->Args({3072, 256, 16, 28}); + b->Args({2304, 256, 8, 26}); + b->Args({2048, 64, 32, 24}); } BENCHMARK(FP32Attention) @@ -239,46 +238,4 @@ BENCHMARK(QD8Attention) BENCHMARK(QS8MobileNetV2)->Unit(benchmark::kMicrosecond)->UseRealTime(); -int ProcessArgs(int& argc, char**& argv) { - for (int i = 1; i < argc;) { - if (strncmp(argv[i], "--num_threads=", 14) == 0) { - FLAGS_num_threads = atoi(argv[i] + 14); - if (FLAGS_num_threads <= 0) { - std::cerr << "Invalid --num_threads: " << FLAGS_num_threads << "\n"; - return 1; - } - std::copy(argv + i + 1, argv + argc, argv + i); - argc -= 1; - } else if (strncmp(argv[i], "--xnn_runtime_flags=", 20) == 0) { - const char* v = argv[i] + 20; - if (strlen(v) > 2 && strncmp(v, "0x", 2) == 0) { - FLAGS_xnn_runtime_flags = strtoul(v + 2, nullptr, 16); - } else { - FLAGS_xnn_runtime_flags = strtoul(v, nullptr, 10); - } - std::copy(argv + i + 1, argv + argc, argv + i); - argc -= 1; - } else { - ++i; - } - } - return 0; -} - -#ifdef BENCHMARK_ARGS_BOTTLENECK -// We are provided with a main that will call this function -extern "C" { -int BenchmarkArgBottleneck(int& argc, char**& argv) { - return ProcessArgs(argc, argv); -} -} -#else -int main(int argc, char** argv) { - ::benchmark::Initialize(&argc, argv); - int status = ProcessArgs(argc, argv); - if (status != 0) return status; - if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; - ::benchmark::RunSpecifiedBenchmarks(); -} -#endif - +XNN_BENCHMARK_MAIN(); diff --git a/bench/prelu.cc b/bench/prelu.cc index 8da165cb695..6ed33855eb4 100644 --- a/bench/prelu.cc +++ b/bench/prelu.cc @@ -256,5 +256,5 @@ BENCHMARK_CAPTURE(xnnpack_prelu_f32, imagenet, "ImageNet 224x224")->Apply(ImageN #endif // BENCHMARK_TENSORFLOW_LITE #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qd8-f16-qb4w-gemm.cc b/bench/qd8-f16-qb4w-gemm.cc index ee0ece7a993..a8e4838bdc2 100644 --- a/bench/qd8-f16-qb4w-gemm.cc +++ b/bench/qd8-f16-qb4w-gemm.cc @@ -659,5 +659,5 @@ BENCHMARK_GEMM_BL(qd8_f16_qb4w_gemm_minmax_ukernel_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qd8-f16-qc4w-gemm.cc b/bench/qd8-f16-qc4w-gemm.cc index 916d129ee52..0b0b75e8227 100644 --- a/bench/qd8-f16-qc4w-gemm.cc +++ b/bench/qd8-f16-qc4w-gemm.cc @@ -1606,5 +1606,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qd8-f16-qc8w-gemm.cc b/bench/qd8-f16-qc8w-gemm.cc index c46f1602263..0c322383d48 100644 --- a/bench/qd8-f16-qc8w-gemm.cc +++ b/bench/qd8-f16-qc8w-gemm.cc @@ -988,5 +988,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qd8-f32-qb4w-gemm.cc b/bench/qd8-f32-qb4w-gemm.cc index 175d9b816c3..9f6fa43fe4b 100644 --- a/bench/qd8-f32-qb4w-gemm.cc +++ b/bench/qd8-f32-qb4w-gemm.cc @@ -1281,5 +1281,5 @@ static void qd8_f32_qb4w_gemm_minmax_ukernel_4x4__scalar(benchmark::State& state BENCHMARK_GEMM_BL(qd8_f32_qb4w_gemm_minmax_ukernel_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qd8-f32-qc4w-gemm.cc b/bench/qd8-f32-qc4w-gemm.cc index 9066e099c1c..b329316f8e0 100644 --- a/bench/qd8-f32-qc4w-gemm.cc +++ b/bench/qd8-f32-qc4w-gemm.cc @@ -3673,5 +3673,5 @@ static void qd8_f32_qc4w_gemm_minmax_ukernel_4x4__scalar(benchmark::State& state BENCHMARK_GEMM(qd8_f32_qc4w_gemm_minmax_ukernel_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qd8-f32-qc8w-gemm.cc b/bench/qd8-f32-qc8w-gemm.cc index f6864d05cfe..5cc50b755e5 100644 --- a/bench/qd8-f32-qc8w-gemm.cc +++ b/bench/qd8-f32-qc8w-gemm.cc @@ -3576,5 +3576,5 @@ static void qd8_f32_qc8w_gemm_minmax_ukernel_4x4__scalar(benchmark::State& state BENCHMARK_GEMM(qd8_f32_qc8w_gemm_minmax_ukernel_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qp8-f32-qb4w-gemm.cc b/bench/qp8-f32-qb4w-gemm.cc index 71c64196e4e..b449a2390d6 100644 --- a/bench/qp8-f32-qb4w-gemm.cc +++ b/bench/qp8-f32-qb4w-gemm.cc @@ -95,5 +95,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qp8-f32-qc4w-gemm.cc b/bench/qp8-f32-qc4w-gemm.cc index cd5af5412eb..ecc56d85fe6 100644 --- a/bench/qp8-f32-qc4w-gemm.cc +++ b/bench/qp8-f32-qc4w-gemm.cc @@ -108,5 +108,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qp8-f32-qc8w-gemm.cc b/bench/qp8-f32-qc8w-gemm.cc index 1d970a7635d..aff2ec7fce2 100644 --- a/bench/qp8-f32-qc8w-gemm.cc +++ b/bench/qp8-f32-qc8w-gemm.cc @@ -82,5 +82,5 @@ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qs8-dwconv.cc b/bench/qs8-dwconv.cc index 3d862732f7d..dad2481c0cb 100644 --- a/bench/qs8-dwconv.cc +++ b/bench/qs8-dwconv.cc @@ -1724,5 +1724,5 @@ BENCHMARK_DWCONV(qs8_dwconv_9p4c__scalar_lrintf); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qs8-gemm.cc b/bench/qs8-gemm.cc index 50e436e06a3..3e2fa7d6ad8 100644 --- a/bench/qs8-gemm.cc +++ b/bench/qs8-gemm.cc @@ -122,5 +122,5 @@ BENCHMARK_GEMM(ruy_st) #endif // BENCHMARK_RUY #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qs8-packw.cc b/bench/qs8-packw.cc index e4d8ee97cbb..5e546965ab8 100644 --- a/bench/qs8-packw.cc +++ b/bench/qs8-packw.cc @@ -38,6 +38,6 @@ BENCHMARK_CAPTURE_BGEMM(qs8_gio_packw, ukernel##_, ukernel, arch_flags, nr, kr, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qs8-qc4w-packw.cc b/bench/qs8-qc4w-packw.cc index d2b90eb754b..8c8806f917a 100644 --- a/bench/qs8-qc4w-packw.cc +++ b/bench/qs8-qc4w-packw.cc @@ -28,6 +28,6 @@ BENCHMARK_CAPTURE_BGEMM(qs8_qc4w_packw, ukernel##_, ukernel, arch_flags, nr, kr, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qs8-qc8w-gemm-fp32.cc b/bench/qs8-qc8w-gemm-fp32.cc index 795f8393620..7b32de8b225 100644 --- a/bench/qs8-qc8w-gemm-fp32.cc +++ b/bench/qs8-qc8w-gemm-fp32.cc @@ -4649,5 +4649,5 @@ static void qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf(benchmark::Stat BENCHMARK_GEMM(qs8_qc8w_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qs8-rdsum.cc b/bench/qs8-rdsum.cc index 0a136758c59..4fbfd6cd235 100644 --- a/bench/qs8-rdsum.cc +++ b/bench/qs8-rdsum.cc @@ -153,5 +153,5 @@ BENCHMARK_CAPTURE(qs8_rdsum, scalar_c4, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qs8-rsum.cc b/bench/qs8-rsum.cc index 2651d3625ef..5d7d47f7eef 100644 --- a/bench/qs8-rsum.cc +++ b/bench/qs8-rsum.cc @@ -470,5 +470,5 @@ BENCHMARK_CAPTURE(qs8_rsum, scalar_u4, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qu8-gemm-fp32.cc b/bench/qu8-gemm-fp32.cc index 67f1ce5f5c2..d8fc6c8b287 100644 --- a/bench/qu8-gemm-fp32.cc +++ b/bench/qu8-gemm-fp32.cc @@ -1636,5 +1636,5 @@ static void qu8_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf(benchmark::State& st BENCHMARK_GEMM(qu8_gemm_minmax_fp32_ukernel_4x4__scalar_lrintf) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qu8-gemm-rndnu.cc b/bench/qu8-gemm-rndnu.cc index 1f2dab3d8a9..e965dc87d87 100644 --- a/bench/qu8-gemm-rndnu.cc +++ b/bench/qu8-gemm-rndnu.cc @@ -403,5 +403,5 @@ static void qu8_gemm_minmax_rndnu_ukernel_4x4__scalar(benchmark::State& state, c BENCHMARK_GEMM(qu8_gemm_minmax_rndnu_ukernel_4x4__scalar) #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qu8-gemm.cc b/bench/qu8-gemm.cc index 2ecea2ff721..bf5e64f894e 100644 --- a/bench/qu8-gemm.cc +++ b/bench/qu8-gemm.cc @@ -1328,5 +1328,5 @@ BENCHMARK_GEMM(gemmlowp_st) #endif // BENCHMARK_GEMMLOWP #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qu8-rdsum.cc b/bench/qu8-rdsum.cc index 747117bcc11..bd4e9306273 100644 --- a/bench/qu8-rdsum.cc +++ b/bench/qu8-rdsum.cc @@ -113,5 +113,5 @@ BENCHMARK_CAPTURE(qu8_rdsum, scalar_c4, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/qu8-rsum.cc b/bench/qu8-rsum.cc index 54982589776..bc2d53e93d9 100644 --- a/bench/qu8-rsum.cc +++ b/bench/qu8-rsum.cc @@ -190,5 +190,5 @@ BENCHMARK_CAPTURE(qu8_rsum, scalar_u4, #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/scaled-dot-product-attention.cc b/bench/scaled-dot-product-attention.cc index 225d166bdc5..fbdb83915fb 100644 --- a/bench/scaled-dot-product-attention.cc +++ b/bench/scaled-dot-product-attention.cc @@ -401,5 +401,5 @@ BENCHMARK_CAPTURE(xnnpack_multihead_scaled_dot_product_attention_cap_tanh_f32, b BENCHMARK_CAPTURE(xnnpack_multihead_scaled_batch_matrix_multiply_cap_tanh_f32, bert, "BERT")->Apply(Bert)->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/softmax.cc b/bench/softmax.cc index 3dc57217889..774e5b1c554 100644 --- a/bench/softmax.cc +++ b/bench/softmax.cc @@ -400,5 +400,5 @@ BENCHMARK(xnnpack_softmax_qu8)->Apply(CharacteristicArguments)->UseRealTime(); #endif // BENCHMARK_TENSORFLOW_LITE #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/unary.cc b/bench/unary.cc index a8536340a00..3144d14e8c9 100644 --- a/bench/unary.cc +++ b/bench/unary.cc @@ -11,24 +11,58 @@ #include #include #include +#include #include #include "utils.h" #include "xnnpack.h" -#include "xnnpack/datatype.h" #include "xnnpack/buffer.h" +#include "xnnpack/common.h" +#include "xnnpack/datatype.h" #include "xnnpack/math.h" #include #ifdef BENCHMARK_TENSORFLOW_LITE -#include "flatbuffers/include/flatbuffers/flatbuffers.h" +#include "flatbuffers/include/flatbuffers/buffer.h" +#include "flatbuffers/include/flatbuffers/flatbuffer_builder.h" +#include "tensorflow/lite/core/interpreter_builder.h" #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" -#include "tensorflow/lite/kernels/test_util.h" -#include "tensorflow/lite/model.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/version.h" #endif // BENCHMARK_TENSORFLOW_LITE +#ifdef BENCHMARK_TENSORFLOW_LITE +namespace tflite { +// Maps the native C++ types to the corresponding TFLite tensor type enum +// values. +template +struct TensorTypeFor; + +#define TFLITE_TENSOR_TYPE_ASSOC(CPP_TYPE, TENSORTYPE_VALUE) \ + template <> \ + struct TensorTypeFor { \ + static constexpr TensorType value = TENSORTYPE_VALUE; \ + }; + +TFLITE_TENSOR_TYPE_ASSOC(bool, TensorType_BOOL); +TFLITE_TENSOR_TYPE_ASSOC(int8_t, TensorType_INT8); +TFLITE_TENSOR_TYPE_ASSOC(int16_t, TensorType_INT16); +TFLITE_TENSOR_TYPE_ASSOC(int32_t, TensorType_INT32); +TFLITE_TENSOR_TYPE_ASSOC(int64_t, TensorType_INT64); +TFLITE_TENSOR_TYPE_ASSOC(uint8_t, TensorType_UINT8); +TFLITE_TENSOR_TYPE_ASSOC(uint16_t, TensorType_UINT16); +TFLITE_TENSOR_TYPE_ASSOC(uint32_t, TensorType_UINT32); +TFLITE_TENSOR_TYPE_ASSOC(uint64_t, TensorType_UINT64); +TFLITE_TENSOR_TYPE_ASSOC(TfLiteFloat16, TensorType_FLOAT16); +TFLITE_TENSOR_TYPE_ASSOC(TfLiteBFloat16, TensorType_BFLOAT16); +TFLITE_TENSOR_TYPE_ASSOC(float, TensorType_FLOAT32); +TFLITE_TENSOR_TYPE_ASSOC(double, TensorType_FLOAT64); +TFLITE_TENSOR_TYPE_ASSOC(std::string, TensorType_STRING); + +#undef TFLITE_TENSOR_TYPE_ASSOC +}; // namespace tflite +#endif // BENCHMARK_TENSORFLOW_LITE + void init_params(xnn_unary_operator op, xnn_datatype in_type, xnn_datatype out_type, xnn_unary_params& params, xnn_quantization_params& input_quantization, @@ -343,13 +377,14 @@ static void benchmark_tflite_unary_operator( state.counters["cpufreq"] = cpu_frequency; } - state.counters["elements"] = benchmark::Counter( - uint64_t(state.iterations()) * batch_size, benchmark::Counter::kIsRate); + state.counters["elements"] = + benchmark::Counter(static_cast(state.iterations()) * batch_size, + benchmark::Counter::kIsRate); const size_t bytes_per_iteration = 2 * batch_size * sizeof(float); - state.counters["bytes"] = - benchmark::Counter(uint64_t(state.iterations()) * bytes_per_iteration, - benchmark::Counter::kIsRate); + state.counters["bytes"] = benchmark::Counter( + static_cast(state.iterations()) * bytes_per_iteration, + benchmark::Counter::kIsRate); interpreter.reset(); } @@ -486,9 +521,9 @@ BENCHMARK_OP(cube_root); BENCHMARK_OP(cosine); BENCHMARK_OP(sine); // Missing in TFlite? -//BENCHMARK_OP(count_leading_zeros); -//BENCHMARK_OP(bitwise_not); -//BENCHMARK_OP(popcount); +// BENCHMARK_OP(count_leading_zeros); +// BENCHMARK_OP(bitwise_not); +// BENCHMARK_OP(popcount); BENCHMARK_OP(sign); BENCHMARK_CONVERT(qs8_qs8, xnnpack::quantized, @@ -526,5 +561,5 @@ BENCHMARK_CONVERT(f32_bf16, float, xnn_bfloat16); // BENCHMARK_CONVERT(f32_f32, float, float); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/utils.cc b/bench/utils.cc index 61e4f862c9d..a5dd6f6bbc5 100644 --- a/bench/utils.cc +++ b/bench/utils.cc @@ -3,33 +3,40 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. +#include "utils.h" + +#include #include #include #include #include +#include #include -#include "xnnpack/common.h" -#include -#include "pthreadpool.h" - #ifdef __linux__ - #include +#include #endif #if defined(__ANDROID__) || defined(_WIN32) || defined(__CYGWIN__) - #include +#include #endif #if defined(__SSE__) || defined(__x86_64__) - #include +#include #endif #if XNN_ENABLE_CPUINFO - #include +#include #endif // XNN_ENABLE_CPUINFO +#include "xnnpack/common.h" #include "xnnpack/hardware-config.h" +#include +#include "pthreadpool.h" -#include "utils.h" +// Common flags for all benchmarks. +int FLAGS_num_threads = 1; +int FLAGS_batch_size = 1; +uint32_t FLAGS_xnn_runtime_flags = 0; +uint32_t FLAGS_benchmark_min_iters = 1; namespace benchmark { namespace utils { @@ -88,6 +95,49 @@ void PthreadpoolClearL2Cache(void* context, size_t id) { }; // namespace +int ProcessArgs(int& argc, char**& argv) { + for (int i = 1; i < argc;) { + if (strncmp(argv[i], "--num_threads=", 14) == 0) { + FLAGS_num_threads = atoi(argv[i] + 14); + if (FLAGS_num_threads <= 0) { + std::cerr << "Invalid --num_threads: " << FLAGS_num_threads << "\n"; + return 1; + } + std::copy(argv + i + 1, argv + argc, argv + i); + argc -= 1; + } else if (strncmp(argv[i], "--batch_size=", 13) == 0) { + FLAGS_batch_size = atoi(argv[i] + 13); + if (FLAGS_batch_size <= 0) { + std::cerr << "Invalid --batch_size: " << FLAGS_batch_size << "\n"; + return 1; + } + std::copy(argv + i + 1, argv + argc, argv + i); + argc -= 1; + } else if (strncmp(argv[i], "--xnn_runtime_flags=", 20) == 0) { + const char* v = argv[i] + 20; + if (strlen(v) > 2 && strncmp(v, "0x", 2) == 0) { + FLAGS_xnn_runtime_flags = strtoul(v + 2, nullptr, 16); + } else { + FLAGS_xnn_runtime_flags = strtoul(v, nullptr, 10); + } + std::copy(argv + i + 1, argv + argc, argv + i); + argc -= 1; + } else if (strncmp(argv[i], "--benchmark_min_iters=", 22) == 0) { + FLAGS_benchmark_min_iters = atoi(argv[i] + 22); + if (FLAGS_benchmark_min_iters <= 0) { + std::cerr << "Invalid --benchmark_min_iters: " << FLAGS_benchmark_min_iters << "\n"; + return 1; + } + std::copy(argv + i + 1, argv + argc, argv + i); + argc -= 1; + } else { + ++i; + } + } + // InitGoogle(...); + return 0; +} + uint32_t PrefetchToL1(const void* ptr, size_t size) { uint32_t step = 16; #if XNN_ENABLE_CPUINFO @@ -154,7 +204,7 @@ void DisableDenormals() { #endif } -// Return clockrate in Hz +// Return clock rate in Hz. uint64_t GetCurrentCpuFrequency() { #ifdef __linux__ int freq = 0; diff --git a/bench/utils.h b/bench/utils.h index 168aeecd828..cdf676734bf 100644 --- a/bench/utils.h +++ b/bench/utils.h @@ -8,14 +8,42 @@ #include #include #include +#include #include "xnnpack/common.h" #include #include "pthreadpool.h" +#ifdef BENCHMARK_ARGS_BOTTLENECK +#define XNN_BENCHMARK_MAIN() \ + extern "C" { \ + int BenchmarkArgBottleneck(int& argc, char**& argv) { \ + return benchmark::utils::ProcessArgs(argc, argv); \ + } \ + } +#else +#define XNN_BENCHMARK_MAIN() \ + int main(int argc, char** argv) { \ + ::benchmark::Initialize(&argc, argv); \ + int status = benchmark::utils::ProcessArgs(argc, argv); \ + if (status != 0) return status; \ + if (::benchmark::ReportUnrecognizedArguments(argc, argv)) return 1; \ + ::benchmark::RunSpecifiedBenchmarks(); \ + } \ + int main(int, char**) +#endif // BENCHMARK_ARGS_BOTTLENECK + +// Common flags for all benchmarks. +extern int FLAGS_num_threads; +extern int FLAGS_batch_size; +extern uint32_t FLAGS_xnn_runtime_flags; +extern uint32_t FLAGS_benchmark_min_iters; + namespace benchmark { namespace utils { +int ProcessArgs(int& argc, char**& argv); + uint32_t WipeCache(); uint32_t PrefetchToL1(const void* ptr, size_t size); diff --git a/bench/vbinary.cc b/bench/vbinary.cc index 06731153fdb..17d29b72ec4 100644 --- a/bench/vbinary.cc +++ b/bench/vbinary.cc @@ -212,5 +212,5 @@ static void vbinary(benchmark::State& state, uint64_t arch_flags, #undef XNN_UKERNEL_WITH_PARAMS #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/vunary.cc b/bench/vunary.cc index ba679c30586..f17ae84fca2 100644 --- a/bench/vunary.cc +++ b/bench/vunary.cc @@ -297,5 +297,5 @@ void vlrelu(benchmark::State& state, uint64_t arch_flags, #undef XNN_CVT_UKERNEL_WITH_PARAMS #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/x16-packw.cc b/bench/x16-packw.cc index a2b699bc3fc..c4506fa5196 100644 --- a/bench/x16-packw.cc +++ b/bench/x16-packw.cc @@ -27,6 +27,6 @@ BENCHMARK_CAPTURE_BGEMM(x16_packw, ukernel##_, ukernel, arch_flags, nr, kr, sr); #undef XNN_UKERNEL #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/x32-packw.cc b/bench/x32-packw.cc index 46c387a0fc4..fec2117202b 100644 --- a/bench/x32-packw.cc +++ b/bench/x32-packw.cc @@ -38,6 +38,6 @@ BENCHMARK_CAPTURE_BGEMM(x32_gio_packw, ukernel##_, ukernel, arch_flags, nr, kr, #undef XNN_UKERNEL #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/x8-lut.cc b/bench/x8-lut.cc index 5cd9bafadc1..e15f04857b0 100644 --- a/bench/x8-lut.cc +++ b/bench/x8-lut.cc @@ -235,5 +235,5 @@ BENCHMARK_CAPTURE(x8_lut, scalar_u16, ->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/x8-packq.cc b/bench/x8-packq.cc index b69755239b3..3ce920b33f5 100644 --- a/bench/x8-packq.cc +++ b/bench/x8-packq.cc @@ -36,6 +36,6 @@ BENCHMARK_CAPTURE_BGEMM(x8_packq, ukernel##_mr4_kr4_, ukernel, arch_flags, /*mr= #undef XNN_UKERNEL #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/x8-packw.cc b/bench/x8-packw.cc index 013a197d50b..0fe894e98eb 100644 --- a/bench/x8-packw.cc +++ b/bench/x8-packw.cc @@ -36,6 +36,6 @@ BENCHMARK_CAPTURE_BGEMM(x8_gio_packw, ukernel##_, ukernel, arch_flags, nr, kr, s #undef XNN_UKERNEL #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/xN-transposec.cc b/bench/xN-transposec.cc index a642dc0b38f..3be27f70042 100644 --- a/bench/xN-transposec.cc +++ b/bench/xN-transposec.cc @@ -73,5 +73,5 @@ static void BenchmarkKernelSize(benchmark::internal::Benchmark* b) #undef XNN_TRANSPOSE_UKERNEL #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/bench/xx-transposev.cc b/bench/xx-transposev.cc index 6f5407abb1e..9209ab14511 100644 --- a/bench/xx-transposev.cc +++ b/bench/xx-transposev.cc @@ -62,5 +62,5 @@ BENCHMARK_CAPTURE(transpose, 1x1_scalar_memcpy, xnn_xx_transposev_ukernel__1x1_s ->Apply(BenchmarkKernelSize)->UseRealTime(); #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif diff --git a/gemm_compiler/aarch64_template.py b/gemm_compiler/aarch64_template.py index 2ef7ed28cd4..67cac7e5fcd 100644 --- a/gemm_compiler/aarch64_template.py +++ b/gemm_compiler/aarch64_template.py @@ -276,17 +276,15 @@ def clamp_inputs_and_outputs( def increment_ptr(self, ptr, step): return f'add {ptr}, {ptr}, {step}\n' - def zero_gp_register(self, reg): - return f'eor {reg}, {reg}, {reg}\n' + def initialize_k_register(self, reg): + kc_register = self.kc_register() + return f'mov {reg}, {kc_register}\n' def cmp_k_and_jump_if_less(self, label): kc_register = self.kc_register() k_register = self.k_register() - return """add {k_register}, {k_register}, 4 - cmp {kc_register}, {k_register} - bne {label}\n""".format( - label=label, k_register=k_register, kc_register=kc_register - ) + return f"""subs {k_register}, {k_register}, 4 + bne {label}\n""" def epilogue(self, M, N, isa): restore_stack = """ diff --git a/gemm_compiler/base_architecture.py b/gemm_compiler/base_architecture.py index 3caccac4c8a..9133ef58eb5 100644 --- a/gemm_compiler/base_architecture.py +++ b/gemm_compiler/base_architecture.py @@ -143,8 +143,8 @@ def increment_ptr(self, ptr, step): raise NotImplementedError @abstractmethod - def zero_gp_register(self, reg): - """Zero the given general purpose register.""" + def initialize_k_register(self, reg): + """Initialized the given general purpose register for inner loop control.""" raise NotImplementedError @abstractmethod diff --git a/gemm_compiler/generate.py b/gemm_compiler/generate.py index 097121c0034..d5a99898f99 100644 --- a/gemm_compiler/generate.py +++ b/gemm_compiler/generate.py @@ -36,8 +36,8 @@ def generate_gemm_microkernel( # the outer loop label asm_string += '\nouter_loop:\n' - asm_string += '# Zero k counter.\n' - asm_string += isa.zero_gp_register(k_register) + asm_string += '# Initialize k counter.\n' + asm_string += isa.initialize_k_register(k_register) # Read a registers from the stack if required asm_string += isa.read_a_registers(M=M) diff --git a/gemm_compiler/neonfma_template.py b/gemm_compiler/neonfma_template.py index 1e88d9320fe..55e28e830fd 100644 --- a/gemm_compiler/neonfma_template.py +++ b/gemm_compiler/neonfma_template.py @@ -60,7 +60,7 @@ def w_registers(self): def input_asm(self): in_asm = { 'loop': [ - 'ldr d{AM}, [{AM_ptr}, {a_offset}]\n', + 'ldr s{AM}, [{AM_ptr}], 4\n', ] } return in_asm @@ -149,6 +149,10 @@ def store( ACC_1=accumulators[M * 3 + mr], c_reg=cm_registers[mr], ) + for mr in range(0, M): + AM_PTR = self.am_registers()[mr] + kc_register = self.kc_register() + asm_string += f'sub {AM_PTR}, {AM_PTR}, {kc_register}\n' CHECK = """ sub {nc}, {nc}, {n_step} b.ne outer_loop diff --git a/gemm_compiler/x64_template.py b/gemm_compiler/x64_template.py index ac6886eab05..637cd565beb 100644 --- a/gemm_compiler/x64_template.py +++ b/gemm_compiler/x64_template.py @@ -272,7 +272,7 @@ def read_a_registers(self, M): def increment_ptr(self, ptr, step): return f'add {ptr}, {step}\n' - def zero_gp_register(self, reg): + def initialize_k_register(self, reg): return f'mov {reg}, 0\n' def cmp_k_and_jump_if_less(self, label): diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32-acc2.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32-acc2.c index 0510794c3a6..34db8d5ddde 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32-acc2.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32-acc2.c @@ -188,5 +188,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u32_acc2( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32-acc4.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32-acc4.c index 230bdf9b04f..7fe5fbbbcb0 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32-acc4.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32-acc4.c @@ -192,5 +192,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u32_acc4( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32.c index 07ea8cc93da..a0eb34e24a9 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u32.c @@ -186,5 +186,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u32( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40-acc2.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40-acc2.c index afd7e5b93f5..e511a940688 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40-acc2.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40-acc2.c @@ -202,5 +202,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u40_acc2( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40-acc5.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40-acc5.c index 9e66b7ffee2..15022ee55c5 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40-acc5.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40-acc5.c @@ -208,5 +208,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u40_acc5( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40.c index cced3be28b3..166ea58ffef 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u40.c @@ -200,5 +200,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u40( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48-acc2.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48-acc2.c index 433c23cc0b7..639a63740ef 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48-acc2.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48-acc2.c @@ -216,5 +216,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u48_acc2( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48-acc3.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48-acc3.c index d26a6f834b1..34223cc4d7e 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48-acc3.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48-acc3.c @@ -218,5 +218,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u48_acc3( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48.c index c49c63a4603..113d55b402a 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u48.c @@ -214,5 +214,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u48( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64-acc2.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64-acc2.c index cc7f01f60f3..f08d61ed90c 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64-acc2.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64-acc2.c @@ -244,5 +244,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u64_acc2( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64-acc4.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64-acc4.c index e235ece9d8b..1124286859c 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64-acc4.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64-acc4.c @@ -248,5 +248,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u64_acc4( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64.c index 118f2969e56..2e18d7d35c6 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u64.c @@ -242,5 +242,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u64( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u72-acc3.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u72-acc3.c index 3658c324212..88fd94b7a9a 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u72-acc3.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u72-acc3.c @@ -260,5 +260,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u72_acc3( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u72.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u72.c index 3519b816642..829c8001716 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u72.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u72.c @@ -256,5 +256,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u72( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80-acc2.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80-acc2.c index f29ab2e10be..31a4c383db6 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80-acc2.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80-acc2.c @@ -272,5 +272,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u80_acc2( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80-acc5.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80-acc5.c index 95959accdc8..ae29289806d 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80-acc5.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80-acc5.c @@ -278,5 +278,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u80_acc5( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80.c index cea563af865..b1cae46662f 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u80.c @@ -270,5 +270,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u80( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc2.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc2.c index 494e2fe95eb..a96eae30ffd 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc2.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc2.c @@ -300,5 +300,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u96_acc2( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc3.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc3.c index 5eb5bcc6e40..513c538c3c3 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc3.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc3.c @@ -302,5 +302,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u96_acc3( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc6.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc6.c index ac002e864f1..41139c1aac1 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc6.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96-acc6.c @@ -308,5 +308,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u96_acc6( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96.c b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96.c index 40cbcdf8afa..bb6ba0f0dc8 100644 --- a/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96.c +++ b/src/f16-raddstoreexpminusmax/gen/f16-raddstoreexpminusmax-neonfp16arith-rr2-p2-u96.c @@ -298,5 +298,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u96( } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f16-raddstoreexpminusmax/neonfp16arith-rr2-p2.c.in b/src/f16-raddstoreexpminusmax/neonfp16arith-rr2-p2.c.in index 0126b780000..9a118e7560e 100644 --- a/src/f16-raddstoreexpminusmax/neonfp16arith-rr2-p2.c.in +++ b/src/f16-raddstoreexpminusmax/neonfp16arith-rr2-p2.c.in @@ -168,5 +168,5 @@ void xnn_f16_raddstoreexpminusmax_ukernel__neonfp16arith_rr2_p2_u${BATCH_TILE}${ } vacc_lo = vpadd_f16(vacc_lo, vacc_lo); vacc_lo = vpadd_f16(vacc_lo, vacc_lo); - vst1_lane_u16(sum, vreinterpret_u16_f16(vacc_lo), 0); + vst1_lane_u16((uint16_t*) sum, vreinterpret_u16_f16(vacc_lo), 0); } diff --git a/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S index f09b6b989a6..b2da35107d2 100644 --- a/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-10x16-minmax-asm-amd64-avx512f-broadcast.S @@ -140,7 +140,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_10x16__asm_amd64_avx512f_broadcast mov [rsp - 280], r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S index 13f7f2fc41c..efd156a9ab4 100644 --- a/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-10x32-minmax-asm-amd64-avx512f-broadcast.S @@ -140,7 +140,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_10x32__asm_amd64_avx512f_broadcast mov [rsp - 280], r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S index 09b78ddbf9e..c8c4c288e71 100644 --- a/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-11x16-minmax-asm-amd64-avx512f-broadcast.S @@ -152,7 +152,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_11x16__asm_amd64_avx512f_broadcast mov [rsp - 296], r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S index 576ef75275f..8e090a0b182 100644 --- a/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-11x32-minmax-asm-amd64-avx512f-broadcast.S @@ -152,7 +152,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_11x32__asm_amd64_avx512f_broadcast mov [rsp - 296], r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S index a20251cc641..c6827b2b669 100644 --- a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-aarch64-neonfma-ld32.S @@ -20,23 +20,22 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_aarch64_neonfma_lane ld2r {v0.4s, v1.4s}, [x13] outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q12, [x5, 0] ldp q13, q14, [x5, 32] add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] + ldr s2, [x3], 4 ldp q7, q8, [x5], 32 ldp q9, q10, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] fmla v12.4s, v8.4s, v2.s[0] fmla v13.4s, v9.4s, v2.s[0] fmla v14.4s, v10.4s, v2.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -53,6 +52,7 @@ inner_loop: b.lo tail_8 stp q11, q12, [x6], 32 stp q13, q14, [x6], 32 + sub x3, x3, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S index 1dfd045b822..de9feca9f79 100644 --- a/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-1x16-minmax-asm-amd64-avx512f-broadcast.S @@ -28,7 +28,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x16__asm_amd64_avx512f_broadcast mov r11, [rsp + 64] outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S index 85b0958e4d3..8bf1ff26b55 100644 --- a/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-1x32-minmax-asm-amd64-avx512f-broadcast.S @@ -28,7 +28,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x32__asm_amd64_avx512f_broadcast mov r11, [rsp + 64] outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S index 9b3b2acd90a..f0e2e17c3d8 100644 --- a/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-1x64-minmax-asm-amd64-avx512f-broadcast.S @@ -28,7 +28,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x64__asm_amd64_avx512f_broadcast mov r11, [rsp + 64] outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S index 3175bb6408d..d6aec57123e 100644 --- a/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-1x8-minmax-asm-aarch64-neonfma-ld32.S @@ -20,19 +20,18 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_1x8__asm_aarch64_neonfma_lane ld2r {v0.4s, v1.4s}, [x13] outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q12, [x5, 0] add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] + ldr s2, [x3], 4 ldp q7, q8, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] fmla v12.4s, v8.4s, v2.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -44,6 +43,7 @@ inner_loop: cmp x1, 8 b.lo tail_4 stp q11, q12, [x6], 32 + sub x3, x3, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S index 6c361d46933..dd3cb73756e 100644 --- a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-aarch64-neonfma-ld32.S @@ -27,8 +27,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_aarch64_neonfma_lane csel x13, x6, x13, LO outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q13, [x5, 0] ldp q15, q17, [x5, 32] @@ -39,8 +39,8 @@ outer_loop: add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 ldp q7, q8, [x5], 32 ldp q9, q10, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] @@ -51,8 +51,7 @@ inner_loop: fmla v16.4s, v9.4s, v3.s[0] fmla v17.4s, v10.4s, v2.s[0] fmla v18.4s, v10.4s, v3.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -79,6 +78,8 @@ inner_loop: stp q15, q17, [x6], 32 stp q12, q14, [x13], 32 stp q16, q18, [x13], 32 + sub x3, x3, x2 + sub x9, x9, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S index f429e9631c6..36236780321 100644 --- a/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-2x16-minmax-asm-amd64-avx512f-broadcast.S @@ -37,7 +37,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x16__asm_amd64_avx512f_broadcast cmovle r13, r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S index 212db9528b5..f80ba7448b1 100644 --- a/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-2x32-minmax-asm-amd64-avx512f-broadcast.S @@ -37,7 +37,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x32__asm_amd64_avx512f_broadcast cmovle r13, r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S index d79b807966c..105cf33322e 100644 --- a/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-2x64-minmax-asm-amd64-avx512f-broadcast.S @@ -37,7 +37,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x64__asm_amd64_avx512f_broadcast cmovle r13, r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S index 1749ed7217d..bba3de14cf7 100644 --- a/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-2x8-minmax-asm-aarch64-neonfma-ld32.S @@ -27,8 +27,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_2x8__asm_aarch64_neonfma_lane csel x13, x6, x13, LO outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q13, [x5, 0] mov v12.16b, v11.16b @@ -36,15 +36,14 @@ outer_loop: add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 ldp q7, q8, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] fmla v12.4s, v7.4s, v3.s[0] fmla v13.4s, v8.4s, v2.s[0] fmla v14.4s, v8.4s, v3.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -61,6 +60,8 @@ inner_loop: b.lo tail_4 stp q11, q13, [x6], 32 stp q12, q14, [x13], 32 + sub x3, x3, x2 + sub x9, x9, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S index a963c5fbab5..9b2af6912f4 100644 --- a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-aarch64-neonfma-ld32.S @@ -31,8 +31,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_aarch64_neonfma_lane csel x14, x13, x14, LS outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q14, [x5, 0] ldp q17, q20, [x5, 32] @@ -47,9 +47,9 @@ outer_loop: add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 ldp q7, q8, [x5], 32 ldp q9, q10, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] @@ -64,8 +64,7 @@ inner_loop: fmla v20.4s, v10.4s, v2.s[0] fmla v21.4s, v10.4s, v3.s[0] fmla v22.4s, v10.4s, v4.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -102,6 +101,9 @@ inner_loop: stp q18, q21, [x13], 32 stp q13, q16, [x14], 32 stp q19, q22, [x14], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S index dc65d2c61af..53ca4d7700e 100644 --- a/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-3x16-minmax-asm-amd64-avx512f-broadcast.S @@ -46,7 +46,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x16__asm_amd64_avx512f_broadcast cmovle rbx, r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S index 92db444e85a..5510c79e6df 100644 --- a/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-3x32-minmax-asm-amd64-avx512f-broadcast.S @@ -46,7 +46,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x32__asm_amd64_avx512f_broadcast cmovle rbx, r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S index ca88fdb4806..bd1641c2879 100644 --- a/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-3x64-minmax-asm-amd64-avx512f-broadcast.S @@ -46,7 +46,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x64__asm_amd64_avx512f_broadcast cmovle rbx, r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S index 195fd2bcbfe..26bebc208e7 100644 --- a/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-3x8-minmax-asm-aarch64-neonfma-ld32.S @@ -31,8 +31,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_3x8__asm_aarch64_neonfma_lane csel x14, x13, x14, LS outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q14, [x5, 0] mov v12.16b, v11.16b @@ -42,9 +42,9 @@ outer_loop: add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 ldp q7, q8, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] fmla v12.4s, v7.4s, v3.s[0] @@ -52,8 +52,7 @@ inner_loop: fmla v14.4s, v8.4s, v2.s[0] fmla v15.4s, v8.4s, v3.s[0] fmla v16.4s, v8.4s, v4.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -75,6 +74,9 @@ inner_loop: stp q11, q14, [x6], 32 stp q12, q15, [x13], 32 stp q13, q16, [x14], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S index f91ba2bdde1..fb40896a6e4 100644 --- a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-aarch64-neonfma-ld32.S @@ -37,8 +37,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_aarch64_neonfma_lane csel x15, x14, x15, LO outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q15, [x5, 0] ldp q19, q23, [x5, 32] @@ -57,10 +57,10 @@ outer_loop: add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] - ldr d5, [x11, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 ldp q7, q8, [x5], 32 ldp q9, q10, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] @@ -79,8 +79,7 @@ inner_loop: fmla v24.4s, v10.4s, v3.s[0] fmla v25.4s, v10.4s, v4.s[0] fmla v26.4s, v10.4s, v5.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -127,6 +126,10 @@ inner_loop: stp q21, q25, [x14], 32 stp q14, q18, [x15], 32 stp q22, q26, [x15], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S index 71d442f0da7..d0938f5ae74 100644 --- a/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-4x16-minmax-asm-amd64-avx512f-broadcast.S @@ -55,7 +55,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x16__asm_amd64_avx512f_broadcast cmovle rbp, rbx outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S index 7f117407c4a..75764d9495a 100644 --- a/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-4x32-minmax-asm-amd64-avx512f-broadcast.S @@ -55,7 +55,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x32__asm_amd64_avx512f_broadcast cmovle rbp, rbx outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S index e862b8a36de..4492f3fabef 100644 --- a/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-4x64-minmax-asm-amd64-avx512f-broadcast.S @@ -55,7 +55,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x64__asm_amd64_avx512f_broadcast cmovle rbp, rbx outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S index 31d3ebf78e9..70f9d83b2f5 100644 --- a/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-4x8-minmax-asm-aarch64-neonfma-ld32.S @@ -37,8 +37,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x8__asm_aarch64_neonfma_lane csel x15, x14, x15, LO outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q15, [x5, 0] mov v12.16b, v11.16b @@ -50,10 +50,10 @@ outer_loop: add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] - ldr d5, [x11, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 ldp q7, q8, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] fmla v12.4s, v7.4s, v3.s[0] @@ -63,8 +63,7 @@ inner_loop: fmla v16.4s, v8.4s, v3.s[0] fmla v17.4s, v8.4s, v4.s[0] fmla v18.4s, v8.4s, v5.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -91,6 +90,10 @@ inner_loop: stp q12, q16, [x13], 32 stp q13, q17, [x14], 32 stp q14, q18, [x15], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S index 8f5d12c8e6c..ffd980a6448 100644 --- a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-aarch64-neonfma-ld32.S @@ -41,8 +41,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_aarch64_neonfma_lane csel x19, x15, x19, LS outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q16, [x5, 0] ldp q21, q26, [x5, 32] @@ -65,11 +65,11 @@ outer_loop: add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] - ldr d5, [x11, x20] - ldr d6, [x12, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 + ldr s6, [x12], 4 ldp q7, q8, [x5], 32 ldp q9, q10, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] @@ -92,8 +92,7 @@ inner_loop: fmla v28.4s, v10.4s, v4.s[0] fmla v29.4s, v10.4s, v5.s[0] fmla v30.4s, v10.4s, v6.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -150,6 +149,11 @@ inner_loop: stp q24, q29, [x15], 32 stp q15, q20, [x19], 32 stp q25, q30, [x19], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 + sub x12, x12, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S index b6743cce467..ea0997bc825 100644 --- a/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-5x16-minmax-asm-amd64-avx512f-broadcast.S @@ -64,7 +64,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x16__asm_amd64_avx512f_broadcast cmovle r8, rbp outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S index c0273cc2f80..7d76f7794f5 100644 --- a/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-5x32-minmax-asm-amd64-avx512f-broadcast.S @@ -64,7 +64,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x32__asm_amd64_avx512f_broadcast cmovle r8, rbp outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S index 8a2a511a247..8adfd1c1ca5 100644 --- a/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-5x64-minmax-asm-amd64-avx512f-broadcast.S @@ -64,7 +64,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x64__asm_amd64_avx512f_broadcast cmovle r8, rbp outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with the biases. vmovaps zmm7, [r9 + 0] diff --git a/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S b/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S index 621a6a6c740..1205fdbff75 100644 --- a/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S +++ b/src/f32-gemm/gen/f32-gemm-5x8-minmax-asm-aarch64-neonfma-ld32.S @@ -41,8 +41,8 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_5x8__asm_aarch64_neonfma_lane csel x19, x15, x19, LS outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with the biases. ldp q11, q16, [x5, 0] mov v12.16b, v11.16b @@ -56,11 +56,11 @@ outer_loop: add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] - ldr d5, [x11, x20] - ldr d6, [x12, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 + ldr s6, [x12], 4 ldp q7, q8, [x5], 32 fmla v11.4s, v7.4s, v2.s[0] fmla v12.4s, v7.4s, v3.s[0] @@ -72,8 +72,7 @@ inner_loop: fmla v18.4s, v8.4s, v4.s[0] fmla v19.4s, v8.4s, v5.s[0] fmla v20.4s, v8.4s, v6.s[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Min/max clamping.. fmin v11.4s, v1.4s, v11.4s @@ -105,6 +104,11 @@ inner_loop: stp q13, q18, [x14], 32 stp q14, q19, [x15], 32 stp q15, q20, [x19], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 + sub x12, x12, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S index 48c1de312f3..d44a42ccb3a 100644 --- a/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-6x16-minmax-asm-amd64-avx512f-broadcast.S @@ -92,7 +92,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_6x16__asm_amd64_avx512f_broadcast mov [rsp - 216], r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S index 0a12f3e02af..9a8a091a14c 100644 --- a/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-6x32-minmax-asm-amd64-avx512f-broadcast.S @@ -92,7 +92,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_6x32__asm_amd64_avx512f_broadcast mov [rsp - 216], r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S index 7d83d0394bf..361c5d1ab3a 100644 --- a/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-7x16-minmax-asm-amd64-avx512f-broadcast.S @@ -104,7 +104,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_7x16__asm_amd64_avx512f_broadcast mov [rsp - 232], r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S index 57f6b2cd431..9354a90d8af 100644 --- a/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-7x32-minmax-asm-amd64-avx512f-broadcast.S @@ -104,7 +104,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_7x32__asm_amd64_avx512f_broadcast mov [rsp - 232], r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S index 70355ca7cb7..7233c24350f 100644 --- a/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-8x16-minmax-asm-amd64-avx512f-broadcast.S @@ -116,7 +116,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_8x16__asm_amd64_avx512f_broadcast mov [rsp - 248], r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S index 4a6a4097146..e4633f2fcfe 100644 --- a/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-8x32-minmax-asm-amd64-avx512f-broadcast.S @@ -116,7 +116,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_8x32__asm_amd64_avx512f_broadcast mov [rsp - 248], r13 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S index b37845822da..c239863a9c7 100644 --- a/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-9x16-minmax-asm-amd64-avx512f-broadcast.S @@ -128,7 +128,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_9x16__asm_amd64_avx512f_broadcast mov [rsp - 264], r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S b/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S index beda0925e7e..c65ab67655f 100644 --- a/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S +++ b/src/f32-gemm/gen/f32-gemm-9x32-minmax-asm-amd64-avx512f-broadcast.S @@ -128,7 +128,7 @@ BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_9x32__asm_amd64_avx512f_broadcast mov [rsp - 264], r10 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/operator-run.c b/src/operator-run.c index afcf6ebbfc2..0a3f060b7a9 100644 --- a/src/operator-run.c +++ b/src/operator-run.c @@ -20,6 +20,7 @@ #include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microkernel-type.h" +#include "xnnpack/microparams-init.h" #include "xnnpack/microparams.h" #include "xnnpack/operator-type.h" #include "xnnpack/operator.h" @@ -404,48 +405,55 @@ void xnn_compute_hmp_grouped_gemm( group_index_b = (index % context->batch_dims_b[k]) + context->batch_dims_b[k] * group_index_b; } - if (context->quantization_params != NULL) { - // If the effective `mr_block_size` is smaller than the kernel's `mr`, - // create a padded copy of the dynamic quantization params. - const struct xnn_qd8_quantization_params* quantization_params = - &context->quantization_params[group_index_a * context->gq_stride + - mr_block_start]; - struct xnn_qd8_quantization_params padded_quantization_params[XNN_MAX_MR]; - if (mr_block_size < context->mr) { - memcpy(padded_quantization_params, quantization_params, - mr_block_size * sizeof(struct xnn_qd8_quantization_params)); - for (size_t i = mr_block_size; i < context->mr; i++) { - padded_quantization_params[i] = - padded_quantization_params[mr_block_size - 1]; - } - quantization_params = padded_quantization_params; - }; - - context->dq_ukernel.function[uarch_index]( - mr_block_size, nr_block_size, k_scaled, - (const void*)((uintptr_t)context->a + mr_block_start * a_stride + - group_index_a * context->ga_stride), - a_stride, - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index_b * context->gw_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize) + - group_index_c * context->gc_stride), - cm_stride, context->cn_stride, &context->params, quantization_params); - } else { - context->ukernel.function[uarch_index]( - mr_block_size, nr_block_size, k_scaled, - (const void*)((uintptr_t)context->a + mr_block_start * a_stride + - group_index_a * context->ga_stride), - a_stride, - (const void*)((uintptr_t)context->packed_w + - nr_block_start * context->w_stride + - group_index_b * context->gw_stride), - (void*)((uintptr_t)context->c + mr_block_start * cm_stride + - (nr_block_start << context->log2_csize) + - group_index_c * context->gc_stride), - cm_stride, context->cn_stride, &context->params); + while (mr_block_size > 0) { + const size_t mr_step = min(mr_block_size, context->mr); + + if (context->quantization_params != NULL) { + // If the effective `mr_block_size` is smaller than the kernel's `mr`, + // create a padded copy of the dynamic quantization params. + const struct xnn_qd8_quantization_params* quantization_params = + &context->quantization_params[group_index_a * context->gq_stride + + mr_block_start]; + struct xnn_qd8_quantization_params padded_quantization_params[XNN_MAX_MR]; + if (mr_step < context->mr) { + for (size_t i = 0; i < mr_step; i++) { + padded_quantization_params[i] = quantization_params[i]; + } + for (size_t i = mr_step; i < context->mr; i++) { + padded_quantization_params[i] = + padded_quantization_params[mr_step - 1]; + } + quantization_params = padded_quantization_params; + }; + + context->dq_ukernel.function[uarch_index]( + mr_step, nr_block_size, k_scaled, + (const void*)((uintptr_t)context->a + mr_block_start * a_stride + + group_index_a * context->ga_stride), + a_stride, + (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride + + group_index_b * context->gw_stride), + (void*)((uintptr_t)context->c + mr_block_start * cm_stride + + (nr_block_start << context->log2_csize) + + group_index_c * context->gc_stride), + cm_stride, context->cn_stride, &context->params, quantization_params); + } else { + context->ukernel.function[uarch_index]( + mr_step, nr_block_size, k_scaled, + (const void*)((uintptr_t)context->a + mr_block_start * a_stride + + group_index_a * context->ga_stride), + a_stride, + (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride + + group_index_b * context->gw_stride), + (void*)((uintptr_t)context->c + mr_block_start * cm_stride + + (nr_block_start << context->log2_csize) + + group_index_c * context->gc_stride), + cm_stride, context->cn_stride, &context->params); + } + mr_block_size -= mr_step; + mr_block_start += mr_step; } } @@ -468,17 +476,21 @@ void xnn_compute_gemm( const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; - context->ukernel.function[XNN_UARCH_DEFAULT]( - mr_block_size, - nr_block_size, - context->k_scaled, - (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), - a_stride, - (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), - (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), - cm_stride, - context->cn_stride, - context->fused_params); + while (mr_block_size > 0) { + const size_t mr_step = min(mr_block_size, context->mr); + + context->ukernel.function[XNN_UARCH_DEFAULT]( + mr_step, nr_block_size, context->k_scaled, + (const void*)((uintptr_t)context->a + mr_block_start * a_stride), + a_stride, + (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride), + (void*)((uintptr_t)context->c + mr_block_start * cm_stride + + (nr_block_start << context->log2_csize)), + cm_stride, context->cn_stride, context->fused_params); + mr_block_size -= mr_step; + mr_block_start += mr_step; + } } void xnn_compute_dqgemm( @@ -491,26 +503,29 @@ void xnn_compute_dqgemm( const size_t a_stride = context->a_stride; const size_t cm_stride = context->cm_stride; - context->dq_ukernel.function[XNN_UARCH_DEFAULT]( - mr_block_size, - nr_block_size, - context->k_scaled, - (const void*) ((uintptr_t) context->a + mr_block_start * a_stride), - a_stride, - (const void*) ((uintptr_t) context->packed_w + nr_block_start * context->w_stride), - (void*) ((uintptr_t) context->c + mr_block_start * cm_stride + (nr_block_start << context->log2_csize)), - cm_stride, - context->cn_stride, - context->fused_params, - (const void*) ((uintptr_t) &context->quantization_params[mr_block_start])); + while (mr_block_size > 0) { + const size_t mr_step = min(mr_block_size, context->mr); + + context->dq_ukernel.function[XNN_UARCH_DEFAULT]( + mr_step, nr_block_size, context->k_scaled, + (const void*)((uintptr_t)context->a + mr_block_start * a_stride), + a_stride, + (const void*)((uintptr_t)context->packed_w + + nr_block_start * context->w_stride), + (void*)((uintptr_t)context->c + mr_block_start * cm_stride + + (nr_block_start << context->log2_csize)), + cm_stride, context->cn_stride, context->fused_params, + (const void*)((uintptr_t)&context + ->quantization_params[mr_block_start])); + mr_block_size -= mr_step; + mr_block_start += mr_step; + } } void xnn_compute_hmp_grouped_qp8gemm( const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], uint32_t uarch_index, size_t group_index, size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, size_t nr_block_size) { - const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset( - mr_block_start, context->k_scaled, context->mr, context->kr, context->sr); const size_t cm_stride = context->cm_stride; const size_t num_batch_dims = context->num_batch_dims; @@ -530,18 +545,27 @@ void xnn_compute_hmp_grouped_qp8gemm( context->batch_dims_b[k] * group_index_b; } - context->qp8_ukernel.function[uarch_index]( - mr_block_size, nr_block_size, context->k_scaled, - (const void*)((uintptr_t)context->a + group_index_a * context->ga_stride + - a_offset), - (const void*)((uintptr_t)context->packed_w + - group_index_b * context->gw_stride + - nr_block_start * context->w_stride), - (void*)((uintptr_t)context->c + group_index_c * context->gc_stride + - mr_block_start * cm_stride + - (nr_block_start << context->log2_csize)), - cm_stride, - /*dst_stride_col=*/sizeof(float), context->fused_params); + while (mr_block_size > 0) { + const size_t mr_step = min(mr_block_size, context->mr); + const size_t a_offset = xnn_x8_packq_f32qp8_packed_offset( + mr_block_start, context->k_scaled, context->mr, context->kr, + context->sr); + + context->qp8_ukernel.function[uarch_index]( + mr_step, nr_block_size, context->k_scaled, + (const void*)((uintptr_t)context->a + + group_index_a * context->ga_stride + a_offset), + (const void*)((uintptr_t)context->packed_w + + group_index_b * context->gw_stride + + nr_block_start * context->w_stride), + (void*)((uintptr_t)context->c + group_index_c * context->gc_stride + + mr_block_start * cm_stride + + (nr_block_start << context->log2_csize)), + cm_stride, + /*dst_stride_col=*/sizeof(float), context->fused_params); + mr_block_size -= mr_step; + mr_block_start += mr_step; + } } void xnn_compute_grouped_qp8gemm( @@ -3075,6 +3099,27 @@ enum xnn_status xnn_run_operator_with_index( op->compute[i].tile[0], op->compute[i].tile[1], flags); break; + case xnn_parallelization_type_2d_tile_1d_dynamic: + assert(op->compute[i].range[0] != 0); + assert(op->compute[i].range[1] != 0); + assert(op->compute[i].tile[0] != 0); + pthreadpool_parallelize_2d_tile_1d_dynamic( + threadpool, op->compute[i].task_2d_tile_1d_dynamic, + (void*)((uintptr_t)&op->context + op->compute[i].context_offset), + op->compute[i].range[0], op->compute[i].range[1], + op->compute[i].tile[0], flags); + break; + case xnn_parallelization_type_2d_tile_2d_dynamic: + assert(op->compute[i].range[0] != 0); + assert(op->compute[i].range[1] != 0); + assert(op->compute[i].tile[0] != 0); + assert(op->compute[i].tile[1] != 0); + pthreadpool_parallelize_2d_tile_2d_dynamic( + threadpool, op->compute[i].task_2d_tile_2d_dynamic, + (void*)((uintptr_t)&op->context + op->compute[i].context_offset), + op->compute[i].range[0], op->compute[i].range[1], + op->compute[i].tile[0], op->compute[i].tile[1], flags); + break; case xnn_parallelization_type_3d: assert(op->compute[i].range[0] != 0); assert(op->compute[i].range[1] != 0); @@ -3126,6 +3171,19 @@ enum xnn_status xnn_run_operator_with_index( op->compute[i].tile[0], op->compute[i].tile[1], flags); break; + case xnn_parallelization_type_3d_tile_2d_dynamic: + assert(op->compute[i].range[0] != 0); + assert(op->compute[i].range[1] != 0); + assert(op->compute[i].range[2] != 0); + assert(op->compute[i].tile[0] != 0); + assert(op->compute[i].tile[1] != 0); + pthreadpool_parallelize_3d_tile_2d_dynamic( + threadpool, op->compute[i].task_3d_tile_2d_dynamic, + (void*)((uintptr_t)&op->context + op->compute[i].context_offset), + op->compute[i].range[0], op->compute[i].range[1], + op->compute[i].range[2], op->compute[i].tile[0], + op->compute[i].tile[1], flags); + break; case xnn_parallelization_type_4d: assert(op->compute[i].range[0] != 0); assert(op->compute[i].range[1] != 0); diff --git a/src/operators/batch-matrix-multiply-nc.c b/src/operators/batch-matrix-multiply-nc.c index c19825035a6..919137d5ad5 100644 --- a/src/operators/batch-matrix-multiply-nc.c +++ b/src/operators/batch-matrix-multiply-nc.c @@ -591,9 +591,9 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( .gc_stride = input_b_batch_stride, }; batch_matrix_multiply_op->compute[0].type = - xnn_parallelization_type_2d_tile_1d; - batch_matrix_multiply_op->compute[0].task_2d_tile_1d = - (pthreadpool_task_2d_tile_1d_t)xnn_compute_batched_packw_gemm_goi; + xnn_parallelization_type_2d_tile_1d_dynamic; + batch_matrix_multiply_op->compute[0].task_2d_tile_1d_dynamic = + (pthreadpool_task_2d_tile_1d_dynamic_t)xnn_compute_batched_packw_gemm_goi; batch_matrix_multiply_op->compute[0].context_offset = offsetof(struct xnn_operator, context.gemm.packw_gemm_goi) - offsetof(struct xnn_operator, context); @@ -622,9 +622,9 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( }; batch_matrix_multiply_op->compute[0].type = - xnn_parallelization_type_2d_tile_1d; - batch_matrix_multiply_op->compute[0].task_2d_tile_1d = - (pthreadpool_task_2d_tile_1d_t)xnn_compute_batched_packw_gemm_gio; + xnn_parallelization_type_2d_tile_1d_dynamic; + batch_matrix_multiply_op->compute[0].task_2d_tile_1d_dynamic = + (pthreadpool_task_2d_tile_1d_dynamic_t)xnn_compute_batched_packw_gemm_gio; batch_matrix_multiply_op->compute[0].context_offset = offsetof(struct xnn_operator, context.gemm.packw_gemm_gio) - offsetof(struct xnn_operator, context); @@ -698,23 +698,23 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( (pthreadpool_task_3d_tile_2d_with_id_t)xnn_compute_hmp_grouped_gemm; } } else { - gemm_compute->type = xnn_parallelization_type_3d_tile_2d; + gemm_compute->type = xnn_parallelization_type_3d_tile_2d_dynamic; if (is_qp8_ukernel) { - gemm_compute->task_3d_tile_2d = - (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_qp8gemm; + gemm_compute->task_3d_tile_2d_dynamic = + (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_grouped_qp8gemm; } else { - gemm_compute->task_3d_tile_2d = - (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_gemm; + gemm_compute->task_3d_tile_2d_dynamic = + (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_grouped_gemm; } } #else - gemm_compute->type = xnn_parallelization_type_3d_tile_2d; + gemm_compute->type = xnn_parallelization_type_3d_tile_2d_dynamic; if (is_qp8_ukernel) { - gemm_compute->task_3d_tile_2d = - (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_qp8gemm; + gemm_compute->task_3d_tile_2d_dynamic = + (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_grouped_qp8gemm; } else { - gemm_compute->task_3d_tile_2d = - (pthreadpool_task_3d_tile_2d_t)xnn_compute_grouped_gemm; + gemm_compute->task_3d_tile_2d_dynamic = + (pthreadpool_task_3d_tile_2d_dynamic_t)xnn_compute_grouped_gemm; } #endif gemm_compute->range[0] = batch_size_c; @@ -724,7 +724,7 @@ static enum xnn_status reshape_batch_matrix_multiply_nc( gemm_compute->tile[1] = nc; batch_matrix_multiply_op->state = xnn_run_state_needs_setup; - return xnn_status_success; + return xnn_status_success; } enum xnn_status xnn_reshape_batch_matrix_multiply_nc_f16( diff --git a/src/operators/convolution-nhwc.c b/src/operators/convolution-nhwc.c index bfc0be24eb7..a23a4fcf0cf 100644 --- a/src/operators/convolution-nhwc.c +++ b/src/operators/convolution-nhwc.c @@ -2070,6 +2070,7 @@ static enum xnn_status reshape_gemm( struct xnn_hmp_gemm_ukernel gemm_ukernel = gemm_cases[mr - 1]; convolution_op->context.gemm.gemm.gemm = (struct gemm_context){ + .mr = mr, .k_scaled = group_input_channels << log2_input_element_size, .a_stride = convolution_op->input_pixel_stride << log2_input_element_size, .ga_stride = group_input_channels << log2_input_element_size, @@ -2099,25 +2100,27 @@ static enum xnn_status reshape_gemm( convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_with_uarch; convolution_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; } else { - convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; - convolution_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; + convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_dynamic; + convolution_op->compute[0].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_gemm; } #else - convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; - convolution_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; - #endif - convolution_op->compute[0].range[0] = batch_output_size; - convolution_op->compute[0].range[1] = group_output_channels; - convolution_op->compute[0].tile[0] = mr; - convolution_op->compute[0].tile[1] = nc; + convolution_op->compute[0].type = xnn_parallelization_type_2d_tile_2d_dynamic; + convolution_op->compute[0].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_gemm; +#endif + convolution_op->compute[0].range[0] = batch_output_size; + convolution_op->compute[0].range[1] = group_output_channels; + convolution_op->compute[0].tile[0] = mr; + convolution_op->compute[0].tile[1] = nc; } else { #if XNN_MAX_UARCH_TYPES > 1 if (xnn_is_hmp_gemm_ukernel(gemm_ukernel)) { convolution_op->compute[0].type = xnn_parallelization_type_3d_tile_2d_with_uarch; convolution_op->compute[0].task_3d_tile_2d_with_id = (pthreadpool_task_3d_tile_2d_with_id_t) xnn_compute_hmp_grouped_gemm; } else { - convolution_op->compute[0].type = xnn_parallelization_type_3d_tile_2d; - convolution_op->compute[0].task_3d_tile_2d = (pthreadpool_task_3d_tile_2d_t) xnn_compute_grouped_gemm; + convolution_op->compute[0].type = xnn_parallelization_type_3d_tile_2d_dynamic; + convolution_op->compute[0].task_3d_tile_2d_dynamic = (pthreadpool_task_3d_tile_2d_dynamic_t) xnn_compute_grouped_gemm; } #else convolution_op->compute[0].type = xnn_parallelization_type_3d_tile_2d; @@ -2250,7 +2253,7 @@ static enum xnn_status reshape_igemm( const size_t w_stride = extra_weights_elements_size + (round_up_po2(group_input_channels, convolution_op->ukernel.igemm.kr * convolution_op->ukernel.igemm.sr) * kernel_size << log2_filter_element_size); const size_t group_output_channels = convolution_op->group_output_channels; - convolution_op->context.igemm.igemm = (struct igemm_context) { + convolution_op->context.igemm.igemm = (struct igemm_context){ .ks = kernel_size, .ks_scaled = kernel_size * mr * sizeof(void*), .kc = group_input_channels << log2_input_element_size, @@ -2258,13 +2261,17 @@ static enum xnn_status reshape_igemm( .indirect_a = convolution_op->indirection_buffer, .zero = convolution_op->zero_buffer, .packed_w = packed_weights(convolution_op), - .cm_stride = convolution_op->output_pixel_stride << log2_output_element_size, + .cm_stride = convolution_op->output_pixel_stride + << log2_output_element_size, .cn_stride = nr << log2_output_element_size, .ga_stride = group_input_channels << log2_input_element_size, .gw_stride = w_stride * round_up(group_output_channels, nr), .gc_stride = group_output_channels << log2_output_element_size, - .ba_stride = input_height * input_width * convolution_op->input_pixel_stride << log2_input_element_size, - .bc_stride = output_size * convolution_op->output_pixel_stride << log2_output_element_size, + .ba_stride = + input_height * input_width * convolution_op->input_pixel_stride + << log2_input_element_size, + .bc_stride = output_size * convolution_op->output_pixel_stride + << log2_output_element_size, .log2_csize = log2_output_element_size, .ukernel = igemm_ukernel, }; diff --git a/src/operators/dynamic-fully-connected-nc.c b/src/operators/dynamic-fully-connected-nc.c index cf5be914310..c339fb7d17a 100644 --- a/src/operators/dynamic-fully-connected-nc.c +++ b/src/operators/dynamic-fully-connected-nc.c @@ -376,13 +376,15 @@ static enum xnn_status reshape_dynamic_fully_connected_nc( } dynamic_fully_connected_op->context.gemm.gemm.gemm = (struct gemm_context){ - .k_scaled = input_channels << log2_input_element_size, - .w_stride = bias_element_size + (round_up_po2(input_channels, kr * sr) << log2_input_element_size), - .a_stride = input_stride << log2_input_element_size, - .cm_stride = output_stride << log2_output_element_size, - .cn_stride = nr << log2_output_element_size, - .log2_csize = log2_output_element_size, - .ukernel = gemm_ukernel, + .k_scaled = input_channels << log2_input_element_size, + .w_stride = bias_element_size + (round_up_po2(input_channels, kr * sr) + << log2_input_element_size), + .a_stride = input_stride << log2_input_element_size, + .cm_stride = output_stride << log2_output_element_size, + .cn_stride = nr << log2_output_element_size, + .log2_csize = log2_output_element_size, + .ukernel = gemm_ukernel, + .mr = mr, }; memcpy(&dynamic_fully_connected_op->context.gemm.gemm.gemm.params, params, params_size); dynamic_fully_connected_op->context.gemm.gemm.gemm.fused_params = &dynamic_fully_connected_op->context.gemm.gemm.gemm.params; @@ -400,13 +402,17 @@ static enum xnn_status reshape_dynamic_fully_connected_nc( dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d_with_uarch; dynamic_fully_connected_op->compute[1].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; } else { - dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d; - dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; + dynamic_fully_connected_op->compute[1].type = + xnn_parallelization_type_2d_tile_2d_dynamic; + dynamic_fully_connected_op->compute[1].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_gemm; } #else - dynamic_fully_connected_op->compute[1].type = xnn_parallelization_type_2d_tile_2d; - dynamic_fully_connected_op->compute[1].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; - #endif + dynamic_fully_connected_op->compute[1].type = + xnn_parallelization_type_2d_tile_2d_dynamic; + dynamic_fully_connected_op->compute[1].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_gemm; +#endif dynamic_fully_connected_op->compute[1].range[0] = batch_size; dynamic_fully_connected_op->compute[1].range[1] = output_channels; dynamic_fully_connected_op->compute[1].tile[0] = mr; diff --git a/src/operators/fully-connected-nc.c b/src/operators/fully-connected-nc.c index 03e2a75f7be..90b750a4900 100644 --- a/src/operators/fully-connected-nc.c +++ b/src/operators/fully-connected-nc.c @@ -2285,25 +2285,31 @@ static enum xnn_status reshape_fully_connected_nc( fully_connected_op->compute[0].task_2d_tile_2d_with_id = (pthreadpool_task_2d_tile_2d_with_id_t) xnn_compute_hmp_gemm; } } else { - fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; + fully_connected_op->compute[0].type = + xnn_parallelization_type_2d_tile_2d_dynamic; if (dynamic_quantization) { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; + fully_connected_op->compute[0].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_dqgemm; } else if (is_qp8_ukernel) { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; + fully_connected_op->compute[0].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_qp8gemm; } else { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; + fully_connected_op->compute[0].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_gemm; } } #else - fully_connected_op->compute[0].type = xnn_parallelization_type_2d_tile_2d; - if (dynamic_quantization) { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_dqgemm; + fully_connected_op->compute[0].type = + xnn_parallelization_type_2d_tile_2d_dynamic; + if (dynamic_quantization) { + fully_connected_op->compute[0].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_dqgemm; } else if (is_qp8_ukernel) { - fully_connected_op->compute[0].task_2d_tile_2d = - (pthreadpool_task_2d_tile_2d_t)xnn_compute_qp8gemm; + fully_connected_op->compute[0].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_qp8gemm; } else { - fully_connected_op->compute[0].task_2d_tile_2d = (pthreadpool_task_2d_tile_2d_t) xnn_compute_gemm; + fully_connected_op->compute[0].task_2d_tile_2d_dynamic = + (pthreadpool_task_2d_tile_2d_dynamic_t)xnn_compute_gemm; } #endif fully_connected_op->compute[0].range[0] = batch_size; diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S index b0225fc5f4f..21ca0f84253 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x16-minmax-asm-amd64-avx512vnni.S @@ -177,7 +177,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x16c4__asm_amd64_avx512vnn vmovups zmmword ptr [rsp + 1040], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S index f0cd38a4acc..a6ae6619c07 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-10x32-minmax-asm-amd64-avx512vnni.S @@ -177,7 +177,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_10x32c4__asm_amd64_avx512vnn vmovups zmmword ptr [rsp + 1040], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S index 638c6dc5671..4d66ebfc7d2 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x16-minmax-asm-amd64-avx512vnni.S @@ -192,7 +192,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x16c4__asm_amd64_avx512vnn vmovups zmmword ptr [rsp + 1104], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S index aad8d973fe7..a8bf41586bc 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-11x32-minmax-asm-amd64-avx512vnni.S @@ -192,7 +192,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_11x32c4__asm_amd64_avx512vnn vmovups zmmword ptr [rsp + 1104], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S index 41b1fcb9ef1..85a7c8a52be 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-aarch64-neondot-ld32.S @@ -25,8 +25,8 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_aarch64_neondot_ outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with k_sum * input zero point. ldr q10, [x24] ldp q2, q3, [x5, 0] @@ -38,15 +38,14 @@ outer_loop: add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] + ldr s2, [x3], 4 ldp q6, q7, [x5], 32 ldp q8, q9, [x5], 32 sdot v12.4s, v6.16b, v2.4b[0] sdot v13.4s, v7.16b, v2.4b[0] sdot v14.4s, v8.16b, v2.4b[0] sdot v15.4s, v9.16b, v2.4b[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Convert from int32 to float. @@ -92,6 +91,7 @@ inner_loop: b.lo tail_8 stp q12, q13, [x6], 32 stp q14, q15, [x6], 32 + sub x3, x3, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S index 6d8dc3d5e2d..ac491a5dde1 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x16-minmax-asm-amd64-avx512vnni.S @@ -38,7 +38,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 464], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S index 40f098232ed..0bef4bbf883 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x32-minmax-asm-amd64-avx512vnni.S @@ -38,7 +38,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 464], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S index 864746630bd..93737fb8207 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x64-minmax-asm-amd64-avx512vnni.S @@ -38,7 +38,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x64c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 464], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S index 91aea6ddc51..f74e5b72625 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-1x8-minmax-asm-aarch64-neondot-ld32.S @@ -25,8 +25,8 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_1x8c4__asm_aarch64_neondot_l outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with k_sum * input zero point. ldr q10, [x24] ldp q2, q3, [x5, 0] @@ -35,12 +35,11 @@ outer_loop: add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] + ldr s2, [x3], 4 ldp q6, q7, [x5], 32 sdot v12.4s, v6.16b, v2.4b[0] sdot v13.4s, v7.16b, v2.4b[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Convert from int32 to float. @@ -71,6 +70,7 @@ inner_loop: cmp x1, 8 b.lo tail_4 stp q12, q13, [x6], 32 + sub x3, x3, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S index 68d29de372d..466642b8f27 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-aarch64-neondot-ld32.S @@ -32,8 +32,8 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_aarch64_neondot_ csel x13, x6, x13, LO outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with k_sum * input zero point. ldp q10, q11, [x24] ldp q2, q3, [x5, 0] @@ -49,8 +49,8 @@ outer_loop: add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 ldp q6, q7, [x5], 32 ldp q8, q9, [x5], 32 sdot v12.4s, v6.16b, v2.4b[0] @@ -61,8 +61,7 @@ inner_loop: sdot v17.4s, v8.16b, v3.4b[0] sdot v18.4s, v9.16b, v2.4b[0] sdot v19.4s, v9.16b, v3.4b[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Convert from int32 to float. @@ -134,6 +133,8 @@ inner_loop: stp q16, q18, [x6], 32 stp q13, q15, [x13], 32 stp q17, q19, [x13], 32 + sub x3, x3, x2 + sub x9, x9, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S index 34b63a18d39..a1d36fcaa02 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x16-minmax-asm-amd64-avx512vnni.S @@ -50,7 +50,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 528], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S index 955f2135499..e925e817701 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x32-minmax-asm-amd64-avx512vnni.S @@ -50,7 +50,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 528], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S index 9dc74973e55..c7f7e099b88 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x64-minmax-asm-amd64-avx512vnni.S @@ -50,7 +50,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x64c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 528], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S index 3e2a668453b..43b2d05e26e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-2x8-minmax-asm-aarch64-neondot-ld32.S @@ -32,8 +32,8 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_2x8c4__asm_aarch64_neondot_l csel x13, x6, x13, LO outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with k_sum * input zero point. ldp q10, q11, [x24] ldp q2, q3, [x5, 0] @@ -44,15 +44,14 @@ outer_loop: add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 ldp q6, q7, [x5], 32 sdot v12.4s, v6.16b, v2.4b[0] sdot v13.4s, v6.16b, v3.4b[0] sdot v14.4s, v7.16b, v2.4b[0] sdot v15.4s, v7.16b, v3.4b[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Convert from int32 to float. @@ -96,6 +95,8 @@ inner_loop: b.lo tail_4 stp q12, q14, [x6], 32 stp q13, q15, [x13], 32 + sub x3, x3, x2 + sub x9, x9, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S index 4720444c9cd..e0df3501360 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-aarch64-neondot-ld32.S @@ -36,8 +36,8 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_aarch64_neondot_ csel x14, x13, x14, LS outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with k_sum * input zero point. ldp q10, q11, [x24] ldr q10, [x24] @@ -58,9 +58,9 @@ outer_loop: add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 ldp q6, q7, [x5], 32 ldp q8, q9, [x5], 32 sdot v12.4s, v6.16b, v2.4b[0] @@ -75,8 +75,7 @@ inner_loop: sdot v21.4s, v9.16b, v2.4b[0] sdot v22.4s, v9.16b, v3.4b[0] sdot v23.4s, v9.16b, v4.4b[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Convert from int32 to float. @@ -174,6 +173,9 @@ inner_loop: stp q19, q22, [x13], 32 stp q14, q17, [x14], 32 stp q20, q23, [x14], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S index 7d65da61eb9..ff0bedd1102 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x16-minmax-asm-amd64-avx512vnni.S @@ -62,7 +62,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 592], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S index 836ea2aaa28..1231fb19cd4 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x32-minmax-asm-amd64-avx512vnni.S @@ -62,7 +62,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 592], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S index 0382f92d220..6518bd7b0c9 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x64-minmax-asm-amd64-avx512vnni.S @@ -62,7 +62,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x64c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 592], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S index f947587a63e..dc270f1d7d7 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-3x8-minmax-asm-aarch64-neondot-ld32.S @@ -36,8 +36,8 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_3x8c4__asm_aarch64_neondot_l csel x14, x13, x14, LS outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with k_sum * input zero point. ldp q10, q11, [x24] ldr q10, [x24] @@ -51,9 +51,9 @@ outer_loop: add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 ldp q6, q7, [x5], 32 sdot v12.4s, v6.16b, v2.4b[0] sdot v13.4s, v6.16b, v3.4b[0] @@ -61,8 +61,7 @@ inner_loop: sdot v15.4s, v7.16b, v2.4b[0] sdot v16.4s, v7.16b, v3.4b[0] sdot v17.4s, v7.16b, v4.4b[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Convert from int32 to float. @@ -119,6 +118,9 @@ inner_loop: stp q12, q15, [x6], 32 stp q13, q16, [x13], 32 stp q14, q17, [x14], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S index 8663d8c685c..6150aafa155 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-aarch64-neondot-ld32.S @@ -42,8 +42,8 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_aarch64_neondot_ csel x15, x14, x15, LO outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with k_sum * input zero point. ldp q10, q11, [x24] ldp q2, q3, [x5, 0] @@ -67,10 +67,10 @@ outer_loop: add x5, x5, 64 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] - ldr d5, [x11, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 ldp q6, q7, [x5], 32 ldp q8, q9, [x5], 32 sdot v12.4s, v6.16b, v2.4b[0] @@ -89,8 +89,7 @@ inner_loop: sdot v25.4s, v9.16b, v3.4b[0] sdot v26.4s, v9.16b, v4.4b[0] sdot v27.4s, v9.16b, v5.4b[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Convert from int32 to float. @@ -214,6 +213,10 @@ inner_loop: stp q22, q26, [x14], 32 stp q15, q19, [x15], 32 stp q23, q27, [x15], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 sub x1, x1, 16 b.ne outer_loop diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S index 02e801e11dc..088e175bf91 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x16-minmax-asm-amd64-avx512vnni.S @@ -74,7 +74,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 656], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S index 3e7674a87fe..eed325b907c 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x32-minmax-asm-amd64-avx512vnni.S @@ -74,7 +74,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 656], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S index 09d697c5aeb..709c05c9e3d 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x64-minmax-asm-amd64-avx512vnni.S @@ -74,7 +74,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x64c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 656], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S index de28ec7033a..4150079a166 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-4x8-minmax-asm-aarch64-neondot-ld32.S @@ -42,8 +42,8 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_4x8c4__asm_aarch64_neondot_l csel x15, x14, x15, LO outer_loop: - # Zero k counter. - eor x20, x20, x20 + # Initialize k counter. + mov x20, x2 # Initialize accumulators with k_sum * input zero point. ldp q10, q11, [x24] ldp q2, q3, [x5, 0] @@ -58,10 +58,10 @@ outer_loop: add x5, x5, 32 inner_loop: - ldr d2, [x3, x20] - ldr d3, [x9, x20] - ldr d4, [x10, x20] - ldr d5, [x11, x20] + ldr s2, [x3], 4 + ldr s3, [x9], 4 + ldr s4, [x10], 4 + ldr s5, [x11], 4 ldp q6, q7, [x5], 32 sdot v12.4s, v6.16b, v2.4b[0] sdot v13.4s, v6.16b, v3.4b[0] @@ -71,8 +71,7 @@ inner_loop: sdot v17.4s, v7.16b, v3.4b[0] sdot v18.4s, v7.16b, v4.4b[0] sdot v19.4s, v7.16b, v5.4b[0] - add x20, x20, 4 - cmp x2, x20 + subs x20, x20, 4 bne inner_loop # Convert from int32 to float. @@ -142,6 +141,10 @@ inner_loop: stp q13, q17, [x13], 32 stp q14, q18, [x14], 32 stp q15, q19, [x15], 32 + sub x3, x3, x2 + sub x9, x9, x2 + sub x10, x10, x2 + sub x11, x11, x2 sub x1, x1, 8 b.ne outer_loop diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S index 7e6dfb52b2b..87080cb6b9c 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x16-minmax-asm-amd64-avx512vnni.S @@ -86,7 +86,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 720], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S index 505d355c8ab..1fe64f5ee56 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x32-minmax-asm-amd64-avx512vnni.S @@ -86,7 +86,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 720], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S index 68e56d0a1ba..0744f774016 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-5x64-minmax-asm-amd64-avx512vnni.S @@ -86,7 +86,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_5x64c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 720], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Initialize accumulators with k_sum * input zero point. vmovaps zmm6, [r9 + 0] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S index 651c87fedbc..161931d2168 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x16-minmax-asm-amd64-avx512vnni.S @@ -117,7 +117,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 784], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S index 96605c933cd..392f5e81e14 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-6x32-minmax-asm-amd64-avx512vnni.S @@ -117,7 +117,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_6x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 784], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S index 6ce0f8014c4..522dd0d1ab4 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x16-minmax-asm-amd64-avx512vnni.S @@ -132,7 +132,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 848], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S index 1f5c69e519c..dc51096698d 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-7x32-minmax-asm-amd64-avx512vnni.S @@ -132,7 +132,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_7x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 848], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S index 59ad3e9cbac..0f27cae0e2e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x16-minmax-asm-amd64-avx512vnni.S @@ -147,7 +147,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 912], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S index ce9a3bc9cef..80eea889ec2 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-8x32-minmax-asm-amd64-avx512vnni.S @@ -147,7 +147,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_8x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 912], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S index c778316dfdb..023f5837e7e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x16-minmax-asm-amd64-avx512vnni.S @@ -162,7 +162,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x16c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 976], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S index 642dcbecfac..217f56b415e 100644 --- a/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S +++ b/src/qd8-f32-qc8w-gemm/gen/qd8-f32-qc8w-gemm-9x32-minmax-asm-amd64-avx512vnni.S @@ -162,7 +162,7 @@ BEGIN_FUNCTION xnn_qd8_f32_qc8w_gemm_minmax_ukernel_9x32c4__asm_amd64_avx512vnni vmovups zmmword ptr [rsp + 976], zmm6 outer_loop: - # Zero k counter. + # Initialize k counter. mov r11, 0 # Read a pointers from stack into GP registers. mov rsi, [rsp - 128] diff --git a/src/subgraph.c b/src/subgraph.c index 3c0c35c48bc..f9c3c620d65 100644 --- a/src/subgraph.c +++ b/src/subgraph.c @@ -931,9 +931,11 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) break; case xnn_node_type_fully_connected: if (subgraph->values[node->inputs[0]].datatype == xnn_datatype_qdint8 || + subgraph->values[node->inputs[0]].datatype == + xnn_datatype_qduint8 || subgraph->values[node->inputs[0]].datatype == xnn_datatype_qpint8) { - // TODO(b/340399245) - Coerce any `qpint8` values back to `qdint8` for - // conversion to fp16. + // TODO(b/340399245) - Coerce any `qpint8` or `qduint8` values back to + // `qdint8` for conversion to fp16. subgraph->values[node->inputs[0]].datatype = xnn_datatype_qdint8; subgraph->values[node->outputs[0]].fp16_compatible = true; } else if (subgraph->values[node->inputs[0]].datatype == @@ -952,8 +954,16 @@ bool xnn_subgraph_rewrite_for_fp16(xnn_subgraph_t subgraph) subgraph->values[node->inputs[0]].fp16_compatible = true; subgraph->values[node->outputs[0]].fp16_compatible = true; } else { - xnn_log_warning("FP16 rewrite aborted: node #%" PRIu32 " (%s). Invalid compute type", - n, xnn_node_type_to_string(node->type)); + xnn_log_warning( + "FP16 rewrite aborted: node #%" PRIu32 + " (%s). Invalid compute type (input=%s, weights=%s, output=%s)", + n, xnn_node_type_to_string(node->type), + xnn_datatype_to_string( + subgraph->values[node->inputs[0]].datatype), + xnn_datatype_to_string( + subgraph->values[node->inputs[1]].datatype), + xnn_datatype_to_string( + subgraph->values[node->outputs[0]].datatype)); return false; } break; diff --git a/src/xnnpack/compute.h b/src/xnnpack/compute.h index 6fde480accd..b6a88d6d9a1 100644 --- a/src/xnnpack/compute.h +++ b/src/xnnpack/compute.h @@ -3,17 +3,16 @@ // This source code is licensed under the BSD-style license found in the // LICENSE file in the root directory of this source tree. -#pragma once +#ifndef THIRD_PARTY_XNNPACK_SRC_XNNPACK_COMPUTE_H_ +#define THIRD_PARTY_XNNPACK_SRC_XNNPACK_COMPUTE_H_ #include #include #include "xnnpack.h" #include "xnnpack/common.h" -#include "xnnpack/math.h" #include "xnnpack/microfnptr.h" #include "xnnpack/microparams.h" - #include "pthreadpool.h" enum xnn_parallelization_type { @@ -25,10 +24,13 @@ enum xnn_parallelization_type { xnn_parallelization_type_2d_with_thread, xnn_parallelization_type_2d_tile_1d, xnn_parallelization_type_2d_tile_2d, + xnn_parallelization_type_2d_tile_1d_dynamic, + xnn_parallelization_type_2d_tile_2d_dynamic, xnn_parallelization_type_3d, xnn_parallelization_type_3d_tile_1d, xnn_parallelization_type_3d_tile_1d_with_thread, xnn_parallelization_type_3d_tile_2d, + xnn_parallelization_type_3d_tile_2d_dynamic, xnn_parallelization_type_4d, xnn_parallelization_type_4d_tile_2d, xnn_parallelization_type_5d, @@ -54,10 +56,13 @@ struct compute_parameters { pthreadpool_task_2d_with_thread_t task_2d_with_thread; pthreadpool_task_2d_tile_1d_t task_2d_tile_1d; pthreadpool_task_2d_tile_2d_t task_2d_tile_2d; + pthreadpool_task_2d_tile_1d_dynamic_t task_2d_tile_1d_dynamic; + pthreadpool_task_2d_tile_2d_dynamic_t task_2d_tile_2d_dynamic; pthreadpool_task_3d_t task_3d; pthreadpool_task_3d_tile_1d_t task_3d_tile_1d; pthreadpool_task_3d_tile_1d_with_thread_t task_3d_tile_1d_with_thread; pthreadpool_task_3d_tile_2d_t task_3d_tile_2d; + pthreadpool_task_3d_tile_2d_dynamic_t task_3d_tile_2d_dynamic; pthreadpool_task_4d_t task_4d; pthreadpool_task_4d_tile_2d_t task_4d_tile_2d; pthreadpool_task_5d_t task_5d; @@ -67,13 +72,14 @@ struct compute_parameters { pthreadpool_task_2d_tile_1d_with_id_t task_2d_tile_1d_with_id; pthreadpool_task_2d_tile_2d_with_id_t task_2d_tile_2d_with_id; pthreadpool_task_3d_tile_1d_with_id_t task_3d_tile_1d_with_id; - pthreadpool_task_3d_tile_1d_with_id_with_thread_t task_3d_tile_1d_with_id_with_thread; + pthreadpool_task_3d_tile_1d_with_id_with_thread_t + task_3d_tile_1d_with_id_with_thread; pthreadpool_task_3d_tile_2d_with_id_t task_3d_tile_2d_with_id; pthreadpool_task_4d_tile_2d_with_id_t task_4d_tile_2d_with_id; #endif // XNN_MAX_UARCH_TYPES > 1 }; - // Offset of the invocation context w.r.t. xnn_operator.context - // Typically 0, but can be non-zero when an operator does multiple invocations. + // Offset of the invocation context w.r.t. xnn_operator.context. Typically 0, + // but can be non-zero when an operator does multiple invocations. size_t context_offset; size_t range[6]; size_t tile[2]; @@ -91,97 +97,48 @@ struct transpose_context { }; XNN_PRIVATE void xnn_compute_transposec_2d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t tile_i, + const struct transpose_context* context, size_t i, size_t j, size_t tile_i, size_t tile_j); XNN_PRIVATE void xnn_compute_transposec_3d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t k, - size_t tile_j, - size_t tile_k); + const struct transpose_context* context, size_t i, size_t j, size_t k, + size_t tile_j, size_t tile_k); XNN_PRIVATE void xnn_compute_transposec_4d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t k, - size_t l, - size_t tile_k, - size_t tile_l); + const struct transpose_context* context, size_t i, size_t j, size_t k, + size_t l, size_t tile_k, size_t tile_l); XNN_PRIVATE void xnn_compute_transposec_5d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t k, - size_t l, - size_t m, - size_t tile_l, - size_t tile_m); + const struct transpose_context* context, size_t i, size_t j, size_t k, + size_t l, size_t m, size_t tile_l, size_t tile_m); XNN_PRIVATE void xnn_compute_transposec_6d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t k, - size_t l, - size_t m, - size_t n, - size_t tile_m, - size_t tile_n); + const struct transpose_context* context, size_t i, size_t j, size_t k, + size_t l, size_t m, size_t n, size_t tile_m, size_t tile_n); XNN_PRIVATE void xnn_compute_transposev_2d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t tile_i, + const struct transpose_context* context, size_t i, size_t j, size_t tile_i, size_t tile_j); XNN_PRIVATE void xnn_compute_transposev_3d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t k, - size_t tile_j, - size_t tile_k); + const struct transpose_context* context, size_t i, size_t j, size_t k, + size_t tile_j, size_t tile_k); XNN_PRIVATE void xnn_compute_transposev_4d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t k, - size_t l, - size_t tile_k, - size_t tile_l); + const struct transpose_context* context, size_t i, size_t j, size_t k, + size_t l, size_t tile_k, size_t tile_l); XNN_PRIVATE void xnn_compute_transposev_5d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t k, - size_t l, - size_t m, - size_t tile_l, - size_t tile_m); + const struct transpose_context* context, size_t i, size_t j, size_t k, + size_t l, size_t m, size_t tile_l, size_t tile_m); XNN_PRIVATE void xnn_compute_transposev_6d( - const struct transpose_context* context, - size_t i, - size_t j, - size_t k, - size_t l, - size_t m, - size_t n, - size_t tile_m, - size_t tile_n); - -// Context for Packing Weights (packw) for GEMM microkernels in Group-OutputChannels-InputChannels layout. -// Kernel has shape GxNxK, bias has shape GxN. + const struct transpose_context* context, size_t i, size_t j, size_t k, + size_t l, size_t m, size_t n, size_t tile_m, size_t tile_n); + +// Context for Packing Weights (packw) for GEMM microkernels in +// Group-OutputChannels-InputChannels layout. Kernel has shape GxNxK, bias has +// shape GxN. struct packw_gemm_goi_context { // Number of input channels. size_t kc; @@ -211,26 +168,24 @@ struct packw_gemm_goi_context { size_t gc_stride; // Packing params passed to the packing microkernel. - const void *params; + const void* params; // Microkernel to preform packing. xnn_packw_gemm_goi_ukernel_fn packw_gemm_goi; }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_packw_gemm_goi( - const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t n_block_start, - size_t n_block_size); - XNN_PRIVATE void xnn_compute_batched_packw_gemm_goi( - const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t n_block_start, - size_t n_block_size); +XNN_PRIVATE void xnn_compute_packw_gemm_goi( + const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t n_block_start, size_t n_block_size); +XNN_PRIVATE void xnn_compute_batched_packw_gemm_goi( + const struct packw_gemm_goi_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t n_block_start, size_t n_block_size); #endif -// Context for Packing Weights (packw) for GEMM microkernels in Groups-InputChannels-OutputChannels layout. -// Kernel has shape GxKxN, bias has shape GxN. +// Context for Packing Weights (packw) for GEMM microkernels in +// Groups-InputChannels-OutputChannels layout. Kernel has shape GxKxN, bias +// has shape GxN. struct packw_gemm_gio_context { // Number of input channels. size_t kc; @@ -266,15 +221,12 @@ struct packw_gemm_gio_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_packw_gemm_gio( - const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t n_block_start, - size_t n_block_size); - XNN_PRIVATE void xnn_compute_batched_packw_gemm_gio( - const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t n_block_start, - size_t n_block_size); +XNN_PRIVATE void xnn_compute_packw_gemm_gio( + const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t n_block_start, size_t n_block_size); +XNN_PRIVATE void xnn_compute_batched_packw_gemm_gio( + const struct packw_gemm_gio_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t n_block_start, size_t n_block_size); #endif // Context for Dense Matrix Multiplication. @@ -342,118 +294,99 @@ struct gemm_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_grouped_gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t group_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_grouped_qp8gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t group_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_dqgemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_qp8gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, - size_t nr_block_size); +XNN_PRIVATE void xnn_compute_grouped_gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t group_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_grouped_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t group_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_dqgemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, + size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, + size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t mr_block_start, size_t nr_block_start, size_t mr_block_size, + size_t nr_block_size); #if XNN_MAX_UARCH_TYPES > 1 - XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t group_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_grouped_qp8gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t group_index, size_t mr_block_start, - size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_dqgemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t mr_block_start, - size_t nr_block_start, - size_t mr_block_size, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_hmp_qp8gemm( - const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, - size_t mr_block_size, size_t nr_block_size); - #endif // XNN_MAX_UARCH_TYPES > 1 +XNN_PRIVATE void xnn_compute_hmp_grouped_gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t group_index, size_t mr_block_start, + size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_hmp_grouped_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t group_index, size_t mr_block_start, + size_t nr_block_start, size_t mr_block_size, size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_hmp_gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_hmp_dqgemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_hmp_qp8gemm( + const struct gemm_context context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t mr_block_start, size_t nr_block_start, + size_t mr_block_size, size_t nr_block_size); +#endif // XNN_MAX_UARCH_TYPES > 1 #endif - // Context for Sparse Matrix-Dense Matrix Multiplication. - // C [MxN] := A [MxK] * B [KxN] + bias [N] - // A and C are dense matrices with row-major storage, B is a sparse matrix. - struct spmm_context { - // N dimension of the B and C matrices. - // Corresponds to number of output channels in 1x1 convolution. - size_t n; - // M dimension of the A and C matrices, pre-scaled by sizeof(element - // size). Corresponds to the stride, in bytes, between adjacent rows of C - // matrix. - size_t scaled_m; - // Input matrix A. - const void* input; - // Packed bias elements and non-zero filter elements. - const void* nonzero_weights; - // Input pointer increments, in bytes, after each processed non-zero - // weight. - const int32_t* input_increments; - // Number of non-zero filter elements per each N (output channel) - // dimension. - const uint32_t* output_channel_nonzeros; - // Output matrix C. - void* output; - // Stride, in bytes, between matrices A corresponding to different images - // in batched 1x1 Convolution - size_t batched_input_stride; - // Stride, in bytes, between matrices C corresponding to different images - // in batched 1x1 Convolution - size_t batched_output_stride; - // Micro-kernel function pointer. - xnn_spmm_ukernel_fn ukernel; - // Output activation parameters. - union { - union xnn_f32_minmax_params f32; - } params; +// Context for Sparse Matrix-Dense Matrix Multiplication. +// C [MxN] := A [MxK] * B [KxN] + bias [N] +// A and C are dense matrices with row-major storage, B is a sparse matrix. +struct spmm_context { + // N dimension of the B and C matrices. + // Corresponds to number of output channels in 1x1 convolution. + size_t n; + // M dimension of the A and C matrices, pre-scaled by sizeof(element + // size). Corresponds to the stride, in bytes, between adjacent rows of C + // matrix. + size_t scaled_m; + // Input matrix A. + const void* input; + // Packed bias elements and non-zero filter elements. + const void* nonzero_weights; + // Input pointer increments, in bytes, after each processed non-zero + // weight. + const int32_t* input_increments; + // Number of non-zero filter elements per each N (output channel) + // dimension. + const uint32_t* output_channel_nonzeros; + // Output matrix C. + void* output; + // Stride, in bytes, between matrices A corresponding to different images + // in batched 1x1 Convolution + size_t batched_input_stride; + // Stride, in bytes, between matrices C corresponding to different images + // in batched 1x1 Convolution + size_t batched_output_stride; + // Micro-kernel function pointer. + xnn_spmm_ukernel_fn ukernel; + // Output activation parameters. + union { + union xnn_f32_minmax_params f32; + } params; }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_spmm( +XNN_PRIVATE void xnn_compute_spmm( const struct spmm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t mr_block_start, - size_t mr_block_size); + size_t batch_index, size_t mr_block_start, size_t mr_block_size); #endif // Context for initializing the indirection buffer for conv2d igemm. @@ -652,26 +585,17 @@ struct subgemm_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_grouped_subgemm2d( - const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t group_index, - size_t subkernel_index, - size_t slice_y, - size_t slice_x_start, - size_t nc_block_start, - size_t slice_x_max, - size_t nc_block_size); - - XNN_PRIVATE void xnn_compute_subgemm2d( - const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t subkernel_index, - size_t slice_y, - size_t slice_x_start, - size_t nc_block_start, - size_t slice_x_max, - size_t nc_block_size); +XNN_PRIVATE void xnn_compute_grouped_subgemm2d( + const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t group_index, size_t subkernel_index, + size_t slice_y, size_t slice_x_start, size_t nc_block_start, + size_t slice_x_max, size_t nc_block_size); + +XNN_PRIVATE void xnn_compute_subgemm2d( + const struct subgemm_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t subkernel_index, size_t slice_y, + size_t slice_x_start, size_t nc_block_start, size_t slice_x_max, + size_t nc_block_size); #endif struct subconv_context { @@ -707,51 +631,33 @@ struct subconv_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_dq_zero_buffer_subconv( +XNN_PRIVATE void xnn_compute_dq_zero_buffer_subconv( const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], size_t size); - XNN_PRIVATE void xnn_compute_grouped_subconv2d( - const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t group_index, - size_t subkernel_index, - size_t slice_y, - size_t slice_x_start, - size_t nr_block_start, - size_t slice_x_max, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_grouped_dqsubconv2d( +XNN_PRIVATE void xnn_compute_grouped_subconv2d( const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t group_index, - size_t subkernel_index, - size_t slice_y, - size_t slice_x_start, - size_t nr_block_start, - size_t slice_x_max, + size_t batch_index, size_t group_index, size_t subkernel_index, + size_t slice_y, size_t slice_x_start, size_t nr_block_start, + size_t slice_x_max, size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_grouped_dqsubconv2d( + const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t group_index, size_t subkernel_index, + size_t slice_y, size_t slice_x_start, size_t nr_block_start, + size_t slice_x_max, size_t nr_block_size); + +XNN_PRIVATE void xnn_compute_subconv2d( + const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t subkernel_index, size_t slice_y, + size_t slice_x_start, size_t nr_block_start, size_t slice_x_max, size_t nr_block_size); - XNN_PRIVATE void xnn_compute_subconv2d( - const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t subkernel_index, - size_t slice_y, - size_t slice_x_start, - size_t nr_block_start, - size_t slice_x_max, - size_t nr_block_size); - - XNN_PRIVATE void xnn_compute_dqsubconv2d( - const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t subkernel_index, - size_t slice_y, - size_t slice_x_start, - size_t nr_block_start, - size_t slice_x_max, - size_t nr_block_size); +XNN_PRIVATE void xnn_compute_dqsubconv2d( + const struct subconv_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t subkernel_index, size_t slice_y, + size_t slice_x_start, size_t nr_block_start, size_t slice_x_max, + size_t nr_block_size); #endif struct conv2d_context { @@ -776,11 +682,9 @@ struct conv2d_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_conv2d_hwc2chw( - const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y_start, - size_t output_y_slice); +XNN_PRIVATE void xnn_compute_conv2d_hwc2chw( + const struct conv2d_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y_start, size_t output_y_slice); #endif // Context for initializing the indirection buffer for dwconv. @@ -839,23 +743,19 @@ struct dwconv_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_dwconv_indirection( - const struct dwconv_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t output_y_start, - size_t output_y_tile); - XNN_PRIVATE void xnn_compute_dwconv_unipass( - const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); - XNN_PRIVATE void xnn_compute_dwconv_multipass( - const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); - XNN_PRIVATE void xnn_compute_dwconv_multipass_with_thread( - const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t thread_index, - size_t batch_index, - size_t output_y); +XNN_PRIVATE void xnn_compute_dwconv_indirection( + const struct dwconv_indirection_init_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t output_y_start, size_t output_y_tile); +XNN_PRIVATE void xnn_compute_dwconv_unipass( + const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); +XNN_PRIVATE void xnn_compute_dwconv_multipass( + const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); +XNN_PRIVATE void xnn_compute_dwconv_multipass_with_thread( + const struct dwconv_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t thread_index, size_t batch_index, size_t output_y); #endif struct dwconv2d_context { @@ -880,10 +780,9 @@ struct dwconv2d_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_dwconv2d_chw( - const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t channel); +XNN_PRIVATE void xnn_compute_dwconv2d_chw( + const struct dwconv2d_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t channel); #endif struct max_pooling_context { @@ -909,10 +808,9 @@ struct max_pooling_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_max_pooling( - const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); +XNN_PRIVATE void xnn_compute_max_pooling( + const struct max_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); #endif struct unpooling_context { @@ -932,10 +830,9 @@ struct unpooling_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_unpooling( - const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t input_y, - size_t input_x); +XNN_PRIVATE void xnn_compute_unpooling( + const struct unpooling_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t input_y, size_t input_x); #endif struct argmax_pooling_context { @@ -967,23 +864,19 @@ struct argmax_pooling_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_argmax_pooling_unipass( - const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); - - // Workspace sized based on batch size * output height. - XNN_PRIVATE void xnn_compute_argmax_pooling_multipass( - const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); - - // Workspace sized based on number of threads. - XNN_PRIVATE void xnn_compute_argmax_pooling_multipass_with_thread( - const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t thread_index, - size_t batch_index, - size_t output_y); +XNN_PRIVATE void xnn_compute_argmax_pooling_unipass( + const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); + +// Workspace sized based on batch size * output height. +XNN_PRIVATE void xnn_compute_argmax_pooling_multipass( + const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); + +// Workspace sized based on number of threads. +XNN_PRIVATE void xnn_compute_argmax_pooling_multipass_with_thread( + const struct argmax_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t thread_index, size_t batch_index, size_t output_y); #endif struct average_pooling_context { @@ -992,11 +885,14 @@ struct average_pooling_context { size_t input_offset; size_t input_batch_stride; - // Stride to get to the next y of input. Used when we have compressed indirection buffers (i.e. indirection buffers - // contain only pointers to the first row of input). + // Stride to get to the next y of input. Used when we have compressed + // indirection buffers (i.e. indirection buffers contain only pointers to the + // first row of input). size_t input_y_stride; - size_t indirect_top_height; // Number of output rows that form the top section of indirection buffer. - size_t indirect_bot_start; // Smallest output row y for the bottom section of indirection buffer. + size_t indirect_top_height; // Number of output rows that form the top + // section of indirection buffer. + size_t indirect_bot_start; // Smallest output row y for the bottom section of + // indirection buffer. void* output; size_t output_batch_stride; @@ -1022,21 +918,17 @@ struct average_pooling_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_average_pooling_unipass( - const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); - - XNN_PRIVATE void xnn_compute_average_pooling_multipass( - const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); - - XNN_PRIVATE void xnn_compute_average_pooling_multipass_with_thread( - const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t thread_index, - size_t batch_index, - size_t output_y); +XNN_PRIVATE void xnn_compute_average_pooling_unipass( + const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); + +XNN_PRIVATE void xnn_compute_average_pooling_multipass( + const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); + +XNN_PRIVATE void xnn_compute_average_pooling_multipass_with_thread( + const struct average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t thread_index, size_t batch_index, size_t output_y); #endif struct pixelwise_average_pooling_context { @@ -1045,11 +937,14 @@ struct pixelwise_average_pooling_context { size_t input_offset; size_t input_batch_stride; - // Stride to get to the next y of input. Used when we have compressed indirection buffers (i.e. indirection buffers - // contain only pointers to the first row of input). + // Stride to get to the next y of input. Used when we have compressed + // indirection buffers (i.e. indirection buffers contain only pointers to the + // first row of input). size_t input_y_stride; - size_t indirect_top_height; // Number of output rows that form the top section of indirection buffer. - size_t indirect_bot_start; // Smallest output row y for the bottom section of indirection buffer. + size_t indirect_top_height; // Number of output rows that form the top + // section of indirection buffer. + size_t indirect_bot_start; // Smallest output row y for the bottom section of + // indirection buffer. const void* pixelwise_buffer; size_t pixelwise_buffer_height_stride; @@ -1077,21 +972,20 @@ struct pixelwise_average_pooling_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass( - const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); - - XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass( - const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t output_y); - - XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass_with_thread( - const struct pixelwise_average_pooling_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t thread_index, - size_t batch_index, - size_t output_y); +XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_unipass( + const struct pixelwise_average_pooling_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); + +XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass( + const struct pixelwise_average_pooling_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t output_y); + +XNN_PRIVATE void xnn_compute_pixelwise_average_pooling_multipass_with_thread( + const struct pixelwise_average_pooling_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t thread_index, size_t batch_index, size_t output_y); #endif struct resize_bilinear_nhwc_indirection_init_context { @@ -1158,20 +1052,17 @@ struct resize_bilinear_chw_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_resize_bilinear_indirection( - const struct resize_bilinear_nhwc_indirection_init_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t output_y_start, - size_t output_y_tile); - XNN_PRIVATE void xnn_compute_resize_bilinear( - const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t pixel_start, - size_t pixel_range); - XNN_PRIVATE void xnn_compute_resize_bilinear_chw( - const struct resize_bilinear_chw_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t pixel_start, - size_t pixel_range); +XNN_PRIVATE void xnn_compute_resize_bilinear_indirection( + const struct resize_bilinear_nhwc_indirection_init_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t output_y_start, size_t output_y_tile); +XNN_PRIVATE void xnn_compute_resize_bilinear( + const struct resize_bilinear_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t pixel_start, size_t pixel_range); +XNN_PRIVATE void xnn_compute_resize_bilinear_chw( + const struct resize_bilinear_chw_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t pixel_start, size_t pixel_range); #endif struct elementwise_binary_context { @@ -1188,24 +1079,30 @@ struct elementwise_binary_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_elementwise_binary_1d_tile( - const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t offset, size_t tile); - XNN_PRIVATE void xnn_compute_elementwise_binary_1d( - const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i); - XNN_PRIVATE void xnn_compute_elementwise_binary_2d( - const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j); - XNN_PRIVATE void xnn_compute_elementwise_binary_3d( - const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j, size_t k); - XNN_PRIVATE void xnn_compute_elementwise_binary_4d( - const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j, size_t k, size_t l); - XNN_PRIVATE void xnn_compute_elementwise_binary_5d( - const struct elementwise_binary_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j, size_t k, size_t l, size_t m); +XNN_PRIVATE void xnn_compute_elementwise_binary_1d_tile( + const struct elementwise_binary_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t offset, size_t tile); +XNN_PRIVATE void xnn_compute_elementwise_binary_1d( + const struct elementwise_binary_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t i); +XNN_PRIVATE void xnn_compute_elementwise_binary_2d( + const struct elementwise_binary_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j); +XNN_PRIVATE void xnn_compute_elementwise_binary_3d( + const struct elementwise_binary_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j, size_t k); +XNN_PRIVATE void xnn_compute_elementwise_binary_4d( + const struct elementwise_binary_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j, size_t k, size_t l); +XNN_PRIVATE void xnn_compute_elementwise_binary_5d( + const struct elementwise_binary_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t i, size_t j, size_t k, size_t l, size_t m); #endif struct lut_strided_context { @@ -1219,9 +1116,9 @@ struct lut_strided_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_lut_strided( - const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index); +XNN_PRIVATE void xnn_compute_lut_strided( + const struct lut_strided_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); #endif struct lut_contiguous_context { @@ -1234,10 +1131,9 @@ struct lut_contiguous_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_lut_contiguous( - const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t offset, - size_t size); +XNN_PRIVATE void xnn_compute_lut_contiguous( + const struct lut_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t offset, size_t size); #endif struct univector_strided_context { @@ -1251,10 +1147,10 @@ struct univector_strided_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_univector_strided( - const struct univector_strided_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t batch_range); +XNN_PRIVATE void xnn_compute_univector_strided( + const struct univector_strided_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t batch_range); #endif struct univector_contiguous_context { @@ -1267,10 +1163,10 @@ struct univector_contiguous_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_univector_contiguous( - const struct univector_contiguous_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t offset, - size_t size); +XNN_PRIVATE void xnn_compute_univector_contiguous( + const struct univector_contiguous_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t offset, size_t size); #endif struct reduce_context { @@ -1294,27 +1190,21 @@ struct reduce_context { }; #ifndef __cplusplus -// Compute contiguous reduction over the 1st, 3rd and 5th dimensions of the input -// tensor. - XNN_PRIVATE void xnn_compute_contiguous_reduce( - const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t output_idx0, - size_t output_idx1, - size_t output_idx2, - size_t output1_block_size, - size_t output2_block_size); +// Compute contiguous reduction over the 1st, 3rd and 5th dimensions of the +// input tensor. +XNN_PRIVATE void xnn_compute_contiguous_reduce( + const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t output_idx0, size_t output_idx1, size_t output_idx2, + size_t output1_block_size, size_t output2_block_size); #endif #ifndef __cplusplus -// Compute discontiguous reduction over the 0st, 2rd and 4th dimensions of the input -// tensor. - XNN_PRIVATE void xnn_compute_discontiguous_reduce( - const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t output_idx0, - size_t output_idx1, - size_t output_idx2, - size_t output1_block_size, - size_t output2_block_size); +// Compute discontiguous reduction over the 0st, 2rd and 4th dimensions of the +// input tensor. +XNN_PRIVATE void xnn_compute_discontiguous_reduce( + const struct reduce_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t output_idx0, size_t output_idx1, size_t output_idx2, + size_t output1_block_size, size_t output2_block_size); #endif struct vmulcaddc_context { @@ -1332,10 +1222,9 @@ struct vmulcaddc_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_vmulcaddc( - const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_start, - size_t batch_size); +XNN_PRIVATE void xnn_compute_vmulcaddc( + const struct vmulcaddc_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_start, size_t batch_size); #endif struct pad_context { @@ -1353,9 +1242,9 @@ struct pad_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_pad_5d( - const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j, size_t k, size_t l, size_t m); +XNN_PRIVATE void xnn_compute_pad_5d( + const struct pad_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, + size_t j, size_t k, size_t l, size_t m); #endif struct slice_context { @@ -1370,21 +1259,20 @@ struct slice_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_slice_1d( - const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i); - XNN_PRIVATE void xnn_compute_slice_2d( - const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j); - XNN_PRIVATE void xnn_compute_slice_3d( - const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j, size_t k); - XNN_PRIVATE void xnn_compute_slice_4d( - const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j, size_t k, size_t l); - XNN_PRIVATE void xnn_compute_slice_5d( - const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t i, size_t j, size_t k, size_t l, size_t m); +XNN_PRIVATE void xnn_compute_slice_1d( + const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i); +XNN_PRIVATE void xnn_compute_slice_2d( + const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, + size_t j); +XNN_PRIVATE void xnn_compute_slice_3d( + const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, + size_t j, size_t k); +XNN_PRIVATE void xnn_compute_slice_4d( + const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, + size_t j, size_t k, size_t l); +XNN_PRIVATE void xnn_compute_slice_5d( + const struct slice_context context[restrict XNN_MIN_ELEMENTS(1)], size_t i, + size_t j, size_t k, size_t l, size_t m); #endif struct f16_qd8_convert_context { @@ -1420,81 +1308,79 @@ struct f32_qd8_convert_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_f16_qd8_convert( - const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index); - - XNN_PRIVATE void xnn_compute_f16_qdu8_convert( - const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index); - - XNN_PRIVATE void xnn_compute_f32_qd8_convert( - const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index); - - XNN_PRIVATE void xnn_compute_f32_qdu8_convert( - const struct f32_qd8_convert_context - context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index); - - XNN_PRIVATE void xnn_compute_pad_qd8_params( - const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index); +XNN_PRIVATE void xnn_compute_f16_qd8_convert( + const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); + +XNN_PRIVATE void xnn_compute_f16_qdu8_convert( + const struct f16_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); + +XNN_PRIVATE void xnn_compute_f32_qd8_convert( + const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); + +XNN_PRIVATE void xnn_compute_f32_qdu8_convert( + const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); + +XNN_PRIVATE void xnn_compute_pad_qd8_params( + const struct f32_qd8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); #endif - struct x32_pack_lh_context { - size_t m; - size_t k; - size_t mr; - size_t kr; - size_t sr; - const float* XNN_RESTRICT lhs; - size_t lhs_stride; - float* XNN_RESTRICT lhs_packed; - xnn_x32_pack_lh_ukernel_fn pack_lh_ukernel; - }; +struct x32_pack_lh_context { + size_t m; + size_t k; + size_t mr; + size_t kr; + size_t sr; + const float* XNN_RESTRICT lhs; + size_t lhs_stride; + float* XNN_RESTRICT lhs_packed; + xnn_x32_pack_lh_ukernel_fn pack_lh_ukernel; +}; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_x32_pack_lh( - const struct x32_pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t m_idx_start, size_t tile); +XNN_PRIVATE void xnn_compute_x32_pack_lh( + const struct x32_pack_lh_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t m_idx_start, size_t tile); #endif - struct f32_qp8_convert_context { - size_t m; - size_t k; - size_t mr; - size_t kr; - size_t sr; - size_t group_stride; - const float* XNN_RESTRICT lhs; - size_t lhs_stride; - int8_t* XNN_RESTRICT lhs_packed; - xnn_x8_packq_f32qp8_ukernel_fn packq_ukernel; - }; +struct f32_qp8_convert_context { + size_t m; + size_t k; + size_t mr; + size_t kr; + size_t sr; + size_t group_stride; + const float* XNN_RESTRICT lhs; + size_t lhs_stride; + int8_t* XNN_RESTRICT lhs_packed; + xnn_x8_packq_f32qp8_ukernel_fn packq_ukernel; +}; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_f32_qp8_convert( - const struct f32_qp8_convert_context - context[restrict XNN_MIN_ELEMENTS(1)], - size_t group_idx, size_t m_idx_start, size_t m_tile); +XNN_PRIVATE void xnn_compute_f32_qp8_convert( + const struct f32_qp8_convert_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t group_idx, size_t m_idx_start, size_t m_tile); #endif - struct u8_softmax_context { - size_t n; - const uint8_t* x; - size_t x_stride; - const uint32_t* t; - uint8_t* y; - size_t y_stride; - xnn_u8_rmax_ukernel_fn rmax_ukernel; - xnn_u8_lut32norm_ukernel_fn lut_norm_ukernel; - }; +struct u8_softmax_context { + size_t n; + const uint8_t* x; + size_t x_stride; + const uint32_t* t; + uint8_t* y; + size_t y_stride; + xnn_u8_rmax_ukernel_fn rmax_ukernel; + xnn_u8_lut32norm_ukernel_fn lut_norm_ukernel; +}; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_u8_softmax( - const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index); +XNN_PRIVATE void xnn_compute_u8_softmax( + const struct u8_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); #endif typedef void (*xnn_compute_reciprocal_fn)(const void* input, void* output); @@ -1524,9 +1410,10 @@ struct floating_point_softmax_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_floating_point_softmax( - const struct floating_point_softmax_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index); +XNN_PRIVATE void xnn_compute_floating_point_softmax( + const struct floating_point_softmax_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index); #endif struct rope_context { @@ -1544,11 +1431,9 @@ struct rope_context { }; #ifndef __cplusplus - XNN_PRIVATE void xnn_compute_rope( - const struct rope_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t head_index, - size_t sequence_index); +XNN_PRIVATE void xnn_compute_rope( + const struct rope_context context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t head_index, size_t sequence_index); #endif struct attention_logits_cap { @@ -1655,8 +1540,8 @@ struct scaled_dot_product_attention_context { // - packed keys // - packed values // - output of Q * K (known as logits) - // These are the offsets into the workspace that can be used to read/write the intermediates. - // These are set during reshape, and then used during setup. + // These are the offsets into the workspace that can be used to read/write the + // intermediates. These are set during reshape, and then used during setup. size_t scaled_query_offset; size_t packed_k_offset; size_t packed_v_offset; @@ -1664,37 +1549,32 @@ struct scaled_dot_product_attention_context { }; #ifndef __cplusplus - // We have 4 variations of compute scaled dot product attention: - // 1. micro-architecture aware and not micro-architecture aware - // 2. whether the workspace size is based on batch_size or number of heads. - // The workspace size is chosen based on which one requires a smaller memory allocation for workspace. - // Batch size (times query heads and query tokens) is compared to number of threads (times MR). - XNN_PRIVATE void xnn_compute_scaled_dot_product_attention( - const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t batch_index, - size_t head_index, - size_t tokens_start, - size_t tokens_block_size); - XNN_PRIVATE void xnn_compute_scaled_dot_product_attention_with_thread( - const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)], - size_t thread_index, - size_t batch_index, - size_t head_index, - size_t tokens_start, - size_t tokens_block_size); - XNN_PRIVATE void xnn_compute_hmp_scaled_dot_product_attention( - const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t batch_index, - size_t head_index, - size_t tokens_start, - size_t tokens_block_size); - XNN_PRIVATE void xnn_compute_hmp_scaled_dot_product_attention_with_thread( - const struct scaled_dot_product_attention_context context[restrict XNN_MIN_ELEMENTS(1)], - uint32_t uarch_index, - size_t thread_index, - size_t batch_index, - size_t head_index, - size_t tokens_start, - size_t tokens_block_size); +// We have 4 variations of compute scaled dot product attention: +// 1. micro-architecture aware and not micro-architecture aware +// 2. whether the workspace size is based on batch_size or number of heads. +// The workspace size is chosen based on which one requires a smaller memory +// allocation for workspace. Batch size (times query heads and query tokens) is +// compared to number of threads (times MR). +XNN_PRIVATE void xnn_compute_scaled_dot_product_attention( + const struct scaled_dot_product_attention_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t batch_index, size_t head_index, size_t tokens_start, + size_t tokens_block_size); +XNN_PRIVATE void xnn_compute_scaled_dot_product_attention_with_thread( + const struct scaled_dot_product_attention_context + context[restrict XNN_MIN_ELEMENTS(1)], + size_t thread_index, size_t batch_index, size_t head_index, + size_t tokens_start, size_t tokens_block_size); +XNN_PRIVATE void xnn_compute_hmp_scaled_dot_product_attention( + const struct scaled_dot_product_attention_context + context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t batch_index, size_t head_index, + size_t tokens_start, size_t tokens_block_size); +XNN_PRIVATE void xnn_compute_hmp_scaled_dot_product_attention_with_thread( + const struct scaled_dot_product_attention_context + context[restrict XNN_MIN_ELEMENTS(1)], + uint32_t uarch_index, size_t thread_index, size_t batch_index, + size_t head_index, size_t tokens_start, size_t tokens_block_size); #endif + +#endif // THIRD_PARTY_XNNPACK_SRC_XNNPACK_COMPUTE_H_ diff --git a/test/BUILD.bazel b/test/BUILD.bazel index c95c2c9ebb7..608ef462d97 100644 --- a/test/BUILD.bazel +++ b/test/BUILD.bazel @@ -976,7 +976,7 @@ xnnpack_unit_test( xnnpack_unit_test( name = "qs8_qc8w_igemm_minmax_fp32_test", - timeout = "moderate", + timeout = "long", srcs = [ "qs8-qc8w-igemm-minmax-fp32.cc", "qs8-qc8w-igemm-minmax-fp32-2.cc", diff --git a/test/fully-connected.cc b/test/fully-connected.cc index 55de4b9e199..8871c3e8755 100644 --- a/test/fully-connected.cc +++ b/test/fully-connected.cc @@ -2855,7 +2855,7 @@ TEST_F(FullyConnectedTestQD8F16QC8W, return xnn_quantization_params{w8dist(rng), f32dist(rng)}; }); std::generate(convert_input.begin(), convert_input.end(), - [&]() { return f32dist(rng); }); + [&]() { return xnn_float16_from_float(f32dist(rng)); }); const float output_min = -std::numeric_limits::infinity(); const float output_max = std::numeric_limits::infinity(); diff --git a/test/reduce-nd.cc b/test/reduce-nd.cc index 2e58d3d5cff..cd165cccf94 100644 --- a/test/reduce-nd.cc +++ b/test/reduce-nd.cc @@ -193,7 +193,7 @@ class ReduceOperatorTester { using StorageType = float; using AccumulatorType = double; - static double GetTolerance() { return 3e-6; } + static double GetTolerance() { return 5e-6; } static xnn_datatype GetXNNDatatype() { return xnn_datatype_fp32; }; static std::uniform_real_distribution BuildRngDistribution() { diff --git a/test/workspace.cc b/test/workspace.cc index 6c0040fe902..5d565cb0eef 100644 --- a/test/workspace.cc +++ b/test/workspace.cc @@ -876,6 +876,7 @@ TEST(WORKSPACE, internally_allocated_dynamic_quantization_parameters) const xnn_value* value = &runtime->values[i]; switch (value->datatype) { case xnn_datatype_qdint8: + case xnn_datatype_qduint8: ASSERT_NE(value->quantization.dynamic_params, nullptr); XNN_FALLTHROUGH; case xnn_datatype_qpint8: diff --git a/tools/generate-gemm-test.py b/tools/generate-gemm-test.py index bd7ffd8794b..7a31c51f278 100755 --- a/tools/generate-gemm-test.py +++ b/tools/generate-gemm-test.py @@ -960,7 +960,7 @@ def main(args): bench_outputs += """\n #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif """ diff --git a/tools/generate-rdsum-benchmark.py b/tools/generate-rdsum-benchmark.py index 0f62538396c..fcfecd882d5 100755 --- a/tools/generate-rdsum-benchmark.py +++ b/tools/generate-rdsum-benchmark.py @@ -130,7 +130,7 @@ def main(args): # Footer with `main` function. benches += "\n\n" + """\ #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif """ diff --git a/tools/generate-spmm-test.py b/tools/generate-spmm-test.py index 9a5f3f14836..8b9534274a5 100755 --- a/tools/generate-spmm-test.py +++ b/tools/generate-spmm-test.py @@ -520,7 +520,7 @@ def main(args): bench_outputs += """\n #ifndef XNNPACK_BENCHMARK_NO_MAIN -BENCHMARK_MAIN(); +XNN_BENCHMARK_MAIN(); #endif """ for output_name in options.output_test: