-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[GPU] Relax UnsqueezeBroadcastReshapeSDPAFusion #27515
[GPU] Relax UnsqueezeBroadcastReshapeSDPAFusion #27515
Conversation
…branch from non-reshape.
Hi, cecilia
can you have a look, thank you. |
…hen read_value of init shape full 0.
@@ -42,7 +38,7 @@ UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() { | |||
return rank_equals(4)(output) && consumers_count(1); | |||
}; | |||
|
|||
auto input_a_m = any_input(not_reshape); | |||
auto input_a_m = any_input(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@sshlyapn May I know why "not_reshape" is asked here previously? Any problem here if I remove this check?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally it was copied from UnsqueezeBroadcastReshapeMatmulFusion transformation, but it seems okay to me to relax this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems that you allowed query input’s reshape, and seems that we need to check whether both sdpa_opt and sdpa_micro supports dynamic padded query input.
E.g.,
Fused QKV gemm => VariadicSplit (crop + optimized out) => reshape (optimized out) => sdpa query input
Not quickly sure which model contains such a pattern.
Maybe you can just create a functional test, which has above pattern, and then check the values are correct.
@yeonbok This relax is an GQA pattern optimizing by removing broadcast nodes from key and value input paths. The sdpa gpu node was in the exec graph already before this optimizing. So correctness of this special case you mentioned should have been assured already.
@@ -42,7 +38,7 @@ UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() { | |||
return rank_equals(4)(output) && consumers_count(1); | |||
}; | |||
|
|||
auto input_a_m = any_input(not_reshape); | |||
auto input_a_m = any_input(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Originally it was copied from UnsqueezeBroadcastReshapeMatmulFusion transformation, but it seems okay to me to relax this
// Sometime input0 shape has zeros (or even dynamic dim) in several dimensions, for | ||
// example concat [-1, 0, 0, 0] + [-1, 4, -1, 128] along axis 2, we could (and should) infer | ||
// dim value of axis 1 and 4 in this case. | ||
for (int64_t i = 0; i < static_cast<int64_t>(out_shapes[0].size()); ++i) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we can just initialize out_shapes[0] with input_shapes[1] instead of input_shapes[0], since the input_shapes[1] input shape is always "new_token" input and input_shapes[0] is "past". It seems [1] shape should always be more detailed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
### Details: - By relaxing UnsqueezeBroadcastReshapeSDPAFusion, GQA pattern is enabled and Broadcasting nodes overheads in paths of key and value are removed, thus improves performance of GLM4 model significantly. - Fix for GLM4V, which has initial state shape (-1, 0, 0, 0), and shape infer failed. ### Tickets: - *CVS-157263* --------- Co-authored-by: Chen Peter <[email protected]>
### Details: - By relaxing UnsqueezeBroadcastReshapeSDPAFusion, GQA pattern is enabled and Broadcasting nodes overheads in paths of key and value are removed, thus improves performance of GLM4 model significantly. - Fix for GLM4V, which has initial state shape (-1, 0, 0, 0), and shape infer failed. ### Tickets: - *CVS-157263* Co-authored-by: Chen Peter <[email protected]>
Details:
Tickets: