diff --git a/src/awkward/_nplikes/array_module.py b/src/awkward/_nplikes/array_module.py index 23b1128292..88d5d93de7 100644 --- a/src/awkward/_nplikes/array_module.py +++ b/src/awkward/_nplikes/array_module.py @@ -17,7 +17,7 @@ ) from awkward._nplikes.placeholder import PlaceholderArray from awkward._nplikes.shape import ShapeItem, unknown_length -from awkward._typing import Any, Final, Literal +from awkward._typing import Any, cast, Final, Literal, DType np = NumpyMetadata.instance() NUMPY_HAS_NEP_50 = packaging.version.Version( @@ -37,7 +37,8 @@ def _nplike_concatenate_has_casting(module: Any) -> bool: class ArrayModuleNumpyLike(NumpyLike): - known_data: Final = True + known_data: Final[bool] = True + _module: Any def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike: return ufunc @@ -48,7 +49,7 @@ def asarray( self, obj, *, - dtype: numpy.dtype | None = None, + dtype: DType | None = None, copy: bool | None = None, ) -> ArrayLike: if isinstance(obj, PlaceholderArray): @@ -73,7 +74,7 @@ def ascontiguousarray(self, x: ArrayLike) -> ArrayLike: return self._module.ascontiguousarray(x) def frombuffer( - self, buffer, *, dtype: np.dtype | None = None, count: ShapeItem = -1 + self, buffer, *, dtype: DType | None = None, count: ShapeItem = -1 ) -> ArrayLike: if isinstance(buffer, PlaceholderArray): raise TypeError("placeholder arrays are not supported in `frombuffer`") @@ -83,39 +84,43 @@ def from_dlpack(self, x: Any) -> ArrayLike: return self._module.from_dlpack(x) def zeros( - self, shape: int | tuple[int, ...], *, dtype: np.dtype | None = None + self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None ) -> ArrayLike: return self._module.zeros(shape, dtype=dtype) def ones( - self, shape: int | tuple[int, ...], *, dtype: np.dtype | None = None + self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None ) -> ArrayLike: return self._module.ones(shape, dtype=dtype) def empty( - self, shape: int | tuple[int, ...], *, dtype: np.dtype | None = None + self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None ) -> ArrayLike: return self._module.empty(shape, dtype=dtype) def full( - self, shape: int | tuple[int, ...], fill_value, *, dtype: np.dtype | None = None + self, + shape: ShapeItem | tuple[ShapeItem, ...], + fill_value, + *, + dtype: DType | None = None, ) -> ArrayLike: return self._module.full(shape, fill_value, dtype=dtype) - def zeros_like(self, x: ArrayLike, *, dtype: np.dtype | None = None) -> ArrayLike: + def zeros_like(self, x: ArrayLike, *, dtype: DType | None = None) -> ArrayLike: if isinstance(x, PlaceholderArray): return self.zeros(x.shape, dtype=dtype or x.dtype) else: return self._module.zeros_like(x, dtype=dtype) - def ones_like(self, x: ArrayLike, *, dtype: np.dtype | None = None) -> ArrayLike: + def ones_like(self, x: ArrayLike, *, dtype: DType | None = None) -> ArrayLike: if isinstance(x, PlaceholderArray): return self.ones(x.shape, dtype=dtype or x.dtype) else: return self._module.ones_like(x, dtype=dtype) def full_like( - self, x: ArrayLike, fill_value, *, dtype: np.dtype | None = None + self, x: ArrayLike, fill_value, *, dtype: DType | None = None ) -> ArrayLike: if isinstance(x, PlaceholderArray): return self.full(x.shape, fill_value, dtype=dtype or x.dtype) @@ -128,7 +133,7 @@ def arange( stop: float | int | None = None, step: float | int = 1, *, - dtype: np.dtype | None = None, + dtype: DType | None = None, ) -> ArrayLike: assert not isinstance(start, PlaceholderArray) assert not isinstance(stop, PlaceholderArray) @@ -230,7 +235,7 @@ def _compute_compatible_shape( ) -> tuple[ShapeItem, ...]: next_shape = list(shape) j = None - length_factor = 1 + length_factor: ShapeItem = 1 for i, item in enumerate(shape): if item != -1: length_factor *= item @@ -313,6 +318,8 @@ def regularize_index_for_length( Returns regularized index that is guaranteed to be in-bounds. """ # We have known length and index + length = cast(int, length) + if index < 0: index = index + length @@ -429,12 +436,12 @@ def broadcast_to(self, x: ArrayLike, shape: tuple[ShapeItem, ...]) -> ArrayLike: def strides(self, x: ArrayLike) -> tuple[ShapeItem, ...]: if isinstance(x, PlaceholderArray): # Assume contiguous - strides = (x.itemsize,) + strides: tuple[int, ...] = (x.itemsize,) for item in x.shape[-1:0:-1]: strides = (item * strides[0], *strides) return strides - return x.strides + return x.strides # type: ignore[attr-defined] ############################ ufuncs @@ -517,7 +524,7 @@ def all( self, x: ArrayLike, *, - axis: int | tuple[int, ...] | None = None, + axis: ShapeItem | tuple[ShapeItem, ...] | None = None, keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> ArrayLike: @@ -528,7 +535,7 @@ def any( self, x: ArrayLike, *, - axis: int | tuple[int, ...] | None = None, + axis: ShapeItem | tuple[ShapeItem, ...] | None = None, keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> ArrayLike: @@ -539,7 +546,7 @@ def min( self, x: ArrayLike, *, - axis: int | tuple[int, ...] | None = None, + axis: ShapeItem | tuple[ShapeItem, ...] | None = None, keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> ArrayLike: @@ -550,7 +557,7 @@ def max( self, x: ArrayLike, *, - axis: int | tuple[int, ...] | None = None, + axis: ShapeItem | tuple[ShapeItem, ...] | None = None, keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> ArrayLike: @@ -558,7 +565,7 @@ def max( return self._module.max(x, axis=axis, keepdims=keepdims, out=maybe_out) def count_nonzero( - self, x: ArrayLike, *, axis: int | tuple[int, ...] | None = None + self, x: ArrayLike, *, axis: ShapeItem | tuple[ShapeItem, ...] | None = None ) -> ArrayLike: assert not isinstance(x, PlaceholderArray) assert isinstance(axis, int) or axis is None @@ -594,9 +601,9 @@ def astype( self, x: ArrayLike, dtype: numpy.dtype, *, copy: bool | None = True ) -> ArrayLike: assert not isinstance(x, PlaceholderArray) - return x.astype(dtype, copy=copy) + return x.astype(dtype, copy=copy) # type: ignore - def can_cast(self, from_: np.dtype | ArrayLike, to: np.dtype | ArrayLike) -> bool: + def can_cast(self, from_: DType | ArrayLike, to: DType | ArrayLike) -> bool: return self._module.can_cast(from_, to, casting="same_kind") @classmethod