Skip to content

Commit

Permalink
allow arbitrary dict structures for parameters, values, and processes
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Dec 1, 2023
1 parent 8c79443 commit 7a0b815
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 54 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,7 @@ jax.config.update("jax_enable_x64", True)

# define a simple model with two processes and two parameters
class MyModel(dlx.Model):
def __call__(
self, processes: dict, parameters: dict[str, dlx.Parameter]
) -> dlx.Result:
def __call__(self, processes: dict, parameters: dict) -> dlx.Result:
res = dlx.Result()

# signal
Expand Down
7 changes: 1 addition & 6 deletions examples/model.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,12 @@
from __future__ import annotations

import jax
import jax.numpy as jnp

import dilax as dlx


class SPlusBModel(dlx.Model):
def __call__(
self,
processes: dict,
parameters: dict[str, jax.Array],
) -> dlx.Result:
def __call__(self, processes: dict, parameters: dict) -> dlx.Result:
res = dlx.Result()

mu_modifier = dlx.modifier(
Expand Down
10 changes: 5 additions & 5 deletions src/dilax/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ class NLL(BaseModule):
def logpdf(self, *args, **kwargs) -> jax.Array:
return jax.scipy.stats.poisson.logpmf(*args, **kwargs)

def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array:
def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array:
if values is _NoValue:
values = self.model.parameter_values
model = self.model.update(values=values)
Expand All @@ -68,11 +68,11 @@ def __init__(self, model: Model, observation: jax.Array) -> None:
super().__init__(model=model, observation=observation)
self.NLL = NLL(model=model, observation=observation)

def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array:
def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array:
if values is _NoValue:
values = self.model.parameter_values
if TYPE_CHECKING:
values = cast(dict[str, jax.Array], values)
values = cast(dict, values)
hessian = jax.hessian(self.NLL, argnums=0)(values)
hessian, _ = jax.tree_util.tree_flatten(hessian)
hessian = jnp.array(hessian)
Expand All @@ -85,7 +85,7 @@ class CovMatrix(Hessian):
Covariance matrix.
"""

def __call__(self, values: dict[str, jax.Array] | Sentinel = _NoValue) -> jax.Array:
def __call__(self, values: dict | Sentinel = _NoValue) -> jax.Array:
if values is _NoValue:
values = self.model.parameter_values
hessian = super().__call__(values=values)
Expand All @@ -105,7 +105,7 @@ def __init__(self, model: Model, observation: jax.Array) -> None:

def __call__(
self,
values: dict[str, jax.Array] | Sentinel = _NoValue,
values: dict | Sentinel = _NoValue,
key: jax.Array | Sentinel = _NoValue,
) -> dict[str, jax.Array]:
if values is _NoValue:
Expand Down
100 changes: 60 additions & 40 deletions src/dilax/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import jax.numpy as jnp

from dilax.parameter import Parameter
from dilax.util import Sentinel, _NoValue
from dilax.util import Sentinel, _NoValue, deep_update

__all__ = [
"Result",
Expand All @@ -34,6 +34,14 @@ def expectation(self) -> jax.Array:
return cast(jax.Array, sum(jax.tree_util.tree_leaves(self.expectations)))


def _is_parameter(leaf: Any) -> bool:
return isinstance(leaf, Parameter)


def _is_none_or_is_parameter(leaf: Any) -> bool:
return leaf is None or _is_parameter(leaf)


class Model(eqx.Module):
"""
A model describing nuisance parameters, templates (histograms), and how they interact.
Expand All @@ -52,7 +60,7 @@ class Model(eqx.Module):
# Define a simple model with two processes and two parameters
class MyModel(dlx.Model):
def __call__(self, processes: dict, parameters: dict[str, dlx.Parameter]) -> dlx.Result:
def __call__(self, processes: dict, parameters: dict) -> dlx.Result:
res = dlx.Result()
# signal
Expand Down Expand Up @@ -101,7 +109,7 @@ def eval(model) -> jax.Array:
def __init__(
self,
processes: dict,
parameters: dict[str, Parameter],
parameters: dict,
auxiliary: Any | Sentinel = _NoValue,
) -> None:
self.processes = processes
Expand All @@ -111,68 +119,80 @@ def __init__(
self.auxiliary = auxiliary

@property
def parameter_values(self) -> dict[str, jax.Array]:
return {key: param.value for key, param in self.parameters.items()}

def parameter_constraints(self) -> dict[str, jax.Array]:
constraints = {}
for name, param in self.parameters.items():
# skip if the parameter was not used / has no constraint
if not param.constraints:
continue
if not len(param.constraints) <= 1:
msg = f"More than one constraint per parameter is not allowed. Got: {param.constraint}"
raise ValueError(msg)
constraint = next(iter(param.constraints))
constraints[name] = constraint.logpdf(param.value)
return constraints
def parameter_values(self) -> dict:
return jax.tree_util.tree_map(
lambda l: l.value, # noqa: E741
self.parameters,
is_leaf=_is_parameter,
)

def parameter_constraints(self) -> dict:
def _constraint(param: Parameter) -> jax.Array:
if param.constraints:
if len(param.constraints) > 1:
msg = f"More than one constraint per parameter is not allowed. Got: {param.constraint}"
raise ValueError(msg)
return next(iter(param.constraints)).logpdf(param.value)
return jnp.array([0.0])

return jax.tree_util.tree_map(
_constraint,
self.parameters,
is_leaf=_is_parameter,
)

def update(
self,
processes: dict | Sentinel = _NoValue,
values: dict[str, jax.Array] | Sentinel = _NoValue,
values: dict | Sentinel = _NoValue,
) -> Model:
if values is _NoValue:
values = {}
if processes is _NoValue:
processes = {}

if TYPE_CHECKING:
values = cast(dict[str, jax.Array], values)
values = cast(dict, values)
processes = cast(dict, processes)

# patch original processes with new ones
new_processes = {}
for key, old_process in self.processes.items():
if key in processes:
new_process = processes[key]
new_processes[key] = new_process
else:
new_processes[key] = old_process
new_processes = deep_update(self.processes, processes)

# patch original parameters with new ones
new_parameters = {}
for key, old_parameter in self.parameters.items():
if key in values:
new_parameter = old_parameter.update(value=values[key])
new_parameters[key] = new_parameter
else:
new_parameters[key] = old_parameter
_updates = deep_update(
jax.tree_util.tree_map(
lambda _: None, self.parameters, is_leaf=_is_parameter
),
values,
)

def _update_params(update: jax.Array | None, param: Parameter) -> Parameter:
if update is None:
return param
return param.update(value=update)

new_parameters = jax.tree_util.tree_map(
_update_params,
_updates,
self.parameters,
is_leaf=_is_none_or_is_parameter,
)

return eqx.tree_at(
lambda t: (t.processes, t.parameters), self, (new_processes, new_parameters)
)

def nll_boundary_penalty(self) -> jax.Array:
penalty = jnp.array([0.0])

for param in self.parameters.values():
penalty += param.boundary_penalty
params = jax.tree_util.tree_leaves(self.parameters, is_leaf=_is_parameter)

return penalty
return sum(
jax.tree_util.tree_map(
lambda p: p.boundary_penalty, params, is_leaf=_is_parameter
)
)

@abc.abstractmethod
def __call__(self, processes: dict, parameters: dict[str, Parameter]) -> Result:
def __call__(self, processes: dict, parameters: dict) -> Result:
...

def evaluate(self) -> Result:
Expand Down
18 changes: 18 additions & 0 deletions src/dilax/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"as1darray",
"dump_hlo_graph",
"dump_jaxpr",
"deep_update",
]


Expand Down Expand Up @@ -328,3 +329,20 @@ def f(x: jax.Array) -> jax.Array:
filepath.write_text(dump_hlo_graph(f, x), encoding='ascii')
"""
return jax.xla_computation(fun)(*args, **kwargs).as_hlo_dot_graph()


def deep_update(
mapping: dict[K, Any],
new_mapping: dict[K, Any],
) -> dict[K, Any]:
updated_mapping = mapping.copy()
for k, v in new_mapping.items():
if (
k in updated_mapping
and isinstance(updated_mapping[k], dict)
and isinstance(v, dict)
):
updated_mapping[k] = deep_update(updated_mapping[k], v)
else:
updated_mapping[k] = v
return updated_mapping

0 comments on commit 7a0b815

Please sign in to comment.