Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite batched dots that do not reduce as multiplication #1178

Merged
merged 3 commits into from
Jan 28, 2025

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Jan 27, 2025

Dot is basically a fused broadcasted multiplication and reduction. When we have cases that correspond just to multiplication there is no advantage of using Dot, specially in vectorized graphs where we'll have Blockwised versions of the Dot.

This PR adds a rewrite for these cases.

Initially I was rewritting the non-blockwise versions as well, but those interfere with the BLAS pipeline. Testing locally BLAS for the outer product is not always faster (depends on the size), but it's generally faster for when doing the rank1 matrix update, so I don't want to mess with that for now.

I think we should introduce BLAS in those cases only when doing an update, not when doing the outer product alone. The rewrite is easy to toggle to allow addressing the non-blockwised version.


📚 Documentation preview 📚: https://pytensor--1178.org.readthedocs.build/en/1178/

@ricardoV94 ricardoV94 force-pushed the dot_as_mul branch 4 times, most recently from e0cb086 to 6b7dcd4 Compare January 27, 2025 11:27
@ricardoV94 ricardoV94 force-pushed the dot_as_mul branch 2 times, most recently from 9a7bad2 to 11af25a Compare January 27, 2025 14:06
@ricardoV94 ricardoV94 changed the title Rewrite dots as multiplication without summation Rewrite batched dots that do not reduce as multiplication Jan 27, 2025
@ricardoV94 ricardoV94 added the linalg Linear algebra label Jan 27, 2025
Copy link

codecov bot commented Jan 27, 2025

Codecov Report

Attention: Patch coverage is 85.71429% with 7 lines in your changes missing coverage. Please review.

Project coverage is 82.27%. Comparing base (b065112) to head (8e08d2f).
Report is 6 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/tensor/rewriting/math.py 86.66% 2 Missing and 2 partials ⚠️
pytensor/tensor/math.py 84.21% 2 Missing and 1 partial ⚠️

❌ Your patch status has failed because the patch coverage (85.71%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@           Coverage Diff           @@
##             main    #1178   +/-   ##
=======================================
  Coverage   82.27%   82.27%           
=======================================
  Files         186      186           
  Lines       48009    48053   +44     
  Branches     8624     8633    +9     
=======================================
+ Hits        39499    39538   +39     
- Misses       6353     6356    +3     
- Partials     2157     2159    +2     
Files with missing lines Coverage Δ
pytensor/tensor/math.py 91.87% <84.21%> (+0.01%) ⬆️
pytensor/tensor/rewriting/math.py 88.93% <86.66%> (-0.05%) ⬇️

Copy link
Member

@jessegrabowski jessegrabowski left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks nice. Added some nitpicks, nothing major. Have you benchmarked this?

return vectorize_node_fallback(op, node, batched_x, batched_y)
old_x_ndim = old_x.type.ndim
old_y_ndim = old_y.type.ndim
match (old_x_ndim, old_y_ndim):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's match week here at pymc-devs!

pytensor/tensor/rewriting/math.py Outdated Show resolved Hide resolved
pytensor/tensor/rewriting/math.py Outdated Show resolved Hide resolved

out = alloc(x_v, 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why did you refactor check_alloc_runtime_broadcast out of this TestAlloc class but not this one?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I refactored the other out because trying to use the method from another module led pytest to execute all the tests of the class. The class calling a method on the class itself should be fine, which I assume is what's happening here?

@ricardoV94
Copy link
Member Author

Looks nice. Added some nitpicks, nothing major. Have you benchmarked this?

I'm sure it's faster without benchmarking because blockwise will be a python loop over batch dims. Even the core case without blockwise was often faster than with dot -> GER (except n=512 haha), but I don't want to mess with the rest of the BLAS pipeline as I mentioned

@ricardoV94 ricardoV94 merged commit 911c6a3 into pymc-devs:main Jan 28, 2025
63 of 64 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants