Skip to content

Commit

Permalink
[Dy2St] pir dy2st unittest verification - Part 8 (#59120)
Browse files Browse the repository at this point in the history
---------

Co-authored-by: SigureMo <[email protected]>
  • Loading branch information
gouzil and SigureMo authored Nov 22, 2023
1 parent 6c8178b commit c9af1b6
Show file tree
Hide file tree
Showing 9 changed files with 241 additions and 205 deletions.
26 changes: 21 additions & 5 deletions python/paddle/jit/dy2static/convert_operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,7 +556,7 @@ def convert_len(var):
operations are added in `len` transformation, such as appending
`shape_op` in var.block.
"""
if isinstance(var, (Variable, OpResult)):
if isinstance(var, Variable):
assert var.ndim > 0, "len() of a 0-D tensor is wrong"
if var.type in [
core.VarDesc.VarType.LOD_TENSOR,
Expand All @@ -575,8 +575,24 @@ def convert_len(var):
'len(var) only supports LoDTensor/LoDTensorArray/SelectedRows, but received %s.'
% type(var)
)
elif isinstance(var, OpResult):
assert var.ndim > 0, "len() of a 0-D tensor is wrong"
if var.is_dense_tensor_type() or var.is_selected_row_type():
# Note: Length of var may be known ahead of time in dygraph,
# but it probably represents batch size which can be variant.
# so we return a variable dynamically inferred from var.shape.
if var.shape[0] > 0 and var.is_dense_tensor_type():
return var.shape[0]
return paddle.shape(var)[0]
elif var.is_dense_tensor_array_type():
return paddle.tensor.array_length(var)
else:
raise TypeError(
'len(var) only supports DenseTensor/DenseTensorArray/SelectedRows, '
+ f'but received {type(var)}.'
)
else:
if isinstance(var, (VariableTuple)):
if isinstance(var, VariableTuple):
return var.__len__()
return len(var)

Expand Down Expand Up @@ -625,11 +641,11 @@ def convert_range(*args):
has_variable = any(isinstance(x, (Variable, OpResult)) for x in args)
if has_variable:
if len(args) == 1:
return paddle.arange(0, args[0], 1, paddle.int64)
return paddle.arange(0, args[0], 1, "int64")
if len(args) == 2:
return paddle.arange(args[0], args[1], 1, paddle.int64)
return paddle.arange(args[0], args[1], 1, "int64")
if len(args) == 3:
return paddle.arange(args[0], args[1], args[2], paddle.int64)
return paddle.arange(args[0], args[1], args[2], "int64")
return range(*args)


Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,7 +390,7 @@ class PartialProgramLayer:
**1. This is a very low level API. Users should not use this API
directly. Please use `partial_program_from(concrete_program)`
to create it.
**2. LoDTensorArray is not currently supported in the output.
**2. TensorArray is not currently supported in the output.
Args:
main_program(Program): The main program that contains ops need to be executed.
Expand Down
24 changes: 21 additions & 3 deletions python/paddle/jit/sot/infer_meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ class VariableCreator:
"""

def __init__(self):
self.var_cache = {}
# TODO(dev): Remove the program and var_cache shims after PIR become default state.
# self.var_cache = {}
# self.main_program = paddle.static.Program()
# self.startup_program = paddle.static.Program()
self.var_name_generator = UniqueNameGenerator("infer_meta_variable_")

def gen_name(self, meta):
Expand All @@ -114,6 +117,21 @@ def gen_name(self, meta):
name += f"_{l}"
return name

@property
def var_cache(self):
if paddle.framework.use_pir_api():
return self.pir_var_cache
else:
return self.legacy_var_cache

@cached_property
def legacy_var_cache(self):
return {}

@cached_property
def pir_var_cache(self):
return {}

@cached_property
def legacy_programs(self):
# Just for PIR and legacy IR compatibility.
Expand All @@ -133,13 +151,13 @@ def main_program(self):

@property
def startup_program(self):
if paddle.base.framework.use_pir_api():
if paddle.framework.use_pir_api():
return self.pir_programs[1]
else:
return self.legacy_programs[1]

def create_var(self, meta):
if paddle.base.framework.use_pir_api():
if paddle.framework.use_pir_api():
with paddle.static.program_guard(
self.main_program, self.startup_program
):
Expand Down
9 changes: 6 additions & 3 deletions test/dygraph_to_static/test_backward_without_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,10 @@
import unittest

import numpy as np
from dygraph_to_static_utils_new import Dy2StTestBase, test_legacy_and_pir
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_legacy_and_pir_exe_and_pir_api,
)

import paddle

Expand All @@ -30,7 +33,7 @@ def forward(self, x):


class TestBackwardWithoutParams(Dy2StTestBase):
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_run(self):
net = paddle.jit.to_static(Net())

Expand All @@ -54,7 +57,7 @@ def forward(self, x):


class TestZeroSizeNet(Dy2StTestBase):
@test_legacy_and_pir
@test_legacy_and_pir_exe_and_pir_api
def test_run(self):
net = paddle.jit.to_static(ZeroSizeNet())
x = paddle.ones([2, 2])
Expand Down
120 changes: 68 additions & 52 deletions test/dygraph_to_static/test_cache_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
from collections import Counter

