From 3102d03a7bb44357a83ed6031818dbbb34296888 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Thu, 19 Oct 2023 17:18:36 +0100 Subject: [PATCH] feat: make broadcasting preserve behaviors --- src/awkward/operations/ak_broadcast_arrays.py | 12 +++++++++--- src/awkward/operations/ak_broadcast_fields.py | 10 +++++++--- 2 files changed, 16 insertions(+), 6 deletions(-) diff --git a/src/awkward/operations/ak_broadcast_arrays.py b/src/awkward/operations/ak_broadcast_arrays.py index 18943e48e6..906c1a357e 100644 --- a/src/awkward/operations/ak_broadcast_arrays.py +++ b/src/awkward/operations/ak_broadcast_arrays.py @@ -3,7 +3,7 @@ import awkward as ak from awkward._backends.dispatch import backend_of from awkward._backends.numpy import NumpyBackend -from awkward._behavior import behavior_of +from awkward._behavior import behavior_of_obj from awkward._connect.numpy import UNSUPPORTED from awkward._dispatch import high_level_function from awkward._layout import wrap_layout @@ -228,7 +228,6 @@ def action(inputs, depth, **kwargs): else: return None - behavior = behavior_of(*arrays, behavior=behavior) out = ak._broadcasting.broadcast_and_apply( inputs, action, @@ -238,7 +237,14 @@ def action(inputs, depth, **kwargs): numpy_to_regular=True, ) assert isinstance(out, tuple) - return [wrap_layout(x, behavior, highlevel) for x in out] + return [ + wrap_layout( + content, + behavior=behavior_of_obj(array, behavior=behavior), + highlevel=highlevel, + ) + for content, array in zip(out, arrays) + ] @ak._connect.numpy.implements("broadcast_arrays") diff --git a/src/awkward/operations/ak_broadcast_fields.py b/src/awkward/operations/ak_broadcast_fields.py index 64c0c20ab7..fbae8d7851 100644 --- a/src/awkward/operations/ak_broadcast_fields.py +++ b/src/awkward/operations/ak_broadcast_fields.py @@ -3,7 +3,7 @@ import awkward as ak from awkward._backends.dispatch import backend_of from awkward._backends.numpy import NumpyBackend -from awkward._behavior import behavior_of +from awkward._behavior import behavior_of_obj from awkward._dispatch import high_level_function from awkward._layout import wrap_layout @@ -58,7 +58,6 @@ def broadcast_fields(*arrays, highlevel=True, behavior=None): def _impl(arrays, highlevel, behavior): backend = backend_of(*arrays, default=cpu, coerce_to_common=True) layouts = [ak.to_layout(x).to_backend(backend) for x in arrays] - behavior = behavior_of(*arrays, behavior=behavior) def identity(content): return content @@ -156,5 +155,10 @@ def recurse(inputs): return [pull(layout) for pull, layout in zip(pullbacks, inner_layouts)] return [ - wrap_layout(x, highlevel=highlevel, behavior=behavior) for x in recurse(layouts) + wrap_layout( + content, + behavior=behavior_of_obj(array, behavior=behavior), + highlevel=highlevel, + ) + for content, array in zip(recurse(layouts), arrays) ]