diff --git a/src/core/src/op/paged_attention.cpp b/src/core/src/op/paged_attention.cpp index 8d13e8411fd0c5..f9835c8098f2cd 100644 --- a/src/core/src/op/paged_attention.cpp +++ b/src/core/src/op/paged_attention.cpp @@ -176,8 +176,12 @@ void PagedAttentionExtension::validate_and_infer_types() { "Input `rotation_trig_lut` should either have rank 2 or be omitted, but it has rank ", get_input_partial_shape(15).rank().get_length(), "."); + if (get_input_element_type(15) != element::f32) { + std::cout << "PA WARNING: " << "Element type of `rotation_trig_lut` input should be f32, but it is " << + get_input_element_type(15) << ".\n"; + } NODE_VALIDATION_CHECK(this, - get_input_element_type(15).is_dynamic() || get_input_element_type(15) == element::f32, + get_input_element_type(15).is_dynamic() || get_input_element_type(15).is_real(), "Element type of `rotation_trig_lut` input should be f32, but it is ", get_input_element_type(15), "."); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp index 35ab6b16726ae7..a5514aad1b726e 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/paged_attention.hpp @@ -21,7 +21,7 @@ struct paged_attention : public primitive_base { paged_attention(const primitive_id& id, const std::vector& inputs) : primitive_base(id, inputs) { - OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 15), + OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 16), "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size()); } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp index 5471a4804c0280..2907973639f2a6 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp @@ -12,6 +12,7 @@ #include "sdpa/sdpa_kernel_base.h" #include "sdpa/sdpa_kernel_selector.h" +#include "sdpa/pa_kv_cache_rotate_kernel_ref.h" #include "sdpa/pa_kv_cache_update_kernel_ref.h" #include "sdpa/pa_sdpa_kernel_opt.h" @@ -28,6 +29,9 @@ struct paged_attention_impl : multi_stage_primitive { using pa_sdpa_kernel_selector_t = kernel_selector::pa_sdpa_kernel_selector; using pa_sdpa_kernel_params_t = kernel_selector::pa_sdpa_params; + using kv_cache_rotate_kernel_selector_t = kernel_selector::kv_cache_rotate_kernel_selector; + using kv_cache_rotate_kernel_params_t = kernel_selector::kv_cache_rotate_params; + using kv_cache_update_kernel_selector_t = kernel_selector::kv_cache_update_kernel_selector; using kv_cache_update_kernel_params_t = kernel_selector::kv_cache_update_params; @@ -50,6 +54,7 @@ struct paged_attention_impl : multi_stage_primitive { KV_CACHE_UPDATE, SDPA, PA_SDPA, + KV_CACHE_ROTATE, }; bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const override { @@ -127,10 +132,16 @@ struct paged_attention_impl : multi_stage_primitive { const auto desc = instance.get_node().as().get_primitive(); kernel_arguments_data args; - if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA) + if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA || stage == Stage::KV_CACHE_ROTATE) args.shape_info = instance.shape_info_memory_ptr(); - if (stage == Stage::KV_CACHE_UPDATE) { + if (stage == Stage::KV_CACHE_ROTATE) { + args.inputs = { instance.rotated_block_indices_ptr(), + instance.rotation_deltas_ptr(), + instance.rotation_trig_lut_ptr() }; + + args.outputs = { instance.key_cache_memory_ptr() }; + } else if (stage == Stage::KV_CACHE_UPDATE) { args.inputs = { instance.key_memory_ptr(), instance.value_memory_ptr(), instance.past_lens_memory_ptr(), @@ -232,7 +243,7 @@ struct paged_attention_impl : multi_stage_primitive { if (stage == Stage::PA_SDPA) { internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size(); internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes.size(); - } else { + } else if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA) { internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size(); if (stage == Stage::SDPA) { @@ -305,6 +316,17 @@ struct paged_attention_impl : multi_stage_primitive { const auto stage = get_paged_attention_stage(*instance.get_impl_params()); const auto is_mixed_mode = stage == PagedAttentionStage::MIXED; + const auto& desc = instance.get_impl_params()->typed_desc(); + if (desc->has_rotation_coefficients) { + int SKIP_ROTATION = 0; + if (const auto env_var = std::getenv("SKIP_ROTATION")) { + std::istringstream ss(env_var); + ss >> SKIP_ROTATION; + } + if (!SKIP_ROTATION) + execute_stage(events, instance, res_events, Stage::KV_CACHE_ROTATE, is_mixed_mode); + } + GPU_DEBUG_TRACE_DETAIL << "Stage::KV_CACHE_UPDATE\n"; execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE, is_mixed_mode); @@ -314,7 +336,6 @@ struct paged_attention_impl : multi_stage_primitive { execute_stage(dep_events, instance, res_events, Stage::SDPA, is_mixed_mode); } - const auto& desc = instance.get_impl_params()->typed_desc(); if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED || desc->has_scores_output()) { GPU_DEBUG_TRACE_DETAIL << "stage: " << static_cast(stage) << " " << is_mixed_mode << "\n"; execute_stage(dep_events, instance, res_events, Stage::PA_SDPA, is_mixed_mode); @@ -428,6 +449,8 @@ struct paged_attention_impl : multi_stage_primitive { config.has_const_scale_val = false; } + config.has_rotation_coefficients_input = desc->has_rotation_coefficients; + if (desc->heads_num != desc->kv_heads_num) { config.broadcast_axis = 1; config.group_size = desc->heads_num / desc->kv_heads_num; @@ -447,6 +470,42 @@ struct paged_attention_impl : multi_stage_primitive { return config; } + static kv_cache_rotate_kernel_params_t get_kv_cache_rotate_kernel_params(const kernel_impl_params& impl_param, + const kernel_selector::MultiDataTensor& input_tensors, + bool is_dynamic = false) { + auto params = get_default_params(impl_param, is_dynamic); + + const auto& key_cache_tensor = input_tensors[3]; + const auto& rotated_block_indices_tensor = input_tensors[13]; + const auto& rotation_deltas_tensor = input_tensors[14]; + const auto& rotation_trig_lut_tensor = input_tensors[15]; + + const auto inputs_number = 3; + const auto outputs_number = 1; + params.inputs.resize(inputs_number); + params.outputs.resize(outputs_number); + params.inputs[0] = rotated_block_indices_tensor; + params.inputs[1] = rotation_deltas_tensor; + params.inputs[2] = rotation_trig_lut_tensor; + params.outputs[0] = key_cache_tensor; + + params.conf = get_sdpa_configuration(impl_param, is_dynamic); + + const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; + std::map in_tensor_to_offset_map = { + {0, in_offsets_map.at(13)}, + {1, in_offsets_map.at(14)}, + {2, in_offsets_map.at(15)}, + }; + std::map out_tensor_to_offset_map = { + {0, in_offsets_map.at(3)}, + }; + + params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map); + + return params; + } + static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, const kernel_selector::MultiDataTensor& input_tensors, @@ -687,6 +746,15 @@ struct paged_attention_impl : multi_stage_primitive { for (const auto& input_layout : impl_param.input_layouts) input_tensors.emplace_back(convert_data_tensor(input_layout)); + if (desc->has_rotation_coefficients) { + auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic()); + (_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func)(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]); + + if (_kernels_data[Stage::KV_CACHE_ROTATE].kernels[0].skip_execution == false) { + std::cout << "GPU: Rotate KV-cache\n"; + } + } + auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); (_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]); @@ -709,6 +777,7 @@ struct paged_attention_impl : multi_stage_primitive { for (const auto& input_layout : impl_param.input_layouts) input_tensors.emplace_back(convert_data_tensor(input_layout)); + const auto& desc = impl_param.typed_desc(); auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic()); auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance(); kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params)); @@ -721,10 +790,14 @@ struct paged_attention_impl : multi_stage_primitive { auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance(); kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params)); - auto pa_impl = cldnn::make_unique(kernels_data); + if (desc->has_rotation_coefficients) { + auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic()); + auto& kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance(); + kernels_data.push_back(kv_cache_rotate_kernel_selector.get_best_kernel(kv_cache_rotate_kernel_params)); + } - // TODO: Check if this is enough - const auto& desc = impl_param.typed_desc(); + // TODO: Check if this is enough for all the cases + auto pa_impl = cldnn::make_unique(kernels_data); pa_impl->has_scores_output = desc->has_scores_output(); return pa_impl; diff --git a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h index dd120127b171d8..be41bbb51a819b 100644 --- a/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h @@ -64,6 +64,9 @@ class typed_primitive_inst : public typed_primitive_inst_base

