Skip to content

Commit

Permalink
Fixed implicit_jvp assuming that it was only being used with iterate
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Dec 27, 2023
1 parent 672f358 commit f3b6965
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 17 deletions.
8 changes: 2 additions & 6 deletions optimistix/_ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def _is_none(x):


def _for_jac(root, args):
fn_rewrite, residual, _inputs = args
iterate, inputs, while_loop = _inputs
del iterate, while_loop
fn_rewrite, residual, inputs = args
return fn_rewrite(root, residual, inputs)


Expand All @@ -114,9 +112,7 @@ def _implicit_impl_jvp(primals, tangents):

def _for_jvp(_diff):
_inputs = eqx.combine(_diff, nondiff)
iterate, inputs, while_loop = _inputs
del iterate, while_loop
return fn_rewrite(root, residual, inputs)
return fn_rewrite(root, residual, _inputs)

operator = lx.JacobianLinearOperator(
_for_jac, root, (fn_rewrite, residual, inputs), tags=tags
Expand Down
11 changes: 3 additions & 8 deletions optimistix/_adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,12 +125,7 @@ def apply(self, primal_fn, rewrite_fn, inputs, tags):
while_loop = ft.partial(
eqxi.while_loop, kind="checkpointed", checkpoints=self.checkpoints
)
return primal_fn(inputs, while_loop)


def _primal_fn(inputs):
primal_fn, inputs, while_loop = inputs
return primal_fn(inputs, while_loop)
return primal_fn(inputs + (while_loop,))


class ImplicitAdjoint(AbstractAdjoint):
Expand All @@ -149,8 +144,8 @@ class ImplicitAdjoint(AbstractAdjoint):
linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=None)

def apply(self, primal_fn, rewrite_fn, inputs, tags):
_inputs = (primal_fn, inputs, ft.partial(eqxi.while_loop, kind="lax"))
return implicit_jvp(_primal_fn, rewrite_fn, _inputs, tags, self.linear_solver)
inputs = inputs + (ft.partial(eqxi.while_loop, kind="lax"),)
return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver)


RecursiveCheckpointAdjoint.__init__.__doc__ = """**Arguments:**
Expand Down
15 changes: 13 additions & 2 deletions optimistix/_iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,19 @@ def _zero(x):
return x


def _iterate(inputs, while_loop):
fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags = inputs
def _iterate(inputs):
(
fn,
solver,
y0,
args,
options,
max_steps,
f_struct,
aux_struct,
tags,
while_loop,
) = inputs
del inputs
static_leaf = lambda x: isinstance(x, eqxi.Static)
f_struct = jtu.tree_map(lambda x: x.value, f_struct, is_leaf=static_leaf)
Expand Down
2 changes: 1 addition & 1 deletion tests/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,4 +785,4 @@ class PiggybackAdjoint(optx.AbstractAdjoint):
def apply(self, primal_fn, rewrite_fn, inputs, tags):
del rewrite_fn, tags
while_loop = ft.partial(eqxi.while_loop, kind="lax")
return primal_fn(inputs, while_loop)
return primal_fn(inputs + (while_loop,))

0 comments on commit f3b6965

Please sign in to comment.