Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use the new API #33

Merged
merged 3 commits into from
Jul 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions src/MKLSparse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,31 +21,32 @@ end
INTERFACE_GNU
end

function set_threading_layer(layer::Threading = THREADING_SEQUENTIAL)
function set_threading_layer(layer::Threading = THREADING_INTEL)
err = @ccall libmkl_rt.MKL_Set_Threading_Layer(layer::Cint)::Cint
(err == -1) && error("MKL_Set_Threading_Layer() returned -1")
return nothing
end

function set_interface_layer(interface::Interface = INTERFACE_ILP64)
function set_interface_layer(interface::Interface = INTERFACE_LP64)
err = @ccall libmkl_rt.MKL_Set_Interface_Layer(interface::Cint)::Cint
(err == -1) && error("MKL_Set_Interface_Layer() returned -1")
return nothing
end

function __init__()
set_interface_layer(Base.USE_BLAS64 ? INTERFACE_ILP64 : INTERFACE_LP64)
set_interface_layer(INTERFACE_LP64)
end

# Wrappers generated by Clang.jl
include("libmklsparse.jl")
include("types.jl")
include("mklsparsematrix.jl")

# TODO: BLAS1

# BLAS2 and BLAS3
include("matdescra.jl")
include("generator.jl")
include("matmul.jl")
include("deprecated.jl")
include("generic.jl")
include("interface.jl")

end # module
20 changes: 15 additions & 5 deletions src/generator.jl → src/deprecated.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,14 @@
matdescra(A::LowerTriangular) = "TLNF"
matdescra(A::UpperTriangular) = "TUNF"
matdescra(A::Diagonal) = "DUNF"
matdescra(A::UnitLowerTriangular) = "TLUF"
matdescra(A::UnitUpperTriangular) = "TUUF"
matdescra(A::Symmetric) = string('S', A.uplo, 'N', 'F')
matdescra(A::Hermitian) = string('H', A.uplo, 'N', 'F')
matdescra(A::SparseMatrixCSC) = "GUUF"
matdescra(A::Transpose) = matdescra(A.parent)
matdescra(A::Adjoint) = matdescra(A.parent)

# The increments to the `__counter` variable is for testing purposes

function _check_transa(t::Char)
Expand Down Expand Up @@ -28,12 +39,11 @@ function _check_mat_mult_matvec(C, A, B, tA)
end

function cscmv!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T},
A::SparseMatrixCSC{T, Int32}, x::StridedVector{T},
β::T, y::StridedVector{T}) where {T <: BlasFloat}
_check_transa(transa)
_check_mat_mult_matvec(y, A, x, transa)
__counter[] += 1

T == Float32 && (mkl_scscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
T == Float64 && (mkl_dcscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
T == ComplexF32 && (mkl_ccscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y))
Expand All @@ -42,7 +52,7 @@ function cscmv!(transa::Char, α::T, matdescra::String,
end

function cscmm!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T},
A::SparseMatrixCSC{T, Int32}, B::StridedMatrix{T},
β::T, C::StridedMatrix{T}) where {T <: BlasFloat}
_check_transa(transa)
_check_mat_mult_matvec(C, A, B, transa)
Expand All @@ -57,7 +67,7 @@ function cscmm!(transa::Char, α::T, matdescra::String,
end

function cscsv!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T},
A::SparseMatrixCSC{T, Int32}, x::StridedVector{T},
y::StridedVector{T}) where {T <: BlasFloat}
n = checksquare(A)
_check_transa(transa)
Expand All @@ -71,7 +81,7 @@ function cscsv!(transa::Char, α::T, matdescra::String,
end

function cscsm!(transa::Char, α::T, matdescra::String,
A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T},
A::SparseMatrixCSC{T, Int32}, B::StridedMatrix{T},
C::StridedMatrix{T}) where {T <: BlasFloat}
mB, nB = size(B)
mC, nC = size(C)
Expand Down
54 changes: 54 additions & 0 deletions src/generic.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,)
for INT in INT_TYPES
for SparseMatrix in (:(SparseMatrixCSC{$T,$INT}), :(MKLSparse.SparseMatrixCSR{$T,$INT}), :(MKLSparse.SparseMatrixCOO{$T,$INT}))

fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv" , mkl_integer_specifier(INT))
fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm" , mkl_integer_specifier(INT))
fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv", mkl_integer_specifier(INT))
fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm", mkl_integer_specifier(INT))

@eval begin
function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y)
return y
end

function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T})
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy)
return y
end

function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
$fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y)
return y
end

function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T})
checksquare(A)
_check_transa(operation)
_check_mat_mult_matvec(y, A, x, operation)
__counter[] += 1
columns = size(y, 2)
ldx = stride(x, 2)
ldy = stride(y, 2)
$fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy)
return y
end
end
end
end
end
90 changes: 90 additions & 0 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import Base: \, *
import LinearAlgebra: mul!, ldiv!

for T in (:Float32, :Float64, :ComplexF32, :ComplexF64)
INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,)
for INT in INT_TYPES

tag_wrappers = ((identity , identity ),
(M -> :(Symmetric{$T, $M}), A -> :(parent($A))),
(M -> :(Hermitian{$T, $M}), A -> :(parent($A))))

triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))),
(M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))),
(M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A))))

op_wrappers = ((identity , 'N', identity ),
(M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))),
(M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A))))

for SparseMatrixType in (:(SparseMatrixCSC{$T, $INT}), :(MKLSparse.SparseMatrixCOO{$T, $INT}), :(MKLSparse.SparseMatrixCSR{$T, $INT}))
for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(taga(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end
end
end

for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers
TypeA = wrapa(trianglea(SparseMatrixType))

@eval begin
function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number)
# return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y)
return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y)
end

function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number)
# return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C)
return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C)
end

function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T})
# return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y)
return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y)
end

function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T})
# return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C)
return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C)
end

function (*)(A::$TypeA, x::StridedVector{$T})
m, n = size(A)
y = Vector{$T}(undef, m)
return mul!(y, A, x, one($T), zero($T))
end

function (*)(A::$TypeA, B::StridedMatrix{$T})
m, k = size(A)
p, n = size(B)
C = Matrix{$T}(undef, m, n)
return mul!(C, A, B, one($T), zero($T))
end

function (\)(A::$TypeA, x::StridedVector{$T})
n = length(x)
y = Vector{$T}(undef, n)
return ldiv!(y, A, x)
end

function (\)(A::$TypeA, B::StridedMatrix{$T})
m, n = size(B)
C = Matrix{$T}(undef, m, n)
return ldiv!(C, A, B)
end
end
end
end
end
end
8 changes: 0 additions & 8 deletions src/matdescra.jl

This file was deleted.

101 changes: 0 additions & 101 deletions src/matmul.jl

This file was deleted.

Loading