Skip to content

Commit

Permalink
[GPU] Initial cache rotation implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Dec 24, 2024
1 parent b92b0e8 commit 23207cb
Show file tree
Hide file tree
Showing 12 changed files with 374 additions and 11 deletions.
6 changes: 5 additions & 1 deletion src/core/src/op/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -176,8 +176,12 @@ void PagedAttentionExtension::validate_and_infer_types() {
"Input `rotation_trig_lut` should either have rank 2 or be omitted, but it has rank ",
get_input_partial_shape(15).rank().get_length(),
".");
if (get_input_element_type(15) != element::f32) {
std::cout << "PA WARNING: " << "Element type of `rotation_trig_lut` input should be f32, but it is " <<
get_input_element_type(15) << ".\n";
}
NODE_VALIDATION_CHECK(this,
get_input_element_type(15).is_dynamic() || get_input_element_type(15) == element::f32,
get_input_element_type(15).is_dynamic() || get_input_element_type(15).is_real(),
"Element type of `rotation_trig_lut` input should be f32, but it is ",
get_input_element_type(15),
".");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct paged_attention : public primitive_base<paged_attention> {
paged_attention(const primitive_id& id,
const std::vector<input_info>& inputs)
: primitive_base(id, inputs) {
OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 15),
OPENVINO_ASSERT((inputs.size() == 13) || (inputs.size() == 16),
"[GPU] Unexpected inputs number for PagedAttention primitive: ",
inputs.size());
}
Expand Down
87 changes: 80 additions & 7 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include "sdpa/sdpa_kernel_base.h"
#include "sdpa/sdpa_kernel_selector.h"
#include "sdpa/pa_kv_cache_rotate_kernel_ref.h"
#include "sdpa/pa_kv_cache_update_kernel_ref.h"
#include "sdpa/pa_sdpa_kernel_opt.h"

Expand All @@ -28,6 +29,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
using pa_sdpa_kernel_selector_t = kernel_selector::pa_sdpa_kernel_selector;
using pa_sdpa_kernel_params_t = kernel_selector::pa_sdpa_params;

using kv_cache_rotate_kernel_selector_t = kernel_selector::kv_cache_rotate_kernel_selector;
using kv_cache_rotate_kernel_params_t = kernel_selector::kv_cache_rotate_params;

using kv_cache_update_kernel_selector_t = kernel_selector::kv_cache_update_kernel_selector;
using kv_cache_update_kernel_params_t = kernel_selector::kv_cache_update_params;

Expand All @@ -50,6 +54,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
KV_CACHE_UPDATE,
SDPA,
PA_SDPA,
KV_CACHE_ROTATE,
};

bool requires_update(primitive_inst& inst, const kernel_impl_params& impl_params) const override {
Expand Down Expand Up @@ -127,10 +132,16 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
const auto desc = instance.get_node().as<paged_attention>().get_primitive();

kernel_arguments_data args;
if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA)
if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA || stage == Stage::KV_CACHE_ROTATE)
args.shape_info = instance.shape_info_memory_ptr();

