Skip to content

Commit

Permalink
spmm() & spmmd!() (untested)
Browse files Browse the repository at this point in the history
  • Loading branch information
alyst committed Jan 3, 2025
1 parent b01f96e commit 3b9b800
Showing 1 changed file with 37 additions and 0 deletions.
37 changes: 37 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,43 @@ function mm!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_d
return C
end

# C := op(A) * B, where C is sparse
function spmm(transA::Char, A::S, B::S) where {S <: AbstractSparseMatrix}
check_trans(transA)
check_mat_op_sizes(nothing, A, transA, B, 'N')
Cout = Ref{sparse_matrix_t}()
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_spmmI}(), S,
transA, hA, hB, Cout)
destroy(hA)
destroy(hB)
check_status(res)
# NOTE: we are guessing what is the storage format of C
hC = MKLSparseMatrix{S}(Cout[])
C = convert(S, hC)
destroy(hC)
return C
end

# C := op(A) * B, where C is dense
function spmmd!(transa::Char, A::S, B::S,
C::StridedMatrix{T};
dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR
) where {S <: AbstractSparseMatrix{T}} where T
check_trans(transa)
check_mat_op_sizes(C, A, transa, B, 'N')
ldC = stride(C, 2)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_T_spmmdI}(), S,
transa, hA, hB, dense_layout, C, ldC)
destroy(hA)
destroy(hB)
check_status(res)
return C
end

# C := opA(A) * opB(B), where C is sparse
function sp2m(transA::Char, A::S, descrA::matrix_descr,
transB::Char, B::S, descrB::matrix_descr
Expand Down

0 comments on commit 3b9b800

Please sign in to comment.