diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index f87f608597a6bb..2638f2ad60cf26 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -24,6 +24,10 @@ struct paged_attention : public primitive_base { OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size()); } + bool has_scores_output() const { + return num_outputs == 2; + } + bool operator==(const primitive& rhs) const override { return compare_common_params(rhs); } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp index 9cf1a252564934..7ac7f87d9f5640 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp @@ -79,6 +79,29 @@ struct paged_attention_impl : multi_stage_primitive { } std::vector get_internal_buffer_layouts_impl() const override { + /* + * Internal buffers allocation owners and users: + * +-----------------+---------------------+----------------+ + * | | Allocates | Uses | + * +-----------------+---------------------+----------------+ + * | KV_CACHE_UPDATE | [0, 1, 2] | | + * +-----------------+---------------------+----------------+ + * | SDPA | | [0, 1, 2, (3)] | + * +-----------------+---------------------+----------------+ + * | PA_SDPA | [(3), 4, 5, 6, (7)] | | + * +-----------------+---------------------+----------------+ + * + * Description: + * 0, 1, 2 - Buffers used for proper blocks distribution for kv_cache_update and + * sdpa_opt (1st token calculation) block configuration over target_seq_len dimension. Filled + * in paged_attention_inst::on_execute() call. + * 3 - Optional buffer used for PA scores output calculation, stores intermediate + * softmax values by partitions. + * 4, 5, 6 - Used for 2nd+ PA calculation (for softmax exp_sums, max_logits, and intermediate output). + * 7 - Optional buffer used for mixed PA execution mode, maps gws idx to subsequence id. Filled + * in paged_attention_inst::on_execute() call. + */ + auto add_internal_buffers = [](std::vector& layouts, const kernel_selector::KernelData& kd) { if (kd.internalBufferSizes.empty()) return; @@ -155,7 +178,7 @@ struct paged_attention_impl : multi_stage_primitive { if (desc->has_alibi) { args.inputs.push_back(instance.alibi_memory_ptr()); } - } else { + } else if (kernel_idx == 2 || kernel_idx == 3) { args.inputs = { instance.past_lens_memory_ptr() }; if (is_mixed_mode) { @@ -163,6 +186,9 @@ struct paged_attention_impl : multi_stage_primitive { // dependency args.inputs.push_back(instance.subsequence_begins_memory_ptr()); } + } else if (kernel_idx == 4) { + args.inputs = { instance.past_lens_memory_ptr(), + instance.subsequence_begins_memory_ptr() }; } args.outputs = { instance.output_memory_ptr(0) }; @@ -196,6 +222,13 @@ struct paged_attention_impl : multi_stage_primitive { internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes.size(); } else { internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size(); + + if (stage == Stage::SDPA) { + const auto desc = instance.get_node().as().get_primitive(); + if (desc->has_scores_output()) { + internal_buffers_count++; // Add softmax intermediate output buffer for scores calculation + } + } } for (size_t kd_idx = 0; kd_idx < _kernels_data[stage].kernels.size(); ++kd_idx) { @@ -338,7 +371,7 @@ struct paged_attention_impl : multi_stage_primitive { return aligned_seq_len; } - static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) { + static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param, bool is_dynamic = true) { kernel_selector::sdpa_configuration config; const auto desc = impl_param.typed_desc(); @@ -362,37 +395,49 @@ struct paged_attention_impl : multi_stage_primitive { config.group_size = desc->heads_num / desc->kv_heads_num; } + if (desc->has_scores_output() && !is_dynamic) { + // TODO: remove duplication with get_pa_sdpa_params + const auto& input_mem = impl_param.memory_deps; + const auto max_context_len = input_mem.at(12); + mem_lock max_context_len_mem_lock(max_context_len, *impl_param.strm); + config.paged_attention_max_len = max_context_len_mem_lock[0]; + + const auto& past_lens_layout = impl_param.get_input_layout(5); + config.paged_attention_sequences_num = past_lens_layout.get_partial_shape()[0].get_length(); + } + return config; } static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, + const kernel_selector::MultiDataTensor& input_tensors, bool is_dynamic = false) { auto params = get_default_params(impl_param, is_dynamic); - const auto& key_layout = impl_param.get_input_layout(1); - const auto& value_layout = impl_param.get_input_layout(2); - const auto& key_cache_layout = impl_param.get_input_layout(3); - const auto& value_cache_layout = impl_param.get_input_layout(4); - const auto& past_lens_layout = impl_param.get_input_layout(5); - const auto& block_indices_layout = impl_param.get_input_layout(7); - const auto& block_indices_begins_layout = impl_param.get_input_layout(8); - const auto& subsequence_begins_layout = impl_param.get_input_layout(6); + const auto& key_tensor = input_tensors[1]; + const auto& value_tensor = input_tensors[2]; + const auto& key_cache_tensor = input_tensors[3]; + const auto& value_cache_tensor = input_tensors[4]; + const auto& past_lens_tensor = input_tensors[5]; + const auto& block_indices_tensor = input_tensors[7]; + const auto& block_indices_begins_tensor = input_tensors[8]; + const auto& subsequence_begins_tensor = input_tensors[6]; const auto inputs_number = 6; const auto outputs_number = 2; params.inputs.resize(inputs_number); params.outputs.resize(outputs_number); - params.inputs[0] = convert_data_tensor(key_layout); - params.inputs[1] = convert_data_tensor(value_layout); - params.inputs[2] = convert_data_tensor(past_lens_layout); - params.inputs[3] = convert_data_tensor(block_indices_layout); - params.inputs[4] = convert_data_tensor(block_indices_begins_layout); - params.inputs[5] = convert_data_tensor(subsequence_begins_layout); - params.outputs[0] = convert_data_tensor(key_cache_layout); - params.outputs[1] = convert_data_tensor(value_cache_layout); + params.inputs[0] = key_tensor; + params.inputs[1] = value_tensor; + params.inputs[2] = past_lens_tensor; + params.inputs[3] = block_indices_tensor; + params.inputs[4] = block_indices_begins_tensor; + params.inputs[5] = subsequence_begins_tensor; + params.outputs[0] = key_cache_tensor; + params.outputs[1] = value_cache_tensor; - params.conf = get_sdpa_configuration(impl_param); + params.conf = get_sdpa_configuration(impl_param, is_dynamic); params.is_prefill = stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED; @@ -418,18 +463,22 @@ struct paged_attention_impl : multi_stage_primitive { return params; } - static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) { + static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, + const PagedAttentionStage& stage, + const kernel_selector::MultiDataTensor& input_tensors, + bool is_dynamic = false) { const auto desc = impl_param.typed_desc(); auto params = get_default_params(impl_param, is_dynamic); - const auto& query_layout = impl_param.get_input_layout(0); - const auto& key_layout = impl_param.get_input_layout(1); - const auto& value_layout = impl_param.get_input_layout(2); - const auto& subsequence_begins_layout = impl_param.get_input_layout(6); - const auto& scale_layout = impl_param.get_input_layout(9); - const auto& alibi_layout = impl_param.get_input_layout(11); - const auto has_alibi = alibi_layout.count() > 0; + const auto& query_tensor = input_tensors[0]; + const auto& key_tensor = input_tensors[1]; + const auto& value_tensor = input_tensors[2]; + const auto& subsequence_begins_tensor = input_tensors[6]; + const auto& scale_tensor = input_tensors[9]; + const auto& alibi_tensor = input_tensors[11]; + const auto has_alibi = impl_param.get_input_layout(11).count() > 0; const auto has_scale_input = !desc->scale_val.has_value(); + const auto has_scores_output = desc->has_scores_output(); auto inputs_number = 4; if (has_scale_input) @@ -440,18 +489,29 @@ struct paged_attention_impl : multi_stage_primitive { auto input_idx = 0; params.inputs.resize(inputs_number); - params.inputs[input_idx++] = convert_data_tensor(query_layout); - params.inputs[input_idx++] = convert_data_tensor(key_layout); - params.inputs[input_idx++] = convert_data_tensor(value_layout); - params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout); + params.inputs[input_idx++] = query_tensor; + params.inputs[input_idx++] = key_tensor; + params.inputs[input_idx++] = value_tensor; + params.inputs[input_idx++] = subsequence_begins_tensor; if (has_scale_input) - params.inputs[input_idx++] = convert_data_tensor(scale_layout); + params.inputs[input_idx++] = scale_tensor; if (has_alibi) - params.inputs[input_idx++] = convert_data_tensor(alibi_layout); + params.inputs[input_idx++] = alibi_tensor; - params.conf = get_sdpa_configuration(impl_param); + if (has_scores_output) { + params.outputs.resize(2); + params.outputs[1] = convert_data_tensor(impl_param.get_output_layout(1)); + + // const auto rotation_inputs_start_idx = 12; + // const auto rotation_inputs_num = 3; + // for (size_t i = 0; i < rotation_inputs_num; i++) { + // params.inputs[input_idx++] = convert_data_tensor(impl_param.get_input_layout(rotation_inputs_start_idx + i)); + // } + } + + params.conf = get_sdpa_configuration(impl_param, is_dynamic); const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; @@ -475,26 +535,33 @@ struct paged_attention_impl : multi_stage_primitive { if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic) params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage); + if (has_scores_output) + out_tensor_to_offset_map.insert({1, out_offsets_map.at(1)}); + params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); return params; } - static pa_sdpa_kernel_params_t get_pa_sdpa_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) { + static pa_sdpa_kernel_params_t get_pa_sdpa_params(const kernel_impl_params& impl_param, + const PagedAttentionStage& stage, + const kernel_selector::MultiDataTensor& input_tensors, + bool is_dynamic = false) { const auto desc = impl_param.typed_desc(); auto params = get_default_params(impl_param, is_dynamic); - const auto& query_layout = impl_param.get_input_layout(0); - const auto& key_cache_layout = impl_param.get_input_layout(3); - const auto& value_cache_layout = impl_param.get_input_layout(4); - const auto& past_lens_layout = impl_param.get_input_layout(5); - const auto& block_indices_layout = impl_param.get_input_layout(7); - const auto& block_indices_begins_layout = impl_param.get_input_layout(8); - const auto& subsequence_begins_layout = impl_param.get_input_layout(6); - const auto& scale_layout = impl_param.get_input_layout(9); - const auto& alibi_layout = impl_param.get_input_layout(11); - const auto has_alibi = alibi_layout.count() > 0; + const auto& query_tensor = input_tensors[0]; + const auto& key_cache_tensor = input_tensors[3]; + const auto& value_cache_tensor = input_tensors[4]; + const auto& past_lens_tensor = input_tensors[5]; + const auto& block_indices_tensor = input_tensors[7]; + const auto& block_indices_begins_tensor = input_tensors[8]; + const auto& subsequence_begins_tensor = input_tensors[6]; + const auto& scale_tensor = input_tensors[9]; + const auto& alibi_tensor = input_tensors[11]; + const auto has_alibi = impl_param.get_input_layout(11).count() > 0; const auto has_scale_input = !desc->scale_val.has_value(); + const auto has_scores_output = desc->has_scores_output(); auto inputs_number = 7; if (has_scale_input) @@ -505,28 +572,33 @@ struct paged_attention_impl : multi_stage_primitive { auto input_idx = 0; params.inputs.resize(inputs_number); - params.inputs[input_idx++] = convert_data_tensor(query_layout); - params.inputs[input_idx++] = convert_data_tensor(key_cache_layout); - params.inputs[input_idx++] = convert_data_tensor(value_cache_layout); - params.inputs[input_idx++] = convert_data_tensor(past_lens_layout); - params.inputs[input_idx++] = convert_data_tensor(block_indices_layout); - params.inputs[input_idx++] = convert_data_tensor(block_indices_begins_layout); - params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout); + params.inputs[input_idx++] = query_tensor; + params.inputs[input_idx++] = key_cache_tensor; + params.inputs[input_idx++] = value_cache_tensor; + params.inputs[input_idx++] = past_lens_tensor; + params.inputs[input_idx++] = block_indices_tensor; + params.inputs[input_idx++] = block_indices_begins_tensor; + params.inputs[input_idx++] = subsequence_begins_tensor; params.conf = get_sdpa_configuration(impl_param); if (has_scale_input) - params.inputs[input_idx++] = convert_data_tensor(scale_layout); + params.inputs[input_idx++] = scale_tensor; if (has_alibi) - params.inputs[input_idx++] = convert_data_tensor(alibi_layout); + params.inputs[input_idx++] = alibi_tensor; + + if (has_scores_output) { + params.outputs.resize(2); + params.outputs[1] = convert_data_tensor(impl_param.get_output_layout(1)); + } - params.multi_tokens_mode = stage == PagedAttentionStage::MIXED; + params.stage = stage; - if ((stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED) && !is_dynamic) { + if (!is_dynamic) { const auto& input_mem = impl_param.memory_deps; const auto max_context_len = input_mem.at(12); mem_lock max_context_len_mem_lock(max_context_len, *impl_param.strm); - params.max_context_len = max_context_len_mem_lock[0]; + params.conf.paged_attention_max_len = max_context_len_mem_lock[0]; } const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; @@ -552,22 +624,32 @@ struct paged_attention_impl : multi_stage_primitive { if (has_alibi) in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)}); + if (has_scores_output) + out_tensor_to_offset_map.insert({1, out_offsets_map.at(1)}); + params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); return params; } void update_dispatch_data(const kernel_impl_params& impl_param) override { + const auto& desc = impl_param.typed_desc(); const auto stage = get_paged_attention_stage(impl_param); - auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, impl_param.is_dynamic()); + kernel_selector::MultiDataTensor input_tensors; + for (const auto& input_layout : impl_param.input_layouts) + input_tensors.emplace_back(convert_data_tensor(input_layout)); + + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); (_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]); if (stage == PagedAttentionStage::PREFILL) { - auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, impl_param.is_dynamic()); + auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); (_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]); - } else if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED) { - auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, impl_param.is_dynamic()); + } + + if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED || desc->has_scores_output()) { + auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); (_kernels_data[Stage::PA_SDPA].update_dispatch_data_func)(pa_sdpa_kernel_params, _kernels_data[Stage::PA_SDPA]); } } @@ -576,15 +658,19 @@ struct paged_attention_impl : multi_stage_primitive { std::vector kernels_data; const auto stage = PagedAttentionStage::UNKNOWN; - auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, impl_param.is_dynamic()); + kernel_selector::MultiDataTensor input_tensors; + for (const auto& input_layout : impl_param.input_layouts) + input_tensors.emplace_back(convert_data_tensor(input_layout)); + + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance(); kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params)); - auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, impl_param.is_dynamic()); + auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance(); kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params)); - auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, impl_param.is_dynamic()); + auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance(); kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params)); diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index a7918ba9c3719c..d71986c9278a9f 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -6,15 +6,18 @@ #include "intel_gpu/primitives/paged_attention.hpp" #include "primitive_inst.h" +#include "sdpa/pa_sdpa_kernel_opt.h" namespace cldnn { -enum PagedAttentionStage { - GENERATE = 0, - PREFILL = 1, - MIXED = 2, - UNKNOWN = 3 -}; +// enum PagedAttentionStage { +// GENERATE = 0, +// PREFILL = 1, +// MIXED = 2, +// UNKNOWN = 3 +// }; + +using PagedAttentionStage = kernel_selector::PagedAttentionStage; PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param); diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 787fd184f75b6a..24bbc568003b31 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -48,14 +48,35 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no template std::vector paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) { - auto out_layout = impl_param.get_input_layout(0); + const auto& desc = impl_param.typed_desc(); + auto data_layout = impl_param.get_input_layout(0); const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape(); bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size; OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation. " "Expected ", paged_attention::block_size, ", but got ", key_cache_ps[3].get_length()); - return {out_layout}; + std::vector output_layouts{ data_layout }; + + if (desc->has_scores_output()) { + const auto past_lens_idx = 5; + const auto& memory_deps = impl_param.memory_deps; + const auto past_lens_mem = memory_deps.at(past_lens_idx); + mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm); + + long int total_size = 0; + const auto past_lens_size = past_lens_mem_lock.size(); + for (size_t i = 0; i < past_lens_size; i++) { + total_size += past_lens_mem_lock[i]; + } + + auto scores_output = data_layout; + scores_output.set_partial_shape(ov::PartialShape{total_size}); + + output_layouts.push_back(scores_output); + } + + return output_layouts; } template std::vector @@ -110,7 +131,8 @@ void paged_attention_inst::on_execute() { std::unique_ptr> sequential_gws_subseq_mapping_lock = nullptr; if (stage == PagedAttentionStage::MIXED) { - const auto sequential_gws_subseq_mapping_idx = 6; + const auto& desc = _impl_params->typed_desc(); + const size_t sequential_gws_subseq_mapping_idx = desc->has_scores_output() ? 7 : 6; OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx, "Unexpected number of intermediates buffers for Paged Attention for mixed stage"); diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 00c43829d02ea7..f99aae9e69f9f4 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -500,3 +500,72 @@ KERNEL(pa_sdpa_finalization_stage)( } #endif + +#ifdef SDPA_STAGE_2 + +REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) +KERNEL(pa_sdpa_scores_calculation)( + const __global INPUT3_TYPE* past_lens, + const __global INPUT6_TYPE* subsequence_begins, + __global OUTPUT1_TYPE* scores_output, + const __global SOFTMAX_ACCUMULATOR_TYPE* softmax_output, + const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, + const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits) { + const uint subsequence_idx = get_global_id(0); + const uint partition_global_idx = get_global_id(2); + const uint partition_idx = get_group_id(2); + const uint partition_size = get_group_size(2); + const uint max_seq_len = get_global_size(2); + const uint partitions_num = get_num_groups(2); + const uint sglid = get_sub_group_local_id(); + + const int subsequence_begin = subsequence_begins[subsequence_idx]; + const int subsequence_end = subsequence_begins[subsequence_idx + 1]; + const uint seq_len = (subsequence_end - subsequence_begin) + past_lens[subsequence_idx]; + + const uint num_of_partitions = CEIL_DIV(seq_len, partition_size); + + if (partition_idx >= num_of_partitions) + return; + + SOFTMAX_ACCUMULATOR_TYPE total_score = SOFTMAX_ACCUMULATOR_VAL_ZERO; + for (uint i = 0; i < HEAD_SIZE; i++) { + SOFTMAX_ACCUMULATOR_TYPE exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO; + SOFTMAX_ACCUMULATOR_TYPE max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN; + + const uint exp_sums_offset = subsequence_idx * HEAD_SIZE * partitions_num + i * partitions_num; + if (partition_global_idx < num_of_partitions) { + exp_sum = exp_sums[exp_sums_offset + partition_global_idx]; + max_logit = max_logits[exp_sums_offset + partition_global_idx]; + } + + SOFTMAX_ACCUMULATOR_TYPE global_max_logit = work_group_reduce_max(max_logit); + SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum * native_exp(max_logit - global_max_logit); + SOFTMAX_ACCUMULATOR_TYPE current_exp_sum = work_group_broadcast(adjusted_exp_sum, partition_idx); + + SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = work_group_reduce_add(adjusted_exp_sum); + + SOFTMAX_ACCUMULATOR_TYPE softmax_value = SOFTMAX_ACCUMULATOR_VAL_ZERO; + if (partition_idx < num_of_partitions) { + const uint input_offset = subsequence_idx * HEAD_SIZE * max_seq_len + i * max_seq_len + partition_global_idx; + softmax_value = softmax_output[input_offset]; + } + + softmax_value = softmax_value * current_exp_sum / global_exp_sum; + total_score + softmax_value; + } + + // WA: need to pass additional input with offsets + uint total_seq_len = 0; + for (uint i = 0; i < subsequence_idx; i++) { + const int subsequence_begin = subsequence_begins[i]; + const int subsequence_end = subsequence_begins[i + 1]; + total_seq_len += (subsequence_end - subsequence_begin) + past_lens[i]; + } + + if (partition_global_idx < seq_len) { + scores_output[total_seq_len + partition_global_idx] = softmax_value; + } +} + +#endif diff --git a/src/plugins/intel_gpu/src/kernel_selector/common_types.h b/src/plugins/intel_gpu/src/kernel_selector/common_types.h index 06b3e04d40e829..eace4b0b22387a 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/common_types.h +++ b/src/plugins/intel_gpu/src/kernel_selector/common_types.h @@ -17,6 +17,7 @@ enum class KernelType { BEAM_TABLE_UPDATE, PA_KV_CACHE_UPDATE, PA_SDPA, + PA_SCORES_CALCULATION, CONVOLUTION, DECONVOLUTION, DFT, diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_scores_calculation_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_scores_calculation_ref.cpp new file mode 100644 index 00000000000000..a9f425fa515946 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_scores_calculation_ref.cpp @@ -0,0 +1,182 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "pa_scores_calculation_ref.h" +#include "sdpa_kernel_base.h" + +#include "kernel_selector_params.h" +#include "kernel_selector_utils.h" + +namespace kernel_selector { + +constexpr size_t subgroup_size = 16; +constexpr size_t paged_attention_block_size = 16; + +static size_t get_generate_stage_block_size(size_t head_size) { + auto preferred_block_size = { 4, 2, 1 }; + for (const auto& block_size : preferred_block_size) { + if (head_size % (block_size * subgroup_size) == 0) { + return block_size; + } + } + + return 1; +} + +KernelsData PAScoresCalculation::GetKernelsData(const Params& p) const { + if (!Validate(p)) { + return {}; + } + + KernelData kd = KernelData::Default(p); + kd.needs_sub_kernels_sync = false; + GetUpdateDispatchDataFunc(kd); + + const auto& params = static_cast(p); + const auto dispatch_data = SetDefault(params); + const auto entry_point = GetEntryPoint(kernelName, params.layerID, p); + const auto jit_constants = GetJitConstants(params); + const auto jit = CreateJit(kernelName, jit_constants, entry_point); + + auto& kernel = kd.kernels[0]; + FillCLKernelData(kernel, + dispatch_data, + params.engineInfo, + kernelName, + jit, + entry_point, + {}, + false, + false, + static_cast(params.inputs.size()), + GetFusedPrimitiveInputsCount(params), + static_cast(params.outputs.size()), + params.is_shape_agnostic); + + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); + + ScalarDescriptor is_prefill_stage; + is_prefill_stage.t = ScalarDescriptor::Types::UINT32; + is_prefill_stage.v.u32 = static_cast(0); + kernel.params.scalars.push_back(is_prefill_stage); + + return {kd}; +} + +ParamsKey PAScoresCalculation::GetSupportedKey() const { + ParamsKey k; + + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT32); + + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT32); + + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfzyx); + + k.EnableDifferentTypes(); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDynamicShapesSupport(); + + return k; +} + +bool PAScoresCalculation::Validate(const Params& params) const { + if (params.GetType() != KernelType::PA_SCORES_CALCULATION) + return false; + + const auto& kernel_params = dynamic_cast(params); + if (kernel_params.inputs.size() != 6) + return false; + + if (kernel_params.outputs.size() != 2) + return false; + + if (!kernel_params.conf.is_paged_attention) + return false; + + if (kernel_params.conf.paged_attention_block_size != static_cast(paged_attention_block_size)) + return false; + + return true; +} + +JitConstants PAScoresCalculation::GetJitConstants(const pa_scores_calculation& params) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + jit.AddConstant(MakeJitConstant("HEAD_SIZE", params.conf.head_size)); + jit.AddConstant(MakeJitConstant("HEADS_NUM", params.conf.heads_num)); + jit.AddConstant(MakeJitConstant("KV_HEADS_NUM", params.conf.kv_heads_num)); + jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size)); + jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size)); + jit.AddConstant(MakeJitConstant("GENERATE_STAGE_BLOCK_SIZE", get_generate_stage_block_size(params.conf.head_size))); + + return jit; +} + +CommonDispatchData PAScoresCalculation::SetDefault(const pa_scores_calculation& params) { + CommonDispatchData dispatch_data; + + const auto& key_cache = params.outputs[0]; + const auto& value_cache = params.outputs[1]; + if (!value_cache.is_dynamic() && !key_cache.is_dynamic()) { + auto heads_number = static_cast(params.conf.kv_heads_num); + + // if (is_prefill) { + // const auto blocks_number = params.conf.paged_attention_aligned_seq_len / paged_attention_block_size; + + // dispatch_data.gws = { blocks_number, + // heads_number, + // subgroup_size }; + // dispatch_data.lws = { 1, 1, subgroup_size }; + // } else { + // const auto& key_input = params.inputs[0]; + // const auto sequences_number = key_input.Batch().v; + + // dispatch_data.gws = { sequences_number, + // heads_number, + // subgroup_size }; + // dispatch_data.lws = { 1, 1, subgroup_size }; + // } + } + + return dispatch_data; +} + +void PAScoresCalculation::GetUpdateDispatchDataFunc(KernelData& kd) const { + kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) { + const auto& prim_params = static_cast(params); + + auto dispatch_data = SetDefault(prim_params); + + OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func"); + kd.kernels[0].params.workGroups.global = dispatch_data.gws; + kd.kernels[0].params.workGroups.local = dispatch_data.lws; + kd.kernels[0].skip_execution = false; + + const auto indexes_dt = Datatype::INT32; + const auto target_seq_len_block_size = 16; + const auto target_seq_len = prim_params.conf.paged_attention_aligned_seq_len; + const auto indexes_buf_size = CeilDiv(target_seq_len, target_seq_len_block_size) * BytesPerElement(indexes_dt); + + kd.internalBufferSizes.clear(); + kd.internalBufferSizes.push_back(indexes_buf_size); + kd.internalBufferSizes.push_back(indexes_buf_size); + kd.internalBufferSizes.push_back(indexes_buf_size); + kd.internalBufferDataType = indexes_dt; + + // kd.kernels[0].params.scalars[0].v.s32 = static_cast(prim_params.is_prefill); + }; +} + +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_scores_calculation_ref.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_scores_calculation_ref.h new file mode 100644 index 00000000000000..fe5e8432ff5f52 --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_scores_calculation_ref.h @@ -0,0 +1,32 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_base_opencl.h" +#include "sdpa_kernel_base.h" + +namespace kernel_selector { + +struct pa_scores_calculation : base_params { + pa_scores_calculation() : base_params(KernelType::PA_SCORES_CALCULATION) {} + + sdpa_configuration conf; +}; + +class PAScoresCalculation : public KernelBaseOpenCL { +public: + PAScoresCalculation() : KernelBaseOpenCL{"pa_scores_calc"} {} + KernelsData GetKernelsData(const Params& params) const override; + ParamsKey GetSupportedKey() const override; + virtual ~PAScoresCalculation() {} + +protected: + bool Validate(const Params& params) const override; + JitConstants GetJitConstants(const pa_scores_calculation& kernel_params) const; + static CommonDispatchData SetDefault(const pa_scores_calculation& kernel_params); + void GetUpdateDispatchDataFunc(KernelData& kd) const override; +}; + +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp index 63c5e74160f652..80ed1c1445e258 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp @@ -2,6 +2,7 @@ // SPDX-License-Identifier: Apache-2.0 // +#include "sdpa_kernel_opt.h" #include "pa_sdpa_kernel_opt.h" #include "kernel_selector_params.h" @@ -15,6 +16,7 @@ enum KernelsTypes { MULTI_TOKENS, FINALIZATION, FINALIZATION_MULTI_TOKENS, + SCORES_CALCULATION, TOTAL_KERNELS_NUM }; @@ -35,6 +37,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) { kernel_name += "_finalization"; } else if (type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { kernel_name += "_finalization_multi_tokens_seq"; + } else if (type == KernelsTypes::SCORES_CALCULATION) { + kernel_name += "_scores_calculation"; } return kernel_name; @@ -46,10 +50,15 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { } const auto& params = static_cast(p); - const std::vector kernels_type = { KernelsTypes::SINGLE_TOKEN, - KernelsTypes::MULTI_TOKENS, - KernelsTypes::FINALIZATION, - KernelsTypes::FINALIZATION_MULTI_TOKENS }; + std::vector kernels_type = { KernelsTypes::SINGLE_TOKEN, + KernelsTypes::MULTI_TOKENS, + KernelsTypes::FINALIZATION, + KernelsTypes::FINALIZATION_MULTI_TOKENS }; + + const auto has_scores_output = params.outputs.size() > 1; + if (has_scores_output) { + kernels_type.push_back(KernelsTypes::SCORES_CALCULATION); + } KernelData kd = KernelData::Default(params, kernels_type.size()); kd.needs_sub_kernels_sync = true; @@ -72,8 +81,8 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { } else if (kernel_type == KernelsTypes::FINALIZATION) { // FINALIZATION kernel uses only the past_lens data input inputs_num = 1; - } else if (kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { - // FINALIZATION_MULTI_TOKENS kernel uses past_lens data input and subsequence_begins + } else if (kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS || kernel_type == KernelsTypes::SCORES_CALCULATION) { + // FINALIZATION_MULTI_TOKENS and SCORES_CALCULATION kernels use past_lens data input and subsequence_begins inputs_num = 2; } @@ -92,14 +101,23 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { static_cast(params.outputs.size()), params.is_shape_agnostic); - kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); - kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); - kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); + uint32_t internal_buffers_num = 0; + if (has_scores_output) { + // Intermediate softmax results for PA scores output + internal_buffers_num++; + } + + // Softmax's exp_sums, max_logits and intermediate output + internal_buffers_num += 3; if (kernel_type == KernelsTypes::MULTI_TOKENS || kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { // MULTIPLE_TOKENS kernels needs additional information related to mapping // launched kernel instances to subsequence indexes - kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); + internal_buffers_num++; + } + + for (uint32_t i = 0; i < internal_buffers_num; i++) { + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, i}); } if (kernel_type == KernelsTypes::FINALIZATION || kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { @@ -108,6 +126,13 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { // Remove unused shape_info argument at finalization stage kernel.params.arguments.erase(kernel.params.arguments.begin()); } + + if (has_scores_output) { + // Add two scalars for partition_length and partitions_num values, + // as depending on PA stage these parameters may vary + kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 1}); + } } return {kd}; @@ -173,7 +198,12 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size)); } - auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS ? 1 : 0; + auto sdpa_stage = 0; + if (kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS) { + sdpa_stage = 1; + } else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) { + sdpa_stage = 2; + } jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1)); if (config.has_const_scale_val) { @@ -203,18 +233,35 @@ CommonDispatchData PagedAttentionSDPAKernelOpt::SetDefault(const pa_sdpa_params& const auto& input = params.inputs[0]; if (!input.is_dynamic()) { - const size_t sequences_number = input.Batch().v; - const size_t num_of_partitions = CeilDiv(params.max_context_len, seq_len_partition_size); + const size_t total_tokens = input.Batch().v; + const size_t num_of_partitions = CeilDiv(params.conf.paged_attention_max_len, seq_len_partition_size); const size_t heads_num = static_cast(params.conf.heads_num); const size_t head_size = static_cast(params.conf.head_size); - if (kernel_idx == 0) { - dispatch_data.gws = { sequences_number, + if (kernel_idx == KernelsTypes::SINGLE_TOKEN || kernel_idx == KernelsTypes::MULTI_TOKENS) { + dispatch_data.gws = { total_tokens, heads_num, head_size * num_of_partitions }; dispatch_data.lws = { 1, 1, head_size }; + } else if (kernel_idx == KernelsTypes::SCORES_CALCULATION) { + const auto& past_lens = params.inputs[3]; + const auto subsequences_number = past_lens.Batch().v; // number of past lens values + + size_t partition_size = 0; + size_t num_of_partitions = 0; + if (params.stage == PagedAttentionStage::PREFILL) { + partition_size = SDPAKernelOpt::get_seq_len_partition_size(params, params.conf.head_size, 1); + } else { + partition_size = seq_len_partition_size; + } + num_of_partitions = CeilDiv(params.conf.paged_attention_max_len, seq_len_partition_size); + + dispatch_data.gws = { subsequences_number, + 1, + partition_size * num_of_partitions }; + dispatch_data.lws = { 1, 1, partition_size }; } else { - dispatch_data.gws = { sequences_number, + dispatch_data.gws = { total_tokens, heads_num, head_size }; dispatch_data.lws = { 1, 1, subgroup_size }; @@ -228,30 +275,33 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) { const auto& prim_params = static_cast(params); - const size_t expected_kernels_num = 4; + const size_t expected_kernels_num = KernelsTypes::TOTAL_KERNELS_NUM; OPENVINO_ASSERT(kd.kernels.size() == expected_kernels_num, "[GPU] Invalid kernels size for update dispatch data func of SDPA kernel"); + const auto has_scores_output = prim_params.outputs.size() > 1; + const auto scores_calc_only = has_scores_output && prim_params.stage == PagedAttentionStage::PREFILL; + auto dispatch_data1 = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN); kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data1.gws; kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data1.lws; - kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = prim_params.multi_tokens_mode; + kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = prim_params.stage == PagedAttentionStage::MIXED && !scores_calc_only; kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data1.gws; kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data1.lws; - kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !prim_params.multi_tokens_mode; + kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = prim_params.stage != PagedAttentionStage::MIXED && !scores_calc_only; const auto& input = prim_params.inputs[0]; - const size_t sequences_number = input.Batch().v; - const size_t num_of_partitions = CeilDiv(prim_params.max_context_len, seq_len_partition_size); + const size_t total_tokens = input.Batch().v; + const size_t num_of_partitions = CeilDiv(prim_params.conf.paged_attention_max_len, seq_len_partition_size); auto dispatch_data2 = SetDefault(prim_params, KernelsTypes::FINALIZATION); kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data2.gws; kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data2.lws; - kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || prim_params.multi_tokens_mode; + kd.kernels[KernelsTypes::FINALIZATION].skip_execution = (num_of_partitions == 1 || prim_params.stage == PagedAttentionStage::MIXED) && !scores_calc_only; kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data2.gws; kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data2.lws; - kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !prim_params.multi_tokens_mode; + kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = (num_of_partitions == 1 || prim_params.stage != PagedAttentionStage::MIXED) && !scores_calc_only; ScalarDescriptor num_of_partitions_scalar; num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32; @@ -261,23 +311,50 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars.resize(1); kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars[0] = num_of_partitions_scalar; + if (has_scores_output) { + auto dispatch_data = SetDefault(prim_params, KernelsTypes::SCORES_CALCULATION); + kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.global = dispatch_data.gws; + kd.kernels[KernelsTypes::SCORES_CALCULATION].params.workGroups.local = dispatch_data.lws; + kd.kernels[KernelsTypes::SCORES_CALCULATION].skip_execution = false; + } + auto buf_dt_size = BytesPerElement(softmax_acc_dt); - auto buf_elements_count = sequences_number * prim_params.conf.heads_num * num_of_partitions; + auto buf_elements_count = total_tokens * prim_params.conf.heads_num * num_of_partitions; auto buf_size = buf_elements_count * buf_dt_size; auto tmp_out_dt_size = BytesPerElement(softmax_acc_dt); - auto tmp_out_elements_count = sequences_number * prim_params.conf.heads_num * prim_params.conf.head_size * num_of_partitions; + auto tmp_out_elements_count = total_tokens * prim_params.conf.heads_num * prim_params.conf.head_size * num_of_partitions; auto tmp_out_size = tmp_out_elements_count * tmp_out_dt_size; kd.internalBufferSizes.clear(); - kd.internalBufferSizes.push_back(buf_size); - kd.internalBufferSizes.push_back(buf_size); - kd.internalBufferSizes.push_back(tmp_out_size); + + if (has_scores_output) { + const auto& past_lens = prim_params.inputs[3]; + auto softmax_buf_dt_size = BytesPerElement(softmax_acc_dt); + auto subsequences_number = past_lens.Batch().v; + auto softmax_buf_elements_count = subsequences_number * num_of_partitions * seq_len_partition_size; + auto softmax_buf_size = softmax_buf_elements_count * softmax_buf_dt_size; + + kd.internalBufferSizes.push_back(softmax_buf_size); // softmax intermediate output + + if (prim_params.stage == PagedAttentionStage::PREFILL) { + // Recalculate buf_size as in case of PREFILL stage it's not needed to allocate buffer per each input token + buf_elements_count = subsequences_number * prim_params.conf.heads_num * prim_params.conf.head_size * num_of_partitions; + buf_size = buf_elements_count * tmp_out_dt_size; + + // Intermediate tmp output buffer is not used for PREFILL stage + tmp_out_size = tmp_out_dt_size; + } + } + + kd.internalBufferSizes.push_back(buf_size); // softmax exp_sums + kd.internalBufferSizes.push_back(buf_size); // softmax max_logits + kd.internalBufferSizes.push_back(tmp_out_size); // intermediate output kd.internalBufferDataType = softmax_acc_dt; - if (prim_params.multi_tokens_mode) { + if (prim_params.stage == PagedAttentionStage::MIXED) { auto buf_dt_size = BytesPerElement(Datatype::INT32); - auto buf_elements_count = sequences_number; + auto buf_elements_count = total_tokens; auto buf_size = Align(buf_elements_count * buf_dt_size, BytesPerElement(softmax_acc_dt)); kd.internalBufferSizes.push_back(buf_size); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h index a2456ccd9e2af5..a52571b03691df 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h @@ -9,11 +9,17 @@ namespace kernel_selector { +enum PagedAttentionStage { + GENERATE = 0, + PREFILL = 1, + MIXED = 2, + UNKNOWN = 3 +}; + struct pa_sdpa_params : base_params { pa_sdpa_params() : base_params(KernelType::PA_SDPA) {} - bool multi_tokens_mode = false; - size_t max_context_len = 0; + PagedAttentionStage stage = PagedAttentionStage::UNKNOWN; sdpa_configuration conf; }; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h index 5cd9c384ff2709..423c2106e084b5 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_base.h @@ -97,6 +97,8 @@ struct sdpa_configuration { bool is_paged_attention = false; int64_t paged_attention_aligned_seq_len = -1; int64_t paged_attention_block_size = 0; + int64_t paged_attention_sequences_num = 0; + int64_t paged_attention_max_len = 0; bool has_const_scale_val = false; float scale_val = 0.f; }; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp index 4e71064efbc895..cb97b78ff02061 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.cpp @@ -10,49 +10,13 @@ namespace kernel_selector { -namespace { -enum KernelsTypes { - SINGLE_TOKEN = 0, - MULTI_TOKENS, - FINALIZATION, - TOTAL_KERNELS_NUM -}; - constexpr size_t subgroup_size = 16; -} // namespace - -static size_t get_sg_number_scale_factor(const sdpa_params& sdpa_params, size_t kernel_type) { - const size_t optimal_scale_factor = 2; - if (kernel_type == KernelsTypes::MULTI_TOKENS) { - if (sdpa_params.conf.head_size * optimal_scale_factor <= sdpa_params.engineInfo.maxWorkGroupSize) { - return optimal_scale_factor; - } - } else if (kernel_type == KernelsTypes::SINGLE_TOKEN) { - if (sdpa_params.conf.head_size * optimal_scale_factor <= sdpa_params.engineInfo.maxWorkGroupSize && - sdpa_params.conf.head_size * optimal_scale_factor / subgroup_size <= subgroup_size) { - return optimal_scale_factor; - } - } - - return 1; -} static size_t get_target_seq_len_block_size() { const size_t block_size = 16; return block_size; } -static size_t get_seq_len_partition_size(const sdpa_params& sdpa_params, size_t kernel_type) { - size_t seq_len = 0; - if (kernel_type == KernelsTypes::MULTI_TOKENS) { - seq_len = sdpa_params.conf.head_size * get_sg_number_scale_factor(sdpa_params, kernel_type); - } else { - seq_len = 256; - } - - return seq_len; -} - static Datatype get_softmax_acc_type() { return Datatype::F32; } @@ -65,13 +29,13 @@ static bool is_prefill_stage(const sdpa_params& sdpa_params) { } static size_t get_partitions_num(const sdpa_params& sdpa_params, size_t kernel_type) { - if (sdpa_params.has_dynamic_tensors() || kernel_type == KernelsTypes::MULTI_TOKENS) + if (sdpa_params.has_dynamic_tensors() || kernel_type == SDPAKernelOpt::KernelsTypes::MULTI_TOKENS) return 1; TransposedDimensionAccessHelperBase dims_k(sdpa_params.inputs[1], sdpa_params.input1_order); auto source_seq_len = dims_k.y_dim().v; - return CeilDiv(source_seq_len, get_seq_len_partition_size(sdpa_params, kernel_type)); + return CeilDiv(source_seq_len, SDPAKernelOpt::get_seq_len_partition_size(sdpa_params, kernel_type)); } static std::vector get_internal_buffer_sizes(const sdpa_params& sdpa_params, size_t kernel_type) { @@ -83,7 +47,7 @@ static std::vector get_internal_buffer_sizes(const sdpa_params& sdpa_par return {blocks_indexes_buf_size}; } else { - if (sdpa_params.has_dynamic_tensors() || kernel_type == KernelsTypes::MULTI_TOKENS) { + if (sdpa_params.has_dynamic_tensors() || kernel_type == SDPAKernelOpt::KernelsTypes::MULTI_TOKENS) { const auto default_bytes_count = BytesPerElement(get_softmax_acc_type()); return {default_bytes_count, default_bytes_count}; } else { @@ -107,7 +71,7 @@ static std::vector get_internal_buffer_sizes(const sdpa_params& sdpa_par } } -static std::string GetKernelName(std::string base_name, KernelsTypes type, const sdpa_params& params) { +static std::string GetKernelName(std::string base_name, SDPAKernelOpt::KernelsTypes type, const sdpa_params& params) { const bool is_indirect = params.indirect_axis != -1; const bool is_paged_attention = params.conf.is_paged_attention; @@ -119,17 +83,44 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type, const if (is_indirect) kernel_name += "_ind"; - if (type == KernelsTypes::SINGLE_TOKEN) { + if (type == SDPAKernelOpt::KernelsTypes::SINGLE_TOKEN) { kernel_name += "_single_token"; - } else if (type == KernelsTypes::MULTI_TOKENS) { + } else if (type == SDPAKernelOpt::KernelsTypes::MULTI_TOKENS) { kernel_name += "_multi_tokens"; - } else if (type == KernelsTypes::FINALIZATION) { + } else if (type == SDPAKernelOpt::KernelsTypes::FINALIZATION) { kernel_name += "_finalization"; } return kernel_name; } +size_t SDPAKernelOpt::get_sg_number_scale_factor(const Params& params, size_t head_size, size_t kernel_type) { + const size_t optimal_scale_factor = 2; + if (kernel_type == KernelsTypes::MULTI_TOKENS) { + if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize) { + return optimal_scale_factor; + } + } else if (kernel_type == KernelsTypes::SINGLE_TOKEN) { + if (head_size * optimal_scale_factor <= params.engineInfo.maxWorkGroupSize && + head_size * optimal_scale_factor / subgroup_size <= subgroup_size) { + return optimal_scale_factor; + } + } + + return 1; +} + +size_t SDPAKernelOpt::get_seq_len_partition_size(const Params& params, size_t head_size, size_t kernel_type) { + size_t seq_len = 0; + if (kernel_type == KernelsTypes::MULTI_TOKENS) { + seq_len = head_size * get_sg_number_scale_factor(params, head_size, kernel_type); + } else { + seq_len = 256; + } + + return seq_len; +} + ParamsKey SDPAKernelOpt::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::INT8); @@ -176,14 +167,14 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke const auto& config = params.conf; jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size)); jit.AddConstant(MakeJitConstant("HEAD_SIZE", config.head_size)); - jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", get_seq_len_partition_size(params, kernel_idx))); + jit.AddConstant(MakeJitConstant("SEQ_LEN_PARTITION_SIZE", get_seq_len_partition_size(params, config.head_size, kernel_idx))); auto target_seq_len_block_size = kernel_idx == KernelsTypes::SINGLE_TOKEN ? 1 : get_target_seq_len_block_size(); jit.AddConstant(MakeJitConstant("TARGET_SEQ_LEN_BLOCK_SIZE", target_seq_len_block_size)); auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION ? 1 : 0; jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1)); - jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, kernel_idx))); + jit.AddConstant(MakeJitConstant("SG_SCALE_FACTOR", get_sg_number_scale_factor(params, config.head_size, kernel_idx))); if (params.conf.is_paged_attention) { if (params.conf.has_alibi_input) { @@ -218,8 +209,8 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k if (params.conf.is_paged_attention) { OPENVINO_ASSERT(kernel_idx == KernelsTypes::MULTI_TOKENS); - const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx); const size_t heads_num = static_cast(params.conf.heads_num); + const size_t sg_num_scale = get_sg_number_scale_factor(params, heads_num, kernel_idx); const size_t target_seq_len_block_size = get_target_seq_len_block_size(); const size_t target_seq_len = static_cast(params.conf.paged_attention_aligned_seq_len); const size_t head_size = static_cast(params.conf.head_size); @@ -243,13 +234,13 @@ CommonDispatchData SDPAKernelOpt::SetDefault(const sdpa_params& params, size_t k const size_t target_seq_len_block_size = kernel_idx == 1 ? get_target_seq_len_block_size() : 1; if (kernel_idx == KernelsTypes::SINGLE_TOKEN) { - const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx); + const size_t sg_num_scale = get_sg_number_scale_factor(params, heads_num, kernel_idx); dispatch_data.gws = { batch_size * heads_num, CeilDiv(target_seq_len, target_seq_len_block_size), head_size * num_of_partitions * sg_num_scale }; dispatch_data.lws = { 1, 1, head_size * sg_num_scale }; } else if (kernel_idx == KernelsTypes::MULTI_TOKENS) { - const size_t sg_num_scale = get_sg_number_scale_factor(params, kernel_idx); + const size_t sg_num_scale = get_sg_number_scale_factor(params, heads_num, kernel_idx); dispatch_data.gws = { batch_size * heads_num, CeilDiv(target_seq_len, target_seq_len_block_size), head_size * sg_num_scale }; @@ -339,6 +330,11 @@ KernelsData SDPAKernelOpt::GetKernelsData(const Params& params) const { kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); + // Intermediate softmax results for PA scores output + if (prim_params.conf.is_paged_attention && prim_params.outputs.size() == 2) { + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); + } + const auto buf_sizes = get_internal_buffer_sizes(prim_params, kernel_idx); if (!prim_params.conf.is_paged_attention) { kd.internalBufferSizes.clear(); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h index 8d7279f5546112..db0d82ad16479a 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_opt.h @@ -9,6 +9,13 @@ namespace kernel_selector { class SDPAKernelOpt : public SDPAKernelBase { public: + enum KernelsTypes { + SINGLE_TOKEN = 0, + MULTI_TOKENS, + FINALIZATION, + TOTAL_KERNELS_NUM + }; + using Parent = SDPAKernelBase; SDPAKernelOpt() : SDPAKernelBase("sdpa_opt") {} virtual ~SDPAKernelOpt() {} @@ -17,6 +24,9 @@ class SDPAKernelOpt : public SDPAKernelBase { KernelsPriority GetKernelsPriority(const Params& params) const override; ParamsKey GetSupportedKey() const override; + static size_t get_sg_number_scale_factor(const sdpa_params& sdpa_params, size_t kernel_type); + static size_t get_seq_len_partition_size(const sdpa_params& sdpa_params, size_t kernel_type); + protected: bool Validate(const Params& p) const override; void GetUpdateDispatchDataFunc(KernelData& kd) const override; diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index 7425b096b6d324..d82d3a66fed7f7 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -61,10 +61,13 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared OPENVINO_ASSERT(alibi_const != nullptr); prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0; + prim.num_outputs = 1; if (op->get_output_size() > 1) { const auto scores_output_idx = 1; const auto& users = op->get_output_target_inputs(scores_output_idx); - OPENVINO_ASSERT(users.size() == 0, "[GPU] PagedAttention implementation doesn't support scores output yet"); + if (users.size() > 0) { + prim.num_outputs++; // Add scores output + } } p.add_primitive(*op, prim); diff --git a/src/plugins/intel_gpu/src/runtime/ocl/ocl_ext.hpp b/src/plugins/intel_gpu/src/runtime/ocl/ocl_ext.hpp index 759d796a5e87e8..0eb533b6999a47 100644 --- a/src/plugins/intel_gpu/src/runtime/ocl/ocl_ext.hpp +++ b/src/plugins/intel_gpu/src/runtime/ocl/ocl_ext.hpp @@ -990,8 +990,9 @@ class KernelIntel : public Kernel { return KernelIntel(cloned_kernel, _usmHelper); } - cl_int setArgUsm(cl_uint index, const UsmMemory& mem) { - return detail::errHandler(_usmHelper.set_kernel_arg_mem_pointer(*this, index, mem.get()), "[CL_EXT] setArgUsm in KernelIntel failed"); + cl_int setArgUsm(cl_uint index, const UsmMemory& mem, size_t offset = 0) { + return detail::errHandler(_usmHelper.set_kernel_arg_mem_pointer(*this, index, static_cast(mem.get()) + offset), + "[CL_EXT] setArgUsm in KernelIntel failed"); } private: const UsmHelper& _usmHelper; diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/mem_perf_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/mem_perf_test.cpp index c739ac9a0206ec..fe31d07d6aedbf 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/mem_perf_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/mem_perf_test.cpp @@ -348,7 +348,7 @@ TEST(mem_perf_test_to_device, DISABLED_usm_device) { validate_result(static_cast(output_buffer_host.get()), img_size * img_size); } -TEST(mem_perf_test_to_device, DISABLED_usm_device_copy) { +TEST(mem_perf_test_to_device, usm_device_copy) { auto ocl_instance = std::make_shared(); auto& ctx = ocl_instance->_context; auto& device = ocl_instance->_device; @@ -359,16 +359,18 @@ TEST(mem_perf_test_to_device, DISABLED_usm_device_copy) { std::cout << "Time of copying data from host buffer cl::UsmMemory (UsmHost type) to cl::UsmMemory (UsmDevice type)" << std::endl; + auto extra_mem = 1024 * 64; + cl::Program program(ctx, kernel_code); checkStatus(program.build({device}, ""), "build"); cl::UsmMemory input_buffer_host(usm_helper); - input_buffer_host.allocateHost(sizeof(uint8_t) * img_size * img_size); + input_buffer_host.allocateHost(sizeof(uint8_t) * img_size * img_size + extra_mem); cl::UsmMemory input_buffer_device(usm_helper); - input_buffer_device.allocateDevice(sizeof(uint8_t) * img_size * img_size); + input_buffer_device.allocateDevice(sizeof(uint8_t) * img_size * img_size + extra_mem); cl::UsmMemory output_buffer(usm_helper); - output_buffer.allocateDevice(sizeof(float) * img_size * img_size); + output_buffer.allocateDevice(sizeof(float) * img_size * img_size + extra_mem); cl::UsmMemory output_buffer_host(usm_helper); - output_buffer_host.allocateHost(sizeof(float) * img_size * img_size); + output_buffer_host.allocateHost(sizeof(float) * img_size * img_size + extra_mem); cl::Kernel kernel1(program, "simple_reorder"); cl::KernelIntel kernel(kernel1, usm_helper); @@ -385,8 +387,8 @@ TEST(mem_perf_test_to_device, DISABLED_usm_device_copy) { false, nullptr, ©_ev); - kernel.setArgUsm(0, input_buffer_device); - kernel.setArgUsm(1, output_buffer); + kernel.setArgUsm(0, input_buffer_device, extra_mem); + kernel.setArgUsm(1, output_buffer, extra_mem); cl::Event ev; std::vector dep_ev = {copy_ev}; queue.enqueueNDRangeKernel(kernel, cl::NDRange(), cl::NDRange(img_size*img_size), cl::NDRange(16), &dep_ev, &ev);