Skip to content

Commit

Permalink
Fix dtype casting for apply_csr
Browse files Browse the repository at this point in the history
  • Loading branch information
kburns committed Dec 23, 2023
1 parent 1fc2048 commit 3b6412b
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 9 deletions.
8 changes: 4 additions & 4 deletions dedalus/core/basis.py
Original file line number Diff line number Diff line change
Expand Up @@ -5228,9 +5228,9 @@ def radial_matrix(self, spinindex_in, spinindex_out, m):
raise ValueError("This should never happen.")
n_size = self.input_basis.n_size(m)
if m == 0:
return sparse.identity(n_size)
return sparse.identity(n_size).tocsr()
else:
return sparse.csr_matrix((0, n_size), dtype=self.dtype)
return sparse.csr_matrix((0, n_size), dtype=self.dtype).tocsr()


class SphereAzimuthalAverage(AzimuthalAverage, operators.Average, operators.SpectralOperator):
Expand Down Expand Up @@ -5343,9 +5343,9 @@ def radial_matrix(self, regindex_in, regindex_out, ell):
raise ValueError("This should never happen.")
n_size = self.input_basis.n_size(ell)
if ell == 0:
return sparse.identity(n_size)
return sparse.identity(n_size).tocsr()
else:
return sparse.csr_matrix((0, n_size), dtype=self.dtype)
return sparse.csr_matrix((0, n_size), dtype=self.dtype).tocsr()


class IntegrateSpinBasis(operators.PolarMOperator):
Expand Down
9 changes: 5 additions & 4 deletions dedalus/core/subsystems.py
Original file line number Diff line number Diff line change
Expand Up @@ -549,10 +549,11 @@ def build_matrices(self, names):
right_perm = right_permutation(self, vars, tau_left=solver.tau_left, interleave_components=solver.interleave_components).tocsr()

# Preconditioners
self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr()
self.pre_left_pinv = self.pre_left.T.tocsr()
self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr()
self.pre_right = self.pre_right_pinv.T.tocsr()
# TODO: remove astype casting, requires dealing with used types in apply_sparse
self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype)
self.pre_left_pinv = self.pre_left.T.tocsr().astype(dtype)
self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype)
self.pre_right = self.pre_right_pinv.T.tocsr().astype(dtype)

# Check preconditioner pseudoinverses
assert_sparse_pinv(self.pre_left, self.pre_left_pinv)
Expand Down
7 changes: 6 additions & 1 deletion dedalus/tools/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,8 +193,13 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads=
if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0:
out.fill(0)
return csr_matvecs(matrix, array, out)
# Promote datatypes
# TODO: find way to optimize this with fused types
matrix_data = matrix.data
if matrix_data.dtype != out.dtype:
matrix_data = matrix_data.astype(out.dtype)
# Call cython routine
cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix.data, array, out, axis, num_threads)
cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads)
return out


Expand Down

0 comments on commit 3b6412b

Please sign in to comment.