forked from openvinotoolkit/openvino
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
16 changed files
with
659 additions
and
163 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
214 changes: 150 additions & 64 deletions
214
src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
182 changes: 182 additions & 0 deletions
182
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_scores_calculation_ref.cpp
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#include "pa_scores_calculation_ref.h" | ||
#include "sdpa_kernel_base.h" | ||
|
||
#include "kernel_selector_params.h" | ||
#include "kernel_selector_utils.h" | ||
|
||
namespace kernel_selector { | ||
|
||
constexpr size_t subgroup_size = 16; | ||
constexpr size_t paged_attention_block_size = 16; | ||
|
||
static size_t get_generate_stage_block_size(size_t head_size) { | ||
auto preferred_block_size = { 4, 2, 1 }; | ||
for (const auto& block_size : preferred_block_size) { | ||
if (head_size % (block_size * subgroup_size) == 0) { | ||
return block_size; | ||
} | ||
} | ||
|
||
return 1; | ||
} | ||
|
||
KernelsData PAScoresCalculation::GetKernelsData(const Params& p) const { | ||
if (!Validate(p)) { | ||
return {}; | ||
} | ||
|
||
KernelData kd = KernelData::Default<pa_scores_calculation>(p); | ||
kd.needs_sub_kernels_sync = false; | ||
GetUpdateDispatchDataFunc(kd); | ||
|
||
const auto& params = static_cast<const pa_scores_calculation&>(p); | ||
const auto dispatch_data = SetDefault(params); | ||
const auto entry_point = GetEntryPoint(kernelName, params.layerID, p); | ||
const auto jit_constants = GetJitConstants(params); | ||
const auto jit = CreateJit(kernelName, jit_constants, entry_point); | ||
|
||
auto& kernel = kd.kernels[0]; | ||
FillCLKernelData(kernel, | ||
dispatch_data, | ||
params.engineInfo, | ||
kernelName, | ||
jit, | ||
entry_point, | ||
{}, | ||
false, | ||
false, | ||
static_cast<int>(params.inputs.size()), | ||
GetFusedPrimitiveInputsCount(params), | ||
static_cast<int>(params.outputs.size()), | ||
params.is_shape_agnostic); | ||
|
||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 0}); | ||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 1}); | ||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::INTERNAL_BUFFER, 2}); | ||
kernel.params.arguments.push_back({ArgumentDescriptor::Types::SCALAR, 0}); | ||
|
||
ScalarDescriptor is_prefill_stage; | ||
is_prefill_stage.t = ScalarDescriptor::Types::UINT32; | ||
is_prefill_stage.v.u32 = static_cast<uint32_t>(0); | ||
kernel.params.scalars.push_back(is_prefill_stage); | ||
|
||
return {kd}; | ||
} | ||
|
||
ParamsKey PAScoresCalculation::GetSupportedKey() const { | ||
ParamsKey k; | ||
|
||
k.EnableInputDataType(Datatype::F16); | ||
k.EnableInputDataType(Datatype::F32); | ||
k.EnableInputDataType(Datatype::INT32); | ||
|
||
k.EnableOutputDataType(Datatype::F16); | ||
k.EnableOutputDataType(Datatype::F32); | ||
k.EnableOutputDataType(Datatype::INT32); | ||
|
||
k.EnableInputLayout(DataLayout::bfyx); | ||
k.EnableOutputLayout(DataLayout::bfyx); | ||
k.EnableOutputLayout(DataLayout::bfzyx); | ||
|
||
k.EnableDifferentTypes(); | ||
k.EnableTensorOffset(); | ||
k.EnableTensorPitches(); | ||
k.EnableBatching(); | ||
k.EnableDynamicShapesSupport(); | ||
|
||
return k; | ||
} | ||
|
||
bool PAScoresCalculation::Validate(const Params& params) const { | ||
if (params.GetType() != KernelType::PA_SCORES_CALCULATION) | ||
return false; | ||
|
||
const auto& kernel_params = dynamic_cast<const pa_scores_calculation&>(params); | ||
if (kernel_params.inputs.size() != 6) | ||
return false; | ||
|
||
if (kernel_params.outputs.size() != 2) | ||
return false; | ||
|
||
if (!kernel_params.conf.is_paged_attention) | ||
return false; | ||
|
||
if (kernel_params.conf.paged_attention_block_size != static_cast<int64_t>(paged_attention_block_size)) | ||
return false; | ||
|
||
return true; | ||
} | ||
|
||
JitConstants PAScoresCalculation::GetJitConstants(const pa_scores_calculation& params) const { | ||
JitConstants jit = MakeBaseParamsJitConstants(params); | ||
|
||
jit.AddConstant(MakeJitConstant("HEAD_SIZE", params.conf.head_size)); | ||
jit.AddConstant(MakeJitConstant("HEADS_NUM", params.conf.heads_num)); | ||
jit.AddConstant(MakeJitConstant("KV_HEADS_NUM", params.conf.kv_heads_num)); | ||
jit.AddConstant(MakeJitConstant("PAGED_ATTENTION_BLOCK_SIZE", paged_attention_block_size)); | ||
jit.AddConstant(MakeJitConstant("SUBGROUP_SIZE", subgroup_size)); | ||
jit.AddConstant(MakeJitConstant("GENERATE_STAGE_BLOCK_SIZE", get_generate_stage_block_size(params.conf.head_size))); | ||
|
||
return jit; | ||
} | ||
|
||
CommonDispatchData PAScoresCalculation::SetDefault(const pa_scores_calculation& params) { | ||
CommonDispatchData dispatch_data; | ||
|
||
const auto& key_cache = params.outputs[0]; | ||
const auto& value_cache = params.outputs[1]; | ||
if (!value_cache.is_dynamic() && !key_cache.is_dynamic()) { | ||
auto heads_number = static_cast<size_t>(params.conf.kv_heads_num); | ||
|
||
// if (is_prefill) { | ||
// const auto blocks_number = params.conf.paged_attention_aligned_seq_len / paged_attention_block_size; | ||
|
||
// dispatch_data.gws = { blocks_number, | ||
// heads_number, | ||
// subgroup_size }; | ||
// dispatch_data.lws = { 1, 1, subgroup_size }; | ||
// } else { | ||
// const auto& key_input = params.inputs[0]; | ||
// const auto sequences_number = key_input.Batch().v; | ||
|
||
// dispatch_data.gws = { sequences_number, | ||
// heads_number, | ||
// subgroup_size }; | ||
// dispatch_data.lws = { 1, 1, subgroup_size }; | ||
// } | ||
} | ||
|
||
return dispatch_data; | ||
} | ||
|
||
void PAScoresCalculation::GetUpdateDispatchDataFunc(KernelData& kd) const { | ||
kd.update_dispatch_data_func = [](const Params& params, KernelData& kd) { | ||
const auto& prim_params = static_cast<const pa_scores_calculation&>(params); | ||
|
||
auto dispatch_data = SetDefault(prim_params); | ||
|
||
OPENVINO_ASSERT(kd.kernels.size() == 1, "[GPU] Invalid kernels size for update dispatch data func"); | ||
kd.kernels[0].params.workGroups.global = dispatch_data.gws; | ||
kd.kernels[0].params.workGroups.local = dispatch_data.lws; | ||
kd.kernels[0].skip_execution = false; | ||
|
||
const auto indexes_dt = Datatype::INT32; | ||
const auto target_seq_len_block_size = 16; | ||
const auto target_seq_len = prim_params.conf.paged_attention_aligned_seq_len; | ||
const auto indexes_buf_size = CeilDiv(target_seq_len, target_seq_len_block_size) * BytesPerElement(indexes_dt); | ||
|
||
kd.internalBufferSizes.clear(); | ||
kd.internalBufferSizes.push_back(indexes_buf_size); | ||
kd.internalBufferSizes.push_back(indexes_buf_size); | ||
kd.internalBufferSizes.push_back(indexes_buf_size); | ||
kd.internalBufferDataType = indexes_dt; | ||
|
||
// kd.kernels[0].params.scalars[0].v.s32 = static_cast<int32_t>(prim_params.is_prefill); | ||
}; | ||
} | ||
|
||
} // namespace kernel_selector |
32 changes: 32 additions & 0 deletions
32
src/plugins/intel_gpu/src/kernel_selector/kernels/sdpa/pa_scores_calculation_ref.h
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
// Copyright (C) 2024 Intel Corporation | ||
// SPDX-License-Identifier: Apache-2.0 | ||
// | ||
|
||
#pragma once | ||
|
||
#include "kernel_base_opencl.h" | ||
#include "sdpa_kernel_base.h" | ||
|
||
namespace kernel_selector { | ||
|
||
struct pa_scores_calculation : base_params { | ||
pa_scores_calculation() : base_params(KernelType::PA_SCORES_CALCULATION) {} | ||
|
||
sdpa_configuration conf; | ||
}; | ||
|
||
class PAScoresCalculation : public KernelBaseOpenCL { | ||
public: | ||
PAScoresCalculation() : KernelBaseOpenCL{"pa_scores_calc"} {} | ||
KernelsData GetKernelsData(const Params& params) const override; | ||
ParamsKey GetSupportedKey() const override; | ||
virtual ~PAScoresCalculation() {} | ||
|
||
protected: | ||
bool Validate(const Params& params) const override; | ||
JitConstants GetJitConstants(const pa_scores_calculation& kernel_params) const; | ||
static CommonDispatchData SetDefault(const pa_scores_calculation& kernel_params); | ||
void GetUpdateDispatchDataFunc(KernelData& kd) const override; | ||
}; | ||
|
||
} // namespace kernel_selector |
Oops, something went wrong.