From 02858f895f89f7e6292dc3a8d7e079d596885286 Mon Sep 17 00:00:00 2001 From: Topher Cawlfield <4094385+tcawlfield@users.noreply.github.com> Date: Tue, 18 Jun 2024 14:58:49 -0600 Subject: [PATCH] feat: Adding awkward.semipublic submodule (#3152) * Renaming _prettyprint to prettyprint (no longer private) * _prettyprint remains as a deprecated submodule * highlevel.py: removing unnecessary prettyprint imports * Fixing (deprecated) _prettyprint with explicit imports. * Putting remove_structure in contents.remove_structure submodule. Also adding a unit test for this and prettyprint. --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- src/awkward/_do.py | 39 +- src/awkward/_prettyprint.py | 462 ++---------------- src/awkward/contents/remove_structure.py | 46 ++ src/awkward/highlevel.py | 27 +- src/awkward/prettyprint.py | 435 +++++++++++++++++ src/awkward/types/recordtype.py | 6 +- ...export_remove_structure_and_prettyprint.py | 30 ++ 7 files changed, 554 insertions(+), 491 deletions(-) create mode 100644 src/awkward/contents/remove_structure.py create mode 100644 src/awkward/prettyprint.py create mode 100644 tests/test_2856_export_remove_structure_and_prettyprint.py diff --git a/src/awkward/_do.py b/src/awkward/_do.py index 6fafd68309..0e2a73a295 100644 --- a/src/awkward/_do.py +++ b/src/awkward/_do.py @@ -11,6 +11,7 @@ from awkward._nplikes.numpy_like import NumpyMetadata from awkward._typing import Any, AxisMaybeNone, Literal from awkward.contents.content import ActionType, Content +from awkward.contents.remove_structure import remove_structure from awkward.errors import AxisError from awkward.forms import form from awkward.record import Record @@ -191,44 +192,6 @@ def pad_none( return layout._pad_none(length, axis, 1, clip) -def remove_structure( - layout: Content | Record, - backend: Backend | None = None, - flatten_records: bool = True, - function_name: str | None = None, - drop_nones: bool = True, - keepdims: bool = False, - allow_records: bool = False, - list_to_regular: bool = False, -): - if isinstance(layout, Record): - return remove_structure( - layout._array[layout._at : layout._at + 1], - backend, - flatten_records, - function_name, - drop_nones, - keepdims, - allow_records, - ) - - else: - if backend is None: - backend = layout._backend - arrays = layout._remove_structure( - backend, - { - "flatten_records": flatten_records, - "function_name": function_name, - "drop_nones": drop_nones, - "keepdims": keepdims, - "allow_records": allow_records, - "list_to_regular": list_to_regular, - }, - ) - return tuple(arrays) - - def flatten(layout: Content, axis: int = 1) -> Content: offsets, flattened = layout._offsets_and_flattened(axis, 1) return flattened diff --git a/src/awkward/_prettyprint.py b/src/awkward/_prettyprint.py index fb2b8199f7..7ba6e0b0a0 100644 --- a/src/awkward/_prettyprint.py +++ b/src/awkward/_prettyprint.py @@ -2,434 +2,34 @@ from __future__ import annotations -import math -import re -from collections.abc import Callable - -import awkward as ak -from awkward._layout import wrap_layout -from awkward._nplikes.numpy import Numpy, NumpyMetadata -from awkward._typing import TYPE_CHECKING, Any, TypeAlias, TypedDict - -if TYPE_CHECKING: - from awkward.contents.content import Content - - -FormatterType: TypeAlias = "Callable[[Any], str]" - - -class FormatterOptions(TypedDict, total=False): - bool: FormatterType - int: FormatterType - timedelta: FormatterType - datetime: FormatterType - float: FormatterType - longfloat: FormatterType - complexfloat: FormatterType - longcomplexfloat: FormatterType - numpystr: FormatterType - object: FormatterType - all: FormatterType - int_kind: FormatterType - float_kind: FormatterType - complex_kind: FormatterType - str_kind: FormatterType - str: FormatterType - bytes: FormatterType - - -np = NumpyMetadata.instance() -numpy = Numpy.instance() - - -def half(integer: int) -> int: - return int(math.ceil(integer / 2)) - - -def alternate(length: int): - halfindex = half(length) - forward = iter(range(halfindex)) - backward = iter(range(length - 1, halfindex - 1, -1)) - going_forward, going_backward = True, True - while going_forward or going_backward: - if going_forward: - try: - yield True, next(forward) - except StopIteration: - going_forward = False - if going_backward: - try: - yield False, next(backward) - except StopIteration: - going_backward = False - - -is_identifier = re.compile(r"^[A-Za-z_][A-Za-z_0-9]*$") - - -# avoid recursion in which ak.Array.__getitem__ calls prettyprint -# to form an error string: private reimplementation of ak.Array.__getitem__ - - -class PlaceholderValue: - def __str__(self): - return "??" - - -def get_at(data: Content, index: int): - if data._layout._is_getitem_at_placeholder(): - return PlaceholderValue() - out = data._layout._getitem_at(index) - if isinstance(out, ak.contents.NumpyArray): - array_param = out.parameter("__array__") - if array_param == "byte": - return ak._util.tobytes(out._raw(numpy)) - elif array_param == "char": - return ak._util.tobytes(out._raw(numpy)).decode(errors="surrogateescape") - if isinstance(out, (ak.contents.Content, ak.record.Record)): - return wrap_layout(out, data._behavior) - else: - return out - - -def get_field(data: Content, field: str): - out = data._layout._getitem_field(field) - if isinstance(out, ak.contents.NumpyArray): - array_param = out.parameter("__array__") - if array_param == "byte": - return ak._util.tobytes(out._raw(numpy)) - elif array_param == "char": - return ak._util.tobytes(out._raw(numpy)).decode(errors="surrogateescape") - if isinstance(out, (ak.contents.Content, ak.record.Record)): - return wrap_layout(out, data._behavior) - else: - return out - - -def custom_str(current: Any) -> str | None: - if ( - issubclass(type(current), ak.highlevel.Record) - and type(current).__str__ is not ak.highlevel.Record.__str__ - ) or ( - issubclass(type(current), ak.highlevel.Array) - and type(current).__str__ is not ak.highlevel.Array.__str__ - ): - return str(current) - - elif ( - issubclass(type(current), ak.highlevel.Record) - and type(current).__repr__ is not ak.highlevel.Record.__repr__ - ) or ( - issubclass(type(current), ak.highlevel.Array) - and type(current).__repr__ is not ak.highlevel.Array.__repr__ - ): - return repr(current) - - else: - return None - - -def valuestr_horiz( - data: Any, limit_cols: int, formatter: Formatter -) -> tuple[int, list[str]]: - if isinstance(data, (ak.highlevel.Array, ak.highlevel.Record)) and ( - not data.layout.backend.nplike.known_data - ): - if isinstance(data, ak.highlevel.Array): - return 5, ["[...]"] - - original_limit_cols = limit_cols - - if isinstance(data, ak.highlevel.Array): - front, back = ["["], ["]"] - limit_cols -= 2 - - if len(data) == 0: - return 2, front + back - - elif len(data) == 1: - cols_taken, strs = valuestr_horiz(get_at(data, 0), limit_cols, formatter) - return 2 + cols_taken, front + strs + back - - else: - limit_cols -= 5 # anticipate the ", ..." - which = 0 - for forward, index in alternate(len(data)): - current = get_at(data, index) - if forward: - for_comma = 0 if which == 0 else 2 - cols_taken, strs = valuestr_horiz( - current, limit_cols - for_comma, formatter - ) - - custom = custom_str(current) - if custom is not None: - strs = custom - - if limit_cols - (for_comma + cols_taken) >= 0: - if which != 0: - front.append(", ") - limit_cols -= 2 - front.extend(strs) - limit_cols -= cols_taken - else: - break - else: - cols_taken, strs = valuestr_horiz( - current, limit_cols - 2, formatter - ) - - custom = custom_str(current) - if custom is not None: - strs = custom - - if limit_cols - (2 + cols_taken) >= 0: - back[:0] = strs - back.insert(0, ", ") - limit_cols -= 2 + cols_taken - else: - break - - which += 1 - - if which == 0: - front.append("...") - limit_cols -= 3 - elif which != len(data): - front.append(", ...") - limit_cols -= 5 - - limit_cols += 5 # credit the ", ..." - return original_limit_cols - limit_cols, front + back - - elif isinstance(data, ak.highlevel.Record): - is_tuple = data.layout.is_tuple - - front = ["("] if is_tuple else ["{"] - limit_cols -= 2 # both the opening and closing brackets - limit_cols -= 5 # anticipate the ", ..." - - which = 0 - fields = data.fields - for key in fields: - for_comma = 0 if which == 0 else 2 - if is_tuple: - key_str = "" - else: - if is_identifier.match(key) is None: - key_str = repr(key) + ": " - if key_str.startswith("u"): - key_str = key_str[1:] - else: - key_str = key + ": " - - if limit_cols - (for_comma + len(key_str) + 3) >= 0: - if which != 0: - front.append(", ") - limit_cols -= 2 - front.append(key_str) - limit_cols -= len(key_str) - which += 1 - - target = limit_cols if len(fields) == 1 else half(limit_cols) - cols_taken, strs = valuestr_horiz( - get_field(data, key), target, formatter - ) - if limit_cols - cols_taken >= 0: - front.extend(strs) - limit_cols -= cols_taken - else: - front.append("...") - limit_cols -= 3 - break - - else: - break - - which += 1 - - if len(fields) != 0: - if which == 0: - front.append("...") - limit_cols -= 3 - elif which != 2 * len(fields): - front.append(", ...") - limit_cols -= 5 - - limit_cols += 5 # credit the ", ..." - front.append(")" if is_tuple else "}") - return original_limit_cols - limit_cols, front - - else: - out = formatter(data) - return len(out), [out] - - -class Formatter: - def __init__(self, formatters: FormatterOptions | None = None, precision: int = 3): - self._formatters: FormatterOptions = formatters or {} - self._precision: int = precision - self._cache: dict[type, FormatterType] = {} - - def __call__(self, obj: Any) -> str: - try: - impl = self._cache[type(obj)] - except KeyError: - impl = self._find_formatter_impl(type(obj)) - self._cache[type(obj)] = impl - return impl(obj) - - def _format_complex(self, data: complex) -> str: - return f"{data.real:.{self._precision}g}+{data.imag:.{self._precision}g}j" - - def _format_real(self, data: float) -> str: - return f"{data:.{self._precision}g}" - - def _find_formatter_impl(self, cls: type) -> FormatterType: - if issubclass(cls, np.bool_): - try: - return self._formatters["bool"] - except KeyError: - return str - elif issubclass(cls, np.integer): - try: - return self._formatters["int"] - except KeyError: - return self._formatters.get("int_kind", str) - elif issubclass(cls, (np.float64, np.float32)): - try: - return self._formatters["float"] - except KeyError: - return self._formatters.get("float_kind", self._format_real) - elif hasattr(np, "float128") and issubclass(cls, np.float128): - try: - return self._formatters["longfloat"] - except KeyError: - return self._formatters.get("float_kind", self._format_real) - elif issubclass(cls, (np.complex64, np.complex128)): - try: - return self._formatters["complexfloat"] - except KeyError: - return self._formatters.get("complex_kind", self._format_complex) - elif hasattr(np, "complex256") and issubclass(cls, np.complex256): - try: - return self._formatters["longcomplexfloat"] - except KeyError: - return self._formatters.get("complex_kind", self._format_complex) - elif issubclass(cls, np.datetime64): - try: - return self._formatters["datetime"] - except KeyError: - return str - elif issubclass(cls, np.timedelta64): - try: - return self._formatters["timedelta"] - except KeyError: - return str - elif issubclass(cls, str): - try: - return self._formatters["str"] - except KeyError: - return self._formatters.get("str_kind", repr) - elif issubclass(cls, bytes): - try: - return self._formatters["bytes"] - except KeyError: - return self._formatters.get("str_kind", repr) - else: - return str - - -def valuestr( - data: Any, limit_rows: int, limit_cols: int, formatter: Formatter | None = None -) -> str: - if formatter is None: - formatter = Formatter() - - if isinstance(data, (ak.highlevel.Array, ak.highlevel.Record)) and ( - not data.layout.backend.nplike.known_data - ): - if isinstance(data, ak.highlevel.Array): - return "[...]" - - if limit_rows <= 1: - _, strs = valuestr_horiz(data, limit_cols, formatter) - return "".join(strs) - - elif isinstance(data, ak.highlevel.Array): - front, back = [], [] - which = 0 - for forward, index in alternate(len(data)): - _, strs = valuestr_horiz(get_at(data, index), limit_cols - 2, formatter) - if forward: - front.append("".join(strs)) - else: - back.insert(0, "".join(strs)) - - which += 1 - if which >= limit_rows: - break - - if len(data) != 0 and which != len(data): - back[0] = "..." - - out = front + back - for i, val in enumerate(out): - if i > 0: - val = out[i] = " " + val - else: - val = out[i] = "[" + val - if i < len(out) - 1: - out[i] = val + "," - else: - out[i] = val + "]" - - return "\n".join(out) - - elif isinstance(data, ak.highlevel.Record): - is_tuple = data.layout.is_tuple - - front = [] - - which = 0 - fields = data.fields - for key in fields: - if is_tuple: - key_str = "" - else: - if is_identifier.match(key) is None: - key_str = repr(key) + ": " - if key_str.startswith("u"): - key_str = key_str[1:] - else: - key_str = key + ": " - _, strs = valuestr_horiz( - get_field(data, key), limit_cols - 2 - len(key_str), formatter - ) - front.append(key_str + "".join(strs)) - - which += 1 - if which >= limit_rows: - break - - if len(fields) != 0 and which != len(fields): - front[-1] = "..." - - out = front - for i, val in enumerate(out): - if i > 0: - val = out[i] = " " + val - elif data.is_tuple: - val = out[i] = "(" + val - else: - val = out[i] = "{" + val - if i < len(out) - 1: - out[i] = val + "," - elif data.is_tuple: - out[i] = val + ")" - else: - out[i] = val + "}" - return "\n".join(out) - - else: - raise AssertionError(type(data)) +# We're renaming awkward._prettyprint to this module, and gently deprecating the +# private submodule. +from awkward.prettyprint import ( + Formatter, + FormatterOptions, + FormatterType, + PlaceholderValue, + alternate, + custom_str, + get_at, + get_field, + half, + is_identifier, + valuestr, + valuestr_horiz, +) + +__all__ = [ + "Formatter", + "FormatterOptions", + "FormatterType", + "PlaceholderValue", + "alternate", + "custom_str", + "get_at", + "get_field", + "half", + "is_identifier", + "valuestr", + "valuestr_horiz", +] diff --git a/src/awkward/contents/remove_structure.py b/src/awkward/contents/remove_structure.py new file mode 100644 index 0000000000..d43dae346c --- /dev/null +++ b/src/awkward/contents/remove_structure.py @@ -0,0 +1,46 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +from awkward._backends.backend import Backend +from awkward.contents.content import Content +from awkward.record import Record + + +# Note: remove_structure is semi-public, exposed via awkward.contents.content. +def remove_structure( + layout: Content | Record, + backend: Backend | None = None, + flatten_records: bool = True, + function_name: str | None = None, + drop_nones: bool = True, + keepdims: bool = False, + allow_records: bool = False, + list_to_regular: bool = False, +): + if isinstance(layout, Record): + return remove_structure( + layout._array[layout._at : layout._at + 1], + backend, + flatten_records, + function_name, + drop_nones, + keepdims, + allow_records, + ) + + else: + if backend is None: + backend = layout._backend + arrays = layout._remove_structure( + backend, + { + "flatten_records": flatten_records, + "function_name": function_name, + "drop_nones": drop_nones, + "keepdims": keepdims, + "allow_records": allow_records, + "list_to_regular": list_to_regular, + }, + ) + return tuple(arrays) diff --git a/src/awkward/highlevel.py b/src/awkward/highlevel.py index e8a0c92b3e..f315945511 100644 --- a/src/awkward/highlevel.py +++ b/src/awkward/highlevel.py @@ -31,10 +31,11 @@ unpickle_array_schema_1, unpickle_record_schema_1, ) -from awkward._prettyprint import Formatter from awkward._regularize import is_non_string_like_iterable from awkward._typing import Any, TypeVar from awkward._util import STDOUT +from awkward.prettyprint import Formatter +from awkward.prettyprint import valuestr as prettyprint_valuestr __all__ = ("Array", "ArrayBuilder", "Record") @@ -1291,16 +1292,12 @@ def __dir__(self): ) def __str__(self): - import awkward._prettyprint - - return awkward._prettyprint.valuestr(self, 1, 80) + return prettyprint_valuestr(self, 1, 80) def __repr__(self): return self._repr(80) def _repr(self, limit_cols): - import awkward._prettyprint - try: pytype = super().__getattribute__("__name__") except AttributeError: @@ -1322,7 +1319,7 @@ def _repr(self, limit_cols): limit_cols - len(pytype) - len(" type='...'") - 3, ), ) - valuestr = valuestr + " " + awkward._prettyprint.valuestr(self, 1, strwidth) + valuestr = valuestr + " " + prettyprint_valuestr(self, 1, strwidth) length = max(3, limit_cols - len(pytype) - len("type='...'") - len(valuestr)) if len(typestr) > length: @@ -1363,11 +1360,9 @@ def show( key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting string values, falling back upon `"str_kind"`. """ - import awkward._prettyprint - formatter_impl = Formatter(formatter, precision=precision) - valuestr = awkward._prettyprint.valuestr( + valuestr = prettyprint_valuestr( self, limit_rows, limit_cols, formatter=formatter_impl ) if type: @@ -2161,16 +2156,12 @@ def __dir__(self): ) def __str__(self): - import awkward._prettyprint - - return awkward._prettyprint.valuestr(self, 1, 80) + return prettyprint_valuestr(self, 1, 80) def __repr__(self): return self._repr(80) def _repr(self, limit_cols): - import awkward._prettyprint - pytype = type(self).__name__ typestr = repr(str(self.type))[1:-1] @@ -2189,7 +2180,7 @@ def _repr(self, limit_cols): limit_cols - len(pytype) - len(" type='...'") - 3, ), ) - valuestr = valuestr + " " + awkward._prettyprint.valuestr(self, 1, strwidth) + valuestr = valuestr + " " + prettyprint_valuestr(self, 1, strwidth) length = max(3, limit_cols - len(pytype) - len("type='...'") - len(valuestr)) if len(typestr) > length: @@ -2229,10 +2220,8 @@ def show( key is ignored; instead, a `"bytes"` and/or `"str"` key is considered when formatting string values, falling back upon `"str_kind"`. """ - import awkward._prettyprint - formatter_impl = Formatter(formatter, precision=precision) - valuestr = awkward._prettyprint.valuestr( + valuestr = prettyprint_valuestr( self, limit_rows, limit_cols, formatter=formatter_impl ) if type: diff --git a/src/awkward/prettyprint.py b/src/awkward/prettyprint.py new file mode 100644 index 0000000000..fb2b8199f7 --- /dev/null +++ b/src/awkward/prettyprint.py @@ -0,0 +1,435 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + +import math +import re +from collections.abc import Callable + +import awkward as ak +from awkward._layout import wrap_layout +from awkward._nplikes.numpy import Numpy, NumpyMetadata +from awkward._typing import TYPE_CHECKING, Any, TypeAlias, TypedDict + +if TYPE_CHECKING: + from awkward.contents.content import Content + + +FormatterType: TypeAlias = "Callable[[Any], str]" + + +class FormatterOptions(TypedDict, total=False): + bool: FormatterType + int: FormatterType + timedelta: FormatterType + datetime: FormatterType + float: FormatterType + longfloat: FormatterType + complexfloat: FormatterType + longcomplexfloat: FormatterType + numpystr: FormatterType + object: FormatterType + all: FormatterType + int_kind: FormatterType + float_kind: FormatterType + complex_kind: FormatterType + str_kind: FormatterType + str: FormatterType + bytes: FormatterType + + +np = NumpyMetadata.instance() +numpy = Numpy.instance() + + +def half(integer: int) -> int: + return int(math.ceil(integer / 2)) + + +def alternate(length: int): + halfindex = half(length) + forward = iter(range(halfindex)) + backward = iter(range(length - 1, halfindex - 1, -1)) + going_forward, going_backward = True, True + while going_forward or going_backward: + if going_forward: + try: + yield True, next(forward) + except StopIteration: + going_forward = False + if going_backward: + try: + yield False, next(backward) + except StopIteration: + going_backward = False + + +is_identifier = re.compile(r"^[A-Za-z_][A-Za-z_0-9]*$") + + +# avoid recursion in which ak.Array.__getitem__ calls prettyprint +# to form an error string: private reimplementation of ak.Array.__getitem__ + + +class PlaceholderValue: + def __str__(self): + return "??" + + +def get_at(data: Content, index: int): + if data._layout._is_getitem_at_placeholder(): + return PlaceholderValue() + out = data._layout._getitem_at(index) + if isinstance(out, ak.contents.NumpyArray): + array_param = out.parameter("__array__") + if array_param == "byte": + return ak._util.tobytes(out._raw(numpy)) + elif array_param == "char": + return ak._util.tobytes(out._raw(numpy)).decode(errors="surrogateescape") + if isinstance(out, (ak.contents.Content, ak.record.Record)): + return wrap_layout(out, data._behavior) + else: + return out + + +def get_field(data: Content, field: str): + out = data._layout._getitem_field(field) + if isinstance(out, ak.contents.NumpyArray): + array_param = out.parameter("__array__") + if array_param == "byte": + return ak._util.tobytes(out._raw(numpy)) + elif array_param == "char": + return ak._util.tobytes(out._raw(numpy)).decode(errors="surrogateescape") + if isinstance(out, (ak.contents.Content, ak.record.Record)): + return wrap_layout(out, data._behavior) + else: + return out + + +def custom_str(current: Any) -> str | None: + if ( + issubclass(type(current), ak.highlevel.Record) + and type(current).__str__ is not ak.highlevel.Record.__str__ + ) or ( + issubclass(type(current), ak.highlevel.Array) + and type(current).__str__ is not ak.highlevel.Array.__str__ + ): + return str(current) + + elif ( + issubclass(type(current), ak.highlevel.Record) + and type(current).__repr__ is not ak.highlevel.Record.__repr__ + ) or ( + issubclass(type(current), ak.highlevel.Array) + and type(current).__repr__ is not ak.highlevel.Array.__repr__ + ): + return repr(current) + + else: + return None + + +def valuestr_horiz( + data: Any, limit_cols: int, formatter: Formatter +) -> tuple[int, list[str]]: + if isinstance(data, (ak.highlevel.Array, ak.highlevel.Record)) and ( + not data.layout.backend.nplike.known_data + ): + if isinstance(data, ak.highlevel.Array): + return 5, ["[...]"] + + original_limit_cols = limit_cols + + if isinstance(data, ak.highlevel.Array): + front, back = ["["], ["]"] + limit_cols -= 2 + + if len(data) == 0: + return 2, front + back + + elif len(data) == 1: + cols_taken, strs = valuestr_horiz(get_at(data, 0), limit_cols, formatter) + return 2 + cols_taken, front + strs + back + + else: + limit_cols -= 5 # anticipate the ", ..." + which = 0 + for forward, index in alternate(len(data)): + current = get_at(data, index) + if forward: + for_comma = 0 if which == 0 else 2 + cols_taken, strs = valuestr_horiz( + current, limit_cols - for_comma, formatter + ) + + custom = custom_str(current) + if custom is not None: + strs = custom + + if limit_cols - (for_comma + cols_taken) >= 0: + if which != 0: + front.append(", ") + limit_cols -= 2 + front.extend(strs) + limit_cols -= cols_taken + else: + break + else: + cols_taken, strs = valuestr_horiz( + current, limit_cols - 2, formatter + ) + + custom = custom_str(current) + if custom is not None: + strs = custom + + if limit_cols - (2 + cols_taken) >= 0: + back[:0] = strs + back.insert(0, ", ") + limit_cols -= 2 + cols_taken + else: + break + + which += 1 + + if which == 0: + front.append("...") + limit_cols -= 3 + elif which != len(data): + front.append(", ...") + limit_cols -= 5 + + limit_cols += 5 # credit the ", ..." + return original_limit_cols - limit_cols, front + back + + elif isinstance(data, ak.highlevel.Record): + is_tuple = data.layout.is_tuple + + front = ["("] if is_tuple else ["{"] + limit_cols -= 2 # both the opening and closing brackets + limit_cols -= 5 # anticipate the ", ..." + + which = 0 + fields = data.fields + for key in fields: + for_comma = 0 if which == 0 else 2 + if is_tuple: + key_str = "" + else: + if is_identifier.match(key) is None: + key_str = repr(key) + ": " + if key_str.startswith("u"): + key_str = key_str[1:] + else: + key_str = key + ": " + + if limit_cols - (for_comma + len(key_str) + 3) >= 0: + if which != 0: + front.append(", ") + limit_cols -= 2 + front.append(key_str) + limit_cols -= len(key_str) + which += 1 + + target = limit_cols if len(fields) == 1 else half(limit_cols) + cols_taken, strs = valuestr_horiz( + get_field(data, key), target, formatter + ) + if limit_cols - cols_taken >= 0: + front.extend(strs) + limit_cols -= cols_taken + else: + front.append("...") + limit_cols -= 3 + break + + else: + break + + which += 1 + + if len(fields) != 0: + if which == 0: + front.append("...") + limit_cols -= 3 + elif which != 2 * len(fields): + front.append(", ...") + limit_cols -= 5 + + limit_cols += 5 # credit the ", ..." + front.append(")" if is_tuple else "}") + return original_limit_cols - limit_cols, front + + else: + out = formatter(data) + return len(out), [out] + + +class Formatter: + def __init__(self, formatters: FormatterOptions | None = None, precision: int = 3): + self._formatters: FormatterOptions = formatters or {} + self._precision: int = precision + self._cache: dict[type, FormatterType] = {} + + def __call__(self, obj: Any) -> str: + try: + impl = self._cache[type(obj)] + except KeyError: + impl = self._find_formatter_impl(type(obj)) + self._cache[type(obj)] = impl + return impl(obj) + + def _format_complex(self, data: complex) -> str: + return f"{data.real:.{self._precision}g}+{data.imag:.{self._precision}g}j" + + def _format_real(self, data: float) -> str: + return f"{data:.{self._precision}g}" + + def _find_formatter_impl(self, cls: type) -> FormatterType: + if issubclass(cls, np.bool_): + try: + return self._formatters["bool"] + except KeyError: + return str + elif issubclass(cls, np.integer): + try: + return self._formatters["int"] + except KeyError: + return self._formatters.get("int_kind", str) + elif issubclass(cls, (np.float64, np.float32)): + try: + return self._formatters["float"] + except KeyError: + return self._formatters.get("float_kind", self._format_real) + elif hasattr(np, "float128") and issubclass(cls, np.float128): + try: + return self._formatters["longfloat"] + except KeyError: + return self._formatters.get("float_kind", self._format_real) + elif issubclass(cls, (np.complex64, np.complex128)): + try: + return self._formatters["complexfloat"] + except KeyError: + return self._formatters.get("complex_kind", self._format_complex) + elif hasattr(np, "complex256") and issubclass(cls, np.complex256): + try: + return self._formatters["longcomplexfloat"] + except KeyError: + return self._formatters.get("complex_kind", self._format_complex) + elif issubclass(cls, np.datetime64): + try: + return self._formatters["datetime"] + except KeyError: + return str + elif issubclass(cls, np.timedelta64): + try: + return self._formatters["timedelta"] + except KeyError: + return str + elif issubclass(cls, str): + try: + return self._formatters["str"] + except KeyError: + return self._formatters.get("str_kind", repr) + elif issubclass(cls, bytes): + try: + return self._formatters["bytes"] + except KeyError: + return self._formatters.get("str_kind", repr) + else: + return str + + +def valuestr( + data: Any, limit_rows: int, limit_cols: int, formatter: Formatter | None = None +) -> str: + if formatter is None: + formatter = Formatter() + + if isinstance(data, (ak.highlevel.Array, ak.highlevel.Record)) and ( + not data.layout.backend.nplike.known_data + ): + if isinstance(data, ak.highlevel.Array): + return "[...]" + + if limit_rows <= 1: + _, strs = valuestr_horiz(data, limit_cols, formatter) + return "".join(strs) + + elif isinstance(data, ak.highlevel.Array): + front, back = [], [] + which = 0 + for forward, index in alternate(len(data)): + _, strs = valuestr_horiz(get_at(data, index), limit_cols - 2, formatter) + if forward: + front.append("".join(strs)) + else: + back.insert(0, "".join(strs)) + + which += 1 + if which >= limit_rows: + break + + if len(data) != 0 and which != len(data): + back[0] = "..." + + out = front + back + for i, val in enumerate(out): + if i > 0: + val = out[i] = " " + val + else: + val = out[i] = "[" + val + if i < len(out) - 1: + out[i] = val + "," + else: + out[i] = val + "]" + + return "\n".join(out) + + elif isinstance(data, ak.highlevel.Record): + is_tuple = data.layout.is_tuple + + front = [] + + which = 0 + fields = data.fields + for key in fields: + if is_tuple: + key_str = "" + else: + if is_identifier.match(key) is None: + key_str = repr(key) + ": " + if key_str.startswith("u"): + key_str = key_str[1:] + else: + key_str = key + ": " + _, strs = valuestr_horiz( + get_field(data, key), limit_cols - 2 - len(key_str), formatter + ) + front.append(key_str + "".join(strs)) + + which += 1 + if which >= limit_rows: + break + + if len(fields) != 0 and which != len(fields): + front[-1] = "..." + + out = front + for i, val in enumerate(out): + if i > 0: + val = out[i] = " " + val + elif data.is_tuple: + val = out[i] = "(" + val + else: + val = out[i] = "{" + val + if i < len(out) - 1: + out[i] = val + "," + elif data.is_tuple: + out[i] = val + ")" + else: + out[i] = val + "}" + return "\n".join(out) + + else: + raise AssertionError(type(data)) diff --git a/src/awkward/types/recordtype.py b/src/awkward/types/recordtype.py index 8a58f8372f..2695343432 100644 --- a/src/awkward/types/recordtype.py +++ b/src/awkward/types/recordtype.py @@ -6,7 +6,7 @@ from collections.abc import Iterable, Mapping import awkward as ak -import awkward._prettyprint +import awkward.prettyprint from awkward._behavior import find_record_typestr from awkward._parameters import parameters_are_equal, type_parameters_equal from awkward._regularize import is_integer @@ -115,7 +115,7 @@ def _str(self, indent: str, compact: bool, behavior: Mapping | None) -> list[str if name is not None: if ( - not ak._prettyprint.is_identifier.match(name) + not ak.prettyprint.is_identifier.match(name) or name in ( "unknown", @@ -143,7 +143,7 @@ def _str(self, indent: str, compact: bool, behavior: Mapping | None) -> list[str if not self.is_tuple: pairs = [] for k, v in zip(self._fields, children): - if ak._prettyprint.is_identifier.match(k) is None: + if ak.prettyprint.is_identifier.match(k) is None: key_str = json.dumps(k) else: key_str = k diff --git a/tests/test_2856_export_remove_structure_and_prettyprint.py b/tests/test_2856_export_remove_structure_and_prettyprint.py new file mode 100644 index 0000000000..0af5a92602 --- /dev/null +++ b/tests/test_2856_export_remove_structure_and_prettyprint.py @@ -0,0 +1,30 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE + +from __future__ import annotations + + +def test_prettyprint_rename(): + import awkward._prettyprint as deprecated_prettyprint + import awkward.prettyprint as new_prettyprint + + assert new_prettyprint.Formatter is deprecated_prettyprint.Formatter + assert new_prettyprint.FormatterOptions is deprecated_prettyprint.FormatterOptions + assert new_prettyprint.FormatterType is deprecated_prettyprint.FormatterType + assert new_prettyprint.PlaceholderValue is deprecated_prettyprint.PlaceholderValue + assert new_prettyprint.alternate is deprecated_prettyprint.alternate + assert new_prettyprint.custom_str is deprecated_prettyprint.custom_str + assert new_prettyprint.get_at is deprecated_prettyprint.get_at + assert new_prettyprint.get_field is deprecated_prettyprint.get_field + assert new_prettyprint.half is deprecated_prettyprint.half + assert new_prettyprint.is_identifier is deprecated_prettyprint.is_identifier + assert new_prettyprint.valuestr is deprecated_prettyprint.valuestr + assert new_prettyprint.valuestr_horiz is deprecated_prettyprint.valuestr_horiz + + +def test_remove_structure_rename(): + from awkward._do import remove_structure as deprecated_remove_structure + from awkward.contents.remove_structure import ( + remove_structure as new_remove_structure, + ) + + assert new_remove_structure is deprecated_remove_structure