Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add WeeklyFourier #1443

Merged
merged 10 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pymc_marketing/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@

DAYS_IN_YEAR: float = 365.25
DAYS_IN_MONTH: float = DAYS_IN_YEAR / 12
DAYS_IN_WEEK: int = 7
3 changes: 2 additions & 1 deletion pymc_marketing/mmm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
TanhSaturationBaselined,
saturation_from_dict,
)
from pymc_marketing.mmm.fourier import MonthlyFourier, YearlyFourier
from pymc_marketing.mmm.fourier import MonthlyFourier, WeeklyFourier, YearlyFourier
from pymc_marketing.mmm.hsgp import (
HSGP,
CovFunc,
Expand Down Expand Up @@ -85,6 +85,7 @@
"SaturationTransformation",
"TanhSaturation",
"TanhSaturationBaselined",
"WeeklyFourier",
"WeibullCDFAdstock",
"WeibullPDFAdstock",
"YearlyFourier",
Expand Down
133 changes: 120 additions & 13 deletions pymc_marketing/mmm/fourier.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@

- Yearly Fourier: A yearly seasonality with a period of 365.25 days
- Monthly Fourier: A monthly seasonality with a period of 365.25 / 12 days
- Weekly Fourier: A weekly seasonality with a period of 7 days

.. plot::
:context: close-figs
Expand Down Expand Up @@ -221,7 +222,7 @@
from pydantic import BaseModel, Field, InstanceOf, field_serializer, model_validator
from typing_extensions import Self

from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_YEAR
from pymc_marketing.constants import DAYS_IN_MONTH, DAYS_IN_WEEK, DAYS_IN_YEAR
from pymc_marketing.deserialize import deserialize, register_deserialization
from pymc_marketing.plot import SelToString, plot_curve, plot_hdi, plot_samples
from pymc_marketing.prior import Prior, VariableFactory, create_dim_handler
Expand Down Expand Up @@ -383,9 +384,20 @@
"""
pass # pragma: no cover

@abstractmethod
def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
"""Return the relevant day within the characteristic periodicity.

Returns
-------
int or float
The relevant period within the characteristic periodicity
"""
pass

Check warning on line 396 in pymc_marketing/mmm/fourier.py

View check run for this annotation

Codecov / codecov/patch

pymc_marketing/mmm/fourier.py#L396

Added line #L396 was not covered by tests

def apply(
self,
dayofyear: pt.TensorLike,
dayofperiod: pt.TensorLike,
result_callback: Callable[[pt.TensorVariable], None] | None = None,
) -> pt.TensorVariable:
"""Apply fourier seasonality to day of year.
Expand All @@ -394,8 +406,8 @@

Parameters
----------
dayofyear : pt.TensorLike
Day of year.
dayofperiod : pt.TensorLike
Day of year or weekday
result_callback : Callable[[pt.TensorVariable], None], optional
Callback function to apply to the result, by default None

Expand Down Expand Up @@ -431,7 +443,7 @@
fourier.apply(dayofyear, result_callback=callback)

"""
periods = dayofyear / self.days_in_period
periods = dayofperiod / self.days_in_period

model = pm.modelcontext(None)
model.add_coord(self.prefix, self.nodes)
Expand Down Expand Up @@ -506,15 +518,15 @@
start_date = self.get_default_start_date(start_date=start_date)
date_range = pd.date_range(
start=start_date,
periods=int(self.days_in_period) + 1,
periods=np.ceil(self.days_in_period) + 1,
freq="D",
)
coords["date"] = date_range.to_numpy()
dayofyear = date_range.dayofyear.to_numpy()
dayofperiod = self._get_days_in_period(date_range).to_numpy()

else:
coords["day"] = full_period
dayofyear = full_period
dayofperiod = full_period

for key, values in parameters[self.variable_name].coords.items():
if key in {"chain", "draw", self.prefix}:
Expand All @@ -525,7 +537,7 @@
name = f"{self.prefix}_trend"
pm.Deterministic(
name,
self.apply(dayofyear=dayofyear),
self.apply(dayofperiod=dayofperiod),
dims=tuple(coords.keys()),
)

Expand Down Expand Up @@ -777,6 +789,16 @@
current_year = datetime.datetime.now().year
return datetime.datetime(year=current_year, month=1, day=1)

def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
"""Return the dayofyear within the yearly periodicity.

Returns
-------
int or float
The relevant period within the characteristic periodicity
"""
return dates.dayofyear


class MonthlyFourier(FourierBase):
"""Monthly fourier seasonality.
Expand All @@ -799,11 +821,11 @@
mu = np.array([0, 0, 0.5, 0])
b = 0.075
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
yearly = MonthlyFourier(n_order=2, prior=dist)
prior = yearly.sample_prior(samples=100)
curve = yearly.sample_curve(prior)
monthly = MonthlyFourier(n_order=2, prior=dist)
prior = monthly.sample_prior(samples=100)
curve = monthly.sample_curve(prior)

_, axes = yearly.plot_curve(curve)
_, axes = monthly.plot_curve(curve)
axes[0].set(title="Monthly Fourier Seasonality")
plt.show()

Expand Down Expand Up @@ -832,6 +854,83 @@
now = datetime.datetime.now()
return datetime.datetime(year=now.year, month=now.month, day=1)

def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
"""Return the dayofyear within the yearly periodicity.

Returns
-------
int or float
The relevant period within the characteristic periodicity
"""
return dates.dayofyear


class WeeklyFourier(FourierBase):
"""Weekly fourier seasonality.

.. plot::
:context: close-figs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Think a space is needed here. It is not rendering in the docs for some reason

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wrong import + missing from __init__.py. Should be fixed now (tested locally)


import arviz as az
import matplotlib.pyplot as plt
import numpy as np

from pymc_marketing.mmm import WeeklyFourier
from pymc_marketing.prior import Prior

az.style.use("arviz-white")

seed = sum(map(ord, "Weekly"))
rng = np.random.default_rng(seed)

mu = np.array([0, 0, 0.5, 0])
b = 0.075
dist = Prior("Laplace", mu=mu, b=b, dims="fourier")
weekly = WeeklyFourier(n_order=2, prior=dist)
prior = weekly.sample_prior(samples=100)
curve = weekly.sample_curve(prior)

_, axes = weekly.plot_curve(curve)
axes[0].set(title="Weekly Fourier Seasonality")
plt.show()

n_order : int
Number of fourier modes to use.
prefix : str, optional
Alternative prefix for the fourier seasonality, by default None or
"fourier"
prior : Prior | VariableFactory, optional
Prior distribution or VariableFactory for the fourier seasonality beta parameters, by
default `Prior("Laplace", mu=0, b=1)`
name : str, optional
Name of the variable that multiplies the fourier modes, by default None
variable_name : str, optional
Name of the variable that multiplies the fourier modes, by default None

"""

days_in_period: float = DAYS_IN_WEEK

def _get_default_start_date(self) -> datetime.datetime:
"""Get the default start date for weekly seasonality.

Returns the first day of the current month.
"""
now = datetime.datetime.now()
return datetime.datetime.fromisocalendar(
year=now.year, week=now.isocalendar().week, day=1
)

def _get_days_in_period(self, dates: pd.DatetimeIndex) -> pd.Index:
"""Return the weekday within the weekly periodicity.

Returns
-------
int or float
The relevant period within the characteristic periodicity
"""
return dates.weekday


def _is_yearly_fourier(data: Any) -> bool:
return data.get("class") == "YearlyFourier"
Expand All @@ -841,6 +940,10 @@
return data.get("class") == "MonthlyFourier"


def _is_weekly_fourier(data: Any) -> bool:
return data.get("class") == "WeeklyFourier"


register_deserialization(
is_type=_is_yearly_fourier,
deserialize=lambda data: YearlyFourier.from_dict(data),
Expand All @@ -850,3 +953,7 @@
is_type=_is_monthly_fourier,
deserialize=lambda data: MonthlyFourier.from_dict(data),
)

register_deserialization(
is_type=_is_weekly_fourier, deserialize=lambda data: WeeklyFourier.from_dict(data)
)
Loading