From a18071e299b13a367304052817ecc6afdcb53180 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Fri, 15 Dec 2023 20:08:07 +0800 Subject: [PATCH 1/2] add automatic data format for upsample --- python/paddle/nn/functional/common.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index b678c80344d30..0199182e5e82a 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -906,6 +906,15 @@ def upsample( [2, 3, 12, 12] """ + if data_format == 'NCHW': + x_shape = len(x.shape) + if x_shape == 3: + data_format = 'NCW' + elif x_shape == 4: + data_format == 'NCHW' + elif x_shape == 5: + data_format == 'NCDHW' + return interpolate( x, size, scale_factor, mode, align_corners, align_mode, data_format ) From e04626f7e777102a4d157364350566ae8509a045 Mon Sep 17 00:00:00 2001 From: iLeGend <824040212@qq.com> Date: Fri, 15 Dec 2023 21:39:12 +0800 Subject: [PATCH 2/2] fix --- python/paddle/nn/functional/common.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/nn/functional/common.py b/python/paddle/nn/functional/common.py index 0199182e5e82a..97a647fbc876d 100644 --- a/python/paddle/nn/functional/common.py +++ b/python/paddle/nn/functional/common.py @@ -911,9 +911,9 @@ def upsample( if x_shape == 3: data_format = 'NCW' elif x_shape == 4: - data_format == 'NCHW' + data_format = 'NCHW' elif x_shape == 5: - data_format == 'NCDHW' + data_format = 'NCDHW' return interpolate( x, size, scale_factor, mode, align_corners, align_mode, data_format