Skip to content

Commit

Permalink
rename typealias MKLSparseMat to SparseMat
Browse files Browse the repository at this point in the history
to distringuish from MKLSparseMatrix
  • Loading branch information
Alexey Stukalov authored and alyst committed Sep 17, 2024
1 parent c8509e5 commit 43d27e6
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions src/interface.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Base: \, *
import LinearAlgebra: mul!, ldiv!

MKLSparseMat{T} = Union{SparseArrays.AbstractSparseMatrixCSC{T}, SparseMatrixCSR{T}, SparseMatrixCOO{T}}
SparseMat{T} = Union{SparseArrays.AbstractSparseMatrixCSC{T}, SparseMatrixCSR{T}, SparseMatrixCOO{T}}

AdjOrTranspMat{T, M} = Union{Adjoint{T, <:M}, Transpose{T,<:M}}

Expand Down Expand Up @@ -41,7 +41,7 @@ describe_and_unwrap(A::Hermitian{<:Any, T}) where T <: Union{Adjoint, Transpose}
# mul!(vec, sparse, vec, a, b)
function mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S},
x::StridedVector{T}, alpha::Number, beta::Number
) where {T <: BlasFloat, S <: MKLSparseMat{T}}
) where {T <: BlasFloat, S <: SparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
# fix the strange behaviour of multipling adjoint vectors by triangular matrices
# looks like wrong the triangle is being used
Expand All @@ -54,7 +54,7 @@ end
# mul!(dense, sparse, dense, a, b)
function mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::StridedMatrix{T}, alpha::Number, beta::Number
) where {T <: BlasFloat, S <: MKLSparseMat{T}}
) where {T <: BlasFloat, S <: SparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
mm!(transA, T(alpha), unwrapA, descrA, B, T(beta), C)
end
Expand All @@ -78,56 +78,56 @@ end

# 3-arg mul!() calls 5-arg mul!()
mul!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S},
x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
x::StridedVector{T}) where {T <: BlasFloat, S <: SparseMat{T}} =
mul!(y, A, x, one(T), zero(T))
mul!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
B::StridedMatrix{T}) where {T <: BlasFloat, S <: SparseMat{T}} =
mul!(C, A, B, one(T), zero(T))
mul!(C::StridedMatrix{T}, A::StridedMatrix{T},
B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
mul!(C, A, B, one(T), zero(T))

# define 4-arg ldiv!(C, A, B, a) (C := alpha*inv(A)*B) that is not present in standard LinearAlgrebra
# redefine 3-arg ldiv!(C, A, B) using 4-arg ldiv!(C, A, B, 1)
function ldiv!(y::StridedVector{T}, A::SimpleOrSpecialOrAdjMat{T, S},
x::StridedVector{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: MKLSparseMat{T}}
x::StridedVector{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: SparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
trsv!(transA, alpha, unwrapA, descrA, x, y)
end

function LinearAlgebra.ldiv!(C::StridedMatrix{T}, A::SimpleOrSpecialOrAdjMat{T, S},
B::StridedMatrix{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: MKLSparseMat{T}}
B::StridedMatrix{T}, alpha::Number = one(T)) where {T <: BlasFloat, S <: SparseMat{T}}
transA, descrA, unwrapA = describe_and_unwrap(A)
trsm!(transA, alpha, unwrapA, descrA, B, C)
end

if VERSION < v"1.10"
# stdlib v1.9 does not provide these methods

(*)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
(*)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: SparseMat{T}} =
mul!(Vector{T}(undef, size(A, 1)), A, x)

(*)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
(*)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: SparseMat{T}} =
mul!(Matrix{T}(undef, size(A, 1), size(B, 2)), A, B)

# xᵀ * B = (Bᵀ * x)ᵀ
(*)(x::Transpose{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
(*)(x::Transpose{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
transpose(mul!(similar(x, size(B, 2)), transpose(B), parent(x)))

# xᴴ * B = (Bᴴ * x)ᴴ
(*)(x::Adjoint{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
(*)(x::Adjoint{T, <:StridedVector{T}}, B::SimpleOrSpecialMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
adjoint(mul!(similar(x, size(B, 2)), adjoint(B), parent(x)))

end # if VERSION < v"1.10"

(*)(A::StridedMatrix{T}, B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: MKLSparseMat{T}} =
(*)(A::StridedMatrix{T}, B::SimpleOrSpecialOrAdjMat{T, S}) where {T <: BlasFloat, S <: SparseMat{T}} =
mul!(Matrix{T}(undef, size(A, 1), size(B, 2)), A, B)

# stdlib does not provide these methods for complex types

# xᴴ * Bᵀ = (Bᵀᴴ * x)ᴴ
function (*)(x::Adjoint{T, <:StridedVector{T}}, B::Transpose{T, <:SimpleOrSpecialMat{T, S}}
) where {T <: Union{ComplexF32, ComplexF64}, S <: MKLSparseMat{T}}
) where {T <: Union{ComplexF32, ComplexF64}, S <: SparseMat{T}}
transB, descrB, unwrapB = describe_and_unwrap(parent(B))
y = similar(x, size(B, 2))
adjoint(mv!('C', one(T), lazypermutedims(unwrapB), lazypermutedims(descrB), parent(x),
Expand All @@ -136,20 +136,20 @@ end

# xᵀ * Bᴴ = (Bᵀᴴ * x)ᵀ
function (*)(x::Transpose{T, <:StridedVector{T}}, B::Adjoint{T, <:SimpleOrSpecialMat{T, S}}
) where {T <: Union{ComplexF32, ComplexF64}, S <: MKLSparseMat{T}}
) where {T <: Union{ComplexF32, ComplexF64}, S <: SparseMat{T}}
transB, descrB, unwrapB = describe_and_unwrap(parent(B))
y = similar(x, size(B, 2))
transpose(mv!('C', one(T), lazypermutedims(unwrapB), lazypermutedims(descrB), parent(x),
zero(T), y))
end

function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}}
function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, x::StridedVector{T}) where {T <: BlasFloat, S <: SparseMat{T}}
n = length(x)
y = Vector{T}(undef, n)
return ldiv!(y, A, x)
end

function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: MKLSparseMat{T}}
function (\)(A::SimpleOrSpecialOrAdjMat{T, S}, B::StridedMatrix{T}) where {T <: BlasFloat, S <: SparseMat{T}}
m, n = size(B)
C = Matrix{T}(undef, m, n)
return ldiv!(C, A, B)
Expand Down

0 comments on commit 43d27e6

Please sign in to comment.