Skip to content

Commit

Permalink
#76 improve DiD plotting + rerun notebooks
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Dec 26, 2022
1 parent f0aefd4 commit 36b4511
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 83 deletions.
83 changes: 41 additions & 42 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,53 +370,52 @@ def plot(self):
alpha=0.5,
ax=ax,
)

# Plot model fit to control group
parts = ax.violinplot(
az.extract(
self.y_pred_control, group="posterior_predictive", var_names="mu"
).values.T,
positions=self.x_pred_control[self.time_variable_name].values,
showmeans=False,
showmedians=False,
widths=0.2,
)
for pc in parts["bodies"]:
pc.set_facecolor("C0")
pc.set_edgecolor("None")
pc.set_alpha(0.5)
time_points = self.x_pred_control[self.time_variable_name].values
plot_xY(
time_points,
self.y_pred_control.posterior_predictive.y_hat,
ax=ax,
plot_hdi_kwargs={"color": "C0"},
)

# Plot model fit to treatment group
parts = ax.violinplot(
az.extract(
self.y_pred_treatment, group="posterior_predictive", var_names="mu"
).values.T,
positions=self.x_pred_treatment[self.time_variable_name].values,
showmeans=False,
showmedians=False,
widths=0.2,
)

for pc in parts["bodies"]:
pc.set_facecolor("C1")
pc.set_edgecolor("None")
pc.set_alpha(0.5)
time_points = self.x_pred_control[self.time_variable_name].values
plot_xY(
time_points,
self.y_pred_treatment.posterior_predictive.y_hat,
ax=ax,
plot_hdi_kwargs={"color": "C1"},
)

# Plot counterfactual - post-test for treatment group IF no treatment
# had occurred.
parts = ax.violinplot(
az.extract(
self.y_pred_counterfactual,
group="posterior_predictive",
var_names="mu",
).values.T,
positions=self.x_pred_counterfactual[self.time_variable_name].values,
showmeans=False,
showmedians=False,
widths=0.2,
)
for pc in parts["bodies"]:
pc.set_facecolor("C2")
pc.set_edgecolor("None")
pc.set_alpha(0.5)
time_points = self.x_pred_counterfactual[self.time_variable_name].values
if len(time_points) == 1:
parts = ax.violinplot(
az.extract(
self.y_pred_counterfactual,
group="posterior_predictive",
var_names="mu",
).values.T,
positions=self.x_pred_counterfactual[self.time_variable_name].values,
showmeans=False,
showmedians=False,
widths=0.2,
)
for pc in parts["bodies"]:
pc.set_facecolor("C2")
pc.set_edgecolor("None")
pc.set_alpha(0.5)
else:
plot_xY(
time_points,
self.y_pred_counterfactual.posterior_predictive.y_hat,
ax=ax,
plot_hdi_kwargs={"color": "C2"},
)

# arrow to label the causal impact
self._plot_causal_impact_arrow(ax)
# formatting
Expand Down
10 changes: 5 additions & 5 deletions docs/notebooks/did_pymc.ipynb

Large diffs are not rendered by default.

56 changes: 20 additions & 36 deletions docs/notebooks/did_pymc_banks.ipynb

Large diffs are not rendered by default.

0 comments on commit 36b4511

Please sign in to comment.