Skip to content

Commit

Permalink
[Op][Internal][GPU] SwiGLU op internal common (#27579)
Browse files Browse the repository at this point in the history
### Details:
- This PR makes GPU SwiGLU to be internal op and swiglu_fusion available
in common_optimizations, possible to be reused by other plugins
- Only necessary updates including style alignment, no logic or
functional changes intended, basically the op has been moved as is, to
not complicate review and avoid issues
 
- Needed to link openvino::runtime::dev
src/plugins/intel_gpu/src/kernel_selector/CMakeLists.txt
https://github.com/openvinotoolkit/openvino/blob/7566bc94c58a501389647b4cc4d7c21df311fa63/src/plugins/intel_gpu/src/kernel_selector/CMakeLists.txt#L78
for visibility of `ov_ops/swiglu.hpp` in

https://github.com/openvinotoolkit/openvino/blob/7566bc94c58a501389647b4cc4d7c21df311fa63/src/plugins/intel_gpu/src/kernel_selector/kernels/swiglu/swiglu_kernel_base.h#L9

Tickets for follow ups: 157623, 157615.

### Tickets:
 - 155542, 138911
  • Loading branch information
mitruska authored Nov 21, 2024
1 parent 9d36703 commit 477722d
Show file tree
Hide file tree
Showing 14 changed files with 158 additions and 118 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,22 +5,19 @@
#pragma once

#include "openvino/op/op.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace intel_gpu {
namespace op {
namespace internal {

/// \brief Operator performing Swish Gated Linear Unit Activation
/// This operation performs gated linear unit activation that combines swish or gelu activation function
class SwiGLU : public ov::op::Op {
class TRANSFORMATIONS_API SwiGLU : public ov::op::Op {
public:
OPENVINO_OP("SwiGLU", "gpu_opset");
OPENVINO_OP("SwiGLU", "ie_internal_opset");

enum GluType {
Swish = 0,
Gelu,
Gelu_Tanh
};
enum GluType { Swish = 0, Gelu, Gelu_Tanh };

SwiGLU() = default;
/// \brief Constructs an SwiGLU operation.
Expand All @@ -44,26 +41,44 @@ class SwiGLU : public ov::op::Op {

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

int64_t get_axis() const { return m_axis; }
int64_t get_split_lengths() const { return m_split_lengths; }
GluType get_glu_type() const { return m_glu_type; }
size_t get_split_to_glu_idx() const { return m_split_to_glu_idx; }
int64_t get_axis() const {
return m_axis;
}
int64_t get_split_lengths() const {
return m_split_lengths;
}
GluType get_glu_type() const {
return m_glu_type;
}
size_t get_split_to_glu_idx() const {
return m_split_to_glu_idx;
}

void set_axis(int64_t axis) { m_axis = axis; }
void set_split_lengths(int64_t split_lengths) { m_split_lengths = split_lengths; }
void set_glu_type(GluType glu_type) { m_glu_type = glu_type; }
void set_split_to_glu_idx(size_t split_to_glu_idx) { m_split_to_glu_idx = split_to_glu_idx; }
void set_axis(int64_t axis) {
m_axis = axis;
}
void set_split_lengths(int64_t split_lengths) {
m_split_lengths = split_lengths;
}
void set_glu_type(GluType glu_type) {
m_glu_type = glu_type;
}
void set_split_to_glu_idx(size_t split_to_glu_idx) {
m_split_to_glu_idx = split_to_glu_idx;
}

private:
int64_t m_axis = 0;
int64_t m_split_lengths = 0;
GluType m_glu_type = GluType::Swish;
size_t m_split_to_glu_idx = 0;
ov::element::Type m_output_type;
ov::element::Type m_output_type{};
};

std::vector<ov::PartialShape> shape_infer(const SwiGLU* op, std::vector<ov::PartialShape> input_shapes);
// TODO 157615: Move to shape_inference
TRANSFORMATIONS_API std::vector<ov::PartialShape> shape_infer(const SwiGLU* op,
std::vector<ov::PartialShape> input_shapes);

} // namespace op
} // namespace intel_gpu
} // namespace ov
} // namespace internal
} // namespace op
} // namespace ov
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/pass/manager.hpp"
#include "openvino/pass/matcher_pass.hpp"
#include "transformations_visibility.hpp"

