Skip to content

Commit

Permalink
[GPU] Handle runtime scale value for PagedAttention
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Oct 23, 2024
1 parent adeb3d2 commit 61bd104
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 14 deletions.
42 changes: 39 additions & 3 deletions src/plugins/intel_gpu/src/graph/impls/ocl/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
instance.value_memory_ptr(),
instance.subsequence_begins_memory_ptr() };

if (!desc->scale_val.has_value()) {
args.inputs.push_back(instance.input_memory_ptr(9));
}

if (desc->has_alibi) {
args.inputs.push_back(instance.alibi_memory_ptr());
}
Expand All @@ -144,6 +148,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
args.inputs.push_back(instance.subsequence_begins_memory_ptr());
}

if (!desc->scale_val.has_value()) {
args.inputs.push_back(instance.input_memory_ptr(9));
}

if (desc->has_alibi) {
args.inputs.push_back(instance.alibi_memory_ptr());
}
Expand Down Expand Up @@ -343,8 +351,10 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
config.paged_attention_block_size = static_cast<int64_t>(paged_attention::block_size);

if (desc->scale_val.has_value()) {
config.has_scale_val = true;
config.has_const_scale_val = true;
config.scale_val = desc->scale_val.value();
} else {
config.has_const_scale_val = false;
}

if (desc->heads_num != desc->kv_heads_num) {
Expand Down Expand Up @@ -409,16 +419,22 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
}

static sdpa_kernel_params_t get_sdpa_kernel_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) {
const auto desc = impl_param.typed_desc<paged_attention>();
auto params = get_default_params<sdpa_kernel_params_t>(impl_param, is_dynamic);

const auto& query_layout = impl_param.get_input_layout(0);
const auto& key_layout = impl_param.get_input_layout(1);
const auto& value_layout = impl_param.get_input_layout(2);
const auto& subsequence_begins_layout = impl_param.get_input_layout(6);
const auto& scale_layout = impl_param.get_input_layout(9);
const auto& alibi_layout = impl_param.get_input_layout(11);
const auto has_alibi = alibi_layout.count() > 0;
const auto has_scale_input = !desc->scale_val.has_value();

auto inputs_number = 4;
if (has_scale_input)
inputs_number++;

if (has_alibi)
inputs_number++;

Expand All @@ -429,6 +445,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
params.inputs[input_idx++] = convert_data_tensor(value_layout);
params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout);

if (has_scale_input)
params.inputs[input_idx++] = convert_data_tensor(scale_layout);

if (has_alibi)
params.inputs[input_idx++] = convert_data_tensor(alibi_layout);

Expand All @@ -446,8 +465,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
{0, out_offsets_map.at(0)},
};

input_idx = 4;
if (has_scale_input)
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(9)});

if (has_alibi)
in_tensor_to_offset_map.insert({4, in_offsets_map.at(11)});
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)});

if ((stage == PagedAttentionStage::PREFILL || stage == PagedAttentionStage::MIXED) && !is_dynamic)
params.conf.paged_attention_aligned_seq_len = get_aligned_seq_len(impl_param, stage);
Expand All @@ -458,6 +481,7 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
}

static pa_sdpa_kernel_params_t get_pa_sdpa_params(const kernel_impl_params& impl_param, const PagedAttentionStage& stage, bool is_dynamic = false) {
const auto desc = impl_param.typed_desc<paged_attention>();
auto params = get_default_params<pa_sdpa_kernel_params_t>(impl_param, is_dynamic);

const auto& query_layout = impl_param.get_input_layout(0);
Expand All @@ -467,10 +491,15 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
const auto& block_indices_layout = impl_param.get_input_layout(7);
const auto& block_indices_begins_layout = impl_param.get_input_layout(8);
const auto& subsequence_begins_layout = impl_param.get_input_layout(6);
const auto& scale_layout = impl_param.get_input_layout(9);
const auto& alibi_layout = impl_param.get_input_layout(11);
const auto has_alibi = alibi_layout.count() > 0;
const auto has_scale_input = !desc->scale_val.has_value();

auto inputs_number = 7;
if (has_scale_input)
inputs_number++;

if (has_alibi)
inputs_number++;

Expand All @@ -485,6 +514,9 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
params.inputs[input_idx++] = convert_data_tensor(subsequence_begins_layout);
params.conf = get_sdpa_configuration(impl_param);

if (has_scale_input)
params.inputs[input_idx++] = convert_data_tensor(scale_layout);

if (has_alibi)
params.inputs[input_idx++] = convert_data_tensor(alibi_layout);

Expand Down Expand Up @@ -513,8 +545,12 @@ struct paged_attention_impl : multi_stage_primitive<paged_attention> {
{0, out_offsets_map.at(0)},
};

input_idx = 7;
if (has_scale_input)
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(9)});

if (has_alibi)
in_tensor_to_offset_map.insert({7, in_offsets_map.at(11)});
in_tensor_to_offset_map.insert({input_idx++, in_offsets_map.at(11)});

params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,11 @@ KERNEL(pa_sdpa_opt)(
#if MULTI_TOKENS_PROCESSING
const __global INPUT6_TYPE* subsequence_begins,
#endif
#if HAS_SCALE_INPUT
const __global SCALE_INPUT_TYPE* scale,
#endif
#if HAS_ALIBI
const __global INPUT7_TYPE* alibi_slopes,
const __global ALIBI_INPUT_TYPE* alibi_slopes,
#endif
__global OUTPUT_TYPE* output,
__global SOFTMAX_ACCUMULATOR_TYPE* exp_sums,
Expand Down Expand Up @@ -117,6 +120,8 @@ KERNEL(pa_sdpa_opt)(
// Apply scale value directly to the query input to improve accuracy in case of a high range of input data
#ifdef SCALE_VAL
q_val = TO_INPUT0_TYPE(SCALE_VAL) * q_val;
#else
q_val = *scale * q_val;
#endif

slm_query[query_idx_local] = q_val;
Expand All @@ -133,6 +138,8 @@ KERNEL(pa_sdpa_opt)(
// Apply scale value directly to the query input to improve accuracy in case of a high range of input data
#ifdef SCALE_VAL
q_val[i] = TO_INPUT0_TYPE(SCALE_VAL) * q_val[i];
#else
q_val[i] = *scale * q_val[i];
#endif
}
#endif
Expand Down
14 changes: 11 additions & 3 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl
Original file line number Diff line number Diff line change
Expand Up @@ -656,6 +656,14 @@ inline MASK_VECTOR_TYPE FUNC(load_attn_mask)(OPTIONAL_SHAPE_INFO_ARG
return mask_vec;
}

#if IS_PAGED_ATTENTION && HAS_ALIBI
#if HAS_SCALE_INPUT
#define ALIBI_TYPE INPUT5_TYPE
#else
#define ALIBI_TYPE INPUT4_TYPE
#endif
#endif

REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(sdpa_opt)(
OPTIONAL_SHAPE_INFO_ARG
Expand All @@ -664,15 +672,15 @@ KERNEL(sdpa_opt)(
const __global INPUT2_TYPE* value_input,
#if IS_PAGED_ATTENTION
const __global INPUT3_TYPE* subsequence_begins,
#if HAS_ALIBI
const __global INPUT4_TYPE* alibi_slopes,
#endif
#endif
#if HAS_ATTN_MASK_INPUT
const __global INPUT3_TYPE* attn_mask,
#endif
#if HAS_SCALE_INPUT
const __global INPUT4_TYPE* scale,
#endif
#if IS_PAGED_ATTENTION && HAS_ALIBI
const __global ALIBI_TYPE* alibi_slopes,
#endif
__global OUTPUT_TYPE* output,
#ifdef BEAM_TABLE_TYPE
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -176,11 +176,19 @@ JitConstants PagedAttentionSDPAKernelOpt::GetJitConstants(const pa_sdpa_params&
auto sdpa_stage = kernel_idx == KernelsTypes::FINALIZATION || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS ? 1 : 0;
jit.AddConstant(MakeJitConstant("SDPA_STAGE_" + std::to_string(sdpa_stage), 1));

if (config.has_scale_val)
if (config.has_const_scale_val) {
jit.AddConstant(MakeJitConstant("SCALE_VAL", config.scale_val));
} else {
const size_t scale_input_idx = 7;
jit.AddConstant(MakeJitConstant("HAS_SCALE_INPUT", 1));
jit.Merge(MakeTypeJitConstants(params.inputs[scale_input_idx].GetDType(), "SCALE_INPUT"));
}

if (params.conf.has_alibi_input)
if (params.conf.has_alibi_input) {
const size_t alibi_input_idx = config.has_const_scale_val ? 7 : 8;
jit.AddConstant(MakeJitConstant("HAS_ALIBI", 1));
jit.Merge(MakeTypeJitConstants(params.inputs[alibi_input_idx].GetDType(), "ALIBI_INPUT"));
}

if (kernel_idx == KernelsTypes::MULTI_TOKENS || kernel_idx == KernelsTypes::FINALIZATION_MULTI_TOKENS)
jit.AddConstant(MakeJitConstant("MULTI_TOKENS_PROCESSING", 1));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ struct sdpa_configuration {
bool is_paged_attention = false;
int64_t paged_attention_aligned_seq_len = -1;
int64_t paged_attention_block_size = 0;
bool has_scale_val = false;
bool has_const_scale_val = false;
float scale_val = 0.f;
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,12 @@ JitConstants SDPAKernelOpt::GetJitConstants(const sdpa_params& params, size_t ke
jit.AddConstant(MakeJitConstant("HAS_ALIBI", 1));
}

if (params.conf.has_scale_val) {
if (params.conf.has_const_scale_val) {
jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE_INV", 1.0f / params.conf.scale_val));
jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE", params.conf.scale_val));
} else {
GPU_DEBUG_TRACE_DETAIL << "HAS_SCALE_INPUT = 1\n";
jit.AddConstant(MakeJitConstant("HAS_SCALE_INPUT", 1));
}
} else if (params.inputs.size() <= 4) {
jit.AddConstant(MakeJitConstant("STATIC_SCALE_VALUE_INV", std::sqrt(static_cast<float>(params.conf.head_size))));
Expand Down
9 changes: 6 additions & 3 deletions src/plugins/intel_gpu/src/plugin/ops/paged_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,12 @@ static void CreatePagedAttentionExtensionOp(ProgramBuilder& p, const std::shared
const size_t alibi_idx = 11;

std::shared_ptr<ov::op::v0::Constant> scale_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(op->get_input_node_shared_ptr(scale_idx));
OPENVINO_ASSERT(scale_const != nullptr);
OPENVINO_ASSERT(ov::shape_size(scale_const->get_output_shape(0)) == 1);
prim.scale_val = scale_const->cast_vector<float>()[0];
if (scale_const) {
OPENVINO_ASSERT(ov::shape_size(scale_const->get_output_shape(0)) == 1);
prim.scale_val = scale_const->cast_vector<float>()[0];
} else {
prim.scale_val = cldnn::optional_value<float>();
}

std::shared_ptr<ov::op::v0::Constant> alibi_const = std::dynamic_pointer_cast<ov::op::v0::Constant>(op->get_input_node_shared_ptr(alibi_idx));
OPENVINO_ASSERT(alibi_const != nullptr);
Expand Down

0 comments on commit 61bd104

Please sign in to comment.