Skip to content

Commit

Permalink
fix: appease mypy
Browse files Browse the repository at this point in the history
  • Loading branch information
agoose77 committed Oct 8, 2024
1 parent edc8290 commit 197de34
Showing 1 changed file with 19 additions and 11 deletions.
30 changes: 19 additions & 11 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,17 @@ def _prepare_named_axis_for_attrs(
return _named_axis


def _make_named_int_class(name: tp.Any) -> type[int]:
class NamedInt(int):
def __repr__(self):
value_repr = super().__repr__()
return f"{name!r} (named axis) -> {value_repr} (pos. axis)"

__str__ = __repr__

return NamedInt


def _named_axis_to_positional_axis(
named_axis: AxisMapping,
axis: AxisName,
Expand Down Expand Up @@ -337,16 +348,14 @@ def _named_axis_to_positional_axis(
# in order to properly display it in error messages. This is useful for cases
# where the positional axis is pointing to a non-existing axis. The error message
# will then show the original (named) axis together with the positional axis.
class namedint(int): ... # pylint: disable=multiple-statements
cls = _make_named_int_class(axis)
return cls(named_axis[axis])

def _repr(self):
return f"{axis!r} (named axis) -> {super(namedint, self).__repr__()} (pos. axis)"

namedint.__repr__ = namedint.__str__ = _repr

return namedint(named_axis[axis])
if is_integer(axis) or axis is None:
return axis
if is_integer(axis):
# TODO: is_integer is an external helper function that doesn't specify types
return int(tp.cast(tp.Any, axis))
elif axis is None:
return None
else:
raise ValueError(f"Invalid {axis=} [{type(axis)=}]")

Expand Down Expand Up @@ -777,9 +786,8 @@ def _normalize_named_slice(
idx = tp.cast(int, ax_name)
out_where[idx] = slice_
elif _is_valid_named_axis(ax_name):
idx = _named_axis_to_positional_axis(named_axis, ax_name)
# it's an integer, pyright doesn't get this
idx = tp.cast(int, idx)
idx = tp.cast(int, _named_axis_to_positional_axis(named_axis, ax_name))
out_where[idx] = slice_
else:
raise ValueError(f"Invalid axis name: {ax_name} in slice {where}")
Expand Down

0 comments on commit 197de34

Please sign in to comment.