Skip to content

Commit

Permalink
fix: unexpected behaviour of akwhere with arrays containing nones (#3168
Browse files Browse the repository at this point in the history
)

Adding functionality to ak.where (mostly in _broadcasting.py) to support
operating on arrays with optional values.

* Adding unit tests for 3098

* My spotless (haha) code got ruffed up a little.

* First attempted fix, partial success

This still has failures and multiple TODOs in the code. But one group of unit tests passes!

* More fixes, more limited success

* ak_where Fixes for unknown arrays, more unit tests, new failure.

* Fixing string arrays, removing broken parameter support

* Ruffing, removing print statements

* Removing print statements from ak_where.py

---------

Co-authored-by: Ianna Osborne <[email protected]>
  • Loading branch information
tcawlfield and ianna authored Jul 19, 2024
1 parent edfea15 commit a9e48eb
Show file tree
Hide file tree
Showing 3 changed files with 319 additions and 2 deletions.
142 changes: 141 additions & 1 deletion src/awkward/_broadcasting.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from awkward.contents.regulararray import RegularArray
from awkward.contents.unionarray import UnionArray
from awkward.contents.unmaskedarray import UnmaskedArray
from awkward.forms import ByteMaskedForm
from awkward.index import ( # IndexU8, ; Index32, ; IndexU32, ; noqa: F401
Index8,
Index64,
Expand Down Expand Up @@ -756,6 +757,142 @@ def broadcast_any_option():
for x, p in zip(outcontent, parameters)
)

def broadcast_any_option_akwhere():
"""
ak_where is a bit like the ternary operator. Due to asymmetries in the three
inputs (their roles are distinct), special handling is required for option-types.
"""
unmasked = [] # Contents of inputs-as-ByteMaskedArrays or non-Content-type
masks: List[Index8] = []
# Here we choose the convention that elements are masked when mask==1
# And byte masks (not bits) so we can pass them as (x,y) to ak_where's action()
for xyc in inputs: # from ak_where, inputs are (x, y, condition)
if not isinstance(xyc, Content):
unmasked.append(xyc)
masks.append(
NumpyArray(backend.nplike.zeros(len(inputs[2]), dtype=np.int8))
)
elif not xyc.is_option:
unmasked.append(xyc)
masks.append(
NumpyArray(backend.nplike.zeros(xyc.length, dtype=np.int8))
)
elif xyc.is_indexed:
# Indexed arrays have no array elements where None, which is a problem for us.
# We don't care what the element's value is when masked. Just that there *is* a value.
if xyc.content.is_unknown:
# Unknown arrays cannot use to_ByteMaskedArray.
# Create a stand-in array of similar shape and any dtype (we use bool here)
unused_unmasked = NumpyArray(
backend.nplike.zeros(xyc.length, dtype=np.bool_)
)
unmasked.append(unused_unmasked)
all_masked = NumpyArray(
backend.nplike.ones(xyc.length, dtype=np.int8)
)
masks.append(all_masked)
else:
xyc_as_masked = xyc.to_ByteMaskedArray(valid_when=False)
unmasked.append(xyc_as_masked.content)
masks.append(NumpyArray(xyc_as_masked.mask.data))
elif not isinstance(xyc.form, ByteMaskedForm) or xyc.form.valid_when:
# Must make existing mask conform to our convention
xyc_as_bytemasked = xyc.to_ByteMaskedArray(valid_when=False)
unmasked.append(xyc_as_bytemasked.content)
masks.append(NumpyArray(xyc_as_bytemasked.mask.data))
else:
unmasked.append(xyc.content)
masks.append(NumpyArray(xyc.mask.data))

# (1) Apply ak_where action to unmasked inputs
outcontent = apply_step(
backend,
unmasked,
action,
depth,
copy.copy(depth_context),
lateral_context,
options,
)
assert isinstance(outcontent, tuple) and len(outcontent) == 1
xy_unmasked = outcontent[0]

# (2) Now apply ak_where action to unmasked condition and mask arrays for x and y
which_mask = (
masks[0], # Now x is the x-mask
masks[1], # y-mask
unmasked[2], # But same condition as previous
)
outmasks = apply_step(
backend,
which_mask,
action,
depth,
copy.copy(depth_context),
lateral_context,
options,
)
assert len(outmasks) == 1
xy_mask = outmasks[0]

