Skip to content

Commit

Permalink
alternative approach
Browse files Browse the repository at this point in the history
  • Loading branch information
GuyAv46 committed Dec 30, 2024
1 parent f2b1bfb commit 8ebd561
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 15 deletions.
25 changes: 22 additions & 3 deletions src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_INT8.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "VecSim/spaces/space_includes.h"

static inline void InnerProductStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
static inline void InnerProductHalfStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
__m256i temp_a = _mm256_loadu_epi8(pVect1);
__m512i va = _mm512_cvtepi8_epi16(temp_a);
pVect1 += 32;
Expand All @@ -22,6 +22,26 @@ static inline void InnerProductStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &s
sum = _mm512_dpwssd_epi32(sum, va, vb);
}

static inline void InnerProductStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
__m512i zeros = _mm512_setzero_si512();

__m512i va = _mm512_loadu_epi8(pVect1); // AVX512BW
pVect1 += 64;
__m512i va_ext = _mm512_movm_epi8(_mm512_cmplt_epi8_mask(va, zeros)); // AVX512BW

__m512i vb = _mm512_loadu_epi8(pVect2); // AVX512BW
pVect2 += 64;
__m512i vb_ext = _mm512_movm_epi8(_mm512_cmplt_epi8_mask(vb, zeros)); // AVX512BW

__m512i va_lo = _mm512_unpacklo_epi8(va, va_ext); // AVX512BW
__m512i vb_lo = _mm512_unpacklo_epi8(vb, vb_ext);
sum = _mm512_dpwssd_epi32(sum, va_lo, vb_lo);

__m512i va_hi = _mm512_unpackhi_epi8(va, va_ext); // AVX512BW
__m512i vb_hi = _mm512_unpackhi_epi8(vb, vb_ext);
sum = _mm512_dpwssd_epi32(sum, va_hi, vb_hi);
}

template <unsigned char residual> // 0..63
static inline int INT8_InnerProductImp(const void *pVect1v, const void *pVect2v, size_t dimension) {
int8_t *pVect1 = (int8_t *)pVect1v;
Expand All @@ -47,13 +67,12 @@ static inline int INT8_InnerProductImp(const void *pVect1v, const void *pVect2v,
}

if constexpr (residual >= 32) {
InnerProductStep(pVect1, pVect2, sum);
InnerProductHalfStep(pVect1, pVect2, sum);
}

// We dealt with the residual part. We are left with some multiple of 64-int_8.
while (pVect1 < pEnd1) {
InnerProductStep(pVect1, pVect2, sum);
InnerProductStep(pVect1, pVect2, sum);
}

return _mm512_reduce_add_epi32(sum);
Expand Down
20 changes: 11 additions & 9 deletions src/VecSim/spaces/IP/IP_AVX512F_BW_VL_VNNI_UINT8.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,19 @@
#include "VecSim/spaces/space_includes.h"

static inline void InnerProductStep(uint8_t *&pVect1, uint8_t *&pVect2, __m512i &sum) {
for (int i = 0; i < 2; i++) {
__m256i temp_a = _mm256_loadu_epi8(pVect1); // AVX512BW
__m512i va = _mm512_cvtepu8_epi16(temp_a);
pVect1 += 32;
__m512i va = _mm512_loadu_epi8(pVect1); // AVX512BW
pVect1 += 64;

__m256i temp_b = _mm256_loadu_epi8(pVect2); // AVX512BW
__m512i vb = _mm512_cvtepu8_epi16(temp_b);
pVect2 += 32;
__m512i vb = _mm512_loadu_epi8(pVect2); // AVX512BW
pVect2 += 64;

sum = _mm512_dpwssd_epi32(sum, va, vb);
}
__m512i va_lo = _mm512_unpacklo_epi8(va, _mm512_setzero_si512()); // AVX512BW
__m512i vb_lo = _mm512_unpacklo_epi8(vb, _mm512_setzero_si512());
sum = _mm512_dpwssd_epi32(sum, va_lo, vb_lo);

__m512i va_hi = _mm512_unpackhi_epi8(va, _mm512_setzero_si512()); // AVX512BW
__m512i vb_hi = _mm512_unpackhi_epi8(vb, _mm512_setzero_si512());
sum = _mm512_dpwssd_epi32(sum, va_hi, vb_hi);

// _mm512_dpwssd_epi32(src, a, b)
// Multiply groups of 2 adjacent pairs of signed 16-bit integers in `a` with corresponding
Expand Down
27 changes: 24 additions & 3 deletions src/VecSim/spaces/L2/L2_AVX512F_BW_VL_VNNI_INT8.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

#include "VecSim/spaces/space_includes.h"

static inline void L2SqrStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
static inline void L2SqrHalfStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
__m256i temp_a = _mm256_loadu_epi8(pVect1);
__m512i va = _mm512_cvtepi8_epi16(temp_a);
pVect1 += 32;
Expand All @@ -23,6 +23,28 @@ static inline void L2SqrStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
sum = _mm512_dpwssd_epi32(sum, diff, diff);
}

static inline void L2SqrStep(int8_t *&pVect1, int8_t *&pVect2, __m512i &sum) {
__m512i zeros = _mm512_setzero_si512();

__m512i va = _mm512_loadu_epi8(pVect1); // AVX512BW
pVect1 += 64;
__m512i va_ext = _mm512_movm_epi8(_mm512_cmplt_epi8_mask(va, zeros)); // AVX512BW

__m512i vb = _mm512_loadu_epi8(pVect2); // AVX512BW
pVect2 += 64;
__m512i vb_ext = _mm512_movm_epi8(_mm512_cmplt_epi8_mask(vb, zeros)); // AVX512BW

__m512i va_lo = _mm512_unpacklo_epi8(va, va_ext); // AVX512BW
__m512i vb_lo = _mm512_unpacklo_epi8(vb, vb_ext);
__m512i diff_lo = _mm512_sub_epi16(va_lo, vb_lo);
sum = _mm512_dpwssd_epi32(sum, diff_lo, diff_lo);

__m512i va_hi = _mm512_unpackhi_epi8(va, va_ext); // AVX512BW
__m512i vb_hi = _mm512_unpackhi_epi8(vb, vb_ext);
__m512i diff_hi = _mm512_sub_epi16(va_hi, vb_hi);
sum = _mm512_dpwssd_epi32(sum, diff_hi, diff_hi);
}

template <unsigned char residual> // 0..63
float INT8_L2SqrSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect2v,
size_t dimension) {
Expand Down Expand Up @@ -50,13 +72,12 @@ float INT8_L2SqrSIMD64_AVX512F_BW_VL_VNNI(const void *pVect1v, const void *pVect
}

if constexpr (residual >= 32) {
L2SqrStep(pVect1, pVect2, sum);
L2SqrHalfStep(pVect1, pVect2, sum);
}

// We dealt with the residual part. We are left with some multiple of 64-int_8.
while (pVect1 < pEnd1) {
L2SqrStep(pVect1, pVect2, sum);
L2SqrStep(pVect1, pVect2, sum);
}

return _mm512_reduce_add_epi32(sum);
Expand Down

0 comments on commit 8ebd561

Please sign in to comment.