Skip to content

Commit

Permalink
[Dy2St] pir dy2st unittest verification - Part 10 (#59276)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <[email protected]>
  • Loading branch information
gouzil and SigureMo authored Nov 28, 2023
1 parent 797c800 commit f9812f6
Show file tree
Hide file tree
Showing 4 changed files with 157 additions and 100 deletions.
4 changes: 3 additions & 1 deletion test/dygraph_to_static/test_deepcopy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,15 @@
from copy import deepcopy

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir
from test_rollback import Net, foo

import paddle
from paddle.jit.dy2static.program_translator import StaticFunction


class TestDeepCopy(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_net(self):
net = Net()
net = paddle.jit.to_static(net)
Expand All @@ -39,6 +40,7 @@ def test_net(self):
self.assertTrue(id(copy_net), id(copy_net.forward.__self__))
np.testing.assert_array_equal(src_out.numpy(), copy_out.numpy())

@test_legacy_and_pt_and_pir
def test_func(self):
st_foo = paddle.jit.to_static(foo)
x = paddle.randn([3, 4])
Expand Down
57 changes: 32 additions & 25 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,15 @@
# limitations under the License.

import inspect
import tempfile
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, test_sot_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
test_sot_only,
)

import paddle
import paddle.nn.functional as F
Expand Down Expand Up @@ -248,6 +253,7 @@ def setUp(self):

self.nested_for_loop_func = nested_for_loop_dyfunc

@test_legacy_and_pt_and_pir
def test_loop_vars(self):
for i in range(len(self.loop_funcs)):
func = self.loop_funcs[i]
Expand All @@ -263,6 +269,7 @@ def test_loop_vars(self):
self.assertEqual(loop_var_names, self.loop_var_names[i])
self.assertEqual(create_var_names, self.create_var_names[i])

@test_legacy_and_pt_and_pir
def test_nested_loop_vars(self):
func = self.nested_for_loop_func
test_func = inspect.getsource(func)
Expand Down Expand Up @@ -303,9 +310,9 @@ def test_nested_loop_vars(self):
class TestTransformWhileLoop(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.x = np.zeros(shape=(1), dtype=np.int32)
self._init_dyfunc()
Expand All @@ -320,17 +327,16 @@ def _run_dygraph(self):
return self._run(to_static=False)

def _run(self, to_static):
with base.dygraph.guard(self.place):
# Set the input of dyfunc to Tensor
tensor_x = base.dygraph.to_variable(self.x, zero_copy=False)
if to_static:
ret = paddle.jit.to_static(self.dyfunc)(tensor_x)
else:
ret = self.dyfunc(tensor_x)
if hasattr(ret, "numpy"):
return ret.numpy()
else:
return ret
# Set the input of dyfunc to Tensor
tensor_x = base.dygraph.to_variable(self.x, zero_copy=False)
if to_static:
ret = paddle.jit.to_static(self.dyfunc)(tensor_x)
else:
ret = self.dyfunc(tensor_x)
if hasattr(ret, "numpy"):
return ret.numpy()
else:
return ret

@test_sot_only
def test_ast_to_func(self):
Expand Down Expand Up @@ -383,9 +389,9 @@ def _init_dyfunc(self):
class TestTransformForLoop(Dy2StTestBase):
def setUp(self):
self.place = (
base.CUDAPlace(0)
if base.is_compiled_with_cuda()
else base.CPUPlace()
paddle.CUDAPlace(0)
if paddle.is_compiled_with_cuda()
else paddle.CPUPlace()
)
self.len = 100
self._init_dyfunc()
Expand All @@ -400,12 +406,11 @@ def _run_dygraph(self):
return self._run(to_static=False)

def _run(self, to_static):
with base.dygraph.guard(self.place):
if to_static:
ret = paddle.jit.to_static(self.dyfunc)(self.len)
else:
ret = self.dyfunc(self.len)
return ret.numpy()
if to_static:
ret = paddle.jit.to_static(self.dyfunc)(self.len)
else:
ret = self.dyfunc(self.len)
return ret.numpy()

@test_sot_only
def test_ast_to_func(self):
Expand Down Expand Up @@ -474,7 +479,9 @@ def test_start(self):
)
],
)
paddle.jit.save(model, "./inference/inference")
temp_dir = tempfile.TemporaryDirectory()
paddle.jit.save(model, temp_dir.name)
temp_dir.cleanup()


if __name__ == '__main__':
Expand Down
179 changes: 113 additions & 66 deletions test/dygraph_to_static/test_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt,
test_legacy_and_pt_and_pir,
test_pir_only,
)
from test_fetch_feed import Linear

import paddle
from paddle import base
from paddle.jit.api import to_static

SEED = 2020

Expand Down Expand Up @@ -75,19 +77,19 @@ def fake_input(self):
]

def _run(self, to_static):
with base.dygraph.guard():
if self.x is None or self.y is None:
self.fake_input()
if self.x is None or self.y is None:
self.fake_input()

if to_static:
out = paddle.jit.to_static(nested_input, full_graph=True)(
self.x, self.y
)
else:
out = nested_input(self.x, self.y)
if to_static:
out = paddle.jit.to_static(nested_input, full_graph=True)(
self.x, self.y
)
else:
out = nested_input(self.x, self.y)

return out.numpy()

@test_legacy_and_pt_and_pir
def test_nest(self):
dygraph_res = self._run(to_static=False)
static_res = self._run(to_static=True)
Expand All @@ -100,20 +102,20 @@ def setUp(self):
self.y = None

