From f78ea7d46ea4c4c574c9d8ba76fc1282e6229f7d Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Thu, 2 May 2024 16:46:20 -0500 Subject: [PATCH] feature: ragged.sort and ragged.argsort (#48) --- src/ragged/_spec_sorting_functions.py | 26 +++++---- tests/test_spec_sorting_functions.py | 81 +++++++++++++++++++++++++++ 2 files changed, 96 insertions(+), 11 deletions(-) diff --git a/src/ragged/_spec_sorting_functions.py b/src/ragged/_spec_sorting_functions.py index d75e3ff..fa4f7a5 100644 --- a/src/ragged/_spec_sorting_functions.py +++ b/src/ragged/_spec_sorting_functions.py @@ -6,7 +6,9 @@ from __future__ import annotations -from ._spec_array_object import array +import awkward as ak + +from ._spec_array_object import _box, _unbox, array def argsort( @@ -34,11 +36,12 @@ def argsort( https://data-apis.org/array-api/latest/API_specification/generated/array_api.argsort.html """ - x # noqa: B018, pylint: disable=W0104 - axis # noqa: B018, pylint: disable=W0104 - descending # noqa: B018, pylint: disable=W0104 - stable # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 132") # noqa: EM101 + (impl,) = _unbox(x) + if not isinstance(impl, ak.Array): + msg = f"axis {axis} is out of bounds for array of dimension 0" + raise ak.errors.AxisError(msg) + out = ak.argsort(impl, axis=axis, ascending=not descending, stable=stable) + return _box(type(x), out) def sort( @@ -66,8 +69,9 @@ def sort( https://data-apis.org/array-api/latest/API_specification/generated/array_api.sort.html """ - x # noqa: B018, pylint: disable=W0104 - axis # noqa: B018, pylint: disable=W0104 - descending # noqa: B018, pylint: disable=W0104 - stable # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 133") # noqa: EM101 + (impl,) = _unbox(x) + if not isinstance(impl, ak.Array): + msg = f"axis {axis} is out of bounds for array of dimension 0" + raise ak.errors.AxisError(msg) + out = ak.sort(impl, axis=axis, ascending=not descending, stable=stable) + return _box(type(x), out) diff --git a/tests/test_spec_sorting_functions.py b/tests/test_spec_sorting_functions.py index ebf82f7..353af11 100644 --- a/tests/test_spec_sorting_functions.py +++ b/tests/test_spec_sorting_functions.py @@ -6,9 +6,90 @@ from __future__ import annotations +import pytest + import ragged +devices = ["cpu"] +try: + import cupy as cp + + # FIXME! + # devices.append("cuda") +except ModuleNotFoundError: + cp = None + def test_existence(): assert ragged.argsort is not None assert ragged.sort is not None + + +@pytest.mark.parametrize("device", devices) +def test_argsort(device): + x = ragged.array( + [[1.1, 0, 2.2], [], [3.3, 4.4], [5.5], [9.9, 7.7, 8.8, 6.6]], device=device + ) + assert ragged.argsort(x, axis=1, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap] + [1, 0, 2], + [], + [0, 1], + [0], + [3, 1, 2, 0], + ] + assert ragged.argsort(x, axis=1, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap] + [2, 0, 1], + [], + [1, 0], + [0], + [0, 2, 1, 3], + ] + assert ragged.argsort(x, axis=0, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap] + [0, 0, 0], + [], + [2, 2], + [3], + [4, 4, 4, 4], + ] + assert ragged.argsort(x, axis=0, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap] + [4, 4, 4], + [], + [3, 2], + [2], + [0, 0, 0, 4], + ] + + +@pytest.mark.parametrize("device", devices) +def test_sort(device): + x = ragged.array( + [[1.1, 0, 2.2], [], [3.3, 4.4], [5.5], [9.9, 7.7, 8.8, 6.6]], device=device + ) + assert ragged.sort(x, axis=1, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap] + [0, 1.1, 2.2], + [], + [3.3, 4.4], + [5.5], + [6.6, 7.7, 8.8, 9.9], + ] + assert ragged.sort(x, axis=1, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap] + [2.2, 1.1, 0], + [], + [4.4, 3.3], + [5.5], + [9.9, 8.8, 7.7, 6.6], + ] + assert ragged.sort(x, axis=0, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap] + [1.1, 0.0, 2.2], + [], + [3.3, 4.4], + [5.5], + [9.9, 7.7, 8.8, 6.6], + ] + assert ragged.sort(x, axis=0, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap] + [9.9, 7.7, 8.8], + [], + [5.5, 4.4], + [3.3], + [1.1, 0.0, 2.2, 6.6], + ]