Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] pir dy2st unittest verification - Part -3 #59571

Merged
merged 9 commits into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 0 additions & 14 deletions test/dygraph_to_static/test_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_ast_only,
test_pt_only,
test_sot_only,
)
from predictor_utils import PredictorTools
Expand Down Expand Up @@ -269,18 +267,6 @@ def predict_analysis_inference(self, data):
out = output()
return out

@test_pt_only
def test_train_pir(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
)
dygraph_loss, dygraph_ppl = self.train_dygraph(
self.bert_config, self.data_reader
)
np.testing.assert_allclose(static_loss, dygraph_loss, rtol=1e-05)
np.testing.assert_allclose(static_ppl, dygraph_ppl, rtol=1e-05)

@test_ast_only
def test_train(self):
static_loss, static_ppl = self.train_static(
self.bert_config, self.data_reader
Expand Down
57 changes: 21 additions & 36 deletions test/dygraph_to_static/test_pir_selectedrows.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import random
import unittest

from dygraph_to_static_utils import Dy2StTestBase
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)

import paddle
from paddle.jit.api import to_static
Expand All @@ -27,17 +30,9 @@
class IRSelectedRowsTestNet(paddle.nn.Layer):
def __init__(self):
super().__init__()
self.embedding = paddle.nn.Embedding(4, 3, sparse=False)

w0 = paddle.to_tensor(
[
[0.0, 0.0, 0.0],
[1.0, 1.0, 1.0],
[2.0, 2.0, 2.0],
[3.0, 3.0, 3.0],
],
dtype="float32",
)
self.embedding = paddle.nn.Embedding(128, 3, sparse=False)

w0 = paddle.rand([128, 3])
self.embedding.weight.set_value(w0)

