Skip to content

Commit

Permalink
#50 #44 alcohol RD example and update corresponding plots
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Nov 3, 2022
1 parent 58c3e85 commit fb63684
Show file tree
Hide file tree
Showing 6 changed files with 4,044 additions and 6,158 deletions.
51 changes: 16 additions & 35 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,63 +308,44 @@ def __init__(
self.score = self.prediction_model.score(X=self.X, y=self.y)

# get the model predictions of the observed data
xi = np.linspace(np.min(self.data["x"]), np.max(self.data["x"]), 200)
self.x_pred = pd.DataFrame({"x": xi, "treated": self._is_treated(xi)})
xi = np.linspace(
np.min(self.data[self.running_variable_name]),
np.max(self.data[self.running_variable_name]),
200,
)
self.x_pred = pd.DataFrame(
{self.running_variable_name: xi, "treated": self._is_treated(xi)}
)
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred)
self.pred = self.prediction_model.predict(X=np.asarray(new_x))

# calculate the counterfactual
xi = xi[xi > self.treatment_threshold]
self.x_counterfact = pd.DataFrame({"x": xi, "treated": np.zeros(xi.shape)})
self.x_counterfact = pd.DataFrame(
{self.running_variable_name: xi, "treated": np.zeros(xi.shape)}
)
(new_x,) = build_design_matrices([self._x_design_info], self.x_counterfact)
self.pred_counterfac = self.prediction_model.predict(X=np.asarray(new_x))

def _is_treated(self, x):
return np.greater_equal(x, self.treatment_threshold)

def plot(self):
fig, ax = plt.subplots(2, 1, figsize=(7, 8))
fig, ax = plt.subplots()
# Plot raw data
sns.scatterplot(
self.data,
x=self.running_variable_name,
y=self.outcome_variable_name,
c="k", # hue="treated",
ax=ax[0],
ax=ax,
)
# Plot model fit to data
plot_xY(
self.x_pred[self.running_variable_name],
self.pred["posterior_predictive"].y_hat,
ax=ax[0],
)
# # Plot counterfactual
plot_xY(
self.x_counterfact[self.running_variable_name],
self.pred_counterfac["posterior_predictive"].y_hat,
ax=ax[0],
plot_hdi_kwargs={"color": "C2"},
)
# Shaded causal effect
# TODO
# Intervention line
ax[0].axvline(
x=self.treatment_threshold,
ls="-",
lw=3,
color="r",
label="treatment threshold",
)
ax[0].set(title=f"$R^2$ on all data = {self.score:.3f}")
ax[0].legend(fontsize=LEGEND_FONT_SIZE)

# Plot causal effect estimate ------------------------
coeff_name = (
"treated[T.True]" # NOTE: get rid of this hard coded variable name!
)
beta = self.prediction_model.idata["posterior"]["beta"].sel(
{"coeffs": coeff_name}
ax=ax,
)
az.plot_posterior(beta, ref_val=0, ax=ax[1])
ax[1].set(title=f"Causal impact", xlabel=coeff_name)
ax.set(title=f"$R^2$ on all data = {self.score:.3f}")
ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)
56 changes: 11 additions & 45 deletions causalpy/skl_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,65 +339,31 @@ def _is_treated(self, x):
return np.greater_equal(x, self.treatment_threshold)

def plot(self):
fig, ax = plt.subplots(2, 1, sharex=True, figsize=(7, 8))
fig, ax = plt.subplots()
# Plot raw data
sns.scatterplot(
self.data,
x=self.running_variable_name,
y=self.outcome_variable_name,
c="k", # hue="treated",
ax=ax[0],
ax=ax,
)
# Plot model fit to data
ax[0].plot(
ax.plot(
self.x_pred[self.running_variable_name],
self.pred,
"k",
markersize=10,
label="model fit",
)
# Plot counterfactual
ax[0].plot(
self.x_counterfact[self.running_variable_name],
self.pred_counterfac,
markersize=10,
ls=":",
c="k",
label="counterfactual",
)
# Shaded causal effect
ax[0].fill_between(
self.x_counterfact[self.running_variable_name],
y1=np.squeeze(self.pred_counterfac),
y2=np.squeeze(self.pred[-len(np.squeeze(self.pred_counterfac)) :]),
color="C0",
alpha=0.25,
label="inferred causal impact",
)
ax[0].set(title=f"$R^2$ on all data = {self.score:.3f}")

# Plot causal effect ------------------------
# NOTE: plotting residual data point requires another call to self.prediction_model.predict() for the treated unit x-values.
ax[1].plot(
self.x_counterfact[self.running_variable_name],
self.causal_impact,
markersize=10,
ax.set(title=f"$R^2$ on all data = {self.score:.3f}")
# Intervention line
ax.axvline(
x=self.treatment_threshold,
ls="-",
c="k",
label="causal impact",
lw=3,
color="r",
label="treatment threshold",
)
ax[1].axhline(y=0, c="k")
ax[1].set(title=f"Causal impact", xlabel=self.running_variable_name)

# Intervention line
for i in [0, 1]:
ax[i].axvline(
x=self.treatment_threshold,
ls="-",
lw=3,
color="r",
label="treatment threshold",
)
ax[i].legend(fontsize=LEGEND_FONT_SIZE)

ax.legend(fontsize=LEGEND_FONT_SIZE)
return (fig, ax)
Loading

0 comments on commit fb63684

Please sign in to comment.