From eaa43ff31d37c525e33a8226693ed273909568d5 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Tue, 24 Sep 2024 16:19:16 -0500 Subject: [PATCH] perf: avoid inflating UnmaskedArrays in broadcasting when you can (#3254) --- src/awkward/_broadcasting.py | 34 +++++++++++++++++++++++++++++++++- 1 file changed, 33 insertions(+), 1 deletion(-) diff --git a/src/awkward/_broadcasting.py b/src/awkward/_broadcasting.py index 7c69212dc2..de5699f73a 100644 --- a/src/awkward/_broadcasting.py +++ b/src/awkward/_broadcasting.py @@ -701,6 +701,36 @@ def broadcast_any_list(): for x, p in zip(outcontent, parameters) ) + def broadcast_any_option_all_UnmaskedArray(): + nextinputs = [] + nextparameters = [] + for x in inputs: + if isinstance(x, UnmaskedArray): + nextinputs.append(x.content) + nextparameters.append(x._parameters) + elif isinstance(x, Content): + nextinputs.append(x) + nextparameters.append(x._parameters) + else: + nextinputs.append(x) + nextparameters.append(NO_PARAMETERS) + + outcontent = apply_step( + backend, + nextinputs, + action, + depth, + copy.copy(depth_context), + lateral_context, + options, + ) + assert isinstance(outcontent, tuple) + parameters = parameters_factory(nextparameters, len(outcontent)) + + return tuple( + UnmaskedArray(x, parameters=p) for x, p in zip(outcontent, parameters) + ) + def broadcast_any_option(): mask = None for x in contents: @@ -1045,7 +1075,9 @@ def continuation(): # Any option-types? elif any(x.is_option for x in contents): - if options["function_name"] == "ak.where": + if all(not x.is_option or isinstance(x, UnmaskedArray) for x in contents): + return broadcast_any_option_all_UnmaskedArray() + elif options["function_name"] == "ak.where": return broadcast_any_option_akwhere() else: return broadcast_any_option()