-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch inline constants in dispatch to avoid graph breaks #1118
base: main
Are you sure you want to change the base?
Conversation
Still need to do something about the runtime broadcast in elemwise. Can we use |
Codecov ReportAttention: Patch coverage is
❌ Your patch check has failed because the patch coverage (68.88%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage. Additional details and impacted files@@ Coverage Diff @@
## main #1118 +/- ##
==========================================
- Coverage 82.27% 82.25% -0.02%
==========================================
Files 186 186
Lines 48000 48037 +37
Branches 8621 8625 +4
==========================================
+ Hits 39490 39512 +22
- Misses 6353 6366 +13
- Partials 2157 2159 +2
|
Even without the runtime broadcast check, elemwise seems to break the graph |
c08d288
to
566145a
Compare
Did you get a chance to profile this pr? |
Btw I did profile this. My machine actually failed to even compile dlogp for a model but I suspect that's unrelated. The logp method did show some improvement. The thing that intrigued me is this change reduced the number of guards by a lot (it was 10:1 with the other ones). I thought that maybe that was the cause of the runtime switch, but that didn't have the payoff I was expecting |
The cost of the guards may be non-linear so we should try to remove all |
Idk about removing all, since guards are the primitive that ensures runtime correctness. Significantly reduce, i agree |
Btw, for the actual perf benefit, these are the numbers i see. # ricardo shape: 772 μs ± 12 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)
# no ricardo shape: 818 μs ± 9.48 μs per loop (mean ± std. dev. of 100 runs, 100 loops each) So it's like ~5%, probably more on slower cpus. The graph breaks are definitely a problem :( |
If we add these two flags with the changes in this PR:
we come down to almost 500us.
|
@Ch0ronomato can you revert the removal of the Elemwise bcast check (for now), and add those flags? Then we can merge this PR and keep playing with stuff |
The ci doesn't like those flags. I'll investigate |
I think the path to fix this is not use those flags by default, but when we have a shape operation. The torch compiler might be really restrictive |
c5f26fd
to
dbc95e4
Compare
@@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs): | |||
def jit_compile(self, fn): | |||
import torch | |||
|
|||
# flag that tend to help our graphs | |||
torch._dynamo.config.capture_dynamic_output_shape_ops = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully when #1159 gets merged we can just delete this flag altogether since torch will know these aren't dynamic
When we have static inputs, inlining helps torch not breaking the graph.
Related to #1110
📚 Documentation preview 📚: https://pytensor--1118.org.readthedocs.build/en/1118/