Skip to content

Commit

Permalink
Use new shape_infer for glu
Browse files Browse the repository at this point in the history
  • Loading branch information
mitruska committed Nov 26, 2024
1 parent 87eacdf commit b8ad0ba
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 21 deletions.
3 changes: 0 additions & 3 deletions src/common/transformations/include/ov_ops/glu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,6 @@ class TRANSFORMATIONS_API GLU : public ov::op::Op {
ov::element::Type m_output_type{};
};

TRANSFORMATIONS_API std::vector<ov::PartialShape> shape_infer(const GLU* op,
std::vector<ov::PartialShape> input_shapes);

} // namespace internal
} // namespace op
} // namespace ov
13 changes: 3 additions & 10 deletions src/common/transformations/src/ov_ops/glu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,9 @@ bool GLU::visit_attributes(ov::AttributeVisitor& visitor) {
void GLU::validate_and_infer_types() {
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;

set_output_type(0, output_type, shape_infer(this, {get_input_partial_shape(0)})[0]);

// const auto input_shapes = ov::util::get_node_input_partial_shapes(*this);
// const auto output_shapes = shape_infer(this, input_shapes);
// set_output_type(0, output_type, output_shapes[0]);
const auto input_shapes = ov::util::get_node_input_partial_shapes(*this);
const auto output_shapes = shape_infer(this, input_shapes);
set_output_type(0, output_type, output_shapes[0]);
}

std::shared_ptr<Node> GLU::clone_with_new_inputs(const ov::OutputVector& new_args) const {
Expand All @@ -53,11 +51,6 @@ std::shared_ptr<Node> GLU::clone_with_new_inputs(const ov::OutputVector& new_arg
m_split_to_glu_idx,
m_output_type);
}

std::vector<ov::PartialShape> shape_infer(const GLU* op, std::vector<ov::PartialShape> input_shapes) {
return glu_shape_infer(op, input_shapes);
}

} // namespace internal
} // namespace op
} // namespace ov
6 changes: 3 additions & 3 deletions src/core/shape_inference/include/glu_shape_inference.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ namespace ov {
namespace op {
namespace internal {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> glu_shape_infer(const GLU* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& tensor_accessor = make_tensor_accessor()) {
std::vector<TRShape> shape_infer(const GLU* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& tensor_accessor = make_tensor_accessor()) {
const auto inputs_count = input_shapes.size();
NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 1);

Expand Down
7 changes: 2 additions & 5 deletions src/plugins/intel_gpu/src/graph/swiglu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
//

#include "ov_ops/glu.hpp"
#include "glu_shape_inference.hpp"
#include "swiglu_inst.h"

#include "primitive_type_base.h"
Expand Down Expand Up @@ -32,11 +33,7 @@ std::vector<layout> swiglu_inst::calc_output_layouts(swiglu_node const& /*node*/
op.set_axis(desc->axis);
op.set_split_lengths(desc->split_lengths);

std::vector<ShapeType> input_shapes = {
impl_param.get_input_layout(0).get<ShapeType>(),
ShapeType(ov::Shape({})),
ShapeType(ov::Shape{2})
};
std::vector<ShapeType> input_shapes = {impl_param.get_input_layout(0).get<ShapeType>()};

std::vector<ShapeType> output_shapes = shape_infer(&op, input_shapes);

Expand Down

0 comments on commit b8ad0ba

Please sign in to comment.