Skip to content

Commit

Permalink
WIP: KV-cache initial version
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Sep 27, 2024
1 parent f00ac41 commit fd905b3
Show file tree
Hide file tree
Showing 50 changed files with 1,206 additions and 97 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class IndirectSDPA : public ov::intel_gpu::op::SDPA {
IndirectSDPA(const OutputVector& data_inputs,
const ov::Output<Node>& beam_table,
const bool is_causal,
const bool is_kv_compressed,
const int64_t indirect_axis,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
Expand Down
11 changes: 11 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
int64_t concat_axis,
const ov::element::Type output_type = ov::element::undefined);

KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& new_token_scale,
const Output<Node>& beam_idx,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;
Expand All @@ -52,11 +61,13 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
void set_gather_axis(int64_t axis) { m_gather_axis = axis; }

bool get_indirect() const { return m_indirect; }
bool get_compressed() const { return m_compressed; }

private:
int64_t m_concat_axis = 0;
int64_t m_gather_axis = 0;
bool m_indirect = false;
bool m_compressed = false;
ov::element::Type m_output_type;
};

Expand Down
24 changes: 23 additions & 1 deletion src/plugins/intel_gpu/include/intel_gpu/op/read_value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace ov {
namespace intel_gpu {
namespace op {

/// \brief Similar to common v6::ReadValue, but it's not derived from ReadValueBase class to avoid ReadValue-Assign pairing check
/// \brief Similar to common v6::CompressedReadValue, but it's not derived from ReadValueBase class to avoid ReadValue-Assign pairing check
/// This is needed to have ReadValue-KVCache pair instead of ReadValue-Assign
class ReadValue : public ov::op::Op, public ov::op::util::VariableExtension {
public:
Expand All @@ -35,6 +35,28 @@ class ReadValue : public ov::op::Op, public ov::op::util::VariableExtension {
}
};

/// \brief Similar to common v6::ReadValue, but it's not derived from ReadValueBase class to avoid ReadValue-Assign pairing check
/// This is needed to have ReadValue-KVCache pair instead of ReadValue-Assign
class CompressedReadValue : public ReadValue {
public:
OPENVINO_OP("CompressedReadValue", "gpu_opset");

CompressedReadValue() = default;

CompressedReadValue(const Output<Node>& compressed_variable_initializer, const Output<Node>& compressed_variable_initializer_scale, const std::shared_ptr<ov::op::util::Variable>& variable);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::string get_variable_id() const override {
OPENVINO_ASSERT(m_variable, "Variable is not initialized. Variable_id is unavailable");
return m_variable->get_info().variable_id;
}
};

} // namespace op
} // namespace intel_gpu
} // namespace ov
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class SDPA : public ov::op::v13::ScaledDotProductAttention {

SDPA(const OutputVector& inputs,
const bool is_causal,
const bool is_kv_compressed,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
Expand All @@ -34,6 +35,7 @@ class SDPA : public ov::op::v13::ScaledDotProductAttention {
std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

bool get_causal() const { return m_is_causal; }
bool get_kv_compressed() const { return m_is_kv_compressed; }

std::vector<int64_t> get_input0_transpose_order() const { return m_order_q; }
std::vector<int64_t> get_input1_transpose_order() const { return m_order_k; }
Expand All @@ -49,6 +51,7 @@ class SDPA : public ov::op::v13::ScaledDotProductAttention {

protected:
bool m_is_causal;
bool m_is_kv_compressed;
std::vector<int64_t> m_order_q;
std::vector<int64_t> m_order_k;
std::vector<int64_t> m_order_v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ class MultiTensorState : public VariableStateBase {
};

// This is multi-tensor state for Indirect KV-Cache + Gemm pattern
// Internally it stores KV Cache state + Beam Table state
// Internally it stores KV Cache state + Beam Table state (+ scale state for kv cache compression)
class VariableStateIndirectKVCache : public MultiTensorState {
public:
VariableStateIndirectKVCache(const VariableStateInfo& info,
std::shared_ptr<RemoteContextImpl> context,
std::shared_ptr<cldnn::ShapePredictor> shape_predictor,
size_t beam_idx,
size_t concat_idx);
size_t concat_idx,
bool has_compression_scale = false);
using Ptr = std::shared_ptr<VariableStateIndirectKVCache>;

void reset() override;
Expand All @@ -41,9 +42,13 @@ class VariableStateIndirectKVCache : public MultiTensorState {
VariableState::Ptr get_beam_table_state() const;
ov::PartialShape get_beam_table_shape(const ov::PartialShape& kv_cache_shape);

VariableState::Ptr get_compression_scale_state() const;
ov::PartialShape get_compression_scale_shape(const ov::PartialShape& kv_cache_shape);

private:
size_t m_beam_axis = 0;
size_t m_concat_axis = 0;
bool m_has_compression_scale = false;
};

} // namespace intel_gpu
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ REGISTER_FACTORY(internal, RMS);
REGISTER_FACTORY(internal, GatherCompressed);
REGISTER_FACTORY(internal, KVCache);
REGISTER_FACTORY(internal, ReadValue);
REGISTER_FACTORY(internal, CompressedReadValue);
REGISTER_FACTORY(internal, Gemm);
REGISTER_FACTORY(internal, SwiGLU);
REGISTER_FACTORY(internal, IndirectGemm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,26 +12,26 @@ namespace cldnn {
struct dynamic_quantize : public primitive_base<dynamic_quantize> {
CLDNN_DECLARE_PRIMITIVE(dynamic_quantize);

dynamic_quantize() : primitive_base("", {}), group_size(0) {}
dynamic_quantize() : primitive_base("", {}), group_sizes{} {}

/// @brief Constructs dynamic_quantize primitive
/// @param id This primitive id
/// @param input Input primitive id
/// @param group_size Quantization group size
/// @param group_sizes Quantization group size
/// @param data_type Output data type of quantized
/// @param output_size Output data size of the primitive
dynamic_quantize(const primitive_id& id,
const input_info& input,
const uint64_t group_size,
const std::vector<uint64_t>& group_sizes,
const std::vector<optional_data_type> data_types = {optional_data_type(data_types::f16), optional_data_type(data_types::i8)})
: primitive_base(id, {input}, 2, data_types),
group_size(group_size) {}
group_sizes(group_sizes) {}

uint64_t group_size = 0;
std::vector<uint64_t> group_sizes;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, group_size);
seed = hash_range(seed, group_sizes.begin(), group_sizes.end());
return seed;
}

Expand All @@ -41,17 +41,17 @@ struct dynamic_quantize : public primitive_base<dynamic_quantize> {

auto rhs_casted = downcast<const dynamic_quantize>(rhs);

return group_size == rhs_casted.group_size;
return group_sizes == rhs_casted.group_sizes;
}

void save(BinaryOutputBuffer& ob) const override {
primitive_base<dynamic_quantize>::save(ob);
ob << group_size;
ob << group_sizes;
}

void load(BinaryInputBuffer& ib) override {
primitive_base<dynamic_quantize>::load(ib);
ib >> group_size;
ib >> group_sizes;
}
};
} // namespace cldnn
13 changes: 10 additions & 3 deletions src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,23 +21,27 @@ struct kv_cache : public primitive_base<kv_cache> {
const ov::op::util::VariableInfo& variable_info,
const int64_t concat_axis,
const int64_t gather_axis,
const bool indirect)
const bool indirect,
const bool compressed)
: primitive_base(id, inputs)
, variable_info(variable_info)
, concat_axis(concat_axis)
, gather_axis(gather_axis)
, indirect(indirect) {}
, indirect(indirect)
, compressed(compressed) {}

ov::op::util::VariableInfo variable_info;
int64_t concat_axis = 0;
int64_t gather_axis = 0;
bool indirect = false;
bool compressed = false;

size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, concat_axis);
seed = hash_combine(seed, gather_axis);
seed = hash_combine(seed, indirect);
seed = hash_combine(seed, compressed);
return seed;
}

