-
Notifications
You must be signed in to change notification settings - Fork 3
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Corrected cross-fitting loop #42
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM overall, thx !
sumscore3 = np.mean(t * (1 - ptmx) / (ptmx * (1 - ptx))) | ||
sumscore4 = np.mean((1 - t) * ptmx / ((1 - ptmx) * ptx)) | ||
y1m1 = (t / ptx * (y - mu_t1_x)) / sumscore1 + mu_t1_x | ||
y0m0 = ((1 - t) / (1 - ptx) * (y - mu_t0_x)) / sumscore2 + mu_t0_x | ||
y1m0 = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Are there tests that check whether this snippet computes the right thing ?
Such computation blocks should better be isolated into small functions bzw.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There aren't any test, I just made sure it gives similar performances with and without normalization on several dataset. I can declared that as an issue, #44
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If we want this lib to survive, we need to add a test suite in highest priority.
See #45
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, let me get back to testing. There should be a unit test that checks that when normalized=True
, the numerical result is fine.
ok. So to recap
As for #43 only the code in get_estimation is tested, so the package is not properly equipped to check specifically that this PR does not break the code. Should we merge this as it is, and implement a more specific test suite? @bthirion @houssamzenati |
Previous cross-fitting loop was using an array of numpy array, and we suspected it to slow down the computation time since
multiply_robust_efficient
is faster thanmed_dml
. This commit optimizes the loop.In the end it appears the loop wasn't slowing down anything (
med_dml
is slower because it calls forest instances more often thanmultiply_robust
).Also the new loop seems to give slightly more accurate estimation.