From 2ed403242d292d9fd9f8d0457a08392f5d42f1a3 Mon Sep 17 00:00:00 2001 From: Ianna Osborne Date: Mon, 5 Aug 2024 16:32:33 +0200 Subject: [PATCH] move to local branch --- tests/test_spec_elementwise_functions.py | 20 +++++++++++++++----- 1 file changed, 15 insertions(+), 5 deletions(-) diff --git a/tests/test_spec_elementwise_functions.py b/tests/test_spec_elementwise_functions.py index fdac4a9..0a4c164 100644 --- a/tests/test_spec_elementwise_functions.py +++ b/tests/test_spec_elementwise_functions.py @@ -14,12 +14,22 @@ with warnings.catch_warnings(): warnings.simplefilter("ignore") - import numpy.array_api as xp import pytest import ragged +has_complex_dtype = True + +if np.lib.NumpyVersion(np.__version__) < "2.0.0b1": + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + import array_api_strict as xp # type: ignore[import-not-found] + + has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes +else: + xp = np + devices = ["cpu"] try: import cupy as cp @@ -384,7 +394,7 @@ def test_ceil_int(device, x_int): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -571,7 +581,7 @@ def test_greater_equal_method(device, x, y): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -838,7 +848,7 @@ def test_pow_inplace_method(device, x, y): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices) @@ -888,7 +898,7 @@ def test_round(device, x): @pytest.mark.skipif( - np.dtype("complex128") not in xp._dtypes._all_dtypes, + not has_complex_dtype, reason=f"complex not allowed in np.array_api version {np.__version__}", ) @pytest.mark.parametrize("device", devices)