Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] pir dy2st unittest verification - Part 12 #59378

Merged
merged 22 commits into from
Dec 6, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion paddle/fluid/pir/dialect/op_generator/api_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -572,7 +572,7 @@ def _gen_check_data_type(self, op_info, op_name):

if (
op_name.endswith(('_grad', '_grad_', '_grad_dense', '_grad_sparse'))
or op_name in ["print", "hardshrink", "det"]
or op_name in ["print", "hardshrink", "det", "assign_out_"]
or len(mapping_name_to_type) == 0
):
return ""
Expand Down
6 changes: 4 additions & 2 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -623,11 +623,13 @@ void BindValue(py::module *m) {
[](Value self) {
if (auto param_op = self.defining_op<::pir::ParameterOp>()) {
return param_op.param_name();
} else if (auto data_op =
self.defining_op<paddle::dialect::DataOp>()) {
return data_op.attribute<pir::StrAttribute>("name").AsString();
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Currently, we can only get name of Value that "
"is "
"persistable"));
"is persistable"));
}
})
.def_property(
Expand Down
5 changes: 3 additions & 2 deletions python/paddle/jit/dy2static/program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
)
from paddle.framework import in_dynamic_mode, use_pir_api
from paddle.nn.layer import layers
from paddle.pir import OpResult
from paddle.utils import flatten, gast

from . import error, logging_utils
Expand Down Expand Up @@ -1032,7 +1033,7 @@ def inputs(self):
inputs = [
var
for var in flatten(concrete_program.inputs)
if isinstance(var, framework.Variable)
if isinstance(var, (framework.Variable, OpResult))
]
return inputs

Expand All @@ -1046,7 +1047,7 @@ def outputs(self):
outputs = [
var
for var in flatten(concrete_program.outputs)
if isinstance(var, framework.Variable)
if isinstance(var, (framework.Variable, OpResult))
]

return outputs
Expand Down
9 changes: 7 additions & 2 deletions python/paddle/nn/layer/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1677,7 +1677,10 @@ def _remove_if_exist(*dicts):

_remove_if_exist(self.__dict__, self._buffers, self._sub_layers)
params[name] = value
elif isinstance(value, paddle.pir.OpResult) and value.persistable:
elif (
isinstance(value, paddle.pir.OpResult)
and value.get_defining_op().name() == 'builtin.parameter'
):
if params is None:
raise ValueError("super().__init__() should be called first")
_remove_if_exist(self.__dict__, self._buffers, self._sub_layers)
Expand Down Expand Up @@ -1729,7 +1732,9 @@ def _remove_if_exist(*dicts):
# Note(Aurelius84): In Dy2stat, the value of the Buffer may be modified in
# decorated function, such as `self.buffer = new_tensor`. So we update its
# value via `assign`.
if type(value) == framework.Variable:
if type(value) == framework.Variable or isinstance(
value, paddle.pir.OpResult
):
from paddle import assign

# Note(zhhsplendid): the condition below happens in PaddleGan model,
Expand Down
41 changes: 26 additions & 15 deletions test/dygraph_to_static/test_basic_api_transformation.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, test_default_mode_only
from dygraph_to_static_utils import (
Dy2StTestBase,
static_guard,
test_default_mode_only,
)

import paddle
from paddle import base, to_tensor
Expand All @@ -31,7 +35,6 @@

# TODO(zhhsplendid): This test is old so that use a static graph style
# mark it as TODO, to refactoring the code of this file.
paddle.enable_static()


def dyfunc_to_variable(x):
Expand Down Expand Up @@ -105,11 +108,12 @@ def get_static_output(self):

@test_default_mode_only
def test_transformed_static_result(self):
for func in self.test_funcs:
self.dygraph_func = func
dygraph_res = self.get_dygraph_output()
static_res = self.get_static_output()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
with static_guard():
for func in self.test_funcs:
self.dygraph_func = func
dygraph_res = self.get_dygraph_output()
static_res = self.get_static_output()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)


# 1. test Apis that inherit from layers.Layer
Expand Down Expand Up @@ -263,9 +267,10 @@ def get_static_output(self):

@test_default_mode_only
def test_transformed_static_result(self):
dygraph_res = self.get_dygraph_output()
static_res = self.get_static_output()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
with static_guard():
dygraph_res = self.get_dygraph_output()
static_res = self.get_static_output()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)


class TestDygraphBasicApi_BilinearTensorProduct(TestDygraphBasicApi):
Expand Down Expand Up @@ -421,9 +426,10 @@ def get_static_output(self):

@test_default_mode_only
def test_transformed_static_result(self):
dygraph_res = self.get_dygraph_output()
static_res = self.get_static_output()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)
with static_guard():
dygraph_res = self.get_dygraph_output()
static_res = self.get_static_output()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)


class TestDygraphBasicApi_ExponentialDecay(TestDygraphBasicApi_CosineDecay):
Expand Down Expand Up @@ -548,8 +554,13 @@ def _get_static_ast_node(self):

@test_default_mode_only
def test_dygraph_api(self):
self.assertTrue(is_dygraph_api(self._get_dygraph_ast_node()) is True)
self.assertTrue(is_dygraph_api(self._get_static_ast_node()) is False)
with static_guard():
self.assertTrue(
is_dygraph_api(self._get_dygraph_ast_node()) is True
)
self.assertTrue(
is_dygraph_api(self._get_static_ast_node()) is False
)


