Skip to content

Commit

Permalink
Fixed and added tests for JacobianLinearOperator(jac="bwd")
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Mar 20, 2024
1 parent 9b923c8 commit dc71972
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 2 deletions.
12 changes: 10 additions & 2 deletions lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -518,7 +518,7 @@ class JacobianLinearOperator(AbstractLinearOperator, strict=True):
x: PyTree[Inexact[Array, "..."]]
args: PyTree[Any]
tags: frozenset[object] = eqx.field(static=True)
jac: Optional[Literal["fwd", "bwd"]] = None
jac: Optional[Literal["fwd", "bwd"]]

@eqxi.doc_remove_args("closure_convert", "_has_aux")
def __init__(
Expand Down Expand Up @@ -568,7 +568,15 @@ def __init__(

def mv(self, vector):
fn = _NoAuxOut(_NoAuxIn(self.fn, self.args))
_, out = jax.jvp(fn, (self.x,), (vector,))
if self.jac == "fwd" or self.jac is None:
_, out = jax.jvp(fn, (self.x,), (vector,))
elif self.jac == "bwd":
jac = jax.jacrev(fn)(self.x)
out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv(
vector
)
else:
raise ValueError("`jac` should be either `'fwd'`, `'bwd'`, or `None`.")
return out

def as_matrix(self):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,34 @@ def test_zero_pytree_as_matrix():
struct = jax.ShapeDtypeStruct((2, 1, 0), a.dtype)
op = lx.PyTreeLinearOperator(a, struct)
assert op.as_matrix().shape == (0, 0)


def test_jacrev_operator():
@jax.custom_vjp
def f(x, _):
return dict(foo=x["bar"] + 2)

def f_fwd(x, _):
return f(x, None), None

def f_bwd(_, g):
return dict(bar=g["foo"] + 5), None

f.defvjp(f_fwd, f_bwd)

x = dict(bar=jnp.arange(2.0))
rev_op = lx.JacobianLinearOperator(f, x, jac="bwd")
as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]])
assert tree_allclose(rev_op.as_matrix(), as_matrix)

y = dict(bar=jnp.arange(2.0) + 1)
true_out = dict(foo=jnp.array([16.0, 17.0]))
for op in (rev_op, lx.materialise(rev_op)):
out = op.mv(y)
assert tree_allclose(out, true_out)

fwd_op = lx.JacobianLinearOperator(f, x, jac="fwd")
with pytest.raises(TypeError, match="can't apply forward-mode autodiff"):
fwd_op.mv(y)
with pytest.raises(TypeError, match="can't apply forward-mode autodiff"):
lx.materialise(fwd_op)

0 comments on commit dc71972

Please sign in to comment.