Skip to content

Commit

Permalink
feat: support output_padding argument in deconv converter
Browse files Browse the repository at this point in the history
  • Loading branch information
chohk88 committed Jan 20, 2025
1 parent f48f040 commit 151340b
Show file tree
Hide file tree
Showing 4 changed files with 136 additions and 14 deletions.
10 changes: 1 addition & 9 deletions py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2447,16 +2447,8 @@ def aten_ops_le(
)


def conv_param_validator(
conv_node: Node, settings: Optional[CompilationSettings] = None
) -> bool:

return conv_node.args[7] in ([0], [0, 0], [0, 0, 0])


@dynamo_tensorrt_converter(
torch.ops.aten.convolution.default,
capability_validator=conv_param_validator,
supports_dynamic_shapes=True,
)
@enforce_tensor_types(
Expand Down Expand Up @@ -2502,7 +2494,7 @@ def aten_ops_convolution(
stride=args[3],
padding=args[4],
dilation=args[5],
# output_padding=args[7],
output_padding=args[7],
groups=args[8],
)

Expand Down
23 changes: 23 additions & 0 deletions py/torch_tensorrt/dynamo/conversion/impl/deconv.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import tensorrt as trt
import torch
from torch.fx.node import Target

from torch_tensorrt.dynamo.conversion import impl
from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext
from torch_tensorrt.dynamo.conversion.converter_utils import (
Expand Down Expand Up @@ -105,6 +106,9 @@ def deconvNd(
padding = (padding,) if isinstance(padding, int) else padding
stride = (stride,) if isinstance(stride, int) else stride
dilation = (dilation,) if isinstance(dilation, int) else dilation
output_padding = (
(output_padding,) if isinstance(output_padding, int) else output_padding
)

# Expand parameters manually for Conv1D computations
if is_deconv1d:
Expand All @@ -113,6 +117,11 @@ def deconvNd(
dilation = (
extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation
)
output_padding = (
(tuple(output_padding) + (0,))
if output_padding is not None
else output_padding
)

set_layer_name(deconv_layer, target, name, source_ir)

Expand All @@ -126,6 +135,20 @@ def deconvNd(
if groups is not None:
deconv_layer.num_groups = groups

ndims = len(padding)
pre_padding_values = []
post_padding_values = []

for dim in range(ndims):
pre_padding = padding[dim]
post_padding = padding[dim] - output_padding[dim]

pre_padding_values.append(pre_padding)
post_padding_values.append(post_padding)

deconv_layer.pre_padding = tuple(pre_padding_values)
deconv_layer.post_padding = tuple(post_padding_values)

# Handle quantization cases
if scale is not None and zero_point is not None:
# Assume the dtype of activation is torch.quint8
Expand Down
56 changes: 56 additions & 0 deletions py/torch_tensorrt/fx/converters/aten_ops_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,62 @@ def aten_ops_batch_norm(
)


@tensorrt_converter(torch.ops.aten.convolution.default)
def aten_ops_convolution(
network: TRTNetwork,
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
kwargs_new = {
"input": args[0],
"weight": args[1],
"bias": args[2],
"stride": args[3],
"padding": args[4],
"dilation": args[5],
"groups": args[8],
}
# we do not handle transposed.
if args[6] is True:
raise RuntimeError(f"Target {target} does not support `transposed=True` ")
# we do not handle output_padding.
if args[7] not in ([0], [0, 0], [0, 0, 0]):
raise RuntimeError(f"Target {target} has non-0 output_padding")

if len(kwargs_new["stride"]) == 1:
return convolution.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=True,
input_val=kwargs_new["input"],
weight=kwargs_new["weight"],
bias=kwargs_new["bias"],
stride=kwargs_new["stride"],
padding=kwargs_new["padding"],
dilation=kwargs_new["dilation"],
groups=kwargs_new["groups"],
)
else:
return convolution.convNd(
network,
target,
source_ir=SourceIR.ATEN,
name=name,
is_conv1d=False,
input_val=kwargs_new["input"],
weight=kwargs_new["weight"],
bias=kwargs_new["bias"],
stride=kwargs_new["stride"],
padding=kwargs_new["padding"],
dilation=kwargs_new["dilation"],
groups=kwargs_new["groups"],
)


@tensorrt_converter(torch.ops.aten.div.default)
@tensorrt_converter(torch.ops.aten.div.Tensor_mode)
@tensorrt_converter(torch.ops.aten.div.Tensor)
Expand Down
61 changes: 56 additions & 5 deletions tests/py/dynamo/conversion/test_deconvolution_aten.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import torch
from parameterized import param, parameterized
from torch.testing._internal.common_utils import run_tests

from torch_tensorrt import Input

from .harness import DispatchTestCase
Expand All @@ -15,6 +16,21 @@ class TestDeconvolutionConverter(DispatchTestCase):
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
param("output_padding_2", 3, stride=2, padding=2, output_padding=1),
param("output_padding_3", 3, stride=2, padding=3, output_padding=1),
param("output_padding_4", 3, stride=3, padding=2, output_padding=1),
param("output_padding_5", 3, stride=3, padding=3, output_padding=1),
param("output_padding_6", 3, stride=3, padding=3, output_padding=2),
param(
"combined_params",
3,
stride=3,
padding=3,
dilation=2,
groups=3,
output_padding=2,
),
]
)
def test_deconv1d(
Expand All @@ -26,6 +42,7 @@ def test_deconv1d(
dilation=1,
groups=1,
bias=True,
output_padding=0,
):
class TestModule(torch.nn.Module):
def __init__(self):
Expand All @@ -36,9 +53,10 @@ def __init__(self):
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
)

def forward(self, x):
Expand Down Expand Up @@ -101,6 +119,22 @@ def forward(self, x):
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
param("output_padding_2", 3, stride=2, padding=1, output_padding=1),
param("output_padding_3", 3, stride=2, padding=2, output_padding=1),
param("output_padding_4", 3, stride=2, padding=3, output_padding=1),
param("output_padding_5", 3, stride=3, padding=2, output_padding=1),
param("output_padding_6", 3, stride=3, padding=3, output_padding=1),
param("output_padding_7", 3, stride=3, padding=3, output_padding=2),
param(
"combined_params",
3,
stride=3,
padding=3,
dilation=2,
groups=3,
output_padding=2,
),
]
)
def test_deconv2d(
Expand All @@ -112,6 +146,7 @@ def test_deconv2d(
dilation=1,
groups=1,
bias=True,
output_padding=0,
):
class TestModule(torch.nn.Module):
def __init__(self):
Expand All @@ -122,9 +157,10 @@ def __init__(self):
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
)

def forward(self, x):
Expand Down Expand Up @@ -172,6 +208,19 @@ def forward(self, x):
param("non_zero_padding", 1, padding=1),
param("dilation", 1, dilation=2),
param("groups", 1, groups=3),
param("output_padding_1", 3, stride=2, padding=1, output_padding=1),
param("output_padding_2", 3, stride=2, padding=2, output_padding=1),
param("output_padding_3", 3, stride=3, padding=3, output_padding=1),
param("output_padding_4", 3, stride=3, padding=3, output_padding=2),
param(
"combined_params",
3,
stride=3,
padding=3,
dilation=2,
groups=3,
output_padding=2,
),
]
)
def test_deconv3d(
Expand All @@ -183,6 +232,7 @@ def test_deconv3d(
dilation=1,
groups=1,
bias=True,
output_padding=0,
):
class TestModule(torch.nn.Module):
def __init__(self):
Expand All @@ -193,9 +243,10 @@ def __init__(self):
kernel_size=kernel_size,
stride=stride,
padding=padding,
dilation=dilation,
output_padding=output_padding,
groups=groups,
bias=bias,
dilation=dilation,
)

def forward(self, x):
Expand All @@ -209,8 +260,8 @@ def forward(self, x):
enable_passes=True,
)

# Testing with (-1, -1, -1, -1, -1) results into Error:
# AssertionError: Channel dim can't be dynamic for deconvolution.
# # Testing with (-1, -1, -1, -1, -1) results into Error:
# # AssertionError: Channel dim can't be dynamic for deconvolution.

def test_deconv3d_with_dynamic_shape(self):
class TestModule(torch.nn.Module):
Expand Down

0 comments on commit 151340b

Please sign in to comment.