Skip to content

Commit

Permalink
Math: Optimize 16-bit matrix multiplication function
Browse files Browse the repository at this point in the history
Implemented optimizations in the 16-bit matrix multiplication function
by changing accumulator data type from int64_t to int32_t. This reduces
the instruction cycle count i.e. by ~8.18% for matrix multiplication.

Enhance pointer arithmetic within loops for better readability and
compiler optimization. Eliminate unnecessary conditionals by
directly handling Q0 data in the algorithm core logic.

Performance gains from these optimisation include a 36.31% reduction
in memory usage for matrix multiplication function

Signed-off-by: Shriram Shastry <[email protected]>
  • Loading branch information
ShriramShastry committed Jun 13, 2024
1 parent 862c645 commit 40b5e6e
Showing 1 changed file with 22 additions and 47 deletions.
69 changes: 22 additions & 47 deletions src/math/matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -24,58 +24,41 @@
* -EINVAL if input dimensions do not allow for multiplication.
* -ERANGE if the shift operation might cause integer overflow.
*/
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c)
int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
struct mat_matrix_16b *c)
{
/* Validate matrix dimensions are compatible for multiplication */
if (a->columns != b->rows || a->rows != c->rows || b->columns != c->columns)
return -EINVAL;

int64_t s;
int16_t *x;
int16_t *y;
int16_t *z = c->data;
int i, j, k;
int y_inc = b->columns;
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;


int32_t acc; /* Accumulator for dot product calculation */
int16_t *x, *y, *z = c->data; /* Pointers for matrices a, b, and c */
int i, j, k; /* Loop counters */
int y_inc = b->columns; /* Column increment for matrix b elements */
/* Calculate shift amount for adjusting fractional bits in the result */
const int shift = a->fractions + b->fractions - c->fractions - 1;

/* Check shift to ensure no integer overflow occurs during shifting */
if (shift < -1 || shift > 31)
return -ERANGE;

/* If all data is Q0 */
if (shift_minus_one == -1) {
for (i = 0; i < a->rows; i++) {
for (j = 0; j < b->columns; j++) {
s = 0;
x = a->data + a->columns * i;
y = b->data + j;
for (k = 0; k < b->rows; k++) {
s += (int32_t)(*x) * (*y);
x++;
y += y_inc;
}
*z = (int16_t)s; /* For Q16.0 */
z++;
}
}

return 0;
}

/* Matrix multiplication loop */
for (i = 0; i < a->rows; i++) {
for (j = 0; j < b->columns; j++) {
s = 0;
x = a->data + a->columns * i;
y = b->data + j;
acc = 0; /* Initialize accumulator for each element */
x = a->data + a->columns * i; /* Set x at the start of ith row of a */
y = b->data + j; /* Set y at the top of jth column of b */
/* Dot product loop */
for (k = 0; k < b->rows; k++) {
s += (int32_t)(*x) * (*y);
x++;
y += y_inc;
acc += (int32_t)(*x++) * (*y); /* Multiply & accumulate */
y += y_inc; /* Move to next row in the current column of b */
}
*z = (int16_t)(((s >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
z++;
/* Assign computed value to c matrix, adjusting for fractional bits */
if (shift == -1)
*z = (int16_t)acc;
else
*z = (int16_t)(((acc >> shift) + 1) >> 1);
z++; /* Move to the next element in the output matrix */
}
}
return 0;
Expand All @@ -98,15 +81,7 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_
*/
int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
struct mat_matrix_16b *c)
{
/* Validate matrix dimensions and non-null pointers */
if (!a || !b || !c ||
a->columns != b->columns || a->rows != b->rows ||
c->columns != a->columns || c->rows != a->rows) {
return -EINVAL;
}

int64_t p;
{ int64_t p;
int16_t *x = a->data;
int16_t *y = b->data;
int16_t *z = c->data;
Expand Down

0 comments on commit 40b5e6e

Please sign in to comment.