Skip to content

Commit

Permalink
tests refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
brash6 committed Nov 13, 2024
1 parent 3650467 commit 192ea20
Show file tree
Hide file tree
Showing 6 changed files with 723 additions and 182 deletions.
2 changes: 1 addition & 1 deletion src/med_bench/estimation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def _fit_treatment_propensity_xm_nuisance(self, t, m, x):
return self

# TODO : Enable any sklearn object as classifier or regressor
def _fit_mediator_nuisance(self, t, m, x):
def _fit_mediator_nuisance(self, t, m, x, y):
"""Fits the nuisance parameter for the density f(M=m|T, X)"""
# estimate mediator densities
clf_param_grid = {}
Expand Down
8 changes: 4 additions & 4 deletions src/med_bench/estimation/mediation_g_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@ def __init__(self, regressor, classifier, **kwargs):
Parameters
----------
regressor
regressor
Regressor used for mu estimation, can be any object with a fit and predict method
classifier
classifier
Classifier used for propensity estimation, can be any object with a fit and predict_proba method
"""
super().__init__(**kwargs)
Expand All @@ -37,7 +37,7 @@ def fit(self, t, m, x, y):

if is_array_integer(m):
self._fit_mediator_nuisance(t, m, x, y)
self._fit_conditional_mean_outcome_nuisance
self._fit_conditional_mean_outcome_nuisance(t, m, x, y)
else:
self._fit_cross_conditional_mean_outcome_nuisance(t, m, x, y)

Expand All @@ -48,7 +48,7 @@ def fit(self, t, m, x, y):

return self

@fitted
@ fitted
def estimate(self, t, m, x, y):
"""Estimates causal effect on data
Expand Down
135 changes: 83 additions & 52 deletions src/med_bench/get_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from med_bench.estimation.mediation_coefficient_product import CoefficientProduct
from med_bench.estimation.mediation_dml import DoubleMachineLearning
from med_bench.estimation.mediation_g_computation import GComputation
from med_bench.estimation.mediation_ipw import ImportanceWeighting
from med_bench.estimation.mediation_ipw import InversePropensityWeighting
from med_bench.estimation.mediation_mr import MultiplyRobust
from med_bench.nuisances.utils import _get_regularization_parameters
from med_bench.utils.constants import CV_FOLDS
Expand All @@ -26,6 +26,7 @@
from sklearn.linear_model import LogisticRegressionCV, RidgeCV
from sklearn.calibration import CalibratedClassifierCV


