Skip to content

Commit

Permalink
[Dy2St] Filter out non-Value vars in while loop
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo committed Jan 30, 2024
1 parent b88fd78 commit d9a54ad
Show file tree
Hide file tree
Showing 2 changed files with 105 additions and 55 deletions.
146 changes: 105 additions & 41 deletions python/paddle/static/nn/control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from __future__ import annotations

import warnings
from functools import cached_property, partial, reduce

Expand Down Expand Up @@ -785,14 +787,53 @@ def create_fake_value_for_undefined_var():
next_vars = [next_vars]
next_cond = cond(*next_vars)
next_cond.stop_gradient = True
cf_yield([next_cond, *flatten(next_vars)])
# Reset type of UndefinedVar from next_vars
for idx, value in undefined_var_mapping.items():
value_new_type = flatten(next_vars)[idx].type()
value.set_type(value_new_type)
cur_block.args()[idx].set_type(value_new_type)
while_op.as_operation().results()[idx].set_type(value_new_type)
return pack_sequence_as(loop_vars, while_op.optimize_update())

# Filter out the constants from next_vars, we only pass the variables (Value) into cf_yield.
# And pass the original fake value directly to constants position.
flattened_next_vars = flatten(next_vars)
(
variable_next_var_indices,
constant_next_var_indices,
) = get_indices_by_discriminator(
flattened_next_vars,
lambda x: isinstance(x, paddle.pir.Value),
)
variable_next_vars, constant_next_vars = select_by_indices(
flattened_next_vars,
variable_next_var_indices,
constant_next_var_indices,
)
(fake_constant_next_vars,) = select_by_indices(
unified_loop_vars, constant_next_var_indices
)
unified_next_vars = create_container_by_items_and_indices(
(variable_next_vars, variable_next_var_indices),
(fake_constant_next_vars, constant_next_var_indices),
)
cf_yield([next_cond, *unified_next_vars])

# Reset type of UndefinedVar from next_vars
for idx, value in undefined_var_mapping.items():
if idx in constant_next_var_indices:
continue
value_new_type = flatten(next_vars)[idx].type()
value.set_type(value_new_type)
cur_block.args()[idx].set_type(value_new_type)
while_op.as_operation().results()[idx].set_type(value_new_type)

# Restore the outputs by variable and constants
optimized_results = while_op.optimize_update()
(optimized_variable_results,) = select_by_indices(
optimized_results, variable_next_var_indices
)

return pack_sequence_as(
loop_vars,
create_container_by_items_and_indices(
(optimized_variable_results, variable_next_var_indices),
(constant_next_vars, constant_next_var_indices),
),
)

if in_dygraph_mode():
now_cond = pre_cond.item()
Expand Down Expand Up @@ -1194,6 +1235,40 @@ def _check_args(branch_index, branch_fns, default):
return final_fn()


def get_indices_by_discriminator(container, *discriminators):
buckets = [[] for _ in range(len(discriminators) + 1)]
for idx, item in enumerate(container):
for i, cond in enumerate(discriminators):
if cond(item):
buckets[i].append(idx)
break
else:
buckets[-1].append(idx)
return buckets


def select_by_indices(container, *index_groups):
buckets = [[] for _ in range(len(index_groups))]
for idx, item in enumerate(container):
for i, indices in enumerate(index_groups):
if idx in indices:
buckets[i].append(item)
break
return buckets


def create_container_by_items_and_indices(*items_indices_pairs):
total_length = reduce(
lambda acc, pair: acc + len(pair[0]), items_indices_pairs, 0
)
container = [None for _ in range(total_length)]
for partial_container, indices in items_indices_pairs:
assert len(partial_container) == len(indices)
for idx, item in zip(indices, partial_container):
container[idx] = item
return container


class OutputSelector:
def __init__(
self, if_op, flattened_true_output, flattened_false_output, names
Expand All @@ -1210,7 +1285,6 @@ def __init__(
def unified_output(self):
unified_true_output = []
unified_false_output = []
variable_indices = []
for true_out, false_out, name in zip(
self.true_output, self.false_output, self.names
):
Expand All @@ -1224,14 +1298,9 @@ def unified_output(self):
],
name,
)
if isinstance(true_out, paddle.pir.Value):
assert isinstance(
false_out, paddle.pir.Value
), "true_out and false_out should be both paddle.pir.Value"
variable_indices.append(len(unified_true_output))
unified_true_output.append(true_out)
unified_false_output.append(false_out)
return unified_true_output, unified_false_output, variable_indices
return unified_true_output, unified_false_output