(p); + kd.needs_sub_kernels_sync = false; + GetUpdateDispatchDataFunc(kd); + + const auto& params = static_cast(p); + const auto dispatch_data = SetDefault(params); + const auto entry_point = GetEntryPoint(kernelName, params.layerID, p); + const auto jit_constants = GetJitConstants(params); + const auto jit = CreateJit(kernelName, jit_constants, entry_point); + + auto& kernel = kd.kernels[0]; + FillCLKernelData(kernel, + dispatch_data, + params.engineInfo, + kernelName, + jit, + entry_point, + {}, + false, + false, + static_cast(params.inputs.size()), + GetFusedPrimitiveInputsCount(params), + static_cast(params.outputs.size()), + params.is_shape_agnostic); + + return {kd}; +} + +ParamsKey KVCacheRotateKernelRef::GetSupportedKey() const { + ParamsKey k; + + k.EnableInputDataType(Datatype::F16); + k.EnableInputDataType(Datatype::F32); + k.EnableInputDataType(Datatype::INT32); + + k.EnableOutputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::F32); + k.EnableOutputDataType(Datatype::INT32); + + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bfzyx); + + k.EnableDifferentTypes(); + k.EnableTensorOffset(); + k.EnableTensorPitches(); + k.EnableBatching(); + k.EnableDynamicShapesSupport(); + + return k; +} + +bool KVCacheRotateKernelRef::Validate(const Params& params) const { + if (params.GetType() != KernelType::PA_KV_CACHE_ROTATE) + return false; + + const auto& kernel_params = dynamic_cast(params); + if (kernel_params.inputs.size() != 3) + return false; + + if (kernel_params.outputs.size() != 1) + return false; + + if (!kernel_params.conf.is_paged_attention) + return false; + + if (kernel_params.conf.paged_attention_block_size != static_cast(paged_attention_block_size)) + return false; + + return true; +} + +JitConstants KVCacheRotateKernelRef::GetJitConstants(const kv_cache_rotate_params& params) const { + JitConstants jit = MakeBaseParamsJitConstants(params); + + jit.AddConstant(MakeJitConstant("HEAD_SIZE", params.conf.head_size)); + jit.AddConstant(MakeJitConstant("HEADS_NUM", params.conf.heads_num)); + jit.AddConstant(MakeJitConstant("KV_HEADS_NUM", params.conf.kv_heads_num)); + jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size)); + jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size)); + + return jit; +} + +CommonDispatchData KVCacheRotateKernelRef::SetDefault(const kv_cache_rotate_params& params) { + CommonDispatchData dispatch_data; + + const auto& rotated_block_indices_input = params.inputs[0]; + if (!rotated_block_indices_input.is_dynamic()) { + auto heads_number = static_cast(params.conf.kv_heads_num); + auto blocks_to_rotate = static_cast(rotated_block_indices_input.Batch().v); + + dispatch_data.gws = { subgroup_size, + heads_number, + blocks_to_rotate }; + dispatch_data.lws = { subgroup_size, heads_number, 1 }; + } + + return dispatch_data; +} + +void KVCacheRotateKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const { + kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) { + const auto& prim_params = static_cast(params); + const auto& rotated_block_indices_input = prim_params.inputs[0]; + + auto dispatch_data = SetDefault(prim_params); + + OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func"); + kd.kernels[0].params.workGroups.global = dispatch_data.gws; + kd.kernels[0].params.workGroups.local = dispatch_data.lws; + kd.kernels[0].skip_execution = rotated_block_indices_input.Batch().v == 0; + }; +} + +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_rotate_kernel_ref.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_rotate_kernel_ref.h new file mode 100644 index 00000000000000..514a209e00d03f --- /dev/null +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_kv_cache_rotate_kernel_ref.h @@ -0,0 +1,33 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "kernel_base_opencl.h" +#include "sdpa_kernel_base.h" + +namespace kernel_selector { + +struct kv_cache_rotate_params : base_params { + kv_cache_rotate_params() : base_params(KernelType::PA_KV_CACHE_ROTATE) {} + + bool is_prefill = false; + sdpa_configuration conf; +}; + +class KVCacheRotateKernelRef : public KernelBaseOpenCL { +public: + KVCacheRotateKernelRef() : KernelBaseOpenCL{"pa_kv_cache_rotate_ref"} {} + KernelsData GetKernelsData(const Params& params) const override; + ParamsKey GetSupportedKey() const override; + virtual ~KVCacheRotateKernelRef() {} + +protected: + bool Validate(const Params& params) const override; + JitConstants GetJitConstants(const kv_cache_rotate_params& kernel_params) const; + static CommonDispatchData SetDefault(const kv_cache_rotate_params& kernel_params); + void GetUpdateDispatchDataFunc(KernelData& kd) const override; +}; + +} // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp index e65fd7fd10976b..20cfecd2c49121 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.cpp @@ -9,6 +9,7 @@ #include "pa_sdpa_kernel_opt.h" #include "pa_kv_cache_update_kernel_ref.h" +#include "pa_kv_cache_rotate_kernel_ref.h" namespace kernel_selector { @@ -32,6 +33,14 @@ KernelsData kv_cache_update_kernel_selector::GetBestKernels(const Params& params return GetNaiveBestKernel(params, KernelType::PA_KV_CACHE_UPDATE); } +kv_cache_rotate_kernel_selector::kv_cache_rotate_kernel_selector() { + Attach(); +} + +KernelsData kv_cache_rotate_kernel_selector::GetBestKernels(const Params& params) const { + return GetNaiveBestKernel(params, KernelType::PA_KV_CACHE_ROTATE); +} + pa_sdpa_kernel_selector::pa_sdpa_kernel_selector() { Attach(); } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.h b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.h index ea2a948e555268..e90d100f17d14c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.h +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/sdpa_kernel_selector.h @@ -35,6 +35,20 @@ class kv_cache_update_kernel_selector : public kernel_selector_base { KernelsData GetBestKernels(const Params& params) const override; }; +class kv_cache_rotate_kernel_selector : public kernel_selector_base { +public: + static kv_cache_rotate_kernel_selector& Instance() { + static kv_cache_rotate_kernel_selector instance_; + return instance_; + } + + kv_cache_rotate_kernel_selector(); + + virtual ~kv_cache_rotate_kernel_selector() {} + + KernelsData GetBestKernels(const Params& params) const override; +}; + class pa_sdpa_kernel_selector : public kernel_selector_base { public: static pa_sdpa_kernel_selector& Instance() { diff --git a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp index ee70a1850aee90..4540cd97107231 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp @@ -22,7 +22,7 @@ namespace ov { namespace intel_gpu { static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared_ptr& op) { - validate_inputs_count(op, {13}); + validate_inputs_count(op, {13, 16}); auto inputs = p.GetInputInfo(op); auto prim = cldnn::paged_attention(layer_type_name_ID(op), inputs); @@ -48,7 +48,6 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared const size_t scale_idx = 9; const size_t alibi_idx = 11; - const size_t rotation_coefficients_idx = 13; std::shared_ptr scale_const = std::dynamic_pointer_cast(op->get_input_node_shared_ptr(scale_idx)); if (scale_const) { @@ -62,6 +61,12 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared OPENVINO_ASSERT(alibi_const != nullptr); prim.has_alibi = ov::shape_size(alibi_const->get_output_shape(0)) > 0; + prim.has_rotation_coefficients = op->get_input_size() == 16; + + if (prim.has_rotation_coefficients) { + std::cout << "Has rotation coefficients\n"; + } + prim.num_outputs = 1; if (op->get_output_size() > 1) { const auto scores_output_idx = 1; diff --git a/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp b/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp index f87f9af5275722..36a19898d770a3 100644 --- a/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp +++ b/src/plugins/intel_gpu/src/plugin/sync_infer_request.cpp @@ -793,6 +793,8 @@ std::vector SyncInferRequest::prepare_input(const std::string auto need_lockable_mem = network->does_node_need_lockable_output(internal_name); + GPU_DEBUG_TRACE_DETAIL << "need_lockable_mem=" << need_lockable_mem << "\n"; + OPENVINO_ASSERT(pshape.compatible(ov::PartialShape(user_tensor->get_shape())) || is_batched_input(port), "[GPU] The input tensor size is not equal to model port shape, can't handle input tensor with name: ", internal_name,