Skip to content

Commit

Permalink
mul!(dense | sparse, sparse, sparse) support
Browse files Browse the repository at this point in the history
  • Loading branch information
alyst committed Sep 17, 2024
1 parent c26ad83 commit 2002b71
Show file tree
Hide file tree
Showing 2 changed files with 205 additions and 0 deletions.
141 changes: 141 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,147 @@ function mm!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_d
return C
end

# C := op(A) * B, where C is sparse
function spmm(transA::Char, A::AbstractSparseMatrix{T}, B::AbstractSparseMatrix{T}) where T
check_trans(transA)
check_mat_op_sizes(nothing, A, transA, B, 'N')
Cout = Ref{sparse_matrix_t}()
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_spmmI}(), typeof(A),
transA, hA, hB, Cout)
destroy(hA)
destroy(hB)
check_status(res)
return MKLSparseMatrix(Cout[])
end

# C := op(A) * B, where C is dense
function spmmd!(transa::Char, A::AbstractSparseMatrix{T}, B::AbstractSparseMatrix{T},
C::StridedMatrix{T};
dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR
) where T
check_trans(transa)
check_mat_op_sizes(C, A, transa, B, 'N')
ldC = stride(C, 2)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_T_spmmdI}(), typeof(A),
transa, hA, hB, dense_layout, C, ldC)
destroy(hA)
destroy(hB)
check_status(res)
return C
end

# C := opA(A) * opB(B), where C is sparse
function sp2m(transA::Char, A::AbstractSparseMatrix{T}, descrA::matrix_descr,
transB::Char, B::AbstractSparseMatrix{T}, descrB::matrix_descr) where T
check_trans(transA)
check_trans(transB)
check_mat_op_sizes(nothing, A, transA, B, transB)
Cout = Ref{sparse_matrix_t}()
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A),
transA, descrA, hA, transB, descrB, hB,
SPARSE_STAGE_FULL_MULT, Cout)
destroy(hA)
destroy(hB)
check_status(res)
# NOTE: we are guessing what is the storage format of C
return MKLSparseMatrix{typeof(A)}(Cout[])
end

# C := opA(A) * opB(B), where C is sparse, in-place version
# C should have the correct size and sparsity pattern
function sp2m!(transA::Char, A::AbstractSparseMatrix{T}, descrA::matrix_descr,
transB::Char, B::AbstractSparseMatrix{T}, descrB::matrix_descr,
C::SparseMatrixCSC{T};
check_nzpattern::Bool = true
) where T
check_trans(transA)
check_trans(transB)
check_mat_op_sizes(C, A, transA, B, transB)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
if check_nzpattern
# pre-multiply A * B to get the number of nonzeros per column in the result
CptnOut = Ref{sparse_matrix_t}()
res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A),
transA, descrA, hA, transB, descrB, hB,
SPARSE_STAGE_NNZ_COUNT, CptnOut)
check_status(res)
hCptn = MKLSparseMatrix{typeof(A)}(CptnOut[])
try
# check if C has the same per-column nonzeros as the result
_C = extract_data(hCptn)
_Cnnz = _C.major_starts[end] - 1
nnz(C) == _Cnnz || error(lazy"Number of nonzeros in the destination matrix ($(nnz(C))) does not match the result ($(_Cnnz))")
C.colptr == _C.major_starts || error(lazy"Nonzeros structure of the destination matrix does not match the result")
catch e
# destroy handles to A and B if the pattern check fails,
# otherwise reuse them at the actual multiplication
destroy(hA)
destroy(hB)
rethrow(e)
finally
destroy(hCptn)
end
# FIXME rowval not checked
end
# FIXME the optimal way would be to create the MKLSparse handle to C reusing its arrays
# and do SPARSE_STAGE_FINALIZE_MULT to directly write to the C.nzval
# but that causes segfaults when the handle is destroyed
# (also the partial mkl_sparse_copy(C) workaround to reuse the nz structure segfaults)
#hC = MKLSparseMatrix(C)
#hC_ref = Ref(hC)
#res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A),
# transA, descrA, hA, transB, descrB, hB,
# SPARSE_STAGE_FINALIZE_MULT, hC_ref)
#@assert hC_ref[] == hC
# so instead we do the full multiplication and copy the result into C nzvals
hCopy_ref = Ref{sparse_matrix_t}()
res = mkl_call(Val{:mkl_sparse_sp2mI}(), typeof(A),
transA, descrA, hA, transB, descrB, hB,
SPARSE_STAGE_FULL_MULT, hCopy_ref)
destroy(hA)
destroy(hB)
check_status(res)
if hCopy_ref[] != C_NULL
hCopy = MKLSparseMatrix{typeof(A)}(hCopy_ref[])
copy!(C, hCopy; check_nzpattern)
destroy(hCopy)
end
return C
end

