From 3b6412b95514a5834249b7a2f78a7056e2e03e08 Mon Sep 17 00:00:00 2001 From: "Keaton J. Burns" Date: Sat, 23 Dec 2023 09:19:20 -0500 Subject: [PATCH] Fix dtype casting for apply_csr --- dedalus/core/basis.py | 8 ++++---- dedalus/core/subsystems.py | 9 +++++---- dedalus/tools/array.py | 7 ++++++- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/dedalus/core/basis.py b/dedalus/core/basis.py index bd06be95..0791f470 100644 --- a/dedalus/core/basis.py +++ b/dedalus/core/basis.py @@ -5228,9 +5228,9 @@ def radial_matrix(self, spinindex_in, spinindex_out, m): raise ValueError("This should never happen.") n_size = self.input_basis.n_size(m) if m == 0: - return sparse.identity(n_size) + return sparse.identity(n_size).tocsr() else: - return sparse.csr_matrix((0, n_size), dtype=self.dtype) + return sparse.csr_matrix((0, n_size), dtype=self.dtype).tocsr() class SphereAzimuthalAverage(AzimuthalAverage, operators.Average, operators.SpectralOperator): @@ -5343,9 +5343,9 @@ def radial_matrix(self, regindex_in, regindex_out, ell): raise ValueError("This should never happen.") n_size = self.input_basis.n_size(ell) if ell == 0: - return sparse.identity(n_size) + return sparse.identity(n_size).tocsr() else: - return sparse.csr_matrix((0, n_size), dtype=self.dtype) + return sparse.csr_matrix((0, n_size), dtype=self.dtype).tocsr() class IntegrateSpinBasis(operators.PolarMOperator): diff --git a/dedalus/core/subsystems.py b/dedalus/core/subsystems.py index dc48be63..2818f376 100644 --- a/dedalus/core/subsystems.py +++ b/dedalus/core/subsystems.py @@ -549,10 +549,11 @@ def build_matrices(self, names): right_perm = right_permutation(self, vars, tau_left=solver.tau_left, interleave_components=solver.interleave_components).tocsr() # Preconditioners - self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr() - self.pre_left_pinv = self.pre_left.T.tocsr() - self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr() - self.pre_right = self.pre_right_pinv.T.tocsr() + # TODO: remove astype casting, requires dealing with used types in apply_sparse + self.pre_left = drop_empty_rows(left_perm @ valid_eqn).tocsr().astype(dtype) + self.pre_left_pinv = self.pre_left.T.tocsr().astype(dtype) + self.pre_right_pinv = drop_empty_rows(right_perm @ valid_var).tocsr().astype(dtype) + self.pre_right = self.pre_right_pinv.T.tocsr().astype(dtype) # Check preconditioner pseudoinverses assert_sparse_pinv(self.pre_left, self.pre_left_pinv) diff --git a/dedalus/tools/array.py b/dedalus/tools/array.py index 51d28b45..e03a0fbd 100644 --- a/dedalus/tools/array.py +++ b/dedalus/tools/array.py @@ -193,8 +193,13 @@ def apply_sparse(matrix, array, axis, out=None, check_shapes=False, num_threads= if OLD_CSR_MATVECS and array.ndim == 2 and axis == 0: out.fill(0) return csr_matvecs(matrix, array, out) + # Promote datatypes + # TODO: find way to optimize this with fused types + matrix_data = matrix.data + if matrix_data.dtype != out.dtype: + matrix_data = matrix_data.astype(out.dtype) # Call cython routine - cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix.data, array, out, axis, num_threads) + cython_linalg.apply_csr(matrix.indptr, matrix.indices, matrix_data, array, out, axis, num_threads) return out