Skip to content

Commit

Permalink
[GPU] PagedAttention prefix support
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Aug 27, 2024
1 parent d28d40b commit bae45a2
Show file tree
Hide file tree
Showing 10 changed files with 319 additions and 94 deletions.
170 changes: 107 additions & 63 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

Large diffs are not rendered by default.

13 changes: 10 additions & 3 deletions src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,14 @@

namespace cldnn {

bool is_prefill_stage(const kernel_impl_params& impl_param);
enum PagedAttentionStage {
GENERATION = 0,
PREFILL = 1,
MIXED = 2,
UNKNOWN = 3
};

PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param);

template <>
struct typed_program_node<paged_attention> : public typed_program_node_base<paged_attention> {
Expand All @@ -20,11 +27,11 @@ struct typed_program_node<paged_attention> : public typed_program_node_base<page
using parent::parent;

std::set<size_t> get_lockable_input_ids() const override {
return { 6 /* subsequence_begins */, 12 /* max_context_len */ };
return { 5 /* past_lens */, 6 /* subsequence_begins */, 12 /* max_context_len */ };
}

std::vector<size_t> get_shape_infer_dependencies() const override {
return { 6 /* subsequence_begins */, 12 /* max_context_len */ };
return { 5 /* past_lens */, 6 /* subsequence_begins */, 12 /* max_context_len */ };
}
};

Expand Down
91 changes: 82 additions & 9 deletions src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,45 @@ GPU_DEFINE_PRIMITIVE_TYPE_ID(paged_attention)

constexpr size_t paged_attention::block_size;

bool is_prefill_stage(const kernel_impl_params& impl_param) {
PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param) {
const auto& query_shape = impl_param.get_input_layout(0).get_partial_shape();
const auto& past_lens_shape = impl_param.get_input_layout(5).get_partial_shape();

if (query_shape.is_static() && past_lens_shape.is_static())
return query_shape[0].get_length() != past_lens_shape[0].get_length();
auto print_arr = [&](mem_lock<int32_t, mem_lock_type::read>& vec, size_t max_len, std::string name) {
std::stringstream ss;
for (size_t i = 0; i < std::min(max_len, vec.size()); i++) {
ss << vec[i] << ", ";
}
GPU_DEBUG_TRACE_DETAIL << "Array " << name << " (len=" << vec.size() << ") content: " << ss.str() << "\n";
};

if (query_shape.is_static() && past_lens_shape.is_static()) {
const auto past_lens_idx = 5;
const auto& memory_deps = impl_param.memory_deps;
const auto past_lens_mem = memory_deps.at(past_lens_idx);
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, *impl_param.strm);

print_arr(past_lens_mem_lock, past_lens_mem_lock.size(), "past_lens_mem_lock");

if (query_shape[0].get_length() == past_lens_shape[0].get_length()) {
GPU_DEBUG_TRACE_DETAIL << "get_paged_attention_stage GENERATION\n";
return PagedAttentionStage::GENERATION;
}

const auto past_lens_size = past_lens_mem_lock.size();
for (size_t i = 0; i < past_lens_size; i++) {
if (past_lens_mem_lock[i] != 0) {
GPU_DEBUG_TRACE_DETAIL << "get_paged_attention_stage MIXED\n";
return PagedAttentionStage::MIXED;
}
}

GPU_DEBUG_TRACE_DETAIL << "get_paged_attention_stage PREFILL\n";
return PagedAttentionStage::PREFILL;
}

return false;
GPU_DEBUG_TRACE_DETAIL << "get_paged_attention_stage UNKNOWN\n";
return PagedAttentionStage::UNKNOWN;
}

layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*node*/, kernel_impl_params const& impl_param) {
Expand Down Expand Up @@ -64,27 +95,44 @@ std::string paged_attention_inst::to_string(const paged_attention_node& node) {
}

void paged_attention_inst::on_execute() {
if (!is_prefill_stage(*_impl_params))
auto stage = get_paged_attention_stage(*_impl_params);

if (stage == PagedAttentionStage::UNKNOWN ||
stage == PagedAttentionStage::GENERATION)
return;

OPENVINO_ASSERT(_intermediates_memory.size() >= 3, "Unexpected number of intermediates buffers for Paged Attention at prefill stage");

GPU_DEBUG_TRACE_DETAIL << "paged attention stage " << stage << "\n";

const auto blocks_indexes_start_idx = 0;
const auto blocks_indexes_end_idx = 1;
const auto gws_seq_indexes_correspondence_idx = 2;
const auto blocked_gws_subseq_mapping_idx = 2;

auto subsequence_begins_mem = subsequence_begins_memory_ptr();
auto blocks_indexes_start_mem = _intermediates_memory[blocks_indexes_start_idx];
auto blocks_indexes_end_mem = _intermediates_memory[blocks_indexes_end_idx];
auto gws_seq_indexes_correspondence_mem = _intermediates_memory[gws_seq_indexes_correspondence_idx];
auto blocked_gws_subseq_mapping_mem = _intermediates_memory[blocked_gws_subseq_mapping_idx];

OPENVINO_ASSERT(subsequence_begins_mem->get_layout().data_type == data_types::i32);

auto& stream = get_network().get_stream();
mem_lock<int32_t, mem_lock_type::read> subsequence_begins_mem_lock(subsequence_begins_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
mem_lock<int32_t, mem_lock_type::write> gws_seq_indexes_correspondence_lock(gws_seq_indexes_correspondence_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream);
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;

if (stage == PagedAttentionStage::MIXED) {
const auto sequential_gws_subseq_mapping_idx = 6;

OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx, "Unexpected index, actual size = ", _intermediates_memory.size());

auto sequential_gws_subseq_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx];
GPU_DEBUG_TRACE_DETAIL << "gws buffer ptr " << sequential_gws_subseq_mapping_mem->buffer_ptr()
<< " intermediate buffers size=" << _intermediates_memory.size() << "\n";
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
}

size_t index = 0;
const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl
Expand All @@ -99,11 +147,36 @@ void paged_attention_inst::on_execute() {

blocks_indexes_start_lock[index] = block_start_pos;
blocks_indexes_end_lock[index] = block_end_pos;
gws_seq_indexes_correspondence_lock[index] = static_cast<int32_t>(i);
blocked_gws_subseq_mapping_mem_lock[index] = static_cast<int32_t>(i);

index++;
}

if (stage == PagedAttentionStage::MIXED) {
GPU_DEBUG_TRACE_DETAIL << "start=" << seq_start << " end=" << " lock=" << sequential_gws_subseq_mapping_lock.get()
<< " " << sequential_gws_subseq_mapping_lock->size() << " " << seq_end << "\n";
for (int32_t idx = seq_start; idx < seq_end; idx++) {
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast<int32_t>(i);
}
}
}

auto print_arr = [&](mem_lock<int32_t, mem_lock_type::write>& vec, size_t max_len, std::string name) {
std::stringstream ss;
for (size_t i = 0; i < std::min(max_len, vec.size()); i++) {
ss << vec[i] << ", ";
}
GPU_DEBUG_TRACE_DETAIL << "Array " << name << " (len=" << vec.size() << ") content: " << ss.str() << "\n";
};

if (stage == PagedAttentionStage::MIXED) {
print_arr(*sequential_gws_subseq_mapping_lock, sequential_gws_subseq_mapping_lock->size(), "sequential_gws_subseq_mapping_lock");
}


print_arr(blocks_indexes_start_lock, blocks_indexes_start_lock.size(), "blocks_indexes_start_lock");
print_arr(blocks_indexes_end_lock, blocks_indexes_end_lock.size(), "blocks_indexes_end_lock");
print_arr(blocked_gws_subseq_mapping_mem_lock, blocked_gws_subseq_mapping_mem_lock.size(), "blocked_gws_subseq_mapping_mem_lock");
}

