From 2dda2c26d6fd1cbc68108292d192854063d61797 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 6 Jun 2023 09:45:40 +0300 Subject: [PATCH 01/44] +add AVX2 optimizations of function DescrIntDecode16f. --- src/Simd/SimdAvx2DescrInt.cpp | 74 +++++++++++++++++++++++++++++++++++ 1 file changed, 74 insertions(+) diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 36ddf2db5b..3a6baf81d7 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -363,6 +363,77 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C6_SHFL), C6_MULLO), 10); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); + src += 12; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s6 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 6; + dst += 8; + } + } + + static void Decode16f7(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C7_SHFL), C7_MULLO), 9); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); + src += 14; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s7 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 7; + dst += 8; + } + } + + static void Decode16f8(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); + _mm_storeu_si128((__m128i*)(dst + i) + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(u8)), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)(dst + i) + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_srli_si128(u8, 8))), _scale, _shift), 0)); + } + for (; i < size; i += 8) + { + __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); + _mm_storeu_si128((__m128i*)(dst + i), _mm256_cvtps_ph(_mm256_fmadd_ps(_src, _scale, _shift), 0)); + } + } + + //------------------------------------------------------------------------------------------------- + template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); template<> int32_t Correlation<6>(const uint8_t* a, const uint8_t* b, size_t size) @@ -796,6 +867,7 @@ namespace Simd _encode32f = Encode32f6; _encode16f = Encode16f6; _decode32f = Decode32f6; + _decode16f = Decode16f6; _cosineDistance = Avx2::CosineDistance<6>; _macroCosineDistances = Avx2::MacroCosineDistances<6>; break; @@ -805,6 +877,7 @@ namespace Simd _encode32f = Encode32f7; _encode16f = Encode16f7; _decode32f = Decode32f7; + _decode16f = Decode16f7; _cosineDistance = Avx2::CosineDistance<7>; _macroCosineDistances = Avx2::MacroCosineDistances<7>; break; @@ -814,6 +887,7 @@ namespace Simd _encode32f = Encode32f8; _encode16f = Encode16f8; _decode32f = Decode32f8; + _decode16f = Decode16f8; _cosineDistance = Avx2::CosineDistance<8>; _macroCosineDistances = Avx2::MacroCosineDistances<8>; break; From 4941316c9e6b9584c1d14f85aed71c1863a88f9b Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 6 Jun 2023 16:47:48 +0300 Subject: [PATCH 02/44] +add AVX-512BW optimizations of functions DescrIntEncode16f, DescrIntDecode16f. --- src/Simd/SimdAvx512bwDescrInt.cpp | 298 ++++++++++++++++++++++-------- src/Simd/SimdDescrIntCommon.h | 51 +++++ 2 files changed, 276 insertions(+), 73 deletions(-) diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index 16766c0769..93bde177f0 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -59,49 +59,57 @@ namespace Simd //------------------------------------------------------------------------------------------------- - SIMD_INLINE __m512i Encode32f(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + static void MinMax16f(const uint16_t* src, size_t size, float& min, float& max) + { + assert(size % 8 == 0); + __m512 _min = _mm512_set1_ps(FLT_MAX); + __m512 _max = _mm512_set1_ps(-FLT_MAX); + size_t i = 0, sizeF = AlignLo(size, F); + for (; i < sizeF; i += F) + { + __m512 _src = _mm512_cvtph_ps(_mm256_loadu_si256((__m256i*)(src + i))); + _min = _mm512_min_ps(_src, _min); + _max = _mm512_max_ps(_src, _max); + } + for (; i < size; i += 8) + { + __m512 _src = _mm512_cvtph_ps(_mm256_maskz_loadu_epi16(0xFF, src + i)); + _min = _mm512_mask_min_ps(_min, 0xFF, _src, _min); + _max = _mm512_mask_max_ps(_max, 0xFF, _src, _max); + } + MinVal32f(_min, min); + MaxVal32f(_max, max); + } + + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE __m512i Encode32f(__m512 src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) { - __m512i value = _mm512_cvtps_epi32(_mm512_mul_ps(_mm512_sub_ps(_mm512_maskz_loadu_ps(mask, src), min), scale)); + __m512i value = _mm512_cvtps_epi32(_mm512_mul_ps(_mm512_sub_ps(src, min), scale)); sum = _mm512_add_epi32(value, sum); sqsum = _mm512_add_epi32(_mm512_madd_epi16(value, value), sqsum); return value; } + SIMD_INLINE __m512i Encode32f(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + return Encode32f(_mm512_maskz_loadu_ps(mask, src), scale, min, sum, sqsum); + } + static SIMD_INLINE __m128i Encode32f6x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) { - static const __m256i SHIFT = SIMD_MM256_SETR_EPI16(256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4); - static const __m256i SHFL0 = SIMD_MM256_SETR_EPI8( - 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1); - static const __m256i SHFL1 = SIMD_MM256_SETR_EPI8( - 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, -1, -1, -1, -1); __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); - __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), SHIFT); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, SHFL0), _mm256_shuffle_epi8(s0, SHFL1)); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E6_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E6_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E6_SHFL1)); return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); } static SIMD_INLINE __m256i Encode32f6x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) { - static const __m512i SHIFT = SIMD_MM512_SETR_EPI16( - 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, - 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4); - static const __m512i SHFL0 = SIMD_MM512_SETR_EPI8( - -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, - 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, - -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1); - static const __m512i SHFL1 = SIMD_MM512_SETR_EPI8( - -1, -1, -1, -1, 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, - 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, - -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, -1, -1, -1, -1); - static const __m512i PERM = SIMD_MM512_SETR_EPI64(0, 2, 1, 3, 4, 6, 5, 7); __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(PERM, _mm512_packus_epi32(i0, i1)), SHIFT); - __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, SHFL0), _mm512_shuffle_epi8(s0, SHFL1)); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E6_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E6_SHFL0), _mm512_shuffle_epi8(s0, E6_SHFL1)); return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); } @@ -125,39 +133,18 @@ namespace Simd static SIMD_INLINE __m128i Encode32f7x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) { - static const __m256i SHIFT = SIMD_MM256_SETR_EPI16(256, 128, 64, 32, 16, 8, 4, 2, 256, 128, 64, 32, 16, 8, 4, 2); - static const __m256i SHFL0 = SIMD_MM256_SETR_EPI8( - 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1); - static const __m256i SHFL1 = SIMD_MM256_SETR_EPI8( - 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1); __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); - __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), SHIFT); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, SHFL0), _mm256_shuffle_epi8(s0, SHFL1)); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E7_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E7_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E7_SHFL1)); return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); } static SIMD_INLINE __m256i Encode32f7x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) { - static const __m512i SHIFT = SIMD_MM512_SETR_EPI16( - 256, 128, 64, 32, 16, 8, 4, 2, 256, 128, 64, 32, 16, 8, 4, 2, - 256, 128, 64, 32, 16, 8, 4, 2, 256, 128, 64, 32, 16, 8, 4, 2); - static const __m512i SHFL0 = SIMD_MM512_SETR_EPI8( - -1, -1, 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, - 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, - -1, -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1); - static const __m512i SHFL1 = SIMD_MM512_SETR_EPI8( - -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, - 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, - -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, - -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1); - static const __m512i PERM = SIMD_MM512_SETR_EPI64(0, 2, 1, 3, 4, 6, 5, 7); __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(PERM, _mm512_packus_epi32(i0, i1)), SHIFT); - __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, SHFL0), _mm512_shuffle_epi8(s0, SHFL1)); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E7_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E7_SHFL0), _mm512_shuffle_epi8(s0, E7_SHFL1)); return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); } @@ -211,28 +198,110 @@ namespace Simd //------------------------------------------------------------------------------------------------- + SIMD_INLINE __m512i Encode16f(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + return Encode32f(_mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, src)), scale, min, sum, sqsum); + } + + static SIMD_INLINE __m128i Encode16f6x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E6_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E6_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E6_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } - const __m512i C6_PERM = SIMD_MM512_SETR_EPI32( - 0x0, 0x1, 0x0, 0x0, 0x1, 0x2, 0x0, 0x0, 0x3, 0x4, 0x0, 0x0, 0x4, 0x5, 0x0, 0x0); - const __m512i C6_SHFL = SIMD_MM512_SETR_EPI8( - 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, - 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x7, - 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, - 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x7); - const __m512i C6_MULLO = SIMD_MM512_SETR_EPI16( - 4, 16, 64, 256, 4, 16, 64, 256, 4, 16, 64, 256, 4, 16, 64, 256, - 4, 16, 64, 256, 4, 16, 64, 256, 4, 16, 64, 256, 4, 16, 64, 256); - - const __m512i C7_PERM = SIMD_MM512_SETR_EPI32( - 0x0, 0x1, 0x0, 0x0, 0x1, 0x2, 0x3, 0x0, 0x3, 0x4, 0x5, 0x0, 0x5, 0x6, 0x0, 0x0); - const __m512i C7_SHFL = SIMD_MM512_SETR_EPI8( - 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x6, - 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x9, 0x9, 0x9, - 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, - 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x7); - const __m512i C7_MULLO = SIMD_MM512_SETR_EPI16( - 2, 4, 8, 16, 32, 64, 128, 256, 2, 4, 8, 16, 32, 64, 128, 256, - 2, 4, 8, 16, 32, 64, 128, 256, 2, 4, 8, 16, 32, 64, 128, 256); + static SIMD_INLINE __m256i Encode16f6x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E6_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E6_SHFL0), _mm512_shuffle_epi8(s0, E6_SHFL1)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode16f6(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 24) + _mm256_mask_storeu_epi8(dst - 4, 0x0FFFFFF0, Encode16f6x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 12) + _mm_mask_storeu_epi8(dst, 0x0FFF, Encode16f6x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x003F, Encode16f6x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f7x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E7_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E7_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E7_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode16f7x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E7_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E7_SHFL0), _mm512_shuffle_epi8(s0, E7_SHFL1)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode16f7(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 28) + _mm256_mask_storeu_epi8(dst - 2, 0x3FFFFFFC, Encode16f7x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 14) + _mm_mask_storeu_epi8(dst, 0x3FFF, Encode16f7x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x007F, Encode16f7x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static void Encode16f8(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t sizeF = AlignLo(size, F), sizeA = AlignLo(size, A), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < sizeA; i += A) + { + __m512i d0 = Encode16f(src + i + 0 * F, _scale, _min, _sum, _sqsum); + __m512i d1 = Encode16f(src + i + 1 * F, _scale, _min, _sum, _sqsum); + __m512i d2 = Encode16f(src + i + 2 * F, _scale, _min, _sum, _sqsum); + __m512i d3 = Encode16f(src + i + 3 * F, _scale, _min, _sum, _sqsum); + _mm512_storeu_si512((__m512i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); + } + for (; i < sizeF; i += F) + { + __m512i d0 = Encode16f(src + i, _scale, _min, _sum, _sqsum); + _mm_storeu_si128((__m128i*)(dst + i), _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); + } + if (i < size) + { + __m512i d0 = Encode16f(src + i, _scale, _min, _sum, _sqsum, 0xFF); + _mm_mask_storeu_epi8(dst + i, 0xFF, _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } //------------------------------------------------------------------------------------------------- @@ -312,6 +381,82 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C6_SHFL), Avx2::C6_MULLO), 10); + _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 12; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s6 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + src += 6; + dst += 8; + } + } + + static void Decode16f7(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C7_SHFL), Avx2::C7_MULLO), 9); + _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 14; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s7 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + src += 7; + dst += 8; + } + } + + static void Decode16f8(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size64 = AlignLo(size, 64); + for (; i < size64; i += 64) + { + __m512i u8 = _mm512_loadu_si512((__m512i*)(src + i)); + _mm256_storeu_si256((__m256i*)(dst + i) + 0, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 0))), _scale, _shift), 0)); + _mm256_storeu_si256((__m256i*)(dst + i) + 1, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 1))), _scale, _shift), 0)); + _mm256_storeu_si256((__m256i*)(dst + i) + 2, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 2))), _scale, _shift), 0)); + _mm256_storeu_si256((__m256i*)(dst + i) + 3, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 3))), _scale, _shift), 0)); + } + for (; i < size16; i += 16) + { + __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); + _mm256_storeu_si256((__m256i*)(dst + i), _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(u8)), _scale, _shift), 0)); + } + if (i < size) + { + __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); + _mm_storeu_si128((__m128i*)(dst + i), _mm256_cvtps_ph(_mm256_fmadd_ps(_src, _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + } + } + + //------------------------------------------------------------------------------------------------- + template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); SIMD_INLINE __m512i Load6(const uint8_t* ptr, __mmask32 mask = 0x00FFFFFF) @@ -855,12 +1000,15 @@ namespace Simd : Avx2::DescrInt(size, depth) { _minMax32f = MinMax32f; + _minMax16f = MinMax16f; switch (depth) { case 6: { _encode32f = Encode32f6; + _encode16f = Encode16f6; _decode32f = Decode32f6; + _decode16f = Decode16f6; _cosineDistance = Avx512bw::CosineDistance<6>; _macroCosineDistances = Avx512bw::MacroCosineDistances<6>; break; @@ -868,7 +1016,9 @@ namespace Simd case 7: { _encode32f = Encode32f7; + _encode16f = Encode16f7; _decode32f = Decode32f7; + _decode16f = Decode16f7; _cosineDistance = Avx512bw::CosineDistance<7>; _macroCosineDistances = Avx512bw::MacroCosineDistances<7>; break; @@ -876,7 +1026,9 @@ namespace Simd case 8: { _encode32f = Encode32f8; + _encode16f = Encode16f8; _decode32f = Decode32f8; + _decode16f = Decode16f8; _cosineDistance = Avx512bw::CosineDistance<8>; _macroCosineDistances = Avx512bw::MacroCosineDistances<8>; _microM = 4; diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index 4ccd61365f..e02d971af8 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -131,6 +131,57 @@ namespace Simd #ifdef SIMD_AVX512BW_ENABLE namespace Avx512bw { + const __m512i EX_PERM = SIMD_MM512_SETR_EPI64(0, 2, 1, 3, 4, 6, 5, 7); + + const __m512i E6_MULLO = SIMD_MM512_SETR_EPI16( + 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, + 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4); + const __m512i E6_SHFL0 = SIMD_MM512_SETR_EPI8( + -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, + 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, + -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1); + const __m512i E6_SHFL1 = SIMD_MM512_SETR_EPI8( + -1, -1, -1, -1, 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, + 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, + -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0xA, 0xC, 0xE, -1, -1, -1, -1); + + const __m512i E7_MULLO = SIMD_MM512_SETR_EPI16( + 256, 128, 64, 32, 16, 8, 4, 2, 256, 128, 64, 32, 16, 8, 4, 2, + 256, 128, 64, 32, 16, 8, 4, 2, 256, 128, 64, 32, 16, 8, 4, 2); + const __m512i E7_SHFL0 = SIMD_MM512_SETR_EPI8( + -1, -1, 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, + 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, + -1, -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1); + const __m512i E7_SHFL1 = SIMD_MM512_SETR_EPI8( + -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, + 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, + -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1); + + const __m512i C6_PERM = SIMD_MM512_SETR_EPI32( + 0x0, 0x1, 0x0, 0x0, 0x1, 0x2, 0x0, 0x0, 0x3, 0x4, 0x0, 0x0, 0x4, 0x5, 0x0, 0x0); + const __m512i C6_SHFL = SIMD_MM512_SETR_EPI8( + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, + 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x7, + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, + 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x7); + const __m512i C6_MULLO = SIMD_MM512_SETR_EPI16( + 4, 16, 64, 256, 4, 16, 64, 256, 4, 16, 64, 256, 4, 16, 64, 256, + 4, 16, 64, 256, 4, 16, 64, 256, 4, 16, 64, 256, 4, 16, 64, 256); + + const __m512i C7_PERM = SIMD_MM512_SETR_EPI32( + 0x0, 0x1, 0x0, 0x0, 0x1, 0x2, 0x3, 0x0, 0x3, 0x4, 0x5, 0x0, 0x5, 0x6, 0x0, 0x0); + const __m512i C7_SHFL = SIMD_MM512_SETR_EPI8( + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x6, + 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x9, 0x9, 0x9, + 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, + 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x7); + const __m512i C7_MULLO = SIMD_MM512_SETR_EPI16( + 2, 4, 8, 16, 32, 64, 128, 256, 2, 4, 8, 16, 32, 64, 128, 256, + 2, 4, 8, 16, 32, 64, 128, 256, 2, 4, 8, 16, 32, 64, 128, 256); } #endif } From 0270246255110d52dd6bb7d3f4ad78ccbb9a2a56 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 7 Jun 2023 18:17:49 +0300 Subject: [PATCH 03/44] *refactoring of class ImagePngLoader. --- prj/vs2019/Base.vcxproj | 1 + prj/vs2019/Base.vcxproj.filters | 3 + prj/vs2022/Base.vcxproj | 1 + prj/vs2022/Base.vcxproj.filters | 3 + src/Simd/SimdAlignment.h | 2 +- src/Simd/SimdAllocator.hpp | 2 +- src/Simd/SimdAvx1SynetDeconvolution32f.cpp | 2 +- src/Simd/SimdAvx2DescrInt.cpp | 2 +- src/Simd/SimdAvx2Gemm32f.cpp | 2 +- src/Simd/SimdAvx2SynetNormalize.cpp | 2 +- src/Simd/SimdAvx512bwDescrInt.cpp | 2 +- src/Simd/SimdAvx512bwGemm32fNN.cpp | 2 +- src/Simd/SimdAvx512bwGemm32fNT.cpp | 2 +- src/Simd/SimdAvx512bwSynetNormalize.cpp | 2 +- src/Simd/SimdBaseCpu.cpp | 2 +- src/Simd/SimdBaseDescrInt.cpp | 2 +- src/Simd/SimdBaseImageLoadPng.cpp | 53 ++++++++-------- src/Simd/SimdBaseSynetNormalize.cpp | 2 +- src/Simd/SimdExp.h | 2 +- src/Simd/SimdExtract.h | 2 +- src/Simd/SimdImageLoad.h | 2 +- src/Simd/SimdImageLoadPng.h | 70 ++++++++++++++++++++++ src/Simd/SimdImageMatcher.hpp | 2 +- src/Simd/SimdPoint.hpp | 2 +- src/Simd/SimdSse41DescrInt.cpp | 2 +- src/Simd/SimdSse41SynetNormalize.cpp | 2 +- src/Simd/SimdSynetDeconvolution32f.h | 2 +- src/Test/TestCheckCpp.cpp | 2 +- src/Test/TestCompare.cpp | 2 +- src/Test/TestFile.cpp | 2 +- src/Test/TestFile.h | 2 +- src/Test/TestFont.cpp | 2 +- src/Test/TestImageMatcher.cpp | 2 +- src/Test/TestRandom.cpp | 2 +- src/Test/TestSynetInnerProduct.cpp | 2 +- src/Test/TestSynetNormalize.cpp | 2 +- src/Use/Use.cpp | 2 +- 37 files changed, 133 insertions(+), 60 deletions(-) create mode 100644 src/Simd/SimdImageLoadPng.h diff --git a/prj/vs2019/Base.vcxproj b/prj/vs2019/Base.vcxproj index c97b735ea3..16465b567b 100644 --- a/prj/vs2019/Base.vcxproj +++ b/prj/vs2019/Base.vcxproj @@ -48,6 +48,7 @@ + diff --git a/prj/vs2019/Base.vcxproj.filters b/prj/vs2019/Base.vcxproj.filters index 87b49edf51..9472abaff6 100644 --- a/prj/vs2019/Base.vcxproj.filters +++ b/prj/vs2019/Base.vcxproj.filters @@ -555,6 +555,9 @@ Inc + + Inc + diff --git a/prj/vs2022/Base.vcxproj b/prj/vs2022/Base.vcxproj index c97b735ea3..16465b567b 100644 --- a/prj/vs2022/Base.vcxproj +++ b/prj/vs2022/Base.vcxproj @@ -48,6 +48,7 @@ + diff --git a/prj/vs2022/Base.vcxproj.filters b/prj/vs2022/Base.vcxproj.filters index 87b49edf51..9472abaff6 100644 --- a/prj/vs2022/Base.vcxproj.filters +++ b/prj/vs2022/Base.vcxproj.filters @@ -555,6 +555,9 @@ Inc + + Inc + diff --git a/src/Simd/SimdAlignment.h b/src/Simd/SimdAlignment.h index 5426d6a850..9b9c6755e8 100644 --- a/src/Simd/SimdAlignment.h +++ b/src/Simd/SimdAlignment.h @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAllocator.hpp b/src/Simd/SimdAllocator.hpp index 0aa7166462..ce7c5c4b49 100644 --- a/src/Simd/SimdAllocator.hpp +++ b/src/Simd/SimdAllocator.hpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2020 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAvx1SynetDeconvolution32f.cpp b/src/Simd/SimdAvx1SynetDeconvolution32f.cpp index 914aaae2ef..35607f0358 100644 --- a/src/Simd/SimdAvx1SynetDeconvolution32f.cpp +++ b/src/Simd/SimdAvx1SynetDeconvolution32f.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 3a6baf81d7..4cdd8d65b2 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAvx2Gemm32f.cpp b/src/Simd/SimdAvx2Gemm32f.cpp index 0a12a5942e..4fca65b392 100644 --- a/src/Simd/SimdAvx2Gemm32f.cpp +++ b/src/Simd/SimdAvx2Gemm32f.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAvx2SynetNormalize.cpp b/src/Simd/SimdAvx2SynetNormalize.cpp index 1b836e6bc5..4c3421d7ad 100644 --- a/src/Simd/SimdAvx2SynetNormalize.cpp +++ b/src/Simd/SimdAvx2SynetNormalize.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index 93bde177f0..33e2a2b970 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAvx512bwGemm32fNN.cpp b/src/Simd/SimdAvx512bwGemm32fNN.cpp index d886e183b0..1e22b77953 100644 --- a/src/Simd/SimdAvx512bwGemm32fNN.cpp +++ b/src/Simd/SimdAvx512bwGemm32fNN.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAvx512bwGemm32fNT.cpp b/src/Simd/SimdAvx512bwGemm32fNT.cpp index 0c80b4fa2b..dfc97c9348 100644 --- a/src/Simd/SimdAvx512bwGemm32fNT.cpp +++ b/src/Simd/SimdAvx512bwGemm32fNT.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdAvx512bwSynetNormalize.cpp b/src/Simd/SimdAvx512bwSynetNormalize.cpp index 117c52c336..b5fc9bf2d4 100644 --- a/src/Simd/SimdAvx512bwSynetNormalize.cpp +++ b/src/Simd/SimdAvx512bwSynetNormalize.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdBaseCpu.cpp b/src/Simd/SimdBaseCpu.cpp index 637911d322..c828c3d164 100644 --- a/src/Simd/SimdBaseCpu.cpp +++ b/src/Simd/SimdBaseCpu.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar, +* Copyright (c) 2011-2023 Yermalayeu Ihar, * 2022-2022 Souriya Trinh, * 2022-2022 Fabien Spindler. * diff --git a/src/Simd/SimdBaseDescrInt.cpp b/src/Simd/SimdBaseDescrInt.cpp index 538c054c5e..b11d05c6c0 100644 --- a/src/Simd/SimdBaseDescrInt.cpp +++ b/src/Simd/SimdBaseDescrInt.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdBaseImageLoadPng.cpp b/src/Simd/SimdBaseImageLoadPng.cpp index 74300361dd..ada685bd0d 100644 --- a/src/Simd/SimdBaseImageLoadPng.cpp +++ b/src/Simd/SimdBaseImageLoadPng.cpp @@ -23,6 +23,7 @@ * SOFTWARE. */ #include "Simd/SimdImageLoad.h" +#include "Simd/SimdImageLoadPng.h" #include "Simd/SimdImageSavePng.h" #include "Simd/SimdArray.h" #include "Simd/SimdCpu.h" @@ -32,12 +33,6 @@ namespace Simd { namespace Base { - SIMD_INLINE int PngError(const char* str, const char* stub) - { - std::cout << "PNG load error: " << str << ", " << stub << "!" << std::endl; - return 0; - } - namespace Zlib { const size_t ZFAST_BITS = 9; @@ -65,7 +60,7 @@ namespace Simd sizes[0] = 0; for (i = 1; i < 16; ++i) if (sizes[i] > (1 << i)) - return PngError("bad sizes", "Corrupt PNG"); + return CorruptPngError("bad sizes"); code = 0; for (i = 1; i < 16; ++i) { @@ -74,12 +69,12 @@ namespace Simd firstSymbol[i] = (uint16_t)k; code = (code + sizes[i]); if (sizes[i] && code - 1 >= (1 << i)) - return PngError("bad codelengths", "Corrupt PNG"); - maxCode[i] = code << (16 - i); // preshift for inner loop + return CorruptPngError("bad codelengths"); + maxCode[i] = code << (16 - i); code <<= 1; k += sizes[i]; } - maxCode[16] = 0x10000; // sentinel + maxCode[16] = 0x10000; for (i = 0; i < num; ++i) { int s = sizelist[i]; @@ -163,7 +158,7 @@ namespace Simd if (z < 256) { if (z < 0) - return PngError("bad huffman code", "Corrupt PNG"); + return CorruptPngError("bad huffman code"); if (dst >= end) { os.Reserve(end - beg + 1); @@ -187,12 +182,12 @@ namespace Simd len += (int)is.ReadBits(zlengthExtra[z]); z = ZhuffmanDecode(is, zDistance); if (z < 0) - return PngError("bad huffman code", "Corrupt PNG"); + return CorruptPngError("bad huffman code"); dist = zdistBase[z]; if (zdistExtra[z]) dist += (int)is.ReadBits(zdistExtra[z]); if (dst - beg < dist) - return PngError("bad dist", "Corrupt PNG"); + return CorruptPngError("bad dist"); if (dst + len > end) { os.Reserve(dst - beg + len); @@ -258,7 +253,7 @@ namespace Simd { int c = ZhuffmanDecode(is, z_codelength); if (c < 0 || c >= 19) - return PngError("bad codelengths", "Corrupt PNG"); + return CorruptPngError("bad codelengths"); if (c < 16) lencodes[n++] = (uint8_t)c; else @@ -267,7 +262,7 @@ namespace Simd if (c == 16) { c = (int)is.ReadBits(2) + 3; - if (n == 0) return PngError("bad codelengths", "Corrupt PNG"); + if (n == 0) return CorruptPngError("bad codelengths"); fill = lencodes[n - 1]; } else if (c == 17) @@ -275,15 +270,15 @@ namespace Simd else if (c == 18) c = (int)is.ReadBits(7) + 11; else - return PngError("bad codelengths", "Corrupt PNG"); + return CorruptPngError("bad codelengths"); if (ntot - n < c) - return PngError("bad codelengths", "Corrupt PNG"); + return CorruptPngError("bad codelengths"); memset(lencodes + n, fill, c); n += c; } } if (n != ntot) - return PngError("bad codelengths", "Corrupt PNG"); + return CorruptPngError("bad codelengths"); if (!zLength.Build(lencodes, hlit)) return 0; if (!zDistance.Build(lencodes + hlit, hdist)) @@ -296,9 +291,9 @@ namespace Simd is.ClearBits(); uint16_t len, nlen; if (!is.Read16u(len) || !is.Read16u(nlen) || nlen != (len ^ 0xffff)) - return PngError("zlib corrupt", "Corrupt PNG"); + return CorruptPngError("zlib corrupt"); if (!os.Write(is, len)) - return PngError("read past buffer", "Corrupt PNG"); + return CorruptPngError("read past buffer"); return 1; } @@ -306,13 +301,13 @@ namespace Simd { uint8_t cmf, flg; if (!(is.Read8u(cmf) && is.Read8u(flg))) - return PngError("bad zlib header", "Corrupt PNG"); + return CorruptPngError("bad zlib header"); if ((int(cmf) * 256 + flg) % 31 != 0) - return PngError("bad zlib header", "Corrupt PNG"); + return CorruptPngError("bad zlib header"); if (flg & 32) - return PngError("no preset dict", "Corrupt PNG"); + return CorruptPngError("no preset dict"); if ((cmf & 15) != 8) - return PngError("bad compression", "Corrupt PNG"); + return CorruptPngError("bad compression"); return 1; } @@ -427,13 +422,13 @@ namespace Simd a.buf0.Resize(x * y * output_bytes); if (a.buf0.Empty()) - return PngError("outofmem", "Out of memory"); + return PngLoadError("outofmem", "Out of memory"); img_width_bytes = (img_n * x * depth + 7) >> 3; img_len = (img_width_bytes + 1) * y; if (raw_len < img_len) - return PngError("not enough pixels", "Corrupt PNG"); + return CorruptPngError("not enough pixels"); for (j = 0; j < y; ++j) { @@ -442,12 +437,12 @@ namespace Simd int filter = *raw++; if (filter > 4) - return PngError("invalid filter", "Corrupt PNG"); + return CorruptPngError("invalid filter"); if (depth < 8) { if (img_width_bytes > x) - return PngError("invalid width", "Corrupt PNG"); + return CorruptPngError("invalid width"); cur += x * out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place filter_bytes = 1; width = img_width_bytes; @@ -701,7 +696,7 @@ namespace Simd a.buf1.Resize(pixel_count * a.img_out_n); if(a.buf1.Empty()) - return PngError("outofmem", "Out of memory"); + return PngLoadError("outofmem", "Out of memory"); uint8_t* p = a.buf1.data; if (a.img_out_n == 3) diff --git a/src/Simd/SimdBaseSynetNormalize.cpp b/src/Simd/SimdBaseSynetNormalize.cpp index 7fec7c1a7b..76ec8170e0 100644 --- a/src/Simd/SimdBaseSynetNormalize.cpp +++ b/src/Simd/SimdBaseSynetNormalize.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdExp.h b/src/Simd/SimdExp.h index e6676dcbd1..99d94df40d 100644 --- a/src/Simd/SimdExp.h +++ b/src/Simd/SimdExp.h @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdExtract.h b/src/Simd/SimdExtract.h index 23fa192fb6..3909e2b005 100644 --- a/src/Simd/SimdExtract.h +++ b/src/Simd/SimdExtract.h @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdImageLoad.h b/src/Simd/SimdImageLoad.h index d941818143..ad090762ca 100644 --- a/src/Simd/SimdImageLoad.h +++ b/src/Simd/SimdImageLoad.h @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2021 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdImageLoadPng.h b/src/Simd/SimdImageLoadPng.h new file mode 100644 index 0000000000..810d6348da --- /dev/null +++ b/src/Simd/SimdImageLoadPng.h @@ -0,0 +1,70 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2021 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#ifndef __SimdImageLoadPng_h__ +#define __SimdImageLoadPng_h__ + +#include "Simd/SimdImageLoad.h" + +namespace Simd +{ + namespace Base + { + SIMD_INLINE int PngLoadError(const char* text, const char* type) + { + std::cout << "PNG load error: " << text << ", " << type << "!" << std::endl; + return 0; + } + + SIMD_INLINE int CorruptPngError(const char* text) + { + return PngLoadError(text, "Corrupt PNG"); + } + } + +#ifdef SIMD_SSE41_ENABLE + namespace Sse41 + { + } +#endif + +#ifdef SIMD_AVX2_ENABLE + namespace Avx2 + { + } +#endif + +#ifdef SIMD_AVX512BW_ENABLE + namespace Avx512bw + { + } +#endif + +#ifdef SIMD_NEON_ENABLE + namespace Neon + { + } +#endif +} + +#endif//__SimdImageLoadPng_h__ diff --git a/src/Simd/SimdImageMatcher.hpp b/src/Simd/SimdImageMatcher.hpp index d01c3f0a35..435f44041c 100644 --- a/src/Simd/SimdImageMatcher.hpp +++ b/src/Simd/SimdImageMatcher.hpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2017 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdPoint.hpp b/src/Simd/SimdPoint.hpp index 1a0fa39f2a..e5ac8d500b 100644 --- a/src/Simd/SimdPoint.hpp +++ b/src/Simd/SimdPoint.hpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2018 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index eda3f3048a..94035c0a7f 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdSse41SynetNormalize.cpp b/src/Simd/SimdSse41SynetNormalize.cpp index b620f2db18..74e0cff766 100644 --- a/src/Simd/SimdSse41SynetNormalize.cpp +++ b/src/Simd/SimdSse41SynetNormalize.cpp @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Simd/SimdSynetDeconvolution32f.h b/src/Simd/SimdSynetDeconvolution32f.h index 8ba439afab..a45827c3b6 100644 --- a/src/Simd/SimdSynetDeconvolution32f.h +++ b/src/Simd/SimdSynetDeconvolution32f.h @@ -1,7 +1,7 @@ /* * Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestCheckCpp.cpp b/src/Test/TestCheckCpp.cpp index 7e1abd5470..db445cf1bb 100644 --- a/src/Test/TestCheckCpp.cpp +++ b/src/Test/TestCheckCpp.cpp @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestCompare.cpp b/src/Test/TestCompare.cpp index 371f9e4aeb..2ab356a3ed 100644 --- a/src/Test/TestCompare.cpp +++ b/src/Test/TestCompare.cpp @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestFile.cpp b/src/Test/TestFile.cpp index e30cb0ac2b..caf3a156a4 100644 --- a/src/Test/TestFile.cpp +++ b/src/Test/TestFile.cpp @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestFile.h b/src/Test/TestFile.h index 2cbff5a869..55df3e93af 100644 --- a/src/Test/TestFile.h +++ b/src/Test/TestFile.h @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestFont.cpp b/src/Test/TestFont.cpp index 89571e26ff..b6c6c6ba42 100644 --- a/src/Test/TestFont.cpp +++ b/src/Test/TestFont.cpp @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestImageMatcher.cpp b/src/Test/TestImageMatcher.cpp index 76de62c569..eacfd6bef1 100644 --- a/src/Test/TestImageMatcher.cpp +++ b/src/Test/TestImageMatcher.cpp @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestRandom.cpp b/src/Test/TestRandom.cpp index 892452619b..2d08b37102 100644 --- a/src/Test/TestRandom.cpp +++ b/src/Test/TestRandom.cpp @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestSynetInnerProduct.cpp b/src/Test/TestSynetInnerProduct.cpp index a24d86ec51..b20e7ce04e 100644 --- a/src/Test/TestSynetInnerProduct.cpp +++ b/src/Test/TestSynetInnerProduct.cpp @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Test/TestSynetNormalize.cpp b/src/Test/TestSynetNormalize.cpp index f064b40b73..b8f7d06db7 100644 --- a/src/Test/TestSynetNormalize.cpp +++ b/src/Test/TestSynetNormalize.cpp @@ -1,7 +1,7 @@ /* * Tests for Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2022 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal diff --git a/src/Use/Use.cpp b/src/Use/Use.cpp index 34bca94e69..34f2cc6547 100644 --- a/src/Use/Use.cpp +++ b/src/Use/Use.cpp @@ -1,7 +1,7 @@ /* * The use examples of Simd Library (http://ermig1979.github.io/Simd). * -* Copyright (c) 2011-2021 Yermalayeu Ihar. +* Copyright (c) 2011-2023 Yermalayeu Ihar. * * Permission is hereby granted, free of charge, to any person obtaining a copy * of this software and associated documentation files (the "Software"), to deal From 72e2d9e54be6588c650d56be2d45f45ace2c63c0 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 8 Jun 2023 11:59:48 +0300 Subject: [PATCH 04/44] *refactoring of class ImagePngLoader (part 2). --- src/Simd/SimdBaseImageLoadPng.cpp | 210 ++++++++++++++---------------- src/Simd/SimdImageLoad.h | 10 +- 2 files changed, 104 insertions(+), 116 deletions(-) diff --git a/src/Simd/SimdBaseImageLoadPng.cpp b/src/Simd/SimdBaseImageLoadPng.cpp index ada685bd0d..0889a86d23 100644 --- a/src/Simd/SimdBaseImageLoadPng.cpp +++ b/src/Simd/SimdBaseImageLoadPng.cpp @@ -370,20 +370,6 @@ namespace Simd #define PNG__BYTECAST(x) ((uint8_t) ((x) & 255)) // truncate int to byte without warnings - struct Png - { - uint32_t width, height; - int channels, img_out_n; - uint8_t depth; - Array8u buf0, buf1; - - SIMD_INLINE int Swap() - { - buf0.Swap(buf1); - return 1; - } - }; - enum { PNG__F_none = 0, @@ -406,22 +392,22 @@ namespace Simd static const uint8_t DepthScaleTable[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; - static int CreatePngImageRaw(Png& a, const uint8_t* raw, uint32_t raw_len, int out_n, uint32_t x, uint32_t y, int depth, int color) + static int CreatePngImageRaw(const uint8_t* raw, uint32_t raw_len, int outN, uint32_t x, uint32_t y, int depth, int color, int channels, Array8u & dst) { int bytes = (depth == 16 ? 2 : 1); - uint32_t i, j, stride = x * out_n * bytes; + uint32_t i, j, stride = x * outN * bytes; uint32_t img_len, img_width_bytes; int k; - int img_n = a.channels; + int img_n = channels; - int output_bytes = out_n * bytes; + int output_bytes = outN * bytes; int filter_bytes = img_n * bytes; int width = x; - assert(out_n == a.channels || out_n == a.channels + 1); + assert(outN == channels || outN == channels + 1); - a.buf0.Resize(x * y * output_bytes); - if (a.buf0.Empty()) + dst.Resize(x * y * output_bytes); + if (dst.Empty()) return PngLoadError("outofmem", "Out of memory"); img_width_bytes = (img_n * x * depth + 7) >> 3; @@ -432,7 +418,7 @@ namespace Simd for (j = 0; j < y; ++j) { - uint8_t* cur = a.buf0.data + stride * j; + uint8_t* cur = dst.data + stride * j; uint8_t* prior; int filter = *raw++; @@ -443,7 +429,7 @@ namespace Simd { if (img_width_bytes > x) return CorruptPngError("invalid width"); - cur += x * out_n - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place + cur += x * outN - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place filter_bytes = 1; width = img_width_bytes; } @@ -467,15 +453,15 @@ namespace Simd if (depth == 8) { - if (img_n != out_n) + if (img_n != outN) cur[img_n] = 255; // first pixel raw += img_n; - cur += out_n; - prior += out_n; + cur += outN; + prior += outN; } else if (depth == 16) { - if (img_n != out_n) + if (img_n != outN) { cur[filter_bytes] = 255; // first pixel top byte cur[filter_bytes + 1] = 255; // first pixel bottom byte @@ -490,7 +476,7 @@ namespace Simd cur += 1; prior += 1; } - if (depth < 8 || img_n == out_n) + if (depth < 8 || img_n == outN) { int nk = (width - 1) * filter_bytes; #define PNG__CASE(f) \ @@ -510,7 +496,7 @@ namespace Simd } else { - assert(img_n + 1 == out_n); + assert(img_n + 1 == outN); #define PNG__CASE(f) \ case f: \ for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ @@ -527,7 +513,7 @@ namespace Simd #undef PNG__CASE if (depth == 16) { - cur = a.buf0.data + stride * j; + cur = dst.data + stride * j; for (i = 0; i < x; ++i, cur += output_bytes) cur[filter_bytes + 1] = 255; } @@ -537,8 +523,8 @@ namespace Simd { for (j = 0; j < y; ++j) { - uint8_t* cur = a.buf0.data + stride * j; - const uint8_t* in = a.buf0.data + stride * j + x * out_n - img_width_bytes; + uint8_t* cur = dst.data + stride * j; + const uint8_t* in = dst.data + stride * j + x * outN - img_width_bytes; uint8_t scale = (color == 0) ? DepthScaleTable[depth] : 1; if (depth == 4) { @@ -587,10 +573,10 @@ namespace Simd if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); } - if (img_n != out_n) + if (img_n != outN) { int q; - cur = a.buf0.data + stride * j; + cur = dst.data + stride * j; if (img_n == 1) { for (q = x - 1; q >= 0; --q) @@ -615,24 +601,24 @@ namespace Simd } else if (depth == 16) { - uint8_t* cur = a.buf0.data; + uint8_t* cur = dst.data; uint16_t* cur16 = (uint16_t*)cur; - for (i = 0; i < x * y * out_n; ++i, cur16++, cur += 2) + for (i = 0; i < x * y * outN; ++i, cur16++, cur += 2) *cur16 = (cur[0] << 8) | cur[1]; } return 1; } - static int CreatePngImage(Png& a, const uint8_t* image_data, uint32_t image_data_len, int out_n, int depth, int color, int interlaced) + static int CreatePngImage(const uint8_t* image_data, uint32_t image_data_len, int outN, int depth, int color, int interlaced, int width, int height, int channels, Array8u& dst) { SIMD_PERF_FUNC(); int bytes = (depth == 16 ? 2 : 1); - int out_bytes = out_n * bytes; + int out_bytes = outN * bytes; if (!interlaced) - return CreatePngImageRaw(a, image_data, image_data_len, out_n, a.width, a.height, depth, color); + return CreatePngImageRaw(image_data, image_data_len, outN, width, height, depth, color, channels, dst); - a.buf1.Resize(a.width * a.height * out_bytes); + Array8u buf(width * height * out_bytes); for (int p = 0; p < 7; ++p) { int xorig[] = { 0,4,0,2,0,1,0 }; @@ -640,12 +626,12 @@ namespace Simd int xspc[] = { 8,8,4,4,2,2,1 }; int yspc[] = { 8,8,8,4,4,2,2 }; int i, j, x, y; - x = (a.width - xorig[p] + xspc[p] - 1) / xspc[p]; - y = (a.height - yorig[p] + yspc[p] - 1) / yspc[p]; + x = (width - xorig[p] + xspc[p] - 1) / xspc[p]; + y = (height - yorig[p] + yspc[p] - 1) / yspc[p]; if (x && y) { - uint32_t img_len = ((((a.channels * x * depth) + 7) >> 3) + 1) * y; - if (!CreatePngImageRaw(a, image_data, image_data_len, out_n, x, y, depth, color)) + uint32_t img_len = ((((channels * x * depth) + 7) >> 3) + 1) * y; + if (!CreatePngImageRaw(image_data, image_data_len, outN, x, y, depth, color, channels, dst)) { return 0; } @@ -655,20 +641,23 @@ namespace Simd { int out_y = j * yspc[p] + yorig[p]; int out_x = i * xspc[p] + xorig[p]; - memcpy(a.buf1.data + out_y * a.width * out_bytes + out_x * out_bytes, - a.buf0.data + (j * x + i) * out_bytes, out_bytes); + memcpy(buf.data + out_y * width * out_bytes + out_x * out_bytes, + dst.data + (j * x + i) * out_bytes, out_bytes); } } image_data += img_len; image_data_len -= img_len; } } - return a.Swap(); + dst.Swap(buf); + return 1; } - template void ComputeTransparency(T * dst, size_t size, size_t out_n, T tc[3]) + //------------------------------------------------------------------------------------------------- + + template void ComputeTransparency(T * dst, size_t size, size_t outN, T tc[3]) { - if (out_n == 2) + if (outN == 2) { for (size_t i = 0; i < size; ++i) { @@ -676,7 +665,7 @@ namespace Simd dst += 2; } } - else if (out_n == 4) + else if (outN == 4) { for (size_t i = 0; i < size; ++i) { @@ -689,40 +678,35 @@ namespace Simd assert(0); } - static int ExpandPalette(Png & a, const uint8_t* palette) - { - uint32_t i, pixel_count = a.width * a.height; - uint8_t * orig = a.buf0.data; - - a.buf1.Resize(pixel_count * a.img_out_n); - if(a.buf1.Empty()) - return PngLoadError("outofmem", "Out of memory"); + //------------------------------------------------------------------------------------------------- - uint8_t* p = a.buf1.data; - if (a.img_out_n == 3) + static void ExpandPalette(const uint8_t* src, size_t size, int outN, const uint8_t* palette, uint8_t* dst) + { + if (outN == 3) { - for (i = 0; i < pixel_count; ++i) + for (size_t i = 0; i < size; ++i) { - int n = orig[i] * 4; - p[0] = palette[n]; - p[1] = palette[n + 1]; - p[2] = palette[n + 2]; - p += 3; + int n = src[i] * 4; + dst[0] = palette[n]; + dst[1] = palette[n + 1]; + dst[2] = palette[n + 2]; + dst += 3; } } - else + else if (outN == 4) { - for (i = 0; i < pixel_count; ++i) + for (size_t i = 0; i < size; ++i) { - int n = orig[i] * 4; - p[0] = palette[n]; - p[1] = palette[n + 1]; - p[2] = palette[n + 2]; - p[3] = palette[n + 3]; - p += 4; + int n = src[i] * 4; + dst[0] = palette[n]; + dst[1] = palette[n + 1]; + dst[2] = palette[n + 2]; + dst[3] = palette[n + 3]; + dst += 4; } } - return a.Swap(); + else + assert(0); } //------------------------------------------------------------------------------------------------- @@ -887,11 +871,12 @@ namespace Simd { if (_param.format == SimdPixelFormatNone) _param.format = SimdPixelFormatRgba32; + _expandPalette = Base::ExpandPalette; } - void ImagePngLoader::SetConverter(int channels) + void ImagePngLoader::SetConverter() { - _converter = GetConverter(_depth, channels, _param.format); + _converter = GetConverter(_depth, _outN, _param.format); } #ifdef SIMD_CPP_2011_ENABLE @@ -908,54 +893,25 @@ namespace Simd if (!ParseFile()) return false; - Png p; - p.width = _width; - p.height = _height; - p.channels = _channels; - p.depth = _depth; - InputMemoryStream zSrc = MergedDataStream(); OutputMemoryStream zDst(AlignHi(size_t(_width) * _depth, 8) * _height * _channels + _height); if(!Zlib::Decode(zSrc, zDst, !_iPhone)) return false; - int req_comp = 4; - if (Image::ChannelCount((Image::Format)_param.format) == _channels && _depth != 16) - req_comp = _channels; - - if ((req_comp == p.channels + 1 && req_comp != 3 && !_paletteChannels) || _hasTrans) - p.img_out_n = p.channels + 1; - else - p.img_out_n = p.channels; - if (!CreatePngImage(p, zDst.Data(), (int)zDst.Size(), p.img_out_n, p.depth, _color, _interlace)) + if (!CreatePngImage(zDst.Data(), (int)zDst.Size(), _outN, _depth, _color, _interlace, _width, _height, _channels, _buffer)) return 0; + if (_hasTrans) { - if (p.depth == 16) - ComputeTransparency((uint16_t*)p.buf0.data, p.width * p.height, p.img_out_n, _tc16); - + if (_depth == 16) + ComputeTransparency((uint16_t*)_buffer.data, _width * _height, _outN, _tc16); else - ComputeTransparency(p.buf0.data, p.width * p.height, p.img_out_n, _tc); + ComputeTransparency(_buffer.data, _width * _height, _outN, _tc); } - if (_paletteChannels) - { - p.channels = _paletteChannels; - p.img_out_n = _paletteChannels; - if (req_comp >= 3) - p.img_out_n = req_comp; - if (!ExpandPalette(p, _palette.data)) - return false; - } - else if (_hasTrans) - ++p.channels; - if (!(p.depth <= 8 || p.depth == 16)) - return false; + ExpandPalette(); - SIMD_PERF_BEG("conversion"); - SetConverter(p.img_out_n); - _image.Recreate(p.width, p.height, (Image::Format)_param.format); - _converter(p.buf0.data, p.width, p.height, p.width * p.img_out_n, _image.data, _image.stride); + ConvertImage(); return true; } @@ -1011,6 +967,15 @@ namespace Simd if (!_stream.ReadBe32u(crc32)) return false; } + int reqN = 4; + if (Image::ChannelCount((Image::Format)_param.format) == _channels && _depth != 16) + reqN = _channels; + else + reqN = 4; + if ((reqN == _channels + 1 && reqN != 3 && !_paletteChannels) || _hasTrans) + _outN = _channels + 1; + else + _outN = _channels; return _idats.size() != 0; } @@ -1154,5 +1119,24 @@ namespace Simd return InputMemoryStream(_idat.data, _idat.size); } } + + void ImagePngLoader::ExpandPalette() + { + if (_paletteChannels) + { + _outN = Max(_paletteChannels, _outN); + Array8u buf(_width * _height * _outN); + _expandPalette(_buffer.data, _width * _height, _outN, _palette.data, buf.data); + _buffer.Swap(buf); + } + } + + void ImagePngLoader::ConvertImage() + { + SIMD_PERF_FUNC(); + SetConverter(); + _image.Recreate(_width, _height, (Image::Format)_param.format); + _converter(_buffer.data, _width, _height, _width * _outN, _image.data, _image.stride); + } } } diff --git a/src/Simd/SimdImageLoad.h b/src/Simd/SimdImageLoad.h index ad090762ca..2f11853e32 100644 --- a/src/Simd/SimdImageLoad.h +++ b/src/Simd/SimdImageLoad.h @@ -154,19 +154,21 @@ namespace Simd virtual bool FromStream(); + typedef void (*ExpandPalettePtr)(const uint8_t* src, size_t size, int outN, const uint8_t* palette, uint8_t* dst); typedef void (*ConverterPtr)(const uint8_t* src, size_t width, size_t height, size_t srcStride, uint8_t* dst, size_t dstStride); protected: + ExpandPalettePtr _expandPalette; ConverterPtr _converter; - virtual void SetConverter(int channels); + virtual void SetConverter(); private: bool _first, _hasTrans, _iPhone; - uint32_t _width, _height, _channels; + uint32_t _width, _height, _channels, _outN; uint16_t _tc16[3]; uint8_t _depth, _color, _interlace, _paletteChannels, _tc[3]; - Array8u _palette, _idat; + Array8u _palette, _idat, _buffer; struct Chunk { @@ -185,6 +187,8 @@ namespace Simd bool ReadTransparency(const Chunk& chunk); bool ReadData(const Chunk& chunk); InputMemoryStream MergedDataStream(); + void ExpandPalette(); + void ConvertImage(); }; class ImageJpegLoader : public ImageLoader From c147637fb0eedc8eaf0a21e842a9712c826c08b1 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Fri, 9 Jun 2023 15:53:50 +0300 Subject: [PATCH 05/44] *refactoring of class ImagePngLoader (part 3). --- src/Simd/SimdBaseImageLoadPng.cpp | 522 +++++++++++++++--------------- src/Simd/SimdImageLoad.h | 2 + 2 files changed, 261 insertions(+), 263 deletions(-) diff --git a/src/Simd/SimdBaseImageLoadPng.cpp b/src/Simd/SimdBaseImageLoadPng.cpp index 0889a86d23..f35e3256ce 100644 --- a/src/Simd/SimdBaseImageLoadPng.cpp +++ b/src/Simd/SimdBaseImageLoadPng.cpp @@ -368,6 +368,8 @@ namespace Simd } } + //------------------------------------------------------------------------------------------------- + #define PNG__BYTECAST(x) ((uint8_t) ((x) & 255)) // truncate int to byte without warnings enum @@ -392,267 +394,6 @@ namespace Simd static const uint8_t DepthScaleTable[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; - static int CreatePngImageRaw(const uint8_t* raw, uint32_t raw_len, int outN, uint32_t x, uint32_t y, int depth, int color, int channels, Array8u & dst) - { - int bytes = (depth == 16 ? 2 : 1); - uint32_t i, j, stride = x * outN * bytes; - uint32_t img_len, img_width_bytes; - int k; - int img_n = channels; - - int output_bytes = outN * bytes; - int filter_bytes = img_n * bytes; - int width = x; - - assert(outN == channels || outN == channels + 1); - - dst.Resize(x * y * output_bytes); - if (dst.Empty()) - return PngLoadError("outofmem", "Out of memory"); - - img_width_bytes = (img_n * x * depth + 7) >> 3; - img_len = (img_width_bytes + 1) * y; - - if (raw_len < img_len) - return CorruptPngError("not enough pixels"); - - for (j = 0; j < y; ++j) - { - uint8_t* cur = dst.data + stride * j; - uint8_t* prior; - int filter = *raw++; - - if (filter > 4) - return CorruptPngError("invalid filter"); - - if (depth < 8) - { - if (img_width_bytes > x) - return CorruptPngError("invalid width"); - cur += x * outN - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place - filter_bytes = 1; - width = img_width_bytes; - } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above - if (j == 0) - filter = FirstRowFilter[filter]; - - for (k = 0; k < filter_bytes; ++k) - { - switch (filter) - { - case PNG__F_none: cur[k] = raw[k]; break; - case PNG__F_sub: cur[k] = raw[k]; break; - case PNG__F_up: cur[k] = PNG__BYTECAST(raw[k] + prior[k]); break; - case PNG__F_avg: cur[k] = PNG__BYTECAST(raw[k] + (prior[k] >> 1)); break; - case PNG__F_paeth: cur[k] = PNG__BYTECAST(raw[k] + Paeth(0, prior[k], 0)); break; - case PNG__F_avg_first: cur[k] = raw[k]; break; - case PNG__F_paeth_first: cur[k] = raw[k]; break; - } - } - - if (depth == 8) - { - if (img_n != outN) - cur[img_n] = 255; // first pixel - raw += img_n; - cur += outN; - prior += outN; - } - else if (depth == 16) - { - if (img_n != outN) - { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes + 1] = 255; // first pixel bottom byte - } - raw += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } - else - { - raw += 1; - cur += 1; - prior += 1; - } - if (depth < 8 || img_n == outN) - { - int nk = (width - 1) * filter_bytes; -#define PNG__CASE(f) \ - case f: \ - for (k=0; k < nk; ++k) - switch (filter) { - case PNG__F_none: memcpy(cur, raw, nk); break; - PNG__CASE(PNG__F_sub) { cur[k] = PNG__BYTECAST(raw[k] + cur[k - filter_bytes]); } break; - PNG__CASE(PNG__F_up) { cur[k] = PNG__BYTECAST(raw[k] + prior[k]); } break; - PNG__CASE(PNG__F_avg) { cur[k] = PNG__BYTECAST(raw[k] + ((prior[k] + cur[k - filter_bytes]) >> 1)); } break; - PNG__CASE(PNG__F_paeth) { cur[k] = PNG__BYTECAST(raw[k] + Paeth(cur[k - filter_bytes], prior[k], prior[k - filter_bytes])); } break; - PNG__CASE(PNG__F_avg_first) { cur[k] = PNG__BYTECAST(raw[k] + (cur[k - filter_bytes] >> 1)); } break; - PNG__CASE(PNG__F_paeth_first) { cur[k] = PNG__BYTECAST(raw[k] + Paeth(cur[k - filter_bytes], 0, 0)); } break; - } -#undef PNG__CASE - raw += nk; - } - else - { - assert(img_n + 1 == outN); -#define PNG__CASE(f) \ - case f: \ - for (i=x-1; i >= 1; --i, cur[filter_bytes]=255,raw+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ - for (k=0; k < filter_bytes; ++k) - switch (filter) { - PNG__CASE(PNG__F_none) { cur[k] = raw[k]; } break; - PNG__CASE(PNG__F_sub) { cur[k] = PNG__BYTECAST(raw[k] + cur[k - output_bytes]); } break; - PNG__CASE(PNG__F_up) { cur[k] = PNG__BYTECAST(raw[k] + prior[k]); } break; - PNG__CASE(PNG__F_avg) { cur[k] = PNG__BYTECAST(raw[k] + ((prior[k] + cur[k - output_bytes]) >> 1)); } break; - PNG__CASE(PNG__F_paeth) { cur[k] = PNG__BYTECAST(raw[k] + Paeth(cur[k - output_bytes], prior[k], prior[k - output_bytes])); } break; - PNG__CASE(PNG__F_avg_first) { cur[k] = PNG__BYTECAST(raw[k] + (cur[k - output_bytes] >> 1)); } break; - PNG__CASE(PNG__F_paeth_first) { cur[k] = PNG__BYTECAST(raw[k] + Paeth(cur[k - output_bytes], 0, 0)); } break; - } -#undef PNG__CASE - if (depth == 16) - { - cur = dst.data + stride * j; - for (i = 0; i < x; ++i, cur += output_bytes) - cur[filter_bytes + 1] = 255; - } - } - } - if (depth < 8) - { - for (j = 0; j < y; ++j) - { - uint8_t* cur = dst.data + stride * j; - const uint8_t* in = dst.data + stride * j + x * outN - img_width_bytes; - uint8_t scale = (color == 0) ? DepthScaleTable[depth] : 1; - if (depth == 4) - { - for (k = x * img_n; k >= 2; k -= 2, ++in) - { - *cur++ = scale * ((*in >> 4)); - *cur++ = scale * ((*in) & 0x0f); - } - if (k > 0) - *cur++ = scale * ((*in >> 4)); - } - else if (depth == 2) - { - for (k = x * img_n; k >= 4; k -= 4, ++in) - { - *cur++ = scale * ((*in >> 6)); - *cur++ = scale * ((*in >> 4) & 0x03); - *cur++ = scale * ((*in >> 2) & 0x03); - *cur++ = scale * ((*in) & 0x03); - } - if (k > 0) - *cur++ = scale * ((*in >> 6)); - if (k > 1) - *cur++ = scale * ((*in >> 4) & 0x03); - if (k > 2) - *cur++ = scale * ((*in >> 2) & 0x03); - } - else if (depth == 1) - { - for (k = x * img_n; k >= 8; k -= 8, ++in) - { - *cur++ = scale * ((*in >> 7)); - *cur++ = scale * ((*in >> 6) & 0x01); - *cur++ = scale * ((*in >> 5) & 0x01); - *cur++ = scale * ((*in >> 4) & 0x01); - *cur++ = scale * ((*in >> 3) & 0x01); - *cur++ = scale * ((*in >> 2) & 0x01); - *cur++ = scale * ((*in >> 1) & 0x01); - *cur++ = scale * ((*in) & 0x01); - } - if (k > 0) *cur++ = scale * ((*in >> 7)); - if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01); - if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01); - if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01); - if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01); - if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); - if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); - } - if (img_n != outN) - { - int q; - cur = dst.data + stride * j; - if (img_n == 1) - { - for (q = x - 1; q >= 0; --q) - { - cur[q * 2 + 1] = 255; - cur[q * 2 + 0] = cur[q]; - } - } - else - { - assert(img_n == 3); - for (q = x - 1; q >= 0; --q) - { - cur[q * 4 + 3] = 255; - cur[q * 4 + 2] = cur[q * 3 + 2]; - cur[q * 4 + 1] = cur[q * 3 + 1]; - cur[q * 4 + 0] = cur[q * 3 + 0]; - } - } - } - } - } - else if (depth == 16) - { - uint8_t* cur = dst.data; - uint16_t* cur16 = (uint16_t*)cur; - for (i = 0; i < x * y * outN; ++i, cur16++, cur += 2) - *cur16 = (cur[0] << 8) | cur[1]; - } - return 1; - } - - static int CreatePngImage(const uint8_t* image_data, uint32_t image_data_len, int outN, int depth, int color, int interlaced, int width, int height, int channels, Array8u& dst) - { - SIMD_PERF_FUNC(); - - int bytes = (depth == 16 ? 2 : 1); - int out_bytes = outN * bytes; - if (!interlaced) - return CreatePngImageRaw(image_data, image_data_len, outN, width, height, depth, color, channels, dst); - - Array8u buf(width * height * out_bytes); - for (int p = 0; p < 7; ++p) - { - int xorig[] = { 0,4,0,2,0,1,0 }; - int yorig[] = { 0,0,4,0,2,0,1 }; - int xspc[] = { 8,8,4,4,2,2,1 }; - int yspc[] = { 8,8,8,4,4,2,2 }; - int i, j, x, y; - x = (width - xorig[p] + xspc[p] - 1) / xspc[p]; - y = (height - yorig[p] + yspc[p] - 1) / yspc[p]; - if (x && y) - { - uint32_t img_len = ((((channels * x * depth) + 7) >> 3) + 1) * y; - if (!CreatePngImageRaw(image_data, image_data_len, outN, x, y, depth, color, channels, dst)) - { - return 0; - } - for (j = 0; j < y; ++j) - { - for (i = 0; i < x; ++i) - { - int out_y = j * yspc[p] + yorig[p]; - int out_x = i * xspc[p] + xorig[p]; - memcpy(buf.data + out_y * width * out_bytes + out_x * out_bytes, - dst.data + (j * x + i) * out_bytes, out_bytes); - } - } - image_data += img_len; - image_data_len -= img_len; - } - } - dst.Swap(buf); - return 1; - } - //------------------------------------------------------------------------------------------------- template void ComputeTransparency(T * dst, size_t size, size_t outN, T tc[3]) @@ -898,8 +639,8 @@ namespace Simd if(!Zlib::Decode(zSrc, zDst, !_iPhone)) return false; - if (!CreatePngImage(zDst.Data(), (int)zDst.Size(), _outN, _depth, _color, _interlace, _width, _height, _channels, _buffer)) - return 0; + if (!CreateImage(zDst.Data(), zDst.Size())) + return false; if (_hasTrans) { @@ -1120,6 +861,261 @@ namespace Simd } } + bool ImagePngLoader::CreateImage(const uint8_t* data, size_t size) + { + SIMD_PERF_FUNC(); + int outS = _outN * (_depth == 16 ? 2 : 1); + if (!_interlace) + return CreateImageRaw(data, (int)size, _width, _height); + Array8u buf(_width * _height * outS); + for (int p = 0; p < 7; ++p) + { + int xorig[] = { 0,4,0,2,0,1,0 }; + int yorig[] = { 0,0,4,0,2,0,1 }; + int xspc[] = { 8,8,4,4,2,2,1 }; + int yspc[] = { 8,8,8,4,4,2,2 }; + int i, j, x, y; + x = (_width - xorig[p] + xspc[p] - 1) / xspc[p]; + y = (_height - yorig[p] + yspc[p] - 1) / yspc[p]; + if (x && y) + { + uint32_t img_len = ((((_channels * x * _depth) + 7) >> 3) + 1) * y; + if (!CreateImageRaw(data, (int)size, x, y)) + return false; + for (j = 0; j < y; ++j) + { + for (i = 0; i < x; ++i) + { + int out_y = j * yspc[p] + yorig[p]; + int out_x = i * xspc[p] + xorig[p]; + memcpy(buf.data + out_y * _width * outS + out_x * outS, _buffer.data + (j * x + i) * outS, outS); + } + } + data += img_len; + size -= img_len; + } + } + _buffer.Swap(buf); + return true; + } + + bool ImagePngLoader::CreateImageRaw(const uint8_t* data, uint32_t size, uint32_t width, uint32_t height) + { + int bytes = (_depth == 16 ? 2 : 1); + uint32_t i, j, stride = width * _outN * bytes; + uint32_t img_len, img_width_bytes; + int k; + int img_n = _channels; + int width_ = width; + + int output_bytes = _outN * bytes; + int filter_bytes = img_n * bytes; + + assert(_outN == _channels || _outN == _channels + 1); + + _buffer.Resize(width * height * output_bytes); + if (_buffer.Empty()) + return PngLoadError("outofmem", "Out of memory"); + + img_width_bytes = (img_n * width * _depth + 7) >> 3; + img_len = (img_width_bytes + 1) * height; + + if (size < img_len) + return CorruptPngError("not enough pixels"); + + for (j = 0; j < height; ++j) + { + uint8_t* cur = _buffer.data + stride * j; + uint8_t* prior; + int filter = *data++; + + if (filter > 4) + return CorruptPngError("invalid filter"); + + if (_depth < 8) + { + if (img_width_bytes > width) + return CorruptPngError("invalid width"); + cur += width * _outN - img_width_bytes; // store output to the rightmost img_len bytes, so we can decode in place + filter_bytes = 1; + width_ = img_width_bytes; + } + prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above + if (j == 0) + filter = FirstRowFilter[filter]; + + for (k = 0; k < filter_bytes; ++k) + { + switch (filter) + { + case PNG__F_none: cur[k] = data[k]; break; + case PNG__F_sub: cur[k] = data[k]; break; + case PNG__F_up: cur[k] = PNG__BYTECAST(data[k] + prior[k]); break; + case PNG__F_avg: cur[k] = PNG__BYTECAST(data[k] + (prior[k] >> 1)); break; + case PNG__F_paeth: cur[k] = PNG__BYTECAST(data[k] + Paeth(0, prior[k], 0)); break; + case PNG__F_avg_first: cur[k] = data[k]; break; + case PNG__F_paeth_first: cur[k] = data[k]; break; + } + } + + if (_depth == 8) + { + if (img_n != _outN) + cur[img_n] = 255; // first pixel + data += img_n; + cur += _outN; + prior += _outN; + } + else if (_depth == 16) + { + if (img_n != _outN) + { + cur[filter_bytes] = 255; // first pixel top byte + cur[filter_bytes + 1] = 255; // first pixel bottom byte + } + data += filter_bytes; + cur += output_bytes; + prior += output_bytes; + } + else + { + data += 1; + cur += 1; + prior += 1; + } + if (_depth < 8 || img_n == _outN) + { + int nk = (width_ - 1) * filter_bytes; +#define PNG__CASE(f) \ + case f: \ + for (k=0; k < nk; ++k) + switch (filter) { + case PNG__F_none: memcpy(cur, data, nk); break; + PNG__CASE(PNG__F_sub) { cur[k] = PNG__BYTECAST(data[k] + cur[k - filter_bytes]); } break; + PNG__CASE(PNG__F_up) { cur[k] = PNG__BYTECAST(data[k] + prior[k]); } break; + PNG__CASE(PNG__F_avg) { cur[k] = PNG__BYTECAST(data[k] + ((prior[k] + cur[k - filter_bytes]) >> 1)); } break; + PNG__CASE(PNG__F_paeth) { cur[k] = PNG__BYTECAST(data[k] + Paeth(cur[k - filter_bytes], prior[k], prior[k - filter_bytes])); } break; + PNG__CASE(PNG__F_avg_first) { cur[k] = PNG__BYTECAST(data[k] + (cur[k - filter_bytes] >> 1)); } break; + PNG__CASE(PNG__F_paeth_first) { cur[k] = PNG__BYTECAST(data[k] + Paeth(cur[k - filter_bytes], 0, 0)); } break; + } +#undef PNG__CASE + data += nk; + } + else + { + assert(img_n + 1 == _outN); +#define PNG__CASE(f) \ + case f: \ + for (i=width-1; i >= 1; --i, cur[filter_bytes]=255,data+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ + for (k=0; k < filter_bytes; ++k) + switch (filter) { + PNG__CASE(PNG__F_none) { cur[k] = data[k]; } break; + PNG__CASE(PNG__F_sub) { cur[k] = PNG__BYTECAST(data[k] + cur[k - output_bytes]); } break; + PNG__CASE(PNG__F_up) { cur[k] = PNG__BYTECAST(data[k] + prior[k]); } break; + PNG__CASE(PNG__F_avg) { cur[k] = PNG__BYTECAST(data[k] + ((prior[k] + cur[k - output_bytes]) >> 1)); } break; + PNG__CASE(PNG__F_paeth) { cur[k] = PNG__BYTECAST(data[k] + Paeth(cur[k - output_bytes], prior[k], prior[k - output_bytes])); } break; + PNG__CASE(PNG__F_avg_first) { cur[k] = PNG__BYTECAST(data[k] + (cur[k - output_bytes] >> 1)); } break; + PNG__CASE(PNG__F_paeth_first) { cur[k] = PNG__BYTECAST(data[k] + Paeth(cur[k - output_bytes], 0, 0)); } break; + } +#undef PNG__CASE + if (_depth == 16) + { + cur = _buffer.data + stride * j; + for (i = 0; i < width; ++i, cur += output_bytes) + cur[filter_bytes + 1] = 255; + } + } + } + if (_depth < 8) + { + for (j = 0; j < height; ++j) + { + uint8_t* cur = _buffer.data + stride * j; + const uint8_t* in = _buffer.data + stride * j + width * _outN - img_width_bytes; + uint8_t scale = (_color == 0) ? DepthScaleTable[_depth] : 1; + if (_depth == 4) + { + for (k = width * img_n; k >= 2; k -= 2, ++in) + { + *cur++ = scale * ((*in >> 4)); + *cur++ = scale * ((*in) & 0x0f); + } + if (k > 0) + *cur++ = scale * ((*in >> 4)); + } + else if (_depth == 2) + { + for (k = width * img_n; k >= 4; k -= 4, ++in) + { + *cur++ = scale * ((*in >> 6)); + *cur++ = scale * ((*in >> 4) & 0x03); + *cur++ = scale * ((*in >> 2) & 0x03); + *cur++ = scale * ((*in) & 0x03); + } + if (k > 0) + *cur++ = scale * ((*in >> 6)); + if (k > 1) + *cur++ = scale * ((*in >> 4) & 0x03); + if (k > 2) + *cur++ = scale * ((*in >> 2) & 0x03); + } + else if (_depth == 1) + { + for (k = width * img_n; k >= 8; k -= 8, ++in) + { + *cur++ = scale * ((*in >> 7)); + *cur++ = scale * ((*in >> 6) & 0x01); + *cur++ = scale * ((*in >> 5) & 0x01); + *cur++ = scale * ((*in >> 4) & 0x01); + *cur++ = scale * ((*in >> 3) & 0x01); + *cur++ = scale * ((*in >> 2) & 0x01); + *cur++ = scale * ((*in >> 1) & 0x01); + *cur++ = scale * ((*in) & 0x01); + } + if (k > 0) *cur++ = scale * ((*in >> 7)); + if (k > 1) *cur++ = scale * ((*in >> 6) & 0x01); + if (k > 2) *cur++ = scale * ((*in >> 5) & 0x01); + if (k > 3) *cur++ = scale * ((*in >> 4) & 0x01); + if (k > 4) *cur++ = scale * ((*in >> 3) & 0x01); + if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); + if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); + } + if (img_n != _outN) + { + int q; + cur = _buffer.data + stride * j; + if (img_n == 1) + { + for (q = width - 1; q >= 0; --q) + { + cur[q * 2 + 1] = 255; + cur[q * 2 + 0] = cur[q]; + } + } + else + { + assert(img_n == 3); + for (q = width - 1; q >= 0; --q) + { + cur[q * 4 + 3] = 255; + cur[q * 4 + 2] = cur[q * 3 + 2]; + cur[q * 4 + 1] = cur[q * 3 + 1]; + cur[q * 4 + 0] = cur[q * 3 + 0]; + } + } + } + } + } + else if (_depth == 16) + { + uint8_t* cur = _buffer.data; + uint16_t* cur16 = (uint16_t*)cur; + for (i = 0; i < width * height * _outN; ++i, cur16++, cur += 2) + *cur16 = (cur[0] << 8) | cur[1]; + } + return 1; + } + void ImagePngLoader::ExpandPalette() { if (_paletteChannels) diff --git a/src/Simd/SimdImageLoad.h b/src/Simd/SimdImageLoad.h index 2f11853e32..82c56fd137 100644 --- a/src/Simd/SimdImageLoad.h +++ b/src/Simd/SimdImageLoad.h @@ -187,6 +187,8 @@ namespace Simd bool ReadTransparency(const Chunk& chunk); bool ReadData(const Chunk& chunk); InputMemoryStream MergedDataStream(); + bool CreateImage(const uint8_t* data, size_t size); + bool CreateImageRaw(const uint8_t* data, uint32_t size, uint32_t width, uint32_t height); void ExpandPalette(); void ConvertImage(); }; From c83d50ed8c5420300dae5fecf93b3f174857632c Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Mon, 12 Jun 2023 20:40:04 +0300 Subject: [PATCH 06/44] *refactoring of class ImagePngLoader (part 4: add line decoders). --- src/Simd/SimdBaseImageLoadPng.cpp | 327 ++++++++++++++++++++---------- src/Simd/SimdImageLoad.h | 2 + 2 files changed, 220 insertions(+), 109 deletions(-) diff --git a/src/Simd/SimdBaseImageLoadPng.cpp b/src/Simd/SimdBaseImageLoadPng.cpp index f35e3256ce..f647315ffa 100644 --- a/src/Simd/SimdBaseImageLoadPng.cpp +++ b/src/Simd/SimdBaseImageLoadPng.cpp @@ -370,29 +370,208 @@ namespace Simd //------------------------------------------------------------------------------------------------- -#define PNG__BYTECAST(x) ((uint8_t) ((x) & 255)) // truncate int to byte without warnings + static const uint8_t DepthScaleTable[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; - enum + static void DecodeLine0(const uint8_t* curr, const uint8_t* prev, int width, int srcN, int dstN, uint8_t* dst) { - PNG__F_none = 0, - PNG__F_sub = 1, - PNG__F_up = 2, - PNG__F_avg = 3, - PNG__F_paeth = 4, - PNG__F_avg_first, - PNG__F_paeth_first - }; + if (srcN == dstN) + memcpy(dst, curr, width * srcN); + else + { + for (int x = 0; x < width; ++x) + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i]; + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + dst += dstN; + } + } + } - static uint8_t FirstRowFilter[5] = + static void DecodeLine1(const uint8_t* curr, const uint8_t* prev, int width, int srcN, int dstN, uint8_t* dst) { - PNG__F_none, - PNG__F_sub, - PNG__F_none, - PNG__F_avg_first, - PNG__F_paeth_first - }; + if (srcN == dstN) + { + for (int i = 0; i < srcN; ++i) + dst[i] = curr[i]; + for (int i = srcN, n = srcN * width; i < n; ++i) + dst[i] = curr[i] + dst[i - dstN]; + } + else + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i]; + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + dst += dstN; + for (int x = 1; x < width; ++x) + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i] + dst[i - dstN]; + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + dst += dstN; + } + } + } - static const uint8_t DepthScaleTable[9] = { 0, 0xff, 0x55, 0, 0x11, 0,0,0, 0x01 }; + static void DecodeLine2(const uint8_t* curr, const uint8_t* prev, int width, int srcN, int dstN, uint8_t* dst) + { + if (srcN == dstN) + { + for (int i = 0, n = srcN * width; i < n; ++i) + dst[i] = curr[i] + prev[i]; + } + else + { + for (int x = 0; x < width; ++x) + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i] + prev[i]; + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + prev += dstN; + dst += dstN; + } + } + } + + static void DecodeLine3(const uint8_t* curr, const uint8_t* prev, int width, int srcN, int dstN, uint8_t* dst) + { + if (srcN == dstN) + { + for (int i = 0; i < srcN; ++i) + dst[i] = curr[i] + (prev[i] >> 1); + for (int i = srcN, n = srcN * width; i < n; ++i) + dst[i] = curr[i] + ((prev[i] + dst[i - dstN]) >> 1); + } + else + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i] + (prev[i] >> 1); + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + prev += dstN; + dst += dstN; + for (int x = 1; x < width; ++x) + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i] + ((prev[i] + dst[i - dstN]) >> 1); + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + prev += dstN; + dst += dstN; + } + } + } + + static void DecodeLine4(const uint8_t* curr, const uint8_t* prev, int width, int srcN, int dstN, uint8_t* dst) + { + if (srcN == dstN) + { + for (int i = 0; i < srcN; ++i) + dst[i] = curr[i] + Paeth(0, prev[i], 0); + for (int i = srcN, n = srcN * width; i < n; ++i) + dst[i] = curr[i] + Paeth(dst[i - dstN], prev[i], prev[i - dstN]); + } + else + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i] + Paeth(0, prev[i], 0); + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + prev += dstN; + dst += dstN; + for (int x = 1; x < width; ++x) + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i] + Paeth(dst[i - dstN], prev[i], prev[i - dstN]); + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + prev += dstN; + dst += dstN; + } + } + } + + static void DecodeLine5(const uint8_t* curr, const uint8_t* prev, int width, int srcN, int dstN, uint8_t* dst) + { + if (srcN == dstN) + { + for (int i = 0; i < srcN; ++i) + dst[i] = curr[i]; + for (int i = srcN, n = srcN * width; i < n; ++i) + dst[i] = curr[i] + (dst[i - dstN] >> 1); + } + else + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i]; + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + dst += dstN; + for (int x = 1; x < width; ++x) + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i] + (dst[i - dstN] >> 1);; + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + dst += dstN; + } + } + } + + static void DecodeLine6(const uint8_t* curr, const uint8_t* prev, int width, int srcN, int dstN, uint8_t* dst) + { + if (srcN == dstN) + { + for (int i = 0; i < srcN; ++i) + dst[i] = curr[i]; + for (int i = srcN, n = srcN * width; i < n; ++i) + dst[i] = curr[i] + Paeth(dst[i - dstN], 0, 0); + } + else + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i]; + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + dst += dstN; + for (int x = 1; x < width; ++x) + { + int i = 0; + for (; i < srcN; ++i) + dst[i] = curr[i] + Paeth(dst[i - dstN], 0, 0); + for (; i < dstN; ++i) + dst[i] = 0xFF; + curr += srcN; + dst += dstN; + } + } + } //------------------------------------------------------------------------------------------------- @@ -612,6 +791,13 @@ namespace Simd { if (_param.format == SimdPixelFormatNone) _param.format = SimdPixelFormatRgba32; + _decodeLine[0] = Base::DecodeLine0; + _decodeLine[1] = Base::DecodeLine1; + _decodeLine[2] = Base::DecodeLine2; + _decodeLine[3] = Base::DecodeLine3; + _decodeLine[4] = Base::DecodeLine4; + _decodeLine[5] = Base::DecodeLine5; + _decodeLine[6] = Base::DecodeLine6; _expandPalette = Base::ExpandPalette; } @@ -901,15 +1087,15 @@ namespace Simd bool ImagePngLoader::CreateImageRaw(const uint8_t* data, uint32_t size, uint32_t width, uint32_t height) { + static uint8_t FirstRowFilter[5] = { 0, 1, 0, 5, 6 }; int bytes = (_depth == 16 ? 2 : 1); uint32_t i, j, stride = width * _outN * bytes; uint32_t img_len, img_width_bytes; int k; - int img_n = _channels; int width_ = width; int output_bytes = _outN * bytes; - int filter_bytes = img_n * bytes; + int filter_bytes = _channels * bytes; assert(_outN == _channels || _outN == _channels + 1); @@ -917,7 +1103,7 @@ namespace Simd if (_buffer.Empty()) return PngLoadError("outofmem", "Out of memory"); - img_width_bytes = (img_n * width * _depth + 7) >> 3; + img_width_bytes = (_channels * width * _depth + 7) >> 3; img_len = (img_width_bytes + 1) * height; if (size < img_len) @@ -940,91 +1126,14 @@ namespace Simd filter_bytes = 1; width_ = img_width_bytes; } - prior = cur - stride; // bugfix: need to compute this after 'cur +=' computation above + prior = cur - stride; if (j == 0) filter = FirstRowFilter[filter]; - for (k = 0; k < filter_bytes; ++k) - { - switch (filter) - { - case PNG__F_none: cur[k] = data[k]; break; - case PNG__F_sub: cur[k] = data[k]; break; - case PNG__F_up: cur[k] = PNG__BYTECAST(data[k] + prior[k]); break; - case PNG__F_avg: cur[k] = PNG__BYTECAST(data[k] + (prior[k] >> 1)); break; - case PNG__F_paeth: cur[k] = PNG__BYTECAST(data[k] + Paeth(0, prior[k], 0)); break; - case PNG__F_avg_first: cur[k] = data[k]; break; - case PNG__F_paeth_first: cur[k] = data[k]; break; - } - } - - if (_depth == 8) - { - if (img_n != _outN) - cur[img_n] = 255; // first pixel - data += img_n; - cur += _outN; - prior += _outN; - } - else if (_depth == 16) - { - if (img_n != _outN) - { - cur[filter_bytes] = 255; // first pixel top byte - cur[filter_bytes + 1] = 255; // first pixel bottom byte - } - data += filter_bytes; - cur += output_bytes; - prior += output_bytes; - } - else - { - data += 1; - cur += 1; - prior += 1; - } - if (_depth < 8 || img_n == _outN) - { - int nk = (width_ - 1) * filter_bytes; -#define PNG__CASE(f) \ - case f: \ - for (k=0; k < nk; ++k) - switch (filter) { - case PNG__F_none: memcpy(cur, data, nk); break; - PNG__CASE(PNG__F_sub) { cur[k] = PNG__BYTECAST(data[k] + cur[k - filter_bytes]); } break; - PNG__CASE(PNG__F_up) { cur[k] = PNG__BYTECAST(data[k] + prior[k]); } break; - PNG__CASE(PNG__F_avg) { cur[k] = PNG__BYTECAST(data[k] + ((prior[k] + cur[k - filter_bytes]) >> 1)); } break; - PNG__CASE(PNG__F_paeth) { cur[k] = PNG__BYTECAST(data[k] + Paeth(cur[k - filter_bytes], prior[k], prior[k - filter_bytes])); } break; - PNG__CASE(PNG__F_avg_first) { cur[k] = PNG__BYTECAST(data[k] + (cur[k - filter_bytes] >> 1)); } break; - PNG__CASE(PNG__F_paeth_first) { cur[k] = PNG__BYTECAST(data[k] + Paeth(cur[k - filter_bytes], 0, 0)); } break; - } -#undef PNG__CASE - data += nk; - } - else - { - assert(img_n + 1 == _outN); -#define PNG__CASE(f) \ - case f: \ - for (i=width-1; i >= 1; --i, cur[filter_bytes]=255,data+=filter_bytes,cur+=output_bytes,prior+=output_bytes) \ - for (k=0; k < filter_bytes; ++k) - switch (filter) { - PNG__CASE(PNG__F_none) { cur[k] = data[k]; } break; - PNG__CASE(PNG__F_sub) { cur[k] = PNG__BYTECAST(data[k] + cur[k - output_bytes]); } break; - PNG__CASE(PNG__F_up) { cur[k] = PNG__BYTECAST(data[k] + prior[k]); } break; - PNG__CASE(PNG__F_avg) { cur[k] = PNG__BYTECAST(data[k] + ((prior[k] + cur[k - output_bytes]) >> 1)); } break; - PNG__CASE(PNG__F_paeth) { cur[k] = PNG__BYTECAST(data[k] + Paeth(cur[k - output_bytes], prior[k], prior[k - output_bytes])); } break; - PNG__CASE(PNG__F_avg_first) { cur[k] = PNG__BYTECAST(data[k] + (cur[k - output_bytes] >> 1)); } break; - PNG__CASE(PNG__F_paeth_first) { cur[k] = PNG__BYTECAST(data[k] + Paeth(cur[k - output_bytes], 0, 0)); } break; - } -#undef PNG__CASE - if (_depth == 16) - { - cur = _buffer.data + stride * j; - for (i = 0; i < width; ++i, cur += output_bytes) - cur[filter_bytes + 1] = 255; - } - } + int size = (_depth < 8 || _channels == _outN ? width_ : width); + int dstN = _depth < 8 || _channels == _outN ? filter_bytes : output_bytes; + _decodeLine[filter](data, cur - stride, size, filter_bytes, dstN, cur); + data += size * filter_bytes; } if (_depth < 8) { @@ -1035,7 +1144,7 @@ namespace Simd uint8_t scale = (_color == 0) ? DepthScaleTable[_depth] : 1; if (_depth == 4) { - for (k = width * img_n; k >= 2; k -= 2, ++in) + for (k = width * _channels; k >= 2; k -= 2, ++in) { *cur++ = scale * ((*in >> 4)); *cur++ = scale * ((*in) & 0x0f); @@ -1045,7 +1154,7 @@ namespace Simd } else if (_depth == 2) { - for (k = width * img_n; k >= 4; k -= 4, ++in) + for (k = width * _channels; k >= 4; k -= 4, ++in) { *cur++ = scale * ((*in >> 6)); *cur++ = scale * ((*in >> 4) & 0x03); @@ -1061,7 +1170,7 @@ namespace Simd } else if (_depth == 1) { - for (k = width * img_n; k >= 8; k -= 8, ++in) + for (k = width * _channels; k >= 8; k -= 8, ++in) { *cur++ = scale * ((*in >> 7)); *cur++ = scale * ((*in >> 6) & 0x01); @@ -1080,11 +1189,11 @@ namespace Simd if (k > 5) *cur++ = scale * ((*in >> 2) & 0x01); if (k > 6) *cur++ = scale * ((*in >> 1) & 0x01); } - if (img_n != _outN) + if (_channels != _outN) { int q; cur = _buffer.data + stride * j; - if (img_n == 1) + if (_channels == 1) { for (q = width - 1; q >= 0; --q) { @@ -1094,7 +1203,7 @@ namespace Simd } else { - assert(img_n == 3); + assert(_channels == 3); for (q = width - 1; q >= 0; --q) { cur[q * 4 + 3] = 255; diff --git a/src/Simd/SimdImageLoad.h b/src/Simd/SimdImageLoad.h index 82c56fd137..e0184bd60c 100644 --- a/src/Simd/SimdImageLoad.h +++ b/src/Simd/SimdImageLoad.h @@ -154,11 +154,13 @@ namespace Simd virtual bool FromStream(); + typedef void (*DecodeLinePtr)(const uint8_t* curr, const uint8_t* prev, int width, int srcN, int dstN, uint8_t* dst); typedef void (*ExpandPalettePtr)(const uint8_t* src, size_t size, int outN, const uint8_t* palette, uint8_t* dst); typedef void (*ConverterPtr)(const uint8_t* src, size_t width, size_t height, size_t srcStride, uint8_t* dst, size_t dstStride); protected: + DecodeLinePtr _decodeLine[7]; ExpandPalettePtr _expandPalette; ConverterPtr _converter; virtual void SetConverter(); From 7b3fd80e03051672848f0ccaa310e19f729b6a52 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 13 Jun 2023 12:17:59 +0300 Subject: [PATCH 07/44] *refactoring of ImageJpegLoader (part 1). --- prj/vs2019/Base.vcxproj | 1 + prj/vs2019/Base.vcxproj.filters | 3 + prj/vs2019/Sse41.vcxproj | 2 + prj/vs2019/Sse41.vcxproj.filters | 6 + prj/vs2022/Base.vcxproj | 1 + prj/vs2022/Base.vcxproj.filters | 3 + prj/vs2022/Sse41.vcxproj | 2 + prj/vs2022/Sse41.vcxproj.filters | 6 + src/Simd/SimdBaseImageLoadJpeg.cpp | 118 +- src/Simd/SimdBaseImageLoadPng.cpp | 13 +- src/Simd/SimdImageLoad.h | 8 + src/Simd/SimdImageLoadJpeg.h | 65 + src/Simd/SimdSse41ImageLoad.cpp | 2 +- src/Simd/SimdSse41ImageLoadJpeg.cpp | 2243 +++++++++++++++++++++++++++ src/Test/TestImageIO.cpp | 2 +- 15 files changed, 2402 insertions(+), 73 deletions(-) create mode 100644 src/Simd/SimdImageLoadJpeg.h create mode 100644 src/Simd/SimdSse41ImageLoadJpeg.cpp diff --git a/prj/vs2019/Base.vcxproj b/prj/vs2019/Base.vcxproj index 16465b567b..fe6321e8c2 100644 --- a/prj/vs2019/Base.vcxproj +++ b/prj/vs2019/Base.vcxproj @@ -48,6 +48,7 @@ + diff --git a/prj/vs2019/Base.vcxproj.filters b/prj/vs2019/Base.vcxproj.filters index 9472abaff6..69ea38857d 100644 --- a/prj/vs2019/Base.vcxproj.filters +++ b/prj/vs2019/Base.vcxproj.filters @@ -558,6 +558,9 @@ Inc + + Inc + diff --git a/prj/vs2019/Sse41.vcxproj b/prj/vs2019/Sse41.vcxproj index a2e30fc8c9..83398668ff 100644 --- a/prj/vs2019/Sse41.vcxproj +++ b/prj/vs2019/Sse41.vcxproj @@ -51,6 +51,7 @@ + @@ -165,6 +166,7 @@ + diff --git a/prj/vs2019/Sse41.vcxproj.filters b/prj/vs2019/Sse41.vcxproj.filters index e97b5543c7..ba52d13052 100644 --- a/prj/vs2019/Sse41.vcxproj.filters +++ b/prj/vs2019/Sse41.vcxproj.filters @@ -373,6 +373,9 @@ Sse41 + + Sse41 + @@ -617,5 +620,8 @@ Inc + + Inc + \ No newline at end of file diff --git a/prj/vs2022/Base.vcxproj b/prj/vs2022/Base.vcxproj index 16465b567b..fe6321e8c2 100644 --- a/prj/vs2022/Base.vcxproj +++ b/prj/vs2022/Base.vcxproj @@ -48,6 +48,7 @@ + diff --git a/prj/vs2022/Base.vcxproj.filters b/prj/vs2022/Base.vcxproj.filters index 9472abaff6..69ea38857d 100644 --- a/prj/vs2022/Base.vcxproj.filters +++ b/prj/vs2022/Base.vcxproj.filters @@ -558,6 +558,9 @@ Inc + + Inc + diff --git a/prj/vs2022/Sse41.vcxproj b/prj/vs2022/Sse41.vcxproj index a2e30fc8c9..83398668ff 100644 --- a/prj/vs2022/Sse41.vcxproj +++ b/prj/vs2022/Sse41.vcxproj @@ -51,6 +51,7 @@ + @@ -165,6 +166,7 @@ + diff --git a/prj/vs2022/Sse41.vcxproj.filters b/prj/vs2022/Sse41.vcxproj.filters index e97b5543c7..ba52d13052 100644 --- a/prj/vs2022/Sse41.vcxproj.filters +++ b/prj/vs2022/Sse41.vcxproj.filters @@ -373,6 +373,9 @@ Sse41 + + Sse41 + @@ -617,5 +620,8 @@ Inc + + Inc + \ No newline at end of file diff --git a/src/Simd/SimdBaseImageLoadJpeg.cpp b/src/Simd/SimdBaseImageLoadJpeg.cpp index d7acfb6cda..88b095afbd 100644 --- a/src/Simd/SimdBaseImageLoadJpeg.cpp +++ b/src/Simd/SimdBaseImageLoadJpeg.cpp @@ -22,6 +22,7 @@ * SOFTWARE. */ #include "Simd/SimdImageLoad.h" +#include "Simd/SimdImageLoadJpeg.h" #include "Simd/SimdArray.h" #include "Simd/SimdCpu.h" #include "Simd/SimdBase.h" @@ -56,9 +57,6 @@ namespace Simd int (*eof) (void* user); // returns nonzero if we are at end of file/data } jpeg_io_callbacks; -#define jpeg_inline SIMD_INLINE -#define JPEG_ASSERT assert - #ifdef _MSC_VER #define JPEG_NOTUSED(v) (void)(v) #else @@ -82,19 +80,7 @@ namespace Simd jpeg_uc* img_buffer_original, * img_buffer_original_end; } jpeg__context; - static int jpeg__err(const char* str) - { - //jpeg__g_failure_reason = str; - return 0; - } - - static int jpeg__err(const char* str1, const char* str2) - { - //jpeg__g_failure_reason = str; - return 0; - } - -#define jpeg__errpuc(x,y) ((unsigned char *)(size_t) (jpeg__err(x,y)?NULL:NULL)) +#define jpeg__errpuc(x,y) ((unsigned char *)(size_t) (JpegLoadError(x,y)?NULL:NULL)) static void jpeg__refill_buffer(jpeg__context* s) { @@ -114,7 +100,7 @@ namespace Simd } } - jpeg_inline static jpeg_uc jpeg__get8(jpeg__context* s) + SIMD_INLINE static jpeg_uc jpeg__get8(jpeg__context* s) { if (s->img_buffer < s->img_buffer_end) return *s->img_buffer++; @@ -153,7 +139,7 @@ namespace Simd s->img_buffer += n; } - jpeg_inline static int jpeg__at_eof(jpeg__context* s) + SIMD_INLINE static int jpeg__at_eof(jpeg__context* s) { if (s->io.read) { if (!(s->io.eof)(s->io_user_data)) return 0; @@ -341,7 +327,7 @@ namespace Simd if (h->size[k] == j) { while (h->size[k] == j) h->code[k++] = (jpeg__uint16)(code++); - if (code - 1 >= (1u << j)) return jpeg__err("bad code lengths", "Corrupt JPEG"); + if (code - 1 >= (1u << j)) return JpegLoadError("bad code lengths", "Corrupt JPEG"); } // compute largest code + 1 for this size, preshifted as needed later h->maxcode[j] = code << (16 - j); @@ -413,7 +399,7 @@ namespace Simd static const jpeg__uint32 jpeg__bmask[17] = { 0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535 }; // decode a jpeg huffman value from the bitstream - jpeg_inline static int jpeg__jpeg_huff_decode(jpeg__jpeg* j, jpeg__huffman* h) + SIMD_INLINE static int jpeg__jpeg_huff_decode(jpeg__jpeg* j, jpeg__huffman* h) { unsigned int temp; int c, k; @@ -454,7 +440,7 @@ namespace Simd // convert the huffman code to the symbol id c = ((j->code_buffer >> (32 - k)) & jpeg__bmask[k]) + h->delta[k]; - JPEG_ASSERT((((j->code_buffer) >> (32 - h->size[c])) & jpeg__bmask[h->size[c]]) == h->code[c]); + assert((((j->code_buffer) >> (32 - h->size[c])) & jpeg__bmask[h->size[c]]) == h->code[c]); // convert the id to a symbol j->code_bits -= k; @@ -467,7 +453,7 @@ namespace Simd // combined JPEG 'receive' and JPEG 'extend', since baseline // always extends everything it receives. - jpeg_inline static int jpeg__extend_receive(jpeg__jpeg* j, int n) + SIMD_INLINE static int jpeg__extend_receive(jpeg__jpeg* j, int n) { unsigned int k; int sgn; @@ -483,7 +469,7 @@ namespace Simd } // get some unsigned bits - jpeg_inline static int jpeg__jpeg_get_bits(jpeg__jpeg* j, int n) + SIMD_INLINE static int jpeg__jpeg_get_bits(jpeg__jpeg* j, int n) { unsigned int k; if (j->code_bits < n) jpeg__grow_buffer_unsafe(j); @@ -494,7 +480,7 @@ namespace Simd return k; } - jpeg_inline static int jpeg__jpeg_get_bit(jpeg__jpeg* j) + SIMD_INLINE static int jpeg__jpeg_get_bit(jpeg__jpeg* j) { unsigned int k; if (j->code_bits < 1) jpeg__grow_buffer_unsafe(j); @@ -529,7 +515,7 @@ namespace Simd if (j->code_bits < 16) jpeg__grow_buffer_unsafe(j); t = jpeg__jpeg_huff_decode(j, hdc); - if (t < 0) return jpeg__err("bad huffman code", "Corrupt JPEG"); + if (t < 0) return JpegLoadError("bad huffman code", "Corrupt JPEG"); // 0 all the ac values now so we can do it 32-bits at a time memset(data, 0, 64 * sizeof(data[0])); @@ -558,7 +544,7 @@ namespace Simd } else { int rs = jpeg__jpeg_huff_decode(j, hac); - if (rs < 0) return jpeg__err("bad huffman code", "Corrupt JPEG"); + if (rs < 0) return JpegLoadError("bad huffman code", "Corrupt JPEG"); s = rs & 15; r = rs >> 4; if (s == 0) { @@ -580,7 +566,7 @@ namespace Simd { int diff, dc; int t; - if (j->spec_end != 0) return jpeg__err("can't merge dc and ac", "Corrupt JPEG"); + if (j->spec_end != 0) return JpegLoadError("can't merge dc and ac", "Corrupt JPEG"); if (j->code_bits < 16) jpeg__grow_buffer_unsafe(j); @@ -588,7 +574,7 @@ namespace Simd // first scan for DC coefficient, must be first memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now t = jpeg__jpeg_huff_decode(j, hdc); - if (t == -1) return jpeg__err("can't merge dc and ac", "Corrupt JPEG"); + if (t == -1) return JpegLoadError("can't merge dc and ac", "Corrupt JPEG"); diff = t ? jpeg__extend_receive(j, t) : 0; dc = j->img_comp[b].dc_pred + diff; @@ -608,7 +594,7 @@ namespace Simd static int jpeg__jpeg_decode_block_prog_ac(jpeg__jpeg* j, short data[64], jpeg__huffman* hac, jpeg__int16* fac) { int k; - if (j->spec_start == 0) return jpeg__err("can't merge dc and ac", "Corrupt JPEG"); + if (j->spec_start == 0) return JpegLoadError("can't merge dc and ac", "Corrupt JPEG"); if (j->succ_high == 0) { int shift = j->succ_low; @@ -635,7 +621,7 @@ namespace Simd } else { int rs = jpeg__jpeg_huff_decode(j, hac); - if (rs < 0) return jpeg__err("bad huffman code", "Corrupt JPEG"); + if (rs < 0) return JpegLoadError("bad huffman code", "Corrupt JPEG"); s = rs & 15; r = rs >> 4; if (s == 0) { @@ -680,7 +666,7 @@ namespace Simd do { int r, s; int rs = jpeg__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh - if (rs < 0) return jpeg__err("bad huffman code", "Corrupt JPEG"); + if (rs < 0) return JpegLoadError("bad huffman code", "Corrupt JPEG"); s = rs & 15; r = rs >> 4; if (s == 0) { @@ -697,7 +683,7 @@ namespace Simd } } else { - if (s != 1) return jpeg__err("bad huffman code", "Corrupt JPEG"); + if (s != 1) return JpegLoadError("bad huffman code", "Corrupt JPEG"); // sign bit if (jpeg__jpeg_get_bit(j)) s = bit; @@ -732,7 +718,7 @@ namespace Simd } // take a -128..127 value and jpeg__clamp it and convert to 0..255 - jpeg_inline static jpeg_uc jpeg__clamp(int x) + SIMD_INLINE static jpeg_uc jpeg__clamp(int x) { // trick to use a single test to catch both cases if ((unsigned int)x > 255) { @@ -1425,10 +1411,10 @@ namespace Simd int L; switch (m) { case JPEG__MARKER_none: // no marker found - return jpeg__err("expected marker", "Corrupt JPEG"); + return JpegLoadError("expected marker", "Corrupt JPEG"); case 0xDD: // DRI - specify restart interval - if (jpeg__get16be(z->s) != 4) return jpeg__err("bad DRI len", "Corrupt JPEG"); + if (jpeg__get16be(z->s) != 4) return JpegLoadError("bad DRI len", "Corrupt JPEG"); z->restart_interval = jpeg__get16be(z->s); return 1; @@ -1438,8 +1424,8 @@ namespace Simd int q = jpeg__get8(z->s); int p = q >> 4, sixteen = (p != 0); int t = q & 15, i; - if (p != 0 && p != 1) return jpeg__err("bad DQT type", "Corrupt JPEG"); - if (t > 3) return jpeg__err("bad DQT table", "Corrupt JPEG"); + if (p != 0 && p != 1) return JpegLoadError("bad DQT type", "Corrupt JPEG"); + if (t > 3) return JpegLoadError("bad DQT table", "Corrupt JPEG"); for (i = 0; i < 64; ++i) z->dequant[t][jpeg__jpeg_dezigzag[i]] = (jpeg__uint16)(sixteen ? jpeg__get16be(z->s) : jpeg__get8(z->s)); @@ -1455,7 +1441,7 @@ namespace Simd int q = jpeg__get8(z->s); int tc = q >> 4; int th = q & 15; - if (tc > 1 || th > 3) return jpeg__err("bad DHT header", "Corrupt JPEG"); + if (tc > 1 || th > 3) return JpegLoadError("bad DHT header", "Corrupt JPEG"); for (i = 0; i < 16; ++i) { sizes[i] = jpeg__get8(z->s); n += sizes[i]; @@ -1483,9 +1469,9 @@ namespace Simd L = jpeg__get16be(z->s); if (L < 2) { if (m == 0xFE) - return jpeg__err("bad COM len", "Corrupt JPEG"); + return JpegLoadError("bad COM len", "Corrupt JPEG"); else - return jpeg__err("bad APP len", "Corrupt JPEG"); + return JpegLoadError("bad APP len", "Corrupt JPEG"); } L -= 2; @@ -1521,7 +1507,7 @@ namespace Simd return 1; } - return jpeg__err("unknown marker", "Corrupt JPEG"); + return JpegLoadError("unknown marker", "Corrupt JPEG"); } // after we see SOS @@ -1530,8 +1516,8 @@ namespace Simd int i; int Ls = jpeg__get16be(z->s); z->scan_n = jpeg__get8(z->s); - if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) return jpeg__err("bad SOS component count", "Corrupt JPEG"); - if (Ls != 6 + 2 * z->scan_n) return jpeg__err("bad SOS len", "Corrupt JPEG"); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) return JpegLoadError("bad SOS component count", "Corrupt JPEG"); + if (Ls != 6 + 2 * z->scan_n) return JpegLoadError("bad SOS len", "Corrupt JPEG"); for (i = 0; i < z->scan_n; ++i) { int id = jpeg__get8(z->s), which; int q = jpeg__get8(z->s); @@ -1539,8 +1525,8 @@ namespace Simd if (z->img_comp[which].id == id) break; if (which == z->s->img_n) return 0; // no match - z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return jpeg__err("bad DC huff", "Corrupt JPEG"); - z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return jpeg__err("bad AC huff", "Corrupt JPEG"); + z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return JpegLoadError("bad DC huff", "Corrupt JPEG"); + z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return JpegLoadError("bad AC huff", "Corrupt JPEG"); z->order[i] = which; } @@ -1553,11 +1539,11 @@ namespace Simd z->succ_low = (aa & 15); if (z->progressive) { if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) - return jpeg__err("bad SOS", "Corrupt JPEG"); + return JpegLoadError("bad SOS", "Corrupt JPEG"); } else { - if (z->spec_start != 0) return jpeg__err("bad SOS", "Corrupt JPEG"); - if (z->succ_high != 0 || z->succ_low != 0) return jpeg__err("bad SOS", "Corrupt JPEG"); + if (z->spec_start != 0) return JpegLoadError("bad SOS", "Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) return JpegLoadError("bad SOS", "Corrupt JPEG"); z->spec_end = 63; } } @@ -1591,21 +1577,21 @@ namespace Simd { jpeg__context* s = z->s; int Lf, p, i, q, h_max = 1, v_max = 1, c; - Lf = jpeg__get16be(s); if (Lf < 11) return jpeg__err("bad SOF len", "Corrupt JPEG"); // JPEG - p = jpeg__get8(s); if (p != 8) return jpeg__err("only 8-bit", "JPEG format not supported: 8-bit only"); // JPEG baseline - s->img_y = jpeg__get16be(s); if (s->img_y == 0) return jpeg__err("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG - s->img_x = jpeg__get16be(s); if (s->img_x == 0) return jpeg__err("0 width", "Corrupt JPEG"); // JPEG requires - if (s->img_y > JPEG_MAX_DIMENSIONS) return jpeg__err("too large", "Very large image (corrupt?)"); - if (s->img_x > JPEG_MAX_DIMENSIONS) return jpeg__err("too large", "Very large image (corrupt?)"); + Lf = jpeg__get16be(s); if (Lf < 11) return JpegLoadError("bad SOF len", "Corrupt JPEG"); // JPEG + p = jpeg__get8(s); if (p != 8) return JpegLoadError("only 8-bit", "JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = jpeg__get16be(s); if (s->img_y == 0) return JpegLoadError("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG + s->img_x = jpeg__get16be(s); if (s->img_x == 0) return JpegLoadError("0 width", "Corrupt JPEG"); // JPEG requires + if (s->img_y > JPEG_MAX_DIMENSIONS) return JpegLoadError("too large", "Very large image (corrupt?)"); + if (s->img_x > JPEG_MAX_DIMENSIONS) return JpegLoadError("too large", "Very large image (corrupt?)"); c = jpeg__get8(s); - if (c != 3 && c != 1 && c != 4) return jpeg__err("bad component count", "Corrupt JPEG"); + if (c != 3 && c != 1 && c != 4) return JpegLoadError("bad component count", "Corrupt JPEG"); s->img_n = c; for (i = 0; i < c; ++i) { z->img_comp[i].data = NULL; z->img_comp[i].linebuf = NULL; } - if (Lf != 8 + 3 * s->img_n) return jpeg__err("bad SOF len", "Corrupt JPEG"); + if (Lf != 8 + 3 * s->img_n) return JpegLoadError("bad SOF len", "Corrupt JPEG"); z->rgb = 0; for (i = 0; i < s->img_n; ++i) { @@ -1614,14 +1600,14 @@ namespace Simd if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) ++z->rgb; q = jpeg__get8(s); - z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return jpeg__err("bad H", "Corrupt JPEG"); - z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return jpeg__err("bad V", "Corrupt JPEG"); - z->img_comp[i].tq = jpeg__get8(s); if (z->img_comp[i].tq > 3) return jpeg__err("bad TQ", "Corrupt JPEG"); + z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return JpegLoadError("bad H", "Corrupt JPEG"); + z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return JpegLoadError("bad V", "Corrupt JPEG"); + z->img_comp[i].tq = jpeg__get8(s); if (z->img_comp[i].tq > 3) return JpegLoadError("bad TQ", "Corrupt JPEG"); } if (scan != JPEG__SCAN_load) return 1; - if (!jpeg__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return jpeg__err("too large", "Image too large to decode"); + if (!jpeg__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return JpegLoadError("too large", "Image too large to decode"); for (i = 0; i < s->img_n; ++i) { if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; @@ -1655,7 +1641,7 @@ namespace Simd z->img_comp[i].linebuf = NULL; z->img_comp[i].raw_data = jpeg__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); if (z->img_comp[i].raw_data == NULL) - return jpeg__free_jpeg_components(z, i + 1, jpeg__err("outofmem", "Out of memory")); + return jpeg__free_jpeg_components(z, i + 1, JpegLoadError("outofmem", "Out of memory")); // align blocks for idct using mmx/sse z->img_comp[i].data = (jpeg_uc*)(((size_t)z->img_comp[i].raw_data + 15) & ~15); if (z->progressive) { @@ -1664,7 +1650,7 @@ namespace Simd z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; z->img_comp[i].raw_coeff = jpeg__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); if (z->img_comp[i].raw_coeff == NULL) - return jpeg__free_jpeg_components(z, i + 1, jpeg__err("outofmem", "Out of memory")); + return jpeg__free_jpeg_components(z, i + 1, JpegLoadError("outofmem", "Out of memory")); z->img_comp[i].coeff = (short*)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); } } @@ -1688,7 +1674,7 @@ namespace Simd z->app14_color_transform = -1; // valid values are 0,1,2 z->marker = JPEG__MARKER_none; // initialize cached marker to empty m = jpeg__get_marker(z); - if (!jpeg__SOI(m)) return jpeg__err("no SOI", "Corrupt JPEG"); + if (!jpeg__SOI(m)) return JpegLoadError("no SOI", "Corrupt JPEG"); if (scan == JPEG__SCAN_type) return 1; m = jpeg__get_marker(z); while (!jpeg__SOF(m)) { @@ -1696,7 +1682,7 @@ namespace Simd m = jpeg__get_marker(z); while (m == JPEG__MARKER_none) { // some files have extra padding after their blocks, so ok, we'll scan - if (jpeg__at_eof(z->s)) return jpeg__err("no SOF", "Corrupt JPEG"); + if (jpeg__at_eof(z->s)) return JpegLoadError("no SOF", "Corrupt JPEG"); m = jpeg__get_marker(z); } } @@ -1735,8 +1721,8 @@ namespace Simd else if (jpeg__DNL(m)) { int Ld = jpeg__get16be(j->s); jpeg__uint32 NL = jpeg__get16be(j->s); - if (Ld != 4) return jpeg__err("bad DNL len", "Corrupt JPEG"); - if (NL != j->s->img_y) return jpeg__err("bad DNL height", "Corrupt JPEG"); + if (Ld != 4) return JpegLoadError("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) return JpegLoadError("bad DNL height", "Corrupt JPEG"); } else { if (!jpeg__process_marker(j, m)) return 0; diff --git a/src/Simd/SimdBaseImageLoadPng.cpp b/src/Simd/SimdBaseImageLoadPng.cpp index f647315ffa..1c793a0395 100644 --- a/src/Simd/SimdBaseImageLoadPng.cpp +++ b/src/Simd/SimdBaseImageLoadPng.cpp @@ -817,6 +817,8 @@ namespace Simd bool ImagePngLoader::FromStream() { + SIMD_PERF_FUNC(); + if (!ParseFile()) return false; @@ -1050,16 +1052,17 @@ namespace Simd bool ImagePngLoader::CreateImage(const uint8_t* data, size_t size) { SIMD_PERF_FUNC(); + int outS = _outN * (_depth == 16 ? 2 : 1); if (!_interlace) return CreateImageRaw(data, (int)size, _width, _height); Array8u buf(_width * _height * outS); for (int p = 0; p < 7; ++p) { - int xorig[] = { 0,4,0,2,0,1,0 }; - int yorig[] = { 0,0,4,0,2,0,1 }; - int xspc[] = { 8,8,4,4,2,2,1 }; - int yspc[] = { 8,8,8,4,4,2,2 }; + static const int xorig[] = { 0,4,0,2,0,1,0 }; + static const int yorig[] = { 0,0,4,0,2,0,1 }; + static const int xspc[] = { 8,8,4,4,2,2,1 }; + static const int yspc[] = { 8,8,8,4,4,2,2 }; int i, j, x, y; x = (_width - xorig[p] + xspc[p] - 1) / xspc[p]; y = (_height - yorig[p] + yspc[p] - 1) / yspc[p]; @@ -1087,7 +1090,7 @@ namespace Simd bool ImagePngLoader::CreateImageRaw(const uint8_t* data, uint32_t size, uint32_t width, uint32_t height) { - static uint8_t FirstRowFilter[5] = { 0, 1, 0, 5, 6 }; + static const uint8_t FirstRowFilter[5] = { 0, 1, 0, 5, 6 }; int bytes = (_depth == 16 ? 2 : 1); uint32_t i, j, stride = width * _outN * bytes; uint32_t img_len, img_width_bytes; diff --git a/src/Simd/SimdImageLoad.h b/src/Simd/SimdImageLoad.h index e0184bd60c..0d83d6e9f9 100644 --- a/src/Simd/SimdImageLoad.h +++ b/src/Simd/SimdImageLoad.h @@ -255,6 +255,14 @@ namespace Simd virtual bool FromStream(); }; + class ImageJpegLoader : public Base::ImageJpegLoader + { + public: + ImageJpegLoader(const ImageLoaderParam& param); + + virtual bool FromStream(); + }; + //--------------------------------------------------------------------- uint8_t* ImageLoadFromMemory(const uint8_t* data, size_t size, size_t* stride, size_t* width, size_t* height, SimdPixelFormatType* format); diff --git a/src/Simd/SimdImageLoadJpeg.h b/src/Simd/SimdImageLoadJpeg.h new file mode 100644 index 0000000000..bdb48e7734 --- /dev/null +++ b/src/Simd/SimdImageLoadJpeg.h @@ -0,0 +1,65 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2021 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#ifndef __SimdImageLoadJpeg_h__ +#define __SimdImageLoadJpeg_h__ + +#include "Simd/SimdImageLoad.h" + +namespace Simd +{ + namespace Base + { + static SIMD_INLINE int JpegLoadError(const char* text, const char* type) + { + std::cout << "JPEG load error: " << text << ", " << type << "!" << std::endl; + return 0; + } + } + +#ifdef SIMD_SSE41_ENABLE + namespace Sse41 + { + } +#endif + +#ifdef SIMD_AVX2_ENABLE + namespace Avx2 + { + } +#endif + +#ifdef SIMD_AVX512BW_ENABLE + namespace Avx512bw + { + } +#endif + +#ifdef SIMD_NEON_ENABLE + namespace Neon + { + } +#endif +} + +#endif diff --git a/src/Simd/SimdSse41ImageLoad.cpp b/src/Simd/SimdSse41ImageLoad.cpp index 317abe95ad..d0b7900ba9 100644 --- a/src/Simd/SimdSse41ImageLoad.cpp +++ b/src/Simd/SimdSse41ImageLoad.cpp @@ -134,7 +134,7 @@ namespace Simd case SimdImageFilePpmTxt: return new ImagePpmTxtLoader(param); case SimdImageFilePpmBin: return new ImagePpmBinLoader(param); case SimdImageFilePng: return new ImagePngLoader(param); - case SimdImageFileJpeg: return new Base::ImageJpegLoader(param); + case SimdImageFileJpeg: return new ImageJpegLoader(param); default: return NULL; } diff --git a/src/Simd/SimdSse41ImageLoadJpeg.cpp b/src/Simd/SimdSse41ImageLoadJpeg.cpp new file mode 100644 index 0000000000..ac41a75962 --- /dev/null +++ b/src/Simd/SimdSse41ImageLoadJpeg.cpp @@ -0,0 +1,2243 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2022 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdImageLoad.h" +#include "Simd/SimdImageLoadJpeg.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdBase.h" +#include "Simd/SimdSse41.h" + +namespace Simd +{ +#ifdef SIMD_SSE41_ENABLE + namespace Sse41 + { +#if defined(SIMD_X64_ENABLE) && !defined(SIMD_SSE41_DISABLE) +#define JPEG_SSE2 + static int jpeg__sse2_available(void) + { + return 1; + } +#endif + +#if defined(SIMD_ARM64_ENABLE) && !defined(SIMD_NEON_DISABLE) +#define JPEG_NEON +#endif + + typedef unsigned char jpeg_uc; + typedef unsigned short jpeg_us; + typedef unsigned short jpeg__uint16; + typedef signed short jpeg__int16; + typedef unsigned int jpeg__uint32; + typedef signed int jpeg__int32; + + typedef struct + { + int (*read) (void* user, char* data, int size); // fill 'data' with 'size' bytes. return number of bytes actually read + void (*skip) (void* user, int n); // skip the next 'n' bytes, or 'unget' the last -n bytes if negative + int (*eof) (void* user); // returns nonzero if we are at end of file/data + } jpeg_io_callbacks; + +#ifdef _MSC_VER +#define JPEG_NOTUSED(v) (void)(v) +#else +#define JPEG_NOTUSED(v) (void)sizeof(v) +#endif + + typedef struct + { + jpeg__uint32 img_x, img_y; + int img_n, img_out_n; + + jpeg_io_callbacks io; + void* io_user_data; + + int read_from_callbacks; + int buflen; + jpeg_uc buffer_start[128]; + int callback_already_read; + + jpeg_uc* img_buffer, * img_buffer_end; + jpeg_uc* img_buffer_original, * img_buffer_original_end; + } jpeg__context; + + static int JpegLoadError(const char* str1, const char* str2) + { + //jpeg__g_failure_reason = str; + return 0; + } + +#define jpeg__errpuc(x,y) ((unsigned char *)(size_t) (JpegLoadError(x,y)?NULL:NULL)) + + static void jpeg__refill_buffer(jpeg__context* s) + { + int n = (s->io.read)(s->io_user_data, (char*)s->buffer_start, s->buflen); + s->callback_already_read += (int)(s->img_buffer - s->img_buffer_original); + if (n == 0) { + // at end of file, treat same as if from memory, but need to handle case + // where s->img_buffer isn't pointing to safe memory, e.g. 0-byte file + s->read_from_callbacks = 0; + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + 1; + *s->img_buffer = 0; + } + else { + s->img_buffer = s->buffer_start; + s->img_buffer_end = s->buffer_start + n; + } + } + + SIMD_INLINE static jpeg_uc jpeg__get8(jpeg__context* s) + { + if (s->img_buffer < s->img_buffer_end) + return *s->img_buffer++; + if (s->read_from_callbacks) { + jpeg__refill_buffer(s); + return *s->img_buffer++; + } + return 0; + } + +#define jpeg_lrot(x,y) (((x) << (y)) | ((x) >> (32 - (y)))) + +#define JPEG_SIMD_ALIGN(type, name) SIMD_ALIGNED(16) type name + + static int jpeg__get16be(jpeg__context* s) + { + int z = jpeg__get8(s); + return (z << 8) + jpeg__get8(s); + } + + static void jpeg__skip(jpeg__context* s, int n) + { + if (n == 0) return; // already there! + if (n < 0) { + s->img_buffer = s->img_buffer_end; + return; + } + if (s->io.read) { + int blen = (int)(s->img_buffer_end - s->img_buffer); + if (blen < n) { + s->img_buffer = s->img_buffer_end; + (s->io.skip)(s->io_user_data, n - blen); + return; + } + } + s->img_buffer += n; + } + + SIMD_INLINE static int jpeg__at_eof(jpeg__context* s) + { + if (s->io.read) { + if (!(s->io.eof)(s->io_user_data)) return 0; + // if feof() is true, check if buffer = end + // special case: we've only got the special 0 character at the end + if (s->read_from_callbacks == 0) return 1; + } + + return s->img_buffer >= s->img_buffer_end; + } + +#define JPEG_MALLOC(sz) malloc(sz) +#define JPEG_REALLOC(p,newsz) realloc(p,newsz) +#define JPEG_FREE(p) free(p) + +#define JPEG_MAX_DIMENSIONS (1 << 24) + + enum + { + JPEG__SCAN_load = 0, + JPEG__SCAN_type, + JPEG__SCAN_header + }; + + static void* jpeg__malloc(size_t size) + { + return JPEG_MALLOC(size); + } + + static int jpeg__addsizes_valid(int a, int b) + { + if (b < 0) return 0; + // now 0 <= b <= INT_MAX, hence also + // 0 <= INT_MAX - b <= INTMAX. + // And "a + b <= INT_MAX" (which might overflow) is the + // same as a <= INT_MAX - b (no overflow) + return a <= INT_MAX - b; + } + + static int jpeg__mul2sizes_valid(int a, int b) + { + if (a < 0 || b < 0) return 0; + if (b == 0) return 1; // mul-by-0 is always safe + // portable way to check for no overflows in a*b + return a <= INT_MAX / b; + } + + static int jpeg__mad2sizes_valid(int a, int b, int add) + { + return jpeg__mul2sizes_valid(a, b) && jpeg__addsizes_valid(a * b, add); + } + + // returns 1 if "a*b*c + add" has no negative terms/factors and doesn't overflow + static int jpeg__mad3sizes_valid(int a, int b, int c, int add) + { + return jpeg__mul2sizes_valid(a, b) && jpeg__mul2sizes_valid(a * b, c) && + jpeg__addsizes_valid(a * b * c, add); + } + + static int jpeg__mad4sizes_valid(int a, int b, int c, int d, int add) + { + return jpeg__mul2sizes_valid(a, b) && jpeg__mul2sizes_valid(a * b, c) && + jpeg__mul2sizes_valid(a * b * c, d) && jpeg__addsizes_valid(a * b * c * d, add); + } + + static void* jpeg__malloc_mad2(int a, int b, int add) + { + if (!jpeg__mad2sizes_valid(a, b, add)) return NULL; + return jpeg__malloc(a * b + add); + } + + static void* jpeg__malloc_mad3(int a, int b, int c, int add) + { + if (!jpeg__mad3sizes_valid(a, b, c, add)) return NULL; + return jpeg__malloc(a * b * c + add); + } + + static jpeg_uc jpeg__compute_y(int r, int g, int b) + { + return (jpeg_uc)(((r * 77) + (g * 150) + (29 * b)) >> 8); + } + + typedef struct + { + int bits_per_channel; + int num_channels; + int channel_order; + } jpeg__result_info; + + static void jpeg__rewind(jpeg__context* s) + { + // conceptually rewind SHOULD rewind to the beginning of the stream, + // but we just rewind to the beginning of the initial buffer, because + // we only use it after doing 'test', which only ever looks at at most 92 bytes + s->img_buffer = s->img_buffer_original; + s->img_buffer_end = s->img_buffer_original_end; + } + + //------------------------------------------------------------------------------ + + // huffman decoding acceleration +#define FAST_BITS 9 // larger handles more cases; smaller stomps less cache + + typedef struct + { + jpeg_uc fast[1 << FAST_BITS]; + // weirdly, repacking this into AoS is a 10% speed loss, instead of a win + jpeg__uint16 code[256]; + jpeg_uc values[256]; + jpeg_uc size[257]; + unsigned int maxcode[18]; + int delta[17]; // old 'firstsymbol' - old 'firstcode' + } jpeg__huffman; + + typedef struct + { + jpeg__context* s; + jpeg__huffman huff_dc[4]; + jpeg__huffman huff_ac[4]; + jpeg__uint16 dequant[4][64]; + jpeg__int16 fast_ac[4][1 << FAST_BITS]; + + // sizes for components, interleaved MCUs + int img_h_max, img_v_max; + int img_mcu_x, img_mcu_y; + int img_mcu_w, img_mcu_h; + + // definition of jpeg image component + struct + { + int id; + int h, v; + int tq; + int hd, ha; + int dc_pred; + + int x, y, w2, h2; + jpeg_uc* data; + void* raw_data, * raw_coeff; + jpeg_uc* linebuf; + short* coeff; // progressive only + int coeff_w, coeff_h; // number of 8x8 coefficient blocks + } img_comp[4]; + + jpeg__uint32 code_buffer; // jpeg entropy-coded buffer + int code_bits; // number of valid bits + unsigned char marker; // marker seen while filling entropy buffer + int nomore; // flag if we saw a marker so must stop + + int progressive; + int spec_start; + int spec_end; + int succ_high; + int succ_low; + int eob_run; + int jfif; + int app14_color_transform; // Adobe APP14 tag + int rgb; + + int scan_n, order[4]; + int restart_interval, todo; + + // kernels + void (*idct_block_kernel)(jpeg_uc* out, int out_stride, short data[64]); + void (*YCbCr_to_RGB_kernel)(jpeg_uc* out, const jpeg_uc* y, const jpeg_uc* pcb, const jpeg_uc* pcr, int count, int step); + jpeg_uc* (*resample_row_hv_2_kernel)(jpeg_uc* out, jpeg_uc* in_near, jpeg_uc* in_far, int w, int hs); + } jpeg__jpeg; + + static int jpeg__build_huffman(jpeg__huffman* h, int* count) + { + int i, j, k = 0; + unsigned int code; + // build size list for each symbol (from JPEG spec) + for (i = 0; i < 16; ++i) + for (j = 0; j < count[i]; ++j) + h->size[k++] = (jpeg_uc)(i + 1); + h->size[k] = 0; + + // compute actual symbols (from jpeg spec) + code = 0; + k = 0; + for (j = 1; j <= 16; ++j) { + // compute delta to add to code to compute symbol id + h->delta[j] = k - code; + if (h->size[k] == j) { + while (h->size[k] == j) + h->code[k++] = (jpeg__uint16)(code++); + if (code - 1 >= (1u << j)) return JpegLoadError("bad code lengths", "Corrupt JPEG"); + } + // compute largest code + 1 for this size, preshifted as needed later + h->maxcode[j] = code << (16 - j); + code <<= 1; + } + h->maxcode[j] = 0xffffffff; + + // build non-spec acceleration table; 255 is flag for not-accelerated + memset(h->fast, 255, 1 << FAST_BITS); + for (i = 0; i < k; ++i) { + int s = h->size[i]; + if (s <= FAST_BITS) { + int c = h->code[i] << (FAST_BITS - s); + int m = 1 << (FAST_BITS - s); + for (j = 0; j < m; ++j) { + h->fast[c + j] = (jpeg_uc)i; + } + } + } + return 1; + } + + // build a table that decodes both magnitude and value of small ACs in + // one go. + static void jpeg__build_fast_ac(jpeg__int16* fast_ac, jpeg__huffman* h) + { + int i; + for (i = 0; i < (1 << FAST_BITS); ++i) { + jpeg_uc fast = h->fast[i]; + fast_ac[i] = 0; + if (fast < 255) { + int rs = h->values[fast]; + int run = (rs >> 4) & 15; + int magbits = rs & 15; + int len = h->size[fast]; + + if (magbits && len + magbits <= FAST_BITS) { + // magnitude code followed by receive_extend code + int k = ((i << len) & ((1 << FAST_BITS) - 1)) >> (FAST_BITS - magbits); + int m = 1 << (magbits - 1); + if (k < m) k += (~0U << magbits) + 1; + // if the result is small enough, we can fit it in fast_ac table + if (k >= -128 && k <= 127) + fast_ac[i] = (jpeg__int16)((k * 256) + (run * 16) + (len + magbits)); + } + } + } + } + + static void jpeg__grow_buffer_unsafe(jpeg__jpeg* j) + { + do { + unsigned int b = j->nomore ? 0 : jpeg__get8(j->s); + if (b == 0xff) { + int c = jpeg__get8(j->s); + while (c == 0xff) c = jpeg__get8(j->s); // consume fill bytes + if (c != 0) { + j->marker = (unsigned char)c; + j->nomore = 1; + return; + } + } + j->code_buffer |= b << (24 - j->code_bits); + j->code_bits += 8; + } while (j->code_bits <= 24); + } + + // (1 << n) - 1 + static const jpeg__uint32 jpeg__bmask[17] = { 0,1,3,7,15,31,63,127,255,511,1023,2047,4095,8191,16383,32767,65535 }; + + // decode a jpeg huffman value from the bitstream + SIMD_INLINE static int jpeg__jpeg_huff_decode(jpeg__jpeg* j, jpeg__huffman* h) + { + unsigned int temp; + int c, k; + + if (j->code_bits < 16) jpeg__grow_buffer_unsafe(j); + + // look at the top FAST_BITS and determine what symbol ID it is, + // if the code is <= FAST_BITS + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + k = h->fast[c]; + if (k < 255) { + int s = h->size[k]; + if (s > j->code_bits) + return -1; + j->code_buffer <<= s; + j->code_bits -= s; + return h->values[k]; + } + + // naive test is to shift the code_buffer down so k bits are + // valid, then test against maxcode. To speed this up, we've + // preshifted maxcode left so that it has (16-k) 0s at the + // end; in other words, regardless of the number of bits, it + // wants to be compared against something shifted to have 16; + // that way we don't need to shift inside the loop. + temp = j->code_buffer >> 16; + for (k = FAST_BITS + 1; ; ++k) + if (temp < h->maxcode[k]) + break; + if (k == 17) { + // error! code not found + j->code_bits -= 16; + return -1; + } + + if (k > j->code_bits) + return -1; + + // convert the huffman code to the symbol id + c = ((j->code_buffer >> (32 - k)) & jpeg__bmask[k]) + h->delta[k]; + assert((((j->code_buffer) >> (32 - h->size[c])) & jpeg__bmask[h->size[c]]) == h->code[c]); + + // convert the id to a symbol + j->code_bits -= k; + j->code_buffer <<= k; + return h->values[c]; + } + + // bias[n] = (-1<code_bits < n) jpeg__grow_buffer_unsafe(j); + + sgn = (jpeg__int32)j->code_buffer >> 31; // sign bit is always in MSB + k = jpeg_lrot(j->code_buffer, n); + if (n < 0 || n >= (int)(sizeof(jpeg__bmask) / sizeof(*jpeg__bmask))) return 0; + j->code_buffer = k & ~jpeg__bmask[n]; + k &= jpeg__bmask[n]; + j->code_bits -= n; + return k + (jpeg__jbias[n] & ~sgn); + } + + // get some unsigned bits + SIMD_INLINE static int jpeg__jpeg_get_bits(jpeg__jpeg* j, int n) + { + unsigned int k; + if (j->code_bits < n) jpeg__grow_buffer_unsafe(j); + k = jpeg_lrot(j->code_buffer, n); + j->code_buffer = k & ~jpeg__bmask[n]; + k &= jpeg__bmask[n]; + j->code_bits -= n; + return k; + } + + SIMD_INLINE static int jpeg__jpeg_get_bit(jpeg__jpeg* j) + { + unsigned int k; + if (j->code_bits < 1) jpeg__grow_buffer_unsafe(j); + k = j->code_buffer; + j->code_buffer <<= 1; + --j->code_bits; + return k & 0x80000000; + } + + // given a value that's at position X in the zigzag stream, + // where does it appear in the 8x8 matrix coded as row-major? + static const jpeg_uc jpeg__jpeg_dezigzag[64 + 15] = + { + 0, 1, 8, 16, 9, 2, 3, 10, + 17, 24, 32, 25, 18, 11, 4, 5, + 12, 19, 26, 33, 40, 48, 41, 34, + 27, 20, 13, 6, 7, 14, 21, 28, + 35, 42, 49, 56, 57, 50, 43, 36, + 29, 22, 15, 23, 30, 37, 44, 51, + 58, 59, 52, 45, 38, 31, 39, 46, + 53, 60, 61, 54, 47, 55, 62, 63, + // let corrupt input sample past end + 63, 63, 63, 63, 63, 63, 63, 63, + 63, 63, 63, 63, 63, 63, 63 + }; + + // decode one 64-entry block-- + static int jpeg__jpeg_decode_block(jpeg__jpeg* j, short data[64], jpeg__huffman* hdc, jpeg__huffman* hac, jpeg__int16* fac, int b, jpeg__uint16* dequant) + { + int diff, dc, k; + int t; + + if (j->code_bits < 16) jpeg__grow_buffer_unsafe(j); + t = jpeg__jpeg_huff_decode(j, hdc); + if (t < 0) return JpegLoadError("bad huffman code", "Corrupt JPEG"); + + // 0 all the ac values now so we can do it 32-bits at a time + memset(data, 0, 64 * sizeof(data[0])); + + diff = t ? jpeg__extend_receive(j, t) : 0; + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc * dequant[0]); + + // decode AC components, see JPEG spec + k = 1; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) jpeg__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + // decode into unzigzag'd location + zig = jpeg__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) * dequant[zig]); + } + else { + int rs = jpeg__jpeg_huff_decode(j, hac); + if (rs < 0) return JpegLoadError("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (rs != 0xf0) break; // end block + k += 16; + } + else { + k += r; + // decode into unzigzag'd location + zig = jpeg__jpeg_dezigzag[k++]; + data[zig] = (short)(jpeg__extend_receive(j, s) * dequant[zig]); + } + } + } while (k < 64); + return 1; + } + + static int jpeg__jpeg_decode_block_prog_dc(jpeg__jpeg* j, short data[64], jpeg__huffman* hdc, int b) + { + int diff, dc; + int t; + if (j->spec_end != 0) return JpegLoadError("can't merge dc and ac", "Corrupt JPEG"); + + if (j->code_bits < 16) jpeg__grow_buffer_unsafe(j); + + if (j->succ_high == 0) { + // first scan for DC coefficient, must be first + memset(data, 0, 64 * sizeof(data[0])); // 0 all the ac values now + t = jpeg__jpeg_huff_decode(j, hdc); + if (t == -1) return JpegLoadError("can't merge dc and ac", "Corrupt JPEG"); + diff = t ? jpeg__extend_receive(j, t) : 0; + + dc = j->img_comp[b].dc_pred + diff; + j->img_comp[b].dc_pred = dc; + data[0] = (short)(dc << j->succ_low); + } + else { + // refinement scan for DC coefficient + if (jpeg__jpeg_get_bit(j)) + data[0] += (short)(1 << j->succ_low); + } + return 1; + } + + // @OPTIMIZE: store non-zigzagged during the decode passes, + // and only de-zigzag when dequantizing + static int jpeg__jpeg_decode_block_prog_ac(jpeg__jpeg* j, short data[64], jpeg__huffman* hac, jpeg__int16* fac) + { + int k; + if (j->spec_start == 0) return JpegLoadError("can't merge dc and ac", "Corrupt JPEG"); + + if (j->succ_high == 0) { + int shift = j->succ_low; + + if (j->eob_run) { + --j->eob_run; + return 1; + } + + k = j->spec_start; + do { + unsigned int zig; + int c, r, s; + if (j->code_bits < 16) jpeg__grow_buffer_unsafe(j); + c = (j->code_buffer >> (32 - FAST_BITS)) & ((1 << FAST_BITS) - 1); + r = fac[c]; + if (r) { // fast-AC path + k += (r >> 4) & 15; // run + s = r & 15; // combined length + j->code_buffer <<= s; + j->code_bits -= s; + zig = jpeg__jpeg_dezigzag[k++]; + data[zig] = (short)((r >> 8) << shift); + } + else { + int rs = jpeg__jpeg_huff_decode(j, hac); + if (rs < 0) return JpegLoadError("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r); + if (r) + j->eob_run += jpeg__jpeg_get_bits(j, r); + --j->eob_run; + break; + } + k += 16; + } + else { + k += r; + zig = jpeg__jpeg_dezigzag[k++]; + data[zig] = (short)(jpeg__extend_receive(j, s) << shift); + } + } + } while (k <= j->spec_end); + } + else { + // refinement scan for these AC coefficients + + short bit = (short)(1 << j->succ_low); + + if (j->eob_run) { + --j->eob_run; + for (k = j->spec_start; k <= j->spec_end; ++k) { + short* p = &data[jpeg__jpeg_dezigzag[k]]; + if (*p != 0) + if (jpeg__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + } + else { + k = j->spec_start; + do { + int r, s; + int rs = jpeg__jpeg_huff_decode(j, hac); // @OPTIMIZE see if we can use the fast path here, advance-by-r is so slow, eh + if (rs < 0) return JpegLoadError("bad huffman code", "Corrupt JPEG"); + s = rs & 15; + r = rs >> 4; + if (s == 0) { + if (r < 15) { + j->eob_run = (1 << r) - 1; + if (r) + j->eob_run += jpeg__jpeg_get_bits(j, r); + r = 64; // force end of block + } + else { + // r=15 s=0 should write 16 0s, so we just do + // a run of 15 0s and then write s (which is 0), + // so we don't have to do anything special here + } + } + else { + if (s != 1) return JpegLoadError("bad huffman code", "Corrupt JPEG"); + // sign bit + if (jpeg__jpeg_get_bit(j)) + s = bit; + else + s = -bit; + } + + // advance by r + while (k <= j->spec_end) { + short* p = &data[jpeg__jpeg_dezigzag[k++]]; + if (*p != 0) { + if (jpeg__jpeg_get_bit(j)) + if ((*p & bit) == 0) { + if (*p > 0) + *p += bit; + else + *p -= bit; + } + } + else { + if (r == 0) { + *p = (short)s; + break; + } + --r; + } + } + } while (k <= j->spec_end); + } + } + return 1; + } + + // take a -128..127 value and jpeg__clamp it and convert to 0..255 + SIMD_INLINE static jpeg_uc jpeg__clamp(int x) + { + // trick to use a single test to catch both cases + if ((unsigned int)x > 255) { + if (x < 0) return 0; + if (x > 255) return 255; + } + return (jpeg_uc)x; + } + +#define jpeg__f2f(x) ((int) (((x) * 4096 + 0.5))) +#define jpeg__fsh(x) ((x) * 4096) + + // derived from jidctint -- DCT_ISLOW +#define JPEG__IDCT_1D(s0,s1,s2,s3,s4,s5,s6,s7) \ + int t0,t1,t2,t3,p1,p2,p3,p4,p5,x0,x1,x2,x3; \ + p2 = s2; \ + p3 = s6; \ + p1 = (p2+p3) * jpeg__f2f(0.5411961f); \ + t2 = p1 + p3*jpeg__f2f(-1.847759065f); \ + t3 = p1 + p2*jpeg__f2f( 0.765366865f); \ + p2 = s0; \ + p3 = s4; \ + t0 = jpeg__fsh(p2+p3); \ + t1 = jpeg__fsh(p2-p3); \ + x0 = t0+t3; \ + x3 = t0-t3; \ + x1 = t1+t2; \ + x2 = t1-t2; \ + t0 = s7; \ + t1 = s5; \ + t2 = s3; \ + t3 = s1; \ + p3 = t0+t2; \ + p4 = t1+t3; \ + p1 = t0+t3; \ + p2 = t1+t2; \ + p5 = (p3+p4)*jpeg__f2f( 1.175875602f); \ + t0 = t0*jpeg__f2f( 0.298631336f); \ + t1 = t1*jpeg__f2f( 2.053119869f); \ + t2 = t2*jpeg__f2f( 3.072711026f); \ + t3 = t3*jpeg__f2f( 1.501321110f); \ + p1 = p5 + p1*jpeg__f2f(-0.899976223f); \ + p2 = p5 + p2*jpeg__f2f(-2.562915447f); \ + p3 = p3*jpeg__f2f(-1.961570560f); \ + p4 = p4*jpeg__f2f(-0.390180644f); \ + t3 += p1+p4; \ + t2 += p2+p3; \ + t1 += p2+p4; \ + t0 += p1+p3; + + static void jpeg__idct_block(jpeg_uc* out, int out_stride, short data[64]) + { + int i, val[64], * v = val; + jpeg_uc* o; + short* d = data; + + // columns + for (i = 0; i < 8; ++i, ++d, ++v) { + // if all zeroes, shortcut -- this avoids dequantizing 0s and IDCTing + if (d[8] == 0 && d[16] == 0 && d[24] == 0 && d[32] == 0 + && d[40] == 0 && d[48] == 0 && d[56] == 0) { + // no shortcut 0 seconds + // (1|2|3|4|5|6|7)==0 0 seconds + // all separate -0.047 seconds + // 1 && 2|3 && 4|5 && 6|7: -0.047 seconds + int dcterm = d[0] * 4; + v[0] = v[8] = v[16] = v[24] = v[32] = v[40] = v[48] = v[56] = dcterm; + } + else { + JPEG__IDCT_1D(d[0], d[8], d[16], d[24], d[32], d[40], d[48], d[56]) + // constants scaled things up by 1<<12; let's bring them back + // down, but keep 2 extra bits of precision + x0 += 512; x1 += 512; x2 += 512; x3 += 512; + v[0] = (x0 + t3) >> 10; + v[56] = (x0 - t3) >> 10; + v[8] = (x1 + t2) >> 10; + v[48] = (x1 - t2) >> 10; + v[16] = (x2 + t1) >> 10; + v[40] = (x2 - t1) >> 10; + v[24] = (x3 + t0) >> 10; + v[32] = (x3 - t0) >> 10; + } + } + + for (i = 0, v = val, o = out; i < 8; ++i, v += 8, o += out_stride) { + // no fast case since the first 1D IDCT spread components out + JPEG__IDCT_1D(v[0], v[1], v[2], v[3], v[4], v[5], v[6], v[7]) + // constants scaled things up by 1<<12, plus we had 1<<2 from first + // loop, plus horizontal and vertical each scale by sqrt(8) so together + // we've got an extra 1<<3, so 1<<17 total we need to remove. + // so we want to round that, which means adding 0.5 * 1<<17, + // aka 65536. Also, we'll end up with -128 to 127 that we want + // to encode as 0..255 by adding 128, so we'll add that before the shift + x0 += 65536 + (128 << 17); + x1 += 65536 + (128 << 17); + x2 += 65536 + (128 << 17); + x3 += 65536 + (128 << 17); + // tried computing the shifts into temps, or'ing the temps to see + // if any were out of range, but that was slower + o[0] = jpeg__clamp((x0 + t3) >> 17); + o[7] = jpeg__clamp((x0 - t3) >> 17); + o[1] = jpeg__clamp((x1 + t2) >> 17); + o[6] = jpeg__clamp((x1 - t2) >> 17); + o[2] = jpeg__clamp((x2 + t1) >> 17); + o[5] = jpeg__clamp((x2 - t1) >> 17); + o[3] = jpeg__clamp((x3 + t0) >> 17); + o[4] = jpeg__clamp((x3 - t0) >> 17); + } + } + +#ifdef JPEG_SSE2 + // sse2 integer IDCT. not the fastest possible implementation but it + // produces bit-identical results to the generic C version so it's + // fully "transparent". + static void jpeg__idct_simd(jpeg_uc* out, int out_stride, short data[64]) + { + // This is constructed to match our regular (generic) integer IDCT exactly. + __m128i row0, row1, row2, row3, row4, row5, row6, row7; + __m128i tmp; + + // dot product constant: even elems=x, odd elems=y +#define dct_const(x,y) _mm_setr_epi16((x),(y),(x),(y),(x),(y),(x),(y)) + +// out(0) = c0[even]*x + c0[odd]*y (c0, x, y 16-bit, out 32-bit) +// out(1) = c1[even]*x + c1[odd]*y +#define dct_rot(out0,out1, x,y,c0,c1) \ + __m128i c0##lo = _mm_unpacklo_epi16((x),(y)); \ + __m128i c0##hi = _mm_unpackhi_epi16((x),(y)); \ + __m128i out0##_l = _mm_madd_epi16(c0##lo, c0); \ + __m128i out0##_h = _mm_madd_epi16(c0##hi, c0); \ + __m128i out1##_l = _mm_madd_epi16(c0##lo, c1); \ + __m128i out1##_h = _mm_madd_epi16(c0##hi, c1) + + // out = in << 12 (in 16-bit, out 32-bit) +#define dct_widen(out, in) \ + __m128i out##_l = _mm_srai_epi32(_mm_unpacklo_epi16(_mm_setzero_si128(), (in)), 4); \ + __m128i out##_h = _mm_srai_epi32(_mm_unpackhi_epi16(_mm_setzero_si128(), (in)), 4) + + // wide add +#define dct_wadd(out, a, b) \ + __m128i out##_l = _mm_add_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_add_epi32(a##_h, b##_h) + + // wide sub +#define dct_wsub(out, a, b) \ + __m128i out##_l = _mm_sub_epi32(a##_l, b##_l); \ + __m128i out##_h = _mm_sub_epi32(a##_h, b##_h) + + // butterfly a/b, add bias, then shift by "s" and pack +#define dct_bfly32o(out0, out1, a,b,bias,s) \ + { \ + __m128i abiased_l = _mm_add_epi32(a##_l, bias); \ + __m128i abiased_h = _mm_add_epi32(a##_h, bias); \ + dct_wadd(sum, abiased, b); \ + dct_wsub(dif, abiased, b); \ + out0 = _mm_packs_epi32(_mm_srai_epi32(sum_l, s), _mm_srai_epi32(sum_h, s)); \ + out1 = _mm_packs_epi32(_mm_srai_epi32(dif_l, s), _mm_srai_epi32(dif_h, s)); \ + } + + // 8-bit interleave step (for transposes) +#define dct_interleave8(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi8(a, b); \ + b = _mm_unpackhi_epi8(tmp, b) + + // 16-bit interleave step (for transposes) +#define dct_interleave16(a, b) \ + tmp = a; \ + a = _mm_unpacklo_epi16(a, b); \ + b = _mm_unpackhi_epi16(tmp, b) + +#define dct_pass(bias,shift) \ + { \ + /* even part */ \ + dct_rot(t2e,t3e, row2,row6, rot0_0,rot0_1); \ + __m128i sum04 = _mm_add_epi16(row0, row4); \ + __m128i dif04 = _mm_sub_epi16(row0, row4); \ + dct_widen(t0e, sum04); \ + dct_widen(t1e, dif04); \ + dct_wadd(x0, t0e, t3e); \ + dct_wsub(x3, t0e, t3e); \ + dct_wadd(x1, t1e, t2e); \ + dct_wsub(x2, t1e, t2e); \ + /* odd part */ \ + dct_rot(y0o,y2o, row7,row3, rot2_0,rot2_1); \ + dct_rot(y1o,y3o, row5,row1, rot3_0,rot3_1); \ + __m128i sum17 = _mm_add_epi16(row1, row7); \ + __m128i sum35 = _mm_add_epi16(row3, row5); \ + dct_rot(y4o,y5o, sum17,sum35, rot1_0,rot1_1); \ + dct_wadd(x4, y0o, y4o); \ + dct_wadd(x5, y1o, y5o); \ + dct_wadd(x6, y2o, y5o); \ + dct_wadd(x7, y3o, y4o); \ + dct_bfly32o(row0,row7, x0,x7,bias,shift); \ + dct_bfly32o(row1,row6, x1,x6,bias,shift); \ + dct_bfly32o(row2,row5, x2,x5,bias,shift); \ + dct_bfly32o(row3,row4, x3,x4,bias,shift); \ + } + + __m128i rot0_0 = dct_const(jpeg__f2f(0.5411961f), jpeg__f2f(0.5411961f) + jpeg__f2f(-1.847759065f)); + __m128i rot0_1 = dct_const(jpeg__f2f(0.5411961f) + jpeg__f2f(0.765366865f), jpeg__f2f(0.5411961f)); + __m128i rot1_0 = dct_const(jpeg__f2f(1.175875602f) + jpeg__f2f(-0.899976223f), jpeg__f2f(1.175875602f)); + __m128i rot1_1 = dct_const(jpeg__f2f(1.175875602f), jpeg__f2f(1.175875602f) + jpeg__f2f(-2.562915447f)); + __m128i rot2_0 = dct_const(jpeg__f2f(-1.961570560f) + jpeg__f2f(0.298631336f), jpeg__f2f(-1.961570560f)); + __m128i rot2_1 = dct_const(jpeg__f2f(-1.961570560f), jpeg__f2f(-1.961570560f) + jpeg__f2f(3.072711026f)); + __m128i rot3_0 = dct_const(jpeg__f2f(-0.390180644f) + jpeg__f2f(2.053119869f), jpeg__f2f(-0.390180644f)); + __m128i rot3_1 = dct_const(jpeg__f2f(-0.390180644f), jpeg__f2f(-0.390180644f) + jpeg__f2f(1.501321110f)); + + // rounding biases in column/row passes, see jpeg__idct_block for explanation. + __m128i bias_0 = _mm_set1_epi32(512); + __m128i bias_1 = _mm_set1_epi32(65536 + (128 << 17)); + + // load + row0 = _mm_load_si128((const __m128i*) (data + 0 * 8)); + row1 = _mm_load_si128((const __m128i*) (data + 1 * 8)); + row2 = _mm_load_si128((const __m128i*) (data + 2 * 8)); + row3 = _mm_load_si128((const __m128i*) (data + 3 * 8)); + row4 = _mm_load_si128((const __m128i*) (data + 4 * 8)); + row5 = _mm_load_si128((const __m128i*) (data + 5 * 8)); + row6 = _mm_load_si128((const __m128i*) (data + 6 * 8)); + row7 = _mm_load_si128((const __m128i*) (data + 7 * 8)); + + // column pass + dct_pass(bias_0, 10); + + { + // 16bit 8x8 transpose pass 1 + dct_interleave16(row0, row4); + dct_interleave16(row1, row5); + dct_interleave16(row2, row6); + dct_interleave16(row3, row7); + + // transpose pass 2 + dct_interleave16(row0, row2); + dct_interleave16(row1, row3); + dct_interleave16(row4, row6); + dct_interleave16(row5, row7); + + // transpose pass 3 + dct_interleave16(row0, row1); + dct_interleave16(row2, row3); + dct_interleave16(row4, row5); + dct_interleave16(row6, row7); + } + + // row pass + dct_pass(bias_1, 17); + + { + // pack + __m128i p0 = _mm_packus_epi16(row0, row1); // a0a1a2a3...a7b0b1b2b3...b7 + __m128i p1 = _mm_packus_epi16(row2, row3); + __m128i p2 = _mm_packus_epi16(row4, row5); + __m128i p3 = _mm_packus_epi16(row6, row7); + + // 8bit 8x8 transpose pass 1 + dct_interleave8(p0, p2); // a0e0a1e1... + dct_interleave8(p1, p3); // c0g0c1g1... + + // transpose pass 2 + dct_interleave8(p0, p1); // a0c0e0g0... + dct_interleave8(p2, p3); // b0d0f0h0... + + // transpose pass 3 + dct_interleave8(p0, p2); // a0b0c0d0... + dct_interleave8(p1, p3); // a4b4c4d4... + + // store + _mm_storel_epi64((__m128i*) out, p0); out += out_stride; + _mm_storel_epi64((__m128i*) out, _mm_shuffle_epi32(p0, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i*) out, p2); out += out_stride; + _mm_storel_epi64((__m128i*) out, _mm_shuffle_epi32(p2, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i*) out, p1); out += out_stride; + _mm_storel_epi64((__m128i*) out, _mm_shuffle_epi32(p1, 0x4e)); out += out_stride; + _mm_storel_epi64((__m128i*) out, p3); out += out_stride; + _mm_storel_epi64((__m128i*) out, _mm_shuffle_epi32(p3, 0x4e)); + } + +#undef dct_const +#undef dct_rot +#undef dct_widen +#undef dct_wadd +#undef dct_wsub +#undef dct_bfly32o +#undef dct_interleave8 +#undef dct_interleave16 +#undef dct_pass + } + +#endif // JPEG_SSE2 + +#define JPEG__MARKER_none 0xff + // if there's a pending marker from the entropy stream, return that + // otherwise, fetch from the stream and get a marker. if there's no + // marker, return 0xff, which is never a valid marker value + static jpeg_uc jpeg__get_marker(jpeg__jpeg* j) + { + jpeg_uc x; + if (j->marker != JPEG__MARKER_none) { x = j->marker; j->marker = JPEG__MARKER_none; return x; } + x = jpeg__get8(j->s); + if (x != 0xff) return JPEG__MARKER_none; + while (x == 0xff) + x = jpeg__get8(j->s); // consume repeated 0xff fill bytes + return x; + } + + // in each scan, we'll have scan_n components, and the order + // of the components is specified by order[] +#define JPEG__RESTART(x) ((x) >= 0xd0 && (x) <= 0xd7) + +// after a restart interval, jpeg__jpeg_reset the entropy decoder and +// the dc prediction + static void jpeg__jpeg_reset(jpeg__jpeg* j) + { + j->code_bits = 0; + j->code_buffer = 0; + j->nomore = 0; + j->img_comp[0].dc_pred = j->img_comp[1].dc_pred = j->img_comp[2].dc_pred = j->img_comp[3].dc_pred = 0; + j->marker = JPEG__MARKER_none; + j->todo = j->restart_interval ? j->restart_interval : 0x7fffffff; + j->eob_run = 0; + // no more than 1<<31 MCUs if no restart_interal? that's plenty safe, + // since we don't even allow 1<<30 pixels + } + + static int jpeg__parse_entropy_coded_data(jpeg__jpeg* z) + { + jpeg__jpeg_reset(z); + if (!z->progressive) { + if (z->scan_n == 1) { + int i, j; + JPEG_SIMD_ALIGN(short, data[64]); + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + int ha = z->img_comp[n].ha; + if (!jpeg__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, z->huff_ac + ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) jpeg__grow_buffer_unsafe(z); + // if it's NOT a restart, then just bail, so we get corrupt data + // rather than no data + if (!JPEG__RESTART(z->marker)) return 1; + jpeg__jpeg_reset(z); + } + } + } + return 1; + } + else { // interleaved + int i, j, k, x, y; + JPEG_SIMD_ALIGN(short, data[64]); + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x) * 8; + int y2 = (j * z->img_comp[n].v + y) * 8; + int ha = z->img_comp[n].ha; + if (!jpeg__jpeg_decode_block(z, data, z->huff_dc + z->img_comp[n].hd, z->huff_ac + ha, z->fast_ac[ha], n, z->dequant[z->img_comp[n].tq])) return 0; + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * y2 + x2, z->img_comp[n].w2, data); + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) jpeg__grow_buffer_unsafe(z); + if (!JPEG__RESTART(z->marker)) return 1; + jpeg__jpeg_reset(z); + } + } + } + return 1; + } + } + else { + if (z->scan_n == 1) { + int i, j; + int n = z->order[0]; + // non-interleaved data, we just need to process one block at a time, + // in trivial scanline order + // number of blocks to do just depends on how many actual "pixels" this + // component has, independent of interleaved MCU blocking and such + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short* data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + if (z->spec_start == 0) { + if (!jpeg__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } + else { + int ha = z->img_comp[n].ha; + if (!jpeg__jpeg_decode_block_prog_ac(z, data, &z->huff_ac[ha], z->fast_ac[ha])) + return 0; + } + // every data block is an MCU, so countdown the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) jpeg__grow_buffer_unsafe(z); + if (!JPEG__RESTART(z->marker)) return 1; + jpeg__jpeg_reset(z); + } + } + } + return 1; + } + else { // interleaved + int i, j, k, x, y; + for (j = 0; j < z->img_mcu_y; ++j) { + for (i = 0; i < z->img_mcu_x; ++i) { + // scan an interleaved mcu... process scan_n components in order + for (k = 0; k < z->scan_n; ++k) { + int n = z->order[k]; + // scan out an mcu's worth of this component; that's just determined + // by the basic H and V specified for the component + for (y = 0; y < z->img_comp[n].v; ++y) { + for (x = 0; x < z->img_comp[n].h; ++x) { + int x2 = (i * z->img_comp[n].h + x); + int y2 = (j * z->img_comp[n].v + y); + short* data = z->img_comp[n].coeff + 64 * (x2 + y2 * z->img_comp[n].coeff_w); + if (!jpeg__jpeg_decode_block_prog_dc(z, data, &z->huff_dc[z->img_comp[n].hd], n)) + return 0; + } + } + } + // after all interleaved components, that's an interleaved MCU, + // so now count down the restart interval + if (--z->todo <= 0) { + if (z->code_bits < 24) jpeg__grow_buffer_unsafe(z); + if (!JPEG__RESTART(z->marker)) return 1; + jpeg__jpeg_reset(z); + } + } + } + return 1; + } + } + } + + static void jpeg__jpeg_dequantize(short* data, jpeg__uint16* dequant) + { + int i; + for (i = 0; i < 64; ++i) + data[i] *= dequant[i]; + } + + static void jpeg__jpeg_finish(jpeg__jpeg* z) + { + if (z->progressive) { + // dequantize and idct the data + int i, j, n; + for (n = 0; n < z->s->img_n; ++n) { + int w = (z->img_comp[n].x + 7) >> 3; + int h = (z->img_comp[n].y + 7) >> 3; + for (j = 0; j < h; ++j) { + for (i = 0; i < w; ++i) { + short* data = z->img_comp[n].coeff + 64 * (i + j * z->img_comp[n].coeff_w); + jpeg__jpeg_dequantize(data, z->dequant[z->img_comp[n].tq]); + z->idct_block_kernel(z->img_comp[n].data + z->img_comp[n].w2 * j * 8 + i * 8, z->img_comp[n].w2, data); + } + } + } + } + } + + static int jpeg__process_marker(jpeg__jpeg* z, int m) + { + int L; + switch (m) { + case JPEG__MARKER_none: // no marker found + return JpegLoadError("expected marker", "Corrupt JPEG"); + + case 0xDD: // DRI - specify restart interval + if (jpeg__get16be(z->s) != 4) return JpegLoadError("bad DRI len", "Corrupt JPEG"); + z->restart_interval = jpeg__get16be(z->s); + return 1; + + case 0xDB: // DQT - define quantization table + L = jpeg__get16be(z->s) - 2; + while (L > 0) { + int q = jpeg__get8(z->s); + int p = q >> 4, sixteen = (p != 0); + int t = q & 15, i; + if (p != 0 && p != 1) return JpegLoadError("bad DQT type", "Corrupt JPEG"); + if (t > 3) return JpegLoadError("bad DQT table", "Corrupt JPEG"); + + for (i = 0; i < 64; ++i) + z->dequant[t][jpeg__jpeg_dezigzag[i]] = (jpeg__uint16)(sixteen ? jpeg__get16be(z->s) : jpeg__get8(z->s)); + L -= (sixteen ? 129 : 65); + } + return L == 0; + + case 0xC4: // DHT - define huffman table + L = jpeg__get16be(z->s) - 2; + while (L > 0) { + jpeg_uc* v; + int sizes[16], i, n = 0; + int q = jpeg__get8(z->s); + int tc = q >> 4; + int th = q & 15; + if (tc > 1 || th > 3) return JpegLoadError("bad DHT header", "Corrupt JPEG"); + for (i = 0; i < 16; ++i) { + sizes[i] = jpeg__get8(z->s); + n += sizes[i]; + } + L -= 17; + if (tc == 0) { + if (!jpeg__build_huffman(z->huff_dc + th, sizes)) return 0; + v = z->huff_dc[th].values; + } + else { + if (!jpeg__build_huffman(z->huff_ac + th, sizes)) return 0; + v = z->huff_ac[th].values; + } + for (i = 0; i < n; ++i) + v[i] = jpeg__get8(z->s); + if (tc != 0) + jpeg__build_fast_ac(z->fast_ac[th], z->huff_ac + th); + L -= n; + } + return L == 0; + } + + // check for comment block or APP blocks + if ((m >= 0xE0 && m <= 0xEF) || m == 0xFE) { + L = jpeg__get16be(z->s); + if (L < 2) { + if (m == 0xFE) + return JpegLoadError("bad COM len", "Corrupt JPEG"); + else + return JpegLoadError("bad APP len", "Corrupt JPEG"); + } + L -= 2; + + if (m == 0xE0 && L >= 5) { // JFIF APP0 segment + static const unsigned char tag[5] = { 'J','F','I','F','\0' }; + int ok = 1; + int i; + for (i = 0; i < 5; ++i) + if (jpeg__get8(z->s) != tag[i]) + ok = 0; + L -= 5; + if (ok) + z->jfif = 1; + } + else if (m == 0xEE && L >= 12) { // Adobe APP14 segment + static const unsigned char tag[6] = { 'A','d','o','b','e','\0' }; + int ok = 1; + int i; + for (i = 0; i < 6; ++i) + if (jpeg__get8(z->s) != tag[i]) + ok = 0; + L -= 6; + if (ok) { + jpeg__get8(z->s); // version + jpeg__get16be(z->s); // flags0 + jpeg__get16be(z->s); // flags1 + z->app14_color_transform = jpeg__get8(z->s); // color transform + L -= 6; + } + } + + jpeg__skip(z->s, L); + return 1; + } + + return JpegLoadError("unknown marker", "Corrupt JPEG"); + } + + // after we see SOS + static int jpeg__process_scan_header(jpeg__jpeg* z) + { + int i; + int Ls = jpeg__get16be(z->s); + z->scan_n = jpeg__get8(z->s); + if (z->scan_n < 1 || z->scan_n > 4 || z->scan_n > (int)z->s->img_n) return JpegLoadError("bad SOS component count", "Corrupt JPEG"); + if (Ls != 6 + 2 * z->scan_n) return JpegLoadError("bad SOS len", "Corrupt JPEG"); + for (i = 0; i < z->scan_n; ++i) { + int id = jpeg__get8(z->s), which; + int q = jpeg__get8(z->s); + for (which = 0; which < z->s->img_n; ++which) + if (z->img_comp[which].id == id) + break; + if (which == z->s->img_n) return 0; // no match + z->img_comp[which].hd = q >> 4; if (z->img_comp[which].hd > 3) return JpegLoadError("bad DC huff", "Corrupt JPEG"); + z->img_comp[which].ha = q & 15; if (z->img_comp[which].ha > 3) return JpegLoadError("bad AC huff", "Corrupt JPEG"); + z->order[i] = which; + } + + { + int aa; + z->spec_start = jpeg__get8(z->s); + z->spec_end = jpeg__get8(z->s); // should be 63, but might be 0 + aa = jpeg__get8(z->s); + z->succ_high = (aa >> 4); + z->succ_low = (aa & 15); + if (z->progressive) { + if (z->spec_start > 63 || z->spec_end > 63 || z->spec_start > z->spec_end || z->succ_high > 13 || z->succ_low > 13) + return JpegLoadError("bad SOS", "Corrupt JPEG"); + } + else { + if (z->spec_start != 0) return JpegLoadError("bad SOS", "Corrupt JPEG"); + if (z->succ_high != 0 || z->succ_low != 0) return JpegLoadError("bad SOS", "Corrupt JPEG"); + z->spec_end = 63; + } + } + + return 1; + } + + static int jpeg__free_jpeg_components(jpeg__jpeg* z, int ncomp, int why) + { + int i; + for (i = 0; i < ncomp; ++i) { + if (z->img_comp[i].raw_data) { + JPEG_FREE(z->img_comp[i].raw_data); + z->img_comp[i].raw_data = NULL; + z->img_comp[i].data = NULL; + } + if (z->img_comp[i].raw_coeff) { + JPEG_FREE(z->img_comp[i].raw_coeff); + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].coeff = 0; + } + if (z->img_comp[i].linebuf) { + JPEG_FREE(z->img_comp[i].linebuf); + z->img_comp[i].linebuf = NULL; + } + } + return why; + } + + static int jpeg__process_frame_header(jpeg__jpeg* z, int scan) + { + jpeg__context* s = z->s; + int Lf, p, i, q, h_max = 1, v_max = 1, c; + Lf = jpeg__get16be(s); if (Lf < 11) return JpegLoadError("bad SOF len", "Corrupt JPEG"); // JPEG + p = jpeg__get8(s); if (p != 8) return JpegLoadError("only 8-bit", "JPEG format not supported: 8-bit only"); // JPEG baseline + s->img_y = jpeg__get16be(s); if (s->img_y == 0) return JpegLoadError("no header height", "JPEG format not supported: delayed height"); // Legal, but we don't handle it--but neither does IJG + s->img_x = jpeg__get16be(s); if (s->img_x == 0) return JpegLoadError("0 width", "Corrupt JPEG"); // JPEG requires + if (s->img_y > JPEG_MAX_DIMENSIONS) return JpegLoadError("too large", "Very large image (corrupt?)"); + if (s->img_x > JPEG_MAX_DIMENSIONS) return JpegLoadError("too large", "Very large image (corrupt?)"); + c = jpeg__get8(s); + if (c != 3 && c != 1 && c != 4) return JpegLoadError("bad component count", "Corrupt JPEG"); + s->img_n = c; + for (i = 0; i < c; ++i) { + z->img_comp[i].data = NULL; + z->img_comp[i].linebuf = NULL; + } + + if (Lf != 8 + 3 * s->img_n) return JpegLoadError("bad SOF len", "Corrupt JPEG"); + + z->rgb = 0; + for (i = 0; i < s->img_n; ++i) { + static const unsigned char rgb[3] = { 'R', 'G', 'B' }; + z->img_comp[i].id = jpeg__get8(s); + if (s->img_n == 3 && z->img_comp[i].id == rgb[i]) + ++z->rgb; + q = jpeg__get8(s); + z->img_comp[i].h = (q >> 4); if (!z->img_comp[i].h || z->img_comp[i].h > 4) return JpegLoadError("bad H", "Corrupt JPEG"); + z->img_comp[i].v = q & 15; if (!z->img_comp[i].v || z->img_comp[i].v > 4) return JpegLoadError("bad V", "Corrupt JPEG"); + z->img_comp[i].tq = jpeg__get8(s); if (z->img_comp[i].tq > 3) return JpegLoadError("bad TQ", "Corrupt JPEG"); + } + + if (scan != JPEG__SCAN_load) return 1; + + if (!jpeg__mad3sizes_valid(s->img_x, s->img_y, s->img_n, 0)) return JpegLoadError("too large", "Image too large to decode"); + + for (i = 0; i < s->img_n; ++i) { + if (z->img_comp[i].h > h_max) h_max = z->img_comp[i].h; + if (z->img_comp[i].v > v_max) v_max = z->img_comp[i].v; + } + + // compute interleaved mcu info + z->img_h_max = h_max; + z->img_v_max = v_max; + z->img_mcu_w = h_max * 8; + z->img_mcu_h = v_max * 8; + // these sizes can't be more than 17 bits + z->img_mcu_x = (s->img_x + z->img_mcu_w - 1) / z->img_mcu_w; + z->img_mcu_y = (s->img_y + z->img_mcu_h - 1) / z->img_mcu_h; + + for (i = 0; i < s->img_n; ++i) { + // number of effective pixels (e.g. for non-interleaved MCU) + z->img_comp[i].x = (s->img_x * z->img_comp[i].h + h_max - 1) / h_max; + z->img_comp[i].y = (s->img_y * z->img_comp[i].v + v_max - 1) / v_max; + // to simplify generation, we'll allocate enough memory to decode + // the bogus oversized data from using interleaved MCUs and their + // big blocks (e.g. a 16x16 iMCU on an image of width 33); we won't + // discard the extra data until colorspace conversion + // + // img_mcu_x, img_mcu_y: <=17 bits; comp[i].h and .v are <=4 (checked earlier) + // so these muls can't overflow with 32-bit ints (which we require) + z->img_comp[i].w2 = z->img_mcu_x * z->img_comp[i].h * 8; + z->img_comp[i].h2 = z->img_mcu_y * z->img_comp[i].v * 8; + z->img_comp[i].coeff = 0; + z->img_comp[i].raw_coeff = 0; + z->img_comp[i].linebuf = NULL; + z->img_comp[i].raw_data = jpeg__malloc_mad2(z->img_comp[i].w2, z->img_comp[i].h2, 15); + if (z->img_comp[i].raw_data == NULL) + return jpeg__free_jpeg_components(z, i + 1, JpegLoadError("outofmem", "Out of memory")); + // align blocks for idct using mmx/sse + z->img_comp[i].data = (jpeg_uc*)(((size_t)z->img_comp[i].raw_data + 15) & ~15); + if (z->progressive) { + // w2, h2 are multiples of 8 (see above) + z->img_comp[i].coeff_w = z->img_comp[i].w2 / 8; + z->img_comp[i].coeff_h = z->img_comp[i].h2 / 8; + z->img_comp[i].raw_coeff = jpeg__malloc_mad3(z->img_comp[i].w2, z->img_comp[i].h2, sizeof(short), 15); + if (z->img_comp[i].raw_coeff == NULL) + return jpeg__free_jpeg_components(z, i + 1, JpegLoadError("outofmem", "Out of memory")); + z->img_comp[i].coeff = (short*)(((size_t)z->img_comp[i].raw_coeff + 15) & ~15); + } + } + + return 1; + } + + // use comparisons since in some cases we handle more than one case (e.g. SOF) +#define jpeg__DNL(x) ((x) == 0xdc) +#define jpeg__SOI(x) ((x) == 0xd8) +#define jpeg__EOI(x) ((x) == 0xd9) +#define jpeg__SOF(x) ((x) == 0xc0 || (x) == 0xc1 || (x) == 0xc2) +#define jpeg__SOS(x) ((x) == 0xda) + +#define jpeg__SOF_progressive(x) ((x) == 0xc2) + + static int jpeg__decode_jpeg_header(jpeg__jpeg* z, int scan) + { + int m; + z->jfif = 0; + z->app14_color_transform = -1; // valid values are 0,1,2 + z->marker = JPEG__MARKER_none; // initialize cached marker to empty + m = jpeg__get_marker(z); + if (!jpeg__SOI(m)) return JpegLoadError("no SOI", "Corrupt JPEG"); + if (scan == JPEG__SCAN_type) return 1; + m = jpeg__get_marker(z); + while (!jpeg__SOF(m)) { + if (!jpeg__process_marker(z, m)) return 0; + m = jpeg__get_marker(z); + while (m == JPEG__MARKER_none) { + // some files have extra padding after their blocks, so ok, we'll scan + if (jpeg__at_eof(z->s)) return JpegLoadError("no SOF", "Corrupt JPEG"); + m = jpeg__get_marker(z); + } + } + z->progressive = jpeg__SOF_progressive(m); + if (!jpeg__process_frame_header(z, scan)) return 0; + return 1; + } + + // decode image to YCbCr format + static int jpeg__decode_jpeg_image(jpeg__jpeg* j) + { + int m; + for (m = 0; m < 4; m++) { + j->img_comp[m].raw_data = NULL; + j->img_comp[m].raw_coeff = NULL; + } + j->restart_interval = 0; + if (!jpeg__decode_jpeg_header(j, JPEG__SCAN_load)) return 0; + m = jpeg__get_marker(j); + while (!jpeg__EOI(m)) { + if (jpeg__SOS(m)) { + if (!jpeg__process_scan_header(j)) return 0; + if (!jpeg__parse_entropy_coded_data(j)) return 0; + if (j->marker == JPEG__MARKER_none) { + // handle 0s at the end of image data from IP Kamera 9060 + while (!jpeg__at_eof(j->s)) { + int x = jpeg__get8(j->s); + if (x == 255) { + j->marker = jpeg__get8(j->s); + break; + } + } + // if we reach eof without hitting a marker, jpeg__get_marker() below will fail and we'll eventually return 0 + } + } + else if (jpeg__DNL(m)) { + int Ld = jpeg__get16be(j->s); + jpeg__uint32 NL = jpeg__get16be(j->s); + if (Ld != 4) return JpegLoadError("bad DNL len", "Corrupt JPEG"); + if (NL != j->s->img_y) return JpegLoadError("bad DNL height", "Corrupt JPEG"); + } + else { + if (!jpeg__process_marker(j, m)) return 0; + } + m = jpeg__get_marker(j); + } + if (j->progressive) + jpeg__jpeg_finish(j); + return 1; + } + + // static jfif-centered resampling (across block boundaries) + + typedef jpeg_uc* (*resample_row_func)(jpeg_uc* out, jpeg_uc* in0, jpeg_uc* in1, + int w, int hs); + +#define jpeg__div4(x) ((jpeg_uc) ((x) >> 2)) + + static jpeg_uc* resample_row_1(jpeg_uc* out, jpeg_uc* in_near, jpeg_uc* in_far, int w, int hs) + { + JPEG_NOTUSED(out); + JPEG_NOTUSED(in_far); + JPEG_NOTUSED(w); + JPEG_NOTUSED(hs); + return in_near; + } + + static jpeg_uc* jpeg__resample_row_v_2(jpeg_uc* out, jpeg_uc* in_near, jpeg_uc* in_far, int w, int hs) + { + // need to generate two samples vertically for every one in input + int i; + JPEG_NOTUSED(hs); + for (i = 0; i < w; ++i) + out[i] = jpeg__div4(3 * in_near[i] + in_far[i] + 2); + return out; + } + + static jpeg_uc* jpeg__resample_row_h_2(jpeg_uc* out, jpeg_uc* in_near, jpeg_uc* in_far, int w, int hs) + { + // need to generate two samples horizontally for every one in input + int i; + jpeg_uc* input = in_near; + + if (w == 1) { + // if only one sample, can't do any interpolation + out[0] = out[1] = input[0]; + return out; + } + + out[0] = input[0]; + out[1] = jpeg__div4(input[0] * 3 + input[1] + 2); + for (i = 1; i < w - 1; ++i) { + int n = 3 * input[i] + 2; + out[i * 2 + 0] = jpeg__div4(n + input[i - 1]); + out[i * 2 + 1] = jpeg__div4(n + input[i + 1]); + } + out[i * 2 + 0] = jpeg__div4(input[w - 2] * 3 + input[w - 1] + 2); + out[i * 2 + 1] = input[w - 1]; + + JPEG_NOTUSED(in_far); + JPEG_NOTUSED(hs); + + return out; + } + +#define jpeg__div16(x) ((jpeg_uc) ((x) >> 4)) + + static jpeg_uc* jpeg__resample_row_hv_2(jpeg_uc* out, jpeg_uc* in_near, jpeg_uc* in_far, int w, int hs) + { + // need to generate 2x2 samples for every one in input + int i, t0, t1; + if (w == 1) { + out[0] = out[1] = jpeg__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + out[0] = jpeg__div4(t1 + 2); + for (i = 1; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = jpeg__div16(3 * t0 + t1 + 8); + out[i * 2] = jpeg__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = jpeg__div4(t1 + 2); + + JPEG_NOTUSED(hs); + + return out; + } + +#if defined(JPEG_SSE2) || defined(JPEG_NEON) + static jpeg_uc* jpeg__resample_row_hv_2_simd(jpeg_uc* out, jpeg_uc* in_near, jpeg_uc* in_far, int w, int hs) + { + // need to generate 2x2 samples for every one in input + int i = 0, t0, t1; + + if (w == 1) { + out[0] = out[1] = jpeg__div4(3 * in_near[0] + in_far[0] + 2); + return out; + } + + t1 = 3 * in_near[0] + in_far[0]; + // process groups of 8 pixels for as long as we can. + // note we can't handle the last pixel in a row in this loop + // because we need to handle the filter boundary conditions. + for (; i < ((w - 1) & ~7); i += 8) { +#if defined(JPEG_SSE2) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + __m128i zero = _mm_setzero_si128(); + __m128i farb = _mm_loadl_epi64((__m128i*) (in_far + i)); + __m128i nearb = _mm_loadl_epi64((__m128i*) (in_near + i)); + __m128i farw = _mm_unpacklo_epi8(farb, zero); + __m128i nearw = _mm_unpacklo_epi8(nearb, zero); + __m128i diff = _mm_sub_epi16(farw, nearw); + __m128i nears = _mm_slli_epi16(nearw, 2); + __m128i curr = _mm_add_epi16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + __m128i prv0 = _mm_slli_si128(curr, 2); + __m128i nxt0 = _mm_srli_si128(curr, 2); + __m128i prev = _mm_insert_epi16(prv0, t1, 0); + __m128i next = _mm_insert_epi16(nxt0, 3 * in_near[i + 8] + in_far[i + 8], 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + __m128i bias = _mm_set1_epi16(8); + __m128i curs = _mm_slli_epi16(curr, 2); + __m128i prvd = _mm_sub_epi16(prev, curr); + __m128i nxtd = _mm_sub_epi16(next, curr); + __m128i curb = _mm_add_epi16(curs, bias); + __m128i even = _mm_add_epi16(prvd, curb); + __m128i odd = _mm_add_epi16(nxtd, curb); + + // interleave even and odd pixels, then undo scaling. + __m128i int0 = _mm_unpacklo_epi16(even, odd); + __m128i int1 = _mm_unpackhi_epi16(even, odd); + __m128i de0 = _mm_srli_epi16(int0, 4); + __m128i de1 = _mm_srli_epi16(int1, 4); + + // pack and write output + __m128i outv = _mm_packus_epi16(de0, de1); + _mm_storeu_si128((__m128i*) (out + i * 2), outv); +#elif defined(JPEG_NEON) + // load and perform the vertical filtering pass + // this uses 3*x + y = 4*x + (y - x) + uint8x8_t farb = vld1_u8(in_far + i); + uint8x8_t nearb = vld1_u8(in_near + i); + int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(farb, nearb)); + int16x8_t nears = vreinterpretq_s16_u16(vshll_n_u8(nearb, 2)); + int16x8_t curr = vaddq_s16(nears, diff); // current row + + // horizontal filter works the same based on shifted vers of current + // row. "prev" is current row shifted right by 1 pixel; we need to + // insert the previous pixel value (from t1). + // "next" is current row shifted left by 1 pixel, with first pixel + // of next block of 8 pixels added in. + int16x8_t prv0 = vextq_s16(curr, curr, 7); + int16x8_t nxt0 = vextq_s16(curr, curr, 1); + int16x8_t prev = vsetq_lane_s16(t1, prv0, 0); + int16x8_t next = vsetq_lane_s16(3 * in_near[i + 8] + in_far[i + 8], nxt0, 7); + + // horizontal filter, polyphase implementation since it's convenient: + // even pixels = 3*cur + prev = cur*4 + (prev - cur) + // odd pixels = 3*cur + next = cur*4 + (next - cur) + // note the shared term. + int16x8_t curs = vshlq_n_s16(curr, 2); + int16x8_t prvd = vsubq_s16(prev, curr); + int16x8_t nxtd = vsubq_s16(next, curr); + int16x8_t even = vaddq_s16(curs, prvd); + int16x8_t odd = vaddq_s16(curs, nxtd); + + // undo scaling and round, then store with even/odd phases interleaved + uint8x8x2_t o; + o.val[0] = vqrshrun_n_s16(even, 4); + o.val[1] = vqrshrun_n_s16(odd, 4); + vst2_u8(out + i * 2, o); +#endif + + // "previous" value for next iter + t1 = 3 * in_near[i + 7] + in_far[i + 7]; + } + + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2] = jpeg__div16(3 * t1 + t0 + 8); + + for (++i; i < w; ++i) { + t0 = t1; + t1 = 3 * in_near[i] + in_far[i]; + out[i * 2 - 1] = jpeg__div16(3 * t0 + t1 + 8); + out[i * 2] = jpeg__div16(3 * t1 + t0 + 8); + } + out[w * 2 - 1] = jpeg__div4(t1 + 2); + + JPEG_NOTUSED(hs); + + return out; + } +#endif + + static jpeg_uc* jpeg__resample_row_generic(jpeg_uc* out, jpeg_uc* in_near, jpeg_uc* in_far, int w, int hs) + { + // resample with nearest-neighbor + int i, j; + JPEG_NOTUSED(in_far); + for (i = 0; i < w; ++i) + for (j = 0; j < hs; ++j) + out[i * hs + j] = in_near[i]; + return out; + } + + // this is a reduced-precision calculation of YCbCr-to-RGB introduced + // to make sure the code produces the same results in both SIMD and scalar +#define jpeg__float2fixed(x) (((int) ((x) * 4096.0f + 0.5f)) << 8) + static void jpeg__YCbCr_to_RGB_row(jpeg_uc* out, const jpeg_uc* y, const jpeg_uc* pcb, const jpeg_uc* pcr, int count, int step) + { + int i; + for (i = 0; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * jpeg__float2fixed(1.40200f); + g = y_fixed + (cr * -jpeg__float2fixed(0.71414f)) + ((cb * -jpeg__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * jpeg__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned)g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned)b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (jpeg_uc)r; + out[1] = (jpeg_uc)g; + out[2] = (jpeg_uc)b; + out[3] = 255; + out += step; + } + } + +#if defined(JPEG_SSE2) || defined(JPEG_NEON) + static void jpeg__YCbCr_to_RGB_simd(jpeg_uc* out, jpeg_uc const* y, jpeg_uc const* pcb, jpeg_uc const* pcr, int count, int step) + { + int i = 0; + +#ifdef JPEG_SSE2 + // step == 3 is pretty ugly on the final interleave, and i'm not convinced + // it's useful in practice (you wouldn't use it for textures, for example). + // so just accelerate step == 4 case. + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + __m128i signflip = _mm_set1_epi8(-0x80); + __m128i cr_const0 = _mm_set1_epi16((short)(1.40200f * 4096.0f + 0.5f)); + __m128i cr_const1 = _mm_set1_epi16(-(short)(0.71414f * 4096.0f + 0.5f)); + __m128i cb_const0 = _mm_set1_epi16(-(short)(0.34414f * 4096.0f + 0.5f)); + __m128i cb_const1 = _mm_set1_epi16((short)(1.77200f * 4096.0f + 0.5f)); + __m128i y_bias = _mm_set1_epi8((char)(unsigned char)128); + __m128i xw = _mm_set1_epi16(255); // alpha channel + + for (; i + 7 < count; i += 8) { + // load + __m128i y_bytes = _mm_loadl_epi64((__m128i*) (y + i)); + __m128i cr_bytes = _mm_loadl_epi64((__m128i*) (pcr + i)); + __m128i cb_bytes = _mm_loadl_epi64((__m128i*) (pcb + i)); + __m128i cr_biased = _mm_xor_si128(cr_bytes, signflip); // -128 + __m128i cb_biased = _mm_xor_si128(cb_bytes, signflip); // -128 + + // unpack to short (and left-shift cr, cb by 8) + __m128i yw = _mm_unpacklo_epi8(y_bias, y_bytes); + __m128i crw = _mm_unpacklo_epi8(_mm_setzero_si128(), cr_biased); + __m128i cbw = _mm_unpacklo_epi8(_mm_setzero_si128(), cb_biased); + + // color transform + __m128i yws = _mm_srli_epi16(yw, 4); + __m128i cr0 = _mm_mulhi_epi16(cr_const0, crw); + __m128i cb0 = _mm_mulhi_epi16(cb_const0, cbw); + __m128i cb1 = _mm_mulhi_epi16(cbw, cb_const1); + __m128i cr1 = _mm_mulhi_epi16(crw, cr_const1); + __m128i rws = _mm_add_epi16(cr0, yws); + __m128i gwt = _mm_add_epi16(cb0, yws); + __m128i bws = _mm_add_epi16(yws, cb1); + __m128i gws = _mm_add_epi16(gwt, cr1); + + // descale + __m128i rw = _mm_srai_epi16(rws, 4); + __m128i bw = _mm_srai_epi16(bws, 4); + __m128i gw = _mm_srai_epi16(gws, 4); + + // back to byte, set up for transpose + __m128i brb = _mm_packus_epi16(rw, bw); + __m128i gxb = _mm_packus_epi16(gw, xw); + + // transpose to interleave channels + __m128i t0 = _mm_unpacklo_epi8(brb, gxb); + __m128i t1 = _mm_unpackhi_epi8(brb, gxb); + __m128i o0 = _mm_unpacklo_epi16(t0, t1); + __m128i o1 = _mm_unpackhi_epi16(t0, t1); + + // store + _mm_storeu_si128((__m128i*) (out + 0), o0); + _mm_storeu_si128((__m128i*) (out + 16), o1); + out += 32; + } + } +#endif + +#ifdef JPEG_NEON + // in this version, step=3 support would be easy to add. but is there demand? + if (step == 4) { + // this is a fairly straightforward implementation and not super-optimized. + uint8x8_t signflip = vdup_n_u8(0x80); + int16x8_t cr_const0 = vdupq_n_s16((short)(1.40200f * 4096.0f + 0.5f)); + int16x8_t cr_const1 = vdupq_n_s16(-(short)(0.71414f * 4096.0f + 0.5f)); + int16x8_t cb_const0 = vdupq_n_s16(-(short)(0.34414f * 4096.0f + 0.5f)); + int16x8_t cb_const1 = vdupq_n_s16((short)(1.77200f * 4096.0f + 0.5f)); + + for (; i + 7 < count; i += 8) { + // load + uint8x8_t y_bytes = vld1_u8(y + i); + uint8x8_t cr_bytes = vld1_u8(pcr + i); + uint8x8_t cb_bytes = vld1_u8(pcb + i); + int8x8_t cr_biased = vreinterpret_s8_u8(vsub_u8(cr_bytes, signflip)); + int8x8_t cb_biased = vreinterpret_s8_u8(vsub_u8(cb_bytes, signflip)); + + // expand to s16 + int16x8_t yws = vreinterpretq_s16_u16(vshll_n_u8(y_bytes, 4)); + int16x8_t crw = vshll_n_s8(cr_biased, 7); + int16x8_t cbw = vshll_n_s8(cb_biased, 7); + + // color transform + int16x8_t cr0 = vqdmulhq_s16(crw, cr_const0); + int16x8_t cb0 = vqdmulhq_s16(cbw, cb_const0); + int16x8_t cr1 = vqdmulhq_s16(crw, cr_const1); + int16x8_t cb1 = vqdmulhq_s16(cbw, cb_const1); + int16x8_t rws = vaddq_s16(yws, cr0); + int16x8_t gws = vaddq_s16(vaddq_s16(yws, cb0), cr1); + int16x8_t bws = vaddq_s16(yws, cb1); + + // undo scaling, round, convert to byte + uint8x8x4_t o; + o.val[0] = vqrshrun_n_s16(rws, 4); + o.val[1] = vqrshrun_n_s16(gws, 4); + o.val[2] = vqrshrun_n_s16(bws, 4); + o.val[3] = vdup_n_u8(255); + + // store, interleaving r/g/b/a + vst4_u8(out, o); + out += 8 * 4; + } + } +#endif + + for (; i < count; ++i) { + int y_fixed = (y[i] << 20) + (1 << 19); // rounding + int r, g, b; + int cr = pcr[i] - 128; + int cb = pcb[i] - 128; + r = y_fixed + cr * jpeg__float2fixed(1.40200f); + g = y_fixed + cr * -jpeg__float2fixed(0.71414f) + ((cb * -jpeg__float2fixed(0.34414f)) & 0xffff0000); + b = y_fixed + cb * jpeg__float2fixed(1.77200f); + r >>= 20; + g >>= 20; + b >>= 20; + if ((unsigned)r > 255) { if (r < 0) r = 0; else r = 255; } + if ((unsigned)g > 255) { if (g < 0) g = 0; else g = 255; } + if ((unsigned)b > 255) { if (b < 0) b = 0; else b = 255; } + out[0] = (jpeg_uc)r; + out[1] = (jpeg_uc)g; + out[2] = (jpeg_uc)b; + out[3] = 255; + out += step; + } + } +#endif + + // set up the kernels + static void jpeg__setup_jpeg(jpeg__jpeg* j) + { + j->idct_block_kernel = jpeg__idct_block; + j->YCbCr_to_RGB_kernel = jpeg__YCbCr_to_RGB_row; + j->resample_row_hv_2_kernel = jpeg__resample_row_hv_2; + +#ifdef JPEG_SSE2 + if (jpeg__sse2_available()) { + j->idct_block_kernel = jpeg__idct_simd; + j->YCbCr_to_RGB_kernel = jpeg__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = jpeg__resample_row_hv_2_simd; + } +#endif + +#ifdef JPEG_NEON + j->idct_block_kernel = jpeg__idct_simd; + j->YCbCr_to_RGB_kernel = jpeg__YCbCr_to_RGB_simd; + j->resample_row_hv_2_kernel = jpeg__resample_row_hv_2_simd; +#endif + } + + // clean up the temporary component buffers + static void jpeg__cleanup_jpeg(jpeg__jpeg* j) + { + jpeg__free_jpeg_components(j, j->s->img_n, 0); + } + + typedef struct + { + resample_row_func resample; + jpeg_uc* line0, * line1; + int hs, vs; // expansion factor in each axis + int w_lores; // horizontal pixels pre-expansion + int ystep; // how far through vertical expansion we are + int ypos; // which pre-expansion row we're on + } jpeg__resample; + + // fast 0..255 * 0..255 => 0..255 rounded multiplication + static jpeg_uc jpeg__blinn_8x8(jpeg_uc x, jpeg_uc y) + { + unsigned int t = x * y + 128; + return (jpeg_uc)((t + (t >> 8)) >> 8); + } + + static jpeg_uc* load_jpeg_image(jpeg__jpeg* z, int* out_x, int* out_y, int* comp, int req_comp) + { + int n, decode_n, is_rgb; + z->s->img_n = 0; // make jpeg__cleanup_jpeg safe + + // validate req_comp + if (req_comp < 0 || req_comp > 4) return jpeg__errpuc("bad req_comp", "Internal error"); + + // load a jpeg image from whichever source, but leave in YCbCr format + if (!jpeg__decode_jpeg_image(z)) { jpeg__cleanup_jpeg(z); return NULL; } + + // determine actual number of components to generate + n = req_comp ? req_comp : z->s->img_n >= 3 ? 3 : 1; + + is_rgb = z->s->img_n == 3 && (z->rgb == 3 || (z->app14_color_transform == 0 && !z->jfif)); + + if (z->s->img_n == 3 && n < 3 && !is_rgb) + decode_n = 1; + else + decode_n = z->s->img_n; + + // resample and color-convert + { + int k; + unsigned int i, j; + jpeg_uc* output; + jpeg_uc* coutput[4] = { NULL, NULL, NULL, NULL }; + + jpeg__resample res_comp[4]; + + for (k = 0; k < decode_n; ++k) { + jpeg__resample* r = &res_comp[k]; + + // allocate line buffer big enough for upsampling off the edges + // with upsample factor of 4 + z->img_comp[k].linebuf = (jpeg_uc*)jpeg__malloc(z->s->img_x + 3); + if (!z->img_comp[k].linebuf) { jpeg__cleanup_jpeg(z); return jpeg__errpuc("outofmem", "Out of memory"); } + + r->hs = z->img_h_max / z->img_comp[k].h; + r->vs = z->img_v_max / z->img_comp[k].v; + r->ystep = r->vs >> 1; + r->w_lores = (z->s->img_x + r->hs - 1) / r->hs; + r->ypos = 0; + r->line0 = r->line1 = z->img_comp[k].data; + + if (r->hs == 1 && r->vs == 1) r->resample = resample_row_1; + else if (r->hs == 1 && r->vs == 2) r->resample = jpeg__resample_row_v_2; + else if (r->hs == 2 && r->vs == 1) r->resample = jpeg__resample_row_h_2; + else if (r->hs == 2 && r->vs == 2) r->resample = z->resample_row_hv_2_kernel; + else r->resample = jpeg__resample_row_generic; + } + + // can't error after this so, this is safe + output = (jpeg_uc*)jpeg__malloc_mad3(n, z->s->img_x, z->s->img_y, 1); + if (!output) { jpeg__cleanup_jpeg(z); return jpeg__errpuc("outofmem", "Out of memory"); } + + // now go ahead and resample + for (j = 0; j < z->s->img_y; ++j) { + jpeg_uc* out = output + n * z->s->img_x * j; + for (k = 0; k < decode_n; ++k) { + jpeg__resample* r = &res_comp[k]; + int y_bot = r->ystep >= (r->vs >> 1); + coutput[k] = r->resample(z->img_comp[k].linebuf, + y_bot ? r->line1 : r->line0, + y_bot ? r->line0 : r->line1, + r->w_lores, r->hs); + if (++r->ystep >= r->vs) { + r->ystep = 0; + r->line0 = r->line1; + if (++r->ypos < z->img_comp[k].y) + r->line1 += z->img_comp[k].w2; + } + } + if (n >= 3) { + jpeg_uc* y = coutput[0]; + if (z->s->img_n == 3) { + if (is_rgb) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = y[i]; + out[1] = coutput[1][i]; + out[2] = coutput[2][i]; + out[3] = 255; + out += n; + } + } + else { + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } + else if (z->s->img_n == 4) { + if (z->app14_color_transform == 0) { // CMYK + for (i = 0; i < z->s->img_x; ++i) { + jpeg_uc m = coutput[3][i]; + out[0] = jpeg__blinn_8x8(coutput[0][i], m); + out[1] = jpeg__blinn_8x8(coutput[1][i], m); + out[2] = jpeg__blinn_8x8(coutput[2][i], m); + out[3] = 255; + out += n; + } + } + else if (z->app14_color_transform == 2) { // YCCK + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + for (i = 0; i < z->s->img_x; ++i) { + jpeg_uc m = coutput[3][i]; + out[0] = jpeg__blinn_8x8(255 - out[0], m); + out[1] = jpeg__blinn_8x8(255 - out[1], m); + out[2] = jpeg__blinn_8x8(255 - out[2], m); + out += n; + } + } + else { // YCbCr + alpha? Ignore the fourth channel for now + z->YCbCr_to_RGB_kernel(out, y, coutput[1], coutput[2], z->s->img_x, n); + } + } + else + for (i = 0; i < z->s->img_x; ++i) { + out[0] = out[1] = out[2] = y[i]; + out[3] = 255; // not used if n==3 + out += n; + } + } + else { + if (is_rgb) { + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) + *out++ = jpeg__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + else { + for (i = 0; i < z->s->img_x; ++i, out += 2) { + out[0] = jpeg__compute_y(coutput[0][i], coutput[1][i], coutput[2][i]); + out[1] = 255; + } + } + } + else if (z->s->img_n == 4 && z->app14_color_transform == 0) { + for (i = 0; i < z->s->img_x; ++i) { + jpeg_uc m = coutput[3][i]; + jpeg_uc r = jpeg__blinn_8x8(coutput[0][i], m); + jpeg_uc g = jpeg__blinn_8x8(coutput[1][i], m); + jpeg_uc b = jpeg__blinn_8x8(coutput[2][i], m); + out[0] = jpeg__compute_y(r, g, b); + out[1] = 255; + out += n; + } + } + else if (z->s->img_n == 4 && z->app14_color_transform == 2) { + for (i = 0; i < z->s->img_x; ++i) { + out[0] = jpeg__blinn_8x8(255 - coutput[0][i], coutput[3][i]); + out[1] = 255; + out += n; + } + } + else { + jpeg_uc* y = coutput[0]; + if (n == 1) + for (i = 0; i < z->s->img_x; ++i) out[i] = y[i]; + else + for (i = 0; i < z->s->img_x; ++i) { *out++ = y[i]; *out++ = 255; } + } + } + } + jpeg__cleanup_jpeg(z); + *out_x = z->s->img_x; + *out_y = z->s->img_y; + if (comp) *comp = z->s->img_n >= 3 ? 3 : 1; // report original components, not output + return output; + } + } + + static void* jpeg__jpeg_load(jpeg__context* s, int* x, int* y, int* comp, int req_comp, jpeg__result_info* ri) + { + unsigned char* result; + jpeg__jpeg* j = (jpeg__jpeg*)jpeg__malloc(sizeof(jpeg__jpeg)); + JPEG_NOTUSED(ri); + j->s = s; + jpeg__setup_jpeg(j); + result = load_jpeg_image(j, x, y, comp, req_comp); + JPEG_FREE(j); + return result; + } + + static int jpeg__jpeg_test(jpeg__context* s) + { + int r; + jpeg__jpeg* j = (jpeg__jpeg*)jpeg__malloc(sizeof(jpeg__jpeg)); + j->s = s; + jpeg__setup_jpeg(j); + r = jpeg__decode_jpeg_header(j, JPEG__SCAN_type); + jpeg__rewind(s); + JPEG_FREE(j); + return r; + } + + static int jpeg__jpeg_info_raw(jpeg__jpeg* j, int* x, int* y, int* comp) + { + if (!jpeg__decode_jpeg_header(j, JPEG__SCAN_header)) { + jpeg__rewind(j->s); + return 0; + } + if (x) *x = j->s->img_x; + if (y) *y = j->s->img_y; + if (comp) *comp = j->s->img_n >= 3 ? 3 : 1; + return 1; + } + + static int jpeg__jpeg_info(jpeg__context* s, int* x, int* y, int* comp) + { + int result; + jpeg__jpeg* j = (jpeg__jpeg*)(jpeg__malloc(sizeof(jpeg__jpeg))); + j->s = s; + result = jpeg__jpeg_info_raw(j, x, y, comp); + JPEG_FREE(j); + return result; + } + + //------------------------------------------------------------------------ + + static int jpeg__stdio_read(void* user, char* data, int size) + { + InputMemoryStream* stream = (InputMemoryStream*)user; + return (int)stream->Read(size, data); + } + + static void jpeg__stdio_skip(void* user, int n) + { + InputMemoryStream* stream = (InputMemoryStream*)user; + stream->Skip(n); + } + + static int jpeg__stdio_eof(void* user) + { + InputMemoryStream* stream = (InputMemoryStream*)user; + return stream->Pos() == stream->Size() ? 1 : 0; + } + + //--------------------------------------------------------------------- + + ImageJpegLoader::ImageJpegLoader(const ImageLoaderParam& param) + : Base::ImageJpegLoader(param) + { + if (_param.format == SimdPixelFormatNone) + _param.format = SimdPixelFormatRgb24; + } + + bool ImageJpegLoader::FromStream() + { + int x, y, comp; + jpeg__context s; + s.io.eof = jpeg__stdio_eof; + s.io.read = jpeg__stdio_read; + s.io.skip = jpeg__stdio_skip; + s.io_user_data = &_stream; + s.buflen = sizeof(s.buffer_start); + s.read_from_callbacks = 1; + s.callback_already_read = 0; + s.img_buffer = s.img_buffer_original = s.buffer_start; + jpeg__refill_buffer(&s); + s.img_buffer_original_end = s.img_buffer_end; + jpeg__result_info ri; + uint8_t * data = (uint8_t*)jpeg__jpeg_load(&s, &x, &y, &comp, 3, &ri); + if (data) + { + size_t stride = 3 * x; + _image.Recreate(x, y, (Image::Format)_param.format); + switch (_param.format) + { + case SimdPixelFormatGray8: + Sse41::RgbToGray(data, x, y, stride, _image.data, _image.stride); + break; + case SimdPixelFormatBgr24: + Sse41::BgrToRgb(data, x, y, stride, _image.data, _image.stride); + break; + case SimdPixelFormatBgra32: + Sse41::RgbToBgra(data, x, y, stride, _image.data, _image.stride, 0xFF); + break; + case SimdPixelFormatRgb24: + Base::Copy(data, stride, x, y, 3, _image.data, _image.stride); + break; + case SimdPixelFormatRgba32: + Sse41::BgrToBgra(data, x, y, stride, _image.data, _image.stride, 0xFF); + break; + default: + break; + } + JPEG_FREE(data); + return true; + } + return false; + } + } +#endif +} diff --git a/src/Test/TestImageIO.cpp b/src/Test/TestImageIO.cpp index 10937b0caf..85484adcb5 100644 --- a/src/Test/TestImageIO.cpp +++ b/src/Test/TestImageIO.cpp @@ -665,7 +665,7 @@ namespace Test std::vector formats = { View::Gray8, View::Bgr24, View::Bgra32, View::Rgb24, View::Rgba32 }; for (size_t format = 0; format < formats.size(); format++) { - for (int file = (int)SimdImageFilePng; file <= (int)SimdImageFilePng; file++) + for (int file = (int)SimdImageFileJpeg; file <= (int)SimdImageFileJpeg; file++) { if (file == SimdImageFileJpeg) { From 12b9586e0d1abca006587395258366be66eae07e Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 13 Jun 2023 17:08:01 +0300 Subject: [PATCH 08/44] +add Support of 5-bit depth in Base implementation of functions DescrIntEncode32f, DescrIntEncode16f, DescrIntDecode32f, DescrIntDecode16f, DescrIntCosineDistance, DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa. --- docs/2023.html | 8 ++ docs/help/group__descrint.html | 2 +- src/Simd/SimdAvx2DescrInt.cpp | 4 + src/Simd/SimdAvx512bwDescrInt.cpp | 4 + src/Simd/SimdBaseDescrInt.cpp | 144 +++++++++++++++++++++++++++- src/Simd/SimdBaseImageLoadJpeg.cpp | 47 +++++---- src/Simd/SimdDescrInt.h | 11 +-- src/Simd/SimdImageLoadJpeg.h | 4 + src/Simd/SimdLib.h | 2 +- src/Simd/SimdSse41DescrInt.cpp | 4 + src/Simd/SimdSse41ImageLoadJpeg.cpp | 31 ++---- src/Test/TestDescrInt.cpp | 19 ++-- 12 files changed, 214 insertions(+), 66 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index d0f5c6d5b4..5611c588a0 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -39,6 +39,14 @@
New features
  • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntEncode16f.
  • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode16f.
  • +
  • Support of 5-bit depth in Base implementation of function DescrIntEncode32f.
  • +
  • Support of 5-bit depth in Base implementation of function DescrIntEncode16f.
  • +
  • Support of 5-bit depth in Base implementation of function DescrIntDecode32f.
  • +
  • Support of 5-bit depth in Base implementation of function DescrIntDecode16f.
  • +
  • Support of 5-bit depth in Base implementation of function DescrIntCosineDistance.
  • +
  • Support of 5-bit depth in Base implementation of function DescrIntCosineDistancesMxNp.
  • +
  • Support of 5-bit depth in Base implementation of function DescrIntCosineDistancesMxNa.
  • +
Renaming
+
Bug fixing
+
    +
  • Compiler error in file SimdYuvToBgr.h.
Renaming
    diff --git a/src/Simd/SimdYuvToBgr.h b/src/Simd/SimdYuvToBgr.h index 4bdc5e541e..ed38351873 100644 --- a/src/Simd/SimdYuvToBgr.h +++ b/src/Simd/SimdYuvToBgr.h @@ -29,6 +29,7 @@ #include "Simd/SimdMath.h" #include "Simd/SimdUnpack.h" #include "Simd/SimdLog.h" +#include "Simd/SimdLoad.h" namespace Simd { From af40a6a02424af4d39e027913d869b0b46feebbd Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 14 Jun 2023 16:11:33 +0300 Subject: [PATCH 12/44] +add Support of 4-bit depth in SSE4.1 optimizations of functions DescrIntEncode32f, DescrIntEncode16f. --- src/Simd/SimdDescrIntCommon.h | 2 + src/Simd/SimdSse41DescrInt.cpp | 75 ++++++++++++++++++++++++++++++++++ 2 files changed, 77 insertions(+) diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index 8841d34c75..e75451c668 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -49,6 +49,8 @@ namespace Simd #ifdef SIMD_SSE41_ENABLE namespace Sse41 { + const __m128i E4_MULLO = SIMD_MM_SETR_EPI16(4096, 1, 4096, 1, 4096, 1, 4096, 1); + const __m128i E5_MULLO = SIMD_MM_SETR_EPI16(256, 32, 4, 128, 16, 2, 64, 8); const __m128i E5_SHFL0 = SIMD_MM_SETR_EPI8(0x1, 0x3, 0x7, 0x9, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); const __m128i E5_SHFL1 = SIMD_MM_SETR_EPI8(0x2, 0x4, 0x8, 0xA, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1); diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 31074b1810..8c74336402 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -86,6 +86,42 @@ namespace Simd return Encode32f(_mm_loadu_ps(src), scale, min, sum, sqsum); } + static SIMD_INLINE __m128i Encode32f4(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); + return _mm_srli_epi32(_mm_mullo_epi16(_mm_packus_epi32(i0, i4), E4_MULLO), 12); + } + + static SIMD_INLINE __m128i Encode32f4x8(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i s0 = Encode32f4(src + 0 * 8, scale, min, sum, sqsum); + return _mm_packus_epi16(_mm_packus_epi32(s0, K_ZERO), K_ZERO); + } + + static SIMD_INLINE __m128i Encode32f4x16(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i s0 = Encode32f4(src + 0 * 8, scale, min, sum, sqsum); + __m128i s1 = Encode32f4(src + 1 * 8, scale, min, sum, sqsum); + return _mm_packus_epi16(_mm_packus_epi32(s0, s1), K_ZERO); + } + + static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size16 = AlignLo(size, 16); + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < size16; i += 16, src += 16, dst += 8) + _mm_storel_epi64((__m128i*)dst, Encode32f4x16(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 4) + *(uint32_t*)(dst) = _mm_extract_epi32(Encode32f4x8(src, _scale, _min, _sum, _sqsum), 0); + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + static SIMD_INLINE __m128i Encode32f5(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) { __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); @@ -198,6 +234,43 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static SIMD_INLINE __m128i Encode16f4(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i u0 = _mm_loadu_si128((__m128i*)(src)); + __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); + __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); + return _mm_srli_epi32(_mm_mullo_epi16(_mm_packus_epi32(i0, i4), E4_MULLO), 12); + } + + static SIMD_INLINE __m128i Encode16f4x8(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i s0 = Encode16f4(src + 0 * 8, scale, min, sum, sqsum); + return _mm_packus_epi16(_mm_packus_epi32(s0, K_ZERO), K_ZERO); + } + + static SIMD_INLINE __m128i Encode16f4x16(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i s0 = Encode16f4(src + 0 * 8, scale, min, sum, sqsum); + __m128i s1 = Encode16f4(src + 1 * 8, scale, min, sum, sqsum); + return _mm_packus_epi16(_mm_packus_epi32(s0, s1), K_ZERO); + } + + static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size16 = AlignLo(size, 16); + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < size16; i += 16, src += 16, dst += 8) + _mm_storel_epi64((__m128i*)dst, Encode16f4x16(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 4) + *(uint32_t*)(dst) = _mm_extract_epi32(Encode16f4x8(src, _scale, _min, _sum, _sqsum), 0); + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + static SIMD_INLINE __m128i Encode16f5(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) { __m128i u0 = _mm_loadu_si128((__m128i*)(src)); @@ -1159,6 +1232,8 @@ namespace Simd { case 4: { + _encode32f = Encode32f4; + _encode16f = Encode16f4; break; } case 5: From 4d6e34e33f1bb6f87aa36084fe980f709122ee22 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 15 Jun 2023 10:57:48 +0300 Subject: [PATCH 13/44] +add Support of 5-bit depth in SSE4.1 optimizations of functions DescrIntDecode32f, DescrIntDecode16f, DescrIntCosineDistance, DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa. --- src/Simd/SimdConst.h | 1 + src/Simd/SimdDescrIntCommon.h | 3 + src/Simd/SimdSse41DescrInt.cpp | 198 +++++++++++++++++++++++++++++++++ 3 files changed, 202 insertions(+) diff --git a/src/Simd/SimdConst.h b/src/Simd/SimdConst.h index c369c9d134..2db5a5d9c9 100644 --- a/src/Simd/SimdConst.h +++ b/src/Simd/SimdConst.h @@ -99,6 +99,7 @@ namespace Simd const __m128i K8_04 = SIMD_MM_SET1_EPI8(0x04); const __m128i K8_07 = SIMD_MM_SET1_EPI8(0x07); const __m128i K8_08 = SIMD_MM_SET1_EPI8(0x08); + const __m128i K8_0F = SIMD_MM_SET1_EPI8(0x0F); const __m128i K8_10 = SIMD_MM_SET1_EPI8(0x10); const __m128i K8_20 = SIMD_MM_SET1_EPI8(0x20); const __m128i K8_40 = SIMD_MM_SET1_EPI8(0x40); diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index e75451c668..fde104d86f 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -64,6 +64,9 @@ namespace Simd const __m128i E7_SHFL0 = SIMD_MM_SETR_EPI8(0x1, 0x3, 0x5, 0x7, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1); const __m128i E7_SHFL1 = SIMD_MM_SETR_EPI8(0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1); + const __m128i C4_MULLO = SIMD_MM_SETR_EPI16(4096, 256, 4096, 256, 4096, 256, 4096, 256); + const __m128i C4_SHFL0 = SIMD_MM_SETR_EPI8(0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3); + const __m128i C5_SHFL0 = SIMD_MM_SETR_EPI8(0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x4); const __m128i C5_SHFL1 = SIMD_MM_SETR_EPI8(0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9); const __m128i C5_MULLO = SIMD_MM_SETR_EPI16(8, 64, 2, 16, 128, 4, 32, 256); diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 8c74336402..85e30f43df 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -389,6 +389,22 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, C4_SHFL0), C4_MULLO), 12); + _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + src += 4; + dst += 8; + } + } + static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) { assert(size % 8 == 0); @@ -452,6 +468,23 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, C4_SHFL0), C4_MULLO), 12); + __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); + src += 4; + dst += 8; + } + } + static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) { assert(size % 8 == 0); @@ -524,6 +557,28 @@ namespace Simd template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); + template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m128i ab32 = _mm_setzero_si128(); + size_t i = 0, size32 = AlignLo(size, 32); + for (; i < size32; i += 32, a += 16, b += 16) + { + __m128i _a = _mm_loadu_si128((__m128i*)a); + __m128i _b = _mm_loadu_si128((__m128i*)b); + __m128i ab16 = _mm_maddubs_epi16(_mm_and_si128(_a, K8_0F), _mm_and_si128(_b, K8_0F)); + ab16 = _mm_add_epi16(ab16, _mm_maddubs_epi16(_mm_and_si128(_mm_srli_epi16(_a, 4), K8_0F), _mm_and_si128(_mm_srli_epi16(_b, 4), K8_0F))); + ab32 = _mm_add_epi32(ab32, _mm_madd_epi16(ab16, K16_0001)); + } + for (; i < size; i += 8, a += 4, b += 4) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C4_SHFL0), C4_MULLO), 12); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C4_SHFL0), C4_MULLO), 12); + ab32 = _mm_add_epi32(_mm_madd_epi16(_a, _b), ab32); + } + return ExtractInt32Sum(ab32); + } + template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) { assert(size % 8 == 0); @@ -629,6 +684,86 @@ namespace Simd template void MicroCosineDistances2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m128i a0, a1, b0; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + __m128i ab10 = _mm_setzero_si128(); + __m128i ab11 = _mm_setzero_si128(); + __m128i ab12 = _mm_setzero_si128(); + __m128i ab13 = _mm_setzero_si128(); + for (; i < size32; i += 32, o += 16) + { + a0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[0] + o)), K8_0F); + a1 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[1] + o)), K8_0F); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[0] + o)), K8_0F); + ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab10 = _mm_add_epi32(ab10, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[1] + o)), K8_0F); + ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab11 = _mm_add_epi32(ab11, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[2] + o)), K8_0F); + ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab12 = _mm_add_epi32(ab12, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[3] + o)), K8_0F); + ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab13 = _mm_add_epi32(ab13, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + a0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[0] + o)), 4), K8_0F); + a1 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[1] + o)), 4), K8_0F); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[0] + o)), 4), K8_0F); + ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab10 = _mm_add_epi32(ab10, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[1] + o)), 4), K8_0F); + ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab11 = _mm_add_epi32(ab11, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[2] + o)), 4), K8_0F); + ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab12 = _mm_add_epi32(ab12, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[3] + o)), 4), K8_0F); + ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab13 = _mm_add_epi32(ab13, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + } + for (; i < size; i += 8, o += 4) + { + a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C4_SHFL0), C4_MULLO), 12); + a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C4_SHFL0), C4_MULLO), 12); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C4_SHFL0), C4_MULLO), 12); + ab00 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab10); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C4_SHFL0), C4_MULLO), 12); + ab01 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab11); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C4_SHFL0), C4_MULLO), 12); + ab02 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab12); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C4_SHFL0), C4_MULLO), 12); + ab03 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab13); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 _size = _mm_set1_ps(float(size)); + DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + } + template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; @@ -955,6 +1090,65 @@ namespace Simd template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m128i a0, a1, b0; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + for (; i < size32; i += 32, o += 16) + { + a0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[0] + o)), K8_0F); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[0] + o)), K8_0F); + ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[1] + o)), K8_0F); + ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[2] + o)), K8_0F); + ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[3] + o)), K8_0F); + ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + a0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[0] + o)), 4), K8_0F); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[0] + o)), 4), K8_0F); + ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[1] + o)), 4), K8_0F); + ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[2] + o)), 4), K8_0F); + ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[3] + o)), 4), K8_0F); + ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + } + for (; i < size; i += 8, o += 4) + { + a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C4_SHFL0), C4_MULLO), 12); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C4_SHFL0), C4_MULLO), 12); + ab00 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab00); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C4_SHFL0), C4_MULLO), 12); + ab01 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab01); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C4_SHFL0), C4_MULLO), 12); + ab02 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab02); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C4_SHFL0), C4_MULLO), 12); + ab03 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 _size = _mm_set1_ps(float(size)); + DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + } + template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; @@ -1234,6 +1428,10 @@ namespace Simd { _encode32f = Encode32f4; _encode16f = Encode16f4; + _decode32f = Decode32f4; + _decode16f = Decode16f4; + _cosineDistance = Sse41::CosineDistance<4>; + _macroCosineDistances = Sse41::MacroCosineDistances<4>; break; } case 5: From 84675f939989bfab381072378eab0276dcaa05ed Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 15 Jun 2023 12:02:37 +0300 Subject: [PATCH 14/44] +add Support of 5-bit depth in AVX2 optimizations of functions DescrIntEncode32f, DescrIntEncode16f, DescrIntDecode32f, DescrIntDecode16f, DescrIntCosineDistance, DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa. --- docs/2023.html | 14 +- src/Simd/SimdAvx2DescrInt.cpp | 258 ++++++++++++++++++++++++++++++++++ src/Simd/SimdDescrIntCommon.h | 16 +++ 3 files changed, 281 insertions(+), 7 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index dc9390421f..e1afb922c2 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -39,13 +39,13 @@
    New features
    • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntEncode16f.
    • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode16f.
    • -
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1 optimizations of function DescrIntEncode32f.
    • -
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1 optimizations of function DescrIntEncode16f.
    • -
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1 optimizations of function DescrIntDecode32f.
    • -
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1 optimizations of function DescrIntDecode16f.
    • -
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1 optimizations of function DescrIntCosineDistance.
    • -
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1 optimizations of function DescrIntCosineDistancesMxNp.
    • -
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1 optimizations of function DescrIntCosineDistancesMxNa.
    • +
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntEncode32f.
    • +
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntEncode16f.
    • +
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntDecode32f.
    • +
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntDecode16f.
    • +
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntCosineDistance.
    • +
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntCosineDistancesMxNp.
    • +
    • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntCosineDistancesMxNa.
    Bug fixing
      diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 303d5818c5..0c836dfd24 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -84,6 +84,44 @@ namespace Simd return Encode32f(_mm256_loadu_ps(src), scale, min, sum, sqsum); } + static SIMD_INLINE __m128i Encode32f5x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E5_MULLO); + return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E5_SHFL0), _mm_shuffle_epi8(s0, Sse41::E5_SHFL1)), _mm_shuffle_epi8(s0, Sse41::E5_SHFL2)); + } + + static SIMD_INLINE __m128i Encode32f5x2(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m256i i8 = Encode32f(src + 8, scale, min, sum, sqsum); + __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E5_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_or_si256(_mm256_shuffle_epi8(s0, E5_SHFL0), _mm256_shuffle_epi8(s0, E5_SHFL1)), _mm256_shuffle_epi8(s0, E5_SHFL2)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static void Encode32f5(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < main16; i += 16, src += 16, dst += 10) + _mm_storeu_si128((__m128i*)dst, Encode32f5x2(src, _scale, _min, _sum, _sqsum)); + for (; i < main; i += 8, src += 8, dst += 5) + _mm_storel_epi64((__m128i*)dst, Encode32f5x1(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 5) + { + __m128i d0 = Encode32f5x1(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + static SIMD_INLINE __m128i Encode32f6x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) { __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); @@ -188,6 +226,44 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static SIMD_INLINE __m128i Encode16f5x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E5_MULLO); + return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E5_SHFL0), _mm_shuffle_epi8(s0, Sse41::E5_SHFL1)), _mm_shuffle_epi8(s0, Sse41::E5_SHFL2)); + } + + static SIMD_INLINE __m128i Encode16f5x2(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); + __m256i i8 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); + __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E5_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_or_si256(_mm256_shuffle_epi8(s0, E5_SHFL0), _mm256_shuffle_epi8(s0, E5_SHFL1)), _mm256_shuffle_epi8(s0, E5_SHFL2)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static void Encode16f5(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < main16; i += 16, src += 16, dst += 10) + _mm_storeu_si128((__m128i*)dst, Encode16f5x2(src, _scale, _min, _sum, _sqsum)); + for (; i < main; i += 8, src += 8, dst += 5) + _mm_storel_epi64((__m128i*)dst, Encode16f5x1(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 5) + { + __m128i d0 = Encode16f5x1(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + static SIMD_INLINE __m128i Encode16f6x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) { __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); @@ -292,6 +368,31 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, C5_SHFL), C5_MULLO), 11); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); + _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); + src += 10; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); + src += 5; + dst += 8; + } + } + static void Decode32f6(const uint8_t* src, float scale, float shift, size_t size, float* dst) { assert(size % 8 == 0); @@ -363,6 +464,31 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, C5_SHFL), C5_MULLO), 11); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); + src += 10; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 5; + dst += 8; + } + } + static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) { assert(size % 8 == 0); @@ -436,6 +562,26 @@ namespace Simd template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); + template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m256i _ab = _mm256_setzero_si256(); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16, a += 10, b += 10) + { + __m256i _a = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)a)), C5_SHFL), C5_MULLO), 11); + __m256i _b = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)b)), C5_SHFL), C5_MULLO), 11); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); + } + for (; i < size; i += 8, a += 5, b += 5) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), _ab); + } + return ExtractSum(_ab); + } + template<> int32_t Correlation<6>(const uint8_t* a, const uint8_t* b, size_t size) { assert(size % 8 == 0); @@ -505,6 +651,67 @@ namespace Simd template void MicroCosineDistances2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, a1, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + __m256i ab10 = _mm256_setzero_si256(); + __m256i ab11 = _mm256_setzero_si256(); + __m256i ab12 = _mm256_setzero_si256(); + __m256i ab13 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 10) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[1] + o))), C5_SHFL), C5_MULLO), 11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + for (; i < size; i += 8, o += 5) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C5_SHFL), C5_MULLO), 11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 _size = _mm_set1_ps(float(size)); + Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + } + template<> void MicroCosineDistances2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; @@ -690,6 +897,51 @@ namespace Simd template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 10) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + for (; i < size; i += 8, o += 5) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 _size = _mm_set1_ps(float(size)); + Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + } + template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; @@ -868,6 +1120,12 @@ namespace Simd } case 5: { + _encode32f = Encode32f5; + _encode16f = Encode16f5; + _decode32f = Decode32f5; + _decode16f = Decode16f5; + _cosineDistance = Avx2::CosineDistance<5>; + _macroCosineDistances = Avx2::MacroCosineDistances<5>; break; } case 6: diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index fde104d86f..bdb2b09294 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -114,6 +114,17 @@ namespace Simd #ifdef SIMD_AVX2_ENABLE namespace Avx2 { + const __m256i E5_MULLO = SIMD_MM256_SETR_EPI16(256, 32, 4, 128, 16, 2, 64, 8, 256, 32, 4, 128, 16, 2, 64, 8); + const __m256i E5_SHFL0 = SIMD_MM256_SETR_EPI8( + 0x1, 0x3, 0x7, 0x9, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, 0x1, 0x3, 0x7, 0x9, 0xD, -1, -1, -1, -1, -1, -1); + const __m256i E5_SHFL1 = SIMD_MM256_SETR_EPI8( + 0x2, 0x4, 0x8, 0xA, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, 0x2, 0x4, 0x8, 0xA, 0xE, -1, -1, -1, -1, -1, -1); + const __m256i E5_SHFL2 = SIMD_MM256_SETR_EPI8( + -1, 0x6, -1, 0xC, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, 0x6, -1, 0xC, -1, -1, -1, -1, -1, -1, -1); + const __m256i E6_MULLO = SIMD_MM256_SETR_EPI16(256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4); const __m256i E6_SHFL0 = SIMD_MM256_SETR_EPI8( 0x1, 0x3, 0x5, 0x9, 0xB, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, @@ -130,6 +141,11 @@ namespace Simd 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1); + const __m256i C5_SHFL = SIMD_MM256_SETR_EPI8( + 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x4, + 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9); + const __m256i C5_MULLO = SIMD_MM256_SETR_EPI16(8, 64, 2, 16, 128, 4, 32, 256, 8, 64, 2, 16, 128, 4, 32, 256); + const __m256i C6_SHFL = SIMD_MM256_SETR_EPI8( 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xB); From 1a4c5c0524646cd49e4cc47edae0ecc70889a025 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 15 Jun 2023 17:29:57 +0300 Subject: [PATCH 15/44] +add Support of 4-bit depth in AVX2 optimizations of functions DescrIntEncode32f, DescrIntEncode16f. --- src/Simd/SimdAvx2DescrInt.cpp | 72 +++++++++++++++++++++++++++++++++++ src/Simd/SimdDescrIntCommon.h | 2 + 2 files changed, 74 insertions(+) diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 0c836dfd24..5c3bcf9cfe 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -84,6 +84,41 @@ namespace Simd return Encode32f(_mm256_loadu_ps(src), scale, min, sum, sqsum); } + static SIMD_INLINE __m128i Encode32f4x8(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0 * 8, scale, min, sum, sqsum); + __m128i s0 = _mm_srli_epi32(_mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E4_MULLO), 12); + return _mm_packus_epi16(_mm_packus_epi32(s0, Sse41::K_ZERO), Sse41::K_ZERO); + } + + static SIMD_INLINE __m128i Encode32f4x32(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0 * 8, scale, min, sum, sqsum); + __m256i i1 = Encode32f(src + 1 * 8, scale, min, sum, sqsum); + __m256i s0 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + __m256i i2 = Encode32f(src + 2 * 8, scale, min, sum, sqsum); + __m256i i3 = Encode32f(src + 3 * 8, scale, min, sum, sqsum); + __m256i s1 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); + return _mm_packus_epi16(_mm_packus_epi32(_mm256_castsi256_si128(s0), _mm256_extracti128_si256(s0, 1)), + _mm_packus_epi32(_mm256_castsi256_si128(s1), _mm256_extracti128_si256(s1, 1))); + } + + static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < size32; i += 32, src += 32, dst += 16) + _mm_storeu_si128((__m128i*)dst, Encode32f4x32(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 4) + *(uint32_t*)(dst) = _mm_extract_epi32(Encode32f4x8(src, _scale, _min, _sum, _sqsum), 0); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + static SIMD_INLINE __m128i Encode32f5x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) { __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); @@ -226,6 +261,41 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static SIMD_INLINE __m128i Encode16f4x8(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); + __m128i s0 = _mm_srli_epi32(_mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E4_MULLO), 12); + return _mm_packus_epi16(_mm_packus_epi32(s0, Sse41::K_ZERO), Sse41::K_ZERO); + } + + static SIMD_INLINE __m128i Encode16f4x32(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); + __m256i i1 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); + __m256i s0 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + __m256i i2 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 2)), scale, min, sum, sqsum); + __m256i i3 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 3)), scale, min, sum, sqsum); + __m256i s1 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); + return _mm_packus_epi16(_mm_packus_epi32(_mm256_castsi256_si128(s0), _mm256_extracti128_si256(s0, 1)), + _mm_packus_epi32(_mm256_castsi256_si128(s1), _mm256_extracti128_si256(s1, 1))); + } + + static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < size32; i += 32, src += 32, dst += 16) + _mm_storeu_si128((__m128i*)dst, Encode16f4x32(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 4) + *(uint32_t*)(dst) = _mm_extract_epi32(Encode16f4x8(src, _scale, _min, _sum, _sqsum), 0); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + static SIMD_INLINE __m128i Encode16f5x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) { __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); @@ -1116,6 +1186,8 @@ namespace Simd { case 4: { + _encode32f = Encode32f4; + _encode16f = Encode16f4; break; } case 5: diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index bdb2b09294..92b04e3209 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -114,6 +114,8 @@ namespace Simd #ifdef SIMD_AVX2_ENABLE namespace Avx2 { + const __m256i E4_MULLO = SIMD_MM256_SETR_EPI16(4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1); + const __m256i E5_MULLO = SIMD_MM256_SETR_EPI16(256, 32, 4, 128, 16, 2, 64, 8, 256, 32, 4, 128, 16, 2, 64, 8); const __m256i E5_SHFL0 = SIMD_MM256_SETR_EPI8( 0x1, 0x3, 0x7, 0x9, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, From 20ab6d2f288b52c137c01f109092f272b48a1c48 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 15 Jun 2023 18:32:42 +0300 Subject: [PATCH 16/44] +add Support of 4-bit depth in AVX2 optimizations of functions DescrIntDecode32f, DescrIntDecode16f, DescrIntCosineDistance, DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa. --- src/Simd/SimdAvx2DescrInt.cpp | 215 ++++++++++++++++++++++++++++++++++ src/Simd/SimdConst.h | 1 + src/Simd/SimdDescrIntCommon.h | 5 + 3 files changed, 221 insertions(+) diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 5c3bcf9cfe..26eaf62f2f 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -438,6 +438,31 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, C4_SHFL), C4_MULLO), 12); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); + _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); + src += 8; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); + src += 4; + dst += 8; + } + } + static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) { assert(size % 8 == 0); @@ -534,6 +559,31 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, C4_SHFL), C4_MULLO), 12); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); + src += 8; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 4; + dst += 8; + } + } + static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) { assert(size % 8 == 0); @@ -632,6 +682,28 @@ namespace Simd template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); + template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m256i ab32 = _mm256_setzero_si256(); + size_t i = 0, size64 = AlignLo(size, 64); + for (; i < size64; i += 64, a += 32, b += 32) + { + __m256i _a = _mm256_loadu_si256((__m256i*)a); + __m256i _b = _mm256_loadu_si256((__m256i*)b); + __m256i ab16 = _mm256_maddubs_epi16(_mm256_and_si256(_a, K8_0F), _mm256_and_si256(_b, K8_0F)); + ab16 = _mm256_add_epi16(ab16, _mm256_maddubs_epi16(_mm256_and_si256(_mm256_srli_epi16(_a, 4), K8_0F), _mm256_and_si256(_mm256_srli_epi16(_b, 4), K8_0F))); + ab32 = _mm256_add_epi32(ab32, _mm256_madd_epi16(ab16, K16_0001)); + } + for (; i < size; i += 8, a += 4, b += 4) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + ab32 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), ab32); + } + return ExtractSum(ab32); + } + template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) { assert(size % 8 == 0); @@ -721,6 +793,86 @@ namespace Simd template void MicroCosineDistances2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size64 = AlignLo(size, 64), o = 16; + __m256i a0, a1, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + __m256i ab10 = _mm256_setzero_si256(); + __m256i ab11 = _mm256_setzero_si256(); + __m256i ab12 = _mm256_setzero_si256(); + __m256i ab13 = _mm256_setzero_si256(); + for (; i < size64; i += 64, o += 32) + { + a0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[0] + o)), K8_0F); + a1 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[1] + o)), K8_0F); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[0] + o)), K8_0F); + ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab10 = _mm256_add_epi32(ab10, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[1] + o)), K8_0F); + ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab11 = _mm256_add_epi32(ab11, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[2] + o)), K8_0F); + ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab12 = _mm256_add_epi32(ab12, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[3] + o)), K8_0F); + ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab13 = _mm256_add_epi32(ab13, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + a0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o)), 4), K8_0F); + a1 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[1] + o)), 4), K8_0F); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o)), 4), K8_0F); + ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab10 = _mm256_add_epi32(ab10, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o)), 4), K8_0F); + ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab11 = _mm256_add_epi32(ab11, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o)), 4), K8_0F); + ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab12 = _mm256_add_epi32(ab12, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o)), 4), K8_0F); + ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab13 = _mm256_add_epi32(ab13, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + } + for (; i < size; i += 8, o += 4) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C4_SHFL), C4_MULLO), 12); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C4_SHFL), C4_MULLO), 12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C4_SHFL), C4_MULLO), 12); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C4_SHFL), C4_MULLO), 12); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C4_SHFL), C4_MULLO), 12); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C4_SHFL), C4_MULLO), 12); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 _size = _mm_set1_ps(float(size)); + Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + } + template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; @@ -967,6 +1119,65 @@ namespace Simd template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size64 = AlignLo(size, 64), o = 16; + __m256i a0, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + for (; i < size64; i += 64, o += 32) + { + a0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[0] + o)), K8_0F); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[0] + o)), K8_0F); + ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[1] + o)), K8_0F); + ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[2] + o)), K8_0F); + ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[3] + o)), K8_0F); + ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + a0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o)), 4), K8_0F); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o)), 4), K8_0F); + ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o)), 4), K8_0F); + ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o)), 4), K8_0F); + ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o)), 4), K8_0F); + ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + } + for (; i < size; i += 8, o += 4) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C4_SHFL), C4_MULLO), 12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C4_SHFL), C4_MULLO), 12); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C4_SHFL), C4_MULLO), 12); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C4_SHFL), C4_MULLO), 12); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C4_SHFL), C4_MULLO), 12); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 _size = _mm_set1_ps(float(size)); + Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + } + template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; @@ -1188,6 +1399,10 @@ namespace Simd { _encode32f = Encode32f4; _encode16f = Encode16f4; + _decode32f = Decode32f4; + _decode16f = Decode16f4; + _cosineDistance = Avx2::CosineDistance<4>; + _macroCosineDistances = Avx2::MacroCosineDistances<4>; break; } case 5: diff --git a/src/Simd/SimdConst.h b/src/Simd/SimdConst.h index 2db5a5d9c9..360dbb32f7 100644 --- a/src/Simd/SimdConst.h +++ b/src/Simd/SimdConst.h @@ -216,6 +216,7 @@ namespace Simd const __m256i K8_04 = SIMD_MM256_SET1_EPI8(0x04); const __m256i K8_07 = SIMD_MM256_SET1_EPI8(0x07); const __m256i K8_08 = SIMD_MM256_SET1_EPI8(0x08); + const __m256i K8_0F = SIMD_MM256_SET1_EPI8(0x0F); const __m256i K8_10 = SIMD_MM256_SET1_EPI8(0x10); const __m256i K8_20 = SIMD_MM256_SET1_EPI8(0x20); const __m256i K8_40 = SIMD_MM256_SET1_EPI8(0x40); diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index 92b04e3209..cef4ef1574 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -143,6 +143,11 @@ namespace Simd 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1); + const __m256i C4_SHFL = SIMD_MM256_SETR_EPI8( + 0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, + 0x4, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x7, 0x7); + const __m256i C4_MULLO = SIMD_MM256_SETR_EPI16(4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256); + const __m256i C5_SHFL = SIMD_MM256_SETR_EPI8( 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9); From b323ec9a698d464e89a2a23495934669e9355786 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 15 Jun 2023 19:55:46 +0300 Subject: [PATCH 17/44] +add Avx2::DecodeCosineDistances. --- src/Simd/SimdAvx2DescrInt.cpp | 52 ++++++++++--------------------- src/Simd/SimdAvx512bwDescrInt.cpp | 38 ++++++++++------------ src/Simd/SimdBaseDescrInt.cpp | 2 +- src/Simd/SimdDescrIntCommon.h | 39 +++++++++++++++++++++-- src/Simd/SimdExtract.h | 9 ++++++ src/Simd/SimdSse41DescrInt.cpp | 42 ++++++++++--------------- 6 files changed, 95 insertions(+), 87 deletions(-) diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 26eaf62f2f..83dfe38ab0 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -786,7 +786,7 @@ namespace Simd template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) { float abSum = (float)Correlation(a + 16, b + 16, size); - Base::DecodeCosineDistance(a, b, abSum, (float)size, distance); + Base::DecodeCosineDistance(a, b, abSum, distance); } //------------------------------------------------------------------------------------------------- @@ -866,11 +866,8 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances(A, B, ab, distances, stride); } template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -927,11 +924,8 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances(A, B, ab, distances, stride); } template<> void MicroCosineDistances2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -988,11 +982,8 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances(A, B, ab, distances, stride); } template<> void MicroCosineDistances2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1049,11 +1040,8 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances(A, B, ab, distances, stride); } template<> void MicroCosineDistances2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1110,11 +1098,8 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances(A, B, ab, distances, stride); } template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); @@ -1174,8 +1159,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1219,8 +1203,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1264,8 +1247,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1309,8 +1291,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1354,8 +1335,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index a4c222a75c..dc2f4d1884 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -535,7 +535,7 @@ namespace Simd template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) { float abSum = (float)Correlation(a + 16, b + 16, size); - Base::DecodeCosineDistance(a, b, abSum, (float)size, distance); + Base::DecodeCosineDistance(a, b, abSum, distance); } //------------------------------------------------------------------------------------------------- @@ -629,11 +629,10 @@ namespace Simd __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); - Sse41::DecodeCosineDistances(A[2], B, ab2, _size, distances + 2 * stride); - Sse41::DecodeCosineDistances(A[3], B, ab3, _size, distances + 3 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances(A[3], B, ab3, distances + 3 * stride); } template<> void MicroCosineDistances4x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -723,11 +722,10 @@ namespace Simd __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); - Sse41::DecodeCosineDistances(A[2], B, ab2, _size, distances + 2 * stride); - Sse41::DecodeCosineDistances(A[3], B, ab3, _size, distances + 3 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances(A[3], B, ab3, distances + 3 * stride); } template<> void MicroCosineDistances4x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -817,11 +815,10 @@ namespace Simd __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); - Sse41::DecodeCosineDistances(A[2], B, ab2, _size, distances + 2 * stride); - Sse41::DecodeCosineDistances(A[3], B, ab3, _size, distances + 3 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances(A[3], B, ab3, distances + 3 * stride); } template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); @@ -868,8 +865,7 @@ namespace Simd ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -914,8 +910,7 @@ namespace Simd ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -960,8 +955,7 @@ namespace Simd ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - Sse41::DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) diff --git a/src/Simd/SimdBaseDescrInt.cpp b/src/Simd/SimdBaseDescrInt.cpp index 343813c5df..b643a35bb9 100644 --- a/src/Simd/SimdBaseDescrInt.cpp +++ b/src/Simd/SimdBaseDescrInt.cpp @@ -508,7 +508,7 @@ namespace Simd template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) { float abSum = (float)Correlation(a + 16, b + 16, size); - Base::DecodeCosineDistance(a, b, abSum, float(size), distance); + Base::DecodeCosineDistance(a, b, abSum, distance); } //------------------------------------------------------------------------------------------------- diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index cef4ef1574..f81c42d4d2 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -26,12 +26,15 @@ #include "Simd/SimdDefs.h" #include "Simd/SimdMath.h" +#include "Simd/SimdLoad.h" +#include "Simd/SimdSet.h" +#include "Simd/SimdStore.h" namespace Simd { namespace Base { - SIMD_INLINE void DecodeCosineDistance(const uint8_t* a, const uint8_t* b, float abSum, float size, float* distance) + SIMD_INLINE void DecodeCosineDistance(const uint8_t* a, const uint8_t* b, float abSum, float* distance) { float aScale = ((float*)a)[0]; float aShift = ((float*)a)[1]; @@ -81,7 +84,7 @@ namespace Simd //------------------------------------------------------------------------------------------------- - SIMD_INLINE void DecodeCosineDistances(const uint8_t* a, const uint8_t* const* B, __m128 abSum, __m128 size, float* distances) + SIMD_INLINE void DecodeCosineDistances(const uint8_t* a, const uint8_t* const* B, __m128 abSum, float* distances) { __m128 aScale, aShift, aMean, aNorm, bScale, bShift, bMean, bNorm; bScale = _mm_loadu_ps((float*)B[0]); @@ -162,6 +165,38 @@ namespace Simd 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x6, 0x7, 0x7, 0x7, 0x8, 0x8, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xC, 0xC, 0xD, 0xD, 0xD); const __m256i C7_MULLO = SIMD_MM256_SETR_EPI16(2, 4, 8, 16, 32, 64, 128, 256, 2, 4, 8, 16, 32, 64, 128, 256); + + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE void DecodeCosineDistances(const uint8_t* const* A, const uint8_t* const* B, __m256 abSum, float* distances, size_t stride) + { + __m256 aScale, aShift, aMean, aNorm, bScale, bShift, bMean, bNorm; + bScale = _mm256_broadcast_ps((__m128*)B[0]); + bShift = _mm256_broadcast_ps((__m128*)B[1]); + bMean = _mm256_broadcast_ps((__m128*)B[2]); + bNorm = _mm256_broadcast_ps((__m128*)B[3]); + aScale = _mm256_unpacklo_ps(bScale, bMean); + aShift = _mm256_unpacklo_ps(bShift, bNorm); + aMean = _mm256_unpackhi_ps(bScale, bMean); + aNorm = _mm256_unpackhi_ps(bShift, bNorm); + bScale = _mm256_unpacklo_ps(aScale, aShift); + bShift = _mm256_unpackhi_ps(aScale, aShift); + bMean = _mm256_unpacklo_ps(aMean, aNorm); + bNorm = _mm256_unpackhi_ps(aMean, aNorm); + + aNorm = Avx::Load((float*)A[0], (float*)A[1]); + aScale = Broadcast<0>(aNorm); + aShift = Broadcast<1>(aNorm); + aMean = Broadcast<2>(aNorm); + aNorm = Broadcast<3>(aNorm); + + __m256 ab = _mm256_mul_ps(abSum, _mm256_mul_ps(aScale, bScale)); + ab = _mm256_fmadd_ps(aMean, bShift, ab); + ab = _mm256_fmadd_ps(bMean, aShift, ab); + + Avx::Store(distances + 0 * stride, distances + 1 * stride, + _mm256_min_ps(_mm256_max_ps(_mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_div_ps(ab, _mm256_mul_ps(aNorm, bNorm))), _mm256_setzero_ps()), _mm256_set1_ps(2.0f))); + } } #endif diff --git a/src/Simd/SimdExtract.h b/src/Simd/SimdExtract.h index 3909e2b005..5cbdc60f86 100644 --- a/src/Simd/SimdExtract.h +++ b/src/Simd/SimdExtract.h @@ -175,6 +175,15 @@ namespace Simd __m256i b = _mm256_hadd_epi32(_mm256_hadd_epi32(a0, a1), _mm256_hadd_epi32(a2, a3)); return _mm_add_epi32(_mm256_castsi256_si128(b), _mm256_extracti128_si256(b, 1)); } + + SIMD_INLINE __m256i Extract8Sums( + const __m256i& a0, const __m256i& a1, const __m256i& a2, const __m256i& a3, + const __m256i& a4, const __m256i& a5, const __m256i& a6, const __m256i& a7) + { + __m256i b0 = _mm256_hadd_epi32(_mm256_hadd_epi32(a0, a1), _mm256_hadd_epi32(a2, a3)); + __m256i b1 = _mm256_hadd_epi32(_mm256_hadd_epi32(a4, a5), _mm256_hadd_epi32(a6, a7)); + return _mm256_add_epi32(_mm256_permute2x128_si256(b0, b1, 0x20), _mm256_permute2x128_si256(b0, b1, 0x31)); + } } #endif// SIMD_AVX2_ENABLE diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 85e30f43df..9d74a3ebc9 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -677,7 +677,7 @@ namespace Simd template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) { float abSum = (float)Correlation(a + 16, b + 16, size); - Base::DecodeCosineDistance(a, b, abSum, (float)size, distance); + Base::DecodeCosineDistance(a, b, abSum, distance); } //------------------------------------------------------------------------------------------------- @@ -759,9 +759,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); } template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -840,9 +839,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); } template<> void MicroCosineDistances2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -921,9 +919,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); } template<> void MicroCosineDistances2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1002,9 +999,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); } template<> void MicroCosineDistances2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1083,9 +1079,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, _size, distances + 1 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); } template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); @@ -1145,8 +1140,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1204,8 +1198,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1263,8 +1256,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1322,8 +1314,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1381,8 +1372,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 _size = _mm_set1_ps(float(size)); - DecodeCosineDistances(A[0], B, ab0, _size, distances + 0 * stride); + DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); } template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) From fa4cbd0ed6c409e8a6493baa4b3adf2052f01ef8 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Fri, 16 Jun 2023 21:40:48 +0300 Subject: [PATCH 18/44] +add Support of 5-bit depth in AVX-512BW optimizations of functions DescrIntEncode32f, DescrIntEncode16f. --- src/Simd/SimdAvx512bwDescrInt.cpp | 72 +++++++++++++++++++++++++++++++ src/Simd/SimdDescrIntCommon.h | 19 ++++++++ 2 files changed, 91 insertions(+) diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index dc2f4d1884..98f02aa2d7 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -96,6 +96,41 @@ namespace Simd return Encode32f(_mm512_maskz_loadu_ps(mask, src), scale, min, sum, sqsum); } + static SIMD_INLINE __m128i Encode32f5x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E5_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E5_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E5_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode32f5x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E5_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_or_si512(_mm512_shuffle_epi8(s0, E5_SHFL0), _mm512_shuffle_epi8(s0, E5_SHFL1)), _mm512_shuffle_epi8(s0, E5_SHFL2)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode32f5(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 20) + _mm256_mask_storeu_epi8(dst - 6, 0x03FFFFC0, Encode32f5x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 10) + _mm_mask_storeu_epi8(dst, 0x03FF, Encode32f5x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x001F, Encode32f5x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + static SIMD_INLINE __m128i Encode32f6x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) { __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); @@ -203,6 +238,41 @@ namespace Simd return Encode32f(_mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, src)), scale, min, sum, sqsum); } + static SIMD_INLINE __m128i Encode16f5x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E5_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E5_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E5_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode16f5x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E5_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_or_si512(_mm512_shuffle_epi8(s0, E5_SHFL0), _mm512_shuffle_epi8(s0, E5_SHFL1)), _mm512_shuffle_epi8(s0, E5_SHFL2)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode16f5(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 20) + _mm256_mask_storeu_epi8(dst - 6, 0x03FFFFC0, Encode16f5x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 10) + _mm_mask_storeu_epi8(dst, 0x03FF, Encode16f5x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x001F, Encode16f5x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + static SIMD_INLINE __m128i Encode16f6x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) { __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); @@ -1003,6 +1073,8 @@ namespace Simd } case 5: { + _encode32f = Encode32f5; + _encode16f = Encode16f5; break; } case 6: diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index f81c42d4d2..ea2458b7a0 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -205,6 +205,25 @@ namespace Simd { const __m512i EX_PERM = SIMD_MM512_SETR_EPI64(0, 2, 1, 3, 4, 6, 5, 7); + const __m512i E5_MULLO = SIMD_MM512_SETR_EPI16( + 256, 32, 4, 128, 16, 2, 64, 8, 256, 32, 4, 128, 16, 2, 64, 8, + 256, 32, 4, 128, 16, 2, 64, 8, 256, 32, 4, 128, 16, 2, 64, 8); + const __m512i E5_SHFL0 = SIMD_MM512_SETR_EPI8( + -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x7, 0x9, 0xD, -1, -1, -1, -1, -1, + 0x1, 0x3, 0x7, 0x9, 0xD, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x1, 0x3, 0x7, 0x9, 0xD, + -1, -1, -1, -1, -1, 0x1, 0x3, 0x7, 0x9, 0xD, -1, -1, -1, -1, -1, -1); + const __m512i E5_SHFL1 = SIMD_MM512_SETR_EPI8( + -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x8, 0xA, 0xE, -1, -1, -1, -1, -1, + 0x2, 0x4, 0x8, 0xA, 0xE, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x8, 0xA, 0xE, + -1, -1, -1, -1, -1, 0x2, 0x4, 0x8, 0xA, 0xE, -1, -1, -1, -1, -1, -1); + const __m512i E5_SHFL2 = SIMD_MM512_SETR_EPI8( + -1, -1, -1, -1, -1, -1, -1, 0x6, -1, 0xC, -1, -1, -1, -1, -1, -1, + -1, 0x6, -1, 0xC, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x6, -1, 0xC, -1, + -1, -1, -1, -1, -1, -1, 0x6, -1, 0xC, -1, -1, -1, -1, -1, -1, -1); + const __m512i E6_MULLO = SIMD_MM512_SETR_EPI16( 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4, 256, 64, 16, 4); From 046cb7d7917c28bbe88292da0ec7cf8d9575dad9 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 21 Jun 2023 12:22:03 +0300 Subject: [PATCH 19/44] +add Support of 5-bit depth in AVX-512BW optimizations of functions DescrIntDecode32f, DescrIntDecode16f, DescrIntCosineDistance, DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa. --- docs/2023.html | 4 +- src/Simd/SimdAvx512bwDescrInt.cpp | 217 ++++++++++++++++++++++++++++++ src/Simd/SimdDescrIntCommon.h | 11 ++ 3 files changed, 230 insertions(+), 2 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index e1afb922c2..a12607e53e 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -39,8 +39,8 @@
      New features
      • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntEncode16f.
      • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode16f.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntEncode32f.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntEncode16f.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntEncode32f.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntEncode16f.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntDecode32f.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntDecode16f.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntCosineDistance.
      • diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index 98f02aa2d7..94561b4738 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -375,6 +375,30 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C5_SHFL), Avx2::C5_MULLO), 11); + _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); + src += 10; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); + src += 5; + dst += 8; + } + } + static void Decode32f6(const uint8_t* src, float scale, float shift, size_t size, float* dst) { assert(size % 8 == 0); @@ -451,6 +475,31 @@ namespace Simd //------------------------------------------------------------------------------------------------- + + static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, Avx2::C5_SHFL), Avx2::C5_MULLO), 11); + _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 10; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + src += 5; + dst += 8; + } + } + static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) { assert(size % 8 == 0); @@ -529,6 +578,32 @@ namespace Simd template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); + SIMD_INLINE __m512i Load5(const uint8_t* ptr, __mmask32 mask = 0x000FFFFF) + { + return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C5_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C5_SHFL), C5_MULLO), 11); + } + + template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m512i _ab = _mm512_setzero_si512(); + size_t i = 0, size32 = AlignLo(size, 32); + for (; i < size32; i += 32, a += 20, b += 20) + { + __m512i _a = Load5(a); + __m512i _b = Load5(b); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 5); + __m512i _a = Load5(a, mask); + __m512i _b = Load5(b, mask); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + return ExtractSum(_ab); + } + SIMD_INLINE __m512i Load6(const uint8_t* ptr, __mmask32 mask = 0x00FFFFFF) { return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C6_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C6_SHFL), C6_MULLO), 10); @@ -612,6 +687,99 @@ namespace Simd template void MicroCosineDistances4x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances4x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, a1, a2, a3, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + __m512i ab10 = _mm512_setzero_si512(); + __m512i ab11 = _mm512_setzero_si512(); + __m512i ab12 = _mm512_setzero_si512(); + __m512i ab13 = _mm512_setzero_si512(); + __m512i ab20 = _mm512_setzero_si512(); + __m512i ab21 = _mm512_setzero_si512(); + __m512i ab22 = _mm512_setzero_si512(); + __m512i ab23 = _mm512_setzero_si512(); + __m512i ab30 = _mm512_setzero_si512(); + __m512i ab31 = _mm512_setzero_si512(); + __m512i ab32 = _mm512_setzero_si512(); + __m512i ab33 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 20) + { + a0 = Load5(A[0] + o); + a1 = Load5(A[1] + o); + a2 = Load5(A[2] + o); + a3 = Load5(A[3] + o); + + b0 = Load5(B[0] + o); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = Load5(B[1] + o); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = Load5(B[2] + o); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = Load5(B[3] + o); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 5); + a0 = Load5(A[0] + o, mask); + a1 = Load5(A[1] + o, mask); + a2 = Load5(A[2] + o, mask); + a3 = Load5(A[3] + o, mask); + + b0 = Load5(B[0] + o, mask); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = Load5(B[1] + o, mask); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = Load5(B[2] + o, mask); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = Load5(B[3] + o, mask); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); + __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances(A[3], B, ab3, distances + 3 * stride); + } + template<> void MicroCosineDistances4x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; @@ -893,6 +1061,51 @@ namespace Simd template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 20) + { + a0 = Load5(A[0] + o); + + b0 = Load5(B[0] + o); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = Load5(B[1] + o); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = Load5(B[2] + o); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = Load5(B[3] + o); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 5); + a0 = Load5(A[0] + o, mask); + + b0 = Load5(B[0] + o, mask); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = Load5(B[1] + o, mask); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = Load5(B[2] + o, mask); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = Load5(B[3] + o, mask); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + } + template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; @@ -1075,6 +1288,10 @@ namespace Simd { _encode32f = Encode32f5; _encode16f = Encode16f5; + _decode32f = Decode32f5; + _decode16f = Decode16f5; + _cosineDistance = Avx512bw::CosineDistance<5>; + _macroCosineDistances = Avx512bw::MacroCosineDistances<5>; break; } case 6: diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index ea2458b7a0..5d7a021815 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -252,6 +252,17 @@ namespace Simd -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1); + const __m512i C5_PERM = SIMD_MM512_SETR_EPI32( + 0x0, 0x1, 0x0, 0x0, 0x1, 0x2, 0x0, 0x0, 0x2, 0x3, 0x0, 0x0, 0x3, 0x4, 0x0, 0x0); + const __m512i C5_SHFL = SIMD_MM512_SETR_EPI8( + 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x4, + 0x1, 0x1, 0x1, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, + 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, + 0x3, 0x3, 0x3, 0x4, 0x4, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x7); + const __m512i C5_MULLO = SIMD_MM512_SETR_EPI16( + 8, 64, 2, 16, 128, 4, 32, 256, 8, 64, 2, 16, 128, 4, 32, 256, + 8, 64, 2, 16, 128, 4, 32, 256, 8, 64, 2, 16, 128, 4, 32, 256); + const __m512i C6_PERM = SIMD_MM512_SETR_EPI32( 0x0, 0x1, 0x0, 0x0, 0x1, 0x2, 0x0, 0x0, 0x3, 0x4, 0x0, 0x0, 0x4, 0x5, 0x0, 0x0); const __m512i C6_SHFL = SIMD_MM512_SETR_EPI8( From 195b06ad8d3e3520df684262dfaa17893ccc5371 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 21 Jun 2023 16:20:56 +0300 Subject: [PATCH 20/44] +add Support of 4-bit depth in AVX-512BW optimizations of functions DescrIntEncode32f, DescrIntEncode16f. --- docs/2023.html | 10 ++-- src/Simd/SimdAvx512bwDescrInt.cpp | 86 +++++++++++++++++++++++++++++++ src/Simd/SimdDescrIntCommon.h | 4 ++ 3 files changed, 95 insertions(+), 5 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index a12607e53e..1bc1aa41a9 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -41,11 +41,11 @@
        New features
      • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode16f.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntEncode32f.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntEncode16f.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntDecode32f.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntDecode16f.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntCosineDistance.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntCosineDistancesMxNp.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2 optimizations of function DescrIntCosineDistancesMxNa.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode32f.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode16f.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistance.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNp.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa.
      Bug fixing
        diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index 94561b4738..bacc9ca15e 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -96,6 +96,48 @@ namespace Simd return Encode32f(_mm512_maskz_loadu_ps(mask, src), scale, min, sum, sqsum); } + static SIMD_INLINE __m128i Encode32f4x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 m0, __mmask16 m1) + { + __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum, m0); + __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum, m1); + __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + return _mm256_castsi256_si128(Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), Avx2::K_ZERO)); + } + + static SIMD_INLINE __m256i Encode32f4x8(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); + __m512i i2 = Encode32f(src + 2 * F, scale, min, sum, sqsum); + __m512i i3 = Encode32f(src + 3 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + __m512i s1 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); + return Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), _mm512_cvtepi32_epi16(s1)); + } + + static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32), size64 = AlignLo(size, 64); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size64; i += 64, src += 64, dst += 32) + _mm256_storeu_si256((__m256i*)dst, Encode32f4x8(src, _scale, _min, _sum, _sqsum)); + for (; i < size32; i += 32, src += 32, dst += 16) + _mm_mask_storeu_epi8(dst, -1, Encode32f4x4(src, _scale, _min, _sum, _sqsum, -1, -1)); + if (i < size) + { + __mmask16 ms0 = TailMask16(size - size32 - 0 * F); + __mmask16 ms1 = TailMask16(size - size32 - 1 * F); + __mmask16 md= TailMask16((size - size32) / 2); + _mm_mask_storeu_epi8(dst, md, Encode32f4x4(src, _scale, _min, _sum, _sqsum, ms0, ms1)); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + static SIMD_INLINE __m128i Encode32f5x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) { __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); @@ -238,6 +280,48 @@ namespace Simd return Encode32f(_mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, src)), scale, min, sum, sqsum); } + static SIMD_INLINE __m128i Encode16f4x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 m0, __mmask16 m1) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum, m0); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum, m1); + __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + return _mm256_castsi256_si128(Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), Avx2::K_ZERO)); + } + + static SIMD_INLINE __m256i Encode16f4x8(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); + __m512i i2 = Encode16f(src + 2 * F, scale, min, sum, sqsum); + __m512i i3 = Encode16f(src + 3 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + __m512i s1 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); + return Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), _mm512_cvtepi32_epi16(s1)); + } + + static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32), size64 = AlignLo(size, 64); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size64; i += 64, src += 64, dst += 32) + _mm256_storeu_si256((__m256i*)dst, Encode16f4x8(src, _scale, _min, _sum, _sqsum)); + for (; i < size32; i += 32, src += 32, dst += 16) + _mm_mask_storeu_epi8(dst, -1, Encode16f4x4(src, _scale, _min, _sum, _sqsum, -1, -1)); + if (i < size) + { + __mmask16 ms0 = TailMask16(size - size32 - 0 * F); + __mmask16 ms1 = TailMask16(size - size32 - 1 * F); + __mmask16 md = TailMask16((size - size32) / 2); + _mm_mask_storeu_epi8(dst, md, Encode16f4x4(src, _scale, _min, _sum, _sqsum, ms0, ms1)); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + static SIMD_INLINE __m128i Encode16f5x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) { __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); @@ -1282,6 +1366,8 @@ namespace Simd { case 4: { + _encode32f = Encode32f4; + _encode16f = Encode16f4; break; } case 5: diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index 5d7a021815..9529c7b6a9 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -205,6 +205,10 @@ namespace Simd { const __m512i EX_PERM = SIMD_MM512_SETR_EPI64(0, 2, 1, 3, 4, 6, 5, 7); + const __m512i E4_MULLO = SIMD_MM512_SETR_EPI16( + 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, + 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1, 4096, 1); + const __m512i E5_MULLO = SIMD_MM512_SETR_EPI16( 256, 32, 4, 128, 16, 2, 64, 8, 256, 32, 4, 128, 16, 2, 64, 8, 256, 32, 4, 128, 16, 2, 64, 8, 256, 32, 4, 128, 16, 2, 64, 8); From 4089000c590866db87e62479c999991ac706c000 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 21 Jun 2023 18:51:02 +0300 Subject: [PATCH 21/44] +add Support of 4-bit depth in AVX-512BW optimizations of functions DescrIntDecode32f, DescrIntDecode16f, DescrIntCosineDistance, DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa. --- src/Simd/SimdAvx2DescrInt.cpp | 20 +- src/Simd/SimdAvx512bwDescrInt.cpp | 305 ++++++++++++++++++++++++++++-- src/Simd/SimdConst.h | 1 + src/Simd/SimdDescrIntCommon.h | 4 +- src/Simd/SimdSse41DescrInt.cpp | 30 +-- 5 files changed, 313 insertions(+), 47 deletions(-) diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 83dfe38ab0..348b568700 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -867,7 +867,7 @@ namespace Simd ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A, B, ab, distances, stride); + DecodeCosineDistances2x4(A, B, ab, distances, stride); } template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -925,7 +925,7 @@ namespace Simd ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A, B, ab, distances, stride); + DecodeCosineDistances2x4(A, B, ab, distances, stride); } template<> void MicroCosineDistances2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -983,7 +983,7 @@ namespace Simd ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A, B, ab, distances, stride); + DecodeCosineDistances2x4(A, B, ab, distances, stride); } template<> void MicroCosineDistances2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1041,7 +1041,7 @@ namespace Simd ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A, B, ab, distances, stride); + DecodeCosineDistances2x4(A, B, ab, distances, stride); } template<> void MicroCosineDistances2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1099,7 +1099,7 @@ namespace Simd ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); } __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A, B, ab, distances, stride); + DecodeCosineDistances2x4(A, B, ab, distances, stride); } template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); @@ -1159,7 +1159,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1203,7 +1203,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1247,7 +1247,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1291,7 +1291,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1335,7 +1335,7 @@ namespace Simd ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index bacc9ca15e..7c44c1533d 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -459,6 +459,30 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, Avx2::C4_SHFL), Avx2::C4_MULLO), 12); + _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); + src += 8; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); + src += 4; + dst += 8; + } + } + static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) { assert(size % 8 == 0); @@ -559,6 +583,29 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, Avx2::C4_SHFL), Avx2::C4_MULLO), 12); + _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 8; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + src += 4; + dst += 8; + } + } static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) { @@ -662,6 +709,31 @@ namespace Simd template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); + template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m512i ab32 = _mm512_setzero_si512(); + size_t i = 0, size128 = AlignLo(size, 128); + for (; i < size128; i += 128, a += 64, b += 64) + { + __m512i _a = _mm512_loadu_si512((__m512i*)a); + __m512i _b = _mm512_loadu_si512((__m512i*)b); + __m512i ab16 = _mm512_maddubs_epi16(_mm512_and_si512(_a, K8_0F), _mm512_and_si512(_b, K8_0F)); + ab16 = _mm512_add_epi16(ab16, _mm512_maddubs_epi16(_mm512_and_si512(_mm512_srli_epi16(_a, 4), K8_0F), _mm512_and_si512(_mm512_srli_epi16(_b, 4), K8_0F))); + ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(ab16, K16_0001)); + } + if(i < size) + { + __mmask16 mask = TailMask16((size - i) / 8); + __m512i _a = _mm512_maskz_loadu_epi32(mask, a); + __m512i _b = _mm512_maskz_loadu_epi32(mask, b); + __m512i ab16 = _mm512_maddubs_epi16(_mm512_and_si512(_a, K8_0F), _mm512_and_si512(_b, K8_0F)); + ab16 = _mm512_add_epi16(ab16, _mm512_maddubs_epi16(_mm512_and_si512(_mm512_srli_epi16(_a, 4), K8_0F), _mm512_and_si512(_mm512_srli_epi16(_b, 4), K8_0F))); + ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(ab16, K16_0001)); + } + return ExtractSum(ab32); + } + SIMD_INLINE __m512i Load5(const uint8_t* ptr, __mmask32 mask = 0x000FFFFF) { return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C5_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C5_SHFL), C5_MULLO), 11); @@ -771,6 +843,131 @@ namespace Simd template void MicroCosineDistances4x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances4x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size128 = AlignLo(size, 128), o = 16; + __m512i a00, a10, a20, a30, a01, a11, a21, a31, b00, b01; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + __m512i ab10 = _mm512_setzero_si512(); + __m512i ab11 = _mm512_setzero_si512(); + __m512i ab12 = _mm512_setzero_si512(); + __m512i ab13 = _mm512_setzero_si512(); + __m512i ab20 = _mm512_setzero_si512(); + __m512i ab21 = _mm512_setzero_si512(); + __m512i ab22 = _mm512_setzero_si512(); + __m512i ab23 = _mm512_setzero_si512(); + __m512i ab30 = _mm512_setzero_si512(); + __m512i ab31 = _mm512_setzero_si512(); + __m512i ab32 = _mm512_setzero_si512(); + __m512i ab33 = _mm512_setzero_si512(); + for (; i < size128; i += 128, o += 64) + { + a01 = _mm512_loadu_si512((__m512i*)(A[0] + o)); + a00 = _mm512_and_si512(a01, K8_0F); + a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); + a11 = _mm512_loadu_si512((__m512i*)(A[1] + o)); + a10 = _mm512_and_si512(a11, K8_0F); + a11 = _mm512_and_si512(_mm512_srli_epi16(a11, 4), K8_0F); + a21 = _mm512_loadu_si512((__m512i*)(A[2] + o)); + a20 = _mm512_and_si512(a21, K8_0F); + a21 = _mm512_and_si512(_mm512_srli_epi16(a21, 4), K8_0F); + a31 = _mm512_loadu_si512((__m512i*)(A[3] + o)); + a30 = _mm512_and_si512(a31, K8_0F); + a31 = _mm512_and_si512(_mm512_srli_epi16(a31, 4), K8_0F); + + b01 = _mm512_loadu_si512((__m512i*)(B[0] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab10 = _mm512_add_epi32(ab10, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab20 = _mm512_add_epi32(ab20, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab30 = _mm512_add_epi32(ab30, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[1] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab11 = _mm512_add_epi32(ab11, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab21 = _mm512_add_epi32(ab21, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab31 = _mm512_add_epi32(ab31, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[2] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab12 = _mm512_add_epi32(ab12, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab22 = _mm512_add_epi32(ab22, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[3] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab13 = _mm512_add_epi32(ab13, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab23 = _mm512_add_epi32(ab23, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab33 = _mm512_add_epi32(ab33, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + } + if (i < size) + { + __mmask16 mask = TailMask32((size - i) / 8); + a01 = _mm512_maskz_loadu_epi32(mask, A[0] + o); + a00 = _mm512_and_si512(a01, K8_0F); + a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); + a11 = _mm512_maskz_loadu_epi32(mask, A[1] + o); + a10 = _mm512_and_si512(a11, K8_0F); + a11 = _mm512_and_si512(_mm512_srli_epi16(a11, 4), K8_0F); + a21 = _mm512_maskz_loadu_epi32(mask, A[2] + o); + a20 = _mm512_and_si512(a21, K8_0F); + a21 = _mm512_and_si512(_mm512_srli_epi16(a21, 4), K8_0F); + a31 = _mm512_maskz_loadu_epi32(mask, A[3] + o); + a30 = _mm512_and_si512(a31, K8_0F); + a31 = _mm512_and_si512(_mm512_srli_epi16(a31, 4), K8_0F); + + b01 = _mm512_maskz_loadu_epi32(mask, B[0] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab10 = _mm512_add_epi32(ab10, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab20 = _mm512_add_epi32(ab20, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab30 = _mm512_add_epi32(ab30, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[1] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab11 = _mm512_add_epi32(ab11, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab21 = _mm512_add_epi32(ab21, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab31 = _mm512_add_epi32(ab31, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[2] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab12 = _mm512_add_epi32(ab12, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab22 = _mm512_add_epi32(ab22, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[3] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab13 = _mm512_add_epi32(ab13, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab23 = _mm512_add_epi32(ab23, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab33 = _mm512_add_epi32(ab33, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); + __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); + } + template<> void MicroCosineDistances4x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; @@ -858,10 +1055,10 @@ namespace Simd __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances(A[3], B, ab3, distances + 3 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } template<> void MicroCosineDistances4x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -951,10 +1148,10 @@ namespace Simd __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances(A[3], B, ab3, distances + 3 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } template<> void MicroCosineDistances4x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1044,10 +1241,10 @@ namespace Simd __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances(A[3], B, ab3, distances + 3 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } template<> void MicroCosineDistances4x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1137,14 +1334,78 @@ namespace Simd __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances(A[3], B, ab3, distances + 3 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template<> void MicroCosineDistances1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size128 = AlignLo(size, 128), o = 16; + __m512i a00, a01, b00, b01; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + for (; i < size128; i += 128, o += 64) + { + a01 = _mm512_loadu_si512((__m512i*)(A[0] + o)); + a00 = _mm512_and_si512(a01, K8_0F); + a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); + + b01 = _mm512_loadu_si512((__m512i*)(B[0] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[1] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[2] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[3] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + } + if (i < size) + { + __mmask16 mask = TailMask32((size - i) / 8); + a01 = _mm512_maskz_loadu_epi32(mask, A[0] + o); + a00 = _mm512_and_si512(a01, K8_0F); + a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); + + b01 = _mm512_maskz_loadu_epi32(mask, B[0] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[1] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + b01 = _mm512_maskz_loadu_epi32(mask, B[2] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[3] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; @@ -1187,7 +1448,7 @@ namespace Simd ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1232,7 +1493,7 @@ namespace Simd ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1277,7 +1538,7 @@ namespace Simd ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1322,7 +1583,7 @@ namespace Simd ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1368,6 +1629,10 @@ namespace Simd { _encode32f = Encode32f4; _encode16f = Encode16f4; + _decode32f = Decode32f4; + _decode16f = Decode16f4; + _cosineDistance = Avx512bw::CosineDistance<4>; + _macroCosineDistances = Avx512bw::MacroCosineDistances<4>; break; } case 5: diff --git a/src/Simd/SimdConst.h b/src/Simd/SimdConst.h index 360dbb32f7..f95bf5857a 100644 --- a/src/Simd/SimdConst.h +++ b/src/Simd/SimdConst.h @@ -385,6 +385,7 @@ namespace Simd const __m512i K8_02 = SIMD_MM512_SET1_EPI8(0x02); const __m512i K8_03 = SIMD_MM512_SET1_EPI8(0x03); const __m512i K8_07 = SIMD_MM512_SET1_EPI8(0x07); + const __m512i K8_0F = SIMD_MM512_SET1_EPI8(0x0F); const __m512i K8_01_FF = SIMD_MM512_SET2_EPI8(0x01, 0xFF); diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index 9529c7b6a9..a53751fd47 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -84,7 +84,7 @@ namespace Simd //------------------------------------------------------------------------------------------------- - SIMD_INLINE void DecodeCosineDistances(const uint8_t* a, const uint8_t* const* B, __m128 abSum, float* distances) + SIMD_INLINE void DecodeCosineDistances1x4(const uint8_t* a, const uint8_t* const* B, __m128 abSum, float* distances) { __m128 aScale, aShift, aMean, aNorm, bScale, bShift, bMean, bNorm; bScale = _mm_loadu_ps((float*)B[0]); @@ -168,7 +168,7 @@ namespace Simd //------------------------------------------------------------------------------------------------- - SIMD_INLINE void DecodeCosineDistances(const uint8_t* const* A, const uint8_t* const* B, __m256 abSum, float* distances, size_t stride) + SIMD_INLINE void DecodeCosineDistances2x4(const uint8_t* const* A, const uint8_t* const* B, __m256 abSum, float* distances, size_t stride) { __m256 aScale, aShift, aMean, aNorm, bScale, bShift, bMean, bNorm; bScale = _mm256_broadcast_ps((__m128*)B[0]); diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 9d74a3ebc9..1ad8d01033 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -759,8 +759,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -839,8 +839,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } template<> void MicroCosineDistances2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -919,8 +919,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } template<> void MicroCosineDistances2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -999,8 +999,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } template<> void MicroCosineDistances2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1079,8 +1079,8 @@ namespace Simd } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances(A[1], B, ab1, distances + 1 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); @@ -1140,7 +1140,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1198,7 +1198,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1256,7 +1256,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1314,7 +1314,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) @@ -1372,7 +1372,7 @@ namespace Simd ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); } __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) From 52b168c87416c09f5c88b028741ce8709ec68f75 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 22 Jun 2023 13:42:54 +0300 Subject: [PATCH 22/44] *change test parameters. --- src/Test/TestGemm.cpp | 21 +++++++++++++++++---- src/Test/TestSynetConvolution8i.cpp | 3 ++- 2 files changed, 19 insertions(+), 5 deletions(-) diff --git a/src/Test/TestGemm.cpp b/src/Test/TestGemm.cpp index 2e9ac3d0a9..f00f5990b3 100644 --- a/src/Test/TestGemm.cpp +++ b/src/Test/TestGemm.cpp @@ -190,8 +190,8 @@ namespace Test //result = result && Gemm32fAutoTest(0, 0, 1002, 1001, 3000, f1, f2); //result = result && Gemm32fAutoTest(0, 0, 49, 49, 32, f1, f2); - result = result && Gemm32fAutoTest(0, 0, 64, 64, 64, f1, f2); - result = result && Gemm32fAutoTest(0, 0, 49, 32, 49, f1, f2); + //result = result && Gemm32fAutoTest(0, 0, 64, 64, 64, f1, f2); + //result = result && Gemm32fAutoTest(0, 0, 49, 32, 49, f1, f2); //result = result && Gemm32fAutoTest(0, 0, 48, 48, 32, f1, f2); //result = result && Gemm32fAutoTest(0, 0, 48, 32, 48, f1, f2); //result = result && Gemm32fAutoTest(0, 0, 48, 48, 48, f1, f2); @@ -199,6 +199,14 @@ namespace Test //result = result && Gemm32fAutoTest(0, 0, 32, 32, 32, f1, f2); //result = result && Gemm32fAutoTest(0, 0, 32, 32, 64, f1, f2); + result = result && Gemm32fAutoTest(0, 0, 128, 128, 512, f1, f2); + result = result && Gemm32fAutoTest(0, 0, 32, 1024, 256, f1, f2); + result = result && Gemm32fAutoTest(0, 0, 64, 512, 256, f1, f2); + result = result && Gemm32fAutoTest(0, 0, 128, 256, 256, f1, f2); + result = result && Gemm32fAutoTest(0, 0, 256, 128, 256, f1, f2); + result = result && Gemm32fAutoTest(0, 0, 512, 64, 256, f1, f2); + result = result && Gemm32fAutoTest(0, 0, 1024, 32, 256, f1, f2); + return result; } @@ -268,8 +276,13 @@ namespace Test //result = result && Gemm32fAutoTest(0, 1, 997, 998, 999, f1, f2); - result = result && Gemm32fAutoTest(0, 1, 49, 49, 32, f1, f2); - result = result && Gemm32fAutoTest(0, 1, 49, 32, 49, f1, f2); + //result = result && Gemm32fAutoTest(0, 1, 49, 49, 32, f1, f2); + //result = result && Gemm32fAutoTest(0, 1, 49, 32, 49, f1, f2); + result = result && Gemm32fAutoTest(0, 1, 128, 128, 512, f1, f2); + result = result && Gemm32fAutoTest(0, 1, 128, 256, 256, f1, f2); + result = result && Gemm32fAutoTest(0, 1, 64, 512, 256, f1, f2); + result = result && Gemm32fAutoTest(0, 1, 32, 1024, 256, f1, f2); + return result; } diff --git a/src/Test/TestSynetConvolution8i.cpp b/src/Test/TestSynetConvolution8i.cpp index 3b733d8110..da9cbbcd5c 100644 --- a/src/Test/TestSynetConvolution8i.cpp +++ b/src/Test/TestSynetConvolution8i.cpp @@ -239,13 +239,14 @@ namespace Test //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 64, 8, 32, 64, _3, _1, _1, _1, _1, 1, aPr, t1, u8, u8), 1, c, f1, f2); //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 63, 8, 32, 64, _3, _1, _1, _1, _1, 1, aPr, t1, u8, u8), 1, c, f1, f2); //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 386, 50, 70, 76, _3, _1, _1, _1, _1, 1, aPr, t1, u8, u8), 1, c, f1, f2); - result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 386, 50, 70, 76, _3, _1, _1, _1, _1, 1, aGe, t1, u8, u8), 1, c, f1, f2); + //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 386, 50, 70, 76, _3, _1, _1, _1, _1, 1, aGe, t1, u8, u8), 1, c, f1, f2); //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 80, 100, 100, 80, _1, _1, _1, _0, _0, 1, aSw, t1, u8, u8), 0, c, f1, f2); //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 64, 8, 32, 64, _3, _1, _1, _1, _1, 1, aPr, t1, u8, u8), 1, c, f1, f2); //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 384, 8, 12, 256, _3, _1, _1, _1, _1, 1, aPr, t1, u8, u8), 1, c, f1, f2); //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 5000, 30, 30, 400, _1, _1, _1, _0, _0, 1, aRe, t1, f32, u8), 0, c, f1, f2); //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 2000, 30, 30, 64, _1, _1, _1, _0, _0, 1, aLr, t1, u8, f32), 1, c, f1, f2); + result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 256, 16, 16, 128, _1, _1, _1, _0, _0, 1, aGe, t1, u8, f32), 1, c, f1, f2); #endif #else //result = result && SynetConvolution8iForwardAutoTest(e, Param(1, 2000, 30, 30, 64, _1, _1, _1, _0, _0, 1, aRe, t1, f32, u8), 0, c, f1, f2); From 9f1a133d92dbb2dca71dd4af87104622dbff664d Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 22 Jun 2023 15:31:37 +0300 Subject: [PATCH 23/44] +add Base implementation of function SynetNormalizeLayerForwardV3. --- docs/2023.html | 2 + docs/help/group__synet__normalize.html | 116 +++++++++++++++++++++++++ src/Simd/SimdBase.h | 3 + src/Simd/SimdBaseSynetNormalize.cpp | 81 +++++++++++++++++ src/Simd/SimdLib.cpp | 15 ++++ src/Simd/SimdLib.h | 43 +++++++++ src/Test/Test.cpp | 1 + src/Test/TestSynetNormalize.cpp | 26 +++++- 8 files changed, 286 insertions(+), 1 deletion(-) diff --git a/docs/2023.html b/docs/2023.html index 1bc1aa41a9..3f1298091f 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -46,6 +46,7 @@
        New features
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistance.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNp.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa.
      • +
      • Base implementation of function SynetNormalizeLayerForwardV3.
      Bug fixing
        @@ -62,6 +63,7 @@
        New features
        • Tests for verifying functionality of function DescrIntEncode16f.
        • Tests for verifying functionality of function DescrIntDecode16f.
        • +
        • Tests for verifying functionality of function SynetNormalizeLayerForwardV3
        Home diff --git a/docs/help/group__synet__normalize.html b/docs/help/group__synet__normalize.html index a9ee451842..ce8b59c272 100644 --- a/docs/help/group__synet__normalize.html +++ b/docs/help/group__synet__normalize.html @@ -56,6 +56,9 @@

        Simd Library Documentation.

        SIMD_API void SimdSynetNormalizeLayerForwardV2 (const float *src, size_t batch, size_t channels, size_t spatial, const float *scale, const float *shift, const float *eps, SimdTensorFormatType format, float *buf, float *dst)  Performs forward propagation of NormalizeLayer (Version 2). More...
          +SIMD_API void SimdSynetNormalizeLayerForwardV3 (const float *src, size_t batch, size_t channels, size_t spatial, const float *scale, const float *shift, const float *eps, SimdTensorFormatType format, float *buf, float *dst) + Performs forward propagation of NormalizeLayer (Version 3). More...

        Detailed Description

        Functions to acceleratе NormalizeLayer in Synet Framework.

        @@ -272,6 +275,119 @@

        +

        ◆ SimdSynetNormalizeLayerForwardV3()

        + +
        +
        + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
        void SimdSynetNormalizeLayerForwardV3 (const float * src,
        size_t batch,
        size_t channels,
        size_t spatial,
        const float * scale,
        const float * shift,
        const float * eps,
        SimdTensorFormatType format,
        float * buf,
        float * dst 
        )
        +
        + +

        Performs forward propagation of NormalizeLayer (Version 3).

        +
        Algorithm's details:
        +\verbatim
        +for(b = 0; b < batch; ++b)
        +    for(c = 0; c < channels; ++c)
        +    {
        +        sum = 0;
        +        for (s = 0; s < spatial; ++s)
        +            sum += src[b, ñ, s];
        +        mean = sum / spatial;
        +        for (s = 0; s < spatial; ++s)
        +            dst[b, c, s] = src[b, c, s] - mean;
        +
        +        sqsum = 0;
        +        for (s = 0; s < spatial; ++s)
        +            sqsum += Square(dst[b, c, s]);
        +        norm = 1 / Sqrt(sqsum / spatial + eps);
        +        for (s = 0; s < spatial; ++s)
        +            dst[b, c, s] = dst[b, c, s] * norm * scale[c] + shift[c];
        +    }
        +
        Note
        This function is used in Synet Framework.
        +
        Parameters
        + + + + + + + + + + + +
        [in]src- a pointer to the input 32-bit float tensor.
        [in]batch- a batch size of input and output tensor.
        [in]channels- a number of channels in input and output tensor.
        [in]spatial- a spatial size (height*width) of input and output tensor.
        [in]scale- an array with scale parameters. The size of the array is equal to channels.
        [in]shift- an array with shift parameters. The size of the array is equal to channels.
        [in]eps- a pointer to epsilon parameter. It is used to prevent division by zero.
        [in]format- a format of input and output tensor. It can be SimdTensorFormatNchw, SimdTensorFormatNhwc.
        [out]buf- a pointer to external temporary buffer. The size of the buffer must be equal to channels. Can be NULL (it causes usage of internal buffer).
        [out]dst- a pointer to the output 32-bit float tensor.
        +
        +
        +
        diff --git a/src/Simd/SimdBase.h b/src/Simd/SimdBase.h index bddf1364d3..b41f9959d6 100644 --- a/src/Simd/SimdBase.h +++ b/src/Simd/SimdBase.h @@ -635,6 +635,9 @@ namespace Simd void SynetNormalizeLayerForwardV2(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + void SynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + void SynetPoolingAverage(const float * src, size_t srcC, size_t srcH, size_t srcW, size_t kernelY, size_t kernelX, size_t strideY, size_t strideX, size_t padY, size_t padX, float* dst, size_t dstH, size_t dstW, SimdBool excludePad, SimdTensorFormatType format); diff --git a/src/Simd/SimdBaseSynetNormalize.cpp b/src/Simd/SimdBaseSynetNormalize.cpp index 76ec8170e0..71c5887c92 100644 --- a/src/Simd/SimdBaseSynetNormalize.cpp +++ b/src/Simd/SimdBaseSynetNormalize.cpp @@ -205,6 +205,87 @@ namespace Simd else assert(0); } + + //------------------------------------------------------------------------------------------------- + + void SynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst) + { + float k = 1.0f / float(spatial), e = *eps; + + if (format == SimdTensorFormatNchw) + { + for (size_t b = 0; b < batch; ++b) + { + for (size_t c = 0; c < channels; ++c) + { + float sum = 0; + for (size_t s = 0; s < spatial; ++s) + sum += src[s]; + float mean = sum * k; + for (size_t s = 0; s < spatial; ++s) + dst[s] = src[s] - mean; + + float sqsum = 0; + for (size_t s = 0; s < spatial; ++s) + sqsum += Square(dst[s]); + float norm = 1.0f / ::sqrt(sqsum * k + e); + for (size_t s = 0; s < spatial; ++s) + dst[s] = dst[s] * norm * scale[c] + shift[c]; + + dst += spatial; + src += spatial; + } + } + + } + else if (format == SimdTensorFormatNhwc) + { + Array32f _buf; + if (buf == NULL) + { + _buf.Resize(channels); + buf = _buf.data; + } + for (size_t b = 0; b < batch; ++b) + { + for (size_t c = 0; c < channels; ++c) + buf[c] = 0; + for (size_t s = 0, o = 0; s < spatial; ++s) + { + for (size_t c = 0; c < channels; ++c, ++o) + buf[c] += src[o]; + } + for (size_t c = 0; c < channels; ++c) + buf[c] = buf[c] * k; + for (size_t s = 0, o = 0; s < spatial; ++s) + { + for (size_t c = 0; c < channels; ++c, ++o) + dst[o] = src[o] - buf[c]; + } + + for (size_t c = 0; c < channels; ++c) + buf[c] = 0; + for (size_t s = 0, o = 0; s < spatial; ++s) + { + for (size_t c = 0; c < channels; ++c, ++o) + buf[c] += Square(dst[o]); + } + for (size_t c = 0; c < channels; ++c) + buf[c] = 1.0f / ::sqrt(buf[c] * k + e); + for (size_t s = 0, o = 0; s < spatial; ++s) + { + for (size_t c = 0; c < channels; ++c, ++o) + dst[o] = dst[o] * buf[c] * scale[c] + shift[c]; + } + + src += channels * spatial; + dst += channels * spatial; + } + } + else + assert(0); + } } #endif } diff --git a/src/Simd/SimdLib.cpp b/src/Simd/SimdLib.cpp index 32ca248fb2..1260a263b1 100644 --- a/src/Simd/SimdLib.cpp +++ b/src/Simd/SimdLib.cpp @@ -6583,6 +6583,21 @@ SIMD_API void SimdSynetNormalizeLayerForwardV2(const float* src, size_t batch, s #endif } +SIMD_API void SimdSynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst) +{ + SIMD_EMPTY(); +#if defined(SIMD_SYNET_ENABLE) + typedef void(*SimdSynetNormalizeLayerForwardV3Ptr) (const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + const static SimdSynetNormalizeLayerForwardV3Ptr simdSynetNormalizeLayerForwardV3 = SIMD_FUNC0(SynetNormalizeLayerForwardV3);// , SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC, SIMD_NEON_FUNC); + + simdSynetNormalizeLayerForwardV3(src, batch, channels, spatial, scale, shift, eps, format, buf, dst); +#else + assert(0); +#endif +} + SIMD_API void* SimdSynetPermuteInit(const size_t* shape, const size_t* order, size_t count, SimdTensorDataType type) { SIMD_EMPTY(); diff --git a/src/Simd/SimdLib.h b/src/Simd/SimdLib.h index 699fbe84a9..1cdf0109da 100644 --- a/src/Simd/SimdLib.h +++ b/src/Simd/SimdLib.h @@ -7672,6 +7672,49 @@ extern "C" SIMD_API void SimdSynetNormalizeLayerForwardV2(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + /*! @ingroup synet_normalize + + \fn void SimdSynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + + \short Performs forward propagation of NormalizeLayer (Version 3). + + Algorithm's details: + \verbatim + for(b = 0; b < batch; ++b) + for(c = 0; c < channels; ++c) + { + sum = 0; + for (s = 0; s < spatial; ++s) + sum += src[b, ñ, s]; + mean = sum / spatial; + for (s = 0; s < spatial; ++s) + dst[b, c, s] = src[b, c, s] - mean; + + sqsum = 0; + for (s = 0; s < spatial; ++s) + sqsum += Square(dst[b, c, s]); + norm = 1 / Sqrt(sqsum / spatial + eps); + for (s = 0; s < spatial; ++s) + dst[b, c, s] = dst[b, c, s] * norm * scale[c] + shift[c]; + } + \endverbatim + + \note This function is used in Synet Framework. + + \param [in] src - a pointer to the input 32-bit float tensor. + \param [in] batch - a batch size of input and output tensor. + \param [in] channels - a number of channels in input and output tensor. + \param [in] spatial - a spatial size (height*width) of input and output tensor. + \param [in] scale - an array with scale parameters. The size of the array is equal to channels. + \param [in] shift - an array with shift parameters. The size of the array is equal to channels. + \param [in] eps - a pointer to epsilon parameter. It is used to prevent division by zero. + \param [in] format - a format of input and output tensor. It can be ::SimdTensorFormatNchw, ::SimdTensorFormatNhwc. + \param [out] buf - a pointer to external temporary buffer. The size of the buffer must be equal to channels. Can be NULL (it causes usage of internal buffer). + \param [out] dst - a pointer to the output 32-bit float tensor. + */ + SIMD_API void SimdSynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + /*! @ingroup synet_permute \fn void* SimdSynetPermuteInit(const size_t * shape, const size_t* order, size_t count, SimdTensorDataType type); diff --git a/src/Test/Test.cpp b/src/Test/Test.cpp index f6d6af17db..768a6de1cd 100644 --- a/src/Test/Test.cpp +++ b/src/Test/Test.cpp @@ -415,6 +415,7 @@ namespace Test TEST_ADD_GROUP_A0(SynetNormalizeLayerForward); TEST_ADD_GROUP_A0(SynetNormalizeLayerForwardV2); + TEST_ADD_GROUP_A0(SynetNormalizeLayerForwardV3); TEST_ADD_GROUP_A0(SynetPermute); diff --git a/src/Test/TestSynetNormalize.cpp b/src/Test/TestSynetNormalize.cpp index b8f7d06db7..25e7160aeb 100644 --- a/src/Test/TestSynetNormalize.cpp +++ b/src/Test/TestSynetNormalize.cpp @@ -181,7 +181,7 @@ namespace Test Tensor32f shift(ToShape(channels)); Tensor32f buf; if (extBuf) - buf.Reshape(ToShape(spatial)); + buf.Reshape(ToShape(Simd::Max(spatial, channels))); Tensor32f dst1(ToShape(batch, channels, 1, spatial, format)); Tensor32f dst2(ToShape(batch, channels, 1, spatial, format)); @@ -239,5 +239,29 @@ namespace Test return result; } + bool SynetNormalizeLayerForwardV3AutoTest() + { + bool result = true; + + result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Base::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); + +//#ifdef SIMD_SSE41_ENABLE +// if (Simd::Sse41::Enable) +// result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Sse41::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); +//#endif +// +//#ifdef SIMD_AVX2_ENABLE +// if (Simd::Avx2::Enable) +// result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx2::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); +//#endif +// +//#ifdef SIMD_AVX512BW_ENABLE +// if (Simd::Avx512bw::Enable) +// result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx512bw::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); +//#endif + + return result; + } + #endif } From 836ff63a87c017fbf7366f224d4ecb7960fa1d02 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 22 Jun 2023 16:30:26 +0300 Subject: [PATCH 24/44] +add SSE4.1 optimizations of function SynetNormalizeLayerForwardV3. --- docs/2023.html | 2 +- src/Simd/SimdBaseSynetNormalize.cpp | 4 +- src/Simd/SimdLib.cpp | 2 +- src/Simd/SimdSse41.h | 3 + src/Simd/SimdSse41SynetNormalize.cpp | 155 +++++++++++++++++++++++++++ src/Test/TestSynetNormalize.cpp | 10 +- 6 files changed, 167 insertions(+), 9 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index 3f1298091f..c681050e69 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -46,7 +46,7 @@
        New features
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistance.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNp.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa.
      • -
      • Base implementation of function SynetNormalizeLayerForwardV3.
      • +
      • Base implementation, SSE4.1 optimizations of function SynetNormalizeLayerForwardV3.
      Bug fixing
        diff --git a/src/Simd/SimdBaseSynetNormalize.cpp b/src/Simd/SimdBaseSynetNormalize.cpp index 71c5887c92..b4573e983e 100644 --- a/src/Simd/SimdBaseSynetNormalize.cpp +++ b/src/Simd/SimdBaseSynetNormalize.cpp @@ -228,7 +228,7 @@ namespace Simd float sqsum = 0; for (size_t s = 0; s < spatial; ++s) - sqsum += Square(dst[s]); + sqsum += Simd::Square(dst[s]); float norm = 1.0f / ::sqrt(sqsum * k + e); for (size_t s = 0; s < spatial; ++s) dst[s] = dst[s] * norm * scale[c] + shift[c]; @@ -269,7 +269,7 @@ namespace Simd for (size_t s = 0, o = 0; s < spatial; ++s) { for (size_t c = 0; c < channels; ++c, ++o) - buf[c] += Square(dst[o]); + buf[c] += Simd::Square(dst[o]); } for (size_t c = 0; c < channels; ++c) buf[c] = 1.0f / ::sqrt(buf[c] * k + e); diff --git a/src/Simd/SimdLib.cpp b/src/Simd/SimdLib.cpp index 1260a263b1..4fecd2183a 100644 --- a/src/Simd/SimdLib.cpp +++ b/src/Simd/SimdLib.cpp @@ -6590,7 +6590,7 @@ SIMD_API void SimdSynetNormalizeLayerForwardV3(const float* src, size_t batch, s #if defined(SIMD_SYNET_ENABLE) typedef void(*SimdSynetNormalizeLayerForwardV3Ptr) (const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); - const static SimdSynetNormalizeLayerForwardV3Ptr simdSynetNormalizeLayerForwardV3 = SIMD_FUNC0(SynetNormalizeLayerForwardV3);// , SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC, SIMD_NEON_FUNC); + const static SimdSynetNormalizeLayerForwardV3Ptr simdSynetNormalizeLayerForwardV3 = SIMD_FUNC1(SynetNormalizeLayerForwardV3, SIMD_SSE41_FUNC);// , SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_NEON_FUNC); simdSynetNormalizeLayerForwardV3(src, batch, channels, spatial, scale, shift, eps, format, buf, dst); #else diff --git a/src/Simd/SimdSse41.h b/src/Simd/SimdSse41.h index 8b82cb90a9..4193083277 100644 --- a/src/Simd/SimdSse41.h +++ b/src/Simd/SimdSse41.h @@ -575,6 +575,9 @@ namespace Simd void SynetNormalizeLayerForwardV2(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + void SynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + void SynetPoolingAverage(const float* src, size_t srcC, size_t srcH, size_t srcW, size_t kernelY, size_t kernelX, size_t strideY, size_t strideX, size_t padY, size_t padX, float* dst, size_t dstH, size_t dstW, SimdBool excludePad, SimdTensorFormatType format); diff --git a/src/Simd/SimdSse41SynetNormalize.cpp b/src/Simd/SimdSse41SynetNormalize.cpp index 74e0cff766..fac2c4dee5 100644 --- a/src/Simd/SimdSse41SynetNormalize.cpp +++ b/src/Simd/SimdSse41SynetNormalize.cpp @@ -347,6 +347,161 @@ namespace Simd else assert(0); } + + //------------------------------------------------------------------------------------------------- + + + void NormalizeNchwV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, float eps, float* dst) + { + float k = 1.0f / float(spatial); + size_t spatialF = AlignLo(spatial, F), s; + for (size_t b = 0; b < batch; ++b) + { + for (size_t c = 0; c < channels; ++c) + { + __m128 _sum = _mm_setzero_ps(); + for (s = 0; s < spatialF; s += F) + _sum = _mm_add_ps(_mm_loadu_ps(src + s), _sum); + float sum = ExtractSum(_sum); + for (; s < spatial; ++s) + sum += src[s]; + __m128 mean = _mm_set1_ps(sum * k); + for (s = 0; s < spatialF; s += F) + _mm_storeu_ps(dst + s, _mm_sub_ps(_mm_loadu_ps(src + s), mean)); + for (; s < spatial; ++s) + _mm_store_ss(dst + s, _mm_sub_ss(_mm_load_ss(src + s), mean)); + + __m128 _sqsum = _mm_setzero_ps(); + for (s = 0; s < spatialF; s += F) + _sqsum = _mm_add_ps(Square(_mm_loadu_ps(dst + s)), _sqsum); + float sqsum = ExtractSum(_sqsum); + for (; s < spatial; ++s) + sqsum += Simd::Square(dst[s]); + __m128 norm = _mm_set1_ps(1.0f / ::sqrt(sqsum * k + eps)); + __m128 _scale = _mm_set1_ps(scale[c]); + __m128 _shift = _mm_set1_ps(shift[c]); + for (s = 0; s < spatialF; s += F) + _mm_storeu_ps(dst + s, _mm_add_ps(_mm_mul_ps(_mm_mul_ps(_mm_loadu_ps(dst + s), norm), _scale), _shift)); + for (; s < spatial; ++s) + _mm_store_ss(dst + s, _mm_add_ss(_mm_mul_ss(_mm_mul_ss(_mm_load_ss(dst + s), norm), _scale), _shift)); + + dst += spatial; + src += spatial; + } + } + } + + void NormalizeNhwcV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, float eps, float* buf, float* dst) + { + float k = 1.0f / float(spatial); + Array32f _buf; + if (buf == NULL) + { + _buf.Resize(spatial); + buf = _buf.data; + } + size_t channelsF = AlignLo(channels, F); + __m128 _eps = _mm_set1_ps(eps), _k = _mm_set1_ps(k), _1 = _mm_set1_ps(1.0f); + for (size_t b = 0, c; b < batch; ++b) + { + for (c = 0; c < channelsF; c += F) + _mm_storeu_ps(buf + c, _mm_setzero_ps()); + for (; c < channels; ++c) + _mm_store_ss(buf + c, _mm_setzero_ps()); + for (size_t s = 0; s < spatial; ++s) + { + const float* ps = src + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m128 _src = _mm_loadu_ps(ps + c); + __m128 _sum = _mm_loadu_ps(buf + c); + _mm_storeu_ps(buf + c, _mm_add_ps(_sum, _src)); + } + for (; c < channels; ++c) + { + __m128 _src = _mm_load_ss(ps + c); + __m128 _sum = _mm_load_ss(buf + c); + _mm_store_ss(buf + c, _mm_add_ss(_sum, _src)); + } + } + for (c = 0; c < channelsF; c += F) + _mm_storeu_ps(buf + c, _mm_mul_ps(_mm_loadu_ps(buf + c), _k)); + for (; c < channels; ++c) + _mm_store_ss(buf + c, _mm_mul_ss(_mm_load_ss(buf + c), _k)); + for (size_t s = 0; s < spatial; ++s) + { + const float* ps = src + s * channels; + float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m128 _src = _mm_loadu_ps(ps + c); + __m128 mean = _mm_loadu_ps(buf + c); + _mm_storeu_ps(pd + c, _mm_sub_ps(_src, mean)); + } + for (; c < channels; ++c) + { + __m128 _src = _mm_load_ss(ps + c); + __m128 mean = _mm_load_ss(buf + c); + _mm_store_ss(pd + c, _mm_sub_ps(_src, mean)); + } + } + + for (c = 0; c < channelsF; c += F) + _mm_storeu_ps(buf + c, _mm_setzero_ps()); + for (; c < channels; ++c) + _mm_store_ss(buf + c, _mm_setzero_ps()); + for (size_t s = 0; s < spatial; ++s) + { + const float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m128 _dst = _mm_loadu_ps(pd + c); + __m128 _sum = _mm_loadu_ps(buf + c); + _mm_storeu_ps(buf + c, _mm_add_ps(_sum, _mm_mul_ps(_dst, _dst))); + } + for (; c < channels; ++c) + { + __m128 _dst = _mm_load_ss(pd + c); + __m128 _sum = _mm_load_ss(buf + c); + _mm_store_ss(buf + c, _mm_add_ss(_sum, _mm_mul_ss(_dst, _dst))); + } + } + for (c = 0; c < channelsF; c += F) + _mm_storeu_ps(buf + c, _mm_div_ps(_1, _mm_sqrt_ps(_mm_add_ps(_mm_mul_ps(_mm_loadu_ps(buf + c), _k), _eps)))); + for (; c < channels; ++c) + _mm_store_ss(buf + c, _mm_div_ss(_1, _mm_sqrt_ss(_mm_add_ss(_mm_mul_ss(_mm_load_ss(buf + c), _k), _eps)))); + for (size_t s = 0; s < spatial; ++s) + { + float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m128 _dst = _mm_loadu_ps(pd + c); + __m128 norm = _mm_loadu_ps(buf + c); + _mm_storeu_ps(pd + c, _mm_add_ps(_mm_mul_ps(_mm_mul_ps(_dst, norm), _mm_loadu_ps(scale + c)), _mm_loadu_ps(shift + c))); + } + for (; c < channels; ++c) + { + __m128 _dst = _mm_load_ss(pd + c); + __m128 norm = _mm_load_ss(buf + c); + _mm_store_ss(pd + c, _mm_add_ss(_mm_mul_ss(_mm_mul_ss(_dst, norm), _mm_load_ss(scale + c)), _mm_load_ss(shift + c))); + } + } + + src += channels * spatial; + dst += channels * spatial; + } + } + + void SynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst) + { + if (format == SimdTensorFormatNchw) + NormalizeNchwV3(src, batch, channels, spatial, scale, shift, *eps, dst); + else if (format == SimdTensorFormatNhwc) + NormalizeNhwcV3(src, batch, channels, spatial, scale, shift, *eps, buf, dst); + else + assert(0); + } } #endif } diff --git a/src/Test/TestSynetNormalize.cpp b/src/Test/TestSynetNormalize.cpp index 25e7160aeb..fabf231b62 100644 --- a/src/Test/TestSynetNormalize.cpp +++ b/src/Test/TestSynetNormalize.cpp @@ -245,11 +245,11 @@ namespace Test result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Base::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); -//#ifdef SIMD_SSE41_ENABLE -// if (Simd::Sse41::Enable) -// result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Sse41::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); -//#endif -// +#ifdef SIMD_SSE41_ENABLE + if (Simd::Sse41::Enable) + result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Sse41::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); +#endif + //#ifdef SIMD_AVX2_ENABLE // if (Simd::Avx2::Enable) // result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx2::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); From 1b7c9943536470f6f5527b76a6201cf25def653e Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 22 Jun 2023 17:35:56 +0300 Subject: [PATCH 25/44] +add AVX2 optimizations of function SynetNormalizeLayerForwardV3. --- docs/2023.html | 2 +- src/Simd/SimdAvx2.h | 3 + src/Simd/SimdAvx2SynetNormalize.cpp | 154 +++++++++++++++++++++++++++ src/Simd/SimdLib.cpp | 2 +- src/Simd/SimdSse41SynetNormalize.cpp | 1 - src/Test/TestSynetNormalize.cpp | 10 +- 6 files changed, 164 insertions(+), 8 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index c681050e69..27e2f9fe0b 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -46,7 +46,7 @@
        New features
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistance.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNp.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa.
      • -
      • Base implementation, SSE4.1 optimizations of function SynetNormalizeLayerForwardV3.
      • +
      • Base implementation, SSE4.1, AVX2 optimizations of function SynetNormalizeLayerForwardV3.
      Bug fixing
        diff --git a/src/Simd/SimdAvx2.h b/src/Simd/SimdAvx2.h index 782c99a7e5..64a0baec65 100644 --- a/src/Simd/SimdAvx2.h +++ b/src/Simd/SimdAvx2.h @@ -518,6 +518,9 @@ namespace Simd void SynetNormalizeLayerForwardV2(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + void SynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + void SynetPoolingMax32f(const float* src, size_t srcC, size_t srcH, size_t srcW, size_t kernelC, size_t kernelY, size_t kernelX, size_t strideC, size_t strideY, size_t strideX, size_t padC, size_t padY, size_t padX, float* dst, size_t dstC, size_t dstH, size_t dstW, SimdTensorFormatType format); diff --git a/src/Simd/SimdAvx2SynetNormalize.cpp b/src/Simd/SimdAvx2SynetNormalize.cpp index 4c3421d7ad..a691365b1a 100644 --- a/src/Simd/SimdAvx2SynetNormalize.cpp +++ b/src/Simd/SimdAvx2SynetNormalize.cpp @@ -349,6 +349,160 @@ namespace Simd else assert(0); } + + //------------------------------------------------------------------------------------------------- + + void NormalizeNchwV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, float eps, float* dst) + { + float k = 1.0f / float(spatial); + size_t spatialF = AlignLo(spatial, F), s; + for (size_t b = 0; b < batch; ++b) + { + for (size_t c = 0; c < channels; ++c) + { + __m256 _sum = _mm256_setzero_ps(); + for (s = 0; s < spatialF; s += F) + _sum = _mm256_add_ps(_mm256_loadu_ps(src + s), _sum); + float sum = Avx::ExtractSum(_sum); + for (; s < spatial; ++s) + sum += src[s]; + __m256 mean = _mm256_set1_ps(sum * k); + for (s = 0; s < spatialF; s += F) + _mm256_storeu_ps(dst + s, _mm256_sub_ps(_mm256_loadu_ps(src + s), mean)); + for (; s < spatial; ++s) + _mm_store_ss(dst + s, _mm_sub_ss(_mm_load_ss(src + s), _mm256_castps256_ps128(mean))); + + __m256 _sqsum = _mm256_setzero_ps(); + for (s = 0; s < spatialF; s += F) + _sqsum = _mm256_add_ps(Square(_mm256_loadu_ps(dst + s)), _sqsum); + float sqsum = Avx::ExtractSum(_sqsum); + for (; s < spatial; ++s) + sqsum += Simd::Square(dst[s]); + __m256 norm = _mm256_set1_ps(1.0f / ::sqrt(sqsum * k + eps)); + __m256 _scale = _mm256_set1_ps(scale[c]); + __m256 _shift = _mm256_set1_ps(shift[c]); + for (s = 0; s < spatialF; s += F) + _mm256_storeu_ps(dst + s, _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(_mm256_loadu_ps(dst + s), norm), _scale), _shift)); + for (; s < spatial; ++s) + _mm_store_ss(dst + s, _mm_add_ss(_mm_mul_ss(_mm_mul_ss(_mm_load_ss(dst + s), _mm256_castps256_ps128(norm)), _mm256_castps256_ps128(_scale)), _mm256_castps256_ps128(_shift))); + + dst += spatial; + src += spatial; + } + } + } + + void NormalizeNhwcV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, float eps, float* buf, float* dst) + { + float k = 1.0f / float(spatial); + Array32f _buf; + if (buf == NULL) + { + _buf.Resize(spatial); + buf = _buf.data; + } + size_t channelsF = AlignLo(channels, F); + __m256 _eps = _mm256_set1_ps(eps), _k = _mm256_set1_ps(k), _1 = _mm256_set1_ps(1.0f); + for (size_t b = 0, c; b < batch; ++b) + { + for (c = 0; c < channelsF; c += F) + _mm256_storeu_ps(buf + c, _mm256_setzero_ps()); + for (; c < channels; ++c) + _mm_store_ss(buf + c, _mm_setzero_ps()); + for (size_t s = 0; s < spatial; ++s) + { + const float* ps = src + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m256 _src = _mm256_loadu_ps(ps + c); + __m256 _sum = _mm256_loadu_ps(buf + c); + _mm256_storeu_ps(buf + c, _mm256_add_ps(_sum, _src)); + } + for (; c < channels; ++c) + { + __m128 _src = _mm_load_ss(ps + c); + __m128 _sum = _mm_load_ss(buf + c); + _mm_store_ss(buf + c, _mm_add_ss(_sum, _src)); + } + } + for (c = 0; c < channelsF; c += F) + _mm256_storeu_ps(buf + c, _mm256_mul_ps(_mm256_loadu_ps(buf + c), _k)); + for (; c < channels; ++c) + _mm_store_ss(buf + c, _mm_mul_ss(_mm_load_ss(buf + c), _mm256_castps256_ps128(_k))); + for (size_t s = 0; s < spatial; ++s) + { + const float* ps = src + s * channels; + float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m256 _src = _mm256_loadu_ps(ps + c); + __m256 mean = _mm256_loadu_ps(buf + c); + _mm256_storeu_ps(pd + c, _mm256_sub_ps(_src, mean)); + } + for (; c < channels; ++c) + { + __m128 _src = _mm_load_ss(ps + c); + __m128 mean = _mm_load_ss(buf + c); + _mm_store_ss(pd + c, _mm_sub_ps(_src, mean)); + } + } + + for (c = 0; c < channelsF; c += F) + _mm256_storeu_ps(buf + c, _mm256_setzero_ps()); + for (; c < channels; ++c) + _mm_store_ss(buf + c, _mm_setzero_ps()); + for (size_t s = 0; s < spatial; ++s) + { + const float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m256 _dst = _mm256_loadu_ps(pd + c); + __m256 _sum = _mm256_loadu_ps(buf + c); + _mm256_storeu_ps(buf + c, _mm256_add_ps(_sum, _mm256_mul_ps(_dst, _dst))); + } + for (; c < channels; ++c) + { + __m128 _dst = _mm_load_ss(pd + c); + __m128 _sum = _mm_load_ss(buf + c); + _mm_store_ss(buf + c, _mm_add_ss(_sum, _mm_mul_ss(_dst, _dst))); + } + } + for (c = 0; c < channelsF; c += F) + _mm256_storeu_ps(buf + c, _mm256_div_ps(_1, _mm256_sqrt_ps(_mm256_add_ps(_mm256_mul_ps(_mm256_loadu_ps(buf + c), _k), _eps)))); + for (; c < channels; ++c) + _mm_store_ss(buf + c, _mm_div_ss(_mm256_castps256_ps128(_1), _mm_sqrt_ss(_mm_add_ss(_mm_mul_ss(_mm_load_ss(buf + c), _mm256_castps256_ps128(_k)), _mm256_castps256_ps128(_eps))))); + for (size_t s = 0; s < spatial; ++s) + { + float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m256 _dst = _mm256_loadu_ps(pd + c); + __m256 norm = _mm256_loadu_ps(buf + c); + _mm256_storeu_ps(pd + c, _mm256_add_ps(_mm256_mul_ps(_mm256_mul_ps(_dst, norm), _mm256_loadu_ps(scale + c)), _mm256_loadu_ps(shift + c))); + } + for (; c < channels; ++c) + { + __m128 _dst = _mm_load_ss(pd + c); + __m128 norm = _mm_load_ss(buf + c); + _mm_store_ss(pd + c, _mm_add_ss(_mm_mul_ss(_mm_mul_ss(_dst, norm), _mm_load_ss(scale + c)), _mm_load_ss(shift + c))); + } + } + + src += channels * spatial; + dst += channels * spatial; + } + } + + void SynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst) + { + if (format == SimdTensorFormatNchw) + NormalizeNchwV3(src, batch, channels, spatial, scale, shift, *eps, dst); + else if (format == SimdTensorFormatNhwc) + NormalizeNhwcV3(src, batch, channels, spatial, scale, shift, *eps, buf, dst); + else + assert(0); + } } #endif } diff --git a/src/Simd/SimdLib.cpp b/src/Simd/SimdLib.cpp index 4fecd2183a..5371520109 100644 --- a/src/Simd/SimdLib.cpp +++ b/src/Simd/SimdLib.cpp @@ -6590,7 +6590,7 @@ SIMD_API void SimdSynetNormalizeLayerForwardV3(const float* src, size_t batch, s #if defined(SIMD_SYNET_ENABLE) typedef void(*SimdSynetNormalizeLayerForwardV3Ptr) (const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); - const static SimdSynetNormalizeLayerForwardV3Ptr simdSynetNormalizeLayerForwardV3 = SIMD_FUNC1(SynetNormalizeLayerForwardV3, SIMD_SSE41_FUNC);// , SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_NEON_FUNC); + const static SimdSynetNormalizeLayerForwardV3Ptr simdSynetNormalizeLayerForwardV3 = SIMD_FUNC2(SynetNormalizeLayerForwardV3, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC);// , SIMD_AVX512BW_FUNC, SIMD_NEON_FUNC); simdSynetNormalizeLayerForwardV3(src, batch, channels, spatial, scale, shift, eps, format, buf, dst); #else diff --git a/src/Simd/SimdSse41SynetNormalize.cpp b/src/Simd/SimdSse41SynetNormalize.cpp index fac2c4dee5..62ed1661ce 100644 --- a/src/Simd/SimdSse41SynetNormalize.cpp +++ b/src/Simd/SimdSse41SynetNormalize.cpp @@ -350,7 +350,6 @@ namespace Simd //------------------------------------------------------------------------------------------------- - void NormalizeNchwV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, float eps, float* dst) { float k = 1.0f / float(spatial); diff --git a/src/Test/TestSynetNormalize.cpp b/src/Test/TestSynetNormalize.cpp index fabf231b62..2070dd58b6 100644 --- a/src/Test/TestSynetNormalize.cpp +++ b/src/Test/TestSynetNormalize.cpp @@ -250,11 +250,11 @@ namespace Test result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Sse41::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); #endif -//#ifdef SIMD_AVX2_ENABLE -// if (Simd::Avx2::Enable) -// result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx2::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); -//#endif -// +#ifdef SIMD_AVX2_ENABLE + if (Simd::Avx2::Enable) + result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx2::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); +#endif + //#ifdef SIMD_AVX512BW_ENABLE // if (Simd::Avx512bw::Enable) // result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx512bw::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); From 600771514021b8707b67713d2e6ba29abb1f537c Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 22 Jun 2023 18:02:22 +0300 Subject: [PATCH 26/44] +add AVX-512BW optimizations of function SynetNormalizeLayerForwardV3. --- docs/2023.html | 2 +- src/Simd/SimdAvx512bw.h | 3 + src/Simd/SimdAvx512bwSynetNormalize.cpp | 162 ++++++++++++++++++++++++ src/Simd/SimdLib.cpp | 2 +- src/Test/TestSynetNormalize.cpp | 8 +- 5 files changed, 171 insertions(+), 6 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index 27e2f9fe0b..c3b284b54c 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -46,7 +46,7 @@
        New features
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistance.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNp.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa.
      • -
      • Base implementation, SSE4.1, AVX2 optimizations of function SynetNormalizeLayerForwardV3.
      • +
      • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function SynetNormalizeLayerForwardV3.
      Bug fixing
        diff --git a/src/Simd/SimdAvx512bw.h b/src/Simd/SimdAvx512bw.h index 62a20be362..289e118bfb 100644 --- a/src/Simd/SimdAvx512bw.h +++ b/src/Simd/SimdAvx512bw.h @@ -556,6 +556,9 @@ namespace Simd void SynetNormalizeLayerForwardV2(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + void SynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + void SynetPoolingAverage(const float* src, size_t srcC, size_t srcH, size_t srcW, size_t kernelY, size_t kernelX, size_t strideY, size_t strideX, size_t padY, size_t padX, float* dst, size_t dstH, size_t dstW, SimdBool excludePad, SimdTensorFormatType format); diff --git a/src/Simd/SimdAvx512bwSynetNormalize.cpp b/src/Simd/SimdAvx512bwSynetNormalize.cpp index b5fc9bf2d4..66bdb2d318 100644 --- a/src/Simd/SimdAvx512bwSynetNormalize.cpp +++ b/src/Simd/SimdAvx512bwSynetNormalize.cpp @@ -368,6 +368,168 @@ namespace Simd else assert(0); } + + //------------------------------------------------------------------------------------------------- + + void NormalizeNchwV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, float eps, float* dst) + { + float k = 1.0f / float(spatial); + size_t spatialF = AlignLo(spatial, F), s; + __mmask16 spatialM = TailMask16(spatial - spatialF); + for (size_t b = 0; b < batch; ++b) + { + for (size_t c = 0; c < channels; ++c) + { + __m512 _sum = _mm512_setzero_ps(); + for (s = 0; s < spatialF; s += F) + _sum = _mm512_add_ps(_mm512_loadu_ps(src + s), _sum); + if(s < spatial) + _sum = _mm512_add_ps(_mm512_maskz_loadu_ps(spatialM, src + s), _sum); + float sum = ExtractSum(_sum); + __m512 mean = _mm512_set1_ps(sum * k); + for (s = 0; s < spatialF; s += F) + _mm512_storeu_ps(dst + s, _mm512_sub_ps(_mm512_loadu_ps(src + s), mean)); + if (s < spatial) + _mm512_mask_storeu_ps(dst + s, spatialM, _mm512_sub_ps(_mm512_maskz_loadu_ps(spatialM, src + s), mean)); + + __m512 _sqsum = _mm512_setzero_ps(); + for (s = 0; s < spatialF; s += F) + { + __m512 _dst = _mm512_loadu_ps(dst + s); + _sqsum = _mm512_fmadd_ps(_dst, _dst, _sqsum); + } + if (s < spatial) + { + __m512 _dst = _mm512_maskz_loadu_ps(spatialM, dst + s); + _sqsum = _mm512_fmadd_ps(_dst, _dst, _sqsum); + } + float sqsum = ExtractSum(_sqsum); + __m512 norm = _mm512_set1_ps(1.0f / ::sqrt(sqsum * k + eps)); + __m512 _scale = _mm512_set1_ps(scale[c]); + __m512 _shift = _mm512_set1_ps(shift[c]); + for (s = 0; s < spatialF; s += F) + _mm512_storeu_ps(dst + s, _mm512_add_ps(_mm512_mul_ps(_mm512_mul_ps(_mm512_loadu_ps(dst + s), norm), _scale), _shift)); + if (s < spatial) + _mm512_mask_storeu_ps(dst + s, spatialM, _mm512_add_ps(_mm512_mul_ps(_mm512_mul_ps(_mm512_maskz_loadu_ps(spatialM, dst + s), norm), _scale), _shift)); + + dst += spatial; + src += spatial; + } + } + } + + void NormalizeNhwcV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, float eps, float* buf, float* dst) + { + float k = 1.0f / float(spatial); + Array32f _buf; + if (buf == NULL) + { + _buf.Resize(spatial); + buf = _buf.data; + } + size_t channelsF = AlignLo(channels, F); + __mmask16 channelsM = TailMask16(channels - channelsF); + __m512 _eps = _mm512_set1_ps(eps), _k = _mm512_set1_ps(k), _1 = _mm512_set1_ps(1.0f); + for (size_t b = 0, c; b < batch; ++b) + { + for (c = 0; c < channelsF; c += F) + _mm512_storeu_ps(buf + c, _mm512_setzero_ps()); + if(c < channels) + _mm512_mask_storeu_ps(buf + c, channelsM, _mm512_setzero_ps()); + for (size_t s = 0; s < spatial; ++s) + { + const float* ps = src + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m512 _src = _mm512_loadu_ps(ps + c); + __m512 _sum = _mm512_loadu_ps(buf + c); + _mm512_storeu_ps(buf + c, _mm512_add_ps(_sum, _src)); + } + if (c < channels) + { + __m512 _src = _mm512_maskz_loadu_ps(channelsM, ps + c); + __m512 _sum = _mm512_maskz_loadu_ps(channelsM, buf + c); + _mm512_mask_storeu_ps(buf + c, channelsM, _mm512_add_ps(_sum, _src)); + } + } + for (c = 0; c < channelsF; c += F) + _mm512_storeu_ps(buf + c, _mm512_mul_ps(_mm512_loadu_ps(buf + c), _k)); + if (c < channels) + _mm512_mask_storeu_ps(buf + c, channelsM, _mm512_mul_ps(_mm512_maskz_loadu_ps(channelsM, buf + c), _k)); + for (size_t s = 0; s < spatial; ++s) + { + const float* ps = src + s * channels; + float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m512 _src = _mm512_loadu_ps(ps + c); + __m512 mean = _mm512_loadu_ps(buf + c); + _mm512_storeu_ps(pd + c, _mm512_sub_ps(_src, mean)); + } + if (c < channels) + { + __m512 _src = _mm512_maskz_loadu_ps(channelsM, ps + c); + __m512 mean = _mm512_maskz_loadu_ps(channelsM, buf + c); + _mm512_mask_storeu_ps(pd + c, channelsM, _mm512_sub_ps(_src, mean)); + } + } + + for (c = 0; c < channelsF; c += F) + _mm512_storeu_ps(buf + c, _mm512_setzero_ps()); + if (c < channels) + _mm512_mask_storeu_ps(buf + c, channelsM, _mm512_setzero_ps()); + for (size_t s = 0; s < spatial; ++s) + { + const float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m512 _dst = _mm512_loadu_ps(pd + c); + __m512 _sum = _mm512_loadu_ps(buf + c); + _mm512_storeu_ps(buf + c, _mm512_fmadd_ps(_dst, _dst, _sum)); + } + if (c < channels) + { + __m512 _dst = _mm512_maskz_loadu_ps(channelsM, pd + c); + __m512 _sum = _mm512_maskz_loadu_ps(channelsM, buf + c); + _mm512_mask_storeu_ps(buf + c, channelsM, _mm512_fmadd_ps(_dst, _dst, _sum)); + } + } + for (c = 0; c < channelsF; c += F) + _mm512_storeu_ps(buf + c, _mm512_div_ps(_1, _mm512_sqrt_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_loadu_ps(buf + c), _k), _eps)))); + if (c < channels) + _mm512_mask_storeu_ps(buf + c, channelsM, _mm512_div_ps(_1, _mm512_sqrt_ps(_mm512_add_ps(_mm512_mul_ps(_mm512_maskz_loadu_ps(channelsM, buf + c), _k), _eps)))); + for (size_t s = 0; s < spatial; ++s) + { + float* pd = dst + s * channels; + for (c = 0; c < channelsF; c += F) + { + __m512 _dst = _mm512_loadu_ps(pd + c); + __m512 norm = _mm512_loadu_ps(buf + c); + _mm512_storeu_ps(pd + c, _mm512_add_ps(_mm512_mul_ps(_mm512_mul_ps(_dst, norm), _mm512_loadu_ps(scale + c)), _mm512_loadu_ps(shift + c))); + } + if (c < channels) + { + __m512 _dst = _mm512_maskz_loadu_ps(channelsM, pd + c); + __m512 norm = _mm512_maskz_loadu_ps(channelsM, buf + c); + _mm512_mask_storeu_ps(pd + c, channelsM, _mm512_add_ps(_mm512_mul_ps(_mm512_mul_ps(_dst, norm), _mm512_maskz_loadu_ps(channelsM, scale + c)), _mm512_maskz_loadu_ps(channelsM, shift + c))); + } + } + + src += channels * spatial; + dst += channels * spatial; + } + } + + void SynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, + const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst) + { + if (format == SimdTensorFormatNchw) + NormalizeNchwV3(src, batch, channels, spatial, scale, shift, *eps, dst); + else if (format == SimdTensorFormatNhwc) + NormalizeNhwcV3(src, batch, channels, spatial, scale, shift, *eps, buf, dst); + else + assert(0); + } } #endif } diff --git a/src/Simd/SimdLib.cpp b/src/Simd/SimdLib.cpp index 5371520109..ad5de4e830 100644 --- a/src/Simd/SimdLib.cpp +++ b/src/Simd/SimdLib.cpp @@ -6590,7 +6590,7 @@ SIMD_API void SimdSynetNormalizeLayerForwardV3(const float* src, size_t batch, s #if defined(SIMD_SYNET_ENABLE) typedef void(*SimdSynetNormalizeLayerForwardV3Ptr) (const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); - const static SimdSynetNormalizeLayerForwardV3Ptr simdSynetNormalizeLayerForwardV3 = SIMD_FUNC2(SynetNormalizeLayerForwardV3, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC);// , SIMD_AVX512BW_FUNC, SIMD_NEON_FUNC); + const static SimdSynetNormalizeLayerForwardV3Ptr simdSynetNormalizeLayerForwardV3 = SIMD_FUNC3(SynetNormalizeLayerForwardV3, SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC);// , SIMD_NEON_FUNC); simdSynetNormalizeLayerForwardV3(src, batch, channels, spatial, scale, shift, eps, format, buf, dst); #else diff --git a/src/Test/TestSynetNormalize.cpp b/src/Test/TestSynetNormalize.cpp index 2070dd58b6..21fde63a34 100644 --- a/src/Test/TestSynetNormalize.cpp +++ b/src/Test/TestSynetNormalize.cpp @@ -255,10 +255,10 @@ namespace Test result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx2::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); #endif -//#ifdef SIMD_AVX512BW_ENABLE -// if (Simd::Avx512bw::Enable) -// result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx512bw::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); -//#endif +#ifdef SIMD_AVX512BW_ENABLE + if (Simd::Avx512bw::Enable) + result = result && SynetNormalizeLayerForwardV2AutoTest(FUNC_SNLF2(Simd::Avx512bw::SynetNormalizeLayerForwardV3), FUNC_SNLF2(SimdSynetNormalizeLayerForwardV3)); +#endif return result; } From 5f3e8c4e31b1e94bfa2fa78160ddef2d3ebb4456 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Fri, 23 Jun 2023 22:47:49 +0300 Subject: [PATCH 27/44] *refactoring of DescrInt. --- src/Simd/SimdAvx2DescrInt.cpp | 40 +++++++-------- src/Simd/SimdAvx512bwDescrInt.cpp | 42 ++++++++-------- src/Simd/SimdBaseDescrInt.cpp | 16 +++--- src/Simd/SimdDescrInt.h | 11 +++-- src/Simd/SimdSse41DescrInt.cpp | 82 ++++++++++++++----------------- 5 files changed, 93 insertions(+), 98 deletions(-) diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 348b568700..d56b63afb1 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -791,9 +791,9 @@ namespace Simd //------------------------------------------------------------------------------------------------- - template void MicroCosineDistances2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template void MicroCosineDistancesDirect2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - template<> void MicroCosineDistances2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size64 = AlignLo(size, 64), o = 16; __m256i a0, a1, b0; @@ -870,7 +870,7 @@ namespace Simd DecodeCosineDistances2x4(A, B, ab, distances, stride); } - template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m256i a0, a1, b0; @@ -928,7 +928,7 @@ namespace Simd DecodeCosineDistances2x4(A, B, ab, distances, stride); } - template<> void MicroCosineDistances2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m256i a0, a1, b0; @@ -986,7 +986,7 @@ namespace Simd DecodeCosineDistances2x4(A, B, ab, distances, stride); } - template<> void MicroCosineDistances2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m256i a0, a1, b0; @@ -1044,7 +1044,7 @@ namespace Simd DecodeCosineDistances2x4(A, B, ab, distances, stride); } - template<> void MicroCosineDistances2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m256i a0, a1, b0; @@ -1102,9 +1102,9 @@ namespace Simd DecodeCosineDistances2x4(A, B, ab, distances, stride); } - template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - template<> void MicroCosineDistances1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size64 = AlignLo(size, 64), o = 16; __m256i a0, b0; @@ -1162,7 +1162,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m256i a0, b0; @@ -1206,7 +1206,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m256i a0, b0; @@ -1250,7 +1250,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m256i a0, b0; @@ -1294,7 +1294,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m256i a0, b0; @@ -1338,7 +1338,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t M2 = AlignLoAny(M, 2); size_t N4 = AlignLoAny(N, 4); @@ -1347,7 +1347,7 @@ namespace Simd { size_t j = 0; for (; j < N4; j += 4) - MicroCosineDistances2x4(A + i, B + j, size, distances + j, stride); + MicroCosineDistancesDirect2x4(A + i, B + j, size, distances + j, stride); for (; j < N; j += 1) { CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); @@ -1359,7 +1359,7 @@ namespace Simd { size_t j = 0; for (; j < N4; j += 4) - MicroCosineDistances1x4(A + i, B + j, size, distances + j, stride); + MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); for (; j < N; j += 1) CosineDistance(A[i], B[j], size, distances + j); distances += 1 * stride; @@ -1382,7 +1382,7 @@ namespace Simd _decode32f = Decode32f4; _decode16f = Decode16f4; _cosineDistance = Avx2::CosineDistance<4>; - _macroCosineDistances = Avx2::MacroCosineDistances<4>; + _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<4>; break; } case 5: @@ -1392,7 +1392,7 @@ namespace Simd _decode32f = Decode32f5; _decode16f = Decode16f5; _cosineDistance = Avx2::CosineDistance<5>; - _macroCosineDistances = Avx2::MacroCosineDistances<5>; + _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<5>; break; } case 6: @@ -1402,7 +1402,7 @@ namespace Simd _decode32f = Decode32f6; _decode16f = Decode16f6; _cosineDistance = Avx2::CosineDistance<6>; - _macroCosineDistances = Avx2::MacroCosineDistances<6>; + _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<6>; break; } case 7: @@ -1412,7 +1412,7 @@ namespace Simd _decode32f = Decode32f7; _decode16f = Decode16f7; _cosineDistance = Avx2::CosineDistance<7>; - _macroCosineDistances = Avx2::MacroCosineDistances<7>; + _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<7>; break; } case 8: @@ -1422,7 +1422,7 @@ namespace Simd _decode32f = Decode32f8; _decode16f = Decode16f8; _cosineDistance = Avx2::CosineDistance<8>; - _macroCosineDistances = Avx2::MacroCosineDistances<8>; + _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<8>; break; } default: diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index 7c44c1533d..286fa71cbb 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -841,9 +841,9 @@ namespace Simd //------------------------------------------------------------------------------------------------- - template void MicroCosineDistances4x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template void MicroCosineDistancesDirect4x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - template<> void MicroCosineDistances4x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect4x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size128 = AlignLo(size, 128), o = 16; __m512i a00, a10, a20, a30, a01, a11, a21, a31, b00, b01; @@ -968,7 +968,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } - template<> void MicroCosineDistances4x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect4x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m512i a0, a1, a2, a3, b0; @@ -1061,7 +1061,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } - template<> void MicroCosineDistances4x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect4x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m512i a0, a1, a2, a3, b0; @@ -1154,7 +1154,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } - template<> void MicroCosineDistances4x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect4x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m512i a0, a1, a2, a3, b0; @@ -1247,7 +1247,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } - template<> void MicroCosineDistances4x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect4x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m512i a0, a1, a2, a3, b0; @@ -1340,9 +1340,9 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); } - template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - template<> void MicroCosineDistances1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size128 = AlignLo(size, 128), o = 16; __m512i a00, a01, b00, b01; @@ -1406,7 +1406,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m512i a0, b0; @@ -1451,7 +1451,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m512i a0, b0; @@ -1496,7 +1496,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m512i a0, b0; @@ -1541,7 +1541,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m512i a0, b0; @@ -1586,7 +1586,7 @@ namespace Simd Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t M4 = AlignLoAny(M, 4); size_t N4 = AlignLoAny(N, 4); @@ -1595,7 +1595,7 @@ namespace Simd { size_t j = 0; for (; j < N4; j += 4) - MicroCosineDistances4x4(A + i, B + j, size, distances + j, stride); + MicroCosineDistancesDirect4x4(A + i, B + j, size, distances + j, stride); for (; j < N; j += 1) { CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); @@ -1609,7 +1609,7 @@ namespace Simd { size_t j = 0; for (; j < N4; j += 4) - MicroCosineDistances1x4(A + i, B + j, size, distances + j, stride); + MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); for (; j < N; j += 1) CosineDistance(A[i], B[j], size, distances + j); distances += 1 * stride; @@ -1632,7 +1632,7 @@ namespace Simd _decode32f = Decode32f4; _decode16f = Decode16f4; _cosineDistance = Avx512bw::CosineDistance<4>; - _macroCosineDistances = Avx512bw::MacroCosineDistances<4>; + _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<4>; break; } case 5: @@ -1642,7 +1642,7 @@ namespace Simd _decode32f = Decode32f5; _decode16f = Decode16f5; _cosineDistance = Avx512bw::CosineDistance<5>; - _macroCosineDistances = Avx512bw::MacroCosineDistances<5>; + _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<5>; break; } case 6: @@ -1652,7 +1652,7 @@ namespace Simd _decode32f = Decode32f6; _decode16f = Decode16f6; _cosineDistance = Avx512bw::CosineDistance<6>; - _macroCosineDistances = Avx512bw::MacroCosineDistances<6>; + _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<6>; break; } case 7: @@ -1662,7 +1662,7 @@ namespace Simd _decode32f = Decode32f7; _decode16f = Decode16f7; _cosineDistance = Avx512bw::CosineDistance<7>; - _macroCosineDistances = Avx512bw::MacroCosineDistances<7>; + _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<7>; break; } case 8: @@ -1672,8 +1672,8 @@ namespace Simd _decode32f = Decode32f8; _decode16f = Decode16f8; _cosineDistance = Avx512bw::CosineDistance<8>; - _macroCosineDistances = Avx512bw::MacroCosineDistances<8>; - _microM = 4; + _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<8>; + _microMd = 4; break; } default: diff --git a/src/Simd/SimdBaseDescrInt.cpp b/src/Simd/SimdBaseDescrInt.cpp index b643a35bb9..807377e547 100644 --- a/src/Simd/SimdBaseDescrInt.cpp +++ b/src/Simd/SimdBaseDescrInt.cpp @@ -513,7 +513,7 @@ namespace Simd //------------------------------------------------------------------------------------------------- - template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { for (size_t i = 0; i < M; ++i) { @@ -543,8 +543,8 @@ namespace Simd { _encSize = 16 + DivHi(size * depth, 8); _range = float((1 << _depth) - 1); - _microM = 1; - _microN = 1; + _microMd = 1; + _microNd = 1; _minMax32f = MinMax32f; _minMax16f = MinMax16f; switch (depth) @@ -556,7 +556,7 @@ namespace Simd _decode32f = Decode32f4; _decode16f = Decode16f4; _cosineDistance = Base::CosineDistance<4>; - _macroCosineDistances = Base::MacroCosineDistances<4>; + _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<4>; break; } case 5: @@ -566,7 +566,7 @@ namespace Simd _decode32f = Decode32f5; _decode16f = Decode16f5; _cosineDistance = Base::CosineDistance<5>; - _macroCosineDistances = Base::MacroCosineDistances<5>; + _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<5>; break; } case 6: @@ -576,7 +576,7 @@ namespace Simd _decode32f = Decode32f6; _decode16f = Decode16f6; _cosineDistance = Base::CosineDistance<6>; - _macroCosineDistances = Base::MacroCosineDistances<6>; + _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<6>; break; } case 7: @@ -586,7 +586,7 @@ namespace Simd _decode32f = Decode32f7; _decode16f = Decode16f7; _cosineDistance = Base::CosineDistance<7>; - _macroCosineDistances = Base::MacroCosineDistances<7>; + _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<7>; break; } case 8: @@ -596,7 +596,7 @@ namespace Simd _decode32f = Decode32f8; _decode16f = Decode16f8; _cosineDistance = Base::CosineDistance<8>; - _macroCosineDistances = Base::MacroCosineDistances<8>; + _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<8>; break; } default: diff --git a/src/Simd/SimdDescrInt.h b/src/Simd/SimdDescrInt.h index a91bc59eb7..4cb23ae0a7 100644 --- a/src/Simd/SimdDescrInt.h +++ b/src/Simd/SimdDescrInt.h @@ -61,8 +61,7 @@ namespace Simd typedef void (*Decode32fPtr)(const uint8_t * src, float scale, float shift, size_t size, float* dst); typedef void (*Decode16fPtr)(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst); typedef void (*CosineDistancePtr)(const uint8_t* a, const uint8_t* b, size_t size, float* distance); - typedef void (*MacroCosineDistancesPtr)(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - + typedef void (*MacroCosineDistancesDirectPtr)(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); MinMax32fPtr _minMax32f; MinMax16fPtr _minMax16f; @@ -71,8 +70,8 @@ namespace Simd Decode32fPtr _decode32f; Decode16fPtr _decode16f; CosineDistancePtr _cosineDistance; - MacroCosineDistancesPtr _macroCosineDistances; - size_t _size, _depth, _encSize, _microM, _microN; + MacroCosineDistancesDirectPtr _macroCosineDistancesDirect; + size_t _size, _depth, _encSize, _microMd, _microNd; float _range; }; @@ -91,6 +90,10 @@ namespace Simd virtual void CosineDistancesMxNa(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const; virtual void CosineDistancesMxNp(size_t M, size_t N, const uint8_t* A, const uint8_t* B, float* distances) const; + + protected: + void CosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const; + }; //------------------------------------------------------------------------------------------------- diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 1ad8d01033..0e88b9e34e 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -682,9 +682,9 @@ namespace Simd //------------------------------------------------------------------------------------------------- - template void MicroCosineDistances2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template void MicroCosineDistancesDirect2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - template<> void MicroCosineDistances2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m128i a0, a1, b0; @@ -763,7 +763,7 @@ namespace Simd DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } - template<> void MicroCosineDistances2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m128i a00, a01, a10, a11, b00, b01; @@ -843,7 +843,7 @@ namespace Simd DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } - template<> void MicroCosineDistances2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m128i a00, a01, a10, a11, b00, b01; @@ -923,7 +923,7 @@ namespace Simd DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } - template<> void MicroCosineDistances2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m128i a00, a01, a10, a11, b00, b01; @@ -1003,7 +1003,7 @@ namespace Simd DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } - template<> void MicroCosineDistances2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m128i a00, a01, a10, a11, b00, b01; @@ -1083,9 +1083,9 @@ namespace Simd DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); } - template void MicroCosineDistances1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - template<> void MicroCosineDistances1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; __m128i a0, a1, b0; @@ -1143,7 +1143,7 @@ namespace Simd DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m128i a00, a01, b00, b01; @@ -1201,7 +1201,7 @@ namespace Simd DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m128i a00, a01, b00, b01; @@ -1259,7 +1259,7 @@ namespace Simd DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m128i a00, a01, b00, b01; @@ -1317,7 +1317,7 @@ namespace Simd DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template<> void MicroCosineDistances1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size16 = AlignLo(size, 16), o = 16; __m128i a00, a01, b00, b01; @@ -1375,7 +1375,7 @@ namespace Simd DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); } - template void MacroCosineDistances(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t M2 = AlignLoAny(M, 2); size_t N4 = AlignLo(N, 4); @@ -1384,7 +1384,7 @@ namespace Simd { size_t j = 0; for (; j < N4; j += 4) - MicroCosineDistances2x4(A + i, B + j, size, distances + j, stride); + MicroCosineDistancesDirect2x4(A + i, B + j, size, distances + j, stride); for (; j < N; j += 1) { CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); @@ -1396,7 +1396,7 @@ namespace Simd { size_t j = 0; for (; j < N4; j += 4) - MicroCosineDistances1x4(A + i, B + j, size, distances + j, stride); + MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); for (; j < N; j += 1) CosineDistance(A[i], B[j], size, distances + j); distances += 1 * stride; @@ -1410,8 +1410,8 @@ namespace Simd { _minMax32f = MinMax32f; _minMax16f = MinMax16f; - _microM = 2; - _microN = 4; + _microMd = 2; + _microNd = 4; switch (depth) { case 4: @@ -1421,7 +1421,7 @@ namespace Simd _decode32f = Decode32f4; _decode16f = Decode16f4; _cosineDistance = Sse41::CosineDistance<4>; - _macroCosineDistances = Sse41::MacroCosineDistances<4>; + _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<4>; break; } case 5: @@ -1431,7 +1431,7 @@ namespace Simd _decode32f = Decode32f5; _decode16f = Decode16f5; _cosineDistance = Sse41::CosineDistance<5>; - _macroCosineDistances = Sse41::MacroCosineDistances<5>; + _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<5>; break; } case 6: @@ -1441,7 +1441,7 @@ namespace Simd _decode32f = Decode32f6; _decode16f = Decode16f6; _cosineDistance = Sse41::CosineDistance<6>; - _macroCosineDistances = Sse41::MacroCosineDistances<6>; + _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<6>; break; } case 7: @@ -1451,7 +1451,7 @@ namespace Simd _decode32f = Decode32f7; _decode16f = Decode16f7; _cosineDistance = Sse41::CosineDistance<7>; - _macroCosineDistances = Sse41::MacroCosineDistances<7>; + _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<7>; break; } case 8: @@ -1461,7 +1461,7 @@ namespace Simd _decode32f = Decode32f8; _decode16f = Decode16f8; _cosineDistance = Sse41::CosineDistance<8>; - _macroCosineDistances = Sse41::MacroCosineDistances<8>; + _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<8>; break; } default: @@ -1471,40 +1471,32 @@ namespace Simd void DescrInt::CosineDistancesMxNa(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const { - const size_t L2 = Base::AlgCacheL2(); - size_t mN = AlignLoAny(L2 / _encSize, _microN); - size_t mM = AlignLoAny(L2 / _encSize, _microM); - for (size_t i = 0; i < M; i += mM) - { - size_t dM = Simd::Min(M, i + mM) - i; - for (size_t j = 0; j < N; j += mN) - { - size_t dN = Simd::Min(N, j + mN) - j; - _macroCosineDistances(dM, dN, A + i, B + j, _size, distances + i * N + j, N); - } - } + CosineDistancesDirect(M, N, A, B, distances); } void DescrInt::CosineDistancesMxNp(size_t M, size_t N, const uint8_t* A, const uint8_t* B, float* distances) const + { + Array8ucp a(M); + for (size_t i = 0; i < M; ++i) + a[i] = A + i * _encSize; + Array8ucp b(N); + for (size_t j = 0; j < N; ++j) + b[j] = B + j * _encSize; + CosineDistancesMxNa(M, N, a.data, b.data, distances); + } + + void DescrInt::CosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const { const size_t L2 = Base::AlgCacheL2(); - size_t mN = AlignLoAny(L2 / _encSize, _microN); - size_t mM = AlignLoAny(L2 / _encSize, _microM); - Array8ucp ap(mM), bp(N); + size_t mN = AlignLoAny(L2 / _encSize, _microNd); + size_t mM = AlignLoAny(L2 / _encSize, _microMd); for (size_t i = 0; i < M; i += mM) { size_t dM = Simd::Min(M, i + mM) - i; - for (size_t k = 0; k < dM; ++k) - ap[k] = A + (i + k) * _encSize; for (size_t j = 0; j < N; j += mN) { size_t dN = Simd::Min(N, j + mN) - j; - if (i == 0) - { - for (size_t k = j, n = j + dN; k < n; ++k) - bp[k] = B + k * _encSize; - } - _macroCosineDistances(dM, dN, ap.data, bp.data + j, _size, distances + i * N + j, N); + _macroCosineDistancesDirect(dM, dN, A + i, B + j, _size, distances + i * N + j, N); } } } From 307db9895d5e8f92765db631c0cdc814c85dec89 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Mon, 26 Jun 2023 18:11:38 +0300 Subject: [PATCH 28/44] +add Sse41::DescrInt::CosineDistancesUnpack (part 1). --- src/Simd/SimdBaseDescrInt.cpp | 7 ---- src/Simd/SimdDescrInt.h | 17 ++++++-- src/Simd/SimdSse41DescrInt.cpp | 71 +++++++++++++++++++++++++++++++++- 3 files changed, 84 insertions(+), 11 deletions(-) diff --git a/src/Simd/SimdBaseDescrInt.cpp b/src/Simd/SimdBaseDescrInt.cpp index 807377e547..99a232eb6b 100644 --- a/src/Simd/SimdBaseDescrInt.cpp +++ b/src/Simd/SimdBaseDescrInt.cpp @@ -543,8 +543,6 @@ namespace Simd { _encSize = 16 + DivHi(size * depth, 8); _range = float((1 << _depth) - 1); - _microMd = 1; - _microNd = 1; _minMax32f = MinMax32f; _minMax16f = MinMax16f; switch (depth) @@ -556,7 +554,6 @@ namespace Simd _decode32f = Decode32f4; _decode16f = Decode16f4; _cosineDistance = Base::CosineDistance<4>; - _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<4>; break; } case 5: @@ -566,7 +563,6 @@ namespace Simd _decode32f = Decode32f5; _decode16f = Decode16f5; _cosineDistance = Base::CosineDistance<5>; - _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<5>; break; } case 6: @@ -576,7 +572,6 @@ namespace Simd _decode32f = Decode32f6; _decode16f = Decode16f6; _cosineDistance = Base::CosineDistance<6>; - _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<6>; break; } case 7: @@ -586,7 +581,6 @@ namespace Simd _decode32f = Decode32f7; _decode16f = Decode16f7; _cosineDistance = Base::CosineDistance<7>; - _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<7>; break; } case 8: @@ -596,7 +590,6 @@ namespace Simd _decode32f = Decode32f8; _decode16f = Decode16f8; _cosineDistance = Base::CosineDistance<8>; - _macroCosineDistancesDirect = Base::MacroCosineDistancesDirect<8>; break; } default: diff --git a/src/Simd/SimdDescrInt.h b/src/Simd/SimdDescrInt.h index 4cb23ae0a7..eaf3b8a435 100644 --- a/src/Simd/SimdDescrInt.h +++ b/src/Simd/SimdDescrInt.h @@ -61,7 +61,6 @@ namespace Simd typedef void (*Decode32fPtr)(const uint8_t * src, float scale, float shift, size_t size, float* dst); typedef void (*Decode16fPtr)(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst); typedef void (*CosineDistancePtr)(const uint8_t* a, const uint8_t* b, size_t size, float* distance); - typedef void (*MacroCosineDistancesDirectPtr)(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); MinMax32fPtr _minMax32f; MinMax16fPtr _minMax16f; @@ -70,8 +69,7 @@ namespace Simd Decode32fPtr _decode32f; Decode16fPtr _decode16f; CosineDistancePtr _cosineDistance; - MacroCosineDistancesDirectPtr _macroCosineDistancesDirect; - size_t _size, _depth, _encSize, _microMd, _microNd; + size_t _size, _depth, _encSize; float _range; }; @@ -94,6 +92,19 @@ namespace Simd protected: void CosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const; + typedef void (*MacroCosineDistancesDirectPtr)(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + MacroCosineDistancesDirectPtr _macroCosineDistancesDirect; + size_t _microMd, _microNd; + + void CosineDistancesUnpack(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const; + + typedef void (*UnpackNormPtr)(size_t count, const uint8_t* const* src, float* dst, size_t stride); + typedef void (*UnpackDataPtr)(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride); + typedef void (*MacroCosineDistancesUnpackPtr)(size_t M, size_t N, const uint8_t* ad, const float * an, const uint8_t* bd, const float* bn, size_t size, float* distances, size_t stride); + UnpackNormPtr _unpackNormA, _unpackNormB; + UnpackDataPtr _unpackDataA, _unpackDataB; + MacroCosineDistancesUnpackPtr _macroCosineDistancesUnpack; + size_t _microMu, _microNu, _unpSize; }; //------------------------------------------------------------------------------------------------- diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 0e88b9e34e..0447ddd5a3 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -1405,13 +1405,56 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void UnpackNormA(size_t count, const uint8_t* const* src, float* dst, size_t stride) + { + for (size_t i = 0; i < count; ++i) + _mm_storeu_si128((__m128i*)dst + i, _mm_loadu_si128((__m128i*)src[i])); + } + + //------------------------------------------------------------------------------------------------- + + + static void UnpackNormB(size_t count, const uint8_t* const* src, float* dst, size_t stride) + { + size_t count4 = AlignLo(count, 4), i = 0; + for (; i < count4; i += 4, src += 4, dst += 4) + { + __m128 s0 = _mm_loadu_ps((float*)src[0]); + __m128 s1 = _mm_loadu_ps((float*)src[1]); + __m128 s2 = _mm_loadu_ps((float*)src[2]); + __m128 s3 = _mm_loadu_ps((float*)src[3]); + __m128 s00 = _mm_unpacklo_ps(s0, s2); + __m128 s01 = _mm_unpacklo_ps(s1, s3); + __m128 s10 = _mm_unpackhi_ps(s0, s2); + __m128 s11 = _mm_unpackhi_ps(s1, s3); + _mm_storeu_ps(dst + 0 * stride, _mm_unpacklo_ps(s00, s01)); + _mm_storeu_ps(dst + 1 * stride, _mm_unpackhi_ps(s00, s01)); + _mm_storeu_ps(dst + 2 * stride, _mm_unpacklo_ps(s10, s11)); + _mm_storeu_ps(dst + 3 * stride, _mm_unpackhi_ps(s10, s11)); + } + for (; i < count; i++, src++, dst++) + { + dst[0 * stride] = ((float*)src)[0]; + dst[1 * stride] = ((float*)src)[1]; + dst[2 * stride] = ((float*)src)[2]; + dst[3 * stride] = ((float*)src)[3]; + } + } + + //------------------------------------------------------------------------------------------------- + DescrInt::DescrInt(size_t size, size_t depth) : Base::DescrInt(size, depth) { _minMax32f = MinMax32f; _minMax16f = MinMax16f; + _unpackNormA = UnpackNormA; + _unpackNormB = UnpackNormB; _microMd = 2; _microNd = 4; + _unpSize = _size * (_depth == 8 ? 2 : 1); + _microMu = 5; + _microNu = 8; switch (depth) { case 4: @@ -1471,7 +1514,10 @@ namespace Simd void DescrInt::CosineDistancesMxNa(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const { - CosineDistancesDirect(M, N, A, B, distances); + if(_unpSize * _microNu > Base::AlgCacheL1() || N * 2 < _microNu || 1) + CosineDistancesDirect(M, N, A, B, distances); + else + CosineDistancesUnpack(M, N, A, B, distances); } void DescrInt::CosineDistancesMxNp(size_t M, size_t N, const uint8_t* A, const uint8_t* B, float* distances) const @@ -1501,6 +1547,29 @@ namespace Simd } } + void DescrInt::CosineDistancesUnpack(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const + { + size_t macroM = AlignLoAny(Base::AlgCacheL2() / _unpSize, _microMu); + size_t macroN = AlignLoAny(Base::AlgCacheL3() / _unpSize, _microNu); + Array8u dA(Min(macroM, M) * _unpSize); + Array8u dB(Min(macroN, N) * _unpSize); + Array32f nA(Min(macroM, M) * 4); + Array32f nB(AlignHi(Min(macroN, N), _microNu) * 4); + for (size_t i = 0; i < M; i += macroM) + { + size_t dM = Simd::Min(M, i + macroM) - i; + _unpackNormA(dM, A + i, nA.data, 1); + //_unpackDataA(dM, A + i, _size, dA.data, 1); + for (size_t j = 0; j < N; j += macroN) + { + size_t dN = Simd::Min(N, j + macroN) - j; + _unpackNormB(dN, B + j, nB.data, dN); + //_unpackDataB(dN, B + j, _size, dB.data, _microNu); + //_macroCosineDistancesUnpack(dM, dN, dA.data, nA.data, dB.data, nB.data, _size, distances + i * N + j, N); + } + } + } + //------------------------------------------------------------------------------------------------- void* DescrIntInit(size_t size, size_t depth) From 0649860a7d3afcc2203b23e0ed448b3664dcf951 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 27 Jun 2023 10:03:59 +0300 Subject: [PATCH 29/44] +add Sse41::DescrInt::CosineDistancesUnpack (part 2, depth=8). --- src/Simd/SimdDescrInt.h | 2 +- src/Simd/SimdDescrIntCommon.h | 24 ++++ src/Simd/SimdSse41DescrInt.cpp | 214 ++++++++++++++++++++++++++++++++- 3 files changed, 235 insertions(+), 5 deletions(-) diff --git a/src/Simd/SimdDescrInt.h b/src/Simd/SimdDescrInt.h index eaf3b8a435..d73f525f4f 100644 --- a/src/Simd/SimdDescrInt.h +++ b/src/Simd/SimdDescrInt.h @@ -100,7 +100,7 @@ namespace Simd typedef void (*UnpackNormPtr)(size_t count, const uint8_t* const* src, float* dst, size_t stride); typedef void (*UnpackDataPtr)(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride); - typedef void (*MacroCosineDistancesUnpackPtr)(size_t M, size_t N, const uint8_t* ad, const float * an, const uint8_t* bd, const float* bn, size_t size, float* distances, size_t stride); + typedef void (*MacroCosineDistancesUnpackPtr)(size_t M, size_t N, size_t K, const uint8_t* ad, const float * an, const uint8_t* bd, const float* bn, float* distances, size_t stride); UnpackNormPtr _unpackNormA, _unpackNormB; UnpackDataPtr _unpackDataA, _unpackDataB; MacroCosineDistancesUnpackPtr _macroCosineDistancesUnpack; diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index a53751fd47..050ac789ff 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -111,6 +111,30 @@ namespace Simd _mm_storeu_ps(distances, _mm_min_ps(_mm_max_ps(_mm_sub_ps(_mm_set1_ps(1.0f), _mm_div_ps(ab, _mm_mul_ps(aNorm, bNorm))), _mm_setzero_ps()), _mm_set1_ps(2.0f))); } + + SIMD_INLINE void DecodeCosineDistances1x4(const float* a, const float *b, size_t stride, __m128i abSum, float* distances) + { + __m128 aScale = _mm_set1_ps(a[0]); + __m128 aShift = _mm_set1_ps(a[1]); + __m128 aMean = _mm_set1_ps(a[2]); + __m128 aNorm = _mm_set1_ps(a[3]); + __m128 bScale = _mm_loadu_ps(b + 0 * stride); + __m128 bShift = _mm_loadu_ps(b + 1 * stride); + __m128 bMean = _mm_loadu_ps(b + 2 * stride); + __m128 bNorm = _mm_loadu_ps(b + 3 * stride); + __m128 ab = _mm_mul_ps(_mm_cvtepi32_ps(abSum), _mm_mul_ps(aScale, bScale)); + ab = _mm_add_ps(_mm_mul_ps(aMean, bShift), ab); + ab = _mm_add_ps(_mm_mul_ps(bMean, aShift), ab); + _mm_storeu_ps(distances, _mm_min_ps(_mm_max_ps(_mm_sub_ps(_mm_set1_ps(1.0f), _mm_div_ps(ab, _mm_mul_ps(aNorm, bNorm))), _mm_setzero_ps()), _mm_set1_ps(2.0f))); + } + + SIMD_INLINE void DecodeCosineDistances1x4(const float* a, const float* b, size_t stride, __m128i abSum, float* distances, size_t N) + { + float d[4]; + DecodeCosineDistances1x4(a, b, stride, abSum, d); + for (size_t i = 0; i < N; ++i) + distances[i] = d[i]; + } } #endif diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 0447ddd5a3..c55bbb966a 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -1088,7 +1088,7 @@ namespace Simd template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) { size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m128i a0, a1, b0; + __m128i a0, b0; __m128i ab00 = _mm_setzero_si128(); __m128i ab01 = _mm_setzero_si128(); __m128i ab02 = _mm_setzero_si128(); @@ -1443,6 +1443,209 @@ namespace Simd //------------------------------------------------------------------------------------------------- + static void UnpackDataA8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t size16 = AlignLo(size, 16); + for (size_t i = 0, j; i < count; i++) + { + const uint8_t* ps = src[i] + 16; + uint16_t* pd = (uint16_t*)dst + i * size; + for (j = 0; j < size16; j += 16, ps += 16, pd += 16) + { + __m128i s = _mm_loadu_si128((__m128i*)ps); + _mm_storeu_si128((__m128i*)pd + 0, UnpackU8<0>(s)); + _mm_storeu_si128((__m128i*)pd + 1, UnpackU8<1>(s)); + } + for (; j < size; j += 8, ps += 8, pd += 8) + { + __m128i s = _mm_loadl_epi64((__m128i*)ps); + _mm_storeu_si128((__m128i*)pd, UnpackU8<0>(s)); + } + } + } + + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE void UnpackDataB8x4(const uint8_t* const* src, size_t offset, uint8_t* dst) + { + __m128i a0 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[0] + offset))); + __m128i a1 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[1] + offset))); + __m128i a2 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[2] + offset))); + __m128i a3 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[3] + offset))); + __m128i b0 = _mm_unpacklo_epi32(a0, a2); + __m128i b1 = _mm_unpacklo_epi32(a1, a3); + __m128i b2 = _mm_unpackhi_epi32(a0, a2); + __m128i b3 = _mm_unpackhi_epi32(a1, a3); + _mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 2, _mm_unpackhi_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 4, _mm_unpacklo_epi32(b2, b3)); + _mm_storeu_si128((__m128i*)dst + 6, _mm_unpackhi_epi32(b2, b3)); + } + + static void UnpackDataB8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t count8 = AlignLo(count, 8), i; + for (i = 0, size += 16; i < count8; i += 8, src += 8) + { + for (size_t j = 16; j < size; j += 8, dst += 8 * A) + { + UnpackDataB8x4(src + 0, j, dst + 0); + UnpackDataB8x4(src + 4, j, dst + A); + } + } + if (i < count) + { + const uint8_t* _src[8]; + for (size_t j = 0; j < 8; i++, j++) + _src[j] = i < count ? *src++ : src[-1]; + for (size_t j = 16; j < size; j += 8, dst += 8 * A) + { + UnpackDataB8x4(_src + 0, j, dst + 0); + UnpackDataB8x4(_src + 4, j, dst + A); + } + } + } + + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE __m128i Set2(const int16_t* src) + { + return _mm_set1_epi32(*(int32_t*)src); + } + + SIMD_INLINE void Madd2(__m128i& ab, __m128i a, __m128i b) + { + ab = _mm_add_epi32(ab, _mm_madd_epi16(a, b)); + } + + template void Correlation16_2xM(size_t N, size_t K, const int16_t* ad0, const int16_t* bd, const float *an, const float *bn, size_t bnStride, float* distances, size_t stride) + { + __m128i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, ab50, ab51, a0, b0, b1; + const int16_t* ad1 = ad0 + 1 * K; + const int16_t* ad2 = ad0 + 2 * K; + const int16_t* ad3 = ad0 + 3 * K; + const int16_t* ad4 = ad0 + 4 * K; + const int16_t* ad5 = ad0 + 5 * K; + if (N > 4) + { + if (M > 0) ab00 = _mm_setzero_si128(), ab01 = _mm_setzero_si128(); + if (M > 1) ab10 = _mm_setzero_si128(), ab11 = _mm_setzero_si128(); + if (M > 2) ab20 = _mm_setzero_si128(), ab21 = _mm_setzero_si128(); + if (M > 3) ab30 = _mm_setzero_si128(), ab31 = _mm_setzero_si128(); + if (M > 4) ab40 = _mm_setzero_si128(), ab41 = _mm_setzero_si128(); + if (M > 5) ab50 = _mm_setzero_si128(), ab51 = _mm_setzero_si128(); + for (size_t k = 0; k < K; k += 2) + { + b0 = _mm_loadu_si128((__m128i*)bd + 0); + b1 = _mm_loadu_si128((__m128i*)bd + 1); + if (M > 0) a0 = Set2(ad0 + k), Madd2(ab00, a0, b0), Madd2(ab01, a0, b1); + if (M > 1) a0 = Set2(ad1 + k), Madd2(ab10, a0, b0), Madd2(ab11, a0, b1); + if (M > 2) a0 = Set2(ad2 + k), Madd2(ab20, a0, b0), Madd2(ab21, a0, b1); + if (M > 3) a0 = Set2(ad3 + k), Madd2(ab30, a0, b0), Madd2(ab31, a0, b1); + if (M > 4) a0 = Set2(ad4 + k), Madd2(ab40, a0, b0), Madd2(ab41, a0, b1); + if (M > 5) a0 = Set2(ad5 + k), Madd2(ab50, a0, b0), Madd2(ab51, a0, b1); + bd += 16; + } + if (N == 8) + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4), an += 4, distances += stride; + } + else + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4, N - 4), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4, N - 4), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4, N - 4), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4, N - 4), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4, N - 4), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4, N - 4), an += 4, distances += stride; + } + } + else + { + if (M > 0) ab00 = _mm_setzero_si128(); + if (M > 1) ab10 = _mm_setzero_si128(); + if (M > 2) ab20 = _mm_setzero_si128(); + if (M > 3) ab30 = _mm_setzero_si128(); + if (M > 4) ab40 = _mm_setzero_si128(); + if (M > 5) ab50 = _mm_setzero_si128(); + for (size_t k = 0; k < K; k += 2) + { + b0 = _mm_loadu_si128((__m128i*)bd + 0); + if (M > 0) a0 = Set2(ad0 + k), Madd2(ab00, a0, b0); + if (M > 1) a0 = Set2(ad1 + k), Madd2(ab10, a0, b0); + if (M > 2) a0 = Set2(ad2 + k), Madd2(ab20, a0, b0); + if (M > 3) a0 = Set2(ad3 + k), Madd2(ab30, a0, b0); + if (M > 4) a0 = Set2(ad4 + k), Madd2(ab40, a0, b0); + if (M > 5) a0 = Set2(ad5 + k), Madd2(ab50, a0, b0); + bd += 16; + } + if (N == 4) + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), an += 4, distances += stride; + } + else + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0, N), an += 4, distances += stride; + } + } + } + + typedef void(*Correlation16_2xM_Ptr)(size_t N, size_t K, const int16_t* ad0, const int16_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride); + + SIMD_INLINE Correlation16_2xM_Ptr GetCorrelation16_2xM(size_t M) + { + switch (M) + { + case 0: return NULL; + case 1: return Correlation16_2xM<1>; + case 2: return Correlation16_2xM<2>; + case 3: return Correlation16_2xM<3>; + case 4: return Correlation16_2xM<4>; + case 5: return Correlation16_2xM<5>; + case 6: return Correlation16_2xM<6>; + } + assert(0); + return NULL; + } + + void MacroCorrelation16(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride) + { + size_t M6 = AlignLoAny(M, 6); + Correlation16_2xM_Ptr correlation_2x6 = GetCorrelation16_2xM(6); + Correlation16_2xM_Ptr correlation_2xT = GetCorrelation16_2xM(M - M6); + const int16_t* a = (int16_t*)ad; + const int16_t* b = (int16_t*)bd; + for (size_t j = 0; j < N; j += 8) + { + size_t dN = Simd::Min(8, N - j); + size_t i = 0; + for (; i < M6; i += 6) + correlation_2x6(dN, K, a + i * K, b, an + i * 4, bn, N, distances + i * stride, stride); + if(i < M) + correlation_2xT(dN, K, a + i * K, b, an + i * 4, bn, N, distances + i * stride, stride); + b += K * 8; + bn += 8; + distances += 8; + } + } + + //------------------------------------------------------------------------------------------------- + DescrInt::DescrInt(size_t size, size_t depth) : Base::DescrInt(size, depth) { @@ -1505,6 +1708,9 @@ namespace Simd _decode16f = Decode16f8; _cosineDistance = Sse41::CosineDistance<8>; _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<8>; + _unpackDataA = UnpackDataA8; + _unpackDataB = UnpackDataB8; + _macroCosineDistancesUnpack = MacroCorrelation16; break; } default: @@ -1559,13 +1765,13 @@ namespace Simd { size_t dM = Simd::Min(M, i + macroM) - i; _unpackNormA(dM, A + i, nA.data, 1); - //_unpackDataA(dM, A + i, _size, dA.data, 1); + _unpackDataA(dM, A + i, _size, dA.data, _unpSize); for (size_t j = 0; j < N; j += macroN) { size_t dN = Simd::Min(N, j + macroN) - j; _unpackNormB(dN, B + j, nB.data, dN); - //_unpackDataB(dN, B + j, _size, dB.data, _microNu); - //_macroCosineDistancesUnpack(dM, dN, dA.data, nA.data, dB.data, nB.data, _size, distances + i * N + j, N); + _unpackDataB(dN, B + j, _size, dB.data, 1); + _macroCosineDistancesUnpack(dM, dN, _size, dA.data, nA.data, dB.data, nB.data, distances + i * N + j, N); } } } From 12812098f09642c9d9155b52ddc13e2f64bbd53f Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 27 Jun 2023 12:18:44 +0300 Subject: [PATCH 30/44] *improve SSE4.1 optimizations of functions DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa for 7-bit depth. --- docs/2023.html | 5 + prj/vs2019/Sse41.vcxproj | 4 + prj/vs2019/Sse41.vcxproj.filters | 12 + prj/vs2022/Sse41.vcxproj | 4 + prj/vs2022/Sse41.vcxproj.filters | 12 + src/Simd/SimdDescrInt.h | 31 +- src/Simd/SimdSse41DescrInt.cpp | 1613 +---------------------------- src/Simd/SimdSse41DescrIntDec.cpp | 232 +++++ src/Simd/SimdSse41DescrIntEnc.cpp | 382 +++++++ src/Simd/SimdSse41DescrIntScd.cpp | 916 ++++++++++++++++ src/Simd/SimdSse41DescrIntScu.cpp | 468 +++++++++ 11 files changed, 2075 insertions(+), 1604 deletions(-) create mode 100644 src/Simd/SimdSse41DescrIntDec.cpp create mode 100644 src/Simd/SimdSse41DescrIntEnc.cpp create mode 100644 src/Simd/SimdSse41DescrIntScd.cpp create mode 100644 src/Simd/SimdSse41DescrIntScu.cpp diff --git a/docs/2023.html b/docs/2023.html index c3b284b54c..00e519968e 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -48,6 +48,11 @@
        New features
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa.
      • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function SynetNormalizeLayerForwardV3.
      +
      Improving
      +
        +
      • SSE4.1 optimizations of function DescrIntCosineDistancesMxNp for 7-bit depth.
      • +
      • SSE4.1 optimizations of function DescrIntCosineDistancesMxNa for 7-bit depth.
      • +
      Bug fixing
      • Compiler error in file SimdYuvToBgr.h.
      • diff --git a/prj/vs2019/Sse41.vcxproj b/prj/vs2019/Sse41.vcxproj index 83398668ff..a181de12c6 100644 --- a/prj/vs2019/Sse41.vcxproj +++ b/prj/vs2019/Sse41.vcxproj @@ -35,6 +35,10 @@ + + + + diff --git a/prj/vs2019/Sse41.vcxproj.filters b/prj/vs2019/Sse41.vcxproj.filters index ba52d13052..4e224f91b1 100644 --- a/prj/vs2019/Sse41.vcxproj.filters +++ b/prj/vs2019/Sse41.vcxproj.filters @@ -376,6 +376,18 @@ Sse41 + + Sse41 + + + Sse41 + + + Sse41 + + + Sse41 + diff --git a/prj/vs2022/Sse41.vcxproj b/prj/vs2022/Sse41.vcxproj index 83398668ff..a181de12c6 100644 --- a/prj/vs2022/Sse41.vcxproj +++ b/prj/vs2022/Sse41.vcxproj @@ -35,6 +35,10 @@ + + + + diff --git a/prj/vs2022/Sse41.vcxproj.filters b/prj/vs2022/Sse41.vcxproj.filters index ba52d13052..4e224f91b1 100644 --- a/prj/vs2022/Sse41.vcxproj.filters +++ b/prj/vs2022/Sse41.vcxproj.filters @@ -376,6 +376,18 @@ Sse41 + + Sse41 + + + Sse41 + + + Sse41 + + + Sse41 + diff --git a/src/Simd/SimdDescrInt.h b/src/Simd/SimdDescrInt.h index d73f525f4f..6f0ecdd048 100644 --- a/src/Simd/SimdDescrInt.h +++ b/src/Simd/SimdDescrInt.h @@ -53,15 +53,16 @@ namespace Simd void VectorNorm(const uint8_t* a, float* norm) const; - protected: - typedef void (*MinMax32fPtr)(const float* src, size_t size, float &min, float &max); - typedef void (*MinMax16fPtr)(const uint16_t* src, size_t size, float& min, float& max); typedef void (*Encode32fPtr)(const float* src, float scale, float min, size_t size, int32_t &sum, int32_t& sqsum, uint8_t* dst); typedef void (*Encode16fPtr)(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst); typedef void (*Decode32fPtr)(const uint8_t * src, float scale, float shift, size_t size, float* dst); typedef void (*Decode16fPtr)(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst); typedef void (*CosineDistancePtr)(const uint8_t* a, const uint8_t* b, size_t size, float* distance); + protected: + typedef void (*MinMax32fPtr)(const float* src, size_t size, float &min, float &max); + typedef void (*MinMax16fPtr)(const uint16_t* src, size_t size, float& min, float& max); + MinMax32fPtr _minMax32f; MinMax16fPtr _minMax16f; Encode32fPtr _encode32f; @@ -89,18 +90,20 @@ namespace Simd virtual void CosineDistancesMxNa(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const; virtual void CosineDistancesMxNp(size_t M, size_t N, const uint8_t* A, const uint8_t* B, float* distances) const; + typedef void (*MacroCosineDistancesDirectPtr)(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + typedef void (*UnpackDataPtr)(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride); + typedef void (*MacroCosineDistancesUnpackPtr)(size_t M, size_t N, size_t K, const uint8_t* ad, const float * an, const uint8_t* bd, const float* bn, float* distances, size_t stride); + protected: + typedef void (*UnpackNormPtr)(size_t count, const uint8_t* const* src, float* dst, size_t stride); + void CosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const; - typedef void (*MacroCosineDistancesDirectPtr)(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); MacroCosineDistancesDirectPtr _macroCosineDistancesDirect; size_t _microMd, _microNd; void CosineDistancesUnpack(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const; - typedef void (*UnpackNormPtr)(size_t count, const uint8_t* const* src, float* dst, size_t stride); - typedef void (*UnpackDataPtr)(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride); - typedef void (*MacroCosineDistancesUnpackPtr)(size_t M, size_t N, size_t K, const uint8_t* ad, const float * an, const uint8_t* bd, const float* bn, float* distances, size_t stride); UnpackNormPtr _unpackNormA, _unpackNormB; UnpackDataPtr _unpackDataA, _unpackDataB; MacroCosineDistancesUnpackPtr _macroCosineDistancesUnpack; @@ -109,6 +112,20 @@ namespace Simd //------------------------------------------------------------------------------------------------- + Base::DescrInt::Encode32fPtr GetEncode32f(size_t depth); + Base::DescrInt::Encode16fPtr GetEncode16f(size_t depth); + + Base::DescrInt::Decode32fPtr GetDecode32f(size_t depth); + Base::DescrInt::Decode16fPtr GetDecode16f(size_t depth); + + Base::DescrInt::CosineDistancePtr GetCosineDistance(size_t depth); + Sse41::DescrInt::MacroCosineDistancesDirectPtr GetMacroCosineDistancesDirect(size_t depth); + + Sse41::DescrInt::UnpackDataPtr GetUnpackData(size_t depth, bool transpose); + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth); + + //------------------------------------------------------------------------------------------------- + void* DescrIntInit(size_t size, size_t depth); } #endif diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index c55bbb966a..85efd6742a 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -73,1338 +73,6 @@ namespace Simd //------------------------------------------------------------------------------------------------- - SIMD_INLINE __m128i Encode32f(__m128 src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i value = _mm_cvtps_epi32(_mm_mul_ps(_mm_sub_ps(src, min), scale)); - sum = _mm_add_epi32(value, sum); - sqsum = _mm_add_epi32(_mm_madd_epi16(value, value), sqsum); - return value; - } - - SIMD_INLINE __m128i Encode32f(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - return Encode32f(_mm_loadu_ps(src), scale, min, sum, sqsum); - } - - static SIMD_INLINE __m128i Encode32f4(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); - return _mm_srli_epi32(_mm_mullo_epi16(_mm_packus_epi32(i0, i4), E4_MULLO), 12); - } - - static SIMD_INLINE __m128i Encode32f4x8(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i s0 = Encode32f4(src + 0 * 8, scale, min, sum, sqsum); - return _mm_packus_epi16(_mm_packus_epi32(s0, K_ZERO), K_ZERO); - } - - static SIMD_INLINE __m128i Encode32f4x16(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i s0 = Encode32f4(src + 0 * 8, scale, min, sum, sqsum); - __m128i s1 = Encode32f4(src + 1 * 8, scale, min, sum, sqsum); - return _mm_packus_epi16(_mm_packus_epi32(s0, s1), K_ZERO); - } - - static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, size16 = AlignLo(size, 16); - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < size16; i += 16, src += 16, dst += 8) - _mm_storel_epi64((__m128i*)dst, Encode32f4x16(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 4) - *(uint32_t*)(dst) = _mm_extract_epi32(Encode32f4x8(src, _scale, _min, _sum, _sqsum), 0); - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f5(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E5_MULLO); - return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, E5_SHFL0), _mm_shuffle_epi8(s0, E5_SHFL1)), _mm_shuffle_epi8(s0, E5_SHFL2)); - } - - static void Encode32f5(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8; - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < main; i += 8, src += 8, dst += 5) - _mm_storel_epi64((__m128i*)dst, Encode32f5(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 5) - { - __m128i d0 = Encode32f5(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); - } - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f6(const float* src, __m128 scale, __m128 min, __m128i & sum, __m128i & sqsum) - { - __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E6_MULLO); - return _mm_or_si128(_mm_shuffle_epi8(s0, E6_SHFL0), _mm_shuffle_epi8(s0, E6_SHFL1)); - } - - static void Encode32f6(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8; - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < main; i += 8, src += 8, dst += 6) - _mm_storel_epi64((__m128i*)dst, Encode32f6(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 6) - { - __m128i d0 = Encode32f6(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); - } - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f7(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E7_MULLO); - return _mm_or_si128(_mm_shuffle_epi8(s0, E7_SHFL0), _mm_shuffle_epi8(s0, E7_SHFL1)); - } - - static void Encode32f7(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8; - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < main; i += 8, src += 8, dst += 7) - _mm_storel_epi64((__m128i*)dst, Encode32f7(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 7) - { - __m128i d0 = Encode32f7(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); - *(uint8_t*)(dst + 6) = _mm_extract_epi8(d0, 6); - } - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - static void Encode32f8(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t sizeA = AlignLo(size, A), i = 0; - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < sizeA; i += A) - { - __m128i d0 = Encode32f(src + i + 0 * F, _scale, _min, _sum, _sqsum); - __m128i d1 = Encode32f(src + i + 1 * F, _scale, _min, _sum, _sqsum); - __m128i d2 = Encode32f(src + i + 2 * F, _scale, _min, _sum, _sqsum); - __m128i d3 = Encode32f(src + i + 3 * F, _scale, _min, _sum, _sqsum); - _mm_storeu_si128((__m128i*)(dst + i), _mm_packus_epi16(_mm_packus_epi32(d0, d1), _mm_packus_epi32(d2, d3))); - } - for (; i < size; i += F) - { - __m128i d0 = Encode32f(src + i, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + i) = _mm_cvtsi128_si32(_mm_packus_epi16(_mm_packus_epi32(d0, _mm_setzero_si128()), _mm_setzero_si128())); - } - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - //------------------------------------------------------------------------------------------------- - - static SIMD_INLINE __m128i Encode16f4(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i u0 = _mm_loadu_si128((__m128i*)(src)); - __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); - __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); - return _mm_srli_epi32(_mm_mullo_epi16(_mm_packus_epi32(i0, i4), E4_MULLO), 12); - } - - static SIMD_INLINE __m128i Encode16f4x8(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i s0 = Encode16f4(src + 0 * 8, scale, min, sum, sqsum); - return _mm_packus_epi16(_mm_packus_epi32(s0, K_ZERO), K_ZERO); - } - - static SIMD_INLINE __m128i Encode16f4x16(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i s0 = Encode16f4(src + 0 * 8, scale, min, sum, sqsum); - __m128i s1 = Encode16f4(src + 1 * 8, scale, min, sum, sqsum); - return _mm_packus_epi16(_mm_packus_epi32(s0, s1), K_ZERO); - } - - static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, size16 = AlignLo(size, 16); - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < size16; i += 16, src += 16, dst += 8) - _mm_storel_epi64((__m128i*)dst, Encode16f4x16(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 4) - *(uint32_t*)(dst) = _mm_extract_epi32(Encode16f4x8(src, _scale, _min, _sum, _sqsum), 0); - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f5(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i u0 = _mm_loadu_si128((__m128i*)(src)); - __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); - __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E5_MULLO); - return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, E5_SHFL0), _mm_shuffle_epi8(s0, E5_SHFL1)), _mm_shuffle_epi8(s0, E5_SHFL2)); - } - - static void Encode16f5(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8; - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < main; i += 8, src += 8, dst += 5) - _mm_storel_epi64((__m128i*)dst, Encode16f5(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 5) - { - __m128i d0 = Encode16f5(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); - } - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f6(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i u0 = _mm_loadu_si128((__m128i*)(src)); - __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); - __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E6_MULLO); - return _mm_or_si128(_mm_shuffle_epi8(s0, E6_SHFL0), _mm_shuffle_epi8(s0, E6_SHFL1)); - } - - static void Encode16f6(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8; - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < main; i += 8, src += 8, dst += 6) - _mm_storel_epi64((__m128i*)dst, Encode16f6(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 6) - { - __m128i d0 = Encode16f6(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); - } - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f7(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) - { - __m128i u0 = _mm_loadu_si128((__m128i*)(src)); - __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); - __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E7_MULLO); - return _mm_or_si128(_mm_shuffle_epi8(s0, E7_SHFL0), _mm_shuffle_epi8(s0, E7_SHFL1)); - } - - static void Encode16f7(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8; - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < main; i += 8, src += 8, dst += 7) - _mm_storel_epi64((__m128i*)dst, Encode16f7(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 7) - { - __m128i d0 = Encode16f7(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); - *(uint8_t*)(dst + 6) = _mm_extract_epi8(d0, 6); - } - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - static void Encode16f8(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t sizeA = AlignLo(size, A), i = 0; - __m128 _scale = _mm_set1_ps(scale); - __m128 _min = _mm_set1_ps(min); - __m128i _sum = _mm_setzero_si128(); - __m128i _sqsum = _mm_setzero_si128(); - for (; i < sizeA; i += A) - { - __m128i u0 = _mm_loadu_si128((__m128i*)(src + i + 0 * F)); - __m128i d0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), _scale, _min, _sum, _sqsum); - __m128i d1 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), _scale, _min, _sum, _sqsum); - __m128i u2 = _mm_loadu_si128((__m128i*)(src + i + 2 * F)); - __m128i d2 = Encode32f(Float16ToFloat32(UnpackU16<0>(u2)), _scale, _min, _sum, _sqsum); - __m128i d3 = Encode32f(Float16ToFloat32(UnpackU16<1>(u2)), _scale, _min, _sum, _sqsum); - _mm_storeu_si128((__m128i*)(dst + i), _mm_packus_epi16(_mm_packus_epi32(d0, d1), _mm_packus_epi32(d2, d3))); - } - for (; i < size; i += F) - { - __m128i u0 = _mm_loadl_epi64((__m128i*)(src + i)); - __m128i d0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + i) = _mm_cvtsi128_si32(_mm_packus_epi16(_mm_packus_epi32(d0, _mm_setzero_si128()), _mm_setzero_si128())); - } - sum = ExtractInt32Sum(_sum); - sqsum = ExtractInt32Sum(_sqsum); - } - - //------------------------------------------------------------------------------------------------- - - static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - for (size_t i = 0; i < size; i += 8) - { - __m128i s4 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, C4_SHFL0), C4_MULLO), 12); - _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); - _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); - src += 4; - dst += 8; - } - } - - static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - for (size_t i = 0; i < size; i += 8) - { - __m128i s5 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, C5_SHFL0), C5_MULLO), 11); - _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); - _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); - src += 5; - dst += 8; - } - } - - static void Decode32f6(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - for (size_t i = 0; i < size; i += 8) - { - __m128i s6 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, C6_SHFL0), C6_MULLO), 10); - _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); - _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); - src += 6; - dst += 8; - } - } - - static void Decode32f7(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - for (size_t i = 0; i < size; i += 8) - { - __m128i s7 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, C7_SHFL0), C7_MULLO), 9); - _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); - _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); - src += 7; - dst += 8; - } - } - - static void Decode32f8(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - size_t i = 0; - for (; i < size; i += 4) - { - __m128 _src = _mm_cvtepi32_ps(_mm_cvtepu8_epi32(_mm_cvtsi32_si128(*(uint32_t*)(src + i)))); - _mm_storeu_ps(dst + i, _mm_add_ps(_mm_mul_ps(_src, _scale), _shift)); - } - } - - //------------------------------------------------------------------------------------------------- - - static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - for (size_t i = 0; i < size; i += 8) - { - __m128i s4 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, C4_SHFL0), C4_MULLO), 12); - __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); - __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); - _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); - src += 4; - dst += 8; - } - } - - static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - for (size_t i = 0; i < size; i += 8) - { - __m128i s5 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, C5_SHFL0), C5_MULLO), 11); - __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); - __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); - _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); - src += 5; - dst += 8; - } - } - - static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - for (size_t i = 0; i < size; i += 8) - { - __m128i s6 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, C6_SHFL0), C6_MULLO), 10); - __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); - __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); - _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); - src += 6; - dst += 8; - } - } - - static void Decode16f7(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - for (size_t i = 0; i < size; i += 8) - { - __m128i s7 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, C7_SHFL0), C7_MULLO), 9); - __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); - __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); - _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); - src += 7; - dst += 8; - } - } - - static void Decode16f8(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m128 _scale = _mm_set1_ps(scale); - __m128 _shift = _mm_set1_ps(shift); - size_t i = 0; - for (; i < size; i += 8) - { - __m128i s8 = _mm_loadl_epi64((__m128i*)(src + i)); - __m128 s0 = _mm_cvtepi32_ps(_mm_cvtepu8_epi32(_mm_srli_si128(s8, 0))); - __m128 s4 = _mm_cvtepi32_ps(_mm_cvtepu8_epi32(_mm_srli_si128(s8, 4))); - __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(s0, _scale), _shift)); - __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(s4, _scale), _shift)); - _mm_storeu_si128((__m128i*)(dst + i), _mm_packus_epi32(d0, d4)); - } - } - - //------------------------------------------------------------------------------------------------- - - template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); - - template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m128i ab32 = _mm_setzero_si128(); - size_t i = 0, size32 = AlignLo(size, 32); - for (; i < size32; i += 32, a += 16, b += 16) - { - __m128i _a = _mm_loadu_si128((__m128i*)a); - __m128i _b = _mm_loadu_si128((__m128i*)b); - __m128i ab16 = _mm_maddubs_epi16(_mm_and_si128(_a, K8_0F), _mm_and_si128(_b, K8_0F)); - ab16 = _mm_add_epi16(ab16, _mm_maddubs_epi16(_mm_and_si128(_mm_srli_epi16(_a, 4), K8_0F), _mm_and_si128(_mm_srli_epi16(_b, 4), K8_0F))); - ab32 = _mm_add_epi32(ab32, _mm_madd_epi16(ab16, K16_0001)); - } - for (; i < size; i += 8, a += 4, b += 4) - { - __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C4_SHFL0), C4_MULLO), 12); - __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C4_SHFL0), C4_MULLO), 12); - ab32 = _mm_add_epi32(_mm_madd_epi16(_a, _b), ab32); - } - return ExtractInt32Sum(ab32); - } - - template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m128i _ab = _mm_setzero_si128(); - size_t i = 0, sizeA = AlignLo(size, A); - for (; i < sizeA; i += A, a += 10, b += 10) - { - __m128i _a = _mm_loadu_si128((__m128i*)a); - __m128i _b = _mm_loadu_si128((__m128i*)b); - __m128i a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C5_SHFL0), C5_MULLO), 11); - __m128i b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C5_SHFL0), C5_MULLO), 11); - _ab = _mm_add_epi32(_mm_madd_epi16(a0, b0), _ab); - __m128i a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C5_SHFL1), C5_MULLO), 11); - __m128i b1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C5_SHFL1), C5_MULLO), 11); - _ab = _mm_add_epi32(_mm_madd_epi16(a1, b1), _ab); - } - for (; i < size; i += 8, a += 5, b += 5) - { - __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C5_SHFL0), C5_MULLO), 11); - __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C5_SHFL0), C5_MULLO), 11); - _ab = _mm_add_epi32(_mm_madd_epi16(_a, _b), _ab); - } - return ExtractInt32Sum(_ab); - } - - template<> int32_t Correlation<6>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m128i _ab = _mm_setzero_si128(); - size_t i = 0, sizeA = AlignLo(size, A); - for (; i < sizeA; i += A, a += 12, b += 12) - { - __m128i _a = _mm_loadu_si128((__m128i*)a); - __m128i _b = _mm_loadu_si128((__m128i*)b); - __m128i a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C6_SHFL0), C6_MULLO), 10); - __m128i b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C6_SHFL0), C6_MULLO), 10); - _ab = _mm_add_epi32(_mm_madd_epi16(a0, b0), _ab); - __m128i a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C6_SHFL1), C6_MULLO), 10); - __m128i b1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C6_SHFL1), C6_MULLO), 10); - _ab = _mm_add_epi32(_mm_madd_epi16(a1, b1), _ab); - } - for (; i < size; i += 8, a += 6, b += 6) - { - __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C6_SHFL0), C6_MULLO), 10); - __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C6_SHFL0), C6_MULLO), 10); - _ab = _mm_add_epi32(_mm_madd_epi16(_a, _b), _ab); - } - return ExtractInt32Sum(_ab); - } - - template<> int32_t Correlation<7>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m128i _ab = _mm_setzero_si128(); - size_t i = 0, sizeA = AlignLo(size, A); - for (; i < sizeA; i += A, a += 14, b += 14) - { - __m128i _a = _mm_loadu_si128((__m128i*)a); - __m128i _b = _mm_loadu_si128((__m128i*)b); - __m128i a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C7_SHFL0), C7_MULLO), 9); - __m128i b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C7_SHFL0), C7_MULLO), 9); - _ab = _mm_add_epi32(_mm_madd_epi16(a0, b0), _ab); - __m128i a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C7_SHFL1), C7_MULLO), 9); - __m128i b1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C7_SHFL1), C7_MULLO), 9); - _ab = _mm_add_epi32(_mm_madd_epi16(a1, b1), _ab); - } - for (; i < size; i += 8, a += 7, b += 7) - { - __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C7_SHFL0), C7_MULLO), 9); - __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C7_SHFL0), C7_MULLO), 9); - _ab = _mm_add_epi32(_mm_madd_epi16(_a, _b), _ab); - } - return ExtractInt32Sum(_ab); - } - - template<> int32_t Correlation<8>(const uint8_t* a, const uint8_t* b, size_t size) - { - size_t i = 0, sizeA = AlignLo(size, A); - __m128i _ab = _mm_setzero_si128(); - for (; i < sizeA; i += A) - { - __m128i _a = _mm_loadu_si128((__m128i*)(a + i)); - __m128i _b = _mm_loadu_si128((__m128i*)(b + i)); - _ab = _mm_add_epi32(_mm_madd_epi16(UnpackU8<0>(_a), UnpackU8<0>(_b)), _ab); - _ab = _mm_add_epi32(_mm_madd_epi16(UnpackU8<1>(_a), UnpackU8<1>(_b)), _ab); - } - for (; i < size; i += 8) - { - __m128i _a = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(a + i))); - __m128i _b = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(b + i))); - _ab = _mm_add_epi32(_mm_madd_epi16(_a, _b), _ab); - } - return ExtractInt32Sum(_ab); - } - - template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) - { - float abSum = (float)Correlation(a + 16, b + 16, size); - Base::DecodeCosineDistance(a, b, abSum, distance); - } - - //------------------------------------------------------------------------------------------------- - - template void MicroCosineDistancesDirect2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - - template<> void MicroCosineDistancesDirect2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m128i a0, a1, b0; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - __m128i ab10 = _mm_setzero_si128(); - __m128i ab11 = _mm_setzero_si128(); - __m128i ab12 = _mm_setzero_si128(); - __m128i ab13 = _mm_setzero_si128(); - for (; i < size32; i += 32, o += 16) - { - a0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[0] + o)), K8_0F); - a1 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[1] + o)), K8_0F); - - b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[0] + o)), K8_0F); - ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - ab10 = _mm_add_epi32(ab10, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[1] + o)), K8_0F); - ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - ab11 = _mm_add_epi32(ab11, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[2] + o)), K8_0F); - ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - ab12 = _mm_add_epi32(ab12, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[3] + o)), K8_0F); - ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - ab13 = _mm_add_epi32(ab13, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); - - a0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[0] + o)), 4), K8_0F); - a1 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[1] + o)), 4), K8_0F); - - b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[0] + o)), 4), K8_0F); - ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - ab10 = _mm_add_epi32(ab10, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[1] + o)), 4), K8_0F); - ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - ab11 = _mm_add_epi32(ab11, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[2] + o)), 4), K8_0F); - ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - ab12 = _mm_add_epi32(ab12, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[3] + o)), 4), K8_0F); - ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - ab13 = _mm_add_epi32(ab13, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); - } - for (; i < size; i += 8, o += 4) - { - a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C4_SHFL0), C4_MULLO), 12); - a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C4_SHFL0), C4_MULLO), 12); - - b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C4_SHFL0), C4_MULLO), 12); - ab00 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab10); - - b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C4_SHFL0), C4_MULLO), 12); - ab01 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab11); - - b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C4_SHFL0), C4_MULLO), 12); - ab02 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab12); - - b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C4_SHFL0), C4_MULLO), 12); - ab03 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab13); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - } - - template<> void MicroCosineDistancesDirect2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m128i a00, a01, a10, a11, b00, b01; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - __m128i ab10 = _mm_setzero_si128(); - __m128i ab11 = _mm_setzero_si128(); - __m128i ab12 = _mm_setzero_si128(); - __m128i ab13 = _mm_setzero_si128(); - for (; i < size16; i += 16, o += 10) - { - a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C5_SHFL0), C5_MULLO), 11); - a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 5), C5_SHFL0), C5_MULLO), 11); - a11 = _mm_loadu_si128((__m128i*)(A[1] + o)); - a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a11, C5_SHFL0), C5_MULLO), 11); - a11 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a11, 5), C5_SHFL0), C5_MULLO), 11); - - b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); - ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab10); - - b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); - ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab11); - - b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); - ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab12); - - b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); - ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab13); - } - for (; i < size; i += 8, o += 5) - { - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C5_SHFL0), C5_MULLO), 11); - a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C5_SHFL0), C5_MULLO), 11); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C5_SHFL0), C5_MULLO), 11); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C5_SHFL0), C5_MULLO), 11); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C5_SHFL0), C5_MULLO), 11); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C5_SHFL0), C5_MULLO), 11); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - } - - template<> void MicroCosineDistancesDirect2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m128i a00, a01, a10, a11, b00, b01; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - __m128i ab10 = _mm_setzero_si128(); - __m128i ab11 = _mm_setzero_si128(); - __m128i ab12 = _mm_setzero_si128(); - __m128i ab13 = _mm_setzero_si128(); - for (; i < size16; i += 16, o += 12) - { - a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C6_SHFL0), C6_MULLO), 10); - a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 6), C6_SHFL0), C6_MULLO), 10); - a11 = _mm_loadu_si128((__m128i*)(A[1] + o)); - a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a11, C6_SHFL0), C6_MULLO), 10); - a11 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a11, 6), C6_SHFL0), C6_MULLO), 10); - - b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); - ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab10); - - b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); - ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab11); - - b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); - ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab12); - - b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); - ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab13); - } - for (; i < size; i += 8, o += 6) - { - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C6_SHFL0), C6_MULLO), 10); - a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C6_SHFL0), C6_MULLO), 10); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C6_SHFL0), C6_MULLO), 10); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C6_SHFL0), C6_MULLO), 10); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C6_SHFL0), C6_MULLO), 10); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C6_SHFL0), C6_MULLO), 10); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - } - - template<> void MicroCosineDistancesDirect2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m128i a00, a01, a10, a11, b00, b01; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - __m128i ab10 = _mm_setzero_si128(); - __m128i ab11 = _mm_setzero_si128(); - __m128i ab12 = _mm_setzero_si128(); - __m128i ab13 = _mm_setzero_si128(); - for (; i < size16; i += 16, o += 14) - { - a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C7_SHFL0), C7_MULLO), 9); - a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 7), C7_SHFL0), C7_MULLO), 9); - a11 = _mm_loadu_si128((__m128i*)(A[1] + o)); - a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a11, C7_SHFL0), C7_MULLO), 9); - a11 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a11, 7), C7_SHFL0), C7_MULLO), 9); - - b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); - ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab10); - - b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); - ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab11); - - b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); - ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab12); - - b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); - ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab13); - } - for (; i < size; i += 8, o += 7) - { - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C7_SHFL0), C7_MULLO), 9); - a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C7_SHFL0), C7_MULLO), 9); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C7_SHFL0), C7_MULLO), 9); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C7_SHFL0), C7_MULLO), 9); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C7_SHFL0), C7_MULLO), 9); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C7_SHFL0), C7_MULLO), 9); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - } - - template<> void MicroCosineDistancesDirect2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m128i a00, a01, a10, a11, b00, b01; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - __m128i ab10 = _mm_setzero_si128(); - __m128i ab11 = _mm_setzero_si128(); - __m128i ab12 = _mm_setzero_si128(); - __m128i ab13 = _mm_setzero_si128(); - for (; i < size16; i += 16, o += 16) - { - a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); - a00 = UnpackU8<0>(a01); - a01 = UnpackU8<1>(a01); - a11 = _mm_loadu_si128((__m128i*)(A[1] + o)); - a10 = UnpackU8<0>(a11); - a11 = UnpackU8<1>(a11); - - b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); - b00 = UnpackU8<0>(b01); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); - b00 = UnpackU8<1>(b01); - ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab10); - - b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); - b00 = UnpackU8<0>(b01); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); - b00 = UnpackU8<1>(b01); - ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab11); - - b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); - b00 = UnpackU8<0>(b01); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); - b00 = UnpackU8<1>(b01); - ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab12); - - b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); - b00 = UnpackU8<0>(b01); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); - b00 = UnpackU8<1>(b01); - ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab13); - } - for (; i < size; i += 8, o += 8) - { - a00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[0] + o))); - a10 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[1] + o))); - - b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[0] + o))); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); - - b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[1] + o))); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); - - b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[2] + o))); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); - - b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[3] + o))); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - } - - template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - - template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m128i a0, b0; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - for (; i < size32; i += 32, o += 16) - { - a0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[0] + o)), K8_0F); - - b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[0] + o)), K8_0F); - ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[1] + o)), K8_0F); - ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[2] + o)), K8_0F); - ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[3] + o)), K8_0F); - ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - - a0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[0] + o)), 4), K8_0F); - - b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[0] + o)), 4), K8_0F); - ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[1] + o)), 4), K8_0F); - ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[2] + o)), 4), K8_0F); - ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[3] + o)), 4), K8_0F); - ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); - } - for (; i < size; i += 8, o += 4) - { - a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C4_SHFL0), C4_MULLO), 12); - - b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C4_SHFL0), C4_MULLO), 12); - ab00 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab00); - - b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C4_SHFL0), C4_MULLO), 12); - ab01 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab01); - - b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C4_SHFL0), C4_MULLO), 12); - ab02 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab02); - - b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C4_SHFL0), C4_MULLO), 12); - ab03 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m128i a00, a01, b00, b01; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - for (; i < size16; i += 16, o += 10) - { - a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C5_SHFL0), C5_MULLO), 11); - a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 5), C5_SHFL0), C5_MULLO), 11); - - b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); - ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); - - b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); - ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); - - b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); - ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); - - b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); - ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); - } - for (; i < size; i += 8, o += 5) - { - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C5_SHFL0), C5_MULLO), 11); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C5_SHFL0), C5_MULLO), 11); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C5_SHFL0), C5_MULLO), 11); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C5_SHFL0), C5_MULLO), 11); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C5_SHFL0), C5_MULLO), 11); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m128i a00, a01, b00, b01; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - for (; i < size16; i += 16, o += 12) - { - a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C6_SHFL0), C6_MULLO), 10); - a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 6), C6_SHFL0), C6_MULLO), 10); - - b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); - ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); - - b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); - ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); - - b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); - ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); - - b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); - ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); - } - for (; i < size; i += 8, o += 6) - { - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C6_SHFL0), C6_MULLO), 10); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C6_SHFL0), C6_MULLO), 10); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C6_SHFL0), C6_MULLO), 10); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C6_SHFL0), C6_MULLO), 10); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C6_SHFL0), C6_MULLO), 10); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m128i a00, a01, b00, b01; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - for (; i < size16; i += 16, o += 14) - { - a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C7_SHFL0), C7_MULLO), 9); - a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 7), C7_SHFL0), C7_MULLO), 9); - - b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); - ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); - - b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); - ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); - - b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); - ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); - - b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); - ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); - } - for (; i < size; i += 8, o += 7) - { - a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C7_SHFL0), C7_MULLO), 9); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C7_SHFL0), C7_MULLO), 9); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C7_SHFL0), C7_MULLO), 9); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C7_SHFL0), C7_MULLO), 9); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - - b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C7_SHFL0), C7_MULLO), 9); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m128i a00, a01, b00, b01; - __m128i ab00 = _mm_setzero_si128(); - __m128i ab01 = _mm_setzero_si128(); - __m128i ab02 = _mm_setzero_si128(); - __m128i ab03 = _mm_setzero_si128(); - for (; i < size16; i += 16, o += 16) - { - a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); - a00 = UnpackU8<0>(a01); - a01 = UnpackU8<1>(a01); - - b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); - b00 = UnpackU8<0>(b01); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - b00 = UnpackU8<1>(b01); - ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); - - b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); - b00 = UnpackU8<0>(b01); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - b00 = UnpackU8<1>(b01); - ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); - - b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); - b00 = UnpackU8<0>(b01); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - b00 = UnpackU8<1>(b01); - ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); - - b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); - b00 = UnpackU8<0>(b01); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - b00 = UnpackU8<1>(b01); - ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); - } - for (; i < size; i += 8, o += 8) - { - a00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[0] + o))); - - b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[0] + o))); - ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); - - b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[1] + o))); - ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); - - b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[2] + o))); - ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); - - b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[3] + o))); - ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t M2 = AlignLoAny(M, 2); - size_t N4 = AlignLo(N, 4); - size_t i = 0; - for (; i < M2; i += 2) - { - size_t j = 0; - for (; j < N4; j += 4) - MicroCosineDistancesDirect2x4(A + i, B + j, size, distances + j, stride); - for (; j < N; j += 1) - { - CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); - CosineDistance(A[i + 1], B[j], size, distances + j + 1 * stride); - } - distances += 2 * stride; - } - for (; i < M; i++) - { - size_t j = 0; - for (; j < N4; j += 4) - MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); - for (; j < N; j += 1) - CosineDistance(A[i], B[j], size, distances + j); - distances += 1 * stride; - } - } - - //------------------------------------------------------------------------------------------------- - static void UnpackNormA(size_t count, const uint8_t* const* src, float* dst, size_t stride) { for (size_t i = 0; i < count; ++i) @@ -1443,284 +111,35 @@ namespace Simd //------------------------------------------------------------------------------------------------- - static void UnpackDataA8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) - { - size_t size16 = AlignLo(size, 16); - for (size_t i = 0, j; i < count; i++) - { - const uint8_t* ps = src[i] + 16; - uint16_t* pd = (uint16_t*)dst + i * size; - for (j = 0; j < size16; j += 16, ps += 16, pd += 16) - { - __m128i s = _mm_loadu_si128((__m128i*)ps); - _mm_storeu_si128((__m128i*)pd + 0, UnpackU8<0>(s)); - _mm_storeu_si128((__m128i*)pd + 1, UnpackU8<1>(s)); - } - for (; j < size; j += 8, ps += 8, pd += 8) - { - __m128i s = _mm_loadl_epi64((__m128i*)ps); - _mm_storeu_si128((__m128i*)pd, UnpackU8<0>(s)); - } - } - } - - //------------------------------------------------------------------------------------------------- - - SIMD_INLINE void UnpackDataB8x4(const uint8_t* const* src, size_t offset, uint8_t* dst) - { - __m128i a0 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[0] + offset))); - __m128i a1 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[1] + offset))); - __m128i a2 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[2] + offset))); - __m128i a3 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[3] + offset))); - __m128i b0 = _mm_unpacklo_epi32(a0, a2); - __m128i b1 = _mm_unpacklo_epi32(a1, a3); - __m128i b2 = _mm_unpackhi_epi32(a0, a2); - __m128i b3 = _mm_unpackhi_epi32(a1, a3); - _mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1)); - _mm_storeu_si128((__m128i*)dst + 2, _mm_unpackhi_epi32(b0, b1)); - _mm_storeu_si128((__m128i*)dst + 4, _mm_unpacklo_epi32(b2, b3)); - _mm_storeu_si128((__m128i*)dst + 6, _mm_unpackhi_epi32(b2, b3)); - } - - static void UnpackDataB8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) - { - size_t count8 = AlignLo(count, 8), i; - for (i = 0, size += 16; i < count8; i += 8, src += 8) - { - for (size_t j = 16; j < size; j += 8, dst += 8 * A) - { - UnpackDataB8x4(src + 0, j, dst + 0); - UnpackDataB8x4(src + 4, j, dst + A); - } - } - if (i < count) - { - const uint8_t* _src[8]; - for (size_t j = 0; j < 8; i++, j++) - _src[j] = i < count ? *src++ : src[-1]; - for (size_t j = 16; j < size; j += 8, dst += 8 * A) - { - UnpackDataB8x4(_src + 0, j, dst + 0); - UnpackDataB8x4(_src + 4, j, dst + A); - } - } - } - - //------------------------------------------------------------------------------------------------- - - SIMD_INLINE __m128i Set2(const int16_t* src) - { - return _mm_set1_epi32(*(int32_t*)src); - } - - SIMD_INLINE void Madd2(__m128i& ab, __m128i a, __m128i b) - { - ab = _mm_add_epi32(ab, _mm_madd_epi16(a, b)); - } - - template void Correlation16_2xM(size_t N, size_t K, const int16_t* ad0, const int16_t* bd, const float *an, const float *bn, size_t bnStride, float* distances, size_t stride) - { - __m128i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, ab50, ab51, a0, b0, b1; - const int16_t* ad1 = ad0 + 1 * K; - const int16_t* ad2 = ad0 + 2 * K; - const int16_t* ad3 = ad0 + 3 * K; - const int16_t* ad4 = ad0 + 4 * K; - const int16_t* ad5 = ad0 + 5 * K; - if (N > 4) - { - if (M > 0) ab00 = _mm_setzero_si128(), ab01 = _mm_setzero_si128(); - if (M > 1) ab10 = _mm_setzero_si128(), ab11 = _mm_setzero_si128(); - if (M > 2) ab20 = _mm_setzero_si128(), ab21 = _mm_setzero_si128(); - if (M > 3) ab30 = _mm_setzero_si128(), ab31 = _mm_setzero_si128(); - if (M > 4) ab40 = _mm_setzero_si128(), ab41 = _mm_setzero_si128(); - if (M > 5) ab50 = _mm_setzero_si128(), ab51 = _mm_setzero_si128(); - for (size_t k = 0; k < K; k += 2) - { - b0 = _mm_loadu_si128((__m128i*)bd + 0); - b1 = _mm_loadu_si128((__m128i*)bd + 1); - if (M > 0) a0 = Set2(ad0 + k), Madd2(ab00, a0, b0), Madd2(ab01, a0, b1); - if (M > 1) a0 = Set2(ad1 + k), Madd2(ab10, a0, b0), Madd2(ab11, a0, b1); - if (M > 2) a0 = Set2(ad2 + k), Madd2(ab20, a0, b0), Madd2(ab21, a0, b1); - if (M > 3) a0 = Set2(ad3 + k), Madd2(ab30, a0, b0), Madd2(ab31, a0, b1); - if (M > 4) a0 = Set2(ad4 + k), Madd2(ab40, a0, b0), Madd2(ab41, a0, b1); - if (M > 5) a0 = Set2(ad5 + k), Madd2(ab50, a0, b0), Madd2(ab51, a0, b1); - bd += 16; - } - if (N == 8) - { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4), an += 4, distances += stride; - if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4), an += 4, distances += stride; - } - else - { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4, N - 4), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4, N - 4), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4, N - 4), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4, N - 4), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4, N - 4), an += 4, distances += stride; - if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4, N - 4), an += 4, distances += stride; - } - } - else - { - if (M > 0) ab00 = _mm_setzero_si128(); - if (M > 1) ab10 = _mm_setzero_si128(); - if (M > 2) ab20 = _mm_setzero_si128(); - if (M > 3) ab30 = _mm_setzero_si128(); - if (M > 4) ab40 = _mm_setzero_si128(); - if (M > 5) ab50 = _mm_setzero_si128(); - for (size_t k = 0; k < K; k += 2) - { - b0 = _mm_loadu_si128((__m128i*)bd + 0); - if (M > 0) a0 = Set2(ad0 + k), Madd2(ab00, a0, b0); - if (M > 1) a0 = Set2(ad1 + k), Madd2(ab10, a0, b0); - if (M > 2) a0 = Set2(ad2 + k), Madd2(ab20, a0, b0); - if (M > 3) a0 = Set2(ad3 + k), Madd2(ab30, a0, b0); - if (M > 4) a0 = Set2(ad4 + k), Madd2(ab40, a0, b0); - if (M > 5) a0 = Set2(ad5 + k), Madd2(ab50, a0, b0); - bd += 16; - } - if (N == 4) - { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; - if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), an += 4, distances += stride; - } - else - { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; - if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0, N), an += 4, distances += stride; - } - } - } - - typedef void(*Correlation16_2xM_Ptr)(size_t N, size_t K, const int16_t* ad0, const int16_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride); - - SIMD_INLINE Correlation16_2xM_Ptr GetCorrelation16_2xM(size_t M) - { - switch (M) - { - case 0: return NULL; - case 1: return Correlation16_2xM<1>; - case 2: return Correlation16_2xM<2>; - case 3: return Correlation16_2xM<3>; - case 4: return Correlation16_2xM<4>; - case 5: return Correlation16_2xM<5>; - case 6: return Correlation16_2xM<6>; - } - assert(0); - return NULL; - } - - void MacroCorrelation16(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride) - { - size_t M6 = AlignLoAny(M, 6); - Correlation16_2xM_Ptr correlation_2x6 = GetCorrelation16_2xM(6); - Correlation16_2xM_Ptr correlation_2xT = GetCorrelation16_2xM(M - M6); - const int16_t* a = (int16_t*)ad; - const int16_t* b = (int16_t*)bd; - for (size_t j = 0; j < N; j += 8) - { - size_t dN = Simd::Min(8, N - j); - size_t i = 0; - for (; i < M6; i += 6) - correlation_2x6(dN, K, a + i * K, b, an + i * 4, bn, N, distances + i * stride, stride); - if(i < M) - correlation_2xT(dN, K, a + i * K, b, an + i * 4, bn, N, distances + i * stride, stride); - b += K * 8; - bn += 8; - distances += 8; - } - } - - //------------------------------------------------------------------------------------------------- - DescrInt::DescrInt(size_t size, size_t depth) : Base::DescrInt(size, depth) { _minMax32f = MinMax32f; _minMax16f = MinMax16f; - _unpackNormA = UnpackNormA; - _unpackNormB = UnpackNormB; + _encode32f = GetEncode32f(_depth); + _encode16f = GetEncode16f(_depth); + + _decode32f = GetDecode32f(_depth); + _decode16f = GetDecode16f(_depth); + + _cosineDistance = GetCosineDistance(_depth); + _macroCosineDistancesDirect = GetMacroCosineDistancesDirect(_depth); _microMd = 2; _microNd = 4; + + _unpackNormA = UnpackNormA; + _unpackNormB = UnpackNormB; + _unpackDataA = GetUnpackData(_depth, false); + _unpackDataB = GetUnpackData(_depth, true); + _macroCosineDistancesUnpack = GetMacroCosineDistancesUnpack(_depth); _unpSize = _size * (_depth == 8 ? 2 : 1); - _microMu = 5; + _microMu = _depth == 8 ? 6 : 5; _microNu = 8; - switch (depth) - { - case 4: - { - _encode32f = Encode32f4; - _encode16f = Encode16f4; - _decode32f = Decode32f4; - _decode16f = Decode16f4; - _cosineDistance = Sse41::CosineDistance<4>; - _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<4>; - break; - } - case 5: - { - _encode32f = Encode32f5; - _encode16f = Encode16f5; - _decode32f = Decode32f5; - _decode16f = Decode16f5; - _cosineDistance = Sse41::CosineDistance<5>; - _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<5>; - break; - } - case 6: - { - _encode32f = Encode32f6; - _encode16f = Encode16f6; - _decode32f = Decode32f6; - _decode16f = Decode16f6; - _cosineDistance = Sse41::CosineDistance<6>; - _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<6>; - break; - } - case 7: - { - _encode32f = Encode32f7; - _encode16f = Encode16f7; - _decode32f = Decode32f7; - _decode16f = Decode16f7; - _cosineDistance = Sse41::CosineDistance<7>; - _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<7>; - break; - } - case 8: - { - _encode32f = Encode32f8; - _encode16f = Encode16f8; - _decode32f = Decode32f8; - _decode16f = Decode16f8; - _cosineDistance = Sse41::CosineDistance<8>; - _macroCosineDistancesDirect = Sse41::MacroCosineDistancesDirect<8>; - _unpackDataA = UnpackDataA8; - _unpackDataB = UnpackDataB8; - _macroCosineDistancesUnpack = MacroCorrelation16; - break; - } - default: - assert(0); - } } void DescrInt::CosineDistancesMxNa(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const { - if(_unpSize * _microNu > Base::AlgCacheL1() || N * 2 < _microNu || 1) + if(_unpSize * _microNu > Base::AlgCacheL1() || N * 2 < _microNu || _depth < 7 || _depth == 8) CosineDistancesDirect(M, N, A, B, distances); else CosineDistancesUnpack(M, N, A, B, distances); diff --git a/src/Simd/SimdSse41DescrIntDec.cpp b/src/Simd/SimdSse41DescrIntDec.cpp new file mode 100644 index 0000000000..8c680648e5 --- /dev/null +++ b/src/Simd/SimdSse41DescrIntDec.cpp @@ -0,0 +1,232 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdFloat16.h" + +namespace Simd +{ +#ifdef SIMD_SSE41_ENABLE + namespace Sse41 + { + static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, C4_SHFL0), C4_MULLO), 12); + _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + src += 4; + dst += 8; + } + } + + static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, C5_SHFL0), C5_MULLO), 11); + _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + src += 5; + dst += 8; + } + } + + static void Decode32f6(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s6 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, C6_SHFL0), C6_MULLO), 10); + _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + src += 6; + dst += 8; + } + } + + static void Decode32f7(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s7 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, C7_SHFL0), C7_MULLO), 9); + _mm_storeu_ps(dst + 0, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + _mm_storeu_ps(dst + 4, _mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + src += 7; + dst += 8; + } + } + + static void Decode32f8(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + size_t i = 0; + for (; i < size; i += 4) + { + __m128 _src = _mm_cvtepi32_ps(_mm_cvtepu8_epi32(_mm_cvtsi32_si128(*(uint32_t*)(src + i)))); + _mm_storeu_ps(dst + i, _mm_add_ps(_mm_mul_ps(_src, _scale), _shift)); + } + } + + //------------------------------------------------------------------------------------------------- + + static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, C4_SHFL0), C4_MULLO), 12); + __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); + src += 4; + dst += 8; + } + } + + static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, C5_SHFL0), C5_MULLO), 11); + __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); + src += 5; + dst += 8; + } + } + + static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s6 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, C6_SHFL0), C6_MULLO), 10); + __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); + src += 6; + dst += 8; + } + } + + static void Decode16f7(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + for (size_t i = 0; i < size; i += 8) + { + __m128i s7 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, C7_SHFL0), C7_MULLO), 9); + __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<0>(s16)), _scale), _shift)); + __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(_mm_cvtepi32_ps(UnpackU16<1>(s16)), _scale), _shift)); + _mm_storeu_si128((__m128i*)dst, _mm_packus_epi32(d0, d4)); + src += 7; + dst += 8; + } + } + + static void Decode16f8(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m128 _scale = _mm_set1_ps(scale); + __m128 _shift = _mm_set1_ps(shift); + size_t i = 0; + for (; i < size; i += 8) + { + __m128i s8 = _mm_loadl_epi64((__m128i*)(src + i)); + __m128 s0 = _mm_cvtepi32_ps(_mm_cvtepu8_epi32(_mm_srli_si128(s8, 0))); + __m128 s4 = _mm_cvtepi32_ps(_mm_cvtepu8_epi32(_mm_srli_si128(s8, 4))); + __m128i d0 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(s0, _scale), _shift)); + __m128i d4 = Float32ToFloat16(_mm_add_ps(_mm_mul_ps(s4, _scale), _shift)); + _mm_storeu_si128((__m128i*)(dst + i), _mm_packus_epi32(d0, d4)); + } + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::Decode32fPtr GetDecode32f(size_t depth) + { + switch (depth) + { + case 4: return Decode32f4; + case 5: return Decode32f5; + case 6: return Decode32f6; + case 7: return Decode32f7; + case 8: return Decode32f8; + default: assert(0); return NULL; + } + } + + Base::DescrInt::Decode16fPtr GetDecode16f(size_t depth) + { + switch (depth) + { + case 4: return Decode16f4; + case 5: return Decode16f5; + case 6: return Decode16f6; + case 7: return Decode16f7; + case 8: return Decode16f8; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdSse41DescrIntEnc.cpp b/src/Simd/SimdSse41DescrIntEnc.cpp new file mode 100644 index 0000000000..92411100ef --- /dev/null +++ b/src/Simd/SimdSse41DescrIntEnc.cpp @@ -0,0 +1,382 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdFloat16.h" + +namespace Simd +{ +#ifdef SIMD_SSE41_ENABLE + namespace Sse41 + { + SIMD_INLINE __m128i Encode32f(__m128 src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i value = _mm_cvtps_epi32(_mm_mul_ps(_mm_sub_ps(src, min), scale)); + sum = _mm_add_epi32(value, sum); + sqsum = _mm_add_epi32(_mm_madd_epi16(value, value), sqsum); + return value; + } + + SIMD_INLINE __m128i Encode32f(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + return Encode32f(_mm_loadu_ps(src), scale, min, sum, sqsum); + } + + static SIMD_INLINE __m128i Encode32f4(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); + return _mm_srli_epi32(_mm_mullo_epi16(_mm_packus_epi32(i0, i4), E4_MULLO), 12); + } + + static SIMD_INLINE __m128i Encode32f4x8(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i s0 = Encode32f4(src + 0 * 8, scale, min, sum, sqsum); + return _mm_packus_epi16(_mm_packus_epi32(s0, K_ZERO), K_ZERO); + } + + static SIMD_INLINE __m128i Encode32f4x16(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i s0 = Encode32f4(src + 0 * 8, scale, min, sum, sqsum); + __m128i s1 = Encode32f4(src + 1 * 8, scale, min, sum, sqsum); + return _mm_packus_epi16(_mm_packus_epi32(s0, s1), K_ZERO); + } + + static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size16 = AlignLo(size, 16); + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < size16; i += 16, src += 16, dst += 8) + _mm_storel_epi64((__m128i*)dst, Encode32f4x16(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 4) + *(uint32_t*)(dst) = _mm_extract_epi32(Encode32f4x8(src, _scale, _min, _sum, _sqsum), 0); + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f5(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E5_MULLO); + return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, E5_SHFL0), _mm_shuffle_epi8(s0, E5_SHFL1)), _mm_shuffle_epi8(s0, E5_SHFL2)); + } + + static void Encode32f5(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8; + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < main; i += 8, src += 8, dst += 5) + _mm_storel_epi64((__m128i*)dst, Encode32f5(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 5) + { + __m128i d0 = Encode32f5(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); + } + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f6(const float* src, __m128 scale, __m128 min, __m128i & sum, __m128i & sqsum) + { + __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E6_MULLO); + return _mm_or_si128(_mm_shuffle_epi8(s0, E6_SHFL0), _mm_shuffle_epi8(s0, E6_SHFL1)); + } + + static void Encode32f6(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8; + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < main; i += 8, src += 8, dst += 6) + _mm_storel_epi64((__m128i*)dst, Encode32f6(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 6) + { + __m128i d0 = Encode32f6(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); + } + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f7(const float* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i i4 = Encode32f(src + 4, scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E7_MULLO); + return _mm_or_si128(_mm_shuffle_epi8(s0, E7_SHFL0), _mm_shuffle_epi8(s0, E7_SHFL1)); + } + + static void Encode32f7(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8; + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < main; i += 8, src += 8, dst += 7) + _mm_storel_epi64((__m128i*)dst, Encode32f7(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 7) + { + __m128i d0 = Encode32f7(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); + *(uint8_t*)(dst + 6) = _mm_extract_epi8(d0, 6); + } + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + static void Encode32f8(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t sizeA = AlignLo(size, A), i = 0; + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < sizeA; i += A) + { + __m128i d0 = Encode32f(src + i + 0 * F, _scale, _min, _sum, _sqsum); + __m128i d1 = Encode32f(src + i + 1 * F, _scale, _min, _sum, _sqsum); + __m128i d2 = Encode32f(src + i + 2 * F, _scale, _min, _sum, _sqsum); + __m128i d3 = Encode32f(src + i + 3 * F, _scale, _min, _sum, _sqsum); + _mm_storeu_si128((__m128i*)(dst + i), _mm_packus_epi16(_mm_packus_epi32(d0, d1), _mm_packus_epi32(d2, d3))); + } + for (; i < size; i += F) + { + __m128i d0 = Encode32f(src + i, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + i) = _mm_cvtsi128_si32(_mm_packus_epi16(_mm_packus_epi32(d0, _mm_setzero_si128()), _mm_setzero_si128())); + } + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + //------------------------------------------------------------------------------------------------- + + static SIMD_INLINE __m128i Encode16f4(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i u0 = _mm_loadu_si128((__m128i*)(src)); + __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); + __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); + return _mm_srli_epi32(_mm_mullo_epi16(_mm_packus_epi32(i0, i4), E4_MULLO), 12); + } + + static SIMD_INLINE __m128i Encode16f4x8(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i s0 = Encode16f4(src + 0 * 8, scale, min, sum, sqsum); + return _mm_packus_epi16(_mm_packus_epi32(s0, K_ZERO), K_ZERO); + } + + static SIMD_INLINE __m128i Encode16f4x16(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i s0 = Encode16f4(src + 0 * 8, scale, min, sum, sqsum); + __m128i s1 = Encode16f4(src + 1 * 8, scale, min, sum, sqsum); + return _mm_packus_epi16(_mm_packus_epi32(s0, s1), K_ZERO); + } + + static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size16 = AlignLo(size, 16); + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < size16; i += 16, src += 16, dst += 8) + _mm_storel_epi64((__m128i*)dst, Encode16f4x16(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 4) + *(uint32_t*)(dst) = _mm_extract_epi32(Encode16f4x8(src, _scale, _min, _sum, _sqsum), 0); + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f5(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i u0 = _mm_loadu_si128((__m128i*)(src)); + __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); + __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E5_MULLO); + return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, E5_SHFL0), _mm_shuffle_epi8(s0, E5_SHFL1)), _mm_shuffle_epi8(s0, E5_SHFL2)); + } + + static void Encode16f5(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8; + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < main; i += 8, src += 8, dst += 5) + _mm_storel_epi64((__m128i*)dst, Encode16f5(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 5) + { + __m128i d0 = Encode16f5(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); + } + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f6(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i u0 = _mm_loadu_si128((__m128i*)(src)); + __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); + __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E6_MULLO); + return _mm_or_si128(_mm_shuffle_epi8(s0, E6_SHFL0), _mm_shuffle_epi8(s0, E6_SHFL1)); + } + + static void Encode16f6(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8; + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < main; i += 8, src += 8, dst += 6) + _mm_storel_epi64((__m128i*)dst, Encode16f6(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 6) + { + __m128i d0 = Encode16f6(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); + } + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f7(const uint16_t* src, __m128 scale, __m128 min, __m128i& sum, __m128i& sqsum) + { + __m128i u0 = _mm_loadu_si128((__m128i*)(src)); + __m128i i0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), scale, min, sum, sqsum); + __m128i i4 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm_packus_epi32(i0, i4), E7_MULLO); + return _mm_or_si128(_mm_shuffle_epi8(s0, E7_SHFL0), _mm_shuffle_epi8(s0, E7_SHFL1)); + } + + static void Encode16f7(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8; + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < main; i += 8, src += 8, dst += 7) + _mm_storel_epi64((__m128i*)dst, Encode16f7(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 7) + { + __m128i d0 = Encode16f7(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); + *(uint8_t*)(dst + 6) = _mm_extract_epi8(d0, 6); + } + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + static void Encode16f8(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t sizeA = AlignLo(size, A), i = 0; + __m128 _scale = _mm_set1_ps(scale); + __m128 _min = _mm_set1_ps(min); + __m128i _sum = _mm_setzero_si128(); + __m128i _sqsum = _mm_setzero_si128(); + for (; i < sizeA; i += A) + { + __m128i u0 = _mm_loadu_si128((__m128i*)(src + i + 0 * F)); + __m128i d0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), _scale, _min, _sum, _sqsum); + __m128i d1 = Encode32f(Float16ToFloat32(UnpackU16<1>(u0)), _scale, _min, _sum, _sqsum); + __m128i u2 = _mm_loadu_si128((__m128i*)(src + i + 2 * F)); + __m128i d2 = Encode32f(Float16ToFloat32(UnpackU16<0>(u2)), _scale, _min, _sum, _sqsum); + __m128i d3 = Encode32f(Float16ToFloat32(UnpackU16<1>(u2)), _scale, _min, _sum, _sqsum); + _mm_storeu_si128((__m128i*)(dst + i), _mm_packus_epi16(_mm_packus_epi32(d0, d1), _mm_packus_epi32(d2, d3))); + } + for (; i < size; i += F) + { + __m128i u0 = _mm_loadl_epi64((__m128i*)(src + i)); + __m128i d0 = Encode32f(Float16ToFloat32(UnpackU16<0>(u0)), _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + i) = _mm_cvtsi128_si32(_mm_packus_epi16(_mm_packus_epi32(d0, _mm_setzero_si128()), _mm_setzero_si128())); + } + sum = ExtractInt32Sum(_sum); + sqsum = ExtractInt32Sum(_sqsum); + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::Encode32fPtr GetEncode32f(size_t depth) + { + switch (depth) + { + case 4: return Encode32f4; + case 5: return Encode32f5; + case 6: return Encode32f6; + case 7: return Encode32f7; + case 8: return Encode32f8; + default: assert(0); return NULL; + } + } + + Base::DescrInt::Encode16fPtr GetEncode16f(size_t depth) + { + switch (depth) + { + case 4: return Encode16f4; + case 5: return Encode16f5; + case 6: return Encode16f6; + case 7: return Encode16f7; + case 8: return Encode16f8; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdSse41DescrIntScd.cpp b/src/Simd/SimdSse41DescrIntScd.cpp new file mode 100644 index 0000000000..bb8592f254 --- /dev/null +++ b/src/Simd/SimdSse41DescrIntScd.cpp @@ -0,0 +1,916 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdFloat16.h" + +namespace Simd +{ +#ifdef SIMD_SSE41_ENABLE + namespace Sse41 + { + template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); + + template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m128i ab32 = _mm_setzero_si128(); + size_t i = 0, size32 = AlignLo(size, 32); + for (; i < size32; i += 32, a += 16, b += 16) + { + __m128i _a = _mm_loadu_si128((__m128i*)a); + __m128i _b = _mm_loadu_si128((__m128i*)b); + __m128i ab16 = _mm_maddubs_epi16(_mm_and_si128(_a, K8_0F), _mm_and_si128(_b, K8_0F)); + ab16 = _mm_add_epi16(ab16, _mm_maddubs_epi16(_mm_and_si128(_mm_srli_epi16(_a, 4), K8_0F), _mm_and_si128(_mm_srli_epi16(_b, 4), K8_0F))); + ab32 = _mm_add_epi32(ab32, _mm_madd_epi16(ab16, K16_0001)); + } + for (; i < size; i += 8, a += 4, b += 4) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C4_SHFL0), C4_MULLO), 12); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C4_SHFL0), C4_MULLO), 12); + ab32 = _mm_add_epi32(_mm_madd_epi16(_a, _b), ab32); + } + return ExtractInt32Sum(ab32); + } + + template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m128i _ab = _mm_setzero_si128(); + size_t i = 0, sizeA = AlignLo(size, A); + for (; i < sizeA; i += A, a += 10, b += 10) + { + __m128i _a = _mm_loadu_si128((__m128i*)a); + __m128i _b = _mm_loadu_si128((__m128i*)b); + __m128i a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C5_SHFL0), C5_MULLO), 11); + __m128i b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C5_SHFL0), C5_MULLO), 11); + _ab = _mm_add_epi32(_mm_madd_epi16(a0, b0), _ab); + __m128i a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C5_SHFL1), C5_MULLO), 11); + __m128i b1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C5_SHFL1), C5_MULLO), 11); + _ab = _mm_add_epi32(_mm_madd_epi16(a1, b1), _ab); + } + for (; i < size; i += 8, a += 5, b += 5) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C5_SHFL0), C5_MULLO), 11); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C5_SHFL0), C5_MULLO), 11); + _ab = _mm_add_epi32(_mm_madd_epi16(_a, _b), _ab); + } + return ExtractInt32Sum(_ab); + } + + template<> int32_t Correlation<6>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m128i _ab = _mm_setzero_si128(); + size_t i = 0, sizeA = AlignLo(size, A); + for (; i < sizeA; i += A, a += 12, b += 12) + { + __m128i _a = _mm_loadu_si128((__m128i*)a); + __m128i _b = _mm_loadu_si128((__m128i*)b); + __m128i a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C6_SHFL0), C6_MULLO), 10); + __m128i b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C6_SHFL0), C6_MULLO), 10); + _ab = _mm_add_epi32(_mm_madd_epi16(a0, b0), _ab); + __m128i a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C6_SHFL1), C6_MULLO), 10); + __m128i b1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C6_SHFL1), C6_MULLO), 10); + _ab = _mm_add_epi32(_mm_madd_epi16(a1, b1), _ab); + } + for (; i < size; i += 8, a += 6, b += 6) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C6_SHFL0), C6_MULLO), 10); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C6_SHFL0), C6_MULLO), 10); + _ab = _mm_add_epi32(_mm_madd_epi16(_a, _b), _ab); + } + return ExtractInt32Sum(_ab); + } + + template<> int32_t Correlation<7>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m128i _ab = _mm_setzero_si128(); + size_t i = 0, sizeA = AlignLo(size, A); + for (; i < sizeA; i += A, a += 14, b += 14) + { + __m128i _a = _mm_loadu_si128((__m128i*)a); + __m128i _b = _mm_loadu_si128((__m128i*)b); + __m128i a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C7_SHFL0), C7_MULLO), 9); + __m128i b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C7_SHFL0), C7_MULLO), 9); + _ab = _mm_add_epi32(_mm_madd_epi16(a0, b0), _ab); + __m128i a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_a, C7_SHFL1), C7_MULLO), 9); + __m128i b1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_b, C7_SHFL1), C7_MULLO), 9); + _ab = _mm_add_epi32(_mm_madd_epi16(a1, b1), _ab); + } + for (; i < size; i += 8, a += 7, b += 7) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), C7_SHFL0), C7_MULLO), 9); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), C7_SHFL0), C7_MULLO), 9); + _ab = _mm_add_epi32(_mm_madd_epi16(_a, _b), _ab); + } + return ExtractInt32Sum(_ab); + } + + template<> int32_t Correlation<8>(const uint8_t* a, const uint8_t* b, size_t size) + { + size_t i = 0, sizeA = AlignLo(size, A); + __m128i _ab = _mm_setzero_si128(); + for (; i < sizeA; i += A) + { + __m128i _a = _mm_loadu_si128((__m128i*)(a + i)); + __m128i _b = _mm_loadu_si128((__m128i*)(b + i)); + _ab = _mm_add_epi32(_mm_madd_epi16(UnpackU8<0>(_a), UnpackU8<0>(_b)), _ab); + _ab = _mm_add_epi32(_mm_madd_epi16(UnpackU8<1>(_a), UnpackU8<1>(_b)), _ab); + } + for (; i < size; i += 8) + { + __m128i _a = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(a + i))); + __m128i _b = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(b + i))); + _ab = _mm_add_epi32(_mm_madd_epi16(_a, _b), _ab); + } + return ExtractInt32Sum(_ab); + } + + template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) + { + float abSum = (float)Correlation(a + 16, b + 16, size); + Base::DecodeCosineDistance(a, b, abSum, distance); + } + + //------------------------------------------------------------------------------------------------- + + template void MicroCosineDistancesDirect2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + + template<> void MicroCosineDistancesDirect2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m128i a0, a1, b0; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + __m128i ab10 = _mm_setzero_si128(); + __m128i ab11 = _mm_setzero_si128(); + __m128i ab12 = _mm_setzero_si128(); + __m128i ab13 = _mm_setzero_si128(); + for (; i < size32; i += 32, o += 16) + { + a0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[0] + o)), K8_0F); + a1 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[1] + o)), K8_0F); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[0] + o)), K8_0F); + ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab10 = _mm_add_epi32(ab10, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[1] + o)), K8_0F); + ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab11 = _mm_add_epi32(ab11, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[2] + o)), K8_0F); + ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab12 = _mm_add_epi32(ab12, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[3] + o)), K8_0F); + ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab13 = _mm_add_epi32(ab13, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + a0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[0] + o)), 4), K8_0F); + a1 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[1] + o)), 4), K8_0F); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[0] + o)), 4), K8_0F); + ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab10 = _mm_add_epi32(ab10, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[1] + o)), 4), K8_0F); + ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab11 = _mm_add_epi32(ab11, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[2] + o)), 4), K8_0F); + ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab12 = _mm_add_epi32(ab12, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[3] + o)), 4), K8_0F); + ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + ab13 = _mm_add_epi32(ab13, _mm_madd_epi16(_mm_maddubs_epi16(a1, b0), K16_0001)); + } + for (; i < size; i += 8, o += 4) + { + a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C4_SHFL0), C4_MULLO), 12); + a1 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C4_SHFL0), C4_MULLO), 12); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C4_SHFL0), C4_MULLO), 12); + ab00 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab10); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C4_SHFL0), C4_MULLO), 12); + ab01 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab11); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C4_SHFL0), C4_MULLO), 12); + ab02 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab12); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C4_SHFL0), C4_MULLO), 12); + ab03 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a1, b0), ab13); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + } + + template<> void MicroCosineDistancesDirect2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m128i a00, a01, a10, a11, b00, b01; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + __m128i ab10 = _mm_setzero_si128(); + __m128i ab11 = _mm_setzero_si128(); + __m128i ab12 = _mm_setzero_si128(); + __m128i ab13 = _mm_setzero_si128(); + for (; i < size16; i += 16, o += 10) + { + a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C5_SHFL0), C5_MULLO), 11); + a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 5), C5_SHFL0), C5_MULLO), 11); + a11 = _mm_loadu_si128((__m128i*)(A[1] + o)); + a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a11, C5_SHFL0), C5_MULLO), 11); + a11 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a11, 5), C5_SHFL0), C5_MULLO), 11); + + b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); + ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab10); + + b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); + ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab11); + + b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); + ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab12); + + b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); + ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab13); + } + for (; i < size; i += 8, o += 5) + { + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C5_SHFL0), C5_MULLO), 11); + a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C5_SHFL0), C5_MULLO), 11); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C5_SHFL0), C5_MULLO), 11); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C5_SHFL0), C5_MULLO), 11); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C5_SHFL0), C5_MULLO), 11); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C5_SHFL0), C5_MULLO), 11); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + } + + template<> void MicroCosineDistancesDirect2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m128i a00, a01, a10, a11, b00, b01; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + __m128i ab10 = _mm_setzero_si128(); + __m128i ab11 = _mm_setzero_si128(); + __m128i ab12 = _mm_setzero_si128(); + __m128i ab13 = _mm_setzero_si128(); + for (; i < size16; i += 16, o += 12) + { + a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C6_SHFL0), C6_MULLO), 10); + a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 6), C6_SHFL0), C6_MULLO), 10); + a11 = _mm_loadu_si128((__m128i*)(A[1] + o)); + a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a11, C6_SHFL0), C6_MULLO), 10); + a11 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a11, 6), C6_SHFL0), C6_MULLO), 10); + + b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); + ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab10); + + b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); + ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab11); + + b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); + ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab12); + + b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); + ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab13); + } + for (; i < size; i += 8, o += 6) + { + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C6_SHFL0), C6_MULLO), 10); + a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C6_SHFL0), C6_MULLO), 10); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C6_SHFL0), C6_MULLO), 10); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C6_SHFL0), C6_MULLO), 10); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C6_SHFL0), C6_MULLO), 10); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C6_SHFL0), C6_MULLO), 10); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + } + + template<> void MicroCosineDistancesDirect2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m128i a00, a01, a10, a11, b00, b01; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + __m128i ab10 = _mm_setzero_si128(); + __m128i ab11 = _mm_setzero_si128(); + __m128i ab12 = _mm_setzero_si128(); + __m128i ab13 = _mm_setzero_si128(); + for (; i < size16; i += 16, o += 14) + { + a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C7_SHFL0), C7_MULLO), 9); + a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 7), C7_SHFL0), C7_MULLO), 9); + a11 = _mm_loadu_si128((__m128i*)(A[1] + o)); + a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a11, C7_SHFL0), C7_MULLO), 9); + a11 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a11, 7), C7_SHFL0), C7_MULLO), 9); + + b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); + ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab10); + + b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); + ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab11); + + b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); + ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab12); + + b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); + ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab13); + } + for (; i < size; i += 8, o += 7) + { + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C7_SHFL0), C7_MULLO), 9); + a10 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[1] + o)), C7_SHFL0), C7_MULLO), 9); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C7_SHFL0), C7_MULLO), 9); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C7_SHFL0), C7_MULLO), 9); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C7_SHFL0), C7_MULLO), 9); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C7_SHFL0), C7_MULLO), 9); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + } + + template<> void MicroCosineDistancesDirect2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m128i a00, a01, a10, a11, b00, b01; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + __m128i ab10 = _mm_setzero_si128(); + __m128i ab11 = _mm_setzero_si128(); + __m128i ab12 = _mm_setzero_si128(); + __m128i ab13 = _mm_setzero_si128(); + for (; i < size16; i += 16, o += 16) + { + a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); + a00 = UnpackU8<0>(a01); + a01 = UnpackU8<1>(a01); + a11 = _mm_loadu_si128((__m128i*)(A[1] + o)); + a10 = UnpackU8<0>(a11); + a11 = UnpackU8<1>(a11); + + b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); + b00 = UnpackU8<0>(b01); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); + b00 = UnpackU8<1>(b01); + ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab10); + + b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); + b00 = UnpackU8<0>(b01); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); + b00 = UnpackU8<1>(b01); + ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab11); + + b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); + b00 = UnpackU8<0>(b01); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); + b00 = UnpackU8<1>(b01); + ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab12); + + b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); + b00 = UnpackU8<0>(b01); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); + b00 = UnpackU8<1>(b01); + ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a11, b00), ab13); + } + for (; i < size; i += 8, o += 8) + { + a00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[0] + o))); + a10 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[1] + o))); + + b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[0] + o))); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + ab10 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab10); + + b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[1] + o))); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + ab11 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab11); + + b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[2] + o))); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + ab12 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab12); + + b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[3] + o))); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + ab13 = _mm_add_epi32(_mm_madd_epi16(a10, b00), ab13); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + } + + template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + + template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m128i a0, b0; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + for (; i < size32; i += 32, o += 16) + { + a0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(A[0] + o)), K8_0F); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[0] + o)), K8_0F); + ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[1] + o)), K8_0F); + ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[2] + o)), K8_0F); + ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_loadu_si128((__m128i*)(B[3] + o)), K8_0F); + ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + a0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(A[0] + o)), 4), K8_0F); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[0] + o)), 4), K8_0F); + ab00 = _mm_add_epi32(ab00, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[1] + o)), 4), K8_0F); + ab01 = _mm_add_epi32(ab01, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[2] + o)), 4), K8_0F); + ab02 = _mm_add_epi32(ab02, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm_and_si128(_mm_srli_epi16(_mm_loadu_si128((__m128i*)(B[3] + o)), 4), K8_0F); + ab03 = _mm_add_epi32(ab03, _mm_madd_epi16(_mm_maddubs_epi16(a0, b0), K16_0001)); + } + for (; i < size; i += 8, o += 4) + { + a0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C4_SHFL0), C4_MULLO), 12); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C4_SHFL0), C4_MULLO), 12); + ab00 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab00); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C4_SHFL0), C4_MULLO), 12); + ab01 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab01); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C4_SHFL0), C4_MULLO), 12); + ab02 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab02); + + b0 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C4_SHFL0), C4_MULLO), 12); + ab03 = _mm_add_epi32(_mm_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m128i a00, a01, b00, b01; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + for (; i < size16; i += 16, o += 10) + { + a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C5_SHFL0), C5_MULLO), 11); + a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 5), C5_SHFL0), C5_MULLO), 11); + + b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); + ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); + + b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); + ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); + + b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); + ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); + + b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C5_SHFL0), C5_MULLO), 11); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 5), C5_SHFL0), C5_MULLO), 11); + ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); + } + for (; i < size; i += 8, o += 5) + { + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C5_SHFL0), C5_MULLO), 11); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C5_SHFL0), C5_MULLO), 11); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C5_SHFL0), C5_MULLO), 11); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C5_SHFL0), C5_MULLO), 11); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C5_SHFL0), C5_MULLO), 11); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m128i a00, a01, b00, b01; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + for (; i < size16; i += 16, o += 12) + { + a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C6_SHFL0), C6_MULLO), 10); + a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 6), C6_SHFL0), C6_MULLO), 10); + + b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); + ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); + + b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); + ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); + + b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); + ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); + + b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C6_SHFL0), C6_MULLO), 10); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 6), C6_SHFL0), C6_MULLO), 10); + ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); + } + for (; i < size; i += 8, o += 6) + { + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C6_SHFL0), C6_MULLO), 10); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C6_SHFL0), C6_MULLO), 10); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C6_SHFL0), C6_MULLO), 10); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C6_SHFL0), C6_MULLO), 10); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C6_SHFL0), C6_MULLO), 10); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m128i a00, a01, b00, b01; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + for (; i < size16; i += 16, o += 14) + { + a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(a01, C7_SHFL0), C7_MULLO), 9); + a01 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(a01, 7), C7_SHFL0), C7_MULLO), 9); + + b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); + ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); + + b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); + ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); + + b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); + ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); + + b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(b01, C7_SHFL0), C7_MULLO), 9); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_srli_si128(b01, 7), C7_SHFL0), C7_MULLO), 9); + ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); + } + for (; i < size; i += 8, o += 7) + { + a00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(A[0] + o)), C7_SHFL0), C7_MULLO), 9); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[0] + o)), C7_SHFL0), C7_MULLO), 9); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[1] + o)), C7_SHFL0), C7_MULLO), 9); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[2] + o)), C7_SHFL0), C7_MULLO), 9); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + + b00 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)(B[3] + o)), C7_SHFL0), C7_MULLO), 9); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m128i a00, a01, b00, b01; + __m128i ab00 = _mm_setzero_si128(); + __m128i ab01 = _mm_setzero_si128(); + __m128i ab02 = _mm_setzero_si128(); + __m128i ab03 = _mm_setzero_si128(); + for (; i < size16; i += 16, o += 16) + { + a01 = _mm_loadu_si128((__m128i*)(A[0] + o)); + a00 = UnpackU8<0>(a01); + a01 = UnpackU8<1>(a01); + + b01 = _mm_loadu_si128((__m128i*)(B[0] + o)); + b00 = UnpackU8<0>(b01); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + b00 = UnpackU8<1>(b01); + ab00 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab00); + + b01 = _mm_loadu_si128((__m128i*)(B[1] + o)); + b00 = UnpackU8<0>(b01); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + b00 = UnpackU8<1>(b01); + ab01 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab01); + + b01 = _mm_loadu_si128((__m128i*)(B[2] + o)); + b00 = UnpackU8<0>(b01); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + b00 = UnpackU8<1>(b01); + ab02 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab02); + + b01 = _mm_loadu_si128((__m128i*)(B[3] + o)); + b00 = UnpackU8<0>(b01); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + b00 = UnpackU8<1>(b01); + ab03 = _mm_add_epi32(_mm_madd_epi16(a01, b00), ab03); + } + for (; i < size; i += 8, o += 8) + { + a00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[0] + o))); + + b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[0] + o))); + ab00 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab00); + + b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[1] + o))); + ab01 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab01); + + b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[2] + o))); + ab02 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab02); + + b00 = _mm_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[3] + o))); + ab03 = _mm_add_epi32(_mm_madd_epi16(a00, b00), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t M2 = AlignLoAny(M, 2); + size_t N4 = AlignLo(N, 4); + size_t i = 0; + for (; i < M2; i += 2) + { + size_t j = 0; + for (; j < N4; j += 4) + MicroCosineDistancesDirect2x4(A + i, B + j, size, distances + j, stride); + for (; j < N; j += 1) + { + CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); + CosineDistance(A[i + 1], B[j], size, distances + j + 1 * stride); + } + distances += 2 * stride; + } + for (; i < M; i++) + { + size_t j = 0; + for (; j < N4; j += 4) + MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); + for (; j < N; j += 1) + CosineDistance(A[i], B[j], size, distances + j); + distances += 1 * stride; + } + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::CosineDistancePtr GetCosineDistance(size_t depth) + { + switch (depth) + { + case 4: return CosineDistance<4>; + case 5: return CosineDistance<5>; + case 6: return CosineDistance<6>; + case 7: return CosineDistance<7>; + case 8: return CosineDistance<8>; + default: assert(0); return NULL; + } + } + + Sse41::DescrInt::MacroCosineDistancesDirectPtr GetMacroCosineDistancesDirect(size_t depth) + { + switch (depth) + { + case 4: return MacroCosineDistancesDirect<4>; + case 5: return MacroCosineDistancesDirect<5>; + case 6: return MacroCosineDistancesDirect<6>; + case 7: return MacroCosineDistancesDirect<7>; + case 8: return MacroCosineDistancesDirect<8>; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdSse41DescrIntScu.cpp b/src/Simd/SimdSse41DescrIntScu.cpp new file mode 100644 index 0000000000..8d1fdcdcf4 --- /dev/null +++ b/src/Simd/SimdSse41DescrIntScu.cpp @@ -0,0 +1,468 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdFloat16.h" +#include "Simd/SimdSynet.h" + +namespace Simd +{ +#ifdef SIMD_SSE41_ENABLE + namespace Sse41 + { + SIMD_INLINE __m128i UnpackData7x8(const uint8_t* src) + { + __m128i _src = _mm_loadl_epi64((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C7_SHFL0), C7_MULLO), 9); + return _mm_packus_epi16(lo, K_ZERO); + } + + SIMD_INLINE __m128i UnpackData7x16(const uint8_t* src) + { + __m128i _src = _mm_loadu_si128((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C7_SHFL0), C7_MULLO), 9); + __m128i hi = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C7_SHFL1), C7_MULLO), 9); + return _mm_packus_epi16(lo, hi); + } + + static void UnpackDataA7(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t size16 = AlignLo(size, 16); + for (size_t i = 0; i < count; i++) + { + const uint8_t* ps = src[i] + 16; + uint8_t* pd = (uint8_t*)dst + i * size; + size_t j = 0; + for (; j < size16; j += 16, ps += 14, pd += 16) + _mm_storeu_si128((__m128i*)pd, UnpackData7x16(ps)); + for (; j < size; j += 8, ps += 7, pd += 8) + _mm_storel_epi64((__m128i*)pd, UnpackData7x8(ps)); + } + } + + static void UnpackDataA8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t size16 = AlignLo(size, 16); + for (size_t i = 0, j; i < count; i++) + { + const uint8_t* ps = src[i] + 16; + uint16_t* pd = (uint16_t*)dst + i * size; + for (j = 0; j < size16; j += 16, ps += 16, pd += 16) + { + __m128i s = _mm_loadu_si128((__m128i*)ps); + _mm_storeu_si128((__m128i*)pd + 0, UnpackU8<0>(s)); + _mm_storeu_si128((__m128i*)pd + 1, UnpackU8<1>(s)); + } + for (; j < size; j += 8, ps += 8, pd += 8) + { + __m128i s = _mm_loadl_epi64((__m128i*)ps); + _mm_storeu_si128((__m128i*)pd, UnpackU8<0>(s)); + } + } + } + + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE void UnpackDataB7x4x16(const uint8_t* const* src, size_t offset, uint8_t* dst) + { + __m128i a0 = UnpackData7x16(src[0] + offset); + __m128i a1 = UnpackData7x16(src[1] + offset); + __m128i a2 = UnpackData7x16(src[2] + offset); + __m128i a3 = UnpackData7x16(src[3] + offset); + __m128i b0 = _mm_unpacklo_epi32(a0, a2); + __m128i b1 = _mm_unpacklo_epi32(a1, a3); + __m128i b2 = _mm_unpackhi_epi32(a0, a2); + __m128i b3 = _mm_unpackhi_epi32(a1, a3); + _mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 2, _mm_unpackhi_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 4, _mm_unpacklo_epi32(b2, b3)); + _mm_storeu_si128((__m128i*)dst + 6, _mm_unpackhi_epi32(b2, b3)); + } + + SIMD_INLINE void UnpackDataB7x4x8(const uint8_t* const* src, size_t offset, uint8_t* dst) + { + __m128i a0 = UnpackData7x8(src[0] + offset); + __m128i a1 = UnpackData7x8(src[1] + offset); + __m128i a2 = UnpackData7x8(src[2] + offset); + __m128i a3 = UnpackData7x8(src[3] + offset); + __m128i b0 = _mm_unpacklo_epi32(a0, a2); + __m128i b1 = _mm_unpacklo_epi32(a1, a3); + _mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 2, _mm_unpackhi_epi32(b0, b1)); + } + + static void UnpackDataB7(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t count8 = AlignLo(count, 8), size16 = AlignLo(size, 16), i, j, o; + for (i = 0; i < count8; i += 8, src += 8) + { + for (j = 0, o = 16; j < size16; j += 16, o += 14, dst += 8 * A) + { + UnpackDataB7x4x16(src + 0, o, dst + 0); + UnpackDataB7x4x16(src + 4, o, dst + A); + } + for (; j < size; j += 8, o += 7, dst += 4 * A) + { + UnpackDataB7x4x8(src + 0, o, dst + 0); + UnpackDataB7x4x8(src + 2, o, dst + A); + } + } + if (i < count) + { + const uint8_t* _src[8]; + for (size_t j = 0; j < 8; i++, j++) + _src[j] = i < count ? *src++ : src[-1]; + for (j = 0, o = 16; j < size16; j += 16, o += 14, dst += 8 * A) + { + UnpackDataB7x4x16(src + 0, o, dst + 0); + UnpackDataB7x4x16(src + 4, o, dst + A); + } + for (; j < size; j += 8, o += 7, dst += 4 * A) + { + UnpackDataB7x4x8(src + 0, o, dst + 0); + UnpackDataB7x4x8(src + 2, o, dst + A); + } + } + } + + SIMD_INLINE void UnpackDataB8x4(const uint8_t* const* src, size_t offset, uint8_t* dst) + { + __m128i a0 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[0] + offset))); + __m128i a1 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[1] + offset))); + __m128i a2 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[2] + offset))); + __m128i a3 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[3] + offset))); + __m128i b0 = _mm_unpacklo_epi32(a0, a2); + __m128i b1 = _mm_unpacklo_epi32(a1, a3); + __m128i b2 = _mm_unpackhi_epi32(a0, a2); + __m128i b3 = _mm_unpackhi_epi32(a1, a3); + _mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 2, _mm_unpackhi_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 4, _mm_unpacklo_epi32(b2, b3)); + _mm_storeu_si128((__m128i*)dst + 6, _mm_unpackhi_epi32(b2, b3)); + } + + static void UnpackDataB8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t count8 = AlignLo(count, 8), i; + for (i = 0, size += 16; i < count8; i += 8, src += 8) + { + for (size_t j = 16; j < size; j += 8, dst += 8 * A) + { + UnpackDataB8x4(src + 0, j, dst + 0); + UnpackDataB8x4(src + 4, j, dst + A); + } + } + if (i < count) + { + const uint8_t* _src[8]; + for (size_t j = 0; j < 8; i++, j++) + _src[j] = i < count ? *src++ : src[-1]; + for (size_t j = 16; j < size; j += 8, dst += 8 * A) + { + UnpackDataB8x4(_src + 0, j, dst + 0); + UnpackDataB8x4(_src + 4, j, dst + A); + } + } + } + + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE __m128i Set2(const int16_t* src) + { + return _mm_set1_epi32(*(int32_t*)src); + } + + SIMD_INLINE void Madd2(__m128i& ab, __m128i a, __m128i b) + { + ab = _mm_add_epi32(ab, _mm_madd_epi16(a, b)); + } + + template void Correlation16_2xM(size_t N, size_t K, const int16_t* ad0, const int16_t* bd, const float *an, const float *bn, size_t bnStride, float* distances, size_t stride) + { + __m128i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, ab50, ab51, a0, b0, b1; + const int16_t* ad1 = ad0 + 1 * K; + const int16_t* ad2 = ad0 + 2 * K; + const int16_t* ad3 = ad0 + 3 * K; + const int16_t* ad4 = ad0 + 4 * K; + const int16_t* ad5 = ad0 + 5 * K; + if (N > 4) + { + if (M > 0) ab00 = _mm_setzero_si128(), ab01 = _mm_setzero_si128(); + if (M > 1) ab10 = _mm_setzero_si128(), ab11 = _mm_setzero_si128(); + if (M > 2) ab20 = _mm_setzero_si128(), ab21 = _mm_setzero_si128(); + if (M > 3) ab30 = _mm_setzero_si128(), ab31 = _mm_setzero_si128(); + if (M > 4) ab40 = _mm_setzero_si128(), ab41 = _mm_setzero_si128(); + if (M > 5) ab50 = _mm_setzero_si128(), ab51 = _mm_setzero_si128(); + for (size_t k = 0; k < K; k += 2) + { + b0 = _mm_loadu_si128((__m128i*)bd + 0); + b1 = _mm_loadu_si128((__m128i*)bd + 1); + if (M > 0) a0 = Set2(ad0 + k), Madd2(ab00, a0, b0), Madd2(ab01, a0, b1); + if (M > 1) a0 = Set2(ad1 + k), Madd2(ab10, a0, b0), Madd2(ab11, a0, b1); + if (M > 2) a0 = Set2(ad2 + k), Madd2(ab20, a0, b0), Madd2(ab21, a0, b1); + if (M > 3) a0 = Set2(ad3 + k), Madd2(ab30, a0, b0), Madd2(ab31, a0, b1); + if (M > 4) a0 = Set2(ad4 + k), Madd2(ab40, a0, b0), Madd2(ab41, a0, b1); + if (M > 5) a0 = Set2(ad5 + k), Madd2(ab50, a0, b0), Madd2(ab51, a0, b1); + bd += 16; + } + if (N == 8) + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4), an += 4, distances += stride; + } + else + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4, N - 4), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4, N - 4), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4, N - 4), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4, N - 4), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4, N - 4), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4, N - 4), an += 4, distances += stride; + } + } + else + { + if (M > 0) ab00 = _mm_setzero_si128(); + if (M > 1) ab10 = _mm_setzero_si128(); + if (M > 2) ab20 = _mm_setzero_si128(); + if (M > 3) ab30 = _mm_setzero_si128(); + if (M > 4) ab40 = _mm_setzero_si128(); + if (M > 5) ab50 = _mm_setzero_si128(); + for (size_t k = 0; k < K; k += 2) + { + b0 = _mm_loadu_si128((__m128i*)bd + 0); + if (M > 0) a0 = Set2(ad0 + k), Madd2(ab00, a0, b0); + if (M > 1) a0 = Set2(ad1 + k), Madd2(ab10, a0, b0); + if (M > 2) a0 = Set2(ad2 + k), Madd2(ab20, a0, b0); + if (M > 3) a0 = Set2(ad3 + k), Madd2(ab30, a0, b0); + if (M > 4) a0 = Set2(ad4 + k), Madd2(ab40, a0, b0); + if (M > 5) a0 = Set2(ad5 + k), Madd2(ab50, a0, b0); + bd += 16; + } + if (N == 4) + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), an += 4, distances += stride; + } + else + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0, N), an += 4, distances += stride; + } + } + } + + typedef void(*Correlation16_2xM_Ptr)(size_t N, size_t K, const int16_t* ad0, const int16_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride); + + SIMD_INLINE Correlation16_2xM_Ptr GetCorrelation16_2xM(size_t M) + { + switch (M) + { + case 0: return NULL; + case 1: return Correlation16_2xM<1>; + case 2: return Correlation16_2xM<2>; + case 3: return Correlation16_2xM<3>; + case 4: return Correlation16_2xM<4>; + case 5: return Correlation16_2xM<5>; + case 6: return Correlation16_2xM<6>; + } + assert(0); + return NULL; + } + + void MacroCorrelation16(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride) + { + size_t M6 = AlignLoAny(M, 6); + Correlation16_2xM_Ptr correlation_2x6 = GetCorrelation16_2xM(6); + Correlation16_2xM_Ptr correlation_2xT = GetCorrelation16_2xM(M - M6); + const int16_t* a = (int16_t*)ad; + const int16_t* b = (int16_t*)bd; + for (size_t j = 0; j < N; j += 8) + { + size_t dN = Simd::Min(8, N - j); + size_t i = 0; + for (; i < M6; i += 6) + correlation_2x6(dN, K, a + i * K, b, an + i * 4, bn, N, distances + i * stride, stride); + if(i < M) + correlation_2xT(dN, K, a + i * K, b, an + i * 4, bn, N, distances + i * stride, stride); + b += K * 8; + bn += 8; + distances += 8; + } + } + + //------------------------------------------------------------------------------------------------- + + template void Correlation8_2xM(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride) + { + __m128i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, a0, b0, b1; + const uint8_t* ad1 = ad0 + 1 * K; + const uint8_t* ad2 = ad0 + 2 * K; + const uint8_t* ad3 = ad0 + 3 * K; + const uint8_t* ad4 = ad0 + 4 * K; + if (N > 4) + { + if (M > 0) ab00 = _mm_setzero_si128(), ab01 = _mm_setzero_si128(); + if (M > 1) ab10 = _mm_setzero_si128(), ab11 = _mm_setzero_si128(); + if (M > 2) ab20 = _mm_setzero_si128(), ab21 = _mm_setzero_si128(); + if (M > 3) ab30 = _mm_setzero_si128(), ab31 = _mm_setzero_si128(); + if (M > 4) ab40 = _mm_setzero_si128(), ab41 = _mm_setzero_si128(); + for (size_t k = 0; k < K; k += 4) + { + b0 = _mm_loadu_si128((__m128i*)bd + 0); + b1 = _mm_loadu_si128((__m128i*)bd + 1); + if (M > 0) a0 = Set4(ad0 + k), Madd4(ab00, a0, b0), Madd4(ab01, a0, b1); + if (M > 1) a0 = Set4(ad1 + k), Madd4(ab10, a0, b0), Madd4(ab11, a0, b1); + if (M > 2) a0 = Set4(ad2 + k), Madd4(ab20, a0, b0), Madd4(ab21, a0, b1); + if (M > 3) a0 = Set4(ad3 + k), Madd4(ab30, a0, b0), Madd4(ab31, a0, b1); + if (M > 4) a0 = Set4(ad4 + k), Madd4(ab40, a0, b0), Madd4(ab41, a0, b1); + bd += DA; + } + if (N == 8) + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4), an += 4, distances += stride; + } + else + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4, N - 4), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4, N - 4), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4, N - 4), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4, N - 4), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4, N - 4), an += 4, distances += stride; + } + } + else + { + if (M > 0) ab00 = _mm_setzero_si128(); + if (M > 1) ab10 = _mm_setzero_si128(); + if (M > 2) ab20 = _mm_setzero_si128(); + if (M > 3) ab30 = _mm_setzero_si128(); + if (M > 4) ab40 = _mm_setzero_si128(); + for (size_t k = 0; k < K; k += 4) + { + b0 = _mm_loadu_si128((__m128i*)bd + 0); + if (M > 0) a0 = Set4(ad0 + k), Madd4(ab00, a0, b0); + if (M > 1) a0 = Set4(ad1 + k), Madd4(ab10, a0, b0); + if (M > 2) a0 = Set4(ad2 + k), Madd4(ab20, a0, b0); + if (M > 3) a0 = Set4(ad3 + k), Madd4(ab30, a0, b0); + if (M > 4) a0 = Set4(ad4 + k), Madd4(ab40, a0, b0); + bd += DA; + } + if (N == 4) + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; + } + else + { + if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; + } + } + } + + typedef void(*Correlation8_2xM_Ptr)(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride); + + SIMD_INLINE Correlation8_2xM_Ptr GetCorrelation8_2xM(size_t M) + { + switch (M) + { + case 0: return NULL; + case 1: return Correlation8_2xM<1>; + case 2: return Correlation8_2xM<2>; + case 3: return Correlation8_2xM<3>; + case 4: return Correlation8_2xM<4>; + case 5: return Correlation8_2xM<5>; + } + assert(0); + return NULL; + } + + void MacroCorrelation8(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride) + { + size_t M5 = AlignLoAny(M, 5); + Correlation8_2xM_Ptr correlation_2x5 = GetCorrelation8_2xM(5); + Correlation8_2xM_Ptr correlation_2xT = GetCorrelation8_2xM(M - M5); + for (size_t j = 0; j < N; j += 8) + { + size_t dN = Simd::Min(8, N - j); + size_t i = 0; + for (; i < M5; i += 5) + correlation_2x5(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + if (i < M) + correlation_2xT(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + bd += K * 8; + bn += 8; + distances += 8; + } + } + + + //------------------------------------------------------------------------------------------------- + + Sse41::DescrInt::UnpackDataPtr GetUnpackData(size_t depth, bool transpose) + { + switch (depth) + { + case 7: return transpose ? UnpackDataB7 : UnpackDataA7; + case 8: return transpose ? UnpackDataB8 : UnpackDataA8; + default: return NULL; + } + } + + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth) + { + return depth == 8 ? MacroCorrelation16 : MacroCorrelation8; + } + } +#endif +} From ebf98bbdd8f0283d22dbe0337c07c2dfea4ab340 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 27 Jun 2023 13:15:44 +0300 Subject: [PATCH 31/44] *improve: SSE4.1 optimizations of functions DescrIntCosineDistancesMxNp and DescrIntCosineDistancesMxNa for 4, 5, 6-bits depth. --- docs/2023.html | 4 +- src/Simd/SimdDescrIntCommon.h | 1 + src/Simd/SimdSse41DescrInt.cpp | 2 +- src/Simd/SimdSse41DescrIntScu.cpp | 122 ++++++++++++++++++++++-------- 4 files changed, 95 insertions(+), 34 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index 00e519968e..58ebbd6698 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -50,8 +50,8 @@
        New features
      Improving
        -
      • SSE4.1 optimizations of function DescrIntCosineDistancesMxNp for 7-bit depth.
      • -
      • SSE4.1 optimizations of function DescrIntCosineDistancesMxNa for 7-bit depth.
      • +
      • SSE4.1 optimizations of function DescrIntCosineDistancesMxNp for 4, 5, 6, 7-bits depth.
      • +
      • SSE4.1 optimizations of function DescrIntCosineDistancesMxNa for 4, 5, 6, 7-bits depth.
      Bug fixing
        diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index 050ac789ff..74c6a46996 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -69,6 +69,7 @@ namespace Simd const __m128i C4_MULLO = SIMD_MM_SETR_EPI16(4096, 256, 4096, 256, 4096, 256, 4096, 256); const __m128i C4_SHFL0 = SIMD_MM_SETR_EPI8(0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3); + const __m128i C4_SHFL1 = SIMD_MM_SETR_EPI8(0x4, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x7, 0x7); const __m128i C5_SHFL0 = SIMD_MM_SETR_EPI8(0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x4); const __m128i C5_SHFL1 = SIMD_MM_SETR_EPI8(0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9); diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 85efd6742a..5ca2328edb 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -139,7 +139,7 @@ namespace Simd void DescrInt::CosineDistancesMxNa(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const { - if(_unpSize * _microNu > Base::AlgCacheL1() || N * 2 < _microNu || _depth < 7 || _depth == 8) + if(_unpSize * _microNu > Base::AlgCacheL1() || N * 2 < _microNu || _depth < 5 || _depth == 8) CosineDistancesDirect(M, N, A, B, distances); else CosineDistancesUnpack(M, N, A, B, distances); diff --git a/src/Simd/SimdSse41DescrIntScu.cpp b/src/Simd/SimdSse41DescrIntScu.cpp index 8d1fdcdcf4..dc2a2bfb13 100644 --- a/src/Simd/SimdSse41DescrIntScu.cpp +++ b/src/Simd/SimdSse41DescrIntScu.cpp @@ -37,14 +37,65 @@ namespace Simd #ifdef SIMD_SSE41_ENABLE namespace Sse41 { - SIMD_INLINE __m128i UnpackData7x8(const uint8_t* src) + template __m128i UnpackData8(const uint8_t* src); + + template<> SIMD_INLINE __m128i UnpackData8<4>(const uint8_t* src) + { + __m128i _src = _mm_loadl_epi64((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C4_SHFL0), C4_MULLO), 12); + return _mm_packus_epi16(lo, K_ZERO); + } + + template<> SIMD_INLINE __m128i UnpackData8<5>(const uint8_t* src) + { + __m128i _src = _mm_loadl_epi64((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C5_SHFL0), C5_MULLO), 11); + return _mm_packus_epi16(lo, K_ZERO); + } + + template<> SIMD_INLINE __m128i UnpackData8<6>(const uint8_t* src) + { + __m128i _src = _mm_loadl_epi64((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C6_SHFL0), C6_MULLO), 10); + return _mm_packus_epi16(lo, K_ZERO); + } + + template<> SIMD_INLINE __m128i UnpackData8<7>(const uint8_t* src) { __m128i _src = _mm_loadl_epi64((__m128i*)src); __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C7_SHFL0), C7_MULLO), 9); return _mm_packus_epi16(lo, K_ZERO); } - SIMD_INLINE __m128i UnpackData7x16(const uint8_t* src) + //------------------------------------------------------------------------------------------------- + + template __m128i UnpackData16(const uint8_t* src); + + template<> SIMD_INLINE __m128i UnpackData16<4>(const uint8_t* src) + { + __m128i _src = _mm_loadu_si128((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C4_SHFL0), C4_MULLO), 12); + __m128i hi = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C4_SHFL1), C4_MULLO), 12); + return _mm_packus_epi16(lo, hi); + } + + template<> SIMD_INLINE __m128i UnpackData16<5>(const uint8_t* src) + { + __m128i _src = _mm_loadu_si128((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C5_SHFL0), C5_MULLO), 11); + __m128i hi = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C5_SHFL1), C5_MULLO), 11); + return _mm_packus_epi16(lo, hi); + } + + template<> SIMD_INLINE __m128i UnpackData16<6>(const uint8_t* src) + { + __m128i _src = _mm_loadu_si128((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C6_SHFL0), C6_MULLO), 10); + __m128i hi = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C6_SHFL1), C6_MULLO), 10); + return _mm_packus_epi16(lo, hi); + } + + template<> SIMD_INLINE __m128i UnpackData16<7>(const uint8_t* src) { __m128i _src = _mm_loadu_si128((__m128i*)src); __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C7_SHFL0), C7_MULLO), 9); @@ -52,7 +103,9 @@ namespace Simd return _mm_packus_epi16(lo, hi); } - static void UnpackDataA7(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + //------------------------------------------------------------------------------------------------- + + template void UnpackDataA(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) { size_t size16 = AlignLo(size, 16); for (size_t i = 0; i < count; i++) @@ -60,13 +113,15 @@ namespace Simd const uint8_t* ps = src[i] + 16; uint8_t* pd = (uint8_t*)dst + i * size; size_t j = 0; - for (; j < size16; j += 16, ps += 14, pd += 16) - _mm_storeu_si128((__m128i*)pd, UnpackData7x16(ps)); - for (; j < size; j += 8, ps += 7, pd += 8) - _mm_storel_epi64((__m128i*)pd, UnpackData7x8(ps)); + for (; j < size16; j += 16, ps += 2 * bits, pd += 16) + _mm_storeu_si128((__m128i*)pd, UnpackData16(ps)); + for (; j < size; j += 8, ps += bits, pd += 8) + _mm_storel_epi64((__m128i*)pd, UnpackData8(ps)); } } + //------------------------------------------------------------------------------------------------- + static void UnpackDataA8(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) { size_t size16 = AlignLo(size, 16); @@ -90,12 +145,12 @@ namespace Simd //------------------------------------------------------------------------------------------------- - SIMD_INLINE void UnpackDataB7x4x16(const uint8_t* const* src, size_t offset, uint8_t* dst) + template SIMD_INLINE void UnpackDataBx4x16(const uint8_t* const* src, size_t offset, uint8_t* dst) { - __m128i a0 = UnpackData7x16(src[0] + offset); - __m128i a1 = UnpackData7x16(src[1] + offset); - __m128i a2 = UnpackData7x16(src[2] + offset); - __m128i a3 = UnpackData7x16(src[3] + offset); + __m128i a0 = UnpackData16(src[0] + offset); + __m128i a1 = UnpackData16(src[1] + offset); + __m128i a2 = UnpackData16(src[2] + offset); + __m128i a3 = UnpackData16(src[3] + offset); __m128i b0 = _mm_unpacklo_epi32(a0, a2); __m128i b1 = _mm_unpacklo_epi32(a1, a3); __m128i b2 = _mm_unpackhi_epi32(a0, a2); @@ -106,32 +161,32 @@ namespace Simd _mm_storeu_si128((__m128i*)dst + 6, _mm_unpackhi_epi32(b2, b3)); } - SIMD_INLINE void UnpackDataB7x4x8(const uint8_t* const* src, size_t offset, uint8_t* dst) + template SIMD_INLINE void UnpackDataBx4x8(const uint8_t* const* src, size_t offset, uint8_t* dst) { - __m128i a0 = UnpackData7x8(src[0] + offset); - __m128i a1 = UnpackData7x8(src[1] + offset); - __m128i a2 = UnpackData7x8(src[2] + offset); - __m128i a3 = UnpackData7x8(src[3] + offset); + __m128i a0 = UnpackData8(src[0] + offset); + __m128i a1 = UnpackData8(src[1] + offset); + __m128i a2 = UnpackData8(src[2] + offset); + __m128i a3 = UnpackData8(src[3] + offset); __m128i b0 = _mm_unpacklo_epi32(a0, a2); __m128i b1 = _mm_unpacklo_epi32(a1, a3); _mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1)); _mm_storeu_si128((__m128i*)dst + 2, _mm_unpackhi_epi32(b0, b1)); } - static void UnpackDataB7(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + template void UnpackDataB(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) { size_t count8 = AlignLo(count, 8), size16 = AlignLo(size, 16), i, j, o; for (i = 0; i < count8; i += 8, src += 8) { - for (j = 0, o = 16; j < size16; j += 16, o += 14, dst += 8 * A) + for (j = 0, o = 16; j < size16; j += 16, o += 2 * bits, dst += 8 * A) { - UnpackDataB7x4x16(src + 0, o, dst + 0); - UnpackDataB7x4x16(src + 4, o, dst + A); + UnpackDataBx4x16(src + 0, o, dst + 0); + UnpackDataBx4x16(src + 4, o, dst + A); } - for (; j < size; j += 8, o += 7, dst += 4 * A) + for (; j < size; j += 8, o += bits, dst += 4 * A) { - UnpackDataB7x4x8(src + 0, o, dst + 0); - UnpackDataB7x4x8(src + 2, o, dst + A); + UnpackDataBx4x8(src + 0, o, dst + 0); + UnpackDataBx4x8(src + 2, o, dst + A); } } if (i < count) @@ -139,19 +194,21 @@ namespace Simd const uint8_t* _src[8]; for (size_t j = 0; j < 8; i++, j++) _src[j] = i < count ? *src++ : src[-1]; - for (j = 0, o = 16; j < size16; j += 16, o += 14, dst += 8 * A) + for (j = 0, o = 16; j < size16; j += 16, o += 2 * bits, dst += 8 * A) { - UnpackDataB7x4x16(src + 0, o, dst + 0); - UnpackDataB7x4x16(src + 4, o, dst + A); + UnpackDataBx4x16(src + 0, o, dst + 0); + UnpackDataBx4x16(src + 4, o, dst + A); } - for (; j < size; j += 8, o += 7, dst += 4 * A) + for (; j < size; j += 8, o += bits, dst += 4 * A) { - UnpackDataB7x4x8(src + 0, o, dst + 0); - UnpackDataB7x4x8(src + 2, o, dst + A); + UnpackDataBx4x8(src + 0, o, dst + 0); + UnpackDataBx4x8(src + 2, o, dst + A); } } } + //------------------------------------------------------------------------------------------------- + SIMD_INLINE void UnpackDataB8x4(const uint8_t* const* src, size_t offset, uint8_t* dst) { __m128i a0 = UnpackU8<0>(_mm_loadl_epi64((__m128i*)(src[0] + offset))); @@ -453,7 +510,10 @@ namespace Simd { switch (depth) { - case 7: return transpose ? UnpackDataB7 : UnpackDataA7; + case 4: return transpose ? UnpackDataB<4> : UnpackDataA<4>; + case 5: return transpose ? UnpackDataB<5> : UnpackDataA<5>; + case 6: return transpose ? UnpackDataB<6> : UnpackDataA<6>; + case 7: return transpose ? UnpackDataB<7> : UnpackDataA<7>; case 8: return transpose ? UnpackDataB8 : UnpackDataA8; default: return NULL; } From eb33159f9c76adddb0c0cac0fa95dc949f52306e Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 27 Jun 2023 15:59:50 +0300 Subject: [PATCH 32/44] *improve AVX2 optimizations of functions DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa for 4, 5, 6, 7-bits depth. --- prj/vs2019/Avx2.vcxproj | 4 + prj/vs2019/Avx2.vcxproj.filters | 12 + prj/vs2019/Sse41.vcxproj | 4 +- prj/vs2019/Sse41.vcxproj.filters | 4 +- prj/vs2022/Avx2.vcxproj | 4 + prj/vs2022/Avx2.vcxproj.filters | 12 + prj/vs2022/Sse41.vcxproj | 4 +- prj/vs2022/Sse41.vcxproj.filters | 4 +- src/Simd/SimdAvx2DescrInt.cpp | 1391 +---------------- src/Simd/SimdAvx2DescrIntCdd.cpp | 753 +++++++++ src/Simd/SimdAvx2DescrIntCdu.cpp | 359 +++++ src/Simd/SimdAvx2DescrIntDec.cpp | 307 ++++ src/Simd/SimdAvx2DescrIntEnc.cpp | 432 +++++ src/Simd/SimdDescrInt.h | 14 + src/Simd/SimdDescrIntCommon.h | 66 +- src/Simd/SimdSse41DescrInt.cpp | 2 +- ...scrIntScd.cpp => SimdSse41DescrIntCdd.cpp} | 0 ...scrIntScu.cpp => SimdSse41DescrIntCdu.cpp} | 135 +- 18 files changed, 2076 insertions(+), 1431 deletions(-) create mode 100644 src/Simd/SimdAvx2DescrIntCdd.cpp create mode 100644 src/Simd/SimdAvx2DescrIntCdu.cpp create mode 100644 src/Simd/SimdAvx2DescrIntDec.cpp create mode 100644 src/Simd/SimdAvx2DescrIntEnc.cpp rename src/Simd/{SimdSse41DescrIntScd.cpp => SimdSse41DescrIntCdd.cpp} (100%) rename src/Simd/{SimdSse41DescrIntScu.cpp => SimdSse41DescrIntCdu.cpp} (77%) diff --git a/prj/vs2019/Avx2.vcxproj b/prj/vs2019/Avx2.vcxproj index 56c2e4fa5f..a037021cb5 100644 --- a/prj/vs2019/Avx2.vcxproj +++ b/prj/vs2019/Avx2.vcxproj @@ -32,6 +32,10 @@ + + + + diff --git a/prj/vs2019/Avx2.vcxproj.filters b/prj/vs2019/Avx2.vcxproj.filters index a5336ca79a..782a659ea2 100644 --- a/prj/vs2019/Avx2.vcxproj.filters +++ b/prj/vs2019/Avx2.vcxproj.filters @@ -337,6 +337,18 @@ Avx2 + + Avx2 + + + Avx2 + + + Avx2 + + + Avx2 + diff --git a/prj/vs2019/Sse41.vcxproj b/prj/vs2019/Sse41.vcxproj index a181de12c6..d1f80d774f 100644 --- a/prj/vs2019/Sse41.vcxproj +++ b/prj/vs2019/Sse41.vcxproj @@ -37,8 +37,8 @@ - - + + diff --git a/prj/vs2019/Sse41.vcxproj.filters b/prj/vs2019/Sse41.vcxproj.filters index 4e224f91b1..2ba7ae9585 100644 --- a/prj/vs2019/Sse41.vcxproj.filters +++ b/prj/vs2019/Sse41.vcxproj.filters @@ -382,10 +382,10 @@ Sse41 - + Sse41 - + Sse41 diff --git a/prj/vs2022/Avx2.vcxproj b/prj/vs2022/Avx2.vcxproj index 56c2e4fa5f..a037021cb5 100644 --- a/prj/vs2022/Avx2.vcxproj +++ b/prj/vs2022/Avx2.vcxproj @@ -32,6 +32,10 @@ + + + + diff --git a/prj/vs2022/Avx2.vcxproj.filters b/prj/vs2022/Avx2.vcxproj.filters index a5336ca79a..782a659ea2 100644 --- a/prj/vs2022/Avx2.vcxproj.filters +++ b/prj/vs2022/Avx2.vcxproj.filters @@ -337,6 +337,18 @@ Avx2 + + Avx2 + + + Avx2 + + + Avx2 + + + Avx2 + diff --git a/prj/vs2022/Sse41.vcxproj b/prj/vs2022/Sse41.vcxproj index a181de12c6..d1f80d774f 100644 --- a/prj/vs2022/Sse41.vcxproj +++ b/prj/vs2022/Sse41.vcxproj @@ -37,8 +37,8 @@ - - + + diff --git a/prj/vs2022/Sse41.vcxproj.filters b/prj/vs2022/Sse41.vcxproj.filters index 4e224f91b1..2ba7ae9585 100644 --- a/prj/vs2022/Sse41.vcxproj.filters +++ b/prj/vs2022/Sse41.vcxproj.filters @@ -382,10 +382,10 @@ Sse41 - + Sse41 - + Sse41 diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index d56b63afb1..0c37f672d7 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -29,6 +29,7 @@ #include "Simd/SimdDescrInt.h" #include "Simd/SimdDescrIntCommon.h" #include "Simd/SimdCpu.h" +#include "Simd/SimdLoad.h" namespace Simd { @@ -71,1298 +72,56 @@ namespace Simd //------------------------------------------------------------------------------------------------- - SIMD_INLINE __m256i Encode32f(__m256 src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + static void UnpackNormA(size_t count, const uint8_t* const* src, float* dst, size_t stride) { - __m256i value = _mm256_cvtps_epi32(_mm256_mul_ps(_mm256_sub_ps(src, min), scale)); - sum = _mm256_add_epi32(value, sum); - sqsum = _mm256_add_epi32(_mm256_madd_epi16(value, value), sqsum); - return value; - } - - SIMD_INLINE __m256i Encode32f(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - return Encode32f(_mm256_loadu_ps(src), scale, min, sum, sqsum); - } - - static SIMD_INLINE __m128i Encode32f4x8(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(src + 0 * 8, scale, min, sum, sqsum); - __m128i s0 = _mm_srli_epi32(_mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E4_MULLO), 12); - return _mm_packus_epi16(_mm_packus_epi32(s0, Sse41::K_ZERO), Sse41::K_ZERO); - } - - static SIMD_INLINE __m128i Encode32f4x32(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(src + 0 * 8, scale, min, sum, sqsum); - __m256i i1 = Encode32f(src + 1 * 8, scale, min, sum, sqsum); - __m256i s0 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); - __m256i i2 = Encode32f(src + 2 * 8, scale, min, sum, sqsum); - __m256i i3 = Encode32f(src + 3 * 8, scale, min, sum, sqsum); - __m256i s1 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); - return _mm_packus_epi16(_mm_packus_epi32(_mm256_castsi256_si128(s0), _mm256_extracti128_si256(s0, 1)), - _mm_packus_epi32(_mm256_castsi256_si128(s1), _mm256_extracti128_si256(s1, 1))); - } - - static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, size32 = AlignLo(size, 32); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < size32; i += 32, src += 32, dst += 16) - _mm_storeu_si128((__m128i*)dst, Encode32f4x32(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 4) - *(uint32_t*)(dst) = _mm_extract_epi32(Encode32f4x8(src, _scale, _min, _sum, _sqsum), 0); - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f5x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E5_MULLO); - return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E5_SHFL0), _mm_shuffle_epi8(s0, Sse41::E5_SHFL1)), _mm_shuffle_epi8(s0, Sse41::E5_SHFL2)); - } - - static SIMD_INLINE __m128i Encode32f5x2(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m256i i8 = Encode32f(src + 8, scale, min, sum, sqsum); - __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E5_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_or_si256(_mm256_shuffle_epi8(s0, E5_SHFL0), _mm256_shuffle_epi8(s0, E5_SHFL1)), _mm256_shuffle_epi8(s0, E5_SHFL2)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static void Encode32f5(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < main16; i += 16, src += 16, dst += 10) - _mm_storeu_si128((__m128i*)dst, Encode32f5x2(src, _scale, _min, _sum, _sqsum)); - for (; i < main; i += 8, src += 8, dst += 5) - _mm_storel_epi64((__m128i*)dst, Encode32f5x1(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 5) - { - __m128i d0 = Encode32f5x1(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f6x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E6_MULLO); - return _mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E6_SHFL0), _mm_shuffle_epi8(s0, Sse41::E6_SHFL1)); - } - - static SIMD_INLINE __m128i Encode32f6x2(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m256i i8 = Encode32f(src + 8, scale, min, sum, sqsum); - __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E6_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, E6_SHFL0), _mm256_shuffle_epi8(s0, E6_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static void Encode32f6(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < main16; i += 16, src += 16, dst += 12) - _mm_storeu_si128((__m128i*)dst, Encode32f6x2(src, _scale, _min, _sum, _sqsum)); - for (; i < main; i += 8, src += 8, dst += 6) - _mm_storel_epi64((__m128i*)dst, Encode32f6x1(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 6) - { - __m128i d0 = Encode32f6x1(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f7x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E7_MULLO); - return _mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E7_SHFL0), _mm_shuffle_epi8(s0, Sse41::E7_SHFL1)); - } - - static SIMD_INLINE __m128i Encode32f7x2(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); - __m256i i8 = Encode32f(src + 8, scale, min, sum, sqsum); - __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E7_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, E7_SHFL0), _mm256_shuffle_epi8(s0, E7_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static void Encode32f7(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < main16; i += 16, src += 16, dst += 14) - _mm_storeu_si128((__m128i*)dst, Encode32f7x2(src, _scale, _min, _sum, _sqsum)); - for (; i < main; i += 8, src += 8, dst += 7) - _mm_storel_epi64((__m128i*)dst, Encode32f7x1(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 7) - { - __m128i d0 = Encode32f7x1(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); - *(uint8_t*)(dst + 6) = _mm_extract_epi8(d0, 6); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static void Encode32f8(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t sizeA = AlignLo(size, A), i = 0; - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < sizeA; i += A) - { - __m256i d0 = Encode32f(src + i + 0 * F, _scale, _min, _sum, _sqsum); - __m256i d1 = Encode32f(src + i + 1 * F, _scale, _min, _sum, _sqsum); - __m256i d2 = Encode32f(src + i + 2 * F, _scale, _min, _sum, _sqsum); - __m256i d3 = Encode32f(src + i + 3 * F, _scale, _min, _sum, _sqsum); - _mm256_storeu_si256((__m256i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); - } - for (; i < size; i += F) - { - __m256i d0 = Encode32f(src + i, _scale, _min, _sum, _sqsum); - _mm_storel_epi64((__m128i*)(dst + i), _mm256_castsi256_si128(PackI16ToU8(PackI32ToI16(d0, _mm256_setzero_si256()), _mm256_setzero_si256()))); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - //------------------------------------------------------------------------------------------------- - - static SIMD_INLINE __m128i Encode16f4x8(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); - __m128i s0 = _mm_srli_epi32(_mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E4_MULLO), 12); - return _mm_packus_epi16(_mm_packus_epi32(s0, Sse41::K_ZERO), Sse41::K_ZERO); - } - - static SIMD_INLINE __m128i Encode16f4x32(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); - __m256i i1 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); - __m256i s0 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); - __m256i i2 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 2)), scale, min, sum, sqsum); - __m256i i3 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 3)), scale, min, sum, sqsum); - __m256i s1 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); - return _mm_packus_epi16(_mm_packus_epi32(_mm256_castsi256_si128(s0), _mm256_extracti128_si256(s0, 1)), - _mm_packus_epi32(_mm256_castsi256_si128(s1), _mm256_extracti128_si256(s1, 1))); - } - - static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, size32 = AlignLo(size, 32); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < size32; i += 32, src += 32, dst += 16) - _mm_storeu_si128((__m128i*)dst, Encode16f4x32(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 4) - *(uint32_t*)(dst) = _mm_extract_epi32(Encode16f4x8(src, _scale, _min, _sum, _sqsum), 0); - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f5x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E5_MULLO); - return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E5_SHFL0), _mm_shuffle_epi8(s0, Sse41::E5_SHFL1)), _mm_shuffle_epi8(s0, Sse41::E5_SHFL2)); - } - - static SIMD_INLINE __m128i Encode16f5x2(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); - __m256i i8 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); - __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E5_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_or_si256(_mm256_shuffle_epi8(s0, E5_SHFL0), _mm256_shuffle_epi8(s0, E5_SHFL1)), _mm256_shuffle_epi8(s0, E5_SHFL2)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static void Encode16f5(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < main16; i += 16, src += 16, dst += 10) - _mm_storeu_si128((__m128i*)dst, Encode16f5x2(src, _scale, _min, _sum, _sqsum)); - for (; i < main; i += 8, src += 8, dst += 5) - _mm_storel_epi64((__m128i*)dst, Encode16f5x1(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 5) - { - __m128i d0 = Encode16f5x1(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f6x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E6_MULLO); - return _mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E6_SHFL0), _mm_shuffle_epi8(s0, Sse41::E6_SHFL1)); - } - - static SIMD_INLINE __m128i Encode16f6x2(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); - __m256i i8 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); - __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E6_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, E6_SHFL0), _mm256_shuffle_epi8(s0, E6_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static void Encode16f6(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < main16; i += 16, src += 16, dst += 12) - _mm_storeu_si128((__m128i*)dst, Encode16f6x2(src, _scale, _min, _sum, _sqsum)); - for (; i < main; i += 8, src += 8, dst += 6) - _mm_storel_epi64((__m128i*)dst, Encode16f6x1(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 6) - { - __m128i d0 = Encode16f6x1(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f7x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); - __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E7_MULLO); - return _mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E7_SHFL0), _mm_shuffle_epi8(s0, Sse41::E7_SHFL1)); - } - - static SIMD_INLINE __m128i Encode16f7x2(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) - { - __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); - __m256i i8 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); - __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E7_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, E7_SHFL0), _mm256_shuffle_epi8(s0, E7_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static void Encode16f7(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < main16; i += 16, src += 16, dst += 14) - _mm_storeu_si128((__m128i*)dst, Encode16f7x2(src, _scale, _min, _sum, _sqsum)); - for (; i < main; i += 8, src += 8, dst += 7) - _mm_storel_epi64((__m128i*)dst, Encode16f7x1(src, _scale, _min, _sum, _sqsum)); - for (; i < size; i += 8, src += 8, dst += 7) - { - __m128i d0 = Encode16f7x1(src, _scale, _min, _sum, _sqsum); - *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); - *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); - *(uint8_t*)(dst + 6) = _mm_extract_epi8(d0, 6); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static void Encode16f8(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t sizeA = AlignLo(size, A), i = 0; - __m256 _scale = _mm256_set1_ps(scale); - __m256 _min = _mm256_set1_ps(min); - __m256i _sum = _mm256_setzero_si256(); - __m256i _sqsum = _mm256_setzero_si256(); - for (; i < sizeA; i += A) - { - __m256i d0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i) + 0)), _scale, _min, _sum, _sqsum); - __m256i d1 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i) + 1)), _scale, _min, _sum, _sqsum); - __m256i d2 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i) + 2)), _scale, _min, _sum, _sqsum); - __m256i d3 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i) + 3)), _scale, _min, _sum, _sqsum); - _mm256_storeu_si256((__m256i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); - } - for (; i < size; i += F) - { - __m256i d0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i))), _scale, _min, _sum, _sqsum); - _mm_storel_epi64((__m128i*)(dst + i), _mm256_castsi256_si128(PackI16ToU8(PackI32ToI16(d0, _mm256_setzero_si256()), _mm256_setzero_si256()))); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); + size_t count2 = AlignLo(count, 2), i = 0; + for (; i < count2; i += 2, src += 2, dst += 8) + _mm256_storeu_ps(dst, Avx::Load((float*)src[0], (float*)src[1])); + for (; i < count; ++i, src += 1, dst += 4) + _mm_storeu_ps(dst, _mm_loadu_ps((float*)src[0])); } //------------------------------------------------------------------------------------------------- - static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) - { - __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, C4_SHFL), C4_MULLO), 12); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); - _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); - src += 8; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s4 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); - src += 4; - dst += 8; - } - } - - static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) - { - __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, C5_SHFL), C5_MULLO), 11); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); - _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); - src += 10; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s5 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); - src += 5; - dst += 8; - } - } - - static void Decode32f6(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) - { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C6_SHFL), C6_MULLO), 10); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); - _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); - src += 12; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s6 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); - src += 6; - dst += 8; - } - } - - static void Decode32f7(const uint8_t* src, float scale, float shift, size_t size, float* dst) + static void UnpackNormB(size_t count, const uint8_t* const* src, float* dst, size_t stride) { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) + size_t count8 = AlignLo(count, 8), count4 = AlignLo(count, 4), i = 0; + for (; i < count8; i += 8, src += 8, dst += 8) { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C7_SHFL), C7_MULLO), 9); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); - _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); - src += 14; - dst += 16; + __m256 s0 = Avx::Load((float*)src[0], (float*)src[4]); + __m256 s1 = Avx::Load((float*)src[1], (float*)src[5]); + __m256 s2 = Avx::Load((float*)src[2], (float*)src[6]); + __m256 s3 = Avx::Load((float*)src[3], (float*)src[7]); + __m256 s00 = _mm256_unpacklo_ps(s0, s2); + __m256 s01 = _mm256_unpacklo_ps(s1, s3); + __m256 s10 = _mm256_unpackhi_ps(s0, s2); + __m256 s11 = _mm256_unpackhi_ps(s1, s3); + _mm256_storeu_ps(dst + 0 * stride, _mm256_unpacklo_ps(s00, s01)); + _mm256_storeu_ps(dst + 1 * stride, _mm256_unpackhi_ps(s00, s01)); + _mm256_storeu_ps(dst + 2 * stride, _mm256_unpacklo_ps(s10, s11)); + _mm256_storeu_ps(dst + 3 * stride, _mm256_unpackhi_ps(s10, s11)); } - for (; i < size; i += 8) + for (; i < count4; i += 4, src += 4, dst += 4) { - __m128i s7 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); - src += 7; - dst += 8; + __m128 s0 = _mm_loadu_ps((float*)src[0]); + __m128 s1 = _mm_loadu_ps((float*)src[1]); + __m128 s2 = _mm_loadu_ps((float*)src[2]); + __m128 s3 = _mm_loadu_ps((float*)src[3]); + __m128 s00 = _mm_unpacklo_ps(s0, s2); + __m128 s01 = _mm_unpacklo_ps(s1, s3); + __m128 s10 = _mm_unpackhi_ps(s0, s2); + __m128 s11 = _mm_unpackhi_ps(s1, s3); + _mm_storeu_ps(dst + 0 * stride, _mm_unpacklo_ps(s00, s01)); + _mm_storeu_ps(dst + 1 * stride, _mm_unpackhi_ps(s00, s01)); + _mm_storeu_ps(dst + 2 * stride, _mm_unpacklo_ps(s10, s11)); + _mm_storeu_ps(dst + 3 * stride, _mm_unpackhi_ps(s10, s11)); } - } - - static void Decode32f8(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) + for (; i < count; i++, src++, dst++) { - __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); - _mm256_storeu_ps(dst + i + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(u8)), _scale, _shift)); - _mm256_storeu_ps(dst + i + F, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_srli_si128(u8, 8))), _scale, _shift)); - } - for (; i < size; i += 8) - { - __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); - _mm256_storeu_ps(dst + i, _mm256_fmadd_ps(_src, _scale, _shift)); - } - } - - //------------------------------------------------------------------------------------------------- - - static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) - { - __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, C4_SHFL), C4_MULLO), 12); - _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); - _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); - src += 8; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s4 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); - _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); - src += 4; - dst += 8; - } - } - - static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) - { - __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, C5_SHFL), C5_MULLO), 11); - _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); - _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); - src += 10; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s5 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); - _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); - src += 5; - dst += 8; - } - } - - static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) - { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C6_SHFL), C6_MULLO), 10); - _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); - _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); - src += 12; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s6 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); - _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); - src += 6; - dst += 8; - } - } - - static void Decode16f7(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) - { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C7_SHFL), C7_MULLO), 9); - _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); - _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); - src += 14; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s7 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); - _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); - src += 7; - dst += 8; - } - } - - static void Decode16f8(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m256 _scale = _mm256_set1_ps(scale); - __m256 _shift = _mm256_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16) - { - __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); - _mm_storeu_si128((__m128i*)(dst + i) + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(u8)), _scale, _shift), 0)); - _mm_storeu_si128((__m128i*)(dst + i) + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_srli_si128(u8, 8))), _scale, _shift), 0)); - } - for (; i < size; i += 8) - { - __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); - _mm_storeu_si128((__m128i*)(dst + i), _mm256_cvtps_ph(_mm256_fmadd_ps(_src, _scale, _shift), 0)); - } - } - - //------------------------------------------------------------------------------------------------- - - template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); - - template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m256i ab32 = _mm256_setzero_si256(); - size_t i = 0, size64 = AlignLo(size, 64); - for (; i < size64; i += 64, a += 32, b += 32) - { - __m256i _a = _mm256_loadu_si256((__m256i*)a); - __m256i _b = _mm256_loadu_si256((__m256i*)b); - __m256i ab16 = _mm256_maddubs_epi16(_mm256_and_si256(_a, K8_0F), _mm256_and_si256(_b, K8_0F)); - ab16 = _mm256_add_epi16(ab16, _mm256_maddubs_epi16(_mm256_and_si256(_mm256_srli_epi16(_a, 4), K8_0F), _mm256_and_si256(_mm256_srli_epi16(_b, 4), K8_0F))); - ab32 = _mm256_add_epi32(ab32, _mm256_madd_epi16(ab16, K16_0001)); - } - for (; i < size; i += 8, a += 4, b += 4) - { - __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); - __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); - ab32 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), ab32); - } - return ExtractSum(ab32); - } - - template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m256i _ab = _mm256_setzero_si256(); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16, a += 10, b += 10) - { - __m256i _a = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)a)), C5_SHFL), C5_MULLO), 11); - __m256i _b = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)b)), C5_SHFL), C5_MULLO), 11); - _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); - } - for (; i < size; i += 8, a += 5, b += 5) - { - __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); - __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); - _ab = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), _ab); - } - return ExtractSum(_ab); - } - - template<> int32_t Correlation<6>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m256i _ab = _mm256_setzero_si256(); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16, a += 12, b += 12) - { - __m256i _a = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)a)), C6_SHFL), C6_MULLO), 10); - __m256i _b = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)b)), C6_SHFL), C6_MULLO), 10); - _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); - } - for (; i < size; i += 8, a += 6, b += 6) - { - __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); - __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); - _ab = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), _ab); - } - return ExtractSum(_ab); - } - - template<> int32_t Correlation<7>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m256i _ab = _mm256_setzero_si256(); - size_t i = 0, size16 = AlignLo(size, 16); - for (; i < size16; i += 16, a += 14, b += 14) - { - __m256i _a = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)a)), C7_SHFL), C7_MULLO), 9); - __m256i _b = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)b)), C7_SHFL), C7_MULLO), 9); - _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); - } - for (; i < size; i += 8, a += 7, b += 7) - { - __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); - __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); - _ab = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), _ab); - } - return ExtractSum(_ab); - } - - template<> int32_t Correlation<8>(const uint8_t* a, const uint8_t* b, size_t size) - { - size_t i = 0, size16 = AlignLo(size, 16); - __m256i _ab = _mm256_setzero_si256(); - for (; i < size16; i += 16) - { - __m256i _a = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(a + i))); - __m256i _b = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(b + i))); - _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); - } - for (; i < size; i += 8) - { - __m256i _a = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(a + i))); - __m256i _b = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(b + i))); - _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); - } - return ExtractSum(_ab); - } - - template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) - { - float abSum = (float)Correlation(a + 16, b + 16, size); - Base::DecodeCosineDistance(a, b, abSum, distance); - } - - //------------------------------------------------------------------------------------------------- - - template void MicroCosineDistancesDirect2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - - template<> void MicroCosineDistancesDirect2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size64 = AlignLo(size, 64), o = 16; - __m256i a0, a1, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - __m256i ab10 = _mm256_setzero_si256(); - __m256i ab11 = _mm256_setzero_si256(); - __m256i ab12 = _mm256_setzero_si256(); - __m256i ab13 = _mm256_setzero_si256(); - for (; i < size64; i += 64, o += 32) - { - a0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[0] + o)), K8_0F); - a1 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[1] + o)), K8_0F); - - b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[0] + o)), K8_0F); - ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - ab10 = _mm256_add_epi32(ab10, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[1] + o)), K8_0F); - ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - ab11 = _mm256_add_epi32(ab11, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[2] + o)), K8_0F); - ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - ab12 = _mm256_add_epi32(ab12, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[3] + o)), K8_0F); - ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - ab13 = _mm256_add_epi32(ab13, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); - - a0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o)), 4), K8_0F); - a1 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[1] + o)), 4), K8_0F); - - b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o)), 4), K8_0F); - ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - ab10 = _mm256_add_epi32(ab10, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o)), 4), K8_0F); - ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - ab11 = _mm256_add_epi32(ab11, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o)), 4), K8_0F); - ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - ab12 = _mm256_add_epi32(ab12, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o)), 4), K8_0F); - ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - ab13 = _mm256_add_epi32(ab13, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); - } - for (; i < size; i += 8, o += 4) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C4_SHFL), C4_MULLO), 12); - a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C4_SHFL), C4_MULLO), 12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C4_SHFL), C4_MULLO), 12); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C4_SHFL), C4_MULLO), 12); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C4_SHFL), C4_MULLO), 12); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C4_SHFL), C4_MULLO), 12); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances2x4(A, B, ab, distances, stride); - } - - template<> void MicroCosineDistancesDirect2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m256i a0, a1, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - __m256i ab10 = _mm256_setzero_si256(); - __m256i ab11 = _mm256_setzero_si256(); - __m256i ab12 = _mm256_setzero_si256(); - __m256i ab13 = _mm256_setzero_si256(); - for (; i < size16; i += 16, o += 10) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); - a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[1] + o))), C5_SHFL), C5_MULLO), 11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - for (; i < size; i += 8, o += 5) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); - a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C5_SHFL), C5_MULLO), 11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances2x4(A, B, ab, distances, stride); - } - - template<> void MicroCosineDistancesDirect2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m256i a0, a1, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - __m256i ab10 = _mm256_setzero_si256(); - __m256i ab11 = _mm256_setzero_si256(); - __m256i ab12 = _mm256_setzero_si256(); - __m256i ab13 = _mm256_setzero_si256(); - for (; i < size16; i += 16, o += 12) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C6_SHFL), C6_MULLO), 10); - a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[1] + o))), C6_SHFL), C6_MULLO), 10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C6_SHFL), C6_MULLO), 10); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C6_SHFL), C6_MULLO), 10); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C6_SHFL), C6_MULLO), 10); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C6_SHFL), C6_MULLO), 10); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - for (; i < size; i += 8, o += 6) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C6_SHFL), C6_MULLO), 10); - a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C6_SHFL), C6_MULLO), 10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C6_SHFL), C6_MULLO), 10); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C6_SHFL), C6_MULLO), 10); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C6_SHFL), C6_MULLO), 10); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C6_SHFL), C6_MULLO), 10); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances2x4(A, B, ab, distances, stride); - } - - template<> void MicroCosineDistancesDirect2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m256i a0, a1, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - __m256i ab10 = _mm256_setzero_si256(); - __m256i ab11 = _mm256_setzero_si256(); - __m256i ab12 = _mm256_setzero_si256(); - __m256i ab13 = _mm256_setzero_si256(); - for (; i < size16; i += 16, o += 14) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C7_SHFL), C7_MULLO), 9); - a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[1] + o))), C7_SHFL), C7_MULLO), 9); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C7_SHFL), C7_MULLO), 9); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C7_SHFL), C7_MULLO), 9); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C7_SHFL), C7_MULLO), 9); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C7_SHFL), C7_MULLO), 9); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - for (; i < size; i += 8, o += 7) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C7_SHFL), C7_MULLO), 9); - a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C7_SHFL), C7_MULLO), 9); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C7_SHFL), C7_MULLO), 9); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C7_SHFL), C7_MULLO), 9); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C7_SHFL), C7_MULLO), 9); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C7_SHFL), C7_MULLO), 9); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances2x4(A, B, ab, distances, stride); - } - - template<> void MicroCosineDistancesDirect2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m256i a0, a1, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - __m256i ab10 = _mm256_setzero_si256(); - __m256i ab11 = _mm256_setzero_si256(); - __m256i ab12 = _mm256_setzero_si256(); - __m256i ab13 = _mm256_setzero_si256(); - for (; i < size16; i += 16, o += 16) - { - a0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(A[0] + o))); - a1 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(A[1] + o))); - - b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[0] + o))); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[1] + o))); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[2] + o))); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[3] + o))); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - for (; i < size; i += 8, o += 8) - { - a0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[0] + o))); - a1 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[1] + o))); - - b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[0] + o))); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); - - b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[1] + o))); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); - - b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[2] + o))); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); - - b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[3] + o))); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); - } - __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); - DecodeCosineDistances2x4(A, B, ab, distances, stride); - } - - template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - - template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size64 = AlignLo(size, 64), o = 16; - __m256i a0, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - for (; i < size64; i += 64, o += 32) - { - a0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[0] + o)), K8_0F); - - b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[0] + o)), K8_0F); - ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[1] + o)), K8_0F); - ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[2] + o)), K8_0F); - ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[3] + o)), K8_0F); - ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - - a0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o)), 4), K8_0F); - - b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o)), 4), K8_0F); - ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o)), 4), K8_0F); - ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o)), 4), K8_0F); - ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - - b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o)), 4), K8_0F); - ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); - } - for (; i < size; i += 8, o += 4) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C4_SHFL), C4_MULLO), 12); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C4_SHFL), C4_MULLO), 12); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C4_SHFL), C4_MULLO), 12); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C4_SHFL), C4_MULLO), 12); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C4_SHFL), C4_MULLO), 12); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m256i a0, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - for (; i < size16; i += 16, o += 10) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - for (; i < size; i += 8, o += 5) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m256i a0, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - for (; i < size16; i += 16, o += 12) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C6_SHFL), C6_MULLO), 10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C6_SHFL), C6_MULLO), 10); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C6_SHFL), C6_MULLO), 10); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C6_SHFL), C6_MULLO), 10); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C6_SHFL), C6_MULLO), 10); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - for (; i < size; i += 8, o += 6) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C6_SHFL), C6_MULLO), 10); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C6_SHFL), C6_MULLO), 10); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C6_SHFL), C6_MULLO), 10); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C6_SHFL), C6_MULLO), 10); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C6_SHFL), C6_MULLO), 10); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m256i a0, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - for (; i < size16; i += 16, o += 14) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C7_SHFL), C7_MULLO), 9); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C7_SHFL), C7_MULLO), 9); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C7_SHFL), C7_MULLO), 9); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C7_SHFL), C7_MULLO), 9); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C7_SHFL), C7_MULLO), 9); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - for (; i < size; i += 8, o += 7) - { - a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C7_SHFL), C7_MULLO), 9); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C7_SHFL), C7_MULLO), 9); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C7_SHFL), C7_MULLO), 9); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C7_SHFL), C7_MULLO), 9); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C7_SHFL), C7_MULLO), 9); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size16 = AlignLo(size, 16), o = 16; - __m256i a0, b0; - __m256i ab00 = _mm256_setzero_si256(); - __m256i ab01 = _mm256_setzero_si256(); - __m256i ab02 = _mm256_setzero_si256(); - __m256i ab03 = _mm256_setzero_si256(); - for (; i < size16; i += 16, o += 16) - { - a0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(A[0] + o))); - - b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[0] + o))); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[1] + o))); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[2] + o))); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[3] + o))); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - for (; i < size; i += 8, o += 8) - { - a0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[0] + o))); - - b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[0] + o))); - ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); - - b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[1] + o))); - ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); - - b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[2] + o))); - ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); - - b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[3] + o))); - ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t M2 = AlignLoAny(M, 2); - size_t N4 = AlignLoAny(N, 4); - size_t i = 0; - for (; i < M2; i += 2) - { - size_t j = 0; - for (; j < N4; j += 4) - MicroCosineDistancesDirect2x4(A + i, B + j, size, distances + j, stride); - for (; j < N; j += 1) - { - CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); - CosineDistance(A[i + 1], B[j], size, distances + j + 1 * stride); - } - distances += 2 * stride; - } - for (; i < M; i++) - { - size_t j = 0; - for (; j < N4; j += 4) - MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); - for (; j < N; j += 1) - CosineDistance(A[i], B[j], size, distances + j); - distances += 1 * stride; + dst[0 * stride] = ((float*)src)[0]; + dst[1 * stride] = ((float*)src)[1]; + dst[2 * stride] = ((float*)src)[2]; + dst[3 * stride] = ((float*)src)[3]; } } @@ -1373,60 +132,24 @@ namespace Simd { _minMax32f = MinMax32f; _minMax16f = MinMax16f; - switch (depth) - { - case 4: - { - _encode32f = Encode32f4; - _encode16f = Encode16f4; - _decode32f = Decode32f4; - _decode16f = Decode16f4; - _cosineDistance = Avx2::CosineDistance<4>; - _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<4>; - break; - } - case 5: - { - _encode32f = Encode32f5; - _encode16f = Encode16f5; - _decode32f = Decode32f5; - _decode16f = Decode16f5; - _cosineDistance = Avx2::CosineDistance<5>; - _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<5>; - break; - } - case 6: - { - _encode32f = Encode32f6; - _encode16f = Encode16f6; - _decode32f = Decode32f6; - _decode16f = Decode16f6; - _cosineDistance = Avx2::CosineDistance<6>; - _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<6>; - break; - } - case 7: - { - _encode32f = Encode32f7; - _encode16f = Encode16f7; - _decode32f = Decode32f7; - _decode16f = Decode16f7; - _cosineDistance = Avx2::CosineDistance<7>; - _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<7>; - break; - } - case 8: + _encode32f = GetEncode32f(_depth); + _encode16f = GetEncode16f(_depth); + + _decode32f = GetDecode32f(_depth); + _decode16f = GetDecode16f(_depth); + + _cosineDistance = GetCosineDistance(_depth); + _macroCosineDistancesDirect = GetMacroCosineDistancesDirect(_depth); + + _unpackNormA = UnpackNormA; + _unpackNormB = UnpackNormB; + if (_depth != 8) { - _encode32f = Encode32f8; - _encode16f = Encode16f8; - _decode32f = Decode32f8; - _decode16f = Decode16f8; - _cosineDistance = Avx2::CosineDistance<8>; - _macroCosineDistancesDirect = Avx2::MacroCosineDistancesDirect<8>; - break; - } - default: - assert(0); + _unpackDataA = GetUnpackData(_depth, false); + _unpackDataB = GetUnpackData(_depth, true); + _macroCosineDistancesUnpack = GetMacroCosineDistancesUnpack(_depth); + _microMu = 5; + _microNu = 16; } } diff --git a/src/Simd/SimdAvx2DescrIntCdd.cpp b/src/Simd/SimdAvx2DescrIntCdd.cpp new file mode 100644 index 0000000000..ccccb5502e --- /dev/null +++ b/src/Simd/SimdAvx2DescrIntCdd.cpp @@ -0,0 +1,753 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#ifdef SIMD_AVX2_ENABLE + namespace Avx2 + { + template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); + + template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m256i ab32 = _mm256_setzero_si256(); + size_t i = 0, size64 = AlignLo(size, 64); + for (; i < size64; i += 64, a += 32, b += 32) + { + __m256i _a = _mm256_loadu_si256((__m256i*)a); + __m256i _b = _mm256_loadu_si256((__m256i*)b); + __m256i ab16 = _mm256_maddubs_epi16(_mm256_and_si256(_a, K8_0F), _mm256_and_si256(_b, K8_0F)); + ab16 = _mm256_add_epi16(ab16, _mm256_maddubs_epi16(_mm256_and_si256(_mm256_srli_epi16(_a, 4), K8_0F), _mm256_and_si256(_mm256_srli_epi16(_b, 4), K8_0F))); + ab32 = _mm256_add_epi32(ab32, _mm256_madd_epi16(ab16, K16_0001)); + } + for (; i < size; i += 8, a += 4, b += 4) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + ab32 = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), ab32); + } + return ExtractSum(ab32); + } + + template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m256i _ab = _mm256_setzero_si256(); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16, a += 10, b += 10) + { + __m256i _a = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)a)), C5_SHFL), C5_MULLO), 11); + __m256i _b = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)b)), C5_SHFL), C5_MULLO), 11); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); + } + for (; i < size; i += 8, a += 5, b += 5) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), _ab); + } + return ExtractSum(_ab); + } + + template<> int32_t Correlation<6>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m256i _ab = _mm256_setzero_si256(); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16, a += 12, b += 12) + { + __m256i _a = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)a)), C6_SHFL), C6_MULLO), 10); + __m256i _b = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)b)), C6_SHFL), C6_MULLO), 10); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); + } + for (; i < size; i += 8, a += 6, b += 6) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), _ab); + } + return ExtractSum(_ab); + } + + template<> int32_t Correlation<7>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m256i _ab = _mm256_setzero_si256(); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16, a += 14, b += 14) + { + __m256i _a = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)a)), C7_SHFL), C7_MULLO), 9); + __m256i _b = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)b)), C7_SHFL), C7_MULLO), 9); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); + } + for (; i < size; i += 8, a += 7, b += 7) + { + __m128i _a = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)a), Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); + __m128i _b = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_mm_loadl_epi64((__m128i*)b), Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_mm256_castsi128_si256(_a), _mm256_castsi128_si256(_b)), _ab); + } + return ExtractSum(_ab); + } + + template<> int32_t Correlation<8>(const uint8_t* a, const uint8_t* b, size_t size) + { + size_t i = 0, size16 = AlignLo(size, 16); + __m256i _ab = _mm256_setzero_si256(); + for (; i < size16; i += 16) + { + __m256i _a = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(a + i))); + __m256i _b = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(b + i))); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); + } + for (; i < size; i += 8) + { + __m256i _a = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(a + i))); + __m256i _b = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(b + i))); + _ab = _mm256_add_epi32(_mm256_madd_epi16(_a, _b), _ab); + } + return ExtractSum(_ab); + } + + template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) + { + float abSum = (float)Correlation(a + 16, b + 16, size); + Base::DecodeCosineDistance(a, b, abSum, distance); + } + + //------------------------------------------------------------------------------------------------- + + template void MicroCosineDistancesDirect2x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + + template<> void MicroCosineDistancesDirect2x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size64 = AlignLo(size, 64), o = 16; + __m256i a0, a1, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + __m256i ab10 = _mm256_setzero_si256(); + __m256i ab11 = _mm256_setzero_si256(); + __m256i ab12 = _mm256_setzero_si256(); + __m256i ab13 = _mm256_setzero_si256(); + for (; i < size64; i += 64, o += 32) + { + a0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[0] + o)), K8_0F); + a1 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[1] + o)), K8_0F); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[0] + o)), K8_0F); + ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab10 = _mm256_add_epi32(ab10, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[1] + o)), K8_0F); + ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab11 = _mm256_add_epi32(ab11, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[2] + o)), K8_0F); + ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab12 = _mm256_add_epi32(ab12, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[3] + o)), K8_0F); + ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab13 = _mm256_add_epi32(ab13, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + a0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o)), 4), K8_0F); + a1 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[1] + o)), 4), K8_0F); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o)), 4), K8_0F); + ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab10 = _mm256_add_epi32(ab10, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o)), 4), K8_0F); + ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab11 = _mm256_add_epi32(ab11, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o)), 4), K8_0F); + ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab12 = _mm256_add_epi32(ab12, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o)), 4), K8_0F); + ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + ab13 = _mm256_add_epi32(ab13, _mm256_madd_epi16(_mm256_maddubs_epi16(a1, b0), K16_0001)); + } + for (; i < size; i += 8, o += 4) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C4_SHFL), C4_MULLO), 12); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C4_SHFL), C4_MULLO), 12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C4_SHFL), C4_MULLO), 12); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C4_SHFL), C4_MULLO), 12); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C4_SHFL), C4_MULLO), 12); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C4_SHFL), C4_MULLO), 12); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances2x4(A, B, ab, distances, stride); + } + + template<> void MicroCosineDistancesDirect2x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, a1, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + __m256i ab10 = _mm256_setzero_si256(); + __m256i ab11 = _mm256_setzero_si256(); + __m256i ab12 = _mm256_setzero_si256(); + __m256i ab13 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 10) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[1] + o))), C5_SHFL), C5_MULLO), 11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + for (; i < size; i += 8, o += 5) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C5_SHFL), C5_MULLO), 11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances2x4(A, B, ab, distances, stride); + } + + template<> void MicroCosineDistancesDirect2x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, a1, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + __m256i ab10 = _mm256_setzero_si256(); + __m256i ab11 = _mm256_setzero_si256(); + __m256i ab12 = _mm256_setzero_si256(); + __m256i ab13 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 12) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C6_SHFL), C6_MULLO), 10); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[1] + o))), C6_SHFL), C6_MULLO), 10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C6_SHFL), C6_MULLO), 10); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C6_SHFL), C6_MULLO), 10); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C6_SHFL), C6_MULLO), 10); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C6_SHFL), C6_MULLO), 10); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + for (; i < size; i += 8, o += 6) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C6_SHFL), C6_MULLO), 10); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C6_SHFL), C6_MULLO), 10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C6_SHFL), C6_MULLO), 10); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C6_SHFL), C6_MULLO), 10); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C6_SHFL), C6_MULLO), 10); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C6_SHFL), C6_MULLO), 10); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances2x4(A, B, ab, distances, stride); + } + + template<> void MicroCosineDistancesDirect2x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, a1, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + __m256i ab10 = _mm256_setzero_si256(); + __m256i ab11 = _mm256_setzero_si256(); + __m256i ab12 = _mm256_setzero_si256(); + __m256i ab13 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 14) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C7_SHFL), C7_MULLO), 9); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[1] + o))), C7_SHFL), C7_MULLO), 9); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C7_SHFL), C7_MULLO), 9); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C7_SHFL), C7_MULLO), 9); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C7_SHFL), C7_MULLO), 9); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C7_SHFL), C7_MULLO), 9); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + for (; i < size; i += 8, o += 7) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C7_SHFL), C7_MULLO), 9); + a1 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[1] + o))), C7_SHFL), C7_MULLO), 9); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C7_SHFL), C7_MULLO), 9); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C7_SHFL), C7_MULLO), 9); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C7_SHFL), C7_MULLO), 9); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C7_SHFL), C7_MULLO), 9); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances2x4(A, B, ab, distances, stride); + } + + template<> void MicroCosineDistancesDirect2x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, a1, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + __m256i ab10 = _mm256_setzero_si256(); + __m256i ab11 = _mm256_setzero_si256(); + __m256i ab12 = _mm256_setzero_si256(); + __m256i ab13 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 16) + { + a0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(A[0] + o))); + a1 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(A[1] + o))); + + b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[0] + o))); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[1] + o))); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[2] + o))); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[3] + o))); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + for (; i < size; i += 8, o += 8) + { + a0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[0] + o))); + a1 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[1] + o))); + + b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[0] + o))); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + ab10 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab10); + + b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[1] + o))); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + ab11 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab11); + + b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[2] + o))); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + ab12 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab12); + + b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[3] + o))); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + ab13 = _mm256_add_epi32(_mm256_madd_epi16(a1, b0), ab13); + } + __m256 ab = _mm256_cvtepi32_ps(Extract8Sums(ab00, ab01, ab02, ab03, ab10, ab11, ab12, ab13)); + DecodeCosineDistances2x4(A, B, ab, distances, stride); + } + + template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + + template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size64 = AlignLo(size, 64), o = 16; + __m256i a0, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + for (; i < size64; i += 64, o += 32) + { + a0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(A[0] + o)), K8_0F); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[0] + o)), K8_0F); + ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[1] + o)), K8_0F); + ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[2] + o)), K8_0F); + ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_loadu_si256((__m256i*)(B[3] + o)), K8_0F); + ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + a0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o)), 4), K8_0F); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o)), 4), K8_0F); + ab00 = _mm256_add_epi32(ab00, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o)), 4), K8_0F); + ab01 = _mm256_add_epi32(ab01, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o)), 4), K8_0F); + ab02 = _mm256_add_epi32(ab02, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + + b0 = _mm256_and_si256(_mm256_srli_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o)), 4), K8_0F); + ab03 = _mm256_add_epi32(ab03, _mm256_madd_epi16(_mm256_maddubs_epi16(a0, b0), K16_0001)); + } + for (; i < size; i += 8, o += 4) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C4_SHFL), C4_MULLO), 12); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C4_SHFL), C4_MULLO), 12); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C4_SHFL), C4_MULLO), 12); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C4_SHFL), C4_MULLO), 12); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C4_SHFL), C4_MULLO), 12); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 10) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + for (; i < size; i += 8, o += 5) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C5_SHFL), C5_MULLO), 11); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C5_SHFL), C5_MULLO), 11); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C5_SHFL), C5_MULLO), 11); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C5_SHFL), C5_MULLO), 11); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C5_SHFL), C5_MULLO), 11); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 12) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C6_SHFL), C6_MULLO), 10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C6_SHFL), C6_MULLO), 10); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C6_SHFL), C6_MULLO), 10); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C6_SHFL), C6_MULLO), 10); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C6_SHFL), C6_MULLO), 10); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + for (; i < size; i += 8, o += 6) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C6_SHFL), C6_MULLO), 10); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C6_SHFL), C6_MULLO), 10); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C6_SHFL), C6_MULLO), 10); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C6_SHFL), C6_MULLO), 10); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C6_SHFL), C6_MULLO), 10); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 14) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(A[0] + o))), C7_SHFL), C7_MULLO), 9); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[0] + o))), C7_SHFL), C7_MULLO), 9); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[1] + o))), C7_SHFL), C7_MULLO), 9); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[2] + o))), C7_SHFL), C7_MULLO), 9); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(B[3] + o))), C7_SHFL), C7_MULLO), 9); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + for (; i < size; i += 8, o += 7) + { + a0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(A[0] + o))), C7_SHFL), C7_MULLO), 9); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[0] + o))), C7_SHFL), C7_MULLO), 9); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[1] + o))), C7_SHFL), C7_MULLO), 9); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[2] + o))), C7_SHFL), C7_MULLO), 9); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_castsi128_si256(_mm_loadl_epi64((__m128i*)(B[3] + o))), C7_SHFL), C7_MULLO), 9); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size16 = AlignLo(size, 16), o = 16; + __m256i a0, b0; + __m256i ab00 = _mm256_setzero_si256(); + __m256i ab01 = _mm256_setzero_si256(); + __m256i ab02 = _mm256_setzero_si256(); + __m256i ab03 = _mm256_setzero_si256(); + for (; i < size16; i += 16, o += 16) + { + a0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(A[0] + o))); + + b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[0] + o))); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[1] + o))); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[2] + o))); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_cvtepu8_epi16(_mm_loadu_si128((__m128i*)(B[3] + o))); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + for (; i < size; i += 8, o += 8) + { + a0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(A[0] + o))); + + b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[0] + o))); + ab00 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab00); + + b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[1] + o))); + ab01 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab01); + + b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[2] + o))); + ab02 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab02); + + b0 = _mm256_cvtepu8_epi16(_mm_loadl_epi64((__m128i*)(B[3] + o))); + ab03 = _mm256_add_epi32(_mm256_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t M2 = AlignLoAny(M, 2); + size_t N4 = AlignLoAny(N, 4); + size_t i = 0; + for (; i < M2; i += 2) + { + size_t j = 0; + for (; j < N4; j += 4) + MicroCosineDistancesDirect2x4(A + i, B + j, size, distances + j, stride); + for (; j < N; j += 1) + { + CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); + CosineDistance(A[i + 1], B[j], size, distances + j + 1 * stride); + } + distances += 2 * stride; + } + for (; i < M; i++) + { + size_t j = 0; + for (; j < N4; j += 4) + MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); + for (; j < N; j += 1) + CosineDistance(A[i], B[j], size, distances + j); + distances += 1 * stride; + } + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::CosineDistancePtr GetCosineDistance(size_t depth) + { + switch (depth) + { + case 4: return CosineDistance<4>; + case 5: return CosineDistance<5>; + case 6: return CosineDistance<6>; + case 7: return CosineDistance<7>; + case 8: return CosineDistance<8>; + default: assert(0); return NULL; + } + } + + Sse41::DescrInt::MacroCosineDistancesDirectPtr GetMacroCosineDistancesDirect(size_t depth) + { + switch (depth) + { + case 4: return MacroCosineDistancesDirect<4>; + case 5: return MacroCosineDistancesDirect<5>; + case 6: return MacroCosineDistancesDirect<6>; + case 7: return MacroCosineDistancesDirect<7>; + case 8: return MacroCosineDistancesDirect<8>; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdAvx2DescrIntCdu.cpp b/src/Simd/SimdAvx2DescrIntCdu.cpp new file mode 100644 index 0000000000..950ca9f75d --- /dev/null +++ b/src/Simd/SimdAvx2DescrIntCdu.cpp @@ -0,0 +1,359 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdSynet.h" + +namespace Simd +{ +#ifdef SIMD_AVX2_ENABLE + namespace Avx2 + { + template __m128i UnpackData16(const uint8_t* src); + + template<> SIMD_INLINE __m128i UnpackData16<4>(const uint8_t* src) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, C4_SHFL), C4_MULLO), 12); + return _mm256_castsi256_si128(PackI16ToU8(s16, K_ZERO)); + } + + template<> SIMD_INLINE __m128i UnpackData16<5>(const uint8_t* src) + { + __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, C5_SHFL), C5_MULLO), 11); + return _mm256_castsi256_si128(PackI16ToU8(s16, K_ZERO)); + } + + template<> SIMD_INLINE __m128i UnpackData16<6>(const uint8_t* src) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C6_SHFL), C6_MULLO), 10); + return _mm256_castsi256_si128(PackI16ToU8(s16, K_ZERO)); + } + + template<> SIMD_INLINE __m128i UnpackData16<7>(const uint8_t* src) + { + __m256i s7 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s7, C7_SHFL), C7_MULLO), 9); + return _mm256_castsi256_si128(PackI16ToU8(s16, K_ZERO)); + } + + //------------------------------------------------------------------------------------------------- + + template __m256i UnpackData32(const uint8_t* src); + + template<> SIMD_INLINE __m256i UnpackData32<4>(const uint8_t* src) + { + __m256i lo = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(src + 0))), C4_SHFL), C4_MULLO), 12); + __m256i hi = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(src + 8))), C4_SHFL), C4_MULLO), 12); + return PackI16ToU8(lo, hi); + } + + template<> SIMD_INLINE __m256i UnpackData32<5>(const uint8_t* src) + { + __m256i lo = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(src + 0))), C5_SHFL), C5_MULLO), 11); + __m256i hi = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(src + 10))), C5_SHFL), C5_MULLO), 11); + return PackI16ToU8(lo, hi); + } + + template<> SIMD_INLINE __m256i UnpackData32<6>(const uint8_t* src) + { + __m256i lo = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(src + 0))), C6_SHFL), C6_MULLO), 10); + __m256i hi = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(src + 12))), C6_SHFL), C6_MULLO), 10); + return PackI16ToU8(lo, hi); + } + + template<> SIMD_INLINE __m256i UnpackData32<7>(const uint8_t* src) + { + __m256i lo = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(src + 0))), C7_SHFL), C7_MULLO), 9); + __m256i hi = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(_mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)(src + 14))), C7_SHFL), C7_MULLO), 9); + return PackI16ToU8(lo, hi); + } + + //------------------------------------------------------------------------------------------------- + + + template void UnpackDataA(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (size_t i = 0; i < count; i++) + { + const uint8_t* ps = src[i] + 16; + uint8_t* pd = (uint8_t*)dst + i * size; + size_t j = 0; + for (; j < size32; j += 32, ps += 4 * bits, pd += 32) + _mm256_storeu_si256((__m256i*)pd, UnpackData32(ps)); + for (; j < size16; j += 16, ps += 2 * bits, pd += 16) + _mm_storeu_si128((__m128i*)pd, UnpackData16(ps)); + for (; j < size; j += 8, ps += bits, pd += 8) + _mm_storel_epi64((__m128i*)pd, Sse41::UnpackData8(ps)); + } + } + + //------------------------------------------------------------------------------------------------- + + template SIMD_INLINE void UnpackDataBx4x32(const uint8_t* const* src, size_t offset, uint8_t* dst) + { + __m256i a0 = UnpackData32(src[0] + offset); + __m256i a1 = UnpackData32(src[1] + offset); + __m256i a2 = UnpackData32(src[2] + offset); + __m256i a3 = UnpackData32(src[3] + offset); + __m256i b0 = _mm256_unpacklo_epi32(a0, a2); + __m256i b1 = _mm256_unpacklo_epi32(a1, a3); + __m256i b2 = _mm256_unpackhi_epi32(a0, a2); + __m256i b3 = _mm256_unpackhi_epi32(a1, a3); + Store((__m128i*)dst + 0, (__m128i*)dst + 16, _mm256_unpacklo_epi32(b0, b1)); + Store((__m128i*)dst + 4, (__m128i*)dst + 20, _mm256_unpackhi_epi32(b0, b1)); + Store((__m128i*)dst + 8, (__m128i*)dst + 24, _mm256_unpacklo_epi32(b2, b3)); + Store((__m128i*)dst + 12, (__m128i*)dst + 28, _mm256_unpackhi_epi32(b2, b3)); + } + + template SIMD_INLINE void UnpackDataBx4x16(const uint8_t* const* src, size_t offset, uint8_t* dst) + { + __m128i a0 = UnpackData16(src[0] + offset); + __m128i a1 = UnpackData16(src[1] + offset); + __m128i a2 = UnpackData16(src[2] + offset); + __m128i a3 = UnpackData16(src[3] + offset); + __m128i b0 = _mm_unpacklo_epi32(a0, a2); + __m128i b1 = _mm_unpacklo_epi32(a1, a3); + __m128i b2 = _mm_unpackhi_epi32(a0, a2); + __m128i b3 = _mm_unpackhi_epi32(a1, a3); + _mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 4, _mm_unpackhi_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 8, _mm_unpacklo_epi32(b2, b3)); + _mm_storeu_si128((__m128i*)dst + 12, _mm_unpackhi_epi32(b2, b3)); + } + + template SIMD_INLINE void UnpackDataBx4x8(const uint8_t* const* src, size_t offset, uint8_t* dst) + { + __m128i a0 = Sse41::UnpackData8(src[0] + offset); + __m128i a1 = Sse41::UnpackData8(src[1] + offset); + __m128i a2 = Sse41::UnpackData8(src[2] + offset); + __m128i a3 = Sse41::UnpackData8(src[3] + offset); + __m128i b0 = _mm_unpacklo_epi32(a0, a2); + __m128i b1 = _mm_unpacklo_epi32(a1, a3); + _mm_storeu_si128((__m128i*)dst + 0, _mm_unpacklo_epi32(b0, b1)); + _mm_storeu_si128((__m128i*)dst + 4, _mm_unpackhi_epi32(b0, b1)); + } + + template void UnpackDataB(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t countDF = AlignLo(count, DF), size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i, j, o; + for (i = 0; i < countDF; i += DF, src += DF) + { + for (j = 0, o = 16; j < size32; j += 32, o += 4 * bits, dst += 16 * A) + { + UnpackDataBx4x32(src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x32(src + 4, o, dst + 1 * Sse41::A); + UnpackDataBx4x32(src + 8, o, dst + 2 * Sse41::A); + UnpackDataBx4x32(src + 12, o, dst + 3 * Sse41::A); + } + for (; j < size16; j += 16, o += 2 * bits, dst += 16 * Sse41::A) + { + UnpackDataBx4x16(src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x16(src + 4, o, dst + 1 * Sse41::A); + UnpackDataBx4x16(src + 8, o, dst + 2 * Sse41::A); + UnpackDataBx4x16(src + 12, o, dst + 3 * Sse41::A); + } + for (; j < size; j += 8, o += bits, dst += 8 * Sse41::A) + { + UnpackDataBx4x8(src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x8(src + 2, o, dst + 1 * Sse41::A); + UnpackDataBx4x8(src + 4, o, dst + 2 * Sse41::A); + UnpackDataBx4x8(src + 6, o, dst + 3 * Sse41::A); + } + } + if (i < count) + { + const uint8_t* _src[DF]; + for (size_t j = 0; j < DF; i++, j++) + _src[j] = i < count ? *src++ : src[-1]; + for (j = 0, o = 16; j < size32; j += 32, o += 4 * bits, dst += 16 * A) + { + UnpackDataBx4x32(src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x32(src + 4, o, dst + 1 * Sse41::A); + UnpackDataBx4x32(src + 8, o, dst + 2 * Sse41::A); + UnpackDataBx4x32(src + 12, o, dst + 3 * Sse41::A); + } + for (; j < size16; j += 16, o += 2 * bits, dst += 16 * Sse41::A) + { + UnpackDataBx4x16(src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x16(src + 4, o, dst + 1 * Sse41::A); + UnpackDataBx4x16(src + 8, o, dst + 2 * Sse41::A); + UnpackDataBx4x16(src + 12, o, dst + 3 * Sse41::A); + } + for (; j < size; j += 8, o += bits, dst += 8 * Sse41::A) + { + UnpackDataBx4x8(src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x8(src + 2, o, dst + 1 * Sse41::A); + UnpackDataBx4x8(src + 4, o, dst + 2 * Sse41::A); + UnpackDataBx4x8(src + 6, o, dst + 3 * Sse41::A); + } + } + } + + //------------------------------------------------------------------------------------------------- + + template void Correlation8_2xM(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride) + { + __m256i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, a0, b0, b1; + const uint8_t* ad1 = ad0 + 1 * K; + const uint8_t* ad2 = ad0 + 2 * K; + const uint8_t* ad3 = ad0 + 3 * K; + const uint8_t* ad4 = ad0 + 4 * K; + if (N > 4) + { + if (M > 0) ab00 = _mm256_setzero_si256(), ab01 = _mm256_setzero_si256(); + if (M > 1) ab10 = _mm256_setzero_si256(), ab11 = _mm256_setzero_si256(); + if (M > 2) ab20 = _mm256_setzero_si256(), ab21 = _mm256_setzero_si256(); + if (M > 3) ab30 = _mm256_setzero_si256(), ab31 = _mm256_setzero_si256(); + if (M > 4) ab40 = _mm256_setzero_si256(), ab41 = _mm256_setzero_si256(); + for (size_t k = 0; k < K; k += 4) + { + b0 = _mm256_loadu_si256((__m256i*)bd + 0); + b1 = _mm256_loadu_si256((__m256i*)bd + 1); + if (M > 0) a0 = Set4(ad0 + k), Madd4(ab00, a0, b0), Madd4(ab01, a0, b1); + if (M > 1) a0 = Set4(ad1 + k), Madd4(ab10, a0, b0), Madd4(ab11, a0, b1); + if (M > 2) a0 = Set4(ad2 + k), Madd4(ab20, a0, b0), Madd4(ab21, a0, b1); + if (M > 3) a0 = Set4(ad3 + k), Madd4(ab30, a0, b0), Madd4(ab31, a0, b1); + if (M > 4) a0 = Set4(ad4 + k), Madd4(ab40, a0, b0), Madd4(ab41, a0, b1); + bd += DA; + } + if (N == DF) + { + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab01, distances + F), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab11, distances + F), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab21, distances + F), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab31, distances + F), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab41, distances + F), an += 4, distances += stride; + } + else + { + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab01, distances + F, N - F), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab11, distances + F, N - F), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab21, distances + F, N - F), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab31, distances + F, N - F), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab41, distances + F, N - F), an += 4, distances += stride; + } + } + else + { + if (M > 0) ab00 = _mm256_setzero_si256(); + if (M > 1) ab10 = _mm256_setzero_si256(); + if (M > 2) ab20 = _mm256_setzero_si256(); + if (M > 3) ab30 = _mm256_setzero_si256(); + if (M > 4) ab40 = _mm256_setzero_si256(); + for (size_t k = 0; k < K; k += 4) + { + b0 = _mm256_loadu_si256((__m256i*)bd + 0); + if (M > 0) a0 = Set4(ad0 + k), Madd4(ab00, a0, b0); + if (M > 1) a0 = Set4(ad1 + k), Madd4(ab10, a0, b0); + if (M > 2) a0 = Set4(ad2 + k), Madd4(ab20, a0, b0); + if (M > 3) a0 = Set4(ad3 + k), Madd4(ab30, a0, b0); + if (M > 4) a0 = Set4(ad4 + k), Madd4(ab40, a0, b0); + bd += DA; + } + if (N == F) + { + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; + } + else + { + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; + } + } + } + + typedef void(*Correlation8_2xM_Ptr)(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride); + + SIMD_INLINE Correlation8_2xM_Ptr GetCorrelation8_2xM(size_t M) + { + switch (M) + { + case 0: return NULL; + case 1: return Correlation8_2xM<1>; + case 2: return Correlation8_2xM<2>; + case 3: return Correlation8_2xM<3>; + case 4: return Correlation8_2xM<4>; + case 5: return Correlation8_2xM<5>; + } + assert(0); + return NULL; + } + + void MacroCorrelation8(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride) + { + size_t M5 = AlignLoAny(M, 5); + Correlation8_2xM_Ptr correlation_2x5 = GetCorrelation8_2xM(5); + Correlation8_2xM_Ptr correlation_2xT = GetCorrelation8_2xM(M - M5); + for (size_t j = 0; j < N; j += DF) + { + size_t dN = Simd::Min(DF, N - j); + size_t i = 0; + for (; i < M5; i += 5) + correlation_2x5(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + if (i < M) + correlation_2xT(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + bd += K * DF; + bn += DF; + distances += DF; + } + } + + //------------------------------------------------------------------------------------------------- + + Sse41::DescrInt::UnpackDataPtr GetUnpackData(size_t depth, bool transpose) + { + switch (depth) + { + case 4: return transpose ? UnpackDataB<4> : UnpackDataA<4>; + case 5: return transpose ? UnpackDataB<5> : UnpackDataA<5>; + case 6: return transpose ? UnpackDataB<6> : UnpackDataA<6>; + case 7: return transpose ? UnpackDataB<7> : UnpackDataA<7>; + default: return NULL; + } + } + + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth) + { + return depth == 8 ? NULL : MacroCorrelation8; + } + } +#endif +} diff --git a/src/Simd/SimdAvx2DescrIntDec.cpp b/src/Simd/SimdAvx2DescrIntDec.cpp new file mode 100644 index 0000000000..3059d335d2 --- /dev/null +++ b/src/Simd/SimdAvx2DescrIntDec.cpp @@ -0,0 +1,307 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#ifdef SIMD_AVX2_ENABLE + namespace Avx2 + { + static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, C4_SHFL), C4_MULLO), 12); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); + _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); + src += 8; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); + src += 4; + dst += 8; + } + } + + static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, C5_SHFL), C5_MULLO), 11); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); + _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); + src += 10; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); + src += 5; + dst += 8; + } + } + + static void Decode32f6(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C6_SHFL), C6_MULLO), 10); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); + _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); + src += 12; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s6 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); + src += 6; + dst += 8; + } + } + + static void Decode32f7(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C7_SHFL), C7_MULLO), 9); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift)); + _mm256_storeu_ps(dst + 8, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift)); + src += 14; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s7 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift)); + src += 7; + dst += 8; + } + } + + static void Decode32f8(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); + _mm256_storeu_ps(dst + i + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(u8)), _scale, _shift)); + _mm256_storeu_ps(dst + i + F, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_srli_si128(u8, 8))), _scale, _shift)); + } + for (; i < size; i += 8) + { + __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); + _mm256_storeu_ps(dst + i, _mm256_fmadd_ps(_src, _scale, _shift)); + } + } + + //------------------------------------------------------------------------------------------------- + + static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, C4_SHFL), C4_MULLO), 12); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); + src += 8; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 4; + dst += 8; + } + } + + static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, C5_SHFL), C5_MULLO), 11); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); + src += 10; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 5; + dst += 8; + } + } + + static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C6_SHFL), C6_MULLO), 10); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); + src += 12; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s6 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 6; + dst += 8; + } + } + + static void Decode16f7(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, C7_SHFL), C7_MULLO), 9); + _mm_storeu_si128((__m128i*)dst + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 0))), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)dst + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(_mm256_extracti128_si256(s16, 1))), _scale, _shift), 0)); + src += 14; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s7 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 7; + dst += 8; + } + } + + static void Decode16f8(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _shift = _mm256_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16); + for (; i < size16; i += 16) + { + __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); + _mm_storeu_si128((__m128i*)(dst + i) + 0, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(u8)), _scale, _shift), 0)); + _mm_storeu_si128((__m128i*)(dst + i) + 1, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_srli_si128(u8, 8))), _scale, _shift), 0)); + } + for (; i < size; i += 8) + { + __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); + _mm_storeu_si128((__m128i*)(dst + i), _mm256_cvtps_ph(_mm256_fmadd_ps(_src, _scale, _shift), 0)); + } + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::Decode32fPtr GetDecode32f(size_t depth) + { + switch (depth) + { + case 4: return Decode32f4; + case 5: return Decode32f5; + case 6: return Decode32f6; + case 7: return Decode32f7; + case 8: return Decode32f8; + default: assert(0); return NULL; + } + } + + Base::DescrInt::Decode16fPtr GetDecode16f(size_t depth) + { + switch (depth) + { + case 4: return Decode16f4; + case 5: return Decode16f5; + case 6: return Decode16f6; + case 7: return Decode16f7; + case 8: return Decode16f8; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdAvx2DescrIntEnc.cpp b/src/Simd/SimdAvx2DescrIntEnc.cpp new file mode 100644 index 0000000000..a8f02ec08b --- /dev/null +++ b/src/Simd/SimdAvx2DescrIntEnc.cpp @@ -0,0 +1,432 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#ifdef SIMD_AVX2_ENABLE + namespace Avx2 + { + SIMD_INLINE __m256i Encode32f(__m256 src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i value = _mm256_cvtps_epi32(_mm256_mul_ps(_mm256_sub_ps(src, min), scale)); + sum = _mm256_add_epi32(value, sum); + sqsum = _mm256_add_epi32(_mm256_madd_epi16(value, value), sqsum); + return value; + } + + SIMD_INLINE __m256i Encode32f(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + return Encode32f(_mm256_loadu_ps(src), scale, min, sum, sqsum); + } + + static SIMD_INLINE __m128i Encode32f4x8(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0 * 8, scale, min, sum, sqsum); + __m128i s0 = _mm_srli_epi32(_mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E4_MULLO), 12); + return _mm_packus_epi16(_mm_packus_epi32(s0, Sse41::K_ZERO), Sse41::K_ZERO); + } + + static SIMD_INLINE __m128i Encode32f4x32(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0 * 8, scale, min, sum, sqsum); + __m256i i1 = Encode32f(src + 1 * 8, scale, min, sum, sqsum); + __m256i s0 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + __m256i i2 = Encode32f(src + 2 * 8, scale, min, sum, sqsum); + __m256i i3 = Encode32f(src + 3 * 8, scale, min, sum, sqsum); + __m256i s1 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); + return _mm_packus_epi16(_mm_packus_epi32(_mm256_castsi256_si128(s0), _mm256_extracti128_si256(s0, 1)), + _mm_packus_epi32(_mm256_castsi256_si128(s1), _mm256_extracti128_si256(s1, 1))); + } + + static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < size32; i += 32, src += 32, dst += 16) + _mm_storeu_si128((__m128i*)dst, Encode32f4x32(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 4) + *(uint32_t*)(dst) = _mm_extract_epi32(Encode32f4x8(src, _scale, _min, _sum, _sqsum), 0); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f5x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E5_MULLO); + return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E5_SHFL0), _mm_shuffle_epi8(s0, Sse41::E5_SHFL1)), _mm_shuffle_epi8(s0, Sse41::E5_SHFL2)); + } + + static SIMD_INLINE __m128i Encode32f5x2(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m256i i8 = Encode32f(src + 8, scale, min, sum, sqsum); + __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E5_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_or_si256(_mm256_shuffle_epi8(s0, E5_SHFL0), _mm256_shuffle_epi8(s0, E5_SHFL1)), _mm256_shuffle_epi8(s0, E5_SHFL2)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static void Encode32f5(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < main16; i += 16, src += 16, dst += 10) + _mm_storeu_si128((__m128i*)dst, Encode32f5x2(src, _scale, _min, _sum, _sqsum)); + for (; i < main; i += 8, src += 8, dst += 5) + _mm_storel_epi64((__m128i*)dst, Encode32f5x1(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 5) + { + __m128i d0 = Encode32f5x1(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f6x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E6_MULLO); + return _mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E6_SHFL0), _mm_shuffle_epi8(s0, Sse41::E6_SHFL1)); + } + + static SIMD_INLINE __m128i Encode32f6x2(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m256i i8 = Encode32f(src + 8, scale, min, sum, sqsum); + __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E6_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, E6_SHFL0), _mm256_shuffle_epi8(s0, E6_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static void Encode32f6(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < main16; i += 16, src += 16, dst += 12) + _mm_storeu_si128((__m128i*)dst, Encode32f6x2(src, _scale, _min, _sum, _sqsum)); + for (; i < main; i += 8, src += 8, dst += 6) + _mm_storel_epi64((__m128i*)dst, Encode32f6x1(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 6) + { + __m128i d0 = Encode32f6x1(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f7x1(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E7_MULLO); + return _mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E7_SHFL0), _mm_shuffle_epi8(s0, Sse41::E7_SHFL1)); + } + + static SIMD_INLINE __m128i Encode32f7x2(const float* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(src + 0, scale, min, sum, sqsum); + __m256i i8 = Encode32f(src + 8, scale, min, sum, sqsum); + __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E7_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, E7_SHFL0), _mm256_shuffle_epi8(s0, E7_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static void Encode32f7(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < main16; i += 16, src += 16, dst += 14) + _mm_storeu_si128((__m128i*)dst, Encode32f7x2(src, _scale, _min, _sum, _sqsum)); + for (; i < main; i += 8, src += 8, dst += 7) + _mm_storel_epi64((__m128i*)dst, Encode32f7x1(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 7) + { + __m128i d0 = Encode32f7x1(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); + *(uint8_t*)(dst + 6) = _mm_extract_epi8(d0, 6); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static void Encode32f8(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t sizeA = AlignLo(size, A), i = 0; + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < sizeA; i += A) + { + __m256i d0 = Encode32f(src + i + 0 * F, _scale, _min, _sum, _sqsum); + __m256i d1 = Encode32f(src + i + 1 * F, _scale, _min, _sum, _sqsum); + __m256i d2 = Encode32f(src + i + 2 * F, _scale, _min, _sum, _sqsum); + __m256i d3 = Encode32f(src + i + 3 * F, _scale, _min, _sum, _sqsum); + _mm256_storeu_si256((__m256i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); + } + for (; i < size; i += F) + { + __m256i d0 = Encode32f(src + i, _scale, _min, _sum, _sqsum); + _mm_storel_epi64((__m128i*)(dst + i), _mm256_castsi256_si128(PackI16ToU8(PackI32ToI16(d0, _mm256_setzero_si256()), _mm256_setzero_si256()))); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + //------------------------------------------------------------------------------------------------- + + static SIMD_INLINE __m128i Encode16f4x8(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); + __m128i s0 = _mm_srli_epi32(_mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E4_MULLO), 12); + return _mm_packus_epi16(_mm_packus_epi32(s0, Sse41::K_ZERO), Sse41::K_ZERO); + } + + static SIMD_INLINE __m128i Encode16f4x32(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); + __m256i i1 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); + __m256i s0 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + __m256i i2 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 2)), scale, min, sum, sqsum); + __m256i i3 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 3)), scale, min, sum, sqsum); + __m256i s1 = _mm256_srli_epi32(_mm256_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); + return _mm_packus_epi16(_mm_packus_epi32(_mm256_castsi256_si128(s0), _mm256_extracti128_si256(s0, 1)), + _mm_packus_epi32(_mm256_castsi256_si128(s1), _mm256_extracti128_si256(s1, 1))); + } + + static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < size32; i += 32, src += 32, dst += 16) + _mm_storeu_si128((__m128i*)dst, Encode16f4x32(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 4) + *(uint32_t*)(dst) = _mm_extract_epi32(Encode16f4x8(src, _scale, _min, _sum, _sqsum), 0); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f5x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E5_MULLO); + return _mm_or_si128(_mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E5_SHFL0), _mm_shuffle_epi8(s0, Sse41::E5_SHFL1)), _mm_shuffle_epi8(s0, Sse41::E5_SHFL2)); + } + + static SIMD_INLINE __m128i Encode16f5x2(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); + __m256i i8 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); + __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E5_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_or_si256(_mm256_shuffle_epi8(s0, E5_SHFL0), _mm256_shuffle_epi8(s0, E5_SHFL1)), _mm256_shuffle_epi8(s0, E5_SHFL2)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static void Encode16f5(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < main16; i += 16, src += 16, dst += 10) + _mm_storeu_si128((__m128i*)dst, Encode16f5x2(src, _scale, _min, _sum, _sqsum)); + for (; i < main; i += 8, src += 8, dst += 5) + _mm_storel_epi64((__m128i*)dst, Encode16f5x1(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 5) + { + __m128i d0 = Encode16f5x1(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint8_t*)(dst + 4) = _mm_extract_epi8(d0, 4); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f6x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E6_MULLO); + return _mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E6_SHFL0), _mm_shuffle_epi8(s0, Sse41::E6_SHFL1)); + } + + static SIMD_INLINE __m128i Encode16f6x2(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); + __m256i i8 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); + __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E6_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, E6_SHFL0), _mm256_shuffle_epi8(s0, E6_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static void Encode16f6(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < main16; i += 16, src += 16, dst += 12) + _mm_storeu_si128((__m128i*)dst, Encode16f6x2(src, _scale, _min, _sum, _sqsum)); + for (; i < main; i += 8, src += 8, dst += 6) + _mm_storel_epi64((__m128i*)dst, Encode16f6x1(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 6) + { + __m128i d0 = Encode16f6x1(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f7x1(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src)), scale, min, sum, sqsum); + __m128i s0 = _mm_mullo_epi16(_mm256_castsi256_si128(PackU32ToI16(i0, _mm256_setzero_si256())), Sse41::E7_MULLO); + return _mm_or_si128(_mm_shuffle_epi8(s0, Sse41::E7_SHFL0), _mm_shuffle_epi8(s0, Sse41::E7_SHFL1)); + } + + static SIMD_INLINE __m128i Encode16f7x2(const uint16_t* src, __m256 scale, __m256 min, __m256i& sum, __m256i& sqsum) + { + __m256i i0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 0)), scale, min, sum, sqsum); + __m256i i8 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)src + 1)), scale, min, sum, sqsum); + __m256i s0 = _mm256_mullo_epi16(PackU32ToI16(i0, i8), E7_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, E7_SHFL0), _mm256_shuffle_epi8(s0, E7_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static void Encode16f7(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, main = size - 8, main16 = AlignLo(main, 16); + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < main16; i += 16, src += 16, dst += 14) + _mm_storeu_si128((__m128i*)dst, Encode16f7x2(src, _scale, _min, _sum, _sqsum)); + for (; i < main; i += 8, src += 8, dst += 7) + _mm_storel_epi64((__m128i*)dst, Encode16f7x1(src, _scale, _min, _sum, _sqsum)); + for (; i < size; i += 8, src += 8, dst += 7) + { + __m128i d0 = Encode16f7x1(src, _scale, _min, _sum, _sqsum); + *(uint32_t*)(dst + 0) = _mm_extract_epi32(d0, 0); + *(uint16_t*)(dst + 4) = _mm_extract_epi16(d0, 2); + *(uint8_t*)(dst + 6) = _mm_extract_epi8(d0, 6); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static void Encode16f8(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t sizeA = AlignLo(size, A), i = 0; + __m256 _scale = _mm256_set1_ps(scale); + __m256 _min = _mm256_set1_ps(min); + __m256i _sum = _mm256_setzero_si256(); + __m256i _sqsum = _mm256_setzero_si256(); + for (; i < sizeA; i += A) + { + __m256i d0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i) + 0)), _scale, _min, _sum, _sqsum); + __m256i d1 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i) + 1)), _scale, _min, _sum, _sqsum); + __m256i d2 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i) + 2)), _scale, _min, _sum, _sqsum); + __m256i d3 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i) + 3)), _scale, _min, _sum, _sqsum); + _mm256_storeu_si256((__m256i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); + } + for (; i < size; i += F) + { + __m256i d0 = Encode32f(_mm256_cvtph_ps(_mm_loadu_si128((__m128i*)(src + i))), _scale, _min, _sum, _sqsum); + _mm_storel_epi64((__m128i*)(dst + i), _mm256_castsi256_si128(PackI16ToU8(PackI32ToI16(d0, _mm256_setzero_si256()), _mm256_setzero_si256()))); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::Encode32fPtr GetEncode32f(size_t depth) + { + switch (depth) + { + case 4: return Encode32f4; + case 5: return Encode32f5; + case 6: return Encode32f6; + case 7: return Encode32f7; + case 8: return Encode32f8; + default: assert(0); return NULL; + } + } + + Base::DescrInt::Encode16fPtr GetEncode16f(size_t depth) + { + switch (depth) + { + case 4: return Encode16f4; + case 5: return Encode16f5; + case 6: return Encode16f6; + case 7: return Encode16f7; + case 8: return Encode16f8; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdDescrInt.h b/src/Simd/SimdDescrInt.h index 6f0ecdd048..554bc4ce78 100644 --- a/src/Simd/SimdDescrInt.h +++ b/src/Simd/SimdDescrInt.h @@ -141,6 +141,20 @@ namespace Simd //------------------------------------------------------------------------------------------------- + Base::DescrInt::Encode32fPtr GetEncode32f(size_t depth); + Base::DescrInt::Encode16fPtr GetEncode16f(size_t depth); + + Base::DescrInt::Decode32fPtr GetDecode32f(size_t depth); + Base::DescrInt::Decode16fPtr GetDecode16f(size_t depth); + + Base::DescrInt::CosineDistancePtr GetCosineDistance(size_t depth); + Sse41::DescrInt::MacroCosineDistancesDirectPtr GetMacroCosineDistancesDirect(size_t depth); + + Sse41::DescrInt::UnpackDataPtr GetUnpackData(size_t depth, bool transpose); + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth); + + //------------------------------------------------------------------------------------------------- + void* DescrIntInit(size_t size, size_t depth); } #endif diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index 74c6a46996..8a78019204 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -85,6 +85,38 @@ namespace Simd //------------------------------------------------------------------------------------------------- + template __m128i UnpackData8(const uint8_t* src); + + template<> SIMD_INLINE __m128i UnpackData8<4>(const uint8_t* src) + { + __m128i _src = _mm_loadl_epi64((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C4_SHFL0), C4_MULLO), 12); + return _mm_packus_epi16(lo, K_ZERO); + } + + template<> SIMD_INLINE __m128i UnpackData8<5>(const uint8_t* src) + { + __m128i _src = _mm_loadl_epi64((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C5_SHFL0), C5_MULLO), 11); + return _mm_packus_epi16(lo, K_ZERO); + } + + template<> SIMD_INLINE __m128i UnpackData8<6>(const uint8_t* src) + { + __m128i _src = _mm_loadl_epi64((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C6_SHFL0), C6_MULLO), 10); + return _mm_packus_epi16(lo, K_ZERO); + } + + template<> SIMD_INLINE __m128i UnpackData8<7>(const uint8_t* src) + { + __m128i _src = _mm_loadl_epi64((__m128i*)src); + __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C7_SHFL0), C7_MULLO), 9); + return _mm_packus_epi16(lo, K_ZERO); + } + + //------------------------------------------------------------------------------------------------- + SIMD_INLINE void DecodeCosineDistances1x4(const uint8_t* a, const uint8_t* const* B, __m128 abSum, float* distances) { __m128 aScale, aShift, aMean, aNorm, bScale, bShift, bMean, bNorm; @@ -113,7 +145,9 @@ namespace Simd _mm_storeu_ps(distances, _mm_min_ps(_mm_max_ps(_mm_sub_ps(_mm_set1_ps(1.0f), _mm_div_ps(ab, _mm_mul_ps(aNorm, bNorm))), _mm_setzero_ps()), _mm_set1_ps(2.0f))); } - SIMD_INLINE void DecodeCosineDistances1x4(const float* a, const float *b, size_t stride, __m128i abSum, float* distances) + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE void DecodeCosineDistances1xF(const float* a, const float *b, size_t stride, __m128i abSum, float* distances) { __m128 aScale = _mm_set1_ps(a[0]); __m128 aShift = _mm_set1_ps(a[1]); @@ -129,10 +163,10 @@ namespace Simd _mm_storeu_ps(distances, _mm_min_ps(_mm_max_ps(_mm_sub_ps(_mm_set1_ps(1.0f), _mm_div_ps(ab, _mm_mul_ps(aNorm, bNorm))), _mm_setzero_ps()), _mm_set1_ps(2.0f))); } - SIMD_INLINE void DecodeCosineDistances1x4(const float* a, const float* b, size_t stride, __m128i abSum, float* distances, size_t N) + SIMD_INLINE void DecodeCosineDistances1xF(const float* a, const float* b, size_t stride, __m128i abSum, float* distances, size_t N) { - float d[4]; - DecodeCosineDistances1x4(a, b, stride, abSum, d); + float d[F]; + DecodeCosineDistances1xF(a, b, stride, abSum, d); for (size_t i = 0; i < N; ++i) distances[i] = d[i]; } @@ -222,6 +256,30 @@ namespace Simd Avx::Store(distances + 0 * stride, distances + 1 * stride, _mm256_min_ps(_mm256_max_ps(_mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_div_ps(ab, _mm256_mul_ps(aNorm, bNorm))), _mm256_setzero_ps()), _mm256_set1_ps(2.0f))); } + + SIMD_INLINE void DecodeCosineDistances1xF(const float* a, const float* b, size_t stride, __m256i abSum, float* distances) + { + __m256 aScale = _mm256_set1_ps(a[0]); + __m256 aShift = _mm256_set1_ps(a[1]); + __m256 aMean = _mm256_set1_ps(a[2]); + __m256 aNorm = _mm256_set1_ps(a[3]); + __m256 bScale = _mm256_loadu_ps(b + 0 * stride); + __m256 bShift = _mm256_loadu_ps(b + 1 * stride); + __m256 bMean = _mm256_loadu_ps(b + 2 * stride); + __m256 bNorm = _mm256_loadu_ps(b + 3 * stride); + __m256 ab = _mm256_mul_ps(_mm256_cvtepi32_ps(abSum), _mm256_mul_ps(aScale, bScale)); + ab = _mm256_add_ps(_mm256_mul_ps(aMean, bShift), ab); + ab = _mm256_add_ps(_mm256_mul_ps(bMean, aShift), ab); + _mm256_storeu_ps(distances, _mm256_min_ps(_mm256_max_ps(_mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_div_ps(ab, _mm256_mul_ps(aNorm, bNorm))), _mm256_setzero_ps()), _mm256_set1_ps(2.0f))); + } + + SIMD_INLINE void DecodeCosineDistances1xF(const float* a, const float* b, size_t stride, __m256i abSum, float* distances, size_t N) + { + float d[F]; + DecodeCosineDistances1xF(a, b, stride, abSum, d); + for (size_t i = 0; i < N; ++i) + distances[i] = d[i]; + } } #endif diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 5ca2328edb..5ea215b91a 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -139,7 +139,7 @@ namespace Simd void DescrInt::CosineDistancesMxNa(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, float* distances) const { - if(_unpSize * _microNu > Base::AlgCacheL1() || N * 2 < _microNu || _depth < 5 || _depth == 8) + if(_unpSize * _microNu > Base::AlgCacheL1() || N * 2 < _microNu || _depth == 8) CosineDistancesDirect(M, N, A, B, distances); else CosineDistancesUnpack(M, N, A, B, distances); diff --git a/src/Simd/SimdSse41DescrIntScd.cpp b/src/Simd/SimdSse41DescrIntCdd.cpp similarity index 100% rename from src/Simd/SimdSse41DescrIntScd.cpp rename to src/Simd/SimdSse41DescrIntCdd.cpp diff --git a/src/Simd/SimdSse41DescrIntScu.cpp b/src/Simd/SimdSse41DescrIntCdu.cpp similarity index 77% rename from src/Simd/SimdSse41DescrIntScu.cpp rename to src/Simd/SimdSse41DescrIntCdu.cpp index dc2a2bfb13..83de89f27e 100644 --- a/src/Simd/SimdSse41DescrIntScu.cpp +++ b/src/Simd/SimdSse41DescrIntCdu.cpp @@ -37,38 +37,6 @@ namespace Simd #ifdef SIMD_SSE41_ENABLE namespace Sse41 { - template __m128i UnpackData8(const uint8_t* src); - - template<> SIMD_INLINE __m128i UnpackData8<4>(const uint8_t* src) - { - __m128i _src = _mm_loadl_epi64((__m128i*)src); - __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C4_SHFL0), C4_MULLO), 12); - return _mm_packus_epi16(lo, K_ZERO); - } - - template<> SIMD_INLINE __m128i UnpackData8<5>(const uint8_t* src) - { - __m128i _src = _mm_loadl_epi64((__m128i*)src); - __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C5_SHFL0), C5_MULLO), 11); - return _mm_packus_epi16(lo, K_ZERO); - } - - template<> SIMD_INLINE __m128i UnpackData8<6>(const uint8_t* src) - { - __m128i _src = _mm_loadl_epi64((__m128i*)src); - __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C6_SHFL0), C6_MULLO), 10); - return _mm_packus_epi16(lo, K_ZERO); - } - - template<> SIMD_INLINE __m128i UnpackData8<7>(const uint8_t* src) - { - __m128i _src = _mm_loadl_epi64((__m128i*)src); - __m128i lo = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(_src, C7_SHFL0), C7_MULLO), 9); - return _mm_packus_epi16(lo, K_ZERO); - } - - //------------------------------------------------------------------------------------------------- - template __m128i UnpackData16(const uint8_t* src); template<> SIMD_INLINE __m128i UnpackData16<4>(const uint8_t* src) @@ -291,21 +259,21 @@ namespace Simd } if (N == 8) { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4), an += 4, distances += stride; - if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4), an += 4, distances += stride; + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab01, distances + 4), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab11, distances + 4), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab21, distances + 4), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab31, distances + 4), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab41, distances + 4), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab51, distances + 4), an += 4, distances += stride; } else { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4, N - 4), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4, N - 4), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4, N - 4), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4, N - 4), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4, N - 4), an += 4, distances += stride; - if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab51, distances + 4, N - 4), an += 4, distances += stride; + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab01, distances + 4, N - 4), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab11, distances + 4, N - 4), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab21, distances + 4, N - 4), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab31, distances + 4, N - 4), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab41, distances + 4, N - 4), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1xF(an, bn + 4, bnStride, ab51, distances + 4, N - 4), an += 4, distances += stride; } } else @@ -329,21 +297,21 @@ namespace Simd } if (N == 4) { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; - if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0), an += 4, distances += stride; + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0), an += 4, distances += stride; } else { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; - if (M > 5) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab50, distances + 0, N), an += 4, distances += stride; + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; + if (M > 5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0, N), an += 4, distances += stride; } } } @@ -414,21 +382,21 @@ namespace Simd if (M > 4) a0 = Set4(ad4 + k), Madd4(ab40, a0, b0), Madd4(ab41, a0, b1); bd += DA; } - if (N == 8) + if (N == DF) { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4), an += 4, distances += stride; + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab01, distances + F), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab11, distances + F), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab21, distances + F), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab31, distances + F), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab41, distances + F), an += 4, distances += stride; } else { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab01, distances + 4, N - 4), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab11, distances + 4, N - 4), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab21, distances + 4, N - 4), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab31, distances + 4, N - 4), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1x4(an, bn + 4, bnStride, ab41, distances + 4, N - 4), an += 4, distances += stride; + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab01, distances + F, N - F), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab11, distances + F, N - F), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab21, distances + F, N - F), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab31, distances + F, N - F), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab41, distances + F, N - F), an += 4, distances += stride; } } else @@ -448,21 +416,21 @@ namespace Simd if (M > 4) a0 = Set4(ad4 + k), Madd4(ab40, a0, b0); bd += DA; } - if (N == 4) + if (N == F) { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), an += 4, distances += stride; } else { - if (M > 0) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; - if (M > 1) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; - if (M > 2) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; - if (M > 3) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; - if (M > 4) DecodeCosineDistances1x4(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; + if (M > 0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0, N), an += 4, distances += stride; + if (M > 1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0, N), an += 4, distances += stride; + if (M > 2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0, N), an += 4, distances += stride; + if (M > 3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0, N), an += 4, distances += stride; + if (M > 4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0, N), an += 4, distances += stride; } } } @@ -489,21 +457,20 @@ namespace Simd size_t M5 = AlignLoAny(M, 5); Correlation8_2xM_Ptr correlation_2x5 = GetCorrelation8_2xM(5); Correlation8_2xM_Ptr correlation_2xT = GetCorrelation8_2xM(M - M5); - for (size_t j = 0; j < N; j += 8) + for (size_t j = 0; j < N; j += DF) { - size_t dN = Simd::Min(8, N - j); + size_t dN = Simd::Min(DF, N - j); size_t i = 0; for (; i < M5; i += 5) correlation_2x5(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); if (i < M) correlation_2xT(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); - bd += K * 8; - bn += 8; - distances += 8; + bd += K * DF; + bn += DF; + distances += DF; } } - //------------------------------------------------------------------------------------------------- Sse41::DescrInt::UnpackDataPtr GetUnpackData(size_t depth, bool transpose) From a4b0c7355abccd3a1c4e63c6c8876564d956383f Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 28 Jun 2023 16:56:50 +0300 Subject: [PATCH 33/44] *improve AVX-512BW optimizations of functions DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa for 4, 5, 6, 7-bits depth. --- docs/2023.html | 4 +- prj/vs2019/Avx512bw.vcxproj | 4 + prj/vs2019/Avx512bw.vcxproj.filters | 12 + prj/vs2022/Avx512bw.vcxproj | 4 + prj/vs2022/Avx512bw.vcxproj.filters | 12 + src/Simd/SimdAvx2DescrIntCdu.cpp | 2 +- src/Simd/SimdAvx512bwDescrInt.cpp | 1662 ++------------------------ src/Simd/SimdAvx512bwDescrIntCdd.cpp | 976 +++++++++++++++ src/Simd/SimdAvx512bwDescrIntCdu.cpp | 481 ++++++++ src/Simd/SimdAvx512bwDescrIntDec.cpp | 313 +++++ src/Simd/SimdAvx512bwDescrIntEnc.cpp | 441 +++++++ src/Simd/SimdDescrInt.h | 14 + src/Simd/SimdDescrIntCommon.h | 24 + src/Simd/SimdSse41DescrInt.cpp | 7 +- 14 files changed, 2369 insertions(+), 1587 deletions(-) create mode 100644 src/Simd/SimdAvx512bwDescrIntCdd.cpp create mode 100644 src/Simd/SimdAvx512bwDescrIntCdu.cpp create mode 100644 src/Simd/SimdAvx512bwDescrIntDec.cpp create mode 100644 src/Simd/SimdAvx512bwDescrIntEnc.cpp diff --git a/docs/2023.html b/docs/2023.html index 58ebbd6698..9563fe0c7b 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -50,8 +50,8 @@
        New features
      Improving
        -
      • SSE4.1 optimizations of function DescrIntCosineDistancesMxNp for 4, 5, 6, 7-bits depth.
      • -
      • SSE4.1 optimizations of function DescrIntCosineDistancesMxNa for 4, 5, 6, 7-bits depth.
      • +
      • SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNp for 4, 5, 6, 7-bits depth.
      • +
      • SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa for 4, 5, 6, 7-bits depth.
      Bug fixing
        diff --git a/prj/vs2019/Avx512bw.vcxproj b/prj/vs2019/Avx512bw.vcxproj index 5ce3ab6cc1..3b8d66991c 100644 --- a/prj/vs2019/Avx512bw.vcxproj +++ b/prj/vs2019/Avx512bw.vcxproj @@ -34,6 +34,10 @@ + + + + diff --git a/prj/vs2019/Avx512bw.vcxproj.filters b/prj/vs2019/Avx512bw.vcxproj.filters index a43d951db7..9d4927c449 100644 --- a/prj/vs2019/Avx512bw.vcxproj.filters +++ b/prj/vs2019/Avx512bw.vcxproj.filters @@ -370,6 +370,18 @@ Avx512bw + + Avx512bw + + + Avx512bw + + + Avx512bw + + + Avx512bw + diff --git a/prj/vs2022/Avx512bw.vcxproj b/prj/vs2022/Avx512bw.vcxproj index 5ce3ab6cc1..3b8d66991c 100644 --- a/prj/vs2022/Avx512bw.vcxproj +++ b/prj/vs2022/Avx512bw.vcxproj @@ -34,6 +34,10 @@ + + + + diff --git a/prj/vs2022/Avx512bw.vcxproj.filters b/prj/vs2022/Avx512bw.vcxproj.filters index a43d951db7..9d4927c449 100644 --- a/prj/vs2022/Avx512bw.vcxproj.filters +++ b/prj/vs2022/Avx512bw.vcxproj.filters @@ -370,6 +370,18 @@ Avx512bw + + Avx512bw + + + Avx512bw + + + Avx512bw + + + Avx512bw + diff --git a/src/Simd/SimdAvx2DescrIntCdu.cpp b/src/Simd/SimdAvx2DescrIntCdu.cpp index 950ca9f75d..4bfa8fb776 100644 --- a/src/Simd/SimdAvx2DescrIntCdu.cpp +++ b/src/Simd/SimdAvx2DescrIntCdu.cpp @@ -229,7 +229,7 @@ namespace Simd const uint8_t* ad2 = ad0 + 2 * K; const uint8_t* ad3 = ad0 + 3 * K; const uint8_t* ad4 = ad0 + 4 * K; - if (N > 4) + if (N > F) { if (M > 0) ab00 = _mm256_setzero_si256(), ab01 = _mm256_setzero_si256(); if (M > 1) ab10 = _mm256_setzero_si256(), ab11 = _mm256_setzero_si256(); diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index 286fa71cbb..8ea93b61da 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -83,1536 +83,73 @@ namespace Simd //------------------------------------------------------------------------------------------------- - SIMD_INLINE __m512i Encode32f(__m512 src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + static void UnpackNormA(size_t count, const uint8_t* const* src, float* dst, size_t stride) { - __m512i value = _mm512_cvtps_epi32(_mm512_mul_ps(_mm512_sub_ps(src, min), scale)); - sum = _mm512_add_epi32(value, sum); - sqsum = _mm512_add_epi32(_mm512_madd_epi16(value, value), sqsum); - return value; - } - - SIMD_INLINE __m512i Encode32f(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) - { - return Encode32f(_mm512_maskz_loadu_ps(mask, src), scale, min, sum, sqsum); - } - - static SIMD_INLINE __m128i Encode32f4x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 m0, __mmask16 m1) - { - __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum, m0); - __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum, m1); - __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); - return _mm256_castsi256_si128(Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), Avx2::K_ZERO)); - } - - static SIMD_INLINE __m256i Encode32f4x8(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) - { - __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); - __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); - __m512i i2 = Encode32f(src + 2 * F, scale, min, sum, sqsum); - __m512i i3 = Encode32f(src + 3 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); - __m512i s1 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); - return Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), _mm512_cvtepi32_epi16(s1)); - } - - static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, size32 = AlignLo(size, 32), size64 = AlignLo(size, 64); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < size64; i += 64, src += 64, dst += 32) - _mm256_storeu_si256((__m256i*)dst, Encode32f4x8(src, _scale, _min, _sum, _sqsum)); - for (; i < size32; i += 32, src += 32, dst += 16) - _mm_mask_storeu_epi8(dst, -1, Encode32f4x4(src, _scale, _min, _sum, _sqsum, -1, -1)); - if (i < size) - { - __mmask16 ms0 = TailMask16(size - size32 - 0 * F); - __mmask16 ms1 = TailMask16(size - size32 - 1 * F); - __mmask16 md= TailMask16((size - size32) / 2); - _mm_mask_storeu_epi8(dst, md, Encode32f4x4(src, _scale, _min, _sum, _sqsum, ms0, ms1)); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f5x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) - { - __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); - __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E5_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E5_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E5_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static SIMD_INLINE __m256i Encode32f5x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) - { - __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); - __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E5_MULLO); - __m512i e0 = _mm512_or_si512(_mm512_or_si512(_mm512_shuffle_epi8(s0, E5_SHFL0), _mm512_shuffle_epi8(s0, E5_SHFL1)), _mm512_shuffle_epi8(s0, E5_SHFL2)); - return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); - } - - static void Encode32f5(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < size32; i += 32, src += 32, dst += 20) - _mm256_mask_storeu_epi8(dst - 6, 0x03FFFFC0, Encode32f5x4(src, _scale, _min, _sum, _sqsum)); - for (; i < size16; i += 16, src += 16, dst += 10) - _mm_mask_storeu_epi8(dst, 0x03FF, Encode32f5x2(src, _scale, _min, _sum, _sqsum)); - if (i < size) - _mm_mask_storeu_epi8(dst, 0x001F, Encode32f5x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f6x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) - { - __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); - __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E6_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E6_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E6_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static SIMD_INLINE __m256i Encode32f6x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) - { - __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); - __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E6_MULLO); - __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E6_SHFL0), _mm512_shuffle_epi8(s0, E6_SHFL1)); - return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); - } - - static void Encode32f6(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < size32; i += 32, src += 32, dst += 24) - _mm256_mask_storeu_epi8(dst - 4, 0x0FFFFFF0, Encode32f6x4(src, _scale, _min, _sum, _sqsum)); - for (; i < size16; i += 16, src += 16, dst += 12) - _mm_mask_storeu_epi8(dst, 0x0FFF, Encode32f6x2(src, _scale, _min, _sum, _sqsum)); - if (i < size) - _mm_mask_storeu_epi8(dst, 0x003F, Encode32f6x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode32f7x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) - { - __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); - __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E7_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E7_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E7_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static SIMD_INLINE __m256i Encode32f7x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) - { - __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); - __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E7_MULLO); - __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E7_SHFL0), _mm512_shuffle_epi8(s0, E7_SHFL1)); - return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); - } - - static void Encode32f7(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < size32; i += 32, src += 32, dst += 28) - _mm256_mask_storeu_epi8(dst - 2, 0x3FFFFFFC, Encode32f7x4(src, _scale, _min, _sum, _sqsum)); - for (; i < size16; i += 16, src += 16, dst += 14) - _mm_mask_storeu_epi8(dst, 0x3FFF, Encode32f7x2(src, _scale, _min, _sum, _sqsum)); - if (i < size) - _mm_mask_storeu_epi8(dst, 0x007F, Encode32f7x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static void Encode32f8(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t sizeF = AlignLo(size, F), sizeA = AlignLo(size, A), i = 0; - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < sizeA; i += A) - { - __m512i d0 = Encode32f(src + i + 0 * F, _scale, _min, _sum, _sqsum); - __m512i d1 = Encode32f(src + i + 1 * F, _scale, _min, _sum, _sqsum); - __m512i d2 = Encode32f(src + i + 2 * F, _scale, _min, _sum, _sqsum); - __m512i d3 = Encode32f(src + i + 3 * F, _scale, _min, _sum, _sqsum); - _mm512_storeu_si512((__m512i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); - } - for (; i < sizeF; i += F) - { - __m512i d0 = Encode32f(src + i, _scale, _min, _sum, _sqsum); - _mm_storeu_si128((__m128i*)(dst + i), _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); - } - if (i < size) - { - __m512i d0 = Encode32f(src + i, _scale, _min, _sum, _sqsum, 0xFF); - _mm_mask_storeu_epi8(dst + i, 0xFF, _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - //------------------------------------------------------------------------------------------------- - - SIMD_INLINE __m512i Encode16f(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) - { - return Encode32f(_mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, src)), scale, min, sum, sqsum); - } - - static SIMD_INLINE __m128i Encode16f4x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 m0, __mmask16 m1) - { - __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum, m0); - __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum, m1); - __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); - return _mm256_castsi256_si128(Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), Avx2::K_ZERO)); - } - - static SIMD_INLINE __m256i Encode16f4x8(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) - { - __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); - __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); - __m512i i2 = Encode16f(src + 2 * F, scale, min, sum, sqsum); - __m512i i3 = Encode16f(src + 3 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); - __m512i s1 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); - return Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), _mm512_cvtepi32_epi16(s1)); - } - - static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t i = 0, size32 = AlignLo(size, 32), size64 = AlignLo(size, 64); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < size64; i += 64, src += 64, dst += 32) - _mm256_storeu_si256((__m256i*)dst, Encode16f4x8(src, _scale, _min, _sum, _sqsum)); - for (; i < size32; i += 32, src += 32, dst += 16) - _mm_mask_storeu_epi8(dst, -1, Encode16f4x4(src, _scale, _min, _sum, _sqsum, -1, -1)); - if (i < size) - { - __mmask16 ms0 = TailMask16(size - size32 - 0 * F); - __mmask16 ms1 = TailMask16(size - size32 - 1 * F); - __mmask16 md = TailMask16((size - size32) / 2); - _mm_mask_storeu_epi8(dst, md, Encode16f4x4(src, _scale, _min, _sum, _sqsum, ms0, ms1)); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f5x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) - { - __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); - __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E5_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E5_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E5_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static SIMD_INLINE __m256i Encode16f5x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) - { - __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); - __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E5_MULLO); - __m512i e0 = _mm512_or_si512(_mm512_or_si512(_mm512_shuffle_epi8(s0, E5_SHFL0), _mm512_shuffle_epi8(s0, E5_SHFL1)), _mm512_shuffle_epi8(s0, E5_SHFL2)); - return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); - } - - static void Encode16f5(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < size32; i += 32, src += 32, dst += 20) - _mm256_mask_storeu_epi8(dst - 6, 0x03FFFFC0, Encode16f5x4(src, _scale, _min, _sum, _sqsum)); - for (; i < size16; i += 16, src += 16, dst += 10) - _mm_mask_storeu_epi8(dst, 0x03FF, Encode16f5x2(src, _scale, _min, _sum, _sqsum)); - if (i < size) - _mm_mask_storeu_epi8(dst, 0x001F, Encode16f5x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f6x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) - { - __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); - __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E6_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E6_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E6_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static SIMD_INLINE __m256i Encode16f6x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) - { - __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); - __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E6_MULLO); - __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E6_SHFL0), _mm512_shuffle_epi8(s0, E6_SHFL1)); - return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); - } - - static void Encode16f6(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < size32; i += 32, src += 32, dst += 24) - _mm256_mask_storeu_epi8(dst - 4, 0x0FFFFFF0, Encode16f6x4(src, _scale, _min, _sum, _sqsum)); - for (; i < size16; i += 16, src += 16, dst += 12) - _mm_mask_storeu_epi8(dst, 0x0FFF, Encode16f6x2(src, _scale, _min, _sum, _sqsum)); - if (i < size) - _mm_mask_storeu_epi8(dst, 0x003F, Encode16f6x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static SIMD_INLINE __m128i Encode16f7x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) - { - __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); - __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E7_MULLO); - __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E7_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E7_SHFL1)); - return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); - } - - static SIMD_INLINE __m256i Encode16f7x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) - { - __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); - __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); - __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E7_MULLO); - __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E7_SHFL0), _mm512_shuffle_epi8(s0, E7_SHFL1)); - return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); - } - - static void Encode16f7(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < size32; i += 32, src += 32, dst += 28) - _mm256_mask_storeu_epi8(dst - 2, 0x3FFFFFFC, Encode16f7x4(src, _scale, _min, _sum, _sqsum)); - for (; i < size16; i += 16, src += 16, dst += 14) - _mm_mask_storeu_epi8(dst, 0x3FFF, Encode16f7x2(src, _scale, _min, _sum, _sqsum)); - if (i < size) - _mm_mask_storeu_epi8(dst, 0x007F, Encode16f7x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - static void Encode16f8(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) - { - assert(size % 8 == 0); - size_t sizeF = AlignLo(size, F), sizeA = AlignLo(size, A), i = 0; - __m512 _scale = _mm512_set1_ps(scale); - __m512 _min = _mm512_set1_ps(min); - __m512i _sum = _mm512_setzero_si512(); - __m512i _sqsum = _mm512_setzero_si512(); - for (; i < sizeA; i += A) - { - __m512i d0 = Encode16f(src + i + 0 * F, _scale, _min, _sum, _sqsum); - __m512i d1 = Encode16f(src + i + 1 * F, _scale, _min, _sum, _sqsum); - __m512i d2 = Encode16f(src + i + 2 * F, _scale, _min, _sum, _sqsum); - __m512i d3 = Encode16f(src + i + 3 * F, _scale, _min, _sum, _sqsum); - _mm512_storeu_si512((__m512i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); - } - for (; i < sizeF; i += F) - { - __m512i d0 = Encode16f(src + i, _scale, _min, _sum, _sqsum); - _mm_storeu_si128((__m128i*)(dst + i), _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); - } - if (i < size) - { - __m512i d0 = Encode16f(src + i, _scale, _min, _sum, _sqsum, 0xFF); - _mm_mask_storeu_epi8(dst + i, 0xFF, _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); - } - sum = ExtractSum(_sum); - sqsum = ExtractSum(_sqsum); - } - - //------------------------------------------------------------------------------------------------- - - static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); - for (; i < size16; i += 16) - { - __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, Avx2::C4_SHFL), Avx2::C4_MULLO), 12); - _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); - src += 8; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s4 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); - src += 4; - dst += 8; - } - } - - static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); - for (; i < size16; i += 16) - { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C5_SHFL), Avx2::C5_MULLO), 11); - _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); - src += 10; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s5 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); - src += 5; - dst += 8; - } - } - - static void Decode32f6(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); - for (; i < size16; i += 16) - { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C6_SHFL), Avx2::C6_MULLO), 10); - _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); - src += 12; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s6 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); - src += 6; - dst += 8; - } - } - - static void Decode32f7(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); - for (; i < size16; i += 16) - { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C7_SHFL), Avx2::C7_MULLO), 9); - _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); - src += 14; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s7 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); - _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); - src += 7; - dst += 8; - } - } - - static void Decode32f8(const uint8_t* src, float scale, float shift, size_t size, float* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size64 = AlignLo(size, 64); - for (; i < size64; i += 64) - { - __m512i u8 = _mm512_loadu_si512((__m512i*)(src + i)); - _mm512_storeu_ps(dst + i + 0 * F, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 0))), _scale, _shift)); - _mm512_storeu_ps(dst + i + 1 * F, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 1))), _scale, _shift)); - _mm512_storeu_ps(dst + i + 2 * F, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 2))), _scale, _shift)); - _mm512_storeu_ps(dst + i + 3 * F, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 3))), _scale, _shift)); - } - for (; i < size16; i += 16) - { - __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); - _mm512_storeu_ps(dst + i, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(u8)), _scale, _shift)); - } - if (i < size) - { - __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); - _mm256_storeu_ps(dst + i, _mm256_fmadd_ps(_src, _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); - } - } - - //------------------------------------------------------------------------------------------------- - - static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); - for (; i < size16; i += 16) - { - __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, Avx2::C4_SHFL), Avx2::C4_MULLO), 12); - _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); - src += 8; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s4 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); - _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); - src += 4; - dst += 8; - } - } - - static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); - for (; i < size16; i += 16) - { - __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, Avx2::C5_SHFL), Avx2::C5_MULLO), 11); - _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); - src += 10; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s5 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); - _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); - src += 5; - dst += 8; - } - } - - static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); - for (; i < size16; i += 16) - { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C6_SHFL), Avx2::C6_MULLO), 10); - _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); - src += 12; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s6 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); - _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); - src += 6; - dst += 8; - } - } - - static void Decode16f7(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); - for (; i < size16; i += 16) - { - __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); - __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C7_SHFL), Avx2::C7_MULLO), 9); - _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); - src += 14; - dst += 16; - } - for (; i < size; i += 8) - { - __m128i s7 = _mm_loadl_epi64((__m128i*)src); - __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); - _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); - src += 7; - dst += 8; - } - } - - static void Decode16f8(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) - { - assert(size % 8 == 0); - __m512 _scale = _mm512_set1_ps(scale); - __m512 _shift = _mm512_set1_ps(shift); - size_t i = 0, size16 = AlignLo(size, 16), size64 = AlignLo(size, 64); - for (; i < size64; i += 64) - { - __m512i u8 = _mm512_loadu_si512((__m512i*)(src + i)); - _mm256_storeu_si256((__m256i*)(dst + i) + 0, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 0))), _scale, _shift), 0)); - _mm256_storeu_si256((__m256i*)(dst + i) + 1, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 1))), _scale, _shift), 0)); - _mm256_storeu_si256((__m256i*)(dst + i) + 2, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 2))), _scale, _shift), 0)); - _mm256_storeu_si256((__m256i*)(dst + i) + 3, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 3))), _scale, _shift), 0)); - } - for (; i < size16; i += 16) - { - __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); - _mm256_storeu_si256((__m256i*)(dst + i), _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(u8)), _scale, _shift), 0)); - } - if (i < size) - { - __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); - _mm_storeu_si128((__m128i*)(dst + i), _mm256_cvtps_ph(_mm256_fmadd_ps(_src, _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); - } - } - - //------------------------------------------------------------------------------------------------- - - template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); - - template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m512i ab32 = _mm512_setzero_si512(); - size_t i = 0, size128 = AlignLo(size, 128); - for (; i < size128; i += 128, a += 64, b += 64) - { - __m512i _a = _mm512_loadu_si512((__m512i*)a); - __m512i _b = _mm512_loadu_si512((__m512i*)b); - __m512i ab16 = _mm512_maddubs_epi16(_mm512_and_si512(_a, K8_0F), _mm512_and_si512(_b, K8_0F)); - ab16 = _mm512_add_epi16(ab16, _mm512_maddubs_epi16(_mm512_and_si512(_mm512_srli_epi16(_a, 4), K8_0F), _mm512_and_si512(_mm512_srli_epi16(_b, 4), K8_0F))); - ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(ab16, K16_0001)); - } - if(i < size) - { - __mmask16 mask = TailMask16((size - i) / 8); - __m512i _a = _mm512_maskz_loadu_epi32(mask, a); - __m512i _b = _mm512_maskz_loadu_epi32(mask, b); - __m512i ab16 = _mm512_maddubs_epi16(_mm512_and_si512(_a, K8_0F), _mm512_and_si512(_b, K8_0F)); - ab16 = _mm512_add_epi16(ab16, _mm512_maddubs_epi16(_mm512_and_si512(_mm512_srli_epi16(_a, 4), K8_0F), _mm512_and_si512(_mm512_srli_epi16(_b, 4), K8_0F))); - ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(ab16, K16_0001)); - } - return ExtractSum(ab32); - } - - SIMD_INLINE __m512i Load5(const uint8_t* ptr, __mmask32 mask = 0x000FFFFF) - { - return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C5_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C5_SHFL), C5_MULLO), 11); - } - - template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m512i _ab = _mm512_setzero_si512(); - size_t i = 0, size32 = AlignLo(size, 32); - for (; i < size32; i += 32, a += 20, b += 20) - { - __m512i _a = Load5(a); - __m512i _b = Load5(b); - _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 5); - __m512i _a = Load5(a, mask); - __m512i _b = Load5(b, mask); - _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); - } - return ExtractSum(_ab); - } - - SIMD_INLINE __m512i Load6(const uint8_t* ptr, __mmask32 mask = 0x00FFFFFF) - { - return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C6_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C6_SHFL), C6_MULLO), 10); - } - - template<> int32_t Correlation<6>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m512i _ab = _mm512_setzero_si512(); - size_t i = 0, size32 = AlignLo(size, 32); - for (; i < size32; i += 32, a += 24, b += 24) - { - __m512i _a = Load6(a); - __m512i _b = Load6(b); - _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 6); - __m512i _a = Load6(a, mask); - __m512i _b = Load6(b, mask); - _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); - } - return ExtractSum(_ab); - } - - SIMD_INLINE __m512i Load7(const uint8_t* ptr, __mmask32 mask = 0x0FFFFFFF) - { - return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C7_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C7_SHFL), C7_MULLO), 9); - } - - template<> int32_t Correlation<7>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - __m512i _ab = _mm512_setzero_si512(); - size_t i = 0, size32 = AlignLo(size, 32); - for (; i < size32; i += 32, a += 28, b += 28) - { - __m512i _a = Load7(a); - __m512i _b = Load7(b); - _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 7); - __m512i _a = Load7(a, mask); - __m512i _b = Load7(b, mask); - _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); - } - return ExtractSum(_ab); - } - - template<> int32_t Correlation<8>(const uint8_t* a, const uint8_t* b, size_t size) - { - assert(size % 8 == 0); - size_t i = 0, size32 = AlignLo(size, 32); - __m512i _ab = _mm512_setzero_si512(); - for (; i < size32; i += 32) - { - __m512i _a = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(a + i))); - __m512i _b = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(b + i))); - _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); - } - if ( i < size) - { - __mmask32 mask = TailMask32(size - i); - __m512i _a = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, a + i)); - __m512i _b = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, b + i)); - _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); - } - return ExtractSum(_ab); - } - - template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) - { - float abSum = (float)Correlation(a + 16, b + 16, size); - Base::DecodeCosineDistance(a, b, abSum, distance); + size_t count2 = AlignLo(count, 2), count4 = AlignLo(count, 4), i = 0; + for (; i < count4; i += 4, src += 4, dst += 16) + _mm512_storeu_ps(dst, Load((float*)src[0], (float*)src[1], (float*)src[2], (float*)src[3])); + for (; i < count2; i += 2, src += 2, dst += 8) + _mm256_storeu_ps(dst, Avx::Load((float*)src[0], (float*)src[1])); + for (; i < count; ++i, src += 1, dst += 4) + _mm_storeu_ps(dst, _mm_loadu_ps((float*)src[0])); } //------------------------------------------------------------------------------------------------- - template void MicroCosineDistancesDirect4x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - - template<> void MicroCosineDistancesDirect4x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size128 = AlignLo(size, 128), o = 16; - __m512i a00, a10, a20, a30, a01, a11, a21, a31, b00, b01; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - __m512i ab10 = _mm512_setzero_si512(); - __m512i ab11 = _mm512_setzero_si512(); - __m512i ab12 = _mm512_setzero_si512(); - __m512i ab13 = _mm512_setzero_si512(); - __m512i ab20 = _mm512_setzero_si512(); - __m512i ab21 = _mm512_setzero_si512(); - __m512i ab22 = _mm512_setzero_si512(); - __m512i ab23 = _mm512_setzero_si512(); - __m512i ab30 = _mm512_setzero_si512(); - __m512i ab31 = _mm512_setzero_si512(); - __m512i ab32 = _mm512_setzero_si512(); - __m512i ab33 = _mm512_setzero_si512(); - for (; i < size128; i += 128, o += 64) - { - a01 = _mm512_loadu_si512((__m512i*)(A[0] + o)); - a00 = _mm512_and_si512(a01, K8_0F); - a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); - a11 = _mm512_loadu_si512((__m512i*)(A[1] + o)); - a10 = _mm512_and_si512(a11, K8_0F); - a11 = _mm512_and_si512(_mm512_srli_epi16(a11, 4), K8_0F); - a21 = _mm512_loadu_si512((__m512i*)(A[2] + o)); - a20 = _mm512_and_si512(a21, K8_0F); - a21 = _mm512_and_si512(_mm512_srli_epi16(a21, 4), K8_0F); - a31 = _mm512_loadu_si512((__m512i*)(A[3] + o)); - a30 = _mm512_and_si512(a31, K8_0F); - a31 = _mm512_and_si512(_mm512_srli_epi16(a31, 4), K8_0F); - - b01 = _mm512_loadu_si512((__m512i*)(B[0] + o)); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - ab10 = _mm512_add_epi32(ab10, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); - ab20 = _mm512_add_epi32(ab20, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); - ab30 = _mm512_add_epi32(ab30, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); - - b01 = _mm512_loadu_si512((__m512i*)(B[1] + o)); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - ab11 = _mm512_add_epi32(ab11, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); - ab21 = _mm512_add_epi32(ab21, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); - ab31 = _mm512_add_epi32(ab31, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); - - b01 = _mm512_loadu_si512((__m512i*)(B[2] + o)); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - ab12 = _mm512_add_epi32(ab12, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); - ab22 = _mm512_add_epi32(ab22, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); - ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); - - b01 = _mm512_loadu_si512((__m512i*)(B[3] + o)); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - ab13 = _mm512_add_epi32(ab13, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); - ab23 = _mm512_add_epi32(ab23, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); - ab33 = _mm512_add_epi32(ab33, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); - } - if (i < size) - { - __mmask16 mask = TailMask32((size - i) / 8); - a01 = _mm512_maskz_loadu_epi32(mask, A[0] + o); - a00 = _mm512_and_si512(a01, K8_0F); - a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); - a11 = _mm512_maskz_loadu_epi32(mask, A[1] + o); - a10 = _mm512_and_si512(a11, K8_0F); - a11 = _mm512_and_si512(_mm512_srli_epi16(a11, 4), K8_0F); - a21 = _mm512_maskz_loadu_epi32(mask, A[2] + o); - a20 = _mm512_and_si512(a21, K8_0F); - a21 = _mm512_and_si512(_mm512_srli_epi16(a21, 4), K8_0F); - a31 = _mm512_maskz_loadu_epi32(mask, A[3] + o); - a30 = _mm512_and_si512(a31, K8_0F); - a31 = _mm512_and_si512(_mm512_srli_epi16(a31, 4), K8_0F); - - b01 = _mm512_maskz_loadu_epi32(mask, B[0] + o); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - ab10 = _mm512_add_epi32(ab10, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); - ab20 = _mm512_add_epi32(ab20, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); - ab30 = _mm512_add_epi32(ab30, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); - - b01 = _mm512_maskz_loadu_epi32(mask, B[1] + o); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - ab11 = _mm512_add_epi32(ab11, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); - ab21 = _mm512_add_epi32(ab21, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); - ab31 = _mm512_add_epi32(ab31, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); - - b01 = _mm512_maskz_loadu_epi32(mask, B[2] + o); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - ab12 = _mm512_add_epi32(ab12, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); - ab22 = _mm512_add_epi32(ab22, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); - ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); - - b01 = _mm512_maskz_loadu_epi32(mask, B[3] + o); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - ab13 = _mm512_add_epi32(ab13, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); - ab23 = _mm512_add_epi32(ab23, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); - ab33 = _mm512_add_epi32(ab33, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); - __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); - } - - template<> void MicroCosineDistancesDirect4x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m512i a0, a1, a2, a3, b0; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - __m512i ab10 = _mm512_setzero_si512(); - __m512i ab11 = _mm512_setzero_si512(); - __m512i ab12 = _mm512_setzero_si512(); - __m512i ab13 = _mm512_setzero_si512(); - __m512i ab20 = _mm512_setzero_si512(); - __m512i ab21 = _mm512_setzero_si512(); - __m512i ab22 = _mm512_setzero_si512(); - __m512i ab23 = _mm512_setzero_si512(); - __m512i ab30 = _mm512_setzero_si512(); - __m512i ab31 = _mm512_setzero_si512(); - __m512i ab32 = _mm512_setzero_si512(); - __m512i ab33 = _mm512_setzero_si512(); - for (; i < size32; i += 32, o += 20) - { - a0 = Load5(A[0] + o); - a1 = Load5(A[1] + o); - a2 = Load5(A[2] + o); - a3 = Load5(A[3] + o); - - b0 = Load5(B[0] + o); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); - ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); - ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); - - b0 = Load5(B[1] + o); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); - ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); - ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); - - b0 = Load5(B[2] + o); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); - ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); - ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); - - b0 = Load5(B[3] + o); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); - ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); - ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 5); - a0 = Load5(A[0] + o, mask); - a1 = Load5(A[1] + o, mask); - a2 = Load5(A[2] + o, mask); - a3 = Load5(A[3] + o, mask); - - b0 = Load5(B[0] + o, mask); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); - ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); - ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); - - b0 = Load5(B[1] + o, mask); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); - ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); - ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); - - b0 = Load5(B[2] + o, mask); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); - ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); - ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); - - b0 = Load5(B[3] + o, mask); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); - ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); - ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); - __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); - } - - template<> void MicroCosineDistancesDirect4x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m512i a0, a1, a2, a3, b0; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - __m512i ab10 = _mm512_setzero_si512(); - __m512i ab11 = _mm512_setzero_si512(); - __m512i ab12 = _mm512_setzero_si512(); - __m512i ab13 = _mm512_setzero_si512(); - __m512i ab20 = _mm512_setzero_si512(); - __m512i ab21 = _mm512_setzero_si512(); - __m512i ab22 = _mm512_setzero_si512(); - __m512i ab23 = _mm512_setzero_si512(); - __m512i ab30 = _mm512_setzero_si512(); - __m512i ab31 = _mm512_setzero_si512(); - __m512i ab32 = _mm512_setzero_si512(); - __m512i ab33 = _mm512_setzero_si512(); - for (; i < size32; i += 32, o += 24) - { - a0 = Load6(A[0] + o); - a1 = Load6(A[1] + o); - a2 = Load6(A[2] + o); - a3 = Load6(A[3] + o); - - b0 = Load6(B[0] + o); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); - ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); - ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); - - b0 = Load6(B[1] + o); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); - ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); - ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); - - b0 = Load6(B[2] + o); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); - ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); - ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); - - b0 = Load6(B[3] + o); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); - ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); - ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 6); - a0 = Load6(A[0] + o, mask); - a1 = Load6(A[1] + o, mask); - a2 = Load6(A[2] + o, mask); - a3 = Load6(A[3] + o, mask); - - b0 = Load6(B[0] + o, mask); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); - ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); - ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); - - b0 = Load6(B[1] + o, mask); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); - ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); - ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); - - b0 = Load6(B[2] + o, mask); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); - ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); - ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); - - b0 = Load6(B[3] + o, mask); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); - ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); - ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); - __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); - } - - template<> void MicroCosineDistancesDirect4x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m512i a0, a1, a2, a3, b0; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - __m512i ab10 = _mm512_setzero_si512(); - __m512i ab11 = _mm512_setzero_si512(); - __m512i ab12 = _mm512_setzero_si512(); - __m512i ab13 = _mm512_setzero_si512(); - __m512i ab20 = _mm512_setzero_si512(); - __m512i ab21 = _mm512_setzero_si512(); - __m512i ab22 = _mm512_setzero_si512(); - __m512i ab23 = _mm512_setzero_si512(); - __m512i ab30 = _mm512_setzero_si512(); - __m512i ab31 = _mm512_setzero_si512(); - __m512i ab32 = _mm512_setzero_si512(); - __m512i ab33 = _mm512_setzero_si512(); - for (; i < size32; i += 32, o += 28) - { - a0 = Load7(A[0] + o); - a1 = Load7(A[1] + o); - a2 = Load7(A[2] + o); - a3 = Load7(A[3] + o); - - b0 = Load7(B[0] + o); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); - ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); - ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); - - b0 = Load7(B[1] + o); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); - ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); - ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); - - b0 = Load7(B[2] + o); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); - ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); - ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); - - b0 = Load7(B[3] + o); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); - ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); - ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 7); - a0 = Load7(A[0] + o, mask); - a1 = Load7(A[1] + o, mask); - a2 = Load7(A[2] + o, mask); - a3 = Load7(A[3] + o, mask); - - b0 = Load7(B[0] + o, mask); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); - ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); - ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); - - b0 = Load7(B[1] + o, mask); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); - ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); - ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); - - b0 = Load7(B[2] + o, mask); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); - ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); - ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); - - b0 = Load7(B[3] + o, mask); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); - ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); - ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); - __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); - } - - template<> void MicroCosineDistancesDirect4x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m512i a0, a1, a2, a3, b0; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - __m512i ab10 = _mm512_setzero_si512(); - __m512i ab11 = _mm512_setzero_si512(); - __m512i ab12 = _mm512_setzero_si512(); - __m512i ab13 = _mm512_setzero_si512(); - __m512i ab20 = _mm512_setzero_si512(); - __m512i ab21 = _mm512_setzero_si512(); - __m512i ab22 = _mm512_setzero_si512(); - __m512i ab23 = _mm512_setzero_si512(); - __m512i ab30 = _mm512_setzero_si512(); - __m512i ab31 = _mm512_setzero_si512(); - __m512i ab32 = _mm512_setzero_si512(); - __m512i ab33 = _mm512_setzero_si512(); - for (; i < size32; i += 32, o += 32) - { - a0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o))); - a1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[1] + o))); - a2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[2] + o))); - a3 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[3] + o))); - - b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o))); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); - ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); - ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); - - b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o))); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); - ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); - ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); - - b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o))); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); - ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); - ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); - - b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o))); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); - ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); - ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); - } - if (i < size) - { - __mmask32 mask = TailMask32(size - i); - a0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[0] + o)); - a1 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[1] + o)); - a2 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[2] + o)); - a3 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[3] + o)); - - b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[0] + o)); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); - ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); - ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); - - b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[1] + o)); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); - ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); - ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); - - b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[2] + o)); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); - ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); - ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); - - b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[3] + o)); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); - ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); - ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); - __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); - __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); - Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); - Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); - } - - template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); - - template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size128 = AlignLo(size, 128), o = 16; - __m512i a00, a01, b00, b01; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - for (; i < size128; i += 128, o += 64) - { - a01 = _mm512_loadu_si512((__m512i*)(A[0] + o)); - a00 = _mm512_and_si512(a01, K8_0F); - a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); - - b01 = _mm512_loadu_si512((__m512i*)(B[0] + o)); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - - b01 = _mm512_loadu_si512((__m512i*)(B[1] + o)); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - - b01 = _mm512_loadu_si512((__m512i*)(B[2] + o)); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - - b01 = _mm512_loadu_si512((__m512i*)(B[3] + o)); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - } - if (i < size) - { - __mmask16 mask = TailMask32((size - i) / 8); - a01 = _mm512_maskz_loadu_epi32(mask, A[0] + o); - a00 = _mm512_and_si512(a01, K8_0F); - a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); - - b01 = _mm512_maskz_loadu_epi32(mask, B[0] + o); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - - b01 = _mm512_maskz_loadu_epi32(mask, B[1] + o); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - b01 = _mm512_maskz_loadu_epi32(mask, B[2] + o); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - - b01 = _mm512_maskz_loadu_epi32(mask, B[3] + o); - b00 = _mm512_and_si512(b01, K8_0F); - b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); - ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m512i a0, b0; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - for (; i < size32; i += 32, o += 20) - { - a0 = Load5(A[0] + o); - - b0 = Load5(B[0] + o); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - - b0 = Load5(B[1] + o); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - - b0 = Load5(B[2] + o); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - - b0 = Load5(B[3] + o); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 5); - a0 = Load5(A[0] + o, mask); - - b0 = Load5(B[0] + o, mask); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - - b0 = Load5(B[1] + o, mask); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - - b0 = Load5(B[2] + o, mask); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - - b0 = Load5(B[3] + o, mask); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m512i a0, b0; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - for (; i < size32; i += 32, o += 24) - { - a0 = Load6(A[0] + o); - - b0 = Load6(B[0] + o); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - - b0 = Load6(B[1] + o); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - - b0 = Load6(B[2] + o); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - - b0 = Load6(B[3] + o); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 6); - a0 = Load6(A[0] + o, mask); - - b0 = Load6(B[0] + o, mask); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - - b0 = Load6(B[1] + o, mask); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - - b0 = Load6(B[2] + o, mask); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - - b0 = Load6(B[3] + o, mask); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m512i a0, b0; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - for (; i < size32; i += 32, o += 28) - { - a0 = Load7(A[0] + o); - - b0 = Load7(B[0] + o); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - - b0 = Load7(B[1] + o); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - - b0 = Load7(B[2] + o); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - - b0 = Load7(B[3] + o); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - } - if (i < size) - { - __mmask32 mask = TailMask32((size - i) / 8 * 7); - a0 = Load7(A[0] + o, mask); - - b0 = Load7(B[0] + o, mask); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - - b0 = Load7(B[1] + o, mask); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - - b0 = Load7(B[2] + o, mask); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - - b0 = Load7(B[3] + o, mask); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t i = 0, size32 = AlignLo(size, 32), o = 16; - __m512i a0, b0; - __m512i ab00 = _mm512_setzero_si512(); - __m512i ab01 = _mm512_setzero_si512(); - __m512i ab02 = _mm512_setzero_si512(); - __m512i ab03 = _mm512_setzero_si512(); - for (; i < size32; i += 32, o += 32) - { - a0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o))); - - b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o))); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - - b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o))); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - - b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o))); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - - b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o))); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - } - if (i < size) - { - __mmask32 mask = TailMask32(size - i); - a0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[0] + o)); - - b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[0] + o)); - ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); - - b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[1] + o)); - ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); - - b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[2] + o)); - ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); - - b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[3] + o)); - ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); - } - __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); - Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); - } - - template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) - { - size_t M4 = AlignLoAny(M, 4); - size_t N4 = AlignLoAny(N, 4); - size_t i = 0; - for (; i < M4; i += 4) - { - size_t j = 0; - for (; j < N4; j += 4) - MicroCosineDistancesDirect4x4(A + i, B + j, size, distances + j, stride); - for (; j < N; j += 1) - { - CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); - CosineDistance(A[i + 1], B[j], size, distances + j + 1 * stride); - CosineDistance(A[i + 2], B[j], size, distances + j + 2 * stride); - CosineDistance(A[i + 3], B[j], size, distances + j + 3 * stride); - } - distances += 4 * stride; - } - for (; i < M; i++) - { - size_t j = 0; - for (; j < N4; j += 4) - MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); - for (; j < N; j += 1) - CosineDistance(A[i], B[j], size, distances + j); - distances += 1 * stride; + static void UnpackNormB(size_t count, const uint8_t* const* src, float* dst, size_t stride) + { + size_t count16 = AlignLo(count, 16), count8 = AlignLo(count, 8), count4 = AlignLo(count, 4), i = 0; + for (; i < count16; i += 16, src += 16, dst += 16) + { + __m512 s0 = Load((float*)src[0], (float*)src[4], (float*)src[8], (float*)src[12]); + __m512 s1 = Load((float*)src[1], (float*)src[5], (float*)src[9], (float*)src[13]); + __m512 s2 = Load((float*)src[2], (float*)src[6], (float*)src[10], (float*)src[14]); + __m512 s3 = Load((float*)src[3], (float*)src[7], (float*)src[11], (float*)src[15]); + __m512 s00 = _mm512_unpacklo_ps(s0, s2); + __m512 s01 = _mm512_unpacklo_ps(s1, s3); + __m512 s10 = _mm512_unpackhi_ps(s0, s2); + __m512 s11 = _mm512_unpackhi_ps(s1, s3); + _mm512_storeu_ps(dst + 0 * stride, _mm512_unpacklo_ps(s00, s01)); + _mm512_storeu_ps(dst + 1 * stride, _mm512_unpackhi_ps(s00, s01)); + _mm512_storeu_ps(dst + 2 * stride, _mm512_unpacklo_ps(s10, s11)); + _mm512_storeu_ps(dst + 3 * stride, _mm512_unpackhi_ps(s10, s11)); + } + for (; i < count8; i += 8, src += 8, dst += 8) + { + __m256 s0 = Avx::Load((float*)src[0], (float*)src[4]); + __m256 s1 = Avx::Load((float*)src[1], (float*)src[5]); + __m256 s2 = Avx::Load((float*)src[2], (float*)src[6]); + __m256 s3 = Avx::Load((float*)src[3], (float*)src[7]); + __m256 s00 = _mm256_unpacklo_ps(s0, s2); + __m256 s01 = _mm256_unpacklo_ps(s1, s3); + __m256 s10 = _mm256_unpackhi_ps(s0, s2); + __m256 s11 = _mm256_unpackhi_ps(s1, s3); + _mm256_storeu_ps(dst + 0 * stride, _mm256_unpacklo_ps(s00, s01)); + _mm256_storeu_ps(dst + 1 * stride, _mm256_unpackhi_ps(s00, s01)); + _mm256_storeu_ps(dst + 2 * stride, _mm256_unpacklo_ps(s10, s11)); + _mm256_storeu_ps(dst + 3 * stride, _mm256_unpackhi_ps(s10, s11)); + } + for (; i < count4; i += 4, src += 4, dst += 4) + { + __m128 s0 = _mm_loadu_ps((float*)src[0]); + __m128 s1 = _mm_loadu_ps((float*)src[1]); + __m128 s2 = _mm_loadu_ps((float*)src[2]); + __m128 s3 = _mm_loadu_ps((float*)src[3]); + __m128 s00 = _mm_unpacklo_ps(s0, s2); + __m128 s01 = _mm_unpacklo_ps(s1, s3); + __m128 s10 = _mm_unpackhi_ps(s0, s2); + __m128 s11 = _mm_unpackhi_ps(s1, s3); + _mm_storeu_ps(dst + 0 * stride, _mm_unpacklo_ps(s00, s01)); + _mm_storeu_ps(dst + 1 * stride, _mm_unpackhi_ps(s00, s01)); + _mm_storeu_ps(dst + 2 * stride, _mm_unpacklo_ps(s10, s11)); + _mm_storeu_ps(dst + 3 * stride, _mm_unpackhi_ps(s10, s11)); + } + for (; i < count; i++, src++, dst++) + { + dst[0 * stride] = ((float*)src)[0]; + dst[1 * stride] = ((float*)src)[1]; + dst[2 * stride] = ((float*)src)[2]; + dst[3 * stride] = ((float*)src)[3]; } } @@ -1623,61 +160,26 @@ namespace Simd { _minMax32f = MinMax32f; _minMax16f = MinMax16f; - switch (depth) - { - case 4: - { - _encode32f = Encode32f4; - _encode16f = Encode16f4; - _decode32f = Decode32f4; - _decode16f = Decode16f4; - _cosineDistance = Avx512bw::CosineDistance<4>; - _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<4>; - break; - } - case 5: - { - _encode32f = Encode32f5; - _encode16f = Encode16f5; - _decode32f = Decode32f5; - _decode16f = Decode16f5; - _cosineDistance = Avx512bw::CosineDistance<5>; - _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<5>; - break; - } - case 6: - { - _encode32f = Encode32f6; - _encode16f = Encode16f6; - _decode32f = Decode32f6; - _decode16f = Decode16f6; - _cosineDistance = Avx512bw::CosineDistance<6>; - _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<6>; - break; - } - case 7: - { - _encode32f = Encode32f7; - _encode16f = Encode16f7; - _decode32f = Decode32f7; - _decode16f = Decode16f7; - _cosineDistance = Avx512bw::CosineDistance<7>; - _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<7>; - break; - } - case 8: - { - _encode32f = Encode32f8; - _encode16f = Encode16f8; - _decode32f = Decode32f8; - _decode16f = Decode16f8; - _cosineDistance = Avx512bw::CosineDistance<8>; - _macroCosineDistancesDirect = Avx512bw::MacroCosineDistancesDirect<8>; - _microMd = 4; - break; - } - default: - assert(0); + _encode32f = GetEncode32f(_depth); + _encode16f = GetEncode16f(_depth); + + _decode32f = GetDecode32f(_depth); + _decode16f = GetDecode16f(_depth); + + _cosineDistance = GetCosineDistance(_depth); + _macroCosineDistancesDirect = GetMacroCosineDistancesDirect(_depth); + _microMd = 4; + _microNd = 4; + + _unpackNormA = UnpackNormA; + _unpackNormB = UnpackNormB; + if (_depth != 8) + { + _unpackDataA = GetUnpackData(_depth, false); + _unpackDataB = GetUnpackData(_depth, true); + _macroCosineDistancesUnpack = GetMacroCosineDistancesUnpack(_depth); + _microMu = 12; + _microNu = 32; } } diff --git a/src/Simd/SimdAvx512bwDescrIntCdd.cpp b/src/Simd/SimdAvx512bwDescrIntCdd.cpp new file mode 100644 index 0000000000..f307c80a3f --- /dev/null +++ b/src/Simd/SimdAvx512bwDescrIntCdd.cpp @@ -0,0 +1,976 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#ifdef SIMD_AVX512BW_ENABLE + namespace Avx512bw + { + template int32_t Correlation(const uint8_t* a, const uint8_t* b, size_t size); + + template<> int32_t Correlation<4>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m512i ab32 = _mm512_setzero_si512(); + size_t i = 0, size128 = AlignLo(size, 128); + for (; i < size128; i += 128, a += 64, b += 64) + { + __m512i _a = _mm512_loadu_si512((__m512i*)a); + __m512i _b = _mm512_loadu_si512((__m512i*)b); + __m512i ab16 = _mm512_maddubs_epi16(_mm512_and_si512(_a, K8_0F), _mm512_and_si512(_b, K8_0F)); + ab16 = _mm512_add_epi16(ab16, _mm512_maddubs_epi16(_mm512_and_si512(_mm512_srli_epi16(_a, 4), K8_0F), _mm512_and_si512(_mm512_srli_epi16(_b, 4), K8_0F))); + ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(ab16, K16_0001)); + } + if(i < size) + { + __mmask16 mask = TailMask16((size - i) / 8); + __m512i _a = _mm512_maskz_loadu_epi32(mask, a); + __m512i _b = _mm512_maskz_loadu_epi32(mask, b); + __m512i ab16 = _mm512_maddubs_epi16(_mm512_and_si512(_a, K8_0F), _mm512_and_si512(_b, K8_0F)); + ab16 = _mm512_add_epi16(ab16, _mm512_maddubs_epi16(_mm512_and_si512(_mm512_srli_epi16(_a, 4), K8_0F), _mm512_and_si512(_mm512_srli_epi16(_b, 4), K8_0F))); + ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(ab16, K16_0001)); + } + return ExtractSum(ab32); + } + + SIMD_INLINE __m512i Load5(const uint8_t* ptr, __mmask32 mask = 0x000FFFFF) + { + return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C5_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C5_SHFL), C5_MULLO), 11); + } + + template<> int32_t Correlation<5>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m512i _ab = _mm512_setzero_si512(); + size_t i = 0, size32 = AlignLo(size, 32); + for (; i < size32; i += 32, a += 20, b += 20) + { + __m512i _a = Load5(a); + __m512i _b = Load5(b); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 5); + __m512i _a = Load5(a, mask); + __m512i _b = Load5(b, mask); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + return ExtractSum(_ab); + } + + SIMD_INLINE __m512i Load6(const uint8_t* ptr, __mmask32 mask = 0x00FFFFFF) + { + return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C6_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C6_SHFL), C6_MULLO), 10); + } + + template<> int32_t Correlation<6>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m512i _ab = _mm512_setzero_si512(); + size_t i = 0, size32 = AlignLo(size, 32); + for (; i < size32; i += 32, a += 24, b += 24) + { + __m512i _a = Load6(a); + __m512i _b = Load6(b); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 6); + __m512i _a = Load6(a, mask); + __m512i _b = Load6(b, mask); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + return ExtractSum(_ab); + } + + SIMD_INLINE __m512i Load7(const uint8_t* ptr, __mmask32 mask = 0x0FFFFFFF) + { + return _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(_mm512_permutexvar_epi32(C7_PERM, _mm512_castsi256_si512(_mm256_maskz_loadu_epi8(mask, ptr))), C7_SHFL), C7_MULLO), 9); + } + + template<> int32_t Correlation<7>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + __m512i _ab = _mm512_setzero_si512(); + size_t i = 0, size32 = AlignLo(size, 32); + for (; i < size32; i += 32, a += 28, b += 28) + { + __m512i _a = Load7(a); + __m512i _b = Load7(b); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 7); + __m512i _a = Load7(a, mask); + __m512i _b = Load7(b, mask); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + return ExtractSum(_ab); + } + + template<> int32_t Correlation<8>(const uint8_t* a, const uint8_t* b, size_t size) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32); + __m512i _ab = _mm512_setzero_si512(); + for (; i < size32; i += 32) + { + __m512i _a = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(a + i))); + __m512i _b = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(b + i))); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + if ( i < size) + { + __mmask32 mask = TailMask32(size - i); + __m512i _a = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, a + i)); + __m512i _b = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, b + i)); + _ab = _mm512_add_epi32(_mm512_madd_epi16(_a, _b), _ab); + } + return ExtractSum(_ab); + } + + template void CosineDistance(const uint8_t* a, const uint8_t* b, size_t size, float* distance) + { + float abSum = (float)Correlation(a + 16, b + 16, size); + Base::DecodeCosineDistance(a, b, abSum, distance); + } + + //------------------------------------------------------------------------------------------------- + + template void MicroCosineDistancesDirect4x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + + template<> void MicroCosineDistancesDirect4x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size128 = AlignLo(size, 128), o = 16; + __m512i a00, a10, a20, a30, a01, a11, a21, a31, b00, b01; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + __m512i ab10 = _mm512_setzero_si512(); + __m512i ab11 = _mm512_setzero_si512(); + __m512i ab12 = _mm512_setzero_si512(); + __m512i ab13 = _mm512_setzero_si512(); + __m512i ab20 = _mm512_setzero_si512(); + __m512i ab21 = _mm512_setzero_si512(); + __m512i ab22 = _mm512_setzero_si512(); + __m512i ab23 = _mm512_setzero_si512(); + __m512i ab30 = _mm512_setzero_si512(); + __m512i ab31 = _mm512_setzero_si512(); + __m512i ab32 = _mm512_setzero_si512(); + __m512i ab33 = _mm512_setzero_si512(); + for (; i < size128; i += 128, o += 64) + { + a01 = _mm512_loadu_si512((__m512i*)(A[0] + o)); + a00 = _mm512_and_si512(a01, K8_0F); + a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); + a11 = _mm512_loadu_si512((__m512i*)(A[1] + o)); + a10 = _mm512_and_si512(a11, K8_0F); + a11 = _mm512_and_si512(_mm512_srli_epi16(a11, 4), K8_0F); + a21 = _mm512_loadu_si512((__m512i*)(A[2] + o)); + a20 = _mm512_and_si512(a21, K8_0F); + a21 = _mm512_and_si512(_mm512_srli_epi16(a21, 4), K8_0F); + a31 = _mm512_loadu_si512((__m512i*)(A[3] + o)); + a30 = _mm512_and_si512(a31, K8_0F); + a31 = _mm512_and_si512(_mm512_srli_epi16(a31, 4), K8_0F); + + b01 = _mm512_loadu_si512((__m512i*)(B[0] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab10 = _mm512_add_epi32(ab10, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab20 = _mm512_add_epi32(ab20, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab30 = _mm512_add_epi32(ab30, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[1] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab11 = _mm512_add_epi32(ab11, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab21 = _mm512_add_epi32(ab21, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab31 = _mm512_add_epi32(ab31, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[2] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab12 = _mm512_add_epi32(ab12, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab22 = _mm512_add_epi32(ab22, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[3] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab13 = _mm512_add_epi32(ab13, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab23 = _mm512_add_epi32(ab23, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab33 = _mm512_add_epi32(ab33, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + } + if (i < size) + { + __mmask16 mask = TailMask32((size - i) / 8); + a01 = _mm512_maskz_loadu_epi32(mask, A[0] + o); + a00 = _mm512_and_si512(a01, K8_0F); + a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); + a11 = _mm512_maskz_loadu_epi32(mask, A[1] + o); + a10 = _mm512_and_si512(a11, K8_0F); + a11 = _mm512_and_si512(_mm512_srli_epi16(a11, 4), K8_0F); + a21 = _mm512_maskz_loadu_epi32(mask, A[2] + o); + a20 = _mm512_and_si512(a21, K8_0F); + a21 = _mm512_and_si512(_mm512_srli_epi16(a21, 4), K8_0F); + a31 = _mm512_maskz_loadu_epi32(mask, A[3] + o); + a30 = _mm512_and_si512(a31, K8_0F); + a31 = _mm512_and_si512(_mm512_srli_epi16(a31, 4), K8_0F); + + b01 = _mm512_maskz_loadu_epi32(mask, B[0] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab10 = _mm512_add_epi32(ab10, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab20 = _mm512_add_epi32(ab20, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab30 = _mm512_add_epi32(ab30, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[1] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab11 = _mm512_add_epi32(ab11, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab21 = _mm512_add_epi32(ab21, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab31 = _mm512_add_epi32(ab31, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[2] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab12 = _mm512_add_epi32(ab12, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab22 = _mm512_add_epi32(ab22, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab32 = _mm512_add_epi32(ab32, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[3] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + ab13 = _mm512_add_epi32(ab13, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a10, b00), _mm512_maddubs_epi16(a11, b01)), K16_0001)); + ab23 = _mm512_add_epi32(ab23, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a20, b00), _mm512_maddubs_epi16(a21, b01)), K16_0001)); + ab33 = _mm512_add_epi32(ab33, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a30, b00), _mm512_maddubs_epi16(a31, b01)), K16_0001)); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); + __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); + } + + template<> void MicroCosineDistancesDirect4x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, a1, a2, a3, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + __m512i ab10 = _mm512_setzero_si512(); + __m512i ab11 = _mm512_setzero_si512(); + __m512i ab12 = _mm512_setzero_si512(); + __m512i ab13 = _mm512_setzero_si512(); + __m512i ab20 = _mm512_setzero_si512(); + __m512i ab21 = _mm512_setzero_si512(); + __m512i ab22 = _mm512_setzero_si512(); + __m512i ab23 = _mm512_setzero_si512(); + __m512i ab30 = _mm512_setzero_si512(); + __m512i ab31 = _mm512_setzero_si512(); + __m512i ab32 = _mm512_setzero_si512(); + __m512i ab33 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 20) + { + a0 = Load5(A[0] + o); + a1 = Load5(A[1] + o); + a2 = Load5(A[2] + o); + a3 = Load5(A[3] + o); + + b0 = Load5(B[0] + o); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = Load5(B[1] + o); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = Load5(B[2] + o); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = Load5(B[3] + o); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 5); + a0 = Load5(A[0] + o, mask); + a1 = Load5(A[1] + o, mask); + a2 = Load5(A[2] + o, mask); + a3 = Load5(A[3] + o, mask); + + b0 = Load5(B[0] + o, mask); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = Load5(B[1] + o, mask); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = Load5(B[2] + o, mask); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = Load5(B[3] + o, mask); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); + __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); + } + + template<> void MicroCosineDistancesDirect4x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, a1, a2, a3, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + __m512i ab10 = _mm512_setzero_si512(); + __m512i ab11 = _mm512_setzero_si512(); + __m512i ab12 = _mm512_setzero_si512(); + __m512i ab13 = _mm512_setzero_si512(); + __m512i ab20 = _mm512_setzero_si512(); + __m512i ab21 = _mm512_setzero_si512(); + __m512i ab22 = _mm512_setzero_si512(); + __m512i ab23 = _mm512_setzero_si512(); + __m512i ab30 = _mm512_setzero_si512(); + __m512i ab31 = _mm512_setzero_si512(); + __m512i ab32 = _mm512_setzero_si512(); + __m512i ab33 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 24) + { + a0 = Load6(A[0] + o); + a1 = Load6(A[1] + o); + a2 = Load6(A[2] + o); + a3 = Load6(A[3] + o); + + b0 = Load6(B[0] + o); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = Load6(B[1] + o); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = Load6(B[2] + o); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = Load6(B[3] + o); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 6); + a0 = Load6(A[0] + o, mask); + a1 = Load6(A[1] + o, mask); + a2 = Load6(A[2] + o, mask); + a3 = Load6(A[3] + o, mask); + + b0 = Load6(B[0] + o, mask); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = Load6(B[1] + o, mask); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = Load6(B[2] + o, mask); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = Load6(B[3] + o, mask); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); + __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); + } + + template<> void MicroCosineDistancesDirect4x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, a1, a2, a3, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + __m512i ab10 = _mm512_setzero_si512(); + __m512i ab11 = _mm512_setzero_si512(); + __m512i ab12 = _mm512_setzero_si512(); + __m512i ab13 = _mm512_setzero_si512(); + __m512i ab20 = _mm512_setzero_si512(); + __m512i ab21 = _mm512_setzero_si512(); + __m512i ab22 = _mm512_setzero_si512(); + __m512i ab23 = _mm512_setzero_si512(); + __m512i ab30 = _mm512_setzero_si512(); + __m512i ab31 = _mm512_setzero_si512(); + __m512i ab32 = _mm512_setzero_si512(); + __m512i ab33 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 28) + { + a0 = Load7(A[0] + o); + a1 = Load7(A[1] + o); + a2 = Load7(A[2] + o); + a3 = Load7(A[3] + o); + + b0 = Load7(B[0] + o); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = Load7(B[1] + o); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = Load7(B[2] + o); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = Load7(B[3] + o); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 7); + a0 = Load7(A[0] + o, mask); + a1 = Load7(A[1] + o, mask); + a2 = Load7(A[2] + o, mask); + a3 = Load7(A[3] + o, mask); + + b0 = Load7(B[0] + o, mask); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = Load7(B[1] + o, mask); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = Load7(B[2] + o, mask); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = Load7(B[3] + o, mask); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); + __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); + } + + template<> void MicroCosineDistancesDirect4x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, a1, a2, a3, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + __m512i ab10 = _mm512_setzero_si512(); + __m512i ab11 = _mm512_setzero_si512(); + __m512i ab12 = _mm512_setzero_si512(); + __m512i ab13 = _mm512_setzero_si512(); + __m512i ab20 = _mm512_setzero_si512(); + __m512i ab21 = _mm512_setzero_si512(); + __m512i ab22 = _mm512_setzero_si512(); + __m512i ab23 = _mm512_setzero_si512(); + __m512i ab30 = _mm512_setzero_si512(); + __m512i ab31 = _mm512_setzero_si512(); + __m512i ab32 = _mm512_setzero_si512(); + __m512i ab33 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 32) + { + a0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o))); + a1 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[1] + o))); + a2 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[2] + o))); + a3 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[3] + o))); + + b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o))); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o))); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o))); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o))); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + if (i < size) + { + __mmask32 mask = TailMask32(size - i); + a0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[0] + o)); + a1 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[1] + o)); + a2 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[2] + o)); + a3 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[3] + o)); + + b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[0] + o)); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + ab10 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab10); + ab20 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab20); + ab30 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab30); + + b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[1] + o)); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + ab11 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab11); + ab21 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab21); + ab31 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab31); + + b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[2] + o)); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + ab12 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab12); + ab22 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab22); + ab32 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab32); + + b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[3] + o)); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + ab13 = _mm512_add_epi32(_mm512_madd_epi16(a1, b0), ab13); + ab23 = _mm512_add_epi32(_mm512_madd_epi16(a2, b0), ab23); + ab33 = _mm512_add_epi32(_mm512_madd_epi16(a3, b0), ab33); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + __m128 ab1 = _mm_cvtepi32_ps(Extract4Sums(ab10, ab11, ab12, ab13)); + __m128 ab2 = _mm_cvtepi32_ps(Extract4Sums(ab20, ab21, ab22, ab23)); + __m128 ab3 = _mm_cvtepi32_ps(Extract4Sums(ab30, ab31, ab32, ab33)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + Sse41::DecodeCosineDistances1x4(A[1], B, ab1, distances + 1 * stride); + Sse41::DecodeCosineDistances1x4(A[2], B, ab2, distances + 2 * stride); + Sse41::DecodeCosineDistances1x4(A[3], B, ab3, distances + 3 * stride); + } + + template void MicroCosineDistancesDirect1x4(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride); + + template<> void MicroCosineDistancesDirect1x4<4>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size128 = AlignLo(size, 128), o = 16; + __m512i a00, a01, b00, b01; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + for (; i < size128; i += 128, o += 64) + { + a01 = _mm512_loadu_si512((__m512i*)(A[0] + o)); + a00 = _mm512_and_si512(a01, K8_0F); + a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); + + b01 = _mm512_loadu_si512((__m512i*)(B[0] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[1] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[2] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_loadu_si512((__m512i*)(B[3] + o)); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + } + if (i < size) + { + __mmask16 mask = TailMask32((size - i) / 8); + a01 = _mm512_maskz_loadu_epi32(mask, A[0] + o); + a00 = _mm512_and_si512(a01, K8_0F); + a01 = _mm512_and_si512(_mm512_srli_epi16(a01, 4), K8_0F); + + b01 = _mm512_maskz_loadu_epi32(mask, B[0] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab00 = _mm512_add_epi32(ab00, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[1] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab01 = _mm512_add_epi32(ab01, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + b01 = _mm512_maskz_loadu_epi32(mask, B[2] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab02 = _mm512_add_epi32(ab02, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + + b01 = _mm512_maskz_loadu_epi32(mask, B[3] + o); + b00 = _mm512_and_si512(b01, K8_0F); + b01 = _mm512_and_si512(_mm512_srli_epi16(b01, 4), K8_0F); + ab03 = _mm512_add_epi32(ab03, _mm512_madd_epi16(_mm512_add_epi16(_mm512_maddubs_epi16(a00, b00), _mm512_maddubs_epi16(a01, b01)), K16_0001)); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<5>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 20) + { + a0 = Load5(A[0] + o); + + b0 = Load5(B[0] + o); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = Load5(B[1] + o); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = Load5(B[2] + o); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = Load5(B[3] + o); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 5); + a0 = Load5(A[0] + o, mask); + + b0 = Load5(B[0] + o, mask); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = Load5(B[1] + o, mask); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = Load5(B[2] + o, mask); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = Load5(B[3] + o, mask); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<6>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 24) + { + a0 = Load6(A[0] + o); + + b0 = Load6(B[0] + o); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = Load6(B[1] + o); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = Load6(B[2] + o); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = Load6(B[3] + o); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 6); + a0 = Load6(A[0] + o, mask); + + b0 = Load6(B[0] + o, mask); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = Load6(B[1] + o, mask); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = Load6(B[2] + o, mask); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = Load6(B[3] + o, mask); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<7>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 28) + { + a0 = Load7(A[0] + o); + + b0 = Load7(B[0] + o); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = Load7(B[1] + o); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = Load7(B[2] + o); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = Load7(B[3] + o); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + if (i < size) + { + __mmask32 mask = TailMask32((size - i) / 8 * 7); + a0 = Load7(A[0] + o, mask); + + b0 = Load7(B[0] + o, mask); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = Load7(B[1] + o, mask); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = Load7(B[2] + o, mask); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = Load7(B[3] + o, mask); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template<> void MicroCosineDistancesDirect1x4<8>(const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t i = 0, size32 = AlignLo(size, 32), o = 16; + __m512i a0, b0; + __m512i ab00 = _mm512_setzero_si512(); + __m512i ab01 = _mm512_setzero_si512(); + __m512i ab02 = _mm512_setzero_si512(); + __m512i ab03 = _mm512_setzero_si512(); + for (; i < size32; i += 32, o += 32) + { + a0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(A[0] + o))); + + b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[0] + o))); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[1] + o))); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[2] + o))); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = _mm512_cvtepu8_epi16(_mm256_loadu_si256((__m256i*)(B[3] + o))); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + if (i < size) + { + __mmask32 mask = TailMask32(size - i); + a0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, A[0] + o)); + + b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[0] + o)); + ab00 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab00); + + b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[1] + o)); + ab01 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab01); + + b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[2] + o)); + ab02 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab02); + + b0 = _mm512_cvtepu8_epi16(_mm256_maskz_loadu_epi8(mask, B[3] + o)); + ab03 = _mm512_add_epi32(_mm512_madd_epi16(a0, b0), ab03); + } + __m128 ab0 = _mm_cvtepi32_ps(Extract4Sums(ab00, ab01, ab02, ab03)); + Sse41::DecodeCosineDistances1x4(A[0], B, ab0, distances + 0 * stride); + } + + template void MacroCosineDistancesDirect(size_t M, size_t N, const uint8_t* const* A, const uint8_t* const* B, size_t size, float* distances, size_t stride) + { + size_t M4 = AlignLoAny(M, 4); + size_t N4 = AlignLoAny(N, 4); + size_t i = 0; + for (; i < M4; i += 4) + { + size_t j = 0; + for (; j < N4; j += 4) + MicroCosineDistancesDirect4x4(A + i, B + j, size, distances + j, stride); + for (; j < N; j += 1) + { + CosineDistance(A[i + 0], B[j], size, distances + j + 0 * stride); + CosineDistance(A[i + 1], B[j], size, distances + j + 1 * stride); + CosineDistance(A[i + 2], B[j], size, distances + j + 2 * stride); + CosineDistance(A[i + 3], B[j], size, distances + j + 3 * stride); + } + distances += 4 * stride; + } + for (; i < M; i++) + { + size_t j = 0; + for (; j < N4; j += 4) + MicroCosineDistancesDirect1x4(A + i, B + j, size, distances + j, stride); + for (; j < N; j += 1) + CosineDistance(A[i], B[j], size, distances + j); + distances += 1 * stride; + } + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::CosineDistancePtr GetCosineDistance(size_t depth) + { + switch (depth) + { + case 4: return CosineDistance<4>; + case 5: return CosineDistance<5>; + case 6: return CosineDistance<6>; + case 7: return CosineDistance<7>; + case 8: return CosineDistance<8>; + default: assert(0); return NULL; + } + } + + Sse41::DescrInt::MacroCosineDistancesDirectPtr GetMacroCosineDistancesDirect(size_t depth) + { + switch (depth) + { + case 4: return MacroCosineDistancesDirect<4>; + case 5: return MacroCosineDistancesDirect<5>; + case 6: return MacroCosineDistancesDirect<6>; + case 7: return MacroCosineDistancesDirect<7>; + case 8: return MacroCosineDistancesDirect<8>; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdAvx512bwDescrIntCdu.cpp b/src/Simd/SimdAvx512bwDescrIntCdu.cpp new file mode 100644 index 0000000000..3b7120ebfd --- /dev/null +++ b/src/Simd/SimdAvx512bwDescrIntCdu.cpp @@ -0,0 +1,481 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdSynet.h" + +namespace Simd +{ +#ifdef SIMD_AVX512BW_ENABLE + namespace Avx512bw + { + const __m512i U4_PERM = SIMD_MM512_SETR_EPI32( + 0x0, 0x1,-1, -1, 0x2, 0x3, -1, -1, 0x4, 0x5, -1, -1, 0x6, 0x7, -1, -1); + const __m512i U4_SHFL0 = SIMD_MM512_SETR_EPI8( + 0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3); + const __m512i U4_SHFL1 = SIMD_MM512_SETR_EPI8( + 0x4, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x7, 0x7, + 0x4, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x7, 0x7, + 0x4, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x7, 0x7, + 0x4, 0x4, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x7, 0x7); + + const __m512i U5_PERM = SIMD_MM512_SETR_EPI32( + 0x0, 0x1, 0x2, -1, 0x2, 0x3, 0x4, -1, 0x5, 0x6, 0x7, -1, 0x7, 0x8, 0x9, -1); + const __m512i U5_SHFL0 = SIMD_MM512_SETR_EPI8( + 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x4, + 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, + 0x0, 0x0, 0x0, 0x1, 0x1, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x4, + 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, 0x5, 0x6, 0x6, 0x6); + const __m512i U5_SHFL1 = SIMD_MM512_SETR_EPI8( + 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9, + 0x7, 0x7, 0x7, 0x8, 0x8, 0x8, 0x8, 0x9, 0x9, 0xA, 0xA, 0xA, 0xA, 0xB, 0xB, 0xB, + 0x5, 0x5, 0x5, 0x6, 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9, + 0x7, 0x7, 0x7, 0x8, 0x8, 0x8, 0x8, 0x9, 0x9, 0xA, 0xA, 0xA, 0xA, 0xB, 0xB, 0xB); + + const __m512i U6_PERM = SIMD_MM512_SETR_EPI32( + 0x0, 0x1, 0x2, -1, 0x3, 0x4, 0x5, -1, 0x6, 0x7, 0x8, -1, 0x9, 0xA, 0xB, -1); + const __m512i U6_SHFL0 = SIMD_MM512_SETR_EPI8( + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5, + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x2, 0x3, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x5); + const __m512i U6_SHFL1 = SIMD_MM512_SETR_EPI8( + 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xB, + 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xB, + 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xB, + 0x6, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, 0x9, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xB); + + const __m512i U7_PERM = SIMD_MM512_SETR_EPI32( + 0x0, 0x1, 0x2, 0x3, 0x3, 0x4, 0x5, 0x6, 0x7, 0x8, 0x9, 0xA, 0xA, 0xB, 0xC, 0xD); + const __m512i U7_SHFL0 = SIMD_MM512_SETR_EPI8( + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x6, + 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8, + 0x0, 0x0, 0x0, 0x1, 0x1, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x6, + 0x2, 0x2, 0x2, 0x3, 0x3, 0x4, 0x4, 0x5, 0x5, 0x6, 0x6, 0x7, 0x7, 0x8, 0x8, 0x8); + const __m512i U7_SHFL1 = SIMD_MM512_SETR_EPI8( + 0x7, 0x7, 0x7, 0x8, 0x8, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xC, 0xC, 0xD, 0xD, 0xD, + 0x9, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xC, 0xC, 0xD, 0xD, 0xE, 0xE, 0xF, 0xF, 0xF, + 0x7, 0x7, 0x7, 0x8, 0x8, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xC, 0xC, 0xD, 0xD, 0xD, + 0x9, 0x9, 0x9, 0xA, 0xA, 0xB, 0xB, 0xC, 0xC, 0xD, 0xD, 0xE, 0xE, 0xF, 0xF, 0xF); + + //------------------------------------------------------------------------------------------------- + + template __m512i UnpackData64(const uint8_t* src, __mmask64 mask); + + template<> SIMD_INLINE __m512i UnpackData64<4>(const uint8_t* src, __mmask64 mask) + { + __m512i val = _mm512_permutexvar_epi32(U4_PERM, _mm512_maskz_loadu_epi8(mask, src)); + __m512i lo = _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(val, U4_SHFL0), C4_MULLO), 12); + __m512i hi = _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(val, U4_SHFL1), C4_MULLO), 12); + return _mm512_packus_epi16(lo, hi); + } + + template<> SIMD_INLINE __m512i UnpackData64<5>(const uint8_t* src, __mmask64 mask) + { + __m512i val = _mm512_permutexvar_epi32(U5_PERM, _mm512_maskz_loadu_epi8(mask, src)); + __m512i lo = _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(val, U5_SHFL0), C5_MULLO), 11); + __m512i hi = _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(val, U5_SHFL1), C5_MULLO), 11); + return _mm512_packus_epi16(lo, hi); + } + + template<> SIMD_INLINE __m512i UnpackData64<6>(const uint8_t* src, __mmask64 mask) + { + __m512i val = _mm512_permutexvar_epi32(U6_PERM, _mm512_maskz_loadu_epi8(mask, src)); + __m512i lo = _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(val, U6_SHFL0), C6_MULLO), 10); + __m512i hi = _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(val, U6_SHFL1), C6_MULLO), 10); + return _mm512_packus_epi16(lo, hi); + } + + template<> SIMD_INLINE __m512i UnpackData64<7>(const uint8_t* src, __mmask64 mask) + { + __m512i val = _mm512_permutexvar_epi32(U7_PERM, _mm512_maskz_loadu_epi8(mask, src)); + __m512i lo = _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(val, U7_SHFL0), C7_MULLO), 9); + __m512i hi = _mm512_srli_epi16(_mm512_mullo_epi16(_mm512_shuffle_epi8(val, U7_SHFL1), C7_MULLO), 9); + return _mm512_packus_epi16(lo, hi); + } + + //------------------------------------------------------------------------------------------------- + + template void UnpackDataA(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t size64 = AlignLo(size, 64); + __mmask64 srcBody = TailMask64(8 * bits), dstBody = __mmask64(-1), srcTail = TailMask64((size - size64) / 8 * bits), dstTail = TailMask64(size - size64); + for (size_t i = 0; i < count; i++) + { + const uint8_t* ps = src[i] + 16; + uint8_t* pd = (uint8_t*)dst + i * size; + size_t j = 0; + for (; j < size64; j += 64, ps += 8 * bits, pd += 64) + _mm512_mask_storeu_epi8(pd, dstBody, UnpackData64(ps, srcBody)); + if(j < size64) + _mm512_mask_storeu_epi8(pd, dstTail, UnpackData64(ps, srcTail)); + } + } + + //------------------------------------------------------------------------------------------------- + + template SIMD_INLINE void UnpackDataBx16xN(const uint8_t* const* src, size_t offset, uint8_t* dst) + { + __mmask64 mask = TailMask64(bits * N); + __m512i a0, a1, a2, a3, a4, a5, a6, a7, a8, a9, aA, aB, aC, aD, aE, aF, b0, b1, b2, b3, b4, b5, b6, b7, b8, b9, bA, bB, bC, bD, bE, bF; + + a0 = UnpackData64(src[0x0] + offset, mask); + a1 = UnpackData64(src[0x1] + offset, mask); + a2 = UnpackData64(src[0x2] + offset, mask); + a3 = UnpackData64(src[0x3] + offset, mask); + a4 = UnpackData64(src[0x4] + offset, mask); + a5 = UnpackData64(src[0x5] + offset, mask); + a6 = UnpackData64(src[0x6] + offset, mask); + a7 = UnpackData64(src[0x7] + offset, mask); + a8 = UnpackData64(src[0x8] + offset, mask); + a9 = UnpackData64(src[0x9] + offset, mask); + aA = UnpackData64(src[0xA] + offset, mask); + aB = UnpackData64(src[0xB] + offset, mask); + aC = UnpackData64(src[0xC] + offset, mask); + aD = UnpackData64(src[0xD] + offset, mask); + aE = UnpackData64(src[0xE] + offset, mask); + aF = UnpackData64(src[0xF] + offset, mask); + + b0 = _mm512_unpacklo_epi32(a0, a2); + b1 = _mm512_unpacklo_epi32(a1, a3); + b2 = _mm512_unpackhi_epi32(a0, a2); + b3 = _mm512_unpackhi_epi32(a1, a3); + b4 = _mm512_unpacklo_epi32(a4, a6); + b5 = _mm512_unpacklo_epi32(a5, a7); + b6 = _mm512_unpackhi_epi32(a4, a6); + b7 = _mm512_unpackhi_epi32(a5, a7); + b8 = _mm512_unpacklo_epi32(a8, aA); + b9 = _mm512_unpacklo_epi32(a9, aB); + bA = _mm512_unpackhi_epi32(a8, aA); + bB = _mm512_unpackhi_epi32(a9, aB); + bC = _mm512_unpacklo_epi32(aC, aE); + bD = _mm512_unpacklo_epi32(aD, aF); + bE = _mm512_unpackhi_epi32(aC, aE); + bF = _mm512_unpackhi_epi32(aD, aF); + + a0 = _mm512_unpacklo_epi32(b0, b1); + a1 = _mm512_unpackhi_epi32(b0, b1); + a2 = _mm512_unpacklo_epi32(b2, b3); + a3 = _mm512_unpackhi_epi32(b2, b3); + a4 = _mm512_unpacklo_epi32(b4, b5); + a5 = _mm512_unpackhi_epi32(b4, b5); + a6 = _mm512_unpacklo_epi32(b6, b7); + a7 = _mm512_unpackhi_epi32(b6, b7); + a8 = _mm512_unpacklo_epi32(b8, b9); + a9 = _mm512_unpackhi_epi32(b8, b9); + aA = _mm512_unpacklo_epi32(bA, bB); + aB = _mm512_unpackhi_epi32(bA, bB); + aC = _mm512_unpacklo_epi32(bC, bD); + aD = _mm512_unpackhi_epi32(bC, bD); + aE = _mm512_unpacklo_epi32(bE, bF); + aF = _mm512_unpackhi_epi32(bE, bF); + + b0 = _mm512_shuffle_i32x4(a0, a4, 0x44); + b1 = _mm512_shuffle_i32x4(a1, a5, 0x44); + b2 = _mm512_shuffle_i32x4(a2, a6, 0x44); + b3 = _mm512_shuffle_i32x4(a3, a7, 0x44); + b4 = _mm512_shuffle_i32x4(a0, a4, 0xEE); + b5 = _mm512_shuffle_i32x4(a1, a5, 0xEE); + b6 = _mm512_shuffle_i32x4(a2, a6, 0xEE); + b7 = _mm512_shuffle_i32x4(a3, a7, 0xEE); + b8 = _mm512_shuffle_i32x4(a8, aC, 0x44); + b9 = _mm512_shuffle_i32x4(a9, aD, 0x44); + bA = _mm512_shuffle_i32x4(aA, aE, 0x44); + bB = _mm512_shuffle_i32x4(aB, aF, 0x44); + bC = _mm512_shuffle_i32x4(a8, aC, 0xEE); + bD = _mm512_shuffle_i32x4(a9, aD, 0xEE); + bE = _mm512_shuffle_i32x4(aA, aE, 0xEE); + bF = _mm512_shuffle_i32x4(aB, aF, 0xEE); + + a0 = _mm512_shuffle_i32x4(b0, b8, 0x88); + a1 = _mm512_shuffle_i32x4(b1, b9, 0x88); + a2 = _mm512_shuffle_i32x4(b2, bA, 0x88); + a3 = _mm512_shuffle_i32x4(b3, bB, 0x88); + a4 = _mm512_shuffle_i32x4(b0, b8, 0xDD); + a5 = _mm512_shuffle_i32x4(b1, b9, 0xDD); + a6 = _mm512_shuffle_i32x4(b2, bA, 0xDD); + a7 = _mm512_shuffle_i32x4(b3, bB, 0xDD); + a8 = _mm512_shuffle_i32x4(b4, bC, 0x88); + a9 = _mm512_shuffle_i32x4(b5, bD, 0x88); + aA = _mm512_shuffle_i32x4(b6, bE, 0x88); + aB = _mm512_shuffle_i32x4(b7, bF, 0x88); + aC = _mm512_shuffle_i32x4(b4, bC, 0xDD); + aD = _mm512_shuffle_i32x4(b5, bD, 0xDD); + aE = _mm512_shuffle_i32x4(b6, bE, 0xDD); + aF = _mm512_shuffle_i32x4(b7, bF, 0xDD); + + if (N > 0) _mm512_storeu_si512(dst + 0x0 * DA, a0); + if (N > 0) _mm512_storeu_si512(dst + 0x1 * DA, a1); + if (N > 1) _mm512_storeu_si512(dst + 0x2 * DA, a2); + if (N > 1) _mm512_storeu_si512(dst + 0x3 * DA, a3); + if (N > 2) _mm512_storeu_si512(dst + 0x4 * DA, a4); + if (N > 2) _mm512_storeu_si512(dst + 0x5 * DA, a5); + if (N > 3) _mm512_storeu_si512(dst + 0x6 * DA, a6); + if (N > 3) _mm512_storeu_si512(dst + 0x7 * DA, a7); + if (N > 4) _mm512_storeu_si512(dst + 0x8 * DA, a8); + if (N > 4) _mm512_storeu_si512(dst + 0x9 * DA, a9); + if (N > 5) _mm512_storeu_si512(dst + 0xA * DA, aA); + if (N > 5) _mm512_storeu_si512(dst + 0xB * DA, aB); + if (N > 6) _mm512_storeu_si512(dst + 0xC * DA, aC); + if (N > 6) _mm512_storeu_si512(dst + 0xD * DA, aD); + if (N > 7) _mm512_storeu_si512(dst + 0xE * DA, aE); + if (N > 7) _mm512_storeu_si512(dst + 0xF * DA, aF); + } + + typedef void (*UnpackDataBx16xN_Ptr)(const uint8_t* const* src, size_t offset, uint8_t* dst); + + template UnpackDataBx16xN_Ptr GetUnpackDataBx16xN(int tail) + { + switch (tail / 8) + { + case 0: return NULL; + case 1: return UnpackDataBx16xN; + case 2: return UnpackDataBx16xN; + case 3: return UnpackDataBx16xN; + case 4: return UnpackDataBx16xN; + case 5: return UnpackDataBx16xN; + case 6: return UnpackDataBx16xN; + case 7: return UnpackDataBx16xN; + case 8: return UnpackDataBx16xN; + default: + assert(0); return NULL; + } + } + + template void UnpackDataB(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) + { + size_t countDF = AlignLo(count, DF), size64 = AlignLo(size, 64), i, j, o; + UnpackDataBx16xN_Ptr unpackDataMain = GetUnpackDataBx16xN(64); + UnpackDataBx16xN_Ptr unpackDataTail = GetUnpackDataBx16xN(size - size64); + for (i = 0; i < countDF; i += DF, src += DF) + { + for (j = 0, o = 16; j < size64; j += 64, o += 8 * bits, dst += 32 * A) + { + unpackDataMain(src + 0, o, dst + 0); + unpackDataMain(src + F, o, dst + A); + } + if (j < size) + { + unpackDataTail(src + 0, o, dst + 0); + unpackDataTail(src + F, o, dst + A); + } + } + if (i < count) + { + size_t tail = count - countDF; + const uint8_t* _src[DF]; + for (size_t j = 0; j < DF; i++, j++) + _src[j] = i < count ? *src++ : src[-1]; + for (j = 0, o = 16; j < size64; j += 64, o += 8 * bits, dst += 32 * A) + { + unpackDataMain(_src + 0, o, dst + 0); + if(tail > F) + unpackDataMain(_src + F, o, dst + A); + } + if (j < size) + { + unpackDataTail(_src + 0, o, dst + 0); + if (tail > F) + unpackDataTail(_src + F, o, dst + A); + } + } + } + + //------------------------------------------------------------------------------------------------- + + template void Correlation8_2xM(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride) + { + __m512i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, ab50, ab51, ab60, ab61, ab70, ab71, ab80, ab81, ab90, ab91, abA0, abA1, abB0, abB1, a0, b0, b1; + const uint8_t* ad1 = ad0 + 1 * K; + const uint8_t* ad2 = ad0 + 2 * K; + const uint8_t* ad3 = ad0 + 3 * K; + const uint8_t* ad4 = ad0 + 4 * K; + const uint8_t* ad5 = ad0 + 5 * K; + if (N > F) + { + if (M > 0x0) ab00 = _mm512_setzero_si512(), ab01 = _mm512_setzero_si512(); + if (M > 0x1) ab10 = _mm512_setzero_si512(), ab11 = _mm512_setzero_si512(); + if (M > 0x2) ab20 = _mm512_setzero_si512(), ab21 = _mm512_setzero_si512(); + if (M > 0x3) ab30 = _mm512_setzero_si512(), ab31 = _mm512_setzero_si512(); + if (M > 0x4) ab40 = _mm512_setzero_si512(), ab41 = _mm512_setzero_si512(); + if (M > 0x5) ab50 = _mm512_setzero_si512(), ab51 = _mm512_setzero_si512(); + if (M > 0x6) ab60 = _mm512_setzero_si512(), ab61 = _mm512_setzero_si512(); + if (M > 0x7) ab70 = _mm512_setzero_si512(), ab71 = _mm512_setzero_si512(); + if (M > 0x8) ab80 = _mm512_setzero_si512(), ab81 = _mm512_setzero_si512(); + if (M > 0x9) ab90 = _mm512_setzero_si512(), ab91 = _mm512_setzero_si512(); + if (M > 0xA) abA0 = _mm512_setzero_si512(), abA1 = _mm512_setzero_si512(); + if (M > 0xB) abB0 = _mm512_setzero_si512(), abB1 = _mm512_setzero_si512(); + for (size_t k0 = 0, k6 = K * 6; k0 < K; k0 += 4, k6 += 4) + { + b0 = _mm512_loadu_si512((__m512i*)bd + 0); + b1 = _mm512_loadu_si512((__m512i*)bd + 1); + if (M > 0x0) a0 = Set4(ad0 + k0), Madd4(ab00, a0, b0), Madd4(ab01, a0, b1); + if (M > 0x1) a0 = Set4(ad1 + k0), Madd4(ab10, a0, b0), Madd4(ab11, a0, b1); + if (M > 0x2) a0 = Set4(ad2 + k0), Madd4(ab20, a0, b0), Madd4(ab21, a0, b1); + if (M > 0x3) a0 = Set4(ad3 + k0), Madd4(ab30, a0, b0), Madd4(ab31, a0, b1); + if (M > 0x4) a0 = Set4(ad4 + k0), Madd4(ab40, a0, b0), Madd4(ab41, a0, b1); + if (M > 0x5) a0 = Set4(ad5 + k0), Madd4(ab50, a0, b0), Madd4(ab51, a0, b1); + if (M > 0x6) a0 = Set4(ad0 + k6), Madd4(ab60, a0, b0), Madd4(ab61, a0, b1); + if (M > 0x7) a0 = Set4(ad1 + k6), Madd4(ab70, a0, b0), Madd4(ab71, a0, b1); + if (M > 0x8) a0 = Set4(ad2 + k6), Madd4(ab80, a0, b0), Madd4(ab81, a0, b1); + if (M > 0x9) a0 = Set4(ad3 + k6), Madd4(ab90, a0, b0), Madd4(ab91, a0, b1); + if (M > 0xA) a0 = Set4(ad4 + k6), Madd4(abA0, a0, b0), Madd4(abA1, a0, b1); + if (M > 0xB) a0 = Set4(ad5 + k6), Madd4(abB0, a0, b0), Madd4(abB1, a0, b1); + bd += DA; + } + __mmask16 tail = TailMask16(N - F); + if (M > 0x0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab01, distances + F, tail), an += 4, distances += stride; + if (M > 0x1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab11, distances + F, tail), an += 4, distances += stride; + if (M > 0x2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab21, distances + F, tail), an += 4, distances += stride; + if (M > 0x3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab31, distances + F, tail), an += 4, distances += stride; + if (M > 0x4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab41, distances + F, tail), an += 4, distances += stride; + if (M > 0x5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab51, distances + F, tail), an += 4, distances += stride; + if (M > 0x6) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab60, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab61, distances + F, tail), an += 4, distances += stride; + if (M > 0x7) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab70, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab71, distances + F, tail), an += 4, distances += stride; + if (M > 0x8) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab80, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab81, distances + F, tail), an += 4, distances += stride; + if (M > 0x9) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab90, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab91, distances + F, tail), an += 4, distances += stride; + if (M > 0xA) DecodeCosineDistances1xF(an, bn + 0, bnStride, abA0, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, abA1, distances + F, tail), an += 4, distances += stride; + if (M > 0xB) DecodeCosineDistances1xF(an, bn + 0, bnStride, abB0, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, abB1, distances + F, tail), an += 4, distances += stride; + } + else + { + if (M > 0x0) ab00 = _mm512_setzero_si512(); + if (M > 0x1) ab10 = _mm512_setzero_si512(); + if (M > 0x2) ab20 = _mm512_setzero_si512(); + if (M > 0x3) ab30 = _mm512_setzero_si512(); + if (M > 0x4) ab40 = _mm512_setzero_si512(); + if (M > 0x5) ab50 = _mm512_setzero_si512(); + if (M > 0x6) ab60 = _mm512_setzero_si512(); + if (M > 0x7) ab70 = _mm512_setzero_si512(); + if (M > 0x8) ab80 = _mm512_setzero_si512(); + if (M > 0x9) ab90 = _mm512_setzero_si512(); + if (M > 0xA) abA0 = _mm512_setzero_si512(); + if (M > 0xB) abB0 = _mm512_setzero_si512(); + for (size_t k0 = 0, k6 = K * 6; k0 < K; k0 += 4, k6 += 4) + { + b0 = _mm512_loadu_si512((__m512i*)bd + 0); + if (M > 0x0) a0 = Set4(ad0 + k0), Madd4(ab00, a0, b0); + if (M > 0x1) a0 = Set4(ad1 + k0), Madd4(ab10, a0, b0); + if (M > 0x2) a0 = Set4(ad2 + k0), Madd4(ab20, a0, b0); + if (M > 0x3) a0 = Set4(ad3 + k0), Madd4(ab30, a0, b0); + if (M > 0x4) a0 = Set4(ad4 + k0), Madd4(ab40, a0, b0); + if (M > 0x5) a0 = Set4(ad5 + k0), Madd4(ab50, a0, b0); + if (M > 0x6) a0 = Set4(ad0 + k6), Madd4(ab60, a0, b0); + if (M > 0x7) a0 = Set4(ad1 + k6), Madd4(ab70, a0, b0); + if (M > 0x8) a0 = Set4(ad2 + k6), Madd4(ab80, a0, b0); + if (M > 0x9) a0 = Set4(ad3 + k6), Madd4(ab90, a0, b0); + if (M > 0xA) a0 = Set4(ad4 + k6), Madd4(abA0, a0, b0); + if (M > 0xB) a0 = Set4(ad5 + k6), Madd4(abB0, a0, b0); + bd += DA; + } + __mmask16 tail = TailMask16(N - F); + if (M > 0x0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0, tail), an += 4, distances += stride; + if (M > 0x1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0, tail), an += 4, distances += stride; + if (M > 0x2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0, tail), an += 4, distances += stride; + if (M > 0x3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0, tail), an += 4, distances += stride; + if (M > 0x4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0, tail), an += 4, distances += stride; + if (M > 0x5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0, tail), an += 4, distances += stride; + if (M > 0x6) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab60, distances + 0, tail), an += 4, distances += stride; + if (M > 0x7) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab70, distances + 0, tail), an += 4, distances += stride; + if (M > 0x8) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab80, distances + 0, tail), an += 4, distances += stride; + if (M > 0x9) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab90, distances + 0, tail), an += 4, distances += stride; + if (M > 0xA) DecodeCosineDistances1xF(an, bn + 0, bnStride, abA0, distances + 0, tail), an += 4, distances += stride; + if (M > 0xB) DecodeCosineDistances1xF(an, bn + 0, bnStride, abB0, distances + 0, tail), an += 4, distances += stride; + } + } + + typedef void(*Correlation8_2xM_Ptr)(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride); + + SIMD_INLINE Correlation8_2xM_Ptr GetCorrelation8_2xM(size_t M) + { + switch (M) + { + case 0x0: return NULL; + case 0x1: return Correlation8_2xM<0x1>; + case 0x2: return Correlation8_2xM<0x2>; + case 0x3: return Correlation8_2xM<0x3>; + case 0x4: return Correlation8_2xM<0x4>; + case 0x5: return Correlation8_2xM<0x5>; + case 0x6: return Correlation8_2xM<0x6>; + case 0x7: return Correlation8_2xM<0x7>; + case 0x8: return Correlation8_2xM<0x8>; + case 0x9: return Correlation8_2xM<0x9>; + case 0xA: return Correlation8_2xM<0xA>; + case 0xB: return Correlation8_2xM<0xB>; + case 0xC: return Correlation8_2xM<0xC>; + } + assert(0); + return NULL; + } + + void MacroCorrelation8(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride) + { + size_t M12 = AlignLoAny(M, 12); + Correlation8_2xM_Ptr correlation_2x12 = GetCorrelation8_2xM(12); + Correlation8_2xM_Ptr correlation_2xT = GetCorrelation8_2xM(M - M12); + for (size_t j = 0; j < N; j += DF) + { + size_t dN = Simd::Min(DF, N - j); + size_t i = 0; + for (; i < M12; i += 12) + correlation_2x12(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + if (i < M) + correlation_2xT(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + bd += K * DF; + bn += DF; + distances += DF; + } + } + + + //------------------------------------------------------------------------------------------------- + + Sse41::DescrInt::UnpackDataPtr GetUnpackData(size_t depth, bool transpose) + { + switch (depth) + { + case 4: return transpose ? UnpackDataB<4> : UnpackDataA<4>; + case 5: return transpose ? UnpackDataB<5> : UnpackDataA<5>; + case 6: return transpose ? UnpackDataB<6> : UnpackDataA<6>; + case 7: return transpose ? UnpackDataB<7> : UnpackDataA<7>; + default: return NULL; + } + } + + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth) + { + return depth == 8 ? NULL : MacroCorrelation8; + } + } +#endif +} diff --git a/src/Simd/SimdAvx512bwDescrIntDec.cpp b/src/Simd/SimdAvx512bwDescrIntDec.cpp new file mode 100644 index 0000000000..a81c2a1ebc --- /dev/null +++ b/src/Simd/SimdAvx512bwDescrIntDec.cpp @@ -0,0 +1,313 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#ifdef SIMD_AVX512BW_ENABLE + namespace Avx512bw + { + static void Decode32f4(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, Avx2::C4_SHFL), Avx2::C4_MULLO), 12); + _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); + src += 8; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); + src += 4; + dst += 8; + } + } + + static void Decode32f5(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C5_SHFL), Avx2::C5_MULLO), 11); + _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); + src += 10; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); + src += 5; + dst += 8; + } + } + + static void Decode32f6(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C6_SHFL), Avx2::C6_MULLO), 10); + _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); + src += 12; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s6 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); + src += 6; + dst += 8; + } + } + + static void Decode32f7(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C7_SHFL), Avx2::C7_MULLO), 9); + _mm512_storeu_ps(dst + 0, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift)); + src += 14; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s7 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); + _mm256_storeu_ps(dst + 0, _mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); + src += 7; + dst += 8; + } + } + + static void Decode32f8(const uint8_t* src, float scale, float shift, size_t size, float* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size64 = AlignLo(size, 64); + for (; i < size64; i += 64) + { + __m512i u8 = _mm512_loadu_si512((__m512i*)(src + i)); + _mm512_storeu_ps(dst + i + 0 * F, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 0))), _scale, _shift)); + _mm512_storeu_ps(dst + i + 1 * F, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 1))), _scale, _shift)); + _mm512_storeu_ps(dst + i + 2 * F, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 2))), _scale, _shift)); + _mm512_storeu_ps(dst + i + 3 * F, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 3))), _scale, _shift)); + } + for (; i < size16; i += 16) + { + __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); + _mm512_storeu_ps(dst + i, _mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(u8)), _scale, _shift)); + } + if (i < size) + { + __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); + _mm256_storeu_ps(dst + i, _mm256_fmadd_ps(_src, _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift))); + } + } + + //------------------------------------------------------------------------------------------------- + + static void Decode16f4(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s4 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s4, Avx2::C4_SHFL), Avx2::C4_MULLO), 12); + _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 8; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s4 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s4, Sse41::C4_SHFL0), Sse41::C4_MULLO), 12); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + src += 4; + dst += 8; + } + } + + static void Decode16f5(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s5 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s5, Avx2::C5_SHFL), Avx2::C5_MULLO), 11); + _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 10; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s5 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s5, Sse41::C5_SHFL0), Sse41::C5_MULLO), 11); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + src += 5; + dst += 8; + } + } + + static void Decode16f6(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C6_SHFL), Avx2::C6_MULLO), 10); + _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 12; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s6 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s6, Sse41::C6_SHFL0), Sse41::C6_MULLO), 10); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + src += 6; + dst += 8; + } + } + + static void Decode16f7(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size32 = AlignLo(size, 32); + for (; i < size16; i += 16) + { + __m256i s6 = _mm256_broadcastsi128_si256(_mm_loadu_si128((__m128i*)src)); + __m256i s16 = _mm256_srli_epi16(_mm256_mullo_epi16(_mm256_shuffle_epi8(s6, Avx2::C7_SHFL), Avx2::C7_MULLO), 9); + _mm256_storeu_si256((__m256i*)dst, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu16_epi32(s16)), _scale, _shift), 0)); + src += 14; + dst += 16; + } + for (; i < size; i += 8) + { + __m128i s7 = _mm_loadl_epi64((__m128i*)src); + __m128i s16 = _mm_srli_epi16(_mm_mullo_epi16(_mm_shuffle_epi8(s7, Sse41::C7_SHFL0), Sse41::C7_MULLO), 9); + _mm_storeu_si128((__m128i*)dst, _mm256_cvtps_ph(_mm256_fmadd_ps(_mm256_cvtepi32_ps(_mm256_cvtepu16_epi32(s16)), _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + src += 7; + dst += 8; + } + } + + static void Decode16f8(const uint8_t* src, float scale, float shift, size_t size, uint16_t* dst) + { + assert(size % 8 == 0); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _shift = _mm512_set1_ps(shift); + size_t i = 0, size16 = AlignLo(size, 16), size64 = AlignLo(size, 64); + for (; i < size64; i += 64) + { + __m512i u8 = _mm512_loadu_si512((__m512i*)(src + i)); + _mm256_storeu_si256((__m256i*)(dst + i) + 0, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 0))), _scale, _shift), 0)); + _mm256_storeu_si256((__m256i*)(dst + i) + 1, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 1))), _scale, _shift), 0)); + _mm256_storeu_si256((__m256i*)(dst + i) + 2, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 2))), _scale, _shift), 0)); + _mm256_storeu_si256((__m256i*)(dst + i) + 3, _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(_mm512_extracti32x4_epi32(u8, 3))), _scale, _shift), 0)); + } + for (; i < size16; i += 16) + { + __m128i u8 = _mm_loadu_si128((__m128i*)(src + i)); + _mm256_storeu_si256((__m256i*)(dst + i), _mm512_cvtps_ph(_mm512_fmadd_ps(_mm512_cvtepi32_ps(_mm512_cvtepu8_epi32(u8)), _scale, _shift), 0)); + } + if (i < size) + { + __m256 _src = _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64((__m128i*)(src + i)))); + _mm_storeu_si128((__m128i*)(dst + i), _mm256_cvtps_ph(_mm256_fmadd_ps(_src, _mm512_castps512_ps256(_scale), _mm512_castps512_ps256(_shift)), 0)); + } + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::Decode32fPtr GetDecode32f(size_t depth) + { + switch (depth) + { + case 4: return Decode32f4; + case 5: return Decode32f5; + case 6: return Decode32f6; + case 7: return Decode32f7; + case 8: return Decode32f8; + default: assert(0); return NULL; + } + } + + Base::DescrInt::Decode16fPtr GetDecode16f(size_t depth) + { + switch (depth) + { + case 4: return Decode16f4; + case 5: return Decode16f5; + case 6: return Decode16f6; + case 7: return Decode16f7; + case 8: return Decode16f8; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdAvx512bwDescrIntEnc.cpp b/src/Simd/SimdAvx512bwDescrIntEnc.cpp new file mode 100644 index 0000000000..bc01fcc08d --- /dev/null +++ b/src/Simd/SimdAvx512bwDescrIntEnc.cpp @@ -0,0 +1,441 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#ifdef SIMD_AVX512BW_ENABLE + namespace Avx512bw + { + SIMD_INLINE __m512i Encode32f(__m512 src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i value = _mm512_cvtps_epi32(_mm512_mul_ps(_mm512_sub_ps(src, min), scale)); + sum = _mm512_add_epi32(value, sum); + sqsum = _mm512_add_epi32(_mm512_madd_epi16(value, value), sqsum); + return value; + } + + SIMD_INLINE __m512i Encode32f(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + return Encode32f(_mm512_maskz_loadu_ps(mask, src), scale, min, sum, sqsum); + } + + static SIMD_INLINE __m128i Encode32f4x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 m0, __mmask16 m1) + { + __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum, m0); + __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum, m1); + __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + return _mm256_castsi256_si128(Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), Avx2::K_ZERO)); + } + + static SIMD_INLINE __m256i Encode32f4x8(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); + __m512i i2 = Encode32f(src + 2 * F, scale, min, sum, sqsum); + __m512i i3 = Encode32f(src + 3 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + __m512i s1 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); + return Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), _mm512_cvtepi32_epi16(s1)); + } + + static void Encode32f4(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32), size64 = AlignLo(size, 64); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size64; i += 64, src += 64, dst += 32) + _mm256_storeu_si256((__m256i*)dst, Encode32f4x8(src, _scale, _min, _sum, _sqsum)); + for (; i < size32; i += 32, src += 32, dst += 16) + _mm_mask_storeu_epi8(dst, -1, Encode32f4x4(src, _scale, _min, _sum, _sqsum, -1, -1)); + if (i < size) + { + __mmask16 ms0 = TailMask16(size - size32 - 0 * F); + __mmask16 ms1 = TailMask16(size - size32 - 1 * F); + __mmask16 md= TailMask16((size - size32) / 2); + _mm_mask_storeu_epi8(dst, md, Encode32f4x4(src, _scale, _min, _sum, _sqsum, ms0, ms1)); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f5x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E5_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E5_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E5_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode32f5x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E5_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_or_si512(_mm512_shuffle_epi8(s0, E5_SHFL0), _mm512_shuffle_epi8(s0, E5_SHFL1)), _mm512_shuffle_epi8(s0, E5_SHFL2)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode32f5(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 20) + _mm256_mask_storeu_epi8(dst - 6, 0x03FFFFC0, Encode32f5x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 10) + _mm_mask_storeu_epi8(dst, 0x03FF, Encode32f5x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x001F, Encode32f5x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f6x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E6_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E6_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E6_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode32f6x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E6_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E6_SHFL0), _mm512_shuffle_epi8(s0, E6_SHFL1)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode32f6(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 24) + _mm256_mask_storeu_epi8(dst - 4, 0x0FFFFFF0, Encode32f6x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 12) + _mm_mask_storeu_epi8(dst, 0x0FFF, Encode32f6x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x003F, Encode32f6x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode32f7x2(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode32f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E7_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E7_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E7_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode32f7x4(const float* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode32f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode32f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E7_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E7_SHFL0), _mm512_shuffle_epi8(s0, E7_SHFL1)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode32f7(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 28) + _mm256_mask_storeu_epi8(dst - 2, 0x3FFFFFFC, Encode32f7x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 14) + _mm_mask_storeu_epi8(dst, 0x3FFF, Encode32f7x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x007F, Encode32f7x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static void Encode32f8(const float* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t sizeF = AlignLo(size, F), sizeA = AlignLo(size, A), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < sizeA; i += A) + { + __m512i d0 = Encode32f(src + i + 0 * F, _scale, _min, _sum, _sqsum); + __m512i d1 = Encode32f(src + i + 1 * F, _scale, _min, _sum, _sqsum); + __m512i d2 = Encode32f(src + i + 2 * F, _scale, _min, _sum, _sqsum); + __m512i d3 = Encode32f(src + i + 3 * F, _scale, _min, _sum, _sqsum); + _mm512_storeu_si512((__m512i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); + } + for (; i < sizeF; i += F) + { + __m512i d0 = Encode32f(src + i, _scale, _min, _sum, _sqsum); + _mm_storeu_si128((__m128i*)(dst + i), _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); + } + if (i < size) + { + __m512i d0 = Encode32f(src + i, _scale, _min, _sum, _sqsum, 0xFF); + _mm_mask_storeu_epi8(dst + i, 0xFF, _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE __m512i Encode16f(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + return Encode32f(_mm512_cvtph_ps(_mm256_maskz_loadu_epi16(mask, src)), scale, min, sum, sqsum); + } + + static SIMD_INLINE __m128i Encode16f4x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 m0, __mmask16 m1) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum, m0); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum, m1); + __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + return _mm256_castsi256_si128(Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), Avx2::K_ZERO)); + } + + static SIMD_INLINE __m256i Encode16f4x8(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); + __m512i i2 = Encode16f(src + 2 * F, scale, min, sum, sqsum); + __m512i i3 = Encode16f(src + 3 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i0, i1), E4_MULLO), 12); + __m512i s1 = _mm512_srli_epi32(_mm512_mullo_epi16(PackU32ToI16(i2, i3), E4_MULLO), 12); + return Avx2::PackI16ToU8(_mm512_cvtepi32_epi16(s0), _mm512_cvtepi32_epi16(s1)); + } + + static void Encode16f4(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t i = 0, size32 = AlignLo(size, 32), size64 = AlignLo(size, 64); + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size64; i += 64, src += 64, dst += 32) + _mm256_storeu_si256((__m256i*)dst, Encode16f4x8(src, _scale, _min, _sum, _sqsum)); + for (; i < size32; i += 32, src += 32, dst += 16) + _mm_mask_storeu_epi8(dst, -1, Encode16f4x4(src, _scale, _min, _sum, _sqsum, -1, -1)); + if (i < size) + { + __mmask16 ms0 = TailMask16(size - size32 - 0 * F); + __mmask16 ms1 = TailMask16(size - size32 - 1 * F); + __mmask16 md = TailMask16((size - size32) / 2); + _mm_mask_storeu_epi8(dst, md, Encode16f4x4(src, _scale, _min, _sum, _sqsum, ms0, ms1)); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f5x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E5_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E5_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E5_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode16f5x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E5_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_or_si512(_mm512_shuffle_epi8(s0, E5_SHFL0), _mm512_shuffle_epi8(s0, E5_SHFL1)), _mm512_shuffle_epi8(s0, E5_SHFL2)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode16f5(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 20) + _mm256_mask_storeu_epi8(dst - 6, 0x03FFFFC0, Encode16f5x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 10) + _mm_mask_storeu_epi8(dst, 0x03FF, Encode16f5x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x001F, Encode16f5x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f6x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E6_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E6_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E6_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode16f6x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E6_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E6_SHFL0), _mm512_shuffle_epi8(s0, E6_SHFL1)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode16f6(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 24) + _mm256_mask_storeu_epi8(dst - 4, 0x0FFFFFF0, Encode16f6x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 12) + _mm_mask_storeu_epi8(dst, 0x0FFF, Encode16f6x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x003F, Encode16f6x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static SIMD_INLINE __m128i Encode16f7x2(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum, __mmask16 mask = -1) + { + __m512i i0 = Encode16f(src, scale, min, sum, sqsum, mask); + __m256i s0 = _mm256_mullo_epi16(_mm512_cvtepi32_epi16(i0), Avx2::E7_MULLO); + __m256i e0 = _mm256_or_si256(_mm256_shuffle_epi8(s0, Avx2::E7_SHFL0), _mm256_shuffle_epi8(s0, Avx2::E7_SHFL1)); + return _mm_or_si128(_mm256_castsi256_si128(e0), _mm256_extracti128_si256(e0, 1)); + } + + static SIMD_INLINE __m256i Encode16f7x4(const uint16_t* src, __m512 scale, __m512 min, __m512i& sum, __m512i& sqsum) + { + __m512i i0 = Encode16f(src + 0 * F, scale, min, sum, sqsum); + __m512i i1 = Encode16f(src + 1 * F, scale, min, sum, sqsum); + __m512i s0 = _mm512_mullo_epi16(_mm512_permutexvar_epi64(EX_PERM, _mm512_packus_epi32(i0, i1)), E7_MULLO); + __m512i e0 = _mm512_or_si512(_mm512_shuffle_epi8(s0, E7_SHFL0), _mm512_shuffle_epi8(s0, E7_SHFL1)); + return _mm256_or_si256(_mm512_castsi512_si256(e0), _mm512_extracti32x8_epi32(e0, 1)); + } + + static void Encode16f7(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t size16 = AlignLo(size, 16), size32 = AlignLo(size, 32), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < size32; i += 32, src += 32, dst += 28) + _mm256_mask_storeu_epi8(dst - 2, 0x3FFFFFFC, Encode16f7x4(src, _scale, _min, _sum, _sqsum)); + for (; i < size16; i += 16, src += 16, dst += 14) + _mm_mask_storeu_epi8(dst, 0x3FFF, Encode16f7x2(src, _scale, _min, _sum, _sqsum)); + if (i < size) + _mm_mask_storeu_epi8(dst, 0x007F, Encode16f7x2(src, _scale, _min, _sum, _sqsum, 0x00FF)); + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + static void Encode16f8(const uint16_t* src, float scale, float min, size_t size, int32_t& sum, int32_t& sqsum, uint8_t* dst) + { + assert(size % 8 == 0); + size_t sizeF = AlignLo(size, F), sizeA = AlignLo(size, A), i = 0; + __m512 _scale = _mm512_set1_ps(scale); + __m512 _min = _mm512_set1_ps(min); + __m512i _sum = _mm512_setzero_si512(); + __m512i _sqsum = _mm512_setzero_si512(); + for (; i < sizeA; i += A) + { + __m512i d0 = Encode16f(src + i + 0 * F, _scale, _min, _sum, _sqsum); + __m512i d1 = Encode16f(src + i + 1 * F, _scale, _min, _sum, _sqsum); + __m512i d2 = Encode16f(src + i + 2 * F, _scale, _min, _sum, _sqsum); + __m512i d3 = Encode16f(src + i + 3 * F, _scale, _min, _sum, _sqsum); + _mm512_storeu_si512((__m512i*)(dst + i), PackI16ToU8(PackI32ToI16(d0, d1), PackI32ToI16(d2, d3))); + } + for (; i < sizeF; i += F) + { + __m512i d0 = Encode16f(src + i, _scale, _min, _sum, _sqsum); + _mm_storeu_si128((__m128i*)(dst + i), _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); + } + if (i < size) + { + __m512i d0 = Encode16f(src + i, _scale, _min, _sum, _sqsum, 0xFF); + _mm_mask_storeu_epi8(dst + i, 0xFF, _mm512_castsi512_si128(PackI16ToU8(PackI32ToI16(d0)))); + } + sum = ExtractSum(_sum); + sqsum = ExtractSum(_sqsum); + } + + //------------------------------------------------------------------------------------------------- + + Base::DescrInt::Encode32fPtr GetEncode32f(size_t depth) + { + switch (depth) + { + case 4: return Encode32f4; + case 5: return Encode32f5; + case 6: return Encode32f6; + case 7: return Encode32f7; + case 8: return Encode32f8; + default: assert(0); return NULL; + } + } + + Base::DescrInt::Encode16fPtr GetEncode16f(size_t depth) + { + switch (depth) + { + case 4: return Encode16f4; + case 5: return Encode16f5; + case 6: return Encode16f6; + case 7: return Encode16f7; + case 8: return Encode16f8; + default: assert(0); return NULL; + } + } + } +#endif +} diff --git a/src/Simd/SimdDescrInt.h b/src/Simd/SimdDescrInt.h index 554bc4ce78..40444bf96b 100644 --- a/src/Simd/SimdDescrInt.h +++ b/src/Simd/SimdDescrInt.h @@ -170,6 +170,20 @@ namespace Simd //------------------------------------------------------------------------------------------------- + Base::DescrInt::Encode32fPtr GetEncode32f(size_t depth); + Base::DescrInt::Encode16fPtr GetEncode16f(size_t depth); + + Base::DescrInt::Decode32fPtr GetDecode32f(size_t depth); + Base::DescrInt::Decode16fPtr GetDecode16f(size_t depth); + + Base::DescrInt::CosineDistancePtr GetCosineDistance(size_t depth); + Sse41::DescrInt::MacroCosineDistancesDirectPtr GetMacroCosineDistancesDirect(size_t depth); + + Sse41::DescrInt::UnpackDataPtr GetUnpackData(size_t depth, bool transpose); + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth); + + //------------------------------------------------------------------------------------------------- + void* DescrIntInit(size_t size, size_t depth); } #endif diff --git a/src/Simd/SimdDescrIntCommon.h b/src/Simd/SimdDescrIntCommon.h index 8a78019204..0b7a99e7f6 100644 --- a/src/Simd/SimdDescrIntCommon.h +++ b/src/Simd/SimdDescrIntCommon.h @@ -257,6 +257,8 @@ namespace Simd _mm256_min_ps(_mm256_max_ps(_mm256_sub_ps(_mm256_set1_ps(1.0f), _mm256_div_ps(ab, _mm256_mul_ps(aNorm, bNorm))), _mm256_setzero_ps()), _mm256_set1_ps(2.0f))); } + //------------------------------------------------------------------------------------------------- + SIMD_INLINE void DecodeCosineDistances1xF(const float* a, const float* b, size_t stride, __m256i abSum, float* distances) { __m256 aScale = _mm256_set1_ps(a[0]); @@ -339,6 +341,10 @@ namespace Simd -1, -1, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1, -1, -1, -1, -1, -1, 0x2, 0x4, 0x6, 0x8, 0xA, 0xC, 0xE, -1, -1); + const __m512i C4_MULLO = SIMD_MM512_SETR_EPI16( + 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, + 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256, 4096, 256); + const __m512i C5_PERM = SIMD_MM512_SETR_EPI32( 0x0, 0x1, 0x0, 0x0, 0x1, 0x2, 0x0, 0x0, 0x2, 0x3, 0x0, 0x0, 0x3, 0x4, 0x0, 0x0); const __m512i C5_SHFL = SIMD_MM512_SETR_EPI8( @@ -371,6 +377,24 @@ namespace Simd const __m512i C7_MULLO = SIMD_MM512_SETR_EPI16( 2, 4, 8, 16, 32, 64, 128, 256, 2, 4, 8, 16, 32, 64, 128, 256, 2, 4, 8, 16, 32, 64, 128, 256, 2, 4, 8, 16, 32, 64, 128, 256); + + //------------------------------------------------------------------------------------------------- + + SIMD_INLINE void DecodeCosineDistances1xF(const float* a, const float* b, size_t stride, __m512i abSum, float* distances, __mmask16 mask = -1) + { + __m512 aScale = _mm512_set1_ps(a[0]); + __m512 aShift = _mm512_set1_ps(a[1]); + __m512 aMean = _mm512_set1_ps(a[2]); + __m512 aNorm = _mm512_set1_ps(a[3]); + __m512 bScale = _mm512_maskz_loadu_ps(mask, b + 0 * stride); + __m512 bShift = _mm512_maskz_loadu_ps(mask, b + 1 * stride); + __m512 bMean = _mm512_maskz_loadu_ps(mask, b + 2 * stride); + __m512 bNorm = _mm512_maskz_loadu_ps(mask, b + 3 * stride); + __m512 ab = _mm512_mul_ps(_mm512_cvtepi32_ps(abSum), _mm512_mul_ps(aScale, bScale)); + ab = _mm512_add_ps(_mm512_mul_ps(aMean, bShift), ab); + ab = _mm512_add_ps(_mm512_mul_ps(bMean, aShift), ab); + _mm512_mask_storeu_ps(distances, mask, _mm512_min_ps(_mm512_max_ps(_mm512_sub_ps(_mm512_set1_ps(1.0f), _mm512_div_ps(ab, _mm512_mul_ps(aNorm, bNorm))), _mm512_setzero_ps()), _mm512_set1_ps(2.0f))); + } } #endif } diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 5ea215b91a..9fc1ca23d4 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -176,10 +176,9 @@ namespace Simd { size_t macroM = AlignLoAny(Base::AlgCacheL2() / _unpSize, _microMu); size_t macroN = AlignLoAny(Base::AlgCacheL3() / _unpSize, _microNu); - Array8u dA(Min(macroM, M) * _unpSize); - Array8u dB(Min(macroN, N) * _unpSize); - Array32f nA(Min(macroM, M) * 4); - Array32f nB(AlignHi(Min(macroN, N), _microNu) * 4); + size_t sizeA = Min(macroM, M), sizeB = AlignHi(Min(macroN, N), _microNu); + Array8u dA(sizeA * _unpSize), dB(sizeB * _unpSize); + Array32f nA(sizeA * 4), nB(sizeB * 4); for (size_t i = 0; i < M; i += macroM) { size_t dM = Simd::Min(M, i + macroM) - i; From 5c67d27d477f91ee07fa8a0419e36a73d5905ac3 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 28 Jun 2023 17:58:56 +0300 Subject: [PATCH 34/44] *fix bugs in DescrInt. --- src/Simd/SimdAvx2DescrInt.cpp | 8 ++++---- src/Simd/SimdAvx2DescrIntCdu.cpp | 30 ++++++++++++++-------------- src/Simd/SimdAvx512bwDescrInt.cpp | 8 ++++---- src/Simd/SimdAvx512bwDescrIntCdu.cpp | 9 +++++---- src/Simd/SimdSse41DescrInt.cpp | 8 ++++---- src/Simd/SimdSse41DescrIntCdu.cpp | 10 +++++----- src/Test/TestDescrInt.cpp | 1 + 7 files changed, 38 insertions(+), 36 deletions(-) diff --git a/src/Simd/SimdAvx2DescrInt.cpp b/src/Simd/SimdAvx2DescrInt.cpp index 0c37f672d7..77560f974a 100644 --- a/src/Simd/SimdAvx2DescrInt.cpp +++ b/src/Simd/SimdAvx2DescrInt.cpp @@ -118,10 +118,10 @@ namespace Simd } for (; i < count; i++, src++, dst++) { - dst[0 * stride] = ((float*)src)[0]; - dst[1 * stride] = ((float*)src)[1]; - dst[2 * stride] = ((float*)src)[2]; - dst[3 * stride] = ((float*)src)[3]; + dst[0 * stride] = ((float*)src[0])[0]; + dst[1 * stride] = ((float*)src[0])[1]; + dst[2 * stride] = ((float*)src[0])[2]; + dst[3 * stride] = ((float*)src[0])[3]; } } diff --git a/src/Simd/SimdAvx2DescrIntCdu.cpp b/src/Simd/SimdAvx2DescrIntCdu.cpp index 4bfa8fb776..c2588f15ba 100644 --- a/src/Simd/SimdAvx2DescrIntCdu.cpp +++ b/src/Simd/SimdAvx2DescrIntCdu.cpp @@ -186,9 +186,9 @@ namespace Simd for (; j < size; j += 8, o += bits, dst += 8 * Sse41::A) { UnpackDataBx4x8(src + 0, o, dst + 0 * Sse41::A); - UnpackDataBx4x8(src + 2, o, dst + 1 * Sse41::A); - UnpackDataBx4x8(src + 4, o, dst + 2 * Sse41::A); - UnpackDataBx4x8(src + 6, o, dst + 3 * Sse41::A); + UnpackDataBx4x8(src + 4, o, dst + 1 * Sse41::A); + UnpackDataBx4x8(src + 8, o, dst + 2 * Sse41::A); + UnpackDataBx4x8(src + 12, o, dst + 3 * Sse41::A); } } if (i < count) @@ -198,24 +198,24 @@ namespace Simd _src[j] = i < count ? *src++ : src[-1]; for (j = 0, o = 16; j < size32; j += 32, o += 4 * bits, dst += 16 * A) { - UnpackDataBx4x32(src + 0, o, dst + 0 * Sse41::A); - UnpackDataBx4x32(src + 4, o, dst + 1 * Sse41::A); - UnpackDataBx4x32(src + 8, o, dst + 2 * Sse41::A); - UnpackDataBx4x32(src + 12, o, dst + 3 * Sse41::A); + UnpackDataBx4x32(_src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x32(_src + 4, o, dst + 1 * Sse41::A); + UnpackDataBx4x32(_src + 8, o, dst + 2 * Sse41::A); + UnpackDataBx4x32(_src + 12, o, dst + 3 * Sse41::A); } for (; j < size16; j += 16, o += 2 * bits, dst += 16 * Sse41::A) { - UnpackDataBx4x16(src + 0, o, dst + 0 * Sse41::A); - UnpackDataBx4x16(src + 4, o, dst + 1 * Sse41::A); - UnpackDataBx4x16(src + 8, o, dst + 2 * Sse41::A); - UnpackDataBx4x16(src + 12, o, dst + 3 * Sse41::A); + UnpackDataBx4x16(_src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x16(_src + 4, o, dst + 1 * Sse41::A); + UnpackDataBx4x16(_src + 8, o, dst + 2 * Sse41::A); + UnpackDataBx4x16(_src + 12, o, dst + 3 * Sse41::A); } for (; j < size; j += 8, o += bits, dst += 8 * Sse41::A) { - UnpackDataBx4x8(src + 0, o, dst + 0 * Sse41::A); - UnpackDataBx4x8(src + 2, o, dst + 1 * Sse41::A); - UnpackDataBx4x8(src + 4, o, dst + 2 * Sse41::A); - UnpackDataBx4x8(src + 6, o, dst + 3 * Sse41::A); + UnpackDataBx4x8(_src + 0, o, dst + 0 * Sse41::A); + UnpackDataBx4x8(_src + 4, o, dst + 1 * Sse41::A); + UnpackDataBx4x8(_src + 8, o, dst + 2 * Sse41::A); + UnpackDataBx4x8(_src + 12, o, dst + 3 * Sse41::A); } } } diff --git a/src/Simd/SimdAvx512bwDescrInt.cpp b/src/Simd/SimdAvx512bwDescrInt.cpp index 8ea93b61da..5a4147b842 100644 --- a/src/Simd/SimdAvx512bwDescrInt.cpp +++ b/src/Simd/SimdAvx512bwDescrInt.cpp @@ -146,10 +146,10 @@ namespace Simd } for (; i < count; i++, src++, dst++) { - dst[0 * stride] = ((float*)src)[0]; - dst[1 * stride] = ((float*)src)[1]; - dst[2 * stride] = ((float*)src)[2]; - dst[3 * stride] = ((float*)src)[3]; + dst[0 * stride] = ((float*)src[0])[0]; + dst[1 * stride] = ((float*)src[0])[1]; + dst[2 * stride] = ((float*)src[0])[2]; + dst[3 * stride] = ((float*)src[0])[3]; } } diff --git a/src/Simd/SimdAvx512bwDescrIntCdu.cpp b/src/Simd/SimdAvx512bwDescrIntCdu.cpp index 3b7120ebfd..069c09dea9 100644 --- a/src/Simd/SimdAvx512bwDescrIntCdu.cpp +++ b/src/Simd/SimdAvx512bwDescrIntCdu.cpp @@ -137,7 +137,7 @@ namespace Simd size_t j = 0; for (; j < size64; j += 64, ps += 8 * bits, pd += 64) _mm512_mask_storeu_epi8(pd, dstBody, UnpackData64(ps, srcBody)); - if(j < size64) + if(j < size) _mm512_mask_storeu_epi8(pd, dstTail, UnpackData64(ps, srcTail)); } } @@ -274,9 +274,9 @@ namespace Simd template void UnpackDataB(size_t count, const uint8_t* const* src, size_t size, uint8_t* dst, size_t stride) { - size_t countDF = AlignLo(count, DF), size64 = AlignLo(size, 64), i, j, o; + size_t countDF = AlignLo(count, DF), size64 = AlignLo(size, 64), tail = size - size64, i, j, o; UnpackDataBx16xN_Ptr unpackDataMain = GetUnpackDataBx16xN(64); - UnpackDataBx16xN_Ptr unpackDataTail = GetUnpackDataBx16xN(size - size64); + UnpackDataBx16xN_Ptr unpackDataTail = GetUnpackDataBx16xN(tail); for (i = 0; i < countDF; i += DF, src += DF) { for (j = 0, o = 16; j < size64; j += 64, o += 8 * bits, dst += 32 * A) @@ -288,6 +288,7 @@ namespace Simd { unpackDataTail(src + 0, o, dst + 0); unpackDataTail(src + F, o, dst + A); + dst += tail * DF; } } if (i < count) @@ -398,7 +399,7 @@ namespace Simd if (M > 0xB) a0 = Set4(ad5 + k6), Madd4(abB0, a0, b0); bd += DA; } - __mmask16 tail = TailMask16(N - F); + __mmask16 tail = TailMask16(N); if (M > 0x0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0, tail), an += 4, distances += stride; if (M > 0x1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0, tail), an += 4, distances += stride; if (M > 0x2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0, tail), an += 4, distances += stride; diff --git a/src/Simd/SimdSse41DescrInt.cpp b/src/Simd/SimdSse41DescrInt.cpp index 9fc1ca23d4..4ffe3467f2 100644 --- a/src/Simd/SimdSse41DescrInt.cpp +++ b/src/Simd/SimdSse41DescrInt.cpp @@ -102,10 +102,10 @@ namespace Simd } for (; i < count; i++, src++, dst++) { - dst[0 * stride] = ((float*)src)[0]; - dst[1 * stride] = ((float*)src)[1]; - dst[2 * stride] = ((float*)src)[2]; - dst[3 * stride] = ((float*)src)[3]; + dst[0 * stride] = ((float*)src[0])[0]; + dst[1 * stride] = ((float*)src[0])[1]; + dst[2 * stride] = ((float*)src[0])[2]; + dst[3 * stride] = ((float*)src[0])[3]; } } diff --git a/src/Simd/SimdSse41DescrIntCdu.cpp b/src/Simd/SimdSse41DescrIntCdu.cpp index 83de89f27e..d16aaab8b9 100644 --- a/src/Simd/SimdSse41DescrIntCdu.cpp +++ b/src/Simd/SimdSse41DescrIntCdu.cpp @@ -154,7 +154,7 @@ namespace Simd for (; j < size; j += 8, o += bits, dst += 4 * A) { UnpackDataBx4x8(src + 0, o, dst + 0); - UnpackDataBx4x8(src + 2, o, dst + A); + UnpackDataBx4x8(src + 4, o, dst + A); } } if (i < count) @@ -164,13 +164,13 @@ namespace Simd _src[j] = i < count ? *src++ : src[-1]; for (j = 0, o = 16; j < size16; j += 16, o += 2 * bits, dst += 8 * A) { - UnpackDataBx4x16(src + 0, o, dst + 0); - UnpackDataBx4x16(src + 4, o, dst + A); + UnpackDataBx4x16(_src + 0, o, dst + 0); + UnpackDataBx4x16(_src + 4, o, dst + A); } for (; j < size; j += 8, o += bits, dst += 4 * A) { - UnpackDataBx4x8(src + 0, o, dst + 0); - UnpackDataBx4x8(src + 2, o, dst + A); + UnpackDataBx4x8(_src + 0, o, dst + 0); + UnpackDataBx4x8(_src + 4, o, dst + A); } } } diff --git a/src/Test/TestDescrInt.cpp b/src/Test/TestDescrInt.cpp index 4205ba77d6..5b65e282ab 100644 --- a/src/Test/TestDescrInt.cpp +++ b/src/Test/TestDescrInt.cpp @@ -582,6 +582,7 @@ namespace Test for (size_t depth = 4; depth <= 8; depth++) { + //result = result && DescrIntCosineDistancesMxNaAutoTest(127, 129, 520, depth, f1, f2); result = result && DescrIntCosineDistancesMxNaAutoTest(256, 128, 256, depth, f1, f2); result = result && DescrIntCosineDistancesMxNaAutoTest(128, 128, 512, depth, f1, f2); } From 6057373db5853d2b0199a1a7e9ec683f6289625e Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 28 Jun 2023 23:19:06 +0300 Subject: [PATCH 35/44] +add AVX-512VNNI optimizations of functions DescrIntCosineDistancesMxNp, DescrIntCosineDistancesMxNa. --- .github/workflows/cmake.yml | 2 +- docs/2023.html | 4 +- prj/vs2019/Avx512vnni.vcxproj | 6 + prj/vs2019/Avx512vnni.vcxproj.filters | 18 +++ prj/vs2022/Avx512vnni.vcxproj | 6 + prj/vs2022/Avx512vnni.vcxproj.filters | 18 +++ src/Simd/SimdAvx512vnniDescrInt.cpp | 59 ++++++++ src/Simd/SimdAvx512vnniDescrIntCdu.cpp | 191 +++++++++++++++++++++++++ src/Simd/SimdDescrInt.h | 19 +++ src/Simd/SimdLib.cpp | 2 +- src/Test/TestDescrInt.cpp | 10 ++ 11 files changed, 331 insertions(+), 4 deletions(-) create mode 100644 src/Simd/SimdAvx512vnniDescrInt.cpp create mode 100644 src/Simd/SimdAvx512vnniDescrIntCdu.cpp diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index cf06c53e07..74bcfa6b62 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -21,7 +21,7 @@ jobs: run: lscpu - name: Configure CMake - run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TEST_FLAGS="-mavx2" + run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_AVX512VNNI=ON -DSIMD_TEST_FLAGS="-mavx2" - name: Build run: cmake --build ${{github.workspace}}/build --config ${{matrix.build_type}} --parallel$(nproc) diff --git a/docs/2023.html b/docs/2023.html index 9563fe0c7b..b8708c65ef 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -44,8 +44,8 @@
        New features
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode32f.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntDecode16f.
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistance.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNp.
      • -
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function DescrIntCosineDistancesMxNa.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW, AVX-512VNNI optimizations of function DescrIntCosineDistancesMxNp.
      • +
      • Support of 4-bit and 5-bit depth in Base implementation, SSE4.1, AVX2, AVX-512BW, AVX-512VNNI optimizations of function DescrIntCosineDistancesMxNa.
      • Base implementation, SSE4.1, AVX2, AVX-512BW optimizations of function SynetNormalizeLayerForwardV3.
      Improving
      diff --git a/prj/vs2019/Avx512vnni.vcxproj b/prj/vs2019/Avx512vnni.vcxproj index 9a9b9be900..7df71eec09 100644 --- a/prj/vs2019/Avx512vnni.vcxproj +++ b/prj/vs2019/Avx512vnni.vcxproj @@ -18,7 +18,10 @@ + + + @@ -32,6 +35,7 @@ + @@ -43,6 +47,8 @@ + + diff --git a/prj/vs2019/Avx512vnni.vcxproj.filters b/prj/vs2019/Avx512vnni.vcxproj.filters index d4c2d4a11a..e8424fccb0 100644 --- a/prj/vs2019/Avx512vnni.vcxproj.filters +++ b/prj/vs2019/Avx512vnni.vcxproj.filters @@ -99,6 +99,18 @@ Inc + + Inc + + + Inc + + + Inc + + + Inc + @@ -125,5 +137,11 @@ Avx512vnni + + Avx512vnni + + + Avx512vnni + \ No newline at end of file diff --git a/prj/vs2022/Avx512vnni.vcxproj b/prj/vs2022/Avx512vnni.vcxproj index 9a9b9be900..7df71eec09 100644 --- a/prj/vs2022/Avx512vnni.vcxproj +++ b/prj/vs2022/Avx512vnni.vcxproj @@ -18,7 +18,10 @@ + + + @@ -32,6 +35,7 @@ + @@ -43,6 +47,8 @@ + + diff --git a/prj/vs2022/Avx512vnni.vcxproj.filters b/prj/vs2022/Avx512vnni.vcxproj.filters index d4c2d4a11a..e8424fccb0 100644 --- a/prj/vs2022/Avx512vnni.vcxproj.filters +++ b/prj/vs2022/Avx512vnni.vcxproj.filters @@ -99,6 +99,18 @@ Inc + + Inc + + + Inc + + + Inc + + + Inc + @@ -125,5 +137,11 @@ Avx512vnni + + Avx512vnni + + + Avx512vnni + \ No newline at end of file diff --git a/src/Simd/SimdAvx512vnniDescrInt.cpp b/src/Simd/SimdAvx512vnniDescrInt.cpp new file mode 100644 index 0000000000..f3fd1897d7 --- /dev/null +++ b/src/Simd/SimdAvx512vnniDescrInt.cpp @@ -0,0 +1,59 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" + +namespace Simd +{ +#ifdef SIMD_AVX512VNNI_ENABLE + namespace Avx512vnni + { + DescrInt::DescrInt(size_t size, size_t depth) + : Avx512bw::DescrInt(size, depth) + { + if (_depth != 8) + { + _macroCosineDistancesUnpack = GetMacroCosineDistancesUnpack(_depth); + _microMu = 12; + _microNu = 32; + } + } + + //------------------------------------------------------------------------------------------------- + + void* DescrIntInit(size_t size, size_t depth) + { + if (!Base::DescrInt::Valid(size, depth)) + return NULL; + return new Avx512vnni::DescrInt(size, depth); + } + } +#endif +} diff --git a/src/Simd/SimdAvx512vnniDescrIntCdu.cpp b/src/Simd/SimdAvx512vnniDescrIntCdu.cpp new file mode 100644 index 0000000000..76aad377a4 --- /dev/null +++ b/src/Simd/SimdAvx512vnniDescrIntCdu.cpp @@ -0,0 +1,191 @@ +/* +* Simd Library (http://ermig1979.github.io/Simd). +* +* Copyright (c) 2011-2023 Yermalayeu Ihar. +* +* Permission is hereby granted, free of charge, to any person obtaining a copy +* of this software and associated documentation files (the "Software"), to deal +* in the Software without restriction, including without limitation the rights +* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +* copies of the Software, and to permit persons to whom the Software is +* furnished to do so, subject to the following conditions: +* +* The above copyright notice and this permission notice shall be included in +* all copies or substantial portions of the Software. +* +* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +* SOFTWARE. +*/ +#include "Simd/SimdMemory.h" +#include "Simd/SimdStore.h" +#include "Simd/SimdExtract.h" +#include "Simd/SimdArray.h" +#include "Simd/SimdUnpack.h" +#include "Simd/SimdDescrInt.h" +#include "Simd/SimdDescrIntCommon.h" +#include "Simd/SimdCpu.h" +#include "Simd/SimdSynet.h" + +namespace Simd +{ +#ifdef SIMD_AVX512VNNI_ENABLE + namespace Avx512vnni + { + template void Correlation8_2xM(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride) + { + __m512i ab00, ab01, ab10, ab11, ab20, ab21, ab30, ab31, ab40, ab41, ab50, ab51, ab60, ab61, ab70, ab71, ab80, ab81, ab90, ab91, abA0, abA1, abB0, abB1, a0, b0, b1; + const uint8_t* ad1 = ad0 + 1 * K; + const uint8_t* ad2 = ad0 + 2 * K; + const uint8_t* ad3 = ad0 + 3 * K; + const uint8_t* ad4 = ad0 + 4 * K; + const uint8_t* ad5 = ad0 + 5 * K; + if (N > F) + { + if (M > 0x0) ab00 = _mm512_setzero_si512(), ab01 = _mm512_setzero_si512(); + if (M > 0x1) ab10 = _mm512_setzero_si512(), ab11 = _mm512_setzero_si512(); + if (M > 0x2) ab20 = _mm512_setzero_si512(), ab21 = _mm512_setzero_si512(); + if (M > 0x3) ab30 = _mm512_setzero_si512(), ab31 = _mm512_setzero_si512(); + if (M > 0x4) ab40 = _mm512_setzero_si512(), ab41 = _mm512_setzero_si512(); + if (M > 0x5) ab50 = _mm512_setzero_si512(), ab51 = _mm512_setzero_si512(); + if (M > 0x6) ab60 = _mm512_setzero_si512(), ab61 = _mm512_setzero_si512(); + if (M > 0x7) ab70 = _mm512_setzero_si512(), ab71 = _mm512_setzero_si512(); + if (M > 0x8) ab80 = _mm512_setzero_si512(), ab81 = _mm512_setzero_si512(); + if (M > 0x9) ab90 = _mm512_setzero_si512(), ab91 = _mm512_setzero_si512(); + if (M > 0xA) abA0 = _mm512_setzero_si512(), abA1 = _mm512_setzero_si512(); + if (M > 0xB) abB0 = _mm512_setzero_si512(), abB1 = _mm512_setzero_si512(); + for (size_t k0 = 0, k6 = K * 6; k0 < K; k0 += 4, k6 += 4) + { + b0 = _mm512_loadu_si512((__m512i*)bd + 0); + b1 = _mm512_loadu_si512((__m512i*)bd + 1); + if (M > 0x0) a0 = Set4(ad0 + k0), Madd4(ab00, a0, b0), Madd4(ab01, a0, b1); + if (M > 0x1) a0 = Set4(ad1 + k0), Madd4(ab10, a0, b0), Madd4(ab11, a0, b1); + if (M > 0x2) a0 = Set4(ad2 + k0), Madd4(ab20, a0, b0), Madd4(ab21, a0, b1); + if (M > 0x3) a0 = Set4(ad3 + k0), Madd4(ab30, a0, b0), Madd4(ab31, a0, b1); + if (M > 0x4) a0 = Set4(ad4 + k0), Madd4(ab40, a0, b0), Madd4(ab41, a0, b1); + if (M > 0x5) a0 = Set4(ad5 + k0), Madd4(ab50, a0, b0), Madd4(ab51, a0, b1); + if (M > 0x6) a0 = Set4(ad0 + k6), Madd4(ab60, a0, b0), Madd4(ab61, a0, b1); + if (M > 0x7) a0 = Set4(ad1 + k6), Madd4(ab70, a0, b0), Madd4(ab71, a0, b1); + if (M > 0x8) a0 = Set4(ad2 + k6), Madd4(ab80, a0, b0), Madd4(ab81, a0, b1); + if (M > 0x9) a0 = Set4(ad3 + k6), Madd4(ab90, a0, b0), Madd4(ab91, a0, b1); + if (M > 0xA) a0 = Set4(ad4 + k6), Madd4(abA0, a0, b0), Madd4(abA1, a0, b1); + if (M > 0xB) a0 = Set4(ad5 + k6), Madd4(abB0, a0, b0), Madd4(abB1, a0, b1); + bd += DA; + } + __mmask16 tail = TailMask16(N - F); + if (M > 0x0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab01, distances + F, tail), an += 4, distances += stride; + if (M > 0x1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab11, distances + F, tail), an += 4, distances += stride; + if (M > 0x2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab21, distances + F, tail), an += 4, distances += stride; + if (M > 0x3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab31, distances + F, tail), an += 4, distances += stride; + if (M > 0x4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab41, distances + F, tail), an += 4, distances += stride; + if (M > 0x5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab51, distances + F, tail), an += 4, distances += stride; + if (M > 0x6) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab60, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab61, distances + F, tail), an += 4, distances += stride; + if (M > 0x7) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab70, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab71, distances + F, tail), an += 4, distances += stride; + if (M > 0x8) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab80, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab81, distances + F, tail), an += 4, distances += stride; + if (M > 0x9) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab90, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, ab91, distances + F, tail), an += 4, distances += stride; + if (M > 0xA) DecodeCosineDistances1xF(an, bn + 0, bnStride, abA0, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, abA1, distances + F, tail), an += 4, distances += stride; + if (M > 0xB) DecodeCosineDistances1xF(an, bn + 0, bnStride, abB0, distances + 0), DecodeCosineDistances1xF(an, bn + F, bnStride, abB1, distances + F, tail), an += 4, distances += stride; + } + else + { + if (M > 0x0) ab00 = _mm512_setzero_si512(); + if (M > 0x1) ab10 = _mm512_setzero_si512(); + if (M > 0x2) ab20 = _mm512_setzero_si512(); + if (M > 0x3) ab30 = _mm512_setzero_si512(); + if (M > 0x4) ab40 = _mm512_setzero_si512(); + if (M > 0x5) ab50 = _mm512_setzero_si512(); + if (M > 0x6) ab60 = _mm512_setzero_si512(); + if (M > 0x7) ab70 = _mm512_setzero_si512(); + if (M > 0x8) ab80 = _mm512_setzero_si512(); + if (M > 0x9) ab90 = _mm512_setzero_si512(); + if (M > 0xA) abA0 = _mm512_setzero_si512(); + if (M > 0xB) abB0 = _mm512_setzero_si512(); + for (size_t k0 = 0, k6 = K * 6; k0 < K; k0 += 4, k6 += 4) + { + b0 = _mm512_loadu_si512((__m512i*)bd + 0); + if (M > 0x0) a0 = Set4(ad0 + k0), Madd4(ab00, a0, b0); + if (M > 0x1) a0 = Set4(ad1 + k0), Madd4(ab10, a0, b0); + if (M > 0x2) a0 = Set4(ad2 + k0), Madd4(ab20, a0, b0); + if (M > 0x3) a0 = Set4(ad3 + k0), Madd4(ab30, a0, b0); + if (M > 0x4) a0 = Set4(ad4 + k0), Madd4(ab40, a0, b0); + if (M > 0x5) a0 = Set4(ad5 + k0), Madd4(ab50, a0, b0); + if (M > 0x6) a0 = Set4(ad0 + k6), Madd4(ab60, a0, b0); + if (M > 0x7) a0 = Set4(ad1 + k6), Madd4(ab70, a0, b0); + if (M > 0x8) a0 = Set4(ad2 + k6), Madd4(ab80, a0, b0); + if (M > 0x9) a0 = Set4(ad3 + k6), Madd4(ab90, a0, b0); + if (M > 0xA) a0 = Set4(ad4 + k6), Madd4(abA0, a0, b0); + if (M > 0xB) a0 = Set4(ad5 + k6), Madd4(abB0, a0, b0); + bd += DA; + } + __mmask16 tail = TailMask16(N); + if (M > 0x0) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab00, distances + 0, tail), an += 4, distances += stride; + if (M > 0x1) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab10, distances + 0, tail), an += 4, distances += stride; + if (M > 0x2) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab20, distances + 0, tail), an += 4, distances += stride; + if (M > 0x3) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab30, distances + 0, tail), an += 4, distances += stride; + if (M > 0x4) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab40, distances + 0, tail), an += 4, distances += stride; + if (M > 0x5) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab50, distances + 0, tail), an += 4, distances += stride; + if (M > 0x6) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab60, distances + 0, tail), an += 4, distances += stride; + if (M > 0x7) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab70, distances + 0, tail), an += 4, distances += stride; + if (M > 0x8) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab80, distances + 0, tail), an += 4, distances += stride; + if (M > 0x9) DecodeCosineDistances1xF(an, bn + 0, bnStride, ab90, distances + 0, tail), an += 4, distances += stride; + if (M > 0xA) DecodeCosineDistances1xF(an, bn + 0, bnStride, abA0, distances + 0, tail), an += 4, distances += stride; + if (M > 0xB) DecodeCosineDistances1xF(an, bn + 0, bnStride, abB0, distances + 0, tail), an += 4, distances += stride; + } + } + + typedef void(*Correlation8_2xM_Ptr)(size_t N, size_t K, const uint8_t* ad0, const uint8_t* bd, const float* an, const float* bn, size_t bnStride, float* distances, size_t stride); + + SIMD_INLINE Correlation8_2xM_Ptr GetCorrelation8_2xM(size_t M) + { + switch (M) + { + case 0x0: return NULL; + case 0x1: return Correlation8_2xM<0x1>; + case 0x2: return Correlation8_2xM<0x2>; + case 0x3: return Correlation8_2xM<0x3>; + case 0x4: return Correlation8_2xM<0x4>; + case 0x5: return Correlation8_2xM<0x5>; + case 0x6: return Correlation8_2xM<0x6>; + case 0x7: return Correlation8_2xM<0x7>; + case 0x8: return Correlation8_2xM<0x8>; + case 0x9: return Correlation8_2xM<0x9>; + case 0xA: return Correlation8_2xM<0xA>; + case 0xB: return Correlation8_2xM<0xB>; + case 0xC: return Correlation8_2xM<0xC>; + } + assert(0); + return NULL; + } + + void MacroCorrelation8(size_t M, size_t N, size_t K, const uint8_t* ad, const float* an, const uint8_t* bd, const float* bn, float* distances, size_t stride) + { + size_t M12 = AlignLoAny(M, 12); + Correlation8_2xM_Ptr correlation_2x12 = GetCorrelation8_2xM(12); + Correlation8_2xM_Ptr correlation_2xT = GetCorrelation8_2xM(M - M12); + for (size_t j = 0; j < N; j += DF) + { + size_t dN = Simd::Min(DF, N - j); + size_t i = 0; + for (; i < M12; i += 12) + correlation_2x12(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + if (i < M) + correlation_2xT(dN, K, ad + i * K, bd, an + i * 4, bn, N, distances + i * stride, stride); + bd += K * DF; + bn += DF; + distances += DF; + } + } + + //------------------------------------------------------------------------------------------------- + + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth) + { + return depth == 8 ? NULL : MacroCorrelation8; + } + } +#endif +} diff --git a/src/Simd/SimdDescrInt.h b/src/Simd/SimdDescrInt.h index 40444bf96b..1065b874fd 100644 --- a/src/Simd/SimdDescrInt.h +++ b/src/Simd/SimdDescrInt.h @@ -187,5 +187,24 @@ namespace Simd void* DescrIntInit(size_t size, size_t depth); } #endif + +#ifdef SIMD_AVX512VNNI_ENABLE + namespace Avx512vnni + { + class DescrInt : public Avx512bw::DescrInt + { + public: + DescrInt(size_t size, size_t depth); + }; + + //------------------------------------------------------------------------------------------------- + + Sse41::DescrInt::MacroCosineDistancesUnpackPtr GetMacroCosineDistancesUnpack(size_t depth); + + //------------------------------------------------------------------------------------------------- + + void* DescrIntInit(size_t size, size_t depth); + } +#endif } #endif//__SimdDescrInt_h__ diff --git a/src/Simd/SimdLib.cpp b/src/Simd/SimdLib.cpp index ad5de4e830..d7d0dab8a1 100644 --- a/src/Simd/SimdLib.cpp +++ b/src/Simd/SimdLib.cpp @@ -1922,7 +1922,7 @@ SIMD_API void* SimdDescrIntInit(size_t size, size_t depth) { SIMD_EMPTY(); typedef void* (*SimdDescrIntInitPtr) (size_t size, size_t depth); - const static SimdDescrIntInitPtr simdDescrIntInit = SIMD_FUNC3(DescrIntInit, SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC);// , SIMD_NEON_FUNC); + const static SimdDescrIntInitPtr simdDescrIntInit = SIMD_FUNC4(DescrIntInit, SIMD_AVX512VNNI_FUNC, SIMD_AVX512BW_FUNC, SIMD_AVX2_FUNC, SIMD_SSE41_FUNC);// , SIMD_NEON_FUNC); return simdDescrIntInit(size, depth); } diff --git a/src/Test/TestDescrInt.cpp b/src/Test/TestDescrInt.cpp index 5b65e282ab..311a335845 100644 --- a/src/Test/TestDescrInt.cpp +++ b/src/Test/TestDescrInt.cpp @@ -610,6 +610,11 @@ namespace Test if (Simd::Avx512bw::Enable) result = result && DescrIntCosineDistancesMxNaAutoTest(FUNC_DI(Simd::Avx512bw::DescrIntInit), FUNC_DI(SimdDescrIntInit)); #endif + +#ifdef SIMD_AVX512VNNI_ENABLE + if (Simd::Avx512vnni::Enable) + result = result && DescrIntCosineDistancesMxNaAutoTest(FUNC_DI(Simd::Avx512vnni::DescrIntInit), FUNC_DI(SimdDescrIntInit)); +#endif //#if defined(SIMD_NEON_ENABLE) // if (Simd::Neon::Enable) @@ -684,6 +689,11 @@ namespace Test result = result && DescrIntCosineDistancesMxNpAutoTest(FUNC_DI(Simd::Avx512bw::DescrIntInit), FUNC_DI(SimdDescrIntInit)); #endif +#ifdef SIMD_AVX512VNNI_ENABLE + if (Simd::Avx512vnni::Enable) + result = result && DescrIntCosineDistancesMxNpAutoTest(FUNC_DI(Simd::Avx512vnni::DescrIntInit), FUNC_DI(SimdDescrIntInit)); +#endif + //#if defined(SIMD_NEON_ENABLE) // if (Simd::Neon::Enable) // result = result && DescrIntCosineDistancesMxNpAutoTest(FUNC_DI(Simd::Neon::DescrIntInit), FUNC_DI(SimdDescrIntInit)); From 05260e243b375bcece9fb6ecf5fcaf40ef7fd635 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 29 Jun 2023 11:26:45 +0300 Subject: [PATCH 36/44] *improve WIN32 exception handling, Host Properties step in Github actions script for MSBuild; fix bugs in DescrIntCosineDistancesMxNaAutoTest and DescrIntCosineDistancesMxNpAutoTest tests. --- .github/workflows/msbuild.yml | 8 ++++---- docs/2023.html | 10 ++++++++++ src/Test/Test.cpp | 23 +++++++++++++++++++++-- src/Test/TestDescrInt.cpp | 4 ++-- 4 files changed, 37 insertions(+), 8 deletions(-) diff --git a/.github/workflows/msbuild.yml b/.github/workflows/msbuild.yml index 9168cdce5c..fa329770ff 100644 --- a/.github/workflows/msbuild.yml +++ b/.github/workflows/msbuild.yml @@ -29,7 +29,7 @@ jobs: run: nuget restore ./prj/vs2022 -Project2ProjectTimeOut 300 - name: Host properties - run: wmic cpu get /format:value + run: wmic cpu get Name,NumberOfCores,NumberOfLogicalProcessors /format:value - name: Build working-directory: ${{env.GITHUB_WORKSPACE}} @@ -58,7 +58,7 @@ jobs: run: nuget restore ./prj/vs2019 -Project2ProjectTimeOut 300 - name: Host properties - run: wmic cpu get /format:value + run: wmic cpu get Name,NumberOfCores,NumberOfLogicalProcessors /format:value - name: Build working-directory: ${{env.GITHUB_WORKSPACE}} @@ -93,7 +93,7 @@ jobs: Start-Process -Wait sdksetup.exe -ArgumentList "/q", "/norestart", "/features", "OptionId.WindowsDesktopSoftwareDevelopmentKit", "OptionId.NetFxSoftwareDevelopmentKit" - name: Host properties - run: wmic cpu get /format:value + run: wmic cpu get Name,NumberOfCores,NumberOfLogicalProcessors /format:value - name: Build working-directory: ${{env.GITHUB_WORKSPACE}} @@ -129,7 +129,7 @@ jobs: Start-Process -Wait sdksetup.exe -ArgumentList "/q", "/norestart", "/features", "OptionId.WindowsDesktopSoftwareDevelopmentKit", "OptionId.NetFxSoftwareDevelopmentKit" - name: Host properties - run: wmic cpu get /format:value + run: wmic cpu get Name,NumberOfCores,NumberOfLogicalProcessors /format:value - name: Build working-directory: ${{env.GITHUB_WORKSPACE}} diff --git a/docs/2023.html b/docs/2023.html index b8708c65ef..8cb318b504 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -70,6 +70,16 @@
      New features
    • Tests for verifying functionality of function DescrIntDecode16f.
    • Tests for verifying functionality of function SynetNormalizeLayerForwardV3
    +
    Improving
    +
      +
    • WIN32 exception handling.
    • +
    + +

    Infrastructure

    +
    Improving
    +
      +
    • Host Properties step in Github actions script for MSBuild.
    • +
    Home
    diff --git a/src/Test/Test.cpp b/src/Test/Test.cpp index 768a6de1cd..654bdac609 100644 --- a/src/Test/Test.cpp +++ b/src/Test/Test.cpp @@ -28,7 +28,7 @@ #include "Test/TestLog.h" #include "Test/TestString.h" -#if defined(_MSC_VER) +#if defined(_WIN32) #define NOMINMAX #include #endif @@ -555,19 +555,38 @@ namespace Test private: static bool RunGroup(const Group & group) { -#if defined(_MSC_VER) +#if defined(_WIN32) __try { return group.autoTest(); } __except (EXCEPTION_EXECUTE_HANDLER) { + PrintErrorMessage(GetExceptionCode()); return false; } #else return group.autoTest(); #endif } + +#if defined(_WIN32) + static void PrintErrorMessage(int code) + { + String desc; + switch (code) + { + case EXCEPTION_ACCESS_VIOLATION: desc = "Access violation"; break; + case EXCEPTION_FLT_DIVIDE_BY_ZERO: desc = "Float divide by zero"; break; + case EXCEPTION_INT_DIVIDE_BY_ZERO: desc = "Integer divide by zero"; break; + case EXCEPTION_ILLEGAL_INSTRUCTION: desc = "Illegal instruction"; break; + case EXCEPTION_STACK_OVERFLOW: desc = "Stack overflow"; break; + default: + desc = "Unknown error(" + std::to_string(code) + ")"; + } + TEST_LOG_SS(Error, "There is unhandled exception: " << desc << " !"); + } +#endif }; volatile bool Task::s_stopped = false; typedef std::shared_ptr TaskPtr; diff --git a/src/Test/TestDescrInt.cpp b/src/Test/TestDescrInt.cpp index 311a335845..1759d3af90 100644 --- a/src/Test/TestDescrInt.cpp +++ b/src/Test/TestDescrInt.cpp @@ -611,7 +611,7 @@ namespace Test result = result && DescrIntCosineDistancesMxNaAutoTest(FUNC_DI(Simd::Avx512bw::DescrIntInit), FUNC_DI(SimdDescrIntInit)); #endif -#ifdef SIMD_AVX512VNNI_ENABLE +#if defined(SIMD_AVX512VNNI_ENABLE) && !defined(SIMD_AMX_EMULATE) if (Simd::Avx512vnni::Enable) result = result && DescrIntCosineDistancesMxNaAutoTest(FUNC_DI(Simd::Avx512vnni::DescrIntInit), FUNC_DI(SimdDescrIntInit)); #endif @@ -689,7 +689,7 @@ namespace Test result = result && DescrIntCosineDistancesMxNpAutoTest(FUNC_DI(Simd::Avx512bw::DescrIntInit), FUNC_DI(SimdDescrIntInit)); #endif -#ifdef SIMD_AVX512VNNI_ENABLE +#if defined(SIMD_AVX512VNNI_ENABLE) && !defined(SIMD_AMX_EMULATE) if (Simd::Avx512vnni::Enable) result = result && DescrIntCosineDistancesMxNpAutoTest(FUNC_DI(Simd::Avx512vnni::DescrIntInit), FUNC_DI(SimdDescrIntInit)); #endif From 5c3c0e8f4c5801a4cf8c1513ef6400a218c7c4be Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Thu, 29 Jun 2023 12:56:37 +0300 Subject: [PATCH 37/44] *fix bug(mingw, Test.cpp). --- src/Simd/SimdAvx512bwDescrIntCdu.cpp | 2 +- src/Test/Test.cpp | 8 +++++--- src/Test/TestPerformance.cpp | 9 ++++++--- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/src/Simd/SimdAvx512bwDescrIntCdu.cpp b/src/Simd/SimdAvx512bwDescrIntCdu.cpp index 069c09dea9..336b1b0cbe 100644 --- a/src/Simd/SimdAvx512bwDescrIntCdu.cpp +++ b/src/Simd/SimdAvx512bwDescrIntCdu.cpp @@ -254,7 +254,7 @@ namespace Simd typedef void (*UnpackDataBx16xN_Ptr)(const uint8_t* const* src, size_t offset, uint8_t* dst); - template UnpackDataBx16xN_Ptr GetUnpackDataBx16xN(int tail) + template UnpackDataBx16xN_Ptr GetUnpackDataBx16xN(size_t tail) { switch (tail / 8) { diff --git a/src/Test/Test.cpp b/src/Test/Test.cpp index 654bdac609..9a271b390e 100644 --- a/src/Test/Test.cpp +++ b/src/Test/Test.cpp @@ -28,8 +28,10 @@ #include "Test/TestLog.h" #include "Test/TestString.h" -#if defined(_WIN32) +#if defined(_MSC_VER) +#ifndef NOMINMAX #define NOMINMAX +#endif #include #endif @@ -555,7 +557,7 @@ namespace Test private: static bool RunGroup(const Group & group) { -#if defined(_WIN32) +#if defined(_MSC_VER) __try { return group.autoTest(); @@ -570,7 +572,7 @@ namespace Test #endif } -#if defined(_WIN32) +#if defined(_MSC_VER) static void PrintErrorMessage(int code) { String desc; diff --git a/src/Test/TestPerformance.cpp b/src/Test/TestPerformance.cpp index 40f97bba4f..4b8462484e 100644 --- a/src/Test/TestPerformance.cpp +++ b/src/Test/TestPerformance.cpp @@ -29,7 +29,9 @@ #include "Test/TestHtml.h" #if defined(_MSC_VER) +#ifndef NOMINMAX #define NOMINMAX +#endif #include #elif defined(__GNUC__) #include @@ -295,7 +297,7 @@ namespace Test AddToFunction(src, dst.avx512vnni, enable.avx512vnni); if (desc.find("Simd::Avx512bf16::") != std::string::npos) AddToFunction(src, dst.avx512bf16, enable.avx512bf16); - if (desc.find("Simd::Amx::") != std::string::npos) + if (desc.find("Simd::AmxBf16::") != std::string::npos) AddToFunction(src, dst.amx, enable.amx); if (desc.find("Simd::Vmx::") != std::string::npos) AddToFunction(src, dst.vmx, enable.vmx); @@ -444,8 +446,8 @@ namespace Test info << "Execution time: " + GetCurrentDateTimeString(); info << ". Test threads: " << threads; info << ". Simd version: " << SimdVersion() << "."; -#if defined(__linux__) String cpu = "Unknown", mem = "Unknown"; +#if defined(__linux__) ::FILE* c = ::popen("lscpu | grep 'Model name:' | sed -r 's/Model name:\\s{1,}//g'", "r"); if (c) { @@ -464,6 +466,7 @@ namespace Test mem = mem.substr(0, mem.find('\n')); ::pclose(m); } +#endif info << std::endl; info << "CPU: " << cpu; info << "; Sockets: " << SimdCpuInfo(SimdCpuInfoSockets); @@ -485,7 +488,7 @@ namespace Test info << (SimdCpuInfo(SimdCpuInfoVsx) ? " VSX" : ""); info << (SimdCpuInfo(SimdCpuInfoNeon) ? " NEON" : ""); info << "."; -#endif + return info.str(); } From 2e527d266ddcfee5984f6e394ae2301280d678b4 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 4 Jul 2023 10:57:16 +0300 Subject: [PATCH 38/44] *RELEASE 5.3.127. --- docs/2023.html | 2 +- docs/download.html | 3 +++ prj/txt/UserVersion.txt | 2 +- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index 8cb318b504..5645e32e33 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -33,7 +33,7 @@

    Simd Library Release Notes (2023).


    -

    July 5, 2023 (version X.X.127)

    +

    July 4, 2023 (version 5.3.127)

    Algorithms

    New features
      diff --git a/docs/download.html b/docs/download.html index ac44b30701..b77a05ee2f 100644 --- a/docs/download.html +++ b/docs/download.html @@ -27,6 +27,9 @@

      Simd Library Download.

      2023

      + + + diff --git a/prj/txt/UserVersion.txt b/prj/txt/UserVersion.txt index 3ec14b919e..e25ce7f876 100644 --- a/prj/txt/UserVersion.txt +++ b/prj/txt/UserVersion.txt @@ -1 +1 @@ -5.3.126 \ No newline at end of file +5.3.127 \ No newline at end of file From baed5793955e051f9ee98f6698b7ae4b3a5ef1f9 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 4 Jul 2023 15:32:00 +0300 Subject: [PATCH 39/44] *improve WIN32 performance report. --- docs/2023.html | 9 +++++++++ src/Test/TestPerformance.cpp | 31 +++++++++++++++++++++++++++++++ 2 files changed, 40 insertions(+) diff --git a/docs/2023.html b/docs/2023.html index 5645e32e33..962eadad71 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -32,6 +32,15 @@

      Simd Library Release Notes (2023).

      2013 +
      +

      August X, 2023 (version X.X.128)

      + +

      Test framework

      +
      Improving
      +
        +
      • WIN32 performance report.
      • +
      +

      July 4, 2023 (version 5.3.127)

      Algorithms

      diff --git a/src/Test/TestPerformance.cpp b/src/Test/TestPerformance.cpp index 4b8462484e..6d03246cad 100644 --- a/src/Test/TestPerformance.cpp +++ b/src/Test/TestPerformance.cpp @@ -440,6 +440,23 @@ namespace Test return "Simd Library Performance Report:"; } +#if defined(_WIN32) + static std::string Execute(const char* cmd) + { + std::shared_ptr pipe(_popen(cmd, "r"), _pclose); + if (!pipe) + return "ERROR"; + char buffer[MAX_PATH]; + std::string result = ""; + while (!feof(pipe.get())) + { + if (fgets(buffer, MAX_PATH, pipe.get()) != NULL) + result += buffer; + } + return result; + } +#endif + static String TestInfo(size_t threads) { std::stringstream info; @@ -466,6 +483,20 @@ namespace Test mem = mem.substr(0, mem.find('\n')); ::pclose(m); } +#elif defined(_WIN32) + String cpuRaw = Execute("wmic cpu get Name /format:value"); + size_t cpuBeg = cpuRaw.find('=') + 1; + size_t cpuEnd = cpuRaw.find('\r', cpuBeg); + while (cpuRaw[cpuEnd - 1] == ' ') + cpuEnd--; + cpu = cpuRaw.substr(cpuBeg, cpuEnd - cpuBeg); + MEMORYSTATUSEX memorystatusex; + memorystatusex.dwLength = sizeof(memorystatusex); + if (GlobalMemoryStatusEx(&memorystatusex) == TRUE) + { + double memGB = double(memorystatusex.ullTotalPhys) / 1024.0 / 1024.0 / 1024.0; + mem = ToString(memGB, 1, false); + } #endif info << std::endl; info << "CPU: " << cpu; From a14822cf4a440639262df0aebe7b438afb42100a Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 4 Jul 2023 16:56:34 +0300 Subject: [PATCH 40/44] +add support of SimdCpuInfoRam in function SimdCpuInfo. --- docs/2023.html | 7 ++++- docs/help/class_simd_1_1_font.html | 2 +- docs/help/group__c__types.html | 4 +++ docs/help/group__info.html | 15 ++++++----- src/Simd/SimdBaseCpu.cpp | 41 ++++++++++++++++++++++++++++++ src/Simd/SimdCpu.h | 1 + src/Simd/SimdLib.cpp | 3 ++- src/Simd/SimdLib.h | 16 +++++++----- src/Simd/SimdLib.hpp | 1 + src/Test/TestPerformance.cpp | 26 ++++--------------- 10 files changed, 78 insertions(+), 38 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index 962eadad71..1bc7e71c70 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -34,7 +34,12 @@

      Simd Library Release Notes (2023).


      August X, 2023 (version X.X.128)

      - +

      Algorithms

      +
      New features
      +
        +
      • Support of SimdCpuInfoRam in function SimdCpuInfo.
      • +
      • Support of SimdCpuInfoRam in function Simd::PrintInfo.
      • +

      Test framework

      Improving
        diff --git a/docs/help/class_simd_1_1_font.html b/docs/help/class_simd_1_1_font.html index 6053304747..0b656ab1b5 100644 --- a/docs/help/class_simd_1_1_font.html +++ b/docs/help/class_simd_1_1_font.html @@ -107,7 +107,7 @@

        Simd Library Documentation.

        }
        The Font class provides text drawing.
        Definition: SimdFont.hpp:64
        Simd::View< Simd::Allocator > View
        Definition: SimdFont.hpp:68
        -
        SIMD_INLINE void FillPixel(View< A > &dst, const Pixel &pixel)
        Fills image by value of given pixel.
        Definition: SimdLib.hpp:2064
        +
        SIMD_INLINE void FillPixel(View< A > &dst, const Pixel &pixel)
        Fills image by value of given pixel.
        Definition: SimdLib.hpp:2065
        32-bit BGRA pixel.
        Definition: SimdPixel.hpp:136
        @ Bgra32
        Definition: SimdView.hpp:89
        diff --git a/docs/help/group__c__types.html b/docs/help/group__c__types.html index 17260b77ab..9084d97015 100644 --- a/docs/help/group__c__types.html +++ b/docs/help/group__c__types.html @@ -96,6 +96,8 @@

        Simd Library Documentation.

        ,
          SimdCpuInfoCacheL3 ,
        +  SimdCpuInfoRam +,
          SimdCpuInfoSse41 ,
          SimdCpuInfoAvx @@ -330,6 +332,8 @@

        SimdCpuInfoCacheL3 

      + - + @@ -109,12 +109,13 @@

      pipe(_popen(cmd, "r"), _pclose); if (!pipe) return "ERROR"; - char buffer[MAX_PATH]; + char buffer[260]; std::string result = ""; while (!feof(pipe.get())) { - if (fgets(buffer, MAX_PATH, pipe.get()) != NULL) + if (fgets(buffer, sizeof(buffer), pipe.get()) != NULL) result += buffer; } return result; @@ -463,7 +463,7 @@ namespace Test info << "Execution time: " + GetCurrentDateTimeString(); info << ". Test threads: " << threads; info << ". Simd version: " << SimdVersion() << "."; - String cpu = "Unknown", mem = "Unknown"; + String cpu = "Unknown"; #if defined(__linux__) ::FILE* c = ::popen("lscpu | grep 'Model name:' | sed -r 's/Model name:\\s{1,}//g'", "r"); if (c) @@ -474,15 +474,6 @@ namespace Test cpu = cpu.substr(0, cpu.find('\n')); ::pclose(c); } - ::FILE* m = ::popen("grep MemTotal /proc/meminfo | awk '{printf \"%.1f\", $2 / 1024 / 1024 }'", "r"); - if (m) - { - char buf[PATH_MAX]; - while (::fgets(buf, PATH_MAX, m)); - mem = buf; - mem = mem.substr(0, mem.find('\n')); - ::pclose(m); - } #elif defined(_WIN32) String cpuRaw = Execute("wmic cpu get Name /format:value"); size_t cpuBeg = cpuRaw.find('=') + 1; @@ -490,13 +481,6 @@ namespace Test while (cpuRaw[cpuEnd - 1] == ' ') cpuEnd--; cpu = cpuRaw.substr(cpuBeg, cpuEnd - cpuBeg); - MEMORYSTATUSEX memorystatusex; - memorystatusex.dwLength = sizeof(memorystatusex); - if (GlobalMemoryStatusEx(&memorystatusex) == TRUE) - { - double memGB = double(memorystatusex.ullTotalPhys) / 1024.0 / 1024.0 / 1024.0; - mem = ToString(memGB, 1, false); - } #endif info << std::endl; info << "CPU: " << cpu; @@ -505,8 +489,8 @@ namespace Test info << ", Threads: " << SimdCpuInfo(SimdCpuInfoThreads); info << "; Cache L1D: " << SimdCpuInfo(SimdCpuInfoCacheL1) / 1024 << " KB"; info << ", L2: " << SimdCpuInfo(SimdCpuInfoCacheL2) / 1024 << " KB"; - info << ", L3: " << SimdCpuInfo(SimdCpuInfoCacheL3) / 1024 / 1024 << " MB"; - info << ", RAM: " << mem << " GB"; + info << ", L3: " << ToString(double(SimdCpuInfo(SimdCpuInfoCacheL3) / 1024) / 1024, 1, false) << " MB"; + info << ", RAM: " << ToString(double(SimdCpuInfo(SimdCpuInfoRam)) / 1024 / 1024 / 1024, 1, false) << " GB"; info << "; SIMD:"; info << (SimdCpuInfo(SimdCpuInfoAmx) ? " AMX" : ""); info << (SimdCpuInfo(SimdCpuInfoAvx512bf16) ? " AVX-512BF16" : ""); From af6f49e425c1b5d275a0dfc01ca3c10a8dddb3b7 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 4 Jul 2023 19:36:43 +0300 Subject: [PATCH 41/44] +add Base implementation of function SimdCpuDesc. --- docs/2023.html | 7 ++++ docs/help/class_simd_1_1_font.html | 2 +- docs/help/group__c__types.html | 24 ++++++++++- docs/help/group__descrint.html | 1 - docs/help/group__info.html | 57 ++++++++++++++++++++++---- docs/help/group__synet__normalize.html | 6 +-- docs/help/namespace_simd.html | 2 +- src/Simd/SimdBaseCpu.cpp | 38 +++++++++++++---- src/Simd/SimdCpu.h | 5 +++ src/Simd/SimdLib.cpp | 10 +++++ src/Simd/SimdLib.h | 44 ++++++++++++++++---- src/Simd/SimdLib.hpp | 3 +- src/Test/TestCheckCpp.cpp | 2 + src/Test/TestPerformance.cpp | 25 ++--------- 14 files changed, 172 insertions(+), 54 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index 1bc7e71c70..34658492f7 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -39,6 +39,7 @@
      New features
      • Support of SimdCpuInfoRam in function SimdCpuInfo.
      • Support of SimdCpuInfoRam in function Simd::PrintInfo.
      • +
      • Base implementation of function SimdCpuDesc.

      Test framework

      Improving
      @@ -46,6 +47,12 @@
      Improving
    • WIN32 performance report.
    • +

      Documentation

      +
      Bug fixing
      +
        +
      • Wrong description of function SimdDescrIntInit.
      • +
      +

      July 4, 2023 (version 5.3.127)

      Algorithms

      diff --git a/docs/help/class_simd_1_1_font.html b/docs/help/class_simd_1_1_font.html index 0b656ab1b5..1a4af07eaa 100644 --- a/docs/help/class_simd_1_1_font.html +++ b/docs/help/class_simd_1_1_font.html @@ -107,7 +107,7 @@

      Simd Library Documentation.

      }
      The Font class provides text drawing.
      Definition: SimdFont.hpp:64
      Simd::View< Simd::Allocator > View
      Definition: SimdFont.hpp:68
      -
      SIMD_INLINE void FillPixel(View< A > &dst, const Pixel &pixel)
      Fills image by value of given pixel.
      Definition: SimdLib.hpp:2065
      +
      SIMD_INLINE void FillPixel(View< A > &dst, const Pixel &pixel)
      Fills image by value of given pixel.
      Definition: SimdLib.hpp:2066
      32-bit BGRA pixel.
      Definition: SimdPixel.hpp:136
      @ Bgra32
      Definition: SimdView.hpp:89
      diff --git a/docs/help/group__c__types.html b/docs/help/group__c__types.html index 9084d97015..806a396d44 100644 --- a/docs/help/group__c__types.html +++ b/docs/help/group__c__types.html @@ -83,6 +83,9 @@

      Simd Library Documentation.


      }

      + + - + diff --git a/src/Simd/SimdBaseCpu.cpp b/src/Simd/SimdBaseCpu.cpp index 94a55e6c75..93e4323509 100644 --- a/src/Simd/SimdBaseCpu.cpp +++ b/src/Simd/SimdBaseCpu.cpp @@ -182,6 +182,15 @@ namespace Simd return 0; } + uint64_t CpuRamSize() + { + MEMORYSTATUSEX memorystatusex; + memorystatusex.dwLength = sizeof(memorystatusex); + if (GlobalMemoryStatusEx(&memorystatusex) == TRUE) + return memorystatusex.ullTotalPhys; + return 0; + } + static std::string Execute(const char* cmd) { std::string result = ""; @@ -199,13 +208,14 @@ namespace Simd return result; } - uint64_t CpuRamSize() + std::string CpuModel() { - MEMORYSTATUSEX memorystatusex; - memorystatusex.dwLength = sizeof(memorystatusex); - if (GlobalMemoryStatusEx(&memorystatusex) == TRUE) - return memorystatusex.ullTotalPhys; - return 0; + std::string raw = Execute("wmic cpu get Name /format:value"); + size_t beg = raw.find('=') + 1; + size_t end = raw.find('\r', beg); + while (raw[end - 1] == ' ') + end--; + return raw.substr(beg, end - beg); } #elif defined(__GNUC__) @@ -280,7 +290,6 @@ namespace Simd } } #endif - uint64_t CpuRamSize() { uint64_t size = 0; @@ -295,6 +304,20 @@ namespace Simd return size; } + std::string CpuModel() + { + std::string model; + ::FILE* file = ::popen("lscpu | grep 'Model name:' | sed -r 's/Model name:\\s{1,}//g'", "r"); + if (file) + { + char buffer[PATH_MAX]; + while (::fgets(buffer, PATH_MAX, file)); + model = buffer; + model = model.substr(0, model.find('\n')); + ::pclose(file); + } + return model; + } #else #error This platform is unsupported! #endif @@ -302,6 +325,7 @@ namespace Simd namespace Cpu { + const std::string CPU_MODEL = Base::CpuModel(); const size_t SOCKET_NUMBER = Base::CpuSocketNumber(); const size_t CORE_NUMBER = Base::CpuCoreNumber(); #ifdef SIMD_CPP_2011_ENABLE diff --git a/src/Simd/SimdCpu.h b/src/Simd/SimdCpu.h index 7ce8d9d700..7b0b6e2e2c 100644 --- a/src/Simd/SimdCpu.h +++ b/src/Simd/SimdCpu.h @@ -27,6 +27,8 @@ #include "Simd/SimdDefs.h" +#include + namespace Simd { #if defined(SIMD_X86_ENABLE) || defined(SIMD_X64_ENABLE) @@ -89,6 +91,7 @@ namespace Simd namespace Cpu { + extern const std::string CPU_MODEL; extern const size_t SOCKET_NUMBER; extern const size_t CORE_NUMBER; extern const size_t THREAD_NUMBER; @@ -110,6 +113,8 @@ namespace Simd bool CheckBit(int at, int bit); #endif + std::string CpuModel(); + size_t CpuSocketNumber(); size_t CpuCoreNumber(); diff --git a/src/Simd/SimdLib.cpp b/src/Simd/SimdLib.cpp index 176014da74..8f5f42495b 100644 --- a/src/Simd/SimdLib.cpp +++ b/src/Simd/SimdLib.cpp @@ -107,6 +107,16 @@ SIMD_API const char * SimdVersion() using namespace Simd; +SIMD_API const char* SimdCpuDesc(SimdCpuDescType type) +{ + switch (type) + { + case SimdCpuDescModel: return Cpu::CPU_MODEL.c_str(); + default: + return NULL; + } +} + SIMD_API uint64_t SimdCpuInfo(SimdCpuInfoType type) { switch (type) diff --git a/src/Simd/SimdLib.h b/src/Simd/SimdLib.h index a3a68b84ed..f8751758cd 100644 --- a/src/Simd/SimdLib.h +++ b/src/Simd/SimdLib.h @@ -220,6 +220,14 @@ typedef enum SimdConvolutionActivationGelu, } SimdConvolutionActivationType; +/*! @ingroup c_types + Describes type of description which can return function ::SimdCpuDesc. +*/ +typedef enum +{ + SimdCpuDescModel, /*!< A CPU model name. */ +} SimdCpuDescType; + /*! @ingroup c_types Describes type of information which can return function ::SimdCpuInfo. */ @@ -685,9 +693,34 @@ extern "C" /*! @ingroup info - \fn size_t SimdCpuInfo(SimdCpuInfoType type); + \fn const char* SimdCpuDesc(SimdCpuDescType type); + + \short Gets description of CPU and %Simd Library. + + \note See enumeration ::SimdCpuDescType. + + Using example: + \verbatim + #include "Simd/SimdLib.h" + #include + + int main() + { + std::cout << "CPU: " << SimdCpuDesc(SimdCpuDescModel) << std::endl; + return 0; + } + \endverbatim + + \param [in] type - a type of required description. + \return a value which contains description of CPU and %Simd Library. + */ + SIMD_API const char* SimdCpuDesc(SimdCpuDescType type); + + /*! @ingroup info + + \fn uint64_t SimdCpuInfo(SimdCpuInfoType type); - \short Gets info about CPU and %Simd Library. + \short Gets information about CPU and %Simd Library. \note See enumeration ::SimdCpuInfoType. @@ -2471,9 +2504,6 @@ extern "C" \short Initilizes Integer Descriptor Engine. - All images must have the same width and height. - This function used for NV12 to YUV420P conversion. - \param [in] size - a length of original (32-bit or 16-bit) float descriptor. It be multiple of 8. Also it must be less or equal than 32768. \param [in] depth - a number of bits in encoded integer descriptor. Supported values: 4, 5, 6, 7, 8. \return a pointer to Integer Descriptor Engine context. On error it returns NULL. It must be released with using of function ::SimdRelease. @@ -7676,9 +7706,9 @@ extern "C" /*! @ingroup synet_normalize - \fn void SimdSynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); + \fn void SimdSynetNormalizeLayerForwardV3(const float* src, size_t batch, size_t channels, size_t spatial, const float* scale, const float* shift, const float* eps, SimdTensorFormatType format, float* buf, float* dst); - \short Performs forward propagation of NormalizeLayer (Version 3). + \short Performs forward propagation of NormalizeLayer (Version 3). Algorithm's details: \verbatim diff --git a/src/Simd/SimdLib.hpp b/src/Simd/SimdLib.hpp index 958782e2dc..ca68d445b5 100644 --- a/src/Simd/SimdLib.hpp +++ b/src/Simd/SimdLib.hpp @@ -41,13 +41,14 @@ namespace Simd \fn void PrintInfo(std::ostream & os) - \short Prints information about %Simd Library and CPU properties. + \short Prints information about %Simd Library and CPU. \param [in, out] os - output stream. */ SIMD_INLINE void PrintInfo(std::ostream & os) { os << "Simd Library: " << SimdVersion(); + os << "; CPU: " << SimdCpuDesc(SimdCpuDescModel); os << "; System Sockets: " << SimdCpuInfo(SimdCpuInfoSockets); os << ", Cores: " << SimdCpuInfo(SimdCpuInfoCores); os << ", Threads: " << SimdCpuInfo(SimdCpuInfoThreads); diff --git a/src/Test/TestCheckCpp.cpp b/src/Test/TestCheckCpp.cpp index db445cf1bb..0e313d351e 100644 --- a/src/Test/TestCheckCpp.cpp +++ b/src/Test/TestCheckCpp.cpp @@ -44,12 +44,14 @@ namespace Test static void TestCpuInfo() { std::cout << "Simd Library : " << SimdVersion() << std::endl; + std::cout << "CPU : " << SimdCpuDesc(SimdCpuDescModel) << std::endl; std::cout << "Sockets : " << SimdCpuInfo(SimdCpuInfoSockets) << std::endl; std::cout << "Cores : " << SimdCpuInfo(SimdCpuInfoCores) << std::endl; std::cout << "Threads : " << SimdCpuInfo(SimdCpuInfoThreads) << std::endl; std::cout << "L1D Cache : " << SimdCpuInfo(SimdCpuInfoCacheL1) / 1024 << " KB" << std::endl; std::cout << "L2 Cache : " << SimdCpuInfo(SimdCpuInfoCacheL2) / 1024 << " KB" << std::endl; std::cout << "L3 Cache : " << SimdCpuInfo(SimdCpuInfoCacheL3) / 1024 << " KB" << std::endl; + std::cout << "RAM : " << SimdCpuInfo(SimdCpuInfoRam) / 1024 / 1024 << " MB" << std::endl; std::cout << "SSE4.1: " << (SimdCpuInfo(SimdCpuInfoSse41) ? "Yes" : "No") << std::endl; std::cout << "AVX: " << (SimdCpuInfo(SimdCpuInfoAvx) ? "Yes" : "No") << std::endl; std::cout << "AVX2: " << (SimdCpuInfo(SimdCpuInfoAvx2) ? "Yes" : "No") << std::endl; diff --git a/src/Test/TestPerformance.cpp b/src/Test/TestPerformance.cpp index b7a20117b3..de4c038316 100644 --- a/src/Test/TestPerformance.cpp +++ b/src/Test/TestPerformance.cpp @@ -462,29 +462,10 @@ namespace Test std::stringstream info; info << "Execution time: " + GetCurrentDateTimeString(); info << ". Test threads: " << threads; - info << ". Simd version: " << SimdVersion() << "."; - String cpu = "Unknown"; -#if defined(__linux__) - ::FILE* c = ::popen("lscpu | grep 'Model name:' | sed -r 's/Model name:\\s{1,}//g'", "r"); - if (c) - { - char buf[PATH_MAX]; - while (::fgets(buf, PATH_MAX, c)); - cpu = buf; - cpu = cpu.substr(0, cpu.find('\n')); - ::pclose(c); - } -#elif defined(_WIN32) - String cpuRaw = Execute("wmic cpu get Name /format:value"); - size_t cpuBeg = cpuRaw.find('=') + 1; - size_t cpuEnd = cpuRaw.find('\r', cpuBeg); - while (cpuRaw[cpuEnd - 1] == ' ') - cpuEnd--; - cpu = cpuRaw.substr(cpuBeg, cpuEnd - cpuBeg); -#endif + info << ". Simd version: " << SimdVersion(); + info << ". CPU: " << SimdCpuDesc(SimdCpuDescModel) << "."; info << std::endl; - info << "CPU: " << cpu; - info << "; Sockets: " << SimdCpuInfo(SimdCpuInfoSockets); + info << "Sockets: " << SimdCpuInfo(SimdCpuInfoSockets); info << ", Cores: " << SimdCpuInfo(SimdCpuInfoCores); info << ", Threads: " << SimdCpuInfo(SimdCpuInfoThreads); info << "; Cache L1D: " << SimdCpuInfo(SimdCpuInfoCacheL1) / 1024 << " KB"; From 271f540ae4907ccff9e97fe24f01b096b1b83381 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 5 Jul 2023 14:09:03 +0300 Subject: [PATCH 42/44] *improve cmake github actions script. --- .github/workflows/cmake.yml | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index 74bcfa6b62..c208668d98 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -21,7 +21,7 @@ jobs: run: lscpu - name: Configure CMake - run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_AVX512VNNI=ON -DSIMD_TEST_FLAGS="-mavx2" + run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_AVX512VNNI=ON -DSIMD_TEST_FLAGS="-march=native" - name: Build run: cmake --build ${{github.workspace}}/build --config ${{matrix.build_type}} --parallel$(nproc) @@ -113,7 +113,7 @@ jobs: run: sudo apt-get -y install clang - name: Configure CMake - run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TOOLCHAIN="clang" -DSIMD_TARGET="" + run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TOOLCHAIN="clang" -DSIMD_TARGET="" -DSIMD_TEST_FLAGS="-march=native" - name: Build run: cmake --build ${{github.workspace}}/build --config ${{matrix.build_type}} --parallel$(nproc) @@ -139,7 +139,7 @@ jobs: run: sudo apt-get -y install g++-12 - name: Configure CMake - run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TOOLCHAIN="g++-12" -DSIMD_TARGET="" -DSIMD_AVX512VNNI=ON + run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TOOLCHAIN="g++-12" -DSIMD_TARGET="" -DSIMD_AVX512VNNI=ON -DSIMD_TEST_FLAGS="-march=native" - name: Build run: cmake --build ${{github.workspace}}/build --config ${{matrix.build_type}} --parallel$(nproc) @@ -175,7 +175,7 @@ jobs: run: wmic cpu get /format:value - name: Configure CMake - run: cmake ./prj/cmake -B build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TARGET="x86_64" -DSIMD_GET_VERSION=OFF + run: cmake ./prj/cmake -B build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TARGET="x86_64" -DSIMD_GET_VERSION=OFF -DSIMD_TEST_FLAGS="-march=native" - name: Build run: cmake --build build --config ${{matrix.build_type}} --parallel2 From d2eae69992ef8f2748cd7db4a68cf058f89c1a82 Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Wed, 5 Jul 2023 15:37:21 +0300 Subject: [PATCH 43/44] *fix bugs in Error in AVX-512VNNI optimizations of classes SynetMergedConvolution8iCdc, SynetMergedConvolution8iCd, SynetMergedConvolution8iDc. --- .github/workflows/cmake.yml | 6 +++--- docs/2023.html | 13 +++++++++++++ .../SimdAvx512vnniSynetMergedConvolution8iInput.cpp | 1 + ...SimdAvx512vnniSynetMergedConvolution8iOutput.cpp | 1 + src/Test/TestCompare.cpp | 1 + 5 files changed, 19 insertions(+), 3 deletions(-) diff --git a/.github/workflows/cmake.yml b/.github/workflows/cmake.yml index c208668d98..f318eee5dd 100644 --- a/.github/workflows/cmake.yml +++ b/.github/workflows/cmake.yml @@ -122,7 +122,7 @@ jobs: working-directory: ${{github.workspace}}/build run: ./Test "-r=.." -m=a -tt=$(nproc) "-ot=log_${{matrix.build_type}}.txt" -ts=10 - build_and_test_gcc_12: + build_and_test_gcc_13: runs-on: ubuntu-latest strategy: @@ -136,10 +136,10 @@ jobs: run: lscpu - name: Install toolchain - run: sudo apt-get -y install g++-12 + run: sudo apt-get -y install g++-13 - name: Configure CMake - run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TOOLCHAIN="g++-12" -DSIMD_TARGET="" -DSIMD_AVX512VNNI=ON -DSIMD_TEST_FLAGS="-march=native" + run: cmake ./prj/cmake -B ${{github.workspace}}/build -DCMAKE_BUILD_TYPE=${{matrix.build_type}} -DSIMD_TOOLCHAIN="g++-13" -DSIMD_TARGET="" -DSIMD_AVX512VNNI=ON -DSIMD_AVX512BF16=ON -DSIMD_TEST_FLAGS="-march=native" - name: Build run: cmake --build ${{github.workspace}}/build --config ${{matrix.build_type}} --parallel$(nproc) diff --git a/docs/2023.html b/docs/2023.html index 34658492f7..74c4fe69ef 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -41,12 +41,25 @@
      New features
    • Support of SimdCpuInfoRam in function Simd::PrintInfo.
    • Base implementation of function SimdCpuDesc.
    • +
      Bug fixing
      +
        +
      • Error in AVX-512VNNI optimizations of class SynetMergedConvolution8iCdc.
      • +
      • Error in AVX-512VNNI optimizations of class SynetMergedConvolution8iCd.
      • +
      • Error in AVX-512VNNI optimizations of class SynetMergedConvolution8iDc.
      • +
      +

      Test framework

      Improving
      • WIN32 performance report.
      +

      Infrastructure

      +
      New features
      +
        +
      • Github actions script for CMake (build and test for GCC-13 (instead of GCC-12), Linux).
      • +
      +

      Documentation

      Bug fixing
        diff --git a/src/Simd/SimdAvx512vnniSynetMergedConvolution8iInput.cpp b/src/Simd/SimdAvx512vnniSynetMergedConvolution8iInput.cpp index 72c46ec86e..70732f3495 100644 --- a/src/Simd/SimdAvx512vnniSynetMergedConvolution8iInput.cpp +++ b/src/Simd/SimdAvx512vnniSynetMergedConvolution8iInput.cpp @@ -873,6 +873,7 @@ namespace Simd case SimdConvolutionActivationElu: SetInput(p, input); break; case SimdConvolutionActivationHswish: SetInput(p, input); break; case SimdConvolutionActivationMish: SetInput(p, input); break; + case SimdConvolutionActivationHardSigmoid: SetInput(p, input); break; case SimdConvolutionActivationSwish: SetInput(p, input); break; case SimdConvolutionActivationGelu: SetInput(p, input); break; } diff --git a/src/Simd/SimdAvx512vnniSynetMergedConvolution8iOutput.cpp b/src/Simd/SimdAvx512vnniSynetMergedConvolution8iOutput.cpp index 5c97c423b1..e1bfee2d12 100644 --- a/src/Simd/SimdAvx512vnniSynetMergedConvolution8iOutput.cpp +++ b/src/Simd/SimdAvx512vnniSynetMergedConvolution8iOutput.cpp @@ -346,6 +346,7 @@ namespace Simd case SimdConvolutionActivationElu: SetOutput(p, output); break; case SimdConvolutionActivationHswish: SetOutput(p, output); break; case SimdConvolutionActivationMish: SetOutput(p, output); break; + case SimdConvolutionActivationHardSigmoid: SetOutput(p, output); break; case SimdConvolutionActivationSwish: SetOutput(p, output); break; case SimdConvolutionActivationGelu: SetOutput(p, output); break; } diff --git a/src/Test/TestCompare.cpp b/src/Test/TestCompare.cpp index 2ab356a3ed..b2d7c52c12 100644 --- a/src/Test/TestCompare.cpp +++ b/src/Test/TestCompare.cpp @@ -225,6 +225,7 @@ namespace Test case DifferenceRelative: error = relative > differenceMax; break; case DifferenceBoth: error = absolute > differenceMax && relative > differenceMax; break; case DifferenceAny: error = absolute > differenceMax || relative > differenceMax; break; + case DifferenceLogical: assert(0); break; } if (error) { From 02350687e659f6c0e652cda8f0300eaf6b022ccd Mon Sep 17 00:00:00 2001 From: Yermalayeu Ihar Date: Tue, 18 Jul 2023 17:17:20 +0300 Subject: [PATCH 44/44] *fix bug: Error (assert) in Base implementation of class ResizerNearest. --- docs/2023.html | 1 + src/Simd/SimdBaseResizerNearest.cpp | 4 ++-- src/Test/TestResize.cpp | 3 +++ 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/2023.html b/docs/2023.html index 74c4fe69ef..c83c3e9135 100644 --- a/docs/2023.html +++ b/docs/2023.html @@ -46,6 +46,7 @@
        Bug fixing
      • Error in AVX-512VNNI optimizations of class SynetMergedConvolution8iCdc.
      • Error in AVX-512VNNI optimizations of class SynetMergedConvolution8iCd.
      • Error in AVX-512VNNI optimizations of class SynetMergedConvolution8iDc.
      • +
      • Error (assert) in Base implementation of class ResizerNearest.

      Test framework

      diff --git a/src/Simd/SimdBaseResizerNearest.cpp b/src/Simd/SimdBaseResizerNearest.cpp index 62df595898..72e512be83 100644 --- a/src/Simd/SimdBaseResizerNearest.cpp +++ b/src/Simd/SimdBaseResizerNearest.cpp @@ -37,7 +37,7 @@ namespace Simd void ResizerNearest::EstimateIndex(size_t srcSize, size_t dstSize, size_t channelSize, size_t channels, int32_t* indices) { - if (_param.method == SimdResizeMethodNearest) + if (_param.method == SimdResizeMethodNearest || _param.method == SimdResizeMethodBilinear || _param.method == SimdResizeMethodBicubic) { float scale = (float)srcSize / dstSize; for (size_t i = 0; i < dstSize; ++i) @@ -51,7 +51,7 @@ namespace Simd } } } - else if (_param.method == SimdResizeMethodNearestPytorch) + else if (_param.method == SimdResizeMethodNearestPytorch || _param.method == SimdResizeMethodBilinearPytorch) { for (size_t i = 0; i < dstSize; ++i) { diff --git a/src/Test/TestResize.cpp b/src/Test/TestResize.cpp index 8b4dbffd56..8496c649a6 100644 --- a/src/Test/TestResize.cpp +++ b/src/Test/TestResize.cpp @@ -345,6 +345,9 @@ namespace Test { bool result = true; + result = result && ResizerAutoTest(SimdResizeMethodBilinear, SimdResizeChannelByte, 4, 100, 1, 200, 10, f1, f2); + result = result && ResizerAutoTest(SimdResizeMethodBicubic, SimdResizeChannelByte, 4, 100, 2, 200, 10, f1, f2); + #if !defined(__aarch64__) || 1 std::vector methods = { SimdResizeMethodNearest, SimdResizeMethodBilinear, SimdResizeMethodBicubic, SimdResizeMethodArea, SimdResizeMethodAreaFast }; for (size_t m = 0; m < methods.size(); ++m)
      Release Notes Download Link Size
      July 4, 2023 Simd-5.3.127.zip 5.9 MB
      June 5, 2023 Simd-5.3.126.zip 5.9 MB

      A size of level 3 cache.

      SimdCpuInfoRam 

      A size of physical RAM.

      +
      SimdCpuInfoSse41 

      Availability of SSE4.1 (x86).

      SimdCpuInfoAvx 

      Availability of AVX (x86).

      diff --git a/docs/help/group__info.html b/docs/help/group__info.html index adbebcacf5..6a05e68305 100644 --- a/docs/help/group__info.html +++ b/docs/help/group__info.html @@ -53,7 +53,7 @@

      Simd Library Documentation.

      SIMD_API const char * SimdVersion ()
       Gets version of Simd Library. More...
       
      SIMD_API size_t SimdCpuInfo (SimdCpuInfoType type)
      SIMD_API uint64_t SimdCpuInfo (SimdCpuInfoType type)
       Gets info about CPU and Simd Library. More...
       
      SIMD_API const char * SimdPerformanceStatistic ()
       
      enum  SimdCpuDescType { SimdCpuDescModel + }
       
      enum  SimdCpuInfoType {
        SimdCpuInfoSockets ,
      @@ -305,6 +308,25 @@

      +

      ◆ SimdCpuDescType

      + +
      +
      + + + + +
      enum SimdCpuDescType
      +
      +

      Describes type of description which can return function SimdCpuDesc.

      + + +
      Enumerator
      SimdCpuDescModel 

      A CPU model name.

      +
      +
      @@ -318,7 +340,7 @@

      -

      Describes type of information which can return function SimdCpuInfo.

      +

      Describes type of information which can return function SimdCpuInfo.

      diff --git a/docs/help/group__descrint.html b/docs/help/group__descrint.html index 3f770e1d9c..e8af581b92 100644 --- a/docs/help/group__descrint.html +++ b/docs/help/group__descrint.html @@ -114,7 +114,6 @@

      Initilizes Integer Descriptor Engine.

      -

      All images must have the same width and height. This function used for NV12 to YUV420P conversion.

      Parameters

      Enumerator
      SimdCpuInfoSockets 

      A number of sockets.

      diff --git a/docs/help/group__info.html b/docs/help/group__info.html index 6a05e68305..dbb545af25 100644 --- a/docs/help/group__info.html +++ b/docs/help/group__info.html @@ -53,14 +53,17 @@

      Simd Library Documentation.

      - - - + + + + + + - +
      [in]size- a length of original (32-bit or 16-bit) float descriptor. It be multiple of 8. Also it must be less or equal than 32768.
      SIMD_API const char * SimdVersion ()
       Gets version of Simd Library. More...
       
      SIMD_API uint64_t SimdCpuInfo (SimdCpuInfoType type)
       Gets info about CPU and Simd Library. More...
       
      SIMD_API const char * SimdCpuDesc (SimdCpuDescType type)
       Gets description of CPU and Simd Library. More...
       
      SIMD_API uint64_t SimdCpuInfo (SimdCpuInfoType type)
       Gets information about CPU and Simd Library. More...
       
      SIMD_API const char * SimdPerformanceStatistic ()
       Gets internal performance statistics of Simd Library. More...
       
      SIMD_INLINE void PrintInfo (std::ostream &os)
       Prints information about Simd Library and CPU properties. More...
       Prints information about Simd Library and CPU. More...
       

      Detailed Description

      @@ -86,14 +89,50 @@

      -

      ◆ SimdCpuInfo()

      + +

      ◆ SimdCpuDesc()

      - + + + + + + +
      size_t SimdCpuInfo const char * SimdCpuDesc (SimdCpuDescType type)
      +
      + +

      Gets description of CPU and Simd Library.

      +
      Note
      See enumeration SimdCpuDescType.
      +

      Using example:

      #include "Simd/SimdLib.h"
      +#include <iostream>
      +
      +int main()
      +{
      +    std::cout << "CPU: " << SimdCpuDesc(SimdCpuDescModel) << std::endl;
      +    return 0;
      +}
      +
      Parameters
      + + +
      [in]type- a type of required description.
      +
      +
      +
      Returns
      a value which contains description of CPU and Simd Library.
      + +
      +
      + +

      ◆ SimdCpuInfo()

      + +
      +
      + + + @@ -102,7 +141,7 @@

      -

      Gets info about CPU and Simd Library.

      +

      Gets information about CPU and Simd Library.

      Note
      See enumeration SimdCpuInfoType.

      Using example:

      #include "Simd/SimdLib.h"
       #include <iostream>
      @@ -175,7 +214,7 @@ 

      -

      Prints information about Simd Library and CPU properties.

      +

      Prints information about Simd Library and CPU.

      Parameters

      uint64_t SimdCpuInfo ( SimdCpuInfoType  type)
      diff --git a/docs/help/group__synet__normalize.html b/docs/help/group__synet__normalize.html index ce8b59c272..5c04681a62 100644 --- a/docs/help/group__synet__normalize.html +++ b/docs/help/group__synet__normalize.html @@ -352,9 +352,7 @@

      Performs forward propagation of NormalizeLayer (Version 3).

      -
      Algorithm's details:
      -\verbatim
      -for(b = 0; b < batch; ++b)
      +

      Algorithm's details:

      for(b = 0; b < batch; ++b)
           for(c = 0; c < channels; ++c)
           {
               sum = 0;
      @@ -383,7 +381,7 @@ 

      [in]

      - +
      [in,out]os- output stream.
      eps- a pointer to epsilon parameter. It is used to prevent division by zero.
      [in]format- a format of input and output tensor. It can be SimdTensorFormatNchw, SimdTensorFormatNhwc.
      [out]buf- a pointer to external temporary buffer. The size of the buffer must be equal to channels. Can be NULL (it causes usage of internal buffer).
      [out]dst- a pointer to the output 32-bit float tensor.
      [out]dst- a pointer to the output 32-bit float tensor.
      diff --git a/docs/help/namespace_simd.html b/docs/help/namespace_simd.html index a9e0d5cf3c..bc112f2dbe 100644 --- a/docs/help/namespace_simd.html +++ b/docs/help/namespace_simd.html @@ -306,7 +306,7 @@

      Simd Library Documentation.

      SIMD_INLINE bool Compatible (const View< A > &a, const View< A > &b, const View< A > &c, const View< A > &d, const View< A > &e)
       
      SIMD_INLINE void PrintInfo (std::ostream &os)
       Prints information about Simd Library and CPU properties. More...
       Prints information about Simd Library and CPU. More...
       
      template<template< class > class A>
      SIMD_INLINE void AbsDifference (const View< A > &a, const View< A > &b, View< A > &c)