From c8ef5bab9e290a255d6d1716dfd6f6c59ea0937a Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Wed, 15 Jan 2025 11:38:32 +0400 Subject: [PATCH 1/5] RopeFusion transformation fix after PagedAttention transformation --- .../fuse_rotary_positional_embeddings.cpp | 47 +++-- .../fuse_rotary_positional_embeddings.cpp | 168 ++++++++++++++++++ 2 files changed, 200 insertions(+), 15 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index c82853ec56e9ed..4a447690b24c90 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -397,9 +397,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto varsplit = makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); varsplit->set_output_size(2); // Reshape or UnSqueeze should both be support - auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}) | + auto dim0 = ov::gen_pattern::Symbol("dim0"); + auto dim1 = ov::gen_pattern::Symbol("dim1"); + auto unsqueeze_sin = makePattern({varsplit->output(0), {dim0, dim1, 1, 32}}) | makePattern({varsplit->output(0), 2}); - auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}) | + auto unsqueeze_cos = makePattern({varsplit->output(1), {dim0, dim1, 1, 32}}) | makePattern({varsplit->output(1), 2}); // repeate cos/sin table auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { @@ -419,10 +421,17 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto neg_Multiply_1177 = makePattern({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}}); auto Unsqueeze_65524 = makePattern({neg_Multiply_1177, -1}); + auto head_num = ov::gen_pattern::Symbol("head_num"); + auto Unsqueeze_28998 = + makePattern({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3); auto Unsqueeze_65525 = makePattern({slice_Slice_1168, -1}); - auto stack_1182 = makePattern({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); + auto Unsqueeze_28999 = + makePattern({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); + auto stack_1182 = + makePattern({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999}, + {{"axis", -1}}); auto ShapeOf_169068 = makePattern({stack_1182}); auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0); @@ -447,7 +456,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { makePattern({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}}); auto permute_Transpose_1213 = makePattern({cat_Concat_1211, {0, 2, 1, 3}}); - auto result = permute_Transpose_1213; + auto result = cat_Concat_1211 | permute_Transpose_1213; matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); @@ -461,7 +470,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { OutputVector new_args; config.rotary_ndims = static_cast(validator["ndims"]); - config.output_trans0213 = true; + if (pattern_map.count(permute_Transpose_1213)) + config.output_trans0213 = true; config.is_interleaved = true; // input is [B,L,H,S] @@ -478,14 +488,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(), pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(), pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(), - pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(), - pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(), pattern_map.at(stack_1182).get_node_shared_ptr(), pattern_map.at(mul_cos).get_node_shared_ptr(), pattern_map.at(mul_sin).get_node_shared_ptr(), pattern_map.at(rotary_emb).get_node_shared_ptr(), - pattern_map.at(cat_Concat_1211).get_node_shared_ptr(), - pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()}, + pattern_map.at(cat_Concat_1211).get_node_shared_ptr()}, new_node); ov::replace_node(old_node, new_node); // shapeof may be moved up from transpose to add, @@ -557,9 +564,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s } else { auto ListConstruct_452_Concat = makePattern({seq_length, {-1}, {head_cnt}, {ndims / 2}, {2}}, {{"axis", 0}}); + auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims / 2, 2}); auto const_target_shape_1 = makeConst({seq_len, batch, head_cnt, ndims / 2, 2}); - reshape_Reshape_453 = makePattern( - {slice_Slice_437 | var_split_1->output(0), ListConstruct_452_Concat | const_target_shape_1}); + reshape_Reshape_453 = + makePattern({slice_Slice_437 | var_split_1->output(0), + ListConstruct_452_Concat | const_target_shape_1 | const_target_shape_0}); } auto x_even = makePattern({reshape_Reshape_453, 0, -1}, {{"batch_dims", 0}}); @@ -588,6 +597,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s } else { auto ListConstruct_379_Concat = makePattern({seq_length, {-1}, {1}, {ndims / 2}, {2}}, {{"axis", 0}}); + auto const_target_shape_0 = makeConst({1, -1, 1, ndims / 2, 2}); auto const_target_shape_2 = makeConst({seq_len, batch, 1, ndims / 2, 2}); auto slice_Slice_449 = makePattern({cos_sin_cache, {0}, seq_length, {1}, {0}}); @@ -596,7 +606,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s // [seq_length, 1, batch, half_rotary_dims, 2] view_Reshape_460 = makePattern({slice_StridedSlice_449 | slice_Slice_449 | var_split_2->output(0), - ListConstruct_379_Concat | const_target_shape_2}, + ListConstruct_379_Concat | const_target_shape_0 | const_target_shape_2}, {{"special_zero", false}}); } @@ -609,12 +619,17 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s auto sub_Subtract_469 = makePattern({x_even_cos, neg_x_odd_sin}, {{"auto_broadcast", "numpy"}}); auto y_even = makePattern({sub_Subtract_469, -1}); + auto const_y_even_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1}); + auto y_even_reshape = + makePattern({sub_Subtract_469, const_y_even_reshape}, {{"special_zero", false}}); auto x_odd_cos = makePattern({x_odd, cos_tab}, {{"auto_broadcast", "numpy"}}); auto x_even_sin = makePattern({x_even, sin_tab}, {{"auto_broadcast", "numpy"}}); auto add_Add_476 = makePattern({x_odd_cos, x_even_sin}, {{"auto_broadcast", "numpy"}}); auto y_odd = makePattern({add_Add_476, -1}); + auto const_y_odd_reshape = makeConst({1, -1, head_cnt, ndims / 2, 1}); + auto y_odd_reshape = makePattern({add_Add_476, const_y_odd_reshape}, {{"special_zero", false}}); - auto stack_481 = makePattern({y_even, y_odd}, {{"axis", -1}}); + auto stack_481 = makePattern({y_even | y_even_reshape, y_odd | y_odd_reshape}, {{"axis", -1}}); auto ShapeOf_135133 = makePattern({stack_481}); auto flatten_Slice_497 = GenSlice(ShapeOf_135133, 0, 3, 1, 0); @@ -629,9 +644,11 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s {{"special_zero", true}}); } else { // [length, batch, head_cnt, half_rotary_dims, 2] + auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims}); const_target_shape_3 = makeConst({seq_len, batch, head_cnt, ndims}); - flatten_Reshape_501 = makePattern({stack_481, flatten_Concat_500 | const_target_shape_3}, - {{"special_zero", true}}); + flatten_Reshape_501 = + makePattern({stack_481, flatten_Concat_500 | const_target_shape_0 | const_target_shape_3}, + {{"special_zero", true}}); } auto slice_Slice_443 = GenSlice(input_key, ndims, INT_MAX, 1, 3); diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 0328831ff1a69c..5813cfe79c23ff 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -1131,3 +1131,171 @@ TEST_F(TransformationTestsF, ConvertToROPE_Flux_mul_squeeze_unsqueeze) { } comparator.enable(FunctionsComparator::ATTRIBUTES); } + +TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) { + disable_rt_info_check(); + const int batch = -1; + const int seq_len = 1; + const int num_heads = 32; + const int num_heads_kv = 2; + const int ndims = 128; + const int rotary_ndims = 64; + const int hidden_size = ndims * (num_heads + 2 * num_heads_kv); + const int hidden_size_q = ndims * num_heads; + const int hidden_size_kv = ndims * num_heads_kv; + using namespace ov; + { + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{seq_len, batch, hidden_size}); + auto cos_sin = std::make_shared(ov::element::f32, + ov::PartialShape{seq_len, batch, rotary_ndims / 2, 2}); + auto aten_slice_Slice_1 = makeOP({cos_sin, {0}, {1}, {1}, {0}}); + auto aten_view_Reshape = makeOP({aten_slice_Slice_1, {seq_len, batch, 1, rotary_ndims / 2, 2}}, + {{"special_zero", false}}); + auto aten_select_Gather_1 = makeOP({aten_view_Reshape, 0, -1}, {{"batch_dims", 0}}); + auto aten_select_Gather_3 = makeOP({aten_view_Reshape, 1, -1}, {{"batch_dims", 0}}); + + auto attn_prim_ListUnpack = + makeOP({input, -1, {hidden_size_q, hidden_size_kv, hidden_size_kv}}); + auto attn_aten_view_Reshape_2 = + makeOP({attn_prim_ListUnpack->output(0), {0, 0, num_heads, ndims}}, + {{"special_zero", true}}); + auto VariadicSplit_29663 = + makeOP({attn_aten_view_Reshape_2, 3, {rotary_ndims, ndims - rotary_ndims}}); + auto aten_reshape_Reshape_55 = + makeOP({VariadicSplit_29663->output(0), {0, 0, num_heads, rotary_ndims / 2, 2}}, + {{"special_zero", true}}); + auto aten_select_Gather_440 = makeOP({aten_reshape_Reshape_55, 0, -1}, {{"batch_dims", 0}}); + auto aten_mul_Multiply_276 = + makeOP({aten_select_Gather_440, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}}); + auto aten_select_Gather_442 = makeOP({aten_reshape_Reshape_55, 1, -1}, {{"batch_dims", 0}}); + auto aten_mul_Multiply_277 = + makeOP({aten_select_Gather_442, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}}); + auto Multiply_34833 = + makeOP({aten_mul_Multiply_277, -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto aten_sub_Subtract_55 = + makeOP({aten_mul_Multiply_276, Multiply_34833}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62197 = makeOP({aten_sub_Subtract_55, {1, -1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_mul_Multiply_278 = + makeOP({aten_select_Gather_442, aten_select_Gather_1}, {{"auto_broadcast", "numpy"}}); + auto aten_mul_Multiply_279 = + makeOP({aten_select_Gather_440, aten_select_Gather_3}, {{"auto_broadcast", "numpy"}}); + auto aten_add_Add_55 = + makeOP({aten_mul_Multiply_278, aten_mul_Multiply_279}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_62198 = makeOP({aten_add_Add_55, {1, -1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_stack_55 = makeOP({Unsqueeze_62197, Unsqueeze_62198}, {{"axis", -1}}); + auto aten_flatten_Reshape_55 = + makeOP({aten_stack_55, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}}); + auto aten_cat_Concat_55 = + makeOP({aten_flatten_Reshape_55, VariadicSplit_29663->output(1)}, {{"axis", -1}}); + + model = std::make_shared(ov::NodeVector{aten_cat_Concat_55}, ov::ParameterVector{input, cos_sin}); + } + manager.register_pass(false); + { + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{seq_len, batch, hidden_size}); + auto gather_cos_sin = + std::make_shared(ov::element::f32, + ov::PartialShape{seq_len, batch, rotary_ndims / 2, 2}); + auto rope = makeOP({input, gather_cos_sin, gather_cos_sin}, + {{"config.slice_start", 0}, + {"config.slice_stop", 4096}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", rotary_ndims}, + {"config.is_chatglm", true}, + {"config.support_2d_rope", false}, + {"config.is_qwen", false}, + {"config.head_cnt", num_heads}, + {"config.head_size", ndims}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin}); + } +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) { + disable_rt_info_check(); + const int batch = -1; + const int num_heads = 16; + const int ndims = 256; + const int rotary_ndims = 64; + using namespace ov; + { + std::vector rpi_idx(rotary_ndims); + for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) { + rpi_idx[i] = index; + rpi_idx[i + 1] = index; + } + auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx); + + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); + auto aten_gather_GatherElements = + std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, rotary_ndims}); + + auto prim_ListUnpack_VariadicSplit = + makeOP({aten_gather_GatherElements, -1, {rotary_ndims / 2, -1}}); + auto aten_unsqueeze_Unsqueeze_1 = + makeOP({prim_ListUnpack_VariadicSplit->output(1), {-1, 1, 1, rotary_ndims / 2}}, + {{"special_zero", false}}); + auto aten_repeat_interleave_Gather_1 = + makeOP({aten_unsqueeze_Unsqueeze_1, repeat_interleave_index, 3}, {{"batch_dims", 0}}); + + auto aten_unsqueeze_Unsqueeze_2 = + makeOP({prim_ListUnpack_VariadicSplit->output(0), {-1, 1, 1, rotary_ndims / 2}}, + {{"special_zero", false}}); + auto aten_repeat_interleave_Gather_3 = + makeOP({aten_unsqueeze_Unsqueeze_2, repeat_interleave_index, 3}, {{"batch_dims", 0}}); + + auto VariadicSplit_32371 = makeOP({input, 3, {rotary_ndims, ndims - rotary_ndims}}); + auto aten_mul_Multiply = + makeOP({VariadicSplit_32371->output(0), aten_repeat_interleave_Gather_1}, + {{"auto_broadcast", "numpy"}}); + auto aten_slice_Slice_10 = makeOP({VariadicSplit_32371->output(0), {1}, {INT_MAX}, {2}, {3}}); + auto Constant_65243 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1.000000f}); + auto aten_neg_Multiply = + makeOP({aten_slice_Slice_10, Constant_65243}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_28998 = makeOP({aten_neg_Multiply, {-1, 1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_slice_Slice_14 = makeOP({VariadicSplit_32371->output(0), {0}, {INT_MAX}, {2}, {3}}); + auto Unsqueeze_28999 = makeOP({aten_slice_Slice_14, {-1, 1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_stack = makeOP({Unsqueeze_28998, Unsqueeze_28999}, {{"axis", -1}}); + auto aten_flatten_Reshape = + makeOP({aten_stack, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}}); + auto aten_mul_Multiply_1 = makeOP({aten_flatten_Reshape, aten_repeat_interleave_Gather_3}, + {{"auto_broadcast", "numpy"}}); + auto aten_add_Add = + makeOP({aten_mul_Multiply, aten_mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); + auto aten_cat_Concat_1 = makeOP({aten_add_Add, VariadicSplit_32371->output(1)}, {{"axis", -1}}); + + model = std::make_shared(ov::NodeVector{aten_cat_Concat_1}, + ov::ParameterVector{input, aten_gather_GatherElements}); + } + manager.register_pass(false); + { + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); + auto aten_gather_GatherElements = + std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 64}); + auto rope = makeOP({input, aten_gather_GatherElements, aten_gather_GatherElements}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", true}, + {"config.rotary_ndims", rotary_ndims}, + {"config.is_chatglm", false}, + {"config.support_2d_rope", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.gather_position_arg_id", 0}}); + model_ref = + std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements}); + } +} \ No newline at end of file From 97c1d0fa29686de0ce6e29cd4e8606f91a87e0be Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Wed, 15 Jan 2025 12:39:20 +0400 Subject: [PATCH 2/5] Revert changes for gpt-j --- .../fuse_rotary_positional_embeddings.cpp | 25 ++---- .../fuse_rotary_positional_embeddings.cpp | 83 ------------------- 2 files changed, 9 insertions(+), 99 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 4a447690b24c90..908aef0dffbdb6 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -397,11 +397,9 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto varsplit = makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); varsplit->set_output_size(2); // Reshape or UnSqueeze should both be support - auto dim0 = ov::gen_pattern::Symbol("dim0"); - auto dim1 = ov::gen_pattern::Symbol("dim1"); - auto unsqueeze_sin = makePattern({varsplit->output(0), {dim0, dim1, 1, 32}}) | + auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}) | makePattern({varsplit->output(0), 2}); - auto unsqueeze_cos = makePattern({varsplit->output(1), {dim0, dim1, 1, 32}}) | + auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}) | makePattern({varsplit->output(1), 2}); // repeate cos/sin table auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { @@ -421,17 +419,10 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto neg_Multiply_1177 = makePattern({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}}); auto Unsqueeze_65524 = makePattern({neg_Multiply_1177, -1}); - auto head_num = ov::gen_pattern::Symbol("head_num"); - auto Unsqueeze_28998 = - makePattern({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3); auto Unsqueeze_65525 = makePattern({slice_Slice_1168, -1}); - auto Unsqueeze_28999 = - makePattern({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); - auto stack_1182 = - makePattern({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999}, - {{"axis", -1}}); + auto stack_1182 = makePattern({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); auto ShapeOf_169068 = makePattern({stack_1182}); auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0); @@ -456,7 +447,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { makePattern({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}}); auto permute_Transpose_1213 = makePattern({cat_Concat_1211, {0, 2, 1, 3}}); - auto result = cat_Concat_1211 | permute_Transpose_1213; + auto result = permute_Transpose_1213; matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); @@ -470,8 +461,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { OutputVector new_args; config.rotary_ndims = static_cast(validator["ndims"]); - if (pattern_map.count(permute_Transpose_1213)) - config.output_trans0213 = true; + config.output_trans0213 = true; config.is_interleaved = true; // input is [B,L,H,S] @@ -488,11 +478,14 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(), pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(), pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(), + pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(), + pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(), pattern_map.at(stack_1182).get_node_shared_ptr(), pattern_map.at(mul_cos).get_node_shared_ptr(), pattern_map.at(mul_sin).get_node_shared_ptr(), pattern_map.at(rotary_emb).get_node_shared_ptr(), - pattern_map.at(cat_Concat_1211).get_node_shared_ptr()}, + pattern_map.at(cat_Concat_1211).get_node_shared_ptr(), + pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()}, new_node); ov::replace_node(old_node, new_node); // shapeof may be moved up from transpose to add, diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 5813cfe79c23ff..1b34e0c4423d3d 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -1215,87 +1215,4 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) { {"config.gather_position_arg_id", 0}}); model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin}); } -} - -TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) { - disable_rt_info_check(); - const int batch = -1; - const int num_heads = 16; - const int ndims = 256; - const int rotary_ndims = 64; - using namespace ov; - { - std::vector rpi_idx(rotary_ndims); - for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) { - rpi_idx[i] = index; - rpi_idx[i + 1] = index; - } - auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx); - - auto input = - std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); - auto aten_gather_GatherElements = - std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, rotary_ndims}); - - auto prim_ListUnpack_VariadicSplit = - makeOP({aten_gather_GatherElements, -1, {rotary_ndims / 2, -1}}); - auto aten_unsqueeze_Unsqueeze_1 = - makeOP({prim_ListUnpack_VariadicSplit->output(1), {-1, 1, 1, rotary_ndims / 2}}, - {{"special_zero", false}}); - auto aten_repeat_interleave_Gather_1 = - makeOP({aten_unsqueeze_Unsqueeze_1, repeat_interleave_index, 3}, {{"batch_dims", 0}}); - - auto aten_unsqueeze_Unsqueeze_2 = - makeOP({prim_ListUnpack_VariadicSplit->output(0), {-1, 1, 1, rotary_ndims / 2}}, - {{"special_zero", false}}); - auto aten_repeat_interleave_Gather_3 = - makeOP({aten_unsqueeze_Unsqueeze_2, repeat_interleave_index, 3}, {{"batch_dims", 0}}); - - auto VariadicSplit_32371 = makeOP({input, 3, {rotary_ndims, ndims - rotary_ndims}}); - auto aten_mul_Multiply = - makeOP({VariadicSplit_32371->output(0), aten_repeat_interleave_Gather_1}, - {{"auto_broadcast", "numpy"}}); - auto aten_slice_Slice_10 = makeOP({VariadicSplit_32371->output(0), {1}, {INT_MAX}, {2}, {3}}); - auto Constant_65243 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1.000000f}); - auto aten_neg_Multiply = - makeOP({aten_slice_Slice_10, Constant_65243}, {{"auto_broadcast", "numpy"}}); - auto Unsqueeze_28998 = makeOP({aten_neg_Multiply, {-1, 1, num_heads, rotary_ndims / 2, 1}}, - {{"special_zero", false}}); - auto aten_slice_Slice_14 = makeOP({VariadicSplit_32371->output(0), {0}, {INT_MAX}, {2}, {3}}); - auto Unsqueeze_28999 = makeOP({aten_slice_Slice_14, {-1, 1, num_heads, rotary_ndims / 2, 1}}, - {{"special_zero", false}}); - auto aten_stack = makeOP({Unsqueeze_28998, Unsqueeze_28999}, {{"axis", -1}}); - auto aten_flatten_Reshape = - makeOP({aten_stack, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}}); - auto aten_mul_Multiply_1 = makeOP({aten_flatten_Reshape, aten_repeat_interleave_Gather_3}, - {{"auto_broadcast", "numpy"}}); - auto aten_add_Add = - makeOP({aten_mul_Multiply, aten_mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); - auto aten_cat_Concat_1 = makeOP({aten_add_Add, VariadicSplit_32371->output(1)}, {{"axis", -1}}); - - model = std::make_shared(ov::NodeVector{aten_cat_Concat_1}, - ov::ParameterVector{input, aten_gather_GatherElements}); - } - manager.register_pass(false); - { - auto input = - std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); - auto aten_gather_GatherElements = - std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 64}); - auto rope = makeOP({input, aten_gather_GatherElements, aten_gather_GatherElements}, - {{"config.slice_start", 0}, - {"config.slice_stop", 0}, - {"config.input_trans0213", false}, - {"config.output_trans0213", false}, - {"config.is_interleaved", true}, - {"config.rotary_ndims", rotary_ndims}, - {"config.is_chatglm", false}, - {"config.support_2d_rope", false}, - {"config.is_qwen", false}, - {"config.head_cnt", 0}, - {"config.head_size", 0}, - {"config.gather_position_arg_id", 0}}); - model_ref = - std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements}); - } } \ No newline at end of file From 3e8a6b9c14abf38990a6b1748356f3359713fe9d Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Thu, 16 Jan 2025 08:20:42 +0400 Subject: [PATCH 3/5] fix chatglm rope pattern --- .../common_optimizations/fuse_rotary_positional_embeddings.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 908aef0dffbdb6..6b12f56215ca83 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -640,7 +640,7 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s auto const_target_shape_0 = makeConst({0, 0, head_cnt, ndims}); const_target_shape_3 = makeConst({seq_len, batch, head_cnt, ndims}); flatten_Reshape_501 = - makePattern({stack_481, flatten_Concat_500 | const_target_shape_0 | const_target_shape_3}, + makePattern({stack_481, flatten_Concat_500 | const_target_shape_3 | const_target_shape_0}, {{"special_zero", true}}); } auto slice_Slice_443 = GenSlice(input_key, ndims, INT_MAX, 1, 3); From 5baf699ad6e1d850f5337bd530d5ce86b2482d54 Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Fri, 17 Jan 2025 14:16:16 +0400 Subject: [PATCH 4/5] Update RopeFusion for qwen, gpt-j models --- .../fuse_rotary_positional_embeddings.cpp | 61 +++++++++----- .../fuse_rotary_positional_embeddings.cpp | 83 +++++++++++++++++++ 2 files changed, 125 insertions(+), 19 deletions(-) diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index 6b12f56215ca83..4abc1a35841823 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -397,9 +397,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto varsplit = makePattern({gather_sin_cos, -1, {ndims / 2, -1}}); varsplit->set_output_size(2); // Reshape or UnSqueeze should both be support - auto unsqueeze_sin = makePattern({varsplit->output(0), {1, -1, 1, 32}}) | + auto dim0 = ov::gen_pattern::Symbol("dim0"); + auto dim1 = ov::gen_pattern::Symbol("dim1"); + auto unsqueeze_sin = makePattern({varsplit->output(0), {dim0, dim1, 1, 32}}) | makePattern({varsplit->output(0), 2}); - auto unsqueeze_cos = makePattern({varsplit->output(1), {1, -1, 1, 32}}) | + auto unsqueeze_cos = makePattern({varsplit->output(1), {dim0, dim1, 1, 32}}) | makePattern({varsplit->output(1), 2}); // repeate cos/sin table auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) { @@ -419,10 +421,17 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { auto neg_Multiply_1177 = makePattern({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}}); auto Unsqueeze_65524 = makePattern({neg_Multiply_1177, -1}); + auto head_num = ov::gen_pattern::Symbol("head_num"); + auto Unsqueeze_28998 = + makePattern({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3); auto Unsqueeze_65525 = makePattern({slice_Slice_1168, -1}); - auto stack_1182 = makePattern({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}}); + auto Unsqueeze_28999 = + makePattern({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}}); + auto stack_1182 = + makePattern({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999}, + {{"axis", -1}}); auto ShapeOf_169068 = makePattern({stack_1182}); auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0); @@ -447,7 +456,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { makePattern({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}}); auto permute_Transpose_1213 = makePattern({cat_Concat_1211, {0, 2, 1, 3}}); - auto result = permute_Transpose_1213; + auto result = cat_Concat_1211 | permute_Transpose_1213; matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) { const auto& pattern_map = m.get_pattern_value_map(); @@ -461,7 +470,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { OutputVector new_args; config.rotary_ndims = static_cast(validator["ndims"]); - config.output_trans0213 = true; + if (pattern_map.count(permute_Transpose_1213)) + config.output_trans0213 = true; config.is_interleaved = true; // input is [B,L,H,S] @@ -478,14 +488,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() { pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(), pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(), pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(), - pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(), - pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(), pattern_map.at(stack_1182).get_node_shared_ptr(), pattern_map.at(mul_cos).get_node_shared_ptr(), pattern_map.at(mul_sin).get_node_shared_ptr(), pattern_map.at(rotary_emb).get_node_shared_ptr(), - pattern_map.at(cat_Concat_1211).get_node_shared_ptr(), - pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()}, + pattern_map.at(cat_Concat_1211).get_node_shared_ptr()}, new_node); ov::replace_node(old_node, new_node); // shapeof may be moved up from transpose to add, @@ -705,6 +712,7 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto rotary_emb_cos = makePattern("[1,?,1,?]"); // [1,..4096,1,128] auto rotary_emb_sin = makePattern("[1,?,1,?]"); // [1,..4096,1,128] auto qkv_proj = makePattern("[?,?,?]"); // [?,?,12288] + auto position_ids = makePattern(); auto head_cnt = ov::gen_pattern::Symbol("head_cnt"); auto head_size = ov::gen_pattern::Symbol("head_size"); @@ -731,14 +739,19 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto ScatterUpdate_463814 = makePattern({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}}); auto slice_Slice_446 = makePattern({rotary_emb_cos, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}}); + + auto gather_cos_by_pos_ids = makePattern({rotary_emb_cos, position_ids, 1}, {{"batch_dims", 0}}); + auto reshape_cos_to_expected_layout = + makePattern({gather_cos_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto slice_StridedSlice_446 = GenStridedSlice(rotary_emb_cos, ScatterUpdate_463814, {0, INT_MAX}, {1, 1}, 1); // tensor_array - auto mul_Multiply_552 = - makePattern({slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446}, - {{"auto_broadcast", "numpy"}}); // tensor_array + auto mul_Multiply_552 = makePattern( + {slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446 | reshape_cos_to_expected_layout}, + {{"auto_broadcast", "numpy"}}); // tensor_array auto reshape_opt1 = [&](std::shared_ptr input_BLHS) { auto ShapeOf_485814 = makePattern({input_BLHS}, {}); @@ -772,8 +785,15 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { makePattern({Multiply_567527, -2}); // tensor_array auto ListUnpack_586_Squeeze = makePattern({ListUnpack_586_Split->output(0), -2}); // tensor_array - auto cat_Concat_593 = makePattern({ListUnpack_586_Squeeze_0, ListUnpack_586_Squeeze}, - {{"axis", -1}}); // tensor_array + + auto ListUnpack_Squeeze_0_1 = + makePattern({Multiply_567527, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto ListUnpack_Squeeze_1 = + makePattern({ListUnpack_586_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + + auto cat_Concat_593 = makePattern( + {ListUnpack_586_Squeeze_0 | ListUnpack_Squeeze_0_1, ListUnpack_586_Squeeze | ListUnpack_Squeeze_1}, + {{"axis", -1}}); // tensor_array auto slice_StridedSlice_470 = GenStridedSlice(rotary_emb_sin, ScatterUpdate_463814, {0, INT_MAX}, @@ -781,9 +801,12 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { 1); // tensor_array auto slice_Slice_470 = makePattern({rotary_emb_sin, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}}); - auto mul_Multiply_594 = - makePattern({cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470}, - {{"auto_broadcast", "numpy"}}); // tensor_array + auto gather_sin_by_pos_ids = makePattern({rotary_emb_sin, position_ids, 1}, {{"batch_dims", 0}}); + auto reshape_sin_to_expected_layout = + makePattern({gather_sin_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto mul_Multiply_594 = makePattern( + {cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470 | reshape_sin_to_expected_layout}, + {{"auto_broadcast", "numpy"}}); // tensor_array auto add_Add_597 = makePattern({mul_Multiply_552, mul_Multiply_594}, {{"auto_broadcast", "numpy"}}); // tensor_array @@ -844,8 +867,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) { auto new_node = std::make_shared(new_args, config); new_node->set_friendly_name(old_node->get_friendly_name()); ov::copy_runtime_info({pattern_map.at(Multiply_567527).get_node_shared_ptr(), - pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(), - pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(), + // pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(), + // pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(), pattern_map.at(cat_Concat_593).get_node_shared_ptr(), pattern_map.at(mul_Multiply_594).get_node_shared_ptr(), pattern_map.at(add_Add_597).get_node_shared_ptr()}, diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 1b34e0c4423d3d..5813cfe79c23ff 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -1215,4 +1215,87 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGLM3_PagedAttention) { {"config.gather_position_arg_id", 0}}); model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, gather_cos_sin}); } +} + +TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) { + disable_rt_info_check(); + const int batch = -1; + const int num_heads = 16; + const int ndims = 256; + const int rotary_ndims = 64; + using namespace ov; + { + std::vector rpi_idx(rotary_ndims); + for (int i = 0, index = 0; i < rotary_ndims; i += 2, index++) { + rpi_idx[i] = index; + rpi_idx[i + 1] = index; + } + auto repeat_interleave_index = makeConst(ov::element::i32, ov::Shape({rotary_ndims}), rpi_idx); + + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); + auto aten_gather_GatherElements = + std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, rotary_ndims}); + + auto prim_ListUnpack_VariadicSplit = + makeOP({aten_gather_GatherElements, -1, {rotary_ndims / 2, -1}}); + auto aten_unsqueeze_Unsqueeze_1 = + makeOP({prim_ListUnpack_VariadicSplit->output(1), {-1, 1, 1, rotary_ndims / 2}}, + {{"special_zero", false}}); + auto aten_repeat_interleave_Gather_1 = + makeOP({aten_unsqueeze_Unsqueeze_1, repeat_interleave_index, 3}, {{"batch_dims", 0}}); + + auto aten_unsqueeze_Unsqueeze_2 = + makeOP({prim_ListUnpack_VariadicSplit->output(0), {-1, 1, 1, rotary_ndims / 2}}, + {{"special_zero", false}}); + auto aten_repeat_interleave_Gather_3 = + makeOP({aten_unsqueeze_Unsqueeze_2, repeat_interleave_index, 3}, {{"batch_dims", 0}}); + + auto VariadicSplit_32371 = makeOP({input, 3, {rotary_ndims, ndims - rotary_ndims}}); + auto aten_mul_Multiply = + makeOP({VariadicSplit_32371->output(0), aten_repeat_interleave_Gather_1}, + {{"auto_broadcast", "numpy"}}); + auto aten_slice_Slice_10 = makeOP({VariadicSplit_32371->output(0), {1}, {INT_MAX}, {2}, {3}}); + auto Constant_65243 = makeConst(element::f32, ov::Shape({1, 1, 1, 1}), {-1.000000f}); + auto aten_neg_Multiply = + makeOP({aten_slice_Slice_10, Constant_65243}, {{"auto_broadcast", "numpy"}}); + auto Unsqueeze_28998 = makeOP({aten_neg_Multiply, {-1, 1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_slice_Slice_14 = makeOP({VariadicSplit_32371->output(0), {0}, {INT_MAX}, {2}, {3}}); + auto Unsqueeze_28999 = makeOP({aten_slice_Slice_14, {-1, 1, num_heads, rotary_ndims / 2, 1}}, + {{"special_zero", false}}); + auto aten_stack = makeOP({Unsqueeze_28998, Unsqueeze_28999}, {{"axis", -1}}); + auto aten_flatten_Reshape = + makeOP({aten_stack, {0, 0, num_heads, rotary_ndims}}, {{"special_zero", true}}); + auto aten_mul_Multiply_1 = makeOP({aten_flatten_Reshape, aten_repeat_interleave_Gather_3}, + {{"auto_broadcast", "numpy"}}); + auto aten_add_Add = + makeOP({aten_mul_Multiply, aten_mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); + auto aten_cat_Concat_1 = makeOP({aten_add_Add, VariadicSplit_32371->output(1)}, {{"axis", -1}}); + + model = std::make_shared(ov::NodeVector{aten_cat_Concat_1}, + ov::ParameterVector{input, aten_gather_GatherElements}); + } + manager.register_pass(false); + { + auto input = + std::make_shared(ov::element::f32, ov::PartialShape{batch, 1, num_heads, ndims}); + auto aten_gather_GatherElements = + std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 64}); + auto rope = makeOP({input, aten_gather_GatherElements, aten_gather_GatherElements}, + {{"config.slice_start", 0}, + {"config.slice_stop", 0}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", true}, + {"config.rotary_ndims", rotary_ndims}, + {"config.is_chatglm", false}, + {"config.support_2d_rope", false}, + {"config.is_qwen", false}, + {"config.head_cnt", 0}, + {"config.head_size", 0}, + {"config.gather_position_arg_id", 0}}); + model_ref = + std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements}); + } } \ No newline at end of file From 36f1184cee73780e86a632d59f831fd19a508286 Mon Sep 17 00:00:00 2001 From: Tikhonov Ivan Date: Fri, 17 Jan 2025 21:34:46 +0400 Subject: [PATCH 5/5] add a unit for qwen model --- .../fuse_rotary_positional_embeddings.cpp | 115 +++++++++++++++++- 1 file changed, 114 insertions(+), 1 deletion(-) diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index 5813cfe79c23ff..96959fd292ee9e 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -14,6 +14,7 @@ #include "ov_ops/rotary_positional_embeddings.hpp" #include "ov_ops/type_relaxed.hpp" #include "transformations/utils/gen_pattern.hpp" +#include "transformations/utils/print_model.hpp" using namespace testing; using namespace ov::gen_pattern; @@ -1298,4 +1299,116 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) { model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements}); } -} \ No newline at end of file +} + +TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) { + using namespace ov; + + { + auto position_ids = std::make_shared(ov::element::i64, ov::PartialShape{-1, -1}); + auto qkv = std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 3 * 4096}); + + auto qkv_proj = makeOP({qkv, 2, {4096, 4096, -1}}); + + auto view_Reshape = makeOP({qkv_proj->output(0), {0, 0, 32, 128}}, {{"special_zero", true}}); + auto slice_Slice_4 = makeOP({view_Reshape, {0}, {128}, {1}, {3}}); + auto slice_Slice = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + + auto Convert_50535 = makeOP({position_ids}, {{"destination_type", "i32"}}); + auto Unsqueeze_23750 = makeOP({Convert_50535, {-1, 1}}, {{"special_zero", false}}); + + auto slice_Slice_1 = makeOP({slice_Slice, Unsqueeze_23750, 1}, {{"batch_dims", 0}}); + auto Reshape_27400 = makeOP({slice_Slice_1, {-1, 1, 1, 128}}, {{"special_zero", false}}); + + auto mul_Multiply = makeOP({slice_Slice_4, Reshape_27400}, {{"auto_broadcast", "numpy"}}); + auto reshape_Reshape = makeOP({slice_Slice_4, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); + auto ListUnpack_Split = makeOP({reshape_Reshape, -2}, {{"num_splits", 2}}); + auto Multiply_54136 = + makeOP({ListUnpack_Split->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto ListUnpack_Squeeze_0 = + makeOP({Multiply_54136, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto ListUnpack_Squeeze = + makeOP({ListUnpack_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto cat_Concat = makeOP({ListUnpack_Squeeze_0, ListUnpack_Squeeze}, {{"axis", -1}}); + + auto slice_Slice_2 = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto slice_Slice_6 = makeOP({slice_Slice_2, Unsqueeze_23750, 1}, {{"batch_dims", 0}}); + auto Reshape_27408 = makeOP({slice_Slice_6, {-1, 1, 1, 128}}, {{"special_zero", false}}); + auto mul_Multiply_1 = makeOP({cat_Concat, Reshape_27408}, {{"auto_broadcast", "numpy"}}); + auto add_Add = makeOP({mul_Multiply, mul_Multiply_1}, {{"auto_broadcast", "numpy"}}); + + auto slice_Slice_10 = makeConst(element::f32, + ov::Shape({ + 1, + 32767, + 1, + 1, + }), + {1}); + auto view_Reshape_1 = makeOP({qkv_proj->output(1), {0, 0, 32, 128}}, {{"special_zero", true}}); + auto slice_Slice_11 = makeOP({view_Reshape_1, {0}, {128}, {1}, {3}}); + auto mul_Multiply_2 = makeOP({slice_Slice_11, Reshape_27400}, {{"auto_broadcast", "numpy"}}); + auto reshape_Reshape_1 = makeOP({slice_Slice_11, {0, 0, 32, 2, 64}}, {{"special_zero", true}}); + auto ListUnpack_Split_1 = makeOP({reshape_Reshape_1, -2}, {{"num_splits", 2}}); + auto Multiply_54139 = + makeOP({ListUnpack_Split_1->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}}); + auto ListUnpack_Squeeze_0_1 = + makeOP({Multiply_54139, {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto ListUnpack_Squeeze_1 = + makeOP({ListUnpack_Split_1->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}}); + auto cat_Concat_2 = makeOP({ListUnpack_Squeeze_0_1, ListUnpack_Squeeze_1}, {{"axis", -1}}); + auto mul_Multiply_3 = makeOP({cat_Concat_2, Reshape_27408}, {{"auto_broadcast", "numpy"}}); + auto add_Add_1 = makeOP({mul_Multiply_2, mul_Multiply_3}, {{"auto_broadcast", "numpy"}}); + model = std::make_shared(ov::NodeVector{add_Add_1}, ov::ParameterVector{position_ids, qkv}); + } + + manager.register_pass(false); + + { + auto input = std::make_shared(ov::element::f32, ov::PartialShape{-1, 1, 4096 * 3}); + auto rotary_emp_sin = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto rotary_emp_cos = makeConst(element::f32, + ov::Shape({ + 1, + 4096, + 1, + 128, + }), + {1}); + auto rope = makeOP({input, rotary_emp_sin, rotary_emp_cos}, + {{"config.slice_start", 4096}, + {"config.slice_stop", 8192}, + {"config.input_trans0213", false}, + {"config.output_trans0213", false}, + {"config.is_interleaved", false}, + {"config.rotary_ndims", 128}, + {"config.is_chatglm", false}, + {"config.support_2d_rope", false}, + {"config.is_qwen", true}, + {"config.head_cnt", 32}, + {"config.head_size", 128}, + {"config.gather_position_arg_id", 0}}); + model_ref = std::make_shared(ov::NodeVector{rope}, ov::ParameterVector{input}); + } +}