Skip to content

Commit

Permalink
move to local branch
Browse files Browse the repository at this point in the history
  • Loading branch information
ianna committed Aug 5, 2024
1 parent fff5354 commit 2ed4032
Showing 1 changed file with 15 additions and 5 deletions.
20 changes: 15 additions & 5 deletions tests/test_spec_elementwise_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,22 @@

with warnings.catch_warnings():
warnings.simplefilter("ignore")
import numpy.array_api as xp

import pytest

import ragged

has_complex_dtype = True

if np.lib.NumpyVersion(np.__version__) < "2.0.0b1":
with warnings.catch_warnings():
warnings.simplefilter("ignore")
import array_api_strict as xp # type: ignore[import-not-found]

has_complex_dtype = np.dtype("complex128") in xp._dtypes._all_dtypes
else:
xp = np

devices = ["cpu"]
try:
import cupy as cp
Expand Down Expand Up @@ -384,7 +394,7 @@ def test_ceil_int(device, x_int):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -571,7 +581,7 @@ def test_greater_equal_method(device, x, y):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -838,7 +848,7 @@ def test_pow_inplace_method(device, x, y):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down Expand Up @@ -888,7 +898,7 @@ def test_round(device, x):


@pytest.mark.skipif(
np.dtype("complex128") not in xp._dtypes._all_dtypes,
not has_complex_dtype,
reason=f"complex not allowed in np.array_api version {np.__version__}",
)
@pytest.mark.parametrize("device", devices)
Expand Down

0 comments on commit 2ed4032

Please sign in to comment.