diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index 0a4c164..7afebd2 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -21,16 +21,21 @@ has_complex_dtype = True -if np.lib.NumpyVersion(np.__version__) < "2.0.0b1": - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - import array_api_strict as xp # type: ignore[import-not-found] - - has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes -else: - xp = np - +# if np.lib.NumpyVersion(np.__version__) < "2.0.0b1": +# with warnings.catch_warnings(): +# warnings.simplefilter("ignore") +# import array_api_strict as xp # type: ignore[import-not-found] + +# has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes +# else: +# xp = np devices = ["cpu"] +try: + import numpy.array_api as xp + + has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes +except ModuleNotFoundError: + import numpy as xp # noqa: ICN001 try: import cupy as cp