-
Notifications
You must be signed in to change notification settings - Fork 116
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
PyTorch inline constants in dispatch to avoid graph breaks #1118
Open
ricardoV94
wants to merge
7
commits into
pymc-devs:main
Choose a base branch
from
ricardoV94:torch_constant_dispatch
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
7 commits
Select commit
Hold shift + click to select a range
15163f2
Split and inverse
ricardoV94 45bc12b
PyTorch inline constants in dispatch to avoid graph breaks
ricardoV94 cb43d9a
Don't check runtime broadcast
ricardoV94 b627684
Readd bcast check
7b4c426
Add compiler flags to help hint more to torch
dbc95e4
Wrong flags
Ch0ronomato eb3ff29
Add split test
Ch0ronomato File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,7 +12,7 @@ | |
from pytensor.compile.sharedvalue import SharedVariable, shared | ||
from pytensor.configdefaults import config | ||
from pytensor.graph import RewriteDatabaseQuery | ||
from pytensor.graph.basic import Apply | ||
from pytensor.graph.basic import Apply, Constant | ||
from pytensor.graph.fg import FunctionGraph | ||
from pytensor.graph.op import Op | ||
from pytensor.ifelse import ifelse | ||
|
@@ -38,6 +38,11 @@ | |
py_mode = Mode(linker="py", optimizer=None) | ||
|
||
|
||
def set_test_value(x, v): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not test value stuff, we're trying to get rid of it! |
||
x.tag.test_value = v | ||
return x | ||
|
||
|
||
def compare_pytorch_and_py( | ||
fgraph: FunctionGraph, | ||
test_inputs: Iterable, | ||
|
@@ -471,3 +476,66 @@ def test_ScalarLoop_Elemwise_multi_carries(): | |
compare_pytorch_and_py( | ||
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6) | ||
) | ||
|
||
|
||
rng = np.random.default_rng(42849) | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"n_splits, axis, values, sizes", | ||
[ | ||
( | ||
0, | ||
0, | ||
set_test_value(pt.vector(), rng.normal(size=20).astype(config.floatX)), | ||
set_test_value(pt.vector(dtype="int64"), []), | ||
), | ||
( | ||
5, | ||
0, | ||
set_test_value(pt.vector(), rng.normal(size=5).astype(config.floatX)), | ||
set_test_value( | ||
pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5) | ||
), | ||
), | ||
( | ||
5, | ||
0, | ||
set_test_value(pt.vector(), rng.normal(size=10).astype(config.floatX)), | ||
set_test_value( | ||
pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5) | ||
), | ||
), | ||
( | ||
5, | ||
-1, | ||
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), | ||
set_test_value( | ||
pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5) | ||
), | ||
), | ||
( | ||
5, | ||
-2, | ||
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)), | ||
set_test_value( | ||
pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5) | ||
), | ||
), | ||
], | ||
) | ||
def test_Split(n_splits, axis, values, sizes): | ||
g = pt.split(values, sizes, n_splits, axis=axis) | ||
assert len(g) == n_splits | ||
if n_splits == 0: | ||
return | ||
g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g) | ||
|
||
compare_pytorch_and_py( | ||
g_fg, | ||
[ | ||
i.tag.test_value | ||
for i in g_fg.inputs | ||
if not isinstance(i, SharedVariable | Constant) | ||
], | ||
) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hopefully when #1159 gets merged we can just delete this flag altogether since torch will know these aren't dynamic