Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PyTorch inline constants in dispatch to avoid graph breaks #1118

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 37 additions & 7 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
Expand All @@ -19,6 +20,7 @@
Eye,
Join,
MakeVector,
Split,
TensorFromScalar,
)

Expand Down Expand Up @@ -120,14 +122,23 @@


@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
def pytorch_funcify_Join(op, node, **kwargs):
axis = node.inputs[0]

return torch.cat(tensors, dim=axis)
if isinstance(axis, Constant):
axis = int(axis.data)

return join
def join_constant_axis(_, *tensors):
return torch.cat(tensors, dim=axis)

Check warning on line 132 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L132

Added line #L132 was not covered by tests

return join_constant_axis

else:

def join(axis, *tensors):
return torch.cat(tensors, dim=axis)

Check warning on line 139 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L138-L139

Added lines #L138 - L139 were not covered by tests

return join

Check warning on line 141 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L141

Added line #L141 was not covered by tests


@pytorch_funcify.register(Eye)
Expand Down Expand Up @@ -172,7 +183,6 @@
@pytorch_funcify.register(OpFromGraph)
def pytorch_funcify_OpFromGraph(op, node, **kwargs):
kwargs.pop("storage_map", None)

# Apply inner rewrites
PYTORCH.optimizer(op.fgraph)
fgraph_fn = pytorch_funcify(op.fgraph, **kwargs, squeeze_output=True)
Expand All @@ -185,3 +195,23 @@
return torch.as_tensor(x)

return tensorfromscalar


@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
x, dim, split_sizes = node.inputs
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
dim = int(dim.data)
split_sizes = tuple(int(size) for size in split_sizes.data)

Check warning on line 205 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L204-L205

Added lines #L204 - L205 were not covered by tests

def split_constant_axis_and_sizes(x, *_):
return x.split(split_sizes, dim=dim)

Check warning on line 208 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L207-L208

Added lines #L207 - L208 were not covered by tests

return split_constant_axis_and_sizes

Check warning on line 210 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L210

Added line #L210 was not covered by tests

else:

def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())

return inner_fn
6 changes: 6 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
Invert,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus


@pytorch_funcify.register(Invert)
def pytorch_funcify_invert(op, node, **kwargs):
return torch.bitwise_not

Check warning on line 17 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L17

Added line #L17 was not covered by tests


@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
Expand Down
19 changes: 16 additions & 3 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import torch

from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast


@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
_, shape = node.inputs

return reshape
if isinstance(shape, Constant):
constant_shape = tuple(int(dim) for dim in shape.data)

def reshape_constant_shape(x, *_):
return torch.reshape(x, constant_shape)

Check warning on line 16 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L16

Added line #L16 was not covered by tests

return reshape_constant_shape

else:

def reshape(x, shape):
return torch.reshape(x, tuple(shape))

return reshape


@pytorch_funcify.register(Shape)
Expand Down
15 changes: 15 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand All @@ -23,7 +24,21 @@
@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list
x, *idxs = node.inputs

if all(isinstance(idx, Constant) for idx in idxs):
# Use constant indices to avoid graph break
constant_indices = indices_from_subtensor(
[int(idx.data) for idx in idxs], idx_list
)
check_negative_steps(constant_indices)

def constant_index_subtensor(x, *_):
return x[constant_indices]

Check warning on line 37 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L37

Added line #L37 was not covered by tests

return constant_index_subtensor

# Fallback that will introduce a graph break
def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
Expand Down
3 changes: 3 additions & 0 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ def conversion_func_register(*args, **kwargs):
def jit_compile(self, fn):
import torch

# flag that tend to help our graphs
torch._dynamo.config.capture_dynamic_output_shape_ops = True
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hopefully when #1159 gets merged we can just delete this flag altogether since torch will know these aren't dynamic


from pytensor.link.pytorch.dispatch import pytorch_typify

class wrapper:
Expand Down
70 changes: 69 additions & 1 deletion tests/link/pytorch/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Apply
from pytensor.graph.basic import Apply, Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.ifelse import ifelse
Expand All @@ -38,6 +38,11 @@
py_mode = Mode(linker="py", optimizer=None)


def set_test_value(x, v):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not test value stuff, we're trying to get rid of it!

x.tag.test_value = v
return x


def compare_pytorch_and_py(
fgraph: FunctionGraph,
test_inputs: Iterable,
Expand Down Expand Up @@ -471,3 +476,66 @@ def test_ScalarLoop_Elemwise_multi_carries():
compare_pytorch_and_py(
f, args, assert_fn=partial(np.testing.assert_allclose, rtol=1e-6)
)


rng = np.random.default_rng(42849)


@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
(
0,
0,
set_test_value(pt.vector(), rng.normal(size=20).astype(config.floatX)),
set_test_value(pt.vector(dtype="int64"), []),
),
(
5,
0,
set_test_value(pt.vector(), rng.normal(size=5).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)
),
),
(
5,
0,
set_test_value(pt.vector(), rng.normal(size=10).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)
),
),
(
5,
-1,
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)
),
),
(
5,
-2,
set_test_value(pt.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value(
pt.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)
),
),
],
)
def test_Split(n_splits, axis, values, sizes):
g = pt.split(values, sizes, n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g)

compare_pytorch_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, SharedVariable | Constant)
],
)
Loading