Skip to content

Commit

Permalink
Add conj function and adjoint tests (#62)
Browse files Browse the repository at this point in the history
* Add conj function and tests for adjoint

* Minor fixes

* Add test for pytree

* Remove self-adjoint tag
  • Loading branch information
Randl authored Nov 12, 2023
1 parent a3684c5 commit 49e83ef
Show file tree
Hide file tree
Showing 5 changed files with 164 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/examples/operators.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,7 @@
"\n",
"- We've already seen some general examples above, like [`lineax.MatrixLinearOperator`][].\n",
"- We've already seen some structured examples above, like [`lineax.TridiagonalLinearOperator`][].\n",
"- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\t\\imes m}$.\n",
"- Given a function $f \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$ and a point $x \\in \\mathbb{R}^n$, then [`lineax.JacobianLinearOperator`][] represents the Jacobian $\\frac{\\mathrm{d}f}{\\mathrm{d}x}(x) \\in \\mathbb{R}^{n \\times m}$.\n",
"- Given a linear function $g \\colon \\mathbb{R}^n \\to \\mathbb{R}^m$, then [`lineax.FunctionLinearOperator`][] represents the matrix corresponding to this linear function, i.e. the unique matrix $A$ for which $g(x) = Ax$.\n",
"- etc!\n",
"\n",
Expand Down
1 change: 1 addition & 0 deletions lineax/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
AddLinearOperator as AddLinearOperator,
AuxLinearOperator as AuxLinearOperator,
ComposedLinearOperator as ComposedLinearOperator,
conj as conj,
diagonal as diagonal,
DiagonalLinearOperator as DiagonalLinearOperator,
DivLinearOperator as DivLinearOperator,
Expand Down
103 changes: 102 additions & 1 deletion lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,7 @@ class JacobianLinearOperator(AbstractLinearOperator):
The Jacobian is not materialised; matrix-vector products, which are in fact
Jacobian-vector products, are computed using autodifferentiation, specifically
`jax.jvp`. Thus `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
`jax.jvp`. Thus, `JacobianLinearOperator(fn, x).mv(v)` is equivalent to
`jax.jvp(fn, (x,), (v,))`.
See also [`lineax.linearise`][], which caches the primal computation, i.e.
Expand Down Expand Up @@ -1917,3 +1917,104 @@ def _(operator):
d = has_unit_diagonal(operator.operator1)
e = has_unit_diagonal(operator.operator2)
return (a or b or c) and d and e


# conj


@ft.singledispatch
def conj(operator: AbstractLinearOperator) -> AbstractLinearOperator:
"""Elementwise conjugate of a linear operator. This returns another linear operator.
**Arguments:**
- `operator`: a linear operator.
**Returns:**
Another linear operator.
"""
_default_not_implemented("conj", operator)


@conj.register(MatrixLinearOperator)
def _(operator):
return MatrixLinearOperator(operator.matrix.conj(), operator.tags)


@conj.register(PyTreeLinearOperator)
def _(operator):
pytree_conj = jtu.tree_map(lambda x: x.conj(), operator.pytree)
return PyTreeLinearOperator(pytree_conj, operator.out_structure(), operator.tags)


@conj.register(JacobianLinearOperator)
def _(operator):
return conj(linearise(operator))


@conj.register(FunctionLinearOperator)
def _(operator):
return FunctionLinearOperator(
lambda vec: jtu.tree_map(jnp.conj, operator.mv(jtu.tree_map(jnp.conj, vec))),
operator.in_structure(),
operator.tags,
)


@conj.register(IdentityLinearOperator)
def _(operator):
return operator


@conj.register(DiagonalLinearOperator)
def _(operator):
return DiagonalLinearOperator(operator.diagonal.conj())


@conj.register(TridiagonalLinearOperator)
def _(operator):
return TridiagonalLinearOperator(
operator.diagonal.conj(),
operator.lower_diagonal.conj(),
operator.upper_diagonal.conj(),
)


@conj.register(TaggedLinearOperator)
def _(operator):
return TaggedLinearOperator(conj(operator.operator), operator.tags)


@conj.register(TangentLinearOperator)
def _(operator):
# Should be unreachable: TangentLinearOperator is used for a narrow set of
# operations only (mv; transpose) inside the JVP rule linear_solve_p.
raise NotImplementedError(
"Please open a GitHub issue: https://github.com/google/lineax"
)


@conj.register(AddLinearOperator)
def _(operator):
return conj(operator.operator1) + conj(operator.operator2)


@conj.register(MulLinearOperator)
def _(operator):
return conj(operator.operator) * operator.scalar.conj()


@conj.register(DivLinearOperator)
def _(operator):
return conj(operator.operator) / operator.scalar.conj()


@conj.register(ComposedLinearOperator)
def _(operator):
return conj(operator.operator1) @ conj(operator.operator2)


@conj.register(AuxLinearOperator)
def _(operator):
return conj(operator.operator)
5 changes: 3 additions & 2 deletions lineax/_solver/cg.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
from .._misc import max_norm, resolve_rcond, tree_dot, tree_where
from .._operator import (
AbstractLinearOperator,
conj,
is_negative_semidefinite,
is_positive_semidefinite,
linearise,
Expand Down Expand Up @@ -107,14 +108,14 @@ def compute(
# If a downstream user wants to avoid this then they can call
# ```
# linear_solve(
# operator.T @ operator, operator.mv(b), solver=CG()
# conj(operator.T) @ operator, operator.mv(b), solver=CG()
# )
# ```
# directly.
operator = linearise(operator)

_mv = operator.mv
_transpose_mv = operator.transpose().mv
_transpose_mv = conj(operator.transpose()).mv

def mv(vector: PyTree) -> PyTree:
return _transpose_mv(_mv(vector))
Expand Down
57 changes: 57 additions & 0 deletions tests/test_adjoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
import jax
import jax.numpy as jnp
import jax.random as jr
import pytest

import lineax as lx
from lineax import FunctionLinearOperator

from .helpers import (
make_diagonal_operator,
make_operators,
make_tridiagonal_operator,
shaped_allclose,
)


@pytest.mark.parametrize("make_operator", make_operators)
@pytest.mark.parametrize("dtype", (jnp.float32, jnp.complex64))
def test_adjoint(make_operator, dtype, getkey):
if make_operator is make_diagonal_operator:
matrix = jnp.eye(4, dtype=dtype)
tags = lx.diagonal_tag
in_size = out_size = 4
elif make_operator is make_tridiagonal_operator:
matrix = jnp.eye(4, dtype=dtype)
tags = lx.tridiagonal_tag
in_size = out_size = 4
else:
matrix = jr.normal(getkey(), (3, 5), dtype=dtype)
tags = ()
in_size = 5
out_size = 3
operator = make_operator(matrix, tags)
v1, v2 = jr.normal(getkey(), (in_size,), dtype=dtype), jr.normal(
getkey(), (out_size,), dtype=dtype
)

inner1 = operator.mv(v1) @ v2.conj()
adjoint_op1 = lx.conj(operator).transpose()
ov2 = adjoint_op1.mv(v2)
inner2 = v1 @ ov2.conj()
assert shaped_allclose(inner1, inner2)

adjoint_op2 = lx.conj(operator.transpose())
ov2 = adjoint_op2.mv(v2)
inner2 = v1 @ ov2.conj()
assert shaped_allclose(inner1, inner2)


def test_functional_pytree_adjoint():
def fn(y):
return {"b": y["a"]}

y_struct = jax.eval_shape(lambda: {"a": 0.0})
operator = FunctionLinearOperator(fn, y_struct)
conj_operator = lx.conj(operator)
assert shaped_allclose(lx.materialise(conj_operator), lx.materialise(operator))

0 comments on commit 49e83ef

Please sign in to comment.