Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

赛题七-开发grad_fn、next_functions两个API 并暴露到python端-v1 #54838

Merged
merged 6 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions paddle/fluid/eager/grad_node_info.cc
Original file line number Diff line number Diff line change
Expand Up @@ -559,4 +559,20 @@ void GradNodeBase::HandleComplexGradToRealGrad(
}
}

std::vector<std::shared_ptr<GradNodeBase>> GradNodeBase::NextFunctions() {
std::vector<std::shared_ptr<GradNodeBase>> next_nodes;
const paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>&
metas = OutputMeta();

for (const auto& meta_list : metas) {
for (const GradSlotMeta& meta : meta_list) {
const auto& edge = meta.GetEdge();
std::shared_ptr<GradNodeBase> next_node = edge.GetMutableGradNode();
next_nodes.push_back(next_node);
}
}

return next_nodes;
}

} // namespace egr
2 changes: 2 additions & 0 deletions paddle/fluid/eager/grad_node_info.h
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,8 @@ class GradNodeBase {
return true;
}

std::vector<std::shared_ptr<egr::GradNodeBase>> NextFunctions();

/**
* Apply GradientHook
* **/
Expand Down
42 changes: 42 additions & 0 deletions paddle/fluid/pybind/eager_properties.cc
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ limitations under the License. */
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/imperative/op_base.h"
#include "paddle/fluid/memory/allocation/allocator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/enforce.h"
Expand All @@ -31,6 +32,7 @@ limitations under the License. */
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"

#pragma GCC diagnostic ignored "-Wwrite-strings"

