From 15f30cb6f1738518d27fd80d4f8aff7036065a6a Mon Sep 17 00:00:00 2001 From: SigureMo Date: Tue, 19 Mar 2024 12:40:17 +0000 Subject: [PATCH 1/2] [Dy2St][PIR] Hold backward program in GradNode --- paddle/fluid/framework/op_desc.cc | 3 +++ paddle/fluid/framework/type_defs.h | 3 ++- paddle/fluid/pybind/op_function_common.cc | 22 ++++++++++++++++--- .../jit/dy2static/pir_partial_program.py | 8 +++---- test/dygraph_to_static/test_no_gradient.py | 4 +++- 5 files changed, 31 insertions(+), 9 deletions(-) diff --git a/paddle/fluid/framework/op_desc.cc b/paddle/fluid/framework/op_desc.cc index 32c520711d978..c777c78b8119e 100644 --- a/paddle/fluid/framework/op_desc.cc +++ b/paddle/fluid/framework/op_desc.cc @@ -977,6 +977,9 @@ struct SetAttrDescVisitor { void operator()(const std::vector &v) const { // just do nothing. } + void operator()(const std::shared_ptr &v) const { + // just do nothing. + } void operator()(const std::vector &v) const { std::vector var_names; for (auto var : v) { diff --git a/paddle/fluid/framework/type_defs.h b/paddle/fluid/framework/type_defs.h index 61f133ceb082a..5147a298e6d4d 100644 --- a/paddle/fluid/framework/type_defs.h +++ b/paddle/fluid/framework/type_defs.h @@ -67,7 +67,8 @@ using Attribute = paddle::variant, ::pir::Block*, - std::vector<::pir::Value>>; + std::vector<::pir::Value>, + std::shared_ptr<::pir::Program>>; using AttributeMap = std::unordered_map; using OpCreator = diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index f8f1424ded243..f64c919baa436 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -858,6 +858,22 @@ void CastPyArg2AttrIRBlock(PyObject* obj, attrs[key] = reinterpret_cast<::pir::Block*&>(vh[0]); } +void CastPyArg2AttrIRProgram(PyObject* obj, + paddle::framework::AttributeMap& attrs, // NOLINT + const std::string& key, + const std::string& op_type, + ssize_t arg_pos) { + VLOG(1) << "After Process shared_ptr"; + ::pybind11::object o = + ::pybind11::reinterpret_borrow<::pybind11::object>(obj); + // ::pybind11::detail::instance* inst = + // (::pybind11::detail::instance*)obj; // NOLINT + // void** vh = inst->simple_layout ? inst->simple_value_holder + // : &inst->nonsimple.values_and_holders[0]; + // attrs[key] = reinterpret_cast>(vh[0]); + attrs[key] = o.cast&>(); +} + void CastPyArg2AttrValues(PyObject* obj, paddle::framework::AttributeMap& attrs, // NOLINT const std::string& key, @@ -1020,9 +1036,9 @@ void ConstructAttrMapForRunProgram( if (std::set({"cuda_graph_capture_mode"}).count(key)) { CastPyArg2AttrString(obj, attrs, key, op_type, arg_pos); - } else if (std::set({"global_block", - "forward_global_block", - "backward_global_block"}) + } else if (std::set({"global_block"}).count(key)) { + CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); + } else if (std::set({"forward_program", "backward_program"}) .count(key)) { CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); } else if (std::set({"is_test", "use_interpretorcore"}) diff --git a/python/paddle/jit/dy2static/pir_partial_program.py b/python/paddle/jit/dy2static/pir_partial_program.py index f57ccc7b01019..a8250ef2c31fc 100644 --- a/python/paddle/jit/dy2static/pir_partial_program.py +++ b/python/paddle/jit/dy2static/pir_partial_program.py @@ -869,10 +869,10 @@ def _prune_unused_params(self, program): def _prepare_attributes(self): attrs = [ - 'forward_global_block', - self.program.forward_program.global_block(), - 'backward_global_block', - self.program.backward_program.global_block(), + 'forward_program', + self.program.forward_program, + 'backward_program', + self.program.backward_program, 'is_test', not self.training, 'program_id', diff --git a/test/dygraph_to_static/test_no_gradient.py b/test/dygraph_to_static/test_no_gradient.py index 1bd3a02f54ede..391ee176dfb58 100644 --- a/test/dygraph_to_static/test_no_gradient.py +++ b/test/dygraph_to_static/test_no_gradient.py @@ -15,7 +15,7 @@ import unittest import numpy -from dygraph_to_static_utils import Dy2StTestBase +from dygraph_to_static_utils import Dy2StTestBase, test_ast_only, test_pir_only import paddle @@ -33,6 +33,8 @@ def main_func(x, index): class TestNoGradientCase(Dy2StTestBase): + @test_ast_only + @test_pir_only def test_no_gradient(self): paddle.disable_static() x = paddle.randn([10, 3]) From 17c5287e4989cc6c9baf2a4860ccdff7926e541b Mon Sep 17 00:00:00 2001 From: SigureMo Date: Sun, 7 Apr 2024 07:40:31 +0000 Subject: [PATCH 2/2] tmp commit --- paddle/fluid/pybind/op_function_common.cc | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/pybind/op_function_common.cc b/paddle/fluid/pybind/op_function_common.cc index f64c919baa436..3800eab7c79cc 100644 --- a/paddle/fluid/pybind/op_function_common.cc +++ b/paddle/fluid/pybind/op_function_common.cc @@ -36,6 +36,7 @@ #include "paddle/phi/common/complex.h" #include "paddle/pir/include/core/block.h" #include "paddle/pir/include/core/op_result.h" +#include "paddle/pir/include/core/program.h" #include "paddle/pir/include/core/value.h" namespace paddle { @@ -864,14 +865,16 @@ void CastPyArg2AttrIRProgram(PyObject* obj, const std::string& op_type, ssize_t arg_pos) { VLOG(1) << "After Process shared_ptr"; - ::pybind11::object o = - ::pybind11::reinterpret_borrow<::pybind11::object>(obj); + ::pybind11::object o = ::pybind11::reinterpret_steal<::pybind11::object>(obj); + // ::pybind11::object o = + // ::pybind11::reinterpret_borrow<::pybind11::object>(obj); // ::pybind11::detail::instance* inst = // (::pybind11::detail::instance*)obj; // NOLINT // void** vh = inst->simple_layout ? inst->simple_value_holder - // : &inst->nonsimple.values_and_holders[0]; + // : + // &inst->nonsimple.values_and_holders[0]; // attrs[key] = reinterpret_cast>(vh[0]); - attrs[key] = o.cast&>(); + attrs[key] = o.cast>(); } void CastPyArg2AttrValues(PyObject* obj, @@ -1040,7 +1043,7 @@ void ConstructAttrMapForRunProgram( CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); } else if (std::set({"forward_program", "backward_program"}) .count(key)) { - CastPyArg2AttrIRBlock(obj, attrs, key, op_type, arg_pos); + CastPyArg2AttrIRProgram(obj, attrs, key, op_type, arg_pos); } else if (std::set({"is_test", "use_interpretorcore"}) .count(key)) { CastPyArg2AttrBoolean(obj, attrs, key, op_type, arg_pos);