-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Rewrite batched dots that do not reduce as multiplication #1178
Conversation
e0cb086
to
6b7dcd4
Compare
9a7bad2
to
11af25a
Compare
Codecov ReportAttention: Patch coverage is
❌ Your patch status has failed because the patch coverage (85.71%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1178 +/- ##
=======================================
Coverage 82.27% 82.27%
=======================================
Files 186 186
Lines 48009 48053 +44
Branches 8624 8633 +9
=======================================
+ Hits 39499 39538 +39
- Misses 6353 6356 +3
- Partials 2157 2159 +2
|
11af25a
to
64bb608
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks nice. Added some nitpicks, nothing major. Have you benchmarked this?
return vectorize_node_fallback(op, node, batched_x, batched_y) | ||
old_x_ndim = old_x.type.ndim | ||
old_y_ndim = old_y.type.ndim | ||
match (old_x_ndim, old_y_ndim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's match week here at pymc-devs!
|
||
out = alloc(x_v, 5, 3) | ||
f = pytensor.function([x_v], out, mode=mode) | ||
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why did you refactor check_alloc_runtime_broadcast
out of this TestAlloc
class but not this one?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I refactored the other out because trying to use the method from another module led pytest to execute all the tests of the class. The class calling a method on the class itself should be fine, which I assume is what's happening here?
I'm sure it's faster without benchmarking because blockwise will be a python loop over batch dims. Even the core case without blockwise was often faster than with dot -> GER (except n=512 haha), but I don't want to mess with the rest of the BLAS pipeline as I mentioned |
64bb608
to
cc443b6
Compare
cc443b6
to
8e08d2f
Compare
Dot is basically a fused broadcasted multiplication and reduction. When we have cases that correspond just to multiplication there is no advantage of using Dot, specially in vectorized graphs where we'll have Blockwised versions of the Dot.
This PR adds a rewrite for these cases.
Initially I was rewritting the non-blockwise versions as well, but those interfere with the BLAS pipeline. Testing locally BLAS for the outer product is not always faster (depends on the size), but it's generally faster for when doing the rank1 matrix update, so I don't want to mess with that for now.
I think we should introduce BLAS in those cases only when doing an update, not when doing the outer product alone. The rewrite is easy to toggle to allow addressing the non-blockwised version.
📚 Documentation preview 📚: https://pytensor--1178.org.readthedocs.build/en/1178/