Skip to content

Commit

Permalink
chore: appease mypy for array_module
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 31, 2023
1 parent 0c082a2 commit 865fbf9
Showing 1 changed file with 29 additions and 22 deletions.
51 changes: 29 additions & 22 deletions src/awkward/_nplikes/array_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
)
from awkward._nplikes.placeholder import PlaceholderArray
from awkward._nplikes.shape import ShapeItem, unknown_length
from awkward._typing import Any, Final, Literal
from awkward._typing import Any, cast, Final, Literal, DType

np = NumpyMetadata.instance()
NUMPY_HAS_NEP_50 = packaging.version.Version(
Expand All @@ -37,7 +37,8 @@ def _nplike_concatenate_has_casting(module: Any) -> bool:


class ArrayModuleNumpyLike(NumpyLike):
known_data: Final = True
known_data: Final[bool] = True
_module: Any

def prepare_ufunc(self, ufunc: UfuncLike) -> UfuncLike:
return ufunc
Expand All @@ -48,7 +49,7 @@ def asarray(
self,
obj,
*,
dtype: numpy.dtype | None = None,
dtype: DType | None = None,
copy: bool | None = None,
) -> ArrayLike:
if isinstance(obj, PlaceholderArray):
Expand All @@ -73,7 +74,7 @@ def ascontiguousarray(self, x: ArrayLike) -> ArrayLike:
return self._module.ascontiguousarray(x)

def frombuffer(
self, buffer, *, dtype: np.dtype | None = None, count: ShapeItem = -1
self, buffer, *, dtype: DType | None = None, count: ShapeItem = -1
) -> ArrayLike:
if isinstance(buffer, PlaceholderArray):
raise TypeError("placeholder arrays are not supported in `frombuffer`")
Expand All @@ -83,39 +84,43 @@ def from_dlpack(self, x: Any) -> ArrayLike:
return self._module.from_dlpack(x)

def zeros(
self, shape: int | tuple[int, ...], *, dtype: np.dtype | None = None
self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None
) -> ArrayLike:
return self._module.zeros(shape, dtype=dtype)

def ones(
self, shape: int | tuple[int, ...], *, dtype: np.dtype | None = None
self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None
) -> ArrayLike:
return self._module.ones(shape, dtype=dtype)

def empty(
self, shape: int | tuple[int, ...], *, dtype: np.dtype | None = None
self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None
) -> ArrayLike:
return self._module.empty(shape, dtype=dtype)

def full(
self, shape: int | tuple[int, ...], fill_value, *, dtype: np.dtype | None = None
self,
shape: ShapeItem | tuple[ShapeItem, ...],
fill_value,
*,
dtype: DType | None = None,
) -> ArrayLike:
return self._module.full(shape, fill_value, dtype=dtype)

def zeros_like(self, x: ArrayLike, *, dtype: np.dtype | None = None) -> ArrayLike:
def zeros_like(self, x: ArrayLike, *, dtype: DType | None = None) -> ArrayLike:
if isinstance(x, PlaceholderArray):
return self.zeros(x.shape, dtype=dtype or x.dtype)
else:
return self._module.zeros_like(x, dtype=dtype)

def ones_like(self, x: ArrayLike, *, dtype: np.dtype | None = None) -> ArrayLike:
def ones_like(self, x: ArrayLike, *, dtype: DType | None = None) -> ArrayLike:
if isinstance(x, PlaceholderArray):
return self.ones(x.shape, dtype=dtype or x.dtype)
else:
return self._module.ones_like(x, dtype=dtype)

def full_like(
self, x: ArrayLike, fill_value, *, dtype: np.dtype | None = None
self, x: ArrayLike, fill_value, *, dtype: DType | None = None
) -> ArrayLike:
if isinstance(x, PlaceholderArray):
return self.full(x.shape, fill_value, dtype=dtype or x.dtype)
Expand All @@ -128,7 +133,7 @@ def arange(
stop: float | int | None = None,
step: float | int = 1,
*,
dtype: np.dtype | None = None,
dtype: DType | None = None,
) -> ArrayLike:
assert not isinstance(start, PlaceholderArray)
assert not isinstance(stop, PlaceholderArray)
Expand Down Expand Up @@ -230,7 +235,7 @@ def _compute_compatible_shape(
) -> tuple[ShapeItem, ...]:
next_shape = list(shape)
j = None
length_factor = 1
length_factor: ShapeItem = 1
for i, item in enumerate(shape):
if item != -1:
length_factor *= item
Expand Down Expand Up @@ -313,6 +318,8 @@ def regularize_index_for_length(
Returns regularized index that is guaranteed to be in-bounds.
""" # We have known length and index
length = cast(int, length)

if index < 0:
index = index + length

Expand Down Expand Up @@ -429,12 +436,12 @@ def broadcast_to(self, x: ArrayLike, shape: tuple[ShapeItem, ...]) -> ArrayLike:
def strides(self, x: ArrayLike) -> tuple[ShapeItem, ...]:
if isinstance(x, PlaceholderArray):
# Assume contiguous
strides = (x.itemsize,)
strides: tuple[int, ...] = (x.itemsize,)
for item in x.shape[-1:0:-1]:
strides = (item * strides[0], *strides)
return strides

return x.strides
return x.strides # type: ignore[attr-defined]

############################ ufuncs

Expand Down Expand Up @@ -517,7 +524,7 @@ def all(
self,
x: ArrayLike,
*,
axis: int | tuple[int, ...] | None = None,
axis: ShapeItem | tuple[ShapeItem, ...] | None = None,
keepdims: bool = False,
maybe_out: ArrayLike | None = None,
) -> ArrayLike:
Expand All @@ -528,7 +535,7 @@ def any(
self,
x: ArrayLike,
*,
axis: int | tuple[int, ...] | None = None,
axis: ShapeItem | tuple[ShapeItem, ...] | None = None,
keepdims: bool = False,
maybe_out: ArrayLike | None = None,
) -> ArrayLike:
Expand All @@ -539,7 +546,7 @@ def min(
self,
x: ArrayLike,
*,
axis: int | tuple[int, ...] | None = None,
axis: ShapeItem | tuple[ShapeItem, ...] | None = None,
keepdims: bool = False,
maybe_out: ArrayLike | None = None,
) -> ArrayLike:
Expand All @@ -550,15 +557,15 @@ def max(
self,
x: ArrayLike,
*,
axis: int | tuple[int, ...] | None = None,
axis: ShapeItem | tuple[ShapeItem, ...] | None = None,
keepdims: bool = False,
maybe_out: ArrayLike | None = None,
) -> ArrayLike:
assert not isinstance(x, PlaceholderArray)
return self._module.max(x, axis=axis, keepdims=keepdims, out=maybe_out)

def count_nonzero(
self, x: ArrayLike, *, axis: int | tuple[int, ...] | None = None
self, x: ArrayLike, *, axis: ShapeItem | tuple[ShapeItem, ...] | None = None
) -> ArrayLike:
assert not isinstance(x, PlaceholderArray)
assert isinstance(axis, int) or axis is None
Expand Down Expand Up @@ -594,9 +601,9 @@ def astype(
self, x: ArrayLike, dtype: numpy.dtype, *, copy: bool | None = True
) -> ArrayLike:
assert not isinstance(x, PlaceholderArray)
return x.astype(dtype, copy=copy)
return x.astype(dtype, copy=copy) # type: ignore

def can_cast(self, from_: np.dtype | ArrayLike, to: np.dtype | ArrayLike) -> bool:
def can_cast(self, from_: DType | ArrayLike, to: DType | ArrayLike) -> bool:
return self._module.can_cast(from_, to, casting="same_kind")

@classmethod
Expand Down

0 comments on commit 865fbf9

Please sign in to comment.