Skip to content

Commit

Permalink
Added support for custom arrays.
Browse files Browse the repository at this point in the history
  • Loading branch information
patrick-kidger committed Jun 25, 2024
1 parent 099b189 commit b0bbff9
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 9 deletions.
14 changes: 6 additions & 8 deletions jaxtyping/_array_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,14 +188,12 @@ def __instancecheck_str__(cls, obj: Any) -> str:
# TensorFlow
dtype = obj.dtype.as_numpy_dtype.__name__
else:
# PyTorch
repr_dtype = repr(obj.dtype).split(".")
if len(repr_dtype) == 2 and repr_dtype[0] == "torch":
dtype = repr_dtype[1]
else:
raise AnnotationError(
"Unrecognised array/tensor type to extract dtype from"
)
# Everyone else, including PyTorch.
# This offers an escape hatch for anyone looking to use jaxtyping for their
# own array-like types.
dtype = obj.dtype
if not isinstance(dtype, str):
*_, dtype = repr(obj.dtype).rsplit(".", 1)

if cls.dtypes is not _any_dtype:
in_dtypes = False
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "jaxtyping"
version = "0.2.30"
version = "0.2.31"
description = "Type annotations and runtime checking for shape and dtype of JAX arrays, and PyTrees."
readme = "README.md"
requires-python ="~=3.9"
Expand Down
56 changes: 56 additions & 0 deletions test/test_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,3 +710,59 @@ def test_scalar_variadic_dim():
def test_scalar_dtype_mismatch():
with pytest.raises(ValueError):
Float[bool, "..."]


def test_custom_array(jaxtyp, typecheck):
class MyArray1:
@property
def dtype(self):
return "foo"

@property
def shape(self):
return (3,)

class MyArray2:
@property
def dtype(self):
return "bar"

@property
def shape(self):
return (3,)

class MyArray3:
@property
def dtype(self):
return "foo"

@property
def shape(self):
return (4,)

class FooDtype(AbstractDtype):
dtypes = ["foo"]

@jaxtyp(typecheck)
def f(x: FooDtype[MyArray1, "3"]):
pass

f(MyArray1())
with pytest.raises(ParamError):
f(MyArray2())
with pytest.raises(ParamError):
f(MyArray3())

@jaxtyp(typecheck)
def g(x: FooDtype[MyArray1, "3"], y: FooDtype[MyArray1, "4"]):
pass

with pytest.raises(ParamError):
g(MyArray1(), MyArray1())

@jaxtyp(typecheck)
def h(x: FooDtype[MyArray1, "3"], y: FooDtype[MyArray3, "4"]):
pass

with pytest.raises(ParamError):
g(MyArray1(), MyArray3())

0 comments on commit b0bbff9

Please sign in to comment.