Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TRANSFORMATIONS] SDPAToPagedAttention transformation: support decompression case in the Qwen-7b-Chat pattern #28514

Open
wants to merge 2 commits into
base: releases/2025/0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -72,8 +72,15 @@ ov::pass::PositionIDsReplacerQwen::PositionIDsReplacerQwen(const Output<Node>& p

auto p_neg_const = wrap_type<v0::Constant>();
auto p_neg_mul = wrap_type<v1::Multiply>({p_current_len, p_neg_const});

// For now, it has always been a constant, but this may change in the future.
// In case of model being in FP16, there will be a decompressing subgraph:
// i.e. Constant -> Convert -> Slice
//
// Also, it hasn't been observed yet, but, theoretically, there can also be a
// dequantizing subgraph, so it's going to be any_input() here.
auto p_rotary_emb_sincos = pattern::any_input();
// the rotary_emb_cos/rotary_emb_sin are sliced by the total length [1,..4096,1,128]
auto p_rotary_emb_sincos = wrap_type<v0::Constant>();
auto p_slice_1 = wrap_type<v8::Slice>({p_rotary_emb_sincos, _const(), p_opt_reshape, _const(), _const()});
auto p_slice_2 = wrap_type<v8::Slice>({p_slice_1, p_neg_mul, _const(), _const(), _const()});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
#include "openvino/op/subtract.hpp"
#include "openvino/op/transpose.hpp"
#include "openvino/op/unsqueeze.hpp"
#include "openvino/pass/visualize_tree.hpp"
#include "transformations/sdpa_to_paged_attention/prev_sequence_length_pattern.hpp"
#include "transformations/sdpa_to_paged_attention/state_management_pattern.hpp"
#include "transformations/sdpa_to_paged_attention/total_sequence_length_pattern.hpp"
Expand Down Expand Up @@ -186,17 +185,25 @@ class Qwen7bChatSDPA {

static std::shared_ptr<Node> gen_rope_emb_sin(const std::shared_ptr<Node>& total_seq_len,
const std::shared_ptr<Node>& neg_mul,
std::shared_ptr<Node>& head_size) {
auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE);
std::shared_ptr<Node>& head_size,
element::Type model_precision) {
auto sin = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE);
if (model_precision != element::f32) {
sin = makeOP<v0::Convert>({sin}, {dest_type_f32});
}
auto sliced_sin_by_total = makeOP<v8::Slice>({sin, {0}, total_seq_len, {1}, {1}});
auto rotary_emb_sin_shape = makeOP<v3::ShapeOf>({sliced_sin_by_total}, {{"output_type", "i64"}});
head_size = makeOP<v8::Gather>({rotary_emb_sin_shape, {3}, 0}, {{"batch_dims", 0}});
return makeOP<v8::Slice>({sliced_sin_by_total, neg_mul, {LLONG_MAX}, {1}, {1}});
}

static std::shared_ptr<Node> gen_rope_emb_cos(const std::shared_ptr<Node>& total_seq_len,
const std::shared_ptr<Node>& neg_mul) {
auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE);
const std::shared_ptr<Node>& neg_mul,
element::Type model_precision) {
auto cos = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE);
if (model_precision != element::f32) {
cos = makeOP<v0::Convert>({cos}, {dest_type_f32});
}
auto sliced_cos_by_total = makeOP<v8::Slice>({cos, {0}, total_seq_len, {1}, {1}});
return makeOP<v8::Slice>({sliced_cos_by_total, neg_mul, {LLONG_MAX}, {1}, {1}});
}
Expand Down Expand Up @@ -343,8 +350,12 @@ class Qwen7bChatPA {

static std::shared_ptr<Node> gen_rope_emb_sin(const std::shared_ptr<Node>& max_context_len,
const std::shared_ptr<Node>& position_ids,
std::shared_ptr<Node>& head_size) {
auto sin = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE);
std::shared_ptr<Node>& head_size,
element::Type model_precision) {
auto sin = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE);
if (model_precision != element::f32) {
sin = makeOP<v0::Convert>({sin}, {dest_type_f32});
}
auto slice_sin = makeOP<v8::Gather>({sin, position_ids, 1}, {{"batch_dims", 0}});

auto slice = makeOP<v8::Slice>({sin, {0}, max_context_len, {1}, {1}});
Expand All @@ -355,8 +366,12 @@ class Qwen7bChatPA {
}

static std::shared_ptr<Node> gen_rope_emb_cos(const std::shared_ptr<Node>& max_context_len,
const std::shared_ptr<Node>& position_ids) {
auto cos = makeConst(element::f32, {1, 4096, 1, 128}, MOCK_VALUE);
const std::shared_ptr<Node>& position_ids,
element::Type model_precision) {
auto cos = makeConst(model_precision, {1, 4096, 1, 128}, MOCK_VALUE);
if (model_precision != element::f32) {
cos = makeOP<v0::Convert>({cos}, {dest_type_f32});
}
auto slice = makeOP<v8::Gather>({cos, position_ids, 1}, {{"batch_dims", 0}});
return makeOP<v1::Reshape>({slice, {-1, 1, 1, 128}}, {{"special_zero", false}});
}
Expand Down Expand Up @@ -425,7 +440,10 @@ class Qwen7bChatPA {

} // namespace

