From 9f0070564bae6cbf43a13cbf6b37e789a1d6da35 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Mon, 15 Jan 2024 15:53:25 -0600 Subject: [PATCH] take (that was the last one) --- src/ragged/_spec_indexing_functions.py | 33 ++++++++++++++++++++++---- tests/test_spec_indexing.py | 21 +++++++++++++++- 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/ragged/_spec_indexing_functions.py b/src/ragged/_spec_indexing_functions.py index bf54c4f..3744d56 100644 --- a/src/ragged/_spec_indexing_functions.py +++ b/src/ragged/_spec_indexing_functions.py @@ -6,7 +6,10 @@ from __future__ import annotations -from ._spec_array_object import array +import awkward as ak +import numpy as np + +from ._spec_array_object import _box, array def take(x: array, indices: array, /, *, axis: None | int = None) -> array: @@ -37,7 +40,27 @@ def take(x: array, indices: array, /, *, axis: None | int = None) -> array: https://data-apis.org/array-api/latest/API_specification/generated/array_api.take.html """ - x # noqa: B018, pylint: disable=W0104 - indices # noqa: B018, pylint: disable=W0104 - axis # noqa: B018, pylint: disable=W0104 - raise NotImplementedError("TODO 109") # noqa: EM101 + if axis is None: + if x.ndim <= 1: + axis = 0 + else: + msg = f"for an {x.ndim}-dimensional array (greater than 1-dimensional), the 'axis' argument is required" + raise TypeError(msg) + + original_axis = axis + if axis < 0: + axis += x.ndim + 1 + if not 0 <= axis < x.ndim: + msg = f"axis {original_axis} is out of bounds for array of dimension {x.ndim}" + raise ak.errors.AxisError(msg) + + toslice = x._impl # pylint: disable=W0212 + if not isinstance(toslice, ak.Array): + toslice = ak.Array(toslice[np.newaxis]) # type: ignore[index] + + if not isinstance(indices, array): + indices = array(indices) # type: ignore[unreachable] + indexarray = indices._impl # pylint: disable=W0212 + + slicer = (slice(None),) * axis + (indexarray,) + return _box(type(x), toslice[slicer]) diff --git a/tests/test_spec_indexing.py b/tests/test_spec_indexing.py index 1776afa..b11efd8 100644 --- a/tests/test_spec_indexing.py +++ b/tests/test_spec_indexing.py @@ -6,6 +6,25 @@ from __future__ import annotations +import ragged + def test(): - pass + # slices are extensively tested in Awkward Array, just check 'axis' argument + + a = ragged.array([0.0, 1.1, 2.2, 3.3, 4.4, 5.5, 6.6, 7.7, 8.8, 9.9]) + assert ragged.take(a, ragged.array([5, 3, 3, 9, 0, 1]), axis=0).tolist() == [ + 5.5, + 3.3, + 3.3, + 9.9, + 0, + 1.1, + ] + + b = ragged.array([[0.0, 1.1, 2.2], [3.3, 4.4], [5.5, 6.6, 7.7, 8.8, 9.9]]) + assert ragged.take(b, ragged.array([0, 1, 1, 0]), axis=1).tolist() == [ + [0, 1.1, 1.1, 0], + [3.3, 4.4, 4.4, 3.3], + [5.5, 6.6, 6.6, 5.5], + ]