From f4fe828f51101d3ffaba9e7ac94253e0ee4ff920 Mon Sep 17 00:00:00 2001 From: Will Dean <57733339+wd60622@users.noreply.github.com> Date: Wed, 25 Dec 2024 13:43:22 +0100 Subject: [PATCH] Add Censored wrapper for Prior class (#1309) * support for deserialization of three classes * add the deserialize logic * correct the type hint * add an example * separate out the error * catch error at deserialization * relax the input type * add test suite * add test for arbitrary serialization via Prior * use deserialize within from_json * test for deserialize support within Prior * use general deserialization in individual media transformations * test both deserialize funcs * test arb deserialization in adstock * add similar deserialization check for saturation * better naming of the tests * support parsing of hsgp kwargs * add to the module level docstring * allow VariableFactory in parse_model_config * implement the censored variable as VariableFactory * add test against VariableFactory * test the variables created with censored variable * add few more censored tests * add tests for errors * implement the censored variable as VariableFactory * add test against VariableFactory * test the variables created with censored variable * add few more censored tests * add tests for errors * add module to documentation * add to the module documentation * Reorder the top level documentation * add more docstring about the functions * serialization support for Censored class * run pre-commit * checks that lead to deterministics * switch out with pydantic.dataclasses.dataclass * handle the pytensor variable case * add setter for dims --- pymc_marketing/prior.py | 269 +++++++++++++++++++++++++++++++++++++++- tests/test_prior.py | 156 +++++++++++++++++++++++ 2 files changed, 424 insertions(+), 1 deletion(-) diff --git a/pymc_marketing/prior.py b/pymc_marketing/prior.py index 20b9e052..c649ddaa 100644 --- a/pymc_marketing/prior.py +++ b/pymc_marketing/prior.py @@ -105,7 +105,8 @@ def custom_transform(x): import pymc as pm import pytensor.tensor as pt import xarray as xr -from pydantic import validate_call +from pydantic import InstanceOf, validate_call +from pydantic.dataclasses import dataclass from pymc.distributions.shape_utils import Dims from pymc_marketing.deserialize import deserialize, register_deserialization @@ -1025,8 +1026,274 @@ def create_likelihood_variable( return distribution.create_variable(name) +class VariableNotFound(Exception): + """Variable is not found.""" + + +def _remove_random_variable(var: pt.TensorVariable) -> None: + if var.name is None: + raise ValueError("This isn't removable") + + name: str = var.name + + model = pm.modelcontext(None) + for idx, free_rv in enumerate(model.free_RVs): + if var == free_rv: + index_to_remove = idx + break + else: + raise VariableNotFound(f"Variable {var.name!r} not found") + + var.name = None + model.free_RVs.pop(index_to_remove) + model.named_vars.pop(name) + + +@dataclass +class Censored: + """Create censored random variable. + + Examples + -------- + Create a censored Normal distribution: + + .. code-block:: python + + from pymc_marketing.prior import Prior, Censored + + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + Create hierarchical censored Normal distribution: + + .. code-block:: python + + from pymc_marketing.prior import Prior, Censored + + normal = Prior( + "Normal", + mu=Prior("Normal"), + sigma=Prior("HalfNormal"), + dims="channel", + ) + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + samples = censored_normal.sample_prior(coords=coords) + + """ + + distribution: InstanceOf[Prior] + lower: float | InstanceOf[pt.TensorVariable] = -np.inf + upper: float | InstanceOf[pt.TensorVariable] = np.inf + + def __post_init__(self) -> None: + """Check validity at initialization.""" + if not self.distribution.centered: + raise ValueError( + "Censored distribution must be centered so that .dist() API can be used on distribution." + ) + + if self.distribution.transform is not None: + raise ValueError( + "Censored distribution can't have a transform so that .dist() API can be used on distribution." + ) + + @property + def dims(self) -> tuple[str, ...]: + """The dims from the distribution to censor.""" + return self.distribution.dims + + @dims.setter + def dims(self, dims) -> None: + self.distribution.dims = dims + + def create_variable(self, name: str) -> pt.TensorVariable: + """Create censored random variable.""" + dist = self.distribution.create_variable(name) + _remove_random_variable(var=dist) + + return pm.Censored( + name, + dist, + lower=self.lower, + upper=self.upper, + dims=self.dims, + ) + + def to_dict(self) -> dict[str, Any]: + """Convert the censored distribution to a dictionary.""" + + def handle_value(value): + if isinstance(value, pt.TensorVariable): + return value.eval().tolist() + + return value + + return { + "class": "Censored", + "data": { + "dist": self.distribution.to_json(), + "lower": handle_value(self.lower), + "upper": handle_value(self.upper), + }, + } + + @classmethod + def from_dict(cls, data: dict[str, Any]) -> Censored: + """Create a censored distribution from a dictionary.""" + data = data["data"] + return cls( # type: ignore + distribution=Prior.from_json(data["dist"]), + lower=data["lower"], + upper=data["upper"], + ) + + def sample_prior( + self, + coords=None, + name: str = "variable", + **sample_prior_predictive_kwargs, + ) -> xr.Dataset: + """Sample the prior distribution for the variable. + + Parameters + ---------- + coords : dict[str, list[str]], optional + The coordinates for the variable, by default None. + Only required if the dims are specified. + name : str, optional + The name of the variable, by default "var". + sample_prior_predictive_kwargs : dict + Additional arguments to pass to `pm.sample_prior_predictive`. + + Returns + ------- + xr.Dataset + The dataset of the prior samples. + + Example + ------- + Sample from a censored Gamma distribution. + + .. code-block:: python + + gamma = Prior("Gamma", mu=1, sigma=1, dims="channel") + dist = Censored(gamma, lower=0.5) + + coords = {"channel": ["C1", "C2", "C3"]} + prior = dist.sample_prior(coords=coords) + + """ + coords = coords or {} + + if missing_keys := set(self.dims) - set(coords.keys()): + raise KeyError(f"Coords are missing the following dims: {missing_keys}") + + with pm.Model(coords=coords): + self.create_variable(name) + + return pm.sample_prior_predictive(**sample_prior_predictive_kwargs).prior + + def to_graph(self): + """Generate a graph of the variables. + + Examples + -------- + Create graph for a censored Normal distribution + + .. code-block:: python + + from pymc_marketing.prior import Prior, Censored + + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + censored_normal.to_graph() + + """ + coords = {name: ["DUMMY"] for name in self.dims} + with pm.Model(coords=coords) as model: + self.create_variable("var") + + return pm.model_to_graphviz(model) + + def create_likelihood_variable( + self, + name: str, + mu: pt.TensorLike, + observed: pt.TensorLike, + ) -> pt.TensorVariable: + """Create observed censored variable. + + Will require that the distribution has a `mu` parameter + and that it has not been set in the parameters. + + Parameters + ---------- + name : str + The name of the variable. + mu : pt.TensorLike + The mu parameter for the likelihood. + observed : pt.TensorLike + The observed data. + + Returns + ------- + pt.TensorVariable + The PyMC variable. + + Examples + -------- + Create a censored likelihood variable in a larger PyMC model. + + .. code-block:: python + + import pymc as pm + from pymc_marketing.prior import Prior, Censored + + normal = Prior("Normal", sigma=Prior("HalfNormal")) + dist = Censored(normal, lower=0) + + observed = 1 + + with pm.Model(): + # Create the likelihood variable + mu = pm.HalfNormal("mu", sigma=1) + dist.create_likelihood_variable("y", mu=mu, observed=observed) + + """ + if "mu" not in _get_pymc_parameters(self.distribution.pymc_distribution): + raise UnsupportedDistributionError( + f"Likelihood distribution {self.distribution.distribution!r} is not supported." + ) + + if "mu" in self.distribution.parameters: + raise MuAlreadyExistsError(self.distribution) + + distribution = self.distribution.deepcopy() + distribution.parameters["mu"] = mu + + dist = distribution.create_variable(name) + _remove_random_variable(var=dist) + + return pm.Censored( + name, + dist, + observed=observed, + lower=self.lower, + upper=self.upper, + dims=self.dims, + ) + + def _is_prior_type(data: dict) -> bool: return "dist" in data +def _is_censored_type(data: dict) -> bool: + return data.keys() == {"class", "data"} and data["class"] == "Censored" + + register_deserialization(is_type=_is_prior_type, deserialize=Prior.from_json) +register_deserialization(is_type=_is_censored_type, deserialize=Censored.from_dict) diff --git a/tests/test_prior.py b/tests/test_prior.py index df7466ec..74e65da9 100644 --- a/tests/test_prior.py +++ b/tests/test_prior.py @@ -29,12 +29,14 @@ register_deserialization, ) from pymc_marketing.prior import ( + Censored, MuAlreadyExistsError, Prior, UnknownTransformError, UnsupportedDistributionError, UnsupportedParameterizationError, UnsupportedShapeError, + VariableFactory, handle_dims, register_tensor_transform, ) @@ -697,6 +699,138 @@ def test_create_prior_with_arbitrary() -> None: assert fast_eval(var_mu).shape == (len(coords["channel"]),) +def test_censored_is_variable_factory() -> None: + normal = Prior("Normal") + censored_normal = Censored(normal, lower=0) + + assert isinstance(censored_normal, VariableFactory) + + +@pytest.mark.parametrize( + "dims, expected_dims", + [ + ("channel", ("channel",)), + (("channel", "geo"), ("channel", "geo")), + ], + ids=["string", "tuple"], +) +def test_censored_dims_from_distribution(dims, expected_dims) -> None: + normal = Prior("Normal", dims=dims) + censored_normal = Censored(normal, lower=0) + + assert censored_normal.dims == expected_dims + + +def test_censored_variables_created() -> None: + normal = Prior("Normal", mu=Prior("Normal"), dims="dim") + censored_normal = Censored(normal, lower=0) + + coords = {"dim": range(3)} + with pm.Model(coords=coords) as model: + censored_normal.create_variable("var") + + var_names = ["var", "var_mu"] + assert set(var.name for var in model.unobserved_RVs) == set(var_names) + dims = [(3,), ()] + for var_name, dim in zip(var_names, dims, strict=False): + assert fast_eval(model[var_name]).shape == dim + + +def test_censored_sample_prior() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": ["A", "B", "C"]} + prior = censored_normal.sample_prior(coords=coords, samples=25) + + assert isinstance(prior, xr.Dataset) + assert prior.sizes == {"chain": 1, "draw": 25, "channel": 3} + + +def test_censored_to_graph() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + + G = censored_normal.to_graph() + assert isinstance(G, Digraph) + + +def test_censored_likelihood_variable() -> None: + normal = Prior("Normal", sigma=Prior("HalfNormal"), dims="channel") + censored_normal = Censored(normal, lower=0) + + coords = {"channel": range(3)} + with pm.Model(coords=coords) as model: + mu = pm.Normal("mu") + variable = censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=[1, 2, 3], + ) + + assert isinstance(variable, pt.TensorVariable) + assert model.observed_RVs == [variable] + assert "likelihood_sigma" in model + + +def test_censored_likelihood_unsupported_distribution() -> None: + cauchy = Prior("Cauchy") + censored_cauchy = Censored(cauchy, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(UnsupportedDistributionError): + censored_cauchy.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + +def test_censored_likelihood_already_has_mu() -> None: + normal = Prior("Normal", mu=Prior("Normal"), sigma=Prior("HalfNormal")) + censored_normal = Censored(normal, lower=0) + + with pm.Model(): + mu = pm.Normal("mu") + with pytest.raises(MuAlreadyExistsError): + censored_normal.create_likelihood_variable( + name="likelihood", + mu=mu, + observed=1, + ) + + +def test_censored_to_dict() -> None: + normal = Prior("Normal", mu=0, sigma=1, dims="channel") + censored_normal = Censored(normal, lower=0) + + data = censored_normal.to_dict() + assert data == { + "class": "Censored", + "data": {"dist": normal.to_json(), "lower": 0, "upper": float("inf")}, + } + + +def test_deserialize_censored() -> None: + data = { + "class": "Censored", + "data": { + "dist": { + "dist": "Normal", + }, + "lower": 0, + "upper": float("inf"), + }, + } + + instance = deserialize(data) + assert isinstance(instance, Censored) + assert isinstance(instance.distribution, Prior) + assert instance.lower == 0 + assert instance.upper == float("inf") + + class ArbitrarySerializable(Arbitrary): def to_dict(self): return {"dims": self.dims} @@ -753,3 +887,25 @@ def test_deserialize_arbitrary_within_prior( dist = deserialize(data) assert isinstance(dist["mu"], ArbitrarySerializable) assert dist["mu"].dims == ("channel",) + + +def test_censored_with_tensor_variable() -> None: + normal = Prior("Normal", dims="channel") + lower = pt.as_tensor_variable([0, 1, 2]) + censored_normal = Censored(normal, lower=lower) + + assert censored_normal.to_dict() == { + "class": "Censored", + "data": { + "dist": normal.to_json(), + "lower": [0, 1, 2], + "upper": float("inf"), + }, + } + + +def test_censored_dims_setter() -> None: + normal = Prior("Normal", dims="channel") + censored_normal = Censored(normal, lower=0) + censored_normal.dims = "date" + assert normal.dims == ("date",)