Skip to content

Commit

Permalink
fixed issues related to naming, predictors
Browse files Browse the repository at this point in the history
  • Loading branch information
houssamzenati committed Feb 7, 2024
1 parent 6b96b6c commit 3f1b345
Show file tree
Hide file tree
Showing 2 changed files with 525 additions and 15 deletions.
32 changes: 17 additions & 15 deletions src/med_bench/mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,10 @@
from sklearn.model_selection import KFold
from sklearn.preprocessing import PolynomialFeatures

from .utils.classifiers import (_estimate_f_mu, _estimate_f_mu_cross_mu,
from .utils.nuisances import (_estimate_f_mu, _estimate_f_mu_cross_mu,
_estimate_px, _estimate_px_mu_cross_mu,
_get_x_classifiers, _get_x_y_classifiers,
_get_y_m_classifiers, _get_y_m_x_classifiers)
_get_t_predictors, _get_t_y_predictors,
_get_y_m_predictors, _get_y_m_t_predictors)
from .utils.utils import _convert_array_to_R

pandas2ri.activate()
Expand Down Expand Up @@ -117,8 +117,8 @@ def mediation_IPW(y, t, m, x, w, z, trim, logit, regularization=True, forest=Fal
calibration mode; for example using a sigmoid function
"""

classifier_x, classifier_xm = _get_x_classifiers(regularization, forest, calibration, calib_method)
p_x, p_xm = _estimate_px(t, m, x, crossfit, classifier_x, classifier_xm)
classifier_t_x, classifier_t_xm = _get_t_predictors(regularization, forest, calibration, calib_method)
p_x, p_xm = _estimate_px(t, m, x, crossfit, classifier_t_x, classifier_t_xm)

if z is not None:
raise NotImplementedError
Expand Down Expand Up @@ -251,8 +251,8 @@ def mediation_g_formula(y, t, m, x, interaction=False, forest=False,
1e-5 and 1e5
"""

classifier_y, classifier_m = _get_y_m_classifiers(regularization, forest, calibration, calib_method)
f, mu = _estimate_f_mu(t, m, x, y, crossfit, classifier_y, classifier_m, interaction)
regressor_y, classifier_m = _get_y_m_predictors(regularization, forest, calibration, calib_method)
f, mu = _estimate_f_mu(t, m, x, y, crossfit, regressor_y, classifier_m, interaction)
f_00x, f_01x, f_10x, f_11x = f
mu_11x, mu_10x, mu_01x, mu_00x = mu

Expand Down Expand Up @@ -316,8 +316,8 @@ def alternative_estimator(y, t, m, x, regularization=True):
# computation of direct effect
y_treated_reg_m = RidgeCV(alphas=alphas, cv=CV_FOLDS).fit(np.hstack((x[treated], m[treated])), y[treated])
y_ctrl_reg_m = RidgeCV(alphas=alphas, cv=CV_FOLDS).fit(np.hstack((x[~treated], m[~treated])), y[~treated])
stack_x_m = np.hstack((x, m))
direct_effect = np.sum(y_treated_reg_m.predict(stack_x_m) - y_ctrl_reg_m.predict(stack_x_m)) / len(y)
xm = np.hstack((x, m))
direct_effect = np.sum(y_treated_reg_m.predict(xm) - y_ctrl_reg_m.predict(xm)) / len(y)

# computation of total effect
y_treated_reg = RidgeCV(alphas=alphas, cv=CV_FOLDS).fit(x[treated], y[treated])
Expand Down Expand Up @@ -454,10 +454,10 @@ def mediation_multiply_robust(
if n != len(x) or n != len(m) or n != len(t):
raise ValueError("Inputs don't have the same number of observations")

classifier_y, cross_y_clf, classifier_m, classifier_x = _get_y_m_x_classifiers(regularization, forest, calibration,
calib_method)
p_x, f, mu, cross_mu = _estimate_f_mu_cross_mu(t, m, x, y, crossfit, classifier_y, cross_y_clf, classifier_m,
classifier_x, interaction)
regressor_y, regressor_cross_y, classifier_m, classifier_t_x = _get_y_m_t_predictors(regularization, forest,
calibration, calib_method)
p_x, f, mu, cross_mu = _estimate_f_mu_cross_mu(t, m, x, y, crossfit, regressor_y, regressor_cross_y, classifier_m,
classifier_t_x, interaction)
f_m0x, f_m1x = f
mu_0mx, mu_1mx = mu
E_mu_t0_t0, E_mu_t0_t1, E_mu_t1_t0, E_mu_t1_t1 = cross_mu
Expand Down Expand Up @@ -772,8 +772,10 @@ def mediation_DML(
"mu_0x",
]

clf_x, clf_xm, clf_y, clf_cross_y = _get_x_y_classifiers(regularization, use_forest, calib_method, random_state)
p, mu, cross_mu = _estimate_px_mu_cross_mu(t, m, x, y, crossfit, clf_x, clf_xm, clf_y, clf_cross_y)
classifier_t_x, classifier_t_xm, regressor_y, regressor_cross_y = _get_t_y_predictors(regularization, use_forest,
calib_method, random_state)
p, mu, cross_mu = _estimate_px_mu_cross_mu(t, m, x, y, crossfit, classifier_t_x, classifier_t_xm,
regressor_y, regressor_cross_y)

p_x, p_xm = p
mu_1mx, mu_1mx_nested, mu_0mx, mu_0mx_nested, mu_1x, mu_0x = mu
Expand Down
Loading

0 comments on commit 3f1b345

Please sign in to comment.