From fbc79e68eb1aab8982338a090f548e34d2300f1e Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Wed, 2 Oct 2024 14:15:27 -0500 Subject: [PATCH 1/4] fix: ak.typetracer.length_one_if_typetracer with option and union types --- src/awkward/forms/form.py | 64 ++++++++++++++++--- ...gth_one_if_typetracer_with_option_types.py | 14 ++++ 2 files changed, 69 insertions(+), 9 deletions(-) create mode 100644 tests/test_3264_length_one_if_typetracer_with_option_types.py diff --git a/src/awkward/forms/form.py b/src/awkward/forms/form.py index 3f9bec55eb..49082970eb 100644 --- a/src/awkward/forms/form.py +++ b/src/awkward/forms/form.py @@ -558,6 +558,52 @@ def max_prefer_unknown(this: ShapeItem, that: ShapeItem) -> ShapeItem: container = {} + def prepare_empty(form): + form_key = f"node-{len(container)}" + + if isinstance(form, (ak.forms.BitMaskedForm, ak.forms.ByteMaskedForm)): + container[form_key] = b"" + return form.copy(content=prepare_empty(form.content), form_key=form_key) + + elif isinstance(form, ak.forms.IndexedOptionForm): + container[form_key] = b"" + return form.copy(content=prepare_empty(form.content), form_key=form_key) + + elif isinstance(form, ak.forms.EmptyForm): + return form + + elif isinstance(form, ak.forms.UnmaskedForm): + return form.copy(content=prepare_empty(form.content)) + + elif isinstance(form, (ak.forms.IndexedForm, ak.forms.ListForm)): + container[form_key] = b"" + return form.copy(content=prepare_empty(form.content), form_key=form_key) + + elif isinstance(form, ak.forms.ListOffsetForm): + container[form_key] = b"" + return form.copy(content=prepare_empty(form.content), form_key=form_key) + + elif isinstance(form, ak.forms.RegularForm): + return form.copy(content=prepare_empty(form.content)) + + elif isinstance(form, ak.forms.NumpyForm): + container[form_key] = b"" + return form.copy(form_key=form_key) + + elif isinstance(form, ak.forms.RecordForm): + return form.copy(contents=[prepare_empty(x) for x in form.contents]) + + elif isinstance(form, ak.forms.UnionForm): + # both tags and index will get this buffer + container[form_key] = b"" + return form.copy( + contents=[prepare_empty(x) for x in form.contents], + form_key=form_key, + ) + + else: + raise AssertionError(f"not a Form: {form!r}") + def prepare(form, multiplier): form_key = f"node-{len(container)}" @@ -566,11 +612,13 @@ def prepare(form, multiplier): container[form_key] = b"\x00" * multiplier else: container[form_key] = b"\xff" * multiplier - return form.copy(form_key=form_key) # DO NOT RECURSE + # switch from recursing down `prepare` to `prepare_empty` + return form.copy(content=prepare_empty(form.content), form_key=form_key) elif isinstance(form, ak.forms.IndexedOptionForm): container[form_key] = b"\xff\xff\xff\xff\xff\xff\xff\xff" # -1 - return form.copy(form_key=form_key) # DO NOT RECURSE + # switch from recursing down `prepare` to `prepare_empty` + return form.copy(content=prepare_empty(form.content), form_key=form_key) elif isinstance(form, ak.forms.EmptyForm): # no error if protected by non-recursing node type @@ -624,13 +672,11 @@ def prepare(form, multiplier): elif isinstance(form, ak.forms.UnionForm): # both tags and index will get this buffer, but index is 8 bytes container[form_key] = b"\x00" * (8 * multiplier) - return form.copy( - # only recurse down contents[0] because all index == 0 - contents=( - [prepare(form.contents[0], multiplier)] + form.contents[1:] - ), - form_key=form_key, - ) + # recurse down contents[0] with `prepare`, but others with `prepare_empty` + contents = [prepare(form.contents[0], multiplier)] + for x in form.contents[1:]: + contents.append(prepare_empty(x)) + return form.copy(contents=contents, form_key=form_key) else: raise AssertionError(f"not a Form: {form!r}") diff --git a/tests/test_3264_length_one_if_typetracer_with_option_types.py b/tests/test_3264_length_one_if_typetracer_with_option_types.py new file mode 100644 index 0000000000..dda733248d --- /dev/null +++ b/tests/test_3264_length_one_if_typetracer_with_option_types.py @@ -0,0 +1,14 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE +# ruff: noqa: E402 + +from __future__ import annotations + +import awkward as ak + + +def test(): + arr = ak.Array([[1], [2, 3], [1, 2, 4, 5]])[[0, None, 2]] + l1 = ak.typetracer.length_one_if_typetracer(ak.to_backend(arr, "typetracer")) + + assert l1.to_list() == [None] + assert str(l1.type) == "1 * option[var * int64]" From eaf579467683094e662df7c3d13f4ed3b1bc9b46 Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Wed, 2 Oct 2024 14:15:47 -0500 Subject: [PATCH 2/4] forgot to add the test --- ...t_3264_length_one_if_typetracer_with_option_types.py~ | 9 +++++++++ 1 file changed, 9 insertions(+) create mode 100644 tests/test_3264_length_one_if_typetracer_with_option_types.py~ diff --git a/tests/test_3264_length_one_if_typetracer_with_option_types.py~ b/tests/test_3264_length_one_if_typetracer_with_option_types.py~ new file mode 100644 index 0000000000..515435b59c --- /dev/null +++ b/tests/test_3264_length_one_if_typetracer_with_option_types.py~ @@ -0,0 +1,9 @@ +# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE +# ruff: noqa: E402 + +from __future__ import annotations + +import awkward as ak + +def test(): + From 67be943b459aacc24d97433810d2d570de7fc35f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 2 Oct 2024 19:16:37 +0000 Subject: [PATCH 3/4] style: pre-commit fixes --- tests/test_3264_length_one_if_typetracer_with_option_types.py~ | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_3264_length_one_if_typetracer_with_option_types.py~ b/tests/test_3264_length_one_if_typetracer_with_option_types.py~ index 515435b59c..ac9a2703cd 100644 --- a/tests/test_3264_length_one_if_typetracer_with_option_types.py~ +++ b/tests/test_3264_length_one_if_typetracer_with_option_types.py~ @@ -6,4 +6,4 @@ from __future__ import annotations import awkward as ak def test(): - + From 177a46051dde944766677b34a0f03b5011c38fff Mon Sep 17 00:00:00 2001 From: Jim Pivarski Date: Wed, 2 Oct 2024 14:20:48 -0500 Subject: [PATCH 4/4] no, not the Emacs backup file --- ...t_3264_length_one_if_typetracer_with_option_types.py~ | 9 --------- 1 file changed, 9 deletions(-) delete mode 100644 tests/test_3264_length_one_if_typetracer_with_option_types.py~ diff --git a/tests/test_3264_length_one_if_typetracer_with_option_types.py~ b/tests/test_3264_length_one_if_typetracer_with_option_types.py~ deleted file mode 100644 index ac9a2703cd..0000000000 --- a/tests/test_3264_length_one_if_typetracer_with_option_types.py~ +++ /dev/null @@ -1,9 +0,0 @@ -# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE -# ruff: noqa: E402 - -from __future__ import annotations - -import awkward as ak - -def test(): -