Skip to content

Commit

Permalink
WIP: [GPU] PA scores output
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Dec 18, 2024
1 parent 328feb6 commit ee8c63f
Show file tree
Hide file tree
Showing 16 changed files with 659 additions and 163 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,10 @@ struct paged_attention : public primitive_base<paged_attention> {
OPENVINO_ASSERT(inputs.size() == 13, "[GPU] Unexpected inputs number for PagedAttention primitive: ", inputs.size());
}

bool has_scores_output() const {
return num_outputs == 2;
}

bool operator==(const primitive& rhs) const override {
return compare_common_params(rhs);
}
Expand Down
214 changes: 150 additions & 64 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp

Large diffs are not rendered by default.

15 changes: 9 additions & 6 deletions src/plugins/intel_gpu/src/graph/include/paged_attention_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,18 @@

#include "intel_gpu/primitives/paged_attention.hpp"
#include "primitive_inst.h"
#include "sdpa/pa_sdpa_kernel_opt.h"

namespace cldnn {

enum PagedAttentionStage {
GENERATE = 0,
PREFILL = 1,
MIXED = 2,
UNKNOWN = 3
};
// enum PagedAttentionStage {
// GENERATE = 0,
// PREFILL = 1,
// MIXED = 2,
// UNKNOWN = 3
// };

using PagedAttentionStage = kernel_selector::PagedAttentionStage;

PagedAttentionStage get_paged_attention_stage(const kernel_impl_params& impl_param);

Expand Down
28 changes: 25 additions & 3 deletions src/plugins/intel_gpu/src/graph/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,35 @@ layout paged_attention_inst::calc_output_layout(const paged_attention_node& /*no

template<typename ShapeType>
std::vector<layout> paged_attention_inst::calc_output_layouts(paged_attention_node const& /*node*/, kernel_impl_params const& impl_param) {
auto out_layout = impl_param.get_input_layout(0);
const auto& desc = impl_param.typed_desc<paged_attention>();
auto data_layout = impl_param.get_input_layout(0);

const auto& key_cache_ps = impl_param.get_input_layout(3).get_partial_shape();
bool valid_block_size = key_cache_ps[3].is_dynamic() || key_cache_ps[3].get_length() == paged_attention::block_size;
OPENVINO_ASSERT(valid_block_size, "[GPU] Incorrect block size for Paged Attention operation. "
"Expected ", paged_attention::block_size, ", but got ", key_cache_ps[3].get_length());

return {out_layout};
std::vector<layout> output_layouts{ data_layout };

if (desc->has_scores_output()) {
const auto past_lens_idx = 5;
const auto& memory_deps = impl_param.memory_deps;
const auto past_lens_mem = memory_deps.at(past_lens_idx);
mem_lock<int32_t, mem_lock_type::read> past_lens_mem_lock(past_lens_mem, *impl_param.strm);

long int total_size = 0;
const auto past_lens_size = past_lens_mem_lock.size();
for (size_t i = 0; i < past_lens_size; i++) {
total_size += past_lens_mem_lock[i];
}

auto scores_output = data_layout;
scores_output.set_partial_shape(ov::PartialShape{total_size});

output_layouts.push_back(scores_output);
}

return output_layouts;
}

template std::vector<layout>
Expand Down Expand Up @@ -110,7 +131,8 @@ void paged_attention_inst::on_execute() {
std::unique_ptr<mem_lock<int32_t, mem_lock_type::write>> sequential_gws_subseq_mapping_lock = nullptr;

if (stage == PagedAttentionStage::MIXED) {
const auto sequential_gws_subseq_mapping_idx = 6;
const auto& desc = _impl_params->typed_desc<paged_attention>();
const size_t sequential_gws_subseq_mapping_idx = desc->has_scores_output() ? 7 : 6;

OPENVINO_ASSERT(_intermediates_memory.size() > sequential_gws_subseq_mapping_idx,
"Unexpected number of intermediates buffers for Paged Attention for mixed stage");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -500,3 +500,72 @@ KERNEL(pa_sdpa_finalization_stage)(
}

#endif

#ifdef SDPA_STAGE_2

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(pa_sdpa_scores_calculation)(
const __global INPUT3_TYPE* past_lens,
const __global INPUT6_TYPE* subsequence_begins,
__global OUTPUT1_TYPE* scores_output,
const __global SOFTMAX_ACCUMULATOR_TYPE* softmax_output,
const __global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
const __global SOFTMAX_ACCUMULATOR_TYPE* max_logits) {
const uint subsequence_idx = get_global_id(0);
const uint partition_global_idx = get_global_id(2);
const uint partition_idx = get_group_id(2);
const uint partition_size = get_group_size(2);
const uint max_seq_len = get_global_size(2);
const uint partitions_num = get_num_groups(2);
const uint sglid = get_sub_group_local_id();

const int subsequence_begin = subsequence_begins[subsequence_idx];
const int subsequence_end = subsequence_begins[subsequence_idx + 1];
const uint seq_len = (subsequence_end - subsequence_begin) + past_lens[subsequence_idx];

const uint num_of_partitions = CEIL_DIV(seq_len, partition_size);

if (partition_idx >= num_of_partitions)
return;

SOFTMAX_ACCUMULATOR_TYPE total_score = SOFTMAX_ACCUMULATOR_VAL_ZERO;
for (uint i = 0; i < HEAD_SIZE; i++) {
SOFTMAX_ACCUMULATOR_TYPE exp_sum = SOFTMAX_ACCUMULATOR_VAL_ZERO;
SOFTMAX_ACCUMULATOR_TYPE max_logit = SOFTMAX_ACCUMULATOR_VAL_MIN;

const uint exp_sums_offset = subsequence_idx * HEAD_SIZE * partitions_num + i * partitions_num;
if (partition_global_idx < num_of_partitions) {
exp_sum = exp_sums[exp_sums_offset + partition_global_idx];
max_logit = max_logits[exp_sums_offset + partition_global_idx];
}

SOFTMAX_ACCUMULATOR_TYPE global_max_logit = work_group_reduce_max(max_logit);
SOFTMAX_ACCUMULATOR_TYPE adjusted_exp_sum = exp_sum * native_exp(max_logit - global_max_logit);
SOFTMAX_ACCUMULATOR_TYPE current_exp_sum = work_group_broadcast(adjusted_exp_sum, partition_idx);

SOFTMAX_ACCUMULATOR_TYPE global_exp_sum = work_group_reduce_add(adjusted_exp_sum);

SOFTMAX_ACCUMULATOR_TYPE softmax_value = SOFTMAX_ACCUMULATOR_VAL_ZERO;
if (partition_idx < num_of_partitions) {
const uint input_offset = subsequence_idx * HEAD_SIZE * max_seq_len + i * max_seq_len + partition_global_idx;
softmax_value = softmax_output[input_offset];
}

softmax_value = softmax_value * current_exp_sum / global_exp_sum;
total_score + softmax_value;
}

// WA: need to pass additional input with offsets
uint total_seq_len = 0;
for (uint i = 0; i < subsequence_idx; i++) {
const int subsequence_begin = subsequence_begins[i];
const int subsequence_end = subsequence_begins[i + 1];
total_seq_len += (subsequence_end - subsequence_begin) + past_lens[i];
}

if (partition_global_idx < seq_len) {
scores_output[total_seq_len + partition_global_idx] = softmax_value;
}
}

#endif
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/kernel_selector/common_types.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ enum class KernelType {
BEAM_TABLE_UPDATE,
PA_KV_CACHE_UPDATE,
PA_SDPA,
PA_SCORES_CALCULATION,
CONVOLUTION,
DECONVOLUTION,
DFT,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "pa_scores_calculation_ref.h"
#include "sdpa_kernel_base.h"

#include "kernel_selector_params.h"
#include "kernel_selector_utils.h"

namespace kernel_selector {

constexpr size_t subgroup_size = 16;
constexpr size_t paged_attention_block_size = 16;

static size_t get_generate_stage_block_size(size_t head_size) {
auto preferred_block_size = { 4, 2, 1 };
for (const auto& block_size : preferred_block_size) {
if (head_size % (block_size * subgroup_size) == 0) {
return block_size;
}
}

return 1;
}

KernelsData PAScoresCalculation::GetKernelsData(const Params& p) const {
if (!Validate(p)) {
return {};
}

KernelData kd = KernelData::Default<pa_scores_calculation>(p);
kd.needs_sub_kernels_sync = false;
GetUpdateDispatchDataFunc(kd);

const auto& params = static_cast<const pa_scores_calculation&>(p);
const auto dispatch_data = SetDefault(params);
const auto entry_point = GetEntryPoint(kernelName, params.layerID, p);
const auto jit_constants = GetJitConstants(params);
const auto jit = CreateJit(kernelName, jit_constants, entry_point);

auto& kernel = kd.kernels[0];
FillCLKernelData(kernel,
dispatch_data,
params.engineInfo,
kernelName,
jit,
entry_point,
{},
false,
false,
static_cast<int>(params.inputs.size()),
GetFusedPrimitiveInputsCount(params),
static_cast<int>(params.outputs.size()),
params.is_shape_agnostic);

kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2});
kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0});

ScalarDescriptor is_prefill_stage;
is_prefill_stage.t = ScalarDescriptor::Types::UINT32;
is_prefill_stage.v.u32 = static_cast<uint32_t>(0);
kernel.params.scalars.push_back(is_prefill_stage);

return {kd};
}

ParamsKey PAScoresCalculation::GetSupportedKey() const {
ParamsKey k;

k.EnableInputDataType(Datatype::F16);
k.EnableInputDataType(Datatype::F32);
k.EnableInputDataType(Datatype::INT32);

k.EnableOutputDataType(Datatype::F16);
k.EnableOutputDataType(Datatype::F32);
k.EnableOutputDataType(Datatype::INT32);

k.EnableInputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfyx);
k.EnableOutputLayout(DataLayout::bfzyx);

k.EnableDifferentTypes();
k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDynamicShapesSupport();

return k;
}

