Skip to content

Commit

Permalink
WIP: fusion of KV cache and DQ
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Oct 11, 2024
1 parent b7cfd67 commit 506233c
Show file tree
Hide file tree
Showing 11 changed files with 244 additions and 49 deletions.
19 changes: 17 additions & 2 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,12 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
const Output<Node>& new_token_data,
const Output<Node>& beam_idx,
const Output<Node>& past_scale,
const Output<Node>& new_token_scale,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type compression_type,
const std::vector<uint64_t>& group_sizes,
const std::vector<uint64_t>& scales_output_order,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;
Expand All @@ -62,17 +64,30 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
void set_gather_axis(int64_t axis) { m_gather_axis = axis; }

bool get_indirect() const { return m_indirect; }

bool get_compressed() const { return m_compressed; }
ov::element::Type get_compression_type() const { return m_compression_type; }
const std::vector<uint64_t>& get_group_sizes() const { return m_group_sizes; };
const std::vector<uint64_t>& get_scales_output_order() const { return m_scales_output_order; };

private:
int64_t m_concat_axis = 0;
int64_t m_gather_axis = 0;
bool m_indirect = false;

// KV-cache compression parameters
bool m_compressed = false;
std::vector<uint64_t> m_group_sizes = {};
std::vector<uint64_t> m_scales_output_order = {};
ov::element::Type m_compression_type = ov::element::undefined;

ov::element::Type m_output_type;
};

std::vector<ov::PartialShape> shape_infer(const KVCache* op, std::vector<ov::PartialShape> input_shapes);
std::vector<ov::PartialShape> shape_infer(const KVCache* op,
const std::vector<ov::PartialShape>& input_shapes,
const std::vector<uint64_t>& group_sizes = {},
const std::vector<uint64_t>& scales_output_order = {});

} // namespace op
} // namespace intel_gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ struct kv_cache : public primitive_base<kv_cache> {
int64_t gather_axis = 0;
bool indirect = false;
bool compressed = false;
std::vector<uint64_t> group_sizes = {};
std::vector<uint64_t> scales_output_order = {};
ov::element::Type compression_type = ov::element::undefined;

size_t hash() const override {
size_t seed = primitive::hash();
Expand All @@ -56,6 +59,7 @@ struct kv_cache : public primitive_base<kv_cache> {
gather_axis == rhs_casted.gather_axis &&
indirect == rhs_casted.indirect &&
compressed == rhs_casted.compressed;
// TODO: add here
}

void save(BinaryOutputBuffer& ob) const override {
Expand All @@ -68,6 +72,7 @@ struct kv_cache : public primitive_base<kv_cache> {
ob << gather_axis;
ob << indirect;
ob << compressed;
// TODO: add here
}

void load(BinaryInputBuffer& ib) override {
Expand All @@ -83,6 +88,7 @@ struct kv_cache : public primitive_base<kv_cache> {
ib >> gather_axis;
ib >> indirect;
ib >> compressed;
// TODO: add here
}
};
} // namespace cldnn
120 changes: 105 additions & 15 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#include "concatenation/concatenation_kernel_base.h"
#include "beam_table_update/beam_table_update_kernel_selector.hpp"
#include "beam_table_update/beam_table_update_kernel_ref.hpp"
#include "dynamic_quantize/dynamic_quantize_kernel_selector.h"
#include "dynamic_quantize/dynamic_quantize_kernel_opt_generic.h"
#include "openvino/core/dimension.hpp"

namespace cldnn {
Expand Down Expand Up @@ -58,6 +60,9 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
using bt_kernel_selector_t = kernel_selector::beam_table_update_kernel_selector;
using bt_kernel_params_t = kernel_selector::beam_table_update_params;

using dq_kernel_selector_t = kernel_selector::dynamic_quantize_kernel_selector;
using dq_kernel_params_t = kernel_selector::dynamic_quantize_params;

DECLARE_OBJECT_TYPE_SERIALIZATION(cldnn::ocl::kv_cache_impl)

std::unique_ptr<primitive_impl> clone() const override {
Expand All @@ -66,7 +71,8 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {

const size_t concat_stage = 0;
const size_t beam_table_stage = 1;
const size_t scale_stage = 2;
const size_t scale_concat_stage = 2;
const size_t dq_concat_stage = 3;

cldnn::memory::ptr beam_table_prev = nullptr;
cldnn::memory::ptr beam_table_new = nullptr;
Expand All @@ -86,8 +92,8 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
// FIXME: indirectness and compression are orthogonal feature.
if (_kernels_data.size() == 3) {
auto& scale_kernel_selector = kernel_selector_t::Instance();
auto scale_kernel_impl = scale_kernel_selector.GetImplementation(_kernels_data[scale_stage].kernelName);
scale_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[scale_stage]);
auto scale_kernel_impl = scale_kernel_selector.GetImplementation(_kernels_data[scale_concat_stage].kernelName);
scale_kernel_impl->GetUpdateDispatchDataFunc(_kernels_data[scale_concat_stage]);
}
}
}
Expand All @@ -102,10 +108,13 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
} else if (stage == beam_table_stage) {
args.inputs = { beam_table_prev, instance.input_memory_ptr(2) };
args.outputs = { beam_table_new };
} else if (stage == scale_stage) {
} else if (stage == scale_concat_stage) {
// FIXME: indirectness and compression are orthogonal feature.
args.inputs = { instance.input_memory_ptr(3), instance.input_memory_ptr(4) }; // [past, new, beam_table, past_scale, new_scale]
args.outputs = { compression_scale };
} else if (stage == dq_concat_stage) {
args.inputs = { instance.input_memory_ptr(1) }; // [past, new, beam_table, past_scale, new_scale]
args.outputs = { instance.output_memory_ptr(0), instance.output_memory_ptr(2) };
}

