Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch inline constants in dispatch to avoid graph breaks #1118

Open
wants to merge 7 commits into
base: main
Choose a base branch
from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Dec 12, 2024

When we have static inputs, inlining helps torch not breaking the graph.

Related to #1110


📚 Documentation preview 📚: https://pytensor--1118.org.readthedocs.build/en/1118/

@ricardoV94 ricardoV94 added performance torch PyTorch backend labels Dec 12, 2024
@ricardoV94
Copy link
Member Author

Still need to do something about the runtime broadcast in elemwise. Can we use torch._check for that instead of Python loops/asserts?

Copy link

codecov bot commented Dec 12, 2024

Codecov Report

Attention: Patch coverage is 68.88889% with 14 lines in your changes missing coverage. Please review.

Project coverage is 82.25%. Comparing base (4ea4259) to head (eb3ff29).
Report is 1 commits behind head on main.

Files with missing lines Patch % Lines
pytensor/link/pytorch/dispatch/basic.py 52.17% 9 Missing and 2 partials ⚠️
pytensor/link/pytorch/dispatch/scalar.py 66.66% 1 Missing ⚠️
pytensor/link/pytorch/dispatch/shape.py 90.00% 1 Missing ⚠️
pytensor/link/pytorch/dispatch/subtensor.py 87.50% 1 Missing ⚠️

❌ Your patch check has failed because the patch coverage (68.88%) is below the target coverage (100.00%). You can increase the patch coverage or adjust the target coverage.

Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #1118      +/-   ##
==========================================
- Coverage   82.27%   82.25%   -0.02%     
==========================================
  Files         186      186              
  Lines       48000    48037      +37     
  Branches     8621     8625       +4     
==========================================
+ Hits        39490    39512      +22     
- Misses       6353     6366      +13     
- Partials     2157     2159       +2     
Files with missing lines Coverage Δ
pytensor/link/pytorch/linker.py 100.00% <100.00%> (ø)
pytensor/link/pytorch/dispatch/scalar.py 73.68% <66.66%> (-0.39%) ⬇️
pytensor/link/pytorch/dispatch/shape.py 85.71% <90.00%> (ø)
pytensor/link/pytorch/dispatch/subtensor.py 89.53% <87.50%> (-0.21%) ⬇️
pytensor/link/pytorch/dispatch/basic.py 87.40% <52.17%> (-7.10%) ⬇️

... and 1 file with indirect coverage changes

@ricardoV94
Copy link
Member Author

Even without the runtime broadcast check, elemwise seems to break the graph

@ricardoV94 ricardoV94 force-pushed the torch_constant_dispatch branch from c08d288 to 566145a Compare December 12, 2024 10:48
@Ch0ronomato
Copy link
Contributor

Did you get a chance to profile this pr?

@Ch0ronomato
Copy link
Contributor

Btw I did profile this. My machine actually failed to even compile dlogp for a model but I suspect that's unrelated. The logp method did show some improvement. The thing that intrigued me is this change reduced the number of guards by a lot (it was 10:1 with the other ones). I thought that maybe that was the cause of the runtime switch, but that didn't have the payoff I was expecting

@ricardoV94
Copy link
Member Author

The cost of the guards may be non-linear so we should try to remove all

@Ch0ronomato
Copy link
Contributor

The cost of the guards may be non-linear so we should try to remove all

Idk about removing all, since guards are the primitive that ensures runtime correctness. Significantly reduce, i agree

@Ch0ronomato
Copy link
Contributor

Btw, for the actual perf benefit, these are the numbers i see.

# ricardo shape: 772 μs ± 12 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)
# no ricardo shape: 818 μs ± 9.48 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)

So it's like ~5%, probably more on slower cpus. The graph breaks are definitely a problem :(

@Ch0ronomato
Copy link
Contributor

Ch0ronomato commented Dec 27, 2024

If we add these two flags with the changes in this PR:

torch._dynamo.config.capture_func_transforms=True
torch._dynamo.config.capture_scalar_outputs = True

we come down to almost 500us.

504 μs ± 12.2 μs per loop (mean ± std. dev. of 100 runs, 100 loops each)

@ricardoV94
Copy link
Member Author

@Ch0ronomato can you revert the removal of the Elemwise bcast check (for now), and add those flags? Then we can merge this PR and keep playing with stuff

@Ch0ronomato
Copy link
Contributor

The ci doesn't like those flags. I'll investigate

@Ch0ronomato
Copy link
Contributor

I think the path to fix this is not use those flags by default, but when we have a shape operation. The torch compiler might be really restrictive

@Ch0ronomato Ch0ronomato force-pushed the torch_constant_dispatch branch from c5f26fd to dbc95e4 Compare January 26, 2025 17:46
@@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs):
def jit_compile(self, fn):
import torch

# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully when #1159 gets merged we can just delete this flag altogether since torch will know these aren't dynamic

@Ch0ronomato Ch0ronomato marked this pull request as ready for review January 26, 2025 20:20
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
performance torch PyTorch backend
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants