Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[GPU] Relax UnsqueezeBroadcastReshapeSDPAFusion #27515

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -106,18 +106,21 @@ std::vector<ov::PartialShape> shape_infer(const KVCache* op, const std::vector<o

const auto& gather_axis = op->get_gather_axis();
const auto& concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size());
// We update output shape with input1 shape by default, as input1 is always new, and in some situations, input0 shape
// has zeros in some dimensions. For example to concat input0 [-1, 0, 0, 0] + input1 [-1, 4, -1, 128] along axis 2,
// we could (and should) infer dim value of axis 1 and 3 in this case.
if (op->get_output_size() >= 2) {
out_shapes[0] = input_shapes[0];
out_shapes[0] = input_shapes[1];
out_shapes[0][gather_axis] = input_shapes[2][0];
out_shapes[0][concat_axis] += input_shapes[1][concat_axis];
out_shapes[0][concat_axis] += input_shapes[0][concat_axis];

std::vector<ov::Dimension> dims(out_shapes[0].size(), 1);
dims[gather_axis] = out_shapes[0][gather_axis];
dims[concat_axis] = out_shapes[0][concat_axis];
out_shapes[1] = dims;
} else {
out_shapes[0] = input_shapes[0];
out_shapes[0][concat_axis] += input_shapes[1][concat_axis];
out_shapes[0] = input_shapes[1];
out_shapes[0][concat_axis] += input_shapes[0][concat_axis];
}

return out_shapes;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,12 @@ std::vector<ov::PartialShape> shape_infer(const SDPA* op,
if (is_broadcastable) {
size_t max_rank = shape_q_t.size();
for (size_t i = 0; i < max_rank; ++i) {
if (shape_q_t[i].is_static() && shape_k_t[i].is_static() && shape_v_t[i].is_static()) {
if (shape_q_t[i].is_static() && shape_k_t[i].is_static()) {
auto broadcasted_dim = shape_q_t[i].get_length();
shape_k_t[i] = broadcasted_dim;
}
if (shape_q_t[i].is_static() && shape_v_t[i].is_static()) {
auto broadcasted_dim = shape_q_t[i].get_length();
shape_v_t[i] = broadcasted_dim;
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,6 @@ using ov::pass::pattern::op::Or;
UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() {
using namespace ov::pass::pattern;

auto not_reshape = [](const ov::Output<ov::Node>& output) -> bool {
return std::dynamic_pointer_cast<ov::op::v1::Reshape>(output.get_node_shared_ptr()) == nullptr;
};

auto unsqueeze_predicate = [](const ov::Output<ov::Node>& output) -> bool {
return rank_equals(5)(output) && consumers_count(1);
};
Expand All @@ -42,7 +38,7 @@ UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() {
return rank_equals(4)(output) && consumers_count(1);
};

auto input_a_m = any_input(not_reshape);
auto input_a_m = any_input();
Copy link
Contributor Author

@ceciliapeng2011 ceciliapeng2011 Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sshlyapn May I know why "not_reshape" is asked here previously? Any problem here if I remove this check?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Originally it was copied from UnsqueezeBroadcastReshapeMatmulFusion transformation, but it seems okay to me to relax this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems that you allowed query input’s reshape, and seems that we need to check whether both sdpa_opt and sdpa_micro supports dynamic padded query input.
E.g.,
Fused QKV gemm => VariadicSplit (crop + optimized out) => reshape (optimized out) => sdpa query input
Not quickly sure which model contains such a pattern.
Maybe you can just create a functional test, which has above pattern, and then check the values are correct.

@yeonbok This relax is an GQA pattern optimizing by removing broadcast nodes from key and value input paths. The sdpa gpu node was in the exec graph already before this optimizing. So correctness of this special case you mentioned should have been assured already.

auto input_attn_mask = any_input();
auto input_scale = any_input();
auto input_b_m = wrap_type<ov::intel_gpu::op::KVCache>({any_input(), any_input()});
Expand Down
Loading