namespace ov {
namespace pass {

class TRANSFORMATIONS_API SwiGLUFusion : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("SwiGLUFusion", "0");
SwiGLUFusion();
};

} // namespace pass
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,29 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "intel_gpu/op/swiglu.hpp"
#include "ov_ops/swiglu.hpp"

#include "openvino/core/partial_shape.hpp"
#include "openvino/core/validation_util.hpp"
#include "openvino/op/variadic_split.hpp"
#include "variadic_split_shape_inference.hpp"

namespace ov {
namespace intel_gpu {
namespace op {
namespace internal {

SwiGLU::SwiGLU(const Output<Node>& data,
int64_t axis,
int64_t split_lengths,
const GluType glu_type,
const size_t split_to_glu_idx,
const ov::element::Type output_type)
: Op({data}), m_axis(axis), m_split_lengths(split_lengths),
m_glu_type(glu_type), m_split_to_glu_idx(split_to_glu_idx), m_output_type(output_type) {
: Op({data}),
m_axis(axis),
m_split_lengths(split_lengths),
m_glu_type(glu_type),
m_split_to_glu_idx(split_to_glu_idx),
m_output_type(output_type) {
validate_and_infer_types();
}

Expand All @@ -33,11 +38,9 @@ bool SwiGLU::visit_attributes(ov::AttributeVisitor& visitor) {
void SwiGLU::validate_and_infer_types() {
auto output_type = m_output_type == ov::element::undefined ? get_input_element_type(0) : m_output_type;

std::vector<ov::PartialShape> input_shapes = {
get_input_partial_shape(0),
ov::PartialShape(ov::Shape{}),
ov::PartialShape(ov::Shape{2})
};
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0),
ov::PartialShape(ov::Shape{}),
ov::PartialShape(ov::Shape{2})};

set_output_type(0, output_type, shape_infer(this, input_shapes)[0]);
}
Expand All @@ -54,16 +57,18 @@ std::shared_ptr<Node> SwiGLU::clone_with_new_inputs(const ov::OutputVector& new_

std::vector<ov::PartialShape> shape_infer(const SwiGLU* op, std::vector<ov::PartialShape> input_shapes) {
ov::op::v1::VariadicSplit variadic_split;
std::vector<int64_t> axis = { op->get_axis() };
std::vector<int64_t> split_lengths = { op->get_split_lengths(), -1 };
std::vector<int64_t> axis = {op->get_axis()};
std::vector<int64_t> split_lengths = {op->get_split_lengths(), -1};

std::unordered_map<size_t, ov::Tensor> const_data;
const_data.emplace(1, ov::Tensor(ov::element::i64, ov::Shape{}, static_cast<void*>(axis.data())));
const_data.emplace(2, ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, static_cast<void*>(split_lengths.data())));
const_data.emplace(
2,
ov::Tensor(ov::element::i64, ov::Shape{split_lengths.size()}, static_cast<void*>(split_lengths.data())));

return ov::op::v1::shape_infer(&variadic_split, input_shapes, ov::make_tensor_accessor(const_data));
}

} // namespace internal
} // namespace op
} // namespace intel_gpu
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,22 @@
// SPDX-License-Identifier: Apache-2.0
//

#include "swiglu_fusion.hpp"

#include "intel_gpu/op/swiglu.hpp"
#include "transformations/common_optimizations/swiglu_fusion.hpp"

#include "openvino/core/rt_info.hpp"
#include "openvino/op/constant.hpp"
#include "openvino/op/gelu.hpp"
#include "openvino/op/multiply.hpp"
#include "openvino/op/swish.hpp"
#include "openvino/op/gelu.hpp"
#include "openvino/op/variadic_split.hpp"
#include "openvino/pass/manager.hpp"
#include "openvino/pass/pattern/op/or.hpp"
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "ov_ops/swiglu.hpp"
#include "transformations/utils/utils.hpp"

namespace ov {
namespace intel_gpu {
namespace pass {

SwiGLUFusion::SwiGLUFusion() {
using namespace ov::pass::pattern;
Expand Down Expand Up @@ -60,20 +60,21 @@ SwiGLUFusion::SwiGLUFusion() {
auto isSwiGLU = pattern_map.count(swish_m);
auto isGeGLU = pattern_map.count(gelu_m);
size_t split_to_glu_idx = 0;
ov::intel_gpu::op::SwiGLU::GluType glu_type = ov::intel_gpu::op::SwiGLU::GluType::Swish;
ov::op::internal::SwiGLU::GluType glu_type = ov::op::internal::SwiGLU::GluType::Swish;

if (isSwiGLU) {
auto swish = std::dynamic_pointer_cast<ov::op::v4::Swish>(pattern_map.at(swish_m).get_node_shared_ptr());
glu_type = ov::intel_gpu::op::SwiGLU::GluType::Swish;
glu_type = ov::op::internal::SwiGLU::GluType::Swish;
split_to_glu_idx = swish->input_value(0).get_index();

size_t split_in_idx = ov::is_type<ov::op::v4::Swish>(mul->get_input_node_shared_ptr(0)) ? 1 : 0;
if (mul->input_value(split_in_idx).get_index() == split_to_glu_idx)
return false;
} else if (isGeGLU) {
auto gelu = std::dynamic_pointer_cast<ov::op::v7::Gelu>(pattern_map.at(gelu_m).get_node_shared_ptr());
glu_type = (gelu->get_approximation_mode() == ov::op::GeluApproximationMode::ERF) ? ov::intel_gpu::op::SwiGLU::GluType::Gelu
: ov::intel_gpu::op::SwiGLU::GluType::Gelu_Tanh;
glu_type = (gelu->get_approximation_mode() == ov::op::GeluApproximationMode::ERF)
? ov::op::internal::SwiGLU::GluType::Gelu
: ov::op::internal::SwiGLU::GluType::Gelu_Tanh;
split_to_glu_idx = gelu->input_value(0).get_index();

size_t split_in_idx = ov::is_type<ov::op::v7::Gelu>(mul->get_input_node_shared_ptr(0)) ? 1 : 0;
Expand All @@ -83,7 +84,8 @@ SwiGLUFusion::SwiGLUFusion() {
OPENVINO_THROW("'glu_type' not initialized");
}

auto variadic_split = std::dynamic_pointer_cast<ov::op::v1::VariadicSplit>(pattern_map.at(variadic_split_m).get_node_shared_ptr());
auto variadic_split = std::dynamic_pointer_cast<ov::op::v1::VariadicSplit>(
pattern_map.at(variadic_split_m).get_node_shared_ptr());
auto variadic_split_in_ps = variadic_split->get_input_partial_shape(0);
auto last_dim = variadic_split_in_ps.rank().get_length() - 1;

Expand All @@ -94,7 +96,8 @@ SwiGLUFusion::SwiGLUFusion() {
return false;
auto axis_value = axis->cast_vector<int64_t>()[0];

auto split_lengths = std::dynamic_pointer_cast<ov::op::v0::Constant>(pattern_map.at(split_lengths_const_m).get_node_shared_ptr());
auto split_lengths = std::dynamic_pointer_cast<ov::op::v0::Constant>(
pattern_map.at(split_lengths_const_m).get_node_shared_ptr());
auto split_lengths_value = split_lengths->cast_vector<int64_t>()[0];
// Allow only case that exactly splits in half along the last dimension
auto split_length = variadic_split_in_ps[last_dim].get_length() / 2;
Expand All @@ -104,12 +107,12 @@ SwiGLUFusion::SwiGLUFusion() {
auto data = pattern_map.at(data_m);
auto output_type = m.get_match_root()->get_output_element_type(0);

auto swiglu = std::make_shared<op::SwiGLU>(data,
axis_value,
split_lengths_value,
glu_type,
split_to_glu_idx,
output_type);
auto swiglu = std::make_shared<ov::op::internal::SwiGLU>(data,
axis_value,
split_lengths_value,
glu_type,
split_to_glu_idx,
output_type);
swiglu->set_friendly_name(m.get_match_root()->get_friendly_name());
ov::copy_runtime_info(m.get_matched_nodes(), swiglu);
ov::replace_node(m.get_match_root(), swiglu);
Expand All @@ -121,5 +124,5 @@ SwiGLUFusion::SwiGLUFusion() {
this->register_matcher(m, callback);
}

} // namespace intel_gpu
} // namespace pass
} // namespace ov
Loading

0 comments on commit 477722d

Please sign in to comment.