Skip to content

Commit

Permalink
Register new media transformations automatically (pymc-labs#1320)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored Dec 28, 2024
1 parent 1630589 commit 327ac97
Show file tree
Hide file tree
Showing 6 changed files with 140 additions and 43 deletions.
20 changes: 8 additions & 12 deletions pymc_marketing/mmm/components/adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ def function(self, x, alpha):
"""

from __future__ import annotations

import numpy as np
import xarray as xr
from pydantic import Field, validate_call
Expand All @@ -60,6 +62,7 @@ def function(self, x, alpha):
from pymc_marketing.mmm.components.base import (
SupportedPrior,
Transformation,
create_registration_meta,
)
from pymc_marketing.mmm.transformers import (
ConvMode,
Expand All @@ -70,8 +73,12 @@ def function(self, x, alpha):
)
from pymc_marketing.prior import Prior

ADSTOCK_TRANSFORMATIONS: dict[str, type[AdstockTransformation]] = {}

AdstockRegistrationMeta: type[type] = create_registration_meta(ADSTOCK_TRANSFORMATIONS)


class AdstockTransformation(Transformation):
class AdstockTransformation(Transformation, metaclass=AdstockRegistrationMeta): # type: ignore
"""Subclass for all adstock functions.
In order to use a custom saturation function, inherit from this class and define:
Expand Down Expand Up @@ -322,17 +329,6 @@ def function(self, x, lam, k):
}


ADSTOCK_TRANSFORMATIONS: dict[str, type[AdstockTransformation]] = {
cls.lookup_name: cls # type: ignore
for cls in [
GeometricAdstock,
DelayedAdstock,
WeibullPDFAdstock,
WeibullCDFAdstock,
]
}


def register_adstock_transformation(cls: type[AdstockTransformation]) -> None:
"""Register a new adstock transformation."""
ADSTOCK_TRANSFORMATIONS[cls.lookup_name] = cls
Expand Down
44 changes: 44 additions & 0 deletions pymc_marketing/mmm/components/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -592,3 +592,47 @@ def _serialize_value(value: Any) -> Any:
return value.tolist()

return value


class DuplicatedTransformationError(Exception):
"""Exception when a transformation is duplicated."""

def __init__(self, name: str, lookup_name: str):
self.name = name
self.lookup_name = lookup_name
super().__init__(f"Duplicate {name}. The name {lookup_name!r} already exists.")


def create_registration_meta(subclasses: dict[str, Any]) -> type[type]:
"""Create a metaclass for registering subclasses.
Parameters
----------
subclasses : dict[str, type[Transformation]]
The subclasses to register.
Returns
-------
type
The metaclass for registering subclasses.
"""

class RegistrationMeta(type):
def __new__(cls, name, bases, attrs):
new_cls = super().__new__(cls, name, bases, attrs)

if "lookup_name" not in attrs:
return new_cls

base_name = bases[0].__name__

lookup_name = attrs["lookup_name"]
if lookup_name in subclasses:
raise DuplicatedTransformationError(base_name, lookup_name)

subclasses[lookup_name] = new_cls

return new_cls

return RegistrationMeta
24 changes: 8 additions & 16 deletions pymc_marketing/mmm/components/saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ def function(self, x, b):
"""

from __future__ import annotations

import numpy as np
import pytensor.tensor as pt
import xarray as xr
Expand All @@ -79,6 +81,7 @@ def function(self, x, b):
from pymc_marketing.deserialize import deserialize, register_deserialization
from pymc_marketing.mmm.components.base import (
Transformation,
create_registration_meta,
)
from pymc_marketing.mmm.transformers import (
hill_function,
Expand All @@ -92,8 +95,12 @@ def function(self, x, b):
)
from pymc_marketing.prior import Prior

SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {}

SaturationRegistrationMeta = create_registration_meta(SATURATION_TRANSFORMATIONS)


class SaturationTransformation(Transformation):
class SaturationTransformation(Transformation, metaclass=SaturationRegistrationMeta): # type: ignore
"""Subclass for all saturation transformations.
In order to use a custom saturation transformation, subclass and define:
Expand Down Expand Up @@ -452,21 +459,6 @@ def function(self, x, alpha, beta):
}


SATURATION_TRANSFORMATIONS: dict[str, type[SaturationTransformation]] = {
cls.lookup_name: cls
for cls in [
LogisticSaturation,
InverseScaledLogisticSaturation,
TanhSaturation,
TanhSaturationBaselined,
MichaelisMentenSaturation,
HillSaturation,
HillSaturationSigmoid,
RootSaturation,
]
}


