Skip to content

Commit

Permalink
[GPU] Relax UnsqueezeBroadcastReshapeSDPAFusion (openvinotoolkit#27515)
Browse files Browse the repository at this point in the history
### Details:
- By relaxing UnsqueezeBroadcastReshapeSDPAFusion, GQA pattern is
enabled and Broadcasting nodes overheads in paths of key and value are
removed, thus improves performance of GLM4 model significantly.
- Fix for GLM4V, which has initial state shape (-1, 0, 0, 0), and shape
infer failed.
 
### Tickets:
 - *CVS-157263*

---------

Co-authored-by: Chen Peter <[email protected]>
  • Loading branch information
2 people authored and NishantPrabhuFujitsu committed Nov 26, 2024
1 parent c4c160b commit 57f246a
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 10 deletions.
11 changes: 7 additions & 4 deletions src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,21 @@ std::vector<ov::PartialShape> shape_infer(const KVCache* op, const std::vector<o

const auto& gather_axis = op->get_gather_axis();
const auto& concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size());
// We update output shape with input1 shape by default, as input1 is always new, and in some situations, input0 shape
// has zeros in some dimensions. For example to concat input0 [-1, 0, 0, 0] + input1 [-1, 4, -1, 128] along axis 2,
// we could (and should) infer dim value of axis 1 and 3 in this case.
if (op->get_output_size() >= 2) {
out_shapes[0] = input_shapes[0];
out_shapes[0] = input_shapes[1];
out_shapes[0][gather_axis] = input_shapes[2][0];
out_shapes[0][concat_axis] += input_shapes[1][concat_axis];
out_shapes[0][concat_axis] += input_shapes[0][concat_axis];

std::vector<ov::Dimension> dims(out_shapes[0].size(), 1);
dims[gather_axis] = out_shapes[0][gather_axis];
dims[concat_axis] = out_shapes[0][concat_axis];
out_shapes[1] = dims;
} else {
out_shapes[0] = input_shapes[0];
out_shapes[0][concat_axis] += input_shapes[1][concat_axis];
out_shapes[0] = input_shapes[1];
out_shapes[0][concat_axis] += input_shapes[0][concat_axis];
}

return out_shapes;
Expand Down
5 changes: 4 additions & 1 deletion src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,12 @@ std::vector<ov::PartialShape> shape_infer(const SDPA* op,
if (is_broadcastable) {
size_t max_rank = shape_q_t.size();
for (size_t i = 0; i < max_rank; ++i) {
if (shape_q_t[i].is_static() && shape_k_t[i].is_static() && shape_v_t[i].is_static()) {
if (shape_q_t[i].is_static() && shape_k_t[i].is_static()) {
auto broadcasted_dim = shape_q_t[i].get_length();
shape_k_t[i] = broadcasted_dim;
}
if (shape_q_t[i].is_static() && shape_v_t[i].is_static()) {
auto broadcasted_dim = shape_q_t[i].get_length();
shape_v_t[i] = broadcasted_dim;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ using ov::pass::pattern::op::Or;
UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() {
using namespace ov::pass::pattern;

auto not_reshape = [](const ov::Output<ov::Node>& output) -> bool {
return std::dynamic_pointer_cast<ov::op::v1::Reshape>(output.get_node_shared_ptr()) == nullptr;
};

auto unsqueeze_predicate = [](const ov::Output<ov::Node>& output) -> bool {
return rank_equals(5)(output) && consumers_count(1);
};
Expand All @@ -42,7 +38,7 @@ UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() {
return rank_equals(4)(output) && consumers_count(1);
};

auto input_a_m = any_input(not_reshape);
auto input_a_m = any_input();
auto input_attn_mask = any_input();
auto input_scale = any_input();
auto input_b_m = wrap_type<ov::intel_gpu::op::KVCache>({any_input(), any_input()});
Expand Down

0 comments on commit 57f246a

Please sign in to comment.