import numpy as np
from dygraph_to_static_utils_new import Dy2StTestBase
from dygraph_to_static_utils_new import (
Dy2StTestBase,
test_ast_only,
test_legacy_and_pir_exe_and_pir_api,
)
from test_fetch_feed import Linear, Pool2D

import paddle
from paddle import base
from paddle.jit.api import to_static
from paddle.jit.dy2static import convert_to_static


Expand All @@ -31,41 +34,55 @@ def setUp(self):
self.dygraph_class = Pool2D
self.data = np.random.random((1, 2, 4, 4)).astype('float32')

@test_legacy_and_pir_exe_and_pir_api
@test_ast_only
def test_cache(self):
prev_ops, cur_ops = Counter(), Counter()
prev_out, cur_out = None, None
with base.dygraph.guard(base.CPUPlace()):
static_net = self.dygraph_class()
for batch_id in range(self.batch_num):
out = static_net(paddle.to_tensor(self.data))
# Check outputs
prev_out = cur_out
cur_out = out
# Check forward ops
prev_ops = cur_ops
static_net = paddle.jit.to_static(self.dygraph_class())
for batch_id in range(self.batch_num):
out = static_net(paddle.to_tensor(self.data))
# Check outputs
prev_out = cur_out
cur_out = out
# Check forward ops
prev_ops = cur_ops

if paddle.framework.use_pir_api():
cur_ops = Counter(
[op.type for op in base.default_main_program().block(0).ops]
[
op.name()
for op in static_net.forward.concrete_program.main_program.global_block().ops
]
)
if batch_id > 0:
prev_out_numpy = (
prev_out[0].numpy()
if isinstance(prev_out, (tuple, list))
else prev_out.numpy()
)
cur_out_numpy = (
cur_out[0].numpy()
if isinstance(cur_out, (tuple, list))
else cur_out.numpy()
)
np.testing.assert_allclose(
prev_out_numpy,
cur_out_numpy,
rtol=1e-05,
err_msg='Output in previous batch is {}\n Output in current batch is \n{}'.format(
prev_out_numpy, cur_out_numpy
),
)
self.assertEqual(prev_ops, cur_ops)

else:
cur_ops = Counter(
[
op.type
for op in static_net.forward.concrete_program.main_program.global_block().ops
]
)
if batch_id > 0:
prev_out_numpy = (
prev_out[0].numpy()
if isinstance(prev_out, (tuple, list))
else prev_out.numpy()
)
cur_out_numpy = (
cur_out[0].numpy()
if isinstance(cur_out, (tuple, list))
else cur_out.numpy()
)
np.testing.assert_allclose(
prev_out_numpy,
cur_out_numpy,
rtol=1e-05,
err_msg='Output in previous batch is {}\n Output in current batch is \n{}'.format(
prev_out_numpy, cur_out_numpy
),
)
self.assertEqual(prev_ops, cur_ops)


class TestCacheProgram2(TestCacheProgram):
Expand All @@ -90,23 +107,23 @@ def train_dygraph(self):
def train(self, to_static=False):
paddle.jit.enable_to_static(to_static)

with base.dygraph.guard(base.CPUPlace()):
dygraph_net = self.dygraph_class()
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=dygraph_net.parameters()
)
loss_data = []
for batch_id in range(self.batch_num):
input = base.dygraph.to_variable(self.data)
pred, avg_loss = dygraph_net(input)

loss_data.append(avg_loss.numpy())
avg_loss.backward()
adam.minimize(avg_loss)
dygraph_net.clear_gradients()
static_net = paddle.jit.to_static(self.dygraph_class())
adam = paddle.optimizer.Adam(
learning_rate=0.001, parameters=static_net.parameters()
)
loss_data = []
for batch_id in range(self.batch_num):
input = paddle.to_tensor(self.data)
pred, avg_loss = static_net(input)

loss_data.append(avg_loss.numpy())
avg_loss.backward()
adam.minimize(avg_loss)
static_net.clear_gradients()

return loss_data

@test_legacy_and_pir_exe_and_pir_api
def test_with_optimizer(self):
dygraph_loss = self.train_dygraph()
static_loss = self.train_static()
Expand All @@ -125,14 +142,14 @@ def simple_func(x):


class TestConvertWithCache(Dy2StTestBase):
@test_legacy_and_pir_exe_and_pir_api
def test_cache(self):
static_func = convert_to_static(simple_func)
# Get transformed function from cache.
cached_func = convert_to_static(simple_func)
self.assertTrue(id(static_func), id(cached_func))


@to_static
def sum_even_until_limit(max_len, limit):
ret_sum = base.dygraph.to_variable(np.zeros(1).astype('int32'))
for i in range(max_len):
Expand All @@ -156,12 +173,11 @@ def sum_under_while(limit):

class TestToOutputWithCache(Dy2StTestBase):
def test_output(self):
with base.dygraph.guard():
ret = sum_even_until_limit(80, 10)
self.assertEqual(ret.numpy(), 30)
ret = paddle.jit.to_static(sum_even_until_limit)(80, 10)
self.assertEqual(ret.numpy(), 30)

ret = to_static(sum_under_while)(100)
self.assertEqual(ret.numpy(), 5050)
ret = paddle.jit.to_static(sum_under_while)(100)
self.assertEqual(ret.numpy(), 5050)


if __name__ == '__main__':
Expand Down
Loading

0 comments on commit c9af1b6

Please sign in to comment.