Skip to content

Commit

Permalink
Make identity operator return matrix with matching dtype
Browse files Browse the repository at this point in the history
  • Loading branch information
Randl committed Aug 8, 2024
1 parent 2272e63 commit ead1e90
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 9 deletions.
9 changes: 8 additions & 1 deletion lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -733,7 +733,14 @@ def mv(self, vector):
return jtu.tree_unflatten(treedef, shaped)

def as_matrix(self):
return jnp.eye(self.out_size(), self.in_size())
leaves = jtu.tree_leaves(self.in_structure())
with jax.numpy_dtype_promotion("standard"):
dtype = (
default_floating_dtype()
if len(leaves) == 0
else jnp.result_type(*leaves)
)
return jnp.eye(self.out_size(), self.in_size(), dtype=dtype)

def transpose(self):
return IdentityLinearOperator(self.out_structure(), self.in_structure())
Expand Down
6 changes: 6 additions & 0 deletions tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,12 @@ def make_diagonal_operator(getkey, matrix, tags):
return lx.DiagonalLinearOperator(diag)


@_operators_append
def make_identity_operator(getkey, matrix, tags):
in_struct = jax.ShapeDtypeStruct((matrix.shape[-1],), matrix.dtype)
return lx.IdentityLinearOperator(input_structure=in_struct)


@_operators_append
def make_tridiagonal_operator(getkey, matrix, tags):
diag1 = jnp.diag(matrix)
Expand Down
40 changes: 32 additions & 8 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,22 +23,37 @@

from .helpers import (
make_diagonal_operator,
make_identity_operator,
make_operators,
make_tridiagonal_operator,
tree_allclose,
)


def test_ops(getkey):
matrix1 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3)))
matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3)))
scalar = jr.normal(getkey(), ())
@pytest.mark.parametrize("make_operator", make_operators)
@pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128))
def test_ops(make_operator, getkey, dtype):
if (
make_operator is make_diagonal_operator
or make_operator is make_identity_operator
):
matrix = jnp.eye(3, dtype=dtype)
tags = lx.diagonal_tag
elif make_operator is make_tridiagonal_operator:
matrix = jnp.eye(3, dtype=dtype)
tags = lx.tridiagonal_tag
else:
matrix = jr.normal(getkey(), (3, 3), dtype=dtype)
tags = ()
matrix1 = make_operator(getkey, matrix, tags)
matrix2 = lx.MatrixLinearOperator(jr.normal(getkey(), (3, 3), dtype=dtype))
scalar = jr.normal(getkey(), (), dtype=dtype)
add = matrix1 + matrix2
composed = matrix1 @ matrix2
mul = matrix1 * scalar
rmul = cast(lx.AbstractLinearOperator, scalar * matrix1)
div = matrix1 / scalar
vec = jr.normal(getkey(), (3,))
vec = jr.normal(getkey(), (3,), dtype=dtype)

assert tree_allclose(matrix1.mv(vec) + matrix2.mv(vec), add.mv(vec))
assert tree_allclose(matrix1.mv(matrix2.mv(vec)), composed.mv(vec))
Expand Down Expand Up @@ -66,7 +81,10 @@ def test_ops(getkey):

@pytest.mark.parametrize("make_operator", make_operators)
def test_structures_vector(make_operator, getkey):
if make_operator is make_diagonal_operator:
if (
make_operator is make_diagonal_operator
or make_operator is make_identity_operator
):
matrix = jnp.eye(4)
tags = lx.diagonal_tag
in_size = out_size = 4
Expand Down Expand Up @@ -96,6 +114,12 @@ def _setup(getkey, matrix, tag: Union[object, frozenset[object]] = frozenset()):
lx.symmetric_tag,
):
continue
if make_operator is make_identity_operator and tag not in (
lx.tridiagonal_tag,
lx.diagonal_tag,
lx.symmetric_tag,
):
continue
operator = make_operator(getkey, matrix, tag)
yield operator

Expand Down Expand Up @@ -327,7 +351,7 @@ def test_identity_with_different_structures():

assert op1.T == op2
# assert op2.transpose((True, False)) == op3
assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7))
assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.float32))
assert op1.in_size() == 7
assert op1.out_size() == 5
vec1 = (
Expand Down Expand Up @@ -356,7 +380,7 @@ def test_identity_with_different_structures_complex():

assert op1.T == op2
# assert op2.transpose((True, False)) == op3
assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7))
assert jnp.array_equal(op1.as_matrix(), jnp.eye(5, 7, dtype=jnp.complex128))
assert op1.in_size() == 7
assert op1.out_size() == 5
vec1 = (
Expand Down

0 comments on commit ead1e90

Please sign in to comment.