return args;
Expand Down Expand Up @@ -160,6 +169,14 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto& variable = instance.get_network().get_variable(desc->variable_info.variable_id);
std::vector<event::ptr> res_events;

if (desc->compressed) {
// In case of KV-cache with compression enabled, skip second concat's kernel as new token data append will
// be handled by dynamic quantization kernel
// However, allow execution of the first token for the case if KV-cache can't be optimized (if optimization is disabled, or
// variables memory was reallocated and we have to copy past KV-cache to new memory)
_kernels_data[concat_stage].kernels[1].skip_execution = true;
}

execute_stage(events, instance, res_events, concat_stage);

const auto& impl_param = *instance.get_impl_params();
Expand Down Expand Up @@ -232,11 +249,19 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
GPU_DEBUG_TRACE_DETAIL << "Override Variable memory\n";
comp_scale_state->set_memory(compression_scale, instance.get_impl_params()->output_layouts[2]);

auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, comp_scale_state->is_set());
(_kernels_data[scale_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[scale_stage]);
_kernels_data[scale_stage].kernels[0].skip_execution = skip_first_kernel;
if (!skip_first_kernel) {
auto comp_scale_kernel_params = get_compression_scale_update_kernel_params(impl_param, comp_scale_state->is_set());
(_kernels_data[scale_concat_stage].update_dispatch_data_func)(comp_scale_kernel_params, _kernels_data[scale_concat_stage]);
_kernels_data[scale_concat_stage].kernels[0].skip_execution = skip_first_kernel;
}

execute_stage(events, instance, res_events, scale_concat_stage);


auto dq_params = get_dq_update_kernel_params(impl_param, impl_param.is_dynamic());
(_kernels_data[dq_concat_stage].update_dispatch_data_func)(dq_params, _kernels_data[dq_concat_stage]);
execute_stage(events, instance, res_events, dq_concat_stage);

execute_stage(events, instance, res_events, scale_stage);
comp_scale_state->set();
}

Expand Down Expand Up @@ -293,7 +318,9 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
const auto inputs_count = 2;
params.inputs.resize(inputs_count);
for (size_t i = 0; i < inputs_count; ++i) {
params.inputs[i] = convert_data_tensor(impl_param.input_layouts[i]);
auto tmp = impl_param.input_layouts[i];
tmp.data_type = data_types::i8;
params.inputs[i] = convert_data_tensor(tmp);
}

params.axis = convert_axis(axis, impl_param.get_output_layout().get_rank());
Expand Down Expand Up @@ -354,8 +381,8 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
params.indirect_axis = indirect_axis;

const bool compressed = impl_param.typed_desc<kv_cache>()->compressed;
const auto beam_table_past_idx = compressed ? 5 : 3;
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, compression_scale_past, compression_scale_new, beam_table_past]]
const auto beam_table_past_idx = compressed ? 4 : 3;
const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset; // [kv_past, kv_new_token, [beam_idx, compression_scale_past, beam_table_past]]
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset; // [kv_present, beam_table_present, compression_scale_present]
std::map<size_t, size_t> in_tensor_to_offset_map = {
{0, in_offsets_map.at(beam_table_past_idx)}, // beam_table_past
Expand All @@ -370,6 +397,55 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
return params;
}

