diff --git a/optimistix/_ad.py b/optimistix/_ad.py index 2ea2a79..e0523c4 100644 --- a/optimistix/_ad.py +++ b/optimistix/_ad.py @@ -88,9 +88,7 @@ def _is_none(x): def _for_jac(root, args): - fn_rewrite, residual, _inputs = args - iterate, inputs, while_loop = _inputs - del iterate, while_loop + fn_rewrite, residual, inputs = args return fn_rewrite(root, residual, inputs) @@ -114,9 +112,7 @@ def _implicit_impl_jvp(primals, tangents): def _for_jvp(_diff): _inputs = eqx.combine(_diff, nondiff) - iterate, inputs, while_loop = _inputs - del iterate, while_loop - return fn_rewrite(root, residual, inputs) + return fn_rewrite(root, residual, _inputs) operator = lx.JacobianLinearOperator( _for_jac, root, (fn_rewrite, residual, inputs), tags=tags diff --git a/optimistix/_adjoint.py b/optimistix/_adjoint.py index 736b111..4a679cb 100644 --- a/optimistix/_adjoint.py +++ b/optimistix/_adjoint.py @@ -125,12 +125,7 @@ def apply(self, primal_fn, rewrite_fn, inputs, tags): while_loop = ft.partial( eqxi.while_loop, kind="checkpointed", checkpoints=self.checkpoints ) - return primal_fn(inputs, while_loop) - - -def _primal_fn(inputs): - primal_fn, inputs, while_loop = inputs - return primal_fn(inputs, while_loop) + return primal_fn(inputs + (while_loop,)) class ImplicitAdjoint(AbstractAdjoint): @@ -149,8 +144,8 @@ class ImplicitAdjoint(AbstractAdjoint): linear_solver: lx.AbstractLinearSolver = lx.AutoLinearSolver(well_posed=None) def apply(self, primal_fn, rewrite_fn, inputs, tags): - _inputs = (primal_fn, inputs, ft.partial(eqxi.while_loop, kind="lax")) - return implicit_jvp(_primal_fn, rewrite_fn, _inputs, tags, self.linear_solver) + inputs = inputs + (ft.partial(eqxi.while_loop, kind="lax"),) + return implicit_jvp(primal_fn, rewrite_fn, inputs, tags, self.linear_solver) RecursiveCheckpointAdjoint.__init__.__doc__ = """**Arguments:** diff --git a/optimistix/_iterate.py b/optimistix/_iterate.py index bf49210..834bff3 100644 --- a/optimistix/_iterate.py +++ b/optimistix/_iterate.py @@ -216,8 +216,19 @@ def _zero(x): return x -def _iterate(inputs, while_loop): - fn, solver, y0, args, options, max_steps, f_struct, aux_struct, tags = inputs +def _iterate(inputs): + ( + fn, + solver, + y0, + args, + options, + max_steps, + f_struct, + aux_struct, + tags, + while_loop, + ) = inputs del inputs static_leaf = lambda x: isinstance(x, eqxi.Static) f_struct = jtu.tree_map(lambda x: x.value, f_struct, is_leaf=static_leaf) diff --git a/tests/helpers.py b/tests/helpers.py index 7502a72..ed56845 100644 --- a/tests/helpers.py +++ b/tests/helpers.py @@ -785,4 +785,4 @@ class PiggybackAdjoint(optx.AbstractAdjoint): def apply(self, primal_fn, rewrite_fn, inputs, tags): del rewrite_fn, tags while_loop = ft.partial(eqxi.while_loop, kind="lax") - return primal_fn(inputs, while_loop) + return primal_fn(inputs + (while_loop,))