diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index c627ee50..c88e9ee2 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -7,11 +7,11 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import lineax.internal as lxi from jaxtyping import Array, PRNGKeyArray, PyTree from .._custom_types import levy_tree_transpose, LevyArea, LevyVal, RealScalarLike from .._misc import ( - default_floating_dtype, force_bitcast_convert_type, is_tuple_of_ints, split_by_tree, @@ -52,7 +52,7 @@ def __init__( levy_area: LevyArea = "", ): self.shape = ( - jax.ShapeDtypeStruct(shape, default_floating_dtype()) + jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) if is_tuple_of_ints(shape) else shape ) @@ -87,8 +87,14 @@ def evaluate( ) -> Union[PyTree[Array], LevyVal]: del left if t1 is None: + dtype = jnp.result_type(t0) t1 = t0 - t0 = 0 + t0 = jnp.array(0, dtype) + else: + with jax.numpy_dtype_promotion("standard"): + dtype = jnp.result_type(t0, t1) + t0 = jnp.astype(t0, dtype) + t1 = jnp.astype(t1, dtype) t0 = eqxi.nondifferentiable(t0, name="t0") t1 = eqxi.nondifferentiable(t1, name="t1") t1 = cast(RealScalarLike, t1) diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index b5335829..bf972526 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -9,6 +9,7 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import lineax.internal as lxi from jaxtyping import Array, Float, PRNGKeyArray, PyTree from .._custom_types import ( @@ -20,7 +21,6 @@ RealScalarLike, ) from .._misc import ( - default_floating_dtype, is_tuple_of_ints, linear_rescale, split_by_tree, @@ -179,7 +179,7 @@ def __init__( self.levy_area = levy_area self._spline = _spline self.shape = ( - jax.ShapeDtypeStruct(shape, default_floating_dtype()) + jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) if is_tuple_of_ints(shape) else shape ) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index d33452bb..b6501e41 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -10,6 +10,7 @@ import jax.core import jax.numpy as jnp import jax.tree_util as jtu +import lineax.internal as lxi from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint @@ -299,10 +300,10 @@ def body_fun_aux(state): # Count the number of steps, just for statistical purposes. num_steps = state.num_steps + 1 - num_accepted_steps = state.num_accepted_steps + keep_step + num_accepted_steps = state.num_accepted_steps + jnp.where(keep_step, 1, 0) # Not just ~keep_step, which does the wrong thing when keep_step is a non-array # bool True/False. - num_rejected_steps = state.num_rejected_steps + jnp.invert(keep_step) + num_rejected_steps = state.num_rejected_steps + jnp.where(keep_step, 0, 1) # # Store the output produced from this numerical step. @@ -369,7 +370,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: subsaveat.fn(tprev, y, args), save_state.ys, ) - save_index = save_state.save_index + keep_step + save_index = save_state.save_index + jnp.where(keep_step, 1, 0) save_state = eqx.tree_at( lambda s: [s.ts, s.ys, s.save_index], save_state, @@ -388,7 +389,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: dense_info, dense_infos, ) - dense_save_index = dense_save_index + keep_step + dense_save_index = dense_save_index + jnp.where(keep_step, 1, 0) new_state = State( y=y, @@ -625,7 +626,7 @@ def diffeqsolve( f"t0 with value {t0} and type {type(t0)}, " f"dt0 with value {dt0} and type {type(dt0)}" ) - with jax.ensure_compile_time_eval(): + with jax.ensure_compile_time_eval(), jax.numpy_dtype_promotion("standard"): pred = (t1 - t0) * dt0 < 0 dt0 = eqxi.error_if(jnp.array(dt0), pred, msg) @@ -641,7 +642,8 @@ def diffeqsolve( ) warnings.warn( "Complex dtype support is work in progress, please read " - "https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully." + "https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.", + stacklevel=2, ) # Backward compatibility @@ -653,7 +655,8 @@ def diffeqsolve( f"{solver.__class__.__name__} is deprecated in favour of " "`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that " "the same terms can now be passed used for both general and SDE-specific " - "solvers!" + "solvers!", + stacklevel=2, ) terms = MultiTerm(*terms) @@ -668,7 +671,8 @@ def diffeqsolve( if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)): warnings.warn( f"`{type(solver).__name__}` is not marked as converging to either the " - "Itô or the Stratonovich solution." + "Itô or the Stratonovich solution.", + stacklevel=2, ) if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): # Specific check to not work even if using HalfSolver(Euler()) @@ -684,11 +688,22 @@ def diffeqsolve( ) # Allow setting e.g. t0 as an int with dt0 as a float. - timelikes = [jnp.array(0.0), t0, t1, dt0] + [ + timelikes = [t0, t1, dt0] + [ s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat) ] timelikes = [x for x in timelikes if x is not None] - time_dtype = jnp.result_type(*timelikes) + with jax.numpy_dtype_promotion("standard"): + time_dtype = jnp.result_type(*timelikes) + if jnp.issubdtype(time_dtype, jnp.complexfloating): + raise ValueError( + "Cannot use complex dtype for `t0`, `t1`, `dt0`, or `SaveAt(ts=...)`." + ) + elif jnp.issubdtype(time_dtype, jnp.floating): + pass + elif jnp.issubdtype(time_dtype, jnp.integer): + time_dtype = lxi.default_floating_dtype() + else: + raise ValueError(f"Unrecognised time dtype {time_dtype}.") t0 = jnp.asarray(t0, dtype=time_dtype) t1 = jnp.asarray(t1, dtype=time_dtype) if dt0 is not None: @@ -708,7 +723,8 @@ def _get_subsaveat_ts(saveat): # fixing issue with float64 and weak dtypes, see discussion at: # https://github.com/patrick-kidger/diffrax/pull/197#discussion_r1130173527 def _promote(yi): - _dtype = jnp.result_type(yi, time_dtype) # noqa: F821 + with jax.numpy_dtype_promotion("standard"): + _dtype = jnp.result_type(yi, time_dtype) # noqa: F821 return jnp.asarray(yi, dtype=_dtype) y0 = jtu.tree_map(_promote, y0) diff --git a/diffrax/_misc.py b/diffrax/_misc.py index 1d8ffa1c..5fe777b4 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -162,8 +162,30 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike return lax.select(pred, a, b) -def default_floating_dtype(): - if jax.config.jax_enable_x64: # pyright: ignore - return jnp.float64 - else: - return jnp.float32 +def upcast_or_raise( + x: ArrayLike, array_for_dtype: ArrayLike, x_name: str, dtype_name: str +): + """If `JAX_NUMPY_DTYPE_PROMOTION=strict`, then this will raise an error if + `jnp.result_type(x, array_for_dtype)` is not the same as `array_for_dtype.dtype`. + It will then cast `x` to `jnp.result_type(x, array_for_dtype)`. + + Thus if `JAX_NUMPY_DTYPE_PROMOTION=standard`, then the usual anything-goes behaviour + will apply. If `JAX_NUMPY_DTYPE_PROMOTION=strict` then we loosen from prohibiting + all dtype casting, to still allowing upcasting. + """ + x_dtype = jnp.result_type(x) + target_dtype = jnp.result_type(array_for_dtype) + with jax.numpy_dtype_promotion("standard"): + promote_dtype = jnp.result_type(x_dtype, target_dtype) + config_value = jax.config.jax_numpy_dtype_promotion + if config_value == "strict": + if target_dtype != promote_dtype: + raise ValueError( + f"When `JAX_NUMPY_DTYPE_PROMOTION=strict`, then {x_name} must have " + f"a dtype that can be promoted to the dtype of {dtype_name}. " + f"However {x_name} had dtype {x_dtype} and {dtype_name} had dtype " + f"{target_dtype}." + ) + elif config_value != "standard": + assert False, f"Unrecognised `JAX_NUMPY_DTYPE_PROMOTION={config_value}`" + return jnp.astype(x, promote_dtype) diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index 847cfa81..4b3d432a 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -27,6 +27,7 @@ VF, Y, ) +from .._misc import upcast_or_raise from .._solution import RESULTS from .._term import AbstractTerm, ODETerm from .base import AbstractStepSizeController @@ -325,6 +326,14 @@ class PIDController( safety: RealScalarLike = 0.9 error_order: Optional[RealScalarLike] = None + def __check_init__(self): + if self.jump_ts is not None and not jnp.issubdtype( + self.jump_ts.dtype, jnp.inexact + ): + raise ValueError( + f"jump_ts must be floating point, not {self.jump_ts.dtype}" + ) + def wrap(self, direction: IntScalarLike): step_ts = None if self.step_ts is None else self.step_ts * direction jump_ts = None if self.jump_ts is None else self.jump_ts * direction @@ -632,10 +641,22 @@ def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLik if self.step_ts is None: return t1 + step_ts0 = upcast_or_raise( + self.step_ts, + t0, + "`PIDController.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + step_ts1 = upcast_or_raise( + self.step_ts, + t1, + "`PIDController.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) # TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping # track of where we were last, and using that as a hint for the next search. - t0_index = jnp.searchsorted(self.step_ts, t0, side="right") - t1_index = jnp.searchsorted(self.step_ts, t1, side="right") + t0_index = jnp.searchsorted(step_ts0, t0, side="right") + t1_index = jnp.searchsorted(step_ts1, t1, side="right") # This minimum may or may not actually be necessary. The left branch is taken # iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must # already satisfy the minimum. @@ -643,7 +664,7 @@ def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLik # so we clamp it just to be sure we're not hitting undefined behaviour. t1 = jnp.where( t0_index < t1_index, - self.step_ts[jnp.minimum(t0_index, len(self.step_ts) - 1)], + step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)], t1, ) return t1 @@ -653,23 +674,35 @@ def _clip_jump_ts( ) -> tuple[RealScalarLike, BoolScalarLike]: if self.jump_ts is None: return t1, False - if self.jump_ts is not None and not jnp.issubdtype( - self.jump_ts.dtype, jnp.inexact - ): + assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact) + if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact): raise ValueError( - f"jump_ts must be floating point, not {self.jump_ts.dtype}" + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t0)}." ) if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact): raise ValueError( - "t0, t1, dt0 must be floating point when specifying jump_t. Got " - f"{jnp.result_type(t1)}." + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t1)}." ) - t0_index = jnp.searchsorted(self.jump_ts, t0, side="right") - t1_index = jnp.searchsorted(self.jump_ts, t1, side="right") + jump_ts0 = upcast_or_raise( + self.jump_ts, + t0, + "`PIDController.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + jump_ts1 = upcast_or_raise( + self.jump_ts, + t1, + "`PIDController.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + t0_index = jnp.searchsorted(jump_ts0, t0, side="right") + t1_index = jnp.searchsorted(jump_ts1, t1, side="right") next_made_jump = t0_index < t1_index t1 = jnp.where( next_made_jump, - eqxi.prevbefore(self.jump_ts[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), + eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), t1, ) return t1, next_made_jump diff --git a/diffrax/_term.py b/diffrax/_term.py index 9bea8cf8..277089b0 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -12,6 +12,7 @@ from jaxtyping import Array, ArrayLike, PyTree, PyTreeDef from ._custom_types import Args, Control, IntScalarLike, RealScalarLike, VF, Y +from ._misc import upcast_or_raise from ._path import AbstractPath @@ -159,7 +160,8 @@ class ODETerm(AbstractTerm): appearing on the right hand side of an ODE, in which the control is time. `vector_field` should return some PyTree, with the same structure as the initial - state `y0`, and with every leaf broadcastable to the equivalent leaf in `y0`. + state `y0`, and with every leaf shape-broadcastable and dtype-upcastable to the + equivalent leaf in `y0`. !!! example @@ -179,13 +181,33 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: "The vector field inside `ODETerm` must return a pytree with the " "same structure as `y0`." ) - return jtu.tree_map(lambda o, yi: jnp.broadcast_to(o, jnp.shape(yi)), out, y) + + def _broadcast_and_upcast(oi, yi): + oi = jnp.broadcast_to(oi, jnp.shape(yi)) + oi = upcast_or_raise( + oi, + yi, + "the vector field passed to `ODETerm`", + "the corresponding leaf of `y`", + ) + return oi + + return jtu.tree_map(_broadcast_and_upcast, out, y) def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike: return t1 - t0 def prod(self, vf: VF, control: RealScalarLike) -> Y: - return jtu.tree_map(lambda v: control * v, vf) + def _mul(v): + c = upcast_or_raise( + control, + v, + "the output of `ODETerm.contr(...)`", + "the output of `ODETerm.vf(...)`", + ) + return c * v + + return jtu.tree_map(_mul, vf) ODETerm.__init__.__doc__ = """**Arguments:** diff --git a/test/conftest.py b/test/conftest.py index 989d28f1..55f09f71 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,6 +4,8 @@ jax.config.update("jax_enable_x64", True) # pyright: ignore +jax.config.update("jax_numpy_rank_promotion", "raise") # pyright: ignore +jax.config.update("jax_numpy_dtype_promotion", "strict") # pyright: ignore @pytest.fixture() diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 4d3ffd7a..4cc996c8 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -83,7 +83,8 @@ def run(ys, controller, state): _, tprev, tnext, _, state, _ = controller.adapt_step_size( 0, 1, y0, y1_candidate, None, y_error, 5, state ) - return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state)) + with jax.numpy_dtype_promotion("standard"): + return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state)) y0 = jnp.array(1.0) y1_candidate = jnp.array(2.0) diff --git a/test/test_brownian.py b/test/test_brownian.py index 76454a4a..4fe293d7 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -1,3 +1,4 @@ +import contextlib import math from typing import Literal from typing_extensions import TypeAlias @@ -33,8 +34,8 @@ def _make_struct(shape, dtype): @pytest.mark.parametrize("levy_area", ["", "space-time"]) @pytest.mark.parametrize("use_levy", (False, True)) def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): - t0 = 0 - t1 = 2 + t0 = 0.0 + t1 = 2.0 shapes = ( (), @@ -91,11 +92,14 @@ def is_tuple_of_ints(obj): if dtype is None: shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints) - for _t0 in _vals.values(): - for _t1 in _vals.values(): - t0, _ = _t0 - _, t1 = _t1 - out = path.evaluate(t0, t1, use_levy=use_levy) + for t0_dtype, (t0, _) in _vals.items(): + for t1_dtype, (_, t2) in _vals.items(): + if all(x in (float, jnp.float32) for x in (t0_dtype, t1_dtype)): + context = contextlib.nullcontext() + else: + context = jax.numpy_dtype_promotion("standard") + with context: + out = path.evaluate(t0, t1, use_levy=use_levy) if use_levy: assert isinstance(out, diffrax.LevyVal) w = out.W diff --git a/test/test_detest.py b/test/test_detest.py index 47d38f73..6dbb20e3 100644 --- a/test/test_detest.py +++ b/test/test_detest.py @@ -12,6 +12,7 @@ import math import diffrax +import jax import jax.flatten_util as fu import jax.numpy as jnp import jax.tree_util as jtu @@ -147,15 +148,15 @@ def _c2(): A = ( jnp.zeros((10, 10)) .at[jnp.arange(9), jnp.arange(9)] - .set(-jnp.arange(1, 10)) + .set(-jnp.arange(1.0, 10.0)) .at[jnp.arange(1, 10), jnp.arange(9)] - .set(jnp.arange(1, 10)) + .set(jnp.arange(1.0, 10.0)) ) def diffeq(t, y, args): return A @ y - init = jnp.zeros(10).at[0].set(1) + init = jnp.zeros(10).at[0].set(1.0) return diffeq, init @@ -215,12 +216,13 @@ def diffeq(t, y, args): r_cubed_j = r_cubed_k = jnp.sum(y_ij**2, axis=0) ** 1.5 d_cubed_jk = jnp.sum((y_ij[:, :, None] - y_ij[:, None, :]) ** 2, axis=0) ** 1.5 - term1_ij = -(m0 + m_j) * y_ij / r_cubed_j - term2_ijk = (y_ij[:, None, :] - y_ij[:, :, None]) / d_cubed_jk - term3_ik = y_ik / r_cubed_k - term4_ijk = m_k * (term2_ijk - term3_ik[:, None]) - term4_ijk = term4_ijk.at[:, jnp.arange(5), jnp.arange(5)].set(0) - term5_ij = jnp.sum(term4_ijk, axis=-1) + with jax.numpy_rank_promotion("allow"): + term1_ij = -(m0 + m_j) * y_ij / r_cubed_j + term2_ijk = (y_ij[:, None, :] - y_ij[:, :, None]) / d_cubed_jk + term3_ik = y_ik / r_cubed_k + term4_ijk = m_k * (term2_ijk - term3_ik[:, None]) + term4_ijk = term4_ijk.at[:, jnp.arange(5), jnp.arange(5)].set(0) + term5_ij = jnp.sum(term4_ijk, axis=-1) ddy_ij = k2 * (term1_ij + term5_ij) return dy_ij, ddy_ij diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index ef3650de..968cee24 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -290,9 +290,10 @@ def test_interpolation_classes(mode, getkey): def _test(firstval, vals, y0, y1): vals = jnp.concatenate([firstval[None], vals]) - true_vals = y0 + ((points - t0) / (t1 - t0))[:, None] * ( - y1 - y0 - ) + with jax.numpy_rank_promotion("allow"): + true_vals = y0 + ((points - t0) / (t1 - t0))[:, None] * ( + y1 - y0 + ) assert tree_allclose(vals, true_vals) jtu.tree_map(_test, firstval, vals, y0, y1) diff --git a/test/test_integrate.py b/test/test_integrate.py index 05a8576b..3bef5ce6 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -1,3 +1,4 @@ +import contextlib import math import operator from typing import cast @@ -77,15 +78,16 @@ def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey): ) and treedef == jtu.tree_structure(None): return - if jnp.iscomplexobj(y_dtype): - - def f(t, y, args): - return jtu.tree_map(lambda _y: operator.mul(-1j, _y), y) - + if jnp.iscomplexobj(y_dtype) and treedef != jtu.tree_structure(None): if isinstance(solver, diffrax.AbstractImplicitSolver): return + else: + complex_warn = pytest.warns(match="Complex dtype") + def f(t, y, args): + return jtu.tree_map(lambda yi: -1j * yi, y) else: + complex_warn = contextlib.nullcontext() def f(t, y, args): return jtu.tree_map(operator.neg, y) @@ -110,15 +112,16 @@ def f(t, y, args): raise ValueError y0 = random_pytree(getkey(), treedef, dtype=y_dtype) try: - sol = diffrax.diffeqsolve( - diffrax.ODETerm(f), - solver, - t0, - t1, - dt0, - y0, - stepsize_controller=stepsize_controller, - ) + with complex_warn: + sol = diffrax.diffeqsolve( + diffrax.ODETerm(f), + solver, + t0, + t1, + dt0, + y0, + stepsize_controller=stepsize_controller, + ) except Exception as e: if isinstance(stepsize_controller, diffrax.ConstantStepSize) and str( e diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 63168a2b..cc73bf99 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -118,7 +118,8 @@ def test_saveat_solution(): assert sol.ts.shape == (4096,) # pyright: ignore assert sol.ys.shape == (4096, 1) # pyright: ignore _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) - _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] + with jax.numpy_rank_promotion("allow"): + _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) assert tree_allclose(sol.ys, _ys) assert sol.controller_state is None