Skip to content

Commit

Permalink
#76 improve DID plotting + improve data pre-processing
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Nov 20, 2022
1 parent 90cc898 commit b1310a6
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 234 deletions.
93 changes: 55 additions & 38 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,8 @@ def __init__(

# TODO: check that data in column self.group_variable_name has TWO levels

# TODO: check we have `unit` as a predictor column which is an vector of labels of unique units

# TODO: `treated` is a deterministic function of group and time, so this should be a function rather than supplied data

# DEVIATION FROM SKL EXPERIMENT CODE =============================
Expand Down Expand Up @@ -303,18 +305,17 @@ def plot(self):

# Plot raw data
# NOTE: This will not work when there is just ONE unit in each group
# sns.lineplot(
# self.data,
# x=self.time_variable_name,
# y=self.outcome_variable_name,
# hue=self.group_variable_name,
# # units="unit",
# estimator=None,
# alpha=0.25,
# ax=ax,
# )
sns.lineplot(
self.data,
x=self.time_variable_name,
y=self.outcome_variable_name,
hue=self.group_variable_name,
units="unit", # NOTE: assumes we have a `unit` predictor variable
estimator=None,
alpha=0.5,
ax=ax,
)
# Plot model fit to control group
# NOTE: This will not work when there is just ONE unit in each group
parts = ax.violinplot(
az.extract(
self.y_pred_control, group="posterior_predictive", var_names="mu"
Expand All @@ -330,7 +331,6 @@ def plot(self):
pc.set_alpha(0.5)

# Plot model fit to treatment group
# NOTE: This will not work when there is just ONE unit in each group
parts = ax.violinplot(
az.extract(
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
Expand All @@ -340,20 +340,41 @@ def plot(self):
showmedians=False,
widths=0.2,
)
# # Plot counterfactual - post-test for treatment group IF no treatment had occurred.
# # NOTE: This will not work when there is just ONE unit in each group
# parts = ax.violinplot(
# az.extract(
# self.y_pred_counterfactual,
# group="posterior_predictive",
# var_names="mu",
# ).values.T,
# positions=self.x_pred_counterfactual[self.time_variable_name].values,
# showmeans=False,
# showmedians=False,
# widths=0.2,
# )
for pc in parts["bodies"]:
pc.set_facecolor("C1")
pc.set_edgecolor("None")
pc.set_alpha(0.5)
# Plot counterfactual - post-test for treatment group IF no treatment had occurred.
parts = ax.violinplot(
az.extract(
self.y_pred_counterfactual,
group="posterior_predictive",
var_names="mu",
).values.T,
positions=self.x_pred_counterfactual[self.time_variable_name].values,
showmeans=False,
showmedians=False,
widths=0.2,
)
for pc in parts["bodies"]:
pc.set_facecolor("C2")
pc.set_edgecolor("None")
pc.set_alpha(0.5)
# arrow to label the causal impact
self._plot_causal_impact_arrow(ax)
# formatting
ax.set(
xticks=self.x_pred_treatment[self.time_variable_name].values,
title=self._causal_impact_summary_stat(),
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)

def _plot_causal_impact_arrow(self, ax):
"""
draw a vertical arrow between `y_pred_counterfactual` and `y_pred_counterfactual`
"""
# Calculate y values to plot the arrow between
y_pred_treatment = (
self.y_pred_treatment["posterior_predictive"]
.mu.isel({"obs_ind": 1})
Expand All @@ -363,32 +384,28 @@ def plot(self):
y_pred_counterfactual = (
self.y_pred_counterfactual["posterior_predictive"].mu.mean().data
)
# Calculate the x position to plot at
diff = np.ptp(self.x_pred_treatment[self.time_variable_name].values)
x = np.max(self.x_pred_treatment[self.time_variable_name].values) + 0.1 * diff
# Plot the arrow
ax.annotate(
"",
xy=(1.15, y_pred_counterfactual),
xy=(x, y_pred_counterfactual),
xycoords="data",
xytext=(1.15, y_pred_treatment),
xytext=(x, y_pred_treatment),
textcoords="data",
arrowprops={"arrowstyle": "<->", "color": "green", "lw": 3},
arrowprops={"arrowstyle": "<-", "color": "green", "lw": 3},
)
# Plot text annotation next to arrow
ax.annotate(
"causal\nimpact",
xy=(1.15, np.mean([y_pred_counterfactual, y_pred_treatment])),
xy=(x, np.mean([y_pred_counterfactual, y_pred_treatment])),
xycoords="data",
xytext=(5, 0),
textcoords="offset points",
color="green",
va="center",
)
# formatting
ax.set(
# xlim=[-0.15, 1.25],
xticks=self.x_pred_treatment[self.time_variable_name].values,
# xticklabels=["pre", "post"],
title=self._causal_impact_summary_stat(),
)
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)

def _causal_impact_summary_stat(self):
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
Expand Down
138 changes: 112 additions & 26 deletions docs/notebooks/did_pymc.ipynb

Large diffs are not rendered by default.

264 changes: 94 additions & 170 deletions docs/notebooks/did_pymc_banks.ipynb

Large diffs are not rendered by default.

0 comments on commit b1310a6

Please sign in to comment.