diff --git a/docs/api/array.md b/docs/api/array.md index 6f33c9b..3e97982 100644 --- a/docs/api/array.md +++ b/docs/api/array.md @@ -84,18 +84,45 @@ Float32[Array, "some_shape"] ## Array -The array should usually be a `jaxtyping.Array`, which is an alias for `jax.numpy.ndarray` (which is itself an alias for `jax.Array`). +The array should typically be either one of: +```python +jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one another +np.ndarray +torch.Tensor +tf.Tensor +``` +That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow. + +Some other types are also supported here: -`jaxtyping.ArrayLike` is also available, which is an alias for `jax.typing.ArrayLike`. This is a union over JAX arrays and the builtin `bool`/`int`/`float`/`complex`. +**Unions:** these are unpacked. For example, `SomeDtype[Union[A, B], "some shape"]` is equivalent to `Union[SomeDtype[A, "some shape"], SomeDtype[B, "some shape"]]`. A common example of a union type here is `np.typing.ArrayLike`. -You can use non-JAX types as well. jaxtyping also supports NumPy, TensorFlow, and PyTorch, e.g.: +**Any:** use `typing.Any` to check just the shape/dtype, but not the array type. + +**Duck-type arrays:** anything with `.shape` and `.dtype` attributes. For example, ```python -Float[np.ndarray, "..."] -Float[tf.Tensor, "..."] -Float[torch.Tensor, "..."] +class MyDuckArray: + @property + def shape(self) -> tuple[int, ...]: + return (3, 4, 5) + + @property + def dtype(self) -> str: + return "my_dtype" + +class MyDtype(jaxtyping.AbstractDtype): + dtypes = ["my_dtype"] + +x = MyDuckArray() +assert isinstance(x, MyDtype[MyDuckArray, "3 4 5"]) +# checks that `type(x) == MyDuckArray` +# and that `x.shape == (3, 4, 5)` +# and that `x.dtype == "my_dtype"` ``` -Shape-and-dtype specified jaxtyping arrays can also be used, e.g. +**TypeVars:** in this case the runtime array is checked for matching the bounds or constraints of the `typing.TypeVar`. + +**Existing jaxtyped annotations:** ```python Image = Float[Array, "channels height width"] BatchImage = Float[Image, "batch"] diff --git a/docs/requirements.txt b/docs/requirements.txt index 00a1607..23b2733 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -7,6 +7,8 @@ pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects mkdocs_include_exclude_files==0.0.1 # Tweak which files are included/excluded jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings. pygments==2.14.0 +mkdocs-autorefs==1.0.1 +mkdocs-material-extensions==1.3.1 # Dependencies of jaxtyping itself. # Always use most up-to-date versions.