Skip to content

Commit

Permalink
Use the new API to support both LP64 and ILP64 API
Browse files Browse the repository at this point in the history
  • Loading branch information
amontoison committed Jul 7, 2024
1 parent 5e64b9d commit 5e5b95b
Show file tree
Hide file tree
Showing 7 changed files with 228 additions and 202 deletions.
6 changes: 3 additions & 3 deletions src/MKLSparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,20 +21,20 @@ end
INTERFACE_GNU
end

function set_threading_layer(layer::Threading = THREADING_SEQUENTIAL)
function set_threading_layer(layer::Threading = THREADING_INTEL)
err = @ccall libmkl_rt.MKL_Set_Threading_Layer(layer::Cint)::Cint
(err == -1) && error("MKL_Set_Threading_Layer() returned -1")
return nothing
end

function set_interface_layer(interface::Interface = INTERFACE_ILP64)
function set_interface_layer(interface::Interface = INTERFACE_LP64)
err = @ccall libmkl_rt.MKL_Set_Interface_Layer(interface::Cint)::Cint
(err == -1) && error("MKL_Set_Interface_Layer() returned -1")
return nothing
end

function __init__()
set_interface_layer(Base.USE_BLAS64 ? INTERFACE_ILP64 : INTERFACE_LP64)
set_interface_layer(INTERFACE_LP64)
end

# Wrappers generated by Clang.jl
Expand Down
8 changes: 4 additions & 4 deletions src/deprecated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ function _check_mat_mult_matvec(C, A, B, tA)
end

function cscmv!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T},
A::SparseMatrixCSC{T, Int32}, x::StridedVector{T},
β::T, y::StridedVector{T}) where {T <: BlasFloat}
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
Expand All @@ -52,7 +52,7 @@ function cscmv!(transa::Char, α::T, matdescra::String,
end

function cscmm!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T},
A::SparseMatrixCSC{T, Int32}, B::StridedMatrix{T},
β::T, C::StridedMatrix{T}) where {T <: BlasFloat}
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
Expand All @@ -67,7 +67,7 @@ function cscmm!(transa::Char, α::T, matdescra::String,
end

function cscsv!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T},
A::SparseMatrixCSC{T, Int32}, x::StridedVector{T},
y::StridedVector{T}) where {T <: BlasFloat}
n = checksquare(A)
_check_transa(transa)
Expand All @@ -81,7 +81,7 @@ function cscsv!(transa::Char, α::T, matdescra::String,
end

function cscsm!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T},
A::SparseMatrixCSC{T, Int32}, B::StridedMatrix{T},
C::StridedMatrix{T}) where {T <: BlasFloat}
mB, nB = size(B)
mC, nC = size(C)
Expand Down
85 changes: 44 additions & 41 deletions src/generic.jl
Original file line number Diff line number Diff line change
@@ -1,50 +1,53 @@
for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
for SparseMatrix in (:(SparseMatrixCSC{$T,BlasInt}), :(MKLSparse.SparseMatrixCSR{$T,BlasInt}), :(MKLSparse.SparseMatrixCOO{$T,BlasInt}))
INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,)
for INT in INT_TYPES
for SparseMatrix in (:(SparseMatrixCSC{$T,$INT}), :(MKLSparse.SparseMatrixCSR{$T,$INT}), :(MKLSparse.SparseMatrixCOO{$T,$INT}))

fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv")
fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm")
fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv")
fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm")
fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv" , mkl_integer_specifier(INT))
fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm" , mkl_integer_specifier(INT))
fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv", mkl_integer_specifier(INT))
fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm", mkl_integer_specifier(INT))

@eval begin
function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y)
return y
end
@eval begin
function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y)
return y
end

function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy)
return y
end
function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy)
return y
end

function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y)
return y
end
function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y)
return y
end

function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy)
return y
function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy)
return y
end
end
end
end
Expand Down
155 changes: 79 additions & 76 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,85 +1,88 @@
import Base: \, *
import LinearAlgebra: mul!, ldiv!

for T in (Float32, Float64, ComplexF32, ComplexF64)

tag_wrappers = ((identity , identity ),
(M -> :(Symmetric{$T, $M}), A -> :(parent($A))),
(M -> :(Hermitian{$T, $M}), A -> :(parent($A))))

triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))),
(M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A))))

op_wrappers = ((identity , 'N', identity ),
(M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))),
(M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A))))

for SparseMatrixType in (:(SparseMatrixCSC{$T, $BlasInt}), :(MKLSparse.SparseMatrixCOO{$T, $BlasInt}), :(MKLSparse.SparseMatrixCSR{$T, $BlasInt}))
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,)
for INT in INT_TYPES

