Skip to content

Commit

Permalink
[GPU] KV-cache compression support
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Oct 24, 2024
1 parent c5e16fc commit 1088754
Show file tree
Hide file tree
Showing 69 changed files with 3,084 additions and 281 deletions.
50 changes: 39 additions & 11 deletions src/common/transformations/include/ov_ops/dynamic_quantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,32 +11,60 @@ namespace ov {
namespace op {
namespace internal {

struct QuantizationConfig {
enum class QuantizationMode { Symmetric, Asymmetric };

QuantizationMode mode = QuantizationMode::Symmetric;
element::Type quantization_dt = element::undefined;
element::Type scale_dt = element::undefined;
element::Type zp_dt = element::undefined;
std::vector<uint64_t> group_sizes = {};

bool operator==(const QuantizationConfig& rhs) const {
return mode == rhs.mode && quantization_dt == rhs.quantization_dt && scale_dt == rhs.scale_dt &&
zp_dt == rhs.zp_dt && group_sizes == rhs.group_sizes;
}

bool is_asymmetric_quantization() const {
return mode == QuantizationMode::Asymmetric;
}
};

/// \brief Operator performing Dynamic Quantize
class TRANSFORMATIONS_API DynamicQuantize : public ov::op::Op {
public:
OPENVINO_OP("DynamicQuantize", "gpu_opset");

OPENVINO_OP("DynamicQuantize", "ie_internal_opset");
DynamicQuantize() = default;
/// \brief Constructs an DynamicQuantize operation.
///
/// \param data Input tensor with data
/// \param group_sizes Group sizes for dynamic quantization
/// \param dt_scale Data type for scale output
DynamicQuantize(const Output<Node>& data, std::vector<uint64_t> group_sizes, element::Type dt_scale);
/// \param config Dynamic quantization configuration
DynamicQuantize(const Output<Node>& data, const QuantizationConfig& config);

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

const std::vector<uint64_t>& get_group_sizes() const {
return m_group_sizes;
};
return m_config.group_sizes;
}

QuantizationConfig::QuantizationMode get_quantization_mode() const {
return m_config.mode;
}

QuantizationConfig get_quantization_config() const {
return m_config;
}

static std::vector<ov::PartialShape> shape_infer(const DynamicQuantize* op,
const std::vector<ov::PartialShape>& input_shapes,
const std::vector<uint64_t>& group_sizes);
const QuantizationConfig& config);

protected:
DynamicQuantize(const Output<Node>& data, const QuantizationConfig& config, size_t outputs_number);

private:
std::vector<uint64_t> m_group_sizes;
element::Type m_dt_scale;
QuantizationConfig m_config;
};

} // namespace internal
Expand Down
36 changes: 24 additions & 12 deletions src/common/transformations/src/ov_ops/dynamic_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,35 +13,42 @@ namespace ov {
namespace op {
namespace internal {

DynamicQuantize::DynamicQuantize(const Output<Node>& data, std::vector<uint64_t> group_sizes, element::Type dt_scale)
DynamicQuantize::DynamicQuantize(const Output<Node>& data, const QuantizationConfig& config, size_t outputs_number)
: Op({data}),
m_group_sizes(std::move(group_sizes)),
m_dt_scale(dt_scale) {
OPENVINO_ASSERT(data.get_partial_shape().rank() == m_group_sizes.size(),
m_config(config) {
OPENVINO_ASSERT(data.get_partial_shape().rank() == m_config.group_sizes.size(),
"FC input rank should be same as the rank of group_size ",
data.get_tensor_ptr()->get_partial_shape().rank(),
" / ",
m_group_sizes.size());
set_output_size(2);
m_config.group_sizes.size());
set_output_size(outputs_number);
}

DynamicQuantize::DynamicQuantize(const Output<Node>& data, const QuantizationConfig& config)
: DynamicQuantize(data, config, config.mode == QuantizationConfig::QuantizationMode::Symmetric ? 2 : 3) {
validate_and_infer_types();
}

void DynamicQuantize::validate_and_infer_types() {
std::vector<ov::PartialShape> input_shapes = {get_input_partial_shape(0)};

auto out_shapes = shape_infer(this, input_shapes, m_group_sizes);
set_output_type(0, element::i8, out_shapes[0]);
set_output_type(1, m_dt_scale, out_shapes[1]);
auto out_shapes = shape_infer(this, input_shapes, m_config);
set_output_type(0, m_config.quantization_dt, out_shapes[0]);
set_output_type(1, m_config.scale_dt, out_shapes[1]);

if (m_config.is_asymmetric_quantization())
set_output_type(2, m_config.zp_dt, out_shapes[2]);
}

std::shared_ptr<Node> DynamicQuantize::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<DynamicQuantize>(new_args.at(0), m_group_sizes, m_dt_scale);
return std::make_shared<DynamicQuantize>(new_args.at(0), m_config);
}

std::vector<ov::PartialShape> DynamicQuantize::shape_infer(const DynamicQuantize* op,
const std::vector<ov::PartialShape>& input_shapes,
const std::vector<uint64_t>& group_sizes) {
const QuantizationConfig& config) {
const auto& group_sizes = config.group_sizes;
std::vector<ov::PartialShape> out_shapes;
out_shapes.push_back(input_shapes[0]);

Expand All @@ -52,7 +59,7 @@ std::vector<ov::PartialShape> DynamicQuantize::shape_infer(const DynamicQuantize
" / ",
group_sizes.size());
for (size_t i = 0; i < scale_shape.size(); i++) {
if (scale_shape[i].is_dynamic())
if (scale_shape[i].is_dynamic() || scale_shape[i] == 0)
continue;

if (group_sizes[i] == UINT64_MAX)
Expand All @@ -63,6 +70,11 @@ std::vector<ov::PartialShape> DynamicQuantize::shape_infer(const DynamicQuantize
}
}
out_shapes.push_back(scale_shape);

// Add zero points shape
if (config.is_asymmetric_quantization())
out_shapes.push_back(scale_shape);

return out_shapes;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ struct kernel_impl_params final {
optional_layout weights_zero_points_layout = optional_layout();
optional_layout activations_zero_points_layout = optional_layout();
optional_layout compensation_layout = optional_layout();
optional_layout state_layout = optional_layout();
std::vector<layout> state_layouts;

std::map<size_t, memory::ptr> memory_deps = {};
size_t primary_input_idx = 0;
Expand Down
57 changes: 57 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/dynamic_quantize.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "openvino/op/op.hpp"
#include "ov_ops/dynamic_quantize.hpp"

namespace ov {
namespace intel_gpu {
namespace op {

class DynamicQuantize : public ov::op::internal::DynamicQuantize {
public:
OPENVINO_OP("DynamicQuantize", "gpu_opset");

using QuantizationConfig = ov::op::internal::QuantizationConfig;

DynamicQuantize() = default;
/// \brief Constructs an DynamicQuantize operation.
///
/// \param data Input tensor with data
/// \param config Dynamic quantization configuration
/// \param scales_zp_output_order Specifies on default order of scales and zero points
/// \param combine_scales_and_zp If true, combines scales and zero points into a single buffer, pairing each scale with its corresponding zero point
DynamicQuantize(const Output<Node>& data,
const QuantizationConfig& config,
const std::vector<uint64_t>& scales_zp_output_order = {},
const bool combine_scales_and_zp = false);

void validate_and_infer_types() override;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

const std::vector<uint64_t>& get_scales_zp_output_order() const {
return m_scales_zp_output_order;
}

bool get_combine_scales_and_zp() const {
return m_combine_scales_and_zp;
}

static std::vector<ov::PartialShape> shape_infer(const DynamicQuantize* op,
const std::vector<ov::PartialShape>& input_shapes,
const QuantizationConfig& config,
const std::vector<uint64_t>& scales_zp_output_order,
const bool combine_scales_and_zp = false);

private:
bool m_combine_scales_and_zp = false;
std::vector<uint64_t> m_scales_zp_output_order;
};

} // namespace op
} // namespace intel_gpu
} // namespace ov
12 changes: 12 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,18 @@ class IndirectSDPA : public ov::intel_gpu::op::SDPA {
const std::vector<int64_t>& order_out,
const ov::element::Type output_type = ov::element::undefined);

IndirectSDPA(const OutputVector& data_inputs,
const ov::Output<Node>& beam_table,
const bool is_causal,
const int64_t indirect_axis,
const std::vector<int64_t>& order_q,
const std::vector<int64_t>& order_k,
const std::vector<int64_t>& order_v,
const std::vector<int64_t>& order_out,
const QuantizationConfig& quantization_config,
const bool combine_scales_and_zp,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor &visitor) override;
void validate_and_infer_types() override;

Expand Down
35 changes: 32 additions & 3 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
#include "openvino/op/op.hpp"
#include "openvino/op/util/variable.hpp"
#include "openvino/op/util/variable_extension.hpp"
#include "intel_gpu/op/dynamic_quantize.hpp"

namespace ov {
namespace intel_gpu {
Expand All @@ -18,20 +19,31 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
public:
OPENVINO_OP("KVCache", "gpu_opset");

using QuantizationConfig = ov::op::internal::QuantizationConfig;

KVCache() = default;

KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& beam_idx,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::undefined);

KVCache(const Output<Node>& past,
const Output<Node>& new_token_data,
const Output<Node>& beam_idx,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
const ov::element::Type output_type = ov::element::undefined);

KVCache(const OutputVector& inputs,
const std::shared_ptr<ov::op::util::Variable>& past_values,
int64_t concat_axis,
int64_t gather_axis,
bool combine_scales_and_zp,
const QuantizationConfig& config,
const std::vector<uint64_t>& scales_zp_output_order,
const ov::element::Type output_type = ov::element::undefined);

bool visit_attributes(ov::AttributeVisitor& visitor) override;
Expand All @@ -53,14 +65,31 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {

bool get_indirect() const { return m_indirect; }

bool get_kv_compressed() const { return m_compressed; }
bool get_combine_scales_and_zp() const { return m_combine_scales_and_zp; }
QuantizationConfig get_quantization_config() const { return m_quantization_config; }
std::vector<uint64_t> get_scales_zp_output_order() const { return m_scales_zp_output_order; }

private:
int64_t m_concat_axis = 0;
int64_t m_gather_axis = 0;
bool m_indirect = false;

bool m_compressed = false;
bool m_combine_scales_and_zp = false;
QuantizationConfig m_quantization_config = {};
std::vector<uint64_t> m_scales_zp_output_order = {};

ov::element::Type m_output_type;
};

std::vector<ov::PartialShape> shape_infer(const KVCache* op, std::vector<ov::PartialShape> input_shapes);
std::vector<ov::PartialShape> shape_infer(const KVCache* op, const std::vector<ov::PartialShape>& input_shapes);

std::vector<ov::PartialShape> shape_infer(const KVCache* op,
const std::vector<ov::PartialShape>& input_shapes,
const ov::op::internal::QuantizationConfig& config,
const std::vector<uint64_t>& scales_output_order = {},
bool combine_scales_and_zp = false);

} // namespace op
} // namespace intel_gpu
Expand Down
7 changes: 7 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/read_value.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,20 @@ class ReadValue : public ov::op::Op, public ov::op::util::VariableExtension {
bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;
void validate_and_infer_types(size_t output_idx, const ov::op::util::VariableInfo& variable_info);

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

std::string get_variable_id() const override {
OPENVINO_ASSERT(m_variable, "Variable is not initialized. Variable_id is unavailable");
return m_variable->get_info().variable_id;
}

protected:
ReadValue(const std::vector<Output<Node>>& variable_initializers, const std::shared_ptr<ov::op::util::Variable>& variable)
: Op(variable_initializers) {
m_variable = variable;
}
};

} // namespace op
Expand Down
42 changes: 42 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/read_values.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
// Copyright (C) 2023 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "intel_gpu/op/read_value.hpp"

namespace ov {
namespace intel_gpu {
namespace op {

/// \brief This operation handles the OpenVINO GPU Plugin's custom variable
// representation (which can store multiple states in a single variable) at the graph level.
class ReadValues : public ReadValue {
public:
OPENVINO_OP("ReadValues", "gpu_opset");

ReadValues() = default;

ReadValues(const std::shared_ptr<ov::op::util::Variable>& variable,
const std::vector<ov::op::util::VariableInfo>& internal_states_infos);

ReadValues(const OutputVector& variable_initializers,
const std::shared_ptr<ov::op::util::Variable>& variable,
const std::vector<ov::op::util::VariableInfo>& internal_states_infos);

bool visit_attributes(ov::AttributeVisitor& visitor) override;

void validate_and_infer_types() override;

std::vector<ov::op::util::VariableInfo> get_all_internal_states_info() const;

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

private:
std::vector<ov::op::util::VariableInfo> m_internal_states_infos;
};

} // namespace op
} // namespace intel_gpu
} // namespace ov
Loading

0 comments on commit 1088754

Please sign in to comment.