def _run(self, to_static):
with base.dygraph.guard():
if self.x is None or self.y is None:
self.x = fake_data([10, 16])
self.y = fake_data([10, 16])

if to_static:
out = paddle.jit.to_static(nested_output, full_graph=True)(
self.x, self.y
)
else:
out = nested_output(self.x, self.y)
if self.x is None or self.y is None:
self.x = fake_data([10, 16])
self.y = fake_data([10, 16])

if to_static:
out = paddle.jit.to_static(nested_output, full_graph=True)(
self.x, self.y
)
else:
out = nested_output(self.x, self.y)

return out

@test_legacy_and_pt_and_pir
def test_nest(self):
dygraph_res = self._run(to_static=False)
dygraph_res = paddle.utils.flatten(dygraph_res)
Expand All @@ -124,7 +126,7 @@ def test_nest(self):
self.assertTrue(len(dygraph_res) == len(static_res))

for dy_var, st_var in zip(dygraph_res, static_res):
if isinstance(dy_var, base.core.eager.Tensor):
if isinstance(dy_var, paddle.Tensor):
np.testing.assert_allclose(
dy_var.numpy(), st_var.numpy(), rtol=1e-05
)
Expand All @@ -134,53 +136,101 @@ def test_nest(self):

class TestWithTrainAndEval(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt
def test_legacy_ir_switch_eval_and_train(self):
# TODO(cleanup-legacy-ir): Remove this test case
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((4, 10)).astype('float32')
x = base.dygraph.to_variable(x_data)
linear_net(x)

_, train_partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training
self.assertEqual(
train_partial_layer.program, train_partial_layer._train_program
)

# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
_, eval_partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
eval_partial_layer.program, eval_partial_layer._infer_program
)

# switch back into training
linear_net.train()
linear_net(x)
self.assertEqual(
train_partial_layer.program, train_partial_layer._train_program
)

@test_ast_only
@test_pir_only
def test_switch_eval_and_train(self):
with base.dygraph.guard():
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((4, 10)).astype('float32')
x = base.dygraph.to_variable(x_data)
linear_net(x)
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((4, 10)).astype('float32')
x = paddle.to_tensor(x_data)
linear_net(x)

_, train_partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training
self.assertEqual(
train_partial_layer.program,
train_partial_layer.train_program,
)

_, train_partial_layer = linear_net.forward.program_cache.last()[-1]
# check default mode is for training
self.assertEqual(
train_partial_layer.program, train_partial_layer._train_program
)
# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
_, eval_partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
eval_partial_layer.program, eval_partial_layer.infer_program
)

# switch back into training
linear_net.train()
linear_net(x)
self.assertEqual(
train_partial_layer.program, train_partial_layer.train_program
)

# switch to run test program after `eval()`
linear_net.eval()
linear_net(x)
_, eval_partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
eval_partial_layer.program, eval_partial_layer._infer_program
)

# switch back into training
class TestWithNoGrad(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt
def test_legacy_ir_with_no_grad(self):
# TODO(cleanup-legacy-ir): Remove this test case
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((5, 10)).astype('float32')
x = paddle.to_tensor(x_data)

with paddle.no_grad():
linear_net.train()
linear_net(x)
# BUG: 我们希望这里 是 ASTStaticFunction(StaticFunction):
_, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
train_partial_layer.program, train_partial_layer._train_program
partial_layer.program, partial_layer._train_program
)


class TestWithNoGrad(Dy2StTestBase):
@test_ast_only
@test_pir_only
def test_with_no_grad(self):
with base.dygraph.guard():
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((5, 10)).astype('float32')
x = base.dygraph.to_variable(x_data)

with paddle.no_grad():
linear_net.train()
linear_net(x)
# BUG: 我们希望这里 是 ASTStaticFunction(StaticFunction):
_, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(
partial_layer.program, partial_layer._train_program
)
linear_net = Linear()
linear_net = paddle.jit.to_static(linear_net, full_graph=True)
x_data = np.random.random((5, 10)).astype('float32')
x = paddle.to_tensor(x_data)

with paddle.no_grad():
linear_net.train()
linear_net(x)
# BUG: 我们希望这里 是 ASTStaticFunction(StaticFunction):
_, partial_layer = linear_net.forward.program_cache.last()[-1]
self.assertEqual(partial_layer.program, partial_layer.train_program)


class GPT2LMHeadModel(paddle.nn.Layer):
Expand All @@ -192,7 +242,6 @@ def __init__(self):
np.random.rand(2, 3).astype('float32')
)

@to_static(full_graph=True)
def forward(self, x):
x = paddle.reshape(x, shape=[-1, 6])
x1, x2, x3 = paddle.split(x=x, axis=1, num_or_sections=3)
Expand All @@ -203,13 +252,11 @@ class TestPruneUnusedParamInProgram(Dy2StTestBase):
def test_prune(self):
input_ids = np.array([[15, 11, 6, 3, 18, 13]]).astype("float32")

place = base.CPUPlace()
with base.dygraph.guard(place):
model = GPT2LMHeadModel()
model.eval()
input_ids = paddle.to_tensor(input_ids)
out = model(input_ids)
np.testing.assert_array_equal(out.numpy(), [[15, 11]])
model = paddle.jit.to_static(GPT2LMHeadModel())
model.eval()
input_ids = paddle.to_tensor(input_ids)
out = model(input_ids)
np.testing.assert_array_equal(out.numpy(), [[15, 11]])


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit f9812f6

Please sign in to comment.