tag_wrappers = ((identity , identity ),
(M -> :(Symmetric{$T, $M}), A -> :(parent($A))),
(M -> :(Hermitian{$T, $M}), A -> :(parent($A))))

triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))),
(M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A))))

op_wrappers = ((identity , 'N', identity ),
(M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))),
(M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A))))

for SparseMatrixType in (:(SparseMatrixCSC{$T, $INT}), :(MKLSparse.SparseMatrixCOO{$T, $INT}), :(MKLSparse.SparseMatrixCSR{$T, $INT}))
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end
end
end
end

for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(trianglea(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end

function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T})
# return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y)
return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y)
end

function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T})
# return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C)
return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C)
end

function (*)(A::$TypeA, x::StridedVector{$T})
m, n = size(A)
y = Vector{$T}(undef, m)
return mul!(y, A, x, one($T), zero($T))
end

function (*)(A::$TypeA, B::StridedMatrix{$T})
m, k = size(A)
p, n = size(B)
C = Matrix{$T}(undef, m, n)
return mul!(C, A, B, one($T), zero($T))
end

function (\)(A::$TypeA, x::StridedVector{$T})
n = length(x)
y = Vector{$T}(undef, n)
return ldiv!(y, A, x)
end

function (\)(A::$TypeA, B::StridedMatrix{$T})
m, n = size(B)
C = Matrix{$T}(undef, m, n)
return ldiv!(C, A, B)
for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(trianglea(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end

function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T})
# return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y)
return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y)
end

function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T})
# return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C)
return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C)
end

function (*)(A::$TypeA, x::StridedVector{$T})
m, n = size(A)
y = Vector{$T}(undef, m)
return mul!(y, A, x, one($T), zero($T))
end

function (*)(A::$TypeA, B::StridedMatrix{$T})
m, k = size(A)
p, n = size(B)
C = Matrix{$T}(undef, m, n)
return mul!(C, A, B, one($T), zero($T))
end

function (\)(A::$TypeA, x::StridedVector{$T})
n = length(x)
y = Vector{$T}(undef, n)
return ldiv!(y, A, x)
end

function (\)(A::$TypeA, B::StridedMatrix{$T})
m, n = size(B)
C = Matrix{$T}(undef, m, n)
return ldiv!(C, A, B)
end
end
end
end
Expand Down
58 changes: 31 additions & 27 deletions src/mklsparsematrix.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,37 +34,41 @@ end
Base.unsafe_convert(::Type{sparse_matrix_t}, desc::MKLSparseMatrix) = desc.handle

for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,)
for INT in INT_TYPES

create_coo = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_coo")
create_csc = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csc")
create_csr = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csr")
create_coo = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_coo", mkl_integer_specifier(INT))
create_csc = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csc", mkl_integer_specifier(INT))
create_csr = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csr", mkl_integer_specifier(INT))
sparse_destroy = (INT == :Int32) ? :mkl_sparse_destroy : :mkl_sparse_destroy_64

@eval begin
# SparseMatrixCOO
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCOO{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_coo(descr_ref, IndexBase, A.m, A.n, nnz(A), A.rows, A.cols, A.vals)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
end
@eval begin
# SparseMatrixCOO
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCOO{$T, $INT}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_coo(descr_ref, IndexBase, A.m, A.n, nnz(A), A.rows, A.cols, A.vals)
obj = MKLSparseMatrix(descr_ref[])
finalizer($sparse_destroy, obj)
return obj
end

# SparseMatrixCSR
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCSR{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csr(descr_ref, IndexBase, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
end
# SparseMatrixCSR
function MKLSparseMatrix(A::MKLSparse.SparseMatrixCSR{$T, $INT}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csr(descr_ref, IndexBase, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer($sparse_destroy, obj)
return obj
end

# SparseMatrixCSC
function MKLSparseMatrix(A::SparseMatrixCSC{$T, BlasInt}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csc(descr_ref, IndexBase, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer(mkl_sparse_destroy, obj)
return obj
# SparseMatrixCSC
function MKLSparseMatrix(A::SparseMatrixCSC{$T, $INT}, IndexBase::Char='O')
descr_ref = Ref{sparse_matrix_t}()
$create_csc(descr_ref, IndexBase, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval)
obj = MKLSparseMatrix(descr_ref[])
finalizer($sparse_destroy, obj)
return obj
end
end
end
end
Loading

0 comments on commit 5e5b95b

Please sign in to comment.