Skip to content

Commit

Permalink
refactor: prepare scalars for broadcastin
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 22, 2023
1 parent baf2df7 commit 9fec71e
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 18 deletions.
10 changes: 7 additions & 3 deletions src/awkward/_connect/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,11 +416,15 @@ def action(inputs, **ignore):

args = [x.data if isinstance(x, NumpyArray) else x for x in inputs]

# Explicitly promote and broadcast inputs, to ensure the correct promotion behavior
broadcasted_args = backend.nplike.broadcast_arrays(
*backend.nplike.promote_scalars(*args[: ufunc.nargs - 1])
)
non_broadcasted_args = args[ufunc.nargs - 1 :]

# Give backend a chance to change the ufunc implementation
impl = backend.prepare_ufunc(ufunc)

# Invoke ufunc
result = impl(*args, **kwargs)
result = impl(*broadcasted_args, *non_broadcasted_args, **kwargs)

if isinstance(result, tuple):
return tuple(
Expand Down
32 changes: 29 additions & 3 deletions src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import math
from numbers import Number

import numpy

Expand Down Expand Up @@ -35,16 +36,22 @@ def asarray(
assert obj.dtype == dtype or dtype is None
return obj
elif copy:
return self._module.array(obj, dtype=dtype, copy=True)
result = self._module.array(obj, dtype=dtype, copy=True)
assert result.dtype != np.dtype("object")
return result
elif copy is None:
return self._module.asarray(obj, dtype=dtype)
result = self._module.asarray(obj, dtype=dtype)
assert result.dtype != np.dtype("object")
return result
else:
if getattr(obj, "dtype", dtype) != dtype:
raise ValueError(
"asarray was called with copy=False for an array of a different dtype"
)
else:
return self._module.asarray(obj, dtype=dtype)
result = self._module.asarray(obj, dtype=dtype)
assert result.dtype != np.dtype("object")
return result

def ascontiguousarray(self, x: ArrayLike) -> ArrayLike:
if isinstance(x, PlaceholderArray):
Expand Down Expand Up @@ -146,8 +153,27 @@ def searchsorted(

############################ manipulation

def promote_scalars(self, *array_or_scalars: Any) -> list[ArrayLike]:
return [
x if hasattr(x, "ndim") else self.promote_scalar(x)
for x in array_or_scalars
]

def promote_scalar(self, obj) -> ArrayLike:
assert not isinstance(obj, PlaceholderArray)
if hasattr(obj, "ndim"):
assert obj.ndim == 0
return obj
elif isinstance(obj, (Number, bool)):
return self.asarray(obj)
else:
raise TypeError(f"expected scalar type, received {obj}")

def broadcast_arrays(self, *arrays: ArrayLike) -> list[ArrayLike]:
assert not any(isinstance(x, PlaceholderArray) for x in arrays)
# Ensure we have arrays!
assert all(hasattr(x, "ndim") for x in arrays)

return self._module.broadcast_arrays(*arrays)

def _compute_compatible_shape(
Expand Down
4 changes: 4 additions & 0 deletions src/awkward/_nplikes/numpylike.py
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,10 @@ def searchsorted(

############################ manipulation

@abstractmethod
def promote_scalars(self, *array_or_scalars: Any) -> list[ArrayLike]:
...

@abstractmethod
def broadcast_arrays(self, *arrays: ArrayLike) -> list[ArrayLike]:
...
Expand Down
32 changes: 20 additions & 12 deletions src/awkward/_nplikes/typetracer.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,8 +387,7 @@ def __getitem__(
try_touch_data(item)
try_touch_data(self)

if is_unknown_scalar(item):
item = self.nplike.promote_scalar(item)
item = self.nplike.promote_scalar(item)

# If this is the first advanced index, insert the location
if not advanced_shapes:
Expand Down Expand Up @@ -506,7 +505,8 @@ class TypeTracer(NumpyLike):

def _apply_ufunc(self, ufunc, *inputs):
for x in inputs:
assert not isinstance(x, PlaceholderArray)
# This routine shouldn't operate on scalars!
assert isinstance(x, TypeTracerArray)
try_touch_data(x)

inputs = [x.content if isinstance(x, MaybeNone) else x for x in inputs]
Expand Down Expand Up @@ -802,14 +802,23 @@ def searchsorted(

############################ manipulation

def promote_scalars(self, *array_or_scalars: Any) -> list[TypeTracerArray]:
return [
x if hasattr(x, "ndim") else self.promote_scalar(x)
for x in array_or_scalars
]

def promote_scalar(self, obj) -> TypeTracerArray:
assert not isinstance(obj, PlaceholderArray)
if is_unknown_scalar(obj):
return obj
# NOTE: This routine is value-based, but will be consistent for typetracer
# because these are Python scalars which should not have been
# produced by kernels. i.e., they should be static constants
# The known-data equivalent of this function should check for
# instances of np.generic to catch this
elif isinstance(obj, (Number, bool)):
# TODO: statically define these types for all nplikes
as_array = numpy.asarray(obj)
return TypeTracerArray._new(as_array.dtype, ())
return self.asarray(obj)
else:
raise TypeError(f"expected scalar type, received {obj}")

Expand Down Expand Up @@ -969,16 +978,15 @@ def broadcast_arrays(self, *arrays: ArrayLike) -> list[TypeTracerArray]:
if len(arrays) == 0:
return []

all_arrays = []
for x in arrays:
if not hasattr(x, "shape"):
x = self.promote_scalar(x)
all_arrays.append(x)
assert hasattr(
x, "shape"
), "encountered non-ArrayLike in `broadcast_arrays`"

shapes = [x.shape for x in all_arrays]
shapes = [x.shape for x in arrays]
shape = self.broadcast_shapes(*shapes)

return [TypeTracerArray._new(x.dtype, shape=shape) for x in all_arrays]
return [TypeTracerArray._new(x.dtype, shape=shape) for x in arrays]

def broadcast_to(
self, x: ArrayLike, shape: tuple[ShapeItem, ...]
Expand Down

0 comments on commit 9fec71e

Please sign in to comment.