From dc71972d8820f3d183f45b124aec8ee6b5514eac Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 18 Mar 2024 20:57:30 +0100 Subject: [PATCH] Fixed and added tests for JacobianLinearOperator(jac="bwd") --- lineax/_operator.py | 12 ++++++++++-- tests/test_operator.py | 31 +++++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 2 deletions(-) diff --git a/lineax/_operator.py b/lineax/_operator.py index 491051b..4640f0f 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -518,7 +518,7 @@ class JacobianLinearOperator(AbstractLinearOperator, strict=True): x: PyTree[Inexact[Array, "..."]] args: PyTree[Any] tags: frozenset[object] = eqx.field(static=True) - jac: Optional[Literal["fwd", "bwd"]] = None + jac: Optional[Literal["fwd", "bwd"]] @eqxi.doc_remove_args("closure_convert", "_has_aux") def __init__( @@ -568,7 +568,15 @@ def __init__( def mv(self, vector): fn = _NoAuxOut(_NoAuxIn(self.fn, self.args)) - _, out = jax.jvp(fn, (self.x,), (vector,)) + if self.jac == "fwd" or self.jac is None: + _, out = jax.jvp(fn, (self.x,), (vector,)) + elif self.jac == "bwd": + jac = jax.jacrev(fn)(self.x) + out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv( + vector + ) + else: + raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.") return out def as_matrix(self): diff --git a/tests/test_operator.py b/tests/test_operator.py index 36be4ae..c7e4bc7 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -377,3 +377,34 @@ def test_zero_pytree_as_matrix(): struct = jax.ShapeDtypeStruct((2, 1, 0), a.dtype) op = lx.PyTreeLinearOperator(a, struct) assert op.as_matrix().shape == (0, 0) + + +def test_jacrev_operator(): + @jax.custom_vjp + def f(x, _): + return dict(foo=x["bar"] + 2) + + def f_fwd(x, _): + return f(x, None), None + + def f_bwd(_, g): + return dict(bar=g["foo"] + 5), None + + f.defvjp(f_fwd, f_bwd) + + x = dict(bar=jnp.arange(2.0)) + rev_op = lx.JacobianLinearOperator(f, x, jac="bwd") + as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]]) + assert tree_allclose(rev_op.as_matrix(), as_matrix) + + y = dict(bar=jnp.arange(2.0) + 1) + true_out = dict(foo=jnp.array([16.0, 17.0])) + for op in (rev_op, lx.materialise(rev_op)): + out = op.mv(y) + assert tree_allclose(out, true_out) + + fwd_op = lx.JacobianLinearOperator(f, x, jac="fwd") + with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): + fwd_op.mv(y) + with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): + lx.materialise(fwd_op)