Skip to content

Commit

Permalink
Fixed and added tests for JacobianLinearOperator(jac="bwd")
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Mar 18, 2024
1 parent 9b923c8 commit 928bcf3
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
9 changes: 8 additions & 1 deletion lineax/_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,7 +568,14 @@ def __init__(

def mv(self, vector):
fn = _NoAuxOut(_NoAuxIn(self.fn, self.args))
_, out = jax.jvp(fn, (self.x,), (vector,))
if self.jac == "fwd":
_, out = jax.jvp(fn, (self.x,), (vector,))
else:
assert self.jac == "bwd"
jac = jax.jacrev(fn)(self.x)
out = PyTreeLinearOperator(jac, output_structure=self.out_structure()).mv(
vector
)
return out

def as_matrix(self):
Expand Down
31 changes: 31 additions & 0 deletions tests/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,3 +377,34 @@ def test_zero_pytree_as_matrix():
struct = jax.ShapeDtypeStruct((2, 1, 0), a.dtype)
op = lx.PyTreeLinearOperator(a, struct)
assert op.as_matrix().shape == (0, 0)


def test_jacrev_operator():
@jax.custom_vjp
def f(x, _):
return dict(foo=x["bar"] + 2)

def f_fwd(x, _):
return f(x, None), None

def f_bwd(_, g):
return dict(bar=g["foo"] + 5), None

f.defvjp(f_fwd, f_bwd)

x = dict(bar=jnp.arange(2.0))
rev_op = lx.JacobianLinearOperator(f, x, jac="bwd")
as_matrix = jnp.array([[6.0, 5.0], [5.0, 6.0]])
assert tree_allclose(rev_op.as_matrix(), as_matrix)

y = dict(bar=jnp.arange(2.0) + 1)
true_out = dict(foo=jnp.array([16.0, 17.0]))
for op in (rev_op, lx.materialise(rev_op)):
out = op.mv(y)
assert tree_allclose(out, true_out)

fwd_op = lx.JacobianLinearOperator(f, x, jac="fwd")
with pytest.raises(TypeError, match="can't apply forward-mode autodiff"):
fwd_op.mv(y)
with pytest.raises(TypeError, match="can't apply forward-mode autodiff"):
lx.materialise(fwd_op)

0 comments on commit 928bcf3

Please sign in to comment.