diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index df1d5518b7..fb9c6a68e3 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -66,3 +66,10 @@ repos: hooks: - id: check-github-workflows args: ["--verbose"] + +- repo: https://github.com/pre-commit/mirrors-mypy + rev: v1.6.1 + hooks: + - id: mypy + additional_dependencies: + - numpy>=1.24 diff --git a/pyproject.toml b/pyproject.toml index bab2387515..c9081e4472 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,40 +220,21 @@ disable = [ [tool.mypy] files = ["src/awkward/**/*.py"] -exclude = ["^src/awkward/[^/]+\\.py$"] plugins = [ "numpy.typing.mypy_plugin" ] python_version = "3.11" +ignore_errors = true +ignore_missing_imports = true [[tool.mypy.overrides]] module = [ - 'awkward.__init__', - 'awkward._connect.*', - 'awkward._cpu_kernels', - 'awkward._errors', - 'awkward._kernel_signatures', - 'awkward._libawkward', - 'awkward._util', - 'awkward.forms', - 'awkward.forth', - 'awkward.highlevel', - 'awkward.nplike', - 'awkward.numba', - 'awkward.types', - 'awkward.types._awkward_datashape_parser', - 'numba.*', - 'llvmlite.*', - 'ROOT.*', - 'cppyy.*', - 'jax.*', - 'pandas.*', - 'cupy.*', - 'pyarrow.*', - 'fsspec.*', - 'numexpr.*', + 'awkward._nplikes.*', + 'awkward._behavior.*', + 'awkward._backends.*', + 'awkward.forms.*', ] -ignore_errors = true +ignore_errors = false ignore_missing_imports = true [tool.ruff] @@ -306,6 +287,7 @@ mccabe.max-complexity = 100 [tool.ruff.lint.per-file-ignores] "dev/*" = ["T20", "TID251"] "src/awkward/numba/*" = ["TID251"] +"src/awkward/_typing.py" = ["TID251"] "src/awkward/_errors.py" = ["TID251"] "src/awkward/errors.py" = ["TID251"] "src/awkward/_connect/*" = ["TID251"] diff --git a/src/awkward/_backends/backend.py b/src/awkward/_backends/backend.py index f7a927c4ce..e815c8be2e 100644 --- a/src/awkward/_backends/backend.py +++ b/src/awkward/_backends/backend.py @@ -2,20 +2,21 @@ from __future__ import annotations from abc import ABC, abstractmethod +from typing import Any import awkward as ak from awkward._kernels import KernelError from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata from awkward._singleton import PublicSingleton -from awkward._typing import Callable, Tuple, TypeAlias, TypeVar, Unpack +from awkward._typing import Callable, Tuple, TypeAlias, TypeVar np = NumpyMetadata.instance() numpy = Numpy.instance() T_co = TypeVar("T_co", covariant=True) -KernelKeyType: TypeAlias = Tuple[str, Unpack[Tuple[np.dtype, ...]]] +KernelKeyType: TypeAlias = Tuple[Any, ...] KernelType: TypeAlias = "Callable[..., KernelError | None]" @@ -49,6 +50,7 @@ def format_kernel_error( errors="surrogateescape" ).lstrip("\n").lstrip("(") + assert error.str is not None message = error.str.decode(errors="surrogateescape") if error.attempt != ak._util.kSliceNone: diff --git a/src/awkward/_backends/cupy.py b/src/awkward/_backends/cupy.py index 2bce1cb764..ce5ff53bc2 100644 --- a/src/awkward/_backends/cupy.py +++ b/src/awkward/_backends/cupy.py @@ -13,7 +13,7 @@ numpy = Numpy.instance() -@register_backend(Cupy) +@register_backend(Cupy) # type: ignore[type-abstract] class CupyBackend(Backend): name: Final[str] = "cuda" diff --git a/src/awkward/_backends/dispatch.py b/src/awkward/_backends/dispatch.py index 4150364244..16fe82b9dd 100644 --- a/src/awkward/_backends/dispatch.py +++ b/src/awkward/_backends/dispatch.py @@ -6,8 +6,8 @@ from awkward._backends.backend import Backend from awkward._nplikes.numpy import Numpy from awkward._nplikes.numpylike import NumpyLike, NumpyMetadata -from awkward._typing import Callable, TypeAlias, TypeVar -from awkward._util import UNSET +from awkward._typing import Callable, TypeAlias, TypeVar, cast +from awkward._util import UNSET, Sentinel np = NumpyMetadata.instance() numpy = Numpy.instance() @@ -19,7 +19,7 @@ BackendLookupFactory: TypeAlias = "Callable[[type[T]], BackendLookup[T]]" -_type_to_backend_lookup: dict[type[T], BackendLookup] = {} +_type_to_backend_lookup: dict[type, BackendLookup] = {} _backend_lookup_factories: list[BackendLookupFactory] = [] _name_to_backend_cls: dict[str, type[Backend]] = {} @@ -68,7 +68,7 @@ def common_backend(backends: Collection[Backend]) -> Backend: ) -def backend_of_obj(obj, default: D = UNSET) -> Backend | D: +def backend_of_obj(obj, default: D | Sentinel = UNSET) -> Backend | D: cls = type(obj) try: lookup = _type_to_backend_lookup[cls] @@ -82,13 +82,13 @@ def backend_of_obj(obj, default: D = UNSET) -> Backend | D: if default is UNSET: raise TypeError(f"cannot find backend for {cls.__name__}") else: - return default + return cast(D, default) _type_to_backend_lookup[cls] = maybe_lookup return maybe_lookup(obj) def backend_of( - *objects, default: D = UNSET, coerce_to_common: bool = False + *objects, default: D | Sentinel = UNSET, coerce_to_common: bool = False ) -> Backend | D: """ Args: @@ -108,7 +108,7 @@ def backend_of( if default is UNSET: raise ValueError("could not find backend for", objects) else: - return default + return cast(D, default) elif len(unique_backends) == 1: return next(iter(unique_backends)) elif coerce_to_common: diff --git a/src/awkward/_backends/jax.py b/src/awkward/_backends/jax.py index ef63dca017..4f53837698 100644 --- a/src/awkward/_backends/jax.py +++ b/src/awkward/_backends/jax.py @@ -16,7 +16,7 @@ numpy = Numpy.instance() -@register_backend(Jax) +@register_backend(Jax) # type: ignore[type-abstract] class JaxBackend(Backend): name: Final[str] = "jax" diff --git a/src/awkward/_backends/numpy.py b/src/awkward/_backends/numpy.py index b40817b729..e8ea90f2d0 100644 --- a/src/awkward/_backends/numpy.py +++ b/src/awkward/_backends/numpy.py @@ -14,7 +14,7 @@ numpy = Numpy.instance() -@register_backend(Numpy) +@register_backend(Numpy) # type: ignore[type-abstract] class NumpyBackend(Backend): name: Final[str] = "cpu" diff --git a/src/awkward/_backends/typetracer.py b/src/awkward/_backends/typetracer.py index cef237ed21..d9ef2e5668 100644 --- a/src/awkward/_backends/typetracer.py +++ b/src/awkward/_backends/typetracer.py @@ -13,7 +13,7 @@ numpy = Numpy.instance() -@register_backend(TypeTracer) +@register_backend(TypeTracer) # type: ignore[type-abstract] class TypeTracerBackend(Backend): name: Final[str] = "typetracer" diff --git a/src/awkward/_behavior.py b/src/awkward/_behavior.py index 8433c3199a..4f4f6bea26 100644 --- a/src/awkward/_behavior.py +++ b/src/awkward/_behavior.py @@ -2,14 +2,22 @@ from __future__ import annotations from collections import ChainMap -from collections.abc import Mapping +from collections.abc import Callable, Mapping import awkward as ak from awkward._nplikes import ufuncs -from awkward._typing import Any +from awkward._typing import TYPE_CHECKING, Any + +if TYPE_CHECKING: + from awkward._nplikes.numpylike import UfuncLike + from awkward._reducers import Reducer + from awkward.contents.content import Content + from awkward.highlevel import Array + from awkward.highlevel import Record as HighLevelRecord + from awkward.record import Record -def overlay_behavior(behavior: dict | None) -> Mapping: +def overlay_behavior(behavior: Mapping | None) -> Mapping: """ Args: behavior: behavior dictionary, or None @@ -19,10 +27,11 @@ def overlay_behavior(behavior: dict | None) -> Mapping: """ if behavior is None: return ak.behavior - return ChainMap(behavior, ak.behavior) + else: + return ChainMap(behavior, ak.behavior) # type: ignore[arg-type] -def get_array_class(layout, behavior): +def get_array_class(layout: Content, behavior: Mapping | None) -> type[Array]: from awkward.highlevel import Array behavior = overlay_behavior(behavior) @@ -39,7 +48,7 @@ def get_array_class(layout, behavior): return Array -def get_record_class(layout, behavior): +def get_record_class(layout: Record, behavior: Mapping | None) -> type[HighLevelRecord]: from awkward.highlevel import Record behavior = overlay_behavior(behavior) @@ -51,14 +60,20 @@ def get_record_class(layout, behavior): return Record -def find_record_reducer(reducer, layout, behavior): +def find_record_reducer( + reducer: Reducer, layout: Record, behavior: Mapping | None +) -> Callable[[Array, bool], Any] | None: behavior = overlay_behavior(behavior) rec = layout.parameter("__record__") if isinstance(rec, str): return behavior.get((reducer.highlevel_function(), rec)) + else: + return None -def find_custom_cast(obj, behavior): +def find_custom_cast( + obj: Any, behavior: Mapping | None +) -> Callable[[Any], Content | Record] | None: behavior = overlay_behavior(behavior) for cls in type(obj).__mro__: fcn = behavior.get(("__cast__", cls)) @@ -67,7 +82,9 @@ def find_custom_cast(obj, behavior): return None -def find_ufunc_generic(ufunc, layout, behavior): +def find_ufunc_generic( + ufunc: UfuncLike, layout: Content, behavior: Mapping | None +) -> Callable[[UfuncLike, str, list, dict], Any] | None: behavior = overlay_behavior(behavior) custom = layout.parameter("__list__") if custom is None: @@ -81,8 +98,8 @@ def find_ufunc_generic(ufunc, layout, behavior): return None -def find_ufunc(behavior, signature: tuple): - if not any(s is None for s in signature): +def find_ufunc(behavior: Mapping | None, signature: tuple) -> UfuncLike | None: + if all(s is not None for s in signature): behavior = overlay_behavior(behavior) # Special case all strings or hashable types. @@ -105,13 +122,14 @@ def find_ufunc(behavior, signature: tuple): ) ): return custom + return None def find_record_typestr( behavior: None | Mapping, parameters: None | Mapping[str, Any], default: str | None = None, -): +) -> str | None: if parameters is None: return default behavior = overlay_behavior(behavior) @@ -122,14 +140,14 @@ def find_array_typestr( behavior: None | Mapping, parameters: None | Mapping[str, Any], default: str | None = None, -): +) -> str | None: if parameters is None: return default behavior = overlay_behavior(behavior) return behavior.get(("__typestr__", parameters.get("__list__")), default) -def behavior_of_obj(obj, behavior: Mapping | None = None) -> Mapping | None: +def behavior_of_obj(obj: Any, behavior: Mapping | None = None) -> Mapping | None: from awkward.highlevel import Array, ArrayBuilder, Record if behavior is not None: @@ -140,7 +158,7 @@ def behavior_of_obj(obj, behavior: Mapping | None = None) -> Mapping | None: return None -def behavior_of(*arrays, behavior: Mapping | None = None) -> Mapping | None: +def behavior_of(*arrays: Any, behavior: Mapping | None = None) -> Mapping | None: if behavior is not None: # An explicit 'behavior' always wins. return behavior @@ -158,5 +176,6 @@ def behavior_of(*arrays, behavior: Mapping | None = None) -> Mapping | None: behavior.update(x_behavior) copied = True else: + assert isinstance(behavior, dict) behavior.update(x_behavior) return behavior diff --git a/src/awkward/_kernels.py b/src/awkward/_kernels.py index b0f418a052..e5b3b979f3 100644 --- a/src/awkward/_kernels.py +++ b/src/awkward/_kernels.py @@ -21,8 +21,8 @@ class KernelError(Protocol): - filename: str | None # pylint: disable=E0602 - str: str | None + filename: bytes | None # pylint: disable=E0602 + str: bytes | None attempt: int id: int diff --git a/src/awkward/_nplikes/__init__.py b/src/awkward/_nplikes/__init__.py index 83cada5622..eddedce0c6 100644 --- a/src/awkward/_nplikes/__init__.py +++ b/src/awkward/_nplikes/__init__.py @@ -33,6 +33,6 @@ def to_nplike( if isinstance(from_nplike, awkward._nplikes.cupy.Cupy) and not isinstance( nplike, awkward._nplikes.cupy.Cupy ): - array = array.get() + array = array.get() # type: ignore[attr-defined] return nplike.asarray(array) diff --git a/src/awkward/_nplikes/array_module.py b/src/awkward/_nplikes/array_module.py index 88d5d93de7..25c3df20b8 100644 --- a/src/awkward/_nplikes/array_module.py +++ b/src/awkward/_nplikes/array_module.py @@ -17,7 +17,10 @@ ) from awkward._nplikes.placeholder import PlaceholderArray from awkward._nplikes.shape import ShapeItem, unknown_length -from awkward._typing import Any, cast, Final, Literal, DType +from awkward._typing import TYPE_CHECKING, Any, Final, Literal, cast + +if TYPE_CHECKING: + from numpy.typing import DTypeLike np = NumpyMetadata.instance() NUMPY_HAS_NEP_50 = packaging.version.Version( @@ -49,7 +52,7 @@ def asarray( self, obj, *, - dtype: DType | None = None, + dtype: DTypeLike | None = None, copy: bool | None = None, ) -> ArrayLike: if isinstance(obj, PlaceholderArray): @@ -74,7 +77,7 @@ def ascontiguousarray(self, x: ArrayLike) -> ArrayLike: return self._module.ascontiguousarray(x) def frombuffer( - self, buffer, *, dtype: DType | None = None, count: ShapeItem = -1 + self, buffer, *, dtype: DTypeLike | None = None, count: ShapeItem = -1 ) -> ArrayLike: if isinstance(buffer, PlaceholderArray): raise TypeError("placeholder arrays are not supported in `frombuffer`") @@ -84,17 +87,26 @@ def from_dlpack(self, x: Any) -> ArrayLike: return self._module.from_dlpack(x) def zeros( - self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None + self, + shape: ShapeItem | tuple[ShapeItem, ...], + *, + dtype: DTypeLike | None = None, ) -> ArrayLike: return self._module.zeros(shape, dtype=dtype) def ones( - self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None + self, + shape: ShapeItem | tuple[ShapeItem, ...], + *, + dtype: DTypeLike | None = None, ) -> ArrayLike: return self._module.ones(shape, dtype=dtype) def empty( - self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: DType | None = None + self, + shape: ShapeItem | tuple[ShapeItem, ...], + *, + dtype: DTypeLike | None = None, ) -> ArrayLike: return self._module.empty(shape, dtype=dtype) @@ -103,24 +115,24 @@ def full( shape: ShapeItem | tuple[ShapeItem, ...], fill_value, *, - dtype: DType | None = None, + dtype: DTypeLike | None = None, ) -> ArrayLike: return self._module.full(shape, fill_value, dtype=dtype) - def zeros_like(self, x: ArrayLike, *, dtype: DType | None = None) -> ArrayLike: + def zeros_like(self, x: ArrayLike, *, dtype: DTypeLike | None = None) -> ArrayLike: if isinstance(x, PlaceholderArray): return self.zeros(x.shape, dtype=dtype or x.dtype) else: return self._module.zeros_like(x, dtype=dtype) - def ones_like(self, x: ArrayLike, *, dtype: DType | None = None) -> ArrayLike: + def ones_like(self, x: ArrayLike, *, dtype: DTypeLike | None = None) -> ArrayLike: if isinstance(x, PlaceholderArray): return self.ones(x.shape, dtype=dtype or x.dtype) else: return self._module.ones_like(x, dtype=dtype) def full_like( - self, x: ArrayLike, fill_value, *, dtype: DType | None = None + self, x: ArrayLike, fill_value, *, dtype: DTypeLike | None = None ) -> ArrayLike: if isinstance(x, PlaceholderArray): return self.full(x.shape, fill_value, dtype=dtype or x.dtype) @@ -133,7 +145,7 @@ def arange( stop: float | int | None = None, step: float | int = 1, *, - dtype: DType | None = None, + dtype: DTypeLike | None = None, ) -> ArrayLike: assert not isinstance(start, PlaceholderArray) assert not isinstance(stop, PlaceholderArray) @@ -212,14 +224,13 @@ def apply_ufunc( ) -> ArrayLike | tuple[ArrayLike]: # Convert np.generic to scalar arrays resolved_args = [ - self.asarray(arg, dtype=arg.dtype) if hasattr(arg, "dtype") else arg + self.asarray(arg, dtype=arg.dtype if hasattr(arg, "dtype") else None) for arg in args ] broadcasted_args = self.broadcast_arrays(*resolved_args) # Choose the broadcasted argument if it wasn't a Python scalar non_generic_value_promoted_args = [ - y if hasattr(x, "ndim") else x - for x, y in zip(resolved_args, broadcasted_args) + y if hasattr(x, "ndim") else x for x, y in zip(args, broadcasted_args) ] # Allow other nplikes to replace implementation impl = self.prepare_ufunc(ufunc) @@ -436,7 +447,7 @@ def broadcast_to(self, x: ArrayLike, shape: tuple[ShapeItem, ...]) -> ArrayLike: def strides(self, x: ArrayLike) -> tuple[ShapeItem, ...]: if isinstance(x, PlaceholderArray): # Assume contiguous - strides: tuple[int, ...] = (x.itemsize,) + strides: tuple[ShapeItem, ...] = (x.dtype.itemsize,) for item in x.shape[-1:0:-1]: strides = (item * strides[0], *strides) return strides @@ -598,12 +609,12 @@ def array_str( ) def astype( - self, x: ArrayLike, dtype: numpy.dtype, *, copy: bool | None = True + self, x: ArrayLike, dtype: DTypeLike, *, copy: bool | None = True ) -> ArrayLike: assert not isinstance(x, PlaceholderArray) - return x.astype(dtype, copy=copy) # type: ignore + return x.astype(dtype, copy=copy) # type: ignore[attr-defined] - def can_cast(self, from_: DType | ArrayLike, to: DType | ArrayLike) -> bool: + def can_cast(self, from_: DTypeLike | ArrayLike, to: DTypeLike | ArrayLike) -> bool: return self._module.can_cast(from_, to, casting="same_kind") @classmethod diff --git a/src/awkward/_nplikes/cupy.py b/src/awkward/_nplikes/cupy.py index f242447407..d278945ec4 100644 --- a/src/awkward/_nplikes/cupy.py +++ b/src/awkward/_nplikes/cupy.py @@ -8,7 +8,11 @@ from awkward._nplikes.dispatch import register_nplike from awkward._nplikes.numpylike import ArrayLike from awkward._nplikes.placeholder import PlaceholderArray -from awkward._typing import Final +from awkward._nplikes.shape import ShapeItem +from awkward._typing import TYPE_CHECKING, Final + +if TYPE_CHECKING: + from numpy.typing import DTypeLike @register_nplike @@ -40,7 +44,7 @@ def ndarray(self): return self._module.ndarray def frombuffer( - self, buffer, *, dtype: numpy.dtype | None = None, count: int = -1 + self, buffer, *, dtype: DTypeLike | None = None, count: ShapeItem = -1 ) -> ArrayLike: assert not isinstance(buffer, PlaceholderArray) assert not isinstance(count, PlaceholderArray) @@ -81,7 +85,7 @@ def all( self, x: ArrayLike, *, - axis: int | tuple[int, ...] | None = None, + axis: ShapeItem | tuple[ShapeItem, ...] | None = None, keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> ArrayLike: @@ -96,7 +100,7 @@ def any( self, x: ArrayLike, *, - axis: int | tuple[int, ...] | None = None, + axis: ShapeItem | tuple[ShapeItem, ...] | None = None, keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> ArrayLike: @@ -108,7 +112,7 @@ def any( return out def count_nonzero( - self, x: ArrayLike, *, axis: int | tuple[int, ...] | None = None + self, x: ArrayLike, *, axis: ShapeItem | tuple[ShapeItem, ...] | None = None ) -> ArrayLike: assert not isinstance(x, PlaceholderArray) assert isinstance(axis, int) or axis is None @@ -122,7 +126,7 @@ def min( self, x: ArrayLike, *, - axis: int | tuple[int, ...] | None = None, + axis: ShapeItem | tuple[ShapeItem, ...] | None = None, keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> ArrayLike: @@ -137,7 +141,7 @@ def max( self, x: ArrayLike, *, - axis: int | tuple[int, ...] | None = None, + axis: ShapeItem | tuple[ShapeItem, ...] | None = None, keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> ArrayLike: @@ -164,4 +168,4 @@ def is_c_contiguous(self, x: ArrayLike) -> bool: if isinstance(x, PlaceholderArray): return True else: - return x.flags["C_CONTIGUOUS"] + return x.flags["C_CONTIGUOUS"] # type: ignore[attr-defined] diff --git a/src/awkward/_nplikes/dispatch.py b/src/awkward/_nplikes/dispatch.py index 3948dbf30b..6d6bd01c88 100644 --- a/src/awkward/_nplikes/dispatch.py +++ b/src/awkward/_nplikes/dispatch.py @@ -1,8 +1,8 @@ from __future__ import annotations from awkward._nplikes.numpylike import NumpyLike -from awkward._typing import TypeVar -from awkward._util import UNSET +from awkward._typing import Any, TypeVar, cast +from awkward._util import UNSET, Sentinel D = TypeVar("D") @@ -19,7 +19,7 @@ def register_nplike(cls: N) -> N: return cls -def nplike_of_obj(obj, *, default: D = UNSET) -> NumpyLike | D: +def nplike_of_obj(obj: Any, *, default: D | Sentinel = UNSET) -> NumpyLike | D: """ Args: *arrays: iterable of possible array objects @@ -44,6 +44,6 @@ def nplike_of_obj(obj, *, default: D = UNSET) -> NumpyLike | D: if default is UNSET: raise TypeError(f"cannot find nplike for {cls.__name__}") else: - return default + return cast(D, default) _type_to_nplike[cls] = nplike return nplike diff --git a/src/awkward/_nplikes/jax.py b/src/awkward/_nplikes/jax.py index d1aeeb9ced..03d1fa74bb 100644 --- a/src/awkward/_nplikes/jax.py +++ b/src/awkward/_nplikes/jax.py @@ -5,8 +5,7 @@ from awkward._nplikes.array_module import ArrayModuleNumpyLike from awkward._nplikes.dispatch import register_nplike from awkward._nplikes.numpylike import ArrayLike, UfuncLike -from awkward._nplikes.shape import ShapeItem -from awkward._typing import Final +from awkward._typing import Final, cast @register_nplike @@ -82,9 +81,8 @@ def is_c_contiguous(self, x: ArrayLike) -> bool: def ascontiguousarray(self, x: ArrayLike) -> ArrayLike: return x - def strides(self, x: ArrayLike) -> tuple[ShapeItem, ...]: - x.touch_shape() - out = (x._dtype.itemsize,) - for item in x._shape[-1:0:-1]: + def strides(self, x: ArrayLike) -> tuple[int, ...]: + out: tuple[int, ...] = (x.dtype.itemsize,) + for item in cast(tuple[int, ...], x.shape[-1:0:-1]): out = (item * out[0], *out) return out diff --git a/src/awkward/_nplikes/numpy.py b/src/awkward/_nplikes/numpy.py index 43497ab24d..e7ee6513b1 100644 --- a/src/awkward/_nplikes/numpy.py +++ b/src/awkward/_nplikes/numpy.py @@ -47,7 +47,7 @@ def is_c_contiguous(self, x: ArrayLike) -> bool: if isinstance(x, PlaceholderArray): return True else: - return x.flags["C_CONTIGUOUS"] + return x.flags["C_CONTIGUOUS"] # type: ignore[attr-defined] def packbits( self, @@ -57,7 +57,7 @@ def packbits( bitorder: Literal["big", "little"] = "big", ): assert not isinstance(x, PlaceholderArray) - return numpy.packbits(x, axis=axis, bitorder=bitorder) + return numpy.packbits(x, axis=axis, bitorder=bitorder) # type: ignore[arg-type] def unpackbits( self, @@ -68,4 +68,4 @@ def unpackbits( bitorder: Literal["big", "little"] = "big", ): assert not isinstance(x, PlaceholderArray) - return numpy.unpackbits(x, axis=axis, count=count, bitorder=bitorder) + return numpy.unpackbits(x, axis=axis, count=count, bitorder=bitorder) # type: ignore[arg-type] diff --git a/src/awkward/_nplikes/numpylike.py b/src/awkward/_nplikes/numpylike.py index 26ac7c63a7..6e95402829 100644 --- a/src/awkward/_nplikes/numpylike.py +++ b/src/awkward/_nplikes/numpylike.py @@ -5,10 +5,12 @@ import numpy -from awkward._nplikes.shape import ShapeItem, unknown_length +from awkward._nplikes.shape import ShapeItem from awkward._singleton import PublicSingleton from awkward._typing import ( + TYPE_CHECKING, Any, + DType, Literal, NamedTuple, Protocol, @@ -18,6 +20,12 @@ overload, ) +if TYPE_CHECKING: + from types import EllipsisType + + from numpy.typing import DTypeLike + + IndexType: TypeAlias = "int | ArrayLike" @@ -31,7 +39,7 @@ class UniqueAllResult(NamedTuple): class ArrayLike(Protocol): @property @abstractmethod - def dtype(self) -> dtype: + def dtype(self) -> DType: ... @property @@ -64,8 +72,8 @@ def __getitem__( self, key: SupportsIndex | slice - | Ellipsis - | tuple[SupportsIndex | slice | Ellipsis | ArrayLike, ...] + | EllipsisType + | tuple[SupportsIndex | slice | EllipsisType | ArrayLike, ...] | ArrayLike, ) -> Self: ... @@ -73,23 +81,23 @@ def __getitem__( @overload def __setitem__( self, - key: SupportsIndex - | slice - | Ellipsis - | tuple[SupportsIndex | slice | Ellipsis | ArrayLike, ...] + key: slice + | EllipsisType + | tuple[SupportsIndex | slice | EllipsisType, ...] | ArrayLike, - value: int | float | bool | complex | ArrayLike, + value: int | float | bool | complex, ): ... @overload def __setitem__( self, - key: slice - | Ellipsis - | tuple[SupportsIndex | slice | Ellipsis, ...] + key: SupportsIndex + | slice + | EllipsisType + | tuple[SupportsIndex | slice | EllipsisType | ArrayLike, ...] | ArrayLike, - value: int | float | bool | complex, + value: int | float | bool | complex | ArrayLike, ): ... @@ -114,7 +122,7 @@ def __len__(self) -> int: ... @abstractmethod - def view(self, dtype: dtype) -> Self: + def view(self, dtype: DType) -> Self: ... # Scalar UFUNCS @@ -155,7 +163,7 @@ def __le__(self, other: int | complex | float | Self) -> Self: ... @abstractmethod - def __eq__(self, other: int | complex | float | bool | Self) -> Self: + def __eq__(self, other: int | complex | float | bool | Self) -> Self: # type: ignore[override] ... @abstractmethod @@ -170,9 +178,11 @@ def __or__(self, other: int | bool | Self) -> Self: def __invert__(self) -> Self: ... + @abstractmethod def __dlpack_device__(self) -> tuple[int, int]: ... + @abstractmethod def __dlpack__(self, stream: Any = None) -> Any: ... @@ -195,6 +205,9 @@ class NumpyMetadata(PublicSingleton): str_ = numpy.str_ bytes_ = numpy.bytes_ + datetime64 = numpy.datetime64 + timedelta64 = numpy.timedelta64 + intp = numpy.intp integer = numpy.integer signedinteger = numpy.signedinteger @@ -208,8 +221,9 @@ class NumpyMetadata(PublicSingleton): dtype = numpy.dtype ufunc = numpy.ufunc iinfo = numpy.iinfo + finfo = numpy.finfo errstate = numpy.errstate - newaxis = numpy.newaxis + newaxis: None = numpy.newaxis ndarray = numpy.ndarray @@ -217,29 +231,20 @@ class NumpyMetadata(PublicSingleton): inf = numpy.inf nat = numpy.datetime64("NaT") - datetime_data = numpy.datetime_data - - @property - def issubdtype(self): - return numpy.issubdtype + datetime_data = staticmethod(numpy.datetime_data) + issubdtype = staticmethod(numpy.issubdtype) AxisError = numpy.AxisError if hasattr(numpy, "float16"): - NumpyMetadata.float16 = numpy.float16 + NumpyMetadata.float16 = numpy.float16 # type: ignore[attr-defined] if hasattr(numpy, "float128"): - NumpyMetadata.float128 = numpy.float128 + NumpyMetadata.float128 = numpy.float128 # type: ignore[attr-defined] if hasattr(numpy, "complex256"): - NumpyMetadata.complex256 = numpy.complex256 - -if hasattr(numpy, "datetime64"): - NumpyMetadata.datetime64 = numpy.datetime64 - -if hasattr(numpy, "timedelta64"): - NumpyMetadata.timedelta64 = numpy.timedelta64 + NumpyMetadata.complex256 = numpy.complex256 # type: ignore[attr-defined] class UfuncLike(Protocol): @@ -266,7 +271,7 @@ def apply_ufunc( method: str, args: list[Any], kwargs: dict[str, Any] | None = None, - ) -> ArrayLike | tuple[ArrayLike]: + ) -> ArrayLike | tuple[ArrayLike, ...]: ... @property @@ -291,7 +296,7 @@ def asarray( self, obj, *, - dtype: numpy.dtype | None = None, + dtype: DTypeLike | None = None, copy: bool | None = None, ) -> ArrayLike: ... @@ -302,7 +307,7 @@ def ascontiguousarray(self, x: ArrayLike) -> ArrayLike: @abstractmethod def frombuffer( - self, buffer, *, dtype: numpy.dtype | None = None, count: int = -1 + self, buffer, *, dtype: DTypeLike | None = None, count: ShapeItem = -1 ) -> ArrayLike: ... @@ -315,7 +320,7 @@ def zeros( self, shape: ShapeItem | tuple[ShapeItem, ...], *, - dtype: numpy.dtype | None = None, + dtype: DTypeLike | None = None, ) -> ArrayLike: ... @@ -324,7 +329,7 @@ def ones( self, shape: ShapeItem | tuple[ShapeItem, ...], *, - dtype: numpy.dtype | None = None, + dtype: DTypeLike | None = None, ) -> ArrayLike: ... @@ -333,7 +338,7 @@ def empty( self, shape: ShapeItem | tuple[ShapeItem, ...], *, - dtype: numpy.dtype | None = None, + dtype: DTypeLike | None = None, ) -> ArrayLike: ... @@ -343,23 +348,21 @@ def full( shape: ShapeItem | tuple[ShapeItem, ...], fill_value, *, - dtype: numpy.dtype | None = None, + dtype: DTypeLike | None = None, ) -> ArrayLike: ... @abstractmethod - def zeros_like( - self, x: ArrayLike, *, dtype: numpy.dtype | None = None - ) -> ArrayLike: + def zeros_like(self, x: ArrayLike, *, dtype: DTypeLike | None = None) -> ArrayLike: ... @abstractmethod - def ones_like(self, x: ArrayLike, *, dtype: numpy.dtype | None = None) -> ArrayLike: + def ones_like(self, x: ArrayLike, *, dtype: DTypeLike | None = None) -> ArrayLike: ... @abstractmethod def full_like( - self, x: ArrayLike, fill_value, *, dtype: numpy.dtype | None = None + self, x: ArrayLike, fill_value, *, dtype: DTypeLike | None = None ) -> ArrayLike: ... @@ -370,7 +373,7 @@ def arange( stop: float | int | None = None, step: float | int = 1, *, - dtype: numpy.dtype | None = None, + dtype: DTypeLike | None = None, ) -> ArrayLike: ... @@ -385,7 +388,7 @@ def meshgrid( @abstractmethod def array_equal( self, x1: ArrayLike, x2: ArrayLike, *, equal_nan: bool = False - ) -> ArrayLike: + ) -> bool: ... @abstractmethod @@ -409,16 +412,8 @@ def broadcast_arrays(self, *arrays: ArrayLike) -> list[ArrayLike]: def broadcast_to(self, x: ArrayLike, shape: tuple[ShapeItem, ...]) -> ArrayLike: ... - @overload - def shape_item_as_index(self, x1: int) -> int: - ... - - @overload - def shape_item_as_index(self, x1: type[unknown_length]) -> ArrayLike: - ... - @abstractmethod - def shape_item_as_index(self, x1): + def shape_item_as_index(self, x1: ShapeItem) -> int | ArrayLike: ... @abstractmethod @@ -666,14 +661,12 @@ def array_str( @abstractmethod def astype( - self, x: ArrayLike, dtype: numpy.dtype, *, copy: bool | None = True + self, x: ArrayLike, dtype: DTypeLike, *, copy: bool | None = True ) -> ArrayLike: ... @abstractmethod - def can_cast( - self, from_: numpy.dtype | ArrayLike, to: numpy.dtype | ArrayLike - ) -> bool: + def can_cast(self, from_: DType | ArrayLike, to: DType | ArrayLike) -> bool: ... @abstractmethod diff --git a/src/awkward/_nplikes/placeholder.py b/src/awkward/_nplikes/placeholder.py index c401121d03..ea6ac07156 100644 --- a/src/awkward/_nplikes/placeholder.py +++ b/src/awkward/_nplikes/placeholder.py @@ -6,25 +6,23 @@ from awkward._nplikes.numpylike import ArrayLike, NumpyLike, NumpyMetadata from awkward._nplikes.shape import ShapeItem, unknown_length -from awkward._typing import Self +from awkward._typing import Any, DType, Self np = NumpyMetadata.instance() class PlaceholderArray(ArrayLike): - def __init__( - self, nplike: NumpyLike, shape: tuple[ShapeItem, ...], dtype: np.dtype - ): + def __init__(self, nplike: NumpyLike, shape: tuple[ShapeItem, ...], dtype: DType): self._nplike = nplike self._shape = shape self._dtype = np.dtype(dtype) @property - def dtype(self) -> np.dtype: + def dtype(self) -> DType: return self._dtype @property - def shape(self) -> tuple[int, ...]: + def shape(self) -> tuple[ShapeItem, ...]: return self._shape @property @@ -32,7 +30,7 @@ def ndim(self) -> int: return len(self._shape) @property - def size(self) -> int: + def size(self) -> ShapeItem: return reduce(mul, self._shape) @property @@ -40,8 +38,8 @@ def nbytes(self) -> int: return 0 @property - def strides(self) -> tuple[int, ...]: - out = (self._dtype.itemsize,) + def strides(self) -> tuple[ShapeItem, ...]: + out: tuple[ShapeItem, ...] = (self._dtype.itemsize,) for item in reversed(self._shape): out = (item * out[0], *out) return out @@ -50,7 +48,7 @@ def strides(self) -> tuple[int, ...]: def T(self): return type(self)(self._nplike, self._shape[::-1], self._dtype) - def view(self, dtype: dtype) -> Self: + def view(self, dtype: DType) -> Self: dtype = np.dtype(dtype) if len(self._shape) >= 1: last, remainder = divmod( @@ -108,7 +106,7 @@ def __index__(self) -> int: raise RuntimeError def __len__(self) -> int: - return self._shape[0] + return int(self._shape[0]) def __add__(self, other): raise RuntimeError @@ -149,4 +147,10 @@ def __sub__(self, other): def __truediv__(self, other): raise RuntimeError - __iter__ = None + __iter__: None = None + + def __dlpack_device__(self) -> tuple[int, int]: + raise RuntimeError + + def __dlpack__(self, stream: Any = None) -> Any: + raise RuntimeError diff --git a/src/awkward/_nplikes/shape.py b/src/awkward/_nplikes/shape.py index c7b50c41fe..82493b1b56 100644 --- a/src/awkward/_nplikes/shape.py +++ b/src/awkward/_nplikes/shape.py @@ -2,15 +2,20 @@ from __future__ import annotations from awkward._singleton import PrivateSingleton -from awkward._typing import Self, TypeAlias +from awkward._typing import TYPE_CHECKING, Self, TypeAlias -ShapeItem: TypeAlias = "int | type[unknown_length]" +__all__ = ("ShapeItem", "UnknownLength", "unknown_length") +ShapeItem: TypeAlias = "int | UnknownLength" -class _UnknownLength(PrivateSingleton): +if TYPE_CHECKING: + from types import NotImplementedType + + +class UnknownLength(PrivateSingleton): _instance_name: str - def __add__(self, other) -> Self | NotImplemented: + def __add__(self, other) -> Self | NotImplementedType: if isinstance(other, int) or other is self: return self else: @@ -19,7 +24,7 @@ def __add__(self, other) -> Self | NotImplemented: __radd__ = __add__ __iadd__ = __add__ - def __sub__(self, other) -> Self | NotImplemented: + def __sub__(self, other) -> Self | NotImplementedType: if isinstance(other, int) or other is self: return self else: @@ -28,7 +33,7 @@ def __sub__(self, other) -> Self | NotImplemented: __rsub__ = __sub__ __isub__ = __sub__ - def __mul__(self, other) -> Self | NotImplemented: + def __mul__(self, other) -> Self | NotImplementedType: if isinstance(other, int) or other is self: return self else: @@ -37,7 +42,7 @@ def __mul__(self, other) -> Self | NotImplemented: __rmul__ = __mul__ __imul__ = __mul__ - def __floordiv__(self, other) -> Self | NotImplemented: + def __floordiv__(self, other) -> Self | NotImplementedType: if isinstance(other, int) or other is self: return self else: @@ -81,7 +86,7 @@ def __int__(self): # Inform the singleton if its module name -_UnknownLength._instance_name = f"{__name__}.unknown_length" +UnknownLength._instance_name = f"{__name__}.unknown_length" # Ensure we have a single instance -unknown_length = _UnknownLength._new() +unknown_length = UnknownLength._new() diff --git a/src/awkward/_nplikes/typetracer.py b/src/awkward/_nplikes/typetracer.py index 7693cd173c..7bba4d83af 100644 --- a/src/awkward/_nplikes/typetracer.py +++ b/src/awkward/_nplikes/typetracer.py @@ -1,6 +1,7 @@ # BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE from __future__ import annotations +from collections.abc import Sequence from numbers import Number from typing import Callable @@ -22,14 +23,27 @@ from awkward._operators import NDArrayOperatorsMixin from awkward._regularize import is_integer, is_non_string_like_sequence from awkward._typing import ( + TYPE_CHECKING, Any, + DType, Final, Literal, Self, SupportsIndex, + TypeGuard, TypeVar, + cast, ) +if TYPE_CHECKING: + from types import EllipsisType + + from numpy.typing import DTypeLike + + from awkward.contents.content import Content + from awkward.forms.form import Form + + np = NumpyMetadata.instance() NUMPY_HAS_NEP_50 = packaging.version.Version( numpy.__version__ @@ -40,15 +54,15 @@ def is_unknown_length(array: Any) -> bool: return array is unknown_length -def is_unknown_scalar(array: Any) -> bool: +def is_unknown_scalar(array: Any) -> TypeGuard[TypeTracerArray]: return isinstance(array, TypeTracerArray) and array.ndim == 0 -def is_unknown_integer(array: Any) -> bool: +def is_unknown_integer(array: Any) -> TypeGuard[TypeTracerArray]: return is_unknown_scalar(array) and np.issubdtype(array.dtype, np.integer) -def is_unknown_array(array: Any) -> bool: +def is_unknown_array(array: Any) -> TypeGuard[TypeTracerArray]: return isinstance(array, TypeTracerArray) and array.ndim > 0 @@ -165,7 +179,7 @@ def __reduce__(self): @classmethod def _new( cls, - dtype: np.dtype, + dtype: DType, shape: tuple[ShapeItem, ...], form_key: str | None = None, report: TypeTracerReport | None = None, @@ -178,8 +192,10 @@ def _new( raise TypeError("typetracer shape must be a tuple") if any(is_unknown_scalar(x) for x in shape): raise TypeError("typetracer shape must be integers or unknown-length") + if not isinstance(dtype, np.dtype): + raise TypeError("typetracer dtype must be an instance of np.dtype") self._shape = shape - self._dtype = np.dtype(dtype) + self._dtype = dtype return self @@ -205,12 +221,12 @@ def T(self) -> Self: ) @property - def dtype(self) -> np.dtype: + def dtype(self) -> DType: return self._dtype @property def size(self) -> ShapeItem: - size = 1 + size: ShapeItem = 1 for item in self._shape: size *= item return size @@ -260,7 +276,7 @@ def ndim(self) -> int: def nbytes(self) -> ShapeItem: return self.size * self._dtype.itemsize - def view(self, dtype: np.dtype) -> Self: + def view(self, dtype: DTypeLike) -> Self: dtype = np.dtype(dtype) if len(self._shape) >= 1: last, remainder = divmod( @@ -312,10 +328,10 @@ def __getitem__( self, key: SupportsIndex | slice - | Ellipsis - | tuple[SupportsIndex | slice | Ellipsis | ArrayLike, ...] + | EllipsisType + | tuple[SupportsIndex | slice | EllipsisType | ArrayLike, ...] | ArrayLike, - ) -> Self | int | float | bool | complex: + ) -> Self: if not isinstance(key, tuple): key = (key,) @@ -356,7 +372,7 @@ def __getitem__( ) # 2. Normalise Ellipsis and boolean arrays - key_parts = [] + key_parts: list[SupportsIndex | slice | ArrayLike] = [] for item in key: if item is Ellipsis: # How many more dimensions do we have than the index provides @@ -371,9 +387,9 @@ def __getitem__( # 3. Apply Indexing advanced_is_at_front = False previous_item_is_basic = True - advanced_shapes = [] - adjacent_advanced_shape = [] - result_shape_parts = [] + advanced_shapes: list[tuple[ShapeItem, ...]] = [] + adjacent_advanced_shape: list[ShapeItem] = [] + result_shape_parts: list[Sequence[ShapeItem]] = [] iter_shape = iter(self.shape) for item in key: # New axes don't reference existing dimensions @@ -447,8 +463,8 @@ def __setitem__( self, key: SupportsIndex | slice - | Ellipsis - | tuple[SupportsIndex | slice | Ellipsis | ArrayLike, ...] + | EllipsisType + | tuple[SupportsIndex | slice | EllipsisType | ArrayLike, ...] | ArrayLike, value: int | float | bool | complex | ArrayLike, ): @@ -484,20 +500,26 @@ def __int__(self) -> int: def __index__(self) -> int: raise RuntimeError("cannot realise an unknown value") + def __dlpack_device__(self) -> tuple[int, int]: + raise RuntimeError("cannot realise an unknown value") + + def __dlpack__(self, stream: Any = None) -> Any: + raise RuntimeError("cannot realise an unknown value") + -def _scalar_type_of(obj) -> numpy.dtype: +def _scalar_type_of(obj) -> DType: if is_unknown_scalar(obj): return obj.dtype else: return numpy.obj2sctype(obj) -def try_touch_data(array): +def try_touch_data(array: Any): if isinstance(array, TypeTracerArray): array.touch_data() -def try_touch_shape(array): +def try_touch_shape(array: Any): if isinstance(array, TypeTracerArray): array.touch_shape() @@ -514,9 +536,9 @@ def apply_ufunc( self, ufunc: UfuncLike, method: str, - args: list[Any], + args: Sequence[Any], kwargs: dict[str, Any] | None = None, - ) -> TypeTracerArray | tuple[TypeTracerArray]: + ) -> TypeTracerArray | tuple[TypeTracerArray, ...]: for x in args: try_touch_data(x) @@ -540,7 +562,7 @@ def apply_ufunc( if len(result_dtypes) == 1: return TypeTracerArray._new(result_dtypes[0], shape=broadcasted_shape) else: - return ( + return tuple( TypeTracerArray._new(dtype, shape=broadcasted_shape) for dtype in result_dtypes ) @@ -551,9 +573,9 @@ def apply_ufunc( self, ufunc: UfuncLike, method: str, - args: list[Any], + args: Sequence[Any], kwargs: dict[str, Any] | None = None, - ) -> TypeTracerArray | tuple[TypeTracerArray]: + ) -> TypeTracerArray | tuple[TypeTracerArray, ...]: for x in args: try_touch_data(x) @@ -561,7 +583,7 @@ def apply_ufunc( args = [x.content if isinstance(x, MaybeNone) else x for x in args] # Convert np.generic to scalar arrays resolved_args = [ - self.asarray(arg, dtype=arg.dtype) if hasattr(arg, "dtype") else arg + self.asarray(arg, dtype=arg.dtype if hasattr(arg, "dtype") else None) for arg in args ] # Broadcast all inputs together @@ -569,8 +591,7 @@ def apply_ufunc( broadcasted_shape = broadcasted_args[0].shape # Choose the broadcasted argument if it wasn't a Python scalar non_generic_value_promoted_args = [ - y if hasattr(x, "ndim") else x - for x, y in zip(resolved_args, broadcasted_args) + y if hasattr(x, "ndim") else x for x, y in zip(args, broadcasted_args) ] # Build proxy (empty) arrays proxy_args = [ @@ -582,12 +603,13 @@ def apply_ufunc( if ufunc.nout == 1: result_dtypes = [proxy_result.dtype] else: + assert isinstance(proxy_result, tuple) result_dtypes = [x.dtype for x in proxy_result] if len(result_dtypes) == 1: return TypeTracerArray._new(result_dtypes[0], shape=broadcasted_shape) else: - return ( + return tuple( TypeTracerArray._new(dtype, shape=broadcasted_shape) for dtype in result_dtypes ) @@ -615,11 +637,14 @@ def asarray( self, obj, *, - dtype: numpy.dtype | None = None, + dtype: DTypeLike | None = None, copy: bool | None = None, ) -> TypeTracerArray: assert not isinstance(obj, PlaceholderArray) + if dtype is not None: + dtype = np.dtype(dtype) + if isinstance(obj, ak.index.Index): obj = obj.data @@ -663,7 +688,7 @@ def asarray( return TypeTracerArray._new(as_array.dtype, ()) elif is_non_string_like_sequence(obj): - shape = [] + shape: list[ShapeItem] = [] flat_items = [] has_seen_leaf = False @@ -709,17 +734,17 @@ def populate_shape_and_items(node, dim): raise TypeError def ascontiguousarray(self, x: ArrayLike) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) return TypeTracerArray._new( x.dtype, shape=x.shape, form_key=x.form_key, report=x.report ) def frombuffer( - self, buffer, *, dtype: np.dtype | None = None, count: int = -1 + self, buffer, *, dtype: DTypeLike | None = None, count: ShapeItem = -1 ) -> TypeTracerArray: - for x in (buffer, count): - assert not isinstance(x, PlaceholderArray) - try_touch_data(x) + assert not isinstance(buffer, PlaceholderArray) + try_touch_data(buffer) + try_touch_data(count) if isinstance(buffer, TypeTracerArray) or is_unknown_scalar(count): raise NotImplementedError @@ -727,29 +752,50 @@ def frombuffer( return self.asarray(numpy.frombuffer(buffer, dtype=dtype, count=count)) def from_dlpack(self, x: Any) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) raise NotImplementedError def zeros( - self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: np.dtype | None = None + self, + shape: ShapeItem | tuple[ShapeItem, ...], + *, + dtype: DTypeLike | None = None, ) -> TypeTracerArray: if not isinstance(shape, tuple): shape = (shape,) + if dtype is None: + dtype = np.dtype(np.finfo(float).dtype) + else: + dtype = np.dtype(dtype) return TypeTracerArray._new(dtype, shape) def ones( - self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: np.dtype | None = None + self, + shape: ShapeItem | tuple[ShapeItem, ...], + *, + dtype: DTypeLike | None = None, ) -> TypeTracerArray: if not isinstance(shape, tuple): shape = (shape,) + if dtype is None: + dtype = np.dtype(np.finfo(float).dtype) + else: + dtype = np.dtype(dtype) return TypeTracerArray._new(dtype, shape) def empty( - self, shape: ShapeItem | tuple[ShapeItem, ...], *, dtype: np.dtype | None = None + self, + shape: ShapeItem | tuple[ShapeItem, ...], + *, + dtype: DTypeLike | None = None, ) -> TypeTracerArray: if not isinstance(shape, tuple): shape = (shape,) + if dtype is None: + dtype = np.dtype(np.finfo(float).dtype) + else: + dtype = np.dtype(dtype) return TypeTracerArray._new(dtype, shape) def full( @@ -757,35 +803,39 @@ def full( shape: ShapeItem | tuple[ShapeItem, ...], fill_value, *, - dtype: np.dtype | None = None, + dtype: DTypeLike | None = None, ) -> TypeTracerArray: assert not isinstance(fill_value, PlaceholderArray) if not isinstance(shape, tuple): shape = (shape,) - dtype = _scalar_type_of(fill_value) if dtype is None else dtype + dtype = _scalar_type_of(fill_value) if dtype is None else np.dtype(dtype) return TypeTracerArray._new(dtype, shape) def zeros_like( - self, x: ArrayLike, *, dtype: np.dtype | None = None + self, x: ArrayLike, *, dtype: DTypeLike | None = None ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_shape(x) + if dtype is None: + dtype = x.dtype + else: + dtype = np.dtype(dtype) if is_unknown_scalar(x): - return TypeTracerArray._new(dtype or x.dtype, shape=()) + return TypeTracerArray._new(dtype, shape=()) else: - return TypeTracerArray._new(dtype or x.dtype, shape=x.shape) + return TypeTracerArray._new(dtype, shape=x.shape) def ones_like( - self, x: ArrayLike, *, dtype: np.dtype | None = None + self, x: ArrayLike, *, dtype: DTypeLike | None = None ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_shape(x) return self.zeros_like(x, dtype=dtype) def full_like( - self, x: ArrayLike, fill_value, *, dtype: np.dtype | None = None + self, x: ArrayLike, fill_value, *, dtype: DTypeLike | None = None ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_shape(x) return self.zeros_like(x, dtype=dtype) @@ -795,7 +845,7 @@ def arange( stop: float | int | None = None, step: float | int = 1, *, - dtype: np.dtype | None = None, + dtype: DTypeLike | None = None, ) -> TypeTracerArray: assert not isinstance(start, PlaceholderArray) assert not isinstance(stop, PlaceholderArray) @@ -806,40 +856,40 @@ def arange( if stop is None: start, stop = 0, start + length: ShapeItem if is_integer(start) and is_integer(stop) and is_integer(step): - length = max(0, (stop - start + (step - (1 if step > 0 else -1))) // step) + length = max(0, (stop - start + (step - (1 if step > 0 else -1))) // step) # type: ignore[assignment] else: length = unknown_length - default_int_type = np.int32 if (ak._util.win or ak._util.bits32) else np.int64 - return TypeTracerArray._new(dtype or default_int_type, (length,)) + if dtype is None: + dtype = np.dtype(np.iinfo(int).dtype) + else: + dtype = np.dtype(dtype) + return TypeTracerArray._new(dtype, (length,)) def meshgrid( self, *arrays: ArrayLike, indexing: Literal["xy", "ij"] = "xy" - ) -> list[TypeTracerArray]: + ) -> list[ArrayLike]: for x in arrays: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) assert x.ndim == 1 - shape = tuple(x.size for x in arrays) + shape: list[ShapeItem] = [x.size for x in arrays] if indexing == "xy": shape[:2] = shape[1], shape[0] dtype = numpy.result_type(*arrays) - return [TypeTracerArray._new(dtype, shape=shape) for _ in arrays] + return [TypeTracerArray._new(dtype, shape=tuple(shape)) for _ in arrays] ############################ testing def array_equal( self, x1: ArrayLike, x2: ArrayLike, *, equal_nan: bool = False - ) -> TypeTracerArray: - assert not isinstance(x1, PlaceholderArray) - assert not isinstance(x2, PlaceholderArray) - try_touch_data(x1) - try_touch_data(x2) - return TypeTracerArray._new(np.bool_, shape=()) + ) -> bool: + raise RuntimeError def searchsorted( self, @@ -849,7 +899,7 @@ def searchsorted( side: Literal["left", "right"] = "left", sorter: ArrayLike | None = None, ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) assert not isinstance(values, PlaceholderArray) assert not isinstance(sorter, PlaceholderArray) try_touch_data(x) @@ -870,7 +920,7 @@ def searchsorted( ############################ manipulation def shape_item_as_index(self, x1: ShapeItem) -> IndexType: if x1 is unknown_length: - return TypeTracerArray._new(np.int64, shape=()) + return TypeTracerArray._new(np.dtype(np.int64), shape=()) elif isinstance(x1, int): return x1 else: @@ -901,6 +951,7 @@ def regularize_index_for_length( if length is unknown_length: return length_scalar + length = cast(int, length) # We have known length and index if index < 0: index = index + length @@ -1016,9 +1067,9 @@ def broadcast_shapes(self, *shapes: tuple[ShapeItem, ...]) -> tuple[ShapeItem, . ) return tuple(result) - def broadcast_arrays(self, *arrays: ArrayLike) -> list[TypeTracerArray]: + def broadcast_arrays(self, *arrays: ArrayLike) -> list[ArrayLike]: for x in arrays: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) if len(arrays) == 0: @@ -1038,7 +1089,7 @@ def broadcast_arrays(self, *arrays: ArrayLike) -> list[TypeTracerArray]: def broadcast_to( self, x: ArrayLike, shape: tuple[ShapeItem, ...] ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) new_shape = self.broadcast_shapes(x.shape, shape) # broadcast_to is asymmetric, whilst broadcast_shapes is not @@ -1060,14 +1111,14 @@ def broadcast_to( def reshape( self, x: ArrayLike, shape: tuple[ShapeItem, ...], *, copy: bool | None = None ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) x.touch_shape() size = x.size # Validate new shape to ensure that it only contains at-most one placeholder n_placeholders = 0 - new_size = 1 + new_size: ShapeItem = 1 for item in shape: if item is unknown_length: # Size is no longer defined @@ -1103,7 +1154,7 @@ def cumsum( axis: int | None = None, maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) if axis is None: return TypeTracerArray._new(x.dtype, (x.size,)) @@ -1112,10 +1163,12 @@ def cumsum( return TypeTracerArray._new(x.dtype, x.shape) def nonzero(self, x: ArrayLike) -> tuple[TypeTracerArray, ...]: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) # array try_touch_data(x) - return (TypeTracerArray._new(np.int64, (unknown_length,)),) * len(x.shape) + return (TypeTracerArray._new(np.dtype(np.int64), (unknown_length,)),) * len( + x.shape + ) def where( self, condition: ArrayLike, x1: ArrayLike, x2: ArrayLike @@ -1128,18 +1181,18 @@ def where( return TypeTracerArray._new(result_dtype, shape=condition.shape) def unique_values(self, x: ArrayLike) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) return TypeTracerArray._new(x.dtype, shape=(unknown_length,)) def unique_all(self, x: ArrayLike) -> UniqueAllResult: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) return UniqueAllResult( TypeTracerArray._new(x.dtype, shape=(unknown_length,)), - TypeTracerArray._new(np.int64, shape=(unknown_length,)), - TypeTracerArray._new(np.int64, shape=x.shape), - TypeTracerArray._new(np.int64, shape=(unknown_length,)), + TypeTracerArray._new(np.dtype(np.int64), shape=(unknown_length,)), + TypeTracerArray._new(np.dtype(np.int64), shape=x.shape), + TypeTracerArray._new(np.dtype(np.int64), shape=(unknown_length,)), ) def sort( @@ -1150,7 +1203,7 @@ def sort( descending: bool = False, stable: bool = True, ) -> ArrayLike: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) return TypeTracerArray._new(x.dtype, shape=x.shape) @@ -1165,7 +1218,7 @@ def concat(self, arrays, *, axis: int | None = 0) -> TypeTracerArray: inner_shape = None emptyarrays = [] for x in arrays: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) if inner_shape is None: inner_shape = x.shape[1:] elif inner_shape != x.shape[1:]: @@ -1190,7 +1243,7 @@ def repeat( *, axis: int | None = None, ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) assert not isinstance(repeats, PlaceholderArray) try_touch_data(x) try_touch_data(repeats) @@ -1217,7 +1270,7 @@ def stack( axis: int = 0, ) -> TypeTracerArray: for x in arrays: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) raise NotImplementedError @@ -1228,7 +1281,7 @@ def packbits( axis: int | None = None, bitorder: Literal["big", "little"] = "big", ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) raise NotImplementedError @@ -1240,14 +1293,14 @@ def unpackbits( count: int | None = None, bitorder: Literal["big", "little"] = "big", ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) raise NotImplementedError def strides(self, x: ArrayLike) -> tuple[ShapeItem, ...]: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) x.touch_shape() - out = (x._dtype.itemsize,) + out: tuple[ShapeItem, ...] = (x._dtype.itemsize,) for item in reversed(x._shape): out = (item * out[0], *out) return out @@ -1261,7 +1314,7 @@ def add( maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: assert not isinstance(x1, PlaceholderArray) - return self.apply_ufunc(numpy.add, "__call__", (x1, x2)) + return self.apply_ufunc(numpy.add, "__call__", (x1, x2)) # type: ignore[arg-type,return-value] def logical_and( self, @@ -1270,7 +1323,7 @@ def logical_and( maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: assert not isinstance(x1, PlaceholderArray) - return self.apply_ufunc(numpy.logical_and, "__call__", (x1, x2)) + return self.apply_ufunc(numpy.logical_and, "__call__", (x1, x2)) # type: ignore[arg-type,return-value] def logical_or( self, @@ -1280,21 +1333,21 @@ def logical_or( ) -> TypeTracerArray: assert not isinstance(x1, PlaceholderArray) assert not isinstance(x2, PlaceholderArray) - return self.apply_ufunc(numpy.logical_or, "__call__", (x1, x2)) + return self.apply_ufunc(numpy.logical_or, "__call__", (x1, x2)) # type: ignore[arg-type,return-value] def logical_not( self, x: ArrayLike, maybe_out: ArrayLike | None = None ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) - return self.apply_ufunc(numpy.logical_not, "__call__", (x,)) + assert isinstance(x, TypeTracerArray) + return self.apply_ufunc(numpy.logical_not, "__call__", (x,)) # type: ignore[arg-type,return-value] def sqrt(self, x: ArrayLike, maybe_out: ArrayLike | None = None) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) - return self.apply_ufunc(numpy.sqrt, "__call__", (x,)) + assert isinstance(x, TypeTracerArray) + return self.apply_ufunc(numpy.sqrt, "__call__", (x,)) # type: ignore[arg-type,return-value] def exp(self, x: ArrayLike, maybe_out: ArrayLike | None = None) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) - return self.apply_ufunc(numpy.exp, "__call__", (x,)) + assert isinstance(x, TypeTracerArray) + return self.apply_ufunc(numpy.exp, "__call__", (x,)) # type: ignore[arg-type,return-value] def divide( self, @@ -1304,7 +1357,7 @@ def divide( ) -> TypeTracerArray: assert not isinstance(x1, PlaceholderArray) assert not isinstance(x2, PlaceholderArray) - return self.apply_ufunc(numpy.divide, "__call__", (x1, x2)) + return self.apply_ufunc(numpy.divide, "__call__", (x1, x2)) # type: ignore[arg-type,return-value] ############################ almost-ufuncs @@ -1317,7 +1370,7 @@ def nan_to_num( posinf: int | float | None = None, neginf: int | float | None = None, ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) return TypeTracerArray._new(x.dtype, shape=x.shape) @@ -1335,12 +1388,12 @@ def isclose( try_touch_data(x1) try_touch_data(x2) out, _ = self.broadcast_arrays(x1, x2) - return TypeTracerArray._new(np.bool_, shape=out.shape) + return TypeTracerArray._new(np.dtype(np.bool_), shape=out.shape) def isnan(self, x: ArrayLike) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) - return TypeTracerArray._new(np.bool_, shape=x.shape) + return TypeTracerArray._new(np.dtype(np.bool_), shape=x.shape) ############################ reducers @@ -1352,10 +1405,10 @@ def all( keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) if axis is None: - return TypeTracerArray._new(np.bool_, shape=()) + return TypeTracerArray._new(np.dtype(np.bool_), shape=()) else: raise NotImplementedError @@ -1367,20 +1420,20 @@ def any( keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) if axis is None: - return TypeTracerArray._new(np.bool_, shape=()) + return TypeTracerArray._new(np.dtype(np.bool_), shape=()) else: raise NotImplementedError def count_nonzero( - self, x: ArrayLike, *, axis: int | None = None + self, x: ArrayLike, *, axis: int | tuple[int, ...] | None = None ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) if axis is None: - return TypeTracerArray._new(np.intp, shape=()) + return TypeTracerArray._new(np.dtype(np.intp), shape=()) else: raise NotImplementedError @@ -1392,7 +1445,7 @@ def min( keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) raise NotImplementedError @@ -1404,7 +1457,7 @@ def max( keepdims: bool = False, maybe_out: ArrayLike | None = None, ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) if axis is None: return TypeTracerArray._new(x.dtype, shape=()) @@ -1419,18 +1472,18 @@ def array_str( precision: int | None = None, suppress_small: bool | None = None, ): - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) try_touch_data(x) return "[## ... ##]" def astype( - self, x: ArrayLike, dtype: numpy.dtype, *, copy: bool | None = True + self, x: ArrayLike, dtype: DTypeLike, *, copy: bool | None = True ) -> TypeTracerArray: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) x.touch_data() return TypeTracerArray._new(np.dtype(dtype), x.shape) - def can_cast(self, from_: np.dtype | ArrayLike, to: np.dtype | ArrayLike) -> bool: + def can_cast(self, from_: DTypeLike | ArrayLike, to: DTypeLike | ArrayLike) -> bool: return numpy.can_cast(from_, to, casting="same_kind") @classmethod @@ -1442,7 +1495,7 @@ def is_own_array(cls, obj) -> bool: return cls.is_own_array_type(type(obj)) def is_c_contiguous(self, x: ArrayLike) -> bool: - assert not isinstance(x, PlaceholderArray) + assert isinstance(x, TypeTracerArray) return True def __dlpack_device__(self) -> tuple[int, int]: @@ -1453,10 +1506,10 @@ def __dlpack__(self, stream=None): def _attach_report( - layout: ak.contents.Content, - form: ak.forms.Form, + layout: Content, + form: Form, report: TypeTracerReport, - getkey: Callable[[ak.forms.form, str], str], + getkey: Callable[[Form, str], str], ): if isinstance(layout, (ak.contents.BitMaskedArray, ak.contents.ByteMaskedArray)): assert isinstance(form, (ak.forms.BitMaskedForm, ak.forms.ByteMaskedForm)) @@ -1489,8 +1542,8 @@ def _attach_report( elif isinstance(layout, ak.contents.NumpyArray): assert isinstance(form, ak.forms.NumpyForm) - layout.data.form_key = getkey(form, "data") - layout.data.report = report + layout.data.form_key = getkey(form, "data") # type: ignore[attr-defined] + layout.data.report = report # type: ignore[attr-defined] elif isinstance(layout, ak.contents.RecordArray): assert isinstance(form, ak.forms.RecordForm) @@ -1516,7 +1569,7 @@ def _attach_report( def typetracer_with_report( form: ak.forms.Form, - getkey: Callable[[ak.forms.form, str], str], + getkey: Callable[[Form, str], str], forget_length: bool = True, ) -> tuple[ak.contents.Content, TypeTracerReport]: layout = form.length_zero_array(highlevel=False).to_typetracer( diff --git a/src/awkward/_typing.py b/src/awkward/_typing.py index 8a892306d5..286c6ad1a8 100644 --- a/src/awkward/_typing.py +++ b/src/awkward/_typing.py @@ -2,11 +2,12 @@ # ruff: noqa: PLE0604 from __future__ import annotations -import numpy import sys import typing from typing import * # noqa: F403 +import numpy + __all__ = list( { "ClassVar", @@ -16,6 +17,7 @@ "Protocol", "Unpack", "TypeAlias", + "TypeGuard", "runtime_checkable", "AxisMaybeNone", "TypedDict", @@ -37,6 +39,7 @@ Self, TypeAlias, TypedDict, + TypeGuard, Unpack, final, ) @@ -50,6 +53,7 @@ SupportsIndex, TypeAlias, TypedDict, + TypeGuard, Unpack, final, runtime_checkable, @@ -61,4 +65,4 @@ ) JSONMapping: TypeAlias = "dict[str, JSONSerializable]" -DType = TypeVar("DType", bound=numpy.dtype) +DType: TypeAlias = numpy.dtype diff --git a/src/awkward/forms/bitmaskedform.py b/src/awkward/forms/bitmaskedform.py index c2af2013da..a9a84e68b5 100644 --- a/src/awkward/forms/bitmaskedform.py +++ b/src/awkward/forms/bitmaskedform.py @@ -8,9 +8,9 @@ import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import type_parameters_equal -from awkward._typing import Iterator, JSONSerializable, Self, final +from awkward._typing import DType, Iterator, JSONSerializable, Self, final from awkward._util import UNSET -from awkward.forms.form import Form, index_to_dtype +from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype np = NumpyMetadata.instance() @@ -209,24 +209,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: if next_content is None: return None else: - return BitMaskedForm( - self._mask, - next_content, - self._valid_when, - self._lsb_order, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(content=next_content) - def _select_columns(self, match_specifier): - return BitMaskedForm( - self._mask, - self._content._select_columns(match_specifier), - self._valid_when, - self._lsb_order, - parameters=self._parameters, - form_key=self._form_key, - ) + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: + return self.copy(content=self._content._select_columns(match_specifier)) def _column_types(self): return self._content._column_types() @@ -263,7 +249,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: yield (getkey(self, "mask"), index_to_dtype[self._mask]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) diff --git a/src/awkward/forms/bytemaskedform.py b/src/awkward/forms/bytemaskedform.py index 539fe77518..2e30057b41 100644 --- a/src/awkward/forms/bytemaskedform.py +++ b/src/awkward/forms/bytemaskedform.py @@ -7,9 +7,9 @@ import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import type_parameters_equal -from awkward._typing import Iterator, JSONSerializable, Self, final +from awkward._typing import DType, Iterator, JSONSerializable, Self, final from awkward._util import UNSET -from awkward.forms.form import Form, index_to_dtype +from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype np = NumpyMetadata.instance() @@ -187,22 +187,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: if next_content is None: return None else: - return ByteMaskedForm( - self._mask, - next_content, - self._valid_when, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(content=next_content) - def _select_columns(self, match_specifier): - return ByteMaskedForm( - self._mask, - self._content._select_columns(match_specifier), - self._valid_when, - parameters=self._parameters, - form_key=self._form_key, - ) + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: + return self.copy(content=self._content._select_columns(match_specifier)) def _column_types(self): return self._content._column_types() @@ -226,7 +214,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: yield (getkey(self, "mask"), index_to_dtype[self._mask]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) diff --git a/src/awkward/forms/emptyform.py b/src/awkward/forms/emptyform.py index 3ca9622a08..0896ae928a 100644 --- a/src/awkward/forms/emptyform.py +++ b/src/awkward/forms/emptyform.py @@ -10,9 +10,9 @@ from awkward._errors import deprecate from awkward._nplikes.numpylike import NumpyMetadata from awkward._nplikes.shape import ShapeItem -from awkward._typing import Iterator, JSONSerializable, Self, final -from awkward._util import UNSET -from awkward.forms.form import Form, JSONMapping +from awkward._typing import DType, Iterator, JSONSerializable, Self, final +from awkward._util import UNSET, Sentinel +from awkward.forms.form import Form, JSONMapping, _SpecifierMatcher np = NumpyMetadata.instance() @@ -22,22 +22,29 @@ class EmptyForm(Form): is_numpy = True is_unknown = True - def __init__(self, *, parameters: JSONMapping | None = None, form_key=None): + def __init__( + self, *, parameters: JSONMapping | None = None, form_key: str | None = None + ): if not (parameters is None or len(parameters) == 0): raise TypeError(f"{type(self).__name__} cannot contain parameters") self._init(parameters=parameters, form_key=form_key) def copy( - self, *, parameters: JSONMapping | None = UNSET, form_key=UNSET + self, + *, + parameters: JSONMapping | Sentinel | None = UNSET, + form_key: str | Sentinel | None = UNSET, ) -> EmptyForm: - if not (parameters is UNSET or parameters is None or len(parameters) == 0): + if not (parameters is UNSET or parameters is None or len(parameters) == 0): # type: ignore[arg-type] raise TypeError(f"{type(self).__name__} cannot contain parameters") return EmptyForm( - form_key=self._form_key if form_key is UNSET else form_key, + form_key=self._form_key if form_key is UNSET else form_key, # type: ignore[arg-type] ) @classmethod - def simplified(cls, *, parameters=None, form_key=None) -> Form: + def simplified( + cls, *, parameters: JSONMapping | None = None, form_key: str | None = None + ) -> Form: if not (parameters is None or len(parameters) == 0): raise TypeError(f"{cls.__name__} cannot contain parameters") return cls(parameters=parameters, form_key=form_key) @@ -123,7 +130,7 @@ def dimension_optiontype(self) -> bool: def _columns(self, path, output, list_indicator): output.append(".".join(path)) - def _select_columns(self, match_specifier): + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: return self def _prune_columns(self, is_inside_record_or_union: bool) -> Self: @@ -152,5 +159,5 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: yield from () diff --git a/src/awkward/forms/form.py b/src/awkward/forms/form.py index 90765d8331..3272e3fd73 100644 --- a/src/awkward/forms/form.py +++ b/src/awkward/forms/form.py @@ -17,7 +17,15 @@ from awkward._nplikes.numpylike import NumpyMetadata from awkward._nplikes.shape import ShapeItem, unknown_length from awkward._parameters import parameters_union -from awkward._typing import Final, Iterator, JSONMapping, JSONSerializable, Self +from awkward._typing import ( + ClassVar, + DType, + Final, + Iterator, + JSONMapping, + JSONSerializable, + Self, +) np = NumpyMetadata.instance() numpy_backend = NumpyBackend.instance() @@ -328,16 +336,14 @@ def __call__(self, field: str, *, next_match_if_empty: bool = False) -> Self | N next_specifiers.extend(self._match_to_next_specifiers[pattern]) if has_matched: - return _SpecifierMatcher( - next_specifiers, match_if_empty=next_match_if_empty - ) + return type(self)(next_specifiers, match_if_empty=next_match_if_empty) elif self.is_empty and self._match_if_empty: return self else: - return + return None -def regularize_buffer_key(buffer_key: str | callable) -> Callable[[Form, str], str]: +def regularize_buffer_key(buffer_key: str | Callable) -> Callable[[Form, str], str]: if isinstance(buffer_key, str): def getkey(form, attribute): @@ -358,7 +364,7 @@ def getkey(form, attribute): ) -index_to_dtype: Final[dict[str, np.dtype]] = { +index_to_dtype: Final[dict[str, DType]] = { "i8": np.dtype(" Self | None: + def _prune_columns(self, is_inside_record_or_union: bool) -> Form | None: raise NotImplementedError - def _select_columns(self, match_specifier) -> Self | None: + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Form | None: raise NotImplementedError def _column_types(self): @@ -658,7 +666,7 @@ def prepare(form, multiplier): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: raise NotImplementedError def expected_from_buffers( diff --git a/src/awkward/forms/indexedform.py b/src/awkward/forms/indexedform.py index 9fd4c6e105..299f9dc30f 100644 --- a/src/awkward/forms/indexedform.py +++ b/src/awkward/forms/indexedform.py @@ -8,9 +8,9 @@ import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import parameters_union, type_parameters_equal -from awkward._typing import Iterator, JSONSerializable, Self, final +from awkward._typing import DType, Iterator, JSONSerializable, Self, final from awkward._util import UNSET -from awkward.forms.form import Form, index_to_dtype +from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype np = NumpyMetadata.instance() @@ -195,20 +195,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: if next_content is None: return None else: - return IndexedForm( - self._index, - next_content, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(content=next_content) - def _select_columns(self, match_specifier): - return IndexedForm( - self._index, - self._content._select_columns(match_specifier), - parameters=self._parameters, - form_key=self._form_key, - ) + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: + return self.copy(content=self._content._select_columns(match_specifier)) def _column_types(self): return self._content._column_types() @@ -230,7 +220,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: yield (getkey(self, "index"), index_to_dtype[self._index]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) diff --git a/src/awkward/forms/indexedoptionform.py b/src/awkward/forms/indexedoptionform.py index 6eb3ab6eec..1f661315cb 100644 --- a/src/awkward/forms/indexedoptionform.py +++ b/src/awkward/forms/indexedoptionform.py @@ -7,9 +7,9 @@ import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import parameters_union, type_parameters_equal -from awkward._typing import Iterator, JSONSerializable, Self, final +from awkward._typing import DType, Iterator, JSONSerializable, Self, final from awkward._util import UNSET -from awkward.forms.form import Form, index_to_dtype +from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype np = NumpyMetadata.instance() @@ -176,20 +176,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: if next_content is None: return None else: - return IndexedOptionForm( - self._index, - next_content, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(content=next_content) - def _select_columns(self, match_specifier): - return IndexedOptionForm( - self._index, - self._content._select_columns(match_specifier), - parameters=self._parameters, - form_key=self._form_key, - ) + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: + return self.copy(content=self._content._select_columns(match_specifier)) def _column_types(self): return self._content._column_types() @@ -211,7 +201,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: yield (getkey(self, "index"), index_to_dtype[self._index]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) diff --git a/src/awkward/forms/listform.py b/src/awkward/forms/listform.py index bbeddec7f1..87867d1cc3 100644 --- a/src/awkward/forms/listform.py +++ b/src/awkward/forms/listform.py @@ -7,9 +7,9 @@ import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import type_parameters_equal -from awkward._typing import Iterator, JSONSerializable, final +from awkward._typing import DType, Iterator, JSONSerializable, Self, final from awkward._util import UNSET -from awkward.forms.form import Form, index_to_dtype +from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype np = NumpyMetadata.instance() @@ -190,27 +190,15 @@ def _columns(self, path, output, list_indicator): path = (*path, list_indicator) self._content._columns(path, output, list_indicator) - def _prune_columns(self, is_inside_record_or_union: bool): + def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: next_content = self._content._prune_columns(is_inside_record_or_union) if next_content is None: return None else: - return ListForm( - self._starts, - self._stops, - next_content, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(content=next_content) - def _select_columns(self, match_specifier): - return ListForm( - self._starts, - self._stops, - self._content._select_columns(match_specifier), - parameters=self._parameters, - form_key=self._form_key, - ) + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: + return self.copy(content=self._content._select_columns(match_specifier)) def _column_types(self): if self.parameter("__array__") in ("string", "bytestring"): @@ -237,7 +225,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: yield (getkey(self, "starts"), index_to_dtype[self._starts]) yield (getkey(self, "stops"), index_to_dtype[self._stops]) if recursive: diff --git a/src/awkward/forms/listoffsetform.py b/src/awkward/forms/listoffsetform.py index 7f6f8214df..4b51859e5a 100644 --- a/src/awkward/forms/listoffsetform.py +++ b/src/awkward/forms/listoffsetform.py @@ -8,9 +8,16 @@ import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import type_parameters_equal -from awkward._typing import Iterator, JSONMapping, JSONSerializable, final +from awkward._typing import ( + DType, + Iterator, + JSONMapping, + JSONSerializable, + Self, + final, +) from awkward._util import UNSET -from awkward.forms.form import Form, index_to_dtype +from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype np = NumpyMetadata.instance() @@ -156,25 +163,15 @@ def _columns(self, path, output, list_indicator): path = (*path, list_indicator) self._content._columns(path, output, list_indicator) - def _prune_columns(self, is_inside_record_or_union: bool): + def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: next_content = self._content._prune_columns(is_inside_record_or_union) if next_content is None: return None else: - return ListOffsetForm( - self._offsets, - next_content, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(content=next_content) - def _select_columns(self, match_specifier): - return ListOffsetForm( - self._offsets, - self._content._select_columns(match_specifier), - parameters=self._parameters, - form_key=self._form_key, - ) + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: + return self.copy(content=self._content._select_columns(match_specifier)) def _column_types(self): if self.parameter("__array__") in ("string", "bytestring"): @@ -199,7 +196,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: yield (getkey(self, "offsets"), index_to_dtype[self._offsets]) if recursive: yield from self._content._expected_from_buffers(getkey, recursive) diff --git a/src/awkward/forms/numpyform.py b/src/awkward/forms/numpyform.py index 7e7e3b2564..c41c8762c1 100644 --- a/src/awkward/forms/numpyform.py +++ b/src/awkward/forms/numpyform.py @@ -9,14 +9,19 @@ from awkward._nplikes.numpylike import NumpyMetadata from awkward._nplikes.shape import unknown_length from awkward._parameters import type_parameters_equal -from awkward._typing import JSONSerializable, Self, final -from awkward._util import UNSET -from awkward.forms.form import Form +from awkward._typing import DType, JSONMapping, JSONSerializable, Self, final +from awkward._util import UNSET, Sentinel +from awkward.forms.form import Form, _SpecifierMatcher np = NumpyMetadata.instance() -def from_dtype(dtype, parameters=None, *, time_units_as_parameter: bool = UNSET): +def from_dtype( + dtype, + parameters: JSONMapping | None = None, + *, + time_units_as_parameter: bool | Sentinel = UNSET, +): if dtype.subdtype is None: inner_shape = () else: @@ -182,6 +187,7 @@ def purelist_parameters(self, *keys: str) -> JSONSerializable: for key in keys: if key in self._parameters: return self._parameters[key] + return None @property def purelist_isregular(self): @@ -219,7 +225,7 @@ def dimension_optiontype(self): def _columns(self, path, output, list_indicator): output.append(".".join(path)) - def _select_columns(self, match_specifier): + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: return self def _prune_columns(self, is_inside_record_or_union: bool) -> Self: @@ -294,7 +300,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: from awkward.types.numpytype import primitive_to_dtype yield (getkey(self, "data"), primitive_to_dtype(self.primitive)) diff --git a/src/awkward/forms/recordform.py b/src/awkward/forms/recordform.py index 26f0f1bc98..67947303c3 100644 --- a/src/awkward/forms/recordform.py +++ b/src/awkward/forms/recordform.py @@ -8,10 +8,10 @@ from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import type_parameters_equal from awkward._regularize import is_integer -from awkward._typing import JSONSerializable, final +from awkward._typing import DType, JSONSerializable, Self, final from awkward._util import UNSET from awkward.errors import FieldNotFoundError -from awkward.forms.form import Form +from awkward.forms.form import Form, _SpecifierMatcher np = NumpyMetadata.instance() @@ -198,6 +198,7 @@ def purelist_parameters(self, *keys: str) -> JSONSerializable: for key in keys: if key in self._parameters: return self._parameters[key] + return None @property def purelist_isregular(self): @@ -246,7 +247,7 @@ def _columns(self, path, output, list_indicator): for content, field in zip(self._contents, self.fields): content._columns((*path, field), output, list_indicator) - def _prune_columns(self, is_inside_record_or_union: bool): + def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: contents = [] fields = [] for content, field in zip(self._contents, self.fields): @@ -256,18 +257,13 @@ def _prune_columns(self, is_inside_record_or_union: bool): contents.append(next_content) fields.append(field) - if fields or not is_inside_record_or_union: - return RecordForm( - contents, - fields, - parameters=self._parameters, - form_key=self._form_key, - ) # If all subtrees are pruned (or we have no subtrees!), then we should prune this form too - else: + if not fields and is_inside_record_or_union: return None + else: + return self.copy(contents=contents, fields=fields) - def _select_columns(self, match_specifier): + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: contents = [] fields = [] for content, field in zip(self._contents, self.fields): @@ -283,12 +279,7 @@ def _select_columns(self, match_specifier): contents.append(next_content) fields.append(field) - return RecordForm( - contents, - fields, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(contents=contents, fields=fields) def _column_types(self): return sum((x._column_types() for x in self._contents), ()) @@ -312,7 +303,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: if recursive: for content in self._contents: yield from content._expected_from_buffers(getkey, recursive) diff --git a/src/awkward/forms/regularform.py b/src/awkward/forms/regularform.py index 9f82666a56..974a579783 100644 --- a/src/awkward/forms/regularform.py +++ b/src/awkward/forms/regularform.py @@ -9,9 +9,9 @@ from awkward._nplikes.shape import unknown_length from awkward._parameters import type_parameters_equal from awkward._regularize import is_integer -from awkward._typing import JSONSerializable, final +from awkward._typing import DType, JSONSerializable, Self, final from awkward._util import UNSET -from awkward.forms.form import Form +from awkward.forms.form import Form, _SpecifierMatcher np = NumpyMetadata.instance() @@ -151,25 +151,15 @@ def _columns(self, path, output, list_indicator): path = (*path, list_indicator) self._content._columns(path, output, list_indicator) - def _prune_columns(self, is_inside_record_or_union: bool): + def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: next_content = self._content._prune_columns(is_inside_record_or_union) if next_content is None: return None else: - return RegularForm( - next_content, - self._size, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(content=next_content) - def _select_columns(self, match_specifier): - return RegularForm( - self._content._select_columns(match_specifier), - self._size, - parameters=self._parameters, - form_key=self._form_key, - ) + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: + return self.copy(content=self._content._select_columns(match_specifier)) def _column_types(self): if self.parameter("__array__") in ("string", "bytestring"): @@ -194,6 +184,6 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: if recursive: yield from self._content._expected_from_buffers(getkey, recursive) diff --git a/src/awkward/forms/unionform.py b/src/awkward/forms/unionform.py index 8ad5c577da..fca998a679 100644 --- a/src/awkward/forms/unionform.py +++ b/src/awkward/forms/unionform.py @@ -8,9 +8,9 @@ import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import type_parameters_equal -from awkward._typing import Iterator, JSONSerializable, Self, final +from awkward._typing import DType, Iterator, JSONSerializable, Self, final from awkward._util import UNSET -from awkward.forms.form import Form, index_to_dtype +from awkward.forms.form import Form, _SpecifierMatcher, index_to_dtype np = NumpyMetadata.instance() @@ -168,6 +168,8 @@ def purelist_parameters(self, *keys: str) -> JSONSerializable: return None return out + return None + @property def purelist_isregular(self): for content in self._contents: @@ -236,7 +238,7 @@ def _columns(self, path, output, list_indicator): for content, field in zip(self._contents, self.fields): content._columns((*path, field), output, list_indicator) - def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: + def _prune_columns(self, is_inside_record_or_union: bool) -> Form | None: contents = [] for content in self._contents: next_content = content._prune_columns(True) @@ -253,26 +255,14 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: elif len(contents) == 1: return contents[0] else: - return UnionForm( - self._tags, - self._index, - contents, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(contents=contents) - def _select_columns(self, match_specifier): + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: contents = [ content._select_columns(match_specifier) for content in self._contents ] - return UnionForm( - self._tags, - self._index, - contents, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(contents=contents) def _column_types(self): return sum((x._column_types() for x in self._contents), ()) @@ -296,7 +286,7 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: yield (getkey(self, "tags"), index_to_dtype[self._tags]) yield (getkey(self, "index"), index_to_dtype[self._index]) if recursive: diff --git a/src/awkward/forms/unmaskedform.py b/src/awkward/forms/unmaskedform.py index c6ea768e9d..722885f6fb 100644 --- a/src/awkward/forms/unmaskedform.py +++ b/src/awkward/forms/unmaskedform.py @@ -7,9 +7,9 @@ import awkward as ak from awkward._nplikes.numpylike import NumpyMetadata from awkward._parameters import parameters_union, type_parameters_equal -from awkward._typing import JSONSerializable, Self, final +from awkward._typing import DType, JSONSerializable, Self, final from awkward._util import UNSET -from awkward.forms.form import Form +from awkward.forms.form import Form, _SpecifierMatcher np = NumpyMetadata.instance() @@ -150,18 +150,10 @@ def _prune_columns(self, is_inside_record_or_union: bool) -> Self | None: if next_content is None: return None else: - return UnmaskedForm( - next_content, - parameters=self._parameters, - form_key=self._form_key, - ) + return self.copy(content=next_content) - def _select_columns(self, match_specifier): - return UnmaskedForm( - self._content._select_columns(match_specifier), - parameters=self._parameters, - form_key=self._form_key, - ) + def _select_columns(self, match_specifier: _SpecifierMatcher) -> Self: + return self.copy(content=self._content._select_columns(match_specifier)) def _column_types(self): return self._content._column_types() @@ -183,6 +175,6 @@ def __setstate__(self, state): def _expected_from_buffers( self, getkey: Callable[[Form, str], str], recursive: bool - ) -> Iterator[tuple[str, np.dtype]]: + ) -> Iterator[tuple[str, DType]]: if recursive: yield from self._content._expected_from_buffers(getkey, recursive)