diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index f38b6cd89974ad..06e83c5adb3e6b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -112,7 +112,14 @@ KERNEL(pa_sdpa_opt)( head_num_idx * HEAD_SIZE + query_idx_local; - slm_query[query_idx_local] = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx); + INPUT0_TYPE q_val = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx); + + // Apply scale value directly to the query input to improve accuracy in case of a high range of input data +#ifdef SCALE_VAL + q_val = TO_INPUT0_TYPE(SCALE_VAL) * q_val; +#endif + + slm_query[query_idx_local] = q_val; barrier(CLK_LOCAL_MEM_FENCE); #else @@ -122,6 +129,11 @@ KERNEL(pa_sdpa_opt)( head_num_idx * HEAD_SIZE + i * SUBGROUP_SIZE; q_val[i] = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx); + + // Apply scale value directly to the query input to improve accuracy in case of a high range of input data +#ifdef SCALE_VAL + q_val[i] = TO_INPUT0_TYPE(SCALE_VAL) * q_val[i]; +#endif } #endif @@ -162,10 +174,6 @@ KERNEL(pa_sdpa_opt)( } } -#ifdef SCALE_VAL - qk_acc = TO_INPUT0_TYPE(SCALE_VAL) * qk_acc; -#endif - const uint token_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + block_num * SUBGROUPS_PER_WG * SUBGROUP_SIZE + sgid * SUBGROUP_SIZE + sglid; #ifdef HAS_ALIBI diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 5d07488c676847..e4e7dcb77e03fb 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -30,7 +30,17 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared auto query_ps = op->get_input_partial_shape(0); auto head_size = key_cache_ps[2].get_length(); auto kv_heads_num = key_cache_ps[1].get_length(); - auto heads_num = query_ps[1].get_length() / head_size; + + // WA: in some cases, the query input may have a bounded dimension + // Use input shape of the input node in such cases + auto heads_num = 0; + auto query_merged_dim = query_ps[1]; + if (query_merged_dim.is_static()) { + heads_num = query_merged_dim.get_length() / head_size; + } else { + auto reshape_input = op->get_input_node_shared_ptr(0)->get_input_partial_shape(0); + heads_num = reshape_input[2].get_length(); + } prim.head_size = head_size; prim.kv_heads_num = kv_heads_num;