bool PAScoresCalculation::Validate(const Params& params) const {
if (params.GetType() != KernelType::PA_SCORES_CALCULATION)
return false;

const auto& kernel_params = dynamic_cast<const pa_scores_calculation&>(params);
if (kernel_params.inputs.size() != 6)
return false;

if (kernel_params.outputs.size() != 2)
return false;

if (!kernel_params.conf.is_paged_attention)
return false;

if (kernel_params.conf.paged_attention_block_size != static_cast<int64_t>(paged_attention_block_size))
return false;

return true;
}

JitConstants PAScoresCalculation::GetJitConstants(const pa_scores_calculation& params) const {
JitConstants jit = MakeBaseParamsJitConstants(params);

jit.AddConstant(MakeJitConstant("HEAD_SIZE", params.conf.head_size));
jit.AddConstant(MakeJitConstant("HEADS_NUM", params.conf.heads_num));
jit.AddConstant(MakeJitConstant("KV_HEADS_NUM", params.conf.kv_heads_num));
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size));
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size));
jit.AddConstant(MakeJitConstant("GENERATE_STAGE_BLOCK_SIZE", get_generate_stage_block_size(params.conf.head_size)));

return jit;
}

CommonDispatchData PAScoresCalculation::SetDefault(const pa_scores_calculation& params) {
CommonDispatchData dispatch_data;

const auto& key_cache = params.outputs[0];
const auto& value_cache = params.outputs[1];
if (!value_cache.is_dynamic() && !key_cache.is_dynamic()) {
auto heads_number = static_cast<size_t>(params.conf.kv_heads_num);

// if (is_prefill) {
// const auto blocks_number = params.conf.paged_attention_aligned_seq_len / paged_attention_block_size;

// dispatch_data.gws = { blocks_number,
// heads_number,
// subgroup_size };
// dispatch_data.lws = { 1, 1, subgroup_size };
// } else {
// const auto& key_input = params.inputs[0];
// const auto sequences_number = key_input.Batch().v;

// dispatch_data.gws = { sequences_number,
// heads_number,
// subgroup_size };
// dispatch_data.lws = { 1, 1, subgroup_size };
// }
}

return dispatch_data;
}

