diff --git a/diffrax/_solver/tsit5.py b/diffrax/_solver/tsit5.py index ebd96539..44b21986 100644 --- a/diffrax/_solver/tsit5.py +++ b/diffrax/_solver/tsit5.py @@ -1,6 +1,7 @@ from collections.abc import Callable from typing import ClassVar, Optional +import jax import jax.numpy as jnp import numpy as np from equinox.internal import ω @@ -147,10 +148,11 @@ def evaluate( ) b6 = -34.87065786149660974 * (t - 1.2) * (t - 0.666666666666666667) * t**2 b7 = 2.5 * (t - 1) * (t - 0.6) * t**2 - return ( - self.y0**ω - + vector_tree_dot(jnp.stack([b1, b2, b3, b4, b5, b6, b7]), self.k) ** ω - ).ω + with jax.numpy_dtype_promotion("standard"): + return ( + self.y0**ω + + vector_tree_dot(jnp.stack([b1, b2, b3, b4, b5, b6, b7]), self.k) ** ω + ).ω class Tsit5(AbstractERK): diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index 5adb147f..343b5920 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -4,6 +4,7 @@ import equinox as eqx import equinox.internal as eqxi +import jax import jax.lax as lax import jax.numpy as jnp import jax.tree_util as jtu @@ -18,6 +19,7 @@ from equinox import AbstractVar from equinox.internal import ω from jaxtyping import Array, PyTree +from lineax.internal import complex_to_real_dtype from .._custom_types import ( Args, @@ -460,8 +462,8 @@ def init( jump_next_step, at_dtmin, dt0, - jnp.array(1.0, dtype=y_dtype), - jnp.array(1.0, dtype=y_dtype), + jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)), + jnp.array(1.0, dtype=complex_to_real_dtype(y_dtype)), ) def adapt_step_size( @@ -569,7 +571,8 @@ def _scale(_y0, _y1_candidate, _y_error): _nan = jnp.isnan(_y1_candidate).any() _y1_candidate = jnp.where(_nan, _y0, _y1_candidate) _y = jnp.maximum(jnp.abs(_y0), jnp.abs(_y1_candidate)) - return _y_error / (self.atol + _y * self.rtol) + with jax.numpy_dtype_promotion("standard"): + return _y_error / (self.atol + _y * self.rtol) scaled_error = self.norm(jtu.tree_map(_scale, y0, y1_candidate, y_error)) keep_step = scaled_error < 1 @@ -607,7 +610,8 @@ def _scale(_y0, _y1_candidate, _y_error): # a grad API boundary as part of a larger model.) factor = lax.stop_gradient(factor) factor = eqxi.nondifferentiable(factor) - dt = prev_dt * factor + with jax.numpy_dtype_promotion("standard"): + dt = prev_dt * factor.astype(prev_dt) # E.g. we failed an implicit step, so y_error=inf, so inv_scaled_error=0, # so factor=factormin, and we shrunk our step. diff --git a/pyproject.toml b/pyproject.toml index 574ffc36..6761ff74 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/patrick-kidger/diffrax" } -dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.4", "optimistix>=0.0.6"] +dependencies = ["jax>=0.4.23", "jaxtyping>=0.2.24", "typing_extensions>=4.5.0", "typeguard==2.13.3", "equinox>=0.11.2", "lineax>=0.0.5", "optimistix>=0.0.6"] [build-system] requires = ["hatchling"] diff --git a/test/test_integrate.py b/test/test_integrate.py index 23781058..acd14b72 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -45,27 +45,25 @@ def _all_pairs(*args): @pytest.mark.parametrize( - "solver,t_dtype,y_dtype,treedef,stepsize_controller", - _all_pairs( - dict( - default=diffrax.Euler(), - opts=( - diffrax.LeapfrogMidpoint(), - diffrax.ReversibleHeun(), - diffrax.Tsit5(), - diffrax.ImplicitEuler( - root_finder=diffrax.VeryChord(rtol=1e-3, atol=1e-6) - ), - diffrax.Kvaerno3(root_finder=diffrax.VeryChord(rtol=1e-3, atol=1e-6)), - ), - ), - dict(default=jnp.float32, opts=(int, float, jnp.int32)), - dict(default=jnp.float32, opts=(jnp.complex64,)), - dict(default=treedefs[0], opts=treedefs[1:]), - dict( - default=diffrax.ConstantStepSize(), - opts=(diffrax.PIDController(rtol=1e-5, atol=1e-8),), - ), + "solver", + ( + diffrax.Euler(), + diffrax.LeapfrogMidpoint(), + diffrax.ReversibleHeun(), + diffrax.Tsit5(), + diffrax.ImplicitEuler(root_finder=diffrax.VeryChord(rtol=1e-3, atol=1e-6)), + diffrax.Kvaerno3(root_finder=diffrax.VeryChord(rtol=1e-3, atol=1e-6)), + ), +) +@pytest.mark.parametrize("t_dtype", (jnp.float32, int, float, jnp.int32)) +@pytest.mark.parametrize("y_dtype", (jnp.float32, jnp.complex64)) +@pytest.mark.parametrize("treedef", treedefs) +@pytest.mark.parametrize( + "stepsize_controller", + ( + diffrax.ConstantStepSize(), + diffrax.PIDController(rtol=1e-5, atol=1e-8), + diffrax.PIDController(rtol=1e-5, atol=1e-8, pcoeff=0.3, icoeff=0.3, dcoeff=0.0), ), ) def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey):