Skip to content

Commit

Permalink
clean code
Browse files Browse the repository at this point in the history
  • Loading branch information
wine99 committed Jan 17, 2025
1 parent 5c7b5a7 commit c48758c
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,7 @@
#include "openvino/pass/pattern/op/wrap_type.hpp"
#include "transformations/utils/utils.hpp"

std::shared_ptr<ov::Node> get_present_state(const std::shared_ptr<ov::Node>& K,
const std::shared_ptr<ov::Node>& V,
const ov::OutputVector& op_inputs);

std::shared_ptr<ov::Node> rotaryEmbedding(ov::Output<ov::Node> input,
ov::Output<ov::Node> past_seqlen,
std::shared_ptr<ov::Node> seqlen_k,
Expand Down Expand Up @@ -113,7 +111,7 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
const auto hidden_size = get_dimensions(node_shape, {2});
const auto total_num_heads_node =
v0::Constant::create(ov::element::i64, ov::Shape{1}, {num_heads + kv_num_heads + kv_num_heads});
auto head_size_node = std::make_shared<v1::Divide>(hidden_size, total_num_heads_node);
auto head_size_node = std::make_shared<v1::Divide>(hidden_size, total_num_heads_node); // should be equal to the last dim of past_key

// transpose Q, K and V to (batch_size, num_heads, sequence_len, head_size)
auto perm = v0::Constant::create(ov::element::i64, ov::Shape{4}, {0, 2, 1, 3});
Expand Down Expand Up @@ -175,8 +173,8 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(

K = construct_kv_cache(past_key, K);
V = construct_kv_cache(past_value, V);
auto present_k = K.get_node_shared_ptr();
auto present_v = V.get_node_shared_ptr();
auto present_k = K;
auto present_v = V;

const size_t kv_num_heads_factor = num_heads / kv_num_heads;
if (kv_num_heads_factor > 1) {
Expand Down Expand Up @@ -232,7 +230,7 @@ ov::OutputVector ov::pass::GroupQueryAttentionDecomposition::decompose(
auto dim_merge_shape = v0::Constant::create(ov::element::i32, ov::Shape{3}, {0, 0, -1});
// reshape the result from (batch_size, sequence_length, num_heads, head_size)
// to (batch_size, sequence_length, num_heads * head_size)
auto output = std::make_shared<v1::Reshape>(qga_output_transposed, dim_merge_shape, true);
auto output = std::make_shared<v1::Reshape>(qga_output_transposed, dim_merge_shape, true)->output(0);

return {output, present_k, present_v};
}
Expand Down
2 changes: 1 addition & 1 deletion src/core/src/op/group_query_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ std::vector<int64_t> get_qkv_sizes(const PartialShape& input_shape, int num_head
return qkv_sizes;
}

// TODO
void GroupQueryAttention::validate_and_infer_types() {
OV_OP_SCOPE(GroupQueryAttention_validate_and_infer_types);
PartialShape input_shape = get_input_partial_shape(0);
Expand All @@ -67,6 +66,7 @@ void GroupQueryAttention::validate_and_infer_types() {
PartialShape kv_past_shape = get_input_partial_shape(3);
// FIXME: Original GQA spec depends on the identical tensor set for input/output, but we cannot know it in advance,
// hence we base on sequence dimension static/dynamic
// https://github.com/openvinotoolkit/openvino/pull/27648
if (kv_past_shape[2].is_dynamic()) {
output_kv_len = kv_past_shape[2] + sequence_len;
} else {
Expand Down

0 comments on commit c48758c

Please sign in to comment.