Skip to content

Commit

Permalink
#76 stop evaluating for multiple units per time point
Browse files Browse the repository at this point in the history
  • Loading branch information
drbenvincent committed Dec 26, 2022
1 parent 0ed7240 commit e3bc8cd
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 79 deletions.
19 changes: 15 additions & 4 deletions causalpy/pymc_experiments.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,9 +298,13 @@ def __init__(
self.x_pred_control = (
self.data
# just the untreated group
.query(f"{self.group_variable_name} == @self.untreated") # 🔥
.query(f"{self.group_variable_name} == @self.untreated")
# drop the outcome variable
.drop(self.outcome_variable_name, axis=1)
# We may have multiple units per time point, we only want one time point
.groupby(self.time_variable_name)
.first()
.reset_index()
)
assert not self.x_pred_control.empty
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_control)
Expand All @@ -310,9 +314,13 @@ def __init__(
self.x_pred_treatment = (
self.data
# just the treated group
.query(f"{self.group_variable_name} == @self.treated") # 🔥
.query(f"{self.group_variable_name} == @self.treated")
# drop the outcome variable
.drop(self.outcome_variable_name, axis=1)
# We may have multiple units per time point, we only want one time point
.groupby(self.time_variable_name)
.first()
.reset_index()
)
assert not self.x_pred_treatment.empty
(new_x,) = build_design_matrices([self._x_design_info], self.x_pred_treatment)
Expand All @@ -322,14 +330,17 @@ def __init__(
self.x_pred_counterfactual = (
self.data
# just the treated group
.query(f"{self.group_variable_name} == @self.treated") # 🔥
.query(f"{self.group_variable_name} == @self.treated")
# just the treatment period(s)
# TODO: the line below might need some work to be more robust
.query("post_treatment == True")
# drop the outcome variable
.drop(self.outcome_variable_name, axis=1)
# DO AN INTERVENTION. Set the post_treatment variable to False
.assign(post_treatment=False)
# We may have multiple units per time point, we only want one time point
.groupby(self.time_variable_name)
.first()
.reset_index()
)
assert not self.x_pred_counterfactual.empty
(new_x,) = build_design_matrices(
Expand Down
44 changes: 21 additions & 23 deletions docs/notebooks/did_pymc.ipynb

Large diffs are not rendered by default.

122 changes: 70 additions & 52 deletions docs/notebooks/did_pymc_banks.ipynb

Large diffs are not rendered by default.

0 comments on commit e3bc8cd

Please sign in to comment.