Skip to content

Commit

Permalink
Set the right shape (n, m) for the regrid matrices. Fixes #233
Browse files Browse the repository at this point in the history
  • Loading branch information
Huite committed Apr 23, 2024
1 parent c81c6fc commit 31aa2e4
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 11 deletions.
12 changes: 12 additions & 0 deletions tests/test_sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,3 +61,15 @@ def test_csr_to_coo(csr_matrix):
assert np.array_equal(coo_matrix.row, [0, 0, 1, 1, 2, 2, 3, 3, 4, 4])
assert np.array_equal(coo_matrix.col, np.arange(10))
assert coo_matrix.nnz == 10


def test_shape():
source_index = np.arange(10)
target_index = np.repeat(np.arange(5), 2)
weights = np.full(10, 0.5)
matrix = sparse.MatrixCSR.from_triplet(target_index, source_index, weights, n=20)
assert matrix.n == 20
assert matrix.m == 10
matrix = sparse.MatrixCSR.from_triplet(target_index, source_index, weights, m=20)
assert matrix.n == 5
assert matrix.m == 20
27 changes: 19 additions & 8 deletions xugrid/core/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,31 @@ class MatrixCOO(NamedTuple):
nnz: int

@staticmethod
def from_triplet(row, col, data) -> "MatrixCOO":
return MatrixCOO(data, row, col, row.max(), col.max(), data.size)
def from_triplet(row, col, data, n=None, m=None) -> "MatrixCOO":
if n is None:
n = row.max() + 1
if m is None:
m = col.max() + 1
nnz = data.size
return MatrixCOO(data, row, col, n, m, nnz)

def to_csr(self) -> "MatrixCSR":
"""
Convert COO matrix to CSR matrix.
Assumes the COO matrix indices are already sorted by row number!
"""
i = np.cumsum(np.bincount(self.row))
i = np.cumsum(np.bincount(self.row, minlength=self.n))
indptr = np.empty(i.size + 1, dtype=IntDType)
indptr[0] = 0
indptr[1:] = i
return MatrixCSR(
self.data,
self.col,
indptr,
indptr.size - 1,
self.col.max(),
self.data.size,
self.n,
self.m,
self.nnz,
)


Expand Down Expand Up @@ -107,8 +112,14 @@ def from_csr_matrix(A: sparse.csr_matrix) -> "MatrixCSR":
return MatrixCSR(A.data, A.indices, A.indptr, n, m, A.nnz)

@staticmethod
def from_triplet(row, col, data) -> "MatrixCSR":
return MatrixCOO.from_triplet(row, col, data).to_csr()
def from_triplet(
row,
col,
data,
n=None,
m=None,
) -> "MatrixCSR":
return MatrixCOO.from_triplet(row, col, data, n, m).to_csr()

def to_coo(self) -> MatrixCOO:
"""
Expand Down
12 changes: 9 additions & 3 deletions xugrid/regrid/regridder.py
Original file line number Diff line number Diff line change
Expand Up @@ -326,7 +326,11 @@ def _compute_weights(self, source, target):
source, target = convert_to_match(source, target)
source_index, target_index, weight_values = source.locate_centroids(target)
self._weights = MatrixCOO.from_triplet(
target_index, source_index, weight_values
target_index,
source_index,
weight_values,
n=target.size,
m=source.size,
)
return

Expand Down Expand Up @@ -364,7 +368,7 @@ def _compute_weights(self, source, target, relative: bool) -> None:
target, relative=relative
)
self._weights = MatrixCSR.from_triplet(
target_index, source_index, weight_values
target_index, source_index, weight_values, n=target.size, m=source.size
)
return

Expand Down Expand Up @@ -526,7 +530,9 @@ def _compute_weights(self, source, target):
source_index, target_index, weights = source.linear_weights(target)
else:
source_index, target_index, weights = source.barycentric(target)
self._weights = MatrixCSR.from_triplet(target_index, source_index, weights)
self._weights = MatrixCSR.from_triplet(
target_index, source_index, weights, n=target.size, m=source.size
)
return

@property
Expand Down

0 comments on commit 31aa2e4

Please sign in to comment.