From df333564571e4e560d54b90e1f6b0daed1f76835 Mon Sep 17 00:00:00 2001 From: Juncheng Date: Thu, 2 Jul 2020 16:07:01 +0800 Subject: [PATCH] Add identity op to bn grad (#3118) * Add identity op to bn grad * rename identity to diff_identity * SetTensorDescInferFn use TensorDesc * revert flow.identity api Former-commit-id: 61a2df94b9b4a5e7f72eb4622aecae50cf0e2a1b --- .../customized/kernels/identity_kernel.cpp | 43 ++++++++++++++++ oneflow/customized/ops/identity_op.cpp | 51 +++++++++++++++++++ oneflow/customized/ops/normalization_op.cpp | 20 +++++++- 3 files changed, 112 insertions(+), 2 deletions(-) create mode 100644 oneflow/customized/kernels/identity_kernel.cpp create mode 100644 oneflow/customized/ops/identity_op.cpp diff --git a/oneflow/customized/kernels/identity_kernel.cpp b/oneflow/customized/kernels/identity_kernel.cpp new file mode 100644 index 00000000000..49eba81ad76 --- /dev/null +++ b/oneflow/customized/kernels/identity_kernel.cpp @@ -0,0 +1,43 @@ +#include "oneflow/core/framework/framework.h" +#include "oneflow/core/kernel/new_kernel_util.h" + +namespace oneflow { + +namespace { + +template +class IdentityKernel final : public user_op::OpKernel { + public: + IdentityKernel() = default; + ~IdentityKernel() override = default; + + private: + void Compute(user_op::KernelComputeContext* ctx) const override { + const user_op::Tensor* in = ctx->Tensor4ArgNameAndIndex("in", 0); + user_op::Tensor* out = ctx->Tensor4ArgNameAndIndex("out", 0); + const ShapeView& in_shape = in->shape(); + CHECK_EQ(out->shape(), in_shape); + const DataType in_data_type = in->data_type(); + CHECK_EQ(out->data_type(), in_data_type); + Memcpy(ctx->device_ctx(), out->mut_dptr(), in->dptr(), + in_shape.elem_cnt() * GetSizeOfDataType(in_data_type)); + } + bool AlwaysComputeWhenAllOutputsEmpty() const override { return false; } +}; + +#define REGISTER_IDENTITY_KERNEL(device) \ + REGISTER_USER_KERNEL("identity") \ + .SetCreateFn>() \ + .SetIsMatchedHob(user_op::HobDeviceType() == device) \ + .SetInplaceProposalFn([](const user_op::InferContext&, \ + user_op::AddInplaceArgPair AddInplaceArgPairFn) -> Maybe { \ + OF_RETURN_IF_ERROR(AddInplaceArgPairFn("out", 0, "in", 0, false)); \ + return Maybe::Ok(); \ + }); + +REGISTER_IDENTITY_KERNEL(DeviceType::kCPU) +REGISTER_IDENTITY_KERNEL(DeviceType::kGPU) + +} // namespace + +} // namespace oneflow diff --git a/oneflow/customized/ops/identity_op.cpp b/oneflow/customized/ops/identity_op.cpp new file mode 100644 index 00000000000..5516a488b48 --- /dev/null +++ b/oneflow/customized/ops/identity_op.cpp @@ -0,0 +1,51 @@ +#include "oneflow/core/framework/framework.h" + +namespace oneflow { + +namespace { + +REGISTER_USER_OP("identity") + .Input("in") + .Output("out") + .SetTensorDescInferFn([](user_op::InferContext* ctx) -> Maybe { + const user_op::TensorDesc* in_desc = ctx->TensorDesc4ArgNameAndIndex("in", 0); + user_op::TensorDesc* out_desc = ctx->TensorDesc4ArgNameAndIndex("out", 0); + *out_desc = *in_desc; + return Maybe::Ok(); + }) + .SetBatchAxisInferFn([](user_op::BatchAxisContext* ctx) -> Maybe { + *ctx->BatchAxis4ArgNameAndIndex("out", 0) = *ctx->BatchAxis4ArgNameAndIndex("in", 0); + return Maybe::Ok(); + }) + .SetGetSbpFn([](user_op::SbpContext* ctx) -> Maybe { + const user_op::TensorDesc& in_tensor = ctx->LogicalTensorDesc4InputArgNameAndIndex("in", 0); + FOR_RANGE(int64_t, i, 0, in_tensor.shape().NumAxes()) { + ctx->NewBuilder() + .Split(user_op::OpArg("in", 0), i) + .Split(user_op::OpArg("out", 0), i) + .Build(); + } + ctx->NewBuilder() + .PartialSum(user_op::OpArg("in", 0)) + .PartialSum(user_op::OpArg("out", 0)) + .Build(); + return Maybe::Ok(); + }); + +REGISTER_USER_OP_GRAD("identity") + .SetGenBackwardOpConfFn([](const user_op::UserOpWrapper& op, user_op::AddOpFn AddOp) { + if (op.NeedGenGradTensor4OpInput("in", 0)) { + user_op::UserOpConfWrapperBuilder builder(op.op_name() + "_grad"); + user_op::UserOpConfWrapper identity_op = + builder.Op("identity") + .Input("in", op.GetGradTensorWithOpOutput("out", 0)) + .Output("out") + .Build(); + op.BindGradTensorWithOpInput(identity_op.output("out", 0), "in", 0); + AddOp(identity_op); + } + }); + +} // namespace + +} // namespace oneflow diff --git a/oneflow/customized/ops/normalization_op.cpp b/oneflow/customized/ops/normalization_op.cpp index e15eff530b8..34e6a92cda1 100644 --- a/oneflow/customized/ops/normalization_op.cpp +++ b/oneflow/customized/ops/normalization_op.cpp @@ -283,11 +283,27 @@ REGISTER_USER_OP_GRAD("normalization") need_norm_grad_op = true; } if (op.NeedGenGradTensor4OpInput("gamma", 0)) { - op.BindGradTensorWithOpInput(grad_op.output("gamma_diff", 0), "gamma", 0); + // TODO(liujuncheng): delete identity op when boxing support separated regsts + const auto identity = + user_op::UserOpConfWrapperBuilder(op.op_name() + "_grad_gamma_diff_identity") + .Op("identity") + .Input("in", grad_op.output("gamma_diff", 0)) + .Output("out") + .Build(); + AddOp(identity); + op.BindGradTensorWithOpInput(identity.output("out", 0), "gamma", 0); need_norm_grad_op = true; } if (op.NeedGenGradTensor4OpInput("beta", 0)) { - op.BindGradTensorWithOpInput(grad_op.output("beta_diff", 0), "beta", 0); + // TODO(liujuncheng): delete identity op when boxing support separated regsts + const auto identity = + user_op::UserOpConfWrapperBuilder(op.op_name() + "_grad_beta_diff_identity") + .Op("identity") + .Input("in", grad_op.output("beta_diff", 0)) + .Output("out") + .Build(); + AddOp(identity); + op.BindGradTensorWithOpInput(identity.output("out", 0), "beta", 0); need_norm_grad_op = true; } if (need_norm_grad_op) { AddOp(grad_op); }