Expand All @@ -50,7 +54,8 @@ struct kv_cache : public primitive_base<kv_cache> {
return variable_info == rhs_casted.variable_info &&
concat_axis == rhs_casted.concat_axis &&
gather_axis == rhs_casted.gather_axis &&
indirect == rhs_casted.indirect;
indirect == rhs_casted.indirect &&
compressed == rhs_casted.compressed;
}

void save(BinaryOutputBuffer& ob) const override {
Expand All @@ -62,6 +67,7 @@ struct kv_cache : public primitive_base<kv_cache> {
ob << concat_axis;
ob << gather_axis;
ob << indirect;
ob << compressed;
}

void load(BinaryInputBuffer& ib) override {
Expand All @@ -76,6 +82,7 @@ struct kv_cache : public primitive_base<kv_cache> {
ib >> concat_axis;
ib >> gather_axis;
ib >> indirect;
ib >> compressed;
}
};
} // namespace cldnn
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,19 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
/// @param id This primitive id.
/// @param inputs Input data primitives id (query, keys, values, [attention_mask], [scale]).
/// @param is_causal If true, assumes causal attention masking. In this case attention_mask input is ignored.
/// @param is_kv_compressed If true, assumes KV cache is compressed into int8.
scaled_dot_product_attention(const primitive_id& id,
const std::vector<cldnn::input_info> inputs,
bool is_causal,
bool is_kv_compressed = false,
int64_t indirect_axis = -1,
const std::vector<int64_t>& input_q_transpose_order = {},
const std::vector<int64_t>& input_k_transpose_order = {},
const std::vector<int64_t>& input_v_transpose_order = {},
const std::vector<int64_t>& output_transpose_order = {})
: primitive_base(id, inputs)
, is_causal(is_causal)
, is_kv_compressed(is_kv_compressed)
, indirect_axis(indirect_axis)
, input_q_transpose_order(input_q_transpose_order)
, input_k_transpose_order(input_k_transpose_order)
Expand All @@ -34,12 +37,13 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
auto data_inputs_num = inputs.size();
if (indirect_axis != -1)
data_inputs_num--;

