From b8ad0ba464ca120089fd5827ac2b8a2d9af58192 Mon Sep 17 00:00:00 2001 From: mitruska Date: Tue, 26 Nov 2024 12:36:32 +0100 Subject: [PATCH] Use new shape_infer for glu --- src/common/transformations/include/ov_ops/glu.hpp | 3 --- src/common/transformations/src/ov_ops/glu.cpp | 13 +++---------- .../shape_inference/include/glu_shape_inference.hpp | 6 +++--- src/plugins/intel_gpu/src/graph/swiglu.cpp | 7 ++----- 4 files changed, 8 insertions(+), 21 deletions(-) diff --git a/src/common/transformations/include/ov_ops/glu.hpp b/src/common/transformations/include/ov_ops/glu.hpp index 33442b56dbc844..add8c3a0582525 100644 --- a/src/common/transformations/include/ov_ops/glu.hpp +++ b/src/common/transformations/include/ov_ops/glu.hpp @@ -75,9 +75,6 @@ class TRANSFORMATIONS_API GLU : public ov::op::Op { ov::element::Type m_output_type{}; }; -TRANSFORMATIONS_API std::vector shape_infer(const GLU* op, - std::vector input_shapes); - } // namespace internal } // namespace op } // namespace ov diff --git a/src/common/transformations/src/ov_ops/glu.cpp b/src/common/transformations/src/ov_ops/glu.cpp index 197620adbe7ed4..9b5fb780d36bb8 100644 --- a/src/common/transformations/src/ov_ops/glu.cpp +++ b/src/common/transformations/src/ov_ops/glu.cpp @@ -37,11 +37,9 @@ bool GLU::visit_attributes(ov::AttributeVisitor& visitor) { void GLU::validate_and_infer_types() { auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type; - set_output_type(0, output_type, shape_infer(this, {get_input_partial_shape(0)})[0]); - - // const auto input_shapes = ov::util::get_node_input_partial_shapes(*this); - // const auto output_shapes = shape_infer(this, input_shapes); - // set_output_type(0, output_type, output_shapes[0]); + const auto input_shapes = ov::util::get_node_input_partial_shapes(*this); + const auto output_shapes = shape_infer(this, input_shapes); + set_output_type(0, output_type, output_shapes[0]); } std::shared_ptr GLU::clone_with_new_inputs(const ov::OutputVector& new_args) const { @@ -53,11 +51,6 @@ std::shared_ptr GLU::clone_with_new_inputs(const ov::OutputVector& new_arg m_split_to_glu_idx, m_output_type); } - -std::vector shape_infer(const GLU* op, std::vector input_shapes) { - return glu_shape_infer(op, input_shapes); -} - } // namespace internal } // namespace op } // namespace ov diff --git a/src/core/shape_inference/include/glu_shape_inference.hpp b/src/core/shape_inference/include/glu_shape_inference.hpp index 65780b54ce6ca6..eae62b3129e870 100644 --- a/src/core/shape_inference/include/glu_shape_inference.hpp +++ b/src/core/shape_inference/include/glu_shape_inference.hpp @@ -12,9 +12,9 @@ namespace ov { namespace op { namespace internal { template > -std::vector glu_shape_infer(const GLU* op, - const std::vector& input_shapes, - const ITensorAccessor& tensor_accessor = make_tensor_accessor()) { +std::vector shape_infer(const GLU* op, + const std::vector& input_shapes, + const ITensorAccessor& tensor_accessor = make_tensor_accessor()) { const auto inputs_count = input_shapes.size(); NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 1); diff --git a/src/plugins/intel_gpu/src/graph/swiglu.cpp b/src/plugins/intel_gpu/src/graph/swiglu.cpp index e82e4e974b1868..ffd5333318cee4 100644 --- a/src/plugins/intel_gpu/src/graph/swiglu.cpp +++ b/src/plugins/intel_gpu/src/graph/swiglu.cpp @@ -3,6 +3,7 @@ // #include "ov_ops/glu.hpp" +#include "glu_shape_inference.hpp" #include "swiglu_inst.h" #include "primitive_type_base.h" @@ -32,11 +33,7 @@ std::vector swiglu_inst::calc_output_layouts(swiglu_node const& /*node*/ op.set_axis(desc->axis); op.set_split_lengths(desc->split_lengths); - std::vector input_shapes = { - impl_param.get_input_layout(0).get(), - ShapeType(ov::Shape({})), - ShapeType(ov::Shape{2}) - }; + std::vector input_shapes = {impl_param.get_input_layout(0).get()}; std::vector output_shapes = shape_infer(&op, input_shapes);