diff --git a/src/common/transformations/include/ov_ops/rotary_positional_embeddings.hpp b/src/common/transformations/include/ov_ops/rotary_positional_embeddings.hpp index dcb9aef187d2d9..08c1aa8e3f5ad8 100644 --- a/src/common/transformations/include/ov_ops/rotary_positional_embeddings.hpp +++ b/src/common/transformations/include/ov_ops/rotary_positional_embeddings.hpp @@ -23,13 +23,14 @@ class TRANSFORMATIONS_API RoPE : public Op { struct Config { size_t slice_start = 0; // slice inner-most dimensions of input size_t slice_stop = 0; - bool input_trans0213 = false; // transpose input dim 1&2 - bool is_interleaved = false; // interleaved mode, implies trans0213 happens after RoPE - size_t rotary_ndims = 0; // dimensions to be embedded (d in the description) - bool is_chatglm = false; // chatglm is special which overrides other setting - bool support_2d_rope = false; // 2d rope mode, Support 2 dimentional rope which is independant of batch and - // each head. change input order to [batch, head_cnt, 4608] to support 2d rope - bool is_qwen = false; // Qwen is special which overrides other setting + bool input_trans0213 = false; // transpose input dim 1&2 + bool output_trans0213 = false; // implies trans0213 happens after RoPE + bool is_interleaved = false; // coordinates are interleaved + size_t rotary_ndims = 0; // dimensions to be embedded (d in the description) + bool is_chatglm = false; // chatglm is special which overrides other setting + bool support_2d_rope = false; // 2d rope mode, Support 2 dimentional rope which is independant of batch and + // each head. change input order to [batch, head_cnt, 4608] to support 2d rope + bool is_qwen = false; // Qwen is special which overrides other setting size_t head_cnt = 0; size_t head_size = 0; int gather_position_arg_id = diff --git a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp index eb1c92bcf9607f..3449151ab93ac5 100644 --- a/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp +++ b/src/common/transformations/include/transformations/common_optimizations/fuse_rotary_positional_embeddings.hpp @@ -12,6 +12,7 @@ namespace pass { class TRANSFORMATIONS_API RoPEFusion; class TRANSFORMATIONS_API RoPEFusionGPTNEOX; +class TRANSFORMATIONS_API RoPEFusionFlux; class TRANSFORMATIONS_API RoPEFusionGPTJ; class TRANSFORMATIONS_API RoPEFusionChatGLM; class TRANSFORMATIONS_API RoPEFusionQwen; @@ -29,6 +30,12 @@ class ov::pass::RoPEFusionGPTNEOX : public ov::pass::MatcherPass { RoPEFusionGPTNEOX(); }; +class ov::pass::RoPEFusionFlux : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("RoPEFusionFlux", "0"); + RoPEFusionFlux(); +}; + class ov::pass::RoPEFusionGPTJ : public ov::pass::MatcherPass { public: OPENVINO_RTTI("RoPEFusionGPTJ", "0"); @@ -85,6 +92,7 @@ class ov::pass::RoPEFusion : public ov::pass::GraphRewrite { public: OPENVINO_RTTI("RoPEFusion", "0"); RoPEFusion(bool support_2d_rope = false) { + add_matcher(); add_matcher(); add_matcher(); // optional heads & tails are fused in separate matcher pass, diff --git a/src/common/transformations/src/ov_ops/rotary_positional_embeddings.cpp b/src/common/transformations/src/ov_ops/rotary_positional_embeddings.cpp index 3e75e2b88df266..88a42a7f456db1 100644 --- a/src/common/transformations/src/ov_ops/rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/ov_ops/rotary_positional_embeddings.cpp @@ -76,7 +76,7 @@ void RoPE::validate_and_infer_types() { if (m_config.input_trans0213) { // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens before RoPE std::swap(input_pshape[2], input_pshape[1]); - } else if (m_config.is_interleaved) { + } else if (m_config.output_trans0213) { // transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens after RoPE std::swap(input_pshape[2], input_pshape[1]); } @@ -90,6 +90,7 @@ bool RoPE::visit_attributes(ov::AttributeVisitor& visitor) { visitor.on_attribute("slice_start", m_config.slice_start); visitor.on_attribute("slice_stop", m_config.slice_stop); visitor.on_attribute("input_trans0213", m_config.input_trans0213); + visitor.on_attribute("output_trans0213", m_config.output_trans0213); visitor.on_attribute("is_interleaved", m_config.is_interleaved); visitor.on_attribute("rotary_ndims", m_config.rotary_ndims); visitor.on_attribute("is_chatglm", m_config.is_chatglm); diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index f002e0043a8744..ec49dd7152fed1 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -23,6 +23,94 @@ using namespace ov::gen_pattern; +ov::pass::RoPEFusionFlux::RoPEFusionFlux() { + MATCHER_SCOPE(RoPEFusionFlux); + // x[?,24,?,128] + // x1 = reshape(x, [?,24,?,64,2]) + // x1_0, x1_1 = split(x1, -1) + // x2 = concat(x1_0, x1_1 * (-1), -1) + // x3 = reshape(x2, [?,24,?,128]) + // y1 = x * t_cos + // y2 = x3 * t_sin + // y = y1 + y2 + auto x = makePattern(ov::Rank(4)); + auto t_cos = makePattern(ov::Rank(4)); + auto t_sin = makePattern(ov::Rank(4)); + + auto num_heads = ov::gen_pattern::Symbol("num_heads"); + auto head_size = ov::gen_pattern::Symbol("head_size"); + + auto x1_target_shape = makeConst({0, num_heads, 0, -1, 2}); + auto x1 = makePattern({x, x1_target_shape}, {{"special_zero", true}}); + auto split = makePattern({x1, -1}, {{"num_splits", 2}}); + split->set_output_size(2); + + // 3 versions of mulitply by -1 depending on transformations execution prior to this pass + auto x1_1_neg_1 = makePattern({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}}); + + auto squeeze_2 = makePattern({split->output(1), -1}); + auto x1_1_neg_2 = makePattern({squeeze_2, -1.0f}, {{"auto_broadcast", "numpy"}}); + auto unsqueeze_2 = makePattern({x1_1_neg_2, -1}); + + auto x1_1_neg_3 = makePattern({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}}); + auto squeeze_3 = makePattern({x1_1_neg_3, -1}); + auto unsqueeze_3 = makePattern({squeeze_3, -1}); + + auto x2 = makePattern({x1_1_neg_1 | unsqueeze_2 | unsqueeze_3, split->output(0)}, {{"axis", -1}}); + auto x3_target_shape = makeConst({0, num_heads, 0, head_size}); + auto x3 = makePattern({x2, x3_target_shape}, {{"special_zero", true}}); + + auto y1 = makePattern({x, t_cos}, {{"auto_broadcast", "numpy"}}); + auto y2 = makePattern({x3, t_sin}, {{"auto_broadcast", "numpy"}}); + + auto y = makePattern({y1, y2}, {{"auto_broadcast", "numpy"}}); + auto result = y; + + matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + PatternValidator validator(m); + if (!validator) { + return false; + } + + const auto& pattern_map = m.get_pattern_value_map(); + auto root = m.get_match_root(); + + op::internal::RoPE::Config config; + config.head_cnt = static_cast(validator["num_heads"]); + config.head_size = static_cast(validator["head_size"]); + config.rotary_ndims = config.head_size; + config.is_interleaved = true; + config.output_trans0213 = false; + + OutputVector new_args; + new_args.push_back(pattern_map.at(x)); + new_args.push_back(pattern_map.at(t_cos)); + new_args.push_back(pattern_map.at(t_sin)); + + auto old_node = root; + auto new_node = std::make_shared(new_args, config); + new_node->set_friendly_name(old_node->get_friendly_name()); + ov::copy_runtime_info({pattern_map.at(x1).get_node_shared_ptr(), + pattern_map.at(split).get_node_shared_ptr(), + pattern_map.at(x2).get_node_shared_ptr(), + pattern_map.at(x3).get_node_shared_ptr(), + pattern_map.at(y1).get_node_shared_ptr(), + pattern_map.at(y2).get_node_shared_ptr(), + pattern_map.at(result).get_node_shared_ptr()}, + new_node); + + ov::replace_node(old_node, new_node); + + // this new node may match following additional matchers + register_new_node(new_node); + + return true; + }; + + auto m = std::make_shared(result, matcher_name); + this->register_matcher(m, callback); +} + ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() { MATCHER_SCOPE(RoPEFusionGPTNEOX); @@ -373,6 +461,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { OutputVector new_args; config.rotary_ndims = static_cast(validator["ndims"]); + config.output_trans0213 = true; config.is_interleaved = true; // input is [B,L,H,S] diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index ea928de5c01702..a42e11120d7276 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -6,7 +6,9 @@ #include +#include "common_test_utils/graph_comparator.hpp" #include "common_test_utils/ov_test_utils.hpp" +#include "openvino/core/node_vector.hpp" #include "openvino/opsets/opset1.hpp" #include "openvino/opsets/opset3.hpp" #include "ov_ops/rotary_positional_embeddings.hpp" @@ -133,6 +135,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_LLama2_no_gather) { {{"config.slice_start", 0}, {"config.slice_stop", 0}, {"config.input_trans0213", true}, + {"config.output_trans0213", false}, {"config.is_interleaved", false}, {"config.is_chatglm", false}, {"config.support_2d_rope", false}, @@ -169,6 +172,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_LLama2_with_gather) { {{"config.slice_start", 0}, {"config.slice_stop", 0}, {"config.input_trans0213", true}, + {"config.output_trans0213", false}, {"config.is_interleaved", false}, {"config.is_chatglm", false}, {"config.support_2d_rope", false}, @@ -308,6 +312,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_no_gather) { {{"config.slice_start", 0}, {"config.slice_stop", ndims}, {"config.input_trans0213", true}, + {"config.output_trans0213", false}, {"config.is_interleaved", false}, {"config.is_chatglm", false}, {"config.support_2d_rope", false}, @@ -343,6 +348,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_with_gather) { {{"config.slice_start", 0}, {"config.slice_stop", ndims}, {"config.input_trans0213", true}, + {"config.output_trans0213", false}, {"config.is_interleaved", false}, {"config.is_chatglm", false}, {"config.support_2d_rope", false}, @@ -459,6 +465,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ) { {{"config.slice_start", 0}, {"config.slice_stop", 0}, {"config.input_trans0213", false}, + {"config.output_trans0213", true}, {"config.is_interleaved", true}, {"config.is_chatglm", false}, {"config.support_2d_rope", false}, @@ -568,6 +575,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGML) { {{"config.slice_start", 0}, {"config.slice_stop", 4096}, {"config.input_trans0213", false}, + {"config.output_trans0213", false}, {"config.is_interleaved", false}, {"config.rotary_ndims", rotary_ndims}, {"config.is_chatglm", true}, @@ -646,6 +654,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGML_Slice) { {{"config.slice_start", 0}, {"config.slice_stop", 4096}, {"config.input_trans0213", false}, + {"config.output_trans0213", false}, {"config.is_interleaved", false}, {"config.rotary_ndims", rotary_ndims}, {"config.is_chatglm", true}, @@ -728,6 +737,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_Slice) { {{"config.slice_start", 0}, {"config.slice_stop", 0}, {"config.input_trans0213", false}, + {"config.output_trans0213", true}, {"config.is_interleaved", true}, {"config.is_chatglm", false}, {"config.support_2d_rope", false}, @@ -843,6 +853,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGML_2d_rope) { {{"config.slice_start", 0}, {"config.slice_stop", 4096}, {"config.input_trans0213", false}, + {"config.output_trans0213", false}, {"config.is_interleaved", false}, {"config.rotary_ndims", rotary_ndims}, {"config.is_chatglm", true}, @@ -951,6 +962,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGML_nano_2d_rope) { {{"config.slice_start", 0}, {"config.slice_stop", 2048}, {"config.input_trans0213", false}, + {"config.output_trans0213", false}, {"config.is_interleaved", false}, {"config.rotary_ndims", rotary_ndims}, {"config.is_chatglm", true}, @@ -962,4 +974,160 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGML_nano_2d_rope) { model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, cos_sin_cache, position_ids}); } -} \ No newline at end of file +} + +TEST_F(TransformationTestsF, ConvertToROPE_Flux_mul) { + disable_rt_info_check(); + const int batch = 2; + const int num_heads = 32; + const int ndims = 128; + { + auto x = + std::make_shared(ov::element::f32, ov::PartialShape{batch, num_heads, -1, ndims}); + auto t_cos = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + auto t_sin = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + + auto x1_shape = makeConst(ov::element::i64, ov::Shape({5}), {0, num_heads, 0, -1, 2}); + auto x1 = std::make_shared(x, x1_shape, true); + + auto split_axis = makeConst(ov::element::i64, ov::Shape(), {-1}); + auto split = std::make_shared(x1, split_axis, 2); + + auto minus_one = makeConst(ov::element::f32, ov::Shape({}), {-1.0f}); + auto x1_1_neg = std::make_shared(split->output(1), minus_one); + + auto x2 = std::make_shared(ov::OutputVector{x1_1_neg->output(0), split->output(0)}, -1); + + auto x3_shape = makeConst(ov::element::i64, ov::Shape({4}), {0, num_heads, 0, ndims}); + auto x3 = std::make_shared(x2, x3_shape, true); + + auto y1 = std::make_shared(x, t_cos); + auto y2 = std::make_shared(x3, t_sin); + auto y = std::make_shared(y1, y2); + + model = std::make_shared(ov::NodeVector{y}, ov::ParameterVector{x, t_cos, t_sin}); + } + manager.register_pass(true); + { + auto x = + std::make_shared(ov::element::f32, ov::PartialShape{batch, num_heads, -1, ndims}); + auto t_cos = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + auto t_sin = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + ov::op::internal::RoPE::Config config; + config.is_interleaved = true; + config.rotary_ndims = ndims; + config.head_cnt = num_heads; + config.head_size = ndims; + auto rope = std::make_shared(ov::OutputVector{x, t_cos, t_sin}, config); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{x, t_cos, t_sin}); + } + comparator.enable(FunctionsComparator::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, ConvertToROPE_Flux_squeeze_mul_unsqueeze) { + disable_rt_info_check(); + const int batch = 2; + const int num_heads = 32; + const int ndims = 128; + { + auto x = + std::make_shared(ov::element::f32, ov::PartialShape{batch, num_heads, -1, ndims}); + auto t_cos = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + auto t_sin = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + + auto x1_shape = makeConst(ov::element::i64, ov::Shape({5}), {0, num_heads, 0, -1, 2}); + auto x1 = std::make_shared(x, x1_shape, true); + + auto split_axis = makeConst(ov::element::i64, ov::Shape(), {-1}); + auto split = std::make_shared(x1, split_axis, 2); + + auto squeeze_axis = makeConst(ov::element::i32, ov::Shape({}), {-1}); + auto squeeze = std::make_shared(split->output(1), squeeze_axis); + + auto minus_one = makeConst(ov::element::f32, ov::Shape({}), {-1.0f}); + auto x1_1_neg = std::make_shared(squeeze, minus_one); + + auto unsqueeze_axis = makeConst(ov::element::i32, ov::Shape({}), {-1}); + auto unsqueeze = std::make_shared(x1_1_neg, unsqueeze_axis); + + auto x2 = std::make_shared(ov::OutputVector{unsqueeze->output(0), split->output(0)}, -1); + + auto x3_shape = makeConst(ov::element::i64, ov::Shape({4}), {0, num_heads, 0, ndims}); + auto x3 = std::make_shared(x2, x3_shape, true); + + auto y1 = std::make_shared(x, t_cos); + auto y2 = std::make_shared(x3, t_sin); + auto y = std::make_shared(y1, y2); + + model = std::make_shared(ov::NodeVector{y}, ov::ParameterVector{x, t_cos, t_sin}); + } + manager.register_pass(true); + { + auto x = + std::make_shared(ov::element::f32, ov::PartialShape{batch, num_heads, -1, ndims}); + auto t_cos = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + auto t_sin = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + ov::op::internal::RoPE::Config config; + config.is_interleaved = true; + config.rotary_ndims = ndims; + config.head_cnt = num_heads; + config.head_size = ndims; + auto rope = std::make_shared(ov::OutputVector{x, t_cos, t_sin}, config); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{x, t_cos, t_sin}); + } + comparator.enable(FunctionsComparator::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, ConvertToROPE_Flux_mul_squeeze_unsqueeze) { + disable_rt_info_check(); + const int batch = 2; + const int num_heads = 32; + const int ndims = 128; + { + auto x = + std::make_shared(ov::element::f32, ov::PartialShape{batch, num_heads, -1, ndims}); + auto t_cos = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + auto t_sin = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + + auto x1_shape = makeConst(ov::element::i64, ov::Shape({5}), {0, num_heads, 0, -1, 2}); + auto x1 = std::make_shared(x, x1_shape, true); + + auto split_axis = makeConst(ov::element::i64, ov::Shape(), {-1}); + auto split = std::make_shared(x1, split_axis, 2); + + auto minus_one = makeConst(ov::element::f32, ov::Shape({}), {-1.0f}); + auto x1_1_neg = std::make_shared(split->output(1), minus_one); + + auto squeeze_axis = makeConst(ov::element::i32, ov::Shape({}), {-1}); + auto squeeze = std::make_shared(x1_1_neg, squeeze_axis); + + auto unsqueeze_axis = makeConst(ov::element::i32, ov::Shape({}), {-1}); + auto unsqueeze = std::make_shared(squeeze, unsqueeze_axis); + + auto x2 = std::make_shared(ov::OutputVector{unsqueeze->output(0), split->output(0)}, -1); + + auto x3_shape = makeConst(ov::element::i64, ov::Shape({4}), {0, num_heads, 0, ndims}); + auto x3 = std::make_shared(x2, x3_shape, true); + + auto y1 = std::make_shared(x, t_cos); + auto y2 = std::make_shared(x3, t_sin); + auto y = std::make_shared(y1, y2); + + model = std::make_shared(ov::NodeVector{y}, ov::ParameterVector{x, t_cos, t_sin}); + } + manager.register_pass(true); + { + auto x = + std::make_shared(ov::element::f32, ov::PartialShape{batch, num_heads, -1, ndims}); + auto t_cos = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + auto t_sin = std::make_shared(ov::element::f32, ov::PartialShape{1, 1, -1, ndims}); + ov::op::internal::RoPE::Config config; + config.is_interleaved = true; + config.rotary_ndims = ndims; + config.head_cnt = num_heads; + config.head_size = ndims; + auto rope = std::make_shared(ov::OutputVector{x, t_cos, t_sin}, config); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{x, t_cos, t_sin}); + } + comparator.enable(FunctionsComparator::ATTRIBUTES); +} diff --git a/src/plugins/intel_cpu/src/nodes/rope.cpp b/src/plugins/intel_cpu/src/nodes/rope.cpp index f089b67a122beb..73a23c5e7cdcd7 100644 --- a/src/plugins/intel_cpu/src/nodes/rope.cpp +++ b/src/plugins/intel_cpu/src/nodes/rope.cpp @@ -392,7 +392,7 @@ void RoPE::initSupportedPrimitiveDescriptors() { m_executor = std::make_shared>(m_config); rtPrecision = ov::element::f32; } - } else if (m_config.is_interleaved) { + } else if (m_config.is_interleaved && m_config.output_trans0213) { OPENVINO_ASSERT(m_config.input_trans0213 == false); OPENVINO_ASSERT(m_config.slice_start == 0); OPENVINO_ASSERT(m_config.slice_stop == 0); diff --git a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp index 9dd1da2d471e5a..27afb95a73a1e9 100644 --- a/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp +++ b/src/plugins/intel_cpu/src/transformations/transformation_pipeline.cpp @@ -842,6 +842,7 @@ void Transformations::PostLpt() { CPU_REGISTER_PASS_X64(postLPTPassManager, ov::pass::RoPEFusion, true); CPU_REGISTER_PASS_ARM64(postLPTPassManager, ov::pass::RoPEFusion, true); + CPU_DISABLE_PASS_COMMON(postLPTPassManager, ov::pass::RoPEFusionFlux); CPU_REGISTER_PASS_X64(postLPTPassManager, CausalMaskPreprocessFusion); #if defined(OPENVINO_ARCH_X86_64) diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp index d7933e2180fe6f..a90caaa8a8cb9f 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/rope.hpp @@ -43,6 +43,7 @@ struct rope : public primitive_base { seed = hash_combine(seed, config.input_trans0213); seed = hash_combine(seed, config.is_chatglm); seed = hash_combine(seed, config.support_2d_rope); + seed = hash_combine(seed, config.output_trans0213); seed = hash_combine(seed, config.is_interleaved); seed = hash_combine(seed, config.is_qwen); seed = hash_combine(seed, config.rotary_ndims); @@ -64,6 +65,7 @@ struct rope : public primitive_base { config.input_trans0213 == rhs_casted.config.input_trans0213 && config.is_chatglm == rhs_casted.config.is_chatglm && config.support_2d_rope == rhs_casted.config.support_2d_rope && + config.output_trans0213 == rhs_casted.config.output_trans0213 && config.is_interleaved == rhs_casted.config.is_interleaved && config.is_qwen == rhs_casted.config.is_qwen && config.rotary_ndims == rhs_casted.config.rotary_ndims && @@ -80,6 +82,7 @@ struct rope : public primitive_base { ob << config.input_trans0213; ob << config.is_chatglm; ob << config.support_2d_rope; + ob << config.output_trans0213; ob << config.is_interleaved; ob << config.is_qwen; ob << config.rotary_ndims; @@ -96,6 +99,7 @@ struct rope : public primitive_base { ib >> config.input_trans0213; ib >> config.is_chatglm; ib >> config.support_2d_rope; + ib >> config.output_trans0213; ib >> config.is_interleaved; ib >> config.is_qwen; ib >> config.rotary_ndims; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp index 27ce085ab83c3f..d06e643c71ad18 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/rope.cpp @@ -45,7 +45,7 @@ struct rope_impl : typed_primitive_impl_ocl { params.slice_stop = primitive->config.slice_stop; params.axis = primitive->config.is_qwen || primitive->config.is_chatglm ? 2 : 3; - params.num_of_inputs = primitive->config.is_chatglm || primitive->config.is_interleaved ? 2 : 3; + params.num_of_inputs = primitive->config.is_chatglm || (primitive->config.output_trans0213 && primitive->config.is_interleaved) ? 2 : 3; if (params.gather_rank > 0) { params.num_of_inputs++; @@ -53,6 +53,7 @@ struct rope_impl : typed_primitive_impl_ocl { params.is_qwen = primitive->config.is_qwen; params.is_chatglm = primitive->config.is_chatglm; + params.is_interleaved = primitive->config.is_interleaved; params.support_2d_rope = primitive->config.support_2d_rope; params.transposed_input = primitive->config.input_trans0213; diff --git a/src/plugins/intel_gpu/src/graph/rope.cpp b/src/plugins/intel_gpu/src/graph/rope.cpp index e168626f8d69a2..bef3f6dfcd93c0 100644 --- a/src/plugins/intel_gpu/src/graph/rope.cpp +++ b/src/plugins/intel_gpu/src/graph/rope.cpp @@ -54,7 +54,7 @@ std::vector rope_inst::calc_output_layouts(rope_node const& node, kernel output_shape[3] = input_slice_size; } - if (desc->config.input_trans0213 || desc->config.is_interleaved) { + if (desc->config.input_trans0213 || desc->config.output_trans0213) { std::swap(output_shape[2], output_shape[1]); } } @@ -77,6 +77,7 @@ std::string rope_inst::to_string(rope_node const& node) { rope_info.add("input_trans0213", desc->config.input_trans0213); rope_info.add("is_chatglm", desc->config.is_chatglm); rope_info.add("support_2d_rope", desc->config.support_2d_rope); + rope_info.add("output_trans0213", desc->config.output_trans0213); rope_info.add("is_interleaved", desc->config.is_interleaved); rope_info.add("is_qwen", desc->config.is_qwen); rope_info.add("rotary_ndims", desc->config.rotary_ndims); diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl index 133440a21301f2..d429916b46d69a 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl @@ -160,3 +160,42 @@ KERNEL(rope_ref)( sin[sin_idx + HALF_ROTARY_NDIMS + r] * in1; } #endif + +#ifdef RotateInterleaved +KERNEL(rope_ref)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* input, + const __global INPUT1_TYPE* cos, + const __global INPUT2_TYPE* sin, + __global OUTPUT_TYPE* output) +{ + const uint b = get_global_id(0); + const uint h = get_global_id(1); + const uint p = (uint)get_global_id(2) / HALF_ROTARY_NDIMS; + const uint r = 2 * ((uint)get_global_id(2) % HALF_ROTARY_NDIMS); + + uint input_idx = INPUT0_GET_INDEX(b, h, p, 0); + + uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; + uint cos_sin_h = h < INPUT1_FEATURE_NUM ? h : 0; + uint cos_sin_p = p < INPUT1_SIZE_Y ? p : 0; + +#ifndef SIN_COS_HAVE_DYNAMIC_PADDINGS + uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); + + uint cos_idx = cos_sin_idx; + uint sin_idx = cos_sin_idx; +#else + uint cos_idx = INPUT1_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); + uint sin_idx = INPUT2_GET_INDEX(cos_sin_b, cos_sin_h, cos_sin_p, 0); +#endif + + uint output_idx = OUTPUT_GET_INDEX(b, h, p, 0); + + INPUT0_TYPE in1 = input[input_idx + r]; + INPUT0_TYPE in2 = input[input_idx + r + 1]; + + output[output_idx + r] = cos[cos_idx + r] * in1 - sin[sin_idx + r] * in2; + output[output_idx + r + 1] = cos[cos_idx + r + 1] * in2 + sin[sin_idx + r + 1] * in1; +} +#endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp index 130c5a69d4262c..98212254be9e3c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.cpp @@ -51,6 +51,8 @@ JitConstants RoPEKernelBase::GetJitConstants(const rope_params& params, RoPEKern jit.AddConstant(MakeJitConstant("SUPPORT_2D_ROPE", true)); } jit.AddConstant(MakeJitConstant("CHATGLM", true)); + } else if (params.is_interleaved) { + jit.AddConstant(MakeJitConstant("RotateInterleaved", true)); } else { jit.AddConstant(MakeJitConstant("RotateHalf", true)); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h index 472131eba5d82f..8e95c12d9a78dd 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/rope/rope_kernel_base.h @@ -24,6 +24,7 @@ struct rope_params : public base_params { bool is_qwen = false; bool is_chatglm = false; + bool is_interleaved = false; bool support_2d_rope = false; bool transposed_input = false; }; diff --git a/src/plugins/intel_gpu/src/plugin/ops/rope.cpp b/src/plugins/intel_gpu/src/plugin/ops/rope.cpp index 321342b3395660..04eae769612bfc 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/rope.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/rope.cpp @@ -29,6 +29,8 @@ static void CreateRoPEOp(ProgramBuilder& p, const std::shared_ptrget_input_partial_shape(config.gather_position_arg_id).size(); } + OPENVINO_ASSERT(!config.is_interleaved || !config.output_trans0213, "[GPU] Unsupported ROPE parameters"); + auto rope = cldnn::rope(layer_type_name_ID(op), inputs, config, diff --git a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp index 741014b461e7f0..3f7fe91da86d93 100644 --- a/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp +++ b/src/plugins/intel_gpu/tests/functional/shared_tests_instances/subgraph_tests/rotary_pos_emb.cpp @@ -7,6 +7,11 @@ namespace ov { namespace test { +INSTANTIATE_TEST_SUITE_P(smoke_RoPETestFlux, + RoPETestFlux, + ::testing::Values(ov::test::utils::DEVICE_GPU), + RoPETestFlux::getTestCaseName); + INSTANTIATE_TEST_SUITE_P(smoke_RoPETestChatGLM, RoPETestChatGLMStridedSlice, ::testing::Values(ov::test::utils::DEVICE_GPU), diff --git a/src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp b/src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp index 7100ddca1083e3..c3f0b8ef0b6015 100644 --- a/src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp +++ b/src/tests/functional/plugin/shared/include/subgraph_tests/rotary_pos_emb.hpp @@ -24,6 +24,13 @@ inline void CheckNumberOfNodesWithType(std::shared_ptr function ASSERT_EQ(num_ops, expectedCount); } +TEST_P(RoPETestFlux, CompareWithRefs) { + SKIP_IF_CURRENT_TEST_IS_DISABLED(); + run(); + auto function = compiledModel.get_runtime_model(); + CheckNumberOfNodesWithType(function, {"RoPE"}, 1); +}; + TEST_P(RoPETestLlama2StridedSlice, CompareWithRefs) { SKIP_IF_CURRENT_TEST_IS_DISABLED(); run(); diff --git a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp index e1182bd3b16e13..39cdb871710e64 100644 --- a/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp +++ b/src/tests/functional/shared_test_classes/include/shared_test_classes/subgraph/rotary_pos_emb.hpp @@ -9,6 +9,20 @@ namespace ov { namespace test { +class RoPETestFlux : public SubgraphBaseTest, public testing::WithParamInterface { +private: + std::shared_ptr build_rope_flux(int batch, + int seq_length, + int num_head, + int ndims); +protected: + void generate_inputs(const std::vector& targetInputStaticShapes) override; + void SetUp() override; + +public: + static std::string getTestCaseName(const testing::TestParamInfo& obj); +}; + class RoPETestLlama2StridedSlice : public SubgraphBaseTest, public testing::WithParamInterface { private: std::shared_ptr buildROPE_Llama2(int batch, diff --git a/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp b/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp index a1848903bb76a2..1a078d9b49ebb7 100644 --- a/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp +++ b/src/tests/functional/shared_test_classes/src/subgraph/rotary_pos_emb.cpp @@ -5,6 +5,7 @@ #include "shared_test_classes/subgraph/rotary_pos_emb.hpp" #include "common_test_utils/ov_tensor_utils.hpp" +#include "openvino/core/node_vector.hpp" #include "transformations/utils/gen_pattern.hpp" using namespace ov::gen_pattern; @@ -13,6 +14,85 @@ using namespace ov; namespace ov { namespace test { +std::shared_ptr RoPETestFlux::build_rope_flux(int batch, + int seq_length, + int num_head, + int ndims) { + auto x = std::make_shared(ov::element::f32, PartialShape{batch, num_head, seq_length, ndims}); + auto t_cos = std::make_shared(ov::element::f32, PartialShape{1, 1, seq_length, ndims}); + auto t_sin = std::make_shared(ov::element::f32, PartialShape{1, 1, seq_length, ndims}); + + auto x1_shape = makeConst(element::i64, ov::Shape({5}), {0, num_head, 0, -1, 2}); + auto x1 = std::make_shared(x, x1_shape, true); + + auto split_axis = makeConst(element::i64, ov::Shape(), {-1}); + auto split = std::make_shared(x1, split_axis, 2); + + auto minus_one = makeConst(element::f32, ov::Shape({}), {-1.0f}); + auto x1_1_neg = std::make_shared(split->output(1), minus_one); + + auto x2 = std::make_shared(OutputVector{x1_1_neg->output(0), split->output(0)}, -1); + + auto x3_shape = makeConst(element::i64, ov::Shape({4}), {0, num_head, 0, ndims}); + auto x3 = std::make_shared(x2, x3_shape, true); + + auto y1 = std::make_shared(x, t_cos); + auto y2 = std::make_shared(x3, t_sin); + auto y = std::make_shared(y1, y2); + + return std::make_shared(ov::NodeVector{y}, ov::ParameterVector{x, t_cos, t_sin}); +} + +void RoPETestFlux::generate_inputs(const std::vector& targetInputStaticShapes) { + const auto& funcInputs = function->inputs(); + + ov::test::utils::InputGenerateData in_data; + in_data.start_from = -1; + in_data.range = 2; + in_data.resolution = 32768; + + auto cos_data = in_data; + cos_data.seed = 10; + + auto sin_data = in_data; + sin_data.seed = 20; + + ov::Tensor t_input = utils::create_and_fill_tensor(funcInputs[0].get_element_type(), targetInputStaticShapes[0], in_data); + ov::Tensor t_cos = utils::create_and_fill_tensor(funcInputs[1].get_element_type(), targetInputStaticShapes[1], cos_data); + ov::Tensor t_sin = utils::create_and_fill_tensor(funcInputs[2].get_element_type(), targetInputStaticShapes[2], sin_data); + + inputs.clear(); + inputs.insert({funcInputs[0].get_node_shared_ptr(), t_input}); + inputs.insert({funcInputs[1].get_node_shared_ptr(), t_cos}); + inputs.insert({funcInputs[2].get_node_shared_ptr(), t_sin}); +} + +void RoPETestFlux::SetUp() { + targetDevice = this->GetParam(); + + const int batch = 1; + const int seq_length = 7; + const size_t max_position_embeddings = 2048; + const size_t ndims = 128; + const size_t num_head = 24; + + std::vector input_shapes = { + {{batch, num_head, seq_length, ndims}, {{batch, num_head, seq_length, ndims}}}, + {{1, 1, seq_length, ndims}, {{1, 1, seq_length, ndims}}}, + {{1, 1, seq_length, ndims}, {{1, 1, seq_length, ndims}}} + }; + init_input_shapes(input_shapes); + function = build_rope_flux(batch, -1, num_head, ndims); +} + +std::string RoPETestFlux::getTestCaseName(const testing::TestParamInfo& obj) { + std::string targetDevice = obj.param; + std::ostringstream result; + result << "targetDevice=" << targetDevice; + return result.str(); +} + + ov::OutputVector RoPETestLlama2StridedSlice::makeCosSinCache(int max_position_embeddings, int rotary_ndims) { std::vector lut_sin(max_position_embeddings * rotary_ndims, 0.0f); std::vector lut_cos(max_position_embeddings * rotary_ndims, 0.0f);