Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update RopeFusion for qwen, gpt-j models to support new pattern SDPA to PA conversion #28512

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -397,9 +397,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
auto varsplit = makePattern<opset1::VariadicSplit>({gather_sin_cos, -1, {ndims / 2, -1}});
varsplit->set_output_size(2);
// Reshape or UnSqueeze should both be support
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {1, -1, 1, 32}}) |
auto dim0 = ov::gen_pattern::Symbol("dim0");
auto dim1 = ov::gen_pattern::Symbol("dim1");
auto unsqueeze_sin = makePattern<opset1::Reshape>({varsplit->output(0), {dim0, dim1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(0), 2});
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {1, -1, 1, 32}}) |
auto unsqueeze_cos = makePattern<opset1::Reshape>({varsplit->output(1), {dim0, dim1, 1, 32}}) |
makePattern<opset1::Unsqueeze>({varsplit->output(1), 2});
// repeate cos/sin table
auto const_idx = makeConst(ov::element::i32, ov::PartialShape::dynamic(), [](const ov::op::v0::Constant& node) {
Expand All @@ -419,10 +421,17 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {

auto neg_Multiply_1177 = makePattern<opset1::Multiply>({slice_Slice_1174, -1.0f}, {{"auto_broadcast", "numpy"}});
auto Unsqueeze_65524 = makePattern<opset1::Unsqueeze>({neg_Multiply_1177, -1});
auto head_num = ov::gen_pattern::Symbol("head_num");
auto Unsqueeze_28998 =
makePattern<opset1::Reshape>({neg_Multiply_1177, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});

auto slice_Slice_1168 = GenSlice(slice_Slice_965 | varsplit_view_Reshape->output(0), 0, int32_max, 2, 3);
auto Unsqueeze_65525 = makePattern<opset1::Unsqueeze>({slice_Slice_1168, -1});
auto stack_1182 = makePattern<opset1::Concat>({Unsqueeze_65524, Unsqueeze_65525}, {{"axis", -1}});
auto Unsqueeze_28999 =
makePattern<opset1::Reshape>({slice_Slice_1168, {-1, 1, head_num, 32, 1}}, {{"special_zero", false}});
auto stack_1182 =
makePattern<opset1::Concat>({Unsqueeze_28998 | Unsqueeze_65524, Unsqueeze_65525 | Unsqueeze_28999},
{{"axis", -1}});

auto ShapeOf_169068 = makePattern<opset1::ShapeOf>({stack_1182});
auto flatten_Slice_1194 = GenSlice(ShapeOf_169068, 0, 3, 1, 0);
Expand All @@ -447,7 +456,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
makePattern<opset1::Concat>({rotary_emb, slice_Slice_971 | varsplit_view_Reshape->output(1)}, {{"axis", -1}});
auto permute_Transpose_1213 = makePattern<opset1::Transpose>({cat_Concat_1211, {0, 2, 1, 3}});

auto result = permute_Transpose_1213;
auto result = cat_Concat_1211 | permute_Transpose_1213;

matcher_pass_callback callback = [=](ov::pass::pattern::Matcher& m) {
const auto& pattern_map = m.get_pattern_value_map();
Expand All @@ -461,7 +470,8 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
OutputVector new_args;
config.rotary_ndims = static_cast<size_t>(validator["ndims"]);

config.output_trans0213 = true;
if (pattern_map.count(permute_Transpose_1213))
config.output_trans0213 = true;
config.is_interleaved = true;

// input is [B,L,H,S]
Expand All @@ -478,14 +488,11 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
pattern_map.at(repeat_interleave_sin).get_node_shared_ptr(),
pattern_map.at(repeat_interleave_cos).get_node_shared_ptr(),
pattern_map.at(neg_Multiply_1177).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65524).get_node_shared_ptr(),
pattern_map.at(Unsqueeze_65525).get_node_shared_ptr(),
pattern_map.at(stack_1182).get_node_shared_ptr(),
pattern_map.at(mul_cos).get_node_shared_ptr(),
pattern_map.at(mul_sin).get_node_shared_ptr(),
pattern_map.at(rotary_emb).get_node_shared_ptr(),
pattern_map.at(cat_Concat_1211).get_node_shared_ptr(),
pattern_map.at(permute_Transpose_1213).get_node_shared_ptr()},
pattern_map.at(cat_Concat_1211).get_node_shared_ptr()},
new_node);
ov::replace_node(old_node, new_node);
// shapeof may be moved up from transpose to add,
Expand Down Expand Up @@ -705,6 +712,7 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto rotary_emb_cos = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
auto rotary_emb_sin = makePattern("[1,?,1,?]"); // [1,..4096,1,128]
auto qkv_proj = makePattern("[?,?,?]"); // [?,?,12288]
auto position_ids = makePattern();

