diff --git a/lineax/_operator.py b/lineax/_operator.py index 491051b..c81dab4 100644 --- a/lineax/_operator.py +++ b/lineax/_operator.py @@ -568,7 +568,14 @@ def __init__( def mv(self, vector): fn = _NoAuxOut(_NoAuxIn(self.fn, self.args)) - _, out = jax.jvp(fn, (self.x,), (vector,)) + if self.jac == "fwd": + _, out = jax.jvp(fn, (self.x,), (vector,)) + else: + assert self.jac == "bwd" + jac = jax.jacrev(fn)(self.x) + out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv( + vector + ) return out def as_matrix(self): diff --git a/tests/test_operator.py b/tests/test_operator.py index 36be4ae..c7e4bc7 100644 --- a/tests/test_operator.py +++ b/tests/test_operator.py @@ -377,3 +377,34 @@ def test_zero_pytree_as_matrix(): struct = jax.ShapeDtypeStruct((2, 1, 0), a.dtype) op = lx.PyTreeLinearOperator(a, struct) assert op.as_matrix().shape == (0, 0) + + +def test_jacrev_operator(): + @jax.custom_vjp + def f(x, _): + return dict(foo=x["bar"] + 2) + + def f_fwd(x, _): + return f(x, None), None + + def f_bwd(_, g): + return dict(bar=g["foo"] + 5), None + + f.defvjp(f_fwd, f_bwd) + + x = dict(bar=jnp.arange(2.0)) + rev_op = lx.JacobianLinearOperator(f, x, jac="bwd") + as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]]) + assert tree_allclose(rev_op.as_matrix(), as_matrix) + + y = dict(bar=jnp.arange(2.0) + 1) + true_out = dict(foo=jnp.array([16.0, 17.0])) + for op in (rev_op, lx.materialise(rev_op)): + out = op.mv(y) + assert tree_allclose(out, true_out) + + fwd_op = lx.JacobianLinearOperator(f, x, jac="fwd") + with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): + fwd_op.mv(y) + with pytest.raises(TypeError, match="can't apply forward-mode autodiff"): + lx.materialise(fwd_op)