Skip to content

Commit

Permalink
add a unit for qwen model
Browse files Browse the repository at this point in the history
  • Loading branch information
itikhono committed Jan 17, 2025
1 parent 5baf699 commit 36f1184
Showing 1 changed file with 114 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#include "ov_ops/rotary_positional_embeddings.hpp"
#include "ov_ops/type_relaxed.hpp"
#include "transformations/utils/gen_pattern.hpp"
#include "transformations/utils/print_model.hpp"

using namespace testing;
using namespace ov::gen_pattern;
Expand Down Expand Up @@ -1298,4 +1299,116 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_PagedAttention) {
model_ref =
std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input, aten_gather_GatherElements});
}
}
}

TEST_F(TransformationTestsF, ConvertToROPE_Qwen_PagedAttention) {
using namespace ov;

{
auto position_ids = std::make_shared<opset1::Parameter>(ov::element::i64, ov::PartialShape{-1, -1});
auto qkv = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 3 * 4096});

auto qkv_proj = makeOP<opset1::VariadicSplit>({qkv, 2, {4096, 4096, -1}});

auto view_Reshape = makeOP<opset1::Reshape>({qkv_proj->output(0), {0, 0, 32, 128}}, {{"special_zero", true}});
auto slice_Slice_4 = makeOP<opset8::Slice>({view_Reshape, {0}, {128}, {1}, {3}});
auto slice_Slice = makeConst(element::f32,
ov::Shape({
1,
4096,
1,
128,
}),
{1});

auto Convert_50535 = makeOP<opset1::Convert>({position_ids}, {{"destination_type", "i32"}});
auto Unsqueeze_23750 = makeOP<opset1::Reshape>({Convert_50535, {-1, 1}}, {{"special_zero", false}});

auto slice_Slice_1 = makeOP<opset8::Gather>({slice_Slice, Unsqueeze_23750, 1}, {{"batch_dims", 0}});
auto Reshape_27400 = makeOP<opset1::Reshape>({slice_Slice_1, {-1, 1, 1, 128}}, {{"special_zero", false}});

auto mul_Multiply = makeOP<opset1::Multiply>({slice_Slice_4, Reshape_27400}, {{"auto_broadcast", "numpy"}});
auto reshape_Reshape = makeOP<opset1::Reshape>({slice_Slice_4, {0, 0, 32, 2, 64}}, {{"special_zero", true}});
auto ListUnpack_Split = makeOP<opset1::Split>({reshape_Reshape, -2}, {{"num_splits", 2}});
auto Multiply_54136 =
makeOP<opset1::Multiply>({ListUnpack_Split->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}});
auto ListUnpack_Squeeze_0 =
makeOP<opset1::Reshape>({Multiply_54136, {-1, 1, 32, 64}}, {{"special_zero", false}});
auto ListUnpack_Squeeze =
makeOP<opset1::Reshape>({ListUnpack_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});
auto cat_Concat = makeOP<opset1::Concat>({ListUnpack_Squeeze_0, ListUnpack_Squeeze}, {{"axis", -1}});

auto slice_Slice_2 = makeConst(element::f32,
ov::Shape({
1,
4096,
1,
128,
}),
{1});
auto slice_Slice_6 = makeOP<opset8::Gather>({slice_Slice_2, Unsqueeze_23750, 1}, {{"batch_dims", 0}});
auto Reshape_27408 = makeOP<opset1::Reshape>({slice_Slice_6, {-1, 1, 1, 128}}, {{"special_zero", false}});
auto mul_Multiply_1 = makeOP<opset1::Multiply>({cat_Concat, Reshape_27408}, {{"auto_broadcast", "numpy"}});
auto add_Add = makeOP<opset1::Add>({mul_Multiply, mul_Multiply_1}, {{"auto_broadcast", "numpy"}});

auto slice_Slice_10 = makeConst(element::f32,
ov::Shape({
1,
32767,
1,
1,
}),
{1});
auto view_Reshape_1 = makeOP<opset1::Reshape>({qkv_proj->output(1), {0, 0, 32, 128}}, {{"special_zero", true}});
auto slice_Slice_11 = makeOP<opset8::Slice>({view_Reshape_1, {0}, {128}, {1}, {3}});
auto mul_Multiply_2 = makeOP<opset1::Multiply>({slice_Slice_11, Reshape_27400}, {{"auto_broadcast", "numpy"}});
auto reshape_Reshape_1 = makeOP<opset1::Reshape>({slice_Slice_11, {0, 0, 32, 2, 64}}, {{"special_zero", true}});
auto ListUnpack_Split_1 = makeOP<opset1::Split>({reshape_Reshape_1, -2}, {{"num_splits", 2}});
auto Multiply_54139 =
makeOP<opset1::Multiply>({ListUnpack_Split_1->output(1), -1.000000f}, {{"auto_broadcast", "numpy"}});
auto ListUnpack_Squeeze_0_1 =
makeOP<opset1::Reshape>({Multiply_54139, {-1, 1, 32, 64}}, {{"special_zero", false}});
auto ListUnpack_Squeeze_1 =
makeOP<opset1::Reshape>({ListUnpack_Split_1->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});
auto cat_Concat_2 = makeOP<opset1::Concat>({ListUnpack_Squeeze_0_1, ListUnpack_Squeeze_1}, {{"axis", -1}});
auto mul_Multiply_3 = makeOP<opset1::Multiply>({cat_Concat_2, Reshape_27408}, {{"auto_broadcast", "numpy"}});
auto add_Add_1 = makeOP<opset1::Add>({mul_Multiply_2, mul_Multiply_3}, {{"auto_broadcast", "numpy"}});
model = std::make_shared<ov::Model>(ov::NodeVector{add_Add_1}, ov::ParameterVector{position_ids, qkv});
}

manager.register_pass<ov::pass::RoPEFusion>(false);

{
auto input = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, 1, 4096 * 3});
auto rotary_emp_sin = makeConst(element::f32,
ov::Shape({
1,
4096,
1,
128,
}),
{1});
auto rotary_emp_cos = makeConst(element::f32,
ov::Shape({
1,
4096,
1,
128,
}),
{1});
auto rope = makeOP<ov::op::internal::RoPE>({input, rotary_emp_sin, rotary_emp_cos},
{{"config.slice_start", 4096},
{"config.slice_stop", 8192},
{"config.input_trans0213", false},
{"config.output_trans0213", false},
{"config.is_interleaved", false},
{"config.rotary_ndims", 128},
{"config.is_chatglm", false},
{"config.support_2d_rope", false},
{"config.is_qwen", true},
{"config.head_cnt", 32},
{"config.head_size", 128},
{"config.gather_position_arg_id", 0}});
model_ref = std::make_shared<ov::Model>(ov::NodeVector{rope}, ov::ParameterVector{input});
}
}

0 comments on commit 36f1184

Please sign in to comment.