Skip to content

Commit

Permalink
Syntax changes/poisson pdf (#16)
Browse files Browse the repository at this point in the history
* change syntax for evaluation of poisson pdf
now the Poisson pdf takes the expectation directly as a value (not the parameter)
the `get_log_probs` function for the parameters was updated accordingly

* change normalization of Poisson logpdf

normalization is now optional with default=True

* [visualization] fix penzai v1 deprecation

* update Poisson pdf test
  • Loading branch information
felixzinn authored Aug 13, 2024
1 parent ecda94c commit 14f9cb2
Show file tree
Hide file tree
Showing 10 changed files with 23 additions and 28 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def loss(
params = eqx.combine(diffable, static)
expectation = model(params, hists)
# Poisson NLL of the expectation and observation
log_likelihood = evm.loss.PoissonLogLikelihood()(expectation, observation)
log_likelihood = evm.pdf.Poisson(lamb=expectation).log_prob(observation)
# Add parameter constraints from logpdfs
constraints = evm.loss.get_log_probs(params)
log_likelihood += evm.util.sum_over_leaves(constraints)
Expand Down
6 changes: 2 additions & 4 deletions docs/binned_likelihood.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,8 @@ def NLL(dynamic_params, static_params, hists, observation):
expectations = model(params, hists)
# first product of Eq. 1 (Poisson term)
log_likelihood = evm.loss.PoissonLogLikelihood()
loss_val = log_likelihood(
expectation=evm.util.sum_over_leaves(expectations),
observation=observation,
loss_val = evm.pdf.Poisson(lamb=evm.util.sum_over_leaves(expectations)).log_prob(
observation
)
# second product of Eq. 1 (constraint)
Expand Down
2 changes: 1 addition & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def loss(
params = eqx.combine(diffable, static)
expectation = model(params, hists)
# Poisson NLL of the expectation and observation
log_likelihood = evm.loss.PoissonLogLikelihood()(expectation, observation)
log_likelihood = evm.pdf.Poisson(expectation).log_prob(observation)
# Add parameter constraints from logpdfs
constraints = evm.loss.get_log_probs(params)
log_likelihood += evm.util.sum_over_leaves(constraints)
Expand Down
7 changes: 2 additions & 5 deletions examples/grad_nll.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@

import evermore as evm

log_likelihood = evm.loss.PoissonLogLikelihood()


@eqx.filter_jit
def loss(model, hists, observation):
expectations = model(hists)
constraints = evm.loss.get_log_probs(model)
loss_val = log_likelihood(
expectation=evm.util.sum_over_leaves(expectations),
observation=observation,
loss_val = evm.pdf.Poisson(lamb=evm.util.sum_over_leaves(expectations)).log_prob(
observation,
)
# add constraint
loss_val += evm.util.sum_over_leaves(constraints)
Expand Down
7 changes: 2 additions & 5 deletions examples/nll_fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,17 +9,14 @@
optim = optax.sgd(learning_rate=1e-2)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

log_likelihood = evm.loss.PoissonLogLikelihood()


@eqx.filter_jit
def loss(dynamic_model, static_model, hists, observation):
model = eqx.combine(dynamic_model, static_model)
expectations = model(hists)
constraints = evm.loss.get_log_probs(model)
loss_val = log_likelihood(
expectation=evm.util.sum_over_leaves(expectations),
observation=observation,
loss_val = evm.pdf.Poisson(evm.util.sum_over_leaves(expectations)).log_prob(
observation
)
# add constraint
loss_val += evm.util.sum_over_leaves(constraints)
Expand Down
9 changes: 3 additions & 6 deletions examples/nll_profiling.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
def fixed_mu_fit(mu: Array) -> Array:
from model import hists, model, observation

log_likelihood = evm.loss.PoissonLogLikelihood()

optim = optax.sgd(learning_rate=1e-2)
opt_state = optim.init(eqx.filter(model, eqx.is_inexact_array))

Expand All @@ -31,10 +29,9 @@ def loss(dynamic_model, static_model, hists, observation):
model = eqx.combine(dynamic_model, static_model)
expectations = model(hists)
constraints = evm.loss.get_log_probs(model)
loss_val = log_likelihood(
expectation=evm.util.sum_over_leaves(expectations),
observation=observation,
)
loss_val = evm.pdf.Poisson(
lamb=evm.util.sum_over_leaves(expectations)
).log_prob(observation)
# add constraint
loss_val += evm.util.sum_over_leaves(constraints)
return -2 * jnp.sum(loss_val)
Expand Down
4 changes: 4 additions & 0 deletions src/evermore/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import jax.numpy as jnp
from jaxtyping import Array, PyTree

from evermore import pdf
from evermore.custom_types import PDFLike
from evermore.parameter import Parameter, params_map

Expand All @@ -25,6 +26,9 @@ def _constraint(param: Parameter) -> Array:
prior = param.prior
if isinstance(prior, PDFLike):
prior = cast(PDFLike, prior)
if isinstance(prior, pdf.Poisson):
# expectation for Poisson pdf (x+1)*lambda
return prior.log_prob((param.value + 1) * prior.lamb)
return prior.log_prob(param.value)
return jnp.array([0.0])

Expand Down
10 changes: 6 additions & 4 deletions src/evermore/pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ def sample(self, key: PRNGKeyArray) -> Array:
class Poisson(PDF):
lamb: Array = eqx.field(converter=jnp.atleast_1d)

def log_prob(self, x: Array) -> Array:
def log_prob(self, x: Array, normalize: bool = True) -> Array:
def _continous_poisson_log_prob(x, lamb):
return xlogy(x, lamb) - lamb - gammaln(x + 1)

logpdf_max = _continous_poisson_log_prob(self.lamb, self.lamb)
unnormalized = _continous_poisson_log_prob((x + 1) * self.lamb, self.lamb)
return unnormalized - logpdf_max
unnormalized = _continous_poisson_log_prob(x, self.lamb)
if normalize:
logpdf_max = _continous_poisson_log_prob(x, x)
return unnormalized - logpdf_max
return unnormalized

def sample(self, key: PRNGKeyArray) -> Array:
# this samples only integers, do we want that?
Expand Down
2 changes: 1 addition & 1 deletion src/evermore/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def _is_evm_cls(leaf: Any, cls: Any) -> bool:


def _convert(leaf: Any, cls: Any) -> Any:
from penzai import pz
from penzai.deprecated.v1 import pz

if isinstance(leaf, cls) and dataclasses.is_dataclass(leaf):
fields = dataclasses.fields(leaf)
Expand Down
2 changes: 1 addition & 1 deletion tests/test_pdf.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@ def test_Normal():
def test_Poisson():
pdf = Poisson(lamb=jnp.array(10))

assert pdf.log_prob(jnp.array(-0.5)) == pytest.approx(-1.196003)
assert pdf.log_prob(jnp.array(5.0)) == pytest.approx(-1.5342636)

0 comments on commit 14f9cb2

Please sign in to comment.