diff --git a/tests/mypy.ini.txt b/tests/mypy.ini.txt new file mode 100644 index 0000000..1215375 --- /dev/null +++ b/tests/mypy.ini.txt @@ -0,0 +1,2 @@ +[mypy] +ignore_missing_imports = True \ No newline at end of file diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index fdac4a9..af3f640 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -12,13 +12,23 @@ import awkward as ak import numpy as np -with warnings.catch_warnings(): - warnings.simplefilter("ignore") - import numpy.array_api as xp - import pytest import ragged +""" +with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import array_api_strict as xp +""" +has_complex_dtype = True + +if np.lib.NumpyVersion(np.__version__) < '2.0.0b1': + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import array_api_strict as xp + has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes +else: + xp = np devices = ["cpu"] try: @@ -377,14 +387,20 @@ def test_ceil(device, x): @pytest.mark.parametrize("device", devices) def test_ceil_int(device, x_int): result = ragged.ceil(x_int.to_device(device)) + print(x_int.dtype) + print(ragged.ceil(x_int.to_device(device)).dtype) assert type(result) is type(x_int) assert result.shape == x_int.shape assert xp.ceil(first(x_int)) == first(result) - assert xp.ceil(first(x_int)).dtype == result.dtype - + print((first(x_int)).dtype) + print((xp.ceil(first(x_int))).dtype) + print(first(result).dtype) + assert np.ceil(first(x_int)).dtype == result.dtype + print((np.ceil(first(x_int))).dtype) + print(result.dtype) @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -571,7 +587,7 @@ def test_greater_equal_method(device, x, y): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -838,7 +854,7 @@ def test_pow_inplace_method(device, x, y): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -888,7 +904,7 @@ def test_round(device, x): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices)