Skip to content

Commit

Permalink
Now using strict shape/dtype promotion rules.
Browse files Browse the repository at this point in the history
This means that:

1. Tests now pass using `JAX_NUMPY_DTYPE_PROMOTION=strict` and `JAX_NUMPY_RANK_PROMOTION=raise`, and these are enabled in tests by default.
2. The values passed to `diffeqsolve` now more carefully determine the dtype used in the integration (previously things were mostly just left to behave in ad-hoc fashion; whatever the various interacting arrays promoted their dtypes to):
    a. The dtype of timelike values is the `jnp.result_type` of `t0`, `t1`, `dt0`, and `SaveAt(ts=...)`. If any of these are complex an error is raised. If these are all integers we use the default floating-point dtype.
    b. The `jnp.result_type` of the time dtype, and each leaf of `y0`, is the dtype of that leaf.
3. Of course, `diffeqsolve` accepts user-specified functions (e.g. the vector field of an `ODETerm`), and these could potentially return arrays with dtypes that do not match the ones we have selected above, which might cause further upcasting. For the sake of backward compatibility we don't try to prohibit this -- a user who feels strongly about this should enable `JAX_NUMPY_DTYPE_PROMOTION=strict` and fix their vector fields appropriately. (And can then be assured that the dtypes of these quantities are exactly as specified by the rules above.) So the key thing this commit enables is that using this flag to enforce this is now possible, without any false positives from Diffrax itself!
  • Loading branch information
patrick-kidger committed Jan 8, 2024
1 parent 7965e89 commit 0ee47c9
Show file tree
Hide file tree
Showing 13 changed files with 184 additions and 71 deletions.
12 changes: 9 additions & 3 deletions diffrax/_brownian/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,11 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import lineax.internal as lxi
from jaxtyping import Array, PRNGKeyArray, PyTree

