Skip to content

Commit

Permalink
[Dy2St][PIR] Support setitem for TensorArray (#61440)
Browse files Browse the repository at this point in the history
  • Loading branch information
SigureMo authored Feb 1, 2024
1 parent a680a39 commit 8a56f09
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 19 deletions.
8 changes: 4 additions & 4 deletions paddle/fluid/pybind/manual_static_op_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -379,17 +379,17 @@ static PyObject *static_api_slice_array(PyObject *self,
starts_tmp, phi::DataType::INT64, phi::CPUPlace());
}

PyObject *ends_obj = PyTuple_GET_ITEM(args, 1);
PyObject *ends_obj = PyTuple_GET_ITEM(args, 2);
pir::Value ends;
if (PyObject_CheckIRValue(ends_obj)) {
ends = CastPyArg2Value(ends_obj, "slice_array", 1);
ends = CastPyArg2Value(ends_obj, "slice_array", 2);
} else if (PyObject_CheckIRVectorOfValue(ends_obj)) {
std::vector<pir::Value> ends_tmp =
CastPyArg2VectorOfValue(ends_obj, "slice_array", 1);
CastPyArg2VectorOfValue(ends_obj, "slice_array", 2);
ends = paddle::dialect::stack(ends_tmp, /*axis*/ 0);
} else {
std::vector<int64_t> ends_tmp =
CastPyArg2Longs(ends_obj, "slice_array", 1);
CastPyArg2Longs(ends_obj, "slice_array", 2);
ends = paddle::dialect::full_int_array(
ends_tmp, phi::DataType::INT64, phi::CPUPlace());
}
Expand Down
20 changes: 13 additions & 7 deletions python/paddle/base/variable_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ def _setitem_for_tensor_array(var, item, value):
assert (
not paddle.in_dynamic_mode()
), "setitem for tensor_array must be called in static graph mode."
if isinstance(item, (Variable, int)):
if isinstance(item, (Variable, paddle.pir.Value, int)):
from paddle.jit.dy2static.convert_operators import to_static_variable
from paddle.tensor import array_write

Expand Down Expand Up @@ -248,17 +248,21 @@ def slice_is_same_to_original(start, end, step):
return start == 0 and end == MAX_INTEGER and step == 1


def parse_index(x, indices):
def is_tensor_array_type(value):
from .framework import in_pir_mode

if in_pir_mode():
is_tensor_array = x.is_dense_tensor_array_type()
return value.is_dense_tensor_array_type()
else:
is_tensor_array = (
hasattr(x, "desc")
and x.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
return (
hasattr(value, "desc")
and value.desc.type() == core.VarDesc.VarType.LOD_TENSOR_ARRAY
)


def parse_index(x, indices):
is_tensor_array = is_tensor_array_type(x)

advanced_index = (
[] if is_tensor_array else [None] * 2 * len(x.shape)
) # content is (dim, index)
Expand Down Expand Up @@ -448,7 +452,9 @@ def _setitem_static(x, indices, values):
from . import in_dynamic_or_pir_mode
from .framework import Variable, default_main_program, in_pir_mode

if x.type == paddle.base.core.VarDesc.VarType.LOD_TENSOR_ARRAY:
is_tensor_array = is_tensor_array_type(x)

if is_tensor_array:
return _setitem_for_tensor_array(x, indices, values)

# step1: parsing the index and recording them
Expand Down
9 changes: 1 addition & 8 deletions test/dygraph_to_static/test_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,7 @@ class TestSliceInIf(TestSliceBase):
def init_dygraph_func(self):
self.dygraph_func = test_slice_in_if

@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
static_res = self.run_static_mode()
Expand All @@ -179,14 +180,6 @@ def init_input(self):
def init_dygraph_func(self):
self.dygraph_func = test_set_value

# TODO(pir-control-flow): Delete this code after supporting control flow
@test_legacy_and_pt_and_pir
def test_transformed_static_result(self):
self.init_dygraph_func()
static_res = self.run_static_mode()
dygraph_res = self.run_dygraph_mode()
np.testing.assert_allclose(dygraph_res, static_res, rtol=1e-05)


class TestSetValueWithLayerAndSave(Dy2StTestBase):
def setUp(self):
Expand Down

0 comments on commit 8a56f09

Please sign in to comment.