Skip to content

Commit

Permalink
keep date column in test_scale_lift_measurements (pymc-labs#1316)
Browse files Browse the repository at this point in the history
  • Loading branch information
malitsadok1 authored Jan 14, 2025
1 parent 8d94482 commit 169eb1f
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 3 deletions.
14 changes: 12 additions & 2 deletions pymc_marketing/mmm/lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,9 +643,19 @@ def scale_lift_measurements(
target_transform,
)

if "date" in df_lift_test.columns:
return pd.concat(
[
df_lift_test_channel_scaled,
df_target_scaled,
df_sigma_scaled,
pd.Series(df_lift_test["date"]),
],
axis=1,
)

return pd.concat(
[df_lift_test_channel_scaled, df_target_scaled, df_sigma_scaled],
axis=1,
[df_lift_test_channel_scaled, df_target_scaled, df_sigma_scaled], axis=1
)


Expand Down
7 changes: 6 additions & 1 deletion tests/mmm/test_lift_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,9 +474,14 @@ def test_scale_lift_measurements(df_lift_test_with_numerics) -> None:
delta_x=lambda row: row["delta_x"] * 2.0,
delta_y=lambda row: row["delta_y"] / 2,
sigma=lambda row: row["sigma"] / 2,
).loc[:, ["channel", "x", "delta_x", "delta_y", "sigma"]]
).loc[
:,
["channel", "x", "delta_x", "delta_y", "sigma"]
+ (["date"] if "date" in df_lift_test_with_numerics.columns else []),
]

pd.testing.assert_frame_equal(
result,
expected,
check_like=True,
)
34 changes: 34 additions & 0 deletions tests/mmm/test_mmm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1101,6 +1101,20 @@ def df_lift_test() -> pd.DataFrame:
)


@pytest.fixture
def df_lift_test_with_date() -> pd.DataFrame:
return pd.DataFrame(
{
"channel": ["channel_1", "channel_1"],
"x": [1, 2],
"delta_x": [1, 1],
"delta_y": [1, 1],
"sigma": [1, 1],
"date": pd.to_datetime(["2020-08-10", "2020-08-31"]),
}
)


def test_add_lift_test_measurements(mmm, toy_X, toy_y, df_lift_test) -> None:
mmm.build_model(X=toy_X, y=toy_y)

Expand Down Expand Up @@ -1314,3 +1328,23 @@ def test_channel_contributions_forward_pass_time_varying_media(toy_X, toy_y) ->
recovered_contributions.to_numpy(),
media_contributions,
)


def test_time_varying_media_with_lift_test(
toy_X, toy_y, df_lift_test_with_date
) -> None:
mmm = MMM(
date_column="date",
channel_columns=["channel_1", "channel_2"],
control_columns=["control_1", "control_2"],
adstock=GeometricAdstock(l_max=2),
saturation=LogisticSaturation(),
time_varying_media=True,
)
mmm.build_model(X=toy_X, y=toy_y)
try:
mmm.add_lift_test_measurements(df_lift_test_with_date)
except Exception as e:
pytest.fail(
f"add_lift_test_measurements for time_varying_media model failed with error {e}"
)

0 comments on commit 169eb1f

Please sign in to comment.