From 78ae76b4deeef6945755fcb7f9cfba347e64192d Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Thu, 19 Jan 2023 18:42:11 +0800 Subject: [PATCH 1/3] add unstack axis check --- .../fluid/tests/unittests/test_unstack_op.py | 20 +++++++++++++++++++ python/paddle/tensor/manipulation.py | 4 ++++ 2 files changed, 24 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index 1dda05fb0a6b8..4012e5dd6a248 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -84,5 +84,25 @@ def initParameters(self): self.axis = 2 +class TestUnstackZeroInputOp(unittest.TestCase): + def unstack_zero_input_static(self): + + paddle.enable_static() + + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + paddle.unstack(x, axis=1) + + def unstack_zero_input_dynamic(self): + + array = np.array([], dtype=np.float32) + x = paddle.to_tensor(np.reshape(array, [0]), dtype='float32') + paddle.unstack(x, axis=1) + + def test_type_error(self): + self.assertRaises(IndexError, self.unstack_zero_input_dynamic) + self.assertRaises(IndexError, self.unstack_zero_input_static) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index bdd903ee8f196..923e6923d6d63 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -543,6 +543,10 @@ def unstack(x, axis=0, num=None): y = paddle.unstack(x, axis=1) # unstack with second axis, which results 3 tensors with shape=[2, 5] """ + if not (-x.ndim <= axis < x.ndim): + raise ValueError( + '`axis` must be in the range [-{0}, {0})'.format(x.ndim) + ) if in_dygraph_mode(): if num is None: num = x.shape[axis] From 52a6afbbc9b42e251e458c72d1038da557dddcdd Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Thu, 19 Jan 2023 18:46:01 +0800 Subject: [PATCH 2/3] IndexErr -> ValueError --- python/paddle/fluid/tests/unittests/test_unstack_op.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index 4012e5dd6a248..5fe988e623a0d 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -100,8 +100,8 @@ def unstack_zero_input_dynamic(self): paddle.unstack(x, axis=1) def test_type_error(self): - self.assertRaises(IndexError, self.unstack_zero_input_dynamic) - self.assertRaises(IndexError, self.unstack_zero_input_static) + self.assertRaises(ValueError, self.unstack_zero_input_dynamic) + self.assertRaises(ValueError, self.unstack_zero_input_static) if __name__ == '__main__': From 71aa8595ea7c6969e56969888c99c8e70f1f5a5c Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Thu, 19 Jan 2023 21:36:55 +0800 Subject: [PATCH 3/3] add static select --- python/paddle/fluid/tests/unittests/test_unstack_op.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/python/paddle/fluid/tests/unittests/test_unstack_op.py b/python/paddle/fluid/tests/unittests/test_unstack_op.py index 5fe988e623a0d..745e14983a56a 100755 --- a/python/paddle/fluid/tests/unittests/test_unstack_op.py +++ b/python/paddle/fluid/tests/unittests/test_unstack_op.py @@ -100,9 +100,13 @@ def unstack_zero_input_dynamic(self): paddle.unstack(x, axis=1) def test_type_error(self): + paddle.disable_static() + self.assertRaises(ValueError, self.unstack_zero_input_dynamic) self.assertRaises(ValueError, self.unstack_zero_input_static) + paddle.disable_static() + if __name__ == '__main__': unittest.main()