-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #11 from pfackeldey/rethink/models_as_pytrees
Be more close to equinox' philosophy of PyTrees
- Loading branch information
Showing
30 changed files
with
4,107 additions
and
1,555 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -51,9 +51,9 @@ jobs: | |
python -m pytest -ra --cov --cov-report=xml --cov-report=term | ||
--durations=20 | ||
- name: Upload coverage report | ||
uses: codecov/[email protected] | ||
# - name: Upload coverage report | ||
# uses: codecov/[email protected] | ||
|
||
- name: Test examples | ||
run: >- | ||
for f in examples/*.py; do echo "run $f" && python "$f"; done | ||
# - name: Test examples | ||
# run: >- | ||
# for f in examples/*.py; do echo "run $f" && python "$f"; done |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,68 @@ | ||
from __future__ import annotations | ||
|
||
import equinox as eqx | ||
import jax.numpy as jnp | ||
import jax.tree_util as jtu | ||
from jaxtyping import Array, PyTree | ||
|
||
import evermore as evm | ||
|
||
|
||
class SPlusBModel(eqx.Module): | ||
mu: evm.Parameter | ||
norm1: evm.Parameter | ||
norm2: evm.Parameter | ||
staterrors: PyTree[evm.Parameter] | ||
|
||
def __init__(self, hists: dict[str, Array]) -> None: | ||
self.mu = evm.Parameter(value=1.0, lower=0.0, upper=10.0) | ||
self.staterrors = evm.parameter.staterrors(hists=hists) | ||
self = evm.parameter.auto_init(self) | ||
|
||
def __call__(self, hists: dict, histsw2: dict) -> dict[str, Array]: | ||
expectations = {} | ||
|
||
# calculate widths of the sum of the nominal histograms for gaussian MC stat | ||
sqrtw2 = jtu.tree_map(jnp.sqrt, histsw2) | ||
widths = evm.util.sum_leaves(sqrtw2) / evm.util.sum_leaves(hists) | ||
gauss_mcstat = self.staterrors["gauss"].gauss(widths) | ||
# barlow-beeston-like condition: above 10 use gauss, below use poisson | ||
mask = evm.util.sum_leaves(hists) > 10 | ||
|
||
# signal process | ||
signal_poisson = self.staterrors["poisson"]["signal"].poisson(hists["signal"]) | ||
signal_mc_stats = evm.modifier.where(mask, gauss_mcstat, signal_poisson) | ||
mu_mod = self.mu.unconstrained() | ||
expectations["signal"] = (signal_mc_stats @ mu_mod)(hists["signal"]) | ||
|
||
# bkg1 process | ||
bkg1_poisson = self.staterrors["poisson"]["bkg1"].poisson(hists["bkg1"]) | ||
bkg1_mc_stats = evm.modifier.where(mask, gauss_mcstat, bkg1_poisson) | ||
norm1_mod = self.norm1.lnN(jnp.array([0.9, 1.1])) | ||
expectations["bkg1"] = (bkg1_mc_stats @ norm1_mod)(hists["bkg1"]) | ||
|
||
# bkg2 process | ||
bkg2_poisson = self.staterrors["poisson"]["bkg2"].poisson(hists["bkg2"]) | ||
bkg2_mc_stats = evm.modifier.where(mask, gauss_mcstat, bkg2_poisson) | ||
norm2_mod = self.norm2.lnN(jnp.array([0.95, 1.05])) | ||
expectations["bkg2"] = (bkg2_mc_stats @ norm2_mod)(hists["bkg2"]) | ||
|
||
# return the modified expectations | ||
return expectations | ||
|
||
|
||
hists = { | ||
"signal": jnp.array([3]), | ||
"bkg1": jnp.array([10]), | ||
"bkg2": jnp.array([20]), | ||
} | ||
histsw2 = { | ||
"signal": jnp.array([5]), | ||
"bkg1": jnp.array([11]), | ||
"bkg2": jnp.array([25]), | ||
} | ||
|
||
model = SPlusBModel(hists) | ||
|
||
# test the model | ||
expectations = model(hists, histsw2) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
import equinox as eqx | ||
import jax | ||
import jax.numpy as jnp | ||
|
||
import evermore as evm | ||
|
||
|
||
class LinearConstrained(eqx.Module): | ||
weights: evm.Parameter | ||
biases: jax.Array | ||
|
||
def __init__(self, in_size, out_size, key): | ||
wkey, bkey = jax.random.split(key) | ||
# weights | ||
constraint = evm.pdf.Gauss( | ||
mean=jnp.zeros((out_size, in_size)), | ||
width=jnp.full((out_size, in_size), 0.5), | ||
) | ||
self.weights = evm.Parameter( | ||
value=jax.random.normal(wkey, (out_size, in_size)), constraint=constraint | ||
) | ||
|
||
# biases | ||
self.biases = jax.random.normal(bkey, (out_size,)) | ||
|
||
def __call__(self, x: jax.Array): | ||
return self.weights.value @ x + self.biases | ||
|
||
|
||
@eqx.filter_jit | ||
def loss_fn(model, x, y): | ||
pred_y = jax.vmap(model)(x) | ||
mse = jax.numpy.mean((y - pred_y) ** 2) | ||
constraints = evm.loss.get_param_constraints(model) | ||
# sum them all up for each weight | ||
constraints = jax.tree_util.tree_map(jnp.sum, constraints) | ||
return mse + evm.util.sum_leaves(constraints) | ||
|
||
|
||
batch_size, in_size, out_size = 32, 2, 3 | ||
model = LinearConstrained(in_size, out_size, key=jax.random.PRNGKey(0)) | ||
x = jax.numpy.zeros((batch_size, in_size)) | ||
y = jax.numpy.zeros((batch_size, out_size)) | ||
loss_val = loss_fn(model, x, y) | ||
grads = eqx.filter_grad(loss_fn)(model, x, y) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,24 @@ | ||
from __future__ import annotations | ||
|
||
import equinox as eqx | ||
from jax import config | ||
from model import init_values, model, observation, optimizer | ||
import jax.numpy as jnp | ||
from model import hists, model, observation | ||
|
||
import evermore as evm | ||
|
||
config.update("jax_enable_x64", True) | ||
|
||
# create negative log likelihood | ||
nll = evm.likelihood.NLL(model=model, observation=observation) | ||
nll = evm.loss.PoissonNLL() | ||
|
||
# fit | ||
params, state = optimizer.fit(fun=nll, init_values=init_values) | ||
|
||
# gradients of nll of fitted model | ||
fast_grad_nll = eqx.filter_jit(eqx.filter_grad(nll)) | ||
grads = fast_grad_nll(params) | ||
# gradients of nll of fitted model only wrt to `mu` | ||
# basically: pass the parameters dict of which you want the gradients | ||
params_ = {k: v for k, v in params.items() if k == "mu"} | ||
grad_mu = fast_grad_nll(params_) | ||
@eqx.filter_jit | ||
def loss(model, hists, observation): | ||
expectations = model(hists) | ||
constraints = evm.loss.get_param_constraints(model) | ||
loss_val = nll( | ||
expectation=evm.util.sum_leaves(expectations), | ||
observation=observation, | ||
) | ||
# add constraint | ||
loss_val += evm.util.sum_leaves(constraints) | ||
return -jnp.sum(loss_val) | ||
|
||
# hessian + cov_matrix of nll of fitted model | ||
hessian = eqx.filter_jit(evm.likelihood.Hessian(model=model, observation=observation))() | ||
|
||
# covariance matrix of fitted model | ||
covmatrix = eqx.filter_jit( | ||
evm.likelihood.CovMatrix(model=model, observation=observation) | ||
)() | ||
loss_val = loss(model, hists, observation) | ||
grads = eqx.filter_grad(loss)(model, hists, observation) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,94 +1,77 @@ | ||
from __future__ import annotations | ||
|
||
import equinox as eqx | ||
import jax.numpy as jnp | ||
from jaxtyping import Array | ||
|
||
import evermore as evm | ||
|
||
|
||
class SPlusBModel(evm.Model): | ||
def __call__(self, processes: dict, parameters: dict) -> evm.Result: | ||
res = evm.Result() | ||
class SPlusBModel(eqx.Module): | ||
mu: evm.Parameter | ||
norm1: evm.Parameter | ||
norm2: evm.Parameter | ||
shape1: evm.Parameter | ||
|
||
mu_modifier = evm.modifier( | ||
name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained() | ||
) | ||
res.add( | ||
process="signal", | ||
expectation=mu_modifier(processes[("signal", "nominal")]), | ||
) | ||
def __init__(self, hist: dict[str, Array], histw2: dict[str, Array]) -> None: | ||
self.mu = evm.Parameter(value=jnp.array([1.0])) | ||
self = evm.parameter.auto_init(self) | ||
|
||
bkg1_modifier = evm.compose( | ||
evm.modifier( | ||
name="lnN1", | ||
parameter=parameters["norm1"], | ||
effect=evm.effect.lnN((0.9, 1.1)), | ||
), | ||
evm.modifier( | ||
name="shape1_bkg1", | ||
parameter=parameters["shape1"], | ||
effect=evm.effect.shape( | ||
up=processes[("background1", "shape_up")], | ||
down=processes[("background1", "shape_down")], | ||
), | ||
), | ||
) | ||
res.add( | ||
process="background1", | ||
expectation=bkg1_modifier(processes[("background1", "nominal")]), | ||
) | ||
def __call__(self, hists: dict) -> dict[str, Array]: | ||
expectations = {} | ||
|
||
bkg2_modifier = evm.compose( | ||
evm.modifier( | ||
name="lnN2", | ||
parameter=parameters["norm2"], | ||
effect=evm.effect.lnN((0.95, 1.05)), | ||
), | ||
evm.modifier( | ||
name="shape1_bkg2", | ||
parameter=parameters["shape1"], | ||
effect=evm.effect.shape( | ||
up=processes[("background2", "shape_up")], | ||
down=processes[("background2", "shape_down")], | ||
), | ||
), | ||
# signal process | ||
sig_mod = self.mu.unconstrained() | ||
expectations["signal"] = sig_mod(hists["nominal"]["signal"]) | ||
|
||
# bkg1 process | ||
bkg1_lnN = self.norm1.lnN(width=jnp.array([0.9, 1.1])) | ||
bkg1_shape = self.shape1.shape( | ||
up=hists["shape_up"]["bkg1"], | ||
down=hists["shape_down"]["bkg1"], | ||
) | ||
res.add( | ||
process="background2", | ||
expectation=bkg2_modifier(processes[("background2", "nominal")]), | ||
# combine modifiers | ||
bkg1_mod = bkg1_lnN @ bkg1_shape | ||
expectations["bkg1"] = bkg1_mod(hists["nominal"]["bkg1"]) | ||
|
||
# bkg2 process | ||
bkg2_lnN = self.norm2.lnN(width=jnp.array([0.95, 1.05])) | ||
bkg2_shape = self.shape1.shape( | ||
up=hists["shape_up"]["bkg2"], | ||
down=hists["shape_down"]["bkg2"], | ||
) | ||
return res | ||
# combine modifiers | ||
bkg2_mod = bkg2_lnN @ bkg2_shape | ||
expectations["bkg2"] = bkg2_mod(hists["nominal"]["bkg2"]) | ||
|
||
# return the modified expectations | ||
return expectations | ||
|
||
def create_model(): | ||
processes = { | ||
("signal", "nominal"): jnp.array([3]), | ||
("background1", "nominal"): jnp.array([10]), | ||
("background2", "nominal"): jnp.array([20]), | ||
("background1", "shape_up"): jnp.array([12]), | ||
("background1", "shape_down"): jnp.array([8]), | ||
("background2", "shape_up"): jnp.array([23]), | ||
("background2", "shape_down"): jnp.array([19]), | ||
} | ||
parameters = { | ||
"mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)), | ||
"norm1": evm.Parameter(value=jnp.array([0.0])), | ||
"norm2": evm.Parameter(value=jnp.array([0.0])), | ||
"shape1": evm.Parameter(value=jnp.array([0.0])), | ||
} | ||
|
||
# return model | ||
return SPlusBModel(processes=processes, parameters=parameters) | ||
hists = { | ||
"nominal": { | ||
"signal": jnp.array([3]), | ||
"bkg1": jnp.array([10]), | ||
"bkg2": jnp.array([20]), | ||
}, | ||
"shape_up": { | ||
"bkg1": jnp.array([12]), | ||
"bkg2": jnp.array([23]), | ||
}, | ||
"shape_down": { | ||
"bkg1": jnp.array([8]), | ||
"bkg2": jnp.array([19]), | ||
}, | ||
} | ||
|
||
hist = hists["nominal"] | ||
histw2 = { | ||
"signal": jnp.array([5]), | ||
"bkg1": jnp.array([11]), | ||
"bkg2": jnp.array([25]), | ||
} | ||
|
||
model = create_model() | ||
model = SPlusBModel(hist, histw2) | ||
|
||
init_values = model.parameter_values | ||
observation = jnp.array([37]) | ||
asimov = model.evaluate().expectation() | ||
|
||
|
||
# create optimizer (from `jaxopt`) | ||
optimizer = evm.optimizer.JaxOptimizer.make( | ||
name="LBFGS", | ||
settings={"maxiter": 5, "jit": True, "unroll": True}, | ||
) | ||
expectations = model(hists) |
Oops, something went wrong.