Skip to content

Commit

Permalink
[GPU] Fix sdpa_micro kernel query input remainder check and kernel se…
Browse files Browse the repository at this point in the history
…lection logic in case of scalar-value attention mask
  • Loading branch information
sshlyapn committed Nov 20, 2024
1 parent 1319c64 commit 041078f
Showing 1 changed file with 5 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,10 @@ bool SDPAKernelMicro::Validate(const Params& p) const {
if (params.conf.is_kv_compressed)
return false;

// Do not use sdpa_micro kernel with a scalar-value mask
if (params.inputs.size() > 3 && !params.inputs[3].is_dynamic() && params.inputs[3].LogicalSize() == 1)
return false;

return true;
}

Expand Down Expand Up @@ -391,7 +395,7 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
bool d_full = (head_size == d_max);
bool v_full = (head_size == tile_v);
bool k_full = !n_keys.is_dynamic && (n_keys.v % tile_k) == 0;
bool q_full = !n_queries.is_dynamic && (n_queries.v % tile_q) != 0;
bool q_full = !n_queries.is_dynamic && (n_queries.v % tile_q) == 0;

auto Q_num_heads_dim = get_num_heads(Q, params.input0_order);
auto K_num_heads_dim = get_num_heads(K, params.input1_order);
Expand Down

0 comments on commit 041078f

Please sign in to comment.