Skip to content

Commit

Permalink
feat: add .attrs to highlevel objects (#2757)
Browse files Browse the repository at this point in the history
* feat: add initial implementation of attrs

* refactor: move pickling to private module

* fix: implement support for disabling custom pickle in tests

* test: test attrs

* test: fix old test usage (unrelated)

* test: importorskip arrow (unrelated)

* test: raise TypeError for `to_regular`

* feat: change prefix to `@`

* test: update test

* fix: use of `ctx.behavior`

* refactor: store `attrs` on `NumbaLookup`

* fix: more removals
  • Loading branch information
agoose77 authored Nov 8, 2023
1 parent 9ee586d commit 8a2fa20
Show file tree
Hide file tree
Showing 163 changed files with 3,260 additions and 1,968 deletions.
44 changes: 44 additions & 0 deletions src/awkward/_attrs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
from __future__ import annotations

from collections.abc import Mapping

from awkward._typing import Any, JSONMapping


def attrs_of_obj(obj, attrs: Mapping | None = None) -> Mapping | None:
from awkward.highlevel import Array, ArrayBuilder, Record

if attrs is not None:
return attrs
elif isinstance(obj, (Array, Record, ArrayBuilder)):
return obj._attrs
else:
return None


def attrs_of(*arrays, attrs: Mapping | None = None) -> Mapping:
# An explicit 'attrs' always wins.
if attrs is not None:
return attrs

copied = False
for x in reversed(arrays):
x_attrs = attrs_of_obj(x)
if x_attrs is None:
continue
if attrs is None:
attrs = x_attrs
elif attrs is x_attrs:
pass
elif not copied:
attrs = dict(attrs)
attrs.update(x_attrs)
copied = True
else:
attrs.update(x_attrs)
return attrs


def without_transient_attrs(attrs: dict[str, Any]) -> JSONMapping:
return {k: v for k, v in attrs.items() if not k.startswith("@")}
8 changes: 4 additions & 4 deletions src/awkward/_backends/dispatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ def backend_of_obj(obj, default: D | Sentinel = UNSET) -> Backend | D:


def backend_of(
*objects, default: D | Sentinel = UNSET, coerce_to_common: bool = False
*objects, default: D | Sentinel = UNSET, coerce_to_common: bool = True
) -> Backend | D:
"""
Args:
Expand All @@ -116,9 +116,9 @@ def backend_of(
return common_backend(unique_backends)
else:
raise ValueError(
"could not find singular backend for",
objects,
"and coercion is not permitted",
f"could not find singular backend for "
f"{', '.join(type(t).__name__ for t in objects)} "
f"and coercion is not permitted",
)


Expand Down
69 changes: 48 additions & 21 deletions src/awkward/_connect/numba/arrayview.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@
from numba.core.errors import NumbaTypeError

import awkward as ak
from awkward._behavior import behavior_of, overlay_behavior
from awkward._layout import wrap_layout
from awkward._behavior import overlay_behavior
from awkward._layout import HighLevelContext, wrap_layout
from awkward._lookup import Lookup
from awkward._nplikes.numpy_like import NumpyMetadata

np = NumpyMetadata.instance()
Expand Down Expand Up @@ -152,7 +153,17 @@ def to_numbatype(form):
########## Lookup


@numba.extending.typeof_impl.register(ak._lookup.Lookup)
class NumbaLookup(Lookup):
def __init__(self, layout, attrs, generator=None):
super().__init__(layout, generator=generator)
self._attrs = attrs

@property
def attrs(self):
return self._attrs


@numba.extending.typeof_impl.register(NumbaLookup)
def typeof_Lookup(obj, c):
return LookupType()

Expand Down Expand Up @@ -192,15 +203,21 @@ def unbox_Lookup(lookuptype, lookupobj, c):
class ArrayView:
@classmethod
def fromarray(cls, array):
behavior = behavior_of(array)
layout = ak.operations.to_layout(
array, allow_record=False, allow_unknown=False, primitive_policy="error"
)
with HighLevelContext() as ctx:
layout = ctx.unwrap(
array,
allow_record=False,
allow_unknown=False,
use_from_iter=False,
primitive_policy="error",
string_policy="error",
none_policy="error",
)

return ArrayView(
to_numbatype(layout.form),
behavior,
ak._lookup.Lookup(layout),
ctx.behavior,
NumbaLookup(layout, ctx.attrs),
0,
0,
len(layout),
Expand All @@ -219,7 +236,7 @@ def __init__(self, type, behavior, lookup, pos, start, stop, fields):
def toarray(self):
layout = self.type.tolayout(self.lookup, self.pos, self.fields)
sliced = layout._getitem_range(self.start, self.stop)
return wrap_layout(sliced, self.behavior)
return wrap_layout(sliced, behavior=self.behavior, attrs=self.lookup.attrs)


@numba.extending.typeof_impl.register(ArrayView)
Expand Down Expand Up @@ -579,20 +596,28 @@ def lower_iternext(context, builder, sig, args, result):
class RecordView:
@classmethod
def fromrecord(cls, record):
behavior = behavior_of(record)
layout = ak.operations.to_layout(
record, allow_record=True, allow_unknown=False, primitive_policy="error"
)
with HighLevelContext() as ctx:
layout = ctx.unwrap(
record,
allow_record=True,
allow_unknown=False,
use_from_iter=False,
primitive_policy="error",
string_policy="error",
none_policy="error",
)
array_layout = layout.array

assert isinstance(layout, ak.record.Record)
arraylayout = layout.array

return RecordView(
ArrayView(
to_numbatype(arraylayout.form),
behavior,
ak._lookup.Lookup(arraylayout),
to_numbatype(array_layout.form),
ctx.behavior,
NumbaLookup(array_layout, ctx.attrs),
0,
0,
len(arraylayout),
len(array_layout),
(),
),
layout.at,
Expand All @@ -603,9 +628,11 @@ def __init__(self, arrayview, at):
self.at = at

def torecord(self):
arraylayout = self.arrayview.toarray().layout
array = self.arrayview.toarray()
return wrap_layout(
ak.record.Record(arraylayout, self.at), self.arrayview.behavior
ak.record.Record(array.layout, self.at),
behavior=self.arrayview.behavior,
attrs=array.attrs,
)


Expand Down
18 changes: 16 additions & 2 deletions src/awkward/_connect/numba/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,18 @@ def __init__(self, behavior):
@numba.extending.register_model(ArrayBuilderType)
class ArrayBuilderModel(numba.core.datamodel.models.StructModel):
def __init__(self, dmm, fe_type):
members = [("rawptr", numba.types.voidptr), ("pyptr", numba.types.pyobject)]
members = [
("rawptr", numba.types.voidptr),
("pyptr", numba.types.pyobject),
("pyattrs", numba.types.pyobject),
]
super().__init__(dmm, fe_type, members)


@numba.core.imputils.lower_constant(ArrayBuilderType)
def lower_const_ArrayBuilder(context, builder, arraybuildertype, arraybuilder):
layout = arraybuilder._layout
attrs = arraybuilder._attrs
rawptr = context.get_constant(numba.intp, arraybuilder._layout._ptr)
proxyout = context.make_helper(builder, arraybuildertype)
proxyout.rawptr = builder.inttoptr(
Expand All @@ -52,20 +57,26 @@ def lower_const_ArrayBuilder(context, builder, arraybuildertype, arraybuilder):
proxyout.pyptr = context.add_dynamic_addr(
builder, id(layout), info=str(type(layout))
)
proxyout.pyattrs = context.add_dynamic_addr(
builder, id(attrs), info=str(type(attrs))
)
return proxyout._getvalue()


@numba.extending.unbox(ArrayBuilderType)
def unbox_ArrayBuilder(arraybuildertype, arraybuilderobj, c):
attrs_obj = c.pyapi.object_getattr_string(arraybuilderobj, "_attrs")
inner_obj = c.pyapi.object_getattr_string(arraybuilderobj, "_layout")
rawptr_obj = c.pyapi.object_getattr_string(inner_obj, "_ptr")

proxyout = c.context.make_helper(c.builder, arraybuildertype)
proxyout.rawptr = c.pyapi.long_as_voidptr(rawptr_obj)
proxyout.pyptr = inner_obj
proxyout.pyattrs = attrs_obj

c.pyapi.decref(inner_obj)
c.pyapi.decref(rawptr_obj)
c.pyapi.decref(attrs_obj)

is_error = numba.core.cgutils.is_not_null(c.builder, c.pyapi.err_occurred())
return numba.extending.NativeValue(proxyout._getvalue(), is_error)
Expand All @@ -90,8 +101,11 @@ def box_ArrayBuilder(arraybuildertype, arraybuilderval, c):

proxyin = c.context.make_helper(c.builder, arraybuildertype, arraybuilderval)
c.pyapi.incref(proxyin.pyptr)
attrs_obj = proxyin.pyattrs

out = c.pyapi.call_method(ArrayBuilder_obj, "_wrap", (proxyin.pyptr, behavior_obj))
out = c.pyapi.call_method(
ArrayBuilder_obj, "_wrap", (proxyin.pyptr, behavior_obj, attrs_obj)
)

c.pyapi.decref(ArrayBuilder_obj)
c.pyapi.decref(behavior_obj)
Expand Down
10 changes: 2 additions & 8 deletions src/awkward/_connect/numexpr.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,7 @@ def evaluate(
names, ex_uses_vml = numexpr.necompiler._names_cache[expr_key]
arguments = getArguments(names, local_dict, global_dict)

arrays = [
ak.operations.to_layout(x, allow_record=True, allow_unknown=True)
for x in arguments
]
arrays = [ak.operations.to_layout(x, allow_unknown=True) for x in arguments]

def action(inputs, **ignore):
if all(
Expand Down Expand Up @@ -131,10 +128,7 @@ def re_evaluate(local_dict=None):
names = numexpr.necompiler._numexpr_last["argnames"]
arguments = getArguments(names, local_dict)

arrays = [
ak.operations.to_layout(x, allow_record=True, allow_unknown=True)
for x in arguments
]
arrays = [ak.operations.to_layout(x, allow_unknown=True) for x in arguments]

def action(inputs, **ignore):
if all(
Expand Down
14 changes: 10 additions & 4 deletions src/awkward/_connect/numpy.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,14 @@ def _to_rectilinear(arg, backend: Backend):
return arg


def array_function(func, types, args, kwargs: dict[str, Any], behavior: Mapping | None):
def array_function(
func,
types,
args,
kwargs: dict[str, Any],
behavior: Mapping | None,
attrs: Mapping[str, Any] | None = None,
):
function = implemented.get(func)
if function is not None:
return function(*args, **kwargs)
Expand All @@ -106,13 +113,13 @@ def array_function(func, types, args, kwargs: dict[str, Any], behavior: Mapping
result,
allow_record=True,
allow_unknown=True,
allow_none=True,
none_policy="pass-through",
regulararray=True,
use_from_iter=True,
primitive_policy="pass-through",
string_policy="pass-through",
)
return wrap_layout(out, behavior=behavior, allow_other=True)
return wrap_layout(out, behavior=behavior, allow_other=True, attrs=attrs)


def implements(numpy_function):
Expand Down Expand Up @@ -152,7 +159,6 @@ def _array_ufunc_custom_cast(inputs, behavior: Mapping | None, backend):
cast_fcn = find_custom_cast(x, behavior)
maybe_layout = ak.operations.to_layout(
x if cast_fcn is None else cast_fcn(x),
allow_record=True,
allow_unknown=True,
primitive_policy="pass-through",
string_policy="pass-through",
Expand Down
Loading

0 comments on commit 8a2fa20

Please sign in to comment.