if (stage == Stage::KV_CACHE_UPDATE) {
if (stage == Stage::KV_CACHE_ROTATE) {
args.inputs = { instance.rotated_block_indices_ptr(),
instance.rotation_deltas_ptr(),
instance.rotation_trig_lut_ptr() };

args.outputs = { instance.key_cache_memory_ptr() };
} else if (stage == Stage::KV_CACHE_UPDATE) {
args.inputs = { instance.key_memory_ptr(),
instance.value_memory_ptr(),
instance.past_lens_memory_ptr(),
Expand Down Expand Up @@ -232,7 +243,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
if (stage == Stage::PA_SDPA) {
internal_buffers_offset = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size();
internal_buffers_count = _kernels_data[Stage::PA_SDPA].internalBufferSizes.size();
} else {
} else if (stage == Stage::KV_CACHE_UPDATE || stage == Stage::SDPA) {
internal_buffers_count = _kernels_data[Stage::KV_CACHE_UPDATE].internalBufferSizes.size();

if (stage == Stage::SDPA) {
Expand Down Expand Up @@ -305,6 +316,17 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
const auto stage = get_paged_attention_stage(*instance.get_impl_params());
const auto is_mixed_mode = stage == PagedAttentionStage::MIXED;

const auto& desc = instance.get_impl_params()->typed_desc<paged_attention>();
if (desc->has_rotation_coefficients) {
int SKIP_ROTATION = 0;
if (const auto env_var = std::getenv("SKIP_ROTATION")) {
std::istringstream ss(env_var);
ss >> SKIP_ROTATION;
}
if (!SKIP_ROTATION)
execute_stage(events, instance, res_events, Stage::KV_CACHE_ROTATE, is_mixed_mode);
}

GPU_DEBUG_TRACE_DETAIL << "Stage::KV_CACHE_UPDATE\n";
execute_stage(events, instance, res_events, Stage::KV_CACHE_UPDATE, is_mixed_mode);

Expand All @@ -314,7 +336,6 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
execute_stage(dep_events, instance, res_events, Stage::SDPA, is_mixed_mode);
}

const auto& desc = instance.get_impl_params()->typed_desc<paged_attention>();
if (stage == PagedAttentionStage::GENERATE || stage == PagedAttentionStage::MIXED || desc->has_scores_output()) {
GPU_DEBUG_TRACE_DETAIL << "stage: " << static_cast<int>(stage) << " " << is_mixed_mode << "\n";
execute_stage(dep_events, instance, res_events, Stage::PA_SDPA, is_mixed_mode);
Expand Down Expand Up @@ -428,6 +449,8 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
config.has_const_scale_val = false;
}

config.has_rotation_coefficients_input = desc->has_rotation_coefficients;

if (desc->heads_num != desc->kv_heads_num) {
config.broadcast_axis = 1;
config.group_size = desc->heads_num / desc->kv_heads_num;
Expand All @@ -447,6 +470,42 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
return config;
}

static kv_cache_rotate_kernel_params_t get_kv_cache_rotate_kernel_params(const kernel_impl_params& impl_param,
const kernel_selector::MultiDataTensor& input_tensors,
bool is_dynamic = false) {
auto params = get_default_params<kv_cache_rotate_kernel_params_t>(impl_param, is_dynamic);

const auto& key_cache_tensor = input_tensors[3];
const auto& rotated_block_indices_tensor = input_tensors[13];
const auto& rotation_deltas_tensor = input_tensors[14];
const auto& rotation_trig_lut_tensor = input_tensors[15];

const auto inputs_number = 3;
const auto outputs_number = 1;
params.inputs.resize(inputs_number);
params.outputs.resize(outputs_number);
params.inputs[0] = rotated_block_indices_tensor;
params.inputs[1] = rotation_deltas_tensor;
params.inputs[2] = rotation_trig_lut_tensor;
params.outputs[0] = key_cache_tensor;

params.conf = get_sdpa_configuration(impl_param, is_dynamic);

const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
std::map<size_t, size_t> in_tensor_to_offset_map = {
{0, in_offsets_map.at(13)},
{1, in_offsets_map.at(14)},
{2, in_offsets_map.at(15)},
};
std::map<size_t, size_t> out_tensor_to_offset_map = {
{0, in_offsets_map.at(3)},
};

params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);

return params;
}

static kv_cache_update_kernel_params_t get_kv_cache_update_kernel_params(const kernel_impl_params& impl_param,
const PagedAttentionStage& stage,
const kernel_selector::MultiDataTensor& input_tensors,
Expand Down Expand Up @@ -687,6 +746,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
for (const auto& input_layout : impl_param.input_layouts)
input_tensors.emplace_back(convert_data_tensor(input_layout));

if (desc->has_rotation_coefficients) {
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic());
(_kernels_data[Stage::KV_CACHE_ROTATE].update_dispatch_data_func)(kv_cache_rotate_kernel_params, _kernels_data[Stage::KV_CACHE_ROTATE]);

if (_kernels_data[Stage::KV_CACHE_ROTATE].kernels[0].skip_execution == false) {
std::cout << "GPU: Rotate KV-cache\n";
}
}

auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
(_kernels_data[Stage::KV_CACHE_UPDATE].update_dispatch_data_func)(kv_cache_update_kernel_params, _kernels_data[Stage::KV_CACHE_UPDATE]);

Expand All @@ -709,6 +777,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
for (const auto& input_layout : impl_param.input_layouts)
input_tensors.emplace_back(convert_data_tensor(input_layout));

const auto& desc = impl_param.typed_desc<paged_attention>();
auto kv_cache_update_kernel_params = get_kv_cache_update_kernel_params(impl_param, stage, input_tensors, impl_param.is_dynamic());
auto& kv_cache_update_kernel_selector = kv_cache_update_kernel_selector_t::Instance();
kernels_data.push_back(kv_cache_update_kernel_selector.get_best_kernel(kv_cache_update_kernel_params));
Expand All @@ -721,10 +790,14 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
auto& pa_sdpa_kernel_selector = pa_sdpa_kernel_selector_t::Instance();
kernels_data.push_back(pa_sdpa_kernel_selector.get_best_kernel(pa_sdpa_kernel_params));

auto pa_impl = cldnn::make_unique<paged_attention_impl>(kernels_data);
if (desc->has_rotation_coefficients) {
auto kv_cache_rotate_kernel_params = get_kv_cache_rotate_kernel_params(impl_param, input_tensors, impl_param.is_dynamic());
auto& kv_cache_rotate_kernel_selector = kv_cache_rotate_kernel_selector_t::Instance();
kernels_data.push_back(kv_cache_rotate_kernel_selector.get_best_kernel(kv_cache_rotate_kernel_params));
}

// TODO: Check if this is enough
const auto& desc = impl_param.typed_desc<paged_attention>();
// TODO: Check if this is enough for all the cases
auto pa_impl = cldnn::make_unique<paged_attention_impl>(kernels_data);
pa_impl->has_scores_output = desc->has_scores_output();

return pa_impl;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
memory::ptr block_indices_memory_ptr() const { return input_memory_ptr(7); }
memory::ptr block_indices_begins_memory_ptr() const { return input_memory_ptr(8); }
memory::ptr alibi_memory_ptr() const { return input_memory_ptr(11); }
memory::ptr rotated_block_indices_ptr() const { return input_memory_ptr(13); }
memory::ptr rotation_deltas_ptr() const { return input_memory_ptr(14); }
memory::ptr rotation_trig_lut_ptr() const { return input_memory_ptr(15); }

memory::ptr rotation_coefficients_memory_ptr() const {
return input_memory_ptr(13);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "include/batch_headers/common.cl"

#define SUBGROUPS_PER_WG KV_HEADS_NUM

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
__attribute__((reqd_work_group_size(SUBGROUP_SIZE, KV_HEADS_NUM, 1)))
KERNEL(pa_kv_cache_rotate)(
OPTIONAL_SHAPE_INFO_ARG
__global const INPUT0_TYPE* rotated_block_indices,
__global const INPUT1_TYPE* rotation_deltas,
__global const INPUT2_TYPE* rotation_trig_lut,
__global OUTPUT_TYPE* key_cache
) {
// Input shapes:
// rotated_block_indices: [num_blocks_to_rotate]
// rotation_deltas: [num_blocks_to_rotate, PAGED_ATTENTION_BLOCK_SIZE] || [num_blocks_to_rotate, 1]
// rotation_trig_lut: [max_num_batched_tokens / PAGED_ATTENTION_BLOCK_SIZE, HEAD_SIZE] || [max_num_batched_tokens, HEAD_SIZE]
// key_cache: [num_blocks, HEADS_NUM, HEAD_SIZE, PAGED_ATTENTION_BLOCK_SIZE]

// Output shapes:
// key_cache (updated): [num_blocks, HEADS_NUM, HEAD_SIZE, PAGED_ATTENTION_BLOCK_SIZE]

const uint head_idx = get_global_id(1);
const uint block_idx = get_global_id(2);
const uint sglid = get_sub_group_local_id();
const uint sgid = get_sub_group_id();

__local INPUT2_TYPE rotation_coefficients[HEAD_SIZE][PAGED_ATTENTION_BLOCK_SIZE];

const bool per_token_rotation = INPUT1_FEATURE_NUM == PAGED_ATTENTION_BLOCK_SIZE;

if (per_token_rotation) {
// Need to load HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE coefficients in total, each subgroup loads SUBGROUP_SIZE values
for (uint i = sgid; i < HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE / SUBGROUP_SIZE; i += SUBGROUPS_PER_WG) {
const uint token_idx = (i / (HEAD_SIZE / SUBGROUP_SIZE));
const uint rotation_trig_lut_start_offset = rotation_deltas[block_idx * INPUT1_FEATURE_NUM + token_idx] * HEAD_SIZE;
const uint inner_offset = (i % (HEAD_SIZE / SUBGROUP_SIZE)) * SUBGROUP_SIZE;
const uint rotation_trig_lut_offset = rotation_trig_lut_start_offset + inner_offset;

INPUT2_TYPE coefficient = rotation_trig_lut[rotation_trig_lut_offset + sglid];

rotation_coefficients[inner_offset + sglid][token_idx] = coefficient;
}
} else {
// Need to load HEAD_SIZE coefficients in total, each subgroup loads SUBGROUP_SIZE values
for (uint i = sgid; i < HEAD_SIZE / SUBGROUP_SIZE; i += SUBGROUPS_PER_WG) {
const uint token_idx = 0;
const uint rotation_trig_lut_start_offset = rotation_deltas[block_idx * INPUT1_FEATURE_NUM + token_idx] * HEAD_SIZE;
const uint inner_offset = i * SUBGROUP_SIZE;
const uint rotation_trig_lut_offset = rotation_trig_lut_start_offset + inner_offset;

INPUT2_TYPE coefficient = rotation_trig_lut[rotation_trig_lut_offset + sglid];

rotation_coefficients[inner_offset + sglid][token_idx] = coefficient;
}
}

barrier(CLK_LOCAL_MEM_FENCE);

const uint token_coefficient_idx = per_token_rotation ? sglid : 0;
const uint block_offset = rotated_block_indices[block_idx] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + sglid;
for (uint i = 0; i < HEAD_SIZE / 2; i++) {
const uint cache_offset = block_offset + i * PAGED_ATTENTION_BLOCK_SIZE;
OUTPUT_TYPE cache_value_first = key_cache[cache_offset];
OUTPUT_TYPE cache_value_second = key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE];

INPUT2_TYPE rotation_value_cos = rotation_coefficients[i][token_coefficient_idx];
INPUT2_TYPE rotation_value_sin = rotation_coefficients[i + (HEAD_SIZE / 2)][token_coefficient_idx];

OUTPUT_TYPE new_cache_value_first = cache_value_first * rotation_value_cos - cache_value_second * rotation_value_sin;
OUTPUT_TYPE new_cache_value_second = cache_value_first * rotation_value_sin + cache_value_second * rotation_value_cos;

key_cache[cache_offset] = new_cache_value_first;
key_cache[cache_offset + (HEAD_SIZE / 2) * PAGED_ATTENTION_BLOCK_SIZE] = new_cache_value_second;
}
}

#undef SUBGROUPS_PER_WG
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/kernel_selector/common_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ enum class KernelType {
ARG_MAX_MIN,
BEAM_TABLE_UPDATE,
PA_KV_CACHE_UPDATE,
PA_KV_CACHE_ROTATE,
PA_SDPA,
PA_SCORES_CALCULATION,
CONVOLUTION,
Expand Down
Loading

0 comments on commit 23207cb

Please sign in to comment.