Skip to content

Commit

Permalink
files cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
brash6 committed Nov 14, 2024
1 parent 825bb12 commit 9bff44b
Show file tree
Hide file tree
Showing 9 changed files with 226 additions and 1,777 deletions.
10 changes: 5 additions & 5 deletions src/med_bench/estimation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,9 +325,9 @@ def _estimate_mediator_probability(self, t, m, x, y):
Returns
-------
f_m0, array-like, shape (n_samples)
f_m0x, array-like, shape (n_samples)
probabilities f(M|T=0,X)
f_m1, array-like, shape (n_samples)
f_m1x, array-like, shape (n_samples)
probabilities f(M|T=1,X)
"""
n = len(y)
Expand All @@ -340,10 +340,10 @@ def _estimate_mediator_probability(self, t, m, x, y):
t0_x = np.hstack([t0.reshape(-1, 1), x])
t1_x = np.hstack([t1.reshape(-1, 1), x])

fm_0 = self._classifier_m.predict_proba(t0_x)[:, 1]
fm_1 = self._classifier_m.predict_proba(t1_x)[:, 1]
f_m0x = self._classifier_m.predict_proba(t0_x)[:, m]
f_m1x = self._classifier_m.predict_proba(t1_x)[:, m]

return fm_0, fm_1
return f_m0x, f_m1x

def _estimate_mediators_probabilities(self, t, m, x, y):
"""
Expand Down
10 changes: 5 additions & 5 deletions src/med_bench/estimation/mediation_g_computation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from med_bench.estimation.base import Estimator
from med_bench.utils.decorators import fitted
from med_bench.utils.utils import is_array_integer
from med_bench.utils.utils import is_array_binary


class GComputation(Estimator):
Expand Down Expand Up @@ -35,7 +35,7 @@ def fit(self, t, m, x, y):
"""Fits nuisance parameters to data"""
t, m, x, y = self._resize(t, m, x, y)

if is_array_integer(m):
if is_array_binary(m):
self._fit_mediator_nuisance(t, m, x, y)
self._fit_conditional_mean_outcome_nuisance(t, m, x, y)
else:
Expand All @@ -55,10 +55,10 @@ def estimate(self, t, m, x, y):
"""
t, m, x, y = self._resize(t, m, x, y)

if is_array_integer(m):
mu_00x, mu_01x, mu_10x, mu_11x = self._estimate_mediators_probabilities(
if is_array_binary(m):
f_00x, f_01x, f_10x, f_11x = self._estimate_mediators_probabilities(
t, m, x, y)
f_00x, f_01x, f_10x, f_11x = self._estimate_conditional_mean_outcome(
mu_00x, mu_01x, mu_10x, mu_11x = self._estimate_conditional_mean_outcome(
t, m, x, y)

direct_effect_i1 = mu_11x - mu_01x
Expand Down
118 changes: 0 additions & 118 deletions src/med_bench/example.py

This file was deleted.

Loading

0 comments on commit 9bff44b

Please sign in to comment.