diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp index 12d961be6d337a..6721d0f9ebd608 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp @@ -106,18 +106,21 @@ std::vector shape_infer(const KVCache* op, const std::vectorget_gather_axis(); const auto& concat_axis = ov::util::normalize(op->get_concat_axis(), input_shapes[0].size()); + // We update output shape with input1 shape by default, as input1 is always new, and in some situations, input0 shape + // has zeros in some dimensions. For example to concat input0 [-1, 0, 0, 0] + input1 [-1, 4, -1, 128] along axis 2, + // we could (and should) infer dim value of axis 1 and 3 in this case. if (op->get_output_size() >= 2) { - out_shapes[0] = input_shapes[0]; + out_shapes[0] = input_shapes[1]; out_shapes[0][gather_axis] = input_shapes[2][0]; - out_shapes[0][concat_axis] += input_shapes[1][concat_axis]; + out_shapes[0][concat_axis] += input_shapes[0][concat_axis]; std::vector dims(out_shapes[0].size(), 1); dims[gather_axis] = out_shapes[0][gather_axis]; dims[concat_axis] = out_shapes[0][concat_axis]; out_shapes[1] = dims; } else { - out_shapes[0] = input_shapes[0]; - out_shapes[0][concat_axis] += input_shapes[1][concat_axis]; + out_shapes[0] = input_shapes[1]; + out_shapes[0][concat_axis] += input_shapes[0][concat_axis]; } return out_shapes; diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp index 09513d99153a1f..3988306ba5eff4 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp @@ -144,9 +144,12 @@ std::vector shape_infer(const SDPA* op, if (is_broadcastable) { size_t max_rank = shape_q_t.size(); for (size_t i = 0; i < max_rank; ++i) { - if (shape_q_t[i].is_static() && shape_k_t[i].is_static() && shape_v_t[i].is_static()) { + if (shape_q_t[i].is_static() && shape_k_t[i].is_static()) { auto broadcasted_dim = shape_q_t[i].get_length(); shape_k_t[i] = broadcasted_dim; + } + if (shape_q_t[i].is_static() && shape_v_t[i].is_static()) { + auto broadcasted_dim = shape_q_t[i].get_length(); shape_v_t[i] = broadcasted_dim; } } diff --git a/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp b/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp index d525792ccd8d06..2b0d2ed5eaf145 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/unsqueeze_broadcast_reshape_sdpa_fusion.cpp @@ -23,10 +23,6 @@ using ov::pass::pattern::op::Or; UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() { using namespace ov::pass::pattern; - auto not_reshape = [](const ov::Output& output) -> bool { - return std::dynamic_pointer_cast(output.get_node_shared_ptr()) == nullptr; - }; - auto unsqueeze_predicate = [](const ov::Output& output) -> bool { return rank_equals(5)(output) && consumers_count(1); }; @@ -42,7 +38,7 @@ UnsqueezeBroadcastReshapeSDPAFusion::UnsqueezeBroadcastReshapeSDPAFusion() { return rank_equals(4)(output) && consumers_count(1); }; - auto input_a_m = any_input(not_reshape); + auto input_a_m = any_input(); auto input_attn_mask = any_input(); auto input_scale = any_input(); auto input_b_m = wrap_type({any_input(), any_input()});