Skip to content

Commit

Permalink
[GPU] Add scores output support for PagedAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Dec 25, 2024
1 parent b25413c commit 8cdd7ba
Show file tree
Hide file tree
Showing 12 changed files with 708 additions and 148 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ struct paged_attention : public primitive_base<paged_attention> {
OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size());
}

bool has_scores_output() const {
return num_outputs == 2;
}

bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}
Expand Down
277 changes: 207 additions & 70 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

Large diffs are not rendered by default.

12 changes: 6 additions & 6 deletions src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,11 @@
#include "intel_gpu/primitives/paged_attention.hpp"
#include "primitive_inst.h"

#include "sdpa/pa_sdpa_kernel_opt.h"

namespace cldnn {

enum PagedAttentionStage {
GENERATE = 0,
PREFILL = 1,
MIXED = 2,
UNKNOWN = 3
};
using PagedAttentionStage = kernel_selector::PagedAttentionStage;

PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param);

Expand Down Expand Up @@ -61,6 +58,9 @@ class typed_primitive_inst<paged_attention> : public typed_primitive_inst_base<p
memory::ptr block_indices_memory_ptr() const { return input_memory_ptr(7); }
memory::ptr block_indices_begins_memory_ptr() const { return input_memory_ptr(8); }
memory::ptr alibi_memory_ptr() const { return input_memory_ptr(11); }
memory::ptr rotated_block_indices_memory_ptr() const { return input_memory_ptr(13); }
memory::ptr rotation_deltas_memory_ptr() const { return input_memory_ptr(14); }
memory::ptr rotation_trig_lut_memory_ptr() const { return input_memory_ptr(15); }

std::shared_ptr<network> prefill_network;

Expand Down
51 changes: 47 additions & 4 deletions src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,38 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no

template<typename ShapeType>
std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) {
auto out_layout = impl_param.get_input_layout(0);
auto data_layout = impl_param.get_input_layout(0);

const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape();
bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size;
OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation. "
"Expected ", paged_attention::block_size, ", but got ", key_cache_ps[3].get_length());

return {out_layout};
std::vector<layout> output_layouts{ data_layout };

const auto& desc = impl_param.typed_desc<paged_attention>();
if (desc->has_scores_output()) {
const auto past_lens_idx = 5;
const auto output_dt = data_layout.data_type;
if (impl_param.get_input_layout(past_lens_idx).is_static()) {
const auto& memory_deps = impl_param.memory_deps;
const auto past_lens_mem = memory_deps.at(past_lens_idx);
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, *impl_param.strm);

long int total_size = 0;
for (size_t i = 0; i < past_lens_mem_lock.size(); i++) {
total_size += past_lens_mem_lock[i];
}

total_size += impl_param.get_input_layout(0).get_shape()[0];

output_layouts.push_back(layout{ov::PartialShape{total_size}, output_dt, format::bfyx});
} else {
output_layouts.push_back(layout{ov::PartialShape::dynamic(1), output_dt, format::bfyx});
}
}

return output_layouts;
}

template std::vector<layout>
Expand Down Expand Up @@ -107,19 +131,33 @@ void paged_attention_inst::on_execute() {
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_start_lock(blocks_indexes_start_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocks_indexes_end_lock(blocks_indexes_end_mem, stream);
mem_lock<int32_t, mem_lock_type::write> blocked_gws_subseq_mapping_mem_lock(blocked_gws_subseq_mapping_mem, stream);
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> subsequence_offsets_lock = nullptr;
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;

const auto& desc = _impl_params->typed_desc<paged_attention>();
const bool has_scores_output = desc->has_scores_output();
if (stage == PagedAttentionStage::MIXED) {
const auto sequential_gws_subseq_mapping_idx = 6;
const size_t sequential_gws_subseq_mapping_idx = has_scores_output ? 8 : 6;

OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx,
"Unexpected number of intermediates buffers for Paged Attention for mixed stage");
"[GPU] Unexpected number of intermediates buffers for Paged Attention for mixed stage");

auto sequential_gws_subseq_mapping_mem = _intermediates_memory[sequential_gws_subseq_mapping_idx];
sequential_gws_subseq_mapping_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(sequential_gws_subseq_mapping_mem, stream));
}

if (has_scores_output) {
const size_t subsequence_offsets_idx = 4;

OPENVINO_ASSERT(_intermediates_memory.size() > subsequence_offsets_idx,
"[GPU] Unexpected number of intermediates buffers for Paged Attention for scores output calculation");

auto subsequence_offsets_mem = _intermediates_memory[subsequence_offsets_idx];
subsequence_offsets_lock.reset(new mem_lock<int32_t, mem_lock_type::write>(subsequence_offsets_mem, stream));
}

size_t index = 0;
size_t subsequence_offsets_acc = 0;
const auto target_seq_len_block_size = 16; // TODO: Get block size from the impl
for (size_t i = 0; i < subsequence_begins_mem_lock.size() - 1; i++) {
const auto past_len = past_lens_mem_lock[i];
Expand Down Expand Up @@ -159,6 +197,11 @@ void paged_attention_inst::on_execute() {
sequential_gws_subseq_mapping_lock->operator[](idx) = static_cast<int32_t>(i);
}
}

if (subsequence_offsets_lock) {
subsequence_offsets_lock->operator[](i) = static_cast<int32_t>(subsequence_offsets_acc);
subsequence_offsets_acc += seq_length + past_len;
}
}
}

Expand Down
179 changes: 179 additions & 0 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,10 @@ KERNEL(pa_sdpa_opt)(
const __global ALIBI_INPUT_TYPE* alibi_slopes,
#endif
__global OUTPUT_TYPE* output,
#if PAGED_ATTENTION_SCORES_OUTPUT
__global SOFTMAX_ACCUMULATOR_TYPE* softmax_results,
const __global int* subsequence_offsets,
#endif
__global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
__global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
__global OUTPUT_TYPE* tmp_out
Expand Down Expand Up @@ -276,6 +280,28 @@ KERNEL(pa_sdpa_opt)(
const uint max_logits_offset = exp_sums_offset;
max_logits[max_logits_offset] = qk_max;
}

#if PAGED_ATTENTION_SCORES_OUTPUT
#if MULTI_TOKENS_PROCESSING
const uint subsequence_idx = gws_subseq_mapping[seq_idx];
const uint subsequence_start_pos = subsequence_begins[subsequence_idx];
const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];
const bool save_softmax_results = seq_idx == subsequence_end_pos - 1;
#else
const uint subsequence_idx = seq_idx;
const bool save_softmax_results = true;
#endif // MULTI_TOKENS_PROCESSING
// PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
// so save SEQ_LEN_PARTITION_SIZE elements for each partition
if (save_softmax_results) {
const uint output_offset = subsequence_idx * HEADS_NUM * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
head_num_idx * total_partitions_num * SEQ_LEN_PARTITION_SIZE +
partition_idx * SEQ_LEN_PARTITION_SIZE;
for (uint i = sgid * SUBGROUP_SIZE + sglid; i < SEQ_LEN_PARTITION_SIZE; i += SUBGROUPS_PER_WG * SUBGROUP_SIZE) {
softmax_results[output_offset + i] = slm_qk_vals[i];
}
}
#endif // PAGED_ATTENTION_SCORES_OUTPUT
}
}

