Skip to content

Commit

Permalink
WIP: [GPU] KV-cache compression micro_sdpa kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Nov 26, 2024
1 parent f6e0ba0 commit cec52fd
Show file tree
Hide file tree
Showing 18 changed files with 463 additions and 88 deletions.
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/dynamic_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ std::vector<layout> dynamic_quantize_inst::__calc_output_layouts(const layout &a

if (attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric &&
attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) {
GPU_DEBUG_TRACE_DETAIL << "Set 3d output: " << layout(output_shapes[2], attrs.zp_dt, output_format).to_short_string() << "\n";
output_layouts.emplace_back(layout(output_shapes[2], attrs.zp_dt, output_format));
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,7 @@ void kernels_cache::build_batch(const batch_program& batch, compiled_kernels& co

// Run compilation
if (precompiled_kernels.empty()) {
GPU_DEBUG_TRACE_DETAIL << "Compiling " << batch.kernels_counter << " " << batch.has_microkernels << "\n";
cl::Program program(cl_build_device.get_context(), batch.source);
{
OV_ITT_SCOPED_TASK(ov::intel_gpu::itt::domains::intel_gpu_plugin, "KernelsCache::BuildProgram::RunCompilation");
Expand Down
42 changes: 30 additions & 12 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,7 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {

if (desc->get_compression_zp_inputs_num() > 0) {
// Copy zero points to the new buffer if needed
execute_stage(events, instance, res_events, scale_concat_stage, zp_concat_stage);
execute_stage(events, instance, res_events, zp_concat_stage, zp_concat_stage);
}

// Perform dynamic quantization of new token data and append result to the KV-cache
Expand Down Expand Up @@ -417,15 +417,19 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
return params;
}

static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param,
bool is_scale = true,
bool is_shape_agnostic = false) {
auto params = get_default_params<kernel_selector::concatenation_params>(impl_param, is_shape_agnostic);

const auto concat_axis = 2;
params.axis = convert_axis(concat_axis, impl_param.get_output_layout().get_rank());

auto inputs_count = 1;
auto comp_scale_past_layout = impl_param.input_layouts[3];
auto comp_scale_present_layout = impl_param.output_layouts[2];
const auto inputs_count = 1;
const auto input_idx = is_scale ? 3 : 4; // scale or zp
const auto output_idx = is_scale ? 2 : 3; // scale or zp
auto comp_scale_past_layout = impl_param.input_layouts[input_idx];
auto comp_scale_present_layout = impl_param.output_layouts[output_idx];

params.inputs.resize(inputs_count);
params.inputs[0] = convert_data_tensor(comp_scale_past_layout);
Expand All @@ -435,10 +439,10 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;

std::map<size_t, size_t> in_tensor_to_offset_map = {
{0, in_offsets_map.at(3)}, // compression_scale_past
{0, in_offsets_map.at(input_idx)}, // compression_[scale/zp]_past
};
std::map<size_t, size_t> out_tensor_to_offset_map = {
{0, out_offsets_map.at(2)}, // compression_scale_present
{0, out_offsets_map.at(output_idx)}, // compression_[scale/zp]_present
};

params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);
Expand All @@ -451,8 +455,11 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto concat_kernel_params = get_concat_kernel_params(impl_param, impl_param.is_dynamic());
auto& concat_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(concat_kernel_selector.get_best_kernel(concat_kernel_params));
const bool indirect = impl_param.typed_desc<kv_cache>()->indirect;
const bool compressed = impl_param.typed_desc<kv_cache>()->compressed;

const auto desc = impl_param.typed_desc<kv_cache>();
const bool indirect = desc->indirect;
const bool compressed = desc->compressed;
const bool has_zp_input = desc->get_compression_zp_inputs_num() > 0;
if (indirect) {
auto bt_update_kernel_params = get_bt_update_kernel_params(impl_param, false);
auto& bt_update_kernel_selector = bt_kernel_selector_t::Instance();
Expand All @@ -464,9 +471,14 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto& dq_kernel_selector = dq_kernel_selector_t::Instance();
kernels_data.push_back(dq_kernel_selector.get_best_kernel(dq_kernel_params));

auto concat_scale_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic());
auto& concat_scale_zp_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_zp_kernel_params));
auto concat_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic());
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_scale_kernel_params));

if (has_zp_input) {
auto concat_zp_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic());
kernels_data.push_back(concat_scale_zp_kernel_selector.get_best_kernel(concat_zp_kernel_params));
}
}
return cldnn::make_unique<kv_cache_impl>(kernels_data);
}
Expand Down Expand Up @@ -494,9 +506,15 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
_kernels_data[concat_stage].kernels[1].skip_execution = true;

// Update dynamic quantization parameters
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic());
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, true, impl_param.is_dynamic());
(_kernels_data[scale_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[scale_concat_stage]);
_kernels_data[scale_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(3).count() == 0;

if (impl_param.typed_desc<kv_cache>()->get_compression_zp_inputs_num() > 0) {
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, false, impl_param.is_dynamic());
(_kernels_data[zp_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[zp_concat_stage]);
_kernels_data[zp_concat_stage].kernels[0].skip_execution = impl_param._can_be_optimized || impl_param.get_input_layout(4).count() == 0;
}
}
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
for (size_t i = 0; i < instance.get_intermediates_memories().size(); i++)
args.intermediates.push_back(instance.get_intermediates_memories()[i]);

GPU_DEBUG_TRACE_DETAIL << "Configured kernel arguments:\n";
for (size_t i = 0; i < _kernels_data[stage].kernels[kd_idx].params.arguments.size(); i++) {
GPU_DEBUG_TRACE_DETAIL << "\t" << i << ": type=" << static_cast<int>(_kernels_data[stage].kernels[kd_idx].params.arguments[i].t) << " "
<< "index=" << _kernels_data[stage].kernels[kd_idx].params.arguments[i].index << "\n";
}

GPU_DEBUG_TRACE_DETAIL << "Memory buffers:"
<< "shape_info=" << args.shape_info << " "
<< "inputs=" << args.inputs.size() << " "
<< "outputs=" << args.outputs.size() << " "
<< "intermediates=" << args.intermediates.size() << " "
<< "weights=" << args.weights << " "
<< "scalars=" << (args.scalars ? args.scalars->size() : 0) << "\n";

stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args);

const auto& gws = params.workGroups.global;
Expand Down Expand Up @@ -241,7 +255,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
if (has_indirect_inputs(impl_param))
data_inputs_num--;

auto has_zp_input_buffers = false;
auto has_zp_input_buffers = desc->get_compression_zp_inputs_num() > 0;
if (desc->is_kv_compressed) {
data_inputs_num -= 2; // key and value compression scales are handled separately

Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/include/read_value_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@ class typed_primitive_inst<read_value> : public typed_primitive_inst_base<read_v

for (size_t i = 0; i < desc->num_outputs; i++) {
const auto& default_layout = desc->output_layouts[i];
// if (impl_param.state_layouts.size() <= i)
// std::cout << "Use default layout\n";
output_layouts.push_back(impl_param.state_layouts.size() > i ? impl_param.state_layouts[i] : default_layout);
}

Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/src/graph/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ std::vector<layout> kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no
for (size_t i = 0; i < desc->num_outputs; i++) {
auto out_type = desc->output_data_types[i].value_or(impl_param.get_input_layout(ports_map.at(i)).data_type);
out_layouts.emplace_back(output_shapes[i], out_type, impl_param.get_output_layout(i).format);
GPU_DEBUG_TRACE_DETAIL << "NEW: kv_cache " << i << ": " << output_shapes[i] << " " << out_type << "\n";
}

return out_layouts;
Expand Down
9 changes: 7 additions & 2 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ void primitive_inst::update_shape() {
if (compressed_cache_variable->has_zp_state()) {
auto scales_state = compressed_cache_variable->get_compression_zp_state();
auto new_zp_layout = compressed_cache_variable->get_compression_zp_state()->get_layout();
GPU_DEBUG_TRACE_DETAIL << "NEW: Update state_layouts:" << new_zp_layout << "\n";
update_state_layout(*scales_state, new_zp_layout, 2);
}
}
Expand Down Expand Up @@ -969,8 +970,9 @@ void primitive_inst::realloc_if_needed() {
compressed_cache_variable->get_compression_scale_state()->set_memory(_outputs[2], present_scales_layout);
if (compressed_cache_variable->has_zp_state()) {
auto present_zp_layout = present_scales_layout;
present_zp_layout.data_type = _impl_params->output_layouts[3].data_type;

_impl_params->output_layouts[3] = present_scales_layout;
_impl_params->output_layouts[3] = present_zp_layout;
compressed_cache_variable->get_compression_zp_state()->set_memory(_outputs[3], present_zp_layout);
}
}
Expand Down Expand Up @@ -1360,7 +1362,7 @@ void primitive_inst::do_runtime_in_place_kv_cache() {
GPU_DEBUG_TRACE_DETAIL << "[do runtime_in_place_kv_cache] " << id()
<< " Updated present_zp_layout's pad : " << present_scales_layout.to_string() << std::endl;

compressed_cache_variable->get_compression_zp_state()->set_layout(present_scales_layout);
compressed_cache_variable->get_compression_zp_state()->set_layout(present_zp_layout);
}
}

Expand Down Expand Up @@ -2076,6 +2078,9 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
_outputs = allocate_outputs();
}
}
if (_node) {
GPU_DEBUG_TRACE_DETAIL << _node->type()->to_string(*_node) << "\n";
}
_impls_factory = std::make_shared<ImplementationsFactory>(_node);
_impl_params->strm = _network.get_stream_ptr();
for (size_t i = 0; i < get_node().get_output_layouts().size(); ++i) {
Expand Down
7 changes: 5 additions & 2 deletions src/plugins/intel_gpu/src/graph/program.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,11 @@ void program::init_program() {
pm = std::unique_ptr<pass_manager>(new pass_manager(*this));
new_shape_infer = _config.get_property(ov::intel_gpu::allow_new_shape_infer);

if (_task_executor == nullptr)
_task_executor = program::make_task_executor(_config);
if (true) {
auto config = _config;
config.set_property(ov::compilation_num_threads(1));
_task_executor = program::make_task_executor(config);
}
_kernels_cache = std::unique_ptr<kernels_cache>(new kernels_cache(_engine, _config, prog_id, _task_executor,
kernel_selector::KernelBase::get_db().get_batch_headers()));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ KERNEL(dynamic_quantize_gpu_kv_cache)(
min_value = work_group_reduce_min(min_value);
max_value = work_group_reduce_max(max_value);
ACCUMULATOR_TYPE scale = (ACCUMULATOR_TYPE)((CHAR_MAX - CHAR_MIN) / (max_value - min_value));
ACCUMULATOR_TYPE zp = (ACCUMULATOR_TYPE)(-min_value * scale) - CHAR_MAX;
ACCUMULATOR_TYPE zp = (ACCUMULATOR_TYPE)(-min_value * scale) + CHAR_MIN;
#else
max_value = work_group_reduce_max(max_value);
ACCUMULATOR_TYPE scale = 127.0h / max_value;
Expand Down Expand Up @@ -112,7 +112,11 @@ KERNEL(dynamic_quantize_gpu_kv_cache)(
#if GROUP_SCALES_WITH_ZP
output_scale[scale_idx + 1] = zp;
#else
#if OUTPUT2_IS_FP
output_zp[scale_idx] = zp;
#else
output_zp[scale_idx] = convert_char_rte(zp);
#endif
#endif
#else
output_scale[scale_idx] = 1.0h / scale;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ KERNEL(dynamic_quantize_gpu_ref)(

#if ASYMMETRIC_QUANTIZATION
OUTPUT1_TYPE scale = (OUTPUT1_TYPE)((CHAR_MAX - CHAR_MIN) / (max_val - min_val));
OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(-min_val * scale) - CHAR_MAX;
OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(-min_val * scale) + CHAR_MIN;
#else
max_val = work_group_reduce_max(max_val);
OUTPUT1_TYPE scale = 127.0h / max_val;
Expand Down Expand Up @@ -145,6 +145,10 @@ KERNEL(dynamic_quantize_gpu_ref)(
#if ASYMMETRIC_QUANTIZATION && GROUP_SCALES_WITH_ZP
output_scale[scale_idx + 1] = zp;
#elif ASYMMETRIC_QUANTIZATION
output_zp[scale_idx] = zp;
#if OUTPUT2_IS_FP
output_zp[scale_idx] = zp;
#else
output_zp[scale_idx] = convert_char_rte(zp);
#endif
#endif
}
Loading

0 comments on commit cec52fd

Please sign in to comment.