Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Corrected cross-fitting loop #42

Merged
merged 2 commits into from
Dec 4, 2023
Merged

Conversation

sami6mz
Copy link
Contributor

@sami6mz sami6mz commented Aug 29, 2023

Previous cross-fitting loop was using an array of numpy array, and we suspected it to slow down the computation time since multiply_robust_efficient is faster than med_dml. This commit optimizes the loop.

In the end it appears the loop wasn't slowing down anything (med_dml is slower because it calls forest instances more often than multiply_robust).
Also the new loop seems to give slightly more accurate estimation.

Copy link
Collaborator

@bthirion bthirion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM overall, thx !

src/benchmark_mediation.py Show resolved Hide resolved
sumscore3 = np.mean(t * (1 - ptmx) / (ptmx * (1 - ptx)))
sumscore4 = np.mean((1 - t) * ptmx / ((1 - ptmx) * ptx))
y1m1 = (t / ptx * (y - mu_t1_x)) / sumscore1 + mu_t1_x
y0m0 = ((1 - t) / (1 - ptx) * (y - mu_t0_x)) / sumscore2 + mu_t0_x
y1m0 = (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are there tests that check whether this snippet computes the right thing ?
Such computation blocks should better be isolated into small functions bzw.

Copy link
Contributor Author

@sami6mz sami6mz Aug 30, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There aren't any test, I just made sure it gives similar performances with and without normalization on several dataset. I can declared that as an issue, #44

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we want this lib to survive, we need to add a test suite in highest priority.
See #45

Copy link
Collaborator

@bthirion bthirion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

Copy link
Collaborator

@bthirion bthirion left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, let me get back to testing. There should be a unit test that checks that when normalized=True, the numerical result is fine.

@judithabk6
Copy link
Owner

ok. So to recap

  • the first commit corrects the way the cross-fitting is done (just rewriting, it yields the same results - should be explicit with a test, obviously)
  • the second commit is just documentation

As for #43 only the code in get_estimation is tested, so the package is not properly equipped to check specifically that this PR does not break the code.

Should we merge this as it is, and implement a more specific test suite? @bthirion @houssamzenati

@judithabk6 judithabk6 merged commit 5211eae into judithabk6:main Dec 4, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants