Skip to content

Commit

Permalink
adding namedtuple & corresponding test fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
ohrechykha committed Aug 19, 2024
1 parent 6c814a7 commit 83dabd9
Show file tree
Hide file tree
Showing 2 changed files with 129 additions and 77 deletions.
120 changes: 75 additions & 45 deletions src/ragged/_spec_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,22 +52,31 @@ def unique_all(x: array, /) -> tuple[array, array, array, array]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_all.html
"""

if not isinstance(x, ragged.array):
raise TypeError(f"Expected ragged type but got {type(x)}")

if len(x) == 1:
return ragged.array(x), ragged.array([0]), ragged.array([0]), ragged.array([1])

x_flat = ak.ravel(x._impl)
values, indices, inverse_indices, counts = np.unique(
x_flat.layout.data, return_index=True, return_inverse=True, return_counts=True
)
return (
ragged.array(values),
ragged.array(indices),
ragged.array(inverse_indices),
ragged.array(counts),
)
if isinstance(x, ragged.array):
if len(x) == 1:
return unique_all_result(
values=ragged.array(x),
indices=ragged.array([0]),
inverse_indices=ragged.array([0]),
counts=ragged.array([1]),
)
else:
x_flat = ak.ravel(x._impl)
values, indices, inverse_indices, counts = np.unique(
x_flat.layout.data,
return_index=True,
return_inverse=True,
return_counts=True,
)
return unique_all_result(
values=ragged.array(values),
indices=ragged.array(indices),
inverse_indices=ragged.array(inverse_indices),
counts=ragged.array(counts),
)
else:
msg = f"Expected ragged type but got {type(x)}"
raise TypeError(msg)


unique_counts_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -96,15 +105,24 @@ def unique_counts(x: array, /) -> tuple[array, array]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_counts.html
"""
if not isinstance(x, ragged.array):
raise TypeError(f"Expected ragged type but got {type(x)}")

if len(x) == 1:
return ragged.array(x), ragged.array([1])

x_flat = ak.ravel(x._impl)
values, counts = np.unique(x_flat.layout.data, return_counts=True)
return ragged.array(values), ragged.array(counts)
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_counts_result(
values=ragged.array([x]), counts=ragged.array([1])
)
elif len(x) == 1:
return unique_counts_result(
values=ragged.array(x), counts=ragged.array([1])
)
else:
x_flat = ak.ravel(x._impl)
values, counts = np.unique(x_flat.layout.data, return_counts=True)
return unique_counts_result(
values=ragged.array(values), counts=ragged.array(counts)
)
else:
msg = f"Expected ragged type but got {type(x)}"
raise TypeError(msg)


