diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 6b12f56215ca83..4abc1a35841823 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -397,9 +397,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto varsplit = makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); varsplit->set_output_size(2); // Reshape or UnSqueeze should both be support - auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}) | + auto dim0 = ov::gen_pattern::Symbol("dim0"); + auto dim1 = ov::gen_pattern::Symbol("dim1"); + auto unsqueeze_sin = makePattern({varsplit->output(0), {dim0, dim1, 1, 32}}) | makePattern({varsplit->output(0), 2}); - auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}) | + auto unsqueeze_cos = makePattern({varsplit->output(1), {dim0, dim1, 1, 32}}) | makePattern({varsplit->output(1), 2}); // repeate cos/sin table auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { @@ -419,10 +421,17 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto neg_Multiply_1177 = makePattern({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}}); auto Unsqueeze_65524 = makePattern({neg_Multiply_1177, -1}); + auto head_num = ov::gen_pattern::Symbol("head_num"); + auto Unsqueeze_28998 = + makePattern({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3); auto Unsqueeze_65525 = makePattern({slice_Slice_1168, -1}); - auto stack_1182 = makePattern({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); + auto Unsqueeze_28999 = + makePattern({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); + auto stack_1182 = + makePattern({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999}, + {{"axis", -1}}); auto ShapeOf_169068 = makePattern({stack_1182}); auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0); @@ -447,7 +456,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { makePattern({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}}); auto permute_Transpose_1213 = makePattern({cat_Concat_1211, {0, 2, 1, 3}}); - auto result = permute_Transpose_1213; + auto result = cat_Concat_1211 | permute_Transpose_1213; matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); @@ -461,7 +470,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { OutputVector new_args; config.rotary_ndims = static_cast(validator["ndims"]); - config.output_trans0213 = true; + if (pattern_map.count(permute_Transpose_1213)) + config.output_trans0213 = true; config.is_interleaved = true; // input is [B,L,H,S] @@ -478,14 +488,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(), pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(), pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(), - pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(), - pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(), pattern_map.at(stack_1182).get_node_shared_ptr(), pattern_map.at(mul_cos).get_node_shared_ptr(), pattern_map.at(mul_sin).get_node_shared_ptr(), pattern_map.at(rotary_emb).get_node_shared_ptr(), - pattern_map.at(cat_Concat_1211).get_node_shared_ptr(), - pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()}, + pattern_map.at(cat_Concat_1211).get_node_shared_ptr()}, new_node); ov::replace_node(old_node, new_node); // shapeof may be moved up from transpose to add, @@ -705,6 +712,7 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto rotary_emb_cos = makePattern("[1,?,1,?]"); // [1,..4096,1,128] auto rotary_emb_sin = makePattern("[1,?,1,?]"); // [1,..4096,1,128] auto qkv_proj = makePattern("[?,?,?]"); // [?,?,12288] + auto position_ids = makePattern(); auto head_cnt = ov::gen_pattern::Symbol("head_cnt"); auto head_size = ov::gen_pattern::Symbol("head_size"); @@ -731,14 +739,19 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto ScatterUpdate_463814 = makePattern({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}}); auto slice_Slice_446 = makePattern({rotary_emb_cos, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}}); + + auto gather_cos_by_pos_ids = makePattern({rotary_emb_cos, position_ids, 1}, {{"batch_dims", 0}}); + auto reshape_cos_to_expected_layout = + makePattern({gather_cos_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto slice_StridedSlice_446 = GenStridedSlice(rotary_emb_cos, ScatterUpdate_463814, {0, INT_MAX}, {1, 1}, 1); // tensor_array - auto mul_Multiply_552 = - makePattern({slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446}, - {{"auto_broadcast", "numpy"}}); // tensor_array + auto mul_Multiply_552 = makePattern( + {slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446 | reshape_cos_to_expected_layout}, + {{"auto_broadcast", "numpy"}}); // tensor_array auto reshape_opt1 = [&](std::shared_ptr input_BLHS) { auto ShapeOf_485814 = makePattern({input_BLHS}, {}); @@ -772,8 +785,15 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { makePattern({Multiply_567527, -2}); // tensor_array auto ListUnpack_586_Squeeze = makePattern({ListUnpack_586_Split->output(0), -2}); // tensor_array - auto cat_Concat_593 = makePattern({ListUnpack_586_Squeeze_0, ListUnpack_586_Squeeze}, - {{"axis", -1}}); // tensor_array + + auto ListUnpack_Squeeze_0_1 = + makePattern({Multiply_567527, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto ListUnpack_Squeeze_1 = + makePattern({ListUnpack_586_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + + auto cat_Concat_593 = makePattern( + {ListUnpack_586_Squeeze_0 | ListUnpack_Squeeze_0_1, ListUnpack_586_Squeeze | ListUnpack_Squeeze_1}, + {{"axis", -1}}); // tensor_array auto slice_StridedSlice_470 = GenStridedSlice(rotary_emb_sin, ScatterUpdate_463814, {0, INT_MAX}, @@ -781,9 +801,12 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { 1); // tensor_array auto slice_Slice_470 = makePattern({rotary_emb_sin, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}}); - auto mul_Multiply_594 = - makePattern({cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470}, - {{"auto_broadcast", "numpy"}}); // tensor_array + auto gather_sin_by_pos_ids = makePattern({rotary_emb_sin, position_ids, 1}, {{"batch_dims", 0}}); + auto reshape_sin_to_expected_layout = + makePattern({gather_sin_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto mul_Multiply_594 = makePattern( + {cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470 | reshape_sin_to_expected_layout}, + {{"auto_broadcast", "numpy"}}); // tensor_array auto add_Add_597 = makePattern({mul_Multiply_552, mul_Multiply_594}, {{"auto_broadcast", "numpy"}}); // tensor_array @@ -844,8 +867,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto new_node = std::make_shared(new_args, config); new_node->set_friendly_name(old_node->get_friendly_name()); ov::copy_runtime_info({pattern_map.at(Multiply_567527).get_node_shared_ptr(), - pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(), - pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(), + // pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(), + // pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(), pattern_map.at(cat_Concat_593).get_node_shared_ptr(), pattern_map.at(mul_Multiply_594).get_node_shared_ptr(), pattern_map.at(add_Add_597).get_node_shared_ptr()}, diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 1b34e0c4423d3d..96959fd292ee9e 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -14,6 +14,7 @@ #include "ov_ops/rotary_positional_embeddings.hpp" #include "ov_ops/type_relaxed.hpp" #include "transformations/utils/gen_pattern.hpp" +#include "transformations/utils/print_model.hpp" using namespace testing; using namespace ov::gen_pattern; @@ -1215,4 +1216,199 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) { {"config.gather_position_arg_id", 0}}); model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin}); } -} \ No newline at end of file +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) { + disable_rt_info_check(); + const int batch = -1; + const int num_heads = 16; + const int ndims = 256; + const int rotary_ndims = 64; + using namespace ov; + { + std::vector rpi_idx(rotary_ndims); + for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) { + rpi_idx[i] = index; + rpi_idx[i + 1] = index; + } + auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx); + + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); + auto aten_gather_GatherElements = + std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, rotary_ndims}); + + auto prim_ListUnpack_VariadicSplit = + makeOP({aten_gather_GatherElements, -1, {rotary_ndims / 2, -1}}); + auto aten_unsqueeze_Unsqueeze_1 = + makeOP({prim_ListUnpack_VariadicSplit->output(1), {-1, 1, 1, rotary_ndims / 2}}, + {{"special_zero", false}}); + auto aten_repeat_interleave_Gather_1 = + makeOP({aten_unsqueeze_Unsqueeze_1, repeat_interleave_index, 3}, {{"batch_dims", 0}}); + + auto aten_unsqueeze_Unsqueeze_2 = + makeOP({prim_ListUnpack_VariadicSplit->output(0), {-1, 1, 1, rotary_ndims / 2}}, + {{"special_zero", false}}); + auto aten_repeat_interleave_Gather_3 = + makeOP({aten_unsqueeze_Unsqueeze_2, repeat_interleave_index, 3}, {{"batch_dims", 0}}); + + auto VariadicSplit_32371 = makeOP({input, 3, {rotary_ndims, ndims - rotary_ndims}}); + auto aten_mul_Multiply = + makeOP({VariadicSplit_32371->output(0), aten_repeat_interleave_Gather_1}, + {{"auto_broadcast", "numpy"}}); + auto aten_slice_Slice_10 = makeOP({VariadicSplit_32371->output(0), {1}, {INT_MAX}, {2}, {3}}); + auto Constant_65243 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1.000000f}); + auto aten_neg_Multiply = + makeOP({aten_slice_Slice_10, Constant_65243}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_28998 = makeOP({aten_neg_Multiply, {-1, 1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_slice_Slice_14 = makeOP({VariadicSplit_32371->output(0), {0}, {INT_MAX}, {2}, {3}}); + auto Unsqueeze_28999 = makeOP({aten_slice_Slice_14, {-1, 1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_stack = makeOP({Unsqueeze_28998, Unsqueeze_28999}, {{"axis", -1}}); + auto aten_flatten_Reshape = + makeOP({aten_stack, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}}); + auto aten_mul_Multiply_1 = makeOP({aten_flatten_Reshape, aten_repeat_interleave_Gather_3}, + {{"auto_broadcast", "numpy"}}); + auto aten_add_Add = + makeOP({aten_mul_Multiply, aten_mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); + auto aten_cat_Concat_1 = makeOP({aten_add_Add, VariadicSplit_32371->output(1)}, {{"axis", -1}}); + + model = std::make_shared(ov::NodeVector{aten_cat_Concat_1}, + ov::ParameterVector{input, aten_gather_GatherElements}); + } + manager.register_pass(false); + { + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); + auto aten_gather_GatherElements = + std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 64}); + auto rope = makeOP({input, aten_gather_GatherElements, aten_gather_GatherElements}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", true}, + {"config.rotary_ndims", rotary_ndims}, + {"config.is_chatglm", false}, + {"config.support_2d_rope", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.gather_position_arg_id", 0}}); + model_ref = + std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) { + using namespace ov; + + { + auto position_ids = std::make_shared(ov::element::i64, ov::PartialShape{-1, -1}); + auto qkv = std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 3 * 4096}); + + auto qkv_proj = makeOP({qkv, 2, {4096, 4096, -1}}); + + auto view_Reshape = makeOP({qkv_proj->output(0), {0, 0, 32, 128}}, {{"special_zero", true}}); + auto slice_Slice_4 = makeOP({view_Reshape, {0}, {128}, {1}, {3}}); + auto slice_Slice = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + + auto Convert_50535 = makeOP({position_ids}, {{"destination_type", "i32"}}); + auto Unsqueeze_23750 = makeOP({Convert_50535, {-1, 1}}, {{"special_zero", false}}); + + auto slice_Slice_1 = makeOP({slice_Slice, Unsqueeze_23750, 1}, {{"batch_dims", 0}}); + auto Reshape_27400 = makeOP({slice_Slice_1, {-1, 1, 1, 128}}, {{"special_zero", false}}); + + auto mul_Multiply = makeOP({slice_Slice_4, Reshape_27400}, {{"auto_broadcast", "numpy"}}); + auto reshape_Reshape = makeOP({slice_Slice_4, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); + auto ListUnpack_Split = makeOP({reshape_Reshape, -2}, {{"num_splits", 2}}); + auto Multiply_54136 = + makeOP({ListUnpack_Split->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto ListUnpack_Squeeze_0 = + makeOP({Multiply_54136, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto ListUnpack_Squeeze = + makeOP({ListUnpack_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto cat_Concat = makeOP({ListUnpack_Squeeze_0, ListUnpack_Squeeze}, {{"axis", -1}}); + + auto slice_Slice_2 = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto slice_Slice_6 = makeOP({slice_Slice_2, Unsqueeze_23750, 1}, {{"batch_dims", 0}}); + auto Reshape_27408 = makeOP({slice_Slice_6, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto mul_Multiply_1 = makeOP({cat_Concat, Reshape_27408}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_10 = makeConst(element::f32, + ov::Shape({ + 1, + 32767, + 1, + 1, + }), + {1}); + auto view_Reshape_1 = makeOP({qkv_proj->output(1), {0, 0, 32, 128}}, {{"special_zero", true}}); + auto slice_Slice_11 = makeOP({view_Reshape_1, {0}, {128}, {1}, {3}}); + auto mul_Multiply_2 = makeOP({slice_Slice_11, Reshape_27400}, {{"auto_broadcast", "numpy"}}); + auto reshape_Reshape_1 = makeOP({slice_Slice_11, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); + auto ListUnpack_Split_1 = makeOP({reshape_Reshape_1, -2}, {{"num_splits", 2}}); + auto Multiply_54139 = + makeOP({ListUnpack_Split_1->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto ListUnpack_Squeeze_0_1 = + makeOP({Multiply_54139, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto ListUnpack_Squeeze_1 = + makeOP({ListUnpack_Split_1->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto cat_Concat_2 = makeOP({ListUnpack_Squeeze_0_1, ListUnpack_Squeeze_1}, {{"axis", -1}}); + auto mul_Multiply_3 = makeOP({cat_Concat_2, Reshape_27408}, {{"auto_broadcast", "numpy"}}); + auto add_Add_1 = makeOP({mul_Multiply_2, mul_Multiply_3}, {{"auto_broadcast", "numpy"}}); + model = std::make_shared(ov::NodeVector{add_Add_1}, ov::ParameterVector{position_ids, qkv}); + } + + manager.register_pass(false); + + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 4096 * 3}); + auto rotary_emp_sin = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto rotary_emp_cos = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto rope = makeOP({input, rotary_emp_sin, rotary_emp_cos}, + {{"config.slice_start", 4096}, + {"config.slice_stop", 8192}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", 128}, + {"config.is_chatglm", false}, + {"config.support_2d_rope", false}, + {"config.is_qwen", true}, + {"config.head_cnt", 32}, + {"config.head_size", 128}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input}); + } +}