Skip to content

Commit

Permalink
Merge pull request #11 from pfackeldey/rethink/models_as_pytrees
Browse files Browse the repository at this point in the history
Be more close to equinox' philosophy of PyTrees
  • Loading branch information
pfackeldey authored Mar 11, 2024
2 parents 1855676 + 578c237 commit 37b955c
Show file tree
Hide file tree
Showing 30 changed files with 4,107 additions and 1,555 deletions.
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ jobs:
python -m pytest -ra --cov --cov-report=xml --cov-report=term
--durations=20
- name: Upload coverage report
uses: codecov/[email protected]
# - name: Upload coverage report
# uses: codecov/[email protected]

- name: Test examples
run: >-
for f in examples/*.py; do echo "run $f" && python "$f"; done
# - name: Test examples
# run: >-
# for f in examples/*.py; do echo "run $f" && python "$f"; done
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,10 @@ Thumbs.db
# if examples are run
examples/*.eqx

# analysis optimisation example
examples/analysis_opt.py
plots_analysis_opt/

test/
.vscode/
.zed/
Expand Down
2 changes: 2 additions & 0 deletions docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
"use_edit_page_button": True,
}
html_context = {"default_mode": "light"}
html_logo = "../assets/favicon.png"
html_favicon = "../assets/favicon.png"

extensions = [
"sphinx.ext.autodoc",
Expand Down
68 changes: 68 additions & 0 deletions examples/bin_by_bin_uncs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
from __future__ import annotations

import equinox as eqx
import jax.numpy as jnp
import jax.tree_util as jtu
from jaxtyping import Array, PyTree

import evermore as evm


class SPlusBModel(eqx.Module):
mu: evm.Parameter
norm1: evm.Parameter
norm2: evm.Parameter
staterrors: PyTree[evm.Parameter]

def __init__(self, hists: dict[str, Array]) -> None:
self.mu = evm.Parameter(value=1.0, lower=0.0, upper=10.0)
self.staterrors = evm.parameter.staterrors(hists=hists)
self = evm.parameter.auto_init(self)

def __call__(self, hists: dict, histsw2: dict) -> dict[str, Array]:
expectations = {}

# calculate widths of the sum of the nominal histograms for gaussian MC stat
sqrtw2 = jtu.tree_map(jnp.sqrt, histsw2)
widths = evm.util.sum_leaves(sqrtw2) / evm.util.sum_leaves(hists)
gauss_mcstat = self.staterrors["gauss"].gauss(widths)
# barlow-beeston-like condition: above 10 use gauss, below use poisson
mask = evm.util.sum_leaves(hists) > 10

# signal process
signal_poisson = self.staterrors["poisson"]["signal"].poisson(hists["signal"])
signal_mc_stats = evm.modifier.where(mask, gauss_mcstat, signal_poisson)
mu_mod = self.mu.unconstrained()
expectations["signal"] = (signal_mc_stats @ mu_mod)(hists["signal"])

# bkg1 process
bkg1_poisson = self.staterrors["poisson"]["bkg1"].poisson(hists["bkg1"])
bkg1_mc_stats = evm.modifier.where(mask, gauss_mcstat, bkg1_poisson)
norm1_mod = self.norm1.lnN(jnp.array([0.9, 1.1]))
expectations["bkg1"] = (bkg1_mc_stats @ norm1_mod)(hists["bkg1"])

# bkg2 process
bkg2_poisson = self.staterrors["poisson"]["bkg2"].poisson(hists["bkg2"])
bkg2_mc_stats = evm.modifier.where(mask, gauss_mcstat, bkg2_poisson)
norm2_mod = self.norm2.lnN(jnp.array([0.95, 1.05]))
expectations["bkg2"] = (bkg2_mc_stats @ norm2_mod)(hists["bkg2"])

# return the modified expectations
return expectations


hists = {
"signal": jnp.array([3]),
"bkg1": jnp.array([10]),
"bkg2": jnp.array([20]),
}
histsw2 = {
"signal": jnp.array([5]),
"bkg1": jnp.array([11]),
"bkg2": jnp.array([25]),
}

model = SPlusBModel(hists)

# test the model
expectations = model(hists, histsw2)
45 changes: 45 additions & 0 deletions examples/dnn_weights_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import equinox as eqx
import jax
import jax.numpy as jnp

import evermore as evm


class LinearConstrained(eqx.Module):
weights: evm.Parameter
biases: jax.Array

def __init__(self, in_size, out_size, key):
wkey, bkey = jax.random.split(key)
# weights
constraint = evm.pdf.Gauss(
mean=jnp.zeros((out_size, in_size)),
width=jnp.full((out_size, in_size), 0.5),
)
self.weights = evm.Parameter(
value=jax.random.normal(wkey, (out_size, in_size)), constraint=constraint
)

# biases
self.biases = jax.random.normal(bkey, (out_size,))

def __call__(self, x: jax.Array):
return self.weights.value @ x + self.biases


@eqx.filter_jit
def loss_fn(model, x, y):
pred_y = jax.vmap(model)(x)
mse = jax.numpy.mean((y - pred_y) ** 2)
constraints = evm.loss.get_param_constraints(model)
# sum them all up for each weight
constraints = jax.tree_util.tree_map(jnp.sum, constraints)
return mse + evm.util.sum_leaves(constraints)


batch_size, in_size, out_size = 32, 2, 3
model = LinearConstrained(in_size, out_size, key=jax.random.PRNGKey(0))
x = jax.numpy.zeros((batch_size, in_size))
y = jax.numpy.zeros((batch_size, out_size))
loss_val = loss_fn(model, x, y)
grads = eqx.filter_grad(loss_fn)(model, x, y)
39 changes: 16 additions & 23 deletions examples/grad_nll.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
from __future__ import annotations

import equinox as eqx
from jax import config
from model import init_values, model, observation, optimizer
import jax.numpy as jnp
from model import hists, model, observation

import evermore as evm

config.update("jax_enable_x64", True)

# create negative log likelihood
nll = evm.likelihood.NLL(model=model, observation=observation)
nll = evm.loss.PoissonNLL()

# fit
params, state = optimizer.fit(fun=nll, init_values=init_values)

# gradients of nll of fitted model
fast_grad_nll = eqx.filter_jit(eqx.filter_grad(nll))
grads = fast_grad_nll(params)
# gradients of nll of fitted model only wrt to `mu`
# basically: pass the parameters dict of which you want the gradients
params_ = {k: v for k, v in params.items() if k == "mu"}
grad_mu = fast_grad_nll(params_)
@eqx.filter_jit
def loss(model, hists, observation):
expectations = model(hists)
constraints = evm.loss.get_param_constraints(model)
loss_val = nll(
expectation=evm.util.sum_leaves(expectations),
observation=observation,
)
# add constraint
loss_val += evm.util.sum_leaves(constraints)
return -jnp.sum(loss_val)

# hessian + cov_matrix of nll of fitted model
hessian = eqx.filter_jit(evm.likelihood.Hessian(model=model, observation=observation))()

# covariance matrix of fitted model
covmatrix = eqx.filter_jit(
evm.likelihood.CovMatrix(model=model, observation=observation)
)()
loss_val = loss(model, hists, observation)
grads = eqx.filter_grad(loss)(model, hists, observation)
133 changes: 58 additions & 75 deletions examples/model.py
Original file line number Diff line number Diff line change
@@ -1,94 +1,77 @@
from __future__ import annotations

import equinox as eqx
import jax.numpy as jnp
from jaxtyping import Array

import evermore as evm


class SPlusBModel(evm.Model):
def __call__(self, processes: dict, parameters: dict) -> evm.Result:
res = evm.Result()
class SPlusBModel(eqx.Module):
mu: evm.Parameter
norm1: evm.Parameter
norm2: evm.Parameter
shape1: evm.Parameter

mu_modifier = evm.modifier(
name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained()
)
res.add(
process="signal",
expectation=mu_modifier(processes[("signal", "nominal")]),
)
def __init__(self, hist: dict[str, Array], histw2: dict[str, Array]) -> None:
self.mu = evm.Parameter(value=jnp.array([1.0]))
self = evm.parameter.auto_init(self)

bkg1_modifier = evm.compose(
evm.modifier(
name="lnN1",
parameter=parameters["norm1"],
effect=evm.effect.lnN((0.9, 1.1)),
),
evm.modifier(
name="shape1_bkg1",
parameter=parameters["shape1"],
effect=evm.effect.shape(
up=processes[("background1", "shape_up")],
down=processes[("background1", "shape_down")],
),
),
)
res.add(
process="background1",
expectation=bkg1_modifier(processes[("background1", "nominal")]),
)
def __call__(self, hists: dict) -> dict[str, Array]:
expectations = {}

bkg2_modifier = evm.compose(
evm.modifier(
name="lnN2",
parameter=parameters["norm2"],
effect=evm.effect.lnN((0.95, 1.05)),
),
evm.modifier(
name="shape1_bkg2",
parameter=parameters["shape1"],
effect=evm.effect.shape(
up=processes[("background2", "shape_up")],
down=processes[("background2", "shape_down")],
),
),
# signal process
sig_mod = self.mu.unconstrained()
expectations["signal"] = sig_mod(hists["nominal"]["signal"])

# bkg1 process
bkg1_lnN = self.norm1.lnN(width=jnp.array([0.9, 1.1]))
bkg1_shape = self.shape1.shape(
up=hists["shape_up"]["bkg1"],
down=hists["shape_down"]["bkg1"],
)
res.add(
process="background2",
expectation=bkg2_modifier(processes[("background2", "nominal")]),
# combine modifiers
bkg1_mod = bkg1_lnN @ bkg1_shape
expectations["bkg1"] = bkg1_mod(hists["nominal"]["bkg1"])

# bkg2 process
bkg2_lnN = self.norm2.lnN(width=jnp.array([0.95, 1.05]))
bkg2_shape = self.shape1.shape(
up=hists["shape_up"]["bkg2"],
down=hists["shape_down"]["bkg2"],
)
return res
# combine modifiers
bkg2_mod = bkg2_lnN @ bkg2_shape
expectations["bkg2"] = bkg2_mod(hists["nominal"]["bkg2"])

# return the modified expectations
return expectations

def create_model():
processes = {
("signal", "nominal"): jnp.array([3]),
("background1", "nominal"): jnp.array([10]),
("background2", "nominal"): jnp.array([20]),
("background1", "shape_up"): jnp.array([12]),
("background1", "shape_down"): jnp.array([8]),
("background2", "shape_up"): jnp.array([23]),
("background2", "shape_down"): jnp.array([19]),
}
parameters = {
"mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"norm1": evm.Parameter(value=jnp.array([0.0])),
"norm2": evm.Parameter(value=jnp.array([0.0])),
"shape1": evm.Parameter(value=jnp.array([0.0])),
}

# return model
return SPlusBModel(processes=processes, parameters=parameters)
hists = {
"nominal": {
"signal": jnp.array([3]),
"bkg1": jnp.array([10]),
"bkg2": jnp.array([20]),
},
"shape_up": {
"bkg1": jnp.array([12]),
"bkg2": jnp.array([23]),
},
"shape_down": {
"bkg1": jnp.array([8]),
"bkg2": jnp.array([19]),
},
}

hist = hists["nominal"]
histw2 = {
"signal": jnp.array([5]),
"bkg1": jnp.array([11]),
"bkg2": jnp.array([25]),
}

model = create_model()
model = SPlusBModel(hist, histw2)

init_values = model.parameter_values
observation = jnp.array([37])
asimov = model.evaluate().expectation()


# create optimizer (from `jaxopt`)
optimizer = evm.optimizer.JaxOptimizer.make(
name="LBFGS",
settings={"maxiter": 5, "jit": True, "unroll": True},
)
expectations = model(hists)
Loading

0 comments on commit 37b955c

Please sign in to comment.