Skip to content

Commit

Permalink
feat: make broadcasting preserve behaviors
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 19, 2023
1 parent f4ac46f commit 3102d03
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 6 deletions.
12 changes: 9 additions & 3 deletions src/awkward/operations/ak_broadcast_arrays.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import awkward as ak
from awkward._backends.dispatch import backend_of
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of
from awkward._behavior import behavior_of_obj
from awkward._connect.numpy import UNSUPPORTED
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout
Expand Down Expand Up @@ -228,7 +228,6 @@ def action(inputs, depth, **kwargs):
else:
return None

behavior = behavior_of(*arrays, behavior=behavior)
out = ak._broadcasting.broadcast_and_apply(
inputs,
action,
Expand All @@ -238,7 +237,14 @@ def action(inputs, depth, **kwargs):
numpy_to_regular=True,
)
assert isinstance(out, tuple)
return [wrap_layout(x, behavior, highlevel) for x in out]
return [
wrap_layout(
content,
behavior=behavior_of_obj(array, behavior=behavior),
highlevel=highlevel,
)
for content, array in zip(out, arrays)
]


@ak._connect.numpy.implements("broadcast_arrays")
Expand Down
10 changes: 7 additions & 3 deletions src/awkward/operations/ak_broadcast_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import awkward as ak
from awkward._backends.dispatch import backend_of
from awkward._backends.numpy import NumpyBackend
from awkward._behavior import behavior_of
from awkward._behavior import behavior_of_obj
from awkward._dispatch import high_level_function
from awkward._layout import wrap_layout

Expand Down Expand Up @@ -58,7 +58,6 @@ def broadcast_fields(*arrays, highlevel=True, behavior=None):
def _impl(arrays, highlevel, behavior):
backend = backend_of(*arrays, default=cpu, coerce_to_common=True)
layouts = [ak.to_layout(x).to_backend(backend) for x in arrays]
behavior = behavior_of(*arrays, behavior=behavior)

def identity(content):
return content
Expand Down Expand Up @@ -156,5 +155,10 @@ def recurse(inputs):
return [pull(layout) for pull, layout in zip(pullbacks, inner_layouts)]

return [
wrap_layout(x, highlevel=highlevel, behavior=behavior) for x in recurse(layouts)
wrap_layout(
content,
behavior=behavior_of_obj(array, behavior=behavior),
highlevel=highlevel,
)
for content, array in zip(recurse(layouts), arrays)
]

0 comments on commit 3102d03

Please sign in to comment.