Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the new API #33

Merged
merged 3 commits into from
Jul 7, 2024
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Use the new API to support both LP64 and ILP64 API
amontoison committed Jul 7, 2024
commit 5e5b95b1e0ab5d8cbdd0912a336e72971747c301
6 changes: 3 additions & 3 deletions src/MKLSparse.jl
Original file line number Diff line number Diff line change
@@ -21,20 +21,20 @@ end
INTERFACE_GNU
end

function set_threading_layer(layer::Threading = THREADING_SEQUENTIAL)
function set_threading_layer(layer::Threading = THREADING_INTEL)
err = @ccall libmkl_rt.MKL_Set_Threading_Layer(layer::Cint)::Cint
(err == -1) && error("MKL_Set_Threading_Layer() returned -1")
return nothing
end

function set_interface_layer(interface::Interface = INTERFACE_ILP64)
function set_interface_layer(interface::Interface = INTERFACE_LP64)
err = @ccall libmkl_rt.MKL_Set_Interface_Layer(interface::Cint)::Cint
(err == -1) && error("MKL_Set_Interface_Layer() returned -1")
return nothing
end

function __init__()
set_interface_layer(Base.USE_BLAS64 ? INTERFACE_ILP64 : INTERFACE_LP64)
set_interface_layer(INTERFACE_LP64)
end

# Wrappers generated by Clang.jl
8 changes: 4 additions & 4 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -39,7 +39,7 @@ function _check_mat_mult_matvec(C, A, B, tA)
end

function cscmv!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T},
A::SparseMatrixCSC{T, Int32}, x::StridedVector{T},
β::T, y::StridedVector{T}) where {T <: BlasFloat}
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
@@ -52,7 +52,7 @@ function cscmv!(transa::Char, α::T, matdescra::String,
end

function cscmm!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T},
A::SparseMatrixCSC{T, Int32}, B::StridedMatrix{T},
β::T, C::StridedMatrix{T}) where {T <: BlasFloat}
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
@@ -67,7 +67,7 @@ function cscmm!(transa::Char, α::T, matdescra::String,
end

function cscsv!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T},
A::SparseMatrixCSC{T, Int32}, x::StridedVector{T},
y::StridedVector{T}) where {T <: BlasFloat}
n = checksquare(A)
_check_transa(transa)
@@ -81,7 +81,7 @@ function cscsv!(transa::Char, α::T, matdescra::String,
end

function cscsm!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T},
A::SparseMatrixCSC{T, Int32}, B::StridedMatrix{T},
C::StridedMatrix{T}) where {T <: BlasFloat}
mB, nB = size(B)
mC, nC = size(C)
85 changes: 44 additions & 41 deletions src/generic.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,53 @@
for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
for SparseMatrix in (:(SparseMatrixCSC{$T,BlasInt}), :(MKLSparse.SparseMatrixCSR{$T,BlasInt}), :(MKLSparse.SparseMatrixCOO{$T,BlasInt}))
INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,)
for INT in INT_TYPES
for SparseMatrix in (:(SparseMatrixCSC{$T,$INT}), :(MKLSparse.SparseMatrixCSR{$T,$INT}), :(MKLSparse.SparseMatrixCOO{$T,$INT}))

fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv")
fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm")
fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv")
fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm")
fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv" , mkl_integer_specifier(INT))
fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm" , mkl_integer_specifier(INT))
fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv", mkl_integer_specifier(INT))
fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm", mkl_integer_specifier(INT))

@eval begin
function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y)
return y
end
@eval begin
function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y)
return y
end

function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy)
return y
end
function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy)
return y
end

function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y)
return y
end
function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y)
return y
end

function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy)
return y
function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy)
return y
end
end
end
end
155 changes: 79 additions & 76 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,85 +1,88 @@
import Base: \, *
import LinearAlgebra: mul!, ldiv!

for T in (Float32, Float64, ComplexF32, ComplexF64)

tag_wrappers = ((identity , identity ),
(M -> :(Symmetric{$T, $M}), A -> :(parent($A))),
(M -> :(Hermitian{$T, $M}), A -> :(parent($A))))

triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))),
(M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A))))

op_wrappers = ((identity , 'N', identity ),
(M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))),
(M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A))))

