Skip to content

Commit

Permalink
Add shape_infer for GLU op
Browse files Browse the repository at this point in the history
  • Loading branch information
mitruska committed Nov 25, 2024
1 parent 0942797 commit 87eacdf
Show file tree
Hide file tree
Showing 5 changed files with 57 additions and 23 deletions.
1 change: 0 additions & 1 deletion src/common/transformations/include/ov_ops/glu.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@ class TRANSFORMATIONS_API GLU : public ov::op::Op {
ov::element::Type m_output_type{};
};

// TODO 157615: Move to shape_inference
TRANSFORMATIONS_API std::vector<ov::PartialShape> shape_infer(const GLU* op,
std::vector<ov::PartialShape> input_shapes);

Expand Down
23 changes: 6 additions & 17 deletions src/common/transformations/src/ov_ops/glu.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,9 @@

#include "ov_ops/glu.hpp"

#include "glu_shape_inference.hpp"
#include "openvino/core/partial_shape.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/variadic_split.hpp"
#include "variadic_split_shape_inference.hpp"

namespace ov {
namespace op {
Expand Down Expand Up @@ -38,11 +37,11 @@ bool GLU::visit_attributes(ov::AttributeVisitor& visitor) {
void GLU::validate_and_infer_types() {
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;

std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0),
ov::PartialShape(ov::Shape{}),
ov::PartialShape(ov::Shape{2})};
set_output_type(0, output_type, shape_infer(this, {get_input_partial_shape(0)})[0]);

set_output_type(0, output_type, shape_infer(this, input_shapes)[0]);
// const auto input_shapes = ov::util::get_node_input_partial_shapes(*this);
// const auto output_shapes = shape_infer(this, input_shapes);
// set_output_type(0, output_type, output_shapes[0]);
}

std::shared_ptr<Node> GLU::clone_with_new_inputs(const ov::OutputVector& new_args) const {
Expand All @@ -56,17 +55,7 @@ std::shared_ptr<Node> GLU::clone_with_new_inputs(const ov::OutputVector& new_arg
}

std::vector<ov::PartialShape> shape_infer(const GLU* op, std::vector<ov::PartialShape> input_shapes) {
ov::op::v1::VariadicSplit variadic_split;
std::vector<int64_t> axis = {op->get_axis()};
std::vector<int64_t> split_lengths = {op->get_split_lengths(), -1};

std::unordered_map<size_t, ov::Tensor> const_data;
const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, static_cast<void*>(axis.data())));
const_data.emplace(
2,
ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, static_cast<void*>(split_lengths.data())));

return ov::op::v1::shape_infer(&variadic_split, input_shapes, ov::make_tensor_accessor(const_data));
return glu_shape_infer(op, input_shapes);
}

} // namespace internal
Expand Down
35 changes: 35 additions & 0 deletions src/core/shape_inference/include/glu_shape_inference.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
// Copyright (C) 2018-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "ov_ops/glu.hpp"
#include "utils.hpp"
#include "variadic_split_shape_inference.hpp"

namespace ov {
namespace op {
namespace internal {
template <class TShape, class TRShape = result_shape_t<TShape>>
std::vector<TRShape> glu_shape_infer(const GLU* op,
const std::vector<TShape>& input_shapes,
const ITensorAccessor& tensor_accessor = make_tensor_accessor()) {
const auto inputs_count = input_shapes.size();
NODE_SHAPE_INFER_CHECK(op, input_shapes, inputs_count == 1);

int64_t axis = op->get_axis();
std::vector<int64_t> split_lengths = {op->get_split_lengths(), -1};
std::vector<TShape> variadic_split_input_shapes = {input_shapes[0], TShape{}, TShape{2}};

std::unordered_map<size_t, ov::Tensor> const_data;
const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, &axis));
const_data.emplace(2, ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, split_lengths.data()));

return ov::op::v1::variadic_split::shape_infer(op,
variadic_split_input_shapes,
ov::make_tensor_accessor(const_data));
}
} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,9 @@ namespace ov {
namespace op {
namespace v1 {

namespace variadic_split {
template <typename T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const VariadicSplit* op,
std::vector<TRShape> shape_infer(const Node* op,
const std::vector<T>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
constexpr bool is_dynamic_shape = std::is_base_of<ov::PartialShape, T>::value;
Expand Down Expand Up @@ -120,6 +121,14 @@ std::vector<TRShape> shape_infer(const VariadicSplit* op,
}
return output_shapes;
}
} // namespace variadic_split

template <typename T, class TRShape = result_shape_t<T>>
std::vector<TRShape> shape_infer(const VariadicSplit* op,
const std::vector<T>& input_shapes,
const ITensorAccessor& ta = make_tensor_accessor()) {
return variadic_split::shape_infer(op, input_shapes, ta);
}

} // namespace v1
} // namespace op
Expand Down
10 changes: 6 additions & 4 deletions src/core/src/op/variadic_split.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ bool VariadicSplit::evaluate(TensorVector& outputs, const TensorVector& inputs)
++out_partial_shape;
}

return variadic_split::evaluate(outputs, inputs);
return op::variadic_split::evaluate(outputs, inputs);
} else {
return false;
}
Expand All @@ -112,16 +112,18 @@ bool VariadicSplit::has_evaluate() const {

bool VariadicSplit::evaluate_lower(TensorVector& output_values) const {
OV_OP_SCOPE(v1_Split_evaluate_lower);
return variadic_split::has_axis_and_splits_bound_set(this) && default_lower_bound_evaluator(this, output_values);
return op::variadic_split::has_axis_and_splits_bound_set(this) &&
default_lower_bound_evaluator(this, output_values);
}

bool VariadicSplit::evaluate_upper(TensorVector& output_values) const {
OV_OP_SCOPE(v1_Split_evaluate_upper);
return variadic_split::has_axis_and_splits_bound_set(this) && default_upper_bound_evaluator(this, output_values);
return op::variadic_split::has_axis_and_splits_bound_set(this) &&
default_upper_bound_evaluator(this, output_values);
}

bool VariadicSplit::evaluate_symbol(TensorSymbolVector& output_symbols) const {
return variadic_split::has_axis_and_splits_bound_set(this) &&
return op::variadic_split::has_axis_and_splits_bound_set(this) &&
ov::util::default_symbol_evaluator(this, output_symbols);
}
} // namespace v1
Expand Down

0 comments on commit 87eacdf

Please sign in to comment.