Skip to content

Commit

Permalink
Math: Optimize 16Bit elementwise matrix multiplication function
Browse files Browse the repository at this point in the history
Implemented optimizations in the 16-bit elementwise
matrix multiplication function by changing accumulator
data type from int64_t to int32_t. This reduces the
instruction cycle count i.e. reducing cycle count by
~51.18%.

Enhance pointer arithmetic within loops for better
readability and compiler optimization opportunities

Eliminate unnecessary conditionals by directly
handling Q0 data in the algorithm's core logic

Update fractional bit shift and rounding logic for more
accurate fixed-point calcualations

Performance gains from these optimizations include a 1.08%
reduction in memory usage for the elementwise matrix
multiplication.

Signed-off-by: Shriram Shastry <[email protected]>
  • Loading branch information
ShriramShastry committed Jun 13, 2024
1 parent 40b5e6e commit f447221
Showing 1 changed file with 24 additions and 23 deletions.
47 changes: 24 additions & 23 deletions src/math/matrix.c
Original file line number Diff line number Diff line change
Expand Up @@ -81,36 +81,37 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
*/
int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b,
struct mat_matrix_16b *c)
{ int64_t p;
{
/* Validate matrix dimensions and non-null pointers */
if (!a || !b || !c ||
a->columns != b->columns || a->rows != b->rows ||
c->columns != a->columns || c->rows != a->rows) {
return -EINVAL;
}

int16_t *x = a->data;
int16_t *y = b->data;
int16_t *z = c->data;
int i;
const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1;
int32_t prod;

if (a->columns != b->columns || b->columns != c->columns ||
a->rows != b->rows || b->rows != c->rows) {
return -EINVAL;
}
/* Compute the total number of elements in the matrices */
const int total_elements = a->rows * a->columns;
/* Compute the required bit shift based on the fractional part of each matrix */
const int shift = a->fractions + b->fractions - c->fractions - 1;

/* If all data is Q0 */
if (shift_minus_one == -1) {
for (i = 0; i < a->rows * a->columns; i++) {
/* Perform multiplication with or without adjusting the fractional bits */
if (shift == -1) {
/* Direct multiplication when no adjustment for fractional bits is needed */
for (int i = 0; i < total_elements; i++, x++, y++, z++)
*z = *x * *y;
x++;
y++;
z++;
} else {
/* Multiplication with rounding to account for the fractional bits */
for (int i = 0; i < total_elements; i++, x++, y++, z++) {
/* Multiply elements as int32_t */
prod = (int32_t)(*x) * *y;
/* Adjust and round the result */
*z = (int16_t)(((prod >> shift) + 1) >> 1);
}

return 0;
}

for (i = 0; i < a->rows * a->columns; i++) {
p = (int32_t)(*x) * *y;
*z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */
x++;
y++;
z++;
}

return 0;
Expand Down

0 comments on commit f447221

Please sign in to comment.