for SparseMatrixType in (:(SparseMatrixCSC{$T, $BlasInt}), :(MKLSparse.SparseMatrixCOO{$T, $BlasInt}), :(MKLSparse.SparseMatrixCSR{$T, $BlasInt}))
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,)
for INT in INT_TYPES

tag_wrappers = ((identity , identity ),
(M -> :(Symmetric{$T, $M}), A -> :(parent($A))),
(M -> :(Hermitian{$T, $M}), A -> :(parent($A))))

triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))),
(M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A))))

op_wrappers = ((identity , 'N', identity ),
(M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))),
(M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A))))

for SparseMatrixType in (:(SparseMatrixCSC{$T, $INT}), :(MKLSparse.SparseMatrixCOO{$T, $INT}), :(MKLSparse.SparseMatrixCSR{$T, $INT}))
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end
end
end
end

for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(trianglea(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end

function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T})
# return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y)
return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y)
end

function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T})
# return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C)
return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C)
end

function (*)(A::$TypeA, x::StridedVector{$T})
m, n = size(A)
y = Vector{$T}(undef, m)
return mul!(y, A, x, one($T), zero($T))
end

function (*)(A::$TypeA, B::StridedMatrix{$T})
m, k = size(A)
p, n = size(B)
C = Matrix{$T}(undef, m, n)
return mul!(C, A, B, one($T), zero($T))
end

function (\)(A::$TypeA, x::StridedVector{$T})
n = length(x)
y = Vector{$T}(undef, n)
return ldiv!(y, A, x)
end

function (\)(A::$TypeA, B::StridedMatrix{$T})
m, n = size(B)
C = Matrix{$T}(undef, m, n)
return ldiv!(C, A, B)
for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(trianglea(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end

function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T})
# return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y)
return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y)
end

function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T})
# return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C)
return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C)
end

function (*)(A::$TypeA, x::StridedVector{$T})
m, n = size(A)
y = Vector{$T}(undef, m)
return mul!(y, A, x, one($T), zero($T))
end

function (*)(A::$TypeA, B::StridedMatrix{$T})
m, k = size(A)
p, n = size(B)
C = Matrix{$T}(undef, m, n)
return mul!(C, A, B, one($T), zero($T))
end

function (\)(A::$TypeA, x::StridedVector{$T})
n = length(x)
y = Vector{$T}(undef, n)
return ldiv!(y, A, x)
end

function (\)(A::$TypeA, B::StridedMatrix{$T})
m, n = size(B)
C = Matrix{$T}(undef, m, n)
return ldiv!(C, A, B)
end
end
end
end
58 changes: 31 additions & 27 deletions src/mklsparsematrix.jl
Original file line number Diff line number Diff line change
@@ -34,37 +34,41 @@ end
Base.unsafe_convert(::Type{sparse_matrix_t}, desc::MKLSparseMatrix) = desc.handle

for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,)
for INT in INT_TYPES

create_coo = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_coo")
create_csc = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csc")
create_csr = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csr")
create_coo = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_coo", mkl_integer_specifier(INT))
create_csc = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csc", mkl_integer_specifier(INT))
create_csr = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csr", mkl_integer_specifier(INT))
sparse_destroy = (INT == :Int32) ? :mkl_sparse_destroy : :mkl_sparse_destroy_64

@eval begin
# SparseMatrixCOO
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCOO{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_coo(descr_ref, IndexBase, A.m, A.n, nnz(A), A.rows, A.cols, A.vals)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
end
@eval begin
# SparseMatrixCOO
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCOO{$T, $INT}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_coo(descr_ref, IndexBase, A.m, A.n, nnz(A), A.rows, A.cols, A.vals)
obj = MKLSparseMatrix(descr_ref[])
finalizer($sparse_destroy, obj)
return obj
end

# SparseMatrixCSR
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCSR{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csr(descr_ref, IndexBase, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
end
# SparseMatrixCSR
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCSR{$T, $INT}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csr(descr_ref, IndexBase, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer($sparse_destroy, obj)
return obj
end

# SparseMatrixCSC
function MKLSparseMatrix(A::SparseMatrixCSC{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csc(descr_ref, IndexBase, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
# SparseMatrixCSC
function MKLSparseMatrix(A::SparseMatrixCSC{$T, $INT}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csc(descr_ref, IndexBase, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer($sparse_destroy, obj)
return obj
end
end
end
end
Loading