From 6c2a4735aeae85d2b987487d19d037d90c34d846 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Sun, 22 Oct 2023 20:34:35 +0100 Subject: [PATCH] fix: always return input for non-promoted scalars --- src/awkward/operations/ak_to_layout.py | 22 +++++++++++++++------- 1 file changed, 15 insertions(+), 7 deletions(-) diff --git a/src/awkward/operations/ak_to_layout.py b/src/awkward/operations/ak_to_layout.py index 8cf7092900..6c8c8ebed9 100644 --- a/src/awkward/operations/ak_to_layout.py +++ b/src/awkward/operations/ak_to_layout.py @@ -95,12 +95,13 @@ def maybe_merge_mappings(primary, secondary): return {**primary, **secondary} -def _handle_as_scalar(obj, layout, *, scalar_policy): +def _handle_as_scalar(obj, *, scalar_policy): assert scalar_policy in ("allow", "promote", "error") if scalar_policy == "allow": - return layout[0] + return obj elif scalar_policy == "promote": + layout = ak.operations.from_iter([obj], highlevel=False) return layout else: assert scalar_policy == "error" @@ -110,8 +111,17 @@ def _handle_as_scalar(obj, layout, *, scalar_policy): def _handle_array_like(obj, layout, *, scalar_policy): + assert scalar_policy in ("allow", "promote", "error") if obj.ndim == 0: - return _handle_as_scalar(obj, layout, scalar_policy=scalar_policy) + if scalar_policy == "allow": + return obj + elif scalar_policy == "promote": + return layout + else: + assert scalar_policy == "error" + raise TypeError( + f"Encountered a scalar ({type(obj).__name__}), but scalars conversion/promotion is disabled" + ) else: return layout @@ -204,12 +214,10 @@ def _impl( f"Encountered a scalar ({type(obj).__name__}), but scalars conversion/promotion is disabled" ) elif isinstance(obj, (datetime, date, time, Number, bool)): - layout = ak.operations.from_iter([obj], highlevel=False) - return _handle_as_scalar(obj, layout, scalar_policy=scalar_policy) + return _handle_as_scalar(obj, scalar_policy=scalar_policy) elif obj is None: if allow_none: - layout = ak.operations.from_iter([obj], highlevel=False) - return _handle_as_scalar(obj, layout, scalar_policy=scalar_policy) + return _handle_as_scalar(obj, scalar_policy=scalar_policy) else: raise TypeError("Encountered None value, and `allow_none` is `False`") # Iterables