simple_options = BroadcastOptions(
allow_records=True,
left_broadcast=True,
right_broadcast=True,
numpy_to_regular=True,
regular_to_jagged=False,
function_name=None,
broadcast_parameters_rule=BroadcastParameterRule.INTERSECT,
)

# (3) Since condition may be tree-like, use apply_step to OR condition and result masks
def action_logical_or(inputs, backend, **kwargs):
# Return None when condition is None or selected element is None
m1, m2 = inputs
if all(isinstance(x, NumpyArray) for x in inputs):
out = NumpyArray(backend.nplike.logical_or(m1.data, m2.data))
return (out,)

cond_mask = masks[2]
mask = apply_step(
backend,
(xy_mask, cond_mask),
action_logical_or,
0,
None,
lateral_context,
simple_options,
)[0]

# (4) Apply mask to unmasked selection results, recursively
def apply_mask_action(inputs, backend, **kwargs):
if all(
x.is_leaf or (x.branch_depth == (False, 1) and is_string_like(x))
for x in inputs
):
content, mask = inputs
if hasattr(mask, "content"):
mask_as_idx = Index8(mask.content.data)
else:
mask_as_idx = Index8(mask.data)
out = ByteMaskedArray(
mask_as_idx,
content,
valid_when=False,
)
return (out,)

masked = apply_step(
backend,
(xy_unmasked, mask),
apply_mask_action,
0,
None,
lateral_context,
simple_options,
)
return masked

def broadcast_any_union():
nextparameters = []

Expand Down Expand Up @@ -908,7 +1045,10 @@ def continuation():

# Any option-types?
elif any(x.is_option for x in contents):
return broadcast_any_option()
if options["function_name"] == "ak.where":
return broadcast_any_option_akwhere()
else:
return broadcast_any_option()

# Any non-string list-types?
elif any(x.is_list and not is_string_like(x) for x in contents):
Expand Down
4 changes: 3 additions & 1 deletion src/awkward/operations/ak_where.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ def action(inputs, backend, **kwargs):
else:
return None

out = ak._broadcasting.broadcast_and_apply(layouts, action, numpy_to_regular=True)
out = ak._broadcasting.broadcast_and_apply(
layouts, action, numpy_to_regular=True, function_name="ak.where"
)

return ctx.wrap(out[0], highlevel=highlevel)
175 changes: 175 additions & 0 deletions tests/test_3098_ak_where_with_arrays_containing_optionals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,175 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import numpy as np

import awkward as ak
from awkward.operations import to_list


def test_ak_where_with_optional_unknowns():
"""
This is the example from the Issue.
In the two cases we fail, the not-selected value has type ?unknown, value None of course.
Names here are changed a little from github issue 3098.
"""
opt_true_cond = ak.Array([[True], [None]])[0] # <Array [True] type='1 * ?bool'>
true_cond = ak.Array([True]) # <Array [True] type='1 * bool'>
none_alternative = ak.Array([None]) # <Array [None] type='1 * ?unknown'>
zero_alternative = ak.Array([0]) # <Array [0] type='1 * int64'>
opt_zero_alternative = ak.Array([[0], [None]])[0] # <Array [0] type='1 * ?int64'>

assert ak.where(opt_true_cond, 1, none_alternative).to_list() == [1]
# Above fails at time of writing. We get [None].
assert ak.where(opt_true_cond, 1, zero_alternative).to_list() == [1]
assert ak.where(opt_true_cond, 1, opt_zero_alternative).to_list() == [1]

# These assertions pass. Note that true_cond is type bool, not ?bool.
assert ak.where(true_cond, 1, none_alternative).to_list() == [1]
assert ak.where(true_cond, 1, zero_alternative).to_list() == [1]
assert ak.where(true_cond, 1, opt_zero_alternative).to_list() == [1]

# Like the first three assertions, The first one here fails.
# This demonstrates that the problem is symmetric w/rt X and Y arrays.
assert ak.where(~opt_true_cond, none_alternative, 1).to_list() == [1]
# Above fails at time of writing. We're getting [None] again.
assert ak.where(~opt_true_cond, zero_alternative, 1).to_list() == [1]
assert ak.where(~opt_true_cond, opt_zero_alternative, 1).to_list() == [1]


