diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py index 71c0604..7215418 100644 --- a/tests/test_adjoint.py +++ b/tests/test_adjoint.py @@ -7,6 +7,7 @@ from .helpers import ( make_diagonal_operator, + make_identity_operator, make_operators, make_tridiagonal_operator, tree_allclose, @@ -16,7 +17,10 @@ @pytest.mark.parametrize("make_operator", make_operators) @pytest.mark.parametrize("dtype", (jnp.float64, jnp.complex128)) def test_adjoint(make_operator, dtype, getkey): - if make_operator is make_diagonal_operator: + if ( + make_operator is make_diagonal_operator + or make_operator is make_identity_operator + ): matrix = jnp.eye(4, dtype=dtype) tags = lx.diagonal_tag in_size = out_size = 4