diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 5813cfe79c23ff..96959fd292ee9e 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -14,6 +14,7 @@ #include "ov_ops/rotary_positional_embeddings.hpp" #include "ov_ops/type_relaxed.hpp" #include "transformations/utils/gen_pattern.hpp" +#include "transformations/utils/print_model.hpp" using namespace testing; using namespace ov::gen_pattern; @@ -1298,4 +1299,116 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) { model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements}); } -} \ No newline at end of file +} + +TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) { + using namespace ov; + + { + auto position_ids = std::make_shared(ov::element::i64, ov::PartialShape{-1, -1}); + auto qkv = std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 3 * 4096}); + + auto qkv_proj = makeOP({qkv, 2, {4096, 4096, -1}}); + + auto view_Reshape = makeOP({qkv_proj->output(0), {0, 0, 32, 128}}, {{"special_zero", true}}); + auto slice_Slice_4 = makeOP({view_Reshape, {0}, {128}, {1}, {3}}); + auto slice_Slice = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + + auto Convert_50535 = makeOP({position_ids}, {{"destination_type", "i32"}}); + auto Unsqueeze_23750 = makeOP({Convert_50535, {-1, 1}}, {{"special_zero", false}}); + + auto slice_Slice_1 = makeOP({slice_Slice, Unsqueeze_23750, 1}, {{"batch_dims", 0}}); + auto Reshape_27400 = makeOP({slice_Slice_1, {-1, 1, 1, 128}}, {{"special_zero", false}}); + + auto mul_Multiply = makeOP({slice_Slice_4, Reshape_27400}, {{"auto_broadcast", "numpy"}}); + auto reshape_Reshape = makeOP({slice_Slice_4, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); + auto ListUnpack_Split = makeOP({reshape_Reshape, -2}, {{"num_splits", 2}}); + auto Multiply_54136 = + makeOP({ListUnpack_Split->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto ListUnpack_Squeeze_0 = + makeOP({Multiply_54136, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto ListUnpack_Squeeze = + makeOP({ListUnpack_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto cat_Concat = makeOP({ListUnpack_Squeeze_0, ListUnpack_Squeeze}, {{"axis", -1}}); + + auto slice_Slice_2 = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto slice_Slice_6 = makeOP({slice_Slice_2, Unsqueeze_23750, 1}, {{"batch_dims", 0}}); + auto Reshape_27408 = makeOP({slice_Slice_6, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto mul_Multiply_1 = makeOP({cat_Concat, Reshape_27408}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_10 = makeConst(element::f32, + ov::Shape({ + 1, + 32767, + 1, + 1, + }), + {1}); + auto view_Reshape_1 = makeOP({qkv_proj->output(1), {0, 0, 32, 128}}, {{"special_zero", true}}); + auto slice_Slice_11 = makeOP({view_Reshape_1, {0}, {128}, {1}, {3}}); + auto mul_Multiply_2 = makeOP({slice_Slice_11, Reshape_27400}, {{"auto_broadcast", "numpy"}}); + auto reshape_Reshape_1 = makeOP({slice_Slice_11, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); + auto ListUnpack_Split_1 = makeOP({reshape_Reshape_1, -2}, {{"num_splits", 2}}); + auto Multiply_54139 = + makeOP({ListUnpack_Split_1->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto ListUnpack_Squeeze_0_1 = + makeOP({Multiply_54139, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto ListUnpack_Squeeze_1 = + makeOP({ListUnpack_Split_1->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto cat_Concat_2 = makeOP({ListUnpack_Squeeze_0_1, ListUnpack_Squeeze_1}, {{"axis", -1}}); + auto mul_Multiply_3 = makeOP({cat_Concat_2, Reshape_27408}, {{"auto_broadcast", "numpy"}}); + auto add_Add_1 = makeOP({mul_Multiply_2, mul_Multiply_3}, {{"auto_broadcast", "numpy"}}); + model = std::make_shared(ov::NodeVector{add_Add_1}, ov::ParameterVector{position_ids, qkv}); + } + + manager.register_pass(false); + + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 4096 * 3}); + auto rotary_emp_sin = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto rotary_emp_cos = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto rope = makeOP({input, rotary_emp_sin, rotary_emp_cos}, + {{"config.slice_start", 4096}, + {"config.slice_stop", 8192}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", 128}, + {"config.is_chatglm", false}, + {"config.support_2d_rope", false}, + {"config.is_qwen", true}, + {"config.head_cnt", 32}, + {"config.head_size", 128}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input}); + } +}