void PAScoresCalculation::GetUpdateDispatchDataFunc(KernelData& kd) const {
kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) {
const auto& prim_params = static_cast<const pa_scores_calculation&>(params);

auto dispatch_data = SetDefault(prim_params);

OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func");
kd.kernels[0].params.workGroups.global = dispatch_data.gws;
kd.kernels[0].params.workGroups.local = dispatch_data.lws;
kd.kernels[0].skip_execution = false;

const auto indexes_dt = Datatype::INT32;
const auto target_seq_len_block_size = 16;
const auto target_seq_len = prim_params.conf.paged_attention_aligned_seq_len;
const auto indexes_buf_size = CeilDiv(target_seq_len, target_seq_len_block_size) * BytesPerElement(indexes_dt);

kd.internalBufferSizes.clear();
kd.internalBufferSizes.push_back(indexes_buf_size);
kd.internalBufferSizes.push_back(indexes_buf_size);
kd.internalBufferSizes.push_back(indexes_buf_size);
kd.internalBufferDataType = indexes_dt;

// kd.kernels[0].params.scalars[0].v.s32 = static_cast<int32_t>(prim_params.is_prefill);
};
}

} // namespace kernel_selector
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "kernel_base_opencl.h"
#include "sdpa_kernel_base.h"

namespace kernel_selector {

struct pa_scores_calculation : base_params {
pa_scores_calculation() : base_params(KernelType::PA_SCORES_CALCULATION) {}

sdpa_configuration conf;
};

class PAScoresCalculation : public KernelBaseOpenCL {
public:
PAScoresCalculation() : KernelBaseOpenCL{"pa_scores_calc"} {}
KernelsData GetKernelsData(const Params& params) const override;
ParamsKey GetSupportedKey() const override;
virtual ~PAScoresCalculation() {}

protected:
bool Validate(const Params& params) const override;
JitConstants GetJitConstants(const pa_scores_calculation& kernel_params) const;
static CommonDispatchData SetDefault(const pa_scores_calculation& kernel_params);
void GetUpdateDispatchDataFunc(KernelData& kd) const override;
};

} // namespace kernel_selector
Loading

0 comments on commit ee8c63f

Please sign in to comment.