diff --git a/paddle/fluid/eager/to_static/run_program_op_node.h b/paddle/fluid/eager/to_static/run_program_op_node.h index f79649f71069d..7ce5e637fcc86 100644 --- a/paddle/fluid/eager/to_static/run_program_op_node.h +++ b/paddle/fluid/eager/to_static/run_program_op_node.h @@ -17,6 +17,7 @@ #include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/grad_node_info.h" #include "paddle/fluid/eager/tensor_wrapper.h" +#include "paddle/fluid/framework/executor_cache.h" #include "paddle/fluid/framework/new_executor/interpretercore.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/ir_adaptor/translator/program_translator.h" @@ -289,6 +290,7 @@ static void ShareTensorsFromScopeByValue( for (size_t i = 0; i < tensors.size(); ++i) { auto &name = names[i]; auto &value = values[i]; + VLOG(2) << "share " << name << " from scope"; if (value.impl() == nullptr) { // skip stop_gradient. continue; @@ -306,7 +308,7 @@ static void ShareTensorsFromScopeByValue( auto &src_tensor = var->Get(); auto *dst_tensor = const_cast( dynamic_cast(tensors[i]->impl().get())); - VLOG(2) << "share " << name << " from scope"; + VLOG(2) << "actually do sharing " << name << " from scope"; *dst_tensor = src_tensor; } else if (var->IsType()) { auto &src_tensor = var->Get(); @@ -500,10 +502,16 @@ inline void PirRunProgramAPI( details::ShareTensorsIntoScopeByValue( forward_global_block, params, param_values, global_inner_scope); // Step 2. create new interpretercore - auto kernel_forward_program = - paddle::dialect::PdOpLowerToKernelPass(forward_program, place); + auto passed_kernel_program = + paddle::framework::ApplyIrPass(forward_program, place); + if (FLAGS_print_ir) { + std::ostringstream print_stream; + print_stream << "LoweredProgram( AfterPass ) is :\n"; + passed_kernel_program->Print(print_stream); + std::cout << print_stream.str() << std::endl; + } interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( - std::move(kernel_forward_program), + std::move(passed_kernel_program), place, /*is_grad=*/false, program_id, @@ -1037,10 +1045,16 @@ inline void PirRunProgramGradAPI( 1); VLOG(2) << "No interpretercore cahce, so create a new interpretercore"; // Step 1. share input_vars & parameters into scope - auto kernel_backward_program = - paddle::dialect::PdOpLowerToKernelPass(backward_program, place); + auto passed_kernel_program = + paddle::framework::ApplyIrPass(backward_program, place); + if (FLAGS_print_ir) { + std::ostringstream print_stream; + print_stream << "LoweredProgram( AfterPass | Backward ) is :\n"; + passed_kernel_program->Print(print_stream); + std::cout << print_stream.str() << std::endl; + } interpreter_core = paddle::framework::CreatePirInterpreterCoreInfoToCache( - std::move(kernel_backward_program), + std::move(passed_kernel_program), place, /*is_grad=*/true, program_id, @@ -1346,9 +1360,7 @@ class PirGradNodeRunProgram : public egr::GradNodeBase { x_grad_ptr.emplace_back(&i); } for (auto &i : params_grad) { - if (i.defined()) { - params_grad_ptr.emplace_back(&i); - } + params_grad_ptr.emplace_back(&i); } } diff --git a/paddle/fluid/framework/executor_cache.cc b/paddle/fluid/framework/executor_cache.cc index 6a6be34c3eebc..6f64cf44bf69c 100644 --- a/paddle/fluid/framework/executor_cache.cc +++ b/paddle/fluid/framework/executor_cache.cc @@ -358,6 +358,24 @@ bool TensorSortHelper(const paddle::Tensor &t1, const paddle::Tensor &t2) { return t1.name() < t2.name(); } +std::unique_ptr<::pir::Program> ApplyIrPass(::pir::Program *program, + phi::Place place) { + auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program, place); + + if (FLAGS_pir_apply_inplace_pass) { + ::pir::PassManager pm(::pir::IrContext::Instance(), 3); + pm.AddPass(::pir::CreateInplacePass()); + pm.Run(ir_res.get()); + + if (FLAGS_print_ir) { + std::cout << "IR After inplace -------------------" << std::endl; + std::cout << *ir_res << std::endl; + } + } + + return ir_res; +} + std::unique_ptr<::pir::Program> ConstructFowardIrProgram( const paddle::framework::BlockDesc *forward_global_block, const paddle::framework::BlockDesc *backward_global_block, @@ -456,21 +474,7 @@ std::unique_ptr<::pir::Program> ConstructFowardIrProgram( program.get()); program_translator.Translate(); - - auto ir_res = paddle::dialect::PdOpLowerToKernelPass(program.get(), place); - - if (FLAGS_pir_apply_inplace_pass) { - ::pir::PassManager pm(::pir::IrContext::Instance(), 3); - pm.AddPass(::pir::CreateInplacePass()); - pm.Run(ir_res.get()); - - if (FLAGS_print_ir) { - std::cout << "IR After inplace -------------------" << std::endl; - std::cout << *ir_res << std::endl; - } - } - - return ir_res; + return ApplyIrPass(program.get(), place); } std::unique_ptr<::pir::Program> ConstructBackwardIrProgram( diff --git a/paddle/fluid/framework/executor_cache.h b/paddle/fluid/framework/executor_cache.h index ad94fcbeca107..e6da435a903aa 100644 --- a/paddle/fluid/framework/executor_cache.h +++ b/paddle/fluid/framework/executor_cache.h @@ -252,6 +252,9 @@ std::shared_ptr CreatePirInterpreterCoreInfoToCache( int64_t program_id, framework::Scope* scope); +std::unique_ptr<::pir::Program> ApplyIrPass(::pir::Program* program, + phi::Place place); + std::unique_ptr<::pir::Program> ConstructFowardIrProgram( const paddle::framework::BlockDesc* forward_global_block, const paddle::framework::BlockDesc* backward_global_block, diff --git a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py index ed5c193aa2eca..fae7a3c9fc283 100644 --- a/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/ops_api_gen.py @@ -82,6 +82,7 @@ 'generate_sequence_xpu', 'layer_norm_act_xpu', 'memcpy', + 'batch_norm_', 'multi_encoder_xpu', 'multihead_matmul', 'squeeze_excitation_block', @@ -104,7 +105,6 @@ 'add_n_', 'add_n_with_kernel', 'assign_value', - 'batch_norm_', 'c_allgather', 'c_allreduce_max', 'c_allreduce_sum', diff --git a/paddle/fluid/pybind/pir.cc b/paddle/fluid/pybind/pir.cc index 58a8012e021d0..e7964c3ae3368 100644 --- a/paddle/fluid/pybind/pir.cc +++ b/paddle/fluid/pybind/pir.cc @@ -136,6 +136,9 @@ inline void SetProgramInt64Attr(std::shared_ptr program, } std::string GetValueInfo(Value v) { + if (v.impl() == nullptr) { + return "nullptr value"; + } std::stringstream ss; if (auto op_result = v.dyn_cast()) { ss << "define_op_name=" << op_result.owner()->name(); @@ -1058,12 +1061,11 @@ int AppendSetParameters(Program *forward_program, std::unordered_set added_op_result; for (const auto &result : outputs_op_result) { - if (!added_op_result.count(result)) { + if (!added_op_result.count(result) || IsFakeOpResult(result)) { std::string parameter_name = name_prefix + std::to_string(counter); AppendSetParameter( forward_program, result, parameter_name, start_point + counter); counter += 1; - added_op_result.insert(result); } } diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 4cd13ec19846a..f7d32bb61908b 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -42,6 +42,7 @@ _global_flags, get_default_dtype, in_dynamic_or_pir_mode, + in_pir_mode, no_grad, ) from .. import functional as F @@ -1056,7 +1057,7 @@ def __init__( self._trainable_statistics = trainable_statistics def forward(self, input): - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm( input, self._mean, @@ -1072,13 +1073,29 @@ def forward(self, input): ) if self._act is None: return batch_norm_out - if in_dynamic_mode(): - return dygraph_utils._append_activation_in_dygraph( - batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn - ) - else: - act_op = getattr(_C_ops, self._act) - return act_op(batch_norm_out) + + return dygraph_utils._append_activation_in_dygraph( + batch_norm_out, act=self._act, use_mkldnn=self._use_mkldnn + ) + elif in_pir_mode(): + batch_norm_out, t1, t2, t3, t4, _ = _C_ops.batch_norm_( + input, + self._mean, + self._variance, + self.weight, + self.bias, + not self.training, + self._momentum, + self._epsilon, + self._data_layout, + self._use_global_stats, + self._trainable_statistics, + ) + if self._act is None: + return batch_norm_out + + act_op = getattr(_C_ops, self._act) + return act_op(batch_norm_out) else: # create output # mean and mean_out share the same memory diff --git a/test/dygraph_to_static/test_resnet.py b/test/dygraph_to_static/test_resnet.py index d86024400272f..7f72b900133a9 100644 --- a/test/dygraph_to_static/test_resnet.py +++ b/test/dygraph_to_static/test_resnet.py @@ -21,10 +21,7 @@ import numpy as np from dygraph_to_static_utils import ( Dy2StTestBase, - test_default_mode_only, - test_legacy_only, - test_pt_only, - test_sot_only, + test_default_and_pir, ) from predictor_utils import PredictorTools @@ -341,7 +338,12 @@ def do_train(self, to_static): ) if batch_id == 10: if to_static: - paddle.jit.save(resnet, self.model_save_prefix) + # TODO(@xiongkun): open after save / load supported in pir. + if ( + to_static + and not paddle.base.framework.use_pir_api() + ): + paddle.jit.save(resnet, self.model_save_prefix) else: paddle.save( resnet.state_dict(), @@ -442,20 +444,7 @@ def verify_predict(self): err_msg=f'predictor_pre:\n {predictor_pre}\n, st_pre: \n{st_pre}.', ) - @test_sot_only - @test_pt_only - def test_resnet_pir(self): - static_loss = self.train(to_static=True) - dygraph_loss = self.train(to_static=False) - np.testing.assert_allclose( - static_loss, - dygraph_loss, - rtol=1e-05, - err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}', - ) - - @test_sot_only - @test_legacy_only + @test_default_and_pir def test_resnet(self): static_loss = self.train(to_static=True) dygraph_loss = self.train(to_static=False) @@ -465,9 +454,11 @@ def test_resnet(self): rtol=1e-05, err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}', ) - self.verify_predict() + # TODO(@xiongkun): open after save / load supported in pir. + if not paddle.base.framework.use_pir_api(): + self.verify_predict() - @test_default_mode_only + @test_default_and_pir def test_resnet_composite(self): core._set_prim_backward_enabled(True) core._add_skip_comp_ops("batch_norm") @@ -481,7 +472,7 @@ def test_resnet_composite(self): err_msg=f'static_loss: {static_loss} \n dygraph_loss: {dygraph_loss}', ) - @test_default_mode_only + @test_default_and_pir def test_in_static_mode_mkldnn(self): paddle.base.set_flags({'FLAGS_use_mkldnn': True}) try: