Skip to content

Commit

Permalink
named_axis: fix typo and avoid dictionary copies where possible
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Oct 8, 2024
1 parent 4bbd5b1 commit f515ea7
Showing 1 changed file with 6 additions and 10 deletions.
16 changes: 6 additions & 10 deletions src/awkward/_namedaxis.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def _get_named_axis(
if isinstance(ctx, MaybeSupportsNamedAxis):
return _get_named_axis(ctx.attrs)
elif isinstance(ctx, tp.Mapping) and NAMED_AXIS_KEY in ctx:
return dict(ctx[NAMED_AXIS_KEY])
return ctx[NAMED_AXIS_KEY]
else:
return {}

Expand Down Expand Up @@ -306,7 +306,7 @@ def _prepare_named_axis_for_attrs(
if isinstance(named_axis, tuple):
_named_axis = _axis_tuple_to_mapping(named_axis)
elif isinstance(named_axis, dict):
_named_axis = dict(named_axis)
_named_axis = named_axis
else:
raise TypeError(
f"named_axis must be a mapping or a tuple, got {named_axis=} [{type(named_axis)=}]"
Expand Down Expand Up @@ -380,7 +380,7 @@ def _named_axis_to_positional_axis(
# The possible strategies are:
# - "keep all" (_keep_named_axis(..., None)): Keep all named axes in the output array, e.g.: `ak.drop_none`
# - "keep one" (_keep_named_axis(..., int)): Keep one named axes in the output array, e.g.: `ak.firsts`
# - "keep up to" (_keep_named_axis_up_to(..., int)): Keep all named axes upto a certain positional axis in the output array, e.g.: `ak.local_index`
# - "keep up to" (_keep_named_axis_up_to(..., int)): Keep all named axes up to a certain positional axis in the output array, e.g.: `ak.local_index`
# - "remove all" (_remove_all_named_axis): Removes all named axis, e.g.: `ak.categories`
# - "remove one" (_remove_named_axis): Remove the named axis from the output array, e.g.: `ak.sum`
# - "add one" (_add_named_axis): Add a new named axis to the output array, e.g.: `ak.concatenate`
Expand Down Expand Up @@ -409,7 +409,7 @@ def _keep_named_axis(
{"x": 0, "y": 1, "z": 2}
"""
if axis is None:
return dict(named_axis)
return named_axis
return {k: 0 for k, v in named_axis.items() if v == axis}


Expand Down Expand Up @@ -577,10 +577,7 @@ def _adjust(pos: int, axis: int, direction: int) -> int:
else:
return pos

out = dict(named_axis)
for k, v in out.items():
out[k] = _adjust(v, axis, direction)
return out
return {k: _adjust(v, axis, direction) for k, v in named_axis.items()}


def _add_named_axis(
Expand Down Expand Up @@ -608,8 +605,7 @@ def _add_named_axis(
if total is None:
total = len(named_axis)

out = dict(named_axis)
return _adjust_pos_axis(out, axis, total, direction=+1)
return _adjust_pos_axis(named_axis, axis, total, direction=+1)


def _unify_named_axis(
Expand Down

0 comments on commit f515ea7

Please sign in to comment.