Skip to content

Commit

Permalink
Rudimentary tests for SupportsIndex in indexing methods
Browse files Browse the repository at this point in the history
  • Loading branch information
honno committed Mar 27, 2024
1 parent a168e5a commit ac2bb06
Showing 1 changed file with 35 additions and 0 deletions.
35 changes: 35 additions & 0 deletions array_api_tests/test_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,41 @@ def test_setitem(shape, dtypes, data):
)


class AwkwardIndexable:
def __init__(self, value: int):
self._value = value

def __int__(self):
raise TypeError("__int__() should not be called")

def __index__(self):
return self._value


@pytest.mark.parametrize(
"x, key",
[
(xp.asarray([0, 1]), AwkwardIndexable(1)),
(xp.asarray([[0, 1], [2, 3]]), (0, AwkwardIndexable(1))),
]
)
def test_getitem_supports_index(x, key):
out = x[key]
assert out == xp.asarray(1)


@pytest.mark.parametrize(
"x, key, expected",
[
(xp.asarray([0, 1]), AwkwardIndexable(1), xp.asarray([0, 42])),
(xp.asarray([[0, 1], [2, 3]]), (0, AwkwardIndexable(1)), xp.asarray([[0, 42], [2, 3]])),
]
)
def test_setitem_supports_index(x, key, expected):
x[key] = xp.asarray(42)
ph.assert_array_elements("__setitem__", out=x, expected=expected, out_repr="x")


@pytest.mark.unvectorized
@pytest.mark.data_dependent_shapes
@given(hh.shapes(), st.data())
Expand Down

0 comments on commit ac2bb06

Please sign in to comment.