self.linear = paddle.nn.Linear(
Expand All @@ -53,50 +48,40 @@ def forward(self, x):
return x


def train(net, adam, x):
def forward(net, x):
loss_data = []
for i in range(10):
for _ in range(10):
out = net(x)
loss = paddle.mean(out)
loss.backward()
adam.step()
adam.clear_grad()
loss_data.append(loss.numpy())
return loss_data


def train_dygraph():
def forward_dygraph():
paddle.seed(100)
net = IRSelectedRowsTestNet()
x = paddle.to_tensor([[0], [1], [3]], dtype="int64", stop_gradient=False)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
adam = paddle.optimizer.Adam(
parameters=net.parameters(), learning_rate=0.01, grad_clip=clip
)
x = paddle.randint(low=0, high=128, shape=[64], dtype="int64")

return train(net, adam, x)
return forward(net, x)


def train_static():
def forward_static():
paddle.seed(100)
net = IRSelectedRowsTestNet()
x = paddle.to_tensor([[0], [1], [3]], dtype="int64", stop_gradient=False)
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
adam = paddle.optimizer.Adam(
parameters=net.parameters(), learning_rate=0.01, grad_clip=clip
)
x = paddle.randint(low=0, high=128, shape=[64], dtype="int64")

return to_static(train, full_graph=True)(net, adam, x)
return to_static(forward, full_graph=True)(net, x)


class TestSimnet(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_dygraph_static_same_loss(self):
dygraph_loss = train_dygraph()
static_loss = train_static()
dygraph_value = forward_dygraph()
static_value = forward_static()

self.assertEqual(len(dygraph_loss), len(static_loss))
for i in range(len(dygraph_loss)):
self.assertAlmostEqual(dygraph_loss[i], static_loss[i].numpy())
self.assertEqual(len(dygraph_value), len(static_value))
for i in range(len(dygraph_value)):
self.assertAlmostEqual(dygraph_value[i], static_value[i].numpy())


if __name__ == '__main__':
Expand Down
65 changes: 34 additions & 31 deletions test/dygraph_to_static/test_program_translator.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@
IrMode,
ToStaticMode,
disable_test_case,
enable_to_static_guard,
test_ast_only,
test_legacy_and_pt_and_pir,
test_legacy_only,
)
from ifelse_simple_func import (
Expand All @@ -33,30 +35,28 @@

import paddle
import paddle.jit.dy2static as _jst
from paddle import base
from paddle.jit.api import to_static
from paddle.jit.dy2static.utils import func_to_source_code
from paddle.utils import gast

np.random.seed(0)


# TODO(Aurelius): Currently, `declarative` don't support decorate the function
# TODO(Aurelius): Currently, `to_static` don't support decorate the function
# that contains layers with initialized operation, like `fc = linear(10, 3)`.
# Because initialized ops will be added into program and be executed many times.
# The parameters are assumed to initialized outside of the function.
def simple_func(x, weight_numpy):
x = base.dygraph.to_variable(x)
w = base.dygraph.to_variable(weight_numpy)
x = paddle.to_tensor(x)
w = paddle.to_tensor(weight_numpy)
y = paddle.matmul(x, w)
z = paddle.mean(y)
return z


@to_static
def decorated_simple_func(x, weight_numpy):
x = base.dygraph.to_variable(x)
w = base.dygraph.to_variable(weight_numpy)
x = paddle.to_tensor(x)
w = paddle.to_tensor(weight_numpy)
y = paddle.matmul(x, w)
z = paddle.mean(y)
return z
Expand Down Expand Up @@ -212,7 +212,8 @@ def false_fn_3():


class NetWithError(paddle.nn.Layer):
@to_static(full_graph=True)
__name__ = 'NetWithError'

def forward(self, x):
linear = paddle.nn.Linear(32, 64)
y = linear(x)
Expand All @@ -225,21 +226,20 @@ def setUp(self):
self.weight = np.random.randn(32, 64).astype('float32')

@test_ast_only
@test_legacy_and_pt_and_pir
def test_raise_error(self):
with base.dygraph.guard():
paddle.jit.enable_to_static(True)
net = NetWithError()
with self.assertRaises(ValueError):
net(base.dygraph.to_variable(self.x))

def test_enable_disable_declarative(self):
paddle.jit.enable_to_static(True)
with base.dygraph.guard():
static_output = decorated_simple_func(self.x, self.weight)

paddle.jit.enable_to_static(False)
with base.dygraph.guard():
dygraph_output = decorated_simple_func(self.x, self.weight)
net = to_static(full_graph=True)(NetWithError())
with self.assertRaises(ValueError):
net(paddle.to_tensor(self.x))

@test_legacy_and_pt_and_pir
def test_enable_disable_to_static(self):
static_output = to_static(decorated_simple_func)(self.x, self.weight)

with enable_to_static_guard(False):
dygraph_output = to_static(decorated_simple_func)(
self.x, self.weight
)
np.testing.assert_allclose(
static_output.numpy(),
dygraph_output.numpy(),
Expand All @@ -260,24 +260,26 @@ class SwitchModeNet(paddle.nn.Layer):
def __init__(self):
super().__init__()

@paddle.jit.to_static
def forward(self, x):
return x + 1

@paddle.jit.to_static
def foo(self):
return True


@paddle.jit.to_static(full_graph=True)
def switch_mode_function():
return True


switch_mode_function = paddle.jit.to_static(full_graph=True)(
switch_mode_function
)


class TestFunctionTrainEvalMode(Dy2StTestBase):
@test_ast_only
@test_legacy_and_pt_and_pir
def test_switch_mode(self):
paddle.disable_static()
switch_mode_function.eval()
switch_mode_function()
self.assertEqual(switch_mode_function._training, False)
Expand All @@ -290,9 +292,9 @@ def test_switch_mode(self):
_, partial_layer = switch_mode_function.program_cache.last()[-1]
self.assertEqual(partial_layer.training, True)

@test_legacy_and_pt_and_pir
def test_raise_error(self):
paddle.disable_static()
net = SwitchModeNet()
net = paddle.jit.to_static(SwitchModeNet())

self.assertEqual(net.training, True)
with self.assertRaises(RuntimeError):
Expand All @@ -301,7 +303,7 @@ def test_raise_error(self):
net.eval()
self.assertEqual(net.training, False)
with self.assertRaises(RuntimeError):
net.foo.train()
paddle.jit.to_static(net.foo).train()


class TestIfElseEarlyReturn(Dy2StTestBase):
Expand Down Expand Up @@ -329,6 +331,7 @@ def func_with_comment(self):
# Comment3
y = paddle.to_tensor([4, 5, 6])

@test_legacy_and_pt_and_pir
def test_remove_comment(self):
code_string = func_to_source_code(self.func_with_comment)
self.assertEqual('#' not in code_string, True)
Expand All @@ -351,18 +354,18 @@ def __init__(self):
self.layer1 = paddle.nn.Linear(10, 10)

def forward(self, data):
@paddle.jit.to_static
def func(ins, x, loss_fn):
x = ins.layer1(x)
return loss_fn(x)

def func1(x):
return func(self, x, obj.func)
return paddle.jit.to_static(func)(self, x, obj.func)

return func1(data)


class TestParameterRecorder(Dy2StTestBase):
@test_legacy_and_pt_and_pir
def test_recorder(self):
"""function calls nn.Layer case."""
net = Net()
Expand Down