diff --git a/src/awkward/_nplikes/array_module.py b/src/awkward/_nplikes/array_module.py index 5bead34287..56250fb365 100644 --- a/src/awkward/_nplikes/array_module.py +++ b/src/awkward/_nplikes/array_module.py @@ -4,6 +4,7 @@ import math import numpy +import packaging.version from awkward._nplikes.numpylike import ( ArrayLike, @@ -18,6 +19,9 @@ from awkward._typing import Any, Final, Literal np = NumpyMetadata.instance() +NUMPY_HAS_NEP_50 = packaging.version.Version( + numpy.__version__ +) >= packaging.version.Version("1.24") class ArrayModuleNumpyLike(NumpyLike): @@ -157,15 +161,22 @@ def apply_ufunc( args: list[Any], kwargs: dict[str, Any] | None = None, ) -> ArrayLike | tuple[ArrayLike]: - # Determine input argument dtypes - input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args] - # Resolve these for the given ufunc - arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout) - resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes) - # Interpret the arguments under these dtypes - resolved_args = [ - self.asarray(arg, dtype=dtype) for arg, dtype in zip(args, resolved_dtypes) - ] + if NUMPY_HAS_NEP_50: + # Determine input argument dtypes + input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args] + # Resolve these for the given ufunc + arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout) + resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes) + # Interpret the arguments under these dtypes + resolved_args = [ + self.asarray(arg, dtype=dtype) + for arg, dtype in zip(args, resolved_dtypes) + ] + else: + resolved_args = [ + self.asarray(arg, dtype=arg.dtype) if hasattr(arg, "dtype") else arg + for arg in args + ] # Broadcast these resolved arguments broadcasted_args = self.broadcast_arrays(*resolved_args) # Allow other nplikes to replace implementation diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py index e0474fff07..ddbb6992c9 100644 --- a/src/awkward/_nplikes/typetracer.py +++ b/src/awkward/_nplikes/typetracer.py @@ -5,6 +5,7 @@ from typing import Callable import numpy +import packaging.version import awkward as ak from awkward._nplikes.dispatch import register_nplike @@ -30,6 +31,9 @@ ) np = NumpyMetadata.instance() +NUMPY_HAS_NEP_50 = packaging.version.Version( + numpy.__version__ +) >= packaging.version.Version("1.24") def is_unknown_length(array: Any) -> bool: @@ -516,26 +520,46 @@ def apply_ufunc( # Unwrap options, assume they don't occur args = [x.content if isinstance(x, MaybeNone) else x for x in args] - # Determine input argument dtypes - input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args] - # Resolve these for the given ufunc - arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout) - resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes) - # Interpret the arguments under these dtypes - resolved_args = [ - self.asarray(arg, dtype=dtype) for arg, dtype in zip(args, resolved_dtypes) - ] - # Broadcast these resolved arguments - broadcasted_args = self.broadcast_arrays(*resolved_args) - result_dtypes = resolved_dtypes[ufunc.nin :] + if NUMPY_HAS_NEP_50: + # Determine input argument dtypes + input_arg_dtypes = [getattr(obj, "dtype", type(obj)) for obj in args] + # Resolve these for the given ufunc + arg_dtypes = tuple(input_arg_dtypes + [None] * ufunc.nout) + resolved_dtypes = ufunc.resolve_dtypes(arg_dtypes) + # Interpret the arguments under these dtypes + resolved_args = [ + self.asarray(arg, dtype=dtype) + for arg, dtype in zip(args, resolved_dtypes) + ] + # Broadcast these resolved arguments + broadcasted_args = self.broadcast_arrays(*resolved_args) + broadcasted_shape = broadcasted_args[0].shape + result_dtypes = resolved_dtypes[ufunc.nin :] + else: + array_like_args = [ + self.asarray(arg, dtype=arg.dtype) + for arg in args + if hasattr(arg, "dtype") + ] + broadcasted_args = self.broadcast_arrays(*array_like_args) + broadcasted_shape = broadcasted_args[0].shape + + numpy_args = [ + (numpy.empty(0, dtype=x.dtype) if hasattr(x, "dtype") else x) + for x in args + ] + numpy_result = ufunc(*numpy_args, **kwargs) + if ufunc.nout == 1: + result_dtypes = [numpy_result.dtype] + else: + result_dtypes = [x.dtype for x in numpy_result] + if len(result_dtypes) == 1: - return TypeTracerArray._new( - result_dtypes[0], shape=broadcasted_args[0].shape - ) + return TypeTracerArray._new(result_dtypes[0], shape=broadcasted_shape) else: return ( - TypeTracerArray._new(dtype, shape=b.shape) - for dtype, b in zip(result_dtypes, broadcasted_args) + TypeTracerArray._new(dtype, shape=broadcasted_shape) + for dtype in result_dtypes ) def _axis_is_valid(self, axis: int, ndim: int) -> bool: