Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] Filter out non-Value vars in while loop #61355

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 112 additions & 41 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from functools import cached_property, partial, reduce

Expand Down Expand Up @@ -785,14 +787,60 @@ def create_fake_value_for_undefined_var():
next_vars = [next_vars]
next_cond = cond(*next_vars)
next_cond.stop_gradient = True
cf_yield([next_cond, *flatten(next_vars)])
# Reset type of UndefinedVar from next_vars
for idx, value in undefined_var_mapping.items():
value_new_type = flatten(next_vars)[idx].type()
value.set_type(value_new_type)
cur_block.args()[idx].set_type(value_new_type)
while_op.as_operation().results()[idx].set_type(value_new_type)
return pack_sequence_as(loop_vars, while_op.optimize_update())

# Filter out the constants from next_vars, we only pass the variables (Value) into cf_yield.
# And pass the original fake value directly to constants position.
flattened_next_vars = flatten(next_vars)
(
variable_next_var_indices,
constant_next_var_indices,
) = get_indices_by_discriminator(
flattened_next_vars,
lambda x: isinstance(x, paddle.pir.Value),
)
variable_next_vars, constant_next_vars = select_by_indices(
flattened_next_vars,
variable_next_var_indices,
constant_next_var_indices,
)
(fake_constant_next_vars,) = select_by_indices(
cur_block.args(), constant_next_var_indices
)
unified_next_vars = create_container_by_items_and_indices(
(variable_next_vars, variable_next_var_indices),
(fake_constant_next_vars, constant_next_var_indices),
)
cf_yield([next_cond, *unified_next_vars])

# Reset type of UndefinedVar from next_vars
for idx, value in undefined_var_mapping.items():
if idx in constant_next_var_indices:
continue
value_new_type = flatten(next_vars)[idx].type()
value.set_type(value_new_type)
cur_block.args()[idx].set_type(value_new_type)
while_op.as_operation().results()[idx].set_type(value_new_type)

# Restore the outputs by variable and constants
optimized_results = while_op.optimize_update()
(optimized_variable_results,) = select_by_indices(
optimized_results, variable_next_var_indices
)
# Prune unused fake values
for fake_value in undefined_var_mapping.values():
if fake_value.use_empty():
fake_value_def_op = fake_value.get_defining_op()
fake_value_def_op.get_parent_block().remove_op(
fake_value_def_op
)

return pack_sequence_as(
loop_vars,
create_container_by_items_and_indices(
(optimized_variable_results, variable_next_var_indices),
(constant_next_vars, constant_next_var_indices),
),
)

if in_dygraph_mode():
now_cond = pre_cond.item()
Expand Down Expand Up @@ -1194,6 +1242,40 @@ def _check_args(branch_index, branch_fns, default):
return final_fn()


def get_indices_by_discriminator(container, *discriminators):
buckets = [[] for _ in range(len(discriminators) + 1)]
for idx, item in enumerate(container):
for i, cond in enumerate(discriminators):
if cond(item):
buckets[i].append(idx)
break
else:
buckets[-1].append(idx)
return buckets


def select_by_indices(container, *index_groups):
buckets = [[] for _ in range(len(index_groups))]
for idx, item in enumerate(container):
for i, indices in enumerate(index_groups):
if idx in indices:
buckets[i].append(item)
break
return buckets


def create_container_by_items_and_indices(*items_indices_pairs):
total_length = reduce(
lambda acc, pair: acc + len(pair[0]), items_indices_pairs, 0
)
container = [None for _ in range(total_length)]
for partial_container, indices in items_indices_pairs:
assert len(partial_container) == len(indices)
for idx, item in zip(indices, partial_container):
container[idx] = item
return container


class OutputSelector:
def __init__(
self, if_op, flattened_true_output, flattened_false_output, names
Expand All @@ -1210,7 +1292,6 @@ def __init__(
def unified_output(self):
unified_true_output = []
unified_false_output = []
variable_indices = []
for true_out, false_out, name in zip(
self.true_output, self.false_output, self.names
):
Expand All @@ -1224,14 +1305,9 @@ def unified_output(self):
],
name,
)
if isinstance(true_out, paddle.pir.Value):
assert isinstance(
false_out, paddle.pir.Value
), "true_out and false_out should be both paddle.pir.Value"
variable_indices.append(len(unified_true_output))
unified_true_output.append(true_out)
unified_false_output.append(false_out)
return unified_true_output, unified_false_output, variable_indices
return unified_true_output, unified_false_output

@property
def unified_true_output(self):
Expand All @@ -1243,7 +1319,18 @@ def unified_false_output(self):

@property
def variable_indices(self):
return self.unified_output[2]
true_variable_indices, _ = get_indices_by_discriminator(
self.unified_true_output,
lambda x: isinstance(x, paddle.pir.Value),
)
false_variable_indices, _ = get_indices_by_discriminator(
self.unified_false_output,
lambda x: isinstance(x, paddle.pir.Value),
)
assert (
true_variable_indices == false_variable_indices
), "true_variable_indices and false_variable_indices should be same"
return true_variable_indices

@property
def constant_indices(self):
Expand All @@ -1254,44 +1341,28 @@ def constant_indices(self):
]

def get_variable_outputs(self):
variable_true_output = self.select_by_indices(
(variable_true_output,) = select_by_indices(
self.unified_true_output,
self.variable_indices,
)
variable_false_output = self.select_by_indices(
(variable_false_output,) = select_by_indices(
self.unified_false_output,
self.variable_indices,
)
return variable_true_output, variable_false_output

def restore_outputs_by_variable_results(self, variable_results):
constant_output = self.select_by_indices(
(constant_output,) = select_by_indices(
self.unified_true_output,
self.constant_indices,
)
restored_output = [None for _ in range(self.num_output)]
self.fill_to_indices(
restored_output,
variable_results,
self.variable_indices,
)
self.fill_to_indices(
restored_output,
constant_output,
self.constant_indices,

restored_output = create_container_by_items_and_indices(
(variable_results, self.variable_indices),
(constant_output, self.constant_indices),
)
return restored_output

@staticmethod
def select_by_indices(unified_args, indices):
return [unified_args[i] for i in indices]

@staticmethod
def fill_to_indices(outputs, partial_outputs, partial_indices):
for i, out in zip(partial_indices, partial_outputs):
outputs[i] = out
return outputs

@staticmethod
def constant_to_variable_promotion(out_with_blocks, name):
from paddle.jit.dy2static.convert_operators import to_static_variable
Expand Down Expand Up @@ -1485,8 +1556,8 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
return None
true_output = None
false_output = None
check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond")
check_type(name, "name", (str, type(None)), "base.layers.cond")
check_variable_and_dtype(pred, "pred", ['bool'], "paddle.static.nn.cond")
check_type(name, "name", (str, type(None)), "paddle.static.nn.cond")
if in_pir_mode():
if_op = build_if_op(pred)
if true_fn is not None:
Expand Down
14 changes: 0 additions & 14 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,6 @@ class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_dyfun_with_conflict_var

# This test raises an error about UndefinedVar in pir mode,
# it can be removed after the bug is fixed.
def test_ast_to_func(self):
static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph()
np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05)


class TestTransformWhileLoopWithNone(TestTransformWhileLoop):
def _init_dyfunc(self):
Expand Down Expand Up @@ -434,13 +427,6 @@ class TestClassVarInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_class_var

# This test raises an error about UndefinedVar in pir mode,
# it can be removed after the bug is fixed.
def test_ast_to_func(self):
np.testing.assert_allclose(
self._run_dygraph(), self._run_static(), rtol=1e-05
)


class TestVarCreateInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
Expand Down