Skip to content

Commit

Permalink
remove obsolete poisson log likelihood
Browse files Browse the repository at this point in the history
  • Loading branch information
pfackeldey committed Aug 20, 2024
1 parent 732354e commit 36d55d7
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 43 deletions.
38 changes: 0 additions & 38 deletions src/evermore/loss.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from collections.abc import Callable
from typing import cast

import equinox as eqx
import jax
import jax.numpy as jnp
from jaxtyping import Array, PyTree

Expand All @@ -13,7 +10,6 @@
__all__ = [
"get_log_probs",
"get_boundary_constraints",
"PoissonLogLikelihood",
]


Expand All @@ -38,37 +34,3 @@ def _constraint(param: Parameter) -> Array:

def get_boundary_constraints(module: PyTree) -> PyTree:
return params_map(lambda p: p.boundary_constraint, module)


class PoissonLogLikelihood(eqx.Module):
"""
Poisson log-likelihood.
Usage:
.. code-block:: python
import evermore as evm
nll = evm.loss.PoissonLogLikelihood()
def loss(model, x, y):
expectation = model(x)
loss = nll(expectation, y)
constraints = evm.loss.get_log_probs(model)
loss += evm.util.sum_over_leaves(constraints)
return -jnp.sum(loss)
"""

@property
def log_prob(self) -> Callable:
return jax.scipy.stats.poisson.logpmf

@jax.named_scope("evm.loss.PoissonLogLikelihood")
def __call__(self, expectation: Array, observation: Array) -> Array:
# poisson log-likelihood
return jnp.sum(
self.log_prob(observation, expectation)
- self.log_prob(observation, observation),
axis=-1,
)
5 changes: 0 additions & 5 deletions tests/test_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,6 @@
import evermore as evm


def test_PoissonLogLikelihood():
f = evm.loss.PoissonLogLikelihood()
assert f(jnp.array([1.0]), jnp.array([1.0])) == 0.0


def test_get_log_probs():
params = {
"a": evm.NormalParameter(value=0.5),
Expand Down

0 comments on commit 36d55d7

Please sign in to comment.