Skip to content

Commit

Permalink
Merge pull request #231 from patrick-kidger/dev
Browse files Browse the repository at this point in the history
Version 0.3.0
  • Loading branch information
patrick-kidger authored Feb 21, 2023
2 parents 05d03d8 + 813fe0f commit 9280c3a
Show file tree
Hide file tree
Showing 65 changed files with 2,474 additions and 1,991 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/run_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ jobs:
run-tests:
strategy:
matrix:
python-version: [ 3.7, 3.8, 3.9 ]
python-version: [ 3.8, 3.9 ]
os: [ ubuntu-latest ]
fail-fast: false
runs-on: ${{ matrix.os }}
Expand Down
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@ repos:
hooks:
- id: black
- repo: https://github.com/nbQA-dev/nbQA
rev: 1.2.3
rev: 1.6.3
hooks:
- id: nbqa-black
- id: nbqa-isort
- id: nbqa-flake8
- repo: https://github.com/PyCQA/isort
rev: 5.10.1
rev: 5.12.0
hooks:
- id: isort
- repo: https://github.com/pycqa/flake8
Expand Down
4 changes: 3 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ _From a technical point of view, the internal structure of the library is pretty
pip install diffrax
```

Requires Python >=3.7 and JAX >=0.3.4.
Requires Python 3.8+, JAX 0.4.3+, and [Equinox](https://github.com/patrick-kidger/equinox) 0.10.0+.

## Documentation

Expand Down Expand Up @@ -65,4 +65,6 @@ Neural networks: [Equinox](https://github.com/patrick-kidger/equinox).

Type annotations and runtime checking for PyTrees and shape/dtype of JAX arrays: [jaxtyping](https://github.com/google/jaxtyping).

Computer vision models: [Eqxvision](https://github.com/paganpasta/eqxvision).

SymPy<->JAX conversion; train symbolic expressions via gradient descent: [sympy2jax](https://github.com/google/sympy2jax).
48 changes: 39 additions & 9 deletions benchmarks/compile_times.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import diffrax as dfx
import equinox as eqx
import fire
import jax
import jax.numpy as jnp
import jax.random as jr
Expand Down Expand Up @@ -31,12 +30,12 @@ def __call__(self, t, y, args):
return jnp.stack(y)


def main(inline: bool, scan_stages: bool, grad: bool, adjoint: str):
if adjoint == "direct":
def run(inline: bool, scan_stages: bool, grad: bool, adjoint_name: str):
if adjoint_name == "direct":
adjoint = dfx.DirectAdjoint()
elif adjoint == "recursive":
elif adjoint_name == "recursive":
adjoint = dfx.RecursiveCheckpointAdjoint()
elif adjoint == "backsolve":
elif adjoint_name == "backsolve":
adjoint = dfx.BacksolveAdjoint()
else:
raise ValueError
Expand Down Expand Up @@ -72,9 +71,40 @@ def solve(y0):
return jnp.sum(sol.ys)

solve_ = ft.partial(solve, jnp.array([1.0]))
print("Compile+run time", timeit.timeit(solve_, number=1))
print("Run time", timeit.timeit(solve_, number=1))
compile_time = timeit.timeit(solve_, number=1)
print(
f"{inline=}, {scan_stages=}, {grad=}, adjoint={adjoint_name}, {compile_time=}"
)


if __name__ == "__main__":
fire.Fire(main)
run(inline=False, scan_stages=False, grad=False, adjoint_name="direct")
run(inline=False, scan_stages=False, grad=False, adjoint_name="recursive")
run(inline=False, scan_stages=False, grad=False, adjoint_name="backsolve")

run(inline=False, scan_stages=False, grad=True, adjoint_name="direct")
run(inline=False, scan_stages=False, grad=True, adjoint_name="recursive")
run(inline=False, scan_stages=False, grad=True, adjoint_name="backsolve")

run(inline=False, scan_stages=True, grad=False, adjoint_name="direct")
run(inline=False, scan_stages=True, grad=False, adjoint_name="recursive")
run(inline=False, scan_stages=True, grad=False, adjoint_name="backsolve")

run(inline=False, scan_stages=True, grad=True, adjoint_name="direct")
run(inline=False, scan_stages=True, grad=True, adjoint_name="recursive")
run(inline=False, scan_stages=True, grad=True, adjoint_name="backsolve")

run(inline=True, scan_stages=False, grad=False, adjoint_name="direct")
run(inline=True, scan_stages=False, grad=False, adjoint_name="recursive")
run(inline=True, scan_stages=False, grad=False, adjoint_name="backsolve")

run(inline=True, scan_stages=False, grad=True, adjoint_name="direct")
run(inline=True, scan_stages=False, grad=True, adjoint_name="recursive")
run(inline=True, scan_stages=False, grad=True, adjoint_name="backsolve")

run(inline=True, scan_stages=True, grad=False, adjoint_name="direct")
run(inline=True, scan_stages=True, grad=False, adjoint_name="recursive")
run(inline=True, scan_stages=True, grad=False, adjoint_name="backsolve")

run(inline=True, scan_stages=True, grad=True, adjoint_name="direct")
run(inline=True, scan_stages=True, grad=True, adjoint_name="recursive")
run(inline=True, scan_stages=True, grad=True, adjoint_name="backsolve")
22 changes: 12 additions & 10 deletions benchmarks/scan_stages.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
"""Benchmarks the effect of `diffrax.AbstractRungeKutta(scan_stages=...)`.
On my CPU-only machine:
On my relatively beefy CPU-only machine:
```
bash> python scan_stages.py False
Compile+run time 24.38062646985054
Run time 0.0018830380868166685
scan_stages=True
Compile+run time 1.8253102810122073
Run time 0.00017526978626847267
bash> python scan_stages.py True
Compile+run time 11.418417416978627
Run time 0.0014536201488226652
scan_stages=False
Compile+run time 10.679616351146251
Run time 0.00021236995235085487
```
"""

Expand All @@ -17,7 +17,6 @@

import diffrax as dfx
import equinox as eqx
import fire
import jax.numpy as jnp
import jax.random as jr

Expand All @@ -44,7 +43,7 @@ def __call__(self, t, y, args):
return jnp.stack(y)


def main(scan_stages):
def run(scan_stages):
vf = VectorField(1, 1, 16, 2, key=jr.PRNGKey(0))
term = dfx.ODETerm(vf)
solver = dfx.Dopri8(scan_stages=scan_stages)
Expand All @@ -60,8 +59,11 @@ def solve(y0):
)

solve_ = ft.partial(solve, jnp.array([1.0]))
print(f"scan_stages={scan_stages}")
print("Compile+run time", timeit.timeit(solve_, number=1))
print("Run time", timeit.timeit(solve_, number=1))


fire.Fire(main)
run(scan_stages=True)
print()
run(scan_stages=False)
18 changes: 11 additions & 7 deletions benchmarks/scan_stages_cnf.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@

import diffrax
import equinox as eqx
import fire
import jax
import jax.nn as jnn
import jax.numpy as jnp
Expand All @@ -50,7 +49,7 @@ def vector_field_prob(t, input, model):
return f, logp


@eqx.filter_vmap(args=(None, 0, None, None))
@eqx.filter_vmap(in_axes=(None, 0, None, None))
def log_prob(model, y0, scan_stages, backsolve):
term = diffrax.ODETerm(vector_field_prob)
solver = diffrax.Dopri5(scan_stages=scan_stages)
Expand Down Expand Up @@ -80,13 +79,18 @@ def solve(model, inputs, scan_stages, backsolve):
return -log_prob(model, inputs, scan_stages, backsolve).mean()


def main(scan_stages, backsolve):
def run(scan_stages, backsolve):
mkey, dkey = jr.split(jr.PRNGKey(0), 2)
model = eqx.nn.MLP(2, 2, 10, 2, activation=jnn.gelu, key=mkey)
x = jr.normal(dkey, (256, 2))
solve_ = ft.partial(solve, model, x, scan_stages, backsolve)
print("Compile+run time", timeit.timeit(solve_, number=1))
print("Run time", timeit.timeit(solve_, number=1))
solve2 = ft.partial(solve, model, x, scan_stages, backsolve)
print(f"scan_stages={scan_stages}, backsolve={backsolve}")
print("Compile+run time", timeit.timeit(solve2, number=1))
print("Run time", timeit.timeit(solve2, number=1))
print()


fire.Fire(main)
run(scan_stages=False, backsolve=False)
run(scan_stages=False, backsolve=True)
run(scan_stages=True, backsolve=False)
run(scan_stages=True, backsolve=True)
28 changes: 16 additions & 12 deletions benchmarks/small_neural_ode.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
"""Benchmarks Diffrax vs torchdiffeq vs jax.experimental.ode.odeint"""

import gc
import time

import diffrax
import equinox as eqx
import fire
import jax
import jax.experimental.ode as experimental
import jax.nn as jnn
Expand Down Expand Up @@ -166,7 +167,7 @@ def time_jax(neural_ode_jax, y0, t1, grad):
_eval_jax(neural_ode_jax, y0, t1)


def main(batch_size=64, t1=100, multiple=False, grad=False):
def run(multiple, grad, batch_size=64, t1=100):
neural_ode_torch = NeuralODETorch(multiple)
neural_ode_diffrax = NeuralODEDiffrax(multiple)
neural_ode_experimental = NeuralODEExperimental(multiple)
Expand All @@ -180,25 +181,28 @@ def main(batch_size=64, t1=100, multiple=False, grad=False):
func_torch[2].bias.copy_(torch.tensor(np.asarray(func_jax.layers[1].bias)))

y0_jax = jrandom.normal(jrandom.PRNGKey(1), (batch_size, 4))
y0_torch = torch.tensor(y0_jax.to_py())
y0_torch = torch.tensor(np.asarray(y0_jax))

time_torch(neural_ode_torch, y0_torch, t1, grad)
torch_time = time_torch(neural_ode_torch, y0_torch, t1, grad)

time_jax(neural_ode_diffrax, y0_jax, t1, grad)
diffrax_time = time_jax(neural_ode_diffrax, y0_jax, t1, grad)
time_jax(neural_ode_diffrax, jnp.copy(y0_jax), t1, grad)
diffrax_time = time_jax(neural_ode_diffrax, jnp.copy(y0_jax), t1, grad)

time_jax(neural_ode_experimental, y0_jax, t1, grad)
experimental_time = time_jax(neural_ode_experimental, y0_jax, t1, grad)
time_jax(neural_ode_experimental, jnp.copy(y0_jax), t1, grad)
experimental_time = time_jax(neural_ode_experimental, jnp.copy(y0_jax), t1, grad)

print(
f"""
torch_time={torch_time}
diffrax_time={diffrax_time}
experimetnal_time={experimental_time}
f""" multiple={multiple}, grad={grad}
torch_time={torch_time}
diffrax_time={diffrax_time}
experimental_time={experimental_time}
"""
)


if __name__ == "__main__":
fire.Fire(main)
run(multiple=False, grad=False)
run(multiple=True, grad=False)
run(multiple=False, grad=True)
run(multiple=True, grad=True)
10 changes: 5 additions & 5 deletions diffrax/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from .adjoint import (
AbstractAdjoint,
BacksolveAdjoint,
DirectAdjoint,
ImplicitAdjoint,
NoAdjoint,
RecursiveCheckpointAdjoint,
)
from .autocitation import citation, citation_rules
from .brownian import AbstractBrownianPath, UnsafeBrownianPath, VirtualBrownianTree
from .event import (
AbstractDiscreteTerminatingEvent,
Expand All @@ -27,14 +28,14 @@
LocalLinearInterpolation,
ThirdOrderHermitePolynomialInterpolation,
)
from .misc import adjoint_rms_seminorm, sde_kl_divergence
from .misc import adjoint_rms_seminorm
from .nonlinear_solver import (
AbstractNonlinearSolver,
NewtonNonlinearSolver,
NonlinearSolution,
)
from .path import AbstractPath
from .saveat import SaveAt
from .saveat import SaveAt, SubSaveAt
from .solution import is_event, is_okay, is_successful, RESULTS, Solution
from .solver import (
AbstractAdaptiveSolver,
Expand All @@ -55,7 +56,6 @@
Dopri8,
Euler,
EulerHeun,
Fehlberg2,
HalfSolver,
Heun,
ImplicitEuler,
Expand Down Expand Up @@ -87,4 +87,4 @@
)


__version__ = "0.2.2"
__version__ = "0.3.0"
File renamed without changes.
Loading

0 comments on commit 9280c3a

Please sign in to comment.