def register_saturation_transformation(cls: type[SaturationTransformation]) -> None:
"""Register a new saturation transformation.
Expand Down
43 changes: 28 additions & 15 deletions tests/mmm/components/test_adstock.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@
WeibullCDFAdstock,
WeibullPDFAdstock,
adstock_from_dict,
register_adstock_transformation,
)
from pymc_marketing.mmm.components.adstock import (
ADSTOCK_TRANSFORMATIONS,
)
from pymc_marketing.mmm.transformers import ConvMode
from pymc_marketing.prior import Prior
Expand Down Expand Up @@ -161,27 +157,25 @@ def test_adstock_from_dict_without_priors(adstock, deserialize_func) -> None:
}


@pytest.mark.parametrize("deserialize_func", [adstock_from_dict, deserialize])
def test_register_adstock_transformation(deserialize_func) -> None:
class NewTransformation(AdstockTransformation):
lookup_name: str = "new_transformation"
default_priors = {}
class AnotherNewTransformation(AdstockTransformation):
lookup_name: str = "another_new_transformation"
default_priors = {}

def function(self, x):
return x
def function(self, x):
return x

register_adstock_transformation(NewTransformation)
assert "new_transformation" in ADSTOCK_TRANSFORMATIONS

@pytest.mark.parametrize("deserialize_func", [adstock_from_dict, deserialize])
def test_automatic_register_adstock_transformation(deserialize_func) -> None:
data = {
"lookup_name": "new_transformation",
"lookup_name": "another_new_transformation",
"l_max": 10,
"normalize": False,
"mode": "Before",
"priors": {},
}
adstock = deserialize_func(data)
assert adstock == NewTransformation(
assert adstock == AnotherNewTransformation(
l_max=10, mode=ConvMode.Before, normalize=False, priors={}
)

Expand Down Expand Up @@ -239,3 +233,22 @@ def test_deserialization(
assert isinstance(alpha, ArbitraryObject)
assert alpha.msg == "hello"
assert alpha.value == 1


def test_deserialize_new_transformation() -> None:
class NewAdstock(AdstockTransformation):
lookup_name = "new_adstock"

def function(self, x):
return x

default_priors = {}

data = {
"lookup_name": "new_adstock",
"l_max": 10,
}

instance = deserialize(data)
assert isinstance(instance, NewAdstock)
assert instance.l_max == 10
34 changes: 34 additions & 0 deletions tests/mmm/components/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,11 @@
from pytensor.tensor import TensorVariable

from pymc_marketing.mmm.components.base import (
DuplicatedTransformationError,
MissingDataParameter,
ParameterPriorException,
Transformation,
create_registration_meta,
)
from pymc_marketing.prior import Prior

Expand Down Expand Up @@ -417,3 +419,35 @@ def test_serialization(new_transformation_class) -> None:
"b": [1, 2, 3],
},
}


def test_automatic_registration() -> None:
subclasses = {}

RegistrationMeta = create_registration_meta(subclasses)

class BaseTransform:
pass

class Transform(BaseTransform, metaclass=RegistrationMeta):
pass

class NewTransform(Transform):
lookup_name = "new"

assert subclasses == {"new": NewTransform}

class AnotherTransform(Transform):
lookup_name = "another"

assert subclasses == {"new": NewTransform, "another": AnotherTransform}

with pytest.raises(DuplicatedTransformationError) as e:

class _(Transform):
lookup_name = "new"

exception = e.value

assert exception.lookup_name == "new"
assert exception.name == "Transform"
18 changes: 18 additions & 0 deletions tests/mmm/components/test_saturation.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
LogisticSaturation,
MichaelisMentenSaturation,
RootSaturation,
SaturationTransformation,
TanhSaturation,
TanhSaturationBaselined,
saturation_from_dict,
Expand Down Expand Up @@ -287,3 +288,20 @@ def test_deserialization(
assert isinstance(alpha, ArbitraryObject)
assert alpha.msg == "hello"
assert alpha.value == 1


def test_deserialize_new_transformation() -> None:
class NewSaturation(SaturationTransformation):
lookup_name = "new_saturation"

def function(self, x):
return x

default_priors = {}

data = {
"lookup_name": "new_saturation",
}

instance = deserialize(data)
assert isinstance(instance, NewSaturation)

0 comments on commit 327ac97

Please sign in to comment.