has_attn_mask_input = data_inputs_num > 3;
has_scale_input = data_inputs_num > 4;
size_t scale_value_cnt = is_kv_compressed ? 2 : 0;
has_attn_mask_input = data_inputs_num > 3 + scale_value_cnt;
has_scale_input = data_inputs_num > 4 + scale_value_cnt;
}

bool is_causal = false;
bool is_kv_compressed = false;
bool has_attn_mask_input = false;
bool has_scale_input = false;
int64_t indirect_axis = -1;
Expand All @@ -52,6 +56,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
size_t hash() const override {
size_t seed = primitive::hash();
seed = hash_combine(seed, is_causal);
seed = hash_combine(seed, is_kv_compressed);
seed = hash_combine(seed, has_attn_mask_input);
seed = hash_combine(seed, has_scale_input);
seed = hash_combine(seed, indirect_axis);
Expand All @@ -69,6 +74,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
auto rhs_casted = downcast<const scaled_dot_product_attention>(rhs);

return is_causal == rhs_casted.is_causal &&
is_kv_compressed == rhs_casted.is_kv_compressed &&
has_attn_mask_input == rhs_casted.has_attn_mask_input &&
has_scale_input == rhs_casted.has_scale_input &&
indirect_axis == rhs_casted.indirect_axis &&
Expand All @@ -81,6 +87,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
void save(BinaryOutputBuffer& ob) const override {
primitive_base<scaled_dot_product_attention>::save(ob);
ob << is_causal;
ob << is_kv_compressed;
ob << has_attn_mask_input;
ob << has_scale_input;
ob << indirect_axis;
Expand All @@ -93,6 +100,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
void load(BinaryInputBuffer& ib) override {
primitive_base<scaled_dot_product_attention>::load(ib);
ib >> is_causal;
ib >> is_kv_compressed;
ib >> has_attn_mask_input;
ib >> has_scale_input;
ib >> indirect_axis;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,7 @@ class debug_configuration {
std::vector<std::string> dynamic_quantize_layers_without_onednn; // Specify Fully-connected layers which enable Dynamic quantization
int dynamic_quantize_group_size; // Enable Dynamic quantization for fully connected primitive by specified group size
int disable_horizontal_fc_fusion; // Disable fc horizontal fusion
int enable_kv_cache_compression; // Enable KV cache compression
std::set<int64_t> dump_iteration; // Dump n-th execution of network.
std::vector<std::string> load_layers_raw_dump; // List of layers to load dumped raw binary and filenames
static const debug_configuration *get_instance();
Expand Down
18 changes: 11 additions & 7 deletions src/plugins/intel_gpu/src/graph/dynamic_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,29 +22,27 @@ layout dynamic_quantize_inst::calc_output_layout(dynamic_quantize_node const& no
}

template<typename ShapeType>
std::vector<layout> dynamic_quantize_inst::__calc_output_layouts(const layout &act_layout, uint64_t group_size) {
std::vector<layout> dynamic_quantize_inst::__calc_output_layouts(const layout &act_layout, const std::vector<uint64_t>& group_sizes) {
ov::op::internal::DynamicQuantize op;
auto output_format = act_layout.format;

std::vector<ShapeType> input_shapes = {
act_layout.get<ShapeType>(),
};

std::vector<uint64_t> shape_group_size(act_layout.get<ShapeType>().size(), 1);
shape_group_size.back() = group_size;

auto output_shapes = ov::op::internal::DynamicQuantize::shape_infer(&op, input_shapes, shape_group_size);
auto output_shapes = ov::op::internal::DynamicQuantize::shape_infer(&op, input_shapes, group_sizes);
GPU_DEBUG_TRACE_DETAIL << "shape infer dynamic" << output_shapes[0] << " " << output_shapes[1] << "\n";

return { layout(output_shapes[0], data_types::i8, output_format), layout(output_shapes[1], data_types::f16, output_format) };
}

template std::vector<layout> dynamic_quantize_inst::__calc_output_layouts<ov::PartialShape>(const layout &act_layout, uint64_t group_size);
template std::vector<layout> dynamic_quantize_inst::__calc_output_layouts<ov::PartialShape>(const layout &act_layout, const std::vector<uint64_t>& group_sizes);

template<typename ShapeType>
std::vector<layout> dynamic_quantize_inst::calc_output_layouts(dynamic_quantize_node const& /*node*/, const kernel_impl_params& impl_param) {
auto desc = impl_param.typed_desc<dynamic_quantize>();
const auto& input_layout = impl_param.get_input_layout();
return __calc_output_layouts<ov::PartialShape>(input_layout, UINT64_MAX /* TODO: handle group_size here */);
return __calc_output_layouts<ov::PartialShape>(input_layout, desc->group_sizes);
}

template std::vector<layout> dynamic_quantize_inst::calc_output_layouts<ov::PartialShape>(dynamic_quantize_node const& node,
Expand All @@ -56,6 +54,12 @@ std::string dynamic_quantize_inst::to_string(dynamic_quantize_node const& node)

std::stringstream primitive_description;

json_composite dynamic_quantize_info;
dynamic_quantize_info.add("group size", desc->group_sizes);
dynamic_quantize_info.add("activation dt", desc->get_output_data_type(0).value_or(data_types::undefined));
dynamic_quantize_info.add("scale dt", desc->get_output_data_type(1).value_or(data_types::undefined));

node_info->add("dynamic_quantize info", dynamic_quantize_info);
node_info->dump(primitive_description);

return primitive_description.str();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "program_helpers.h"

#include "intel_gpu/runtime/itt.hpp"
#include "intel_gpu/runtime/debug_configuration.hpp"

using namespace cldnn;

Expand All @@ -19,13 +20,15 @@ void build_implementations::run(program& p) {
for (auto& n : p.get_processing_order()) {
if (auto impl = n->get_selected_impl()) {
auto params = n->get_kernel_impl_params();
GPU_DEBUG_TRACE << "add_kernels_source: " << params->desc->id << std::endl;
cache.add_kernels_source(*params, impl->get_kernels_source());
}
}
cache.build_all();
for (auto& n : p.get_processing_order()) {
if (auto impl = n->get_selected_impl()) {
auto params = n->get_kernel_impl_params();
GPU_DEBUG_TRACE << "init_kernels: " << params->desc->id << std::endl;
impl->init_kernels(cache, *params);
impl->reset_kernels_source();
}
Expand Down
Loading

0 comments on commit fd905b3

Please sign in to comment.