Expand Down Expand Up @@ -370,6 +396,10 @@ KERNEL(pa_sdpa_finalization_stage)(
const __global INPUT6_TYPE* subsequence_begins,
#endif
__global OUTPUT_TYPE* output,
#if PAGED_ATTENTION_SCORES_OUTPUT
__global SOFTMAX_ACCUMULATOR_TYPE* softmax_results,
const __global int* subsequence_offsets,
#endif
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
const __global OUTPUT_TYPE* tmp_out,
Expand Down Expand Up @@ -500,3 +530,152 @@ KERNEL(pa_sdpa_finalization_stage)(
}

#endif

#ifdef SDPA_STAGE_2
#define MAX_PARTITIONS_NUM 128

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(pa_sdpa_scores_calculation)(
const __global INPUT3_TYPE* past_lens,
const __global INPUT6_TYPE* subsequence_begins,
__global OUTPUT1_TYPE* scores_output,
const __global SOFTMAX_ACCUMULATOR_TYPE* softmax_output,
const __global int* subsequence_offsets,
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
const __global OUTPUT_TYPE* tmp_out,
const uint is_mixed_mode) {
const uint subsequence_idx = get_global_id(2);
const uint partition_global_idx = get_global_id(0);
const uint local_id = get_local_id(0);
const uint partition_idx = get_group_id(0);
const uint partition_size = get_local_size(0);
const uint max_seq_len = get_global_size(0);
const uint partitions_num = get_num_groups(0);
const uint sgid = get_sub_group_id();
const uint sgid_num = get_num_sub_groups();
const uint sglid = get_sub_group_local_id();

const int subsequence_begin = subsequence_begins[subsequence_idx];
const int subsequence_end = subsequence_begins[subsequence_idx + 1];
const uint seq_len = (subsequence_end - subsequence_begin) + past_lens[subsequence_idx];

const uint num_of_partitions = CEIL_DIV(seq_len, partition_size);

if (partition_idx >= num_of_partitions)
return;

__local SOFTMAX_ACCUMULATOR_TYPE slm_exp_sums[HEADS_NUM];
__local SOFTMAX_ACCUMULATOR_TYPE slm_global_exp_sum[HEADS_NUM];

SOFTMAX_ACCUMULATOR_TYPE total_score = SOFTMAX_ACCUMULATOR_VAL_ZERO;
if (seq_len <= partition_size) {
// If seq_len is less than the partition size, just reduce the results over the heads
for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];
total_score += softmax_value;
}
} else if (seq_len <= partition_size * MAX_PARTITIONS_NUM) {
// Optimized version for longer prompts (up to partition_size * MAX_PARTITIONS_NUM, ~64K tokens)

// Depending on the previous kernel exp_sums and max_logits might have different structure:
// For ordinary 1st and 2nd token kernels, there is only a single entry per subsequence.
// However, for mixed mode execution, exp_sums and max_logits include information for all
// tokens of each subsequence, but only the last one is needed for score calculation.
const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx;

for (uint head_idx = sgid; head_idx < HEADS_NUM; head_idx += sgid_num) {
SOFTMAX_ACCUMULATOR_TYPE max_logit[MAX_PARTITIONS_NUM / SUBGROUP_SIZE];
SOFTMAX_ACCUMULATOR_TYPE exp_sum[MAX_PARTITIONS_NUM / SUBGROUP_SIZE];

const uint exp_sums_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
for (int i = 0; i < partitions_num / SUBGROUP_SIZE; i++) {
max_logit[i] = max_logits[exp_sums_offset + i * SUBGROUP_SIZE + sglid];
exp_sum[i] = exp_sums[exp_sums_offset + i * SUBGROUP_SIZE + sglid];
}

const uint partitions_leftovers = partitions_num % SUBGROUP_SIZE;
if (partitions_leftovers != 0) {
const uint idx = partitions_num / SUBGROUP_SIZE;
max_logit[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[exp_sums_offset + idx * SUBGROUP_SIZE + sglid];
exp_sum[idx] = sglid >= partitions_leftovers ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_offset + idx * SUBGROUP_SIZE + sglid];
}

SOFTMAX_ACCUMULATOR_TYPE global_max_logit = max_logit[0];
for (uint i = 1; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit[i]);
}

global_max_logit = sub_group_reduce_max(global_max_logit);

SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum[i] * native_exp(max_logit[i] - global_max_logit);
// slm_exp_sums[head_idx][i * SUBGROUP_SIZE + sglid] = adjusted_exp_sum;
if (i * SUBGROUP_SIZE + sglid == partition_idx)
slm_exp_sums[head_idx] = adjusted_exp_sum;
global_exp_sum += adjusted_exp_sum;
}
slm_global_exp_sum[head_idx] = global_exp_sum;
}

barrier(CLK_LOCAL_MEM_FENCE);

for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = slm_exp_sums[head_idx];
SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = slm_global_exp_sum[head_idx];

const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];

softmax_value = softmax_value * adjusted_exp_sum / global_exp_sum;
total_score += softmax_value;
}
} else {
// Non optimized fallback version
const uint subsequence_pos = is_mixed_mode ? subsequence_end - 1 : subsequence_idx;
for (uint head_idx = 0; head_idx < HEADS_NUM; head_idx++) {
SOFTMAX_ACCUMULATOR_TYPE global_max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;
const uint max_logits_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
const uint partition_offset = i * SUBGROUP_SIZE + sglid;
SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset];
global_max_logit = SOFTMAX_ACCUMULATOR_MAX_FUNC(global_max_logit, max_logit);
}

global_max_logit = sub_group_reduce_max(global_max_logit);

SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE partition_adjusted_exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
const uint exp_sums_base_offset = subsequence_pos * HEADS_NUM * partitions_num + head_idx * partitions_num;
for (uint i = 0; i < CEIL_DIV(partitions_num, SUBGROUP_SIZE); i++) {
const uint partition_offset = i * SUBGROUP_SIZE + sglid;
SOFTMAX_ACCUMULATOR_TYPE exp_sum = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_ZERO : exp_sums[exp_sums_base_offset + partition_offset];
SOFTMAX_ACCUMULATOR_TYPE max_logit = partition_offset >= partitions_num ? SOFTMAX_ACCUMULATOR_VAL_MIN : max_logits[max_logits_base_offset + partition_offset];
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum * native_exp(max_logit - global_max_logit);
global_exp_sum += adjusted_exp_sum;

// Save and broadcast the adjusted exp_sum for the currently being processed partition
if (i == partition_idx / SUBGROUP_SIZE)
partition_adjusted_exp_sum = sub_group_broadcast(adjusted_exp_sum, partition_idx % SUBGROUP_SIZE);
}

global_exp_sum = sub_group_reduce_add(global_exp_sum);

const uint input_offset = subsequence_idx * HEADS_NUM * max_seq_len + head_idx * max_seq_len + partition_global_idx;
SOFTMAX_ACCUMULATOR_TYPE softmax_value = softmax_output[input_offset];

softmax_value = softmax_value * partition_adjusted_exp_sum / global_exp_sum;
total_score += softmax_value;
}
}

const uint output_offset = subsequence_offsets[subsequence_idx];
if (partition_global_idx < seq_len) {
scores_output[output_offset + partition_global_idx] = total_score;
}
}

#undef MAX_PARTITIONS_NUM
#endif
44 changes: 44 additions & 0 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -842,6 +842,14 @@ KERNEL(sdpa_opt)(
const __global int* blocked_indexes_start,
const __global int* blocked_indexes_end,
const __global int* gws_seq_indexes_correspondence
#if PAGED_ATTENTION_SCORES_OUTPUT
, __global SOFTMAX_ACCUMULATOR_TYPE* softmax_results
, const __global int* subsequence_offsets
, __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums
, __global SOFTMAX_ACCUMULATOR_TYPE* max_logits
, __global OUTPUT_TYPE* tmp_out
, const uint aligned_max_context_len
#endif
#else
__global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
__global SOFTMAX_ACCUMULATOR_TYPE* max_logits,
Expand Down Expand Up @@ -1222,6 +1230,42 @@ KERNEL(sdpa_opt)(
slm_qk_vals[sglid * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE + i] = qk_acc[i];
}

#if PAGED_ATTENTION_SCORES_OUTPUT
const uint subsequence_idx = gws_seq_indexes_correspondence[target_seq_dim];
const uint subsequence_end_pos = subsequence_begins[subsequence_idx + 1];
const uint block_start_pos = blocked_indexes_start[target_seq_dim];
const uint block_end_pos = blocked_indexes_end[target_seq_dim];

// PagedAttention is supposed to save only last "row" of the QK matrix multiplication,
// so save SEQ_LEN_PARTITION_SIZE elements for each partition
if (subsequence_end_pos == block_end_pos) {
const uint last_row_idx = block_end_pos - block_start_pos - 1;
if (sglid == last_row_idx) {
const uint partition_idx = start_partition_idx / SEQ_LEN_PARTITION_SIZE;

if (sgid == 0) {
const uint max_partitions_num = aligned_max_context_len / SEQ_LEN_PARTITION_SIZE;
const uint exp_sums_output_offset = subsequence_idx * NUM_HEADS * max_partitions_num +
num_heads_dim * max_partitions_num +
partition_idx;
exp_sums[exp_sums_output_offset] = exp_sum_new;
max_logits[exp_sums_output_offset] = qk_max_new;
const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
num_heads_dim * aligned_max_context_len +
partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE;
}

const uint output_offset = subsequence_idx * NUM_HEADS * aligned_max_context_len +
num_heads_dim * aligned_max_context_len +
partition_idx * SEQ_LEN_PARTITION_SIZE + sgid * TARGET_SEQ_LEN_BLOCK_SIZE;
for (uint i = 0; i < TARGET_SEQ_LEN_BLOCK_SIZE; i++) {
softmax_results[output_offset + i] = qk_acc[i];
}

}
}
#endif

barrier(CLK_LOCAL_MEM_FENCE);
}

Expand Down
Loading

0 comments on commit 8cdd7ba

Please sign in to comment.