diff --git a/test/dygraph_to_static/test_break_continue.py b/test/dygraph_to_static/test_break_continue.py index ef5efd047e247..8fed82dc91c45 100644 --- a/test/dygraph_to_static/test_break_continue.py +++ b/test/dygraph_to_static/test_break_continue.py @@ -19,7 +19,6 @@ Dy2StTestBase, enable_to_static_guard, test_ast_only, - test_legacy_and_pt, test_legacy_and_pt_and_pir, ) @@ -37,6 +36,7 @@ def setUp(self): self.error = "Your if/else have different number of return value." @test_ast_only + @test_legacy_and_pt_and_pir def test_error(self): if self.dyfunc: with self.assertRaisesRegex(Dygraph2StaticException, self.error): @@ -241,8 +241,8 @@ def test_transformed_static_result(self): ) -# TODO(pir-control-flow): Fix this after we support control-flow in PIR class TestContinueNotPirBase(TestContinueInFor): + @test_legacy_and_pt_and_pir def test_transformed_static_result(self): self.init_dygraph_func() dygraph_res = self.run_dygraph_mode() @@ -355,7 +355,6 @@ def init_dygraph_func(self): self.dygraph_func = test_optim_break_in_while # TODO: Open PIR test when while_loop dy2st fixed - @test_legacy_and_pt def test_transformed_static_result(self): self.init_dygraph_func() dygraph_res = self.run_dygraph_mode() diff --git a/test/dygraph_to_static/test_jit_setitem.py b/test/dygraph_to_static/test_jit_setitem.py index 5dbce38790300..1b36323bab99c 100644 --- a/test/dygraph_to_static/test_jit_setitem.py +++ b/test/dygraph_to_static/test_jit_setitem.py @@ -202,6 +202,7 @@ def run_dygraph(self, func): y = func() return (y,) + @test_legacy_and_pt_and_pir def test_case(self): func = self.init_func() dy_res = self.run_dygraph(func) diff --git a/test/dygraph_to_static/test_tensor_shape.py b/test/dygraph_to_static/test_tensor_shape.py index bbfe6703b8246..01315d328397b 100644 --- a/test/dygraph_to_static/test_tensor_shape.py +++ b/test/dygraph_to_static/test_tensor_shape.py @@ -230,6 +230,32 @@ def dyfunc_dict_assign_shape(): a['shape'] = x.shape[0] +def walk(block, fn): + fn(block) + for op in block.ops: + for sub_block in op.blocks(): + walk(sub_block, fn) + + +def get_op_num_in_block(block, op_name): + num_ops = 0 + for op in block.ops: + if op.name() == op_name: + num_ops += 1 + return num_ops + + +def get_op_num_in_program(program, op_name): + num_ops = 0 + + def _calc_op_num(block): + nonlocal num_ops + num_ops += get_op_num_in_block(block, op_name) + + walk(program.global_block(), _calc_op_num) + return num_ops + + # 1. Basic tests without control flow class TestTensorShapeBasic(Dy2StTestBase): def setUp(self): @@ -277,7 +303,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 3 - self.pir_expected_shape_op_num = 1 + self.pir_expected_shape_op_num = 0 self.pir_expected_slice_op_num = 0 def _compute_op_num(self, program): @@ -292,23 +318,8 @@ def _compute_op_num(self, program): def _compute_pir_op_num(self, program): op_num = program.global_block().num_ops() - shape_op_num = 0 - slice_op_num = 0 - - shape_op_num += len( - [ - op - for op in program.global_block().ops - if op.name() == "pd_op.reshape" - ] - ) - slice_op_num += len( - [ - op - for op in program.global_block().ops - if op.name() == "pd_op.slice" - ] - ) + shape_op_num = get_op_num_in_program(program, "pd_op.shape") + slice_op_num = get_op_num_in_program(program, "pd_op.slice") return op_num, shape_op_num, slice_op_num @test_ast_only @@ -342,7 +353,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 3 - self.pir_expected_shape_op_num = 1 + self.pir_expected_shape_op_num = 0 self.pir_expected_slice_op_num = 0 @@ -357,7 +368,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 4 - self.pir_expected_shape_op_num = 1 + self.pir_expected_shape_op_num = 0 self.pir_expected_slice_op_num = 0 @@ -377,7 +388,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 3 - self.pir_expected_shape_op_num = 1 + self.pir_expected_shape_op_num = 0 self.pir_expected_slice_op_num = 0 @@ -392,7 +403,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 3 - self.pir_expected_shape_op_num = 1 + self.pir_expected_shape_op_num = 0 self.pir_expected_slice_op_num = 0 @@ -464,7 +475,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 14 - self.pir_expected_shape_op_num = 1 + self.pir_expected_shape_op_num = 2 self.pir_expected_slice_op_num = 2 @@ -480,7 +491,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 3 - self.pir_expected_shape_op_num = 1 + self.pir_expected_shape_op_num = 0 self.pir_expected_slice_op_num = 0 @@ -647,23 +658,8 @@ def _compute_op_num(self, program): def _compute_pir_op_num(self, program): op_num = program.global_block().num_ops() - shape_op_num = 0 - slice_op_num = 0 - - shape_op_num += len( - [ - op - for op in program.global_block().ops - if op.name() == "pd_op.reshape" - ] - ) - slice_op_num += len( - [ - op - for op in program.global_block().ops - if op.name() == "pd_op.slice" - ] - ) + shape_op_num = get_op_num_in_program(program, "pd_op.shape") + slice_op_num = get_op_num_in_program(program, "pd_op.slice") return op_num, shape_op_num, slice_op_num @test_ast_only @@ -698,7 +694,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 15 - self.pir_expected_shape_op_num = 1 + self.pir_expected_shape_op_num = 2 self.pir_expected_slice_op_num = 2 @@ -728,8 +724,8 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 41 - self.pir_expected_shape_op_num = 1 - self.pir_expected_slice_op_num = 1 + self.pir_expected_shape_op_num = 4 + self.pir_expected_slice_op_num = 4 class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape): @@ -742,15 +738,9 @@ def _set_expected_op_num(self): self.expected_slice_op_num = 3 def _set_pir_expected_op_num(self): - self.pir_expected_op_num = 2 - self.pir_expected_shape_op_num = 0 - self.pir_expected_slice_op_num = 0 - - @test_ast_only - @test_pir_only - def test_pir_op_num(self): - # Remove this after we support control flow - pass + self.pir_expected_op_num = 35 + self.pir_expected_shape_op_num = 2 + self.pir_expected_slice_op_num = 3 class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape): @@ -764,8 +754,8 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 27 - self.pir_expected_shape_op_num = 0 - self.pir_expected_slice_op_num = 2 + self.pir_expected_shape_op_num = 3 + self.pir_expected_slice_op_num = 3 class TestChangeShapeAfterAssign(TestTensorShapeBasic): @@ -783,7 +773,7 @@ def _set_expected_op_num(self): def _set_pir_expected_op_num(self): self.pir_expected_op_num = 12 - self.pir_expected_shape_op_num = 2 + self.pir_expected_shape_op_num = 1 self.pir_expected_slice_op_num = 1 diff --git a/test/dygraph_to_static/test_write_python_container.py b/test/dygraph_to_static/test_write_python_container.py index ba9de76351a00..c7960e1cc87ca 100644 --- a/test/dygraph_to_static/test_write_python_container.py +++ b/test/dygraph_to_static/test_write_python_container.py @@ -14,7 +14,12 @@ import unittest -from dygraph_to_static_utils import Dy2StTestBase, test_ast_only, test_sot_only +from dygraph_to_static_utils import ( + Dy2StTestBase, + test_ast_only, + test_legacy_and_pt_and_pir, + test_sot_only, +) import paddle @@ -113,6 +118,7 @@ def get_raw_value(self, container, getitem_path): return out @test_sot_only + @test_legacy_and_pt_and_pir def test_write_container_sot(self): func_static = paddle.jit.to_static(self.func) input = paddle.to_tensor([1, 2, 3])