From ec1006c10e8afb9cff9278db607e052877fa6f95 Mon Sep 17 00:00:00 2001 From: Lindsey Gray Date: Fri, 14 Apr 2023 11:21:23 -0500 Subject: [PATCH] get _ensure_systematics daskified --- src/coffea/nanoevents/methods/base.py | 44 ++++++++++++++++++++++----- 1 file changed, 36 insertions(+), 8 deletions(-) diff --git a/src/coffea/nanoevents/methods/base.py b/src/coffea/nanoevents/methods/base.py index 3e84252d78..6af9411534 100644 --- a/src/coffea/nanoevents/methods/base.py +++ b/src/coffea/nanoevents/methods/base.py @@ -18,9 +18,12 @@ class _ClassMethodFn: def __init__(self, attr: str, **kwargs: Any) -> None: self.attr = attr + self.kwargs = kwargs def __call__(self, coll: awkward.Array, *args: Any, **kwargs: Any) -> awkward.Array: - return getattr(coll, self.attr)(*args, **kwargs) + allkwargs = self.kwargs + allkwargs.update(kwargs) + return getattr(coll, self.attr)(*args, **allkwargs) @awkward.mixin_class(behavior) @@ -36,12 +39,35 @@ def add_kind(cls, kind: str): """ cls._systematic_kinds.add(kind) - def _ensure_systematics(self): + def _ensure_systematics(self, _dask_array_=None): """ Make sure that the parent object always has a field called '__systematics__'. """ if "__systematics__" not in awkward.fields(self): - self["__systematics__"] = {} + if _dask_array_ is not None: + x = awkward.Array( + awkward.Array([{}]).layout.to_typetracer(forget_length=True) + ) + _dask_array_._meta["__systematics__"] = x + + def add_systematics_hack(array): + if awkward.backend(array) == "typetracer": + array["__systematics__"] = x + return array + array["__systematics__"] = {} + return array + + temp = dask_awkward.map_partitions( + add_systematics_hack, + _dask_array_, + label="ensure-systematics", + meta=_dask_array_._meta, + ) + _dask_array_._meta = temp._meta + _dask_array_._dask = temp._dask + _dask_array_._name = temp._name + else: + self["__systematics__"] = {} @property def systematics(self): @@ -109,11 +135,13 @@ def add_systematic( print("vf ", varying_function) print("da ", _dask_array_, type(_dask_array_)) _dask_array_.map_partitions( - _ClassMethodFn("add_systematic"), - name, - kind, - what, - varying_function, + _ClassMethodFn( + "add_systematic", + name=name, + kind=kind, + varying_function=varying_function, + ), + what=what, ) self._ensure_systematics()