Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Dy2St] Use ShadowOutputOp to get dy2st output #60363

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -498,6 +498,10 @@ void HandleForSpecialOp(pir::Operation* op,
// change opreand name to param_name
auto orig_name = value_exe_info->GetValue2VarName().at(value);

if (var_name == orig_name) {
return;
}

Comment on lines +501 to +504
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改理由同 #59928 (comment)

if (value_exe_info->GetScope()->FindVar(var_name) != nullptr) {
const_cast<Scope*>(value_exe_info->GetScope())->EraseVars({var_name});
VLOG(1) << "var " << var_name << " has been removed from scope";
Expand Down
44 changes: 22 additions & 22 deletions paddle/fluid/pybind/pir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1053,14 +1053,14 @@ std::pair<std::shared_ptr<Program>, OpResultMap> CloneProgram(
std::make_pair(associated_array_key, associated_array_value));
}

void AppendSetParameter(Program *forward_program,
void AppendShadowOutput(Program *forward_program,
const pir::OpResult &result,
const std::string &name,
size_t start_point) {
pir::IrContext *ctx = pir::IrContext::Instance();
auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name());
auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());
pir::AttributeMap attribute_map = {
{"parameter_name", pir::StrAttribute::get(ctx, name)},
{"output_name", pir::StrAttribute::get(ctx, name)},
};
pir::Operation *operation =
pir::Operation::Create({result}, attribute_map, {}, op_info);
Expand All @@ -1073,7 +1073,7 @@ void AppendSetParameter(Program *forward_program,
}
}

int AppendSetParameters(Program *forward_program,
int AppendShadowOutputs(Program *forward_program,
const std::vector<pir::OpResult> &outputs_op_result,
int start_point,
std::string name_prefix) {
Expand All @@ -1082,9 +1082,9 @@ int AppendSetParameters(Program *forward_program,

for (const auto &result : outputs_op_result) {
if (!added_op_result.count(result) || IsFakeOpResult(result)) {
std::string parameter_name = name_prefix + std::to_string(counter);
AppendSetParameter(
forward_program, result, parameter_name, start_point + counter);
std::string shadow_output_name = name_prefix + std::to_string(counter);
AppendShadowOutput(
forward_program, result, shadow_output_name, start_point + counter);
counter += 1;
added_op_result.insert(result);
}
Expand Down Expand Up @@ -1200,20 +1200,20 @@ SplitedResult SplitForwardBackward(
if (v.impl() == nullptr) {
return;
}
// NOTE(Aurelius84): we should skip insert SetParameterOp repeatly by
// NOTE(Aurelius84): we should skip insert ShadowOutputOp repeatly by
// calling SplitForwardBackward multi-times.
std::string parameter_name =
std::string shadow_output_name =
std::string("output_") + std::to_string(counter);
std::unordered_set<pir::Value> inserted_value;
for (auto it = forward_program->block()->rbegin();
it != forward_program->block()->rend();
++it) {
if (it->isa<pir::SetParameterOp>()) {
if (it->isa<pir::ShadowOutputOp>()) {
auto out_name =
it->attribute<pir::StrAttribute>("parameter_name").AsString();
if (out_name == parameter_name) {
it->attribute<pir::StrAttribute>("output_name").AsString();
if (out_name == shadow_output_name) {
VLOG(4) << out_name
<< " has been inserted SetParameterOp, skip it now.";
<< " has been inserted ShadowOutputOp, skip it now.";
return;
}

Expand All @@ -1224,9 +1224,9 @@ SplitedResult SplitForwardBackward(
if (inserted_value.count(forward_value_map[v])) {
return;
}
auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name());
auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());
pir::AttributeMap attribute_map = {
{"parameter_name", pir::StrAttribute::get(ctx, parameter_name)},
{"output_name", pir::StrAttribute::get(ctx, shadow_output_name)},
};
pir::Operation *operation = pir::Operation::Create(
{forward_value_map[v]}, attribute_map, {}, op_info);
Expand All @@ -1241,9 +1241,9 @@ SplitedResult SplitForwardBackward(
if (v.impl() == nullptr) {
return;
}
auto op_info = ctx->GetRegisteredOpInfo(pir::SetParameterOp::name());
auto op_info = ctx->GetRegisteredOpInfo(pir::ShadowOutputOp::name());
pir::AttributeMap attribute_map = {
{"parameter_name",
{"output_name",
pir::StrAttribute::get(
ctx, std::string("output_") + std::to_string(counter))},
};
Expand Down Expand Up @@ -1368,10 +1368,10 @@ pir::Type CreateSelectedRowsTypeByDenseTensor(pir::Type dense_tensor_type) {
}
}

void ResetParameterName(pir::Operation *op, const std::string &name) {
void ResetShadowOutputName(pir::Operation *op, const std::string &name) {
pir::IrContext *ctx = pir::IrContext::Instance();
if (op->isa<pir::SetParameterOp>()) {
op->set_attribute("parameter_name", pir::StrAttribute::get(ctx, name));
if (op->isa<pir::ShadowOutputOp>()) {
op->set_attribute("output_name", pir::StrAttribute::get(ctx, name));
}
}

Expand Down Expand Up @@ -1406,9 +1406,9 @@ std::map<int, int> GetOpInplaceInfo(const pir::Operation *op) {
void BindUtils(pybind11::module *m) {
m->def("clone_program", CloneProgram);
m->def("get_op_inplace_info", GetOpInplaceInfo);
m->def("reset_parameter_name", ResetParameterName);
m->def("reset_shadow_output_name", ResetShadowOutputName);
m->def("split_program", SplitForwardBackward);
m->def("append_set_parameters", AppendSetParameters);
m->def("append_shadow_outputs", AppendShadowOutputs);
m->def("fake_op_result", FakeOpResult);
m->def("is_fake_op_result", IsFakeOpResult);
m->def("get_current_insertion_point", []() -> PyInsertionPoint {
Expand Down
35 changes: 20 additions & 15 deletions python/paddle/jit/dy2static/pir_partial_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def union(self, x, y):
self.father[father_x] = father_y

def find_root(self, x):
if not self.father.__contains__(x):
if x not in self.father:
self.father[x] = x
if self.father[x].is_same(x):
return x
Expand Down Expand Up @@ -135,24 +135,29 @@ def _get_value_name_map_from_program(cls, program):
ret = ValueDict()
ret[fake_op_result()] = "FakeVar"
for op in program.global_block().ops:
if op.name() == "pd_op.data":
ret[op.result(0)] = op.attrs()["name"]
if op.name() == "builtin.set_parameter":
ret[op.operand(0).source()] = op.attrs()["parameter_name"]
if op.name() == "builtin.parameter":
elif op.name() == "builtin.parameter":
ret[op.result(0)] = op.attrs()["parameter_name"]
elif op.name() == "builtin.shadow_output":
ret[op.operand(0).source()] = op.attrs()["output_name"]
elif op.name() == "pd_op.data":
ret[op.result(0)] = op.attrs()["name"]
return ret

@classmethod
def _get_name_defining_op(cls, program, value):
for op in program.global_block().ops:
if op.name() == "pd_op.data":
if op.name() == "builtin.set_parameter":
if value.is_same(op.operand(0).source()):
return op
elif op.name() == "builtin.parameter":
if value.is_same(op.result(0)):
return op
if op.name() == "builtin.set_parameter":
elif op.name() == "builtin.shadow_output":
if value.is_same(op.operand(0).source()):
return op
if op.name() == "builtin.parameter":
elif op.name() == "pd_op.data":
if value.is_same(op.result(0)):
return op
return None
Expand Down Expand Up @@ -291,7 +296,7 @@ def _forward_backward_program(self):
def program_attr(self):
assert (
self.finish_pass is False
), "program_attr() is called by PartialProgramLayer, don't call it matually, use program_name_attr instead."
), "program_attr() is called by PartialProgramLayer, don't call it manually, use program_name_attr instead."
# can't apply pass after call this function.
self.finish_pass = True
fwd_map = {
Expand Down Expand Up @@ -346,7 +351,7 @@ def has_name(value):
if has_name(ufset.find_root(value)):
name_defining_op = self._get_name_defining_op(program, value)
if name_defining_op:
paddle.core.pir.reset_parameter_name(
paddle.core.pir.reset_shadow_output_name(
name_defining_op, value2name[ufset.find_root(value)]
)

Expand Down Expand Up @@ -384,8 +389,8 @@ class PirPassContext:
"""

INPUT_OP_NAME = "pd_op.data"
PARM_OP_NAME = "builtin.parameter"
OUTPUT_OP_NAME = "builtin.set_parameter"
PARAM_OP_NAME = "builtin.parameter"
OUTPUT_OP_NAME = "builtin.shadow_output"

@classmethod
def apply(cls, runable_program, build_strategy):
Expand Down Expand Up @@ -419,7 +424,7 @@ def _prepare_attr(cls, program):
op_name = op.name()
if op_name == cls.INPUT_OP_NAME:
inputs.append(op.result(0))
elif op_name == cls.PARM_OP_NAME:
elif op_name == cls.PARAM_OP_NAME:
params.append(op.result(0))
elif op_name == cls.OUTPUT_OP_NAME:
outputs.append(op.operand(0).source())
Expand Down Expand Up @@ -546,7 +551,7 @@ def origin_runable_program(self):
inputs = list(self._inputs.var_list)
outputs = list(self._outputs.var_list)
params = self._param_values
paddle.base.libpaddle.pir.append_set_parameters(
paddle.base.libpaddle.pir.append_shadow_outputs(
self._origin_main_program,
outputs,
len(self._origin_main_program.global_block().ops),
Expand Down Expand Up @@ -796,7 +801,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
dtype=out_op_result.dtype,
)
forward_outputs_grads.append(value)
paddle.base.libpaddle.pir.append_set_parameters(
paddle.base.libpaddle.pir.append_shadow_outputs(
program,
forward_outputs_grads,
len(program.global_block().ops),
Expand Down Expand Up @@ -861,7 +866,7 @@ def _append_backward_desc(self, train_runnable_program: RunableProgram):
)
)
backward_end_op_index = len(program.global_block().ops)
paddle.base.libpaddle.pir.append_set_parameters(
paddle.base.libpaddle.pir.append_shadow_outputs(
program,
output_grads_to_append,
backward_end_op_index,
Expand Down
2 changes: 1 addition & 1 deletion python/paddle/jit/pir_dy2static/parameter_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def get(self, program, value):
return None
root_var = inplace_dict[value]
saved = []
while inplace_dict.__contains__(root_var):
while root_var in inplace_dict:
saved.append(root_var)
root_var = inplace_dict[root_var]
for var in saved:
Expand Down
3 changes: 1 addition & 2 deletions test/dygraph_to_static/test_tensor_memcpy_on_cpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
from dygraph_to_static_utils import (
Dy2StTestBase,
enable_to_static_guard,
test_legacy_and_pt,
test_legacy_and_pt_and_pir,
)

Expand Down Expand Up @@ -69,7 +68,7 @@ def _run(self):
x2 = paddle.jit.to_static(tensor_copy_to_cuda)(x1)
return x1.place, x2.place, x2.numpy()

@test_legacy_and_pt
@test_legacy_and_pt_and_pir
def test_tensor_cuda_on_default_cpu(self):
if not paddle.is_compiled_with_cuda():
return
Expand Down