Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] pir dy2st unittest verification - Part 10 #59276

Merged
merged 9 commits into from
Nov 28, 2023
4 changes: 3 additions & 1 deletion test/dygraph_to_static/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
from copy import deepcopy

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir
from test_rollback import Net, foo

import paddle
from paddle.jit.dy2static.program_translator import StaticFunction


class TestDeepCopy(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_net(self):
net = Net()
net = paddle.jit.to_static(net)
Expand All @@ -39,6 +40,7 @@ def test_net(self):
self.assertTrue(id(copy_net), id(copy_net.forward.__self__))
np.testing.assert_array_equal(src_out.numpy(), copy_out.numpy())

@test_legacy_and_pt_and_pir
def test_func(self):
st_foo = paddle.jit.to_static(foo)
x = paddle.randn([3, 4])
Expand Down
57 changes: 32 additions & 25 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
# limitations under the License.

import inspect
import tempfile
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, test_sot_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
test_sot_only,
)

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -248,6 +253,7 @@ def setUp(self):

self.nested_for_loop_func = nested_for_loop_dyfunc

@test_legacy_and_pt_and_pir
def test_loop_vars(self):
for i in range(len(self.loop_funcs)):
func = self.loop_funcs[i]
Expand All @@ -263,6 +269,7 @@ def test_loop_vars(self):
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])

@test_legacy_and_pt_and_pir
def test_nested_loop_vars(self):
func = self.nested_for_loop_func
test_func = inspect.getsource(func)
Expand Down Expand Up @@ -303,9 +310,9 @@ def test_nested_loop_vars(self):
class TestTransformWhileLoop(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.x = np.zeros(shape=(1), dtype=np.int32)
self._init_dyfunc()
Expand All @@ -320,17 +327,16 @@ def _run_dygraph(self):
return self._run(to_static=False)

def _run(self, to_static):
with base.dygraph.guard(self.place):
# Set the input of dyfunc to Tensor
tensor_x = base.dygraph.to_variable(self.x, zero_copy=False)
if to_static:
ret = paddle.jit.to_static(self.dyfunc)(tensor_x)
else:
ret = self.dyfunc(tensor_x)
if hasattr(ret, "numpy"):
return ret.numpy()
else:
return ret
# Set the input of dyfunc to Tensor
tensor_x = base.dygraph.to_variable(self.x, zero_copy=False)
if to_static:
ret = paddle.jit.to_static(self.dyfunc)(tensor_x)
else:
ret = self.dyfunc(tensor_x)
if hasattr(ret, "numpy"):
return ret.numpy()
else:
return ret

@test_sot_only
def test_ast_to_func(self):
Expand Down Expand Up @@ -383,9 +389,9 @@ def _init_dyfunc(self):
class TestTransformForLoop(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.len = 100
self._init_dyfunc()
Expand All @@ -400,12 +406,11 @@ def _run_dygraph(self):
return self._run(to_static=False)

def _run(self, to_static):
with base.dygraph.guard(self.place):
if to_static:
ret = paddle.jit.to_static(self.dyfunc)(self.len)
else:
ret = self.dyfunc(self.len)
return ret.numpy()
if to_static:
ret = paddle.jit.to_static(self.dyfunc)(self.len)
else:
ret = self.dyfunc(self.len)
return ret.numpy()

@test_sot_only
def test_ast_to_func(self):
Expand Down Expand Up @@ -474,7 +479,9 @@ def test_start(self):
)
],
)
paddle.jit.save(model, "./inference/inference")
temp_dir = tempfile.TemporaryDirectory()
paddle.jit.save(model, temp_dir.name)
temp_dir.cleanup()


if __name__ == '__main__':
Expand Down
179 changes: 113 additions & 66 deletions test/dygraph_to_static/test_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt,
test_legacy_and_pt_and_pir,
test_pir_only,
)
from test_fetch_feed import Linear

import paddle
from paddle import base
from paddle.jit.api import to_static

SEED = 2020

Expand Down Expand Up @@ -75,19 +77,19 @@ def fake_input(self):
]

def _run(self, to_static):
with base.dygraph.guard():
if self.x is None or self.y is None:
self.fake_input()
if self.x is None or self.y is None:
self.fake_input()

if to_static:
out = paddle.jit.to_static(nested_input, full_graph=True)(
self.x, self.y
)
else:
out = nested_input(self.x, self.y)
if to_static:
out = paddle.jit.to_static(nested_input, full_graph=True)(
self.x, self.y
)
else:
out = nested_input(self.x, self.y)

return out.numpy()

