From 362c530d088ee3e6e72361632df9d6a841b9560c Mon Sep 17 00:00:00 2001 From: Houssam Zenati Date: Wed, 13 Dec 2023 15:17:44 +0100 Subject: [PATCH] corrected the issues for the PR --- src/.DS_Store | Bin 0 -> 6148 bytes src/med_bench/get_estimation.py | 114 ++++++++------- src/med_bench/mediation.py | 6 +- src/tests/.DS_Store | Bin 0 -> 6148 bytes .../estimation}/test_get_estimation.py | 46 +++--- .../simulate_data}/test_get_simulated_data.py | 2 +- src/treatment_effect.py | 131 ++++++++++++++++++ 7 files changed, 219 insertions(+), 80 deletions(-) create mode 100644 src/.DS_Store create mode 100644 src/tests/.DS_Store rename src/{ => tests/estimation}/test_get_estimation.py (80%) rename src/{ => tests/simulate_data}/test_get_simulated_data.py (99%) diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..6d2aa7cc96193e0290092c5210100e7aa38e11f8 GIT binary patch literal 6148 zcmeH~&u`N(6vv;pj;2gVl>mv|EOD(yN71xtm(Y#Ft^~mWP)L@*BC@z@QmUz{lso)G z{3HA?J8j=*ds5bxYXy^E_5Ag(XEz_kc1%QKI7#-1xe%p z-|;BV%BtW0BQ`c$TMyb^+uQa&2j^-UR8ci4htc>Wm)0g#6;+fK z=~xG3Ns2C?K4e9rrb9I;l2XS8roroYov^zzo9*o%9`xj3a4_%5+2PB6Pre%L&*vTQ z>GKyyZ^Q5TMWKG`KTiUim9|G#Z{-&WhL7TBaBXOloPqjQYon-vxeo!2!6lBsS_#|$ D;z_mX literal 0 HcmV?d00001 diff --git a/src/med_bench/get_estimation.py b/src/med_bench/get_estimation.py index dbb7b99..a909b95 100644 --- a/src/med_bench/get_estimation.py +++ b/src/med_bench/get_estimation.py @@ -7,8 +7,16 @@ from rpy2.rinterface_lib.embedded import RRuntimeError import pandas as pd import numpy as np -from .mediation import mediation_IPW, mediation_coefficient_products, mediation_g_formula,mediation_multiply_robust, r_g_estimator, r_medDML, r_mediate, mediation_DML - +from .mediation import ( + mediation_IPW, + mediation_coefficient_products, + mediation_g_computation, + mediation_multiply_robust, + mediation_DML, + r_mediation_g_computation, + r_mediation_DML, + r_mediate, +) def get_estimation(x, t, m, y, estimator, config): """Wrapper estimator fonction ; calls an estimator given mediation data @@ -285,9 +293,9 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "g_computation_noreg": + elif estimator == "mediation_g_computation_noreg": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -298,9 +306,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, calibration=False, ) - elif estimator == "g_computation_noreg_cf": + elif estimator == "mediation_g_computation_noreg_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -311,9 +319,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, calibration=False, ) - elif estimator == "g_computation_reg": + elif estimator == "mediation_g_computation_reg": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -324,9 +332,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "g_computation_reg_cf": + elif estimator == "mediation_g_computation_reg_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -337,9 +345,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "g_computation_reg_calibration": + elif estimator == "mediation_g_computation_reg_calibration": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -350,9 +358,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "g_computation_reg_calibration_iso": + elif estimator == "mediation_g_computation_reg_calibration_iso": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -364,9 +372,9 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "g_computation_reg_calibration_cf": + elif estimator == "mediation_g_computation_reg_calibration_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -377,9 +385,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "g_computation_reg_calibration_iso_cf": + elif estimator == "mediation_g_computation_reg_calibration_iso_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -391,9 +399,9 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "g_computation_forest": + elif estimator == "mediation_g_computation_forest": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -404,9 +412,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "g_computation_forest_cf": + elif estimator == "mediation_g_computation_forest_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -417,9 +425,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "g_computation_forest_calibration": + elif estimator == "mediation_g_computation_forest_calibration": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -430,9 +438,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "g_computation_forest_calibration_iso": + elif estimator == "mediation_g_computation_forest_calibration_iso": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -444,9 +452,9 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "g_computation_forest_calibration_cf": + elif estimator == "mediation_g_computation_forest_calibration_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -457,9 +465,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "g_computation_forest_calibration_iso_cf": + elif estimator == "mediation_g_computation_forest_calibration_iso_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -471,7 +479,7 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "multiply_robust_noreg": + elif estimator == "mediation_multiply_robust_noreg": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -485,7 +493,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, calibration=False, ) - elif estimator == "multiply_robust_noreg_cf": + elif estimator == "mediation_multiply_robust_noreg_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -499,7 +507,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, calibration=False, ) - elif estimator == "multiply_robust_reg": + elif estimator == "mediation_multiply_robust_reg": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -513,7 +521,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "multiply_robust_reg_cf": + elif estimator == "mediation_multiply_robust_reg_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -527,7 +535,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "multiply_robust_reg_calibration": + elif estimator == "mediation_multiply_robust_reg_calibration": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -541,7 +549,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "multiply_robust_reg_calibration_iso": + elif estimator == "mediation_multiply_robust_reg_calibration_iso": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -556,7 +564,7 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "multiply_robust_reg_calibration_cf": + elif estimator == "mediation_multiply_robust_reg_calibration_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -570,7 +578,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "multiply_robust_reg_calibration_iso_cf": + elif estimator == "mediation_multiply_robust_reg_calibration_iso_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -585,7 +593,7 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "multiply_robust_forest": + elif estimator == "mediation_multiply_robust_forest": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -599,7 +607,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "multiply_robust_forest_cf": + elif estimator == "mediation_multiply_robust_forest_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -613,7 +621,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "multiply_robust_forest_calibration": + elif estimator == "mediation_multiply_robust_forest_calibration": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -627,7 +635,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "multiply_robust_forest_calibration_iso": + elif estimator == "mediation_multiply_robust_forest_calibration_iso": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -642,7 +650,7 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "multiply_robust_forest_calibration_cf": + elif estimator == "mediation_multiply_robust_forest_calibration_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -656,7 +664,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "multiply_robust_forest_calibration_iso_cf": + elif estimator == "mediation_multiply_robust_forest_calibration_iso_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -674,22 +682,22 @@ def get_estimation(x, t, m, y, estimator, config): elif estimator == "simulation_based": if config in (0, 1, 2): effects = r_mediate(y, t, m, x, interaction=False) - elif estimator == "DML_mediation": + elif estimator == "mediation_DML": if config > 0: - effects = r_medDML(y, t, m, x, trim=0.0, order=1) - elif estimator == "med_dml_noreg": + effects = r_mediation_DML(y, t, m, x, trim=0.0, order=1) + elif estimator == "mediation_DML_noreg": effects = mediation_DML(x, t, m, y, trim=0, regularization=False) - elif estimator == "med_dml_reg": + elif estimator == "mediation_DML_reg": effects = mediation_DML(x, t, m, y, trim=0) - elif estimator == "med_dml_reg_fixed_seed": + elif estimator == "mediation_DML_reg_fixed_seed": effects = mediation_DML(x, t, m, y, trim=0, random_state=321) - elif estimator == "med_dml_noreg_cf": + elif estimator == "mediation_DML_noreg_cf": effects = mediation_DML(x, t, m, y, trim=0, crossfit=4, regularization=False) - elif estimator == "med_dml_reg_cf": + elif estimator == "mediation_DML_reg_cf": effects = mediation_DML(x, t, m, y, trim=0, crossfit=4) - elif estimator == "G_estimator": + elif estimator == "mediation_g_computation": if config in (0, 1, 2): - effects = r_g_estimator(y, t, m, x) + effects = r_mediation_g_computation(y, t, m, x) else: raise ValueError("Unrecognized estimator label.") if effects is None: diff --git a/src/med_bench/mediation.py b/src/med_bench/mediation.py index 72d111e..2edd551 100644 --- a/src/med_bench/mediation.py +++ b/src/med_bench/mediation.py @@ -323,7 +323,7 @@ def mediation_coefficient_products(y, t, m, x, interaction=False, regularization None] -def mediation_g_formula(y, t, m, x, interaction=False, forest=False, +def mediation_g_computation(y, t, m, x, interaction=False, forest=False, crossfit=0, calibration=True, regularization=True, calib_method='sigmoid'): """ @@ -889,7 +889,7 @@ def r_mediate(y, t, m, x, interaction=False): return to_return + [None] -def r_g_estimator(y, t, m, x): +def r_mediation_g_computation(y, t, m, x): m = m.ravel() var_names = [[y, 'y'], [t, 't'], @@ -926,7 +926,7 @@ def r_g_estimator(y, t, m, x): indirect_effect, None] -def r_medDML(y, t, m, x, trim=0.05, order=1): +def r_mediation_DML(y, t, m, x, trim=0.05, order=1): """ y array-like, shape (n_samples) outcome value for each unit, continuous diff --git a/src/tests/.DS_Store b/src/tests/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..70d7305845d50f43d9ab3e9f3379b9b7bc581bf3 GIT binary patch literal 6148 zcmeHKO;6iE5SLlnvP_P_MZ zAK~x7%*`BoYc;{=orN4EyCXRQyw*-2N@B10@c2PhS_<#z~`k>JXeTS7rUOKSgD*z%LBiS&f<05giL*HR# z5qr>-r6O9Ya#svx>DaDap6{@-Xz8HbifZd7k|m^ty4cod#y&lKqsTT%Hksh1Kx@;m$u>+bT(|) X6d?KzD~s5ICI 1 - trim] = 1 - trim + y1m1 = np.sum(y * t / p_x) / np.sum(t / p_x) + y0m0 = np.sum(y * (1 - t) / (1 - p_x)) /\ + np.sum((1 - t) / (1 - p_x)) + return y1m1 - y0m0 + + +def AIPW(y, t, m, x, clip=0.01, forest=False, crossfit=0, forest_r=False, + regularization=True): + """ + AIPW estimator + """ + if regularization: + alphas = ALPHAS + cs = ALPHAS + else: + alphas = [TINY] + cs = [np.inf] + n = len(y) + if len(x.shape) == 1: + x = x.reshape(-1, 1) + if crossfit < 2: + train_test_list = [[np.arange(n), np.arange(n)]] + else: + kf = KFold(n_splits=crossfit) + train_test_list = list() + for train_index, test_index in kf.split(x): + train_test_list.append([train_index, test_index]) + + if not forest_r: + mu_1x, mu_0x, e_x = [np.zeros(n) for h in range(3)] + + for train_index, test_index in train_test_list: + treated_train_index = np.array( + list(set(train_index).intersection(np.where(t == 1)[0]))) + control_train_index = np.array( + list(set(train_index).intersection(np.where(t == 0)[0]))) + if not forest: + y_reg_treated = RidgeCV(alphas=alphas, cv=CV_FOLDS)\ + .fit(x[treated_train_index, :], y[treated_train_index]) + y_reg_control = RidgeCV(alphas=alphas, cv=CV_FOLDS)\ + .fit(x[control_train_index, :], y[control_train_index]) + t_prob = LogisticRegressionCV(Cs=cs, cv=CV_FOLDS)\ + .fit(x[train_index, :], t[train_index]) + else: + y_reg_treated = RandomForestRegressor(max_depth=3, + min_samples_leaf=5)\ + .fit(x[treated_train_index, :], y[treated_train_index]) + y_reg_control = RandomForestRegressor(max_depth=3, + min_samples_leaf=5)\ + .fit(x[control_train_index, :], y[control_train_index]) + t_prob = CalibratedClassifierCV( + RandomForestClassifier(max_depth=3, min_samples_leaf=5))\ + .fit(x[train_index, :], t[train_index]) + mu_1x[test_index] = y_reg_treated.predict(x[test_index, :]) + mu_0x[test_index] = y_reg_control.predict(x[test_index, :]) + e_x[test_index] = t_prob.predict_proba(x[test_index, :])[:, 1] + e_x = np.clip(e_x, clip, 1 - clip) + total_effect = np.mean(mu_1x - mu_0x + t * (y - mu_1x) / e_x - + (1 - t) * (y - mu_0x) / (1 - e_x)) + else: + x_r, t_r, y_r = [_convert_array_to_R(uu) for uu in (x, t, y)] + cf = grf.causal_forest(x_r, y_r, t_r, num_trees=500) + total_effect = grf.average_treatment_effect(cf)[0] + return [total_effect] + [None] * 5 + + +def bart(y, t, m, x, tmle=False): + x_r, t_r, y_r = [_convert_array_to_R(uu) for uu in (x, t, y)] + if tmle: + bart_model = brc.bartc(y_r, t_r, x_r, method_rsp='tmle', + estimand='ate') + else: + bart_model = brc.bartc(y_r, t_r, x_r, method_rsp='p.weight', + estimand='ate') + ate = brc.summary_bartcFit(bart_model).rx2('estimates')[0][0] + return [ate] + [None] * 5 \ No newline at end of file