Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Input check #80

Closed
wants to merge 43 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
43 commits
Select commit Hold shift + click to select a range
b0a1ddd
add random_state to scikit_learn functions
brash6 Mar 5, 2024
ab6e721
remove unexpected random_state
brash6 Mar 5, 2024
b75a0bb
tests CI
Mar 5, 2024
0bd98ff
tests without R dependencies
Mar 6, 2024
036921c
tests ignoring R dependencies
Mar 8, 2024
eff7730
tests ignoring R dependencies 2
Mar 8, 2024
cf675e8
warnings if R or an R package is not installed (mediation.py)
Mar 14, 2024
fa52e96
warnings if R or an R package is not installed
zbakhm Mar 14, 2024
7ebf5f4
workflow update
Mar 14, 2024
d8b5add
workflow update 2
Mar 14, 2024
0d2982f
workflow update 3
Mar 14, 2024
09faa47
workflow update 4
Mar 14, 2024
3960b4d
workflow update 5
Mar 14, 2024
7ec0143
saving packages to cache
Mar 14, 2024
cd19682
saving packages to cache 2
Mar 14, 2024
48824ba
saving packages to cache 3
Mar 14, 2024
be088d0
tests using cache
Mar 14, 2024
7095310
update workfolws
Mar 14, 2024
14c97cd
last update
Mar 14, 2024
666174c
Removing commented lines & making sure to apply the 80 characters lim…
zbakhm Mar 15, 2024
edd684f
add DS_STORE in gitignore file
brash6 Mar 18, 2024
5eaef73
force pandas version to 1.2.1 - add pytest to setup file
brash6 Mar 18, 2024
84f8922
add generate_tests_results.py and tests_results.npy files
brash6 Mar 18, 2024
2b8a7bc
add constants file
brash6 Mar 18, 2024
43bc76d
get rid of unused imports
brash6 Mar 18, 2024
0badf9c
minor reformatting
brash6 Mar 18, 2024
795d35f
handle glmnet error in tolerance tests
brash6 Mar 18, 2024
86afa28
add exact estimation tests
brash6 Mar 18, 2024
c57cdec
add DS_STORE to gitignore
brash6 Mar 18, 2024
35726c5
Merge branch 'develop' into create_new_tests
brash6 Mar 18, 2024
bab06f3
use a dictionary for tolerance thresholds
brash6 Mar 19, 2024
3ec2c60
discard previous changes (tolerance threshold dict)
brash6 Mar 19, 2024
bab13ab
estimators now return a tuple instead of a list
brash6 Mar 19, 2024
684de63
optimise constants
brash6 Mar 19, 2024
5a9a3ec
Merge pull request #73 from judithabk6/test-CI
brash6 Mar 19, 2024
4058a89
Merge branch 'develop' into create_new_tests
brash6 Mar 19, 2024
9d249f2
Solving issue #76: fail import of _get_interactions
zbakhm Mar 20, 2024
d0286ad
Update src/med_bench/utils/constants.py
brash6 Mar 20, 2024
7b93aa6
remove useless comment in nuisances.py
brash6 Mar 20, 2024
f42a6d1
Enhance readibility of generate_tests_results.py functions
brash6 Mar 20, 2024
0b8020c
Merge pull request #74 from judithabk6/create_new_tests
brash6 Mar 20, 2024
3792895
enforce input check at the function level to avoid issues with input …
judithabk6 Mar 28, 2024
2922cea
small fixes in expected mediator type
judithabk6 Mar 28, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions .github/workflows/main.yaml

This file was deleted.

50 changes: 50 additions & 0 deletions .github/workflows/save-packages-cache.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: cache-R

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.11' # Specify the Python version you want to use

- name: Install Package in Editable Mode with Python Dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"

- name: Setup R
uses: r-lib/actions/setup-r@v2
with:
r-version: '4.3.2' # Use the R version you prefer

- name: Install R packages
uses: r-lib/actions/setup-r-dependencies@v2
with:
cache: true
cache-version: 1
dependencies: 'NA'
install-pandoc: false
packages: |
grf
causalweight
mediation

- name: Install plmed package
run: |
R -e "pak::pkg_install('ohines/plmed')"

- name: Install Pytest
run: |
pip install pytest
54 changes: 54 additions & 0 deletions .github/workflows/tests-with-R.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
name: CI-with-R

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.11' # Specify the Python version you want to use

- name: Install Package in Editable Mode with Python Dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"

- name: Setup R
uses: r-lib/actions/setup-r@v2
with:
r-version: '4.3.2' # Use the R version you prefer

- name: Install R packages
uses: r-lib/actions/setup-r-dependencies@v2
with:
cache: true
cache-version: 1
dependencies: 'NA'
install-pandoc: false
packages: |
grf
causalweight
mediation

- name: Install plmed package
run: |
R -e "pak::pkg_install('ohines/plmed')"

- name: Install Pytest
run: |
pip install pytest

