Skip to content

Commit

Permalink
fix to tensor return nested structure (#57100)
Browse files Browse the repository at this point in the history
  • Loading branch information
feifei-111 authored Sep 9, 2023
1 parent 2b2b2d7 commit 7727108
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 21 deletions.
2 changes: 0 additions & 2 deletions python/paddle/tensor/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -680,8 +680,6 @@ def _to_tensor_static(data, dtype=None, stop_gradient=None):
d, dtype, stop_gradient
)
data = paddle.stack(to_stack_list)
data = paddle.squeeze(data, -1)

else:
raise RuntimeError(
f"Do not support transform type `{type(data)}` to tensor"
Expand Down
2 changes: 1 addition & 1 deletion test/dygraph_to_static/test_cpu_cuda_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def func(x):
x = paddle.to_tensor([3])
np.testing.assert_allclose(
paddle.jit.to_static(func)(x).numpy(),
np.array([1, 2, 3, 4]),
np.array([[1], [2], [3], [4]]),
rtol=1e-05,
)

Expand Down
19 changes: 1 addition & 18 deletions test/dygraph_to_static/test_to_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,7 @@
import unittest

import numpy
from dygraph_to_static_util import (
ast_only_test,
dy2static_unittest,
sot_only_test,
)
from dygraph_to_static_util import dy2static_unittest

import paddle
from paddle.base import core
Expand Down Expand Up @@ -154,7 +150,6 @@ def test_to_tensor_badreturn(self):
self.assertTrue(a.stop_gradient == b.stop_gradient)
self.assertTrue(a.place._equals(b.place))

@ast_only_test
def test_to_tensor_err_log(self):
paddle.disable_static()
x = paddle.to_tensor([3])
Expand All @@ -166,18 +161,6 @@ def test_to_tensor_err_log(self):
in str(e)
)

@sot_only_test
def test_to_tensor_err_log_sot(self):
paddle.disable_static()
x = paddle.to_tensor([3])
try:
a = paddle.jit.to_static(case8)(x)
except Exception as e:
self.assertTrue(
"Can't constructs a 'paddle.Tensor' with data type <class 'dict'>"
in str(e)
)


class TestStatic(unittest.TestCase):
def test_static(self):
Expand Down

0 comments on commit 7727108

Please sign in to comment.