static dq_kernel_params_t get_dq_update_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<kv_cache>();
auto params = get_default_params<dq_kernel_params_t>(impl_param, is_shape_agnostic);

params.append_axis = primitive->concat_axis;
params.group_sizes = primitive->group_sizes;
params.scales_output_order = primitive->scales_output_order;

if (!is_shape_agnostic) {
const auto& past_kv_cache_shape = impl_param.input_layouts[0].get_partial_shape();
params.axis_offset = past_kv_cache_shape[primitive->concat_axis].get_length();
} else {
params.axis_offset = 0;
}

auto inputs_count = 1;
auto outputs_count = 2;
params.inputs.resize(inputs_count);
params.outputs.resize(outputs_count);

auto current_token_layout = impl_param.input_layouts[1];
auto present_layout = impl_param.output_layouts[0];
auto present_scales_layout = impl_param.output_layouts[2];
params.inputs[0] = convert_data_tensor(current_token_layout);
params.outputs[0] = convert_data_tensor(present_layout);
params.outputs[1] = convert_data_tensor(present_scales_layout);

const auto& in_offsets_map = impl_param.in_port_to_shape_info_offset;
const auto& out_offsets_map = impl_param.out_port_to_shape_info_offset;

// FIXME: need to handle the index properly when indirect is off
std::map<size_t, size_t> in_tensor_to_offset_map = {
{0, in_offsets_map.at(1)}, // compression_scale_past
};
std::map<size_t, size_t> out_tensor_to_offset_map = {
{0, out_offsets_map.at(0)}, // compression_scale_present
{1, out_offsets_map.at(2)}, // compression_scale_present
};

GPU_DEBUG_TRACE_DETAIL << "DQ shapes: " << current_token_layout.to_short_string() << " " << present_layout.to_short_string() << " " << present_scales_layout.to_short_string() << "\n";
GPU_DEBUG_TRACE_DETAIL << "DQ: Dynamic shape in0 " << in_offsets_map.at(1) << "\n";
GPU_DEBUG_TRACE_DETAIL << "DQ: Dynamic shape out " << out_offsets_map.at(0) << "\n";
GPU_DEBUG_TRACE_DETAIL << "DQ: Dynamic shape out " << out_offsets_map.at(2) << "\n";
params.set_dynamic_shape_offsets(in_tensor_to_offset_map, out_tensor_to_offset_map);

return params;
}


