From 0ee47c98efe6de80388cce50eae80f91736047d1 Mon Sep 17 00:00:00 2001 From: Patrick Kidger <33688385+patrick-kidger@users.noreply.github.com> Date: Sun, 31 Dec 2023 04:12:47 -0800 Subject: [PATCH] Now using strict shape/dtype promotion rules. This means that: 1. Tests now pass using `JAX_NUMPY_DTYPE_PROMOTION=strict` and `JAX_NUMPY_RANK_PROMOTION=raise`, and these are enabled in tests by default. 2. The values passed to `diffeqsolve` now more carefully determine the dtype used in the integration (previously things were mostly just left to behave in ad-hoc fashion; whatever the various interacting arrays promoted their dtypes to): a. The dtype of timelike values is the `jnp.result_type` of `t0`, `t1`, `dt0`, and `SaveAt(ts=...)`. If any of these are complex an error is raised. If these are all integers we use the default floating-point dtype. b. The `jnp.result_type` of the time dtype, and each leaf of `y0`, is the dtype of that leaf. 3. Of course, `diffeqsolve` accepts user-specified functions (e.g. the vector field of an `ODETerm`), and these could potentially return arrays with dtypes that do not match the ones we have selected above, which might cause further upcasting. For the sake of backward compatibility we don't try to prohibit this -- a user who feels strongly about this should enable `JAX_NUMPY_DTYPE_PROMOTION=strict` and fix their vector fields appropriately. (And can then be assured that the dtypes of these quantities are exactly as specified by the rules above.) So the key thing this commit enables is that using this flag to enforce this is now possible, without any false positives from Diffrax itself! --- diffrax/_brownian/path.py | 12 +++-- diffrax/_brownian/tree.py | 4 +- diffrax/_integrate.py | 38 ++++++++++----- diffrax/_misc.py | 32 +++++++++++-- diffrax/_step_size_controller/adaptive.py | 57 ++++++++++++++++++----- diffrax/_term.py | 28 +++++++++-- test/conftest.py | 2 + test/test_adaptive_stepsize_controller.py | 3 +- test/test_brownian.py | 18 ++++--- test/test_detest.py | 20 ++++---- test/test_global_interpolation.py | 7 +-- test/test_integrate.py | 31 ++++++------ test/test_saveat_solution.py | 3 +- 13 files changed, 184 insertions(+), 71 deletions(-) diff --git a/diffrax/_brownian/path.py b/diffrax/_brownian/path.py index c627ee50..c88e9ee2 100644 --- a/diffrax/_brownian/path.py +++ b/diffrax/_brownian/path.py @@ -7,11 +7,11 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import lineax.internal as lxi from jaxtyping import Array, PRNGKeyArray, PyTree from .._custom_types import levy_tree_transpose, LevyArea, LevyVal, RealScalarLike from .._misc import ( - default_floating_dtype, force_bitcast_convert_type, is_tuple_of_ints, split_by_tree, @@ -52,7 +52,7 @@ def __init__( levy_area: LevyArea = "", ): self.shape = ( - jax.ShapeDtypeStruct(shape, default_floating_dtype()) + jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) if is_tuple_of_ints(shape) else shape ) @@ -87,8 +87,14 @@ def evaluate( ) -> Union[PyTree[Array], LevyVal]: del left if t1 is None: + dtype = jnp.result_type(t0) t1 = t0 - t0 = 0 + t0 = jnp.array(0, dtype) + else: + with jax.numpy_dtype_promotion("standard"): + dtype = jnp.result_type(t0, t1) + t0 = jnp.astype(t0, dtype) + t1 = jnp.astype(t1, dtype) t0 = eqxi.nondifferentiable(t0, name="t0") t1 = eqxi.nondifferentiable(t1, name="t1") t1 = cast(RealScalarLike, t1) diff --git a/diffrax/_brownian/tree.py b/diffrax/_brownian/tree.py index b5335829..bf972526 100644 --- a/diffrax/_brownian/tree.py +++ b/diffrax/_brownian/tree.py @@ -9,6 +9,7 @@ import jax.numpy as jnp import jax.random as jr import jax.tree_util as jtu +import lineax.internal as lxi from jaxtyping import Array, Float, PRNGKeyArray, PyTree from .._custom_types import ( @@ -20,7 +21,6 @@ RealScalarLike, ) from .._misc import ( - default_floating_dtype, is_tuple_of_ints, linear_rescale, split_by_tree, @@ -179,7 +179,7 @@ def __init__( self.levy_area = levy_area self._spline = _spline self.shape = ( - jax.ShapeDtypeStruct(shape, default_floating_dtype()) + jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype()) if is_tuple_of_ints(shape) else shape ) diff --git a/diffrax/_integrate.py b/diffrax/_integrate.py index d33452bb..b6501e41 100644 --- a/diffrax/_integrate.py +++ b/diffrax/_integrate.py @@ -10,6 +10,7 @@ import jax.core import jax.numpy as jnp import jax.tree_util as jtu +import lineax.internal as lxi from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint @@ -299,10 +300,10 @@ def body_fun_aux(state): # Count the number of steps, just for statistical purposes. num_steps = state.num_steps + 1 - num_accepted_steps = state.num_accepted_steps + keep_step + num_accepted_steps = state.num_accepted_steps + jnp.where(keep_step, 1, 0) # Not just ~keep_step, which does the wrong thing when keep_step is a non-array # bool True/False. - num_rejected_steps = state.num_rejected_steps + jnp.invert(keep_step) + num_rejected_steps = state.num_rejected_steps + jnp.where(keep_step, 0, 1) # # Store the output produced from this numerical step. @@ -369,7 +370,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: subsaveat.fn(tprev, y, args), save_state.ys, ) - save_index = save_state.save_index + keep_step + save_index = save_state.save_index + jnp.where(keep_step, 1, 0) save_state = eqx.tree_at( lambda s: [s.ts, s.ys, s.save_index], save_state, @@ -388,7 +389,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState: dense_info, dense_infos, ) - dense_save_index = dense_save_index + keep_step + dense_save_index = dense_save_index + jnp.where(keep_step, 1, 0) new_state = State( y=y, @@ -625,7 +626,7 @@ def diffeqsolve( f"t0 with value {t0} and type {type(t0)}, " f"dt0 with value {dt0} and type {type(dt0)}" ) - with jax.ensure_compile_time_eval(): + with jax.ensure_compile_time_eval(), jax.numpy_dtype_promotion("standard"): pred = (t1 - t0) * dt0 < 0 dt0 = eqxi.error_if(jnp.array(dt0), pred, msg) @@ -641,7 +642,8 @@ def diffeqsolve( ) warnings.warn( "Complex dtype support is work in progress, please read " - "https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully." + "https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.", + stacklevel=2, ) # Backward compatibility @@ -653,7 +655,8 @@ def diffeqsolve( f"{solver.__class__.__name__} is deprecated in favour of " "`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that " "the same terms can now be passed used for both general and SDE-specific " - "solvers!" + "solvers!", + stacklevel=2, ) terms = MultiTerm(*terms) @@ -668,7 +671,8 @@ def diffeqsolve( if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)): warnings.warn( f"`{type(solver).__name__}` is not marked as converging to either the " - "Itô or the Stratonovich solution." + "Itô or the Stratonovich solution.", + stacklevel=2, ) if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController): # Specific check to not work even if using HalfSolver(Euler()) @@ -684,11 +688,22 @@ def diffeqsolve( ) # Allow setting e.g. t0 as an int with dt0 as a float. - timelikes = [jnp.array(0.0), t0, t1, dt0] + [ + timelikes = [t0, t1, dt0] + [ s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat) ] timelikes = [x for x in timelikes if x is not None] - time_dtype = jnp.result_type(*timelikes) + with jax.numpy_dtype_promotion("standard"): + time_dtype = jnp.result_type(*timelikes) + if jnp.issubdtype(time_dtype, jnp.complexfloating): + raise ValueError( + "Cannot use complex dtype for `t0`, `t1`, `dt0`, or `SaveAt(ts=...)`." + ) + elif jnp.issubdtype(time_dtype, jnp.floating): + pass + elif jnp.issubdtype(time_dtype, jnp.integer): + time_dtype = lxi.default_floating_dtype() + else: + raise ValueError(f"Unrecognised time dtype {time_dtype}.") t0 = jnp.asarray(t0, dtype=time_dtype) t1 = jnp.asarray(t1, dtype=time_dtype) if dt0 is not None: @@ -708,7 +723,8 @@ def _get_subsaveat_ts(saveat): # fixing issue with float64 and weak dtypes, see discussion at: # https://github.com/patrick-kidger/diffrax/pull/197#discussion_r1130173527 def _promote(yi): - _dtype = jnp.result_type(yi, time_dtype) # noqa: F821 + with jax.numpy_dtype_promotion("standard"): + _dtype = jnp.result_type(yi, time_dtype) # noqa: F821 return jnp.asarray(yi, dtype=_dtype) y0 = jtu.tree_map(_promote, y0) diff --git a/diffrax/_misc.py b/diffrax/_misc.py index 1d8ffa1c..5fe777b4 100644 --- a/diffrax/_misc.py +++ b/diffrax/_misc.py @@ -162,8 +162,30 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike return lax.select(pred, a, b) -def default_floating_dtype(): - if jax.config.jax_enable_x64: # pyright: ignore - return jnp.float64 - else: - return jnp.float32 +def upcast_or_raise( + x: ArrayLike, array_for_dtype: ArrayLike, x_name: str, dtype_name: str +): + """If `JAX_NUMPY_DTYPE_PROMOTION=strict`, then this will raise an error if + `jnp.result_type(x, array_for_dtype)` is not the same as `array_for_dtype.dtype`. + It will then cast `x` to `jnp.result_type(x, array_for_dtype)`. + + Thus if `JAX_NUMPY_DTYPE_PROMOTION=standard`, then the usual anything-goes behaviour + will apply. If `JAX_NUMPY_DTYPE_PROMOTION=strict` then we loosen from prohibiting + all dtype casting, to still allowing upcasting. + """ + x_dtype = jnp.result_type(x) + target_dtype = jnp.result_type(array_for_dtype) + with jax.numpy_dtype_promotion("standard"): + promote_dtype = jnp.result_type(x_dtype, target_dtype) + config_value = jax.config.jax_numpy_dtype_promotion + if config_value == "strict": + if target_dtype != promote_dtype: + raise ValueError( + f"When `JAX_NUMPY_DTYPE_PROMOTION=strict`, then {x_name} must have " + f"a dtype that can be promoted to the dtype of {dtype_name}. " + f"However {x_name} had dtype {x_dtype} and {dtype_name} had dtype " + f"{target_dtype}." + ) + elif config_value != "standard": + assert False, f"Unrecognised `JAX_NUMPY_DTYPE_PROMOTION={config_value}`" + return jnp.astype(x, promote_dtype) diff --git a/diffrax/_step_size_controller/adaptive.py b/diffrax/_step_size_controller/adaptive.py index 847cfa81..4b3d432a 100644 --- a/diffrax/_step_size_controller/adaptive.py +++ b/diffrax/_step_size_controller/adaptive.py @@ -27,6 +27,7 @@ VF, Y, ) +from .._misc import upcast_or_raise from .._solution import RESULTS from .._term import AbstractTerm, ODETerm from .base import AbstractStepSizeController @@ -325,6 +326,14 @@ class PIDController( safety: RealScalarLike = 0.9 error_order: Optional[RealScalarLike] = None + def __check_init__(self): + if self.jump_ts is not None and not jnp.issubdtype( + self.jump_ts.dtype, jnp.inexact + ): + raise ValueError( + f"jump_ts must be floating point, not {self.jump_ts.dtype}" + ) + def wrap(self, direction: IntScalarLike): step_ts = None if self.step_ts is None else self.step_ts * direction jump_ts = None if self.jump_ts is None else self.jump_ts * direction @@ -632,10 +641,22 @@ def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLik if self.step_ts is None: return t1 + step_ts0 = upcast_or_raise( + self.step_ts, + t0, + "`PIDController.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + step_ts1 = upcast_or_raise( + self.step_ts, + t1, + "`PIDController.step_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) # TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping # track of where we were last, and using that as a hint for the next search. - t0_index = jnp.searchsorted(self.step_ts, t0, side="right") - t1_index = jnp.searchsorted(self.step_ts, t1, side="right") + t0_index = jnp.searchsorted(step_ts0, t0, side="right") + t1_index = jnp.searchsorted(step_ts1, t1, side="right") # This minimum may or may not actually be necessary. The left branch is taken # iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must # already satisfy the minimum. @@ -643,7 +664,7 @@ def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLik # so we clamp it just to be sure we're not hitting undefined behaviour. t1 = jnp.where( t0_index < t1_index, - self.step_ts[jnp.minimum(t0_index, len(self.step_ts) - 1)], + step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)], t1, ) return t1 @@ -653,23 +674,35 @@ def _clip_jump_ts( ) -> tuple[RealScalarLike, BoolScalarLike]: if self.jump_ts is None: return t1, False - if self.jump_ts is not None and not jnp.issubdtype( - self.jump_ts.dtype, jnp.inexact - ): + assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact) + if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact): raise ValueError( - f"jump_ts must be floating point, not {self.jump_ts.dtype}" + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t0)}." ) if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact): raise ValueError( - "t0, t1, dt0 must be floating point when specifying jump_t. Got " - f"{jnp.result_type(t1)}." + "`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. " + f"Got {jnp.result_type(t1)}." ) - t0_index = jnp.searchsorted(self.jump_ts, t0, side="right") - t1_index = jnp.searchsorted(self.jump_ts, t1, side="right") + jump_ts0 = upcast_or_raise( + self.jump_ts, + t0, + "`PIDController.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + jump_ts1 = upcast_or_raise( + self.jump_ts, + t1, + "`PIDController.jump_ts`", + "time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)", + ) + t0_index = jnp.searchsorted(jump_ts0, t0, side="right") + t1_index = jnp.searchsorted(jump_ts1, t1, side="right") next_made_jump = t0_index < t1_index t1 = jnp.where( next_made_jump, - eqxi.prevbefore(self.jump_ts[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), + eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]), t1, ) return t1, next_made_jump diff --git a/diffrax/_term.py b/diffrax/_term.py index 9bea8cf8..277089b0 100644 --- a/diffrax/_term.py +++ b/diffrax/_term.py @@ -12,6 +12,7 @@ from jaxtyping import Array, ArrayLike, PyTree, PyTreeDef from ._custom_types import Args, Control, IntScalarLike, RealScalarLike, VF, Y +from ._misc import upcast_or_raise from ._path import AbstractPath @@ -159,7 +160,8 @@ class ODETerm(AbstractTerm): appearing on the right hand side of an ODE, in which the control is time. `vector_field` should return some PyTree, with the same structure as the initial - state `y0`, and with every leaf broadcastable to the equivalent leaf in `y0`. + state `y0`, and with every leaf shape-broadcastable and dtype-upcastable to the + equivalent leaf in `y0`. !!! example @@ -179,13 +181,33 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF: "The vector field inside `ODETerm` must return a pytree with the " "same structure as `y0`." ) - return jtu.tree_map(lambda o, yi: jnp.broadcast_to(o, jnp.shape(yi)), out, y) + + def _broadcast_and_upcast(oi, yi): + oi = jnp.broadcast_to(oi, jnp.shape(yi)) + oi = upcast_or_raise( + oi, + yi, + "the vector field passed to `ODETerm`", + "the corresponding leaf of `y`", + ) + return oi + + return jtu.tree_map(_broadcast_and_upcast, out, y) def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike: return t1 - t0 def prod(self, vf: VF, control: RealScalarLike) -> Y: - return jtu.tree_map(lambda v: control * v, vf) + def _mul(v): + c = upcast_or_raise( + control, + v, + "the output of `ODETerm.contr(...)`", + "the output of `ODETerm.vf(...)`", + ) + return c * v + + return jtu.tree_map(_mul, vf) ODETerm.__init__.__doc__ = """**Arguments:** diff --git a/test/conftest.py b/test/conftest.py index 989d28f1..55f09f71 100644 --- a/test/conftest.py +++ b/test/conftest.py @@ -4,6 +4,8 @@ jax.config.update("jax_enable_x64", True) # pyright: ignore +jax.config.update("jax_numpy_rank_promotion", "raise") # pyright: ignore +jax.config.update("jax_numpy_dtype_promotion", "strict") # pyright: ignore @pytest.fixture() diff --git a/test/test_adaptive_stepsize_controller.py b/test/test_adaptive_stepsize_controller.py index 4d3ffd7a..4cc996c8 100644 --- a/test/test_adaptive_stepsize_controller.py +++ b/test/test_adaptive_stepsize_controller.py @@ -83,7 +83,8 @@ def run(ys, controller, state): _, tprev, tnext, _, state, _ = controller.adapt_step_size( 0, 1, y0, y1_candidate, None, y_error, 5, state ) - return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state)) + with jax.numpy_dtype_promotion("standard"): + return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state)) y0 = jnp.array(1.0) y1_candidate = jnp.array(2.0) diff --git a/test/test_brownian.py b/test/test_brownian.py index 76454a4a..4fe293d7 100644 --- a/test/test_brownian.py +++ b/test/test_brownian.py @@ -1,3 +1,4 @@ +import contextlib import math from typing import Literal from typing_extensions import TypeAlias @@ -33,8 +34,8 @@ def _make_struct(shape, dtype): @pytest.mark.parametrize("levy_area", ["", "space-time"]) @pytest.mark.parametrize("use_levy", (False, True)) def test_shape_and_dtype(ctr, levy_area, use_levy, getkey): - t0 = 0 - t1 = 2 + t0 = 0.0 + t1 = 2.0 shapes = ( (), @@ -91,11 +92,14 @@ def is_tuple_of_ints(obj): if dtype is None: shape = jtu.tree_map(_make_struct, shape, dtype, is_leaf=is_tuple_of_ints) - for _t0 in _vals.values(): - for _t1 in _vals.values(): - t0, _ = _t0 - _, t1 = _t1 - out = path.evaluate(t0, t1, use_levy=use_levy) + for t0_dtype, (t0, _) in _vals.items(): + for t1_dtype, (_, t2) in _vals.items(): + if all(x in (float, jnp.float32) for x in (t0_dtype, t1_dtype)): + context = contextlib.nullcontext() + else: + context = jax.numpy_dtype_promotion("standard") + with context: + out = path.evaluate(t0, t1, use_levy=use_levy) if use_levy: assert isinstance(out, diffrax.LevyVal) w = out.W diff --git a/test/test_detest.py b/test/test_detest.py index 47d38f73..6dbb20e3 100644 --- a/test/test_detest.py +++ b/test/test_detest.py @@ -12,6 +12,7 @@ import math import diffrax +import jax import jax.flatten_util as fu import jax.numpy as jnp import jax.tree_util as jtu @@ -147,15 +148,15 @@ def _c2(): A = ( jnp.zeros((10, 10)) .at[jnp.arange(9), jnp.arange(9)] - .set(-jnp.arange(1, 10)) + .set(-jnp.arange(1.0, 10.0)) .at[jnp.arange(1, 10), jnp.arange(9)] - .set(jnp.arange(1, 10)) + .set(jnp.arange(1.0, 10.0)) ) def diffeq(t, y, args): return A @ y - init = jnp.zeros(10).at[0].set(1) + init = jnp.zeros(10).at[0].set(1.0) return diffeq, init @@ -215,12 +216,13 @@ def diffeq(t, y, args): r_cubed_j = r_cubed_k = jnp.sum(y_ij**2, axis=0) ** 1.5 d_cubed_jk = jnp.sum((y_ij[:, :, None] - y_ij[:, None, :]) ** 2, axis=0) ** 1.5 - term1_ij = -(m0 + m_j) * y_ij / r_cubed_j - term2_ijk = (y_ij[:, None, :] - y_ij[:, :, None]) / d_cubed_jk - term3_ik = y_ik / r_cubed_k - term4_ijk = m_k * (term2_ijk - term3_ik[:, None]) - term4_ijk = term4_ijk.at[:, jnp.arange(5), jnp.arange(5)].set(0) - term5_ij = jnp.sum(term4_ijk, axis=-1) + with jax.numpy_rank_promotion("allow"): + term1_ij = -(m0 + m_j) * y_ij / r_cubed_j + term2_ijk = (y_ij[:, None, :] - y_ij[:, :, None]) / d_cubed_jk + term3_ik = y_ik / r_cubed_k + term4_ijk = m_k * (term2_ijk - term3_ik[:, None]) + term4_ijk = term4_ijk.at[:, jnp.arange(5), jnp.arange(5)].set(0) + term5_ij = jnp.sum(term4_ijk, axis=-1) ddy_ij = k2 * (term1_ij + term5_ij) return dy_ij, ddy_ij diff --git a/test/test_global_interpolation.py b/test/test_global_interpolation.py index ef3650de..968cee24 100644 --- a/test/test_global_interpolation.py +++ b/test/test_global_interpolation.py @@ -290,9 +290,10 @@ def test_interpolation_classes(mode, getkey): def _test(firstval, vals, y0, y1): vals = jnp.concatenate([firstval[None], vals]) - true_vals = y0 + ((points - t0) / (t1 - t0))[:, None] * ( - y1 - y0 - ) + with jax.numpy_rank_promotion("allow"): + true_vals = y0 + ((points - t0) / (t1 - t0))[:, None] * ( + y1 - y0 + ) assert tree_allclose(vals, true_vals) jtu.tree_map(_test, firstval, vals, y0, y1) diff --git a/test/test_integrate.py b/test/test_integrate.py index 05a8576b..3bef5ce6 100644 --- a/test/test_integrate.py +++ b/test/test_integrate.py @@ -1,3 +1,4 @@ +import contextlib import math import operator from typing import cast @@ -77,15 +78,16 @@ def test_basic(solver, t_dtype, y_dtype, treedef, stepsize_controller, getkey): ) and treedef == jtu.tree_structure(None): return - if jnp.iscomplexobj(y_dtype): - - def f(t, y, args): - return jtu.tree_map(lambda _y: operator.mul(-1j, _y), y) - + if jnp.iscomplexobj(y_dtype) and treedef != jtu.tree_structure(None): if isinstance(solver, diffrax.AbstractImplicitSolver): return + else: + complex_warn = pytest.warns(match="Complex dtype") + def f(t, y, args): + return jtu.tree_map(lambda yi: -1j * yi, y) else: + complex_warn = contextlib.nullcontext() def f(t, y, args): return jtu.tree_map(operator.neg, y) @@ -110,15 +112,16 @@ def f(t, y, args): raise ValueError y0 = random_pytree(getkey(), treedef, dtype=y_dtype) try: - sol = diffrax.diffeqsolve( - diffrax.ODETerm(f), - solver, - t0, - t1, - dt0, - y0, - stepsize_controller=stepsize_controller, - ) + with complex_warn: + sol = diffrax.diffeqsolve( + diffrax.ODETerm(f), + solver, + t0, + t1, + dt0, + y0, + stepsize_controller=stepsize_controller, + ) except Exception as e: if isinstance(stepsize_controller, diffrax.ConstantStepSize) and str( e diff --git a/test/test_saveat_solution.py b/test/test_saveat_solution.py index 63168a2b..cc73bf99 100644 --- a/test/test_saveat_solution.py +++ b/test/test_saveat_solution.py @@ -118,7 +118,8 @@ def test_saveat_solution(): assert sol.ts.shape == (4096,) # pyright: ignore assert sol.ys.shape == (4096, 1) # pyright: ignore _ts = jnp.where(sol.ts == jnp.inf, jnp.nan, sol.ts) - _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] + with jax.numpy_rank_promotion("allow"): + _ys = _y0 * jnp.exp(-0.5 * (_ts - _t0))[:, None] _ys = jnp.where(jnp.isnan(_ys), jnp.inf, _ys) assert tree_allclose(sol.ys, _ys) assert sol.controller_state is None