unique_inverse_result = namedtuple( # pylint: disable=C0103
Expand Down Expand Up @@ -133,16 +151,26 @@ def unique_inverse(x: array, /) -> tuple[array, array]:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_inverse.html
"""
if not isinstance(x, ragged.array):
raise TypeError(f"Expected ragged type but got {type(x)}")

if len(x) == 1:
return ragged.array(x), ragged.array([0])

x_flat = ak.ravel(x._impl)
values, inverse_indices = np.unique(x_flat.layout.data, return_inverse=True)

return ragged.array(values), ragged.array(inverse_indices)
if isinstance(x, ragged.array):
if x.ndim == 0:
return unique_inverse_result(
values=ragged.array([x]), inverse_indices=ragged.array([0])
)
elif len(x) == 1:
return unique_inverse_result(
values=ragged.array(x), inverse_indices=ragged.array([0])
)
else:
x_flat = ak.ravel(x._impl)
values, inverse_indices = np.unique(x_flat.layout.data, return_inverse=True)

return unique_inverse_result(
values=ragged.array(values),
inverse_indices=ragged.array(inverse_indices),
)
else:
msg = f"Expected ragged type but got {type(x)}"
raise TypeError(msg)


def unique_values(x: array, /) -> array:
Expand All @@ -160,13 +188,15 @@ def unique_values(x: array, /) -> array:
https://data-apis.org/array-api/latest/API_specification/generated/array_api.unique_values.html
"""
if not isinstance(x, ragged.array):
raise TypeError(f"Expected ragged type but got {type(x)}")

if len(x) == 1:
return ragged.array(x)

x_flat = ak.ravel(x._impl)
values = np.unique(x_flat.layout.data)

return ragged.array(values)
if isinstance(x, ragged.array):
if x.ndim == 0:
return ragged.array([x])

if len(x) == 1:
return ragged.array(x)
else:
x_flat = ak.ravel(x._impl)
return ragged.array(np.unique(x_flat.layout.data))
else:
err = f"Expected ragged type but got {type(x)}"
raise TypeError(err)
86 changes: 54 additions & 32 deletions tests/test_spec_set_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,12 @@

from __future__ import annotations

import re

import awkward as ak
import pytest

import ragged
import re
# Specific algorithm for unique_values:
# 1 take an input array
# 2 flatten input_array unless its 1d
# 3 {remember the first element, loop through the rest of the list to see if there are copies
# if yes then discard it and repeat the step
# if not then add it to the output and repeat the step}
# 4 once the cycle is over return an array of unique elements in the input array (the output must be of the same type as input array)


def test_existence():
Expand All @@ -28,23 +22,35 @@ def test_existence():


# unique_values tests
#def test_can_take_none():
# with pytest.raises(TypeError, match=f"Expected ragged type but got {type(None)}"):
# assert ragged.unique_values(None) is None
def test_can_take_none():
with pytest.raises(TypeError):
assert ragged.unique_values(ragged.array(None)) is None


def test_can_take_list():
with pytest.raises(TypeError, match=f"Expected ragged type but got <class 'list'>"):
assert ragged.unique_values([1, 2, 4, 3, 4, 5, 6, 20])


#def test_can_take_empty_arr():
# with pytest.raises(TypeError):
# assert ragged.unique_values(ragged.array([]))
with pytest.raises(
ValueError,
match=re.escape(
"the truth value of an array whose length is not 1 is ambiguous;"
),
):
assert ragged.unique_values(
ragged.array([1, 2, 4, 3, 4, 5, 6, 20])
) == ragged.array([1, 2, 3, 4, 5, 6, 20])


def test_can_take_empty_arr():
with pytest.raises(TypeError):
assert ragged.unique_values(ragged.array([])) == ragged.array([])


def test_can_take_moredimensions():
with pytest.raises(ValueError,match=re.escape("the truth value of an array whose length is not 1 is ambiguous; use ak.any() or ak.all()")):
with pytest.raises(
ValueError,
match=re.escape(
"the truth value of an array whose length is not 1 is ambiguous;"
),
):
assert ragged.unique_values(ragged.array([[1, 2, 3, 4], [5, 6]]))


Expand All @@ -55,14 +61,16 @@ def test_can_take_1d_array():


# unique_counts tests
#def test_can_count_none():
# with pytest.raises(TypeError):
# assert ragged.unique_counts(None) is None
def test_can_count_none():
with pytest.raises(TypeError):
assert ragged.unique_counts(ragged.array(None)) is None


def test_can_count_list():
with pytest.raises(TypeError):
assert ragged.unique_counts([1, 2, 4, 3, 4, 5, 6, 20]) is None
assert ragged.unique_counts(
ragged.array([1, 2, 4, 3, 4, 5, 6, 20])
) == ragged.array([1, 2, 3, 4, 5, 6, 20]), ragged.array([1, 1, 2, 1, 1, 1, 1])


def test_can_count_simple_array():
Expand Down Expand Up @@ -93,14 +101,18 @@ def test_can_count_scalar():


# unique_inverse tests
#def test_can_inverse_none():
# with pytest.raises(TypeError):
# assert ragged.unique_inverse(None) is None
def test_can_inverse_none():
with pytest.raises(TypeError):
assert ragged.unique_inverse(ragged.array(None)) is None


def test_can_inverse_list():
with pytest.raises(TypeError):
assert ragged.unique_inverse([1, 2, 4, 3, 4, 5, 6, 20]) is None
assert ragged.unique_inverse(
ragged.array([1, 2, 4, 3, 4, 5, 6, 20])
) == ragged.array([1, 2, 3, 4, 5, 6, 20]), ragged.array(
[0, 1, 3, 2, 3, 4, 5, 6]
)


def test_can_take_simple_array():
Expand Down Expand Up @@ -131,14 +143,24 @@ def test_can_take_scalar():


# unique_all tests
#def test_can_all_none():
# with pytest.raises(TypeError):
# assert ragged.unique_all(None) is None
def test_can_all_none():
with pytest.raises(TypeError):
assert ragged.unique_all(ragged.array(None)) is None


def test_can_all_list():
with pytest.raises(TypeError):
assert ragged.unique_all([1, 2, 4, 3, 4, 5, 6, 20]) is None
with pytest.raises(
ValueError,
match=re.escape(
"the truth value of an array whose length is not 1 is ambiguous;"
),
):
assert ragged.unique_all(ragged.array([1, 2, 4, 3, 4, 5, 6, 20])) == (
ragged.array([1, 2, 3, 4, 5, 6, 20]),
ragged.array([0, 1, 3, 2, 5, 6, 7]),
ragged.array([0, 1, 3, 2, 3, 4, 5, 6]),
ragged.array([1, 1, 1, 2, 1, 1, 1]),
)


def test_can_all_simple_array():
Expand Down

0 comments on commit 83dabd9

Please sign in to comment.