if __name__ == '__main__':
Expand Down
49 changes: 25 additions & 24 deletions test/dygraph_to_static/test_convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
)

import paddle
Expand Down Expand Up @@ -46,32 +47,32 @@ def forward(self):
class TestConvertCall(Dy2StTestBase):
# fallback mode will raise a InnerError, it's ok.
@test_ast_only
@test_legacy_and_pt_and_pir
def test_class_exception(self):
@paddle.jit.to_static
def call_not_exist():
net = CallNotExist()
return net()

with self.assertRaises(AttributeError):
call_not_exist()
paddle.jit.to_static(call_not_exist())

@paddle.jit.to_static
def forward_not_exist():
return net()

with self.assertRaises(AttributeError):
forward_not_exist()
paddle.jit.to_static(forward_not_exist)()

@test_legacy_and_pt_and_pir
def test_callable_list(self):
@paddle.jit.to_static
def callable_list(x, y):
callable_list = CallableList()
return callable_list(x) + y

self.assertEqual(callable_list(1, 2), 3)
self.assertEqual(paddle.jit.to_static(callable_list)(1, 2), 3)


class TestConvertShapeCompare(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_non_variable(self):
self.assertEqual(
paddle.jit.dy2static.convert_shape_compare(1, "<", 2), True
Expand Down Expand Up @@ -135,9 +136,9 @@ def error_func():

def test_variable(self):
paddle.enable_static()
with paddle.static.program_guard(
paddle.static.Program(), paddle.static.Program()
):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.static.data(name='x', shape=[3, 2], dtype='float32')
y = paddle.static.data(name='y', shape=[3, 2], dtype='float32')
self.assertEqual(
Expand Down Expand Up @@ -196,7 +197,6 @@ class ShapeLayer(paddle.nn.Layer):
def __init__(self):
super().__init__()

@paddle.jit.to_static(input_spec=[paddle.static.InputSpec(shape=[None, 1])])
def forward(self, x):
x = paddle.reshape(x, [-1, x.shape[1]])
bs = x.shape[0] # -1
Expand All @@ -207,19 +207,23 @@ def forward(self, x):


class TestChooseShapeAttrOrApiWithLayer(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_tensor_shape(self):
x = paddle.zeros(shape=[4, 1], dtype='float32')
net = ShapeLayer()
net = paddle.jit.to_static(
function=ShapeLayer(),
input_spec=[paddle.static.InputSpec(shape=[None, 1])],
)
out = net(x)

np.testing.assert_array_equal(out.numpy(), x.numpy())


class TestIfElseNoValue(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_else_ret_none(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])

@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
Expand All @@ -230,7 +234,6 @@ def with_common_value(x, use_cache=False):
z = x - 1
return None

@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
Expand All @@ -240,15 +243,15 @@ def without_common_value(x, use_cache=False):
c = x + 1
return None

out = with_common_value(input_x, False)
out = paddle.jit.to_static(with_common_value)(input_x, False)
self.assertIsNone(out)
out = without_common_value(input_x, False)
out = paddle.jit.to_static(without_common_value)(input_x, False)
self.assertIsNone(out)

@test_legacy_and_pt_and_pir
def test_else_ret_c(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])

@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
Expand All @@ -259,7 +262,6 @@ def with_common_value(x, use_cache=False):
z = x - 1
return c

@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
Expand All @@ -269,18 +271,18 @@ def without_common_value(x, use_cache=False):
c = x + 1
return c

out = with_common_value(input_x, False)
out = paddle.jit.to_static(with_common_value)(input_x, False)
self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1))
out = without_common_value(input_x, False)
out = paddle.jit.to_static(without_common_value)(input_x, False)
self.assertListEqual(paddle.tolist(out), paddle.tolist(input_x + 1))
y, z = with_common_value(input_x, True)
y, z = paddle.jit.to_static(with_common_value)(input_x, True)
self.assertListEqual(paddle.tolist(y), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x + 2))

@test_legacy_and_pt_and_pir
def test_else_ret_cz(self):
input_x = paddle.to_tensor([[1, 2, 3], [4, 5, 6]])

@paddle.jit.to_static
def with_common_value(x, use_cache=False):
if use_cache:
y = x + 1
Expand All @@ -291,7 +293,6 @@ def with_common_value(x, use_cache=False):
z = x - 1
return c, z

@paddle.jit.to_static
def without_common_value(x, use_cache=False):
if use_cache:
y = x + 1
Expand All @@ -302,10 +303,10 @@ def without_common_value(x, use_cache=False):
d = x - 1
return c, d

c, z = with_common_value(input_x, False)
c, z = paddle.jit.to_static(with_common_value)(input_x, False)
self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(z), paddle.tolist(input_x - 1))
c, d = without_common_value(input_x, False)
c, d = paddle.jit.to_static(without_common_value)(input_x, False)
self.assertListEqual(paddle.tolist(c), paddle.tolist(input_x + 1))
self.assertListEqual(paddle.tolist(d), paddle.tolist(input_x - 1))

Expand Down
Loading