diff --git a/src/evermore/loss.py b/src/evermore/loss.py index fbf77c1..e5190b8 100644 --- a/src/evermore/loss.py +++ b/src/evermore/loss.py @@ -1,8 +1,5 @@ -from collections.abc import Callable from typing import cast -import equinox as eqx -import jax import jax.numpy as jnp from jaxtyping import Array, PyTree @@ -13,7 +10,6 @@ __all__ = [ "get_log_probs", "get_boundary_constraints", - "PoissonLogLikelihood", ] @@ -38,37 +34,3 @@ def _constraint(param: Parameter) -> Array: def get_boundary_constraints(module: PyTree) -> PyTree: return params_map(lambda p: p.boundary_constraint, module) - - -class PoissonLogLikelihood(eqx.Module): - """ - Poisson log-likelihood. - - Usage: - - .. code-block:: python - - import evermore as evm - - nll = evm.loss.PoissonLogLikelihood() - - def loss(model, x, y): - expectation = model(x) - loss = nll(expectation, y) - constraints = evm.loss.get_log_probs(model) - loss += evm.util.sum_over_leaves(constraints) - return -jnp.sum(loss) - """ - - @property - def log_prob(self) -> Callable: - return jax.scipy.stats.poisson.logpmf - - @jax.named_scope("evm.loss.PoissonLogLikelihood") - def __call__(self, expectation: Array, observation: Array) -> Array: - # poisson log-likelihood - return jnp.sum( - self.log_prob(observation, expectation) - - self.log_prob(observation, observation), - axis=-1, - ) diff --git a/tests/test_loss.py b/tests/test_loss.py index 9524624..9f0f8d5 100644 --- a/tests/test_loss.py +++ b/tests/test_loss.py @@ -6,11 +6,6 @@ import evermore as evm -def test_PoissonLogLikelihood(): - f = evm.loss.PoissonLogLikelihood() - assert f(jnp.array([1.0]), jnp.array([1.0])) == 0.0 - - def test_get_log_probs(): params = { "a": evm.NormalParameter(value=0.5),