Skip to content

Commit

Permalink
Merge branch 'patrick-kidger:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
thibmonsel authored Oct 23, 2023
2 parents 539757d + 712c208 commit fe1ca9a
Show file tree
Hide file tree
Showing 16 changed files with 125 additions and 48 deletions.
14 changes: 11 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -61,16 +61,24 @@ If you found this library useful in academic research, please cite: [(arXiv link

## See also: other libraries in the JAX ecosystem

[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.

[Equinox](https://github.com/patrick-kidger/equinox): neural networks.

[Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers.

[Lineax](https://github.com/google/lineax): linear solvers and linear least squares.
[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares.

[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays.
[Lineax](https://github.com/google/lineax): linear solvers.

[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.
[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling.

[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device).

[sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent.

[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models.

[Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs).

[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!)
8 changes: 7 additions & 1 deletion diffrax/adjoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,7 @@ def _loop_backsolve_bwd(
throw,
init_state,
):
assert discrete_terminating_event is None

#
# Unpack our various arguments. Delete a lot of things just to make sure we're not
Expand Down Expand Up @@ -565,7 +566,6 @@ def _loop_backsolve_bwd(
adjoint=self,
solver=solver,
stepsize_controller=stepsize_controller,
discrete_terminating_event=discrete_terminating_event,
terms=adjoint_terms,
dt0=None if dt0 is None else -dt0,
max_steps=max_steps,
Expand Down Expand Up @@ -744,6 +744,7 @@ def loop(
init_state,
passed_solver_state,
passed_controller_state,
discrete_terminating_event,
**kwargs,
):
if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure(
Expand Down Expand Up @@ -785,6 +786,10 @@ def loop(
"`diffrax.BacksolveAdjoint` is only compatible with solvers that take "
"a single term."
)
if discrete_terminating_event is not None:
raise NotImplementedError(
"`diffrax.BacksolveAdjoint` is not compatible with events."
)

y = init_state.y
init_state = eqx.tree_at(lambda s: s.y, init_state, object())
Expand All @@ -798,6 +803,7 @@ def loop(
saveat=saveat,
init_state=init_state,
solver=solver,
discrete_terminating_event=discrete_terminating_event,
**kwargs,
)
final_state = _only_transpose_ys(final_state)
Expand Down
2 changes: 1 addition & 1 deletion diffrax/brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ class UnsafeBrownianPath(AbstractBrownianPath):
correlation structure isn't needed.)
"""

shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field()
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
# Handled as a string because PRNGKey is actually a function, not a class, which
# makes it appearly badly in autogenerated documentation.
key: "jax.random.PRNGKey" # noqa: F821
Expand Down
2 changes: 1 addition & 1 deletion diffrax/brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ class VirtualBrownianTree(AbstractBrownianPath):
t0: Scalar = field(init=True)
t1: Scalar = field(init=True) # override init=False in AbstractPath
tol: Scalar
shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field()
shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True)
key: "jax.random.PRNGKey" # noqa: F821

def __init__(
Expand Down
8 changes: 7 additions & 1 deletion diffrax/global_interpolation.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ class LinearInterpolation(AbstractGlobalInterpolation):
ys: PyTree[Array["times", ...]] # noqa: F821

def __post_init__(self):
super().__post_init__()

def _check(_ys):
if _ys.shape[0] != self.ts.shape[0]:
raise ValueError(
Expand Down Expand Up @@ -179,6 +181,8 @@ class CubicInterpolation(AbstractGlobalInterpolation):
]

def __post_init__(self):
super().__post_init__()

def _check(d, c, b, a):
error_msg = (
"Each cubic coefficient must have `times - 1` entries, where "
Expand Down Expand Up @@ -287,12 +291,14 @@ def derivative(self, t: Scalar, left: bool = True) -> PyTree:
class DenseInterpolation(AbstractGlobalInterpolation):
ts_size: Int # Takes values in {1, 2, 3, ...}
infos: DenseInfos
interpolation_cls: Type[AbstractLocalInterpolation] = eqx.static_field()
interpolation_cls: Type[AbstractLocalInterpolation] = eqx.field(static=True)
direction: Scalar
t0_if_trivial: Array
y0_if_trivial: PyTree[Array]

def __post_init__(self):
super().__post_init__()

def _check(_infos):
assert _infos.shape[0] + 1 == self.ts.shape[0]

Expand Down
23 changes: 6 additions & 17 deletions diffrax/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import jax.tree_util as jtu
from jax.typing import ArrayLike

from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint
from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
from .custom_types import Array, Bool, Int, PyTree, Scalar
from .event import AbstractDiscreteTerminatingEvent
from .global_interpolation import DenseInterpolation
Expand Down Expand Up @@ -415,6 +415,11 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
)
new_state = eqx.tree_at(lambda s: s.result, new_state, result)

if not _filtering:
# This is only necessary for Equinox <0.11.1.
# After that, this fix has been upstreamed to Equinox.
# TODO: remove once we make Equinox >=0.11.1 required.
new_state = jtu.tree_map(jnp.array, new_state)
return new_state

_filtering = True
Expand Down Expand Up @@ -633,22 +638,6 @@ def diffeqsolve(
"An SDE should not be solved with adaptive step sizes with Euler's "
"method, as it may not converge to the correct solution."
)
# TODO: remove these lines.
#
# These are to work around an edge case: on the backward pass,
# RecursiveCheckpointAdjoint currently tries to differentiate the overall
# per-step function wrt all floating-point arrays. In particular this includes
# `state.tprev`, which feeds into the control, which feeds into
# VirtualBrownianTree, which can't be differentiated.
# We're waiting on JAX to offer a way of specifying which arguments to a
# custom_vjp have symbolic zero *tangents* (not cotangents) so that we can more
# precisely determine what to differentiate wrt.
#
# We don't replace this in the case of an unsafe SDE because
# RecursiveCheckpointAdjoint will raise an error in that case anyway, so we
# should let the normal error be raised.
if isinstance(adjoint, RecursiveCheckpointAdjoint) and not is_unsafe_sde(terms):
adjoint = DirectAdjoint()
if is_unsafe_sde(terms):
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion diffrax/nonlinear_solver/newton.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def body_fn(val):
val = (flat, step + 1, diffsize, diffsize_prev)
return val

val = (flat, 0, 0.0, 0.0)
val = (flat, 0, jnp.array(0.0), jnp.array(0.0))
val = lax.while_loop(cond_fn, body_fn, val)
flat, num_steps, diffsize, diffsize_prev = val

Expand Down
10 changes: 7 additions & 3 deletions diffrax/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,16 @@
class RESULTS(metaclass=eqxi.ContainerMeta):
successful = ""
discrete_terminating_event_occurred = (
"Terminating solve because a discrete event occurred."
"Terminating differential equation solve because a discrete terminating event "
"occurred."
)
max_steps_reached = (
"The maximum number of solver steps was reached. Try increasing `max_steps`."
"The maximum number of steps was reached in the differential equation solver. "
"Try increasing `diffrax.diffeqsolve(..., max_steps=...)`."
)
dt_min_reached = (
"The minimum step size was reached in the differential equation solver."
)
dt_min_reached = "The minimum step size was reached."
implicit_divergence = "Implicit method diverged."
implicit_nonconvergence = (
"Implicit method did not converge within the required number of iterations."
Expand Down
4 changes: 2 additions & 2 deletions diffrax/step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -423,8 +423,8 @@ def adapt_step_size(
# h_n is the nth step size
# ε_n = atol + norm(y) * rtol with y on the nth step
# r_n = norm(y_error) with y_error on the nth step
# δ_{n,m} = norm(y_error / (atol + norm(y) * rtol)) with y_error on the nth
# step and y on the mth step
# δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth
# step and y on the mth step
# β_1 = pcoeff + icoeff + dcoeff
# β_2 = -(pcoeff + 2 * dcoeff)
# β_3 = dcoeff
Expand Down
8 changes: 5 additions & 3 deletions diffrax/term.py
Original file line number Diff line number Diff line change
Expand Up @@ -409,7 +409,9 @@ def is_vf_expensive(
y: Tuple[PyTree, PyTree, PyTree, PyTree],
args: PyTree,
) -> bool:
return self.term.is_vf_expensive(t0, t1, y, args)
_t0 = jnp.where(self.direction == 1, t0, -t1)
_t1 = jnp.where(self.direction == 1, t1, -t0)
return self.term.is_vf_expensive(_t0, _t1, y, args)


class AdjointTerm(AbstractTerm):
Expand All @@ -422,8 +424,8 @@ def is_vf_expensive(
y: Tuple[PyTree, PyTree, PyTree, PyTree],
args: PyTree,
) -> bool:
control = self.contr(t0, t1)
if sum(c.size for c in jtu.tree_leaves(control)) in (0, 1):
control_struct = jax.eval_shape(self.contr, t0, t1)
if sum(c.size for c in jtu.tree_leaves(control_struct)) in (0, 1):
return False
else:
return True
Expand Down
9 changes: 0 additions & 9 deletions docs/usage/getting-started.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,15 +99,6 @@ print(sol.evaluate(1.1)) # DeviceArray(0.89436394)
As you can see, basically nothing has changed compared to the ODE example; all the same APIs are used. The only difference is that we created an SDE solver rather than an ODE solver.


!!! info

If using some SDE-specific solvers, for example [`diffrax.ItoMilstein`][], then the solver makes a distinction between drift and diffusion. (In the previous example, the solver [`diffrax.Euler`][] is completely oblivious to this distinction. In this case the drift and diffusion should be passed separately as a 2-tuple of terms, rather than wrapped into a single [`diffrax.MultiTerm`][]. This would involve changing the above example with:

```python
terms = (ODETerm(drift), ControlTerm(diffusion, brownian_motion))
solver = ItoMilstein()
```

!!! info

To do adaptive stepping with an SDE, then the typical approach is to wrap the solver like so -- and to use the following default step size controller:
Expand Down
3 changes: 2 additions & 1 deletion docs/usage/manual-stepping.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ In the following example, we solve an ODE using [`diffrax.Tsit5`][], and print o
See the [Abstract solvers](../api/solvers/abstract_solvers.md) page for a reference on the solver methods (`init`, `step`) used here.

```python
import jax.numpy as jnp
from diffrax import ODETerm, Tsit5

vector_field = lambda t, y, args: -y
Expand All @@ -20,7 +21,7 @@ solver = Tsit5()
t0 = 0
dt0 = 0.05
t1 = 1
y0 = 1
y0 = jnp.array(1.0)
args = None

tprev = t0
Expand Down
4 changes: 2 additions & 2 deletions examples/nonlinear_heat_pde.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@
"source": [
"# Represents the interval [x0, x_final] discretised into n equally-spaced points.\n",
"class SpatialDiscretisation(eqx.Module):\n",
" x0: float = eqx.static_field()\n",
" x_final: float = eqx.static_field()\n",
" x0: float = eqx.field(static=True)\n",
" x_final: float = eqx.field(static=True)\n",
" vals: Float[Array, \"n\"]\n",
"\n",
" @classmethod\n",
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@

python_requires = "~=3.9"

install_requires = ["jax>=0.4.13", "equinox>=0.10.11"]
install_requires = ["jax>=0.4.13", "equinox>=0.11.1"]

setuptools.setup(
name=name,
Expand Down
50 changes: 48 additions & 2 deletions test/test_event.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import diffrax
import jax
import jax.numpy as jnp
import pytest


def test_discrete_terminate1():
Expand All @@ -9,7 +11,12 @@ def test_discrete_terminate1():
t1 = jnp.inf
dt0 = 1
y0 = 1.0
event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.y > 10)

def event_fn(state, **kwargs):
assert isinstance(state.y, jax.Array)
return state.tprev > 10

event = diffrax.DiscreteTerminatingEvent(event_fn)
sol = diffrax.diffeqsolve(
term, solver, t0, t1, dt0, y0, discrete_terminating_event=event
)
Expand All @@ -23,11 +30,50 @@ def test_discrete_terminate2():
t1 = jnp.inf
dt0 = 1
y0 = 1.0
event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.tprev > 10)

def event_fn(state, **kwargs):
assert isinstance(state.y, jax.Array)
return state.tprev > 10

event = diffrax.DiscreteTerminatingEvent(event_fn)
sol = diffrax.diffeqsolve(
term, solver, t0, t1, dt0, y0, discrete_terminating_event=event
)
assert jnp.all(sol.ts > 10)


def test_event_backsolve():
term = diffrax.ODETerm(lambda t, y, args: y)
solver = diffrax.Tsit5()
t0 = 0
t1 = jnp.inf
dt0 = 1
y0 = 1.0

def event_fn(state, **kwargs):
assert isinstance(state.y, jax.Array)
return state.tprev > 10

event = diffrax.DiscreteTerminatingEvent(event_fn)

@jax.jit
@jax.grad
def run(y0):
sol = diffrax.diffeqsolve(
term,
solver,
t0,
t1,
dt0,
y0,
discrete_terminating_event=event,
adjoint=diffrax.BacksolveAdjoint(),
)
return jnp.sum(sol.ys)

# And in particular not some other error.
with pytest.raises(NotImplementedError):
run(y0)


# diffrax.SteadyStateEvent tested as part of test_adjoint.py::test_implicit
24 changes: 24 additions & 0 deletions test/test_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,3 +418,27 @@ def run(y0):
assert sol.made_jump is False

run(1)


def test_no_jit():
# https://github.com/patrick-kidger/diffrax/issues/293
# https://github.com/patrick-kidger/diffrax/issues/321

# Test that this doesn't crash.
with jax.disable_jit():

def vector_field(t, y, args):
return jnp.zeros_like(y)

term = diffrax.ODETerm(vector_field)
y = jnp.zeros((1,))
stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5)
diffrax.diffeqsolve(
term,
diffrax.Kvaerno4(),
t0=0,
t1=1e-2,
dt0=1e-3,
stepsize_controller=stepsize_controller,
y0=y,
)

0 comments on commit fe1ca9a

Please sign in to comment.