Skip to content

Commit

Permalink
Remove skipping of tests for i1/i2 dtypes since work-around
Browse files Browse the repository at this point in the history
was applied in C++.

Add tests for 2d input arrays, for axis=0 and axis=1

Add a test for non-contiguous input, 0d input, validation

100% coverage of top_k function implementation achieved
  • Loading branch information
oleksandr-pavlyk committed Jan 4, 2025
1 parent 7c7d8f9 commit e1b7540
Showing 1 changed file with 156 additions and 6 deletions.
162 changes: 156 additions & 6 deletions dpctl/tests/test_usm_ndarray_top_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def _expected_largest_inds(inp, n, shift, k):
@pytest.mark.parametrize(
"dtype",
[
pytest.param("i1", marks=pytest.mark.skip(reason="CPU bug")),
"i1",
"u1",
"i2",
"u2",
Expand All @@ -74,8 +74,6 @@ def _expected_largest_inds(inp, n, shift, k):
def test_top_k_1d_largest(dtype, n):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)
if dtype == "i1":
pytest.skip()

shift, k = 734, 5
o = dpt.ones(n, dtype=dtype)
Expand All @@ -89,9 +87,9 @@ def test_top_k_1d_largest(dtype, n):
assert s.values.shape == (k,)
assert s.values.dtype == inp.dtype
assert s.indices.shape == (k,)
assert dpt.all(s.indices == expected_inds)
assert dpt.all(s.values == dpt.ones(k, dtype=dtype)), s.values
assert dpt.all(s.values == inp[s.indices]), s.indices
assert dpt.all(s.indices == expected_inds), (s.indices, expected_inds)


def _expected_smallest_inds(inp, n, shift, k):
Expand Down Expand Up @@ -128,7 +126,7 @@ def _expected_smallest_inds(inp, n, shift, k):
@pytest.mark.parametrize(
"dtype",
[
pytest.param("i1", marks=pytest.mark.skip(reason="CPU bug")),
"i1",
"u1",
"i2",
"u2",
Expand Down Expand Up @@ -160,6 +158,158 @@ def test_top_k_1d_smallest(dtype, n):
assert s.values.shape == (k,)
assert s.values.dtype == inp.dtype
assert s.indices.shape == (k,)
assert dpt.all(s.indices == expected_inds)
assert dpt.all(s.values == dpt.zeros(k, dtype=dtype)), s.values
assert dpt.all(s.values == inp[s.indices]), s.indices
assert dpt.all(s.indices == expected_inds), (s.indices, expected_inds)


@pytest.mark.parametrize(
"dtype",
[
# skip short types to ensure that m*n can be represented
# in the type
"i4",
"u4",
"i8",
"u8",
"f2",
"f4",
"f8",
"c8",
"c16",
],
)
@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193])
def test_top_k_2d_largest(dtype, n):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

m, k = 8, 3
if dtype == "f2" and m * n > 2000:
pytest.skip(
"f2 can not distinguish between large integers used in this test"
)

x = dpt.reshape(dpt.arange(m * n, dtype=dtype), (m, n))

r = dpt.top_k(x, k, axis=1)

assert r.values.shape == (m, k)
assert r.indices.shape == (m, k)
expected_inds = dpt.reshape(dpt.arange(n, dtype=r.indices.dtype), (1, n))[
:, -k:
]
assert expected_inds.shape == (1, k)
assert dpt.all(
dpt.sort(r.indices, axis=1) == dpt.sort(expected_inds, axis=1)
), (r.indices, expected_inds)
expected_vals = x[:, -k:]
assert dpt.all(
dpt.sort(r.values, axis=1) == dpt.sort(expected_vals, axis=1)
)


@pytest.mark.parametrize(
"dtype",
[
# skip short types to ensure that m*n can be represented
# in the type
"i4",
"u4",
"i8",
"u8",
"f2",
"f4",
"f8",
"c8",
"c16",
],
)
@pytest.mark.parametrize("n", [37, 39, 61, 255, 257, 513, 1021, 8193])
def test_top_k_2d_smallest(dtype, n):
q = get_queue_or_skip()
skip_if_dtype_not_supported(dtype, q)

m, k = 8, 3
if dtype == "f2" and m * n > 2000:
pytest.skip(
"f2 can not distinguish between large integers used in this test"
)

x = dpt.reshape(dpt.arange(m * n, dtype=dtype), (m, n))

r = dpt.top_k(x, k, axis=1, mode="smallest")

assert r.values.shape == (m, k)
assert r.indices.shape == (m, k)
expected_inds = dpt.reshape(dpt.arange(n, dtype=r.indices.dtype), (1, n))[
:, :k
]
assert dpt.all(
dpt.sort(r.indices, axis=1) == dpt.sort(expected_inds, axis=1)
)
assert dpt.all(dpt.sort(r.values, axis=1) == dpt.sort(x[:, :k], axis=1))


def test_top_k_0d():
get_queue_or_skip()

a = dpt.ones(tuple(), dtype="i4")
assert a.ndim == 0
assert a.size == 1

r = dpt.top_k(a, 1)
assert r.values == a
assert r.indices == dpt.zeros_like(a, dtype=r.indices.dtype)


def test_top_k_noncontig():
get_queue_or_skip()

a = dpt.arange(256, dtype=dpt.int32)[::2]
r = dpt.top_k(a, 3)

assert dpt.all(dpt.sort(r.values) == dpt.asarray([250, 252, 254])), r.values
assert dpt.all(
dpt.sort(r.indices) == dpt.asarray([125, 126, 127])
), r.indices


def test_top_k_axis0():
get_queue_or_skip()

m, n, k = 128, 8, 3
x = dpt.reshape(dpt.arange(m * n, dtype=dpt.int32), (m, n))

r = dpt.top_k(x, k, axis=0, mode="smallest")
assert r.values.shape == (k, n)
assert r.indices.shape == (k, n)
expected_inds = dpt.reshape(dpt.arange(m, dtype=r.indices.dtype), (m, 1))[
:k, :
]
assert dpt.all(
dpt.sort(r.indices, axis=0) == dpt.sort(expected_inds, axis=0)
)
assert dpt.all(dpt.sort(r.values, axis=0) == dpt.sort(x[:k, :], axis=0))


def test_top_k_validation():
get_queue_or_skip()
x = dpt.ones(10, dtype=dpt.int64)
with pytest.raises(ValueError):
# k must be positive
dpt.top_k(x, -1)
with pytest.raises(TypeError):
# argument should be usm_ndarray
dpt.top_k(list(), 2)
x2 = dpt.reshape(x, (2, 5))
with pytest.raises(ValueError):
# k must not exceed array dimension
# along specified axis
dpt.top_k(x2, 100, axis=1)
with pytest.raises(ValueError):
# for 0d arrays, k must be 1
dpt.top_k(x[0], 2)
with pytest.raises(ValueError):
# mode must be "largest", or "smallest"
dpt.top_k(x, 2, mode="invalid")

0 comments on commit e1b7540

Please sign in to comment.