diff --git a/README.md b/README.md index 9ae914f..cae92d3 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,8 @@ pip install jpc ``` Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/equinox) -0.11.2+, [Diffrax](https://github.com/patrick-kidger/diffrax) 0.5.1+, and +0.11.2+, [Diffrax](https://github.com/patrick-kidger/diffrax) 0.5.1+, +[Optax](https://github.com/google-deepmind/optax) 0.2.2+, and [Jaxtyping](https://github.com/patrick-kidger/jaxtyping) 0.2.24+. ## Documentation diff --git a/docs/index.md b/docs/index.md index 32c3fe1..1d4e0b1 100644 --- a/docs/index.md +++ b/docs/index.md @@ -18,7 +18,8 @@ pip install jpc ``` Requires Python 3.9+, JAX 0.4.23+, [Equinox](https://github.com/patrick-kidger/equinox) -0.11.2+, [Diffrax](https://github.com/patrick-kidger/diffrax) 0.5.1+, and +0.11.2+, [Diffrax](https://github.com/patrick-kidger/diffrax) 0.5.1+, +[Optax](https://github.com/google-deepmind/optax) 0.2.2+, and [Jaxtyping](https://github.com/patrick-kidger/jaxtyping) 0.2.24+. ## Quick example diff --git a/pyproject.toml b/pyproject.toml index dbc405e..d040f0c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,6 +36,7 @@ dependencies = [ "jax>=0.4.23", "equinox>=0.11.2", "diffrax>=0.5.1", + "optax>=0.2.2", "jaxtyping>=0.2.24" ]