From 9fec71e990bb0cd15061ce2f09a05654ba581278 Mon Sep 17 00:00:00 2001 From: Angus Hollands Date: Sun, 22 Oct 2023 22:14:59 +0100 Subject: [PATCH] refactor: prepare scalars for broadcastin --- src/awkward/_connect/numpy.py | 10 ++++++--- src/awkward/_nplikes/array_module.py | 32 +++++++++++++++++++++++++--- src/awkward/_nplikes/numpylike.py | 4 ++++ src/awkward/_nplikes/typetracer.py | 32 +++++++++++++++++----------- 4 files changed, 60 insertions(+), 18 deletions(-) diff --git a/src/awkward/_connect/numpy.py b/src/awkward/_connect/numpy.py index f81d909d66..d22f25353c 100644 --- a/src/awkward/_connect/numpy.py +++ b/src/awkward/_connect/numpy.py @@ -416,11 +416,15 @@ def action(inputs, **ignore): args = [x.data if isinstance(x, NumpyArray) else x for x in inputs] + # Explicitly promote and broadcast inputs, to ensure the correct promotion behavior + broadcasted_args = backend.nplike.broadcast_arrays( + *backend.nplike.promote_scalars(*args[: ufunc.nargs - 1]) + ) + non_broadcasted_args = args[ufunc.nargs - 1 :] + # Give backend a chance to change the ufunc implementation impl = backend.prepare_ufunc(ufunc) - - # Invoke ufunc - result = impl(*args, **kwargs) + result = impl(*broadcasted_args, *non_broadcasted_args, **kwargs) if isinstance(result, tuple): return tuple( diff --git a/src/awkward/_nplikes/array_module.py b/src/awkward/_nplikes/array_module.py index 291fec4623..56f2a6f2d8 100644 --- a/src/awkward/_nplikes/array_module.py +++ b/src/awkward/_nplikes/array_module.py @@ -2,6 +2,7 @@ from __future__ import annotations import math +from numbers import Number import numpy @@ -35,16 +36,22 @@ def asarray( assert obj.dtype == dtype or dtype is None return obj elif copy: - return self._module.array(obj, dtype=dtype, copy=True) + result = self._module.array(obj, dtype=dtype, copy=True) + assert result.dtype != np.dtype("object") + return result elif copy is None: - return self._module.asarray(obj, dtype=dtype) + result = self._module.asarray(obj, dtype=dtype) + assert result.dtype != np.dtype("object") + return result else: if getattr(obj, "dtype", dtype) != dtype: raise ValueError( "asarray was called with copy=False for an array of a different dtype" ) else: - return self._module.asarray(obj, dtype=dtype) + result = self._module.asarray(obj, dtype=dtype) + assert result.dtype != np.dtype("object") + return result def ascontiguousarray(self, x: ArrayLike) -> ArrayLike: if isinstance(x, PlaceholderArray): @@ -146,8 +153,27 @@ def searchsorted( ############################ manipulation + def promote_scalars(self, *array_or_scalars: Any) -> list[ArrayLike]: + return [ + x if hasattr(x, "ndim") else self.promote_scalar(x) + for x in array_or_scalars + ] + + def promote_scalar(self, obj) -> ArrayLike: + assert not isinstance(obj, PlaceholderArray) + if hasattr(obj, "ndim"): + assert obj.ndim == 0 + return obj + elif isinstance(obj, (Number, bool)): + return self.asarray(obj) + else: + raise TypeError(f"expected scalar type, received {obj}") + def broadcast_arrays(self, *arrays: ArrayLike) -> list[ArrayLike]: assert not any(isinstance(x, PlaceholderArray) for x in arrays) + # Ensure we have arrays! + assert all(hasattr(x, "ndim") for x in arrays) + return self._module.broadcast_arrays(*arrays) def _compute_compatible_shape( diff --git a/src/awkward/_nplikes/numpylike.py b/src/awkward/_nplikes/numpylike.py index b2bbd29f5f..54744723a7 100644 --- a/src/awkward/_nplikes/numpylike.py +++ b/src/awkward/_nplikes/numpylike.py @@ -377,6 +377,10 @@ def searchsorted( ############################ manipulation + @abstractmethod + def promote_scalars(self, *array_or_scalars: Any) -> list[ArrayLike]: + ... + @abstractmethod def broadcast_arrays(self, *arrays: ArrayLike) -> list[ArrayLike]: ... diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py index 357b222ce9..f0a3352991 100644 --- a/src/awkward/_nplikes/typetracer.py +++ b/src/awkward/_nplikes/typetracer.py @@ -387,8 +387,7 @@ def __getitem__( try_touch_data(item) try_touch_data(self) - if is_unknown_scalar(item): - item = self.nplike.promote_scalar(item) + item = self.nplike.promote_scalar(item) # If this is the first advanced index, insert the location if not advanced_shapes: @@ -506,7 +505,8 @@ class TypeTracer(NumpyLike): def _apply_ufunc(self, ufunc, *inputs): for x in inputs: - assert not isinstance(x, PlaceholderArray) + # This routine shouldn't operate on scalars! + assert isinstance(x, TypeTracerArray) try_touch_data(x) inputs = [x.content if isinstance(x, MaybeNone) else x for x in inputs] @@ -802,14 +802,23 @@ def searchsorted( ############################ manipulation + def promote_scalars(self, *array_or_scalars: Any) -> list[TypeTracerArray]: + return [ + x if hasattr(x, "ndim") else self.promote_scalar(x) + for x in array_or_scalars + ] + def promote_scalar(self, obj) -> TypeTracerArray: assert not isinstance(obj, PlaceholderArray) if is_unknown_scalar(obj): return obj + # NOTE: This routine is value-based, but will be consistent for typetracer + # because these are Python scalars which should not have been + # produced by kernels. i.e., they should be static constants + # The known-data equivalent of this function should check for + # instances of np.generic to catch this elif isinstance(obj, (Number, bool)): - # TODO: statically define these types for all nplikes - as_array = numpy.asarray(obj) - return TypeTracerArray._new(as_array.dtype, ()) + return self.asarray(obj) else: raise TypeError(f"expected scalar type, received {obj}") @@ -969,16 +978,15 @@ def broadcast_arrays(self, *arrays: ArrayLike) -> list[TypeTracerArray]: if len(arrays) == 0: return [] - all_arrays = [] for x in arrays: - if not hasattr(x, "shape"): - x = self.promote_scalar(x) - all_arrays.append(x) + assert hasattr( + x, "shape" + ), "encountered non-ArrayLike in `broadcast_arrays`" - shapes = [x.shape for x in all_arrays] + shapes = [x.shape for x in arrays] shape = self.broadcast_shapes(*shapes) - return [TypeTracerArray._new(x.dtype, shape=shape) for x in all_arrays] + return [TypeTracerArray._new(x.dtype, shape=shape) for x in arrays] def broadcast_to( self, x: ArrayLike, shape: tuple[ShapeItem, ...]