Skip to content

Commit

Permalink
optimise constants
Browse files Browse the repository at this point in the history
  • Loading branch information
brash6 committed Mar 19, 2024
1 parent bab13ab commit 684de63
Showing 1 changed file with 51 additions and 87 deletions.
138 changes: 51 additions & 87 deletions src/med_bench/utils/constants.py
Original file line number Diff line number Diff line change
@@ -1,64 +1,64 @@
import itertools
import os
import numpy as np
from numpy.random import default_rng

# CONSTANTS USED FOR TESTS

# TOLERANCE THRESHOLDS

SMALL_ATE_TOLERANCE = 0.05
SMALL_DIRECT_TOLERANCE = 0.05
SMALL_INDIRECT_TOLERANCE = 0.2

MEDIUM_ATE_TOLERANCE = 0.10
MEDIUM_DIRECT_TOLERANCE = 0.10
MEDIUM_INDIRECT_TOLERANCE = 0.4

LARGE_ATE_TOLERANCE = 0.15
LARGE_DIRECT_TOLERANCE = 0.15
LARGE_INDIRECT_TOLERANCE = 0.8
# indirect effect is weak, leading to a large relative error

SMALL_TOLERANCE = np.array(
[
SMALL_ATE_TOLERANCE,
SMALL_DIRECT_TOLERANCE,
SMALL_DIRECT_TOLERANCE,
SMALL_INDIRECT_TOLERANCE,
SMALL_INDIRECT_TOLERANCE,
]
)
TOLERANCE_THRESHOLDS = {
"SMALL": {
"ATE": 0.05,
"DIRECT": 0.05,
"INDIRECT": 0.2,
},
"MEDIUM": {
"ATE": 0.10,
"DIRECT": 0.10,
"INDIRECT": 0.4,
},
"LARGE": {
"ATE": 0.15,
"DIRECT": 0.15,
"INDIRECT": 0.8,
},
"INFINITE": {
"ATE": np.inf,
"DIRECT": np.inf,
"INDIRECT": np.inf,
},
}

MEDIUM_TOLERANCE = np.array(
[
MEDIUM_ATE_TOLERANCE,
MEDIUM_DIRECT_TOLERANCE,
MEDIUM_DIRECT_TOLERANCE,
MEDIUM_INDIRECT_TOLERANCE,
MEDIUM_INDIRECT_TOLERANCE,
]
)

LARGE_TOLERANCE = np.array(
[
LARGE_ATE_TOLERANCE,
LARGE_DIRECT_TOLERANCE,
LARGE_DIRECT_TOLERANCE,
LARGE_INDIRECT_TOLERANCE,
LARGE_INDIRECT_TOLERANCE,
]
)
def get_tolerance_array(tolerance_size: str) -> np.array:
"""Get tolerance array for tolerance tests
INFINITE_TOLERANCE = np.array(
[
np.inf,
np.inf,
np.inf,
np.inf,
np.inf,
]
)
Parameters
----------
tolerance_size : str
tolerance size, can be "SMALL", "MEDIUM", "LARGE" or "INFINITE"
Returns
-------
np.array
array of size 5 containing the ATE, DIRECT and INDIRECT effects tolerance
"""

return np.array(
[
TOLERANCE_THRESHOLDS[tolerance_size]["ATE"],
TOLERANCE_THRESHOLDS[tolerance_size]["DIRECT"],
TOLERANCE_THRESHOLDS[tolerance_size]["DIRECT"],
TOLERANCE_THRESHOLDS[tolerance_size]["INDIRECT"],
TOLERANCE_THRESHOLDS[tolerance_size]["INDIRECT"],
]
)


SMALL_TOLERANCE = get_tolerance_array("SMALL")
MEDIUM_TOLERANCE = get_tolerance_array("MEDIUM")
LARGE_TOLERANCE = get_tolerance_array("LARGE")
INFINITE_TOLERANCE = get_tolerance_array("INFINITE")

TOLERANCE_DICT = {
"coefficient_product": LARGE_TOLERANCE,
Expand Down Expand Up @@ -98,43 +98,7 @@
"mediation_multiply_robust_forest_calibration_cf": INFINITE_TOLERANCE,
}

ESTIMATORS = [
"coefficient_product",
"mediation_ipw_noreg",
"mediation_ipw_reg",
"mediation_ipw_reg_calibration",
"mediation_ipw_forest",
"mediation_ipw_forest_calibration",
"mediation_g_computation_noreg",
"mediation_g_computation_reg",
"mediation_g_computation_reg_calibration",
"mediation_g_computation_forest",
"mediation_g_computation_forest_calibration",
"mediation_multiply_robust_noreg",
"mediation_multiply_robust_reg",
"mediation_multiply_robust_reg_calibration",
"mediation_multiply_robust_forest",
"mediation_multiply_robust_forest_calibration",
"simulation_based",
"mediation_DML",
"mediation_DML_reg_fixed_seed",
"mediation_g_estimator",
"mediation_ipw_noreg_cf",
"mediation_ipw_reg_cf",
"mediation_ipw_reg_calibration_cf",
"mediation_ipw_forest_cf",
"mediation_ipw_forest_calibration_cf",
"mediation_g_computation_noreg_cf",
"mediation_g_computation_reg_cf",
"mediation_g_computation_reg_calibration_cf",
"mediation_g_computation_forest_cf",
"mediation_g_computation_forest_calibration_cf",
"mediation_multiply_robust_noreg_cf",
"mediation_multiply_robust_reg_cf",
"mediation_multiply_robust_reg_calibration_cf",
"mediation_multiply_robust_forest_cf",
"mediation_multiply_robust_forest_calibration_cf",
]
ESTIMATORS = list(TOLERANCE_DICT.keys())

# PARAMETERS VALUES FOR DATA GENERATION

Expand Down

0 comments on commit 684de63

Please sign in to comment.