static kernel_params_t get_compression_scale_update_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) {
const auto& primitive = impl_param.typed_desc<kv_cache>();
auto params = get_default_params<kernel_selector::concatenation_params>(impl_param, is_shape_agnostic);
Expand All @@ -379,7 +455,8 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {

auto inputs_count = 2;
auto comp_scale_past_layout = impl_param.input_layouts[3];
auto comp_scale_new_layout = impl_param.input_layouts[4];
auto comp_scale_new_layout = impl_param.input_layouts[4]; // <-- this should be replaced with inner layout

auto comp_scale_present_layout = impl_param.output_layouts[2];

GPU_DEBUG_TRACE_DETAIL << "Past scale: " << comp_scale_past_layout.to_short_string() << "\n";
Expand Down Expand Up @@ -411,24 +488,37 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
return params;
}


static std::unique_ptr<primitive_impl> create(const typed_program_node<kv_cache>& arg, const kernel_impl_params& impl_param) {
std::vector<kernel_selector::kernel_data> kernels_data;
// if (arg.id().find("kvcache:__module.model.transformer.h.0.attn/aten::cat/Concat_4") != std::string::npos)
// std::cout << "mingyuki: create " << arg.id() << std::endl;
GPU_DEBUG_TRACE_DETAIL << "KVCACHE Select concat\n";
GPU_DEBUG_TRACE_DETAIL << "KVCACHE Select concat\n";
auto concat_kernel_params = get_concat_kernel_params(impl_param, impl_param.is_dynamic());
auto& concat_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(concat_kernel_selector.get_best_kernel(concat_kernel_params));
const bool indirect = impl_param.typed_desc<kv_cache>()->indirect;
const bool compressed = impl_param.typed_desc<kv_cache>()->compressed;
GPU_DEBUG_TRACE_DETAIL << "KVCACHE Select beam table\n";
GPU_DEBUG_TRACE_DETAIL << "KVCACHE Select beam table\n";
if (indirect) {
auto bt_update_kernel_params = get_bt_update_kernel_params(impl_param, false);
auto& bt_update_kernel_selector = bt_kernel_selector_t::Instance();
kernels_data.push_back(bt_update_kernel_selector.get_best_kernel(bt_update_kernel_params));
}
GPU_DEBUG_TRACE_DETAIL << "KVCACHE Select DQ\n";
GPU_DEBUG_TRACE_DETAIL << "KVCACHE Select DQ\n";
if (compressed) {
auto comp_scale_update_kernel_params = get_compression_scale_update_kernel_params(impl_param, false);
auto& comp_scale_update_kernel_selector = kernel_selector_t::Instance();
kernels_data.push_back(comp_scale_update_kernel_selector.get_best_kernel(comp_scale_update_kernel_params));
// auto comp_scale_update_kernel_params = get_compression_scale_update_kernel_params(impl_param, impl_param.is_dynamic());
// auto& comp_scale_update_kernel_selector = kernel_selector_t::Instance();
// kernels_data.push_back(comp_scale_update_kernel_selector.get_best_kernel(comp_scale_update_kernel_params));

kernels_data.push_back(kernel_selector::kernel_data());

auto dq_kernel_params = get_dq_update_kernel_params(impl_param, impl_param.is_dynamic());
auto& dq_kernel_selector = dq_kernel_selector_t::Instance();
kernels_data.push_back(dq_kernel_selector.get_best_kernel(dq_kernel_params));
}
return cldnn::make_unique<kv_cache_impl>(kernels_data);
}
Expand Down
5 changes: 3 additions & 2 deletions src/plugins/intel_gpu/src/graph/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,10 +42,11 @@ std::vector<layout> kv_cache_inst::calc_output_layouts(kv_cache_node const& node

if (desc->compressed) {
input_shapes.push_back(impl_param.get_input_layout(3).get<ShapeType>());
input_shapes.push_back(impl_param.get_input_layout(4).get<ShapeType>());
// input_shapes.push_back(impl_param.get_input_layout(4).get<ShapeType>());
}

std::vector<ShapeType> output_shapes = shape_infer(&op, input_shapes);
std::vector<ShapeType> output_shapes = desc->compressed ? shape_infer(&op, input_shapes, desc->group_sizes, desc->scales_output_order)
: shape_infer(&op, input_shapes);

static const std::map<size_t, size_t> ports_map = {{0, 0}, {1, 2}};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@ inline uint FUNC(get_scales_offset_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, u
return OUTPUT1_GET_INDEX(b, f, y, x);
}

inline uint FUNC(get_scales_offset)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint y, uint x) {
inline uint FUNC(get_scales_offset)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint y, uint x, uint axis_offset) {
#ifdef APPEND_MODE
APPEND_AXIS_NAME += axis_offset;
#endif
#ifdef SCALES_OUTPUT_ORDER
return FUNC_CALL(get_scales_offset_nt)(OPTIONAL_SHAPE_INFO_TENSOR SCALES_OUTPUT_ORDER);
#else
Expand All @@ -45,7 +48,11 @@ KERNEL(dynamic_quantize_gpu_opt_generic)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input,
__global OUTPUT_TYPE* output,
__global OUTPUT1_TYPE* output_scale)
__global OUTPUT1_TYPE* output_scale
#ifdef APPEND_MODE
, const uint axis_offset
#endif
)
{
const uint sglid = get_sub_group_local_id();
const uint grouped_indexes = get_global_id(1);
Expand Down Expand Up @@ -75,11 +82,12 @@ KERNEL(dynamic_quantize_gpu_opt_generic)(
OUTPUT_BLOCK_WRITE(output, output_offset + i * SUBGROUP_SIZE, convert_char(val[i] * scale));
}

#ifdef SCALES_OUTPUT_ORDER
const uint scale_idx = FUNC_CALL(get_scales_offset)(OPTIONAL_SHAPE_INFO_TENSOR b, f, y, x);
#ifdef APPEND_MODE
const uint scale_axis_offset = axis_offset;
#else
const uint scale_idx = OUTPUT1_GET_INDEX_SAFE(b, f, y, x);
const uint scale_axis_offset = 0;
#endif
const uint scale_idx = FUNC_CALL(get_scales_offset)(OPTIONAL_SHAPE_INFO_TENSOR b, f, y, x, scale_axis_offset);

if (grouped_indexes == 0 && sglid == 0)
output_scale[scale_idx] = 1.0h / scale;
Expand Down
Loading

0 comments on commit 506233c

Please sign in to comment.