def test_ak_where_with_optionals():
"""
It turns out that we don't need to use ?unknown arrays to trigger this issue.
We only need a None (masked value) in an element that is selected against.
At time-of-writing (ATOW), ak.where() produces None values when:
1. The conditional array values have OptionType (at least ?bool or ?int64), *AND*
2. The value array element *NOT* selected has Option type and holds a None value.
In this case regardless of the type or value of the array element that *IS* selected,
the result for that element will, incorrectly, be None.
"""
# This passes. Note that a condition of None creates a None in the result.
assert to_list(
ak.where(
ak.Array([True, False, None]), ak.Array([1, 2, 3]), ak.Array([4, 5, 6])
)
) == [1, 5, None]

# This also passes. (The presence of None at the end forces option types to be used.)
assert to_list(
ak.where(
ak.Array([True, False, None]),
ak.Array([1, 2, None]),
ak.Array([4, 5, None]),
)
) == [1, 5, None]

# This fails (ATOW). The presence of None forces option types to be used.
assert to_list(
ak.where(
ak.Array([True, False, None]),
ak.Array([1, 2, None]),
ak.Array([None, 5, None]),
)
) == [1, 5, None] # ATOW we get [None, 5, None]

# Fails ATOW. Same as above but with a None in the X argument.
assert to_list(
ak.where(
ak.Array([True, False, None]),
ak.Array([1, None, None]),
ak.Array([4, 5, None]),
)
) == [1, 5, None] # ATOW we get [1, None, None]

# Fails ATOW. Same as above the Y argument is not even an optional type (but X still is).
assert to_list(
ak.where(
ak.Array([True, False, None]),
ak.Array([1, None, None]),
ak.Array([4, 5, 6]),
)
) == [1, 5, None] # ATOW we get [1, None, None]


def test_ak_where_with_optionals_multidim():
# This needs to continue to work:
assert to_list(
ak.where(
ak.Array([True, False]),
ak.Array([[1, 2], [3, 4]]),
ak.Array([[10, 11], [12, 13]]),
)
) == [[1, 2], [12, 13]]

# Option types only in X, not condition
assert to_list(
ak.where(
ak.Array([[True, True], [False, False]]),
ak.Array([[1, 2], None]),
ak.Array([[10, 11], [12, 13]]),
)
) == [[1, 2], [12, 13]]

# Option types in condition and X, only one level of depth
assert to_list(
ak.where(
ak.Array([[True, True], [False, False], [True, False], None]),
ak.Array([[1, 2], None, None, [7, 8]]),
ak.Array([[11, 12], [13, 14], [15, 16], [17, 18]]),
)
) == [[1, 2], [13, 14], [None, 16], [None, None]]
# Note: [[1, 2], [13, 14], [None, 16], None] might seem more natural,
# but broadcasting expands these arrays out.

assert to_list(
ak.where(
ak.Array([[True, False], [True, None]]),
ak.Array([1, 2]),
ak.Array([None, 12]),
)
) == [[1, None], [2, None]]


def test_ak_where_more_option_types():
assert to_list(
ak.where(
ak.Array([False, True, None]),
ak.Array(["this", None, "that"]),
ak.Array(["foo", "bar", "baz"]),
)
) == ["foo", None, None]

bitmasked5 = ak.contents.BitMaskedArray(
mask=ak.index.Index(
np.array(
[
0b10100,
],
dtype=np.uint8,
)
),
content=ak.contents.NumpyArray(np.arange(5)),
valid_when=False,
length=5,
lsb_order=True,
parameters={"_my_param": "boysenberry"},
) # [0, 1, None, 3, None]
unmasked5 = ak.contents.UnmaskedArray(
ak.contents.NumpyArray(np.arange(10, 15))
) # [10, 11, 12, 13, 14]
union5 = ak.Array([True, None, "two", 3, 4.4])

mixed_result = ak.where(
ak.Array([True, None, True, False, True]), bitmasked5, unmasked5
)
assert to_list(mixed_result) == [0, None, None, 13, None]
assert (
mixed_result.layout.parameters.get("_my_param") is None
) # Params not preserved.

assert to_list(
ak.where(ak.Array([True, True, True, False, None]), union5, unmasked5)
) == [True, None, "two", 13, None]

0 comments on commit a9e48eb

Please sign in to comment.