Skip to content

Commit

Permalink
feat: set functions and tests (#57)
Browse files Browse the repository at this point in the history
* adding set functions and tests

* pushing pre-commit changes

* ruff fixes for test_spec_set_functions.py

* further fixes in test_spec_set_functions.py

* fixing mypy unreachable errors in _spec_set_functions.py

* marking tests with None and empty arrays as comments

* adding namedtuple & corresponding test fixes

* function if changes + test standartization

* correcting function ifs + test standartization

* further test standartisation

* scalar handling and testing

* _array_object changes, empty array handling + tests

* implementing Jim's suggestion, disabling CI errors

* disabling errors and warnings

* further ignores + adding equal_nan in np.unique instances

* better ignores

* improving ignores

* avoiding code duplication in _spec_array_object

Co-authored-by: Jim Pivarski <[email protected]>

* returning np.empty and input dtype in all functions

---------

Co-authored-by: Jim Pivarski <[email protected]>
Co-authored-by: Jim Pivarski <[email protected]>
  • Loading branch information
3 people authored Sep 13, 2024
1 parent 72230bf commit 0b31f72
Show file tree
Hide file tree
Showing 4 changed files with 334 additions and 12 deletions.
2 changes: 1 addition & 1 deletion src/ragged/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@
nonzero,
where,
)
from ._spec_set_functions import (
from ._spec_set_functions import ( # pylint: disable=R0401
unique_all,
unique_counts,
unique_inverse,
Expand Down
3 changes: 3 additions & 0 deletions src/ragged/_spec_array_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import numpy as np
from awkward.contents import (
Content,
EmptyArray,
ListArray,
ListOffsetArray,
NumpyArray,
Expand Down Expand Up @@ -44,6 +45,8 @@ def _shape_dtype(layout: Content) -> tuple[Shape, Dtype]:
else:
shape = (*shape, None)
node = node.content
if isinstance(node, EmptyArray):
node = node.to_NumpyArray(dtype=np.float64)

if isinstance(node, NumpyArray):
shape = shape + node.data.shape[1:]
Expand Down
111 changes: 100 additions & 11 deletions src/ragged/_spec_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@

from collections import namedtuple

import awkward as ak
import numpy as np

import ragged

from ._spec_array_object import array

unique_all_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -47,8 +52,39 @@ def unique_all(x: array, /) -> tuple[array, array, array, array]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_all.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 128") # noqa: EM101
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_all_result(
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
indices=ragged.array([0]),
inverse_indices=ragged.array([0]),
counts=ragged.array([1]),
)
else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_all_result(
values=ragged.array(np.empty(0, x.dtype)),
indices=ragged.array(np.empty(0, np.int64)),
inverse_indices=ragged.array(np.empty(0, np.int64)),
counts=ragged.array(np.empty(0, np.int64)),
)
values, indices, inverse_indices, counts = np.unique(
x_flat.layout.data, # pylint: disable=E1101
return_index=True,
return_inverse=True,
return_counts=True,
equal_nan=False,
)
return unique_all_result(
values=ragged.array(values),
indices=ragged.array(indices),
inverse_indices=ragged.array(inverse_indices),
counts=ragged.array(counts),
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(msg)


unique_counts_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -77,9 +113,30 @@ def unique_counts(x: array, /) -> tuple[array, array]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_counts.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 129") # noqa: EM101
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_counts_result(
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
counts=ragged.array([1]), # pylint: disable=W0212
)
else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_counts_result(
values=ragged.array(np.empty(0, x.dtype)),
counts=ragged.array(np.empty(0, np.int64)),
)
values, counts = np.unique(
x_flat.layout.data, # pylint: disable=E1101
return_counts=True,
equal_nan=False,
)
return unique_counts_result(
values=ragged.array(values), counts=ragged.array(counts)
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(msg)


unique_inverse_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -108,9 +165,32 @@ def unique_inverse(x: array, /) -> tuple[array, array]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_inverse.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 130") # noqa: EM101
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_inverse_result(
values=ragged.array(np.unique(x._impl, equal_nan=False)), # pylint: disable=W0212
inverse_indices=ragged.array([0]),
)
else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return unique_inverse_result(
values=ragged.array(np.empty(0, x.dtype)),
inverse_indices=ragged.array(np.empty(0, np.int64)),
)
values, inverse_indices = np.unique(
x_flat.layout.data, # pylint: disable=E1101
return_inverse=True,
equal_nan=False,
)

return unique_inverse_result(
values=ragged.array(values),
inverse_indices=ragged.array(inverse_indices),
)
else:
msg = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(msg)


def unique_values(x: array, /) -> array:
Expand All @@ -128,6 +208,15 @@ def unique_values(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_values.html
"""

x # noqa: B018, pylint: disable=W0104
raise NotImplementedError("TODO 131") # noqa: EM101
if isinstance(x, ragged.array):
if x.ndim == 0:
return ragged.array(np.unique(x._impl, equal_nan=False)) # pylint: disable=W0212

else:
x_flat = ak.ravel(x._impl) # pylint: disable=W0212
if isinstance(x_flat.layout, ak.contents.EmptyArray): # pylint: disable=E1101
return ragged.array(np.empty(0, x.dtype))
return ragged.array(np.unique(x_flat.layout.data, equal_nan=False)) # pylint: disable=E1101
else:
err = f"Expected ragged type but got {type(x)}" # type: ignore[unreachable]
raise TypeError(err)
Loading

0 comments on commit 0b31f72

Please sign in to comment.