From 5139cb3aa64861026464e6488c2d243e39883774 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 31 Oct 2022 19:44:35 -0400 Subject: [PATCH 1/3] Use the new API --- src/MKLSparse.jl | 7 +- src/{generator.jl => deprecated.jl} | 12 +++- src/generic.jl | 51 ++++++++++++++ src/interface.jl | 87 ++++++++++++++++++++++++ src/matdescra.jl | 8 --- src/matmul.jl | 101 ---------------------------- src/mklsparsematrix.jl | 67 ++++++++++++++++++ src/types.jl | 25 +++++++ test/test_BLAS.jl | 98 ++++++++++++++++----------- 9 files changed, 304 insertions(+), 152 deletions(-) rename src/{generator.jl => deprecated.jl} (90%) create mode 100644 src/generic.jl create mode 100644 src/interface.jl delete mode 100644 src/matdescra.jl delete mode 100644 src/matmul.jl create mode 100644 src/mklsparsematrix.jl diff --git a/src/MKLSparse.jl b/src/MKLSparse.jl index 21c211f..6997b42 100644 --- a/src/MKLSparse.jl +++ b/src/MKLSparse.jl @@ -40,12 +40,13 @@ end # Wrappers generated by Clang.jl include("libmklsparse.jl") include("types.jl") +include("mklsparsematrix.jl") # TODO: BLAS1 # BLAS2 and BLAS3 -include("matdescra.jl") -include("generator.jl") -include("matmul.jl") +include("deprecated.jl") +include("generic.jl") +include("interface.jl") end # module diff --git a/src/generator.jl b/src/deprecated.jl similarity index 90% rename from src/generator.jl rename to src/deprecated.jl index 68ca14f..0ec0fe3 100644 --- a/src/generator.jl +++ b/src/deprecated.jl @@ -1,3 +1,14 @@ +matdescra(A::LowerTriangular) = "TLNF" +matdescra(A::UpperTriangular) = "TUNF" +matdescra(A::Diagonal) = "DUNF" +matdescra(A::UnitLowerTriangular) = "TLUF" +matdescra(A::UnitUpperTriangular) = "TUUF" +matdescra(A::Symmetric) = string('S', A.uplo, 'N', 'F') +matdescra(A::Hermitian) = string('H', A.uplo, 'N', 'F') +matdescra(A::SparseMatrixCSC) = "GUUF" +matdescra(A::Transpose) = matdescra(A.parent) +matdescra(A::Adjoint) = matdescra(A.parent) + # The increments to the `__counter` variable is for testing purposes function _check_transa(t::Char) @@ -33,7 +44,6 @@ function cscmv!(transa::Char, α::T, matdescra::String, _check_transa(transa) _check_mat_mult_matvec(y, A, x, transa) __counter[] += 1 - T == Float32 && (mkl_scscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y)) T == Float64 && (mkl_dcscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y)) T == ComplexF32 && (mkl_ccscmv(transa, A.m, A.n, α, matdescra, A.nzval, A.rowval, A.colptr, pointer(A.colptr, 2), x, β, y)) diff --git a/src/generic.jl b/src/generic.jl new file mode 100644 index 0000000..cb43d40 --- /dev/null +++ b/src/generic.jl @@ -0,0 +1,51 @@ +for T in (:Float32, :Float64, :ComplexF32, :ComplexF64) + for SparseMatrix in (:(SparseMatrixCSC{$T,BlasInt}), :(MKLSparse.SparseMatrixCSR{$T,BlasInt}), :(MKLSparse.SparseMatrixCOO{$T,BlasInt})) + + fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv") + fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm") + fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv") + fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm") + + @eval begin + function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T}) + _check_transa(operation) + _check_mat_mult_matvec(y, A, x, operation) + __counter[] += 1 + $fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y) + return y + end + + function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T}) + _check_transa(operation) + _check_mat_mult_matvec(y, A, x, operation) + __counter[] += 1 + columns = size(y, 2) + ldx = stride(x, 2) + ldy = stride(y, 2) + $fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy) + return y + end + + function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T}) + checksquare(A) + _check_transa(operation) + _check_mat_mult_matvec(y, A, x, operation) + __counter[] += 1 + $fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y) + return y + end + + function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T}) + checksquare(A) + _check_transa(operation) + _check_mat_mult_matvec(y, A, x, operation) + __counter[] += 1 + columns = size(y, 2) + ldx = stride(x, 2) + ldy = stride(y, 2) + $fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy) + return y + end + end + end +end diff --git a/src/interface.jl b/src/interface.jl new file mode 100644 index 0000000..aaa1e31 --- /dev/null +++ b/src/interface.jl @@ -0,0 +1,87 @@ +import Base: \, * +import LinearAlgebra: mul!, ldiv! + +for T in (Float32, Float64, ComplexF32, ComplexF64) + + tag_wrappers = ((identity , identity ), + (M -> :(Symmetric{$T, $M}), A -> :(parent($A))), + (M -> :(Hermitian{$T, $M}), A -> :(parent($A)))) + + triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))), + (M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))), + (M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))), + (M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A)))) + + op_wrappers = ((identity , 'N', identity ), + (M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))), + (M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A)))) + + for SparseMatrixType in (:(SparseMatrixCSC{$T, $BlasInt}), :(MKLSparse.SparseMatrixCOO{$T, $BlasInt}), :(MKLSparse.SparseMatrixCSR{$T, $BlasInt})) + for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers + TypeA = wrapa(taga(SparseMatrixType)) + + @eval begin + function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number) + # return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y) + return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y) + end + + function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number) + # return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C) + return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C) + end + end + end + + for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers + TypeA = wrapa(trianglea(SparseMatrixType)) + + @eval begin + function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number) + # return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y) + return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y) + end + + function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number) + # return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C) + return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C) + end + + function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}) + # return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y) + return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y) + end + + function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}) + # return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C) + return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C) + end + + function (*)(A::$TypeA, x::StridedVector{$T}) + m, n = size(A) + y = Vector{$T}(undef, m) + return mul!(y, A, x, one($T), zero($T)) + end + + function (*)(A::$TypeA, B::StridedMatrix{$T}) + m, k = size(A) + p, n = size(B) + C = Matrix{$T}(undef, m, n) + return mul!(C, A, B, one($T), zero($T)) + end + + function (\)(A::$TypeA, x::StridedVector{$T}) + n = length(x) + y = Vector{$T}(undef, n) + return ldiv!(y, A, x) + end + + function (\)(A::$TypeA, B::StridedMatrix{$T}) + m, n = size(B) + C = Matrix{$T}(undef, m, n) + return ldiv!(C, A, B) + end + end + end + end +end diff --git a/src/matdescra.jl b/src/matdescra.jl deleted file mode 100644 index 247275a..0000000 --- a/src/matdescra.jl +++ /dev/null @@ -1,8 +0,0 @@ -matdescra(A::LowerTriangular) = "TLNF" -matdescra(A::UpperTriangular) = "TUNF" -matdescra(A::Diagonal) = "DUNF" -matdescra(A::UnitLowerTriangular) = "TLUF" -matdescra(A::UnitUpperTriangular) = "TUUF" -matdescra(A::Symmetric) = string('S', A.uplo, 'N', 'F') -matdescra(A::Hermitian) = string('H', A.uplo, 'N', 'F') -matdescra(A::SparseMatrixCSC) = "GUUF" diff --git a/src/matmul.jl b/src/matmul.jl deleted file mode 100644 index 6baa8a3..0000000 --- a/src/matmul.jl +++ /dev/null @@ -1,101 +0,0 @@ -import Base: *, \ -import LinearAlgebra: mul!, ldiv! - -_get_data(A::LowerTriangular) = tril(A.data) -_get_data(A::UpperTriangular) = triu(A.data) -_get_data(A::UnitLowerTriangular) = tril(A.data) -_get_data(A::UnitUpperTriangular) = triu(A.data) -_get_data(A::Symmetric) = A.data - -_unwrap_adj(x::Union{Adjoint,Transpose}) = parent(x) -_unwrap_adj(x) = x - -const SparseMatrices{T} = Union{SparseMatrixCSC{T,BlasInt}, - Symmetric{T,SparseMatrixCSC{T,BlasInt}}, - LowerTriangular{T, SparseMatrixCSC{T,BlasInt}}, - UnitLowerTriangular{T, SparseMatrixCSC{T,BlasInt}}, - UpperTriangular{T, SparseMatrixCSC{T,BlasInt}}, - UnitUpperTriangular{T, SparseMatrixCSC{T,BlasInt}}} - -for T in [Complex{Float32}, Complex{Float64}, Float32, Float64] -for mat in (:StridedVector, :StridedMatrix) -for (tchar, ttype) in (('N', :()), - ('C', :Adjoint), - ('T', :Transpose)) - AT = tchar == 'N' ? :(SparseMatrixCSC{$T,BlasInt}) : :($ttype{$T,SparseMatrixCSC{$T,BlasInt}}) - @eval begin - function mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}, α::Number, β::Number) - A = _unwrap_adj(adjA) - if isa(B, AbstractVector) - return cscmv!($tchar, $T(α), matdescra(A), A, B, $T(β), C) - else - return cscmm!($tchar, $T(α), matdescra(A), A, B, $T(β), C) - end - end - - mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}) = mul!(C, adjA, B, one($T), zero($T)) - - function (*)(adjA::$AT, B::$mat{$T}) - A = _unwrap_adj(adjA) - if isa(B,AbstractVector) - return mul!(zeros($T, mkl_size($tchar, A)[1]), adjA, B) - else - return mul!(zeros($T, mkl_size($tchar, A)[1], size(B,2)), adjA, B) - end - end - end - - for w in (:Symmetric, :LowerTriangular, :UnitLowerTriangular, :UpperTriangular, :UnitUpperTriangular) - AT = tchar == 'N' ? - :($w{$T,SparseMatrixCSC{$T,BlasInt}}) : - :($ttype{$T,$w{$T,SparseMatrixCSC{$T,BlasInt}}}) - @eval begin - function mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}, α::Number, β::Number) - A = _unwrap_adj(adjA) - if isa(B,AbstractVector) - return cscmv!($tchar, $T(α), matdescra(A), _get_data(A), B, $T(β), C) - else - return cscmm!($tchar, $T(α), matdescra(A), _get_data(A), B, $T(β), C) - end - end - - mul!(C::$mat{$T}, adjA::$AT, B::$mat{$T}) = mul!(C, adjA, B, one($T), zero($T)) - - function (*)(adjA::$AT, B::$mat{$T}) - A = _unwrap_adj(adjA) - if isa(B,AbstractVector) - return mul!(zeros($T, mkl_size($tchar, A)[1]), adjA, B) - else - return mul!(zeros($T, mkl_size($tchar, A)[1], size(B,2)), adjA, B) - end - end - end - - if w != :Symmetric - @eval begin - function ldiv!(α::Number, adjA::$AT, - B::$mat{$T}, C::$mat{$T}) - A = _unwrap_adj(adjA) - if isa(B,AbstractVector) - return cscsv!($tchar, $T(α), matdescra(A), _get_data(A), B, C) - else - return cscsm!($tchar, $T(α), matdescra(A), _get_data(A), B, C) - end - end - - ldiv!(C::$mat{$T}, A::$AT, B::$mat{$T}) = - ldiv!(one($T), A, B, C) - - function (\)(A::$AT, B::$mat{$T}) - if isa(B,AbstractVector) - return ldiv!(zeros($T, size(A,1)), A, B) - else - return ldiv!(zeros($T, size(A,1), size(B,2)), A, B) - end - end - end - end - end -end -end # mat -end # T diff --git a/src/mklsparsematrix.jl b/src/mklsparsematrix.jl new file mode 100644 index 0000000..9072786 --- /dev/null +++ b/src/mklsparsematrix.jl @@ -0,0 +1,67 @@ +## MKL sparse matrix + +# https://github.com/JuliaSmoothOptimizers/SparseMatricesCOO.jl +mutable struct SparseMatrixCOO{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti} + m::Int + n::Int + rows::Vector{Ti} + cols::Vector{Ti} + vals::Vector{Tv} +end + +# https://github.com/gridap/SparseMatricesCSR.jl +mutable struct SparseMatrixCSR{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti} + m::Int + n::Int + rowptr::Vector{Ti} + colval::Vector{Ti} + nzval::Vector{Tv} +end + +SparseArrays.nnz(A::MKLSparse.SparseMatrixCOO) = length(A.vals) +SparseArrays.nnz(A::MKLSparse.SparseMatrixCSR) = length(A.nzval) + +matrixdescra(A::MKLSparse.SparseMatrixCSR) = matrix_descr('G', 'F', 'N') +matrixdescra(A::MKLSparse.SparseMatrixCOO) = matrix_descr('G', 'F', 'N') + +mutable struct MKLSparseMatrix + handle::sparse_matrix_t +end + +Base.unsafe_convert(::Type{sparse_matrix_t}, desc::MKLSparseMatrix) = desc.handle + +for T in (:Float32, :Float64, :ComplexF32, :ComplexF64) + + create_coo = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_coo") + create_csc = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csc") + create_csr = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csr") + + @eval begin + # SparseMatrixCOO + function MKLSparseMatrix(A::MKLSparse.SparseMatrixCOO{$T, BlasInt}, IndexBase::Char='O') + descr_ref = Ref{sparse_matrix_t}() + $create_coo(descr_ref, IndexBase, A.m, A.n, nnz(A), A.rows, A.cols, A.vals) + obj = MKLSparseMatrix(descr_ref[]) + finalizer(mkl_sparse_destroy, obj) + return obj + end + + # SparseMatrixCSR + function MKLSparseMatrix(A::MKLSparse.SparseMatrixCSR{$T, BlasInt}, IndexBase::Char='O') + descr_ref = Ref{sparse_matrix_t}() + $create_csr(descr_ref, IndexBase, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval) + obj = MKLSparseMatrix(descr_ref[]) + finalizer(mkl_sparse_destroy, obj) + return obj + end + + # SparseMatrixCSC + function MKLSparseMatrix(A::SparseMatrixCSC{$T, BlasInt}, IndexBase::Char='O') + descr_ref = Ref{sparse_matrix_t}() + $create_csc(descr_ref, IndexBase, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval) + obj = MKLSparseMatrix(descr_ref[]) + finalizer(mkl_sparse_destroy, obj) + return obj + end + end +end diff --git a/src/types.jl b/src/types.jl index f319c48..0546bcb 100644 --- a/src/types.jl +++ b/src/types.jl @@ -1,5 +1,30 @@ # MKL sparse types +function mkl_type_specifier(T::Symbol) + if T == :Float32 + 's' + elseif T == :Float64 + 'd' + elseif T == :ComplexF32 + 'c' + elseif T == :ComplexF64 + 'z' + else + throw(ArgumentError("Unsupported numeric type $T")) + end +end + +matrixdescra(A::LowerTriangular) = matrix_descr('T','L','N') +matrixdescra(A::UpperTriangular) = matrix_descr('T','U','N') +matrixdescra(A::Diagonal) = matrix_descr('D','F','N') +matrixdescra(A::UnitLowerTriangular) = matrix_descr('T','L','U') +matrixdescra(A::UnitUpperTriangular) = matrix_descr('T','U','U') +matrixdescra(A::Symmetric) = matrix_descr('S', A.uplo, 'N') +matrixdescra(A::Hermitian) = matrix_descr('H', A.uplo, 'N') +matrixdescra(A::SparseMatrixCSC) = matrix_descr('G', 'F', 'N') +matrixdescra(A::Transpose) = matrixdescra(A.parent) +matrixdescra(A::Adjoint) = matrixdescra(A.parent) + function Base.convert(::Type{sparse_operation_t}, trans::Char) if trans == 'N' SPARSE_OPERATION_NON_TRANSPOSE diff --git a/test/test_BLAS.jl b/test/test_BLAS.jl index c569eee..1d11f83 100644 --- a/test/test_BLAS.jl +++ b/test/test_BLAS.jl @@ -22,23 +22,68 @@ macro test_blas(ex) end end -@testset "matrix-vector multiplication (non-square)" begin - for i = 1:5 - a = sprand(10, 5, 0.5) - b = rand(5) - @test_blas maximum(abs.(a*b - Array(a)*b)) < 100*eps() - b = rand(5, 5) - @test_blas maximum(abs.(a*b - Array(a)*b)) < 100*eps() - b = rand(10) - @test_blas maximum(abs.(a'*b - Array(a)'*b)) < 100*eps() - @test_blas maximum(abs.(transpose(a)*b - Array(a)'*b)) < 100*eps() - b = rand(10,10) - @test_blas maximum(abs.(a'*b - Array(a)'*b)) < 100*eps() - @test_blas maximum(abs.(transpose(a)*b - Array(a)'*b)) < 100*eps() +for T in (Float64, ComplexF64) + @testset "matrix-vector and matrix-matrix multiplications (non-square) -- $T" begin + for i = 1:5 + a = sprand(T, 10, 5, 0.5) + b = rand(T, 5) + @test_blas a*b ≈ Array(a)*b + B = rand(T, 5, 5) + @test_blas a*B ≈ Array(a)*B + b = rand(T, 10) + @test_blas a'*b ≈ Array(a)'*b + @test_blas transpose(a)*b ≈ transpose(Array(a))*b + B = rand(T, 10, 10) + @test_blas a'*B ≈ Array(a)'*B + @test_blas transpose(a)*B ≈ transpose(Array(a))*B + end + end + + @testset "Symmetric / Hermitian -- $T" begin + n = 10 + A = sprandn(T, n, n, 0.5) + sqrt(n)*I + b = rand(T, n) + B = rand(T, n, 3) + symA = A + transpose(A) + hermA = A + adjoint(A) + @test_blas Symmetric(symA) * b ≈ Array(Symmetric(symA)) * b + @test_blas Hermitian(hermA) * b ≈ Array(Hermitian(hermA)) * b + @test_blas Symmetric(symA) * B ≈ Array(Symmetric(symA)) * B + @test_blas Hermitian(hermA) * B ≈ Array(Hermitian(hermA)) * B + end + + @testset "triangular -- $T" begin + n = 10 + A = sprandn(T, n, n, 0.5) + sqrt(n)*I + b = rand(T, n) + B = rand(T, n, 3) + trilA = tril(A) + triuA = triu(A) + trilUA = tril(A, -1) + I + triuUA = triu(A, 1) + I + + @test_blas LowerTriangular(trilA) \ b ≈ Array(LowerTriangular(trilA)) \ b + @test_blas LowerTriangular(trilA) * b ≈ Array(LowerTriangular(trilA)) * b + @test_blas LowerTriangular(trilA) \ B ≈ Array(LowerTriangular(trilA)) \ B + @test_blas LowerTriangular(trilA) * B ≈ Array(LowerTriangular(trilA)) * B + + @test_blas UpperTriangular(triuA) \ b ≈ Array(UpperTriangular(triuA)) \ b + @test_blas UpperTriangular(triuA) * b ≈ Array(UpperTriangular(triuA)) * b + @test_blas UpperTriangular(triuA) \ B ≈ Array(UpperTriangular(triuA)) \ B + @test_blas UpperTriangular(triuA) * B ≈ Array(UpperTriangular(triuA)) * B + + @test_blas UnitLowerTriangular(trilUA) \ b ≈ Array(UnitLowerTriangular(trilUA)) \ b + @test_blas UnitLowerTriangular(trilUA) * b ≈ Array(UnitLowerTriangular(trilUA)) * b + @test_blas UnitLowerTriangular(trilUA) \ B ≈ Array(UnitLowerTriangular(trilUA)) \ B + @test_blas UnitLowerTriangular(trilUA) * B ≈ Array(UnitLowerTriangular(trilUA)) * B + + @test_blas UnitUpperTriangular(triuUA) \ b ≈ Array(UnitUpperTriangular(triuUA)) \ b + @test_blas UnitUpperTriangular(triuUA) * b ≈ Array(UnitUpperTriangular(triuUA)) * b + @test_blas UnitUpperTriangular(triuUA) \ B ≈ Array(UnitUpperTriangular(triuUA)) \ B + @test_blas UnitUpperTriangular(triuUA) * B ≈ Array(UnitUpperTriangular(triuUA)) * B end end -#? @testset "complex matrix-vector multiplication" begin for i = 1:5 a = I + im * 0.1*sprandn(5, 5, 0.2) @@ -68,28 +113,3 @@ end @test_throws DimensionMismatch a.*c end end - -@testset "triangular" begin - n = 100 - A = sprandn(n, n, 0.5) + sqrt(n)*I - b = rand(n) - symA = A + transpose(A) - trilA = tril(A) - triuA = triu(A) - trilUA = tril(A, -1) + I - triuUA = triu(A, 1) + I - - @test_blas LowerTriangular(trilA) \ b ≈ Array(LowerTriangular(trilA)) \ b - @test_blas LowerTriangular(trilA) * b ≈ Array(LowerTriangular(trilA)) * b - - @test_blas UpperTriangular(triuA) \ b ≈ Array(UpperTriangular(triuA)) \ b - @test_blas UpperTriangular(triuA) * b ≈ Array(UpperTriangular(triuA)) * b - - @test_blas UnitLowerTriangular(trilUA) \ b ≈ Array(UnitLowerTriangular(trilUA)) \ b - @test_blas UnitLowerTriangular(trilUA) * b ≈ Array(UnitLowerTriangular(trilUA)) * b - - @test_blas UnitUpperTriangular(triuUA) \ b ≈ Array(UnitUpperTriangular(triuUA)) \ b - @test_blas UnitUpperTriangular(triuUA) * b ≈ Array(UnitUpperTriangular(triuUA)) * b - - @test_blas Symmetric(symA) * b ≈ Array(Symmetric(symA)) * b -end From 5e64b9d8ce4f4b1f5f24ca906341b539d3896c32 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Mon, 31 Oct 2022 21:34:43 -0400 Subject: [PATCH 2/3] Add size for SparseMatrixCSR and SparseMatrixCOO --- src/mklsparsematrix.jl | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/mklsparsematrix.jl b/src/mklsparsematrix.jl index 9072786..c5ddc2f 100644 --- a/src/mklsparsematrix.jl +++ b/src/mklsparsematrix.jl @@ -18,6 +18,9 @@ mutable struct SparseMatrixCSR{Tv,Ti} <: AbstractSparseMatrix{Tv,Ti} nzval::Vector{Tv} end +Base.size(A::MKLSparse.SparseMatrixCOO) = (A.m, A.n) +Base.size(A::MKLSparse.SparseMatrixCSR) = (A.m, A.n) + SparseArrays.nnz(A::MKLSparse.SparseMatrixCOO) = length(A.vals) SparseArrays.nnz(A::MKLSparse.SparseMatrixCSR) = length(A.nzval) From 5e5b95b1e0ab5d8cbdd0912a336e72971747c301 Mon Sep 17 00:00:00 2001 From: Alexis Montoison Date: Sun, 7 Jul 2024 18:31:08 -0400 Subject: [PATCH 3/3] Use the new API to support both LP64 and ILP64 API --- src/MKLSparse.jl | 6 +- src/deprecated.jl | 8 +-- src/generic.jl | 85 +++++++++++----------- src/interface.jl | 155 +++++++++++++++++++++-------------------- src/mklsparsematrix.jl | 58 ++++++++------- src/types.jl | 10 +++ test/test_BLAS.jl | 108 ++++++++++++++-------------- 7 files changed, 228 insertions(+), 202 deletions(-) diff --git a/src/MKLSparse.jl b/src/MKLSparse.jl index 6997b42..088130d 100644 --- a/src/MKLSparse.jl +++ b/src/MKLSparse.jl @@ -21,20 +21,20 @@ end INTERFACE_GNU end -function set_threading_layer(layer::Threading = THREADING_SEQUENTIAL) +function set_threading_layer(layer::Threading = THREADING_INTEL) err = @ccall libmkl_rt.MKL_Set_Threading_Layer(layer::Cint)::Cint (err == -1) && error("MKL_Set_Threading_Layer() returned -1") return nothing end -function set_interface_layer(interface::Interface = INTERFACE_ILP64) +function set_interface_layer(interface::Interface = INTERFACE_LP64) err = @ccall libmkl_rt.MKL_Set_Interface_Layer(interface::Cint)::Cint (err == -1) && error("MKL_Set_Interface_Layer() returned -1") return nothing end function __init__() - set_interface_layer(Base.USE_BLAS64 ? INTERFACE_ILP64 : INTERFACE_LP64) + set_interface_layer(INTERFACE_LP64) end # Wrappers generated by Clang.jl diff --git a/src/deprecated.jl b/src/deprecated.jl index 0ec0fe3..0ea21d1 100644 --- a/src/deprecated.jl +++ b/src/deprecated.jl @@ -39,7 +39,7 @@ function _check_mat_mult_matvec(C, A, B, tA) end function cscmv!(transa::Char, α::T, matdescra::String, - A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T}, + A::SparseMatrixCSC{T, Int32}, x::StridedVector{T}, β::T, y::StridedVector{T}) where {T <: BlasFloat} _check_transa(transa) _check_mat_mult_matvec(y, A, x, transa) @@ -52,7 +52,7 @@ function cscmv!(transa::Char, α::T, matdescra::String, end function cscmm!(transa::Char, α::T, matdescra::String, - A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T}, + A::SparseMatrixCSC{T, Int32}, B::StridedMatrix{T}, β::T, C::StridedMatrix{T}) where {T <: BlasFloat} _check_transa(transa) _check_mat_mult_matvec(C, A, B, transa) @@ -67,7 +67,7 @@ function cscmm!(transa::Char, α::T, matdescra::String, end function cscsv!(transa::Char, α::T, matdescra::String, - A::SparseMatrixCSC{T, BlasInt}, x::StridedVector{T}, + A::SparseMatrixCSC{T, Int32}, x::StridedVector{T}, y::StridedVector{T}) where {T <: BlasFloat} n = checksquare(A) _check_transa(transa) @@ -81,7 +81,7 @@ function cscsv!(transa::Char, α::T, matdescra::String, end function cscsm!(transa::Char, α::T, matdescra::String, - A::SparseMatrixCSC{T, BlasInt}, B::StridedMatrix{T}, + A::SparseMatrixCSC{T, Int32}, B::StridedMatrix{T}, C::StridedMatrix{T}) where {T <: BlasFloat} mB, nB = size(B) mC, nC = size(C) diff --git a/src/generic.jl b/src/generic.jl index cb43d40..6c21597 100644 --- a/src/generic.jl +++ b/src/generic.jl @@ -1,50 +1,53 @@ for T in (:Float32, :Float64, :ComplexF32, :ComplexF64) - for SparseMatrix in (:(SparseMatrixCSC{$T,BlasInt}), :(MKLSparse.SparseMatrixCSR{$T,BlasInt}), :(MKLSparse.SparseMatrixCOO{$T,BlasInt})) + INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,) + for INT in INT_TYPES + for SparseMatrix in (:(SparseMatrixCSC{$T,$INT}), :(MKLSparse.SparseMatrixCSR{$T,$INT}), :(MKLSparse.SparseMatrixCOO{$T,$INT})) - fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv") - fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm") - fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv") - fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm") + fname_mv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mv" , mkl_integer_specifier(INT)) + fname_mm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_mm" , mkl_integer_specifier(INT)) + fname_trsv = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsv", mkl_integer_specifier(INT)) + fname_trsm = Symbol("mkl_sparse_", mkl_type_specifier(T), "_trsm", mkl_integer_specifier(INT)) - @eval begin - function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T}) - _check_transa(operation) - _check_mat_mult_matvec(y, A, x, operation) - __counter[] += 1 - $fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y) - return y - end + @eval begin + function mv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, beta::$T, y::StridedVector{$T}) + _check_transa(operation) + _check_mat_mult_matvec(y, A, x, operation) + __counter[] += 1 + $fname_mv(operation, alpha, MKLSparseMatrix(A), descr, x, beta, y) + return y + end - function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T}) - _check_transa(operation) - _check_mat_mult_matvec(y, A, x, operation) - __counter[] += 1 - columns = size(y, 2) - ldx = stride(x, 2) - ldy = stride(y, 2) - $fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy) - return y - end + function mm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, beta::$T, y::StridedMatrix{$T}) + _check_transa(operation) + _check_mat_mult_matvec(y, A, x, operation) + __counter[] += 1 + columns = size(y, 2) + ldx = stride(x, 2) + ldy = stride(y, 2) + $fname_mm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, beta, y, ldy) + return y + end - function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T}) - checksquare(A) - _check_transa(operation) - _check_mat_mult_matvec(y, A, x, operation) - __counter[] += 1 - $fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y) - return y - end + function trsv!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedVector{$T}, y::StridedVector{$T}) + checksquare(A) + _check_transa(operation) + _check_mat_mult_matvec(y, A, x, operation) + __counter[] += 1 + $fname_trsv(operation, alpha, MKLSparseMatrix(A), descr, x, y) + return y + end - function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T}) - checksquare(A) - _check_transa(operation) - _check_mat_mult_matvec(y, A, x, operation) - __counter[] += 1 - columns = size(y, 2) - ldx = stride(x, 2) - ldy = stride(y, 2) - $fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy) - return y + function trsm!(operation::Char, alpha::$T, A::$SparseMatrix, descr::matrix_descr, x::StridedMatrix{$T}, y::StridedMatrix{$T}) + checksquare(A) + _check_transa(operation) + _check_mat_mult_matvec(y, A, x, operation) + __counter[] += 1 + columns = size(y, 2) + ldx = stride(x, 2) + ldy = stride(y, 2) + $fname_trsm(operation, alpha, MKLSparseMatrix(A), descr, 'C', x, columns, ldx, y, ldy) + return y + end end end end diff --git a/src/interface.jl b/src/interface.jl index aaa1e31..a3fdaad 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,85 +1,88 @@ import Base: \, * import LinearAlgebra: mul!, ldiv! -for T in (Float32, Float64, ComplexF32, ComplexF64) - - tag_wrappers = ((identity , identity ), - (M -> :(Symmetric{$T, $M}), A -> :(parent($A))), - (M -> :(Hermitian{$T, $M}), A -> :(parent($A)))) - - triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))), - (M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))), - (M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))), - (M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A)))) - - op_wrappers = ((identity , 'N', identity ), - (M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))), - (M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A)))) - - for SparseMatrixType in (:(SparseMatrixCSC{$T, $BlasInt}), :(MKLSparse.SparseMatrixCOO{$T, $BlasInt}), :(MKLSparse.SparseMatrixCSR{$T, $BlasInt})) - for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers - TypeA = wrapa(taga(SparseMatrixType)) - - @eval begin - function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number) - # return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y) - return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y) - end - - function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number) - # return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C) - return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C) +for T in (:Float32, :Float64, :ComplexF32, :ComplexF64) + INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,) + for INT in INT_TYPES + + tag_wrappers = ((identity , identity ), + (M -> :(Symmetric{$T, $M}), A -> :(parent($A))), + (M -> :(Hermitian{$T, $M}), A -> :(parent($A)))) + + triangle_wrappers = ((M -> :(LowerTriangular{$T, $M}) , A -> :(parent($A))), + (M -> :(UnitLowerTriangular{$T, $M}), A -> :(parent($A))), + (M -> :(UpperTriangular{$T, $M}) , A -> :(parent($A))), + (M -> :(UnitUpperTriangular{$T, $M}), A -> :(parent($A)))) + + op_wrappers = ((identity , 'N', identity ), + (M -> :(Transpose{$T, $M}), 'T', A -> :(parent($A))), + (M -> :(Adjoint{$T, $M}) , 'C', A -> :(parent($A)))) + + for SparseMatrixType in (:(SparseMatrixCSC{$T, $INT}), :(MKLSparse.SparseMatrixCOO{$T, $INT}), :(MKLSparse.SparseMatrixCSR{$T, $INT})) + for (taga, untaga) in tag_wrappers, (wrapa, transa, unwrapa) in op_wrappers + TypeA = wrapa(taga(SparseMatrixType)) + + @eval begin + function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number) + # return cscmv!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), x, $T(beta), y) + return mv!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y) + end + + function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number) + # return cscmm!($transa, $T(alpha), $matdescra(A), $(untaga(unwrapa(:A))), B, $T(beta), C) + return mm!($transa, $T(alpha), $(untaga(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C) + end end end - end - - for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers - TypeA = wrapa(trianglea(SparseMatrixType)) - - @eval begin - function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number) - # return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y) - return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y) - end - - function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number) - # return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C) - return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C) - end - - function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}) - # return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y) - return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y) - end - - function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}) - # return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C) - return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C) - end - - function (*)(A::$TypeA, x::StridedVector{$T}) - m, n = size(A) - y = Vector{$T}(undef, m) - return mul!(y, A, x, one($T), zero($T)) - end - - function (*)(A::$TypeA, B::StridedMatrix{$T}) - m, k = size(A) - p, n = size(B) - C = Matrix{$T}(undef, m, n) - return mul!(C, A, B, one($T), zero($T)) - end - - function (\)(A::$TypeA, x::StridedVector{$T}) - n = length(x) - y = Vector{$T}(undef, n) - return ldiv!(y, A, x) - end - function (\)(A::$TypeA, B::StridedMatrix{$T}) - m, n = size(B) - C = Matrix{$T}(undef, m, n) - return ldiv!(C, A, B) + for (trianglea, untrianglea) in triangle_wrappers, (wrapa, transa, unwrapa) in op_wrappers + TypeA = wrapa(trianglea(SparseMatrixType)) + + @eval begin + function LinearAlgebra.mul!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}, alpha::Number, beta::Number) + # return cscmv!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), x, $T(beta), y) + return mv!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, $T(beta), y) + end + + function LinearAlgebra.mul!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}, alpha::Number, beta::Number) + # return cscmm!($transa, $T(alpha), $matdescra(A), $(untrianglea(unwrapa(:A))), B, $T(beta), C) + return mm!($transa, $T(alpha), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, $T(beta), C) + end + + function LinearAlgebra.ldiv!(y::StridedVector{$T}, A::$TypeA, x::StridedVector{$T}) + # return cscsv!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), x, y) + return trsv!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), x, y) + end + + function LinearAlgebra.ldiv!(C::StridedMatrix{$T}, A::$TypeA, B::StridedMatrix{$T}) + # return cscsm!($transa, one($T), $matdescra(A), $(untrianglea(unwrapa(:A))), B, C) + return trsm!($transa, one($T), $(untrianglea(unwrapa(:A))), $matrixdescra(A), B, C) + end + + function (*)(A::$TypeA, x::StridedVector{$T}) + m, n = size(A) + y = Vector{$T}(undef, m) + return mul!(y, A, x, one($T), zero($T)) + end + + function (*)(A::$TypeA, B::StridedMatrix{$T}) + m, k = size(A) + p, n = size(B) + C = Matrix{$T}(undef, m, n) + return mul!(C, A, B, one($T), zero($T)) + end + + function (\)(A::$TypeA, x::StridedVector{$T}) + n = length(x) + y = Vector{$T}(undef, n) + return ldiv!(y, A, x) + end + + function (\)(A::$TypeA, B::StridedMatrix{$T}) + m, n = size(B) + C = Matrix{$T}(undef, m, n) + return ldiv!(C, A, B) + end end end end diff --git a/src/mklsparsematrix.jl b/src/mklsparsematrix.jl index c5ddc2f..dfff3ad 100644 --- a/src/mklsparsematrix.jl +++ b/src/mklsparsematrix.jl @@ -34,37 +34,41 @@ end Base.unsafe_convert(::Type{sparse_matrix_t}, desc::MKLSparseMatrix) = desc.handle for T in (:Float32, :Float64, :ComplexF32, :ComplexF64) + INT_TYPES = Base.USE_BLAS64 ? (:Int32, :Int64) : (:Int32,) + for INT in INT_TYPES - create_coo = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_coo") - create_csc = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csc") - create_csr = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csr") + create_coo = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_coo", mkl_integer_specifier(INT)) + create_csc = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csc", mkl_integer_specifier(INT)) + create_csr = Symbol("mkl_sparse_", mkl_type_specifier(T), "_create_csr", mkl_integer_specifier(INT)) + sparse_destroy = (INT == :Int32) ? :mkl_sparse_destroy : :mkl_sparse_destroy_64 - @eval begin - # SparseMatrixCOO - function MKLSparseMatrix(A::MKLSparse.SparseMatrixCOO{$T, BlasInt}, IndexBase::Char='O') - descr_ref = Ref{sparse_matrix_t}() - $create_coo(descr_ref, IndexBase, A.m, A.n, nnz(A), A.rows, A.cols, A.vals) - obj = MKLSparseMatrix(descr_ref[]) - finalizer(mkl_sparse_destroy, obj) - return obj - end + @eval begin + # SparseMatrixCOO + function MKLSparseMatrix(A::MKLSparse.SparseMatrixCOO{$T, $INT}, IndexBase::Char='O') + descr_ref = Ref{sparse_matrix_t}() + $create_coo(descr_ref, IndexBase, A.m, A.n, nnz(A), A.rows, A.cols, A.vals) + obj = MKLSparseMatrix(descr_ref[]) + finalizer($sparse_destroy, obj) + return obj + end - # SparseMatrixCSR - function MKLSparseMatrix(A::MKLSparse.SparseMatrixCSR{$T, BlasInt}, IndexBase::Char='O') - descr_ref = Ref{sparse_matrix_t}() - $create_csr(descr_ref, IndexBase, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval) - obj = MKLSparseMatrix(descr_ref[]) - finalizer(mkl_sparse_destroy, obj) - return obj - end + # SparseMatrixCSR + function MKLSparseMatrix(A::MKLSparse.SparseMatrixCSR{$T, $INT}, IndexBase::Char='O') + descr_ref = Ref{sparse_matrix_t}() + $create_csr(descr_ref, IndexBase, A.m, A.n, A.rowptr, pointer(A.rowptr, 2), A.colval, A.nzval) + obj = MKLSparseMatrix(descr_ref[]) + finalizer($sparse_destroy, obj) + return obj + end - # SparseMatrixCSC - function MKLSparseMatrix(A::SparseMatrixCSC{$T, BlasInt}, IndexBase::Char='O') - descr_ref = Ref{sparse_matrix_t}() - $create_csc(descr_ref, IndexBase, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval) - obj = MKLSparseMatrix(descr_ref[]) - finalizer(mkl_sparse_destroy, obj) - return obj + # SparseMatrixCSC + function MKLSparseMatrix(A::SparseMatrixCSC{$T, $INT}, IndexBase::Char='O') + descr_ref = Ref{sparse_matrix_t}() + $create_csc(descr_ref, IndexBase, A.m, A.n, A.colptr, pointer(A.colptr, 2), A.rowval, A.nzval) + obj = MKLSparseMatrix(descr_ref[]) + finalizer($sparse_destroy, obj) + return obj + end end end end diff --git a/src/types.jl b/src/types.jl index 0546bcb..55a160a 100644 --- a/src/types.jl +++ b/src/types.jl @@ -14,6 +14,16 @@ function mkl_type_specifier(T::Symbol) end end +function mkl_integer_specifier(INT::Symbol) + if INT == :Int32 + "" + elseif INT == :Int64 + "_64" + else + throw(ArgumentError("Unsupported numeric type $INT")) + end +end + matrixdescra(A::LowerTriangular) = matrix_descr('T','L','N') matrixdescra(A::UpperTriangular) = matrix_descr('T','U','N') matrixdescra(A::Diagonal) = matrix_descr('D','F','N') diff --git a/test/test_BLAS.jl b/test/test_BLAS.jl index 1d11f83..fe505a2 100644 --- a/test/test_BLAS.jl +++ b/test/test_BLAS.jl @@ -23,64 +23,70 @@ macro test_blas(ex) end for T in (Float64, ComplexF64) - @testset "matrix-vector and matrix-matrix multiplications (non-square) -- $T" begin - for i = 1:5 - a = sprand(T, 10, 5, 0.5) - b = rand(T, 5) - @test_blas a*b ≈ Array(a)*b - B = rand(T, 5, 5) - @test_blas a*B ≈ Array(a)*B - b = rand(T, 10) - @test_blas a'*b ≈ Array(a)'*b - @test_blas transpose(a)*b ≈ transpose(Array(a))*b - B = rand(T, 10, 10) - @test_blas a'*B ≈ Array(a)'*B - @test_blas transpose(a)*B ≈ transpose(Array(a))*B + INT_TYPES = Base.USE_BLAS64 ? (Int32, Int64) : (Int32,) + for INT in INT_TYPES + @testset "matrix-vector and matrix-matrix multiplications (non-square) -- $T -- $INT" begin + for i = 1:5 + A = sprand(T, 10, 5, 0.5) + A = SparseMatrixCSC{T,INT}(A) + b = rand(T, 5) + @test_blas A*b ≈ Array(A)*b + B = rand(T, 5, 5) + @test_blas A*B ≈ Array(A)*B + b = rand(T, 10) + @test_blas A'*b ≈ Array(A)'*b + @test_blas transpose(A)*b ≈ transpose(Array(A))*b + B = rand(T, 10, 10) + @test_blas A'*B ≈ Array(A)'*B + @test_blas transpose(A)*B ≈ transpose(Array(A))*B + end end - end - @testset "Symmetric / Hermitian -- $T" begin - n = 10 - A = sprandn(T, n, n, 0.5) + sqrt(n)*I - b = rand(T, n) - B = rand(T, n, 3) - symA = A + transpose(A) - hermA = A + adjoint(A) - @test_blas Symmetric(symA) * b ≈ Array(Symmetric(symA)) * b - @test_blas Hermitian(hermA) * b ≈ Array(Hermitian(hermA)) * b - @test_blas Symmetric(symA) * B ≈ Array(Symmetric(symA)) * B - @test_blas Hermitian(hermA) * B ≈ Array(Hermitian(hermA)) * B - end + @testset "Symmetric / Hermitian -- $T -- $INT" begin + n = 10 + A = sprandn(T, n, n, 0.5) + sqrt(n)*I + A = SparseMatrixCSC{T,INT}(A) + b = rand(T, n) + B = rand(T, n, 3) + symA = A + transpose(A) + hermA = A + adjoint(A) + @test_blas Symmetric(symA) * b ≈ Array(Symmetric(symA)) * b + @test_blas Hermitian(hermA) * b ≈ Array(Hermitian(hermA)) * b + @test_blas Symmetric(symA) * B ≈ Array(Symmetric(symA)) * B + @test_blas Hermitian(hermA) * B ≈ Array(Hermitian(hermA)) * B + end - @testset "triangular -- $T" begin - n = 10 - A = sprandn(T, n, n, 0.5) + sqrt(n)*I - b = rand(T, n) - B = rand(T, n, 3) - trilA = tril(A) - triuA = triu(A) - trilUA = tril(A, -1) + I - triuUA = triu(A, 1) + I + @testset "triangular -- $T -- $INT" begin + n = 10 + A = sprandn(T, n, n, 0.5) + sqrt(n)*I + A = SparseMatrixCSC{T,INT}(A) + b = rand(T, n) + B = rand(T, n, 3) + trilA = tril(A) + triuA = triu(A) + trilUA = tril(A, -1) + I + triuUA = triu(A, 1) + I - @test_blas LowerTriangular(trilA) \ b ≈ Array(LowerTriangular(trilA)) \ b - @test_blas LowerTriangular(trilA) * b ≈ Array(LowerTriangular(trilA)) * b - @test_blas LowerTriangular(trilA) \ B ≈ Array(LowerTriangular(trilA)) \ B - @test_blas LowerTriangular(trilA) * B ≈ Array(LowerTriangular(trilA)) * B + @test_blas LowerTriangular(trilA) \ b ≈ Array(LowerTriangular(trilA)) \ b + @test_blas LowerTriangular(trilA) * b ≈ Array(LowerTriangular(trilA)) * b + @test_blas LowerTriangular(trilA) \ B ≈ Array(LowerTriangular(trilA)) \ B + @test_blas LowerTriangular(trilA) * B ≈ Array(LowerTriangular(trilA)) * B - @test_blas UpperTriangular(triuA) \ b ≈ Array(UpperTriangular(triuA)) \ b - @test_blas UpperTriangular(triuA) * b ≈ Array(UpperTriangular(triuA)) * b - @test_blas UpperTriangular(triuA) \ B ≈ Array(UpperTriangular(triuA)) \ B - @test_blas UpperTriangular(triuA) * B ≈ Array(UpperTriangular(triuA)) * B + @test_blas UpperTriangular(triuA) \ b ≈ Array(UpperTriangular(triuA)) \ b + @test_blas UpperTriangular(triuA) * b ≈ Array(UpperTriangular(triuA)) * b + @test_blas UpperTriangular(triuA) \ B ≈ Array(UpperTriangular(triuA)) \ B + @test_blas UpperTriangular(triuA) * B ≈ Array(UpperTriangular(triuA)) * B - @test_blas UnitLowerTriangular(trilUA) \ b ≈ Array(UnitLowerTriangular(trilUA)) \ b - @test_blas UnitLowerTriangular(trilUA) * b ≈ Array(UnitLowerTriangular(trilUA)) * b - @test_blas UnitLowerTriangular(trilUA) \ B ≈ Array(UnitLowerTriangular(trilUA)) \ B - @test_blas UnitLowerTriangular(trilUA) * B ≈ Array(UnitLowerTriangular(trilUA)) * B + @test_blas UnitLowerTriangular(trilUA) \ b ≈ Array(UnitLowerTriangular(trilUA)) \ b + @test_blas UnitLowerTriangular(trilUA) * b ≈ Array(UnitLowerTriangular(trilUA)) * b + @test_blas UnitLowerTriangular(trilUA) \ B ≈ Array(UnitLowerTriangular(trilUA)) \ B + @test_blas UnitLowerTriangular(trilUA) * B ≈ Array(UnitLowerTriangular(trilUA)) * B - @test_blas UnitUpperTriangular(triuUA) \ b ≈ Array(UnitUpperTriangular(triuUA)) \ b - @test_blas UnitUpperTriangular(triuUA) * b ≈ Array(UnitUpperTriangular(triuUA)) * b - @test_blas UnitUpperTriangular(triuUA) \ B ≈ Array(UnitUpperTriangular(triuUA)) \ B - @test_blas UnitUpperTriangular(triuUA) * B ≈ Array(UnitUpperTriangular(triuUA)) * B + @test_blas UnitUpperTriangular(triuUA) \ b ≈ Array(UnitUpperTriangular(triuUA)) \ b + @test_blas UnitUpperTriangular(triuUA) * b ≈ Array(UnitUpperTriangular(triuUA)) * b + @test_blas UnitUpperTriangular(triuUA) \ B ≈ Array(UnitUpperTriangular(triuUA)) \ B + @test_blas UnitUpperTriangular(triuUA) * B ≈ Array(UnitUpperTriangular(triuUA)) * B + end end end