Skip to content

Commit

Permalink
Add more tests for 100% coverage of top_k function
Browse files Browse the repository at this point in the history
  • Loading branch information
oleksandr-pavlyk committed Dec 31, 2024
1 parent 5f096b8 commit f78d172
Showing 1 changed file with 40 additions and 0 deletions.
40 changes: 40 additions & 0 deletions dpctl/tests/test_usm_ndarray_top_k.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,3 +276,43 @@ def test_top_k_noncontig():
assert dpt.all(
dpt.sort(r.indices) == dpt.asarray([125, 126, 127])
), r.indices


def test_top_k_axis0():
get_queue_or_skip()

m, n, k = 128, 8, 3
x = dpt.reshape(dpt.arange(m * n, dtype=dpt.int32), (m, n))

r = dpt.top_k(x, k, axis=0, mode="smallest")
assert r.values.shape == (k, n)
assert r.indices.shape == (k, n)
expected_inds = dpt.reshape(dpt.arange(m, dtype=r.indices.dtype), (m, 1))[
:k, :
]
assert dpt.all(
dpt.sort(r.indices, axis=0) == dpt.sort(expected_inds, axis=0)
)
assert dpt.all(dpt.sort(r.values, axis=0) == dpt.sort(x[:k, :], axis=0))


def test_top_k_validation():
get_queue_or_skip()
x = dpt.ones(10, dtype=dpt.int64)
with pytest.raises(ValueError):
# k must be positive
dpt.top_k(x, -1)
with pytest.raises(TypeError):
# argument should be usm_ndarray
dpt.top_k(list(), 2)
x2 = dpt.reshape(x, (2, 5))
with pytest.raises(ValueError):
# k must not exceed array dimension
# along specified axis
dpt.top_k(x2, 100, axis=1)
with pytest.raises(ValueError):
# for 0d arrays, k must be 1
dpt.top_k(x[0], 2)
with pytest.raises(ValueError):
# mode must be "largest", or "smallest"
dpt.top_k(x, 2, mode="invalid")

0 comments on commit f78d172

Please sign in to comment.