From 40b5e6e09cb56627c338b7e7c054a7a09a6e1786 Mon Sep 17 00:00:00 2001 From: Shriram Shastry Date: Thu, 13 Jun 2024 11:06:22 +0530 Subject: [PATCH] Math: Optimize 16-bit matrix multiplication function Implemented optimizations in the 16-bit matrix multiplication function by changing accumulator data type from int64_t to int32_t. This reduces the instruction cycle count i.e. by ~8.18% for matrix multiplication. Enhance pointer arithmetic within loops for better readability and compiler optimization. Eliminate unnecessary conditionals by directly handling Q0 data in the algorithm core logic. Performance gains from these optimisation include a 36.31% reduction in memory usage for matrix multiplication function Signed-off-by: Shriram Shastry --- src/math/matrix.c | 69 +++++++++++++++-------------------------------- 1 file changed, 22 insertions(+), 47 deletions(-) diff --git a/src/math/matrix.c b/src/math/matrix.c index 7ff418178650..e1a160b51516 100644 --- a/src/math/matrix.c +++ b/src/math/matrix.c @@ -24,58 +24,41 @@ * -EINVAL if input dimensions do not allow for multiplication. * -ERANGE if the shift operation might cause integer overflow. */ -int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) +int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, + struct mat_matrix_16b *c) { /* Validate matrix dimensions are compatible for multiplication */ if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns) return -EINVAL; - int64_t s; - int16_t *x; - int16_t *y; - int16_t *z = c->data; - int i, j, k; - int y_inc = b->columns; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; - - + int32_t acc; /* Accumulator for dot product calculation */ + int16_t *x, *y, *z = c->data; /* Pointers for matrices a, b, and c */ + int i, j, k; /* Loop counters */ + int y_inc = b->columns; /* Column increment for matrix b elements */ + /* Calculate shift amount for adjusting fractional bits in the result */ + const int shift = a->fractions + b->fractions - c->fractions - 1; /* Check shift to ensure no integer overflow occurs during shifting */ if (shift < -1 || shift > 31) return -ERANGE; - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows; i++) { - for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; - for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; - } - *z = (int16_t)s; /* For Q16.0 */ - z++; - } - } - - return 0; - } - + /* Matrix multiplication loop */ for (i = 0; i < a->rows; i++) { for (j = 0; j < b->columns; j++) { - s = 0; - x = a->data + a->columns * i; - y = b->data + j; + acc = 0; /* Initialize accumulator for each element */ + x = a->data + a->columns * i; /* Set x at the start of ith row of a */ + y = b->data + j; /* Set y at the top of jth column of b */ + /* Dot product loop */ for (k = 0; k < b->rows; k++) { - s += (int32_t)(*x) * (*y); - x++; - y += y_inc; + acc += (int32_t)(*x++) * (*y); /* Multiply & accumulate */ + y += y_inc; /* Move to next row in the current column of b */ } - *z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - z++; + /* Assign computed value to c matrix, adjusting for fractional bits */ + if (shift == -1) + *z = (int16_t)acc; + else + *z = (int16_t)(((acc >> shift) + 1) >> 1); + z++; /* Move to the next element in the output matrix */ } } return 0; @@ -98,15 +81,7 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_ */ int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) -{ - /* Validate matrix dimensions and non-null pointers */ - if (!a || !b || !c || - a->columns != b->columns || a->rows != b->rows || - c->columns != a->columns || c->rows != a->rows) { - return -EINVAL; - } - - int64_t p; +{ int64_t p; int16_t *x = a->data; int16_t *y = b->data; int16_t *z = c->data;