Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PIR/Dy2static】fix 5 unittest- 3 yellow; 2 green #59894

Merged
merged 1 commit into from
Dec 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 6 additions & 4 deletions test/dygraph_to_static/test_build_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_default_mode_only,
test_default_and_pir,
test_legacy_and_pt_and_pir,
)
from test_resnet import ResNetHelper
Expand Down Expand Up @@ -66,7 +66,7 @@ def verify_predict(self):
err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@test_default_mode_only
@test_default_and_pir
def test_resnet(self):
static_loss = self.train(to_static=True)
dygraph_loss = self.train(to_static=False)
Expand All @@ -76,9 +76,11 @@ def test_resnet(self):
rtol=1e-05,
err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}',
)
self.verify_predict()
# TODO(@xiongkun): open after save / load supported in pir.
if not paddle.base.framework.use_pir_api():
self.verify_predict()

@test_default_mode_only
@test_default_and_pir
def test_in_static_mode_mkldnn(self):
paddle.set_flags({'FLAGS_use_mkldnn': True})
try:
Expand Down
20 changes: 10 additions & 10 deletions test/dygraph_to_static/test_mobile_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, test_pt_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_legacy_and_pt_and_pir,
)
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -590,7 +593,8 @@ def train_mobilenet(args, to_static):
batch_id += 1
t_last = time.time()
if batch_id > args.train_step:
if to_static:
# TODO(@xiongkun): open after save / load supported in pir.
if to_static and not paddle.base.framework.use_pir_api():
paddle.jit.save(net, args.model_save_prefix)
else:
paddle.save(
Expand Down Expand Up @@ -734,20 +738,16 @@ def assert_same_predict(self, model_name):
err_msg=f'inference_pred_res:\n {predictor_pre}\n, st_pre: \n{st_pre}.',
)

@test_pt_only
def test_mobile_net_pir(self):
# MobileNet-V1
self.assert_same_loss("MobileNetV1")
# MobileNet-V2
self.assert_same_loss("MobileNetV2")

@test_legacy_and_pt_and_pir
def test_mobile_net(self):
# MobileNet-V1
self.assert_same_loss("MobileNetV1")
# MobileNet-V2
self.assert_same_loss("MobileNetV2")

self.verify_predict()
# TODO(@xiongkun): open after save / load supported in pir.
if not paddle.base.framework.use_pir_api():
self.verify_predict()

def verify_predict(self):
# MobileNet-V1
Expand Down
18 changes: 13 additions & 5 deletions test/dygraph_to_static/test_se_resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@
import unittest

import numpy as np
from dygraph_to_static_utils import Dy2StTestBase, test_ast_only, test_pt_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_default_and_pir,
)
from predictor_utils import PredictorTools

import paddle
Expand Down Expand Up @@ -449,7 +452,11 @@ def train(self, train_reader, to_static):

step_idx += 1
if step_idx == STEP_NUM:
if to_static:
# TODO(@xiongkun): open after save / load supported in pir.
if (
to_static
and not paddle.base.framework.use_pir_api()
):
paddle.jit.save(
se_resnext,
self.model_save_prefix,
Expand Down Expand Up @@ -565,8 +572,7 @@ def verify_predict(self):
),
)

@test_ast_only
@test_pt_only
@test_default_and_pir
def test_check_result(self):
pred_1, loss_1, acc1_1, acc5_1 = self.train(
self.train_reader, to_static=False
Expand Down Expand Up @@ -600,7 +606,9 @@ def test_check_result(self):
err_msg=f'static acc5: {acc5_1} \ndygraph acc5: {acc5_2}',
)

self.verify_predict()
# TODO(@xiongkun): open after save / load supported in pir.
if not paddle.base.framework.use_pir_api():
self.verify_predict()


if __name__ == '__main__':
Expand Down
4 changes: 2 additions & 2 deletions test/dygraph_to_static/test_tsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_default_mode_only,
test_default_and_pir,
)
from tsm_config_utils import merge_configs, parse_config, print_configs

Expand Down Expand Up @@ -375,7 +375,7 @@ def train(args, fake_data_reader, to_static):


class TestTsm(Dy2StTestBase):
@test_default_mode_only
@test_default_and_pir
def test_dygraph_static_same_loss(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({"FLAGS_cudnn_deterministic": True})
Expand Down
8 changes: 6 additions & 2 deletions test/dygraph_to_static/test_yolov3.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,18 @@
import numpy as np
from dygraph_to_static_utils import (
Dy2StTestBase,
test_default_mode_only,
test_default_and_pir,
)
from yolov3 import YOLOv3, cfg

import paddle

if paddle.is_compiled_with_cuda():
paddle.base.set_flags({'FLAGS_cudnn_deterministic': True})

random.seed(0)
np.random.seed(0)
paddle.seed(0)


class SmoothedValue:
Expand Down Expand Up @@ -166,7 +170,7 @@ def train(to_static):


class TestYolov3(Dy2StTestBase):
@test_default_mode_only
@test_default_and_pir
def test_dygraph_static_same_loss(self):
dygraph_loss = train(to_static=False)
static_loss = train(to_static=True)
Expand Down