Skip to content

Commit

Permalink
Add optional ax argument to CLV plot utilities (pymc-labs#344)
Browse files Browse the repository at this point in the history
  • Loading branch information
wd60622 authored Aug 10, 2023
1 parent 8c38233 commit 601fca8
Showing 1 changed file with 38 additions and 18 deletions.
56 changes: 38 additions & 18 deletions pymc_marketing/clv/plotting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from typing import Optional

import matplotlib.pyplot as plt
import numpy as np

Expand All @@ -10,16 +12,18 @@
def plot_frequency_recency_matrix(
model,
t=1,
max_frequency=None,
max_recency=None,
title=None,
xlabel="Historical Frequency",
ylabel="Recency",
max_frequency: Optional[int] = None,
max_recency: Optional[int] = None,
title: Optional[str] = None,
xlabel: str = "Customer's Historical Frequency",
ylabel: str = "Customer's Recency",
ax: Optional[plt.Axes] = None,
**kwargs,
) -> plt.Axes:
"""
Plot recency frequency matrix as heatmap.
Plot a figure of expected transactions in T next units of time by a customer's frequency and recency.
Parameters
----------
model: lifetimes model
Expand All @@ -37,8 +41,11 @@ def plot_frequency_recency_matrix(
Figure xlabel
ylabel: str, optional
Figure ylabel
ax: plt.Axes, optional
A matplotlib axes instance. Creates new axes instance by default.
kwargs
Passed into the matplotlib.imshow command.
Returns
-------
axes: matplotlib.AxesSubplot
Expand All @@ -64,19 +71,23 @@ def plot_frequency_recency_matrix(
.mean(("draw", "chain"))
.values.reshape(mesh_recency.shape)
)
if ax is None:
ax = plt.subplot(111)

ax = plt.subplot(111)
pcm = ax.imshow(Z, **kwargs)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
if title is None:
title = (
"Expected Number of Future Purchases for {} Unit{} of Time,".format(
t, "s"[t == 1 :]
)
+ "\nby Frequency and Recency of a Customer"
)
plt.title(title)

ax.set(
xlabel=xlabel,
ylabel=ylabel,
title=title,
)

force_aspect(ax)

Expand All @@ -88,17 +99,19 @@ def plot_frequency_recency_matrix(

def plot_probability_alive_matrix(
model,
max_frequency=None,
max_recency=None,
title="Probability Customer is Alive,\nby Frequency and Recency of a Customer",
xlabel="Customer's Historical Frequency",
ylabel="Customer's Recency",
max_frequency: Optional[int] = None,
max_recency: Optional[int] = None,
title: str = "Probability Customer is Alive,\nby Frequency and Recency of a Customer",
xlabel: str = "Customer's Historical Frequency",
ylabel: str = "Customer's Recency",
ax: Optional[plt.Axes] = None,
**kwargs,
) -> plt.Axes:
"""
Plot probability alive matrix as heatmap.
Plot a figure of the probability a customer is alive based on their
frequency and recency.
Parameters
----------
model: lifetimes model
Expand All @@ -114,8 +127,11 @@ def plot_probability_alive_matrix(
Figure xlabel
ylabel: str, optional
Figure ylabel
ax: plt.Axes, optional
A matplotlib axes instance. Creates new axes instance by default.
kwargs
Passed into the matplotlib.imshow command.
Returns
-------
axes: matplotlib.AxesSubplot
Expand Down Expand Up @@ -143,12 +159,16 @@ def plot_probability_alive_matrix(

interpolation = kwargs.pop("interpolation", "none")

ax = plt.subplot(111)
if ax is None:
ax = plt.subplot(111)

pcm = ax.imshow(Z, interpolation=interpolation, **kwargs)
plt.xlabel(xlabel)
plt.ylabel(ylabel)
plt.title(title)

ax.set(
xlabel=xlabel,
ylabel=ylabel,
title=title,
)
force_aspect(ax)

# plot colorbar beside matrix
Expand Down

0 comments on commit 601fca8

Please sign in to comment.