def transform_outputs(causal_effects):
"""Transforms outputs in the old format
Expand All @@ -42,6 +43,7 @@ def transform_outputs(causal_effects):
indirect_control = causal_effects['indirect_effect_control']
return [total, direct_treated, direct_control, indirect_treated, indirect_control, 0]


def get_estimation(x, t, m, y, estimator, config):
"""Wrapper estimator fonction ; calls an estimator given mediation data
in order to estimate total, direct, and indirect effects.
Expand Down Expand Up @@ -89,9 +91,12 @@ def get_estimation(x, t, m, y, estimator, config):
effects = raw_res_R[0, :]
elif estimator == "coefficient_product":
effects = mediation_coefficient_product(y, t, m, x)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = CoefficientProduct(regressor=reg, classifier=clf, regularize=True)
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = CoefficientProduct(
regressor=reg, classifier=clf, regularize=True)
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand All @@ -112,7 +117,7 @@ def get_estimation(x, t, m, y, estimator, config):
cs, alphas = _get_regularization_parameters(regularization=False)
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = ImportanceWeighting(
estimator = InversePropensityWeighting(
clip=1e-6, trim=0, regressor=reg, classifier=clf
)
estimator.fit(t, m, x, y)
Expand Down Expand Up @@ -147,7 +152,7 @@ def get_estimation(x, t, m, y, estimator, config):
cs, alphas = _get_regularization_parameters(regularization=True)
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = ImportanceWeighting(
estimator = InversePropensityWeighting(
clip=1e-6, trim=0, regressor=reg, classifier=clf
)
estimator.fit(t, m, x, y)
Expand Down Expand Up @@ -182,7 +187,7 @@ def get_estimation(x, t, m, y, estimator, config):
cs, alphas = _get_regularization_parameters(regularization=True)
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = ImportanceWeighting(
estimator = InversePropensityWeighting(
clip=1e-6, trim=0, regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid")
)
estimator.fit(t, m, x, y)
Expand All @@ -204,7 +209,7 @@ def get_estimation(x, t, m, y, estimator, config):
cs, alphas = _get_regularization_parameters(regularization=True)
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = ImportanceWeighting(
estimator = InversePropensityWeighting(
clip=1e-6, trim=0, regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic")
)
estimator.fit(t, m, x, y)
Expand Down Expand Up @@ -250,9 +255,11 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=None,
)
cs, alphas = _get_regularization_parameters(regularization=True)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = ImportanceWeighting(
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = InversePropensityWeighting(
clip=1e-6, trim=0, regressor=reg, classifier=clf
)
estimator.fit(t, m, x, y)
Expand Down Expand Up @@ -285,9 +292,11 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=None,
)
cs, alphas = _get_regularization_parameters(regularization=True)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = ImportanceWeighting(
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = InversePropensityWeighting(
clip=1e-6, trim=0, regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid")
)
estimator.fit(t, m, x, y)
Expand All @@ -307,9 +316,11 @@ def get_estimation(x, t, m, y, estimator, config):
calibration="isotonic",
)
cs, alphas = _get_regularization_parameters(regularization=True)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = ImportanceWeighting(
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = InversePropensityWeighting(
clip=1e-6, trim=0, regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic")
)
estimator.fit(t, m, x, y)
Expand Down Expand Up @@ -423,7 +434,8 @@ def get_estimation(x, t, m, y, estimator, config):
cs, alphas = _get_regularization_parameters(regularization=True)
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = GComputation(regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid"))
estimator = GComputation(
regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid"))
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand All @@ -443,7 +455,8 @@ def get_estimation(x, t, m, y, estimator, config):
cs, alphas = _get_regularization_parameters(regularization=True)
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = GComputation(regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic"))
estimator = GComputation(
regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic"))
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand Down Expand Up @@ -487,8 +500,10 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=None,
)
cs, alphas = _get_regularization_parameters(regularization=True)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = GComputation(regressor=reg, classifier=clf)
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
Expand Down Expand Up @@ -520,9 +535,12 @@ def get_estimation(x, t, m, y, estimator, config):
calibration='sigmoid',
)
cs, alphas = _get_regularization_parameters(regularization=True)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = GComputation(regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid"))
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = GComputation(
regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid"))
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand All @@ -541,9 +559,12 @@ def get_estimation(x, t, m, y, estimator, config):
calibration="isotonic",
)
cs, alphas = _get_regularization_parameters(regularization=True)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = GComputation(regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic"))
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = GComputation(
regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic"))
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand Down Expand Up @@ -592,8 +613,8 @@ def get_estimation(x, t, m, y, estimator, config):
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = MultiplyRobust(
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=clf)
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=clf)
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand Down Expand Up @@ -630,8 +651,8 @@ def get_estimation(x, t, m, y, estimator, config):
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = MultiplyRobust(
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=clf)
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=clf)
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand Down Expand Up @@ -667,8 +688,8 @@ def get_estimation(x, t, m, y, estimator, config):
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = MultiplyRobust(
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=CalibratedClassifierCV(clf, method="sigmoid"))
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=CalibratedClassifierCV(clf, method="sigmoid"))
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand All @@ -690,8 +711,8 @@ def get_estimation(x, t, m, y, estimator, config):
clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS)
reg = RidgeCV(alphas=alphas, cv=CV_FOLDS)
estimator = MultiplyRobust(
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=CalibratedClassifierCV(clf, method="isotonic"))
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=CalibratedClassifierCV(clf, method="isotonic"))
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand Down Expand Up @@ -738,11 +759,13 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=None,
)
cs, alphas = _get_regularization_parameters(regularization=False)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = MultiplyRobust(
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=clf)
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=clf)
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand Down Expand Up @@ -775,11 +798,13 @@ def get_estimation(x, t, m, y, estimator, config):
calibration='sigmoid',
)
cs, alphas = _get_regularization_parameters(regularization=False)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = MultiplyRobust(
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=CalibratedClassifierCV(clf, method="sigmoid"))
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=CalibratedClassifierCV(clf, method="sigmoid"))
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand All @@ -798,11 +823,13 @@ def get_estimation(x, t, m, y, estimator, config):
calibration="isotonic",
)
cs, alphas = _get_regularization_parameters(regularization=False)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = MultiplyRobust(
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=CalibratedClassifierCV(clf, method="isotonic"))
clip=1e-6, ratio="propensities", normalized=True, regressor=reg,
classifier=CalibratedClassifierCV(clf, method="isotonic"))
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)
Expand Down Expand Up @@ -859,7 +886,7 @@ def get_estimation(x, t, m, y, estimator, config):
estimator.fit(t, m, x, y)
causal_effects = estimator.estimate(t, m, x, y)
effects = transform_outputs(causal_effects)

elif estimator == "mediation_dml_reg":
effects = mediation_dml(
y, t, m, x, trim=0, clip=1e-6, calibration=None)
Expand Down Expand Up @@ -913,8 +940,10 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=None,
forest=True)
cs, alphas = _get_regularization_parameters(regularization=True)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = DoubleMachineLearning(
clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=clf
)
Expand All @@ -933,8 +962,10 @@ def get_estimation(x, t, m, y, estimator, config):
calibration='sigmoid',
forest=True)
cs, alphas = _get_regularization_parameters(regularization=True)
clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42)
clf = RandomForestClassifier(
random_state=42, n_estimators=100, min_samples_leaf=10)
reg = RandomForestRegressor(
n_estimators=100, min_samples_leaf=10, random_state=42)
estimator = DoubleMachineLearning(
clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid")
)
Expand Down
Loading

0 comments on commit 192ea20

Please sign in to comment.