diff --git a/src/math/matrix.c b/src/math/matrix.c index d9c76cfeecff..b851e132b2c2 100644 --- a/src/math/matrix.c +++ b/src/math/matrix.c @@ -82,36 +82,36 @@ int mat_multiply(struct mat_matrix_16b *a, struct mat_matrix_16b *b, int mat_multiply_elementwise(struct mat_matrix_16b *a, struct mat_matrix_16b *b, struct mat_matrix_16b *c) { + /* Validate matrix dimensions and non-null pointers */ + if (!a || !b || !c || + a->columns != b->columns || a->rows != b->rows || + c->columns != a->columns || c->rows != a->rows) { + return -EINVAL; + } + int16_t *x = a->data; int16_t *y = b->data; int16_t *z = c->data; - int64_t p; - int i; - const int shift_minus_one = a->fractions + b->fractions - c->fractions - 1; + int32_t prod; - if (a->columns != b->columns || b->columns != c->columns || - a->rows != b->rows || b->rows != c->rows) { - return -EINVAL; - } + /* Compute the total number of elements in the matrices */ + const int total_elements = a->rows * a->columns; + /* Compute the required bit shift based on the fractional part of each matrix */ + const int shift = a->fractions + b->fractions - c->fractions - 1; - /* If all data is Q0 */ - if (shift_minus_one == -1) { - for (i = 0; i < a->rows * a->columns; i++) { + /* Perform multiplication with or without adjusting the fractional bits */ + if (shift == -1) { + /* Direct multiplication when no adjustment for fractional bits is needed */ + for (int i = 0; i < total_elements; i++, x++, y++, z++) *z = *x * *y; - x++; - y++; - z++; + } else { + /* Multiplication with rounding to account for the fractional bits */ + for (int i = 0; i < total_elements; i++, x++, y++, z++) { + /* Multiply elements as int32_t */ + prod = (int32_t)(*x) * *y; + /* Adjust and round the result */ + *z = (int16_t)(((prod >> shift) + 1) >> 1); } - - return 0; - } - - for (i = 0; i < a->rows * a->columns; i++) { - p = (int32_t)(*x) * *y; - *z = (int16_t)(((p >> shift_minus_one) + 1) >> 1); /*Shift to Qx.y */ - x++; - y++; - z++; } return 0;