From 5f1978de2fefd8eea16fdb25a1837a4ce1b61ea4 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 17 Aug 2023 14:27:30 -0700 Subject: [PATCH 01/12] Explicitly marked BacksolveAdjoint as incompatible with events. --- diffrax/adjoint.py | 8 +++++++- test/test_event.py | 50 ++++++++++++++++++++++++++++++++++++++++++++-- 2 files changed, 55 insertions(+), 3 deletions(-) diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 59a19b0c..e64e4970 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -530,6 +530,7 @@ def _loop_backsolve_bwd( throw, init_state, ): + assert discrete_terminating_event is None # # Unpack our various arguments. Delete a lot of things just to make sure we're not @@ -565,7 +566,6 @@ def _loop_backsolve_bwd( adjoint=self, solver=solver, stepsize_controller=stepsize_controller, - discrete_terminating_event=discrete_terminating_event, terms=adjoint_terms, dt0=None if dt0 is None else -dt0, max_steps=max_steps, @@ -744,6 +744,7 @@ def loop( init_state, passed_solver_state, passed_controller_state, + discrete_terminating_event, **kwargs, ): if jtu.tree_structure(saveat.subs, is_leaf=_is_subsaveat) != jtu.tree_structure( @@ -785,6 +786,10 @@ def loop( "`diffrax.BacksolveAdjoint` is only compatible with solvers that take " "a single term." ) + if discrete_terminating_event is not None: + raise NotImplementedError( + "`diffrax.BacksolveAdjoint` is not compatible with events." + ) y = init_state.y init_state = eqx.tree_at(lambda s: s.y, init_state, object()) @@ -798,6 +803,7 @@ def loop( saveat=saveat, init_state=init_state, solver=solver, + discrete_terminating_event=discrete_terminating_event, **kwargs, ) final_state = _only_transpose_ys(final_state) diff --git a/test/test_event.py b/test/test_event.py index 3129be6c..5ba8cc87 100644 --- a/test/test_event.py +++ b/test/test_event.py @@ -1,5 +1,7 @@ import diffrax +import jax import jax.numpy as jnp +import pytest def test_discrete_terminate1(): @@ -9,7 +11,12 @@ def test_discrete_terminate1(): t1 = jnp.inf dt0 = 1 y0 = 1.0 - event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.y > 10) + + def event_fn(state, **kwargs): + assert isinstance(state.y, jax.Array) + return state.tprev > 10 + + event = diffrax.DiscreteTerminatingEvent(event_fn) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, discrete_terminating_event=event ) @@ -23,11 +30,50 @@ def test_discrete_terminate2(): t1 = jnp.inf dt0 = 1 y0 = 1.0 - event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.tprev > 10) + + def event_fn(state, **kwargs): + assert isinstance(state.y, jax.Array) + return state.tprev > 10 + + event = diffrax.DiscreteTerminatingEvent(event_fn) sol = diffrax.diffeqsolve( term, solver, t0, t1, dt0, y0, discrete_terminating_event=event ) assert jnp.all(sol.ts > 10) +def test_event_backsolve(): + term = diffrax.ODETerm(lambda t, y, args: y) + solver = diffrax.Tsit5() + t0 = 0 + t1 = jnp.inf + dt0 = 1 + y0 = 1.0 + + def event_fn(state, **kwargs): + assert isinstance(state.y, jax.Array) + return state.tprev > 10 + + event = diffrax.DiscreteTerminatingEvent(event_fn) + + @jax.jit + @jax.grad + def run(y0): + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + discrete_terminating_event=event, + adjoint=diffrax.BacksolveAdjoint(), + ) + return jnp.sum(sol.ys) + + # And in particular not some other error. + with pytest.raises(NotImplementedError): + run(y0) + + # diffrax.SteadyStateEvent tested as part of test_adjoint.py::test_implicit From 104936ab1c5861235f766bdf3256018525670213 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 20 Aug 2023 04:46:26 -0700 Subject: [PATCH 02/12] Updated error messages to be more precise. --- diffrax/solution.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/diffrax/solution.py b/diffrax/solution.py index 8ae10f31..0d1eb485 100644 --- a/diffrax/solution.py +++ b/diffrax/solution.py @@ -13,12 +13,16 @@ class RESULTS(metaclass=eqxi.ContainerMeta): successful = "" discrete_terminating_event_occurred = ( - "Terminating solve because a discrete event occurred." + "Terminating differential equation solve because a discrete terminating event " + "occurred." ) max_steps_reached = ( - "The maximum number of solver steps was reached. Try increasing `max_steps`." + "The maximum number of steps was reached in the differential equation solver. " + "Try increasing `diffrax.diffeqsolve(..., max_steps=...)`." + ) + dt_min_reached = ( + "The minimum step size was reached in the differential equation solver." ) - dt_min_reached = "The minimum step size was reached." implicit_divergence = "Implicit method diverged." implicit_nonconvergence = ( "Implicit method did not converge within the required number of iterations." From 737bf39b0d6fe751228cb44f575546497dc4e35b Mon Sep 17 00:00:00 2001 From: Victor Velev Date: Fri, 25 Aug 2023 11:54:31 -0500 Subject: [PATCH 03/12] Fix manual-stepping example (#297) --- docs/usage/manual-stepping.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/usage/manual-stepping.md b/docs/usage/manual-stepping.md index 93b8f12d..68ff5878 100644 --- a/docs/usage/manual-stepping.md +++ b/docs/usage/manual-stepping.md @@ -11,6 +11,7 @@ In the following example, we solve an ODE using [`diffrax.Tsit5`][], and print o See the [Abstract solvers](../api/solvers/abstract_solvers.md) page for a reference on the solver methods (`init`, `step`) used here. ```python +import jax.numpy as jnp from diffrax import ODETerm, Tsit5 vector_field = lambda t, y, args: -y @@ -20,7 +21,7 @@ solver = Tsit5() t0 = 0 dt0 = 0.05 t1 = 1 -y0 = 1 +y0 = jnp.array(1.0) args = None tprev = t0 From 097e7bc00acf42a5d88b4c90f48c93f07bb7ec34 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Thu, 31 Aug 2023 09:06:16 -0700 Subject: [PATCH 04/12] Moved static_field -> field(static=True) --- diffrax/brownian/path.py | 2 +- diffrax/brownian/tree.py | 2 +- diffrax/global_interpolation.py | 2 +- examples/nonlinear_heat_pde.ipynb | 4 ++-- 4 files changed, 5 insertions(+), 5 deletions(-) diff --git a/diffrax/brownian/path.py b/diffrax/brownian/path.py index c1d5f95f..a844450a 100644 --- a/diffrax/brownian/path.py +++ b/diffrax/brownian/path.py @@ -32,7 +32,7 @@ class UnsafeBrownianPath(AbstractBrownianPath): correlation structure isn't needed.) """ - shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field() + shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) # Handled as a string because PRNGKey is actually a function, not a class, which # makes it appearly badly in autogenerated documentation. key: "jax.random.PRNGKey" # noqa: F821 diff --git a/diffrax/brownian/tree.py b/diffrax/brownian/tree.py index ad93f641..577092ce 100644 --- a/diffrax/brownian/tree.py +++ b/diffrax/brownian/tree.py @@ -60,7 +60,7 @@ class VirtualBrownianTree(AbstractBrownianPath): t0: Scalar = field(init=True) t1: Scalar = field(init=True) # override init=False in AbstractPath tol: Scalar - shape: PyTree[jax.ShapeDtypeStruct] = eqx.static_field() + shape: PyTree[jax.ShapeDtypeStruct] = eqx.field(static=True) key: "jax.random.PRNGKey" # noqa: F821 def __init__( diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index 5900bb45..12975c8a 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -287,7 +287,7 @@ def derivative(self, t: Scalar, left: bool = True) -> PyTree: class DenseInterpolation(AbstractGlobalInterpolation): ts_size: Int # Takes values in {1, 2, 3, ...} infos: DenseInfos - interpolation_cls: Type[AbstractLocalInterpolation] = eqx.static_field() + interpolation_cls: Type[AbstractLocalInterpolation] = eqx.field(static=True) direction: Scalar t0_if_trivial: Array y0_if_trivial: PyTree[Array] diff --git a/examples/nonlinear_heat_pde.ipynb b/examples/nonlinear_heat_pde.ipynb index c7fc00cd..eaf2c9e8 100644 --- a/examples/nonlinear_heat_pde.ipynb +++ b/examples/nonlinear_heat_pde.ipynb @@ -85,8 +85,8 @@ "source": [ "# Represents the interval [x0, x_final] discretised into n equally-spaced points.\n", "class SpatialDiscretisation(eqx.Module):\n", - " x0: float = eqx.static_field()\n", - " x_final: float = eqx.static_field()\n", + " x0: float = eqx.field(static=True)\n", + " x_final: float = eqx.field(static=True)\n", " vals: Float[Array, \"n\"]\n", "\n", " @classmethod\n", From e0c12d69e529d44dc0060c49f0cc11fb06248db8 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 11 Sep 2023 10:29:57 -0700 Subject: [PATCH 05/12] Added missing super-post-init calls --- diffrax/global_interpolation.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/diffrax/global_interpolation.py b/diffrax/global_interpolation.py index 12975c8a..e24224d6 100644 --- a/diffrax/global_interpolation.py +++ b/diffrax/global_interpolation.py @@ -67,6 +67,8 @@ class LinearInterpolation(AbstractGlobalInterpolation): ys: PyTree[Array["times", ...]] # noqa: F821 def __post_init__(self): + super().__post_init__() + def _check(_ys): if _ys.shape[0] != self.ts.shape[0]: raise ValueError( @@ -179,6 +181,8 @@ class CubicInterpolation(AbstractGlobalInterpolation): ] def __post_init__(self): + super().__post_init__() + def _check(d, c, b, a): error_msg = ( "Each cubic coefficient must have `times - 1` entries, where " @@ -293,6 +297,8 @@ class DenseInterpolation(AbstractGlobalInterpolation): y0_if_trivial: PyTree[Array] def __post_init__(self): + super().__post_init__() + def _check(_infos): assert _infos.shape[0] + 1 == self.ts.shape[0] From 15b6c6e7c9021f5df5566a03521c8dfb8da76ed9 Mon Sep 17 00:00:00 2001 From: Rembert Daems Date: Fri, 15 Sep 2023 12:20:19 +0200 Subject: [PATCH 06/12] Update getting-started.md Removed out of date info box with confusing information about how to define the terms for specific SDE solvers --- docs/usage/getting-started.md | 9 --------- 1 file changed, 9 deletions(-) diff --git a/docs/usage/getting-started.md b/docs/usage/getting-started.md index c690c52e..f024b309 100644 --- a/docs/usage/getting-started.md +++ b/docs/usage/getting-started.md @@ -99,15 +99,6 @@ print(sol.evaluate(1.1)) # DeviceArray(0.89436394) As you can see, basically nothing has changed compared to the ODE example; all the same APIs are used. The only difference is that we created an SDE solver rather than an ODE solver. -!!! info - - If using some SDE-specific solvers, for example [`diffrax.ItoMilstein`][], then the solver makes a distinction between drift and diffusion. (In the previous example, the solver [`diffrax.Euler`][] is completely oblivious to this distinction. In this case the drift and diffusion should be passed separately as a 2-tuple of terms, rather than wrapped into a single [`diffrax.MultiTerm`][]. This would involve changing the above example with: - - ```python - terms = (ODETerm(drift), ControlTerm(diffusion, brownian_motion)) - solver = ItoMilstein() - ``` - !!! info To do adaptive stepping with an SDE, then the typical approach is to wrap the solver like so -- and to use the following default step size controller: From 9126ce0c7656951945a8b779722e19b7ebddca33 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Tue, 3 Oct 2023 22:12:01 -0700 Subject: [PATCH 07/12] Fixed BacksolveAdjoint+VirtualBrownianTree on TPU without JIT. Quite the edge case! This combination was hitting an error in which the VirtualBrownianTree refused to be evaluated outside of [t0, t1]. The proximal cause turned out to be that `WrapTerm.is_vf_expensive` was missing a couple of lines to flip around t0 and t1 if solving backwards-in-time. Backwards-in-time solves occur most commonly when using BacksolveAdjoint. The lack of JIT is important as this call can actually be DCE'd: it is only used for its static metadata, not its runtime value. So this error does not appear when running under JIT, which is also the reason it didn't get caught in our tests. The final piece of the puzzle, TPUs, is still a bit of a mystery. That the error should be raised when running eagerly makes sense. What is surprising is that CPU-vs-TPU seem to have different behaviour here. Why does CPU DCE this at all? --- diffrax/term.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/diffrax/term.py b/diffrax/term.py index ba8bd506..2136440b 100644 --- a/diffrax/term.py +++ b/diffrax/term.py @@ -409,7 +409,9 @@ def is_vf_expensive( y: Tuple[PyTree, PyTree, PyTree, PyTree], args: PyTree, ) -> bool: - return self.term.is_vf_expensive(t0, t1, y, args) + _t0 = jnp.where(self.direction == 1, t0, -t1) + _t1 = jnp.where(self.direction == 1, t1, -t0) + return self.term.is_vf_expensive(_t0, _t1, y, args) class AdjointTerm(AbstractTerm): @@ -422,8 +424,8 @@ def is_vf_expensive( y: Tuple[PyTree, PyTree, PyTree, PyTree], args: PyTree, ) -> bool: - control = self.contr(t0, t1) - if sum(c.size for c in jtu.tree_leaves(control)) in (0, 1): + control_struct = jax.eval_shape(self.contr, t0, t1) + if sum(c.size for c in jtu.tree_leaves(control_struct)) in (0, 1): return False else: return True From ad8eac5d682fdaa27c36a240fa0299360692fb52 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:44:46 +0100 Subject: [PATCH 08/12] Update ecosystem links --- README.md | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 98655ca7..5e2507a4 100644 --- a/README.md +++ b/README.md @@ -61,16 +61,22 @@ If you found this library useful in academic research, please cite: [(arXiv link ## See also: other libraries in the JAX ecosystem +[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays. + [Equinox](https://github.com/patrick-kidger/equinox): neural networks. [Optax](https://github.com/deepmind/optax): first-order gradient (SGD, Adam, ...) optimisers. -[Lineax](https://github.com/google/lineax): linear solvers and linear least squares. +[Optimistix](https://github.com/patrick-kidger/optimistix): root finding, minimisation, fixed points, and least squares. -[jaxtyping](https://github.com/google/jaxtyping): type annotations for shape/dtype of arrays. +[Lineax](https://github.com/google/lineax): linear solvers. -[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models. +[BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. [sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. +[Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models. + [Levanter](https://github.com/stanford-crfm/levanter): scalable+reliable training of foundation models (e.g. LLMs). + +[PySR](https://github.com/milesCranmer/PySR): symbolic regression. (Non-JAX honourable mention!) From 05f9f693c7aca4b7c4d6d5fd940c272556501a66 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Fri, 6 Oct 2023 15:52:11 +0100 Subject: [PATCH 09/12] Add Orbax to ecosystem list. --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 5e2507a4..c8dea370 100644 --- a/README.md +++ b/README.md @@ -73,6 +73,8 @@ If you found this library useful in academic research, please cite: [(arXiv link [BlackJAX](https://github.com/blackjax-devs/blackjax): probabilistic+Bayesian sampling. +[Orbax](https://github.com/google/orbax): checkpointing (async/multi-host/multi-device). + [sympy2jax](https://github.com/google/sympy2jax): SymPy<->JAX conversion; train symbolic expressions via gradient descent. [Eqxvision](https://github.com/paganpasta/eqxvision): computer vision models. From ec651ce807038f93d54610d70de17a1314cecdb9 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Mon, 9 Oct 2023 19:14:28 +0100 Subject: [PATCH 10/12] Fix for ZeroDivisionError without JIT --- diffrax/integrate.py | 5 +++++ diffrax/nonlinear_solver/newton.py | 2 +- test/test_integrate.py | 24 ++++++++++++++++++++++++ 3 files changed, 30 insertions(+), 1 deletion(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 814c90d9..2d904eae 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -415,6 +415,11 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: ) new_state = eqx.tree_at(lambda s: s.result, new_state, result) + if not _filtering: + # This is only necessary for Equinox <0.11.1. + # After that, this fix has been upstreamed to Equinox. + # TODO: remove once we make Equinox >=0.11.1 required. + new_state = jtu.tree_map(jnp.array, new_state) return new_state _filtering = True diff --git a/diffrax/nonlinear_solver/newton.py b/diffrax/nonlinear_solver/newton.py index 1458a8d0..c167bb1c 100644 --- a/diffrax/nonlinear_solver/newton.py +++ b/diffrax/nonlinear_solver/newton.py @@ -141,7 +141,7 @@ def body_fn(val): val = (flat, step + 1, diffsize, diffsize_prev) return val - val = (flat, 0, 0.0, 0.0) + val = (flat, 0, jnp.array(0.0), jnp.array(0.0)) val = lax.while_loop(cond_fn, body_fn, val) flat, num_steps, diffsize, diffsize_prev = val diff --git a/test/test_integrate.py b/test/test_integrate.py index d2e9cd84..0df3bff6 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -418,3 +418,27 @@ def run(y0): assert sol.made_jump is False run(1) + + +def test_no_jit(): + # https://github.com/patrick-kidger/diffrax/issues/293 + # https://github.com/patrick-kidger/diffrax/issues/321 + + # Test that this doesn't crash. + with jax.disable_jit(): + + def vector_field(t, y, args): + return jnp.zeros_like(y) + + term = diffrax.ODETerm(vector_field) + y = jnp.zeros((1,)) + stepsize_controller = diffrax.PIDController(rtol=1e-5, atol=1e-5) + diffrax.diffeqsolve( + term, + diffrax.Kvaerno4(), + t0=0, + t1=1e-2, + dt0=1e-3, + stepsize_controller=stepsize_controller, + y0=y, + ) From 2cc447f0323f181cd17d01de6a912e75c1b1695b Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 8 Oct 2023 23:18:12 -0700 Subject: [PATCH 11/12] Fixed SDEs being unnecessarily slow to solve when max_steps!=None. The main changes needed to make this happen are in https://github.com/patrick-kidger/equinox/pull/548 and as such this commit is fairly small -- it declares a dependency on a new (as yet unreleasd) version of Equinox, and removes the compatibility shim that was there before. **How things used to be.** To explain what's going on here a little more carefully: JAX only recently added support for communicating to a `custom_vjp` which input arguments were being perturbed, and which output cotangents were symbolic zeros. Prior to that, a `custom_vjp` basically just had to differentiate all inexact arrays. However, there are some nondifferentiable inexact arrays, in the sense that attempting to differentiate them will raise an error. Solving SDEs has one such array: the nondifferentiable input to a VirtualBrownianTree, which is guarded by a `eqxi.nondifferentiable` to reflect the fact that Brownian motion is nondifferentiable. So it used to be the case that the `custom_vjp` underlying `RecursiveCheckpointAdjoint` would differentiate the overall make-a-step function with respect to all inexact arrays, including the time variable, hit the `nondifferentiable` guard, and crash. One (unsafe) fix would have just been to remove the `nondifferentiable` guard. In practice I previously took the slower, safer option: silently switch out `RecursiveCheckpointAdjoint` for a `DirectAdjoint`. The latter is much less efficient, but uses no `custom_vjp`, and thus used the perturbation and symbolic-zero propagation rules already present in JAX's AD machinery, and was thus safe to use here. **How things are now.** And so, what has now changed: JAX has now added support for tracking which inputs are perturbed, and which cotangents are symbolic zeros. The Equinox PR above uses this functionality to determine which parts of the carry need to be differentiated; no longer is it just "all inexact arrays". And thus this Diffrax PR removes the compatibility shim that is no longer needed. **Implications.** SDE solving should now be much faster. In particular this fixes the speed issue reported in #317. --- diffrax/integrate.py | 18 +----------------- setup.py | 2 +- 2 files changed, 2 insertions(+), 18 deletions(-) diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 2d904eae..d1a1b5a1 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -10,7 +10,7 @@ import jax.tree_util as jtu from jax.typing import ArrayLike -from .adjoint import AbstractAdjoint, DirectAdjoint, RecursiveCheckpointAdjoint +from .adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint from .custom_types import Array, Bool, Int, PyTree, Scalar from .event import AbstractDiscreteTerminatingEvent from .global_interpolation import DenseInterpolation @@ -638,22 +638,6 @@ def diffeqsolve( "An SDE should not be solved with adaptive step sizes with Euler's " "method, as it may not converge to the correct solution." ) - # TODO: remove these lines. - # - # These are to work around an edge case: on the backward pass, - # RecursiveCheckpointAdjoint currently tries to differentiate the overall - # per-step function wrt all floating-point arrays. In particular this includes - # `state.tprev`, which feeds into the control, which feeds into - # VirtualBrownianTree, which can't be differentiated. - # We're waiting on JAX to offer a way of specifying which arguments to a - # custom_vjp have symbolic zero *tangents* (not cotangents) so that we can more - # precisely determine what to differentiate wrt. - # - # We don't replace this in the case of an unsafe SDE because - # RecursiveCheckpointAdjoint will raise an error in that case anyway, so we - # should let the normal error be raised. - if isinstance(adjoint, RecursiveCheckpointAdjoint) and not is_unsafe_sde(terms): - adjoint = DirectAdjoint() if is_unsafe_sde(terms): if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): raise ValueError( diff --git a/setup.py b/setup.py index 908c1d58..618c8a79 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ python_requires = "~=3.9" -install_requires = ["jax>=0.4.13", "equinox>=0.10.11"] +install_requires = ["jax>=0.4.13", "equinox>=0.11.1"] setuptools.setup( name=name, From 712c20807fcb260691dbaad6240698f4e1fe3dae Mon Sep 17 00:00:00 2001 From: Jason Rader <38091354+packquickly@users.noreply.github.com> Date: Fri, 20 Oct 2023 13:55:34 +0300 Subject: [PATCH 12/12] Update comment on adaptive step-sizing MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit I noticed the description of the comment on adaptive step-sizing defined δ_{n, n} as the inverse of what it looks to be in code. If I'm wrong, just close this issue --- diffrax/step_size_controller/adaptive.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index f794cfb6..71d2761d 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -423,8 +423,8 @@ def adapt_step_size( # h_n is the nth step size # ε_n = atol + norm(y) * rtol with y on the nth step # r_n = norm(y_error) with y_error on the nth step - # δ_{n,m} = norm(y_error / (atol + norm(y) * rtol)) with y_error on the nth - # step and y on the mth step + # δ_{n,m} = norm(y_error / (atol + norm(y) * rtol))^(-1) with y_error on the nth + # step and y on the mth step # β_1 = pcoeff + icoeff + dcoeff # β_2 = -(pcoeff + 2 * dcoeff) # β_3 = dcoeff