diff --git a/jaxtyping/_array_types.py b/jaxtyping/_array_types.py index d77c496..fb539ae 100644 --- a/jaxtyping/_array_types.py +++ b/jaxtyping/_array_types.py @@ -19,6 +19,7 @@ import enum import functools as ft +import importlib.util import re import sys import types @@ -26,7 +27,17 @@ from dataclasses import dataclass from typing import Any, Literal, NoReturn, Optional, Union -import numpy as np + +# Bit of a hack, but jaxtyping provides nicer error messages than typeguard. This means +# we sometimes want to use it as our runtime type checker everywhere, even in non-array +# use-cases, for which numpy is too heavy a dependency. +# Honestly we should probably consider factoring out part of jaxtyping into a separate +# package. (Specifically (a) the multi-argument checking and (b) the better error +# messages and (c) the import hook that places the checker on the bottom of the +# decorator stack.) And resist the urge to write our own runtime type-checker, I really +# don't want to have to keep that up-to-date with changes in the Python typing spec... +if importlib.util.find_spec("numpy") is not None: + import numpy as np from ._errors import AnnotationError from ._storage import ( diff --git a/pyproject.toml b/pyproject.toml index e50b329..c42cb26 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -23,7 +23,7 @@ classifiers = [ "Topic :: Scientific/Engineering :: Mathematics", ] urls = {repository = "https://github.com/google/jaxtyping" } -dependencies = ["numpy>=1.20.0", "typeguard==2.13.3"] +dependencies = ["typeguard==2.13.3"] entry-points = {pytest11 = {jaxtyping = "jaxtyping._pytest_plugin"}} [build-system]