Skip to content

Commit

Permalink
grad-of-vmap-of-linear_solve with symbolic zero cotangents no longer …
Browse files Browse the repository at this point in the history
…crashes
  • Loading branch information
patrick-kidger committed Apr 14, 2024
1 parent 9b923c8 commit 0adb209
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 0 deletions.
1 change: 1 addition & 0 deletions lineax/_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,6 +267,7 @@ def _linear_solve_transpose(inputs, cts_out):
jtu.tree_map(
_assert_defined, (operator, state, options, solver), is_leaf=_is_undefined
)
cts_solution = jtu.tree_map(ft.partial(eqxi.materialise_zeros, allow_struct=True), operator.in_structure(), cts_solution)
operator_transpose = operator.transpose()
state_transpose, options_transpose = solver.transpose(state, options)
cts_vector, _, _ = eqxi.filter_primitive_bind(
Expand Down
18 changes: 18 additions & 0 deletions tests/test_solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,3 +138,21 @@ def fn(y):
grad, sol = jax.grad(f, has_aux=True)(x, z)
assert tree_allclose(grad, -z / (x**2))
assert tree_allclose(sol, z / x)


def test_grad_vmap_symbolic_cotangent():
def f(x):
return x[0], x[1]

@jax.vmap
def to_vmap(x):
op = lx.FunctionLinearOperator(f, jax.eval_shape(lambda: x))
sol = lx.linear_solve(op, x)
return sol.value[0]

@jax.grad
def to_grad(x):
return jnp.sum(to_vmap(x))

x = (jnp.arange(3.0), jnp.arange(3.0))
to_grad(x)

0 comments on commit 0adb209

Please sign in to comment.