diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 34ee945f..82573314 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -14,7 +14,7 @@ jobs: with: python-version: "3.8" test-script: | - python -m pip install pytest psutil jax jaxlib equinox scipy + python -m pip install pytest psutil jax jaxlib equinox scipy optax cp -r ${{ github.workspace }}/test ./test pytest pypi-token: ${{ secrets.pypi_token }} diff --git a/.github/workflows/run_tests.yml b/.github/workflows/run_tests.yml index 806a3ff8..486c8b75 100644 --- a/.github/workflows/run_tests.yml +++ b/.github/workflows/run_tests.yml @@ -23,7 +23,7 @@ jobs: - name: Install dependencies run: | python -m pip install --upgrade pip - python -m pip install pytest psutil wheel scipy numpy jaxlib + python -m pip install pytest psutil wheel scipy numpy optax jaxlib - name: Checks with pre-commit uses: pre-commit/action@v2.0.3 diff --git a/diffrax/__init__.py b/diffrax/__init__.py index c4ef915c..466bd7bc 100644 --- a/diffrax/__init__.py +++ b/diffrax/__init__.py @@ -1,10 +1,16 @@ from .adjoint import ( AbstractAdjoint, BacksolveAdjoint, + ImplicitAdjoint, NoAdjoint, RecursiveCheckpointAdjoint, ) from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree +from .event import ( + AbstractDiscreteTerminatingEvent, + DiscreteTerminatingEvent, + SteadyStateEvent, +) from .global_interpolation import ( AbstractGlobalInterpolation, backward_hermite_coefficients, @@ -31,7 +37,6 @@ from .saveat import SaveAt from .solution import RESULTS, Solution from .solver import ( - AbstractAdaptiveSDESolver, AbstractAdaptiveSolver, AbstractDIRK, AbstractERK, @@ -45,6 +50,7 @@ AbstractWrappedSolver, Bosh3, ButcherTableau, + CalculateJacobian, Dopri5, Dopri8, Euler, @@ -81,4 +87,4 @@ ) -__version__ = "0.1.2" +__version__ = "0.2.0" diff --git a/diffrax/adjoint.py b/diffrax/adjoint.py index 7955d991..af091448 100644 --- a/diffrax/adjoint.py +++ b/diffrax/adjoint.py @@ -6,7 +6,7 @@ import jax.lax as lax import jax.numpy as jnp -from .misc import nondifferentiable_output, ω +from .misc import implicit_jvp, nondifferentiable_output, ω from .saveat import SaveAt from .term import AbstractTerm, AdjointTerm @@ -22,6 +22,7 @@ def loop( terms, solver, stepsize_controller, + discrete_terminating_event, saveat, t0, t1, @@ -92,6 +93,73 @@ def loop(self, *, throw, **kwargs): return final_state, aux_stats +def _vf(ys, residual, args__terms, closure): + state_no_y, _ = residual + t = state_no_y.tprev + (y,) = ys # unpack length-1 dimension + args, terms = args__terms + _, _, solver, _, _ = closure + return solver.func(terms, t, y, args) + + +def _solve(args__terms, closure): + args, terms = args__terms + self, kwargs, solver, saveat, init_state = closure + final_state, aux_stats = self._loop_fn( + **kwargs, + args=args, + terms=terms, + solver=solver, + saveat=saveat, + init_state=init_state, + is_bounded=False, + ) + # Note that we use .ys not .y here. The former is what is actually returned + # by diffeqsolve, so it is the thing we want to attach the tangent to. + return final_state.ys, ( + eqx.tree_at(lambda s: s.ys, final_state, None), + aux_stats, + ) + + +class ImplicitAdjoint(AbstractAdjoint): + r"""Backpropagate via the [implicit function theorem](https://en.wikipedia.org/wiki/Implicit_function_theorem#Statement_of_the_theorem). + + This is used when solving towards a steady state, typically using + [`diffrax.SteadyStateEvent`][]. In this case, the output of the solver is $y(θ)$ + for which $f(t, y(θ), θ) = 0$. (Where $θ$ corresponds to all parameters found + through `terms` and `args`, but not `y0`.) Then we can skip backpropagating through + the solver and instead directly compute + $\frac{\mathrm{d}y}{\mathrm{d}θ} = - (\frac{\mathrm{d}f}{\mathrm{d}y})^{-1}\frac{\mathrm{d}f}{\mathrm{d}θ}$ + via the implicit function theorem. + """ # noqa: E501 + + def loop(self, *, args, terms, solver, saveat, throw, init_state, **kwargs): + del throw + + # `is` check because this may return a Tracer from SaveAt(ts=) + if eqx.tree_equal(saveat, SaveAt(t1=True)) is not True: + raise ValueError( + "Can only use `adjoint=ImplicitAdjoint()` with `SaveAt(t1=True)`." + ) + + init_state = eqx.tree_at( + lambda s: (s.y, s.solver_state, s.controller_state), + init_state, + replace_fn=lax.stop_gradient, + ) + closure = (self, kwargs, solver, saveat, init_state) + ys, residual = implicit_jvp(_solve, _vf, (args, terms), closure) + + final_state_no_ys, aux_stats = residual + return ( + eqx.tree_at( + lambda s: s.ys, final_state_no_ys, ys, is_leaf=lambda x: x is None + ), + aux_stats, + ) + + # Compute derivatives with respect to the first argument: # - y, corresponding to the initial state; # - args, corresponding to explicit parameters; @@ -116,7 +184,6 @@ def _loop_backsolve_fwd(y__args__terms, **kwargs): return (final_state, aux_stats), (ts, ys) -# TODO: implement this as a single diffeqsolve with events, once events are supported. def _loop_backsolve_bwd( residuals, grad_final_state__aux_stats, @@ -125,6 +192,7 @@ def _loop_backsolve_bwd( self, solver, stepsize_controller, + discrete_terminating_event, saveat, t0, t1, @@ -162,6 +230,7 @@ def _loop_backsolve_bwd( adjoint=self, solver=solver, stepsize_controller=stepsize_controller, + discrete_terminating_event=discrete_terminating_event, terms=adjoint_terms, dt0=None if dt0 is None else -dt0, max_steps=max_steps, diff --git a/diffrax/event.py b/diffrax/event.py new file mode 100644 index 00000000..72952cc5 --- /dev/null +++ b/diffrax/event.py @@ -0,0 +1,98 @@ +import abc +from typing import Callable, Optional + +import equinox as eqx + +from .custom_types import Bool, PyTree, Scalar +from .misc import rms_norm +from .step_size_controller import AbstractAdaptiveStepSizeController + + +class AbstractDiscreteTerminatingEvent(eqx.Module): + """Evaluated at the end of each integration step. If true then the solve is stopped + at that time. + """ + + @abc.abstractmethod + def __call__(self, state, **kwargs): + """**Arguments:** + + - `state`: a dataclass of the evolving state of the system, including in + particular the solution `state.y` at time `state.tprev`. + - `**kwargs`: the integration options held constant throughout the solve + are passed as keyword arguments: `terms`, `solver`, `args`. etc. + + **Returns** + + A boolean. If true then the solve is terminated. + """ + + +class DiscreteTerminatingEvent(AbstractDiscreteTerminatingEvent): + """Terminates the solve if its condition is ever active.""" + + cond_fn: Callable[..., Bool] + + def __call__(self, state, **kwargs): + return self.cond_fn(state, **kwargs) + + +DiscreteTerminatingEvent.__init__.__doc__ = """**Arguments:** + +- `cond_fn`: A function `(state, **kwargs) -> bool` that is evaluated on every step of + the differential equation solve. If it returns `True` then the solve is finished at + that timestep. `state` is a dataclass of the evolving state of the system, + including in particular the solution `state.y` at time `state.tprev`. Passed as + keyword arguments are the `terms`, `solver`, `args` etc. that are constant + throughout the solve. +""" + + +class SteadyStateEvent(AbstractDiscreteTerminatingEvent): + """Terminates the solve once it reaches a steady state.""" + + rtol: Optional[float] = None + atol: Optional[float] = None + norm: Callable[[PyTree], Scalar] = rms_norm + + def __call__(self, state, *, terms, args, solver, stepsize_controller, **kwargs): + del kwargs + _error = False + if self.rtol is None: + if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): + _rtol = stepsize_controller.rtol + else: + _error = True + else: + _rtol = self.rtol + if self.atol is None: + if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): + _atol = stepsize_controller.atol + else: + _error = True + else: + _atol = self.atol + if _error: + raise ValueError( + "The `rtol` and `atol` tolerances for `SteadyStateEvent` default " + "to the `rtol` and `atol` used with an adaptive step size " + "controller (such as `diffrax.PIDController`). Either use an " + "adaptive step size controller, or specify these tolerances " + "manually." + ) + + # TODO: this makes an additional function evaluation that in practice has + # probably already been made by the solver. + vf = solver.func(terms, state.tprev, state.y, args) + return self.norm(vf) < _atol + _rtol * self.norm(state.y) + + +SteadyStateEvent.__init__.__doc__ = """**Arguments:** + +- `rtol`: The relative tolerance for determining convergence. Defaults to the + same `rtol` as passed to an adaptive step controller if one is used. +- `atol`: The absolute tolerance for determining convergence. Defaults to the + same `atol` as passed to an adaptive step controller if one is used. +- `norm`: A function `PyTree -> Scalar`, which is called to determine whether + the vector field is close to zero. +""" diff --git a/diffrax/integrate.py b/diffrax/integrate.py index 97efdfbd..05653d52 100644 --- a/diffrax/integrate.py +++ b/diffrax/integrate.py @@ -14,6 +14,7 @@ RecursiveCheckpointAdjoint, ) from .custom_types import Array, Bool, Int, PyTree, Scalar +from .event import AbstractDiscreteTerminatingEvent from .global_interpolation import DenseInterpolation from .heuristics import is_sde, is_unsafe_sde from .misc import ( @@ -25,14 +26,8 @@ unvmap_max, ) from .saveat import SaveAt -from .solution import RESULTS, Solution -from .solver import ( - AbstractAdaptiveSDESolver, - AbstractItoSolver, - AbstractSolver, - AbstractStratonovichSolver, - Euler, -) +from .solution import is_okay, is_successful, RESULTS, Solution +from .solver import AbstractItoSolver, AbstractSolver, AbstractStratonovichSolver, Euler from .step_size_controller import ( AbstractAdaptiveStepSizeController, AbstractStepSizeController, @@ -89,6 +84,9 @@ def _save(state: _State, t: Scalar) -> _State: def _clip_to_end(tprev, tnext, t1, keep_step): + # The tolerance means that we don't end up with too-small intervals for + # dense output, which then gives numerically unstable answers due to floating + # point errors. if tnext.dtype is jnp.dtype("float64"): tol = 1e-10 else: @@ -102,6 +100,7 @@ def loop( *, solver, stepsize_controller, + discrete_terminating_event, saveat, t0, t1, @@ -121,7 +120,7 @@ def loop( init_state = eqx.tree_at(lambda s: s.dense_ts, init_state, dense_ts) def cond_fun(state): - return (state.tprev < t1) & (state.result == RESULTS.successful) + return (state.tprev < t1) & is_successful(state.result) def body_fun(state, inplace): @@ -164,9 +163,6 @@ def body_fun(state, inplace): # Do some book-keeping. # - # The 1e-6 tolerance means that we don't end up with too-small intervals for - # dense output, which then gives numerically unstable answers due to floating - # point errors. tprev = jnp.minimum(tprev, t1) tnext = _clip_to_end(tprev, tnext, t1, keep_step) @@ -179,7 +175,7 @@ def body_fun(state, inplace): made_jump = keep(made_jump, state.made_jump) solver_result = keep(solver_result, RESULTS.successful) - # TODO: if we ever support events, then they should go in here. + # TODO: if we ever support non-terminating events, then they should go in here. # In particular the thing to be careful about is in the `if saveat.steps` # branch below, where we want to make sure that it is the value of `y` at # `tprev` that is actually saved. (And not just the value of `y` at the @@ -187,10 +183,8 @@ def body_fun(state, inplace): # Store the first unsuccessful result we get whilst iterating (if any). result = state.result - result = jnp.where(result == RESULTS.successful, solver_result, result) - result = jnp.where( - result == RESULTS.successful, stepsize_controller_result, result - ) + result = jnp.where(is_okay(result), solver_result, result) + result = jnp.where(is_okay(result), stepsize_controller_result, result) # Count the number of steps, just for statistical purposes. num_steps = state.num_steps + 1 @@ -366,6 +360,26 @@ def maybe_inplace(i, x, u): dense_save_index=dense_save_index, ) + if discrete_terminating_event is not None: + discrete_terminating_event_occurred = discrete_terminating_event( + new_state, + solver=solver, + stepsize_controller=stepsize_controller, + saveat=saveat, + t0=t0, + t1=t1, + dt0=dt0, + max_steps=max_steps, + terms=terms, + args=args, + ) + result = jnp.where( + discrete_terminating_event_occurred, + RESULTS.discrete_terminating_event_occurred, + result, + ) + new_state = eqx.tree_at(lambda s: s.result, new_state, result) + return new_state if is_bounded: @@ -466,7 +480,9 @@ def _cond_fun(state): if saveat.t1 and not saveat.steps: # if saveat.steps then the final value is already saved. - final_state = _save(final_state, t1) + # Using `tprev` instead of `t1` in case of an event terminating the solve + # early. (And absent such an event then `tprev == t1`.) + final_state = _save(final_state, final_state.tprev) result = jnp.where( cond_fun(final_state), RESULTS.max_steps_reached, final_state.result ) @@ -487,6 +503,7 @@ def diffeqsolve( saveat: SaveAt = SaveAt(t1=True), stepsize_controller: AbstractStepSizeController = ConstantStepSize(), adjoint: AbstractAdjoint = RecursiveCheckpointAdjoint(), + discrete_terminating_event: Optional[AbstractDiscreteTerminatingEvent] = None, max_steps: Optional[int] = 16**3, throw: bool = True, solver_state: Optional[PyTree] = None, @@ -510,8 +527,7 @@ def diffeqsolve( (For non-ordinary differential equations (SDEs, CDEs), this also specifies the Brownian motion or the control.) - `solver`: The solver for the differential equation. See the guide on [how to - choose a solver](../usage/how-to-choose-a-solver.md), or the [complete list of - solvers](../api/solver.md). + choose a solver](../usage/how-to-choose-a-solver.md). - `t0`: The start of the region of integration. - `t1`: The end of the region of integration. - `dt0`: The step size to use for the first step. If using fixed step sizes then @@ -538,13 +554,15 @@ def diffeqsolve( checkpointing, which is usually the best option for most problems. See the page on [Adjoints](./adjoints.md) for more information. + - `discrete_terminating_event`: A discrete event at which to terminate the solve + early. See the page on [Events](./events.md) for more information. + - `max_steps`: The maximum number of steps to take before quitting the computation unconditionally. Can also be set to `None` to allow an arbitrary number of steps, although this - will disable backpropagation via discretise-then-optimise (backpropagation via - optimise-then-discretise will still work), and also disables - `saveat.steps=True` and `saveat.dense=True`. + is incompatible with `adjoint=RecursiveCheckpointAdjoint()` (the default) and + is incompatible with `saveat=SaveAt(steps=True)` or `saveat=SaveAt(dense=True)`. Note that (a) compile times; and (b) backpropagation run times; will increase as `max_steps` increases. (Specifically, each time `max_steps` passes a power @@ -571,16 +589,15 @@ def diffeqsolve( of the returned solution object, to determine which batch elements succeeded and which failed. - - `solver_state`: Some initial state for the solver. Can be useful when for example - using a reversible solver to recompute a solution. Generally obtained by - `SaveAt(solver_state=True)`. It is unlikely you will need to use this option. + - `solver_state`: Some initial state for the solver. Generally obtained by + `SaveAt(solver_state=True)` from a previous solve. - `controller_state`: Some initial state for the step size controller. Generally - obtained by `SaveAt(controller_state=True)`. It is unlikely you will need to - use this option. + obtained by `SaveAt(controller_state=True)` from a previous solve. - `made_jump`: Whether a jump has just been made at `t0`. Used to update - `solver_state` (if passed). It is unlikely you will need to use this option. + `solver_state` (if passed). Generally obtained by `SaveAt(made_jump=True)` + from a previous solve. **Returns:** @@ -656,11 +673,6 @@ def diffeqsolve( "An SDE should not be solved with adaptive step sizes with Euler's " "method; it will not converge to the correct solution." ) - if not isinstance(solver, AbstractAdaptiveSDESolver): - raise ValueError( - "An adaptive step size controller is being used with a solver " - "that does not provide error estimates suitable for SDEs." - ) if is_unsafe_sde(terms): if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): raise ValueError( @@ -731,12 +743,12 @@ def _promote(yi): error_order = solver.error_order(terms) if controller_state is None: (tnext, controller_state) = stepsize_controller.init( - terms, t0, t1, y0, dt0, args, solver.func_for_init, error_order + terms, t0, t1, y0, dt0, args, solver.func, error_order ) else: if dt0 is None: (tnext, _) = stepsize_controller.init( - terms, t0, t1, y0, dt0, args, solver.func_for_init, error_order + terms, t0, t1, y0, dt0, args, solver.func, error_order ) else: tnext = t0 + dt0 @@ -822,6 +834,7 @@ def _promote(yi): terms=terms, solver=solver, stepsize_controller=stepsize_controller, + discrete_terminating_event=discrete_terminating_event, saveat=saveat, t0=t0, t1=t1, @@ -883,7 +896,7 @@ def _promote(yi): result = final_state.result error_index = unvmap_max(result) branched_error_if( - throw & (result != RESULTS.successful), + throw & jnp.invert(is_okay(result)), error_index, RESULTS.reverse_lookup, ) diff --git a/diffrax/misc/__init__.py b/diffrax/misc/__init__.py index 084c00f2..74e7dbc3 100644 --- a/diffrax/misc/__init__.py +++ b/diffrax/misc/__init__.py @@ -1,4 +1,9 @@ -from .ad import fixed_custom_jvp, nondifferentiable_input, nondifferentiable_output +from .ad import ( + fixed_custom_jvp, + implicit_jvp, + nondifferentiable_input, + nondifferentiable_output, +) from .bounded_while_loop import bounded_while_loop, HadInplaceUpdate from .errors import branched_error_if, error_if from .misc import ( diff --git a/diffrax/misc/ad.py b/diffrax/misc/ad.py index b92c9516..21d47681 100644 --- a/diffrax/misc/ad.py +++ b/diffrax/misc/ad.py @@ -1,12 +1,15 @@ +import functools as ft from typing import Any import equinox as eqx import jax +import jax.flatten_util as fu import jax.interpreters.ad as ad import jax.interpreters.batching as batching import jax.interpreters.mlir as mlir import jax.interpreters.xla as xla import jax.lax as lax +import jax.numpy as jnp from ..custom_types import PyTree @@ -109,3 +112,75 @@ def __call__(self, *args): ) nondiff_args_tracer = jax.tree_map(lax.stop_gradient, nondiff_args_tracer) return self.fn(nondiff_args_nontracer, nondiff_args_tracer, diff_args) + + +# TODO: I think the jacfwd and the jvp can probably be combined, as they both +# basically do the same thing. That might improve efficiency via parallelism. +def implicit_jvp(fn_primal, fn_rewrite, args, closure): + """ + Takes a function `fn_primal : (args, closure) -> (root, residual)` and a function + `fn_rewrite : (root, residual, args, closure) -> arb`. + + Has primals `fn_primal(args, closure)[0]` with auxiliary information + `fn_primal(args, closure)[1]`. + Has tangents `-(d(fn_rewrite)/d(root))^-1 d(fn_rewrite)/d(args)`, evaluated at + `(root, residual, args, closure)`. + + This is used for rewriting gradients via the implicit function theorem. + + Note that due to limitations with JAX's custom autodiff, both `fn_primal` and + `fn_rewrite` should be global functions (i.e. they should not capture any JAX array + via closure, even if it does not participate in autodiff). + """ + diff_args, nondiff_args = eqx.partition(args, eqx.is_inexact_array) + root, residual = _implicit_backprop( + fn_primal, fn_rewrite, nondiff_args, closure, diff_args + ) + # Trim off the zero tangents we added to `residual`. + return root, jax.tree_map(lax.stop_gradient, residual) + + +@ft.partial(fixed_custom_jvp, nondiff_argnums=(0, 1, 2, 3)) +def _implicit_backprop(fn_primal, fn_rewrite, nondiff_args, closure, diff_args): + del fn_rewrite + args = eqx.combine(diff_args, nondiff_args) + return fn_primal(args, closure) + + +@_implicit_backprop.defjvp +def _implicit_backprop_jvp( + fn_primal, fn_rewrite, nondiff_args, closure, diff_args, tang_diff_args +): + (diff_args,) = diff_args + (tang_diff_args,) = tang_diff_args + root, residual = _implicit_backprop( + fn_primal, fn_rewrite, nondiff_args, closure, diff_args + ) + + flat_root, unflatten_root = fu.ravel_pytree(root) + args = eqx.combine(nondiff_args, diff_args) + + def _for_jac(_root): + _root = unflatten_root(_root) + _out = fn_rewrite(_root, residual, args, closure) + _out, _ = fu.ravel_pytree(_out) + return _out + + jac_flat_root = jax.jacfwd(_for_jac)(flat_root) + + flat_diff_args, unflatten_diff_args = fu.ravel_pytree(diff_args) + flat_tang_diff_args, _ = fu.ravel_pytree(tang_diff_args) + + def _for_jvp(_diff_args): + _diff_args = unflatten_diff_args(_diff_args) + _args = eqx.combine(nondiff_args, _diff_args) + _out = fn_rewrite(root, residual, _args, closure) + _out, _ = fu.ravel_pytree(_out) + return _out + + _, jvp_flat_diff_args = jax.jvp(_for_jvp, (flat_diff_args,), (flat_tang_diff_args,)) + + tang_root = -jnp.linalg.solve(jac_flat_root, jvp_flat_diff_args) + tang_root = unflatten_root(tang_root) + tang_residual = jax.tree_map(jnp.zeros_like, residual) + return (root, residual), (tang_root, tang_residual) diff --git a/diffrax/misc/misc.py b/diffrax/misc/misc.py index e35bcc03..875e731a 100644 --- a/diffrax/misc/misc.py +++ b/diffrax/misc/misc.py @@ -1,4 +1,3 @@ -import typing from typing import Optional, Tuple import jax @@ -29,11 +28,6 @@ def force_bitcast_convert_type(val, new_type): class ContainerMeta(type): def __new__(cls, name, bases, dict): - - if getattr(typing, "GENERATING_DOCUMENTATION", False): - # Display containers as ints in documentation - return int - assert "reverse_lookup" not in dict _dict = {} reverse_lookup = [] diff --git a/diffrax/nonlinear_solver/base.py b/diffrax/nonlinear_solver/base.py index 6f24796c..24872ae5 100644 --- a/diffrax/nonlinear_solver/base.py +++ b/diffrax/nonlinear_solver/base.py @@ -4,11 +4,12 @@ import equinox as eqx import jax import jax.flatten_util as fu +import jax.lax as lax import jax.numpy as jnp import jax.scipy as jsp from ..custom_types import Int, PyTree, Scalar -from ..misc import fixed_custom_jvp +from ..misc import implicit_jvp from ..solution import RESULTS @@ -21,6 +22,18 @@ class NonlinearSolution(eqx.Module): result: RESULTS +def _primal(diff_args, closure): + self, fn, x, jac, nondiff_args = closure + nsol = self._solve(fn, x, jac, nondiff_args, diff_args) + return nsol.root, eqx.tree_at(lambda s: s.root, nsol, None) + + +def _rewrite(root, _, diff_args, closure): + _, fn, _, _, nondiff_args = closure + args = eqx.combine(diff_args, nondiff_args) + return fn(root, args) + + class AbstractNonlinearSolver(eqx.Module): """Abstract base class for all nonlinear root-finding algorithms. @@ -30,13 +43,6 @@ class AbstractNonlinearSolver(eqx.Module): rtol: Optional[Scalar] = None atol: Optional[Scalar] = None - def __init_subclass__(cls, **kwargs): - super().__init_subclass__(**kwargs) - # Note that this breaks the descriptor protocol so we have to pass self - # manually in __call__. - cls._solve = fixed_custom_jvp(cls._solve, nondiff_argnums=(0, 1, 2, 3, 4)) - cls._solve.defjvp(_root_solve_jvp) - @abc.abstractmethod def _solve( self, @@ -80,10 +86,13 @@ def __call__( whether the solver managed to converge or not. """ - # TODO: switch from is_inexact_array to is_perturbed once JAX issue #9567 is - # fixed. + x = lax.stop_gradient(x) diff_args, nondiff_args = eqx.partition(args, eqx.is_inexact_array) - return self._solve(self, fn, x, jac, nondiff_args, diff_args) + closure = (self, fn, x, jac, nondiff_args) + root, nsol_no_root = implicit_jvp(_primal, _rewrite, diff_args, closure) + return eqx.tree_at( + lambda s: s.root, nsol_no_root, root, is_leaf=lambda z: z is None + ) @staticmethod def jac(fn: Callable, x: PyTree, args: PyTree) -> LU_Jacobian: @@ -98,65 +107,3 @@ def jac(fn: Callable, x: PyTree, args: PyTree) -> LU_Jacobian: # Handle integer arguments flat = flat.astype(jnp.float32) return jsp.linalg.lu_factor(jax.jacfwd(curried)(flat)) - - -# TODO: I think the jacfwd and the jvp can probably be combined, as they both -# basically do the same thing. That might improve efficiency via parallelism. -# TODO: support differentiating wrt `fn`? This isn't terribly hard -- just pass it as -# part of `diff_args` and use a custom "apply" instead of `fn`. However I can see that -# stating "differentiating wrt `fn` is allowed" might result in confusion if an attempt -# is made to differentiate wrt anything `fn` closes over. (Which is the behaviour of -# `lax.custom_root`. Such closure-differentiation is "magical" behaviour that I won't -# ever put into code I write; if differentiating wrt "closed over values" is expected -# then it's much safer to require that `fn` be a PyTree a la Equinox, but at time of -# writing that isn't yet culturally widespread enough.) -def _root_solve_jvp( - self: AbstractNonlinearSolver, - fn: callable, - x: PyTree, - jac: Optional[LU_Jacobian], - nondiff_args: PyTree, - diff_args: PyTree, - tang_diff_args: PyTree, -): - """JVP for differentiably solving for the root of a function, via the implicit - function theorem. - - Gradients are computed with respect to diff_args. - - This is a lot like lax.custom_root -- we just use less magic. Rather than creating - gradients for whatever the function happened to close over, we create gradients for - just diff_args. - """ - - (diff_args,) = diff_args - (tang_diff_args,) = tang_diff_args - solution = self._solve(self, fn, x, jac, nondiff_args, diff_args) - root = solution.root - - flat_root, unflatten_root = fu.ravel_pytree(root) - args = eqx.combine(nondiff_args, diff_args) - - def _for_jac(_root): - _root = unflatten_root(_root) - _out = fn(_root, args) - _out, _ = fu.ravel_pytree(_out) - return _out - - jac_flat_root = jax.jacfwd(_for_jac)(flat_root) - - flat_diff_args, unflatten_diff_args = fu.ravel_pytree(diff_args) - flat_tang_diff_args, _ = fu.ravel_pytree(tang_diff_args) - - def _for_jvp(_diff_args): - _diff_args = unflatten_diff_args(_diff_args) - _args = eqx.combine(nondiff_args, _diff_args) - _out = fn(root, _args) - _out, _ = fu.ravel_pytree(_out) - return _out - - _, jvp_flat_diff_args = jax.jvp(_for_jvp, (flat_diff_args,), (flat_tang_diff_args,)) - - tang_root = -jnp.linalg.solve(jac_flat_root, jvp_flat_diff_args) - tang_root = unflatten_root(tang_root) - return solution, NonlinearSolution(root=tang_root, num_steps=0, result=0) diff --git a/diffrax/nonlinear_solver/newton.py b/diffrax/nonlinear_solver/newton.py index 08d2af79..1458a8d0 100644 --- a/diffrax/nonlinear_solver/newton.py +++ b/diffrax/nonlinear_solver/newton.py @@ -93,9 +93,21 @@ def _solve( diff_args: PyTree, ) -> Tuple[PyTree, RESULTS]: args = eqx.combine(nondiff_args, diff_args) - rtol = 1e-3 if self.rtol is None else self.rtol - atol = 1e-6 if self.atol is None else self.atol - scale = atol + rtol * self.norm(x) + if self.rtol is None or self.atol is None: + raise ValueError( + "The `rtol` and `atol` tolerances for `NewtonNonlinearSolver` default " + "to the `rtol` and `atol` used with an adaptive step size " + "controller (such as `diffrax.PIDController`). Either use an " + "adaptive step size controller, or specify these tolerances " + "manually.\n" + "Note that this changed in Diffrax version 0.2.0. If you want to match " + "the previous defaults then specify `rtol=1e-3`, `atol=1e-6`. For " + "example:\n" + "```\n" + "diffrax.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6)\n" + "```\n" + ) + scale = self.atol + self.rtol * self.norm(x) flat, unflatten = fu.ravel_pytree(x) if flat.size == 0: return NonlinearSolution(root=x, num_steps=0, result=RESULTS.successful) @@ -153,12 +165,12 @@ def body_fn(val): NewtonNonlinearSolver.__init__.__doc__ = """ **Arguments:** +- `rtol`: The relative tolerance for determining convergence. Defaults to the same + `rtol` as passed to an adaptive step controller if one is used. +- `atol`: The absolute tolerance for determining convergence. Defaults to the same + `atol` as passed to an adaptive step controller if one is used. - `max_steps`: The maximum number of steps allowed. If more than this are required then the iteration fails. Set to `None` to allow an arbitrary number of steps. -- `rtol`: The relative tolerance for determining convergence. If using an adaptive step - size controller, will default to the same `rtol`. Else defaults to `1e-3`. -- `atol`: The absolute tolerance for determining convergence. If using an adaptive step - size controller, will default to the same `atol`. Else defaults to `1e-6`. - `kappa`: The kappa value for determining convergence. - `norm`: A function `PyTree -> Scalar`, which is called to determine the size of the current value. (Used in determining convergence.) diff --git a/diffrax/solution.py b/diffrax/solution.py index e08f2be1..a49b9fe0 100644 --- a/diffrax/solution.py +++ b/diffrax/solution.py @@ -1,6 +1,8 @@ from dataclasses import field from typing import Any, Dict, Optional +import jax.numpy as jnp + from .custom_types import Array, Bool, PyTree, Scalar from .global_interpolation import DenseInterpolation from .misc import ContainerMeta @@ -18,6 +20,38 @@ class RESULTS(metaclass=ContainerMeta): implicit_nonconvergence = ( "Implicit method did not converge within the required number of iterations." ) + discrete_terminating_event_occurred = ( + "Terminating solve because a discrete event occurred." + ) + + +def is_okay(result: RESULTS) -> Bool: + return is_successful(result) | is_event(result) + + +def is_successful(result: RESULTS) -> Bool: + return result == RESULTS.successful + + +# TODO: In the future we may support other event types, in which case this function +# should be updated. +def is_event(result: RESULTS) -> Bool: + return result == RESULTS.discrete_terminating_event_occurred + + +def update_result(old_result: RESULTS, new_result: RESULTS) -> RESULTS: + """ + Returns: + + old | success event_o error_o + new | + --------+------------------------- + success | success event_o error_o + event_n | event_n event_o error_o + error_n | error_n error_n error_o + """ + out_result = jnp.where(is_okay(old_result), new_result, old_result) + return jnp.where(is_okay(new_result) & is_event(old_result), old_result, out_result) class Solution(AbstractPath): diff --git a/diffrax/solver/__init__.py b/diffrax/solver/__init__.py index d05a9fee..8c8dabf9 100644 --- a/diffrax/solver/__init__.py +++ b/diffrax/solver/__init__.py @@ -1,5 +1,4 @@ from .base import ( - AbstractAdaptiveSDESolver, AbstractAdaptiveSolver, AbstractImplicitSolver, AbstractItoSolver, @@ -31,6 +30,7 @@ AbstractRungeKutta, AbstractSDIRK, ButcherTableau, + CalculateJacobian, ) from .semi_implicit_euler import SemiImplicitEuler from .tsit5 import Tsit5 diff --git a/diffrax/solver/base.py b/diffrax/solver/base.py index 3c8eacab..004f8a2c 100644 --- a/diffrax/solver/base.py +++ b/diffrax/solver/base.py @@ -23,12 +23,20 @@ def vector_tree_dot(a, b): class _MetaAbstractSolver(type(eqx.Module)): def __instancecheck__(cls, obj): if super(_MetaAbstractSolver, AbstractWrappedSolver).__instancecheck__(obj): - obj = obj.solver - return super().__instancecheck__(obj) + # Either one will suffice. + return super().__instancecheck__(obj) or super().__instancecheck__( + obj.solver + ) + else: + return super().__instancecheck__(obj) class AbstractSolver(eqx.Module, metaclass=_MetaAbstractSolver): - """Abstract base class for all differential equation solvers.""" + """Abstract base class for all differential equation solvers. + + Subclasses should have a class-level attribute `terms`, specifying the PyTree + structure of `terms` in `diffeqsolve(terms, ...)`. + """ @property @abc.abstractmethod @@ -136,57 +144,76 @@ def step( happened successfully, or if (unusually) it failed for some reason. """ - def func_for_init( + @abc.abstractmethod + def func( self, terms: PyTree[AbstractTerm], t0: Scalar, y0: PyTree, args: PyTree ) -> PyTree: - """Provides vector field evaluations to select the initial step size. + """Evaluate the vector field at a point. (This is unlike + [`diffrax.AbstractSolver.step`][], which operates over an interval.) - This is used to make a point evaluation. This is unlike - [`diffrax.AbstractSolver.step`][], which operates over an interval. - - In general differential equation solvers are interval-based. There is precisely - one place where point evaluations are needed: selecting the initial step size - automatically in an ODE solve. And that is what this function is for. + For most operations differential equation solvers are interval-based, so this + opertion should be used sparingly. This operation is needed for things like + selecting an initial step size. **Arguments:** As [`diffrax.diffeqsolve`][] **Returns:** - The evaluation of the vector field at `t0`. + The evaluation of the vector field at `t0`, `y0`. """ - raise ValueError( - "An initial step size cannot be selected automatically. The most common " - "scenario for this error to occur is when trying to use adaptive step " - "size solvers with SDEs. Please specify an initial `dt0` instead." - ) - class AbstractImplicitSolver(AbstractSolver): + """Indicates that this is an implicit differential equation solver, and as such + that it should take a nonlinear solver as an argument. + """ + nonlinear_solver: AbstractNonlinearSolver = NewtonNonlinearSolver() +AbstractImplicitSolver.__init__.__doc__ = """**Arguments:** + +- `nonlinear_solver`: The nonlinear solver to use. Defaults to a Newton solver. +""" + + class AbstractItoSolver(AbstractSolver): - pass + """Indicates that when used as an SDE solver that this solver will converge to the + Itô solution. + """ class AbstractStratonovichSolver(AbstractSolver): - pass + """Indicates that when used as an SDE solver that this solver will converge to the + Stratonovich solution. + """ class AbstractAdaptiveSolver(AbstractSolver): - pass + """Indicates that this solver provides error estimates, and that as such it may be + used with an adaptive step size controller. + """ -class AbstractAdaptiveSDESolver(AbstractAdaptiveSolver): - pass +class AbstractWrappedSolver(AbstractSolver): + """Wraps another solver "transparently", in the sense that all `isinstance` checks + will be forwarded on to the wrapped solver, e.g. when testing whether the solver is + implicit/adaptive/SDE-compatible/etc. + Inherit from this class if that is desired behaviour. (Do not inherit from this + class if that is not desired behaviour.) + """ -class AbstractWrappedSolver(AbstractSolver): solver: AbstractSolver -class HalfSolver(AbstractWrappedSolver, AbstractAdaptiveSDESolver): +AbstractWrappedSolver.__init__.__doc__ = """**Arguments:** + +- `solver`: The solver to wrap. +""" + + +class HalfSolver(AbstractAdaptiveSolver, AbstractWrappedSolver): """Wraps another solver, trading cost in order to provide error estimates. (That is, it means the solver can be used with an adaptive step size controller, regardless of whether the underlying solver supports adaptive step sizing.) @@ -195,13 +222,15 @@ class HalfSolver(AbstractWrappedSolver, AbstractAdaptiveSDESolver): and comparing the results between the full step and the two half steps. Hence the name "HalfSolver". - As such each step costs 3 times the computational cost of the wrapped solver. + As such each step costs 3 times the computational cost of the wrapped solver, + whilst producing results that are roughly twice as accurate, in addition to + producing error estimates. !!! tip Many solvers already provide error estimates, making `HalfSolver` primarily useful when using a solver that doesn't provide error estimates -- e.g. - [`diffrax.Euler`][] -- such solvers are most common when solving SDEs. + [`diffrax.Euler`][]. Such solvers are most common when solving SDEs. """ @property @@ -270,10 +299,8 @@ def step( return y1, y_error, dense_info, solver_state, result - def func_for_init( - self, terms: PyTree[AbstractTerm], t0: Scalar, y0: PyTree, args: PyTree - ): - return self.solver.func_for_init(terms, t0, y0, args) + def func(self, terms: PyTree[AbstractTerm], t0: Scalar, y0: PyTree, args: PyTree): + return self.solver.func(terms, t0, y0, args) HalfSolver.__init__.__doc__ = """**Arguments:** diff --git a/diffrax/solver/bosh3.py b/diffrax/solver/bosh3.py index f3a0db84..69378721 100644 --- a/diffrax/solver/bosh3.py +++ b/diffrax/solver/bosh3.py @@ -21,6 +21,9 @@ class Bosh3(AbstractERK): 3rd order explicit Runge--Kutta method. Has an embedded 2nd order method for adaptive step sizing. + + Also sometimes known as "Heun's third order method". (Not to be confused with + [`diffrax.Heun`][], which is a second order method). """ tableau = _bosh3_tableau diff --git a/diffrax/solver/euler.py b/diffrax/solver/euler.py index b5ae3e70..01a1d820 100644 --- a/diffrax/solver/euler.py +++ b/diffrax/solver/euler.py @@ -7,14 +7,14 @@ from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm -from .base import AbstractItoSolver, AbstractSolver +from .base import AbstractItoSolver _ErrorEstimate = None _SolverState = None -class Euler(AbstractItoSolver, AbstractSolver): +class Euler(AbstractItoSolver): """Euler's method. 1st order explicit Runge--Kutta method. Does not support adaptive step sizing. @@ -47,11 +47,11 @@ def step( dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful - def func_for_init( + def func( self, terms: AbstractTerm, t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - return terms.func_for_init(t0, y0, args) + return terms.vf(t0, y0, args) diff --git a/diffrax/solver/euler_heun.py b/diffrax/solver/euler_heun.py index a2831b7b..8b02e616 100644 --- a/diffrax/solver/euler_heun.py +++ b/diffrax/solver/euler_heun.py @@ -7,14 +7,14 @@ from ..misc import ω from ..solution import RESULTS from ..term import AbstractTerm -from .base import AbstractSolver, AbstractStratonovichSolver +from .base import AbstractStratonovichSolver _ErrorEstimate = None _SolverState = None -class EulerHeun(AbstractStratonovichSolver, AbstractSolver): +class EulerHeun(AbstractStratonovichSolver): """Euler-Heun method. Used to solve SDEs, and converges to the Stratonovich solution. @@ -55,3 +55,13 @@ def step( dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful + + def func( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + y0: PyTree, + args: PyTree, + ) -> PyTree: + drift, diffusion = terms + return drift.vf(t0, y0, args), diffusion.vf(t0, y0, args) diff --git a/diffrax/solver/heun.py b/diffrax/solver/heun.py index df5024ae..a868b0e9 100644 --- a/diffrax/solver/heun.py +++ b/diffrax/solver/heun.py @@ -1,7 +1,7 @@ import numpy as np from ..local_interpolation import ThirdOrderHermitePolynomialInterpolation -from .base import AbstractAdaptiveSDESolver, AbstractStratonovichSolver +from .base import AbstractStratonovichSolver from .runge_kutta import AbstractERK, ButcherTableau @@ -13,7 +13,7 @@ ) -class Heun(AbstractERK, AbstractStratonovichSolver, AbstractAdaptiveSDESolver): +class Heun(AbstractERK, AbstractStratonovichSolver): """Heun's method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive @@ -23,7 +23,8 @@ class Heun(AbstractERK, AbstractStratonovichSolver, AbstractAdaptiveSDESolver): or "explicit trapezoidal rule". Should not be confused with Heun's third order method, which is a different (higher - order) method occasionally also just referred to as "Heun's method". + order) method occasionally also just referred to as "Heun's method". (Which is + available in Diffrax as [`diffrax.Bosh3`][].) When used to solve SDEs, converges to the Stratonovich solution. """ diff --git a/diffrax/solver/implicit_euler.py b/diffrax/solver/implicit_euler.py index 7c28dd10..ac6a78f2 100644 --- a/diffrax/solver/implicit_euler.py +++ b/diffrax/solver/implicit_euler.py @@ -56,11 +56,11 @@ def step( dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, nonlinear_sol.result - def func_for_init( + def func( self, terms: AbstractTerm, t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - return terms.func_for_init(t0, y0, args) + return terms.vf(t0, y0, args) diff --git a/diffrax/solver/leapfrog_midpoint.py b/diffrax/solver/leapfrog_midpoint.py index 11ce0950..5f3894a2 100644 --- a/diffrax/solver/leapfrog_midpoint.py +++ b/diffrax/solver/leapfrog_midpoint.py @@ -23,7 +23,8 @@ class LeapfrogMidpoint(AbstractSolver): Note that this is referred to as the "leapfrog/midpoint method" as this is the name used by Shampine in the reference below. It should not be confused with any of the many other "leapfrog methods" (there are several), or with the "midpoint method" - (which is usually taken to refer to an explicit Runge--Kutta method). + (which is usually taken to refer to the explicit Runge--Kutta method + [`diffrax.Midpoint`][]). ??? cite "Reference" @@ -71,11 +72,11 @@ def step( solver_state = (t0, y0) return y1, None, dense_info, solver_state, RESULTS.successful - def func_for_init( + def func( self, terms: AbstractTerm, t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - return terms.func_for_init(t0, y0, args) + return terms.vf(t0, y0, args) diff --git a/diffrax/solver/midpoint.py b/diffrax/solver/midpoint.py index d9a9b5a0..8a8b50fe 100644 --- a/diffrax/solver/midpoint.py +++ b/diffrax/solver/midpoint.py @@ -1,7 +1,7 @@ import numpy as np from ..local_interpolation import ThirdOrderHermitePolynomialInterpolation -from .base import AbstractAdaptiveSDESolver, AbstractStratonovichSolver +from .base import AbstractStratonovichSolver from .runge_kutta import AbstractERK, ButcherTableau @@ -13,7 +13,7 @@ ) -class Midpoint(AbstractERK, AbstractStratonovichSolver, AbstractAdaptiveSDESolver): +class Midpoint(AbstractERK, AbstractStratonovichSolver): """Midpoint method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive diff --git a/diffrax/solver/milstein.py b/diffrax/solver/milstein.py index beecfb4b..78fc632a 100644 --- a/diffrax/solver/milstein.py +++ b/diffrax/solver/milstein.py @@ -71,6 +71,16 @@ def _to_jvp(_y0): dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful + def func( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + y0: PyTree, + args: PyTree, + ) -> PyTree: + drift, diffusion = terms + return drift.vf(t0, y0, args), diffusion.vf(t0, y0, args) + class ItoMilstein(AbstractItoSolver): r"""Milstein's method; Itô version. @@ -312,3 +322,13 @@ def _dot(_, _v0): dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful + + def func( + self, + terms: Tuple[AbstractTerm, AbstractTerm], + t0: Scalar, + y0: PyTree, + args: PyTree, + ) -> PyTree: + drift, diffusion = terms + return drift.vf(t0, y0, args), diffusion.vf(t0, y0, args) diff --git a/diffrax/solver/ralston.py b/diffrax/solver/ralston.py index ece3ef8d..be3321b9 100644 --- a/diffrax/solver/ralston.py +++ b/diffrax/solver/ralston.py @@ -1,7 +1,7 @@ import numpy as np from ..local_interpolation import ThirdOrderHermitePolynomialInterpolation -from .base import AbstractAdaptiveSDESolver, AbstractStratonovichSolver +from .base import AbstractStratonovichSolver from .runge_kutta import AbstractERK, ButcherTableau @@ -22,11 +22,13 @@ ) -class Ralston(AbstractERK, AbstractStratonovichSolver, AbstractAdaptiveSDESolver): +class Ralston(AbstractERK, AbstractStratonovichSolver): """Ralston's method. 2nd order explicit Runge--Kutta method. Has an embedded Euler method for adaptive step sizing. + + When used to solve SDEs, converges to the Stratonovich solution. """ tableau = _ralston_tableau diff --git a/diffrax/solver/reversible_heun.py b/diffrax/solver/reversible_heun.py index 5195644f..73a50893 100644 --- a/diffrax/solver/reversible_heun.py +++ b/diffrax/solver/reversible_heun.py @@ -14,7 +14,7 @@ _SolverState = Tuple[PyTree, PyTree] -class ReversibleHeun(AbstractStratonovichSolver, AbstractAdaptiveSolver): +class ReversibleHeun(AbstractAdaptiveSolver, AbstractStratonovichSolver): """Reversible Heun method. Algebraically reversible 2nd order method. Has an embedded 1st order method for @@ -75,11 +75,11 @@ def step( solver_state = (yhat1, vf1) return y1, y1_error, dense_info, solver_state, RESULTS.successful - def func_for_init( + def func( self, terms: AbstractTerm, t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - return terms.func_for_init(t0, y0, args) + return terms.vf(t0, y0, args) diff --git a/diffrax/solver/runge_kutta.py b/diffrax/solver/runge_kutta.py index f7589fdc..dea0fda7 100644 --- a/diffrax/solver/runge_kutta.py +++ b/diffrax/solver/runge_kutta.py @@ -10,7 +10,7 @@ from ..custom_types import Bool, DenseInfo, PyTree, Scalar from ..misc import ContainerMeta, ω -from ..solution import RESULTS +from ..solution import is_okay, RESULTS, update_result from ..term import AbstractTerm from .base import AbstractAdaptiveSolver, AbstractImplicitSolver, vector_tree_dot @@ -35,6 +35,8 @@ def _check(_x): # trace time. @dataclass(frozen=True) class ButcherTableau: + """The Butcher tableau for an explicit or diagonal Runge--Kutta method.""" + # Explicit RK methods c: np.ndarray b_sol: np.ndarray @@ -98,9 +100,50 @@ def __post_init__(self): ) -class _CalculateJacobian(metaclass=ContainerMeta): +ButcherTableau.__init__.__doc__ = """**Arguments:** + +Let `k` denote the number of stages of the solver. + +- `a_lower`: the lower triangle (without the diagonal) of the Butcher tableau. Should + be a tuple of NumPy arrays, corresponding to the rows of this lower triangle. The + first array represents the should be of shape `(1,)`. Each subsequent array should + be of shape `(2,)`, `(3,)` etc. The final array should have shape `(k - 1,)`. +- `b_sol`: the linear combination of stages to take to produce the output at each step. + Should be a NumPy array of shape `(k,)`. +- `b_error`: the linear combination of stages to take to produce the error estimate at + each step. Should be a NumPy array of shape `(k,)`. Note that this is *not* + differenced against `b_sol` prior to evaluation. (i.e. `b_error` gives the linear + combination for producing the error estimate directly, not for producing some + alternate solution that is compared against the main solution). +- `c`: the time increments used in the Butcher tableau. +- `a_diagonal`: optional. The diagonal of the Butcher tableau. Should be `None` or a + NumPy array of shape `(k,)`. Used for diagonal implicit Runge--Kutta methods only. +- `a_predictor`: optional. Used in a similar way to `a_lower`; specifies the linear + combination of previous stages to use as a predictor for the solution to the + implicit problem at that stage. See + [the developer documentation](../../devdocs/predictor_dirk). U#sed for diagonal + implicit Runge--Kutta methods only. + +Whether the solver exhibits either the FSAL or SSAL properties is determined +automatically. +""" + + +class CalculateJacobian(metaclass=ContainerMeta): + """An enumeration of possible ways a Runga--Kutta method may wish to calculate a + Jacobian. + + `never`: used for explicit Runga--Kutta methods. + + `every_step`: the Jacobian is calculated once per step; in particular it is + calculated at the start of the step and re-used for every stage in the step. + Used for SDIRK and ESDIRK methods. + + `every_stage`: the Jacobian is calculated once per stage. Used for DIRK methods. + """ + never = "never" - at_start = "at_start" + every_step = "every_step" every_stage = "every_stage" @@ -135,6 +178,19 @@ def _implicit_relation_k(ki, nonlinear_solve_args): class AbstractRungeKutta(AbstractAdaptiveSolver): + """Abstract base class for all Runge--Kutta solvers. (Other than fully-implicit + Runge--Kutta methods, which have a different computational structure.) + + Whilst this class can be subclassed directly, when defining your own Runge--Kutta + methods, it is usally better to work with [`diffrax.AbstractERK`][], + [`diffrax.AbstractDIRK`][], [`diffrax.AbstractSDIRK`][], + [`diffrax.AbstractESDIRK`][] directly. + + Subclasses should specify two class-level attributes. The first is `tableau`, an + instance of [`diffrax.ButcherTableau`][]. The second is `calculate_jacobian`, an + instance of [`diffrax.CalculateJacobian`][]. + """ + scan_stages: bool = False term_structure = jax.tree_structure(0) @@ -146,17 +202,17 @@ def tableau(self) -> ButcherTableau: @property @abc.abstractmethod - def _calculate_jacobian(self) -> _CalculateJacobian: + def calculate_jacobian(self) -> CalculateJacobian: pass - def func_for_init( + def func( self, terms: AbstractTerm, t0: Scalar, y0: PyTree, args: PyTree, ) -> PyTree: - return terms.func_for_init(t0, y0, args) + return terms.vf(t0, y0, args) def init( self, @@ -259,7 +315,7 @@ def step( ) else: if ( - self._calculate_jacobian == _CalculateJacobian.at_start + self.calculate_jacobian == CalculateJacobian.every_step or implicit_first_stage or not self.scan_stages ): @@ -286,7 +342,7 @@ def step( jac_f = None jac_k = None - if self._calculate_jacobian == _CalculateJacobian.at_start: + if self.calculate_jacobian == CalculateJacobian.every_step: assert self.tableau.a_diagonal is not None # Skipping the first element to account for ESDIRK methods. assert all( @@ -472,7 +528,7 @@ def eval_stage(_carry, _input): else: _k_pred = _vector_tree_dot(_a_predictor_i, ks, _i) # noqa: F821 # Determine Jacobian to use at this stage - if self._calculate_jacobian == _CalculateJacobian.every_stage: + if self.calculate_jacobian == CalculateJacobian.every_stage: if _return_fi: _jac_f = self.nonlinear_solver.jac( _implicit_relation_f, @@ -504,7 +560,7 @@ def eval_stage(_carry, _input): ), ) else: - assert self._calculate_jacobian == _CalculateJacobian.at_start + assert self.calculate_jacobian == CalculateJacobian.every_step _jac_f = jac_f _jac_k = jac_k # Solve nonlinear problem @@ -551,9 +607,7 @@ def eval_stage(_carry, _input): _ki = _nonlinear_sol.root else: assert False - _result = jnp.where( - _result == RESULTS.successful, _nonlinear_sol.result, _result - ) + _result = update_result(_result, _nonlinear_sol.result) del _nonlinear_sol else: # Explicit stage @@ -684,7 +738,7 @@ def eval_stage(_carry, _input): else: y_error = vector_tree_dot(self.tableau.b_error, ks) y_error = jax.tree_map( - lambda _y_error: jnp.where(result == RESULTS.successful, _y_error, jnp.inf), + lambda _y_error: jnp.where(is_okay(result), _y_error, jnp.inf), y_error, ) # i.e. an implicit step failed to converge @@ -713,24 +767,49 @@ def eval_stage(_carry, _input): class AbstractERK(AbstractRungeKutta): - _calculate_jacobian = _CalculateJacobian.never + """Abstract base class for all Explicit Runge--Kutta solvers. + + Subclasses should include a class-level attribute `tableau`, an instance of + [`diffrax.ButcherTableau`][]. + """ + + calculate_jacobian = CalculateJacobian.never class AbstractDIRK(AbstractRungeKutta, AbstractImplicitSolver): - _calculate_jacobian = _CalculateJacobian.every_stage + """Abstract base class for all Diagonal Implicit Runge--Kutta solvers. + + Subclasses should include a class-level attribute `tableau`, an instance of + [`diffrax.ButcherTableau`][]. + """ + + calculate_jacobian = CalculateJacobian.every_stage class AbstractSDIRK(AbstractDIRK): + """Abstract base class for all Singular Diagonal Implict Runge--Kutta solvers. + + Subclasses should include a class-level attribute `tableau`, an instance of + [`diffrax.ButcherTableau`][]. + """ + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls.tableau is not None: # Abstract subclasses may not have a tableau. diagonal = cls.tableau.a_diagonal[0] assert (cls.tableau.a_diagonal == diagonal).all() - _calculate_jacobian = _CalculateJacobian.at_start + calculate_jacobian = CalculateJacobian.every_step class AbstractESDIRK(AbstractDIRK): + """Abstract base class for all Explicit Singular Diagonal Implicit Runge--Kutta + solvers. + + Subclasses should include a class-level attribute `tableau`, an instance of + [`diffrax.ButcherTableau`][]. + """ + def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if cls.tableau is not None: # Abstract subclasses may not have a tableau. @@ -738,4 +817,4 @@ def __init_subclass__(cls, **kwargs): diagonal = cls.tableau.a_diagonal[1] assert (cls.tableau.a_diagonal[1:] == diagonal).all() - _calculate_jacobian = _CalculateJacobian.at_start + calculate_jacobian = CalculateJacobian.every_step diff --git a/diffrax/solver/semi_implicit_euler.py b/diffrax/solver/semi_implicit_euler.py index 34650cca..bf07c2f9 100644 --- a/diffrax/solver/semi_implicit_euler.py +++ b/diffrax/solver/semi_implicit_euler.py @@ -50,7 +50,7 @@ def step( dense_info = dict(y0=y0, y1=y1) return y1, None, dense_info, None, RESULTS.successful - def func_for_init( + def func( self, terms: Tuple[AbstractTerm, AbstractTerm], t0: Scalar, @@ -60,6 +60,6 @@ def func_for_init( term_1, term_2 = terms y0_1, y0_2 = y0 - f1 = term_1.func_for_init(t0, y0_2, args) - f2 = term_2.func_for_init(t0, y0_1, args) + f1 = term_1.func(t0, y0_2, args) + f2 = term_2.func(t0, y0_1, args) return (f1, f2) diff --git a/diffrax/step_size_controller/adaptive.py b/diffrax/step_size_controller/adaptive.py index 3be27bdd..1cc82dc9 100644 --- a/diffrax/step_size_controller/adaptive.py +++ b/diffrax/step_size_controller/adaptive.py @@ -19,13 +19,13 @@ def _select_initial_step( t0: Scalar, y0: PyTree, args: PyTree, - func_for_init: Callable[[Scalar, PyTree, PyTree], PyTree], + func: Callable[[Scalar, PyTree, PyTree], PyTree], error_order: Scalar, rtol: Scalar, atol: Scalar, norm: Callable[[PyTree], Scalar], ) -> Scalar: - f0 = func_for_init(terms, t0, y0, args) + f0 = func(terms, t0, y0, args) scale = (atol + ω(y0).call(jnp.abs) * rtol).ω d0 = norm((y0**ω / scale**ω).ω) d1 = norm((f0**ω / scale**ω).ω) @@ -36,7 +36,7 @@ def _select_initial_step( t1 = t0 + h0 y1 = (y0**ω + h0 * f0**ω).ω - f1 = func_for_init(terms, t1, y1, args) + f1 = func(terms, t1, y1, args) d2 = norm(((f1**ω - f0**ω) / scale**ω).ω) / h0 max_d = jnp.maximum(d1, d2) @@ -61,6 +61,14 @@ def __repr__(self): class AbstractAdaptiveStepSizeController(AbstractStepSizeController): + """Indicates an adaptive step size controller. + + Accepts tolerances `rtol` and `atol`. When used in conjunction with an implicit + solver ([`diffrax.AbstractImplicitSolver`][]), then these tolerances will + automatically be used as the tolerances for the nonlinear solver passed to the + implicit solver, if they are not specified manually. + """ + rtol: Optional[Scalar] = None atol: Optional[Scalar] = None @@ -297,7 +305,7 @@ def init( y0: PyTree, dt0: Optional[Scalar], args: PyTree, - func_for_init: Callable[[Scalar, PyTree, PyTree], PyTree], + func: Callable[[Scalar, PyTree, PyTree], PyTree], error_order: Optional[Scalar], ) -> Tuple[Scalar, _ControllerState]: del t1 @@ -308,7 +316,7 @@ def init( t0, y0, args, - func_for_init, + func, error_order, self.rtol, self.atol, @@ -383,14 +391,9 @@ def adapt_step_size( # `scaled_error = norm(y_error) / (atol + norm(y) * rtol)` (2) # We do (1). torchdiffeq and torchsde do (1). Soderlind's papers and # OrdinaryDiffEq.jl do (2). - # We choose to do (1) by considering what were to happen if we were to increase - # the dimensionality of `y` and `y_error` with zeros. (i.e. append as many - # `dy/dt=0` problems as we please, and then solve them perfectly) Assuming that - # `norm` normalises by the number of dimensions (e.g. like an RMS norm) then - # (2) will see `norm(y_error) -> 0`, `norm(y) -> 0`, and therefore `atol` - # playing a larger and larger role. In contrast (2) simply scales things down - # without `atol` taking on extra importance. (This is quite thin justification - # though.) + # We choose to do (1) by considering what if `y` were to contain different + # components at very different scales. The errors in the small components may + # be drowned out by the errors in the big components if we were using (2). # # Some will put the multiplication by `safety` outside the `coeff/error_order` # exponent. (1) Some will put it inside. (2) diff --git a/diffrax/step_size_controller/base.py b/diffrax/step_size_controller/base.py index 6cd3251f..fac76513 100644 --- a/diffrax/step_size_controller/base.py +++ b/diffrax/step_size_controller/base.py @@ -59,7 +59,7 @@ def init( y0: PyTree, dt0: Optional[Scalar], args: PyTree, - func_for_init: Callable[[Scalar, PyTree, PyTree], PyTree], + func: Callable[[Scalar, PyTree, PyTree], PyTree], error_order: Optional[Scalar], ) -> Tuple[Scalar, _ControllerState]: r"""Determines the size of the first step, and initialise any hidden state for @@ -67,7 +67,7 @@ def init( **Arguments:** As `diffeqsolve`. - - `func_for_init`: The value of `solver.func_for_init`. + - `func`: The value of `solver.func`. - `error_order`: The order of the error estimate. If solving an ODE this will typically be `solver.order()`. If solving an SDE this will typically be `solver.strong_order() + 0.5`. diff --git a/diffrax/step_size_controller/constant.py b/diffrax/step_size_controller/constant.py index c431a15c..7f905d54 100644 --- a/diffrax/step_size_controller/constant.py +++ b/diffrax/step_size_controller/constant.py @@ -28,10 +28,10 @@ def init( y0: PyTree, dt0: Optional[Scalar], args: PyTree, - func_for_init: Callable[[Scalar, PyTree, PyTree], PyTree], + func: Callable[[Scalar, PyTree, PyTree], PyTree], error_order: Optional[Scalar], ) -> Tuple[Scalar, Scalar]: - del terms, t1, y0, args, func_for_init, error_order + del terms, t1, y0, args, func, error_order if dt0 is None: raise ValueError( "Constant step size solvers cannot select step size automatically; " @@ -108,10 +108,10 @@ def init( y0: PyTree, dt0: None, args: PyTree, - func_for_init: Callable[[Scalar, PyTree, PyTree], PyTree], + func: Callable[[Scalar, PyTree, PyTree], PyTree], error_order: Optional[Scalar], ) -> Tuple[Scalar, int]: - del y0, args, func_for_init, error_order + del y0, args, func, error_order if dt0 is not None: raise ValueError( "`dt0` should be `None`. Step location is already determined " diff --git a/diffrax/term.py b/diffrax/term.py index d9281dfb..4a09d61c 100644 --- a/diffrax/term.py +++ b/diffrax/term.py @@ -136,39 +136,6 @@ def vf_prod(self, t: Scalar, y: PyTree, args: PyTree, control: PyTree) -> PyTree """ return self.prod(self.vf(t, y, args), control) - # This is a pinhole break in our vector-field/control abstraction. - # Everywhere else we get to evaluate over some interval, which allows us to - # evaluate our control over that interval. However to select the initial point in - # an adaptive step size scheme, the standard heuristic is to start by making - # evaluations at just the initial point -- no intervals involved. - def func_for_init(self, t: Scalar, y: PyTree, args: PyTree) -> PyTree: - """This is a special-cased version of [`diffrax.AbstractTerm.vf`][]. - - If it so happens that the PyTree structures $T$ and $S$ are the same, then a - subclass of `AbstractTerm` shoud set `func_for_init = vf`. - - This case is used when selecting the initial step size of an ODE solve - automatically. - - See [`diffrax.AbstractSolver.func_for_init`][]. - """ - - # Heuristic for whether it's safe to select an initial step automatically. - vf = self.vf(t, y, args) - flat_vf, tree_vf = jax.tree_flatten(vf) - flat_y, tree_y = jax.tree_flatten(y) - if tree_vf != tree_y or any( - jnp.shape(x) != jnp.shape(y) for x, y in zip(flat_vf, flat_y) - ): - raise ValueError( - "An initial step size cannot be selected automatically. The most " - "common scenario for this error to occur is when trying to use " - "adaptive step size solvers with SDEs, or with CDEs without " - "`ControlTerm(...).to_ode()`. Please specify an initial `dt0` instead." - ) - else: - return vf - def is_vf_expensive( self, t0: Scalar, diff --git a/docs/api/adjoints.md b/docs/api/adjoints.md index bafba014..39cef0d4 100644 --- a/docs/api/adjoints.md +++ b/docs/api/adjoints.md @@ -31,6 +31,10 @@ There are multiple ways to backpropagate through a differential equation (to com selection: members: false +::: diffrax.ImplicitAdjoint + selection: + members: false + ::: diffrax.BacksolveAdjoint selection: members: diff --git a/docs/api/events.md b/docs/api/events.md new file mode 100644 index 00000000..263a195c --- /dev/null +++ b/docs/api/events.md @@ -0,0 +1,24 @@ +# Events + +Events allow for interrupting a differential equation solve, and changing its internal state, or terminating the solve before `t1` is reached. + +At the moment a single kind of event is supported: discrete events which are checked at the end of every step, and which halt the integration once they become true. + +??? abstract "`diffrax.AbstractDiscreteTerminatingEvent`" + + ::: diffrax.AbstractDiscreteTerminatingEvent + selection: + members: + - __call__ + +--- + +::: diffrax.DiscreteTerminatingEvent + selection: + members: + - __init__ + +::: diffrax.SteadyStateEvent + selection: + members: + - __init__ diff --git a/docs/api/solver.md b/docs/api/solver.md deleted file mode 100644 index e240245c..00000000 --- a/docs/api/solver.md +++ /dev/null @@ -1,173 +0,0 @@ -# Solvers - -The complete list of solvers, categorised by type, is as follows. See also [How to choose a solver](../usage/how-to-choose-a-solver.md). - -!!! info "Term structure" - - The type of solver chosen determines how the `terms` argument of `diffeqsolve` should be laid out. Most of them demand that it should be a single `AbstractTerm`. But for example [`diffrax.SemiImplicitEuler`][] demands that it be a 2-tuple `(AbstractTerm, AbstractTerm)`, to represent the two vector fields that solver uses. - - If it is different from this default, then you can find the appropriate structure documented below, and available programmatically under `.term_structure`. - -!!! info "Stochastic differential equations" - - Little distinction is made between solvers for different kinds of differential equation, like between ODEs and SDEs. Diffrax's term system allows for treating them all in a unified way. - - Those ODE solvers that make sense as SDE solvers are documented as such below. For the common case of an SDE with drift and Brownian-motion-driven diffusion, they can be used by combining drift and diffusion into a single term: - - ```python - drift = lambda t, y, args: -y - diffusion = lambda t, y, args: y[..., None] - bm = UnsafeBrownianPath(shape=(1,), key=...) - terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, bm)) - diffeqsolve(terms, solver=Euler(), ...) - ``` - - In addition there are some [SDE-specific solvers](#sde-only-solvers). - - -??? abstract "`diffrax.AbstractSolver`" - - All of the classes implement the following interface specified by [`diffrax.AbstractSolver`][]. - - The exact details of this interface are only really useful if you're using the [Manual stepping](../usage/manual-stepping.md) interface or defining your own solvers; otherwise this is all just internal to the library. - - ::: diffrax.AbstractSolver - selection: - members: - - order - - strong_order - - error_order - - term_structure - - init - - step - - func_for_init - ---- - -### Explicit Runge--Kutta (ERK) methods - -::: diffrax.Euler - selection: - members: false - -::: diffrax.Heun - selection: - members: false - -::: diffrax.Midpoint - selection: - members: false - -::: diffrax.Ralston - selection: - members: false - -::: diffrax.Bosh3 - selection: - members: false - -::: diffrax.Tsit5 - selection: - members: false - -::: diffrax.Dopri5 - selection: - members: false - -::: diffrax.Dopri8 - selection: - members: false - ---- - -### Implicit Runge--Kutta (IRK) methods - -::: diffrax.ImplicitEuler - selection: - members: false - -::: diffrax.Kvaerno3 - selection: - members: false - -::: diffrax.Kvaerno4 - selection: - members: false - -::: diffrax.Kvaerno5 - selection: - members: false - ---- - -### Symplectic methods - -??? info "Term and state structure" - - The state of the system (the initial value of which is given by `y0` to [`diffrax.diffeqsolve`][]) must be a 2-tuple (of PyTrees). The terms (given by the value of `terms` to [`diffrax.diffeqsolve`][]) must be a 2-tuple of `AbstractTerms`. - - Letting `v, w = y0` and `f, g = terms`, then `v` is updated according to `f(t, w, args) * dt` and `w` is updated according to `g(t, v, args) * dt`. - - See also this [Wikipedia page](https://en.wikipedia.org/wiki/Semi-implicit_Euler_method#Setting). - -::: diffrax.SemiImplicitEuler - selection: - members: false - ---- - -### Reversible methods - -::: diffrax.ReversibleHeun - selection: - members: false - ---- - -### Linear multistep methods - -::: diffrax.LeapfrogMidpoint - selection: - members: false - ---- - -### SDE-only solvers - -!!! tip "Other SDE solvers" - - Many low-order ODE solvers can also be used as SDE solvers: - - **Itô:** - - - [`diffrax.Euler`][] - - **Stratonovich:** - - - [`diffrax.Heun`][] - - [`diffrax.Midpoint`][] - - [`diffrax.ReversibleHeun`][] - -!!! info "Term structure" - - For these SDE-specific solvers, the terms (given by the value of `terms` to [`diffrax.diffeqsolve`][]) must be a 2-tuple `(AbstractTerm, AbstractTerm)`, representing the drift and diffusion respectively. Typically that means `(ODETerm(...), ControlTerm(..., ...))`. - -::: diffrax.EulerHeun - selection: - members: false - -::: diffrax.ItoMilstein - selection: - members: false - -::: diffrax.StratonovichMilstein - selection: - members: false - ---- - -### Wrapper solvers - -::: diffrax.HalfSolver - selection: - members: false diff --git a/docs/api/solvers/abstract_solvers.md b/docs/api/solvers/abstract_solvers.md new file mode 100644 index 00000000..2942c989 --- /dev/null +++ b/docs/api/solvers/abstract_solvers.md @@ -0,0 +1,86 @@ +# Abstract solvers + +All of the solvers (both ODE and SDE solvers) implement the following interface specified by [`diffrax.AbstractSolver`][]. + +The exact details of this interface are only really useful if you're using the [Manual stepping](../../usage/manual-stepping.md) interface or defining your own solvers; otherwise this is all just internal to the library. + +Also see [Extending Diffrax](../../usage/extending.md) for more information on defining your own solvers. + +In addition [`diffrax.AbstractSolver`][] has several subclasses that you can use to mark your custom solver as exhibiting particular behaviour. + +--- + +::: diffrax.AbstractSolver + selection: + members: + - order + - strong_order + - error_order + - init + - step + - func + +--- + +::: diffrax.AbstractImplicitSolver + selection: + members: + - __init__ + +--- + +::: diffrax.AbstractAdaptiveSolver + selection: + members: false + +--- + +::: diffrax.AbstractItoSolver + selection: + members: false + +--- + +::: diffrax.AbstractStratonovichSolver + selection: + members: false + +--- + +::: diffrax.AbstractWrappedSolver + selection: + members: + - __init__ + +--- + +### Abstract Runge--Kutta solvers + +::: diffrax.AbstractRungeKutta + selection: + members: false + +::: diffrax.AbstractERK + selection: + members: false + +::: diffrax.AbstractDIRK + selection: + members: false + +::: diffrax.AbstractSDIRK + selection: + members: false + +::: diffrax.AbstractESDIRK + selection: + members: false + +::: diffrax.ButcherTableau + selection: + members: + - __init__ + +::: diffrax.CalculateJacobian + selection: + members: false diff --git a/docs/api/solvers/ode_solvers.md b/docs/api/solvers/ode_solvers.md new file mode 100644 index 00000000..04087a4c --- /dev/null +++ b/docs/api/solvers/ode_solvers.md @@ -0,0 +1,117 @@ +# ODE solvers + +See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#ordinary-differential-equations). + +!!! info "Term structure" + + The type of solver chosen determines how the `terms` argument of `diffeqsolve` should be laid out. Most of them demand that it should be a single `AbstractTerm`. But for example [`diffrax.SemiImplicitEuler`][] demands that it be a 2-tuple `(AbstractTerm, AbstractTerm)`, to represent the two vector fields that solver uses. + + If it is different from this default, then you can find the appropriate structure documented below, and available programmatically under `.term_structure`. + +--- + +### Explicit Runge--Kutta (ERK) methods + +These methods are suitable for most problems. + +Each of these takes a `scan_stages` argument at initialisation, defaulting to `False`. Set to `True` to substantially improve compilation speed in return for a slight reduction in runtime speed. + +::: diffrax.Euler + selection: + members: false + +::: diffrax.Heun + selection: + members: false + +::: diffrax.Midpoint + selection: + members: false + +::: diffrax.Ralston + selection: + members: false + +::: diffrax.Bosh3 + selection: + members: false + +::: diffrax.Tsit5 + selection: + members: false + +::: diffrax.Dopri5 + selection: + members: false + +::: diffrax.Dopri8 + selection: + members: false + +--- + +### Implicit Runge--Kutta (IRK) methods + +These methods are suitable for stiff problems. + +Each of these takes a `scan_stages` argument at initialisation, which [behaves the same as for the explicit Runge--Kutta methods](#explicit-runge-kutta-erk-methods). In addition, each of these takes a `nonlinear_solver` argument at initialisation, defaulting to a Newton solver, which is used to solve the implicit problem at each step. See the page on [nonlinear solvers](../nonlinear_solver.md). + +::: diffrax.ImplicitEuler + selection: + members: false + +::: diffrax.Kvaerno3 + selection: + members: false + +::: diffrax.Kvaerno4 + selection: + members: false + +::: diffrax.Kvaerno5 + selection: + members: false + +--- + +### Symplectic methods + +These methods are suitable for problems with symplectic structure; that is to say those ODEs of the form + +$\frac{\mathrm{d}v}{\mathrm{d}t}(t) = f(t, w(t))$ + +$\frac{\mathrm{d}w}{\mathrm{d}t}(t) = g(t, v(t))$ + +In particular this includes Hamiltonian systems. + +??? info "Term and state structure" + + The state of the system (the initial value of which is given by `y0` to [`diffrax.diffeqsolve`][]) must be a 2-tuple (of PyTrees). The terms (given by the value of `terms` to [`diffrax.diffeqsolve`][]) must be a 2-tuple of `AbstractTerms`. + + Letting `v, w = y0` and `f, g = terms`, then `v` is updated according to `f(t, w, args)` and `w` is updated according to `g(t, v, args)`. + + See also this [Wikipedia page](https://en.wikipedia.org/wiki/Semi-implicit_Euler_method#Setting). + +::: diffrax.SemiImplicitEuler + selection: + members: false + +--- + +### Reversible methods + +These methods can be run "in reverse": solving from an initial condition `y0` to obtain some terminal value `y1`, it is possible to reconstruct `y0` from `y1` with zero truncation error. (There will still be a small amount of floating point error.) This can be done via `SaveAt(solver_state=True)` to save the final solver state, and then passing it as `diffeqsolve(..., solver_state=solver_state)` on the backwards-in-time pass. + +In addition all [symplectic methods](#symplectic-methods) are reversible, as are some linear multistep methods. (Below are the non-symplectic reversible solvers.) + +::: diffrax.ReversibleHeun + selection: + members: false + +--- + +### Linear multistep methods + +::: diffrax.LeapfrogMidpoint + selection: + members: false diff --git a/docs/api/solvers/sde_solvers.md b/docs/api/solvers/sde_solvers.md new file mode 100644 index 00000000..39b38039 --- /dev/null +++ b/docs/api/solvers/sde_solvers.md @@ -0,0 +1,84 @@ +# SDE solvers + +See also [How to choose a solver](../../usage/how-to-choose-a-solver.md#stochastic-differential-equations). + +!!! info "Term structure" + + The type of solver chosen determines how the `terms` argument of `diffeqsolve` should be laid out. Most of them operate in the same way whether they are solving an ODE or an SDE, and as such expected that it should be a single `AbstractTerm`. For SDEs that typically means a [`diffrax.MultiTerm`][] wrapping together a drift ([`diffrax.ODETerm`][]) and diffusion ([`diffrax.ControlTerm`][]). (Although you could also include any other term, e.g. an exogenous forcing term, if you wished.) For example: + + ```python + drift = lambda t, y, args: -y + diffusion = lambda t, y, args: y[..., None] + bm = UnsafeBrownianPath(shape=(1,), key=...) + terms = MultiTerm(ODETerm(drift), ControlTerm(diffusion, bm)) + diffeqsolve(terms, solver=Euler(), ...) + ``` + + Some solvers are SDE-specific. For these, such as for example [`diffrax.StratonovichMilstein`][], then `terms` should be a 2-tuple `(AbstractTerm, AbstractTerm)`, representing the drift and diffusion separately. + + For those SDE-specific solvers then this is documented below, and the term structure is available programmatically under `.term_structure`. + +--- + +### Explicit Runge--Kutta (ERK) methods + +Each of these takes a `scan_stages` argument at initialisation, which [behaves the same as as the explicit Runge--Kutta methods for ODEs](./ode_solvers.md#explicit-runge-kutta-erk-methods). + +::: diffrax.Euler + selection: + members: false + +::: diffrax.Heun + selection: + members: false + +::: diffrax.Midpoint + selection: + members: false + +::: diffrax.Ralston + selection: + members: false + +!!! info + + In addition to the solvers above, then most higher-order ODE solvers can actually also be used as SDE solvers. They will typically converge to the Stratonovich solution. In practice this is computationally wasteful as they will not obtain more accurate solutions when applied to SDEs. + +--- + +### Reversible methods + +These are reversible in the same way as when applied to ODEs. [See here.](./ode_solvers.md#reversible-methods) + +::: diffrax.ReversibleHeun + selection: + members: false + +--- + +### SDE-only solvers + +!!! info "Term structure" + + For these SDE-specific solvers, the terms (given by the value of `terms` to [`diffrax.diffeqsolve`][]) must be a 2-tuple `(AbstractTerm, AbstractTerm)`, representing the drift and diffusion respectively. Typically that means `(ODETerm(...), ControlTerm(..., ...))`. + +::: diffrax.EulerHeun + selection: + members: false + +::: diffrax.ItoMilstein + selection: + members: false + +::: diffrax.StratonovichMilstein + selection: + members: false + +--- + +### Wrapper solvers + +::: diffrax.HalfSolver + selection: + members: + - __init__ diff --git a/docs/api/stepsize_controller.md b/docs/api/stepsize_controller.md index 1667fe9c..cdd10b7d 100644 --- a/docs/api/stepsize_controller.md +++ b/docs/api/stepsize_controller.md @@ -2,8 +2,12 @@ The list of step size controllers is as follows. The most common cases are fixed step sizes with [`diffrax.ConstantStepSize`][] and adaptive step sizes with [`diffrax.PIDController`][]. +!!! warning -??? abstract "`diffrax.AbstractStepSizeController`" + To perform adaptive stepping with SDEs requires [commutative noise](../usage/how-to-choose-a-solver.md#stochastic-differential-equations). Note that this commutativity condition is not checked. + + +??? abstract "Abtract base classes" All of the classes implement the following interface specified by [`diffrax.AbstractStepSizeController`][]. @@ -17,6 +21,12 @@ The list of step size controllers is as follows. The most common cases are fixed - init - adapt_step_size + ::: diffrax.AbstractAdaptiveStepSizeController + selection: + members: + - rtol + - atol + --- ::: diffrax.ConstantStepSize @@ -28,23 +38,8 @@ The list of step size controllers is as follows. The most common cases are fixed selection: members: - __init__ - - ts ::: diffrax.PIDController selection: members: - __init__ - - pcoeff - - icoeff - - dcoeff - - rtol - - atol - - dtmin - - dtmax - - force_dtmin - - step_ts - - jump_ts - - factormin - - factormax - - norm - - safety diff --git a/docs/further_details/faq.md b/docs/further_details/faq.md index 626926a4..dfad2f7d 100644 --- a/docs/further_details/faq.md +++ b/docs/further_details/faq.md @@ -1,5 +1,9 @@ # FAQ +### Compilation is taking a long time. + +If you're using a Runge--Kutta method like [`diffrax.Dopri5`][] etc., then try setting `scan_stages=True` when initialisating the solver, for example `Dopri5(scan_stages=True)`. This will substantially reduce compile time at the expense of a slightly slower run time. + ### The solve is taking loads of steps / I'm getting NaN gradients / other weird behaviour. Try switching to 64-bit precision. (Instead of the 32-bit that is the default in JAX.) [See here](https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision). diff --git a/docs/requirements.txt b/docs/requirements.txt index 72780a04..9ffa1e5d 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -4,7 +4,7 @@ mkdocs-material==7.3.6 # Theme pymdown-extensions==9.4 # Markdown extensions e.g. to handle LaTeX. mkdocstrings==0.17.0 # Autogenerate documentation from docstrings. mknotebooks==0.7.1 # Turn Jupyter Lab notebooks into webpages. -pytkdocs_tweaks==0.0.4 # Tweaks mkdocstrings to improve various aspects +pytkdocs_tweaks==0.0.6 # Tweaks mkdocstrings to improve various aspects mkdocs_include_exclude_files==0.0.1 # Allow for customising which files get included jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings. diff --git a/docs/usage/extending.md b/docs/usage/extending.md index e7d76d39..50866af0 100644 --- a/docs/usage/extending.md +++ b/docs/usage/extending.md @@ -14,7 +14,6 @@ The main points of extension are as follows: - `diffrax.AbstractImplicitSolver` are those solvers that solve implicit problems (and therefore take a `nonlinear_solver` argument). - `diffrax.AbstractAdaptiveSolver` are those solvers capable of providing error estimates (and thus can be used with adaptive step size controllers). - `diffrax.AbstractItoSolver` and `diffrax.AbstractStratonovichSolver` are used to specify which SDE solution a particular solver is known to converge to. - - `diffrax.AbstractAdaptiveSDESolver` are those solvers whose error estimates are suitable for SDEs. (That is, both the method and the embedded method both converge to the same choice of Itô/Stratonovich.) - `diffrax.AbstractWrappedSolver` indicates that the solver is used to wrap another solver, and so e.g. it will be treated as an implicit solver/etc. if the wrapped solver is also an implicit solver/etc. - **Custom step size controllers** should inherit from [`diffrax.AbstractStepSizeController`][]. diff --git a/docs/usage/getting-started.md b/docs/usage/getting-started.md index 8708540d..c690c52e 100644 --- a/docs/usage/getting-started.md +++ b/docs/usage/getting-started.md @@ -49,7 +49,7 @@ print(sol.ys) # DeviceArray([1. , 0.368, 0.135, 0.0498]) - The numerical solver (here `Dopri5`) can be switched out. - See the guide on [How to choose a solver](./how-to-choose-a-solver.md). - - See the [Solvers](../api/solver.md) page for the full list of solvers. + - See the [ODE solvers](../api/solvers/ode_solvers.md) page for the full list of solvers. - Where to save the result (e.g. to obtain dense output) can be adjusted by changing [`diffrax.SaveAt`][]. - Step sizes and locations can be changed. - The initial step size can be selected adaptively by setting `dt0=None`. @@ -92,6 +92,7 @@ print(sol.evaluate(1.1)) # DeviceArray(0.89436394) - The numerical solver used is `Euler()`. (Also known as Euler--Maruyama when applied to SDEs.) - There's no clever hackery behind the scenes: `Euler()` for an SDE simply works in exactly the same way as `Euler()` for an ODE -- we just need to specify the extra diffusion term. - This converges to an Itô SDE because of the choice of solver. (Whether an SDE solver converges to Itô or Stratonovich SDE is a property of the solver.) + - See the [SDE solvers](../api/solvers/sde_solvers.md) page for the full list of solvers. - The solution is saved densely -- a continuous path is the output. We can then evaluate it at any point in the interval; in this case `0.1`. - No step size controller is specified so by default a constant step size is used. diff --git a/docs/usage/how-to-choose-a-solver.md b/docs/usage/how-to-choose-a-solver.md index c0cad17e..ef086b6f 100644 --- a/docs/usage/how-to-choose-a-solver.md +++ b/docs/usage/how-to-choose-a-solver.md @@ -1,9 +1,9 @@ # How to choose a solver -The full list of solvers is available on the [Solvers](../api/solver.md) page. - ## Ordinary differential equations +The full list of ODE solvers is available on the [ODE solvers](../api/solvers/ode_solvers.md) page. + !!! info ODE problems are informally divided into "stiff" and "non-stiff" problems. "Stiffness" generally refers to how difficult an equation is to solve numerically. Non-stiff problems are quite common, and usually solved using straightforward techniques like explicit Runge--Kutta methods. Stiff problems usually require more computationally expensive techniques, like implicit Runge--Kutta methods. diff --git a/docs/usage/manual-stepping.md b/docs/usage/manual-stepping.md index 17182fac..4b494098 100644 --- a/docs/usage/manual-stepping.md +++ b/docs/usage/manual-stepping.md @@ -12,7 +12,7 @@ In the following example, we solve an ODE using [`diffrax.Tsit5`][], and print o !!! note - See the [Solvers](../api/solver.md) page for a reference on the solver methods (`init`, `step`) used here. + See the [Abstract solvers](../api/solvers/abstract_solvers.md) page for a reference on the solver methods (`init`, `step`) used here. ```python from diffrax import ODETerm, Tsit5 diff --git a/examples/steady_state.ipynb b/examples/steady_state.ipynb new file mode 100644 index 00000000..1cbd1653 --- /dev/null +++ b/examples/steady_state.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "9877bdcb", + "metadata": {}, + "source": [ + "# Steady states" + ] + }, + { + "cell_type": "markdown", + "id": "3a16df75", + "metadata": {}, + "source": [ + "This example demonstrates how to use Diffrax to solve an ODE until it reaches a steady state. The key feature will be the use of event handling to detect that the steady state has been reached.\n", + "\n", + "In addition, for this example we need to backpropagate through the procedure of finding a steady state. We can do this efficiently using the implicit function theorem.\n", + "\n", + "This example is available as a Jupyter notebook [here](https://github.com/patrick-kidger/diffrax/blob/main/examples/steady_state.ipynb)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "7053a132", + "metadata": { + "execute": { + "shell": { + "execute_reply": "2022-07-15T17:49:27.190533+00:00" + } + }, + "iopub": { + "execute_input": "2022-07-15T17:46:56.174218+00:00", + "status": { + "busy": "2022-07-15T17:46:56.173283+00:00", + "idle": "2022-07-15T17:49:27.191890+00:00" + } + } + }, + "outputs": [], + "source": [ + "import diffrax\n", + "import equinox as eqx # https://github.com/patrick-kidger/equinox\n", + "import jax.numpy as jnp\n", + "import optax # https://github.com/deepmind/optax" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "5a39737f", + "metadata": { + "execute": { + "shell": { + "execute_reply": "2022-07-15T17:49:27.200260+00:00" + } + }, + "iopub": { + "execute_input": "2022-07-15T17:49:27.194682+00:00", + "status": { + "busy": "2022-07-15T17:49:27.194211+00:00", + "idle": "2022-07-15T17:49:27.201694+00:00" + } + } + }, + "outputs": [], + "source": [ + "class ExponentialDecayToSteadyState(eqx.Module):\n", + " steady_state: float\n", + "\n", + " def __call__(self, t, y, args):\n", + " return self.steady_state - y" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "525b5430", + "metadata": { + "execute": { + "shell": { + "execute_reply": "2022-07-15T17:49:27.210168+00:00" + } + }, + "iopub": { + "execute_input": "2022-07-15T17:49:27.203780+00:00", + "status": { + "busy": "2022-07-15T17:49:27.202937+00:00", + "idle": "2022-07-15T17:49:27.211528+00:00" + } + } + }, + "outputs": [], + "source": [ + "def loss(model, target_steady_state):\n", + " term = diffrax.ODETerm(model)\n", + " solver = diffrax.Tsit5()\n", + " t0 = 0\n", + " t1 = jnp.inf\n", + " dt0 = None\n", + " y0 = 1.0\n", + " max_steps = None\n", + " controller = diffrax.PIDController(rtol=1e-3, atol=1e-6)\n", + " event = diffrax.SteadyStateEvent()\n", + " adjoint = diffrax.ImplicitAdjoint()\n", + " # This combination of event, t1, max_steps, adjoint is particularly\n", + " # natural: we keep integration forever until we hit the event, with\n", + " # no maximum time or number of steps. Backpropagation happens via\n", + " # the implicit function theorem.\n", + " sol = diffrax.diffeqsolve(\n", + " term,\n", + " solver,\n", + " t0,\n", + " t1,\n", + " dt0,\n", + " y0,\n", + " max_steps=max_steps,\n", + " stepsize_controller=controller,\n", + " discrete_terminating_event=event,\n", + " adjoint=adjoint,\n", + " )\n", + " (y1,) = sol.ys\n", + " return (y1 - target_steady_state) ** 2" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "ec634466", + "metadata": { + "execute": { + "shell": { + "execute_reply": "2022-07-15T17:49:28.926507+00:00" + } + }, + "iopub": { + "execute_input": "2022-07-15T17:49:27.214972+00:00", + "status": { + "busy": "2022-07-15T17:49:27.214240+00:00", + "idle": "2022-07-15T17:49:28.927742+00:00" + } + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Step: 0 Steady State: 0.025839969515800476\n", + "Step: 1 Steady State: 0.058249037712812424\n", + "Step: 2 Steady State: 0.09451574087142944\n", + "Step: 3 Steady State: 0.13270404934883118\n", + "Step: 4 Steady State: 0.17144456505775452\n", + "Step: 5 Steady State: 0.2097906768321991\n", + "Step: 6 Steady State: 0.24709917604923248\n", + "Step: 7 Steady State: 0.28294336795806885\n", + "Step: 8 Steady State: 0.3170691728591919\n", + "Step: 9 Steady State: 0.34933507442474365\n", + "Step: 10 Steady State: 0.37968066334724426\n", + "Step: 11 Steady State: 0.4081019163131714\n", + "Step: 12 Steady State: 0.43463483452796936\n", + "Step: 13 Steady State: 0.45934173464775085\n", + "Step: 14 Steady State: 0.4823019802570343\n", + "Step: 15 Steady State: 0.5035936236381531\n", + "Step: 16 Steady State: 0.5233209133148193\n", + "Step: 17 Steady State: 0.5415788888931274\n", + "Step: 18 Steady State: 0.5584676265716553\n", + "Step: 19 Steady State: 0.5740787982940674\n", + "Step: 20 Steady State: 0.5885017514228821\n", + "Step: 21 Steady State: 0.6018210053443909\n", + "Step: 22 Steady State: 0.6141175627708435\n", + "Step: 23 Steady State: 0.6254667043685913\n", + "Step: 24 Steady State: 0.6359376907348633\n", + "Step: 25 Steady State: 0.6455990076065063\n", + "Step: 26 Steady State: 0.6545112729072571\n", + "Step: 27 Steady State: 0.6627309322357178\n", + "Step: 28 Steady State: 0.6703115701675415\n", + "Step: 29 Steady State: 0.6773026585578918\n", + "Step: 30 Steady State: 0.6837494373321533\n", + "Step: 31 Steady State: 0.6896938681602478\n", + "Step: 32 Steady State: 0.6951748728752136\n", + "Step: 33 Steady State: 0.7002284526824951\n", + "Step: 34 Steady State: 0.7048872113227844\n", + "Step: 35 Steady State: 0.7091819047927856\n", + "Step: 36 Steady State: 0.7131412029266357\n", + "Step: 37 Steady State: 0.7167739868164062\n", + "Step: 38 Steady State: 0.7201183438301086\n", + "Step: 39 Steady State: 0.7231980562210083\n", + "Step: 40 Steady State: 0.7260348796844482\n", + "Step: 41 Steady State: 0.7286462187767029\n", + "Step: 42 Steady State: 0.7310511469841003\n", + "Step: 43 Steady State: 0.733269989490509\n", + "Step: 44 Steady State: 0.7353137731552124\n", + "Step: 45 Steady State: 0.7371994853019714\n", + "Step: 46 Steady State: 0.7389383912086487\n", + "Step: 47 Steady State: 0.740541934967041\n", + "Step: 48 Steady State: 0.7420334219932556\n", + "Step: 49 Steady State: 0.7434003353118896\n", + "Step: 50 Steady State: 0.7446598410606384\n", + "Step: 51 Steady State: 0.7458205819129944\n", + "Step: 52 Steady State: 0.7468900680541992\n", + "Step: 53 Steady State: 0.7478761672973633\n", + "Step: 54 Steady State: 0.7487852573394775\n", + "Step: 55 Steady State: 0.7496234178543091\n", + "Step: 56 Steady State: 0.750394344329834\n", + "Step: 57 Steady State: 0.7511063814163208\n", + "Step: 58 Steady State: 0.751763105392456\n", + "Step: 59 Steady State: 0.7523672580718994\n", + "Step: 60 Steady State: 0.7529228329658508\n", + "Step: 61 Steady State: 0.753433346748352\n", + "Step: 62 Steady State: 0.7539049983024597\n", + "Step: 63 Steady State: 0.7543382048606873\n", + "Step: 64 Steady State: 0.7547407746315002\n", + "Step: 65 Steady State: 0.7551127672195435\n", + "Step: 66 Steady State: 0.7554563879966736\n", + "Step: 67 Steady State: 0.7557693123817444\n", + "Step: 68 Steady State: 0.7560611367225647\n", + "Step: 69 Steady State: 0.7563308477401733\n", + "Step: 70 Steady State: 0.7565800547599792\n", + "Step: 71 Steady State: 0.756810188293457\n", + "Step: 72 Steady State: 0.7570226788520813\n", + "Step: 73 Steady State: 0.7572163343429565\n", + "Step: 74 Steady State: 0.7573966979980469\n", + "Step: 75 Steady State: 0.7575633525848389\n", + "Step: 76 Steady State: 0.7577127814292908\n", + "Step: 77 Steady State: 0.7578537464141846\n", + "Step: 78 Steady State: 0.7579842805862427\n", + "Step: 79 Steady State: 0.7581048607826233\n", + "Step: 80 Steady State: 0.7582123279571533\n", + "Step: 81 Steady State: 0.7583134770393372\n", + "Step: 82 Steady State: 0.7584078907966614\n", + "Step: 83 Steady State: 0.7584953904151917\n", + "Step: 84 Steady State: 0.758575975894928\n", + "Step: 85 Steady State: 0.7586501836776733\n", + "Step: 86 Steady State: 0.7587193250656128\n", + "Step: 87 Steady State: 0.7587832808494568\n", + "Step: 88 Steady State: 0.7588424682617188\n", + "Step: 89 Steady State: 0.7588958144187927\n", + "Step: 90 Steady State: 0.7589460015296936\n", + "Step: 91 Steady State: 0.7589924931526184\n", + "Step: 92 Steady State: 0.7590354681015015\n", + "Step: 93 Steady State: 0.7590752243995667\n", + "Step: 94 Steady State: 0.7591111063957214\n", + "Step: 95 Steady State: 0.7591448426246643\n", + "Step: 96 Steady State: 0.7591760754585266\n", + "Step: 97 Steady State: 0.7592049241065979\n", + "Step: 98 Steady State: 0.7592315673828125\n", + "Step: 99 Steady State: 0.7592562437057495\n", + "Target: 0.7599999904632568\n" + ] + } + ], + "source": [ + "model = ExponentialDecayToSteadyState(\n", + " jnp.array(0.0)\n", + ") # initial steady state guess is 0.\n", + "# target steady state is 0.76\n", + "target_steady_state = jnp.array(0.76)\n", + "optim = optax.sgd(1e-2, momentum=0.7, nesterov=True)\n", + "opt_state = optim.init(model)\n", + "\n", + "\n", + "@eqx.filter_jit\n", + "def make_step(model, opt_state, target_steady_state):\n", + " grads = eqx.filter_grad(loss)(model, target_steady_state)\n", + " updates, opt_state = optim.update(grads, opt_state)\n", + " model = eqx.apply_updates(model, updates)\n", + " return model, opt_state\n", + "\n", + "\n", + "for step in range(100):\n", + " model, opt_state = make_step(model, opt_state, target_steady_state)\n", + " print(f\"Step: {step} Steady State: {model.steady_state}\")\n", + "print(f\"Target: {target_steady_state}\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "py37", + "language": "python", + "name": "py37" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.13" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/mkdocs.yml b/mkdocs.yml index 0df0ecc2..ea7612d2 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -105,15 +105,20 @@ nav: - Continuous Normalising Flow: 'examples/continuous_normalising_flow.ipynb' - Symbolic Regression: 'examples/symbolic_regression.ipynb' - Stiff ODE: 'examples/stiff_ode.ipynb' + - Steady State: 'examples/steady_state.ipynb' - Basic API: - 'api/type_terminology.md' - 'api/diffeqsolve.md' - - 'api/solver.md' + - Solvers: + - 'api/solvers/ode_solvers.md' + - 'api/solvers/sde_solvers.md' + - 'api/solvers/abstract_solvers.md' - 'api/saveat.md' - 'api/stepsize_controller.md' - 'api/solution.md' - Advanced API: - 'api/adjoints.md' + - 'api/events.md' - 'api/terms.md' - 'api/path.md' - 'api/interpolation.md' diff --git a/test/helpers.py b/test/helpers.py index a4eb7fbe..27c021cd 100644 --- a/test/helpers.py +++ b/test/helpers.py @@ -4,39 +4,51 @@ import time import diffrax +import equinox as eqx import jax import jax.numpy as jnp import jax.random as jrandom all_ode_solvers = ( - (diffrax.Bosh3, dict(scan_stages=False)), - (diffrax.Bosh3, dict(scan_stages=True)), - (diffrax.Dopri5, dict(scan_stages=False)), - (diffrax.Dopri5, dict(scan_stages=True)), - (diffrax.Dopri8, dict(scan_stages=False)), - (diffrax.Dopri8, dict(scan_stages=True)), - (diffrax.Euler, {}), - (diffrax.Ralston, dict(scan_stages=False)), - (diffrax.Ralston, dict(scan_stages=True)), - (diffrax.Midpoint, dict(scan_stages=False)), - (diffrax.Midpoint, dict(scan_stages=True)), - (diffrax.Heun, dict(scan_stages=False)), - (diffrax.Heun, dict(scan_stages=True)), - (diffrax.LeapfrogMidpoint, {}), - (diffrax.ReversibleHeun, {}), - (diffrax.Tsit5, dict(scan_stages=False)), - (diffrax.Tsit5, dict(scan_stages=True)), - (diffrax.ImplicitEuler, {}), - (diffrax.Kvaerno3, dict(scan_stages=False)), - (diffrax.Kvaerno3, dict(scan_stages=True)), - (diffrax.Kvaerno4, dict(scan_stages=False)), - (diffrax.Kvaerno4, dict(scan_stages=True)), - (diffrax.Kvaerno5, dict(scan_stages=False)), - (diffrax.Kvaerno5, dict(scan_stages=True)), + diffrax.Bosh3(scan_stages=False), + diffrax.Bosh3(scan_stages=True), + diffrax.Dopri5(scan_stages=False), + diffrax.Dopri5(scan_stages=True), + diffrax.Dopri8(scan_stages=False), + diffrax.Dopri8(scan_stages=True), + diffrax.Euler(), + diffrax.Ralston(scan_stages=False), + diffrax.Ralston(scan_stages=True), + diffrax.Midpoint(scan_stages=False), + diffrax.Midpoint(scan_stages=True), + diffrax.Heun(scan_stages=False), + diffrax.Heun(scan_stages=True), + diffrax.LeapfrogMidpoint(), + diffrax.ReversibleHeun(), + diffrax.Tsit5(scan_stages=False), + diffrax.Tsit5(scan_stages=True), + diffrax.ImplicitEuler(), + diffrax.Kvaerno3(scan_stages=False), + diffrax.Kvaerno3(scan_stages=True), + diffrax.Kvaerno4(scan_stages=False), + diffrax.Kvaerno4(scan_stages=True), + diffrax.Kvaerno5(scan_stages=False), + diffrax.Kvaerno5(scan_stages=True), ) +def implicit_tol(solver): + if isinstance(solver, diffrax.AbstractImplicitSolver): + return eqx.tree_at( + lambda s: (s.nonlinear_solver.rtol, s.nonlinear_solver.atol), + solver, + (1e-3, 1e-6), + is_leaf=lambda x: x is None, + ) + return solver + + def random_pytree(key, treedef): keys = jrandom.split(key, treedef.num_leaves) leaves = [] diff --git a/test/test_adjoint.py b/test/test_adjoint.py index fe1a6283..ed5befaa 100644 --- a/test/test_adjoint.py +++ b/test/test_adjoint.py @@ -1,10 +1,12 @@ import math +from typing import Any import diffrax import equinox as eqx import jax import jax.numpy as jnp import jax.random as jrandom +import optax import pytest from .helpers import shaped_allclose @@ -193,3 +195,56 @@ def run(model): return jnp.sum(sol.ys) run(mlp) + + +def test_implicit(): + class ExponentialDecayToSteadyState(eqx.Module): + steady_state: float + non_jax_type: Any + + def __call__(self, t, y, args): + return self.steady_state - y + + def loss(model, target_steady_state): + term = diffrax.ODETerm(model) + solver = diffrax.Tsit5() + t0 = 0 + t1 = jnp.inf + dt0 = None + y0 = 1.0 + max_steps = None + controller = diffrax.PIDController(rtol=1e-3, atol=1e-6) + event = diffrax.SteadyStateEvent() + adjoint = diffrax.ImplicitAdjoint() + sol = diffrax.diffeqsolve( + term, + solver, + t0, + t1, + dt0, + y0, + max_steps=max_steps, + stepsize_controller=controller, + discrete_terminating_event=event, + adjoint=adjoint, + ) + (y1,) = sol.ys + return (y1 - target_steady_state) ** 2 + + model = ExponentialDecayToSteadyState(jnp.array(0.0), object()) + target_steady_state = jnp.array(0.76) + optim = optax.sgd(1e-2, momentum=0.7, nesterov=True) + opt_state = optim.init(eqx.filter(model, eqx.is_array)) + + @eqx.filter_jit + def make_step(model, opt_state, target_steady_state): + grads = eqx.filter_grad(loss)(model, target_steady_state) + updates, opt_state = optim.update(grads, opt_state) + model = eqx.apply_updates(model, updates) + return model, opt_state + + for step in range(100): + model, opt_state = make_step(model, opt_state, target_steady_state) + assert shaped_allclose( + model.steady_state, target_steady_state, rtol=1e-2, atol=1e-2 + ) diff --git a/test/test_detest.py b/test/test_detest.py index d56e3723..5d5e9962 100644 --- a/test/test_detest.py +++ b/test/test_detest.py @@ -18,7 +18,7 @@ import pytest import scipy.integrate as integrate -from .helpers import all_ode_solvers, shaped_allclose +from .helpers import all_ode_solvers, implicit_tol, shaped_allclose # @@ -355,47 +355,47 @@ def diffeq(t, y, args): return diffeq, init -@pytest.mark.parametrize("solver_ctr,solver_kwargs", all_ode_solvers) -def test_a(solver_ctr, solver_kwargs): - if solver_ctr in (diffrax.Euler, diffrax.ImplicitEuler): +@pytest.mark.parametrize("solver", all_ode_solvers) +def test_a(solver): + if isinstance(solver, (diffrax.Euler, diffrax.ImplicitEuler)): # Euler is pretty bad at solving things, so only do some simple tests. - _test(solver_ctr, solver_kwargs, [_a1, _a2], higher=False) + _test(solver, [_a1, _a2], higher=False) else: - _test(solver_ctr, solver_kwargs, [_a1, _a2, _a3, _a4, _a5], higher=False) + _test(solver, [_a1, _a2, _a3, _a4, _a5], higher=False) -@pytest.mark.parametrize("solver_ctr,solver_kwargs", all_ode_solvers) -def test_b(solver_ctr, solver_kwargs): - _test(solver_ctr, solver_kwargs, [_b1, _b2, _b3, _b4, _b5], higher=True) +@pytest.mark.parametrize("solver", all_ode_solvers) +def test_b(solver): + _test(solver, [_b1, _b2, _b3, _b4, _b5], higher=True) -@pytest.mark.parametrize("solver_ctr,solver_kwargs", all_ode_solvers) -def test_c(solver_ctr, solver_kwargs): - _test(solver_ctr, solver_kwargs, [_c1, _c2, _c3, _c4, _c5], higher=True) +@pytest.mark.parametrize("solver", all_ode_solvers) +def test_c(solver): + _test(solver, [_c1, _c2, _c3, _c4, _c5], higher=True) -@pytest.mark.parametrize("solver_ctr,solver_kwargs", all_ode_solvers) -def test_d(solver_ctr, solver_kwargs): - _test(solver_ctr, solver_kwargs, [_d1, _d2, _d3, _d4, _d5], higher=True) +@pytest.mark.parametrize("solver", all_ode_solvers) +def test_d(solver): + _test(solver, [_d1, _d2, _d3, _d4, _d5], higher=True) -@pytest.mark.parametrize("solver_ctr,solver_kwargs", all_ode_solvers) -def test_e(solver_ctr, solver_kwargs): - _test(solver_ctr, solver_kwargs, [_e1, _e2, _e3, _e4, _e5], higher=True) +@pytest.mark.parametrize("solver", all_ode_solvers) +def test_e(solver): + _test(solver, [_e1, _e2, _e3, _e4, _e5], higher=True) -def _test(solver_ctr, solver_kwargs, problems, higher): +def _test(solver, problems, higher): for problem in problems: vector_field, init = problem() term = diffrax.ODETerm(vector_field) - solver = solver_ctr(**solver_kwargs) if higher and solver.order(term) < 4: # Too difficult to get accurate solutions with a low-order solver return max_steps = 16**4 - if not issubclass(solver_ctr, diffrax.AbstractAdaptiveSolver): + if not isinstance(solver, diffrax.AbstractAdaptiveSolver): + solver = implicit_tol(solver) dt0 = 0.01 - if solver_ctr is diffrax.LeapfrogMidpoint: + if type(solver) is diffrax.LeapfrogMidpoint: # This is an *awful* long-time-horizon solver. # It gets decent results to begin with, but then the oscillations # build up by t=20. @@ -403,7 +403,7 @@ def _test(solver_ctr, solver_kwargs, problems, higher): dt0 = 0.000001 max_steps = 20_000_001 stepsize_controller = diffrax.ConstantStepSize() - elif solver_ctr is diffrax.ReversibleHeun and problem is _a1: + elif type(solver) is diffrax.ReversibleHeun and problem is _a1: # ReversibleHeun is a bit like LeapfrogMidpoint, and therefore bad over # long time horizons. (It develops very large oscillations over long time # horizons.) diff --git a/test/test_event.py b/test/test_event.py new file mode 100644 index 00000000..3129be6c --- /dev/null +++ b/test/test_event.py @@ -0,0 +1,33 @@ +import diffrax +import jax.numpy as jnp + + +def test_discrete_terminate1(): + term = diffrax.ODETerm(lambda t, y, args: y) + solver = diffrax.Tsit5() + t0 = 0 + t1 = jnp.inf + dt0 = 1 + y0 = 1.0 + event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.y > 10) + sol = diffrax.diffeqsolve( + term, solver, t0, t1, dt0, y0, discrete_terminating_event=event + ) + assert jnp.all(sol.ys > 10) + + +def test_discrete_terminate2(): + term = diffrax.ODETerm(lambda t, y, args: y) + solver = diffrax.Tsit5() + t0 = 0 + t1 = jnp.inf + dt0 = 1 + y0 = 1.0 + event = diffrax.DiscreteTerminatingEvent(lambda state, **kwargs: state.tprev > 10) + sol = diffrax.diffeqsolve( + term, solver, t0, t1, dt0, y0, discrete_terminating_event=event + ) + assert jnp.all(sol.ts > 10) + + +# diffrax.SteadyStateEvent tested as part of test_adjoint.py::test_implicit diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index 89664216..6a9a1e0c 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -7,7 +7,7 @@ import jax.random as jrandom import pytest -from .helpers import all_ode_solvers, shaped_allclose +from .helpers import all_ode_solvers, implicit_tol, shaped_allclose @pytest.mark.parametrize("mode", ["linear", "linear2", "cubic"]) @@ -311,12 +311,12 @@ def _test(firstderiv, derivs, y0, y1): jax.tree_map(_test, firstderiv, derivs, y0, y1) -def _test_dense_interpolation(solver_ctr, solver_kwargs, key, t1): +def _test_dense_interpolation(solver, key, t1): y0 = jrandom.uniform(key, (), minval=0.4, maxval=2) dt0 = t1 / 1e3 sol = diffrax.diffeqsolve( diffrax.ODETerm(lambda t, y, args: -y), - solver=solver_ctr(**solver_kwargs), + solver=solver, t0=0, t1=t1, dt0=dt0, @@ -333,18 +333,17 @@ def _test_dense_interpolation(solver_ctr, solver_kwargs, key, t1): return vals, true_vals, derivs, true_derivs -@pytest.mark.parametrize("solver_ctr,solver_kwargs", all_ode_solvers) -def test_dense_interpolation(solver_ctr, solver_kwargs, getkey): +@pytest.mark.parametrize("solver", all_ode_solvers) +def test_dense_interpolation(solver, getkey): + solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) - vals, true_vals, derivs, true_derivs = _test_dense_interpolation( - solver_ctr, solver_kwargs, key, 1 - ) + vals, true_vals, derivs, true_derivs = _test_dense_interpolation(solver, key, 1) assert jnp.array_equal(vals[0], true_vals[0]) val_tol = { diffrax.Euler: 1e-3, diffrax.ImplicitEuler: 1e-3, diffrax.LeapfrogMidpoint: 1e-5, - }.get(solver_ctr, 1e-6) + }.get(type(solver), 1e-6) assert shaped_allclose(vals, true_vals, atol=val_tol, rtol=val_tol) deriv_tol = { diffrax.ReversibleHeun: 1e-2, @@ -353,17 +352,18 @@ def test_dense_interpolation(solver_ctr, solver_kwargs, getkey): diffrax.Euler: 1e-3, diffrax.ImplicitEuler: 1e-3, diffrax.Ralston: 1e-3, - }.get(solver_ctr, 1e-6) + }.get(type(solver), 1e-6) assert shaped_allclose(derivs, true_derivs, atol=deriv_tol, rtol=deriv_tol) # When vmap'ing then it can happen that some batch elements take more steps to solve # than others. This means some padding is used to make things line up; here we test # that all of this works as intended. -@pytest.mark.parametrize("solver_ctr,solver_kwargs", all_ode_solvers) -def test_dense_interpolation_vmap(solver_ctr, solver_kwargs, getkey): +@pytest.mark.parametrize("solver", all_ode_solvers) +def test_dense_interpolation_vmap(solver, getkey): + solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) - _test_dense = ft.partial(_test_dense_interpolation, solver_ctr, solver_kwargs, key) + _test_dense = ft.partial(_test_dense_interpolation, solver, key) _test_dense_vmap = jax.vmap(_test_dense) vals, true_vals, derivs, true_derivs = _test_dense_vmap(jnp.array([0.5, 1.0])) assert jnp.array_equal(vals[:, 0], true_vals[:, 0]) @@ -371,7 +371,7 @@ def test_dense_interpolation_vmap(solver_ctr, solver_kwargs, getkey): diffrax.Euler: 1e-3, diffrax.ImplicitEuler: 1e-3, diffrax.LeapfrogMidpoint: 1e-5, - }.get(solver_ctr, 1e-6) + }.get(type(solver), 1e-6) assert shaped_allclose(vals, true_vals, atol=val_tol, rtol=val_tol) deriv_tol = { diffrax.ReversibleHeun: 1e-2, @@ -380,5 +380,5 @@ def test_dense_interpolation_vmap(solver_ctr, solver_kwargs, getkey): diffrax.Euler: 1e-3, diffrax.ImplicitEuler: 1e-3, diffrax.Ralston: 1e-3, - }.get(solver_ctr, 1e-6) + }.get(type(solver), 1e-6) assert shaped_allclose(derivs, true_derivs, atol=deriv_tol, rtol=deriv_tol) diff --git a/test/test_integrate.py b/test/test_integrate.py index 6f04791c..46e64841 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -10,7 +10,13 @@ import scipy.stats from diffrax.misc import ω -from .helpers import all_ode_solvers, random_pytree, shaped_allclose, treedefs +from .helpers import ( + all_ode_solvers, + implicit_tol, + random_pytree, + shaped_allclose, + treedefs, +) def _all_pairs(*args): @@ -32,16 +38,20 @@ def _all_pairs(*args): @pytest.mark.parametrize( - "solver_ctr,t_dtype,treedef,stepsize_controller", + "solver,t_dtype,treedef,stepsize_controller", _all_pairs( dict( - default=diffrax.Euler, + default=diffrax.Euler(), opts=( - diffrax.LeapfrogMidpoint, - diffrax.ReversibleHeun, - diffrax.Tsit5, - diffrax.ImplicitEuler, - diffrax.Kvaerno3, + diffrax.LeapfrogMidpoint(), + diffrax.ReversibleHeun(), + diffrax.Tsit5(), + diffrax.ImplicitEuler( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6) + ), + diffrax.Kvaerno3( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6) + ), ), ), dict(default=jnp.float32, opts=(int, float, jnp.int32)), @@ -52,8 +62,8 @@ def _all_pairs(*args): ), ), ) -def test_basic(solver_ctr, t_dtype, treedef, stepsize_controller, getkey): - if not issubclass(solver_ctr, diffrax.AbstractAdaptiveSolver) and isinstance( +def test_basic(solver, t_dtype, treedef, stepsize_controller, getkey): + if not isinstance(solver, diffrax.AbstractAdaptiveSolver) and isinstance( stepsize_controller, diffrax.PIDController ): return @@ -83,7 +93,7 @@ def f(t, y, args): try: sol = diffrax.diffeqsolve( diffrax.ODETerm(f), - solver_ctr(), + solver, t0, t1, dt0, @@ -104,8 +114,9 @@ def f(t, y, args): assert shaped_allclose(y1, true_y1, atol=1e-2, rtol=1e-2) -@pytest.mark.parametrize("solver_ctr,solver_kwargs", all_ode_solvers) -def test_ode_order(solver_ctr, solver_kwargs): +@pytest.mark.parametrize("solver", all_ode_solvers) +def test_ode_order(solver): + solver = implicit_tol(solver) key = jrandom.PRNGKey(5678) akey, ykey = jrandom.split(key, 2) @@ -115,7 +126,6 @@ def f(t, y, args): return A @ y term = diffrax.ODETerm(f) - solver = solver_ctr(**solver_kwargs) t0 = 0 t1 = 4 y0 = jrandom.normal(ykey, (10,), dtype=jnp.float64) @@ -522,7 +532,17 @@ def test_compile_time_steps(): assert shaped_allclose(sol.stats["compiled_num_steps"], jnp.array([20, 20])) -@pytest.mark.parametrize("solver", [diffrax.ImplicitEuler(), diffrax.Kvaerno5()]) +@pytest.mark.parametrize( + "solver", + [ + diffrax.ImplicitEuler( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6) + ), + diffrax.Kvaerno5( + nonlinear_solver=diffrax.NewtonNonlinearSolver(rtol=1e-3, atol=1e-6) + ), + ], +) def test_grad_implicit_solve(solver): # Check that we work around JAX issue #9374 # Whilst we're at -- for efficiency -- check the use of PyTree-valued state with diff --git a/test/test_interpolation.py b/test/test_interpolation.py index d14c9e72..2c280579 100644 --- a/test/test_interpolation.py +++ b/test/test_interpolation.py @@ -3,7 +3,7 @@ import jax.numpy as jnp import jax.random as jrandom -from .helpers import all_ode_solvers, shaped_allclose +from .helpers import all_ode_solvers, implicit_tol, shaped_allclose def _test_path_derivative(path, name): @@ -54,13 +54,12 @@ def test_derivative(getkey): ) paths.append((local_linear_interp, "local linear", ys[0], ys[-1])) - for solver_ctr, solver_kwargs in all_ode_solvers: - if solver_ctr is diffrax.Tsit5: - continue + for solver in all_ode_solvers: + solver = implicit_tol(solver) y0 = jrandom.normal(getkey(), (3,)) solution = diffrax.diffeqsolve( diffrax.ODETerm(lambda t, y, p: -y), - solver_ctr(**solver_kwargs), + solver, 0, 1, 0.01, @@ -68,7 +67,7 @@ def test_derivative(getkey): saveat=diffrax.SaveAt(dense=True, t1=True), ) y1 = solution.ys[-1] - paths.append((solution, solver_ctr.__name__, y0, y1)) + paths.append((solution, type(solver).__name__, y0, y1)) # actually do tests diff --git a/test/test_newton_solver.py b/test/test_newton_solver.py index f8449cf2..38a3acf8 100644 --- a/test/test_newton_solver.py +++ b/test/test_newton_solver.py @@ -20,7 +20,7 @@ def _fn2(x, args): return a - b, b -# Nontrivial nteractions between inputs +# Nontrivial interactions between inputs @jax.jit def _fn3(x, args): mlp = eqx.nn.MLP(4, 4, 256, 2, key=jrandom.PRNGKey(678))