auto head_cnt = ov::gen_pattern::Symbol("head_cnt");
auto head_size = ov::gen_pattern::Symbol("head_size");
Expand All @@ -731,14 +739,19 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto ScatterUpdate_463814 = makePattern<opset3::ScatterUpdate>({{0, 0}, {1}, Gather_377635 | neg_Multiply, {0}});
auto slice_Slice_446 =
makePattern<ov::opset8::Slice>({rotary_emb_cos, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}});

auto gather_cos_by_pos_ids = makePattern<opset8::Gather>({rotary_emb_cos, position_ids, 1}, {{"batch_dims", 0}});
auto reshape_cos_to_expected_layout =
makePattern<opset8::Reshape>({gather_cos_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}});

auto slice_StridedSlice_446 = GenStridedSlice(rotary_emb_cos,
ScatterUpdate_463814,
{0, INT_MAX},
{1, 1},
1); // tensor_array<f32[1,..4096,1,128]>
auto mul_Multiply_552 =
makePattern<opset1::Multiply>({slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
auto mul_Multiply_552 = makePattern<opset1::Multiply>(
{slice_Slice_543, slice_StridedSlice_446 | slice_Slice_446 | reshape_cos_to_expected_layout},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>

auto reshape_opt1 = [&](std::shared_ptr<Node> input_BLHS) {
auto ShapeOf_485814 = makePattern<opset1::ShapeOf>({input_BLHS}, {});
Expand Down Expand Up @@ -772,18 +785,28 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
makePattern<opset1::Squeeze>({Multiply_567527, -2}); // tensor_array<f32[?,?,32,64]>
auto ListUnpack_586_Squeeze =
makePattern<opset1::Squeeze>({ListUnpack_586_Split->output(0), -2}); // tensor_array<f32[?,?,32,64]>
auto cat_Concat_593 = makePattern<opset1::Concat>({ListUnpack_586_Squeeze_0, ListUnpack_586_Squeeze},
{{"axis", -1}}); // tensor_array<f32[?,?,32,128]>

auto ListUnpack_Squeeze_0_1 =
makePattern<opset1::Reshape>({Multiply_567527, {-1, 1, 32, 64}}, {{"special_zero", false}});
auto ListUnpack_Squeeze_1 =
makePattern<opset1::Reshape>({ListUnpack_586_Split->output(0), {-1, 1, 32, 64}}, {{"special_zero", false}});

auto cat_Concat_593 = makePattern<opset1::Concat>(
{ListUnpack_586_Squeeze_0 | ListUnpack_Squeeze_0_1, ListUnpack_586_Squeeze | ListUnpack_Squeeze_1},
{{"axis", -1}}); // tensor_array<f32[?,?,32,128]>
auto slice_StridedSlice_470 = GenStridedSlice(rotary_emb_sin,
ScatterUpdate_463814,
{0, INT_MAX},
{1, 1},
1); // tensor_array<f32[1,..4096,1,128]>
auto slice_Slice_470 =
makePattern<opset8::Slice>({rotary_emb_sin, Gather_377635 | neg_Multiply, {INT_MAX}, {1}, {1}});
auto mul_Multiply_594 =
makePattern<opset1::Multiply>({cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
auto gather_sin_by_pos_ids = makePattern<opset8::Gather>({rotary_emb_sin, position_ids, 1}, {{"batch_dims", 0}});
auto reshape_sin_to_expected_layout =
makePattern<opset8::Reshape>({gather_sin_by_pos_ids, {-1, 1, 1, 128}}, {{"special_zero", false}});
auto mul_Multiply_594 = makePattern<opset1::Multiply>(
{cat_Concat_593, slice_StridedSlice_470 | slice_Slice_470 | reshape_sin_to_expected_layout},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>
auto add_Add_597 = makePattern<opset1::Add>({mul_Multiply_552, mul_Multiply_594},
{{"auto_broadcast", "numpy"}}); // tensor_array<f32[?,?,32,128]>

Expand Down Expand Up @@ -844,8 +867,8 @@ ov::pass::RoPEFusionQwen::RoPEFusionQwen(int split_output_id) {
auto new_node = std::make_shared<op::internal::RoPE>(new_args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::copy_runtime_info({pattern_map.at(Multiply_567527).get_node_shared_ptr(),
pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(),
pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(),
// pattern_map.at(ListUnpack_586_Squeeze_0).get_node_shared_ptr(),
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

debug? need to double check

// pattern_map.at(ListUnpack_586_Squeeze).get_node_shared_ptr(),
pattern_map.at(cat_Concat_593).get_node_shared_ptr(),
pattern_map.at(mul_Multiply_594).get_node_shared_ptr(),
pattern_map.at(add_Add_597).get_node_shared_ptr()},
Expand Down
Loading
Loading