Skip to content

Commit

Permalink
PyTorch inline constants in dispatch to avoid graph breaks
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Dec 12, 2024
1 parent 231a977 commit c08d288
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 9 deletions.
43 changes: 37 additions & 6 deletions pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.utils import fgraph_to_python
Expand All @@ -19,6 +20,7 @@
Eye,
Join,
MakeVector,
Split,
TensorFromScalar,
)

Expand Down Expand Up @@ -120,14 +122,23 @@ def arange(start, stop, step):


@pytorch_funcify.register(Join)
def pytorch_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [torch.tensor(tensor) for tensor in tensors]
def pytorch_funcify_Join(op, node, **kwargs):
axis = node.inputs[0]

return torch.cat(tensors, dim=axis)
if isinstance(axis, Constant):
axis = int(axis.data)

return join
def join_constant_axis(_, *tensors):
return torch.cat(tensors, dim=axis)

Check warning on line 132 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L132

Added line #L132 was not covered by tests

return join_constant_axis

else:

def join(axis, *tensors):
return torch.cat(tensors, dim=axis)

Check warning on line 139 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L138-L139

Added lines #L138 - L139 were not covered by tests

return join

Check warning on line 141 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L141

Added line #L141 was not covered by tests


@pytorch_funcify.register(Eye)
Expand Down Expand Up @@ -185,3 +196,23 @@ def tensorfromscalar(x):
return torch.as_tensor(x)

return tensorfromscalar


@pytorch_funcify.register(Split)
def pytorch_funcify_Split(op, node, **kwargs):
x, dim, split_sizes = node.inputs

Check warning on line 203 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L203

Added line #L203 was not covered by tests
if isinstance(dim, Constant) and isinstance(split_sizes, Constant):
dim = int(dim.data)
split_sizes = tuple(int(size) for size in split_sizes.data)

Check warning on line 206 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L205-L206

Added lines #L205 - L206 were not covered by tests

def split_constant_axis_and_sizes(x, *_):
return x.split(split_sizes, dim=dim)

Check warning on line 209 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L208-L209

Added lines #L208 - L209 were not covered by tests

return split_constant_axis_and_sizes

Check warning on line 211 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L211

Added line #L211 was not covered by tests

else:

def inner_fn(x, dim, split_amounts):
return x.split(split_amounts.tolist(), dim=dim.item())

Check warning on line 216 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L215-L216

Added lines #L215 - L216 were not covered by tests

return inner_fn

Check warning on line 218 in pytensor/link/pytorch/dispatch/basic.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/basic.py#L218

Added line #L218 was not covered by tests
6 changes: 6 additions & 0 deletions pytensor/link/pytorch/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,18 @@
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
Invert,
ScalarOp,
)
from pytensor.scalar.loop import ScalarLoop
from pytensor.scalar.math import Softplus


@pytorch_funcify.register(Invert)
def pytorch_funcify_invert(op, node, **kwargs):
return torch.bitwise_not

Check warning on line 17 in pytensor/link/pytorch/dispatch/scalar.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/scalar.py#L17

Added line #L17 was not covered by tests


@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
Expand Down
19 changes: 16 additions & 3 deletions pytensor/link/pytorch/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,28 @@
import torch

from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast


@pytorch_funcify.register(Reshape)
def pytorch_funcify_Reshape(op, node, **kwargs):
def reshape(x, shape):
return torch.reshape(x, tuple(shape))
_, shape = node.inputs

return reshape
if isinstance(shape, Constant):
constant_shape = tuple(int(dim) for dim in shape.data)

def reshape_constant_shape(x, *_):
return torch.reshape(x, constant_shape)

Check warning on line 16 in pytensor/link/pytorch/dispatch/shape.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/shape.py#L16

Added line #L16 was not covered by tests

return reshape_constant_shape

else:

def reshape(x, shape):
return torch.reshape(x, tuple(shape))

return reshape


@pytorch_funcify.register(Shape)
Expand Down
15 changes: 15 additions & 0 deletions pytensor/link/pytorch/dispatch/subtensor.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pytensor.graph.basic import Constant
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
Expand All @@ -23,7 +24,21 @@ def check_negative_steps(indices):
@pytorch_funcify.register(Subtensor)
def pytorch_funcify_Subtensor(op, node, **kwargs):
idx_list = op.idx_list
x, *idxs = node.inputs

if all(isinstance(idx, Constant) for idx in idxs):
# Use constant indices to avoid graph break
constant_indices = indices_from_subtensor(
[int(idx.data) for idx in idxs], idx_list
)
check_negative_steps(constant_indices)

def constant_index_subtensor(x, *_):
return x[constant_indices]

Check warning on line 37 in pytensor/link/pytorch/dispatch/subtensor.py

View check run for this annotation

Codecov / codecov/patch

pytensor/link/pytorch/dispatch/subtensor.py#L37

Added line #L37 was not covered by tests

return constant_index_subtensor

# Fallback that will introduce a graph break
def subtensor(x, *flattened_indices):
indices = indices_from_subtensor(flattened_indices, idx_list)
check_negative_steps(indices)
Expand Down

0 comments on commit c08d288

Please sign in to comment.