Skip to content

Commit

Permalink
feature: ragged.sort and ragged.argsort (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski authored May 2, 2024
1 parent 7d96860 commit f78ea7d
Show file tree
Hide file tree
Showing 2 changed files with 96 additions and 11 deletions.
26 changes: 15 additions & 11 deletions src/ragged/_spec_sorting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,9 @@

from __future__ import annotations

from ._spec_array_object import array
import awkward as ak

from ._spec_array_object import _box, _unbox, array


def argsort(
Expand Down Expand Up @@ -34,11 +36,12 @@ def argsort(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.argsort.html
"""

x # noqa: B018, pylint: disable=W0104
axis # noqa: B018, pylint: disable=W0104
descending # noqa: B018, pylint: disable=W0104
stable # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 132") # noqa: EM101
(impl,) = _unbox(x)
if not isinstance(impl, ak.Array):
msg = f"axis {axis} is out of bounds for array of dimension 0"
raise ak.errors.AxisError(msg)
out = ak.argsort(impl, axis=axis, ascending=not descending, stable=stable)
return _box(type(x), out)


def sort(
Expand Down Expand Up @@ -66,8 +69,9 @@ def sort(
https://data-apis.org/array-api/latest/API_specification/generated/array_api.sort.html
"""

x # noqa: B018, pylint: disable=W0104
axis # noqa: B018, pylint: disable=W0104
descending # noqa: B018, pylint: disable=W0104
stable # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 133") # noqa: EM101
(impl,) = _unbox(x)
if not isinstance(impl, ak.Array):
msg = f"axis {axis} is out of bounds for array of dimension 0"
raise ak.errors.AxisError(msg)
out = ak.sort(impl, axis=axis, ascending=not descending, stable=stable)
return _box(type(x), out)
81 changes: 81 additions & 0 deletions tests/test_spec_sorting_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,90 @@

from __future__ import annotations

import pytest

import ragged

devices = ["cpu"]
try:
import cupy as cp

# FIXME!
# devices.append("cuda")
except ModuleNotFoundError:
cp = None


def test_existence():
assert ragged.argsort is not None
assert ragged.sort is not None


@pytest.mark.parametrize("device", devices)
def test_argsort(device):
x = ragged.array(
[[1.1, 0, 2.2], [], [3.3, 4.4], [5.5], [9.9, 7.7, 8.8, 6.6]], device=device
)
assert ragged.argsort(x, axis=1, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap]
[1, 0, 2],
[],
[0, 1],
[0],
[3, 1, 2, 0],
]
assert ragged.argsort(x, axis=1, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap]
[2, 0, 1],
[],
[1, 0],
[0],
[0, 2, 1, 3],
]
assert ragged.argsort(x, axis=0, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap]
[0, 0, 0],
[],
[2, 2],
[3],
[4, 4, 4, 4],
]
assert ragged.argsort(x, axis=0, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap]
[4, 4, 4],
[],
[3, 2],
[2],
[0, 0, 0, 4],
]


@pytest.mark.parametrize("device", devices)
def test_sort(device):
x = ragged.array(
[[1.1, 0, 2.2], [], [3.3, 4.4], [5.5], [9.9, 7.7, 8.8, 6.6]], device=device
)
assert ragged.sort(x, axis=1, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap]
[0, 1.1, 2.2],
[],
[3.3, 4.4],
[5.5],
[6.6, 7.7, 8.8, 9.9],
]
assert ragged.sort(x, axis=1, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap]
[2.2, 1.1, 0],
[],
[4.4, 3.3],
[5.5],
[9.9, 8.8, 7.7, 6.6],
]
assert ragged.sort(x, axis=0, stable=True, descending=False).tolist() == [ # type: ignore[comparison-overlap]
[1.1, 0.0, 2.2],
[],
[3.3, 4.4],
[5.5],
[9.9, 7.7, 8.8, 6.6],
]
assert ragged.sort(x, axis=0, stable=True, descending=True).tolist() == [ # type: ignore[comparison-overlap]
[9.9, 7.7, 8.8],
[],
[5.5, 4.4],
[3.3],
[1.1, 0.0, 2.2, 6.6],
]

0 comments on commit f78ea7d

Please sign in to comment.