diff --git a/docs/examples/operators.ipynb b/docs/examples/operators.ipynb index 97ff64c..85ae616 100644 --- a/docs/examples/operators.ipynb +++ b/docs/examples/operators.ipynb @@ -167,7 +167,7 @@ "\n", "- We've already seen some general examples above, like [`lineax.MatrixLinearOperator`][].\n", "- We've already seen some structured examples above, like [`lineax.TridiagonalLinearOperator`][].\n", - "- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\t\\imes m}$.\n", + "- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\times m}$.\n", "- Given a linear function $g \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$, then [`lineax.FunctionLinearOperator`][] represents the matrix corresponding to this linear function, i.e. the unique matrix $A$ for which $g(x) = Ax$.\n", "- etc!\n", "\n", diff --git a/lineax/__init__.py b/lineax/__init__.py index 6c9743c..0e63ffc 100644 --- a/lineax/__init__.py +++ b/lineax/__init__.py @@ -18,6 +18,7 @@ AddLinearOperator as AddLinearOperator, AuxLinearOperator as AuxLinearOperator, ComposedLinearOperator as ComposedLinearOperator, + conj as conj, diagonal as diagonal, DiagonalLinearOperator as DiagonalLinearOperator, DivLinearOperator as DivLinearOperator, diff --git a/lineax/_operator.py b/lineax/_operator.py index 410aa80..96548b9 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -511,7 +511,7 @@ class JacobianLinearOperator(AbstractLinearOperator): The Jacobian is not materialised; matrix-vector products, which are in fact Jacobian-vector products, are computed using autodifferentiation, specifically - `jax.jvp`. Thus `JacobianLinearOperator(fn, x).mv(v)` is equivalent to + `jax.jvp`. Thus, `JacobianLinearOperator(fn, x).mv(v)` is equivalent to `jax.jvp(fn, (x,), (v,))`. See also [`lineax.linearise`][], which caches the primal computation, i.e. @@ -1917,3 +1917,104 @@ def _(operator): d = has_unit_diagonal(operator.operator1) e = has_unit_diagonal(operator.operator2) return (a or b or c) and d and e + + +# conj + + +@ft.singledispatch +def conj(operator: AbstractLinearOperator) -> AbstractLinearOperator: + """Elementwise conjugate of a linear operator. This returns another linear operator. + + **Arguments:** + + - `operator`: a linear operator. + + **Returns:** + + Another linear operator. + """ + _default_not_implemented("conj", operator) + + +@conj.register(MatrixLinearOperator) +def _(operator): + return MatrixLinearOperator(operator.matrix.conj(), operator.tags) + + +@conj.register(PyTreeLinearOperator) +def _(operator): + pytree_conj = jtu.tree_map(lambda x: x.conj(), operator.pytree) + return PyTreeLinearOperator(pytree_conj, operator.out_structure(), operator.tags) + + +@conj.register(JacobianLinearOperator) +def _(operator): + return conj(linearise(operator)) + + +@conj.register(FunctionLinearOperator) +def _(operator): + return FunctionLinearOperator( + lambda vec: jtu.tree_map(jnp.conj, operator.mv(jtu.tree_map(jnp.conj, vec))), + operator.in_structure(), + operator.tags, + ) + + +@conj.register(IdentityLinearOperator) +def _(operator): + return operator + + +@conj.register(DiagonalLinearOperator) +def _(operator): + return DiagonalLinearOperator(operator.diagonal.conj()) + + +@conj.register(TridiagonalLinearOperator) +def _(operator): + return TridiagonalLinearOperator( + operator.diagonal.conj(), + operator.lower_diagonal.conj(), + operator.upper_diagonal.conj(), + ) + + +@conj.register(TaggedLinearOperator) +def _(operator): + return TaggedLinearOperator(conj(operator.operator), operator.tags) + + +@conj.register(TangentLinearOperator) +def _(operator): + # Should be unreachable: TangentLinearOperator is used for a narrow set of + # operations only (mv; transpose) inside the JVP rule linear_solve_p. + raise NotImplementedError( + "Please open a GitHub issue: https://github.com/google/lineax" + ) + + +@conj.register(AddLinearOperator) +def _(operator): + return conj(operator.operator1) + conj(operator.operator2) + + +@conj.register(MulLinearOperator) +def _(operator): + return conj(operator.operator) * operator.scalar.conj() + + +@conj.register(DivLinearOperator) +def _(operator): + return conj(operator.operator) / operator.scalar.conj() + + +@conj.register(ComposedLinearOperator) +def _(operator): + return conj(operator.operator1) @ conj(operator.operator2) + + +@conj.register(AuxLinearOperator) +def _(operator): + return conj(operator.operator) diff --git a/lineax/_solver/cg.py b/lineax/_solver/cg.py index b536f2f..c5a8e63 100644 --- a/lineax/_solver/cg.py +++ b/lineax/_solver/cg.py @@ -32,6 +32,7 @@ from .._misc import max_norm, resolve_rcond, tree_dot, tree_where from .._operator import ( AbstractLinearOperator, + conj, is_negative_semidefinite, is_positive_semidefinite, linearise, @@ -107,14 +108,14 @@ def compute( # If a downstream user wants to avoid this then they can call # ``` # linear_solve( - # operator.T @ operator, operator.mv(b), solver=CG() + # conj(operator.T) @ operator, operator.mv(b), solver=CG() # ) # ``` # directly. operator = linearise(operator) _mv = operator.mv - _transpose_mv = operator.transpose().mv + _transpose_mv = conj(operator.transpose()).mv def mv(vector: PyTree) -> PyTree: return _transpose_mv(_mv(vector)) diff --git a/tests/test_adjoint.py b/tests/test_adjoint.py new file mode 100644 index 0000000..b3722ff --- /dev/null +++ b/tests/test_adjoint.py @@ -0,0 +1,57 @@ +import jax +import jax.numpy as jnp +import jax.random as jr +import pytest + +import lineax as lx +from lineax import FunctionLinearOperator + +from .helpers import ( + make_diagonal_operator, + make_operators, + make_tridiagonal_operator, + shaped_allclose, +) + + +@pytest.mark.parametrize("make_operator", make_operators) +@pytest.mark.parametrize("dtype", (jnp.float32, jnp.complex64)) +def test_adjoint(make_operator, dtype, getkey): + if make_operator is make_diagonal_operator: + matrix = jnp.eye(4, dtype=dtype) + tags = lx.diagonal_tag + in_size = out_size = 4 + elif make_operator is make_tridiagonal_operator: + matrix = jnp.eye(4, dtype=dtype) + tags = lx.tridiagonal_tag + in_size = out_size = 4 + else: + matrix = jr.normal(getkey(), (3, 5), dtype=dtype) + tags = () + in_size = 5 + out_size = 3 + operator = make_operator(matrix, tags) + v1, v2 = jr.normal(getkey(), (in_size,), dtype=dtype), jr.normal( + getkey(), (out_size,), dtype=dtype + ) + + inner1 = operator.mv(v1) @ v2.conj() + adjoint_op1 = lx.conj(operator).transpose() + ov2 = adjoint_op1.mv(v2) + inner2 = v1 @ ov2.conj() + assert shaped_allclose(inner1, inner2) + + adjoint_op2 = lx.conj(operator.transpose()) + ov2 = adjoint_op2.mv(v2) + inner2 = v1 @ ov2.conj() + assert shaped_allclose(inner1, inner2) + + +def test_functional_pytree_adjoint(): + def fn(y): + return {"b": y["a"]} + + y_struct = jax.eval_shape(lambda: {"a": 0.0}) + operator = FunctionLinearOperator(fn, y_struct) + conj_operator = lx.conj(operator) + assert shaped_allclose(lx.materialise(conj_operator), lx.materialise(operator))