namespace paddle {
Expand Down Expand Up @@ -301,6 +303,41 @@ PyObject* tensor_properties_get_dtype(TensorObject* self, void* closure) {
EAGER_CATCH_AND_THROW_RETURN_NULL
}

PyObject* tensor_properties_get_grad_fn(TensorObject* self, void* closure) {
EAGER_TRY
if (!self->tensor.defined()) {
// Handle undefined tensors if necessary; otherwise, return nullptr or an
// appropriate PyObject. In this case, I will return Py_None.
Py_INCREF(Py_None);
return Py_None;
}

// Get GradNode from the tensor
auto meta = egr::EagerUtils::nullable_autograd_meta(
self->tensor); // If meta exists, get the GradNode

if (meta) {
// Get the GradNode from meta
auto grad_node = meta->GradNode(); // Convert GradNode to a Python object
// The conversion will depend on the structure of GradNode.

if (!grad_node) {
Py_INCREF(Py_None);
return Py_None;
}

PyObject* py_grad_node = ToPyObject(grad_node);

return py_grad_node;
} else {
// If meta does not exist, return an appropriate Python object (e.g., None
// or a special value).
Py_INCREF(Py_None);
return Py_None;
}
EAGER_CATCH_AND_THROW_RETURN_NULL
}

struct PyGetSetDef variable_properties[] = {
{"grad",
(getter)tensor_properties_get_grad,
Expand Down Expand Up @@ -341,6 +378,11 @@ struct PyGetSetDef variable_properties[] = {
{"dtype", (getter)tensor_properties_get_dtype, nullptr, nullptr, nullptr},
{"type", (getter)tensor_properties_get_type, nullptr, nullptr, nullptr},
{"is_leaf", (getter)tensor_properties_is_leaf, nullptr, nullptr, nullptr},
{"grad_fn",
(getter)tensor_properties_get_grad_fn,
nullptr,
nullptr,
nullptr},
{nullptr, nullptr, nullptr, nullptr, nullptr}};

// variable_properties for core.eager.StringTensor
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pybind/eager_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,15 @@ paddle::optional<paddle::Tensor> GetOptionalTensorFromArgs(
}
}


PyObject* ToPyObject(egr::GradNodeBase* grad_node) {
py::object py_obj = py::cast(grad_node, py::return_value_policy::reference);
py::handle py_handle = py::handle(py_obj);
PyObject* py_grad_node = py_handle.ptr();
Py_INCREF(py_grad_node);
return py_grad_node;
}

static paddle::Tensor& GetTensorFromPyObject(const std::string& op_type,
const std::string& arg_name,
PyObject* obj,
Expand Down
3 changes: 3 additions & 0 deletions paddle/fluid/pybind/eager_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ typedef SSIZE_T ssize_t;
#include "paddle/utils/pybind.h"
#include "pybind11/pybind11.h"
#include "pybind11/stl.h"
#include "paddle/fluid/eager/grad_node_info.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/phi/core/distributed/auto_parallel/dist_attr.h"
#include "paddle/phi/core/distributed/auto_parallel/dist_tensor.h"
Expand Down Expand Up @@ -125,6 +126,8 @@ PyObject* ToPyObject(
const std::unordered_map<std::string, std::vector<std::string>>& value);
PyObject* ToPyObject(const paddle::framework::Vocab& value);

PyObject* ToPyObject(egr::GradNodeBase* grad_node);

class PyTensorHook : public egr::TensorHook {
public:
explicit PyTensorHook(PyObject* func) : py_func_(func) {
Expand Down
9 changes: 9 additions & 0 deletions paddle/fluid/pybind/pybind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <Python.h>
#include "paddle/fluid/eager/grad_node_info.h"

// Avoid a problem with copysign defined in pyconfig.h on Windows.
#ifdef copysign
#undef copysign
Expand Down Expand Up @@ -776,6 +778,13 @@ PYBIND11_MODULE(libpaddle, m) {
}
});

py::class_<egr::GradNodeBase>(m, "GradNodeBase")
.def("name", &egr::GradNodeBase::name)
.def_property_readonly("next_functions",
&egr::GradNodeBase::NextFunctions)
.def("input_meta", &egr::GradNodeBase::InputMeta)
.def("output_meta", &egr::GradNodeBase::OutputMeta);

#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m.def("cudnn_version", &platform::DnnVersion);
m.def("gpu_memory_available", []() {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

import paddle
import paddle.nn as nn

class Testmodel(nn.Layer):
def __init__(self):
super(Testmodel, self).__init__()

def forward(self, x):
y = x ** 2
y = x + y
return y

class TestAnonmousSurvey(unittest.TestCase):

def init_graph(self):
""" define reversed graph

func_name [str]: represents the name of the operator node
next_funcs [dict]: represents the operator node
"""
self.grad_fn_1 = {
"func_name": "GradNodeAccumulation",
"next_funcs": {}
}
self.grad_fn_2 = {
"func_name": "PowGradNode",
"next_funcs": {
"GradNodeAccumulation": self.grad_fn_1
}
}
self.grad_fn_3 = {
"func_name": "AddGradNode",
"next_funcs": {
"GradNodeAccumulation": self.grad_fn_1,
"PowGradNode": self.grad_fn_2
}
}
self.output_grad_fn = {
"grad_fn": self.grad_fn_3
}

def init_data(self):
""" define output of model

the final output will be saved self.output
"""
model = Testmodel()
x = paddle.randn([1, 3, 24, 24])
x.stop_gradient = False
self.output = model(x)


def setUp(self):
self.init_graph()
self.init_data()


def test_grad_fn_and_next_funs(self):
self.check_func(self.output.grad_fn, self.output_grad_fn["grad_fn"])


def check_func(self, grad_fn, grad_fn_json):
"""check each node

:param grad_fn: grad_fn of node
:return grad_fn_json: gead_node_json of node
"""
# print(grad_fn.name())
# assert func name
self.assertEqual(grad_fn.name(), grad_fn_json["func_name"])
# Recursively test other nodes
if hasattr(grad_fn, 'next_functions') and grad_fn.next_functions[0]:
next_funcs_json = grad_fn_json["next_funcs"]
for u in grad_fn.next_functions:
self.check_func(u, next_funcs_json[u.name()])

unittest.main()