From 151340bb2fea04ccb73bb9fd35e9393ea098a1ee Mon Sep 17 00:00:00 2001 From: Hoonkyung Cho Date: Mon, 20 Jan 2025 20:05:58 +0900 Subject: [PATCH] feat: support output_padding argument in deconv converter --- .../dynamo/conversion/aten_ops_converters.py | 10 +-- .../dynamo/conversion/impl/deconv.py | 23 +++++++ .../fx/converters/aten_ops_converters.py | 56 +++++++++++++++++ .../conversion/test_deconvolution_aten.py | 61 +++++++++++++++++-- 4 files changed, 136 insertions(+), 14 deletions(-) diff --git a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py index 5fd2f3ff37..eac8a4b70c 100644 --- a/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py +++ b/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py @@ -2447,16 +2447,8 @@ def aten_ops_le( ) -def conv_param_validator( - conv_node: Node, settings: Optional[CompilationSettings] = None -) -> bool: - - return conv_node.args[7] in ([0], [0, 0], [0, 0, 0]) - - @dynamo_tensorrt_converter( torch.ops.aten.convolution.default, - capability_validator=conv_param_validator, supports_dynamic_shapes=True, ) @enforce_tensor_types( @@ -2502,7 +2494,7 @@ def aten_ops_convolution( stride=args[3], padding=args[4], dilation=args[5], - # output_padding=args[7], + output_padding=args[7], groups=args[8], ) diff --git a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py index 03a209e2a5..d19a92e646 100644 --- a/py/torch_tensorrt/dynamo/conversion/impl/deconv.py +++ b/py/torch_tensorrt/dynamo/conversion/impl/deconv.py @@ -6,6 +6,7 @@ import tensorrt as trt import torch from torch.fx.node import Target + from torch_tensorrt.dynamo.conversion import impl from torch_tensorrt.dynamo.conversion._ConversionContext import ConversionContext from torch_tensorrt.dynamo.conversion.converter_utils import ( @@ -105,6 +106,9 @@ def deconvNd( padding = (padding,) if isinstance(padding, int) else padding stride = (stride,) if isinstance(stride, int) else stride dilation = (dilation,) if isinstance(dilation, int) else dilation + output_padding = ( + (output_padding,) if isinstance(output_padding, int) else output_padding + ) # Expand parameters manually for Conv1D computations if is_deconv1d: @@ -113,6 +117,11 @@ def deconvNd( dilation = ( extend_attr_to_tuple(dilation, 2) if dilation is not None else dilation ) + output_padding = ( + (tuple(output_padding) + (0,)) + if output_padding is not None + else output_padding + ) set_layer_name(deconv_layer, target, name, source_ir) @@ -126,6 +135,20 @@ def deconvNd( if groups is not None: deconv_layer.num_groups = groups + ndims = len(padding) + pre_padding_values = [] + post_padding_values = [] + + for dim in range(ndims): + pre_padding = padding[dim] + post_padding = padding[dim] - output_padding[dim] + + pre_padding_values.append(pre_padding) + post_padding_values.append(post_padding) + + deconv_layer.pre_padding = tuple(pre_padding_values) + deconv_layer.post_padding = tuple(post_padding_values) + # Handle quantization cases if scale is not None and zero_point is not None: # Assume the dtype of activation is torch.quint8 diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 795ae7c4d9..a725ce8aa3 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -104,6 +104,62 @@ def aten_ops_batch_norm( ) +@tensorrt_converter(torch.ops.aten.convolution.default) +def aten_ops_convolution( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + kwargs_new = { + "input": args[0], + "weight": args[1], + "bias": args[2], + "stride": args[3], + "padding": args[4], + "dilation": args[5], + "groups": args[8], + } + # we do not handle transposed. + if args[6] is True: + raise RuntimeError(f"Target {target} does not support `transposed=True` ") + # we do not handle output_padding. + if args[7] not in ([0], [0, 0], [0, 0, 0]): + raise RuntimeError(f"Target {target} has non-0 output_padding") + + if len(kwargs_new["stride"]) == 1: + return convolution.convNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_conv1d=True, + input_val=kwargs_new["input"], + weight=kwargs_new["weight"], + bias=kwargs_new["bias"], + stride=kwargs_new["stride"], + padding=kwargs_new["padding"], + dilation=kwargs_new["dilation"], + groups=kwargs_new["groups"], + ) + else: + return convolution.convNd( + network, + target, + source_ir=SourceIR.ATEN, + name=name, + is_conv1d=False, + input_val=kwargs_new["input"], + weight=kwargs_new["weight"], + bias=kwargs_new["bias"], + stride=kwargs_new["stride"], + padding=kwargs_new["padding"], + dilation=kwargs_new["dilation"], + groups=kwargs_new["groups"], + ) + + @tensorrt_converter(torch.ops.aten.div.default) @tensorrt_converter(torch.ops.aten.div.Tensor_mode) @tensorrt_converter(torch.ops.aten.div.Tensor) diff --git a/tests/py/dynamo/conversion/test_deconvolution_aten.py b/tests/py/dynamo/conversion/test_deconvolution_aten.py index d6cbc0579f..046c646871 100644 --- a/tests/py/dynamo/conversion/test_deconvolution_aten.py +++ b/tests/py/dynamo/conversion/test_deconvolution_aten.py @@ -1,6 +1,7 @@ import torch from parameterized import param, parameterized from torch.testing._internal.common_utils import run_tests + from torch_tensorrt import Input from .harness import DispatchTestCase @@ -15,6 +16,21 @@ class TestDeconvolutionConverter(DispatchTestCase): param("non_zero_padding", 1, padding=1), param("dilation", 1, dilation=2), param("groups", 1, groups=3), + param("output_padding_1", 3, stride=2, padding=1, output_padding=1), + param("output_padding_2", 3, stride=2, padding=2, output_padding=1), + param("output_padding_3", 3, stride=2, padding=3, output_padding=1), + param("output_padding_4", 3, stride=3, padding=2, output_padding=1), + param("output_padding_5", 3, stride=3, padding=3, output_padding=1), + param("output_padding_6", 3, stride=3, padding=3, output_padding=2), + param( + "combined_params", + 3, + stride=3, + padding=3, + dilation=2, + groups=3, + output_padding=2, + ), ] ) def test_deconv1d( @@ -26,6 +42,7 @@ def test_deconv1d( dilation=1, groups=1, bias=True, + output_padding=0, ): class TestModule(torch.nn.Module): def __init__(self): @@ -36,9 +53,10 @@ def __init__(self): kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, + output_padding=output_padding, groups=groups, bias=bias, + dilation=dilation, ) def forward(self, x): @@ -101,6 +119,22 @@ def forward(self, x): param("non_zero_padding", 1, padding=1), param("dilation", 1, dilation=2), param("groups", 1, groups=3), + param("output_padding_1", 3, stride=2, padding=1, output_padding=1), + param("output_padding_2", 3, stride=2, padding=1, output_padding=1), + param("output_padding_3", 3, stride=2, padding=2, output_padding=1), + param("output_padding_4", 3, stride=2, padding=3, output_padding=1), + param("output_padding_5", 3, stride=3, padding=2, output_padding=1), + param("output_padding_6", 3, stride=3, padding=3, output_padding=1), + param("output_padding_7", 3, stride=3, padding=3, output_padding=2), + param( + "combined_params", + 3, + stride=3, + padding=3, + dilation=2, + groups=3, + output_padding=2, + ), ] ) def test_deconv2d( @@ -112,6 +146,7 @@ def test_deconv2d( dilation=1, groups=1, bias=True, + output_padding=0, ): class TestModule(torch.nn.Module): def __init__(self): @@ -122,9 +157,10 @@ def __init__(self): kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, + output_padding=output_padding, groups=groups, bias=bias, + dilation=dilation, ) def forward(self, x): @@ -172,6 +208,19 @@ def forward(self, x): param("non_zero_padding", 1, padding=1), param("dilation", 1, dilation=2), param("groups", 1, groups=3), + param("output_padding_1", 3, stride=2, padding=1, output_padding=1), + param("output_padding_2", 3, stride=2, padding=2, output_padding=1), + param("output_padding_3", 3, stride=3, padding=3, output_padding=1), + param("output_padding_4", 3, stride=3, padding=3, output_padding=2), + param( + "combined_params", + 3, + stride=3, + padding=3, + dilation=2, + groups=3, + output_padding=2, + ), ] ) def test_deconv3d( @@ -183,6 +232,7 @@ def test_deconv3d( dilation=1, groups=1, bias=True, + output_padding=0, ): class TestModule(torch.nn.Module): def __init__(self): @@ -193,9 +243,10 @@ def __init__(self): kernel_size=kernel_size, stride=stride, padding=padding, - dilation=dilation, + output_padding=output_padding, groups=groups, bias=bias, + dilation=dilation, ) def forward(self, x): @@ -209,8 +260,8 @@ def forward(self, x): enable_passes=True, ) - # Testing with (-1, -1, -1, -1, -1) results into Error: - # AssertionError: Channel dim can't be dynamic for deconvolution. + # # Testing with (-1, -1, -1, -1, -1) results into Error: + # # AssertionError: Channel dim can't be dynamic for deconvolution. def test_deconv3d_with_dynamic_shape(self): class TestModule(torch.nn.Module):