TEST_F(TransformationTestsF, SDPAToPA_Qwen) {
class SDPAToPATest : public TransformationTestsF, public ::testing::WithParamInterface<element::Type> {};

TEST_P(SDPAToPATest, SDPAToPA_Qwen7bChat_General) {
const auto model_precision = GetParam();
{
// Inputs to SDPA transformer:
auto beam_idx = makeOP<v0::Parameter>({}, {{"shape", PartialShape{DYN}}, el_type_i64});
Expand Down Expand Up @@ -455,8 +473,9 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) {
// RoPE emb sin/cos init:
auto neg_cur_seq_len = Qwen7bChatSDPA::neg_mul(current_seq_len);
auto head_size = shared_ptr<Node>();
auto rope_emb_sin = Qwen7bChatSDPA::gen_rope_emb_sin(total_seq_len, neg_cur_seq_len, head_size);
auto rope_emb_cos = Qwen7bChatSDPA::gen_rope_emb_cos(total_seq_len, neg_cur_seq_len);
auto rope_emb_sin =
Qwen7bChatSDPA::gen_rope_emb_sin(total_seq_len, neg_cur_seq_len, head_size, model_precision);
auto rope_emb_cos = Qwen7bChatSDPA::gen_rope_emb_cos(total_seq_len, neg_cur_seq_len, model_precision);

// RoPE for Q,K inputs:
auto rope_q = Qwen7bChatSDPA::gen_rope(QKV::Q, qkv_proj, head_size, rope_emb_sin, rope_emb_cos);
Expand Down Expand Up @@ -515,8 +534,10 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) {

// RoPE emb sin/cos init:
auto head_size = shared_ptr<Node>();
auto rope_emb_sin = Qwen7bChatPA::gen_rope_emb_sin(max_context_len_aligned, position_ids_aligned, head_size);
auto rope_emb_cos = Qwen7bChatPA::gen_rope_emb_cos(max_context_len_aligned, position_ids_aligned);
auto rope_emb_sin =
Qwen7bChatPA::gen_rope_emb_sin(max_context_len_aligned, position_ids_aligned, head_size, model_precision);
auto rope_emb_cos =
Qwen7bChatPA::gen_rope_emb_cos(max_context_len_aligned, position_ids_aligned, model_precision);

// rope Q, K:
auto rope_Q = Qwen7bChatPA::gen_rope(QKV::Q, qkv_proj, head_size, rope_emb_sin, rope_emb_cos);
Expand Down Expand Up @@ -564,7 +585,7 @@ TEST_F(TransformationTestsF, SDPAToPA_Qwen) {
disable_rt_info_check();
}

TEST_F(TransformationTestsF, SDPAToPA_TotalSequenceLengthPatternQwen) {
TEST_P(SDPAToPATest, SDPAToPA_Qwen7bChat_TotalSequenceLengthPattern) {
{
// Inputs to SDPA transformer:
auto beam_idx = makeOP<v0::Parameter>({}, {{"shape", PartialShape{DYN}}, el_type_i64});
Expand Down Expand Up @@ -632,7 +653,7 @@ static std::shared_ptr<ov::Node> make_param(const PartialShape& pshape,
// TODO: write a test for StateManagementPattern only (because changes for Alibi are inside it)
// TODO: align precisions, check the copying of "fuse_names" attr in SDPAToPagedAttention
// checking the graph structure and names, other checks are temporarily disabled:
TEST_F(TransformationTestsF, SDPAToPA_Baichuan2_13b_general_test) {
TEST_P(SDPAToPATest, SDPAToPA_Baichuan2_13b_General) {
{
auto beam_idx = make_param(PartialShape{DYN}, element::i32, "beam_idx");
auto position_ids = make_param(PartialShape{DYN, DYN}, element::i64, "position_ids");
Expand Down Expand Up @@ -881,4 +902,17 @@ TEST_F(TransformationTestsF, SDPAToPA_Baichuan2_13b_general_test) {
disable_result_friendly_names_check();
disable_rt_info_check();
}
}
}

/*
As there's often a need to cover specific model's architecutres in these
tests, please, make sure you name the tests in the following manner:
SDPAToPA_MODELNAME_PATTERNYOUCOVER:
i.e. SDPAToPA_Qwen7bChat_TotalSequenceLengthPattern or
SDPAToPA_Baichuan2_13b_General if this is a test for the
entire SDPAToPA transformation
*/

const std::vector<ov::element::Type> element_types = {element::f16, element::f32};

INSTANTIATE_TEST_SUITE_P(SDPAToPATest_Conversion, SDPAToPATest, testing::ValuesIn(element_types));
Loading