Skip to content

Commit

Permalink
corrected the issues for the PR
Browse files Browse the repository at this point in the history
  • Loading branch information
Houssam Zenati committed Dec 13, 2023
1 parent a4daeee commit 362c530
Show file tree
Hide file tree
Showing 7 changed files with 219 additions and 80 deletions.
Binary file added src/.DS_Store
Binary file not shown.
114 changes: 61 additions & 53 deletions src/med_bench/get_estimation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,16 @@
from rpy2.rinterface_lib.embedded import RRuntimeError
import pandas as pd
import numpy as np
from .mediation import mediation_IPW, mediation_coefficient_products, mediation_g_formula,mediation_multiply_robust, r_g_estimator, r_medDML, r_mediate, mediation_DML

from .mediation import (
mediation_IPW,
mediation_coefficient_products,
mediation_g_computation,
mediation_multiply_robust,
mediation_DML,
r_mediation_g_computation,
r_mediation_DML,
r_mediate,
)

def get_estimation(x, t, m, y, estimator, config):
"""Wrapper estimator fonction ; calls an estimator given mediation data
Expand Down Expand Up @@ -285,9 +293,9 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=True,
calib_method="isotonic",
)
elif estimator == "g_computation_noreg":
elif estimator == "mediation_g_computation_noreg":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -298,9 +306,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=False,
calibration=False,
)
elif estimator == "g_computation_noreg_cf":
elif estimator == "mediation_g_computation_noreg_cf":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -311,9 +319,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=False,
calibration=False,
)
elif estimator == "g_computation_reg":
elif estimator == "mediation_g_computation_reg":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -324,9 +332,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=False,
)
elif estimator == "g_computation_reg_cf":
elif estimator == "mediation_g_computation_reg_cf":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -337,9 +345,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=False,
)
elif estimator == "g_computation_reg_calibration":
elif estimator == "mediation_g_computation_reg_calibration":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -350,9 +358,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=True,
)
elif estimator == "g_computation_reg_calibration_iso":
elif estimator == "mediation_g_computation_reg_calibration_iso":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -364,9 +372,9 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=True,
calib_method="isotonic",
)
elif estimator == "g_computation_reg_calibration_cf":
elif estimator == "mediation_g_computation_reg_calibration_cf":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -377,9 +385,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=True,
)
elif estimator == "g_computation_reg_calibration_iso_cf":
elif estimator == "mediation_g_computation_reg_calibration_iso_cf":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -391,9 +399,9 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=True,
calib_method="isotonic",
)
elif estimator == "g_computation_forest":
elif estimator == "mediation_g_computation_forest":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -404,9 +412,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=False,
)
elif estimator == "g_computation_forest_cf":
elif estimator == "mediation_g_computation_forest_cf":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -417,9 +425,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=False,
)
elif estimator == "g_computation_forest_calibration":
elif estimator == "mediation_g_computation_forest_calibration":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -430,9 +438,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=True,
)
elif estimator == "g_computation_forest_calibration_iso":
elif estimator == "mediation_g_computation_forest_calibration_iso":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -444,9 +452,9 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=True,
calib_method="isotonic",
)
elif estimator == "g_computation_forest_calibration_cf":
elif estimator == "mediation_g_computation_forest_calibration_cf":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -457,9 +465,9 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=True,
)
elif estimator == "g_computation_forest_calibration_iso_cf":
elif estimator == "mediation_g_computation_forest_calibration_iso_cf":
if config in (0, 1, 2):
effects = mediation_g_formula(
effects = mediation_g_computation(
y,
t,
m,
Expand All @@ -471,7 +479,7 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=True,
calib_method="isotonic",
)
elif estimator == "multiply_robust_noreg":
elif estimator == "mediation_multiply_robust_noreg":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -485,7 +493,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=False,
calibration=False,
)
elif estimator == "multiply_robust_noreg_cf":
elif estimator == "mediation_multiply_robust_noreg_cf":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -499,7 +507,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=False,
calibration=False,
)
elif estimator == "multiply_robust_reg":
elif estimator == "mediation_multiply_robust_reg":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -513,7 +521,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=False,
)
elif estimator == "multiply_robust_reg_cf":
elif estimator == "mediation_multiply_robust_reg_cf":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -527,7 +535,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=False,
)
elif estimator == "multiply_robust_reg_calibration":
elif estimator == "mediation_multiply_robust_reg_calibration":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -541,7 +549,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=True,
)
elif estimator == "multiply_robust_reg_calibration_iso":
elif estimator == "mediation_multiply_robust_reg_calibration_iso":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -556,7 +564,7 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=True,
calib_method="isotonic",
)
elif estimator == "multiply_robust_reg_calibration_cf":
elif estimator == "mediation_multiply_robust_reg_calibration_cf":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -570,7 +578,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=True,
)
elif estimator == "multiply_robust_reg_calibration_iso_cf":
elif estimator == "mediation_multiply_robust_reg_calibration_iso_cf":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -585,7 +593,7 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=True,
calib_method="isotonic",
)
elif estimator == "multiply_robust_forest":
elif estimator == "mediation_multiply_robust_forest":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -599,7 +607,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=False,
)
elif estimator == "multiply_robust_forest_cf":
elif estimator == "mediation_multiply_robust_forest_cf":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -613,7 +621,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=False,
)
elif estimator == "multiply_robust_forest_calibration":
elif estimator == "mediation_multiply_robust_forest_calibration":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -627,7 +635,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=True,
)
elif estimator == "multiply_robust_forest_calibration_iso":
elif estimator == "mediation_multiply_robust_forest_calibration_iso":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -642,7 +650,7 @@ def get_estimation(x, t, m, y, estimator, config):
calibration=True,
calib_method="isotonic",
)
elif estimator == "multiply_robust_forest_calibration_cf":
elif estimator == "mediation_multiply_robust_forest_calibration_cf":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -656,7 +664,7 @@ def get_estimation(x, t, m, y, estimator, config):
regularization=True,
calibration=True,
)
elif estimator == "multiply_robust_forest_calibration_iso_cf":
elif estimator == "mediation_multiply_robust_forest_calibration_iso_cf":
if config in (0, 1, 2):
effects = mediation_multiply_robust(
y,
Expand All @@ -674,22 +682,22 @@ def get_estimation(x, t, m, y, estimator, config):
elif estimator == "simulation_based":
if config in (0, 1, 2):
effects = r_mediate(y, t, m, x, interaction=False)
elif estimator == "DML_mediation":
elif estimator == "mediation_DML":
if config > 0:
effects = r_medDML(y, t, m, x, trim=0.0, order=1)
elif estimator == "med_dml_noreg":
effects = r_mediation_DML(y, t, m, x, trim=0.0, order=1)
elif estimator == "mediation_DML_noreg":
effects = mediation_DML(x, t, m, y, trim=0, regularization=False)
elif estimator == "med_dml_reg":
elif estimator == "mediation_DML_reg":
effects = mediation_DML(x, t, m, y, trim=0)
elif estimator == "med_dml_reg_fixed_seed":
elif estimator == "mediation_DML_reg_fixed_seed":
effects = mediation_DML(x, t, m, y, trim=0, random_state=321)
elif estimator == "med_dml_noreg_cf":
elif estimator == "mediation_DML_noreg_cf":
effects = mediation_DML(x, t, m, y, trim=0, crossfit=4, regularization=False)
elif estimator == "med_dml_reg_cf":
elif estimator == "mediation_DML_reg_cf":
effects = mediation_DML(x, t, m, y, trim=0, crossfit=4)
elif estimator == "G_estimator":
elif estimator == "mediation_g_computation":
if config in (0, 1, 2):
effects = r_g_estimator(y, t, m, x)
effects = r_mediation_g_computation(y, t, m, x)
else:
raise ValueError("Unrecognized estimator label.")
if effects is None:
Expand Down
6 changes: 3 additions & 3 deletions src/med_bench/mediation.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,7 +323,7 @@ def mediation_coefficient_products(y, t, m, x, interaction=False, regularization
None]


def mediation_g_formula(y, t, m, x, interaction=False, forest=False,
def mediation_g_computation(y, t, m, x, interaction=False, forest=False,
crossfit=0, calibration=True, regularization=True,
calib_method='sigmoid'):
"""
Expand Down Expand Up @@ -889,7 +889,7 @@ def r_mediate(y, t, m, x, interaction=False):
return to_return + [None]


def r_g_estimator(y, t, m, x):
def r_mediation_g_computation(y, t, m, x):
m = m.ravel()
var_names = [[y, 'y'],
[t, 't'],
Expand Down Expand Up @@ -926,7 +926,7 @@ def r_g_estimator(y, t, m, x):
indirect_effect,
None]

def r_medDML(y, t, m, x, trim=0.05, order=1):
def r_mediation_DML(y, t, m, x, trim=0.05, order=1):
"""
y array-like, shape (n_samples)
outcome value for each unit, continuous
Expand Down
Binary file added src/tests/.DS_Store
Binary file not shown.
Loading

0 comments on commit 362c530

Please sign in to comment.