Skip to content

Commit

Permalink
copy!(SparseMatrixCSC, MKLSparseMatrix)
Browse files Browse the repository at this point in the history
  • Loading branch information
Alexey Stukalov authored and alyst committed Sep 17, 2024
1 parent 43d27e6 commit c26ad83
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion src/mklsparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,27 @@ Base.convert(::Type{SparseMatrixCSC}, A::MKLSparseMatrix{SparseMatrixCSC{Tv, Ti}
function Base.convert(::Type{S}, A::MKLSparseMatrix{S}) where {S <: SparseMatrixCSR}
_A = extract_data(A)
# not converting the col indices depending on index_base
@show length(_A.nzval)
return S(_A.size..., copy(_A.major_starts), copy(_A.minor_val), copy(_A.nzval))
end

Base.convert(::Type{SparseMatrixCSR}, A::MKLSparseMatrix{SparseMatrixCSR{Tv, Ti}}) where {Tv, Ti} =
convert(SparseMatrixCSR{Tv, Ti}, A)

# copy the non-zero values from the MKL Sparse matrix A into the sparse matrix B
# A and B should have the same non-zero pattern
function Base.copy!(B::S, A::MKLSparseMatrix{S};
check_nzpattern::Bool = true) where {S <: SparseMatrixCSC}
_A = extract_data(A)
Ti = eltype(B.rowval)
length(_A.nzval) == nnz(B) || error(lazy"Number of nonzeros in the source ($(length(_A.nzval))) does not match the destination matrix ($(nnz(B)))")
size(B) == _A.size || throw(DimensionMismatch(lazy"Size of the source $(_A.size) does not match the destination $(size(B))"))
if check_nzpattern
B.colptr == _A.major_starts || error("Source and destination colptr do not match")
rowval_match = _A.index_base == SPARSE_INDEX_BASE_ZERO ?
all((a, b) -> a + one(Ti) == b, zip(_A.minor_val, B.rowval)) : # convert to 1-based
_A.minor_val == B.rowval
rowval_match || error("Source and destination rowval do not match")
end
(pointer(B.nzval) != pointer(_A.nzval)) && copy!(B.nzval, _A.nzval)
return B
end

0 comments on commit c26ad83

Please sign in to comment.