- name: Run tests
run: |
pytest
33 changes: 33 additions & 0 deletions .github/workflows/tests-without-R.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
name: CI-without-R

on:
push:
branches: [ main ]
pull_request:
branches: [ main ]

jobs:
test:
runs-on: ubuntu-latest

steps:
- name: Checkout code
uses: actions/checkout@v4

- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: '3.11' # Specify the Python version you want to use

- name: Install Package in Editable Mode with Python Dependencies
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"

- name: Install Pytest
run: |
pip install pytest

- name: Run tests
run: |
pytest
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -127,3 +127,7 @@ dmypy.json

# Pyre type checker
.pyre/

# DS_STORE files
src/.DS_Store
.DS_Store
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,14 @@
),
package_dir={"": "src"},
install_requires=[
'pandas>=1.2.1',
'pandas==1.2.1',
'scikit-learn>=0.22.1',
'numpy>=1.19.2',
'rpy2>=2.9.4',
'scipy>=1.5.2',
'seaborn>=0.11.1',
'matplotlib>=3.3.2'
'matplotlib>=3.3.2',
"pytest"
],
classifiers=[
'Programming Language :: Python :: 3',
Expand Down
7 changes: 2 additions & 5 deletions src/med_bench/get_estimation.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#!/usr/bin/env python
# -*- coding:utf-8 -*-


import time
import sys
from rpy2.rinterface_lib.embedded import RRuntimeError
import pandas as pd
import numpy as np

from .mediation import (
mediation_IPW,
mediation_coefficient_product,
Expand All @@ -18,6 +14,7 @@
r_mediate,
)


def get_estimation(x, t, m, y, estimator, config):
"""Wrapper estimator fonction ; calls an estimator given mediation data
in order to estimate total, direct, and indirect effects.
Expand Down
83 changes: 43 additions & 40 deletions src/med_bench/get_simulated_data.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,7 @@
import numpy as np
from numpy.random import default_rng
from scipy import stats
import pandas as pd
from pathlib import Path
from scipy.stats import bernoulli
from scipy.special import expit
import matplotlib.pyplot as plt
import pathlib
import seaborn as sns


def simulate_data(n,
Expand All @@ -23,38 +17,38 @@ def simulate_data(n,
beta_t_factor=1,
beta_m_factor=1):
"""Simulate data for mediation analysis

Parameters
----------
n: :obj:`int`,
Number of samples to generate.

rg: RandomState instance,
Controls the pseudo random number generator used to generate the
data at fit time.

mis_spec_m: obj:`bool`,
Whether the mediator generation is misspecified or not
defaults to False

mis_spec_y: obj:`bool`,
Whether the output model is misspecified or not
defaults to False

dim_x: :obj:`int`, optional,
Number of covariates in the input.
Defaults to 1

dim_m: :obj:`int`, optional,
Number of mediatiors to generate.
Defaults to 1

seed: :obj:`int` or None, optional,
Controls the pseudo random number generator used to generate the
coefficients of the model.
Pass an int for reproducible output across multiple function calls.
Defaults to None

type_m: :obj:`str`,
Whether the mediator is binary or continuous
Defaults to 'binary',
Expand All @@ -66,26 +60,26 @@ def simulate_data(n,
sigma_m :obj:`float`,
noise variance on mediator
Defaults to 0.5,

beta_t_factor: :obj:`float`,
scaling factor on treatment effect,
Defaults to 1,

beta_m_factor: :obj:`float`,
scaling factor on mediator,
Defaults to 1,

returns
-------
x: ndarray of shape (n, dim_x)
the simulated covariates

t: ndarray of shape (n, 1)
the simulated treatment

m: ndarray of shape (n, dim_m)
the simulated mediators

y: ndarray of shape (n, 1)
the simulated outcome

Expand Down Expand Up @@ -137,9 +131,11 @@ def simulate_data(n,
m = m_2d[np.arange(n), t[:, 0]].reshape(-1, 1)
else:
random_noise = sigma_m * rg.standard_normal((n, dim_m))
m0 = x.dot(beta_x) + t0.dot(beta_t) + t0 * (x.dot(beta_xt)) + random_noise
m1 = x.dot(beta_x) + t1.dot(beta_t) + t1 * (x.dot(beta_xt)) + random_noise
m = x.dot(beta_x) + t.dot(beta_t) + t * (x.dot(beta_xt)) + random_noise
m0 = x.dot(beta_x) + t0.dot(beta_t) + t0 * \
(x.dot(beta_xt)) + random_noise
m1 = x.dot(beta_x) + t1.dot(beta_t) + t1 * \
(x.dot(beta_xt)) + random_noise
m = x.dot(beta_x) + t.dot(beta_t) + t * (x.dot(beta_xt)) + random_noise

# generate the outcome Y
gamma_m = np.ones((dim_m, 1)) * 0.5 / dim_m * beta_m_factor
Expand All @@ -150,47 +146,54 @@ def simulate_data(n,
else:
gamma_t_m = np.zeros((dim_m, 1))

y = x.dot(gamma_x) + gamma_t * t + m.dot(gamma_m) + m.dot(gamma_t_m) * t + sigma_y * rg.standard_normal((n, 1))
y = x.dot(gamma_x) + gamma_t * t + m.dot(gamma_m) + \
m.dot(gamma_t_m) * t + sigma_y * rg.standard_normal((n, 1))

# Compute differents types of effects
if type_m == 'binary':
theta_1 = gamma_t + gamma_t_m * np.mean(p_m1)
theta_0 = gamma_t + gamma_t_m * np.mean(p_m0)
delta_1 = np.mean((p_m1 - p_m0) * (gamma_m.flatten() + gamma_t_m.dot(t1.T)))
delta_0 = np.mean((p_m1 - p_m0) * (gamma_m.flatten() + gamma_t_m.dot(t0.T)))
delta_1 = np.mean(
(p_m1 - p_m0) * (gamma_m.flatten() + gamma_t_m.dot(t1.T)))
delta_0 = np.mean(
(p_m1 - p_m0) * (gamma_m.flatten() + gamma_t_m.dot(t0.T)))
else:
theta_1 = gamma_t + gamma_t_m.T.dot(np.mean(m1, axis=0)) # to do mean(m1) pour avoir un vecteur de taille dim_m
# to do mean(m1) pour avoir un vecteur de taille dim_m
theta_1 = gamma_t + gamma_t_m.T.dot(np.mean(m1, axis=0))
theta_0 = gamma_t + gamma_t_m.T.dot(np.mean(m0, axis=0))
delta_1 = (gamma_t * t1 + m1.dot(gamma_m) + m1.dot(gamma_t_m) * t1 - (gamma_t * t1 + m0.dot(gamma_m) + m0.dot(gamma_t_m) * t1)).mean()
delta_0 = (gamma_t * t0 + m1.dot(gamma_m) + m1.dot(gamma_t_m) * t0 - (gamma_t * t0 + m0.dot(gamma_m) + m0.dot(gamma_t_m) * t0)).mean()
delta_1 = (gamma_t * t1 + m1.dot(gamma_m) + m1.dot(gamma_t_m) * t1 -
(gamma_t * t1 + m0.dot(gamma_m) + m0.dot(gamma_t_m) * t1)).mean()
delta_0 = (gamma_t * t0 + m1.dot(gamma_m) + m1.dot(gamma_t_m) * t0 -
(gamma_t * t0 + m0.dot(gamma_m) + m0.dot(gamma_t_m) * t0)).mean()

if type_m == 'binary':
pre_pm = np.hstack((p_m0.reshape(-1, 1), p_m1.reshape(-1, 1)))
pre_pm[m.ravel()==0, :] = 1 - pre_pm[m.ravel()==0, :]
pre_pm[m.ravel() == 0, :] = 1 - pre_pm[m.ravel() == 0, :]
pm = pre_pm[:, 1].reshape(-1, 1)
else:
p_m0 = np.prod(stats.norm.pdf((m - x.dot(beta_x)) - t0.dot(beta_t) - t0 * (x.dot(beta_xt)) / sigma_m), axis=1)
p_m1 = np.prod(stats.norm.pdf((m - x.dot(beta_x)) - t1.dot(beta_t) - t1 * (x.dot(beta_xt)) / sigma_m), axis=1)
p_m0 = np.prod(stats.norm.pdf((m - x.dot(beta_x)) -
t0.dot(beta_t) - t0 * (x.dot(beta_xt)) / sigma_m), axis=1)
p_m1 = np.prod(stats.norm.pdf((m - x.dot(beta_x)) -
t1.dot(beta_t) - t1 * (x.dot(beta_xt)) / sigma_m), axis=1)
pre_pm = np.hstack((p_m0.reshape(-1, 1), p_m1.reshape(-1, 1)))
pm = pre_pm[:, 1].reshape(-1, 1)


px = np.prod(stats.norm.pdf(x), axis=1)

pre_pt = np.hstack(((1-p_t).reshape(-1, 1), p_t.reshape(-1, 1)))
double_px = np.hstack((px.reshape(-1, 1), px.reshape(-1, 1)))
denom = np.sum(pre_pm * pre_pt * double_px, axis=1)
num = pm.ravel() * p_t.ravel() * px.ravel()
th_p_t_mx = num.ravel() / denom
return (x,
t,
m,
y,

return (x,
t,
m,
y,
theta_1.flatten()[0] + delta_0.flatten()[0],
theta_1.flatten()[0],
theta_0.flatten()[0],
theta_1.flatten()[0],
theta_0.flatten()[0],
delta_1.flatten()[0],
delta_0.flatten()[0],
p_t,
th_p_t_mx)
delta_0.flatten()[0],
p_t,
th_p_t_mx)
Loading
Loading