Skip to content

Commit

Permalink
Now using the same jaxpr in the state.
Browse files Browse the repository at this point in the history
This is quite an important fix!

The bit that matters here is that the `f_eval_info.jac` in
`AbstractGaussNewton.step` now throws away its static (non-array) parts
of its PyTree, and instead uses the equivalent static (non-array) parts
of `state.f_info.jac`, i.e. as were computed in
`AbstractGaussNewton.init`.

Now at a logical level this shouldn't matter at all: the static pieces
should be the same in both cases, as they're just the output of
`_make_f_info` with similarly-structured inputs.

However, `_make_f_info` calls `lx.FunctionLinearOperator` which calls
`eqx.filter_closure_convert` which calls `jax.make_jaxpr` which returns
a jaxpr... and so between the two calls to `_make_f_info`, we actually
end up with two jaxprs. Both encode the same program, but are two
different Python objects. Now jaxprs have `__eq__` defined according to
identity, so these two (functionally identical) jaxprs do not compare
as equal.

Previously we worked around this inside `_iterate.py`: we carefully
removed or wrapped any jaxprs before anything that would try to compare
them for equality. This was a bit ugly, but it worked.

However, it turns out that this still left a problem when manually
stepping an Optimistix solver! (In a way akin to an Optax solver:
something like
```python
@eqx.filter_jit
def make_step(...):
    ... = solver.step(...)

for ... in ...:  # Python level for-loop
    ... = make_step(...)
```
)
then in fact on every iteration of the Python loop, we would end up
recompiling, as we always gets a new jaxpr at
```
state      # state for the Gauss-Newton solver
  .f_info  # as returned by _make_f_info
  .jac     # the FunctionLinearOperator
  .fn      # the closure-converted function
  .jaxpr   # the jaxpr from the closure conversion
```
!

Now one fix is simply to demand that manually stepping a solver
requires similar hackery as we had in `_iterate.py`. But maybe enough
is enough, and we should try doing something better instead: that is,
we do what this PR does, and just preserves the same jaxpr all the way
through.

For bonus points, this means that we can now remove our special jaxpr
handling from `_iterate.py` (and from `filter_cond`, which also needed
this for the same reason).

Finally, you might be wondering: why do we need to trace two equivalent
jaxprs at all? This seems inefficient -- can we arrange to trace it
just once? The answer is "probably, but not in this PR". This seems to
require that (a) Lineax offer a way to turn off closure conversion
(done in patrick-kidger/lineax#71), but that (b) when
using this, this still seems to trigger a similar issue in JAX, that
the primal and tangent results from `jax.custom_jvp` match. So for now
this is just something to try and tackle later -- once we do, we'll get
slightly better compile times.
  • Loading branch information
patrick-kidger committed Dec 27, 2023
1 parent f3b6965 commit 5528cc3
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 28 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ repos:
- id: ruff-format # formatter
types_or: [ python, pyi, jupyter ]
- repo: https://github.com/RobertCraigie/pyright-python
rev: v1.1.330
rev: v1.1.331
hooks:
- id: pyright
additional_dependencies: ["equinox", "jax", "lineax", "pytest", "optax"]
17 changes: 1 addition & 16 deletions optimistix/_iterate.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,20 +37,9 @@
_Node = eqxi.doc_repr(Any, "Node")


def _is_jaxpr(x):
return isinstance(x, (jax.core.Jaxpr, jax.core.ClosedJaxpr))


def _is_array_or_jaxpr(x):
return _is_jaxpr(x) or eqx.is_array(x)


class AbstractIterativeSolver(eqx.Module, Generic[Y, Out, Aux, SolverState]):
"""Abstract base class for all iterative solvers."""

# Essentially every solver has an rtol+atol+norm. So for now we're just hardcoding
# that every solver must have these variables, as they're needed when using a
# minimiser or least-squares solver on a root-finding problem.
rtol: AbstractVar[float]
atol: AbstractVar[float]
norm: AbstractVar[Callable[[PyTree], Scalar]]
Expand Down Expand Up @@ -255,11 +244,7 @@ def body_fun(carry):
new_y, new_state, aux = solver.step(fn, y, args, options, state, tags)
new_dynamic_state, new_static_state = eqx.partition(new_state, eqx.is_array)

new_static_state_no_jaxpr = eqx.filter(
new_static_state, _is_jaxpr, inverse=True
)
static_state_no_jaxpr = eqx.filter(state, _is_array_or_jaxpr, inverse=True)
assert eqx.tree_equal(static_state_no_jaxpr, new_static_state_no_jaxpr) is True
assert eqx.tree_equal(static_state, new_static_state) is True
return new_y, num_steps + 1, new_dynamic_state, aux

def buffers(carry):
Expand Down
4 changes: 1 addition & 3 deletions optimistix/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,18 +231,16 @@ def _true_fun(_dynamic):
_operands = eqx.combine(_dynamic, static)
_out = true_fun(*_operands)
_dynamic_out, _static_out = eqx.partition(_out, eqx.is_array)
_static_out = wrap_jaxpr(_static_out)
return _dynamic_out, eqxi.Static(_static_out)

def _false_fun(_dynamic):
_operands = eqx.combine(_dynamic, static)
_out = false_fun(*_operands)
_dynamic_out, _static_out = eqx.partition(_out, eqx.is_array)
_static_out = wrap_jaxpr(_static_out)
return _dynamic_out, eqxi.Static(_static_out)

dynamic_out, static_out = lax.cond(pred, _true_fun, _false_fun, dynamic)
return eqx.combine(dynamic_out, unwrap_jaxpr(static_out.value))
return eqx.combine(dynamic_out, static_out.value)


def verbose_print(*args: tuple[bool, str, Any]) -> None:
Expand Down
24 changes: 16 additions & 8 deletions optimistix/_solver/gauss_newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,14 @@ class AbstractGaussNewton(AbstractLeastSquaresSolver[Y, Out, Aux, _GaussNewtonSt
This includes methods such as [`optimistix.GaussNewton`][],
[`optimistix.LevenbergMarquardt`][], and [`optimistix.Dogleg`][].
Subclasses must provide the following abstract attributes, with the following types:
- `rtol: float`
- `atol: float`
- `norm: Callable[[PyTree], Scalar]`
- `descent: AbstractDescent`
- `search: AbstractSearch`
- `verbose: frozenset[str]
Subclasses must provide the following attributes, with the following types:
- `rtol`: `float`
- `atol`: `float`
- `norm`: `Callable[[PyTree], Scalar]`
- `descent`: `AbstractDescent`
- `search`: `AbstractSearch`
- `verbose`: `frozenset[str]`
"""

rtol: AbstractVar[float]
Expand Down Expand Up @@ -243,6 +243,14 @@ def step(
tags: frozenset[object],
) -> tuple[Y, _GaussNewtonState, Aux]:
f_eval_info, aux_eval = _make_f_info(fn, state.y_eval, args, tags)
# We have a jaxpr in `f_info.jac`, which are compared by identity. Here we
# arrange to use the same one so that downstream equality checks (e.g. in the
# `filter_cond` below)
dynamic = eqx.filter(f_eval_info.jac, eqx.is_array)
static = eqx.filter(state.f_info.jac, eqx.is_array, inverse=True)
jac = eqx.combine(dynamic, static)
f_eval_info = eqx.tree_at(lambda f: f.jac, f_eval_info, jac)

step_size, accept, search_result, search_state = self.search.step(
state.first_step,
y,
Expand Down

0 comments on commit 5528cc3

Please sign in to comment.