@test_legacy_and_pt_and_pir
def test_nest(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
Expand All @@ -100,20 +102,20 @@ def setUp(self):
self.y = None

def _run(self, to_static):
with base.dygraph.guard():
if self.x is None or self.y is None:
self.x = fake_data([10, 16])
self.y = fake_data([10, 16])

if to_static:
out = paddle.jit.to_static(nested_output, full_graph=True)(
self.x, self.y
)
else:
out = nested_output(self.x, self.y)
if self.x is None or self.y is None:
self.x = fake_data([10, 16])
self.y = fake_data([10, 16])

if to_static:
out = paddle.jit.to_static(nested_output, full_graph=True)(
self.x, self.y
)
else:
out = nested_output(self.x, self.y)

return out

@test_legacy_and_pt_and_pir
def test_nest(self):
dygraph_res = self._run(to_static=False)
dygraph_res = paddle.utils.flatten(dygraph_res)
Expand All @@ -124,7 +126,7 @@ def test_nest(self):
self.assertTrue(len(dygraph_res) == len(static_res))

for dy_var, st_var in zip(dygraph_res, static_res):
if isinstance(dy_var, base.core.eager.Tensor):
if isinstance(dy_var, paddle.Tensor):
np.testing.assert_allclose(
dy_var.numpy(), st_var.numpy(), rtol=1e-05
)
Expand All @@ -134,53 +136,101 @@ def test_nest(self):

class TestWithTrainAndEval(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt
def test_legacy_ir_switch_eval_and_train(self):
# TODO(cleanup-legacy-ir): Remove this test case
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((4, 10)).astype('float32')
x = base.dygraph.to_variable(x_data)
linear_net(x)

_, train_partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training
self.assertEqual(
train_partial_layer.program, train_partial_layer._train_program
)

# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
_, eval_partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
eval_partial_layer.program, eval_partial_layer._infer_program
)

# switch back into training
linear_net.train()
linear_net(x)
self.assertEqual(
train_partial_layer.program, train_partial_layer._train_program
)

@test_ast_only
@test_pir_only
def test_switch_eval_and_train(self):
with base.dygraph.guard():
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((4, 10)).astype('float32')
x = base.dygraph.to_variable(x_data)
linear_net(x)
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((4, 10)).astype('float32')
x = paddle.to_tensor(x_data)
linear_net(x)

_, train_partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training
self.assertEqual(
train_partial_layer.program,
train_partial_layer.train_program,
)

_, train_partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training
self.assertEqual(
train_partial_layer.program, train_partial_layer._train_program
)
# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
_, eval_partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
eval_partial_layer.program, eval_partial_layer.infer_program
)

# switch back into training
linear_net.train()
linear_net(x)
self.assertEqual(
train_partial_layer.program, train_partial_layer.train_program
)

# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
_, eval_partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
eval_partial_layer.program, eval_partial_layer._infer_program
)

# switch back into training
class TestWithNoGrad(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt
def test_legacy_ir_with_no_grad(self):
# TODO(cleanup-legacy-ir): Remove this test case
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((5, 10)).astype('float32')
x = paddle.to_tensor(x_data)

with paddle.no_grad():
linear_net.train()
linear_net(x)
# BUG: 我们希望这里 是 ASTStaticFunction(StaticFunction):
_, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
train_partial_layer.program, train_partial_layer._train_program
partial_layer.program, partial_layer._train_program
)


class TestWithNoGrad(Dy2StTestBase):
@test_ast_only
@test_pir_only
def test_with_no_grad(self):
with base.dygraph.guard():
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((5, 10)).astype('float32')
x = base.dygraph.to_variable(x_data)

with paddle.no_grad():
linear_net.train()
linear_net(x)
# BUG: 我们希望这里 是 ASTStaticFunction(StaticFunction):
_, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
partial_layer.program, partial_layer._train_program
)
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((5, 10)).astype('float32')
x = paddle.to_tensor(x_data)

with paddle.no_grad():
linear_net.train()
linear_net(x)
# BUG: 我们希望这里 是 ASTStaticFunction(StaticFunction):
_, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(partial_layer.program, partial_layer.train_program)


class GPT2LMHeadModel(paddle.nn.Layer):
Expand All @@ -192,7 +242,6 @@ def __init__(self):
np.random.rand(2, 3).astype('float32')
)

@to_static(full_graph=True)
def forward(self, x):
x = paddle.reshape(x, shape=[-1, 6])
x1, x2, x3 = paddle.split(x=x, axis=1, num_or_sections=3)
Expand All @@ -203,13 +252,11 @@ class TestPruneUnusedParamInProgram(Dy2StTestBase):
def test_prune(self):
input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32")

place = base.CPUPlace()
with base.dygraph.guard(place):
model = GPT2LMHeadModel()
model.eval()
input_ids = paddle.to_tensor(input_ids)
out = model(input_ids)
np.testing.assert_array_equal(out.numpy(), [[15, 11]])
model = paddle.jit.to_static(GPT2LMHeadModel())
model.eval()
input_ids = paddle.to_tensor(input_ids)
out = model(input_ids)
np.testing.assert_array_equal(out.numpy(), [[15, 11]])


if __name__ == '__main__':
Expand Down
Loading