Skip to content

Commit

Permalink
[Dy2St][PIR] Enable some dy2st control flow unittests (#61370)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Feb 1, 2024
1 parent 20f87c4 commit be7e923
Show file tree
Hide file tree
Showing 4 changed files with 56 additions and 60 deletions.
5 changes: 2 additions & 3 deletions test/dygraph_to_static/test_break_continue.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
Dy2StTestBase,
enable_to_static_guard,
test_ast_only,
test_legacy_and_pt,
test_legacy_and_pt_and_pir,
)

Expand All @@ -37,6 +36,7 @@ def setUp(self):
self.error = "Your if/else have different number of return value."

@test_ast_only
@test_legacy_and_pt_and_pir
def test_error(self):
if self.dyfunc:
with self.assertRaisesRegex(Dygraph2StaticException, self.error):
Expand Down Expand Up @@ -241,8 +241,8 @@ def test_transformed_static_result(self):
)


# TODO(pir-control-flow): Fix this after we support control-flow in PIR
class TestContinueNotPirBase(TestContinueInFor):
@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
dygraph_res = self.run_dygraph_mode()
Expand Down Expand Up @@ -355,7 +355,6 @@ def init_dygraph_func(self):
self.dygraph_func = test_optim_break_in_while

# TODO: Open PIR test when while_loop dy2st fixed
@test_legacy_and_pt
def test_transformed_static_result(self):
self.init_dygraph_func()
dygraph_res = self.run_dygraph_mode()
Expand Down
1 change: 1 addition & 0 deletions test/dygraph_to_static/test_jit_setitem.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ def run_dygraph(self, func):
y = func()
return (y,)

@test_legacy_and_pt_and_pir
def test_case(self):
func = self.init_func()
dy_res = self.run_dygraph(func)
Expand Down
102 changes: 46 additions & 56 deletions test/dygraph_to_static/test_tensor_shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,32 @@ def dyfunc_dict_assign_shape():
a['shape'] = x.shape[0]


def walk(block, fn):
fn(block)
for op in block.ops:
for sub_block in op.blocks():
walk(sub_block, fn)


def get_op_num_in_block(block, op_name):
num_ops = 0
for op in block.ops:
if op.name() == op_name:
num_ops += 1
return num_ops


def get_op_num_in_program(program, op_name):
num_ops = 0

def _calc_op_num(block):
nonlocal num_ops
num_ops += get_op_num_in_block(block, op_name)

walk(program.global_block(), _calc_op_num)
return num_ops


# 1. Basic tests without control flow
class TestTensorShapeBasic(Dy2StTestBase):
def setUp(self):
Expand Down Expand Up @@ -277,7 +303,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 3
self.pir_expected_shape_op_num = 1
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0

def _compute_op_num(self, program):
Expand All @@ -292,23 +318,8 @@ def _compute_op_num(self, program):

def _compute_pir_op_num(self, program):
op_num = program.global_block().num_ops()
shape_op_num = 0
slice_op_num = 0

shape_op_num += len(
[
op
for op in program.global_block().ops
if op.name() == "pd_op.reshape"
]
)
slice_op_num += len(
[
op
for op in program.global_block().ops
if op.name() == "pd_op.slice"
]
)
shape_op_num = get_op_num_in_program(program, "pd_op.shape")
slice_op_num = get_op_num_in_program(program, "pd_op.slice")
return op_num, shape_op_num, slice_op_num

@test_ast_only
Expand Down Expand Up @@ -342,7 +353,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 3
self.pir_expected_shape_op_num = 1
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0


Expand All @@ -357,7 +368,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 4
self.pir_expected_shape_op_num = 1
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0


Expand All @@ -377,7 +388,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 3
self.pir_expected_shape_op_num = 1
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0


Expand All @@ -392,7 +403,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 3
self.pir_expected_shape_op_num = 1
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0


Expand Down Expand Up @@ -464,7 +475,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 14
self.pir_expected_shape_op_num = 1
self.pir_expected_shape_op_num = 2
self.pir_expected_slice_op_num = 2


Expand All @@ -480,7 +491,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 3
self.pir_expected_shape_op_num = 1
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0


Expand Down Expand Up @@ -647,23 +658,8 @@ def _compute_op_num(self, program):

def _compute_pir_op_num(self, program):
op_num = program.global_block().num_ops()
shape_op_num = 0
slice_op_num = 0

shape_op_num += len(
[
op
for op in program.global_block().ops
if op.name() == "pd_op.reshape"
]
)
slice_op_num += len(
[
op
for op in program.global_block().ops
if op.name() == "pd_op.slice"
]
)
shape_op_num = get_op_num_in_program(program, "pd_op.shape")
slice_op_num = get_op_num_in_program(program, "pd_op.slice")
return op_num, shape_op_num, slice_op_num

@test_ast_only
Expand Down Expand Up @@ -698,7 +694,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 15
self.pir_expected_shape_op_num = 1
self.pir_expected_shape_op_num = 2
self.pir_expected_slice_op_num = 2


Expand Down Expand Up @@ -728,8 +724,8 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 41
self.pir_expected_shape_op_num = 1
self.pir_expected_slice_op_num = 1
self.pir_expected_shape_op_num = 4
self.pir_expected_slice_op_num = 4


class TestOpNumWithTensorShapeInFor1(TestOpNumBasicWithTensorShape):
Expand All @@ -742,15 +738,9 @@ def _set_expected_op_num(self):
self.expected_slice_op_num = 3

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 2
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 0

@test_ast_only
@test_pir_only
def test_pir_op_num(self):
# Remove this after we support control flow
pass
self.pir_expected_op_num = 35
self.pir_expected_shape_op_num = 2
self.pir_expected_slice_op_num = 3


class TestOpNumWithTensorShapeInWhile1(TestOpNumBasicWithTensorShape):
Expand All @@ -764,8 +754,8 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 27
self.pir_expected_shape_op_num = 0
self.pir_expected_slice_op_num = 2
self.pir_expected_shape_op_num = 3
self.pir_expected_slice_op_num = 3


class TestChangeShapeAfterAssign(TestTensorShapeBasic):
Expand All @@ -783,7 +773,7 @@ def _set_expected_op_num(self):

def _set_pir_expected_op_num(self):
self.pir_expected_op_num = 12
self.pir_expected_shape_op_num = 2
self.pir_expected_shape_op_num = 1
self.pir_expected_slice_op_num = 1


Expand Down
8 changes: 7 additions & 1 deletion test/dygraph_to_static/test_write_python_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,12 @@

import unittest

from dygraph_to_static_utils import Dy2StTestBase, test_ast_only, test_sot_only
from dygraph_to_static_utils import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pt_and_pir,
test_sot_only,
)

import paddle

Expand Down Expand Up @@ -113,6 +118,7 @@ def get_raw_value(self, container, getitem_path):
return out

@test_sot_only
@test_legacy_and_pt_and_pir
def test_write_container_sot(self):
func_static = paddle.jit.to_static(self.func)
input = paddle.to_tensor([1, 2, 3])
Expand Down

0 comments on commit be7e923

Please sign in to comment.