diff --git a/src/.DS_Store b/src/.DS_Store new file mode 100644 index 0000000..6d2aa7c Binary files /dev/null and b/src/.DS_Store differ diff --git a/src/med_bench/get_estimation.py b/src/med_bench/get_estimation.py index dbb7b99..a909b95 100644 --- a/src/med_bench/get_estimation.py +++ b/src/med_bench/get_estimation.py @@ -7,8 +7,16 @@ from rpy2.rinterface_lib.embedded import RRuntimeError import pandas as pd import numpy as np -from .mediation import mediation_IPW, mediation_coefficient_products, mediation_g_formula,mediation_multiply_robust, r_g_estimator, r_medDML, r_mediate, mediation_DML - +from .mediation import ( + mediation_IPW, + mediation_coefficient_products, + mediation_g_computation, + mediation_multiply_robust, + mediation_DML, + r_mediation_g_computation, + r_mediation_DML, + r_mediate, +) def get_estimation(x, t, m, y, estimator, config): """Wrapper estimator fonction ; calls an estimator given mediation data @@ -285,9 +293,9 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "g_computation_noreg": + elif estimator == "mediation_g_computation_noreg": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -298,9 +306,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, calibration=False, ) - elif estimator == "g_computation_noreg_cf": + elif estimator == "mediation_g_computation_noreg_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -311,9 +319,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, calibration=False, ) - elif estimator == "g_computation_reg": + elif estimator == "mediation_g_computation_reg": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -324,9 +332,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "g_computation_reg_cf": + elif estimator == "mediation_g_computation_reg_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -337,9 +345,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "g_computation_reg_calibration": + elif estimator == "mediation_g_computation_reg_calibration": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -350,9 +358,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "g_computation_reg_calibration_iso": + elif estimator == "mediation_g_computation_reg_calibration_iso": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -364,9 +372,9 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "g_computation_reg_calibration_cf": + elif estimator == "mediation_g_computation_reg_calibration_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -377,9 +385,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "g_computation_reg_calibration_iso_cf": + elif estimator == "mediation_g_computation_reg_calibration_iso_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -391,9 +399,9 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "g_computation_forest": + elif estimator == "mediation_g_computation_forest": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -404,9 +412,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "g_computation_forest_cf": + elif estimator == "mediation_g_computation_forest_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -417,9 +425,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "g_computation_forest_calibration": + elif estimator == "mediation_g_computation_forest_calibration": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -430,9 +438,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "g_computation_forest_calibration_iso": + elif estimator == "mediation_g_computation_forest_calibration_iso": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -444,9 +452,9 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "g_computation_forest_calibration_cf": + elif estimator == "mediation_g_computation_forest_calibration_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -457,9 +465,9 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "g_computation_forest_calibration_iso_cf": + elif estimator == "mediation_g_computation_forest_calibration_iso_cf": if config in (0, 1, 2): - effects = mediation_g_formula( + effects = mediation_g_computation( y, t, m, @@ -471,7 +479,7 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "multiply_robust_noreg": + elif estimator == "mediation_multiply_robust_noreg": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -485,7 +493,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, calibration=False, ) - elif estimator == "multiply_robust_noreg_cf": + elif estimator == "mediation_multiply_robust_noreg_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -499,7 +507,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=False, calibration=False, ) - elif estimator == "multiply_robust_reg": + elif estimator == "mediation_multiply_robust_reg": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -513,7 +521,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "multiply_robust_reg_cf": + elif estimator == "mediation_multiply_robust_reg_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -527,7 +535,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "multiply_robust_reg_calibration": + elif estimator == "mediation_multiply_robust_reg_calibration": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -541,7 +549,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "multiply_robust_reg_calibration_iso": + elif estimator == "mediation_multiply_robust_reg_calibration_iso": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -556,7 +564,7 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "multiply_robust_reg_calibration_cf": + elif estimator == "mediation_multiply_robust_reg_calibration_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -570,7 +578,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "multiply_robust_reg_calibration_iso_cf": + elif estimator == "mediation_multiply_robust_reg_calibration_iso_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -585,7 +593,7 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "multiply_robust_forest": + elif estimator == "mediation_multiply_robust_forest": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -599,7 +607,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "multiply_robust_forest_cf": + elif estimator == "mediation_multiply_robust_forest_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -613,7 +621,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=False, ) - elif estimator == "multiply_robust_forest_calibration": + elif estimator == "mediation_multiply_robust_forest_calibration": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -627,7 +635,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "multiply_robust_forest_calibration_iso": + elif estimator == "mediation_multiply_robust_forest_calibration_iso": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -642,7 +650,7 @@ def get_estimation(x, t, m, y, estimator, config): calibration=True, calib_method="isotonic", ) - elif estimator == "multiply_robust_forest_calibration_cf": + elif estimator == "mediation_multiply_robust_forest_calibration_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -656,7 +664,7 @@ def get_estimation(x, t, m, y, estimator, config): regularization=True, calibration=True, ) - elif estimator == "multiply_robust_forest_calibration_iso_cf": + elif estimator == "mediation_multiply_robust_forest_calibration_iso_cf": if config in (0, 1, 2): effects = mediation_multiply_robust( y, @@ -674,22 +682,22 @@ def get_estimation(x, t, m, y, estimator, config): elif estimator == "simulation_based": if config in (0, 1, 2): effects = r_mediate(y, t, m, x, interaction=False) - elif estimator == "DML_mediation": + elif estimator == "mediation_DML": if config > 0: - effects = r_medDML(y, t, m, x, trim=0.0, order=1) - elif estimator == "med_dml_noreg": + effects = r_mediation_DML(y, t, m, x, trim=0.0, order=1) + elif estimator == "mediation_DML_noreg": effects = mediation_DML(x, t, m, y, trim=0, regularization=False) - elif estimator == "med_dml_reg": + elif estimator == "mediation_DML_reg": effects = mediation_DML(x, t, m, y, trim=0) - elif estimator == "med_dml_reg_fixed_seed": + elif estimator == "mediation_DML_reg_fixed_seed": effects = mediation_DML(x, t, m, y, trim=0, random_state=321) - elif estimator == "med_dml_noreg_cf": + elif estimator == "mediation_DML_noreg_cf": effects = mediation_DML(x, t, m, y, trim=0, crossfit=4, regularization=False) - elif estimator == "med_dml_reg_cf": + elif estimator == "mediation_DML_reg_cf": effects = mediation_DML(x, t, m, y, trim=0, crossfit=4) - elif estimator == "G_estimator": + elif estimator == "mediation_g_computation": if config in (0, 1, 2): - effects = r_g_estimator(y, t, m, x) + effects = r_mediation_g_computation(y, t, m, x) else: raise ValueError("Unrecognized estimator label.") if effects is None: diff --git a/src/med_bench/mediation.py b/src/med_bench/mediation.py index 72d111e..2edd551 100644 --- a/src/med_bench/mediation.py +++ b/src/med_bench/mediation.py @@ -323,7 +323,7 @@ def mediation_coefficient_products(y, t, m, x, interaction=False, regularization None] -def mediation_g_formula(y, t, m, x, interaction=False, forest=False, +def mediation_g_computation(y, t, m, x, interaction=False, forest=False, crossfit=0, calibration=True, regularization=True, calib_method='sigmoid'): """ @@ -889,7 +889,7 @@ def r_mediate(y, t, m, x, interaction=False): return to_return + [None] -def r_g_estimator(y, t, m, x): +def r_mediation_g_computation(y, t, m, x): m = m.ravel() var_names = [[y, 'y'], [t, 't'], @@ -926,7 +926,7 @@ def r_g_estimator(y, t, m, x): indirect_effect, None] -def r_medDML(y, t, m, x, trim=0.05, order=1): +def r_mediation_DML(y, t, m, x, trim=0.05, order=1): """ y array-like, shape (n_samples) outcome value for each unit, continuous diff --git a/src/tests/.DS_Store b/src/tests/.DS_Store new file mode 100644 index 0000000..70d7305 Binary files /dev/null and b/src/tests/.DS_Store differ diff --git a/src/test_get_estimation.py b/src/tests/estimation/test_get_estimation.py similarity index 80% rename from src/test_get_estimation.py rename to src/tests/estimation/test_get_estimation.py index 76c57eb..f95a4c3 100644 --- a/src/test_get_estimation.py +++ b/src/tests/estimation/test_get_estimation.py @@ -81,35 +81,35 @@ "mediation_ipw_reg_calibration": INFINITE_TOLERANCE, "mediation_ipw_forest": INFINITE_TOLERANCE, "mediation_ipw_forest_calibration": INFINITE_TOLERANCE, - "g_computation_noreg": LARGE_TOLERANCE, - "g_computation_reg": MEDIUM_TOLERANCE, - "g_computation_reg_calibration": LARGE_TOLERANCE, - "g_computation_forest": LARGE_TOLERANCE, - "g_computation_forest_calibration": INFINITE_TOLERANCE, - "multiply_robust_noreg": INFINITE_TOLERANCE, - "multiply_robust_reg": LARGE_TOLERANCE, - "multiply_robust_reg_calibration": LARGE_TOLERANCE, - "multiply_robust_forest": INFINITE_TOLERANCE, - "multiply_robust_forest_calibration": LARGE_TOLERANCE, + "mediation_g_computation_noreg": LARGE_TOLERANCE, + "mediation_g_computation_reg": MEDIUM_TOLERANCE, + "mediation_g_computation_reg_calibration": LARGE_TOLERANCE, + "mediation_g_computation_forest": LARGE_TOLERANCE, + "mediation_g_computation_forest_calibration": INFINITE_TOLERANCE, + "mediation_multiply_robust_noreg": INFINITE_TOLERANCE, + "mediation_multiply_robust_reg": LARGE_TOLERANCE, + "mediation_multiply_robust_reg_calibration": LARGE_TOLERANCE, + "mediation_multiply_robust_forest": INFINITE_TOLERANCE, + "mediation_multiply_robust_forest_calibration": LARGE_TOLERANCE, "simulation_based": LARGE_TOLERANCE, - "DML_mediation": INFINITE_TOLERANCE, - "med_dml_reg_fixed_seed": INFINITE_TOLERANCE, - "G_estimator": SMALL_TOLERANCE, + "mediation_DML": INFINITE_TOLERANCE, + "mediation_DML_reg_fixed_seed": INFINITE_TOLERANCE, + "mediation_g_computation": SMALL_TOLERANCE, "mediation_ipw_noreg_cf": INFINITE_TOLERANCE, "mediation_ipw_reg_cf": INFINITE_TOLERANCE, "mediation_ipw_reg_calibration_cf": INFINITE_TOLERANCE, "mediation_ipw_forest_cf": INFINITE_TOLERANCE, "mediation_ipw_forest_calibration_cf": INFINITE_TOLERANCE, - "g_computation_noreg_cf": SMALL_TOLERANCE, - "g_computation_reg_cf": LARGE_TOLERANCE, - "g_computation_reg_calibration_cf": LARGE_TOLERANCE, - "g_computation_forest_cf": INFINITE_TOLERANCE, - "g_computation_forest_calibration_cf": LARGE_TOLERANCE, - "multiply_robust_noreg_cf": MEDIUM_TOLERANCE, - "multiply_robust_reg_cf": LARGE_TOLERANCE, - "multiply_robust_reg_calibration_cf": MEDIUM_TOLERANCE, - "multiply_robust_forest_cf": INFINITE_TOLERANCE, - "multiply_robust_forest_calibration_cf": INFINITE_TOLERANCE, + "mediation_g_computation_noreg_cf": SMALL_TOLERANCE, + "mediation_g_computation_reg_cf": LARGE_TOLERANCE, + "mediation_g_computation_reg_calibration_cf": LARGE_TOLERANCE, + "mediation_g_computation_forest_cf": INFINITE_TOLERANCE, + "mediation_g_computation_forest_calibration_cf": LARGE_TOLERANCE, + "mediation_multiply_robust_noreg_cf": MEDIUM_TOLERANCE, + "mediation_multiply_robust_reg_cf": LARGE_TOLERANCE, + "mediation_multiply_robust_reg_calibration_cf": MEDIUM_TOLERANCE, + "mediation_multiply_robust_forest_cf": INFINITE_TOLERANCE, + "mediation_multiply_robust_forest_calibration_cf": INFINITE_TOLERANCE, } diff --git a/src/test_get_simulated_data.py b/src/tests/simulate_data/test_get_simulated_data.py similarity index 99% rename from src/test_get_simulated_data.py rename to src/tests/simulate_data/test_get_simulated_data.py index 25ee8ba..e13f4d8 100644 --- a/src/test_get_simulated_data.py +++ b/src/tests/simulate_data/test_get_simulated_data.py @@ -340,7 +340,7 @@ def test_null_sigma_m_makes_nan(): seed=1, type_m="continuous", sigma_y=0.5, - sigma_m=0.5, + sigma_m=0., beta_t_factor=1, beta_m_factor=1, ) diff --git a/src/treatment_effect.py b/src/treatment_effect.py index e69de29..2d802f7 100644 --- a/src/treatment_effect.py +++ b/src/treatment_effect.py @@ -0,0 +1,131 @@ +""" +the objective of this script is to implement estimators without mediation in +causal inference, simulate data, and evaluate and compare estimators +""" + +# using rpy2 to have the same data in R and python... + +import rpy2.robjects as robjects +import rpy2.robjects.packages as rpackages +from rpy2.robjects import pandas2ri, numpy2ri +from sklearn.linear_model import LogisticRegressionCV, RidgeCV, LassoCV +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.preprocessing import PolynomialFeatures +from sklearn.calibration import CalibratedClassifierCV +import numpy as np +from numpy.random import default_rng +from scipy import stats +import pandas as pd +from pathlib import Path +from scipy.stats import bernoulli +from scipy.special import expit + +from itertools import combinations +from sklearn.model_selection import KFold + +pandas2ri.activate() +numpy2ri.activate() + +causalweight = rpackages.importr('causalweight') +mediation = rpackages.importr('mediation') +Rstats = rpackages.importr('stats') +base = rpackages.importr('base') +grf = rpackages.importr('grf') +plmed = rpackages.importr('plmed') + +ALPHAS = np.logspace(-5, 5, 8) +CV_FOLDS = 5 +TINY = 1.e-12 + + +def plain_IPW(y, t, x, trim=0.01, regularization=True): + """ + plain IPW estimator without mediation + """ + if regularization: + cs = ALPHAS + else: + cs = [np.inf] + if len(x.shape) == 1: + x = x.reshape(-1, 1) + p_x_clf = LogisticRegressionCV(Cs=cs, cv=CV_FOLDS).fit(x, t) + p_x = p_x_clf.predict_proba(x)[:, 1] + # trimming + p_x[p_x < trim] = trim + p_x[p_x > 1 - trim] = 1 - trim + y1m1 = np.sum(y * t / p_x) / np.sum(t / p_x) + y0m0 = np.sum(y * (1 - t) / (1 - p_x)) /\ + np.sum((1 - t) / (1 - p_x)) + return y1m1 - y0m0 + + +def AIPW(y, t, m, x, clip=0.01, forest=False, crossfit=0, forest_r=False, + regularization=True): + """ + AIPW estimator + """ + if regularization: + alphas = ALPHAS + cs = ALPHAS + else: + alphas = [TINY] + cs = [np.inf] + n = len(y) + if len(x.shape) == 1: + x = x.reshape(-1, 1) + if crossfit < 2: + train_test_list = [[np.arange(n), np.arange(n)]] + else: + kf = KFold(n_splits=crossfit) + train_test_list = list() + for train_index, test_index in kf.split(x): + train_test_list.append([train_index, test_index]) + + if not forest_r: + mu_1x, mu_0x, e_x = [np.zeros(n) for h in range(3)] + + for train_index, test_index in train_test_list: + treated_train_index = np.array( + list(set(train_index).intersection(np.where(t == 1)[0]))) + control_train_index = np.array( + list(set(train_index).intersection(np.where(t == 0)[0]))) + if not forest: + y_reg_treated = RidgeCV(alphas=alphas, cv=CV_FOLDS)\ + .fit(x[treated_train_index, :], y[treated_train_index]) + y_reg_control = RidgeCV(alphas=alphas, cv=CV_FOLDS)\ + .fit(x[control_train_index, :], y[control_train_index]) + t_prob = LogisticRegressionCV(Cs=cs, cv=CV_FOLDS)\ + .fit(x[train_index, :], t[train_index]) + else: + y_reg_treated = RandomForestRegressor(max_depth=3, + min_samples_leaf=5)\ + .fit(x[treated_train_index, :], y[treated_train_index]) + y_reg_control = RandomForestRegressor(max_depth=3, + min_samples_leaf=5)\ + .fit(x[control_train_index, :], y[control_train_index]) + t_prob = CalibratedClassifierCV( + RandomForestClassifier(max_depth=3, min_samples_leaf=5))\ + .fit(x[train_index, :], t[train_index]) + mu_1x[test_index] = y_reg_treated.predict(x[test_index, :]) + mu_0x[test_index] = y_reg_control.predict(x[test_index, :]) + e_x[test_index] = t_prob.predict_proba(x[test_index, :])[:, 1] + e_x = np.clip(e_x, clip, 1 - clip) + total_effect = np.mean(mu_1x - mu_0x + t * (y - mu_1x) / e_x - + (1 - t) * (y - mu_0x) / (1 - e_x)) + else: + x_r, t_r, y_r = [_convert_array_to_R(uu) for uu in (x, t, y)] + cf = grf.causal_forest(x_r, y_r, t_r, num_trees=500) + total_effect = grf.average_treatment_effect(cf)[0] + return [total_effect] + [None] * 5 + + +def bart(y, t, m, x, tmle=False): + x_r, t_r, y_r = [_convert_array_to_R(uu) for uu in (x, t, y)] + if tmle: + bart_model = brc.bartc(y_r, t_r, x_r, method_rsp='tmle', + estimand='ate') + else: + bart_model = brc.bartc(y_r, t_r, x_r, method_rsp='p.weight', + estimand='ate') + ate = brc.summary_bartcFit(bart_model).rx2('estimates')[0][0] + return [ate] + [None] * 5 \ No newline at end of file