From 192ea20e89f87a9dbbcb7e5f4cd625cfeb1df784 Mon Sep 17 00:00:00 2001 From: brash6 Date: Wed, 13 Nov 2024 23:14:27 +0100 Subject: [PATCH] tests refactor --- src/med_bench/estimation/base.py | 2 +- .../estimation/mediation_g_computation.py | 8 +- src/med_bench/get_estimation.py | 135 +++-- src/med_bench/get_estimation_results.py | 529 ++++++++++++++++++ src/tests/estimation/test_exact_estimation.py | 109 ---- src/tests/estimation/test_get_estimation.py | 122 +++- 6 files changed, 723 insertions(+), 182 deletions(-) create mode 100644 src/med_bench/get_estimation_results.py delete mode 100644 src/tests/estimation/test_exact_estimation.py diff --git a/src/med_bench/estimation/base.py b/src/med_bench/estimation/base.py index cdd8b73..6bfa501 100644 --- a/src/med_bench/estimation/base.py +++ b/src/med_bench/estimation/base.py @@ -156,7 +156,7 @@ def _fit_treatment_propensity_xm_nuisance(self, t, m, x): return self # TODO : Enable any sklearn object as classifier or regressor - def _fit_mediator_nuisance(self, t, m, x): + def _fit_mediator_nuisance(self, t, m, x, y): """Fits the nuisance parameter for the density f(M=m|T, X)""" # estimate mediator densities clf_param_grid = {} diff --git a/src/med_bench/estimation/mediation_g_computation.py b/src/med_bench/estimation/mediation_g_computation.py index b531640..a813b41 100644 --- a/src/med_bench/estimation/mediation_g_computation.py +++ b/src/med_bench/estimation/mediation_g_computation.py @@ -13,9 +13,9 @@ def __init__(self, regressor, classifier, **kwargs): Parameters ---------- - regressor + regressor Regressor used for mu estimation, can be any object with a fit and predict method - classifier + classifier Classifier used for propensity estimation, can be any object with a fit and predict_proba method """ super().__init__(**kwargs) @@ -37,7 +37,7 @@ def fit(self, t, m, x, y): if is_array_integer(m): self._fit_mediator_nuisance(t, m, x, y) - self._fit_conditional_mean_outcome_nuisance + self._fit_conditional_mean_outcome_nuisance(t, m, x, y) else: self._fit_cross_conditional_mean_outcome_nuisance(t, m, x, y) @@ -48,7 +48,7 @@ def fit(self, t, m, x, y): return self - @fitted + @ fitted def estimate(self, t, m, x, y): """Estimates causal effect on data diff --git a/src/med_bench/get_estimation.py b/src/med_bench/get_estimation.py index 8245c81..c58e2a2 100644 --- a/src/med_bench/get_estimation.py +++ b/src/med_bench/get_estimation.py @@ -17,7 +17,7 @@ from med_bench.estimation.mediation_coefficient_product import CoefficientProduct from med_bench.estimation.mediation_dml import DoubleMachineLearning from med_bench.estimation.mediation_g_computation import GComputation -from med_bench.estimation.mediation_ipw import ImportanceWeighting +from med_bench.estimation.mediation_ipw import InversePropensityWeighting from med_bench.estimation.mediation_mr import MultiplyRobust from med_bench.nuisances.utils import _get_regularization_parameters from med_bench.utils.constants import CV_FOLDS @@ -26,6 +26,7 @@ from sklearn.linear_model import LogisticRegressionCV, RidgeCV from sklearn.calibration import CalibratedClassifierCV + def transform_outputs(causal_effects): """Transforms outputs in the old format @@ -42,6 +43,7 @@ def transform_outputs(causal_effects): indirect_control = causal_effects['indirect_effect_control'] return [total, direct_treated, direct_control, indirect_treated, indirect_control, 0] + def get_estimation(x, t, m, y, estimator, config): """Wrapper estimator fonction ; calls an estimator given mediation data in order to estimate total, direct, and indirect effects. @@ -89,9 +91,12 @@ def get_estimation(x, t, m, y, estimator, config): effects = raw_res_R[0, :] elif estimator == "coefficient_product": effects = mediation_coefficient_product(y, t, m, x) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) - estimator = CoefficientProduct(regressor=reg, classifier=clf, regularize=True) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator = CoefficientProduct( + regressor=reg, classifier=clf, regularize=True) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -112,7 +117,7 @@ def get_estimation(x, t, m, y, estimator, config): cs, alphas = _get_regularization_parameters(regularization=False) clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) - estimator = ImportanceWeighting( + estimator = InversePropensityWeighting( clip=1e-6, trim=0, regressor=reg, classifier=clf ) estimator.fit(t, m, x, y) @@ -147,7 +152,7 @@ def get_estimation(x, t, m, y, estimator, config): cs, alphas = _get_regularization_parameters(regularization=True) clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) - estimator = ImportanceWeighting( + estimator = InversePropensityWeighting( clip=1e-6, trim=0, regressor=reg, classifier=clf ) estimator.fit(t, m, x, y) @@ -182,7 +187,7 @@ def get_estimation(x, t, m, y, estimator, config): cs, alphas = _get_regularization_parameters(regularization=True) clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) - estimator = ImportanceWeighting( + estimator = InversePropensityWeighting( clip=1e-6, trim=0, regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid") ) estimator.fit(t, m, x, y) @@ -204,7 +209,7 @@ def get_estimation(x, t, m, y, estimator, config): cs, alphas = _get_regularization_parameters(regularization=True) clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) - estimator = ImportanceWeighting( + estimator = InversePropensityWeighting( clip=1e-6, trim=0, regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic") ) estimator.fit(t, m, x, y) @@ -250,9 +255,11 @@ def get_estimation(x, t, m, y, estimator, config): calibration=None, ) cs, alphas = _get_regularization_parameters(regularization=True) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) - estimator = ImportanceWeighting( + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator = InversePropensityWeighting( clip=1e-6, trim=0, regressor=reg, classifier=clf ) estimator.fit(t, m, x, y) @@ -285,9 +292,11 @@ def get_estimation(x, t, m, y, estimator, config): calibration=None, ) cs, alphas = _get_regularization_parameters(regularization=True) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) - estimator = ImportanceWeighting( + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator = InversePropensityWeighting( clip=1e-6, trim=0, regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid") ) estimator.fit(t, m, x, y) @@ -307,9 +316,11 @@ def get_estimation(x, t, m, y, estimator, config): calibration="isotonic", ) cs, alphas = _get_regularization_parameters(regularization=True) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) - estimator = ImportanceWeighting( + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator = InversePropensityWeighting( clip=1e-6, trim=0, regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic") ) estimator.fit(t, m, x, y) @@ -423,7 +434,8 @@ def get_estimation(x, t, m, y, estimator, config): cs, alphas = _get_regularization_parameters(regularization=True) clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) - estimator = GComputation(regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid")) + estimator = GComputation( + regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid")) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -443,7 +455,8 @@ def get_estimation(x, t, m, y, estimator, config): cs, alphas = _get_regularization_parameters(regularization=True) clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) - estimator = GComputation(regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic")) + estimator = GComputation( + regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic")) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -487,8 +500,10 @@ def get_estimation(x, t, m, y, estimator, config): calibration=None, ) cs, alphas = _get_regularization_parameters(regularization=True) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) estimator = GComputation(regressor=reg, classifier=clf) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) @@ -520,9 +535,12 @@ def get_estimation(x, t, m, y, estimator, config): calibration='sigmoid', ) cs, alphas = _get_regularization_parameters(regularization=True) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) - estimator = GComputation(regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid")) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator = GComputation( + regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid")) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -541,9 +559,12 @@ def get_estimation(x, t, m, y, estimator, config): calibration="isotonic", ) cs, alphas = _get_regularization_parameters(regularization=True) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) - estimator = GComputation(regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic")) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator = GComputation( + regressor=reg, classifier=CalibratedClassifierCV(clf, method="isotonic")) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -592,8 +613,8 @@ def get_estimation(x, t, m, y, estimator, config): clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) estimator = MultiplyRobust( - clip=1e-6, ratio="propensities", normalized=True, regressor=reg, - classifier=clf) + clip=1e-6, ratio="propensities", normalized=True, regressor=reg, + classifier=clf) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -630,8 +651,8 @@ def get_estimation(x, t, m, y, estimator, config): clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) estimator = MultiplyRobust( - clip=1e-6, ratio="propensities", normalized=True, regressor=reg, - classifier=clf) + clip=1e-6, ratio="propensities", normalized=True, regressor=reg, + classifier=clf) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -667,8 +688,8 @@ def get_estimation(x, t, m, y, estimator, config): clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) estimator = MultiplyRobust( - clip=1e-6, ratio="propensities", normalized=True, regressor=reg, - classifier=CalibratedClassifierCV(clf, method="sigmoid")) + clip=1e-6, ratio="propensities", normalized=True, regressor=reg, + classifier=CalibratedClassifierCV(clf, method="sigmoid")) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -690,8 +711,8 @@ def get_estimation(x, t, m, y, estimator, config): clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) estimator = MultiplyRobust( - clip=1e-6, ratio="propensities", normalized=True, regressor=reg, - classifier=CalibratedClassifierCV(clf, method="isotonic")) + clip=1e-6, ratio="propensities", normalized=True, regressor=reg, + classifier=CalibratedClassifierCV(clf, method="isotonic")) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -738,11 +759,13 @@ def get_estimation(x, t, m, y, estimator, config): calibration=None, ) cs, alphas = _get_regularization_parameters(regularization=False) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) estimator = MultiplyRobust( - clip=1e-6, ratio="propensities", normalized=True, regressor=reg, - classifier=clf) + clip=1e-6, ratio="propensities", normalized=True, regressor=reg, + classifier=clf) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -775,11 +798,13 @@ def get_estimation(x, t, m, y, estimator, config): calibration='sigmoid', ) cs, alphas = _get_regularization_parameters(regularization=False) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) estimator = MultiplyRobust( - clip=1e-6, ratio="propensities", normalized=True, regressor=reg, - classifier=CalibratedClassifierCV(clf, method="sigmoid")) + clip=1e-6, ratio="propensities", normalized=True, regressor=reg, + classifier=CalibratedClassifierCV(clf, method="sigmoid")) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -798,11 +823,13 @@ def get_estimation(x, t, m, y, estimator, config): calibration="isotonic", ) cs, alphas = _get_regularization_parameters(regularization=False) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) estimator = MultiplyRobust( - clip=1e-6, ratio="propensities", normalized=True, regressor=reg, - classifier=CalibratedClassifierCV(clf, method="isotonic")) + clip=1e-6, ratio="propensities", normalized=True, regressor=reg, + classifier=CalibratedClassifierCV(clf, method="isotonic")) estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) @@ -859,7 +886,7 @@ def get_estimation(x, t, m, y, estimator, config): estimator.fit(t, m, x, y) causal_effects = estimator.estimate(t, m, x, y) effects = transform_outputs(causal_effects) - + elif estimator == "mediation_dml_reg": effects = mediation_dml( y, t, m, x, trim=0, clip=1e-6, calibration=None) @@ -913,8 +940,10 @@ def get_estimation(x, t, m, y, estimator, config): calibration=None, forest=True) cs, alphas = _get_regularization_parameters(regularization=True) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) estimator = DoubleMachineLearning( clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=clf ) @@ -933,8 +962,10 @@ def get_estimation(x, t, m, y, estimator, config): calibration='sigmoid', forest=True) cs, alphas = _get_regularization_parameters(regularization=True) - clf = RandomForestClassifier(random_state=42, n_estimators=100, min_samples_leaf=10) - reg = RandomForestRegressor(n_estimators=100, min_samples_leaf=10, random_state=42) + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) estimator = DoubleMachineLearning( clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=CalibratedClassifierCV(clf, method="sigmoid") ) diff --git a/src/med_bench/get_estimation_results.py b/src/med_bench/get_estimation_results.py new file mode 100644 index 0000000..5498021 --- /dev/null +++ b/src/med_bench/get_estimation_results.py @@ -0,0 +1,529 @@ +#!/usr/bin/env python +# -*- coding:utf-8 -*- + +import numpy as np + +from .mediation import ( + mediation_IPW, + mediation_g_formula, + mediation_multiply_robust, + mediation_dml, + r_mediation_g_estimator, + r_mediation_dml, + r_mediate, +) + +from med_bench.estimation.mediation_coefficient_product import CoefficientProduct +from med_bench.estimation.mediation_dml import DoubleMachineLearning +from med_bench.estimation.mediation_g_computation import GComputation +from med_bench.estimation.mediation_ipw import InversePropensityWeighting +from med_bench.estimation.mediation_mr import MultiplyRobust +from med_bench.nuisances.utils import _get_regularization_parameters +from med_bench.utils.constants import CV_FOLDS + +from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor +from sklearn.linear_model import LogisticRegressionCV, RidgeCV +from sklearn.calibration import CalibratedClassifierCV + + +def transform_outputs(causal_effects): + """Transforms outputs in the old format + + Args: + causal_effects (dict): dictionary of causal effects + + Returns: + list: list of causal effects + """ + total = causal_effects['total_effect'] + direct_treated = causal_effects['direct_effect_treated'] + direct_control = causal_effects['direct_effect_control'] + indirect_treated = causal_effects['indirect_effect_treated'] + indirect_control = causal_effects['indirect_effect_control'] + return [total, direct_treated, direct_control, indirect_treated, indirect_control, 0] + + +def get_estimation_results(x, t, m, y, estimator, config): + """Dynamically selects and calls an estimator (class-based or legacy function) to estimate total, direct, and indirect effects.""" + + effects = None # Initialize variable to store the effects + + # Helper function for regularized regressor and classifier initialization + def get_regularized_regressor_and_classifier(regularize=True, calibration=None, method="sigmoid"): + cs, alphas = _get_regularization_parameters(regularization=regularize) + clf = LogisticRegressionCV(random_state=42, Cs=cs, cv=CV_FOLDS) + reg = RidgeCV(alphas=alphas, cv=CV_FOLDS) + if calibration: + clf = CalibratedClassifierCV(clf, method=method) + return clf, reg + + if estimator == "mediation_IPW_R": + # Use R-based mediation estimator with direct output extraction + x_r, t_r, m_r, y_r = [_convert_array_to_R(uu) for uu in (x, t, m, y)] + output_w = causalweight.medweight( + y=y_r, d=t_r, m=m_r, x=x_r, trim=0.0, ATET="FALSE", logit="TRUE", boot=2 + ) + raw_res_R = np.array(output_w.rx2("results")) + effects = raw_res_R[0, :] + + elif estimator == "coefficient_product": + # Class-based implementation for CoefficientProduct + estimator_obj = CoefficientProduct(regularize=True) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_ipw_noreg": + # Class-based implementation for InversePropensityWeighting without regularization + clf, reg = get_regularized_regressor_and_classifier(regularize=False) + estimator_obj = InversePropensityWeighting( + clip=1e-6, trim=0, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_ipw_noreg_cf": + # Legacy function for crossfit with no regularization + effects = mediation_IPW( + y, t, m, x, trim=0, regularization=False, forest=False, crossfit=2, clip=1e-6, calibration=None + ) + + elif estimator == "mediation_ipw_reg": + # Class-based implementation with regularization + clf, reg = get_regularized_regressor_and_classifier(regularize=True) + estimator_obj = InversePropensityWeighting( + clip=1e-6, trim=0, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_ipw_reg_cf": + # Legacy function with crossfit and regularization + effects = mediation_IPW( + y, t, m, x, trim=0, regularization=True, forest=False, crossfit=2, clip=1e-6, calibration=None + ) + + elif estimator == "mediation_ipw_reg_calibration": + # Class-based implementation with regularization and calibration (sigmoid) + clf, reg = get_regularized_regressor_and_classifier( + regularize=True, calibration=True, method="sigmoid") + estimator_obj = InversePropensityWeighting( + clip=1e-6, trim=0, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_ipw_reg_calibration_iso": + # Class-based implementation with isotonic calibration + clf, reg = get_regularized_regressor_and_classifier( + regularize=True, calibration=True, method="isotonic") + estimator_obj = InversePropensityWeighting( + clip=1e-6, trim=0, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_ipw_reg_calibration_cf": + # Legacy function with crossfit and sigmoid calibration + effects = mediation_IPW( + y, t, m, x, trim=0, regularization=True, forest=False, crossfit=2, clip=1e-6, calibration="sigmoid" + ) + + elif estimator == "mediation_ipw_reg_calibration_iso_cf": + # Legacy function with crossfit and isotonic calibration + effects = mediation_IPW( + y, t, m, x, trim=0, regularization=True, forest=False, crossfit=2, clip=1e-6, calibration="isotonic" + ) + + elif estimator == "mediation_ipw_forest": + # Class-based implementation with forest models + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator_obj = InversePropensityWeighting( + clip=1e-6, trim=0, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_ipw_forest_cf": + # Legacy function with forest and crossfit + effects = mediation_IPW( + y, t, m, x, trim=0, regularization=True, forest=True, crossfit=2, clip=1e-6, calibration=None + ) + + elif estimator == "mediation_ipw_forest_calibration": + # Class-based implementation with forest and calibrated sigmoid + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + calibrated_clf = CalibratedClassifierCV(clf, method="sigmoid") + estimator_obj = InversePropensityWeighting( + clip=1e-6, trim=0, regressor=reg, classifier=calibrated_clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_ipw_forest_calibration_iso": + # Class-based implementation with isotonic calibration + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + calibrated_clf = CalibratedClassifierCV(clf, method="isotonic") + estimator_obj = InversePropensityWeighting( + clip=1e-6, trim=0, regressor=reg, classifier=calibrated_clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_g_computation_noreg": + # Class-based implementation of GComputation without regularization + clf, reg = get_regularized_regressor_and_classifier(regularize=False) + estimator_obj = GComputation(regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_g_computation_noreg_cf": + # Legacy function with crossfit and no regularization + effects = mediation_g_formula( + y, t, m, x, interaction=False, forest=False, crossfit=2, regularization=False, calibration=None + ) + + elif estimator == "mediation_g_computation_reg": + # Class-based implementation of GComputation with regularization + clf, reg = get_regularized_regressor_and_classifier(regularize=True) + estimator_obj = GComputation(regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_g_computation_reg_cf": + # Legacy function with regularization and crossfit + effects = mediation_g_formula( + y, t, m, x, interaction=False, forest=False, crossfit=2, regularization=True, calibration=None + ) + + elif estimator == "mediation_g_computation_forest_cf": + if config in (0, 1, 2): + effects = mediation_g_formula( + y, + t, + m, + x, + interaction=False, + forest=True, + crossfit=2, + regularization=True, + calibration=None, + ) + + elif estimator == "mediation_g_computation_reg_calibration": + # Class-based implementation with regularization and calibrated sigmoid + clf, reg = get_regularized_regressor_and_classifier( + regularize=True, calibration=True, method="sigmoid") + estimator_obj = GComputation(regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_g_computation_reg_calibration_iso": + # Class-based implementation with isotonic calibration + clf, reg = get_regularized_regressor_and_classifier( + regularize=True, calibration=True, method="isotonic") + estimator_obj = GComputation(regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_g_computation_forest": + # Class-based implementation with forest models + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator_obj = GComputation(regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "mediation_multiply_robust_noreg": + # Class-based implementation for MultiplyRobust without regularization + clf, reg = get_regularized_regressor_and_classifier(regularize=False) + estimator_obj = MultiplyRobust( + ratio="propensities", normalized=True, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + elif estimator == "simulation_based": + # R-based function for simulation + effects = r_mediate(y, t, m, x, interaction=False) + + elif estimator == "mediation_dml": + # R-based function for Double Machine Learning with legacy config + effects = r_mediation_dml(y, t, m, x, trim=0.0, order=1) + + elif estimator == "mediation_dml_noreg": + # Class-based implementation for DoubleMachineLearning without regularization + clf, reg = get_regularized_regressor_and_classifier(regularize=False) + estimator_obj = DoubleMachineLearning( + clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # Regularized, crossfitting, calibration (isotonic) for InversePropensityWeighting + elif estimator == "mediation_ipw_reg_calibration_iso_cf": + effects = mediation_IPW( + y, t, m, x, trim=0, regularization=True, forest=False, crossfit=2, clip=1e-6, calibration="isotonic" + ) + + # Forest and crossfit with sigmoid calibration for InversePropensityWeighting + elif estimator == "mediation_ipw_forest_calibration_cf": + effects = mediation_IPW( + y, t, m, x, trim=0, regularization=True, forest=True, crossfit=2, clip=1e-6, calibration="sigmoid" + ) + + # Forest and crossfit with isotonic calibration for InversePropensityWeighting + elif estimator == "mediation_ipw_forest_calibration_iso_cf": + effects = mediation_IPW( + y, t, m, x, trim=0, regularization=True, forest=True, crossfit=2, clip=1e-6, calibration="isotonic" + ) + + # MultiplyRobust without regularization, with crossfitting + elif estimator == "mediation_multiply_robust_noreg_cf": + effects = mediation_multiply_robust( + y, t, m.astype(int), x, interaction=False, forest=False, crossfit=2, clip=1e-6, regularization=False, calibration=None + ) + + # Regularized MultiplyRobust estimator + elif estimator == "mediation_multiply_robust_reg": + clf, reg = get_regularized_regressor_and_classifier(regularize=True) + estimator_obj = MultiplyRobust( + ratio="propensities", normalized=True, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # Regularized MultiplyRobust with crossfitting + elif estimator == "mediation_multiply_robust_reg_cf": + effects = mediation_multiply_robust( + y, t, m.astype(int), x, interaction=False, forest=False, crossfit=2, clip=1e-6, regularization=True, calibration=None + ) + + # Regularized MultiplyRobust with sigmoid calibration + elif estimator == "mediation_multiply_robust_reg_calibration": + clf, reg = get_regularized_regressor_and_classifier( + regularize=True, calibration=True, method="sigmoid") + estimator_obj = MultiplyRobust( + ratio="propensities", normalized=True, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # Regularized MultiplyRobust with isotonic calibration + elif estimator == "mediation_multiply_robust_reg_calibration_iso": + clf, reg = get_regularized_regressor_and_classifier( + regularize=True, calibration=True, method="isotonic") + estimator_obj = MultiplyRobust( + ratio="propensities", normalized=True, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # Regularized MultiplyRobust with sigmoid calibration and crossfitting + elif estimator == "mediation_multiply_robust_reg_calibration_cf": + effects = mediation_multiply_robust( + y, t, m.astype(int), x, interaction=False, forest=False, crossfit=2, clip=1e-6, regularization=True, calibration="sigmoid" + ) + + # Regularized MultiplyRobust with isotonic calibration and crossfitting + elif estimator == "mediation_multiply_robust_reg_calibration_iso_cf": + effects = mediation_multiply_robust( + y, t, m.astype(int), x, interaction=False, forest=False, crossfit=2, clip=1e-6, regularization=True, calibration="isotonic" + ) + + elif estimator == "mediation_multiply_robust_forest": + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator = MultiplyRobust( + ratio="propensities", normalized=True, regressor=reg, + classifier=clf) + estimator.fit(t, m, x, y) + causal_effects = estimator.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # MultiplyRobust with forest and crossfitting + elif estimator == "mediation_multiply_robust_forest_cf": + effects = mediation_multiply_robust( + y, t, m.astype(int), x, interaction=False, forest=True, crossfit=2, clip=1e-6, regularization=True, calibration=None + ) + + # MultiplyRobust with forest and sigmoid calibration + elif estimator == "mediation_multiply_robust_forest_calibration": + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + calibrated_clf = CalibratedClassifierCV(clf, method="sigmoid") + estimator_obj = MultiplyRobust( + ratio="propensities", normalized=True, regressor=reg, classifier=calibrated_clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # MultiplyRobust with forest and isotonic calibration + elif estimator == "mediation_multiply_robust_forest_calibration_iso": + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + calibrated_clf = CalibratedClassifierCV(clf, method="isotonic") + estimator_obj = MultiplyRobust( + ratio="propensities", normalized=True, regressor=reg, classifier=calibrated_clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # MultiplyRobust with forest, sigmoid calibration, and crossfitting + elif estimator == "mediation_multiply_robust_forest_calibration_cf": + effects = mediation_multiply_robust( + y, t, m.astype(int), x, interaction=False, forest=True, crossfit=2, clip=1e-6, regularization=True, calibration="sigmoid" + ) + + # MultiplyRobust with forest, isotonic calibration, and crossfitting + elif estimator == "mediation_multiply_robust_forest_calibration_iso_cf": + effects = mediation_multiply_robust( + y, t, m.astype(int), x, interaction=False, forest=True, crossfit=2, clip=1e-6, regularization=True, calibration="isotonic" + ) + + # Regularized Double Machine Learning + elif estimator == "mediation_dml_reg": + clf, reg = get_regularized_regressor_and_classifier(regularize=True) + estimator_obj = DoubleMachineLearning( + clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # Double Machine Learning with fixed seed + elif estimator == "mediation_dml_reg_fixed_seed": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, random_state=321, calibration=None) + + # Double Machine Learning without regularization, with crossfitting + elif estimator == "mediation_dml_noreg_cf": + effects = mediation_dml(y, t, m, x, trim=0, clip=1e-6, + crossfit=2, regularization=False, calibration=None) + + # Regularized Double Machine Learning with crossfitting + elif estimator == "mediation_dml_reg_cf": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, crossfit=2, calibration=None) + + # Regularized Double Machine Learning with sigmoid calibration + elif estimator == "mediation_dml_reg_calibration": + clf, reg = get_regularized_regressor_and_classifier( + regularize=True, calibration=True, method="sigmoid") + estimator_obj = DoubleMachineLearning( + clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # Regularized Double Machine Learning with forest models + elif estimator == "mediation_dml_forest": + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + estimator_obj = DoubleMachineLearning( + clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # Double Machine Learning with forest and calibrated sigmoid + elif estimator == "mediation_dml_forest_calibration": + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + calibrated_clf = CalibratedClassifierCV(clf, method="sigmoid") + estimator_obj = DoubleMachineLearning( + clip=1e-6, trim=0, normalized=True, regressor=reg, classifier=calibrated_clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # Double Machine Learning with forest, crossfitting, and sigmoid calibration + elif estimator == "mediation_dml_reg_calibration_cf": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, crossfit=2, calibration="sigmoid", forest=False) + + # Double Machine Learning with forest and crossfitting + elif estimator == "mediation_dml_forest_cf": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, crossfit=2, calibration=None, forest=True) + + # Double Machine Learning with forest, crossfitting, and calibrated sigmoid + elif estimator == "mediation_dml_forest_calibration_cf": + effects = mediation_dml( + y, t, m, x, trim=0, clip=1e-6, crossfit=2, calibration="sigmoid", forest=True) + + # GComputation with regularization, crossfitting, and sigmoid calibration + elif estimator == "mediation_g_computation_reg_calibration_cf": + effects = mediation_g_formula( + y, t, m, x, interaction=False, forest=False, crossfit=2, regularization=True, calibration="sigmoid") + + # GComputation with forest and sigmoid calibration + elif estimator == "mediation_g_computation_forest_calibration": + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + calibrated_clf = CalibratedClassifierCV(clf, method="sigmoid") + estimator_obj = GComputation(regressor=reg, classifier=calibrated_clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # GComputation with forest and isotonic calibration + elif estimator == "mediation_g_computation_forest_calibration_iso": + clf = RandomForestClassifier( + random_state=42, n_estimators=100, min_samples_leaf=10) + reg = RandomForestRegressor( + n_estimators=100, min_samples_leaf=10, random_state=42) + calibrated_clf = CalibratedClassifierCV(clf, method="isotonic") + estimator_obj = GComputation(regressor=reg, classifier=calibrated_clf) + estimator_obj.fit(t, m, x, y) + causal_effects = estimator_obj.estimate(t, m, x, y) + effects = transform_outputs(causal_effects) + + # GComputation with forest, crossfitting, and sigmoid calibration + elif estimator == "mediation_g_computation_forest_calibration_cf": + effects = mediation_g_formula( + y, t, m, x, interaction=False, forest=True, crossfit=2, regularization=True, calibration="sigmoid") + + # GComputation with forest, crossfitting, and isotonic calibration + elif estimator == "mediation_g_computation_forest_calibration_iso_cf": + effects = mediation_g_formula( + y, t, m, x, interaction=False, forest=True, crossfit=2, regularization=True, calibration="isotonic") + + elif estimator == "mediation_g_estimator": + if config in (0, 1, 2): + effects = r_mediation_g_estimator(y, t, m, x) + else: + raise ValueError("Unrecognized estimator label.") + + # Catch unsupported estimators and raise an error + if effects is None: + raise ValueError( + f"Estimation failed for {estimator}. Check inputs or configuration.") + return effects diff --git a/src/tests/estimation/test_exact_estimation.py b/src/tests/estimation/test_exact_estimation.py deleted file mode 100644 index 3e4fab1..0000000 --- a/src/tests/estimation/test_exact_estimation.py +++ /dev/null @@ -1,109 +0,0 @@ -""" -Pytest file for get_estimation.py - -It tests all the benchmark_mediation estimators : -- for a certain tolerance -- whether their effects satisfy "total = direct + indirect" -- whether they support (n,1) and (n,) inputs - -To be robust to future updates, tests are adjusted with a smaller tolerance when possible. -The test is skipped if estimator has not been implemented yet, i.e. if ValueError is raised. -The test fails for any other unwanted behavior. -""" - -from pprint import pprint -import pytest -import os -import numpy as np - -from med_bench.get_estimation import get_estimation -from med_bench.utils.constants import R_DEPENDENT_ESTIMATORS -from med_bench.utils.utils import DependencyNotInstalledError, check_r_dependencies - -current_dir = os.path.dirname(__file__) -true_estimations_file_path = os.path.join(current_dir, 'tests_results.npy') -TRUE_ESTIMATIONS = np.load(true_estimations_file_path, allow_pickle=True) - - -@pytest.fixture(params=range(TRUE_ESTIMATIONS.shape[0])) -def tests_results_idx(request): - return request.param - - -@pytest.fixture -def data(tests_results_idx): - return TRUE_ESTIMATIONS[tests_results_idx] - - -@pytest.fixture -def estimator(data): - return data[0] - - -@pytest.fixture -def x(data): - return data[1] - - -# t is raveled because some estimators fail with (n,1) inputs -@pytest.fixture -def t(data): - return data[2] - - -@pytest.fixture -def m(data): - return data[3] - - -@pytest.fixture -def y(data): - return data[4] - - -@pytest.fixture -def config(data): - return data[5] - - -@pytest.fixture -def result(data): - return data[6] - - -@pytest.fixture -def effects_chap(x, t, m, y, estimator, config): - # try whether estimator is implemented or not - - try: - res = get_estimation(x, t, m, y, estimator, config)[0:5] - - # NaN situations - if np.all(np.isnan(res)): - pytest.xfail("all effects are NaN") - elif np.any(np.isnan(res)): - pprint("NaN found") - - except Exception as e: - if str(e) in ( - "Estimator only supports 1D binary mediator.", - "Estimator does not support 1D binary mediator.", - ): - pytest.skip(f"{e}") - - # We skip the test if an error with function from glmet rpy2 package occurs - elif "glmnet::glmnet" in str(e): - pytest.skip(f"{e}") - - elif estimator in R_DEPENDENT_ESTIMATORS and not check_r_dependencies(): - assert isinstance(e, DependencyNotInstalledError) == True - pytest.skip(f"{e}") - - else: - pytest.fail(f"{e}") - - return res - - -def test_estimation_exactness(result, effects_chap): - assert np.all(effects_chap == pytest.approx(result, abs=1.e-3)) diff --git a/src/tests/estimation/test_get_estimation.py b/src/tests/estimation/test_get_estimation.py index 9a99b90..04d57c8 100644 --- a/src/tests/estimation/test_get_estimation.py +++ b/src/tests/estimation/test_get_estimation.py @@ -14,48 +14,53 @@ from pprint import pprint import pytest import numpy as np +import os +from med_bench.get_estimation_results import get_estimation_results from med_bench.get_simulated_data import simulate_data -from med_bench.get_estimation import get_estimation - from med_bench.utils.utils import DependencyNotInstalledError, check_r_dependencies from med_bench.utils.constants import PARAMETER_LIST, PARAMETER_NAME, R_DEPENDENT_ESTIMATORS, TOLERANCE_DICT +current_dir = os.path.dirname(__file__) +true_estimations_file_path = os.path.join(current_dir, 'tests_results.npy') +TRUE_ESTIMATIONS = np.load(true_estimations_file_path, allow_pickle=True) + @pytest.fixture(params=PARAMETER_LIST) def dict_param(request): return dict(zip(PARAMETER_NAME, request.param)) +# Two distinct data fixtures @pytest.fixture -def data(dict_param): +def data_simulated(dict_param): return simulate_data(**dict_param) @pytest.fixture -def x(data): - return data[0] +def x(data_simulated): + return data_simulated[0] # t is raveled because some estimators fail with (n,1) inputs @pytest.fixture -def t(data): - return data[1].ravel() +def t(data_simulated): + return data_simulated[1].ravel() @pytest.fixture -def m(data): - return data[2] +def m(data_simulated): + return data_simulated[2] @pytest.fixture -def y(data): - return data[3].ravel() # same reason as t +def y(data_simulated): + return data_simulated[3].ravel() # same reason as t @pytest.fixture -def effects(data): - return np.array(data[4:9]) +def effects(data_simulated): + return np.array(data_simulated[4:9]) @pytest.fixture(params=list(TOLERANCE_DICT.keys())) @@ -80,7 +85,7 @@ def effects_chap(x, t, m, y, estimator, config): # try whether estimator is implemented or not try: - res = get_estimation(x, t, m, y, estimator, config)[0:5] + res = get_estimation_results(x, t, m, y, estimator, config)[0:5] except Exception as e: if str(e) in ( "Estimator only supports 1D binary mediator.", @@ -126,9 +131,94 @@ def test_robustness_to_ravel_format(data, estimator, config, effects_chap): if "forest" in estimator: pytest.skip("Forest estimator skipped") assert np.all( - get_estimation(data[0], data[1], data[2], - data[3], estimator, config)[0:5] + get_estimation_results(data[0], data[1], data[2], + data[3], estimator, config)[0:5] == pytest.approx( effects_chap, nan_ok=True ) # effects_chap is obtained with data[1].ravel() and data[3].ravel() ) + + +@pytest.fixture(params=range(TRUE_ESTIMATIONS.shape[0])) +def tests_results_idx(request): + return request.param + + +@pytest.fixture +def data_true(tests_results_idx): + return TRUE_ESTIMATIONS[tests_results_idx] + + +@pytest.fixture +def estimator_true(data_true): + return data_true[0] + + +@pytest.fixture +def x_true(data_true): + return data_true[1] + + +# t is raveled because some estimators fail with (n,1) inputs +@pytest.fixture +def t_true(data_true): + return data_true[2] + + +@pytest.fixture +def m_true(data_true): + return data_true[3] + + +@pytest.fixture +def y_true(data_true): + return data_true[4] + + +@pytest.fixture +def config_true(data_true): + return data_true[5] + + +@pytest.fixture +def result_true(data_true): + return data_true[6] + + +@pytest.fixture +def effects_chap_true(x_true, t_true, m_true, y_true, estimator_true, config_true): + # try whether estimator is implemented or not + + try: + res = get_estimation_results(x_true, t_true, m_true, + y_true, estimator_true, config_true)[0:5] + + # NaN situations + if np.all(np.isnan(res)): + pytest.xfail("all effects are NaN") + elif np.any(np.isnan(res)): + pprint("NaN found") + + except Exception as e: + if str(e) in ( + "Estimator only supports 1D binary mediator.", + "Estimator does not support 1D binary mediator.", + ): + pytest.skip(f"{e}") + + # We skip the test if an error with function from glmet rpy2 package occurs + elif "glmnet::glmnet" in str(e): + pytest.skip(f"{e}") + + elif estimator in R_DEPENDENT_ESTIMATORS and not check_r_dependencies(): + assert isinstance(e, DependencyNotInstalledError) == True + pytest.skip(f"{e}") + + else: + pytest.fail(f"{e}") + + return res + + +def test_estimation_exactness(result_true, effects_chap_true): + assert np.all(effects_chap_true == pytest.approx(result_true, abs=1.e-3))