Skip to content

Commit

Permalink
take (that was the last one)
Browse files Browse the repository at this point in the history
  • Loading branch information
jpivarski committed Jan 15, 2024
1 parent 66452d3 commit 9f00705
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 6 deletions.
33 changes: 28 additions & 5 deletions src/ragged/_spec_indexing_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@

from __future__ import annotations

from ._spec_array_object import array
import awkward as ak
import numpy as np

from ._spec_array_object import _box, array


def take(x: array, indices: array, /, *, axis: None | int = None) -> array:
Expand Down Expand Up @@ -37,7 +40,27 @@ def take(x: array, indices: array, /, *, axis: None | int = None) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.take.html
"""

x # noqa: B018, pylint: disable=W0104
indices # noqa: B018, pylint: disable=W0104
axis # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 109") # noqa: EM101
if axis is None:
if x.ndim <= 1:
axis = 0
else:
msg = f"for an {x.ndim}-dimensional array (greater than 1-dimensional), the 'axis' argument is required"
raise TypeError(msg)

original_axis = axis
if axis < 0:
axis += x.ndim + 1
if not 0 <= axis < x.ndim:
msg = f"axis {original_axis} is out of bounds for array of dimension {x.ndim}"
raise ak.errors.AxisError(msg)

toslice = x._impl # pylint: disable=W0212
if not isinstance(toslice, ak.Array):
toslice = ak.Array(toslice[np.newaxis]) # type: ignore[index]

if not isinstance(indices, array):
indices = array(indices) # type: ignore[unreachable]
indexarray = indices._impl # pylint: disable=W0212

slicer = (slice(None),) * axis + (indexarray,)
return _box(type(x), toslice[slicer])
21 changes: 20 additions & 1 deletion tests/test_spec_indexing.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,25 @@

from __future__ import annotations

import ragged


def test():
pass
# slices are extensively tested in Awkward Array, just check 'axis' argument

a = ragged.array([0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9])
assert ragged.take(a, ragged.array([5, 3, 3, 9, 0, 1]), axis=0).tolist() == [
5.5,
3.3,
3.3,
9.9,
0,
1.1,
]

b = ragged.array([[0.0, 1.1, 2.2], [3.3, 4.4], [5.5, 6.6, 7.7, 8.8, 9.9]])
assert ragged.take(b, ragged.array([0, 1, 1, 0]), axis=1).tolist() == [
[0, 1.1, 1.1, 0],
[3.3, 4.4, 4.4, 3.3],
[5.5, 6.6, 6.6, 5.5],
]

0 comments on commit 9f00705

Please sign in to comment.