# C := alpha * opA(A) * opB(B) + beta * C, where C is dense
function sp2md!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descrA::matrix_descr,
transB::Char, B::AbstractSparseMatrix{T}, descrB::matrix_descr,
beta::T, C::StridedMatrix{T};
dense_layout::sparse_layout_t = SPARSE_LAYOUT_COLUMN_MAJOR
) where T
check_trans(transA)
check_trans(transB)
check_mat_op_sizes(C, A, transA, B, transB)
ldC = stride(C, 2)
hA = MKLSparseMatrix(A)
hB = MKLSparseMatrix(B)
res = mkl_call(Val{:mkl_sparse_T_sp2mdI}(), typeof(A),
transA, descrA, hA, transB, descrB, hB,
alpha, beta,
C, dense_layout, ldC)
if res != SPARSE_STATUS_SUCCESS
@show transA descrA transB descrB
end
destroy(hA)
destroy(hB)
check_status(res)
return C
end


# find y: op(A) * y = alpha * x
function trsv!(transA::Char, alpha::T, A::AbstractSparseMatrix{T}, descr::matrix_descr,
x::StridedVector{T}, y::StridedVector{T}
Expand Down
64 changes: 64 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,27 @@ function mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
mm!(transA, T(alpha), unwrapA, descrA, B, T(beta), C)
end

# mul!(dense, sparse, sparse, a, b)
function mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::SimpleOrSpecialOrAdjMat{T, S}, alpha::Number, beta::Number
) where {T <: BlasFloat, S <: SparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
transB, descrB, unwrapB = describe_and_unwrap(B)
# FIXME only general matrices are supported by sp2m in MKL SparseBLAS
# should the elements of the special matrices be fixed?
if descrA.type == SPARSE_MATRIX_TYPE_SYMMETRIC
@assert issymmetric(unwrapA) "A must be symmetric"
end
if descrB.type == SPARSE_MATRIX_TYPE_SYMMETRIC
@assert issymmetric(unwrapB) "B must be symmetric"
end
descrA = matrix_descr(descrA, type = SPARSE_MATRIX_TYPE_GENERAL, diag = SPARSE_DIAG_NON_UNIT, mode = SPARSE_FILL_MODE_FULL)
descrB = matrix_descr(descrB, type = SPARSE_MATRIX_TYPE_GENERAL, diag = SPARSE_DIAG_NON_UNIT, mode = SPARSE_FILL_MODE_FULL)
sp2md!(transA, T(alpha), unwrapA, descrA,
transB, unwrapB, descrB,
T(beta), C)
end

# mul!(dense, dense, sparse, a, b)
# ColMajorRes = ColMajorMtx*SparseMatrixCSC is implemented via
# RowMajorRes = SparseMatrixCSR*RowMajorMtx Sparse MKL BLAS calls
Expand Down Expand Up @@ -87,6 +108,34 @@ mul!(C::StridedMatrix{T}, A::StridedMatrix{T},
B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
mul!(C, A, B, one(T), zero(T))

# mul!(dense, sparse, sparse) calls sp2md!()
mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::SimpleOrSpecialOrAdjMat{T, S}
) where {T <: BlasFloat, S <: SparseMat{T}} =
mul!(C, A, B, one(T), zero(T))

# mul!(sparse, sparse, sparse)
mul!(C::SparseMatrixCSC{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::SimpleOrSpecialOrAdjMat{T, S}
) where {T <: BlasFloat, S <: SparseMat{T}} =
unsafe_mul!(C, A, B; check_nzpattern = true)

# unsafe_mul!() allows disabling the check for the result's non-zero pattern
function unsafe_mul!(C::SparseMatrixCSC{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::SimpleOrSpecialOrAdjMat{T, S};
check_nzpattern::Bool = true
) where {T <: BlasFloat, S <: SparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
transB, descrB, unwrapB = describe_and_unwrap(B)
# FIXME only general matrices are supported by sp2m in MKL SparseBLAS
# should the elements of the special matrices be fixed?
descrA = matrix_descr(descrA, type = SPARSE_MATRIX_TYPE_GENERAL)
descrB = matrix_descr(descrB, type = SPARSE_MATRIX_TYPE_GENERAL)
sp2m!(transA, unwrapA, descrA,
transB, unwrapB, descrB,
parent(C); check_nzpattern)
end

# define 4-arg ldiv!(C, A, B, a) (C := alpha*inv(A)*B) that is not present in standard LinearAlgrebra
# redefine 3-arg ldiv!(C, A, B) using 4-arg ldiv!(C, A, B, 1)
function ldiv!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S},
Expand All @@ -101,6 +150,21 @@ function LinearAlgebra.ldiv!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T,
trsm!(transA, alpha, unwrapA, descrA, B, C)
end

# sparse := sparse * sparse
function (*)(A::SimpleOrSpecialOrAdjMat{T, S},
B::SimpleOrSpecialOrAdjMat{T, S}
) where {T <: BlasFloat, S <: SparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
transB, descrB, unwrapB = describe_and_unwrap(B)
# FIXME only general matrices are supported by sp2m in MKL SparseBLAS
# should the elements of the special matrices be fixed?
descrA = matrix_descr(descrA, type = SPARSE_MATRIX_TYPE_GENERAL)
descrB = matrix_descr(descrB, type = SPARSE_MATRIX_TYPE_GENERAL)
res = sp2m(transA, unwrapA, descrA,
transB, unwrapB, descrB)
return convert(S, res)
end

if VERSION < v"1.10"
# stdlib v1.9 does not provide these methods

Expand Down

0 comments on commit 2002b71

Please sign in to comment.