Skip to content

Commit

Permalink
Relax the type that is needed for evaluation metric function (pymc-la…
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored Jan 19, 2025
1 parent 6f12316 commit 5e83235
Show file tree
Hide file tree
Showing 3 changed files with 103 additions and 31 deletions.
46 changes: 34 additions & 12 deletions pymc_marketing/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,8 +148,10 @@
from typing import Any, Literal

import arviz as az
import numpy as np
import numpy.typing as npt
import pandas as pd
import pymc as pm
import xarray as xr
from pymc.model.core import Model
from pytensor.tensor import TensorVariable

Expand Down Expand Up @@ -460,18 +462,19 @@ def log_inference_data(


def log_mmm_evaluation_metrics(
y_true: np.ndarray,
y_pred: np.ndarray,
y_true: npt.NDArray | pd.Series,
y_pred: npt.NDArray | xr.DataArray,
metrics_to_calculate: list[str] | None = None,
hdi_prob: float = 0.94,
prefix: str = "",
) -> None:
"""Log evaluation metrics produced by `pymc_marketing.mmm.evaluation.compute_summary_metrics()` to MLflow.
Parameters
----------
y_true : np.ndarray
y_true : npt.NDArray | pd.Series
The true values of the target variable.
y_pred : np.ndarray
y_pred : npt.NDArray | xr.DataArray
The predicted values of the target variable.
metrics_to_calculate : list of str or None, optional
List of metrics to calculate. If None, all available metrics will be calculated.
Expand All @@ -484,23 +487,42 @@ def log_mmm_evaluation_metrics(
* `mape`: Mean Absolute Percentage Error.
hdi_prob : float, optional
The probability mass of the highest density interval. Defaults to 0.94.
prefix : str, optional
Prefix to add to the metric names. Defaults to "".
"""
# Convert y_true and y_pred to numpy arrays if they're not already
y_true_np = y_true.to_numpy() if hasattr(y_true, "to_numpy") else np.array(y_true)
y_pred_np = y_pred.to_numpy() if hasattr(y_pred, "to_numpy") else np.array(y_pred)
Examples
--------
Log in-sample evaluation metrics for a PyMC-Marketing MMM model:
.. code-block:: python
import mlflow
from pymc_marketing.mmm import MMM
mmm = MMM(...)
mmm.fit(X, y)
predictions = mmm.sample_posterior_predictive(X)
with mlflow.start_run():
log_mmm_evaluation_metrics(y, predictions["y"])
"""
metric_summaries = compute_summary_metrics(
y_true=y_true_np,
y_pred=y_pred_np,
y_true=y_true,
y_pred=y_pred,
metrics_to_calculate=metrics_to_calculate,
hdi_prob=hdi_prob,
)

if prefix and not prefix.endswith("_"):
prefix = f"{prefix}_"

for metric, stats in metric_summaries.items():
for stat, value in stats.items():
# mlflow doesn't support % in metric names
mlflow.log_metric(f"{metric}_{stat.replace('%', '')}", value)
mlflow.log_metric(f"{prefix}{metric}_{stat.replace('%', '')}", value)


class MMMWrapper(mlflow.pyfunc.PythonModel):
Expand Down
61 changes: 51 additions & 10 deletions pymc_marketing/mmm/evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,13 @@
# limitations under the License.
"""Evaluation and diagnostics for MMM models."""

from typing import cast

import arviz as az
import numpy as np
import numpy.typing as npt
import pandas as pd
import xarray as xr
from sklearn.metrics import (
mean_absolute_error,
mean_absolute_percentage_error,
Expand All @@ -26,17 +30,17 @@


def calculate_metric_distributions(
y_true: npt.NDArray,
y_pred: npt.NDArray,
y_true: npt.NDArray | pd.Series,
y_pred: npt.NDArray | xr.DataArray,
metrics_to_calculate: list[str] | None = None,
) -> dict[str, npt.NDArray]:
"""Calculate distributions of evaluation metrics for posterior samples.
Parameters
----------
y_true : npt.NDArray
y_true : npt.NDArray | pd.Series
True values for the dataset. Shape: (date,)
y_pred : npt.NDArray
y_pred : npt.NDArray | xr.DataArray
Posterior predictive samples. Shape: (date, sample)
metrics_to_calculate : list of str or None, optional
List of metrics to calculate. Options include:
Expand All @@ -53,6 +57,12 @@ def calculate_metric_distributions(
dict of str to npt.NDArray
A dictionary containing calculated metric distributions.
"""
if isinstance(y_true, pd.Series):
y_true = cast(np.ndarray, y_true.to_numpy())

if isinstance(y_pred, xr.DataArray):
y_pred = y_pred.values

metric_functions = {
"r_squared": lambda y_true, y_pred: az.r2_score(y_true, y_pred.T)["r2"],
"rmse": root_mean_squared_error,
Expand Down Expand Up @@ -131,8 +141,8 @@ def summarize_metric_distributions(


def compute_summary_metrics(
y_true: npt.NDArray,
y_pred: npt.NDArray,
y_true: npt.NDArray | pd.Series,
y_pred: npt.NDArray | xr.DataArray,
metrics_to_calculate: list[str] | None = None,
hdi_prob: float = 0.94,
) -> dict[str, dict[str, float]]:
Expand All @@ -143,9 +153,9 @@ def compute_summary_metrics(
Parameters
----------
y_true : npt.NDArray
y_true : npt.NDArray | pd.Series
The true values of the target variable.
y_pred : npt.NDArray
y_pred : npt.NDArray | xr.DataArray
The predicted values of the target variable.
metrics_to_calculate : list of str or None, optional
List of metrics to calculate. Options include:
Expand Down Expand Up @@ -179,6 +189,7 @@ def compute_summary_metrics(
.. code-block:: python
import pandas as pd
from pymc_marketing.mmm import (
GeometricAdstock,
LogisticSaturation,
Expand Down Expand Up @@ -213,7 +224,7 @@ def compute_summary_metrics(
results = compute_summary_metrics(
y_true=mmm.y,
y_pred=posterior_preds.y,
metrics_to_calculate=['r_squared', 'rmse', 'mae'],
metrics_to_calculate=["r_squared", "rmse", "mae"],
hdi_prob=0.89
)
Expand All @@ -223,9 +234,39 @@ def compute_summary_metrics(
for stat, value in stats.items():
print(f" {stat}: {value:.4f}")
print()
# r_squared:
# mean: 0.9055
# median: 0.9061
# std: 0.0098
# min: 0.8669
# max: 0.9371
# 89%_hdi_lower: 0.8891
# 89%_hdi_upper: 0.9198
#
# rmse:
# mean: 351.9120
# median: 351.0219
# std: 19.4732
# min: 290.6544
# max: 418.0821
# 89%_hdi_lower: 317.0673
# 89%_hdi_upper: 378.1048
#
# mae:
# mean: 281.6953
# median: 281.2757
# std: 16.3375
# min: 234.1462
# max: 337.9461
# 89%_hdi_lower: 255.7273
# 89%_hdi_upper: 307.2391
"""
metric_distributions = calculate_metric_distributions(
y_true, y_pred, metrics_to_calculate
y_true,
y_pred,
metrics_to_calculate,
)
metric_summaries = summarize_metric_distributions(metric_distributions, hdi_prob)
return metric_summaries
27 changes: 18 additions & 9 deletions tests/mmm/test_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
# limitations under the License.

import numpy as np
import pandas as pd
import pytest
import xarray as xr

from pymc_marketing.mmm.evaluation import (
calculate_metric_distributions,
Expand Down Expand Up @@ -43,8 +45,17 @@ def sample_data(manage_random_state) -> tuple[np.ndarray, np.ndarray]:
return y_true, y_pred


def test_calculate_metric_distributions_all_metrics(sample_data) -> None:
@pytest.mark.parametrize("y_true_cls", [np.array, pd.Series])
@pytest.mark.parametrize("y_pred_cls", [np.array, xr.DataArray])
def test_calculate_metric_distributions_all_metrics(
sample_data,
y_true_cls,
y_pred_cls,
) -> None:
y_true, y_pred = sample_data
y_true = y_true_cls(y_true)
y_pred = y_pred_cls(y_pred)

metrics = ["r_squared", "rmse", "nrmse", "mae", "nmae", "mape"]

results = calculate_metric_distributions(y_true, y_pred, metrics)
Expand All @@ -56,14 +67,12 @@ def test_calculate_metric_distributions_all_metrics(sample_data) -> None:
assert all(len(results[metric]) == y_pred.shape[1] for metric in metrics)

# Value range checks
assert all(0 <= results["r_squared"]) and all(
results["r_squared"] <= 1
) # R² between 0 and 1
assert all(results["rmse"] >= 0) # RMSE non-negative
assert all(results["nrmse"] >= 0) # NRMSE non-negative
assert all(results["mae"] >= 0) # MAE non-negative
assert all(results["nmae"] >= 0) # NMAE non-negative
assert all(results["mape"] >= 0) # MAPE non-negative
assert all(0 <= results["r_squared"]) and all(results["r_squared"] <= 1)
assert all(results["rmse"] >= 0)
assert all(results["nrmse"] >= 0)
assert all(results["mae"] >= 0)
assert all(results["nmae"] >= 0)
assert all(results["mape"] >= 0)


def test_calculate_metric_distributions_default_metrics(sample_data) -> None:
Expand Down

0 comments on commit 5e83235

Please sign in to comment.