@property
def unified_true_output(self):
Expand All @@ -1243,7 +1312,18 @@ def unified_false_output(self):

@property
def variable_indices(self):
return self.unified_output[2]
true_variable_indices, _ = get_indices_by_discriminator(
self.unified_true_output,
lambda x: isinstance(x, paddle.pir.Value),
)
false_variable_indices, _ = get_indices_by_discriminator(
self.unified_false_output,
lambda x: isinstance(x, paddle.pir.Value),
)
assert (
true_variable_indices == false_variable_indices
), "true_variable_indices and false_variable_indices should be same"
return true_variable_indices

@property
def constant_indices(self):
Expand All @@ -1254,44 +1334,28 @@ def constant_indices(self):
]

def get_variable_outputs(self):
variable_true_output = self.select_by_indices(
(variable_true_output,) = select_by_indices(
self.unified_true_output,
self.variable_indices,
)
variable_false_output = self.select_by_indices(
(variable_false_output,) = select_by_indices(
self.unified_false_output,
self.variable_indices,
)
return variable_true_output, variable_false_output

def restore_outputs_by_variable_results(self, variable_results):
constant_output = self.select_by_indices(
(constant_output,) = select_by_indices(
self.unified_true_output,
self.constant_indices,
)
restored_output = [None for _ in range(self.num_output)]
self.fill_to_indices(
restored_output,
variable_results,
self.variable_indices,
)
self.fill_to_indices(
restored_output,
constant_output,
self.constant_indices,

restored_output = create_container_by_items_and_indices(
(variable_results, self.variable_indices),
(constant_output, self.constant_indices),
)
return restored_output

@staticmethod
def select_by_indices(unified_args, indices):
return [unified_args[i] for i in indices]

@staticmethod
def fill_to_indices(outputs, partial_outputs, partial_indices):
for i, out in zip(partial_indices, partial_outputs):
outputs[i] = out
return outputs

@staticmethod
def constant_to_variable_promotion(out_with_blocks, name):
from paddle.jit.dy2static.convert_operators import to_static_variable
Expand Down Expand Up @@ -1485,8 +1549,8 @@ def cond(pred, true_fn=None, false_fn=None, name=None, return_names=None):
return None
true_output = None
false_output = None
check_variable_and_dtype(pred, "pred", ['bool'], "base.layers.cond")
check_type(name, "name", (str, type(None)), "base.layers.cond")
check_variable_and_dtype(pred, "pred", ['bool'], "paddle.static.nn.cond")
check_type(name, "name", (str, type(None)), "paddle.static.nn.cond")
if in_pir_mode():
if_op = build_if_op(pred)
if true_fn is not None:
Expand Down
14 changes: 0 additions & 14 deletions test/dygraph_to_static/test_loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,13 +344,6 @@ class TestTransformWhileLoopWithConflicVar(TestTransformWhileLoop):
def _init_dyfunc(self):
self.dyfunc = while_loop_dyfun_with_conflict_var

# This test raises an error about UndefinedVar in pir mode,
# it can be removed after the bug is fixed.
def test_ast_to_func(self):
static_numpy = self._run_static()
dygraph_numpy = self._run_dygraph()
np.testing.assert_allclose(dygraph_numpy, static_numpy, rtol=1e-05)


class TestTransformWhileLoopWithNone(TestTransformWhileLoop):
def _init_dyfunc(self):
Expand Down Expand Up @@ -434,13 +427,6 @@ class TestClassVarInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
self.dyfunc = for_loop_class_var

# This test raises an error about UndefinedVar in pir mode,
# it can be removed after the bug is fixed.
def test_ast_to_func(self):
np.testing.assert_allclose(
self._run_dygraph(), self._run_static(), rtol=1e-05
)


class TestVarCreateInForLoop(TestTransformForLoop):
def _init_dyfunc(self):
Expand Down

0 comments on commit d9a54ad

Please sign in to comment.