Skip to content

Commit

Permalink
Add Censored wrapper for Prior class (pymc-labs#1309)
Browse files Browse the repository at this point in the history
* support for deserialization of three classes

* add the deserialize logic

* correct the type hint

* add an example

* separate out the error

* catch error at deserialization

* relax the input type

* add test suite

* add test for arbitrary serialization via Prior

* use deserialize within from_json

* test for deserialize support within Prior

* use general deserialization in individual media transformations

* test both deserialize funcs

* test arb deserialization in adstock

* add similar deserialization check for saturation

* better naming of the tests

* support parsing of hsgp kwargs

* add to the module level docstring

* allow VariableFactory in parse_model_config

* implement the censored variable as VariableFactory

* add test against VariableFactory

* test the variables created with censored variable

* add few more censored tests

* add tests for errors

* implement the censored variable as VariableFactory

* add test against VariableFactory

* test the variables created with censored variable

* add few more censored tests

* add tests for errors

* add module to documentation

* add to the module documentation

* Reorder the top level documentation

* add more docstring about the functions

* serialization support for Censored class

* run pre-commit

* checks that lead to deterministics

* switch out with pydantic.dataclasses.dataclass

* handle the pytensor variable case

* add setter for dims
  • Loading branch information
wd60622 authored Dec 25, 2024
1 parent 8600ef3 commit f4fe828
Show file tree
Hide file tree
Showing 2 changed files with 424 additions and 1 deletion.
269 changes: 268 additions & 1 deletion pymc_marketing/prior.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,8 @@ def custom_transform(x):
import pymc as pm
import pytensor.tensor as pt
import xarray as xr
from pydantic import validate_call
from pydantic import InstanceOf, validate_call
from pydantic.dataclasses import dataclass
from pymc.distributions.shape_utils import Dims

from pymc_marketing.deserialize import deserialize, register_deserialization
Expand Down Expand Up @@ -1025,8 +1026,274 @@ def create_likelihood_variable(
return distribution.create_variable(name)


class VariableNotFound(Exception):
"""Variable is not found."""


def _remove_random_variable(var: pt.TensorVariable) -> None:
if var.name is None:
raise ValueError("This isn't removable")

name: str = var.name

model = pm.modelcontext(None)
for idx, free_rv in enumerate(model.free_RVs):
if var == free_rv:
index_to_remove = idx
break
else:
raise VariableNotFound(f"Variable {var.name!r} not found")

var.name = None
model.free_RVs.pop(index_to_remove)
model.named_vars.pop(name)


@dataclass
class Censored:
"""Create censored random variable.
Examples
--------
Create a censored Normal distribution:
.. code-block:: python
from pymc_marketing.prior import Prior, Censored
normal = Prior("Normal")
censored_normal = Censored(normal, lower=0)
Create hierarchical censored Normal distribution:
.. code-block:: python
from pymc_marketing.prior import Prior, Censored
normal = Prior(
"Normal",
mu=Prior("Normal"),
sigma=Prior("HalfNormal"),
dims="channel",
)
censored_normal = Censored(normal, lower=0)
coords = {"channel": range(3)}
samples = censored_normal.sample_prior(coords=coords)
"""

distribution: InstanceOf[Prior]
lower: float | InstanceOf[pt.TensorVariable] = -np.inf
upper: float | InstanceOf[pt.TensorVariable] = np.inf

def __post_init__(self) -> None:
"""Check validity at initialization."""
if not self.distribution.centered:
raise ValueError(
"Censored distribution must be centered so that .dist() API can be used on distribution."
)

if self.distribution.transform is not None:
raise ValueError(
"Censored distribution can't have a transform so that .dist() API can be used on distribution."
)

@property
def dims(self) -> tuple[str, ...]:
"""The dims from the distribution to censor."""
return self.distribution.dims

@dims.setter
def dims(self, dims) -> None:
self.distribution.dims = dims

def create_variable(self, name: str) -> pt.TensorVariable:
"""Create censored random variable."""
dist = self.distribution.create_variable(name)
_remove_random_variable(var=dist)

return pm.Censored(
name,
dist,
lower=self.lower,
upper=self.upper,
dims=self.dims,
)

def to_dict(self) -> dict[str, Any]:
"""Convert the censored distribution to a dictionary."""

def handle_value(value):
if isinstance(value, pt.TensorVariable):
return value.eval().tolist()

return value

return {
"class": "Censored",
"data": {
"dist": self.distribution.to_json(),
"lower": handle_value(self.lower),
"upper": handle_value(self.upper),
},
}

@classmethod
def from_dict(cls, data: dict[str, Any]) -> Censored:
"""Create a censored distribution from a dictionary."""
data = data["data"]
return cls( # type: ignore
distribution=Prior.from_json(data["dist"]),
lower=data["lower"],
upper=data["upper"],
)

def sample_prior(
self,
coords=None,
name: str = "variable",
**sample_prior_predictive_kwargs,
) -> xr.Dataset:
"""Sample the prior distribution for the variable.
Parameters
----------
coords : dict[str, list[str]], optional
The coordinates for the variable, by default None.
Only required if the dims are specified.
name : str, optional
The name of the variable, by default "var".
sample_prior_predictive_kwargs : dict
Additional arguments to pass to `pm.sample_prior_predictive`.
Returns
-------
xr.Dataset
The dataset of the prior samples.
Example
-------
Sample from a censored Gamma distribution.
.. code-block:: python
gamma = Prior("Gamma", mu=1, sigma=1, dims="channel")
dist = Censored(gamma, lower=0.5)
coords = {"channel": ["C1", "C2", "C3"]}
prior = dist.sample_prior(coords=coords)
"""
coords = coords or {}

if missing_keys := set(self.dims) - set(coords.keys()):
raise KeyError(f"Coords are missing the following dims: {missing_keys}")

with pm.Model(coords=coords):
self.create_variable(name)

return pm.sample_prior_predictive(**sample_prior_predictive_kwargs).prior

def to_graph(self):
"""Generate a graph of the variables.
Examples
--------
Create graph for a censored Normal distribution
.. code-block:: python
from pymc_marketing.prior import Prior, Censored
normal = Prior("Normal")
censored_normal = Censored(normal, lower=0)
censored_normal.to_graph()
"""
coords = {name: ["DUMMY"] for name in self.dims}
with pm.Model(coords=coords) as model:
self.create_variable("var")

return pm.model_to_graphviz(model)

def create_likelihood_variable(
self,
name: str,
mu: pt.TensorLike,
observed: pt.TensorLike,
) -> pt.TensorVariable:
"""Create observed censored variable.
Will require that the distribution has a `mu` parameter
and that it has not been set in the parameters.
Parameters
----------
name : str
The name of the variable.
mu : pt.TensorLike
The mu parameter for the likelihood.
observed : pt.TensorLike
The observed data.
Returns
-------
pt.TensorVariable
The PyMC variable.
Examples
--------
Create a censored likelihood variable in a larger PyMC model.
.. code-block:: python
import pymc as pm
from pymc_marketing.prior import Prior, Censored
normal = Prior("Normal", sigma=Prior("HalfNormal"))
dist = Censored(normal, lower=0)
observed = 1
with pm.Model():
# Create the likelihood variable
mu = pm.HalfNormal("mu", sigma=1)
dist.create_likelihood_variable("y", mu=mu, observed=observed)
"""
if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution):
raise UnsupportedDistributionError(
f"Likelihood distribution {self.distribution.distribution!r} is not supported."
)

if "mu" in self.distribution.parameters:
raise MuAlreadyExistsError(self.distribution)

distribution = self.distribution.deepcopy()
distribution.parameters["mu"] = mu

dist = distribution.create_variable(name)
_remove_random_variable(var=dist)

return pm.Censored(
name,
dist,
observed=observed,
lower=self.lower,
upper=self.upper,
dims=self.dims,
)


def _is_prior_type(data: dict) -> bool:
return "dist" in data


def _is_censored_type(data: dict) -> bool:
return data.keys() == {"class", "data"} and data["class"] == "Censored"


register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_json)
register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict)
Loading

0 comments on commit f4fe828

Please sign in to comment.