void paged_attention_inst::update_shape_info_tensor(const kernel_impl_params& params) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,24 +11,26 @@ KERNEL(pa_kv_cache_update)(
__global const INPUT2_TYPE* past_lens,
__global const INPUT3_TYPE* block_indices,
__global const INPUT4_TYPE* block_indices_begins,
__global const INPUT5_TYPE* subsequence_begins,
__global OUTPUT_TYPE* key_cache_data,
__global OUTPUT1_TYPE* value_cache_data,
const __global int* blocked_indexes_start,
const __global int* blocked_indexes_end,
const __global int* gws_seq_indexes_correspondence
const __global int* gws_seq_indexes_correspondence,
const int is_prefill_stage
) {
// If the the number of new tokens equals to the number of past_lens elements,
// then it's the 2nd+ iteration
if (INPUT0_BATCH_NUM == INPUT2_BATCH_NUM) {
if (!is_prefill_stage) {
// 2nd+ token
const uint seq_idx = (uint)get_global_id(0);
const uint head_idx = (uint)get_global_id(1);
const uint sglid = (uint)get_global_id(2);

const uint seq_len = past_lens[seq_idx];
const uint current_token_pos_in_block = seq_len % PAGED_ATTENTION_BLOCK_SIZE;
const uint seq_last_block_idx = block_indices_begins[seq_idx + 1] - 1;
const uint block_idx = block_indices[seq_last_block_idx];
const uint seq_block_idx = block_indices_begins[seq_idx] + seq_len / PAGED_ATTENTION_BLOCK_SIZE;
const uint block_idx = block_indices[seq_block_idx];

uint key_value_in_offset = seq_idx * KV_HEADS_NUM * HEAD_SIZE + head_idx * HEAD_SIZE;

Expand Down Expand Up @@ -69,14 +71,22 @@ KERNEL(pa_kv_cache_update)(
const uint head_idx = get_global_id(1);
const uint sglid = get_global_id(2);

const uint subsequence_idx = gws_seq_indexes_correspondence[block_idx];
const uint subsequence_begin_idx = subsequence_begins[subsequence_idx];

const uint block_start_pos = blocked_indexes_start[block_idx];
const uint block_end_pos = blocked_indexes_end[block_idx];
const uint tokens_num = block_end_pos - block_start_pos;

uint key_value_in_offset = block_start_pos * KV_HEADS_NUM * HEAD_SIZE +
head_idx * HEAD_SIZE;

uint key_out_offset = block_indices[block_idx] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
const uint cached_blocks_num = past_lens[subsequence_idx] / PAGED_ATTENTION_BLOCK_SIZE;
const uint current_block_idx = (block_start_pos - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE;

const uint block_offset = block_indices_begins[subsequence_idx] + cached_blocks_num + current_block_idx;

uint key_out_offset = block_indices[block_offset] * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE +
head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE;

uint value_out_offset = key_out_offset;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,14 @@
#error pa_sdpa_opt.cl
#endif

#if HAS_ALIBI
#if MULTI_TOKENS_PROCESSING
#define ALIBI_INPUT_TYPE INPUT7_TYPE
#else
#define ALIBI_INPUT_TYPE INPUT6_TYPE
#endif
#endif

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
__attribute__((reqd_work_group_size(1, 1, HEAD_SIZE)))
KERNEL(pa_sdpa_opt)(
Expand All @@ -34,13 +42,19 @@ KERNEL(pa_sdpa_opt)(
const __global INPUT3_TYPE* past_lens,
const __global INPUT4_TYPE* block_indices,
const __global INPUT5_TYPE* block_indices_begins,
#if MULTI_TOKENS_PROCESSING
const __global INPUT6_TYPE* subsequence_begins,
#endif
#if HAS_ALIBI
const __global INPUT6_TYPE* alibi_slopes,
const __global ALIBI_INPUT_TYPE* alibi_slopes,
#endif
__global OUTPUT_TYPE* output,
__global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
__global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
__global OUTPUT_TYPE* tmp_out
#if MULTI_TOKENS_PROCESSING
, __global const int* gws_subseq_mapping
#endif
) {
// Input shapes:
// query: [sequences_num, HEADS_NUM * HEAD_SIZE]
Expand All @@ -66,7 +80,15 @@ KERNEL(pa_sdpa_opt)(

const uint batch_idx = seq_idx;

#if MULTI_TOKENS_PROCESSING
const int subsequence_idx = gws_subseq_mapping[seq_idx];
const int subsequence_begin = subsequence_begins[subsequence_idx];
const int subsequence_end = subsequence_begins[subsequence_idx + 1];
const uint seq_len = past_lens[subsequence_idx] + 1 + (seq_idx - subsequence_begin);
#else
const uint subsequence_idx = seq_idx;
const uint seq_len = past_lens[seq_idx] + 1;
#endif

const uint partition_idx = get_group_id(2);
const uint block_start_idx = partition_idx * SEQ_LEN_PARTITION_SIZE / PAGED_ATTENTION_BLOCK_SIZE;
Expand All @@ -79,7 +101,7 @@ KERNEL(pa_sdpa_opt)(

#ifdef STORE_QUERY_TO_SLM
// SLM buffer for query inputs
__local INPUT0_TYPE slm_query[HEAD_SIZE * TARGET_SEQ_LEN_BLOCK_SIZE];
__local INPUT0_TYPE slm_query[HEAD_SIZE];
#endif

// SLM for intermediate QK results
Expand Down Expand Up @@ -117,7 +139,7 @@ KERNEL(pa_sdpa_opt)(
if (sgid < blocks_num_per_partition % SUBGROUPS_PER_WG)
blocks_num++;

const uint start_block_idx = block_indices_begins[seq_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION + sgid;
const uint start_block_idx = block_indices_begins[subsequence_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION + sgid;
for (uint block_num = 0; block_num < blocks_num; block_num++) {
#ifdef BROADCAST_GROUP_SIZE
const uint head_idx = head_num_idx / BROADCAST_GROUP_SIZE;
Expand Down Expand Up @@ -255,7 +277,7 @@ KERNEL(pa_sdpa_opt)(
blocks_num_per_partition = blocks_num_per_partition - 1;
}

const uint start_block_idx = block_indices_begins[seq_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION;
const uint start_block_idx = block_indices_begins[subsequence_idx] + partition_idx * PAGED_ATTENTION_BLOCKS_PER_PARTITION;

for (uint block_num = 0; block_num < blocks_num_per_partition; block_num++) {
#ifdef BROADCAST_GROUP_SIZE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,12 @@ KernelsData KVCacheUpdateKernelRef::GetKernelsData(const Params& p) const {
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0});

ScalarDescriptor is_prefill_stage;
is_prefill_stage.t = ScalarDescriptor::Types::UINT32;
is_prefill_stage.v.u32 = static_cast<uint32_t>(0);
kernel.params.scalars.push_back(is_prefill_stage);

return {kd};
}
Expand Down Expand Up @@ -90,7 +96,7 @@ bool KVCacheUpdateKernelRef::Validate(const Params& params) const {
return false;

const auto& kernel_params = dynamic_cast<const kv_cache_update_params&>(params);
if (kernel_params.inputs.size() != 5)
if (kernel_params.inputs.size() != 6)
return false;

if (kernel_params.outputs.size() != 2)
Expand Down Expand Up @@ -124,16 +130,15 @@ CommonDispatchData KVCacheUpdateKernelRef::SetDefault(const kv_cache_update_para
const auto& key_cache = params.outputs[0];
const auto& value_cache = params.outputs[1];
if (!value_cache.is_dynamic() && !key_cache.is_dynamic()) {
bool is_prefill = params.inputs[0].Batch().v != params.inputs[2].Batch().v;
bool is_prefill = params.is_prefill;
auto heads_number = static_cast<size_t>(params.conf.kv_heads_num);

if (is_prefill) {
const auto& block_indices_input = params.inputs[3];
const auto blocks_number = block_indices_input.Batch().v;
const auto blocks_number = params.conf.paged_attention_aligned_seq_len / paged_attention_block_size;

dispatch_data.gws = { blocks_number,
heads_number,
subgroup_size};
subgroup_size };
dispatch_data.lws = { 1, 1, subgroup_size };
} else {
const auto& key_input = params.inputs[0];
Expand All @@ -159,6 +164,8 @@ void KVCacheUpdateKernelRef::GetUpdateDispatchDataFunc(KernelData& kd) const {
kd.kernels[0].params.workGroups.global = dispatch_data.gws;
kd.kernels[0].params.workGroups.local = dispatch_data.lws;
kd.kernels[0].skip_execution = false;

kd.kernels[0].params.scalars[0].v.s32 = static_cast<int32_t>(prim_params.is_prefill);
};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace kernel_selector {
struct kv_cache_update_params : base_params {
kv_cache_update_params() : base_params(KernelType::PA_KV_CACHE_UPDATE) {}

bool is_prefill = false;
sdpa_configuration conf;
};

Expand Down
Loading

0 comments on commit bae45a2

Please sign in to comment.