From bae45a2790ddd5bfbb3f0c3cb3cc73190bf5638f Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Tue, 27 Aug 2024 20:57:07 +0400 Subject: [PATCH] [GPU] PagedAttention prefix support --- .../src/graph/impls/ocl/paged_attention.cpp | 170 +++++++++++------- .../src/graph/include/paged_attention_inst.h | 13 +- .../intel_gpu/src/graph/paged_attention.cpp | 91 +++++++++- .../cl_kernels/pa_kv_cache_update_ref.cl | 20 ++- .../kernel_selector/cl_kernels/pa_sdpa_opt.cl | 30 +++- .../sdpa/pa_kv_cache_update_kernel_ref.cpp | 17 +- .../sdpa/pa_kv_cache_update_kernel_ref.h | 1 + .../kernels/sdpa/pa_sdpa_kernel_opt.cpp | 37 +++- .../kernels/sdpa/pa_sdpa_kernel_opt.h | 3 +- .../src/plugin/sync_infer_request.cpp | 31 ++++ 10 files changed, 319 insertions(+), 94 deletions(-) diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp index e7a64185969a2d..fae7a7525b9953 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp @@ -80,6 +80,7 @@ struct paged_attention_impl : multi_stage_primitive { layout inbuf_layout = {dtype, format::bfyx, // simple linear format (flattern to x channel) {1, 1, 1, (tensor::value_type)(size / bpp)}}; layouts.push_back(inbuf_layout); + GPU_DEBUG_TRACE_DETAIL << "add layout with size " << size << "\n"; } }; @@ -94,7 +95,7 @@ struct paged_attention_impl : multi_stage_primitive { const auto desc = instance.get_node().as().get_primitive(); kernel_arguments_data args; - if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA || (stage == Stage::PA_SDPA && kernel_idx == 0)) + if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA) args.shape_info = instance.shape_info_memory_ptr(); if (stage == Stage::KV_CACHE_UPDATE) { @@ -102,7 +103,8 @@ struct paged_attention_impl : multi_stage_primitive { instance.value_memory_ptr(), instance.past_lens_memory_ptr(), instance.block_indices_memory_ptr(), - instance.block_indices_begins_memory_ptr() }; + instance.block_indices_begins_memory_ptr(), + instance.subsequence_begins_memory_ptr() }; args.outputs = { instance.key_cache_memory_ptr(), instance.value_cache_memory_ptr() }; @@ -118,7 +120,9 @@ struct paged_attention_impl : multi_stage_primitive { args.outputs = { instance.output_memory_ptr(0) }; } else if (stage == Stage::PA_SDPA) { - if (kernel_idx == 0) { + if (kernel_idx == 0 || kernel_idx == 1) { + args.shape_info = instance.shape_info_memory_ptr(); + args.inputs = { instance.input_memory_ptr(0), instance.key_cache_memory_ptr(), instance.value_cache_memory_ptr(), @@ -126,11 +130,17 @@ struct paged_attention_impl : multi_stage_primitive { instance.block_indices_memory_ptr(), instance.block_indices_begins_memory_ptr() }; + if (kernel_idx == 1) { + // Multi tokens kernel version has additional subsequence_begins_memory memory + // dependency + args.inputs.push_back(instance.subsequence_begins_memory_ptr()); + } + if (desc->has_alibi) { args.inputs.push_back(instance.alibi_memory_ptr()); } } else { - args.inputs = { instance.past_lens_memory_ptr(), }; + args.inputs = { instance.past_lens_memory_ptr() }; } args.outputs = { instance.output_memory_ptr(0) }; @@ -140,7 +150,8 @@ struct paged_attention_impl : multi_stage_primitive { } std::set get_lockable_internal_buffers() const override { - return std::set{ 0, 1, 2 }; /* SDPA and KV_CACHE_UPDATE indexes configuration */ + return std::set{ 0, 1, 2, /* SDPA and KV_CACHE_UPDATE indexes configuration */ + 6, /* PA_SDPA multiple tokens mode */ }; }; void execute_stage(const std::vector& events, paged_attention_inst& instance, std::vector& all_events, size_t stage) { @@ -199,19 +210,70 @@ struct paged_attention_impl : multi_stage_primitive { event::ptr execute_impl(const std::vector& events, paged_attention_inst& instance) override { std::vector res_events; + const auto stage = get_paged_attention_stage(*instance.get_impl_params()); execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE); std::vector dep_events(res_events.begin(), res_events.end()); - if (is_prefill_stage(*instance.get_impl_params())) { + if (stage == PagedAttentionStage::PREFILL) { execute_stage(dep_events, instance, res_events, Stage::SDPA); - } else { + } else if (stage == PagedAttentionStage::GENERATION || stage == PagedAttentionStage::MIXED) { execute_stage(dep_events, instance, res_events, Stage::PA_SDPA); } return aggregate_events(res_events, instance.get_network().get_stream(), res_events.size() > 1); } + static int64_t get_aligned_seq_len(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, int64_t target_seq_len_block_size = 16) { + // Since at prefill stage Q, K, V inputs may contain multiple sequences with arbitrary + // target sequence lengths each (shape is [sequences_num * target_seq_len, num_heads * head_size]), + // to apply blocking to the first dimension (target_seq_len of each sequence), we need to calculate aligned total + // target sequence length for proper kernel dispatching + // For instance, if input contains two sequences with 35 and 28 sequence lengths each, + // the Q, K, V inputs at prefill stage will have shapes [35 + 28, num_heads * head_size]; considering kernel's + // target_seq_len_block_size equals 16, we need to launch kernel instances for the following ranges: + // [0, 15], [16, 31], [32, 34], [35, 50], [51, 62], so aligned target_seq_len_block_size should be 5 * 16 = 80, + // and 5 kernels instances should be launched (for each range, some of them containing leftovers) + // + // In general, to obtain length for each sequence, we have to parse subsequence_begins input, + // which contains begin and end indexes for each sequence (for above example it will contain three values: {0, 35, 63}) + // However, as long as kernel's target_seq_len_block_size matches with vLLM's block_size, + // we can reuse block_indices_shape[0] size to determine total aligned sequences length size, avoiding + // memory access at runtime, because vLLM internally uses similar logic to configure blocks for KV cache + + auto calculate_aligned_seq_len = [&]() { + const auto& input_mem = impl_param.memory_deps; + const auto subsequence_begins_input_idx = 6; + const auto subsequence_begins_mem = input_mem.at(subsequence_begins_input_idx); + mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, *impl_param.strm); + + auto aligned_seq_len = 0; + for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { + auto prompt_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; + aligned_seq_len += align_to(prompt_length, target_seq_len_block_size); + } + + return aligned_seq_len; + }; + + int64_t aligned_seq_len = 0; + if (stage == PagedAttentionStage::PREFILL) { + const auto desc = impl_param.typed_desc(); + if (static_cast(paged_attention::block_size) == target_seq_len_block_size) { + const auto block_indices_input_idx = 7; + const auto& block_indices_ps = impl_param.get_input_layout(block_indices_input_idx).get_partial_shape(); + + aligned_seq_len = block_indices_ps[0].get_length() * target_seq_len_block_size; + } else { + aligned_seq_len = calculate_aligned_seq_len(); + } + } else { + aligned_seq_len = calculate_aligned_seq_len(); + } + + return aligned_seq_len; + } + static kernel_selector::sdpa_configuration get_sdpa_configuration(const kernel_impl_params& impl_param) { kernel_selector::sdpa_configuration config; @@ -237,7 +299,9 @@ struct paged_attention_impl : multi_stage_primitive { return config; } - static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic = false) { + static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param, + const PagedAttentionStage& stage, + bool is_dynamic = false) { auto params = get_default_params(impl_param, is_dynamic); const auto& key_layout = impl_param.get_input_layout(1); @@ -247,8 +311,9 @@ struct paged_attention_impl : multi_stage_primitive { const auto& past_lens_layout = impl_param.get_input_layout(5); const auto& block_indices_layout = impl_param.get_input_layout(7); const auto& block_indices_begins_layout = impl_param.get_input_layout(8); + const auto& subsequence_begins_layout = impl_param.get_input_layout(6); - const auto inputs_number = 5; + const auto inputs_number = 6; const auto outputs_number = 2; params.inputs.resize(inputs_number); params.outputs.resize(outputs_number); @@ -257,11 +322,17 @@ struct paged_attention_impl : multi_stage_primitive { params.inputs[2] = convert_data_tensor(past_lens_layout); params.inputs[3] = convert_data_tensor(block_indices_layout); params.inputs[4] = convert_data_tensor(block_indices_begins_layout); + params.inputs[5] = convert_data_tensor(subsequence_begins_layout); params.outputs[0] = convert_data_tensor(key_cache_layout); params.outputs[1] = convert_data_tensor(value_cache_layout); params.conf = get_sdpa_configuration(impl_param); + params.is_prefill = stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED; + + if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic) + params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage); + const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; std::map in_tensor_to_offset_map = { {0, in_offsets_map.at(1)}, @@ -269,6 +340,7 @@ struct paged_attention_impl : multi_stage_primitive { {2, in_offsets_map.at(5)}, {3, in_offsets_map.at(7)}, {4, in_offsets_map.at(8)}, + {5, in_offsets_map.at(6)}, }; std::map out_tensor_to_offset_map = { {0, in_offsets_map.at(3)}, @@ -280,7 +352,7 @@ struct paged_attention_impl : multi_stage_primitive { return params; } - static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, bool is_dynamic = false) { + static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) { auto params = get_default_params(impl_param, is_dynamic); const auto& query_layout = impl_param.get_input_layout(0); @@ -321,54 +393,15 @@ struct paged_attention_impl : multi_stage_primitive { if (has_alibi) in_tensor_to_offset_map.insert({4, in_offsets_map.at(11)}); - if (is_prefill_stage(impl_param) && !is_dynamic) { - auto get_aligned_seq_len = [&](int64_t target_seq_len_block_size = 16) { - // Since at prefill stage Q, K, V inputs may contain multiple sequences with arbitrary - // target sequence lengths each (shape is [sequences_num * target_seq_len, num_heads * head_size]), - // to apply blocking to the first dimension (target_seq_len of each sequence), we need to calculate aligned total - // target sequence length for proper kernel dispatching - // For instance, if input contains two sequences with 35 and 28 sequence lengths each, - // the Q, K, V inputs at prefill stage will have shapes [35 + 28, num_heads * head_size]; considering kernel's - // target_seq_len_block_size equals 16, we need to launch kernel instances for the following ranges: - // [0, 15], [16, 31], [32, 34], [35, 50], [51, 62], so aligned target_seq_len_block_size should be 5 * 16 = 80, - // and 5 kernels instances should be launched (for each range, some of them containing leftovers) - // - // In general, to obtain length for each sequence, we have to parse subsequence_begins input, - // which contains begin and end indexes for each sequence (for above example it will contain three values: {0, 35, 63}) - // However, as long as kernel's target_seq_len_block_size matches with vLLM's block_size, - // we can reuse block_indices_shape[0] size to determine total aligned sequences length size, avoiding - // memory access at runtime, because vLLM internally uses similar logic to configure blocks for KV cache - - int64_t aligned_seq_len = 0; - const auto desc = impl_param.typed_desc(); - if (static_cast(paged_attention::block_size) == target_seq_len_block_size) { - const auto block_indices_input_idx = 7; - const auto& block_indices_ps = impl_param.get_input_layout(block_indices_input_idx).get_partial_shape(); - - aligned_seq_len = block_indices_ps[0].get_length() * target_seq_len_block_size; - } else { - const auto& input_mem = impl_param.memory_deps; - const auto subsequence_begins_mem = input_mem.at(6); - mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, *impl_param.strm); - - for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) { - auto prompt_length = subsequence_begins_mem_lock[i + 1] - subsequence_begins_mem_lock[i]; - aligned_seq_len += align_to(prompt_length, target_seq_len_block_size); - } - } - - return aligned_seq_len; - }; - - params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(); - } + if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic) + params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage); params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); return params; } - static pa_sdpa_kernel_params_t get_pa_sdpa_params(const kernel_impl_params& impl_param, bool is_dynamic = false) { + static pa_sdpa_kernel_params_t get_pa_sdpa_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) { auto params = get_default_params(impl_param, is_dynamic); const auto& query_layout = impl_param.get_input_layout(0); @@ -377,10 +410,11 @@ struct paged_attention_impl : multi_stage_primitive { const auto& past_lens_layout = impl_param.get_input_layout(5); const auto& block_indices_layout = impl_param.get_input_layout(7); const auto& block_indices_begins_layout = impl_param.get_input_layout(8); + const auto& subsequence_begins_layout = impl_param.get_input_layout(6); const auto& alibi_layout = impl_param.get_input_layout(11); const auto has_alibi = alibi_layout.count() > 0; - auto inputs_number = 6; + auto inputs_number = 7; if (has_alibi) inputs_number++; @@ -392,12 +426,15 @@ struct paged_attention_impl : multi_stage_primitive { params.inputs[input_idx++] = convert_data_tensor(past_lens_layout); params.inputs[input_idx++] = convert_data_tensor(block_indices_layout); params.inputs[input_idx++] = convert_data_tensor(block_indices_begins_layout); + params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout); params.conf = get_sdpa_configuration(impl_param); if (has_alibi) params.inputs[input_idx++] = convert_data_tensor(alibi_layout); - if (!is_prefill_stage(impl_param) && !is_dynamic) { + params.multi_tokens_mode = stage == PagedAttentionStage::MIXED; + + if ((stage == PagedAttentionStage::GENERATION || stage == PagedAttentionStage::MIXED) && !is_dynamic) { const auto& input_mem = impl_param.memory_deps; const auto max_context_len = input_mem.at(12); mem_lock max_context_len_mem_lock(max_context_len, *impl_param.strm); @@ -414,13 +451,14 @@ struct paged_attention_impl : multi_stage_primitive { {3, in_offsets_map.at(5)}, {4, in_offsets_map.at(7)}, {5, in_offsets_map.at(8)}, + {6, in_offsets_map.at(6)}, }; std::map out_tensor_to_offset_map = { {0, out_offsets_map.at(0)}, }; if (has_alibi) - in_tensor_to_offset_map.insert({6, in_offsets_map.at(11)}); + in_tensor_to_offset_map.insert({7, in_offsets_map.at(11)}); params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); @@ -428,30 +466,36 @@ struct paged_attention_impl : multi_stage_primitive { } void update_dispatch_data(const kernel_impl_params& impl_param) override { - auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic()); + const auto stage = get_paged_attention_stage(impl_param); + + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, impl_param.is_dynamic()); (_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]); - if (is_prefill_stage(impl_param)) { - auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic()); + if (stage == PagedAttentionStage::PREFILL) { + auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, impl_param.is_dynamic()); (_kernels_data[Stage::SDPA].update_dispatch_data_func)(sdpa_kernel_params, _kernels_data[Stage::SDPA]); - } else { - auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, impl_param.is_dynamic()); + } else if (stage == PagedAttentionStage::GENERATION || stage == PagedAttentionStage::MIXED) { + auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, impl_param.is_dynamic()); (_kernels_data[Stage::PA_SDPA].update_dispatch_data_func)(pa_sdpa_kernel_params, _kernels_data[Stage::PA_SDPA]); } } static std::unique_ptr create(const typed_program_node& arg, const kernel_impl_params& impl_param) { std::vector kernels_data; + const auto stage = PagedAttentionStage::UNKNOWN; - auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, impl_param.is_dynamic()); + std::cout << "Create KV cache update\n"; + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, impl_param.is_dynamic()); auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance(); kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params)); - auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, impl_param.is_dynamic()); + std::cout << "Create SDPA\n"; + auto sdpa_kernel_params = get_sdpa_kernel_params(impl_param, stage, impl_param.is_dynamic()); auto& sdpa_kernel_selector = sdpa_kernel_selector_t::Instance(); kernels_data.push_back(sdpa_kernel_selector.get_best_kernel(sdpa_kernel_params)); - auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, impl_param.is_dynamic()); + std::cout << "Create SDPA paged attention\n"; + auto pa_sdpa_kernel_params = get_pa_sdpa_params(impl_param, stage, impl_param.is_dynamic()); auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance(); kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params)); diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index 5fec71ba9421d0..8e83438c433822 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -9,7 +9,14 @@ namespace cldnn { -bool is_prefill_stage(const kernel_impl_params& impl_param); +enum PagedAttentionStage { + GENERATION = 0, + PREFILL = 1, + MIXED = 2, + UNKNOWN = 3 +}; + +PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param); template <> struct typed_program_node : public typed_program_node_base { @@ -20,11 +27,11 @@ struct typed_program_node : public typed_program_node_base get_lockable_input_ids() const override { - return { 6 /* subsequence_begins */, 12 /* max_context_len */ }; + return { 5 /* past_lens */, 6 /* subsequence_begins */, 12 /* max_context_len */ }; } std::vector get_shape_infer_dependencies() const override { - return { 6 /* subsequence_begins */, 12 /* max_context_len */ }; + return { 5 /* past_lens */, 6 /* subsequence_begins */, 12 /* max_context_len */ }; } }; diff --git a/src/plugins/intel_gpu/src/graph/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/paged_attention.cpp index 094d83b8450867..20d6bd24888ec8 100644 --- a/src/plugins/intel_gpu/src/graph/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/paged_attention.cpp @@ -13,14 +13,45 @@ GPU_DEFINE_PRIMITIVE_TYPE_ID(paged_attention) constexpr size_t paged_attention::block_size; -bool is_prefill_stage(const kernel_impl_params& impl_param) { +PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) { const auto& query_shape = impl_param.get_input_layout(0).get_partial_shape(); const auto& past_lens_shape = impl_param.get_input_layout(5).get_partial_shape(); - if (query_shape.is_static() && past_lens_shape.is_static()) - return query_shape[0].get_length() != past_lens_shape[0].get_length(); + auto print_arr = [&](mem_lock& vec, size_t max_len, std::string name) { + std::stringstream ss; + for (size_t i = 0; i < std::min(max_len, vec.size()); i++) { + ss << vec[i] << ", "; + } + GPU_DEBUG_TRACE_DETAIL << "Array " << name << " (len=" << vec.size() << ") content: " << ss.str() << "\n"; + }; + + if (query_shape.is_static() && past_lens_shape.is_static()) { + const auto past_lens_idx = 5; + const auto& memory_deps = impl_param.memory_deps; + const auto past_lens_mem = memory_deps.at(past_lens_idx); + mem_lock past_lens_mem_lock(past_lens_mem, *impl_param.strm); + + print_arr(past_lens_mem_lock, past_lens_mem_lock.size(), "past_lens_mem_lock"); + + if (query_shape[0].get_length() == past_lens_shape[0].get_length()) { + GPU_DEBUG_TRACE_DETAIL << "get_paged_attention_stage GENERATION\n"; + return PagedAttentionStage::GENERATION; + } + + const auto past_lens_size = past_lens_mem_lock.size(); + for (size_t i = 0; i < past_lens_size; i++) { + if (past_lens_mem_lock[i] != 0) { + GPU_DEBUG_TRACE_DETAIL << "get_paged_attention_stage MIXED\n"; + return PagedAttentionStage::MIXED; + } + } + + GPU_DEBUG_TRACE_DETAIL << "get_paged_attention_stage PREFILL\n"; + return PagedAttentionStage::PREFILL; + } - return false; + GPU_DEBUG_TRACE_DETAIL << "get_paged_attention_stage UNKNOWN\n"; + return PagedAttentionStage::UNKNOWN; } layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*node*/, kernel_impl_params const& impl_param) { @@ -64,19 +95,24 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) { } void paged_attention_inst::on_execute() { - if (!is_prefill_stage(*_impl_params)) + auto stage = get_paged_attention_stage(*_impl_params); + + if (stage == PagedAttentionStage::UNKNOWN || + stage == PagedAttentionStage::GENERATION) return; OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage"); + GPU_DEBUG_TRACE_DETAIL << "paged attention stage " << stage << "\n"; + const auto blocks_indexes_start_idx = 0; const auto blocks_indexes_end_idx = 1; - const auto gws_seq_indexes_correspondence_idx = 2; + const auto blocked_gws_subseq_mapping_idx = 2; auto subsequence_begins_mem = subsequence_begins_memory_ptr(); auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx]; auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx]; - auto gws_seq_indexes_correspondence_mem = _intermediates_memory[gws_seq_indexes_correspondence_idx]; + auto blocked_gws_subseq_mapping_mem = _intermediates_memory[blocked_gws_subseq_mapping_idx]; OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32); @@ -84,7 +120,19 @@ void paged_attention_inst::on_execute() { mem_lock subsequence_begins_mem_lock(subsequence_begins_mem, stream); mem_lock blocks_indexes_start_lock(blocks_indexes_start_mem, stream); mem_lock blocks_indexes_end_lock(blocks_indexes_end_mem, stream); - mem_lock gws_seq_indexes_correspondence_lock(gws_seq_indexes_correspondence_mem, stream); + mem_lock blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream); + std::unique_ptr> sequential_gws_subseq_mapping_lock = nullptr; + + if (stage == PagedAttentionStage::MIXED) { + const auto sequential_gws_subseq_mapping_idx = 6; + + OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx, "Unexpected index, actual size = ", _intermediates_memory.size()); + + auto sequential_gws_subseq_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx]; + GPU_DEBUG_TRACE_DETAIL << "gws buffer ptr " << sequential_gws_subseq_mapping_mem->buffer_ptr() + << " intermediate buffers size=" << _intermediates_memory.size() << "\n"; + sequential_gws_subseq_mapping_lock.reset(new mem_lock(sequential_gws_subseq_mapping_mem, stream)); + } size_t index = 0; const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl @@ -99,11 +147,36 @@ void paged_attention_inst::on_execute() { blocks_indexes_start_lock[index] = block_start_pos; blocks_indexes_end_lock[index] = block_end_pos; - gws_seq_indexes_correspondence_lock[index] = static_cast(i); + blocked_gws_subseq_mapping_mem_lock[index] = static_cast(i); index++; } + + if (stage == PagedAttentionStage::MIXED) { + GPU_DEBUG_TRACE_DETAIL << "start=" << seq_start << " end=" << " lock=" << sequential_gws_subseq_mapping_lock.get() + << " " << sequential_gws_subseq_mapping_lock->size() << " " << seq_end << "\n"; + for (int32_t idx = seq_start; idx < seq_end; idx++) { + sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast(i); + } + } } + + auto print_arr = [&](mem_lock& vec, size_t max_len, std::string name) { + std::stringstream ss; + for (size_t i = 0; i < std::min(max_len, vec.size()); i++) { + ss << vec[i] << ", "; + } + GPU_DEBUG_TRACE_DETAIL << "Array " << name << " (len=" << vec.size() << ") content: " << ss.str() << "\n"; + }; + + if (stage == PagedAttentionStage::MIXED) { + print_arr(*sequential_gws_subseq_mapping_lock, sequential_gws_subseq_mapping_lock->size(), "sequential_gws_subseq_mapping_lock"); + } + + + print_arr(blocks_indexes_start_lock, blocks_indexes_start_lock.size(), "blocks_indexes_start_lock"); + print_arr(blocks_indexes_end_lock, blocks_indexes_end_lock.size(), "blocks_indexes_end_lock"); + print_arr(blocked_gws_subseq_mapping_mem_lock, blocked_gws_subseq_mapping_mem_lock.size(), "blocked_gws_subseq_mapping_mem_lock"); } void paged_attention_inst::update_shape_info_tensor(const kernel_impl_params& params) { diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl index d0c3ed5b13d859..d41e3fd5fc9a80 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl @@ -11,15 +11,17 @@ KERNEL(pa_kv_cache_update)( __global const INPUT2_TYPE* past_lens, __global const INPUT3_TYPE* block_indices, __global const INPUT4_TYPE* block_indices_begins, + __global const INPUT5_TYPE* subsequence_begins, __global OUTPUT_TYPE* key_cache_data, __global OUTPUT1_TYPE* value_cache_data, const __global int* blocked_indexes_start, const __global int* blocked_indexes_end, - const __global int* gws_seq_indexes_correspondence + const __global int* gws_seq_indexes_correspondence, + const int is_prefill_stage ) { // If the the number of new tokens equals to the number of past_lens elements, // then it's the 2nd+ iteration - if (INPUT0_BATCH_NUM == INPUT2_BATCH_NUM) { + if (!is_prefill_stage) { // 2nd+ token const uint seq_idx = (uint)get_global_id(0); const uint head_idx = (uint)get_global_id(1); @@ -27,8 +29,8 @@ KERNEL(pa_kv_cache_update)( const uint seq_len = past_lens[seq_idx]; const uint current_token_pos_in_block = seq_len % PAGED_ATTENTION_BLOCK_SIZE; - const uint seq_last_block_idx = block_indices_begins[seq_idx + 1] - 1; - const uint block_idx = block_indices[seq_last_block_idx]; + const uint seq_block_idx = block_indices_begins[seq_idx] + seq_len / PAGED_ATTENTION_BLOCK_SIZE; + const uint block_idx = block_indices[seq_block_idx]; uint key_value_in_offset = seq_idx * KV_HEADS_NUM * HEAD_SIZE + head_idx * HEAD_SIZE; @@ -69,6 +71,9 @@ KERNEL(pa_kv_cache_update)( const uint head_idx = get_global_id(1); const uint sglid = get_global_id(2); + const uint subsequence_idx = gws_seq_indexes_correspondence[block_idx]; + const uint subsequence_begin_idx = subsequence_begins[subsequence_idx]; + const uint block_start_pos = blocked_indexes_start[block_idx]; const uint block_end_pos = blocked_indexes_end[block_idx]; const uint tokens_num = block_end_pos - block_start_pos; @@ -76,7 +81,12 @@ KERNEL(pa_kv_cache_update)( uint key_value_in_offset = block_start_pos * KV_HEADS_NUM * HEAD_SIZE + head_idx * HEAD_SIZE; - uint key_out_offset = block_indices[block_idx] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + + const uint cached_blocks_num = past_lens[subsequence_idx] / PAGED_ATTENTION_BLOCK_SIZE; + const uint current_block_idx = (block_start_pos - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE; + + const uint block_offset = block_indices_begins[subsequence_idx] + cached_blocks_num + current_block_idx; + + uint key_out_offset = block_indices[block_offset] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE; uint value_out_offset = key_out_offset; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl index 0d27a797fac920..4627940e20f79b 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl @@ -24,6 +24,14 @@ #error pa_sdpa_opt.cl #endif +#if HAS_ALIBI + #if MULTI_TOKENS_PROCESSING + #define ALIBI_INPUT_TYPE INPUT7_TYPE + #else + #define ALIBI_INPUT_TYPE INPUT6_TYPE + #endif +#endif + REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE) __attribute__((reqd_work_group_size(1, 1, HEAD_SIZE))) KERNEL(pa_sdpa_opt)( @@ -34,13 +42,19 @@ KERNEL(pa_sdpa_opt)( const __global INPUT3_TYPE* past_lens, const __global INPUT4_TYPE* block_indices, const __global INPUT5_TYPE* block_indices_begins, +#if MULTI_TOKENS_PROCESSING + const __global INPUT6_TYPE* subsequence_begins, +#endif #if HAS_ALIBI - const __global INPUT6_TYPE* alibi_slopes, + const __global ALIBI_INPUT_TYPE* alibi_slopes, #endif __global OUTPUT_TYPE* output, __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums, __global SOFTMAX_ACCUMULATOR_TYPE* max_logits, __global OUTPUT_TYPE* tmp_out +#if MULTI_TOKENS_PROCESSING + , __global const int* gws_subseq_mapping +#endif ) { // Input shapes: // query: [sequences_num, HEADS_NUM * HEAD_SIZE] @@ -66,7 +80,15 @@ KERNEL(pa_sdpa_opt)( const uint batch_idx = seq_idx; +#if MULTI_TOKENS_PROCESSING + const int subsequence_idx = gws_subseq_mapping[seq_idx]; + const int subsequence_begin = subsequence_begins[subsequence_idx]; + const int subsequence_end = subsequence_begins[subsequence_idx + 1]; + const uint seq_len = past_lens[subsequence_idx] + 1 + (seq_idx - subsequence_begin); +#else + const uint subsequence_idx = seq_idx; const uint seq_len = past_lens[seq_idx] + 1; +#endif const uint partition_idx = get_group_id(2); const uint block_start_idx = partition_idx * SEQ_LEN_PARTITION_SIZE / PAGED_ATTENTION_BLOCK_SIZE; @@ -79,7 +101,7 @@ KERNEL(pa_sdpa_opt)( #ifdef STORE_QUERY_TO_SLM // SLM buffer for query inputs - __local INPUT0_TYPE slm_query[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE]; + __local INPUT0_TYPE slm_query[HEAD_SIZE]; #endif // SLM for intermediate QK results @@ -117,7 +139,7 @@ KERNEL(pa_sdpa_opt)( if (sgid < blocks_num_per_partition % SUBGROUPS_PER_WG) blocks_num++; - const uint start_block_idx = block_indices_begins[seq_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION + sgid; + const uint start_block_idx = block_indices_begins[subsequence_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION + sgid; for (uint block_num = 0; block_num < blocks_num; block_num++) { #ifdef BROADCAST_GROUP_SIZE const uint head_idx = head_num_idx / BROADCAST_GROUP_SIZE; @@ -255,7 +277,7 @@ KERNEL(pa_sdpa_opt)( blocks_num_per_partition = blocks_num_per_partition - 1; } - const uint start_block_idx = block_indices_begins[seq_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION; + const uint start_block_idx = block_indices_begins[subsequence_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION; for (uint block_num = 0; block_num < blocks_num_per_partition; block_num++) { #ifdef BROADCAST_GROUP_SIZE diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp index 312b340480dbe7..a8c5f7584d9ecf 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.cpp @@ -57,6 +57,12 @@ KernelsData KVCacheUpdateKernelRef::GetKernelsData(const Params& p) const { kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); + kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); + + ScalarDescriptor is_prefill_stage; + is_prefill_stage.t = ScalarDescriptor::Types::UINT32; + is_prefill_stage.v.u32 = static_cast(0); + kernel.params.scalars.push_back(is_prefill_stage); return {kd}; } @@ -90,7 +96,7 @@ bool KVCacheUpdateKernelRef::Validate(const Params& params) const { return false; const auto& kernel_params = dynamic_cast(params); - if (kernel_params.inputs.size() != 5) + if (kernel_params.inputs.size() != 6) return false; if (kernel_params.outputs.size() != 2) @@ -124,16 +130,15 @@ CommonDispatchData KVCacheUpdateKernelRef::SetDefault(const kv_cache_update_para const auto& key_cache = params.outputs[0]; const auto& value_cache = params.outputs[1]; if (!value_cache.is_dynamic() && !key_cache.is_dynamic()) { - bool is_prefill = params.inputs[0].Batch().v != params.inputs[2].Batch().v; + bool is_prefill = params.is_prefill; auto heads_number = static_cast(params.conf.kv_heads_num); if (is_prefill) { - const auto& block_indices_input = params.inputs[3]; - const auto blocks_number = block_indices_input.Batch().v; + const auto blocks_number = params.conf.paged_attention_aligned_seq_len / paged_attention_block_size; dispatch_data.gws = { blocks_number, heads_number, - subgroup_size}; + subgroup_size }; dispatch_data.lws = { 1, 1, subgroup_size }; } else { const auto& key_input = params.inputs[0]; @@ -159,6 +164,8 @@ void KVCacheUpdateKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const { kd.kernels[0].params.workGroups.global = dispatch_data.gws; kd.kernels[0].params.workGroups.local = dispatch_data.lws; kd.kernels[0].skip_execution = false; + + kd.kernels[0].params.scalars[0].v.s32 = static_cast(prim_params.is_prefill); }; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.h index 3520996ceba44c..020426b0c3b996 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_update_kernel_ref.h @@ -12,6 +12,7 @@ namespace kernel_selector { struct kv_cache_update_params : base_params { kv_cache_update_params() : base_params(KernelType::PA_KV_CACHE_UPDATE) {} + bool is_prefill = false; sdpa_configuration conf; }; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp index add33747152261..da6d7fad2e0d25 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.cpp @@ -12,6 +12,7 @@ namespace kernel_selector { namespace { enum KernelsTypes { SINGLE_TOKEN = 0, + MULTI_TOKENS, FINALIZATION, TOTAL_KERNELS_NUM }; @@ -27,6 +28,8 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type) { if (type == KernelsTypes::SINGLE_TOKEN) { kernel_name += "_single_token"; + } else if (type == KernelsTypes::MULTI_TOKENS) { + kernel_name += "_multi_tokens_seq"; } else if (type == KernelsTypes::FINALIZATION) { kernel_name += "_finalization"; } @@ -41,6 +44,7 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { const auto& params = static_cast(p); const std::vector kernels_type = { KernelsTypes::SINGLE_TOKEN, + KernelsTypes::MULTI_TOKENS, KernelsTypes::FINALIZATION }; KernelData kd = KernelData::Default(params, kernels_type.size()); @@ -57,7 +61,15 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { const auto jit = CreateJit(kernel_name, jit_constants, entry_point); - const size_t inputs_num = kernel_type == KernelsTypes::SINGLE_TOKEN ? static_cast(params.inputs.size()) : 1; + size_t inputs_num = static_cast(params.inputs.size()); + if (kernel_type == KernelsTypes::FINALIZATION) { + // FINALIZATION kernel uses only the past_lens data input + inputs_num = 1; + } else if (kernel_type == KernelsTypes::SINGLE_TOKEN) { + // SINGLE_TOKEN kernel excludes the subsequence_begins input + inputs_num -= 1; + } + auto& kernel = kd.kernels[kd_kernels_idx++]; FillCLKernelData(kernel, dispatch_data, @@ -76,13 +88,16 @@ KernelsData PagedAttentionSDPAKernelOpt::GetKernelsData(const Params& p) const { kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); - kd.internalBufferDataType = softmax_acc_dt; if (kernel_type == KernelsTypes::FINALIZATION) { kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); // Remove unused shape_info argument at finalization stage kernel.params.arguments.erase(kernel.params.arguments.begin()); + } else if (kernel_type == KernelsTypes::MULTI_TOKENS) { + // MULTIPLE_TOKENS kernels needs additional information related to mapping + // launched kernel instances to subsequence indexes + kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 3}); } } @@ -158,6 +173,9 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params& if (params.conf.has_alibi_input) jit.AddConstant(MakeJitConstant("HAS_ALIBI", 1)); + if (kernel_idx == KernelsTypes::MULTI_TOKENS) + jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1)); + jit.Merge(MakeTypeJitConstants(softmax_acc_dt, "SOFTMAX_ACCUMULATOR")); return jit; @@ -193,13 +211,17 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) { const auto& prim_params = static_cast(params); - const size_t expected_kernels_num = 2; + const size_t expected_kernels_num = 3; OPENVINO_ASSERT(kd.kernels.size() == expected_kernels_num, "[GPU] Invalid kernels size for update dispatch data func of SDPA kernel"); auto dispatch_data1 = SetDefault(prim_params, KernelsTypes::SINGLE_TOKEN); kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.global = dispatch_data1.gws; kd.kernels[KernelsTypes::SINGLE_TOKEN].params.workGroups.local = dispatch_data1.lws; - kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = false; + kd.kernels[KernelsTypes::SINGLE_TOKEN].skip_execution = prim_params.multi_tokens_mode; + + kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.global = dispatch_data1.gws; + kd.kernels[KernelsTypes::MULTI_TOKENS].params.workGroups.local = dispatch_data1.lws; + kd.kernels[KernelsTypes::MULTI_TOKENS].skip_execution = !prim_params.multi_tokens_mode; const auto& input = prim_params.inputs[0]; const size_t sequences_number = input.Batch().v; @@ -229,6 +251,13 @@ void PagedAttentionSDPAKernelOpt::GetUpdateDispatchDataFunc(KernelData& kd) cons kd.internalBufferSizes.push_back(buf_size); kd.internalBufferSizes.push_back(tmp_out_size); kd.internalBufferDataType = softmax_acc_dt; + + if (prim_params.multi_tokens_mode) { + auto buf_dt_size = BytesPerElement(Datatype::INT32); + auto buf_elements_count = sequences_number; + auto buf_size = Align(buf_elements_count * buf_dt_size, BytesPerElement(softmax_acc_dt)); + kd.internalBufferSizes.push_back(buf_size < 4 ? 4 : buf_size); + } }; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h index 2f0d8a52c6eda6..a2456ccd9e2af5 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_sdpa_kernel_opt.h @@ -12,8 +12,9 @@ namespace kernel_selector { struct pa_sdpa_params : base_params { pa_sdpa_params() : base_params(KernelType::PA_SDPA) {} - sdpa_configuration conf; + bool multi_tokens_mode = false; size_t max_context_len = 0; + sdpa_configuration conf; }; class PagedAttentionSDPAKernelOpt : public KernelBaseOpenCL { diff --git a/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp b/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp index 346b4471779593..467c17f811de09 100644 --- a/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp +++ b/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp @@ -726,6 +726,37 @@ std::vector SyncInferRequest::prepare_input(const std::string bool is_remote_tensor_impl = remote_tensor_impl_ptr != nullptr; bool is_usm_host_tensor = usm_host_ptr != nullptr; + auto print_arr = [&](int32_t* vec, size_t max_len, std::string name) { + std::stringstream ss; + for (size_t i = 0; i < max_len; i++) { + ss << vec[i] << ", "; + } + GPU_DEBUG_TRACE_DETAIL << "Array " << name << " (len=" << max_len << ") content: " << ss.str() << "\n"; + }; + auto print_arr2 = [&](int64_t* vec, size_t max_len, std::string name) { + std::stringstream ss; + for (size_t i = 0; i < max_len; i++) { + ss << vec[i] << ", "; + } + GPU_DEBUG_TRACE_DETAIL << "Array " << name << " (len=" << max_len << ") content: " << ss.str() << "\n"; + }; + + + if (internal_name == "parameter:past_lens" || + internal_name == "parameter:subsequence_begins" || + internal_name == "parameter:block_indices" || + internal_name == "parameter:block_indices_begins" || + internal_name == "parameter:max_context_len") { + print_arr(user_tensor->data(), user_tensor->get_size(), internal_name); + } + + + if (internal_name == "parameter:input_ids") { + print_arr2(user_tensor->data(), user_tensor->get_size(), internal_name); + } + + + GPU_DEBUG_TRACE_DETAIL << "Prepare input for " << internal_name << " (is_remote_tensor_impl ? " << is_remote_tensor_impl << ", is_usm_host_tensor ? " << is_usm_host_tensor