from .._custom_types import levy_tree_transpose, LevyArea, LevyVal, RealScalarLike
from .._misc import (
default_floating_dtype,
force_bitcast_convert_type,
is_tuple_of_ints,
split_by_tree,
Expand Down Expand Up @@ -52,7 +52,7 @@ def __init__(
levy_area: LevyArea = "",
):
self.shape = (
jax.ShapeDtypeStruct(shape, default_floating_dtype())
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
if is_tuple_of_ints(shape)
else shape
)
Expand Down Expand Up @@ -87,8 +87,14 @@ def evaluate(
) -> Union[PyTree[Array], LevyVal]:
del left
if t1 is None:
dtype = jnp.result_type(t0)
t1 = t0
t0 = 0
t0 = jnp.array(0, dtype)
else:
with jax.numpy_dtype_promotion("standard"):
dtype = jnp.result_type(t0, t1)
t0 = jnp.astype(t0, dtype)
t1 = jnp.astype(t1, dtype)
t0 = eqxi.nondifferentiable(t0, name="t0")
t1 = eqxi.nondifferentiable(t1, name="t1")
t1 = cast(RealScalarLike, t1)
Expand Down
4 changes: 2 additions & 2 deletions diffrax/_brownian/tree.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import jax.numpy as jnp
import jax.random as jr
import jax.tree_util as jtu
import lineax.internal as lxi
from jaxtyping import Array, Float, PRNGKeyArray, PyTree

from .._custom_types import (
Expand All @@ -20,7 +21,6 @@
RealScalarLike,
)
from .._misc import (
default_floating_dtype,
is_tuple_of_ints,
linear_rescale,
split_by_tree,
Expand Down Expand Up @@ -179,7 +179,7 @@ def __init__(
self.levy_area = levy_area
self._spline = _spline
self.shape = (
jax.ShapeDtypeStruct(shape, default_floating_dtype())
jax.ShapeDtypeStruct(shape, lxi.default_floating_dtype())
if is_tuple_of_ints(shape)
else shape
)
Expand Down
38 changes: 27 additions & 11 deletions diffrax/_integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import jax.core
import jax.numpy as jnp
import jax.tree_util as jtu
import lineax.internal as lxi
from jaxtyping import Array, ArrayLike, Float, Inexact, PyTree, Real

from ._adjoint import AbstractAdjoint, RecursiveCheckpointAdjoint
Expand Down Expand Up @@ -299,10 +300,10 @@ def body_fun_aux(state):

# Count the number of steps, just for statistical purposes.
num_steps = state.num_steps + 1
num_accepted_steps = state.num_accepted_steps + keep_step
num_accepted_steps = state.num_accepted_steps + jnp.where(keep_step, 1, 0)
# Not just ~keep_step, which does the wrong thing when keep_step is a non-array
# bool True/False.
num_rejected_steps = state.num_rejected_steps + jnp.invert(keep_step)
num_rejected_steps = state.num_rejected_steps + jnp.where(keep_step, 0, 1)

#
# Store the output produced from this numerical step.
Expand Down Expand Up @@ -369,7 +370,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
subsaveat.fn(tprev, y, args),
save_state.ys,
)
save_index = save_state.save_index + keep_step
save_index = save_state.save_index + jnp.where(keep_step, 1, 0)
save_state = eqx.tree_at(
lambda s: [s.ts, s.ys, s.save_index],
save_state,
Expand All @@ -388,7 +389,7 @@ def save_steps(subsaveat: SubSaveAt, save_state: SaveState) -> SaveState:
dense_info,
dense_infos,
)
dense_save_index = dense_save_index + keep_step
dense_save_index = dense_save_index + jnp.where(keep_step, 1, 0)

new_state = State(
y=y,
Expand Down Expand Up @@ -625,7 +626,7 @@ def diffeqsolve(
f"t0 with value {t0} and type {type(t0)}, "
f"dt0 with value {dt0} and type {type(dt0)}"
)
with jax.ensure_compile_time_eval():
with jax.ensure_compile_time_eval(), jax.numpy_dtype_promotion("standard"):
pred = (t1 - t0) * dt0 < 0
dt0 = eqxi.error_if(jnp.array(dt0), pred, msg)

Expand All @@ -641,7 +642,8 @@ def diffeqsolve(
)
warnings.warn(
"Complex dtype support is work in progress, please read "
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully."
"https://github.com/patrick-kidger/diffrax/pull/197 and proceed carefully.",
stacklevel=2,
)

# Backward compatibility
Expand All @@ -653,7 +655,8 @@ def diffeqsolve(
f"{solver.__class__.__name__} is deprecated in favour of "
"`terms=MultiTerm(ODETerm(...), SomeOtherTerm(...))`. This means that "
"the same terms can now be passed used for both general and SDE-specific "
"solvers!"
"solvers!",
stacklevel=2,
)
terms = MultiTerm(*terms)

Expand All @@ -668,7 +671,8 @@ def diffeqsolve(
if not isinstance(solver, (AbstractItoSolver, AbstractStratonovichSolver)):
warnings.warn(
f"`{type(solver).__name__}` is not marked as converging to either the "
"Itô or the Stratonovich solution."
"Itô or the Stratonovich solution.",
stacklevel=2,
)
if isinstance(stepsize_controller, AbstractAdaptiveStepSizeController):
# Specific check to not work even if using HalfSolver(Euler())
Expand All @@ -684,11 +688,22 @@ def diffeqsolve(
)

# Allow setting e.g. t0 as an int with dt0 as a float.
timelikes = [jnp.array(0.0), t0, t1, dt0] + [
timelikes = [t0, t1, dt0] + [
s.ts for s in jtu.tree_leaves(saveat.subs, is_leaf=_is_subsaveat)
]
timelikes = [x for x in timelikes if x is not None]
time_dtype = jnp.result_type(*timelikes)
with jax.numpy_dtype_promotion("standard"):
time_dtype = jnp.result_type(*timelikes)
if jnp.issubdtype(time_dtype, jnp.complexfloating):
raise ValueError(
"Cannot use complex dtype for `t0`, `t1`, `dt0`, or `SaveAt(ts=...)`."
)
elif jnp.issubdtype(time_dtype, jnp.floating):
pass
elif jnp.issubdtype(time_dtype, jnp.integer):
time_dtype = lxi.default_floating_dtype()
else:
raise ValueError(f"Unrecognised time dtype {time_dtype}.")
t0 = jnp.asarray(t0, dtype=time_dtype)
t1 = jnp.asarray(t1, dtype=time_dtype)
if dt0 is not None:
Expand All @@ -708,7 +723,8 @@ def _get_subsaveat_ts(saveat):
# fixing issue with float64 and weak dtypes, see discussion at:
# https://github.com/patrick-kidger/diffrax/pull/197#discussion_r1130173527
def _promote(yi):
_dtype = jnp.result_type(yi, time_dtype) # noqa: F821
with jax.numpy_dtype_promotion("standard"):
_dtype = jnp.result_type(yi, time_dtype) # noqa: F821
return jnp.asarray(yi, dtype=_dtype)

y0 = jtu.tree_map(_promote, y0)
Expand Down
32 changes: 27 additions & 5 deletions diffrax/_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,8 +162,30 @@ def static_select(pred: BoolScalarLike, a: ArrayLike, b: ArrayLike) -> ArrayLike
return lax.select(pred, a, b)


def default_floating_dtype():
if jax.config.jax_enable_x64: # pyright: ignore
return jnp.float64
else:
return jnp.float32
def upcast_or_raise(
x: ArrayLike, array_for_dtype: ArrayLike, x_name: str, dtype_name: str
):
"""If `JAX_NUMPY_DTYPE_PROMOTION=strict`, then this will raise an error if
`jnp.result_type(x, array_for_dtype)` is not the same as `array_for_dtype.dtype`.
It will then cast `x` to `jnp.result_type(x, array_for_dtype)`.
Thus if `JAX_NUMPY_DTYPE_PROMOTION=standard`, then the usual anything-goes behaviour
will apply. If `JAX_NUMPY_DTYPE_PROMOTION=strict` then we loosen from prohibiting
all dtype casting, to still allowing upcasting.
"""
x_dtype = jnp.result_type(x)
target_dtype = jnp.result_type(array_for_dtype)
with jax.numpy_dtype_promotion("standard"):
promote_dtype = jnp.result_type(x_dtype, target_dtype)
config_value = jax.config.jax_numpy_dtype_promotion
if config_value == "strict":
if target_dtype != promote_dtype:
raise ValueError(
f"When `JAX_NUMPY_DTYPE_PROMOTION=strict`, then {x_name} must have "
f"a dtype that can be promoted to the dtype of {dtype_name}. "
f"However {x_name} had dtype {x_dtype} and {dtype_name} had dtype "
f"{target_dtype}."
)
elif config_value != "standard":
assert False, f"Unrecognised `JAX_NUMPY_DTYPE_PROMOTION={config_value}`"
return jnp.astype(x, promote_dtype)
57 changes: 45 additions & 12 deletions diffrax/_step_size_controller/adaptive.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
VF,
Y,
)
from .._misc import upcast_or_raise
from .._solution import RESULTS
from .._term import AbstractTerm, ODETerm
from .base import AbstractStepSizeController
Expand Down Expand Up @@ -325,6 +326,14 @@ class PIDController(
safety: RealScalarLike = 0.9
error_order: Optional[RealScalarLike] = None

def __check_init__(self):
if self.jump_ts is not None and not jnp.issubdtype(
self.jump_ts.dtype, jnp.inexact
):
raise ValueError(
f"jump_ts must be floating point, not {self.jump_ts.dtype}"
)

def wrap(self, direction: IntScalarLike):
step_ts = None if self.step_ts is None else self.step_ts * direction
jump_ts = None if self.jump_ts is None else self.jump_ts * direction
Expand Down Expand Up @@ -632,18 +641,30 @@ def _clip_step_ts(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLik
if self.step_ts is None:
return t1

step_ts0 = upcast_or_raise(
self.step_ts,
t0,
"`PIDController.step_ts`",
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
)
step_ts1 = upcast_or_raise(
self.step_ts,
t1,
"`PIDController.step_ts`",
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
)
# TODO: it should be possible to switch this O(nlogn) for just O(n) by keeping
# track of where we were last, and using that as a hint for the next search.
t0_index = jnp.searchsorted(self.step_ts, t0, side="right")
t1_index = jnp.searchsorted(self.step_ts, t1, side="right")
t0_index = jnp.searchsorted(step_ts0, t0, side="right")
t1_index = jnp.searchsorted(step_ts1, t1, side="right")
# This minimum may or may not actually be necessary. The left branch is taken
# iff t0_index < t1_index <= len(self.step_ts), so all valid t0_index s must
# already satisfy the minimum.
# However, that branch is actually executed unconditionally and then where'd,
# so we clamp it just to be sure we're not hitting undefined behaviour.
t1 = jnp.where(
t0_index < t1_index,
self.step_ts[jnp.minimum(t0_index, len(self.step_ts) - 1)],
step_ts1[jnp.minimum(t0_index, len(self.step_ts) - 1)],
t1,
)
return t1
Expand All @@ -653,23 +674,35 @@ def _clip_jump_ts(
) -> tuple[RealScalarLike, BoolScalarLike]:
if self.jump_ts is None:
return t1, False
if self.jump_ts is not None and not jnp.issubdtype(
self.jump_ts.dtype, jnp.inexact
):
assert jnp.issubdtype(self.jump_ts.dtype, jnp.inexact)
if not jnp.issubdtype(jnp.result_type(t0), jnp.inexact):
raise ValueError(
f"jump_ts must be floating point, not {self.jump_ts.dtype}"
"`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. "
f"Got {jnp.result_type(t0)}."
)
if not jnp.issubdtype(jnp.result_type(t1), jnp.inexact):
raise ValueError(
"t0, t1, dt0 must be floating point when specifying jump_t. Got "
f"{jnp.result_type(t1)}."
"`t0`, `t1`, `dt0` must be floating point when specifying `jump_ts`. "
f"Got {jnp.result_type(t1)}."
)
t0_index = jnp.searchsorted(self.jump_ts, t0, side="right")
t1_index = jnp.searchsorted(self.jump_ts, t1, side="right")
jump_ts0 = upcast_or_raise(
self.jump_ts,
t0,
"`PIDController.jump_ts`",
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
)
jump_ts1 = upcast_or_raise(
self.jump_ts,
t1,
"`PIDController.jump_ts`",
"time (the result type of `t0`, `t1`, `dt0`, `SaveAt(ts=...)` etc.)",
)
t0_index = jnp.searchsorted(jump_ts0, t0, side="right")
t1_index = jnp.searchsorted(jump_ts1, t1, side="right")
next_made_jump = t0_index < t1_index
t1 = jnp.where(
next_made_jump,
eqxi.prevbefore(self.jump_ts[jnp.minimum(t0_index, len(self.jump_ts) - 1)]),
eqxi.prevbefore(jump_ts1[jnp.minimum(t0_index, len(self.jump_ts) - 1)]),
t1,
)
return t1, next_made_jump
Expand Down
28 changes: 25 additions & 3 deletions diffrax/_term.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from jaxtyping import Array, ArrayLike, PyTree, PyTreeDef

from ._custom_types import Args, Control, IntScalarLike, RealScalarLike, VF, Y
from ._misc import upcast_or_raise
from ._path import AbstractPath


Expand Down Expand Up @@ -159,7 +160,8 @@ class ODETerm(AbstractTerm):
appearing on the right hand side of an ODE, in which the control is time.
`vector_field` should return some PyTree, with the same structure as the initial
state `y0`, and with every leaf broadcastable to the equivalent leaf in `y0`.
state `y0`, and with every leaf shape-broadcastable and dtype-upcastable to the
equivalent leaf in `y0`.
!!! example
Expand All @@ -179,13 +181,33 @@ def vf(self, t: RealScalarLike, y: Y, args: Args) -> VF:
"The vector field inside `ODETerm` must return a pytree with the "
"same structure as `y0`."
)
return jtu.tree_map(lambda o, yi: jnp.broadcast_to(o, jnp.shape(yi)), out, y)

def _broadcast_and_upcast(oi, yi):
oi = jnp.broadcast_to(oi, jnp.shape(yi))
oi = upcast_or_raise(
oi,
yi,
"the vector field passed to `ODETerm`",
"the corresponding leaf of `y`",
)
return oi

return jtu.tree_map(_broadcast_and_upcast, out, y)

def contr(self, t0: RealScalarLike, t1: RealScalarLike) -> RealScalarLike:
return t1 - t0

def prod(self, vf: VF, control: RealScalarLike) -> Y:
return jtu.tree_map(lambda v: control * v, vf)
def _mul(v):
c = upcast_or_raise(
control,
v,
"the output of `ODETerm.contr(...)`",
"the output of `ODETerm.vf(...)`",
)
return c * v

return jtu.tree_map(_mul, vf)


ODETerm.__init__.__doc__ = """**Arguments:**
Expand Down
2 changes: 2 additions & 0 deletions test/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@


jax.config.update("jax_enable_x64", True) # pyright: ignore
jax.config.update("jax_numpy_rank_promotion", "raise") # pyright: ignore
jax.config.update("jax_numpy_dtype_promotion", "strict") # pyright: ignore


@pytest.fixture()
Expand Down
3 changes: 2 additions & 1 deletion test/test_adaptive_stepsize_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def run(ys, controller, state):
_, tprev, tnext, _, state, _ = controller.adapt_step_size(
0, 1, y0, y1_candidate, None, y_error, 5, state
)
return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state))
with jax.numpy_dtype_promotion("standard"):
return tprev + tnext + sum(jnp.sum(x) for x in jtu.tree_leaves(state))

y0 = jnp.array(1.0)
y1_candidate = jnp.array(2.0)
Expand Down
Loading

0 comments on commit 0ee47c9

Please sign in to comment.