diff --git a/cinn/frontend/op_mappers/paddle/matmul.cc b/cinn/frontend/op_mappers/paddle/matmul.cc index 7db1c86fea..be5e735a85 100644 --- a/cinn/frontend/op_mappers/paddle/matmul.cc +++ b/cinn/frontend/op_mappers/paddle/matmul.cc @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include + #include "cinn/frontend/op_mapper_registry.h" #include "cinn/frontend/op_mappers/common_utils.h" @@ -46,6 +48,72 @@ void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c ctx.AddVarModelToProgram(out_name, out->id); } +void MatMulGradOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& ctx) { + // get dy + CHECK_EQ(op_desc.Input(paddle::GradVarName("Out")).size(), 1UL); + auto dout_name = op_desc.Input(paddle::GradVarName("Out")).front(); + + // get intput X and Y + CHECK_EQ(op_desc.Input("X").size(), 1UL); + auto x_name = op_desc.Input("X").front(); + CHECK_EQ(op_desc.Input("Y").size(), 1UL); + auto y_name = op_desc.Input("Y").front(); + + // get d_x + std::string dx_name, dy_name; + bool has_dx = !op_desc.Output(paddle::GradVarName("X")).empty(); + bool has_dy = !op_desc.Output(paddle::GradVarName("Y")).empty(); + if (has_dx) { + CHECK_EQ(op_desc.Output(paddle::GradVarName("X")).size(), 1UL); + dx_name = op_desc.Output(paddle::GradVarName("X")).front(); + } + if (has_dy) { + CHECK_EQ(op_desc.Output(paddle::GradVarName("Y")).size(), 1UL); + dy_name = op_desc.Output(paddle::GradVarName("Y")).front(); + } + + // get attr + auto trans_x = utils::GetAttrOrDefault(op_desc, "trans_x", false); + trans_x = utils::GetAttrOrDefault(op_desc, "transpose_X", trans_x); + + auto trans_y = utils::GetAttrOrDefault(op_desc, "trans_y", false); + trans_y = utils::GetAttrOrDefault(op_desc, "transpose_Y", trans_y); + + auto alpha = utils::GetAttrOrDefault(op_desc, "alpha", 1.0f); + + auto x = ctx.GetVar(x_name); + auto y = ctx.GetVar(y_name); + auto dout = ctx.GetVar(dout_name); + if (has_dx) { + absl::optional dx; + if (trans_x && trans_y) { + dx = ctx.Builder()->Matmul(y, dout, true, true, alpha); + } else if (trans_x) { + dx = ctx.Builder()->Matmul(y, dout, false, true, alpha); + } else if (trans_y) { + dx = ctx.Builder()->Matmul(dout, y, false, false, alpha); + } else { + dx = ctx.Builder()->Matmul(dout, y, false, true, alpha); + } + ctx.AddVar(dx_name, dx.value()); + ctx.AddVarModelToProgram(dx_name, dx.value()->id); + } + if (has_dy) { + absl::optional dy; + if (trans_x && trans_y) { + dy = ctx.Builder()->Matmul(dout, x, true, true, alpha); + } else if (trans_x) { + dy = ctx.Builder()->Matmul(x, dout, false, false, alpha); + } else if (trans_y) { + dy = ctx.Builder()->Matmul(dout, x, true, false, alpha); + } else { + dy = ctx.Builder()->Matmul(x, dout, true, false, alpha); + } + ctx.AddVar(dy_name, dy.value()); + ctx.AddVarModelToProgram(dy_name, dy.value()->id); + } +} + } // namespace paddle_mappers } // namespace frontend } // namespace cinn @@ -53,5 +121,6 @@ void MatMulOpMapper(const paddle::cpp::OpDesc& op_desc, const OpMapperContext& c CINN_REGISTER_HELPER(paddle_matmul) { CINN_REGISTER_OP_MAPPER(matmul, cinn::frontend::paddle_mappers::MatMulOpMapper) CINN_REGISTER_OP_MAPPER(matmul_v2, cinn::frontend::paddle_mappers::MatMulOpMapper) + CINN_REGISTER_OP_MAPPER(matmul_v2_grad, cinn::frontend::paddle_mappers::MatMulGradOpMapper) return true; } diff --git a/python/tests/op_mappers/test_matmul_op.py b/python/tests/op_mappers/test_matmul_op.py new file mode 100644 index 0000000000..777b552c84 --- /dev/null +++ b/python/tests/op_mappers/test_matmul_op.py @@ -0,0 +1,51 @@ +#!/usr/bin/env python3 + +# Copyright (c) 2023 CINN Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_mapper_test import OpMapperTest, logger +import paddle + + +class TestMatmulOp(OpMapperTest): + def init_input_data(self): + self.feed_data = { + "x": self.random([16, 32], "float32"), + "y": self.random([32, 16], "float32") + } + + def set_op_type(self): + return "matmul" + + def set_op_inputs(self): + x = paddle.static.data('X', self.feed_data["x"].shape, + self.feed_data["x"].dtype) + x = paddle.static.data('Y', self.feed_data["y"].shape, + self.feed_data["Y"].dtype) + return {'X': [x], 'Y': [y]} + + def set_op_attrs(self): + return {"trans_x": False, "trans_y": False} + + def set_op_outputs(self): + return {'Out': [str(self.feed_data['x'].dtype)]} + + def test_check_results(self): + self.check_outputs_and_grads() + + +if __name__ == "__main__": + unittest.main()