Skip to content

Commit

Permalink
update README.md Example
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Mar 11, 2024
1 parent 524e66c commit bc2766e
Showing 1 changed file with 35 additions and 66 deletions.
101 changes: 35 additions & 66 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,79 +35,48 @@ _evermore_ in a nutshell:
import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array

import evermore as evm

jax.config.update("jax_enable_x64", True)


# define a simple model with two processes and two parameters
class MyModel(evm.Model):
def __call__(self, processes: dict, parameters: dict) -> evm.Result:
res = evm.Result()

# signal
mu_mod = evm.modifier(
name="mu", parameter=parameters["mu"], effect=evm.effect.unconstrained()
)
res.add(process="signal", expectation=mu_mod(processes["signal"]))

# background
bkg_mod = evm.modifier(
name="sigma", parameter=parameters["sigma"], effect=evm.effect.gauss(0.2)
)
res.add(process="background", expectation=bkg_mod(processes["background"]))
return res


# setup model
processes = {"signal": jnp.array([10.0]), "background": jnp.array([50.0])}
parameters = {
"mu": evm.Parameter(value=jnp.array([1.0]), bounds=(0.0, jnp.inf)),
"sigma": evm.Parameter(value=jnp.array([0.0])),
}
model = MyModel(processes=processes, parameters=parameters)

# define negative log-likelihood with data (observation)
nll = evm.likelihood.NLL(model=model, observation=jnp.array([64.0]))
# jit it!
fast_nll = eqx.filter_jit(nll)

# setup fit: initial values of parameters and a suitable optimizer
init_values = model.parameter_values
optimizer = evm.optimizer.JaxOptimizer.make(
name="ScipyMinimize", settings={"method": "trust-constr"}
)

# fit
values, state = optimizer.fit(fun=fast_nll, init_values=init_values)

print(values)
# -> {'mu': Array([1.4], dtype=float64),
# 'sigma': Array([4.04723836e-14], dtype=float64)}

# eval model with fitted values
print(model.update(values=values).evaluate().expectation())
# -> Array([64.], dtype=float64)


# gradients of "prefit" model:
print(eqx.filter_grad(nll)({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([-0.12258065], dtype=float64)}


# gradients of "postfit" model:
@eqx.filter_grad
@eqx.filter_jit
def grad_postfit_nll(where: dict[str, jax.Array]) -> dict[str, jax.Array]:
nll = evm.likelihood.NLL(
model=model.update(values=values), observation=jnp.array([64.0])
)
return nll(where)


print(grad_postfit_nll({"sigma": jnp.array([0.2])}))
# -> {'sigma': Array([0.5030303], dtype=float64)}
class Model(eqx.Module):
mu: evm.Parameter
syst: evm.Parameter

def __call__(self, hists: dict[str, Array]) -> Array:
mu_modifier = self.mu.unconstrained()
syst_modifier = self.syst.lnN(width=jnp.array([0.9, 1.1]))
return mu_modifier(hists["signal"]) + syst_modifier(hists["bkg"])


nll = evm.loss.PoissonNLL()


def loss(model: Model, hists: dict[str, Array], observation: Array) -> Array:
expectation = model(hists)
# Poisson NLL of the expectation and observation
log_likelihood = nll(expectation, observation)
# Add parameter constraints from logpdfs
constraints = evm.loss.get_param_constraints(model)
log_likelihood += evm.util.sum_leaves(constraints)
return -jnp.sum(log_likelihood)


# setup model and data
hists = {"signal": jnp.array([3]), "bkg": jnp.array([10])}
observation = jnp.array([15])
model = Model(mu=evm.Parameter(1.0), syst=evm.Parameter(0.0))

# negative log-likelihood
loss_val = loss(model, hists, observation)
# gradients of negative log-likelihood w.r.t. model parameters
grads = eqx.filter_grad(loss)(model, hists, observation)
print(f"{grads.mu.value=}, {grads.syst.value=}")
# -> grads.mu.value=Array([-0.46153846]), grads.syst.value=Array([-0.15436207])
```

## Contributing
Expand Down

0 comments on commit bc2766e

Please sign in to comment.