Skip to content

Commit

Permalink
[GPU] Fix accuracy issue of the qwen-2 model caused by a high data ra…
Browse files Browse the repository at this point in the history
…nge of query input by applying scale at the beginning. Also, fix PA operation creation in cases where query input has a dynamic input dimension
  • Loading branch information
sshlyapn committed Sep 3, 2024
1 parent 9db9a78 commit 756299b
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 6 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,14 @@ KERNEL(pa_sdpa_opt)(
head_num_idx * HEAD_SIZE +
query_idx_local;

slm_query[query_idx_local] = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);
INPUT0_TYPE q_val = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);

// Apply scale value directly to the query input to improve accuracy in case of a high range of input data
#ifdef SCALE_VAL
q_val = TO_INPUT0_TYPE(SCALE_VAL) * q_val;
#endif

slm_query[query_idx_local] = q_val;

barrier(CLK_LOCAL_MEM_FENCE);
#else
Expand All @@ -122,6 +129,11 @@ KERNEL(pa_sdpa_opt)(
head_num_idx * HEAD_SIZE +
i * SUBGROUP_SIZE;
q_val[i] = BLOCK_READN(INPUT0_TYPE, 1, query, query_idx);

// Apply scale value directly to the query input to improve accuracy in case of a high range of input data
#ifdef SCALE_VAL
q_val[i] = TO_INPUT0_TYPE(SCALE_VAL) * q_val[i];
#endif
}
#endif

Expand Down Expand Up @@ -162,10 +174,6 @@ KERNEL(pa_sdpa_opt)(
}
}

#ifdef SCALE_VAL
qk_acc = TO_INPUT0_TYPE(SCALE_VAL) * qk_acc;
#endif

const uint token_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + block_num * SUBGROUPS_PER_WG * SUBGROUP_SIZE + sgid * SUBGROUP_SIZE + sglid;

#ifdef HAS_ALIBI
Expand Down
12 changes: 11 additions & 1 deletion src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,17 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
auto query_ps = op->get_input_partial_shape(0);
auto head_size = key_cache_ps[2].get_length();
auto kv_heads_num = key_cache_ps[1].get_length();
auto heads_num = query_ps[1].get_length() / head_size;

// WA: in some cases, the query input may have a bounded dimension
// Use input shape of the input node in such cases
auto heads_num = 0;
auto query_merged_dim = query_ps[1];
if (query_merged_dim.is_static()) {
heads_num = query_merged_dim.get_length() / head_size;
} else {
auto reshape_input = op->get_input_node_shared_ptr(0)->get_input_partial_shape(0);
heads_num = reshape_input[2].get_length();
}

prim.head_size = head_size;
prim.kv_heads_num = kv_heads_num;
Expand Down

0 comments on commit 756299b

Please sign in to comment.