diff --git a/lineax/_operator.py b/lineax/_operator.py index ad1e4f5..d8ab745 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -733,7 +733,14 @@ def mv(self, vector): return jtu.tree_unflatten(treedef, shaped) def as_matrix(self): - return jnp.eye(self.out_size(), self.in_size()) + leaves = jtu.tree_leaves(self.in_structure()) + with jax.numpy_dtype_promotion("standard"): + dtype = ( + default_floating_dtype() + if len(leaves) == 0 + else jnp.result_type(*leaves) + ) + return jnp.eye(self.out_size(), self.in_size(), dtype=dtype) def transpose(self): return IdentityLinearOperator(self.out_structure(), self.in_structure()) diff --git a/tests/helpers.py b/tests/helpers.py index 32f81d0..ddbdaef 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -198,6 +198,12 @@ def make_diagonal_operator(getkey, matrix, tags): return lx.DiagonalLinearOperator(diag) +@_operators_append +def make_identity_operator(getkey, matrix, tags): + in_struct = jax.ShapeDtypeStruct((matrix.shape[-1],), matrix.dtype) + return lx.IdentityLinearOperator(input_structure=in_struct) + + @_operators_append def make_tridiagonal_operator(getkey, matrix, tags): diag1 = jnp.diag(matrix) diff --git a/tests/test_operator.py b/tests/test_operator.py index e0e1e7a..06e4ddf 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -23,22 +23,37 @@ from .helpers import ( make_diagonal_operator, + make_identity_operator, make_operators, make_tridiagonal_operator, tree_allclose, ) -def test_ops(getkey): - matrix1 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3))) - matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3))) - scalar = jr.normal(getkey(), ()) +@pytest.mark.parametrize("make_operator", make_operators) +@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) +def test_ops(make_operator, getkey, dtype): + if ( + make_operator is make_diagonal_operator + or make_operator is make_identity_operator + ): + matrix = jnp.eye(3, dtype=dtype) + tags = lx.diagonal_tag + elif make_operator is make_tridiagonal_operator: + matrix = jnp.eye(3, dtype=dtype) + tags = lx.tridiagonal_tag + else: + matrix = jr.normal(getkey(), (3, 3), dtype=dtype) + tags = () + matrix1 = make_operator(getkey, matrix, tags) + matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype)) + scalar = jr.normal(getkey(), (), dtype=dtype) add = matrix1 + matrix2 composed = matrix1 @ matrix2 mul = matrix1 * scalar rmul = cast(lx.AbstractLinearOperator, scalar * matrix1) div = matrix1 / scalar - vec = jr.normal(getkey(), (3,)) + vec = jr.normal(getkey(), (3,), dtype=dtype) assert tree_allclose(matrix1.mv(vec) + matrix2.mv(vec), add.mv(vec)) assert tree_allclose(matrix1.mv(matrix2.mv(vec)), composed.mv(vec)) @@ -66,7 +81,10 @@ def test_ops(getkey): @pytest.mark.parametrize("make_operator", make_operators) def test_structures_vector(make_operator, getkey): - if make_operator is make_diagonal_operator: + if ( + make_operator is make_diagonal_operator + or make_operator is make_identity_operator + ): matrix = jnp.eye(4) tags = lx.diagonal_tag in_size = out_size = 4 @@ -96,6 +114,12 @@ def _setup(getkey, matrix, tag: Union[object, frozenset[object]] = frozenset()): lx.symmetric_tag, ): continue + if make_operator is make_identity_operator and tag not in ( + lx.tridiagonal_tag, + lx.diagonal_tag, + lx.symmetric_tag, + ): + continue operator = make_operator(getkey, matrix, tag) yield operator @@ -327,7 +351,7 @@ def test_identity_with_different_structures(): assert op1.T == op2 # assert op2.transpose((True, False)) == op3 - assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7)) + assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.float32)) assert op1.in_size() == 7 assert op1.out_size() == 5 vec1 = ( @@ -356,7 +380,7 @@ def test_identity_with_different_structures_complex(): assert op1.T == op2 # assert op2.transpose((True, False)) == op3 - assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7)) + assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.complex128)) assert op1.in_size() == 7 assert op1.out_size() == 5 vec1 = (