Skip to content

Commit

Permalink
add more meaningful checks
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Nov 24, 2023
1 parent e2aa07a commit 21acd81
Showing 1 changed file with 37 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/dilax/modifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,9 @@ def __check_init__(self):
"must be the same."
)
raise ValueError(msg)
if not self.threshold > 0.0:
msg = f"Threshold must be >= 0.0, got: {self.threshold}"
raise ValueError(msg)

def scale_factor(self, sumw: jax.Array) -> jax.Array:
from functools import partial
Expand Down Expand Up @@ -293,10 +296,43 @@ class Mode(eqx.Enumeration):

sumw: dict[str, jax.Array]
sumw2: dict[str, jax.Array]
masks: dict[str, jax.Array]
threshold: float = 10.0
mode: str = Mode.barlow_beeston_lite
key_template: str = "__staterror_{process}__"

def __init__(
self,
sumw: dict[str, jax.Array],
sumw2: dict[str, jax.Array],
threshold: float = 10.0,
mode: str = Mode.barlow_beeston_lite,
key_template: str = "__staterror_{process}__",
) -> None:
self.sumw = sumw
self.sumw2 = sumw2
self.masks = {p: _sumw < threshold for p, _sumw in sumw.items()}
self.threshold = threshold
self.mode = mode
self.key_template = key_template

def __check_init__(self):
if jax.tree_util.tree_structure(self.sumw) != jax.tree_util.tree_structure(
self.sumw2
): # type: ignore[operator]
msg = (
"The structure of `sumw` and `sumw2` needs to be identical, got "
f"`sumw`: {jax.tree_util.tree_structure(self.sumw)}) and "
f"`sumw2`: {jax.tree_util.tree_structure(self.sumw2)})"
)
raise ValueError(msg)
if not self.threshold > 0.0:
msg = f"Threshold must be >= 0.0, got: {self.threshold}"
raise ValueError(msg)
if not isinstance(self.mode, self.Mode):
msg = f"Mode must be of type {self.Mode}, got: {self.mode}"
raise ValueError(msg)

def prepare(
self
) -> tuple[dict[str, dict[str, Parameter]], dict[str, dict[str, eqx.Partial]]]:
Expand Down Expand Up @@ -354,7 +390,7 @@ def prepare(
for process, _sumw in self.sumw.items():
key = self.key_template.format(process=process)
process_parameters = parameters[key] = {}
mask = _sumw < self.threshold
mask = self.masks[process]
for i in range(len(_sumw)):
pkey = f"{process}_{i}"
if self.mode == self.Mode.barlow_beeston_lite and not mask[i]:
Expand Down

0 comments on commit 21acd81

Please sign in to comment.