From 97f38dbefcd500da7f3ba45b8d1b8778dbc1950b Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Thu, 30 Nov 2023 14:46:30 +0000 Subject: [PATCH 1/6] add test_program_translator --- .../test_program_translator.py | 41 +++++++++++++------ 1 file changed, 28 insertions(+), 13 deletions(-) diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index a05779df7f113..739a4ff69bc2c 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -18,7 +18,11 @@ import astor import numpy as np -from dygraph_to_static_utils import Dy2StTestBase, test_ast_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pt_and_pir, +) from ifelse_simple_func import ( dyfunc_with_if_else_early_return1, dyfunc_with_if_else_early_return2, @@ -46,7 +50,6 @@ def simple_func(x, weight_numpy): return z -@to_static def decorated_simple_func(x, weight_numpy): x = base.dygraph.to_variable(x) w = base.dygraph.to_variable(weight_numpy) @@ -205,7 +208,8 @@ def false_fn_3(): class NetWithError(paddle.nn.Layer): - @to_static(full_graph=True) + __name__ = 'NetWithError' + def forward(self, x): linear = paddle.nn.Linear(32, 64) y = linear(x) @@ -218,21 +222,27 @@ def setUp(self): self.weight = np.random.randn(32, 64).astype('float32') @test_ast_only + @test_legacy_and_pt_and_pir def test_raise_error(self): with base.dygraph.guard(): paddle.jit.enable_to_static(True) - net = NetWithError() + net = to_static(full_graph=True)(NetWithError()) with self.assertRaises(ValueError): net(base.dygraph.to_variable(self.x)) + @test_legacy_and_pt_and_pir def test_enable_disable_declarative(self): paddle.jit.enable_to_static(True) with base.dygraph.guard(): - static_output = decorated_simple_func(self.x, self.weight) + static_output = to_static(decorated_simple_func)( + self.x, self.weight + ) paddle.jit.enable_to_static(False) with base.dygraph.guard(): - dygraph_output = decorated_simple_func(self.x, self.weight) + dygraph_output = to_static(decorated_simple_func)( + self.x, self.weight + ) np.testing.assert_allclose( static_output.numpy(), dygraph_output.numpy(), @@ -253,22 +263,25 @@ class SwitchModeNet(paddle.nn.Layer): def __init__(self): super().__init__() - @paddle.jit.to_static def forward(self, x): return x + 1 - @paddle.jit.to_static def foo(self): return True -@paddle.jit.to_static(full_graph=True) def switch_mode_function(): return True +switch_mode_function = paddle.jit.to_static(full_graph=True)( + switch_mode_function +) + + class TestFunctionTrainEvalMode(Dy2StTestBase): @test_ast_only + @test_legacy_and_pt_and_pir def test_switch_mode(self): paddle.disable_static() switch_mode_function.eval() @@ -283,9 +296,10 @@ def test_switch_mode(self): _, partial_layer = switch_mode_function.program_cache.last()[-1] self.assertEqual(partial_layer.training, True) + @test_legacy_and_pt_and_pir def test_raise_error(self): paddle.disable_static() - net = SwitchModeNet() + net = paddle.jit.to_static(SwitchModeNet()) self.assertEqual(net.training, True) with self.assertRaises(RuntimeError): @@ -294,7 +308,7 @@ def test_raise_error(self): net.eval() self.assertEqual(net.training, False) with self.assertRaises(RuntimeError): - net.foo.train() + paddle.jit.to_static(net.foo).train() class TestIfElseEarlyReturn(Dy2StTestBase): @@ -319,6 +333,7 @@ def func_with_comment(self): # Comment3 y = paddle.to_tensor([4, 5, 6]) + @test_legacy_and_pt_and_pir def test_remove_comment(self): code_string = func_to_source_code(self.func_with_comment) self.assertEqual('#' not in code_string, True) @@ -341,18 +356,18 @@ def __init__(self): self.layer1 = paddle.nn.Linear(10, 10) def forward(self, data): - @paddle.jit.to_static def func(ins, x, loss_fn): x = ins.layer1(x) return loss_fn(x) def func1(x): - return func(self, x, obj.func) + return paddle.jit.to_static(func)(self, x, obj.func) return func1(data) class TestParameterRecorder(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_recorder(self): """function calls nn.Layer case.""" net = Net() From eaef78a40bf934975dffed859b31f41938b72046 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Fri, 1 Dec 2023 02:36:26 +0000 Subject: [PATCH 2/6] add test_pir_selectedrows --- .../test_pir_selectedrows.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/test/dygraph_to_static/test_pir_selectedrows.py b/test/dygraph_to_static/test_pir_selectedrows.py index 6b6daa5edfb98..792658bd8071e 100644 --- a/test/dygraph_to_static/test_pir_selectedrows.py +++ b/test/dygraph_to_static/test_pir_selectedrows.py @@ -15,7 +15,7 @@ import random import unittest -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir import paddle from paddle.jit.api import to_static @@ -55,7 +55,7 @@ def forward(self, x): def train(net, adam, x): loss_data = [] - for i in range(10): + for _ in range(10): out = net(x) loss = paddle.mean(out) loss.backward() @@ -90,13 +90,18 @@ def train_static(): class TestSimnet(Dy2StTestBase): - def test_dygraph_static_same_loss(self): - dygraph_loss = train_dygraph() + def test(self): static_loss = train_static() - self.assertEqual(len(dygraph_loss), len(static_loss)) - for i in range(len(dygraph_loss)): - self.assertAlmostEqual(dygraph_loss[i], static_loss[i].numpy()) + @test_legacy_and_pt_and_pir + def test_dygraph_static_same_loss(self): + dygraph_loss = train_dygraph() + + self.assertEqual(len(dygraph_loss), len(static_loss)) + for i in range(len(dygraph_loss)): + self.assertAlmostEqual(dygraph_loss[i], static_loss[i].numpy()) + + test_dygraph_static_same_loss(self) if __name__ == '__main__': From 9332440275ab63260c72b66e84628441cb25583b Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Fri, 1 Dec 2023 09:08:21 +0000 Subject: [PATCH 3/6] roll back test_pir_selectedrows --- .../test_pir_selectedrows.py | 19 ++++++++----------- 1 file changed, 8 insertions(+), 11 deletions(-) diff --git a/test/dygraph_to_static/test_pir_selectedrows.py b/test/dygraph_to_static/test_pir_selectedrows.py index 792658bd8071e..54d0b7ce2e798 100644 --- a/test/dygraph_to_static/test_pir_selectedrows.py +++ b/test/dygraph_to_static/test_pir_selectedrows.py @@ -15,7 +15,9 @@ import random import unittest -from dygraph_to_static_utils import Dy2StTestBase, test_legacy_and_pt_and_pir +from dygraph_to_static_utils import ( + Dy2StTestBase, +) import paddle from paddle.jit.api import to_static @@ -90,18 +92,13 @@ def train_static(): class TestSimnet(Dy2StTestBase): - def test(self): + def test_dygraph_static_same_loss(self): + dygraph_loss = train_dygraph() static_loss = train_static() - @test_legacy_and_pt_and_pir - def test_dygraph_static_same_loss(self): - dygraph_loss = train_dygraph() - - self.assertEqual(len(dygraph_loss), len(static_loss)) - for i in range(len(dygraph_loss)): - self.assertAlmostEqual(dygraph_loss[i], static_loss[i].numpy()) - - test_dygraph_static_same_loss(self) + self.assertEqual(len(dygraph_loss), len(static_loss)) + for i in range(len(dygraph_loss)): + self.assertAlmostEqual(dygraph_loss[i], static_loss[i].numpy()) if __name__ == '__main__': From 5eb396aacf3441d690fc8aa06f0a09487a62c440 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Mon, 4 Dec 2023 03:29:28 +0000 Subject: [PATCH 4/6] no need backward() --- .../test_pir_selectedrows.py | 51 +++++++------------ 1 file changed, 17 insertions(+), 34 deletions(-) diff --git a/test/dygraph_to_static/test_pir_selectedrows.py b/test/dygraph_to_static/test_pir_selectedrows.py index 54d0b7ce2e798..4b8b55f9e44f6 100644 --- a/test/dygraph_to_static/test_pir_selectedrows.py +++ b/test/dygraph_to_static/test_pir_selectedrows.py @@ -17,6 +17,7 @@ from dygraph_to_static_utils import ( Dy2StTestBase, + test_legacy_and_pt_and_pir, ) import paddle @@ -29,17 +30,9 @@ class IRSelectedRowsTestNet(paddle.nn.Layer): def __init__(self): super().__init__() - self.embedding = paddle.nn.Embedding(4, 3, sparse=False) - - w0 = paddle.to_tensor( - [ - [0.0, 0.0, 0.0], - [1.0, 1.0, 1.0], - [2.0, 2.0, 2.0], - [3.0, 3.0, 3.0], - ], - dtype="float32", - ) + self.embedding = paddle.nn.Embedding(128, 3, sparse=False) + + w0 = paddle.rand([128, 3]) self.embedding.weight.set_value(w0) self.linear = paddle.nn.Linear( @@ -55,50 +48,40 @@ def forward(self, x): return x -def train(net, adam, x): +def forward(net, x): loss_data = [] for _ in range(10): out = net(x) loss = paddle.mean(out) - loss.backward() - adam.step() - adam.clear_grad() loss_data.append(loss.numpy()) return loss_data -def train_dygraph(): +def forward_dygraph(): paddle.seed(100) net = IRSelectedRowsTestNet() - x = paddle.to_tensor([[0], [1], [3]], dtype="int64", stop_gradient=False) - clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) - adam = paddle.optimizer.Adam( - parameters=net.parameters(), learning_rate=0.01, grad_clip=clip - ) + x = paddle.randint(low=0, high=128, shape=[64], dtype="int64") - return train(net, adam, x) + return forward(net, x) -def train_static(): +def forward_static(): paddle.seed(100) net = IRSelectedRowsTestNet() - x = paddle.to_tensor([[0], [1], [3]], dtype="int64", stop_gradient=False) - clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) - adam = paddle.optimizer.Adam( - parameters=net.parameters(), learning_rate=0.01, grad_clip=clip - ) + x = paddle.randint(low=0, high=128, shape=[64], dtype="int64") - return to_static(train, full_graph=True)(net, adam, x) + return to_static(forward, full_graph=True)(net, x) class TestSimnet(Dy2StTestBase): + @test_legacy_and_pt_and_pir def test_dygraph_static_same_loss(self): - dygraph_loss = train_dygraph() - static_loss = train_static() + dygraph_value = forward_dygraph() + static_value = forward_static() - self.assertEqual(len(dygraph_loss), len(static_loss)) - for i in range(len(dygraph_loss)): - self.assertAlmostEqual(dygraph_loss[i], static_loss[i].numpy()) + self.assertEqual(len(dygraph_value), len(static_value)) + for i in range(len(dygraph_value)): + self.assertAlmostEqual(dygraph_value[i], static_value[i].numpy()) if __name__ == '__main__': From da7acff4377ab25764ac4584a15ccb2e62167f2d Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 11 Dec 2023 11:36:32 +0000 Subject: [PATCH 5/6] update test_bert --- test/dygraph_to_static/test_bert.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/test/dygraph_to_static/test_bert.py b/test/dygraph_to_static/test_bert.py index f9d8620956f24..b4358ec07ce54 100644 --- a/test/dygraph_to_static/test_bert.py +++ b/test/dygraph_to_static/test_bert.py @@ -23,8 +23,6 @@ from dygraph_to_static_utils import ( Dy2StTestBase, enable_to_static_guard, - test_ast_only, - test_pt_only, test_sot_only, ) from predictor_utils import PredictorTools @@ -269,18 +267,6 @@ def predict_analysis_inference(self, data): out = output() return out - @test_pt_only - def test_train_pir(self): - static_loss, static_ppl = self.train_static( - self.bert_config, self.data_reader - ) - dygraph_loss, dygraph_ppl = self.train_dygraph( - self.bert_config, self.data_reader - ) - np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05) - np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05) - - @test_ast_only def test_train(self): static_loss, static_ppl = self.train_static( self.bert_config, self.data_reader From cecb51bf7bf9a81cca1994e07b4188cb18e5f518 Mon Sep 17 00:00:00 2001 From: SigureMo Date: Mon, 11 Dec 2023 11:52:00 +0000 Subject: [PATCH 6/6] update test_program_translator --- .../test_program_translator.py | 33 +++++++------------ 1 file changed, 12 insertions(+), 21 deletions(-) diff --git a/test/dygraph_to_static/test_program_translator.py b/test/dygraph_to_static/test_program_translator.py index ce98f5d987f51..73e1d457882d4 100644 --- a/test/dygraph_to_static/test_program_translator.py +++ b/test/dygraph_to_static/test_program_translator.py @@ -23,6 +23,7 @@ IrMode, ToStaticMode, disable_test_case, + enable_to_static_guard, test_ast_only, test_legacy_and_pt_and_pir, test_legacy_only, @@ -34,7 +35,6 @@ import paddle import paddle.jit.dy2static as _jst -from paddle import base from paddle.jit.api import to_static from paddle.jit.dy2static.utils import func_to_source_code from paddle.utils import gast @@ -42,21 +42,21 @@ np.random.seed(0) -# TODO(Aurelius): Currently, `declarative` don't support decorate the function +# TODO(Aurelius): Currently, `to_static` don't support decorate the function # that contains layers with initialized operation, like `fc = linear(10, 3)`. # Because initialized ops will be added into program and be executed many times. # The parameters are assumed to initialized outside of the function. def simple_func(x, weight_numpy): - x = base.dygraph.to_variable(x) - w = base.dygraph.to_variable(weight_numpy) + x = paddle.to_tensor(x) + w = paddle.to_tensor(weight_numpy) y = paddle.matmul(x, w) z = paddle.mean(y) return z def decorated_simple_func(x, weight_numpy): - x = base.dygraph.to_variable(x) - w = base.dygraph.to_variable(weight_numpy) + x = paddle.to_tensor(x) + w = paddle.to_tensor(weight_numpy) y = paddle.matmul(x, w) z = paddle.mean(y) return z @@ -228,22 +228,15 @@ def setUp(self): @test_ast_only @test_legacy_and_pt_and_pir def test_raise_error(self): - with base.dygraph.guard(): - paddle.jit.enable_to_static(True) - net = to_static(full_graph=True)(NetWithError()) - with self.assertRaises(ValueError): - net(base.dygraph.to_variable(self.x)) + net = to_static(full_graph=True)(NetWithError()) + with self.assertRaises(ValueError): + net(paddle.to_tensor(self.x)) @test_legacy_and_pt_and_pir - def test_enable_disable_declarative(self): - paddle.jit.enable_to_static(True) - with base.dygraph.guard(): - static_output = to_static(decorated_simple_func)( - self.x, self.weight - ) + def test_enable_disable_to_static(self): + static_output = to_static(decorated_simple_func)(self.x, self.weight) - paddle.jit.enable_to_static(False) - with base.dygraph.guard(): + with enable_to_static_guard(False): dygraph_output = to_static(decorated_simple_func)( self.x, self.weight ) @@ -287,7 +280,6 @@ class TestFunctionTrainEvalMode(Dy2StTestBase): @test_ast_only @test_legacy_and_pt_and_pir def test_switch_mode(self): - paddle.disable_static() switch_mode_function.eval() switch_mode_function() self.assertEqual(switch_mode_function._training, False) @@ -302,7 +294,6 @@ def test_switch_mode(self): @test_legacy_and_pt_and_pir def test_raise_error(self): - paddle.disable_static() net = paddle.jit.to_static(SwitchModeNet()) self.assertEqual(net.training, True)