Skip to content

Commit

Permalink
Improve docs to mention all the variant arraylikes we now support
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Nov 17, 2024
1 parent 8823419 commit d28a86c
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 7 deletions.
41 changes: 34 additions & 7 deletions docs/api/array.md
Original file line number Diff line number Diff line change
Expand Up @@ -84,18 +84,45 @@ Float32[Array, "some_shape"]

## Array

The array should usually be a `jaxtyping.Array`, which is an alias for `jax.numpy.ndarray` (which is itself an alias for `jax.Array`).
The array should typically be either one of:
```python
jaxtyping.Array / jax.Array / jax.numpy.ndarray # these are all aliases of one another
np.ndarray
torch.Tensor
tf.Tensor
```
That is -- despite the now-historical name! -- jaxtyping also supports NumPy + PyTorch + TensorFlow.

Some other types are also supported here:

`jaxtyping.ArrayLike` is also available, which is an alias for `jax.typing.ArrayLike`. This is a union over JAX arrays and the builtin `bool`/`int`/`float`/`complex`.
**Unions:** these are unpacked. For example, `SomeDtype[Union[A, B], "some shape"]` is equivalent to `Union[SomeDtype[A, "some shape"], SomeDtype[B, "some shape"]]`. A common example of a union type here is `np.typing.ArrayLike`.

You can use non-JAX types as well. jaxtyping also supports NumPy, TensorFlow, and PyTorch, e.g.:
**Any:** use `typing.Any` to check just the shape/dtype, but not the array type.

**Duck-type arrays:** anything with `.shape` and `.dtype` attributes. For example,
```python
Float[np.ndarray, "..."]
Float[tf.Tensor, "..."]
Float[torch.Tensor, "..."]
class MyDuckArray:
@property
def shape(self) -> tuple[int, ...]:
return (3, 4, 5)

@property
def dtype(self) -> str:
return "my_dtype"

class MyDtype(jaxtyping.AbstractDtype):
dtypes = ["my_dtype"]

x = MyDuckArray()
assert isinstance(x, MyDtype[MyDuckArray, "3 4 5"])
# checks that `type(x) == MyDuckArray`
# and that `x.shape == (3, 4, 5)`
# and that `x.dtype == "my_dtype"`
```

Shape-and-dtype specified jaxtyping arrays can also be used, e.g.
**TypeVars:** in this case the runtime array is checked for matching the bounds or constraints of the `typing.TypeVar`.

**Existing jaxtyped annotations:**
```python
Image = Float[Array, "channels height width"]
BatchImage = Float[Image, "batch"]
Expand Down
2 changes: 2 additions & 0 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ pytkdocs_tweaks==0.0.8 # Tweaks mkdocstrings to improve various aspects
mkdocs_include_exclude_files==0.0.1 # Tweak which files are included/excluded
jinja2==3.0.3 # Older version. After 3.1.0 seems to be incompatible with current versions of mkdocstrings.
pygments==2.14.0
mkdocs-autorefs==1.0.1
mkdocs-material-extensions==1.3.1

# Dependencies of jaxtyping itself.
# Always use most up-to-date versions.
Expand Down

0 comments on commit d28a86c

Please sign in to comment.