diff --git a/src/Metis.jl b/src/Metis.jl index ad9b9a2..fc0b780 100644 --- a/src/Metis.jl +++ b/src/Metis.jl @@ -3,7 +3,7 @@ module Metis using SparseArrays -using LinearAlgebra: ishermitian +using LinearAlgebra: ishermitian, Hermitian, Symmetric using METIS_jll: libmetis # Metis C API: Clang.jl auto-generated bindings and some manual methods @@ -33,6 +33,7 @@ struct Graph end end + """ Metis.graph(G::SparseMatrixCSC; weights=false, check_hermitian=true) @@ -71,6 +72,102 @@ function graph(G::SparseMatrixCSC; weights::Bool=false, check_hermitian::Bool=tr return Graph(idx_t(N), xadj, adjncy, vwgt, adjwgt) end +const HermOrSymCSC{Tv, Ti} = Union{ + Hermitian{Tv, SparseMatrixCSC{Tv, Ti}}, Symmetric{Tv, SparseMatrixCSC{Tv, Ti}}, +} + +if VERSION < v"1.10" + # From https://github.com/JuliaSparse/SparseArrays.jl/blob/313a04f4a78bbc534f89b6b4d9c598453e2af17c/src/linalg.jl#L1106-L1117 + # MIT license: https://github.com/JuliaSparse/SparseArrays.jl/blob/main/LICENSE.md + function nzrangeup(A, i, excl=false) + r = nzrange(A, i); r1 = r.start; r2 = r.stop + rv = rowvals(A) + @inbounds r2 < r1 || rv[r2] <= i - excl ? r : r1:(searchsortedlast(view(rv, r1:r2), i - excl) + r1-1) + end + function nzrangelo(A, i, excl=false) + r = nzrange(A, i); r1 = r.start; r2 = r.stop + rv = rowvals(A) + @inbounds r2 < r1 || rv[r1] >= i + excl ? r : (searchsortedfirst(view(rv, r1:r2), i + excl) + r1-1):r2 + end +else + using SparseArrays: nzrangeup, nzrangelo +end + +""" + Metis.graph(G::Union{Hermitian, Symmetric}; weights::Bool=false) + +Construct the 1-based CSR representation of the `Hermitian` or `Symmetric` wrapped sparse +matrix `G`. +Weights are not currently supported for this method so passing `weights=true` will throw an +error. +""" +function graph(H::HermOrSymCSC; weights::Bool=false) + # This method is derived from the method `SparseMatrixCSC(::HermOrSymCSC)` from + # SparseArrays.jl + # (https://github.com/JuliaSparse/SparseArrays.jl/blob/313a04f4a78bbc534f89b6b4d9c598453e2af17c/src/sparseconvert.jl#L124-L173) + # with MIT license + # (https://github.com/JuliaSparse/SparseArrays.jl/blob/main/LICENSE.md). + weights && throw(ArgumentError("weights not supported yet")) + # Extract data + A = H.data + upper = H.uplo == 'U' + rowval = rowvals(A) + m, n = size(A) + @assert m == n + # New colptr for the full matrix + newcolptr = Vector{idx_t}(undef, n + 1) + newcolptr[1] = 1 + # SparseArrays.nzrange for the upper/lower part excluding the diagonal + nzrng = if upper + (A, col) -> nzrangeup(A, col, #=exclude diagonal=# true) + else + (A, col) -> nzrangelo(A, col, #=exclude diagonal=# true) + end + # If the upper part is stored we loop forward, otherwise backwards + colrange = upper ? (1:1:n) : (n:-1:1) + @inbounds for col in colrange + r = nzrng(A, col) + # Number of entries in the stored part of this column, excluding the diagonal entry + newcolptr[col + 1] = length(r) + # Increment columnptr corresponding to the stored rows + for k in r + row = rowval[k] + @assert upper ? row < col : row > col + @assert row != col # Diagonal entries should not be here + newcolptr[row + 1] += 1 + end + end + # Accumulate the colptr and allocate new rowval + cumsum!(newcolptr, newcolptr) + nz = newcolptr[n + 1] - 1 + newrowval = Vector{idx_t}(undef, nz) + # Populate the rowvals + @inbounds for col = 1:n + newk = newcolptr[col] + for k in nzrng(A, col) + row = rowval[k] + @assert col != row + newrowval[newk] = row + newk += 1 + ni = newcolptr[row] + newrowval[ni] = col + newcolptr[row] = ni + 1 + end + newcolptr[col] = newk + end + # Shuffle back the colptrs + @inbounds for j = n:-1:1 + newcolptr[j+1] = newcolptr[j] + end + newcolptr[1] = 1 + # Return Graph + N = n + xadj = newcolptr + adjncy = newrowval + vwgt = C_NULL + adjwgt = C_NULL + return Graph(idx_t(N), xadj, adjncy, vwgt, adjwgt) +end """ perm, iperm = Metis.permutation(G) diff --git a/test/runtests.jl b/test/runtests.jl index 85e4338..7ba3e11 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Metis using Random using Test using SparseArrays +using LinearAlgebra: Symmetric, Hermitian import LightGraphs, Graphs @testset "Metis.graph(::SparseMatrixCSC)" begin @@ -20,6 +21,24 @@ import LightGraphs, Graphs @test iszero(S - GW) end +@testset "Metis.graph(::Union{Hermitian, Symmetric})" begin + rng = MersenneTwister(0) + for T in (Symmetric, Hermitian), uplo in (:U, :L) + S = sprand(rng, Int, 10, 10, 0.2); fill!(S.nzval, 1) + TS = T(S, uplo) + CSCS = SparseMatrixCSC(TS) + @test TS == CSCS + g1 = Metis.graph(TS) + g2 = Metis.graph(CSCS) + @test g1.nvtxs == g2.nvtxs + @test g1.xadj == g2.xadj + @test g1.adjncy == g2.adjncy + @test g1.vwgt == g2.vwgt == C_NULL + @test g1.adjwgt == g2.adjwgt == C_NULL + @test_throws ArgumentError Metis.graph(TS; weights = true) + end +end + @testset "Metis.permutation" begin rng = MersenneTwister(0) S = sprand(rng, 10, 10, 0.5); S = S + S'; fill!(S.nzval, 1)