Skip to content

Commit

Permalink
get _ensure_systematics daskified
Browse files Browse the repository at this point in the history
  • Loading branch information
lgray committed Jan 13, 2024
1 parent 1a5c7ca commit ec1006c
Showing 1 changed file with 36 additions and 8 deletions.
44 changes: 36 additions & 8 deletions src/coffea/nanoevents/methods/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,12 @@
class _ClassMethodFn:
def __init__(self, attr: str, **kwargs: Any) -> None:
self.attr = attr
self.kwargs = kwargs

def __call__(self, coll: awkward.Array, *args: Any, **kwargs: Any) -> awkward.Array:
return getattr(coll, self.attr)(*args, **kwargs)
allkwargs = self.kwargs
allkwargs.update(kwargs)
return getattr(coll, self.attr)(*args, **allkwargs)


@awkward.mixin_class(behavior)
Expand All @@ -36,12 +39,35 @@ def add_kind(cls, kind: str):
"""
cls._systematic_kinds.add(kind)

def _ensure_systematics(self):
def _ensure_systematics(self, _dask_array_=None):
"""
Make sure that the parent object always has a field called '__systematics__'.
"""
if "__systematics__" not in awkward.fields(self):
self["__systematics__"] = {}
if _dask_array_ is not None:
x = awkward.Array(
awkward.Array([{}]).layout.to_typetracer(forget_length=True)
)
_dask_array_._meta["__systematics__"] = x

def add_systematics_hack(array):
if awkward.backend(array) == "typetracer":
array["__systematics__"] = x
return array
array["__systematics__"] = {}
return array

temp = dask_awkward.map_partitions(
add_systematics_hack,
_dask_array_,
label="ensure-systematics",
meta=_dask_array_._meta,
)
_dask_array_._meta = temp._meta
_dask_array_._dask = temp._dask
_dask_array_._name = temp._name
else:
self["__systematics__"] = {}

@property
def systematics(self):
Expand Down Expand Up @@ -109,11 +135,13 @@ def add_systematic(
print("vf ", varying_function)
print("da ", _dask_array_, type(_dask_array_))
_dask_array_.map_partitions(
_ClassMethodFn("add_systematic"),
name,
kind,
what,
varying_function,
_ClassMethodFn(
"add_systematic",
name=name,
kind=kind,
varying_function=varying_function,
),
what=what,
)

self._ensure_systematics()
Expand Down

0 comments on commit ec1006c

Please sign in to comment.