diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp index 7b5df754023a4a..cfc1e17c87ac6e 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp @@ -52,6 +52,15 @@ struct paged_attention_impl : multi_stage_primitive { PA_SDPA, }; + bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const override { + const auto stage = get_paged_attention_stage(impl_params); + + // In case of MIXED mode execution Paged Attention may require dispatch data update and internal + // buffers reallocation even if the input shapes haven't been changed. Therefore, check the current execution + // mode and update parameters if needed + return stage == PagedAttentionStage::MIXED; + } + void load(BinaryInputBuffer& ib) override { parent::load(ib); if (is_dynamic()) { @@ -90,7 +99,7 @@ struct paged_attention_impl : multi_stage_primitive { return layouts; } - kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage, size_t kernel_idx) const { + kernel_arguments_data get_arguments(const paged_attention_inst& instance, size_t stage, size_t kernel_idx, bool is_mixed_mode) const { const auto desc = instance.get_node().as().get_primitive(); kernel_arguments_data args; @@ -129,7 +138,7 @@ struct paged_attention_impl : multi_stage_primitive { instance.block_indices_memory_ptr(), instance.block_indices_begins_memory_ptr() }; - if (kernel_idx == 1) { + if (is_mixed_mode) { // Multi tokens kernel version has additional subsequence_begins_memory memory // dependency args.inputs.push_back(instance.subsequence_begins_memory_ptr()); @@ -140,6 +149,12 @@ struct paged_attention_impl : multi_stage_primitive { } } else { args.inputs = { instance.past_lens_memory_ptr() }; + + if (is_mixed_mode) { + // Multi tokens kernel version has additional subsequence_begins_memory memory + // dependency + args.inputs.push_back(instance.subsequence_begins_memory_ptr()); + } } args.outputs = { instance.output_memory_ptr(0) }; @@ -153,7 +168,11 @@ struct paged_attention_impl : multi_stage_primitive { 6, /* PA_SDPA multiple tokens mode */ }; }; - void execute_stage(const std::vector& events, paged_attention_inst& instance, std::vector& all_events, size_t stage) { + void execute_stage(const std::vector& events, + paged_attention_inst& instance, + std::vector& all_events, + size_t stage, + bool is_mixed_mode) { stream& stream = instance.get_network().get_stream(); std::vector tmp_events(events); size_t kernel_offset = 0; @@ -181,7 +200,7 @@ struct paged_attention_impl : multi_stage_primitive { auto& params = _kernels_data[stage].kernels[kd_idx].params; - auto args = get_arguments(instance, stage, kd_idx); + auto args = get_arguments(instance, stage, kd_idx, is_mixed_mode); args.scalars = ¶ms.scalars; const auto& intermediate_memories = instance.get_intermediates_memories(); @@ -211,14 +230,15 @@ struct paged_attention_impl : multi_stage_primitive { event::ptr execute_impl(const std::vector& events, paged_attention_inst& instance) override { std::vector res_events; const auto stage = get_paged_attention_stage(*instance.get_impl_params()); + const auto is_mixed_mode = stage == PagedAttentionStage::MIXED; - execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE); + execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE, is_mixed_mode); std::vector dep_events(res_events.begin(), res_events.end()); if (stage == PagedAttentionStage::PREFILL) { - execute_stage(dep_events, instance, res_events, Stage::SDPA); + execute_stage(dep_events, instance, res_events, Stage::SDPA, is_mixed_mode); } else if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED) { - execute_stage(dep_events, instance, res_events, Stage::PA_SDPA); + execute_stage(dep_events, instance, res_events, Stage::PA_SDPA, is_mixed_mode); } return instance.get_network().get_stream().aggregate_events(res_events, res_events.size() > 1); @@ -248,9 +268,45 @@ struct paged_attention_impl : multi_stage_primitive { mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, *impl_param.strm); auto aligned_seq_len = 0; - for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { - auto prompt_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; - aligned_seq_len += align_to(prompt_length, target_seq_len_block_size); + if (stage == PagedAttentionStage::MIXED) { + const auto past_lens_idx = 5; + const auto past_lens_mem = input_mem.at(past_lens_idx); + mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm); + + for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { + auto past_len = past_lens_mem_lock[i]; + auto seq_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; + + // Since in MIXED execution mode the present KV-cache can be appended to the past KV-cache at any offset inside block, + // to ensure proper alignment and update_kv_cache kernel scheduling, we need to account for the number of unaligned tokens + // in the first block + // For example, if we need to store values in the following slots: + // + // block0: |O|O|O|O|O|O|O|O|O|O|O|O|U|U|U|U| + // block1: |U|U|U|U|U|U|U|U|U|U|U|U|U|U|U|U| + // block2: |U|U|U|U|U|U|E|E|E|E|E|E|E|E|E|E| + // Where O - occupied slots, U - currently beeing updated slots, E - empty slots + // + // We need to schedule 3 update_kv_cache operations: + // - For ranges of block0: [12-15] + // - For ranges of block1: [0-15] + // - For ranges of block2: [0-5] + // + // Therefore, consider an additional increment of aligned_seq_len to properly process all the blocks + + auto occupied_slots_num = past_len % target_seq_len_block_size; + if (past_len != 0 && seq_length + occupied_slots_num > target_seq_len_block_size) { + aligned_seq_len += target_seq_len_block_size; + seq_length -= target_seq_len_block_size - occupied_slots_num; + } + + aligned_seq_len += align_to(seq_length, target_seq_len_block_size); + } + } else { + for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { + auto prompt_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; + aligned_seq_len += align_to(prompt_length, target_seq_len_block_size); + } } return aligned_seq_len; diff --git a/src/plugins/intel_gpu/src/graph/include/primitive_inst.h b/src/plugins/intel_gpu/src/graph/include/primitive_inst.h index 6efb2c4c03644f..e4e00e75182bae 100644 --- a/src/plugins/intel_gpu/src/graph/include/primitive_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/primitive_inst.h @@ -110,6 +110,11 @@ struct primitive_impl { OPENVINO_ASSERT(false, "[GPU] update() is not implemented for dynamic implemenation ", _kernel_name); } + virtual bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const { + OPENVINO_ASSERT(_is_dynamic, "[GPU] requires_update() is called for static shape implementation ", _kernel_name); + return false; + } + static kernel_impl_params static_canonicalize_shapes(const kernel_impl_params& impl_params); virtual kernel_impl_params canonicalize_shapes(const kernel_impl_params& impl_params) const { diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 037a6a1e8b04aa..cce0613dda8226 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -93,6 +93,7 @@ void paged_attention_inst::on_execute() { const auto blocks_indexes_end_idx = 1; const auto blocked_gws_subseq_mapping_idx = 2; + const auto past_lens_mem = past_lens_memory_ptr(); auto subsequence_begins_mem = subsequence_begins_memory_ptr(); auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx]; auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx]; @@ -100,7 +101,8 @@ void paged_attention_inst::on_execute() { OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32); - auto& stream = get_network().get_stream(); + auto& stream = get_network().get_program()->get_stream(); + mem_lock past_lens_mem_lock(past_lens_mem, stream); mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, stream); mem_lock blocks_indexes_start_lock(blocks_indexes_start_mem, stream); mem_lock blocks_indexes_end_lock(blocks_indexes_end_mem, stream); @@ -120,11 +122,28 @@ void paged_attention_inst::on_execute() { size_t index = 0; const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { + const auto past_len = past_lens_mem_lock[i]; const auto seq_start = subsequence_begins_mem_lock[i]; const auto seq_end = subsequence_begins_mem_lock[i + 1]; const auto seq_length = seq_end - seq_start; - for (int32_t j = 0; j < seq_length; j += target_seq_len_block_size) { + int32_t j = 0; + if (past_len != 0) { + auto block_start_pos = seq_start; + auto empty_slots = target_seq_len_block_size - (past_len % target_seq_len_block_size); + auto block_end_pos = seq_start + std::min(empty_slots, seq_length); + + blocks_indexes_start_lock[index] = block_start_pos; + blocks_indexes_end_lock[index] = block_end_pos; + blocked_gws_subseq_mapping_mem_lock[index] = static_cast(i); + + index++; + + auto added_tokens = block_end_pos - block_start_pos; + j += added_tokens; + } + + for (; j < seq_length; j += target_seq_len_block_size) { auto block_start_pos = subsequence_begins_mem_lock[i] + j; auto block_end_pos = std::min(block_start_pos + target_seq_len_block_size, seq_end); diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index ad1541177b7dd6..fe2b11e5e36bf2 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -1648,6 +1648,7 @@ event::ptr primitive_inst::execute(const std::vector& events) { // Try update impl if current impl is dynamic because opt kernel may be added to impl cache through async compilation. // Only try update weight and realloc when impl is updated. const bool can_use_async_compilation = use_async_compilation(); + bool is_updated = false; if (shape_changed() || !_impl || (!shape_changed() && _impl->is_dynamic() && can_use_async_compilation)) { if (update_impl(can_use_async_compilation)) { need_args_update = true; @@ -1657,9 +1658,22 @@ event::ptr primitive_inst::execute(const std::vector& events) { auto ev_reset = realloc_if_needed(); if (ev_reset) dependencies.push_back(ev_reset); + + is_updated = true; } } + // Paged Attention may require dispatch data update and internal buffers reallocation + // even if the input shapes haven't been changed + if (_node->is_type() && !is_updated && _impl->requires_update(*this, *_impl_params)) { + _impl->update(*this, *_impl_params); + + need_args_update = true; + auto ev_reset = realloc_if_needed(); + if (ev_reset) + dependencies.push_back(ev_reset); + } + OPENVINO_ASSERT(_impl_params->get_output_layout().is_static(), "[GPU] Can't execute ", primitive_id, " primitive as output layout is dynamic in runtime"); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl index d41e3fd5fc9a80..ef2f78496b2cf2 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl @@ -4,6 +4,8 @@ #include "include/batch_headers/common.cl" +REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) +__attribute__((reqd_work_group_size(1, 1, SUBGROUP_SIZE))) KERNEL(pa_kv_cache_update)( OPTIONAL_SHAPE_INFO_ARG __global const INPUT0_TYPE* key_data, @@ -77,20 +79,25 @@ KERNEL(pa_kv_cache_update)( const uint block_start_pos = blocked_indexes_start[block_idx]; const uint block_end_pos = blocked_indexes_end[block_idx]; const uint tokens_num = block_end_pos - block_start_pos; + const uint past_len = past_lens[subsequence_idx]; + + const uint token_start_pos = (past_len + block_start_pos - subsequence_begin_idx) % PAGED_ATTENTION_BLOCK_SIZE; uint key_value_in_offset = block_start_pos * KV_HEADS_NUM * HEAD_SIZE + head_idx * HEAD_SIZE; - const uint cached_blocks_num = past_lens[subsequence_idx] / PAGED_ATTENTION_BLOCK_SIZE; - const uint current_block_idx = (block_start_pos - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE; + const uint current_block_idx = (past_len + block_start_pos - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE; - const uint block_offset = block_indices_begins[subsequence_idx] + cached_blocks_num + current_block_idx; + const uint block_offset = block_indices_begins[subsequence_idx] + current_block_idx; uint key_out_offset = block_indices[block_offset] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; uint value_out_offset = key_out_offset; + key_out_offset += token_start_pos; + value_out_offset += token_start_pos * HEAD_SIZE; + if (tokens_num == PAGED_ATTENTION_BLOCK_SIZE) { unroll_for (uint token_num = 0; token_num < PAGED_ATTENTION_BLOCK_SIZE; token_num++) { uint head_idx_index = 0; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 06e83c5adb3e6b..22b561e3d78661 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -215,7 +215,11 @@ KERNEL(pa_sdpa_opt)( // TODO: const uint global_data_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + local_data_idx const uint global_data_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid; +#if SEQ_LEN_PARTITION_SIZE % SUBGROUPS_PER_WG * SUBGROUP_SIZE == 0 if (global_data_idx < seq_len) { +#else + if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) { +#endif SOFTMAX_ACCUMULATOR_TYPE qk_new = native_exp(TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) - qk_max); slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new); @@ -242,7 +246,11 @@ KERNEL(pa_sdpa_opt)( const uint local_data_idx = qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid; const uint global_data_idx = partition_idx * SEQ_LEN_PARTITION_SIZE + qk_idx * (SUBGROUPS_PER_WG * SUBGROUP_SIZE) + sgid * SUBGROUP_SIZE + sglid; +#if SEQ_LEN_PARTITION_SIZE % SUBGROUPS_PER_WG * SUBGROUP_SIZE == 0 if (global_data_idx < seq_len) { +#else + if (global_data_idx < seq_len && local_data_idx < SEQ_LEN_PARTITION_SIZE) { +#endif SOFTMAX_ACCUMULATOR_TYPE qk_new = TO_SOFTMAX_ACCUMULATOR_TYPE(slm_qk_vals[local_data_idx]) / exp_sum; slm_qk_vals[local_data_idx] = TO_OUTPUT_TYPE(qk_new); } @@ -351,17 +359,29 @@ KERNEL(pa_sdpa_opt)( REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) KERNEL(pa_sdpa_finalization_stage)( const __global INPUT3_TYPE* past_lens, +#if MULTI_TOKENS_PROCESSING + const __global INPUT6_TYPE* subsequence_begins, +#endif __global OUTPUT_TYPE* output, const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, const __global OUTPUT_TYPE* tmp_out, +#if MULTI_TOKENS_PROCESSING + const __global int* gws_subseq_mapping, +#endif const uint total_partitions_num) { const uint seq_idx = get_global_id(0); const uint head_num_idx = get_global_id(1); const uint head_size_idx = get_global_id(2); const uint sglid = get_sub_group_local_id(); +#if MULTI_TOKENS_PROCESSING + const int subsequence_idx = gws_subseq_mapping[seq_idx]; + const int subsequence_begin = subsequence_begins[subsequence_idx]; + const uint seq_len = past_lens[subsequence_idx] + 1 + (seq_idx - subsequence_begin); +#else const uint seq_len = past_lens[seq_idx] + 1; +#endif const uint num_of_partitions = CEIL_DIV(seq_len, SEQ_LEN_PARTITION_SIZE); diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp index 2f8ced75052c4f..161c37ab3d3bf7 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp @@ -14,6 +14,7 @@ enum KernelsTypes { SINGLE_TOKEN = 0, MULTI_TOKENS, FINALIZATION, + FINALIZATION_MULTI_TOKENS, TOTAL_KERNELS_NUM }; @@ -32,6 +33,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) { kernel_name += "_multi_tokens_seq"; } else if (type == KernelsTypes::FINALIZATION) { kernel_name += "_finalization"; + } else if (type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { + kernel_name += "_finalization_multi_tokens_seq"; } return kernel_name; @@ -45,7 +48,8 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { const auto& params = static_cast(p); const std::vector kernels_type = { KernelsTypes::SINGLE_TOKEN, KernelsTypes::MULTI_TOKENS, - KernelsTypes::FINALIZATION }; + KernelsTypes::FINALIZATION, + KernelsTypes::FINALIZATION_MULTI_TOKENS }; KernelData kd = KernelData::Default(params, kernels_type.size()); kd.needs_sub_kernels_sync = true; @@ -68,6 +72,9 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { } else if (kernel_type == KernelsTypes::FINALIZATION) { // FINALIZATION kernel uses only the past_lens data input inputs_num = 1; + } else if (kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { + // FINALIZATION_MULTI_TOKENS kernel uses past_lens data input and subsequence_begins + inputs_num = 2; } auto& kernel = kd.kernels[kd_kernels_idx++]; @@ -89,11 +96,13 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); - if (kernel_type == KernelsTypes::MULTI_TOKENS) { + if (kernel_type == KernelsTypes::MULTI_TOKENS || kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { // MULTIPLE_TOKENS kernels needs additional information related to mapping // launched kernel instances to subsequence indexes kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); - } else if (kernel_type == KernelsTypes::FINALIZATION) { + } + + if (kernel_type == KernelsTypes::FINALIZATION || kernel_type == KernelsTypes::FINALIZATION_MULTI_TOKENS) { kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // Remove unused shape_info argument at finalization stage @@ -164,7 +173,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& jit.AddConstant(MakeJitConstant("BROADCAST_GROUP_SIZE", config.group_size)); } - auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION ? 1 : 0; + auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS ? 1 : 0; jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1)); if (config.has_scale_val) @@ -173,7 +182,7 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& if (params.conf.has_alibi_input) jit.AddConstant(MakeJitConstant("HAS_ALIBI", 1)); - if (kernel_idx == KernelsTypes::MULTI_TOKENS) + if (kernel_idx == KernelsTypes::MULTI_TOKENS || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS) jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1)); jit.Merge(MakeTypeJitConstants(softmax_acc_dt, "SOFTMAX_ACCUMULATOR")); @@ -211,7 +220,7 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) { const auto& prim_params = static_cast(params); - const size_t expected_kernels_num = 3; + const size_t expected_kernels_num = 4; OPENVINO_ASSERT(kd.kernels.size() == expected_kernels_num, "[GPU] Invalid kernels size for update dispatch data func of SDPA kernel"); auto dispatch_data1 = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN); @@ -230,13 +239,19 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons auto dispatch_data2 = SetDefault(prim_params, KernelsTypes::FINALIZATION); kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.global = dispatch_data2.gws; kd.kernels[KernelsTypes::FINALIZATION].params.workGroups.local = dispatch_data2.lws; - kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1; + kd.kernels[KernelsTypes::FINALIZATION].skip_execution = num_of_partitions == 1 || prim_params.multi_tokens_mode; + + kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.global = dispatch_data2.gws; + kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.workGroups.local = dispatch_data2.lws; + kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].skip_execution = num_of_partitions == 1 || !prim_params.multi_tokens_mode; ScalarDescriptor num_of_partitions_scalar; num_of_partitions_scalar.t = ScalarDescriptor::Types::UINT32; num_of_partitions_scalar.v.u32 = static_cast(num_of_partitions); kd.kernels[KernelsTypes::FINALIZATION].params.scalars.resize(1); kd.kernels[KernelsTypes::FINALIZATION].params.scalars[0] = num_of_partitions_scalar; + kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars.resize(1); + kd.kernels[KernelsTypes::FINALIZATION_MULTI_TOKENS].params.scalars[0] = num_of_partitions_scalar; auto buf_dt_size = BytesPerElement(softmax_acc_dt); auto buf_elements_count = sequences_number * prim_params.conf.heads_num * num_of_partitions;