From 4199b332e6587766c5436df297502e97f9d4ea97 Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Mon, 28 Oct 2024 16:21:28 +0400 Subject: [PATCH] Add KVCacheCompressed operation, replace QuantizationConfig with Attribute structure --- .../include/ov_ops/dynamic_quantize.hpp | 71 +++---- .../src/ov_ops/dynamic_quantize.cpp | 55 +++-- .../include/intel_gpu/op/indirect_sdpa.hpp | 3 +- .../include/intel_gpu/op/kv_cache.hpp | 34 +-- .../intel_gpu/op/kv_cache_compressed.hpp | 56 +++++ .../intel_gpu/include/intel_gpu/op/sdpa.hpp | 11 +- .../intel_gpu/plugin/primitives_list.hpp | 1 + .../intel_gpu/primitives/dynamic_quantize.hpp | 69 +++--- .../include/intel_gpu/primitives/kv_cache.hpp | 62 +++--- .../scaled_dot_product_attention.hpp | 83 +++++--- .../intel_gpu/src/graph/dynamic_quantize.cpp | 38 ++-- .../graph_optimizer/prepare_buffer_fusing.cpp | 2 +- .../src/graph/impls/ocl/dynamic_quantize.cpp | 8 +- .../src/graph/impls/ocl/kv_cache.cpp | 10 +- .../ocl/scaled_dot_product_attention.cpp | 11 +- .../src/graph/include/dynamic_quantize_inst.h | 4 +- src/plugins/intel_gpu/src/graph/kv_cache.cpp | 27 ++- .../intel_gpu/src/graph/primitive_inst.cpp | 4 +- .../graph/scaled_dot_product_attention.cpp | 12 +- .../src/plugin/ops/dynamic_quantize.cpp | 6 +- .../intel_gpu/src/plugin/ops/kv_cache.cpp | 30 ++- .../ops/scaled_dot_product_attention.cpp | 5 +- .../dynamic_quantize_fully_connected.cpp | 4 +- .../transformations/kv_cache_compression.cpp | 90 ++++---- .../transformations/op/indirect_sdpa.cpp | 8 +- .../plugin/transformations/op/kv_cache.cpp | 197 +++++++++--------- .../src/plugin/transformations/op/sdpa.cpp | 9 +- .../src/runtime/execution_config.cpp | 10 + .../test_cases/dynamic_quantize_gpu_test.cpp | 9 +- .../transformations/kv_cache_compression.cpp | 76 +++---- 30 files changed, 531 insertions(+), 474 deletions(-) create mode 100644 src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp diff --git a/src/common/transformations/include/ov_ops/dynamic_quantize.hpp b/src/common/transformations/include/ov_ops/dynamic_quantize.hpp index 7b4b7ea8230898..492a57b9d5590a 100644 --- a/src/common/transformations/include/ov_ops/dynamic_quantize.hpp +++ b/src/common/transformations/include/ov_ops/dynamic_quantize.hpp @@ -11,30 +11,18 @@ namespace ov { namespace op { namespace internal { -struct QuantizationConfig { - enum class QuantizationType { Symmetric, Asymmetric }; - - QuantizationType type = QuantizationType::Symmetric; - element::Type quantization_dt = element::undefined; - element::Type scale_dt = element::undefined; - element::Type zp_dt = element::undefined; - std::vector group_sizes = {}; - - bool operator==(const QuantizationConfig& rhs) const { - return type == rhs.type && quantization_dt == rhs.quantization_dt && scale_dt == rhs.scale_dt && - zp_dt == rhs.zp_dt && group_sizes == rhs.group_sizes; - } - - bool is_asymmetric_quantization() const { - return type == QuantizationType::Asymmetric; - } -}; - /// \brief Operator performing Dynamic Quantize class TRANSFORMATIONS_API DynamicQuantize : public ov::op::Op { public: OPENVINO_OP("DynamicQuantize", "ie_internal_opset"); + /** + * @brief Configuration for the type of quantization applied to the data: + * - Symmetric: Quantization where the zero point is fixed at zero, and the range is symmetric around zero. + * - Asymmetric: Quantization where the zero point is not fixed at zero. + */ + enum class QuantizationType { Symmetric, Asymmetric }; + /** * @brief Configuration for how Activations, Scales and Zero Points will be stored in output buffers: * - Planar: Activations, Scales, and Zero Points are stored in independent buffers. @@ -43,51 +31,60 @@ class TRANSFORMATIONS_API DynamicQuantize : public ov::op::Op { */ enum class OutputStorageType { Planar, InterleavedScalesZP, /* InterleavedActivationsScalesZP */ }; + /// \brief Structure that specifies attributes for interpolation + struct Attributes { + QuantizationType quantization_type = QuantizationType::Symmetric; + element::Type quantization_dt = element::undefined; + element::Type scale_dt = element::undefined; + element::Type zp_dt = element::undefined; + + std::vector group_sizes = {}; + std::vector scales_zp_output_order = {}; + OutputStorageType output_storage_type = OutputStorageType::Planar; + }; + DynamicQuantize() = default; /// \brief Constructs an DynamicQuantize operation. /// /// \param data Input tensor with data /// \param config Dynamic quantization configuration DynamicQuantize(const Output& data, - const QuantizationConfig& config, - const OutputStorageType& output_storage = OutputStorageType::Planar, - const std::vector& scales_zp_output_order = {}); + const Attributes& attrs); void validate_and_infer_types() override; std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; - const std::vector& get_group_sizes() const { - return m_config.group_sizes; + const Attributes& get_attrs() const { + return m_attrs; + } + + void set_attrs(Attributes attrs) { + m_attrs = std::move(attrs); } - QuantizationConfig::QuantizationType get_quantization_type() const { - return m_config.type; + const std::vector& get_group_sizes() const { + return m_attrs.group_sizes; } - QuantizationConfig get_quantization_config() const { - return m_config; + QuantizationType get_quantization_type() const { + return m_attrs.quantization_type; } OutputStorageType get_output_storage_type() const { - return m_output_storage_type; + return m_attrs.output_storage_type; } const std::vector& get_scales_zp_output_order() const { - return m_scales_zp_output_order; + return m_attrs.scales_zp_output_order; } static std::vector shape_infer( const DynamicQuantize* op, - const std::vector& input_shapes, - const QuantizationConfig& config, - const OutputStorageType& output_storage = OutputStorageType::Planar, - const std::vector& scales_zp_output_order = {}); + const std::vector& input_shapes); protected: - OutputStorageType m_output_storage_type; - std::vector m_scales_zp_output_order; - QuantizationConfig m_config; + Attributes m_attrs; }; } // namespace internal diff --git a/src/common/transformations/src/ov_ops/dynamic_quantize.cpp b/src/common/transformations/src/ov_ops/dynamic_quantize.cpp index c0df90031af168..748f3c8a0ba304 100644 --- a/src/common/transformations/src/ov_ops/dynamic_quantize.cpp +++ b/src/common/transformations/src/ov_ops/dynamic_quantize.cpp @@ -14,33 +14,29 @@ namespace op { namespace internal { DynamicQuantize::DynamicQuantize(const Output& data, - const QuantizationConfig& config, - const OutputStorageType& output_storage, - const std::vector& scales_zp_output_order) + const Attributes& attrs) : Op({data}), - m_output_storage_type(output_storage), - m_scales_zp_output_order(scales_zp_output_order), - m_config(config) { - if (m_scales_zp_output_order.empty()) { - m_scales_zp_output_order.resize(data.get_partial_shape().size()); - std::iota(m_scales_zp_output_order.begin(), m_scales_zp_output_order.end(), 0); + m_attrs(attrs) { + if (m_attrs.scales_zp_output_order.empty()) { + m_attrs.scales_zp_output_order.resize(data.get_partial_shape().size()); + std::iota(m_attrs.scales_zp_output_order.begin(), m_attrs.scales_zp_output_order.end(), 0); } - OPENVINO_ASSERT(data.get_partial_shape().rank() == m_config.group_sizes.size(), + OPENVINO_ASSERT(data.get_partial_shape().rank() == m_attrs.group_sizes.size(), "DQ input rank should be same as the rank of group_size ", data.get_tensor_ptr()->get_partial_shape().rank(), " / ", - m_config.group_sizes.size()); + m_attrs.group_sizes.size()); - OPENVINO_ASSERT(data.get_partial_shape().size() == m_scales_zp_output_order.size(), + OPENVINO_ASSERT(data.get_partial_shape().size() == m_attrs.scales_zp_output_order.size(), "DQ input rank should be same as the rank of scales and zero points output order)"); size_t outputs_number = 2; - if (config.is_asymmetric_quantization() && output_storage == OutputStorageType::Planar) + if (m_attrs.quantization_type == QuantizationType::Asymmetric && m_attrs.output_storage_type == OutputStorageType::Planar) outputs_number = 3; - OPENVINO_ASSERT((output_storage == OutputStorageType::Planar) || - (config.is_asymmetric_quantization() && config.scale_dt == config.zp_dt), + OPENVINO_ASSERT((m_attrs.output_storage_type == OutputStorageType::Planar) || + (m_attrs.quantization_type == QuantizationType::Asymmetric && m_attrs.scale_dt == m_attrs.zp_dt), "Scales and Zero Points should have the same data type to be stored in the single buffer"); set_output_size(outputs_number); @@ -50,29 +46,26 @@ DynamicQuantize::DynamicQuantize(const Output& data, void DynamicQuantize::validate_and_infer_types() { std::vector input_shapes = {get_input_partial_shape(0)}; - auto out_shapes = shape_infer(this, input_shapes, m_config, m_output_storage_type, m_scales_zp_output_order); - set_output_type(0, m_config.quantization_dt, out_shapes[0]); - set_output_type(1, m_config.scale_dt, out_shapes[1]); + auto out_shapes = shape_infer(this, input_shapes); + set_output_type(0, m_attrs.quantization_dt, out_shapes[0]); + set_output_type(1, m_attrs.scale_dt, out_shapes[1]); - if (m_config.is_asymmetric_quantization() && m_output_storage_type == OutputStorageType::Planar) - set_output_type(2, m_config.zp_dt, out_shapes[2]); + if (m_attrs.quantization_type == QuantizationType::Asymmetric && m_attrs.output_storage_type == OutputStorageType::Planar) + set_output_type(2, m_attrs.zp_dt, out_shapes[2]); } std::shared_ptr DynamicQuantize::clone_with_new_inputs(const ov::OutputVector& new_args) const { check_new_args_count(this, new_args); - return std::make_shared(new_args.at(0), m_config, m_output_storage_type, m_scales_zp_output_order); + return std::make_shared(new_args.at(0), m_attrs); } std::vector DynamicQuantize::shape_infer(const DynamicQuantize* op, - const std::vector& input_shapes, - const QuantizationConfig& config, - const OutputStorageType& output_storage, - const std::vector& scales_zp_output_order) { - const auto& group_sizes = config.group_sizes; + const std::vector& input_shapes) { std::vector out_shapes; out_shapes.push_back(input_shapes[0]); auto scale_shape = input_shapes[0]; + const auto& group_sizes = op->m_attrs.group_sizes; OPENVINO_ASSERT(scale_shape.size() == group_sizes.size(), "Scale_shape and group_size are supposed to have same rank: ", scale_shape.size(), @@ -91,7 +84,7 @@ std::vector DynamicQuantize::shape_infer(const DynamicQuantize out_shapes.push_back(scale_shape); // Add zero points shape, same as the scales - if (config.is_asymmetric_quantization() && output_storage == OutputStorageType::Planar) + if (op->m_attrs.quantization_type == QuantizationType::Asymmetric && op->m_attrs.output_storage_type == OutputStorageType::Planar) out_shapes.push_back(scale_shape); auto transpose_shape = [](const ov::PartialShape& shape, const std::vector& scales_zp_output_order) { @@ -105,14 +98,16 @@ std::vector DynamicQuantize::shape_infer(const DynamicQuantize }; // Transpose scales and zero points shapes + const auto& scales_zp_output_order = op->m_attrs.scales_zp_output_order; for (size_t i = 1; i < out_shapes.size(); i++) { out_shapes[i] = transpose_shape(out_shapes[i], scales_zp_output_order); } - if (config.is_asymmetric_quantization() && output_storage != OutputStorageType::Planar) { + if (op->m_attrs.quantization_type == QuantizationType::Asymmetric && op->m_attrs.output_storage_type != OutputStorageType::Planar) { // Currently scales and zero points are supposed to be combined over the last dimension only - const auto combine_axis = out_shapes[1].size() - 1; - OPENVINO_ASSERT(config.group_sizes[scales_zp_output_order[combine_axis]] != 1); + const auto combine_axis = scales_zp_output_order.empty() ? out_shapes[1].size() - 1 + : scales_zp_output_order[out_shapes[1].size() - 1]; + OPENVINO_ASSERT(group_sizes[combine_axis] != 1); out_shapes[1][combine_axis] *= 2; // [scale, zero_point] pairs } diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp index 4ce90a685690e5..7c45c93c7e74f1 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp @@ -37,8 +37,7 @@ class IndirectSDPA : public ov::intel_gpu::op::SDPA { const std::vector& order_k, const std::vector& order_v, const std::vector& order_out, - const QuantizationConfig& quantization_config, - const bool combine_scales_and_zp, + const QuantizationAttribute& quantization_attribute, const ov::element::Type output_type = ov::element::undefined); bool visit_attributes(ov::AttributeVisitor &visitor) override; diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp index e3acfa5223412f..7048d5229f25db 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp @@ -19,8 +19,6 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension { public: OPENVINO_OP("KVCache", "gpu_opset"); - using QuantizationConfig = ov::op::internal::QuantizationConfig; - KVCache() = default; KVCache(const Output& past, @@ -37,15 +35,6 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension { int64_t gather_axis, const ov::element::Type output_type = ov::element::undefined); - KVCache(const OutputVector& inputs, - const std::shared_ptr& past_values, - int64_t concat_axis, - int64_t gather_axis, - bool combine_scales_and_zp, - const QuantizationConfig& config, - const std::vector& scales_zp_output_order, - const ov::element::Type output_type = ov::element::undefined); - bool visit_attributes(ov::AttributeVisitor& visitor) override; void validate_and_infer_types() override; @@ -65,32 +54,23 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension { bool get_indirect() const { return m_indirect; } - bool get_kv_compressed() const { return m_compressed; } - bool get_combine_scales_and_zp() const { return m_combine_scales_and_zp; } - QuantizationConfig get_quantization_config() const { return m_quantization_config; } - std::vector get_scales_zp_output_order() const { return m_scales_zp_output_order; } +protected: + KVCache(const OutputVector& inputs, + const std::shared_ptr& past_values, + bool indirect, + int64_t concat_axis, + int64_t gather_axis, + const ov::element::Type output_type = ov::element::undefined); -private: int64_t m_concat_axis = 0; int64_t m_gather_axis = 0; bool m_indirect = false; - bool m_compressed = false; - bool m_combine_scales_and_zp = false; - QuantizationConfig m_quantization_config = {}; - std::vector m_scales_zp_output_order = {}; - ov::element::Type m_output_type; }; std::vector shape_infer(const KVCache* op, const std::vector& input_shapes); -std::vector shape_infer(const KVCache* op, - const std::vector& input_shapes, - const ov::op::internal::QuantizationConfig& config, - const std::vector& scales_output_order = {}, - bool combine_scales_and_zp = false); - } // namespace op } // namespace intel_gpu } // namespace ov diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp new file mode 100644 index 00000000000000..4ee8cb388b61ea --- /dev/null +++ b/src/plugins/intel_gpu/include/intel_gpu/op/kv_cache_compressed.hpp @@ -0,0 +1,56 @@ +// Copyright (C) 2023 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "intel_gpu/op/kv_cache.hpp" +#include "ov_ops/dynamic_quantize.hpp" + +namespace ov { +namespace intel_gpu { +namespace op { + +/// \brief Operator that implements Key-Values cache subgraph for large language models. +/// This operation updates data of the corresponding Variable +class KVCacheCompressed : public ov::intel_gpu::op::KVCache { +public: + OPENVINO_OP("KVCacheCompressed", "gpu_opset"); + + using QuantizationAttrs = ov::op::internal::DynamicQuantize::Attributes; + + KVCacheCompressed() = default; + + KVCacheCompressed(const OutputVector& inputs, + const std::shared_ptr& past_values, + int64_t concat_axis, + int64_t gather_axis, + const QuantizationAttrs& quantization_attrs, + const ov::element::Type output_type = ov::element::undefined); + + void validate_and_infer_types() override; + + std::shared_ptr clone_with_new_inputs(const ov::OutputVector& new_args) const override; + + bool get_kv_compressed() const { return m_compressed; } + bool get_combine_scales_and_zp() const { + return m_quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + m_quantization_attrs.output_storage_type != ov::op::internal::DynamicQuantize::OutputStorageType::Planar; + } + + QuantizationAttrs get_quantization_attrs() const { return m_quantization_attrs; } + void set_quantization_attrs(QuantizationAttrs attrs) { m_quantization_attrs = std::move(attrs); } + + std::vector get_scales_zp_output_order() const { return m_quantization_attrs.scales_zp_output_order; } + +private: + bool m_compressed; + QuantizationAttrs m_quantization_attrs = {}; +}; + +std::vector shape_infer(const KVCacheCompressed* op, + const std::vector& input_shapes); + +} // namespace op +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp index 30f91d4bd74b86..f7bc0d780ffd38 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp @@ -18,7 +18,7 @@ class SDPA : public ov::op::v13::ScaledDotProductAttention { public: OPENVINO_OP("SDPA", "gpu_opset"); - using QuantizationConfig = ov::op::internal::QuantizationConfig; + using QuantizationAttribute = ov::op::internal::DynamicQuantize::Attributes; SDPA() = default; @@ -36,8 +36,7 @@ class SDPA : public ov::op::v13::ScaledDotProductAttention { const std::vector& order_k, const std::vector& order_v, const std::vector& order_out, - const QuantizationConfig& quantization_config, - const bool m_combine_scales_and_zp, + const QuantizationAttribute& quantization_attrs, const ov::element::Type output_type = ov::element::undefined); bool visit_attributes(ov::AttributeVisitor &visitor) override; @@ -55,8 +54,7 @@ class SDPA : public ov::op::v13::ScaledDotProductAttention { ov::element::Type get_output_type() const { return m_output_type; } bool get_kv_compressed() const { return m_compressed; } - bool get_combine_scales_and_zp() const { return m_combine_scales_and_zp; } - QuantizationConfig get_quantization_config() const { return m_quantization_config; } + QuantizationAttribute get_quantization_attrs() const { return m_quantization_attrs; } size_t get_compression_inputs_num() const; static std::vector default_order(size_t rank) { @@ -74,8 +72,7 @@ class SDPA : public ov::op::v13::ScaledDotProductAttention { ov::element::Type m_output_type; bool m_compressed = false; - bool m_combine_scales_and_zp = false; - QuantizationConfig m_quantization_config = {}; + QuantizationAttribute m_quantization_attrs = {}; }; std::vector shape_infer(const SDPA* op, diff --git a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp index 04314ef033e019..27e5540a3786ab 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/plugin/primitives_list.hpp @@ -283,6 +283,7 @@ REGISTER_FACTORY(internal, FullyConnectedCompressed); REGISTER_FACTORY(internal, RMS); REGISTER_FACTORY(internal, GatherCompressed); REGISTER_FACTORY(internal, KVCache); +REGISTER_FACTORY(internal, KVCacheCompressed); REGISTER_FACTORY(internal, ReadValue); REGISTER_FACTORY(internal, ReadValues); REGISTER_FACTORY(internal, Gemm); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/dynamic_quantize.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/dynamic_quantize.hpp index 29e7baa1656680..79af223e32cdaa 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/dynamic_quantize.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/dynamic_quantize.hpp @@ -14,7 +14,7 @@ namespace cldnn { struct dynamic_quantize : public primitive_base { CLDNN_DECLARE_PRIMITIVE(dynamic_quantize); - using QuantizationConfig = ov::op::internal::QuantizationConfig; + using Attributes = ov::op::internal::DynamicQuantize::Attributes; dynamic_quantize() : primitive_base("", {}) {} @@ -26,31 +26,26 @@ struct dynamic_quantize : public primitive_base { /// @param output_size Output data size of the primitive dynamic_quantize(const primitive_id& id, const input_info& input, - const QuantizationConfig& config, - const bool combine_scales_and_zp = false, - const std::vector& scales_zp_output_order = {}) + const Attributes& attrs) : primitive_base(id, {input}) - , combine_scales_and_zp(combine_scales_and_zp) - , quantization_config(config) - , scales_zp_output_order(scales_zp_output_order) { + , attrs(attrs) { num_outputs = 2; - if (config.is_asymmetric_quantization() && !combine_scales_and_zp) + if (attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) num_outputs++; } - bool combine_scales_and_zp = false; - QuantizationConfig quantization_config; - std::vector scales_zp_output_order = {}; + Attributes attrs; size_t hash() const override { size_t seed = primitive::hash(); - seed = hash_range(seed, scales_zp_output_order.begin(), scales_zp_output_order.end()); - seed = hash_range(seed, quantization_config.group_sizes.begin(), quantization_config.group_sizes.end()); - seed = hash_combine(seed, quantization_config.type); - seed = hash_combine(seed, quantization_config.quantization_dt.hash()); - seed = hash_combine(seed, quantization_config.scale_dt.hash()); - seed = hash_combine(seed, quantization_config.zp_dt.hash()); - seed = hash_combine(seed, combine_scales_and_zp); + seed = hash_range(seed, attrs.scales_zp_output_order.begin(), attrs.scales_zp_output_order.end()); + seed = hash_range(seed, attrs.group_sizes.begin(), attrs.group_sizes.end()); + seed = hash_combine(seed, attrs.quantization_type); + seed = hash_combine(seed, attrs.quantization_dt.hash()); + seed = hash_combine(seed, attrs.scale_dt.hash()); + seed = hash_combine(seed, attrs.zp_dt.hash()); + seed = hash_combine(seed, attrs.output_storage_type); return seed; } @@ -61,33 +56,37 @@ struct dynamic_quantize : public primitive_base { auto rhs_casted = downcast(rhs); - return scales_zp_output_order == rhs_casted.scales_zp_output_order || - combine_scales_and_zp == rhs_casted.combine_scales_and_zp || - quantization_config == rhs_casted.quantization_config; + return attrs.scales_zp_output_order == rhs_casted.attrs.scales_zp_output_order && + attrs.output_storage_type == rhs_casted.attrs.output_storage_type && + attrs.group_sizes == rhs_casted.attrs.group_sizes && + attrs.quantization_dt == rhs_casted.attrs.quantization_dt && + attrs.scale_dt == rhs_casted.attrs.scale_dt && + attrs.zp_dt == rhs_casted.attrs.zp_dt && + attrs.quantization_type == rhs_casted.attrs.quantization_type;; } void save(BinaryOutputBuffer& ob) const override { primitive_base::save(ob); - ob << combine_scales_and_zp; - ob << scales_zp_output_order; - ob << quantization_config.group_sizes; - ob << make_data(&quantization_config.type, sizeof(quantization_config.type)); - ob << make_data(&quantization_config.quantization_dt, sizeof(quantization_config.quantization_dt)); - ob << make_data(&quantization_config.scale_dt, sizeof(quantization_config.scale_dt)); - ob << make_data(&quantization_config.zp_dt, sizeof(quantization_config.zp_dt)); + ob << make_data(&attrs.quantization_type, sizeof(attrs.quantization_type)); + ob << make_data(&attrs.quantization_dt, sizeof(attrs.quantization_dt)); + ob << make_data(&attrs.scale_dt, sizeof(attrs.scale_dt)); + ob << make_data(&attrs.zp_dt, sizeof(attrs.zp_dt)); + ob << make_data(&attrs.output_storage_type, sizeof(attrs.output_storage_type)); + ob << attrs.scales_zp_output_order; + ob << attrs.group_sizes; } void load(BinaryInputBuffer& ib) override { primitive_base::load(ib); - ib >> combine_scales_and_zp; - ib >> scales_zp_output_order; - ib >> quantization_config.group_sizes; - ib >> make_data(&quantization_config.type, sizeof(quantization_config.type)); - ib >> make_data(&quantization_config.quantization_dt, sizeof(quantization_config.quantization_dt)); - ib >> make_data(&quantization_config.scale_dt, sizeof(quantization_config.scale_dt)); - ib >> make_data(&quantization_config.zp_dt, sizeof(quantization_config.zp_dt)); + ib >> make_data(&attrs.quantization_type, sizeof(attrs.quantization_type)); + ib >> make_data(&attrs.quantization_dt, sizeof(attrs.quantization_dt)); + ib >> make_data(&attrs.scale_dt, sizeof(attrs.scale_dt)); + ib >> make_data(&attrs.zp_dt, sizeof(attrs.zp_dt)); + ib >> make_data(&attrs.output_storage_type, sizeof(attrs.output_storage_type)); + ib >> attrs.scales_zp_output_order; + ib >> attrs.group_sizes; } }; } // namespace cldnn diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp index 5d3abd04978666..1c8f095752aca2 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/kv_cache.hpp @@ -18,7 +18,7 @@ namespace cldnn { struct kv_cache : public primitive_base { CLDNN_DECLARE_PRIMITIVE(kv_cache) - using QuantizationConfig = ov::op::internal::QuantizationConfig; + using QuantizationAttributes = ov::op::internal::DynamicQuantize::Attributes; kv_cache() : primitive_base("", {}) {} @@ -40,9 +40,7 @@ struct kv_cache : public primitive_base { bool indirect = false; bool compressed = false; - bool combine_scales_and_zp = false; - QuantizationConfig quantization_config; - std::vector scales_zp_output_order = {}; + QuantizationAttributes quantization_attributes; size_t hash() const override { size_t seed = primitive::hash(); @@ -50,13 +48,13 @@ struct kv_cache : public primitive_base { seed = hash_combine(seed, gather_axis); seed = hash_combine(seed, indirect); seed = hash_combine(seed, compressed); - seed = hash_combine(seed, combine_scales_and_zp); - seed = hash_range(seed, scales_zp_output_order.begin(), scales_zp_output_order.end()); - seed = hash_range(seed, quantization_config.group_sizes.begin(), quantization_config.group_sizes.end()); - seed = hash_combine(seed, quantization_config.type); - seed = hash_combine(seed, quantization_config.quantization_dt.hash()); - seed = hash_combine(seed, quantization_config.scale_dt.hash()); - seed = hash_combine(seed, quantization_config.zp_dt.hash()); + seed = hash_range(seed, quantization_attributes.scales_zp_output_order.begin(), quantization_attributes.scales_zp_output_order.end()); + seed = hash_range(seed, quantization_attributes.group_sizes.begin(), quantization_attributes.group_sizes.end()); + seed = hash_combine(seed, quantization_attributes.quantization_type); + seed = hash_combine(seed, quantization_attributes.quantization_dt.hash()); + seed = hash_combine(seed, quantization_attributes.scale_dt.hash()); + seed = hash_combine(seed, quantization_attributes.zp_dt.hash()); + seed = hash_combine(seed, quantization_attributes.output_storage_type);; return seed; } @@ -72,9 +70,13 @@ struct kv_cache : public primitive_base { gather_axis == rhs_casted.gather_axis && indirect == rhs_casted.indirect && compressed == rhs_casted.compressed && - scales_zp_output_order == rhs_casted.scales_zp_output_order && - combine_scales_and_zp == rhs_casted.combine_scales_and_zp && - quantization_config == rhs_casted.quantization_config; + quantization_attributes.scales_zp_output_order == rhs_casted.quantization_attributes.scales_zp_output_order && + quantization_attributes.output_storage_type == rhs_casted.quantization_attributes.output_storage_type && + quantization_attributes.group_sizes == rhs_casted.quantization_attributes.group_sizes && + quantization_attributes.quantization_dt == rhs_casted.quantization_attributes.quantization_dt && + quantization_attributes.scale_dt == rhs_casted.quantization_attributes.scale_dt && + quantization_attributes.zp_dt == rhs_casted.quantization_attributes.zp_dt && + quantization_attributes.quantization_type == rhs_casted.quantization_attributes.quantization_type; } void save(BinaryOutputBuffer& ob) const override { @@ -87,13 +89,13 @@ struct kv_cache : public primitive_base { ob << gather_axis; ob << indirect; ob << compressed; - ob << combine_scales_and_zp; - ob << scales_zp_output_order; - ob << quantization_config.group_sizes; - ob << make_data(&quantization_config.type, sizeof(quantization_config.type)); - ob << make_data(&quantization_config.quantization_dt, sizeof(quantization_config.quantization_dt)); - ob << make_data(&quantization_config.scale_dt, sizeof(quantization_config.scale_dt)); - ob << make_data(&quantization_config.zp_dt, sizeof(quantization_config.zp_dt)); + ob << make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type)); + ob << make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt)); + ob << make_data(&quantization_attributes.scale_dt, sizeof(quantization_attributes.scale_dt)); + ob << make_data(&quantization_attributes.zp_dt, sizeof(quantization_attributes.zp_dt)); + ob << make_data(&quantization_attributes.output_storage_type, sizeof(quantization_attributes.output_storage_type)); + ob << quantization_attributes.scales_zp_output_order; + ob << quantization_attributes.group_sizes; } void load(BinaryInputBuffer& ib) override { @@ -109,13 +111,13 @@ struct kv_cache : public primitive_base { ib >> gather_axis; ib >> indirect; ib >> compressed; - ib >> combine_scales_and_zp; - ib >> scales_zp_output_order; - ib >> quantization_config.group_sizes; - ib >> make_data(&quantization_config.type, sizeof(quantization_config.type)); - ib >> make_data(&quantization_config.quantization_dt, sizeof(quantization_config.quantization_dt)); - ib >> make_data(&quantization_config.scale_dt, sizeof(quantization_config.scale_dt)); - ib >> make_data(&quantization_config.zp_dt, sizeof(quantization_config.zp_dt)); + ib >> make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type)); + ib >> make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt)); + ib >> make_data(&quantization_attributes.scale_dt, sizeof(quantization_attributes.scale_dt)); + ib >> make_data(&quantization_attributes.zp_dt, sizeof(quantization_attributes.zp_dt)); + ib >> make_data(&quantization_attributes.output_storage_type, sizeof(quantization_attributes.output_storage_type)); + ib >> quantization_attributes.scales_zp_output_order; + ib >> quantization_attributes.group_sizes; } size_t get_compression_scales_inputs_num() const { @@ -127,7 +129,9 @@ struct kv_cache : public primitive_base { } size_t get_compression_zp_inputs_num() const { - if (compressed && quantization_config.is_asymmetric_quantization() && !combine_scales_and_zp) { + if (compressed && + quantization_attributes.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + quantization_attributes.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) { return 1; } else { return 0; diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp index 5b829cea5ead0d..1fd5b43824d0a7 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/scaled_dot_product_attention.hpp @@ -12,7 +12,7 @@ namespace cldnn { struct scaled_dot_product_attention : public primitive_base { CLDNN_DECLARE_PRIMITIVE(scaled_dot_product_attention) - using QuantizationConfig = ov::op::internal::QuantizationConfig; + using QuantizationAttributes = ov::op::internal::DynamicQuantize::Attributes; scaled_dot_product_attention() : primitive_base("", {}) {} @@ -28,15 +28,13 @@ struct scaled_dot_product_attention : public primitive_base& input_k_transpose_order = {}, const std::vector& input_v_transpose_order = {}, const std::vector& output_transpose_order = {}, - bool is_kv_compressed = false, - bool combine_scales_and_zp = false, - const QuantizationConfig& quantization_config = {}) + const QuantizationAttributes& quantization_attributes = {}, + bool is_kv_compressed = false) : primitive_base(id, inputs) , is_causal(is_causal) , indirect_axis(indirect_axis) , is_kv_compressed(is_kv_compressed) - , combine_scales_and_zp(combine_scales_and_zp) - , quantization_config(quantization_config) + , quantization_attributes(quantization_attributes) , input_q_transpose_order(input_q_transpose_order) , input_k_transpose_order(input_k_transpose_order) , input_v_transpose_order(input_v_transpose_order) @@ -48,7 +46,8 @@ struct scaled_dot_product_attention : public primitive_base 3; @@ -61,8 +60,7 @@ struct scaled_dot_product_attention : public primitive_base input_q_transpose_order; std::vector input_k_transpose_order; @@ -80,12 +78,14 @@ struct scaled_dot_product_attention : public primitive_base> input_k_transpose_order; ib >> input_v_transpose_order; ib >> output_transpose_order; - ib >> combine_scales_and_zp; - ib >> quantization_config.group_sizes; - ib >> make_data(&quantization_config.type, sizeof(quantization_config.type)); - ib >> make_data(&quantization_config.quantization_dt, sizeof(quantization_config.quantization_dt)); - ib >> make_data(&quantization_config.scale_dt, sizeof(quantization_config.scale_dt)); - ib >> make_data(&quantization_config.zp_dt, sizeof(quantization_config.zp_dt)); + ib >> make_data(&quantization_attributes.quantization_type, sizeof(quantization_attributes.quantization_type)); + ib >> make_data(&quantization_attributes.quantization_dt, sizeof(quantization_attributes.quantization_dt)); + ib >> make_data(&quantization_attributes.scale_dt, sizeof(quantization_attributes.scale_dt)); + ib >> make_data(&quantization_attributes.zp_dt, sizeof(quantization_attributes.zp_dt)); + ib >> make_data(&quantization_attributes.output_storage_type, sizeof(quantization_attributes.output_storage_type)); + ib >> quantization_attributes.scales_zp_output_order; + ib >> quantization_attributes.group_sizes; + } + + size_t get_compression_scales_inputs_num() const { + if (is_kv_compressed) { + return 2; + } else { + return 0; + } + } + + size_t get_compression_zp_inputs_num() const { + if (is_kv_compressed && + quantization_attributes.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + quantization_attributes.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) { + return 2; + } else { + return 0; + } } }; } // namespace cldnn diff --git a/src/plugins/intel_gpu/src/graph/dynamic_quantize.cpp b/src/plugins/intel_gpu/src/graph/dynamic_quantize.cpp index 956b033041309a..8e4957d5f52797 100644 --- a/src/plugins/intel_gpu/src/graph/dynamic_quantize.cpp +++ b/src/plugins/intel_gpu/src/graph/dynamic_quantize.cpp @@ -23,41 +23,38 @@ layout dynamic_quantize_inst::calc_output_layout(dynamic_quantize_node const& no template std::vector dynamic_quantize_inst::__calc_output_layouts(const layout &act_layout, - const dynamic_quantize::QuantizationConfig& config, - const std::vector& scales_zp_output_order, - const bool combine_scales_and_zp) { + const dynamic_quantize::Attributes& attrs) { ov::op::internal::DynamicQuantize op; + op.set_attrs(attrs); + auto output_format = act_layout.format; std::vector input_shapes = { act_layout.get(), }; - auto output_storage_type = combine_scales_and_zp ? ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP - : ov::op::internal::DynamicQuantize::OutputStorageType::Planar; - auto output_shapes = ov::op::internal::DynamicQuantize::shape_infer(&op, input_shapes, config, output_storage_type, scales_zp_output_order); + auto output_shapes = ov::op::internal::DynamicQuantize::shape_infer(&op, input_shapes); - std::vector output_layouts = { layout(output_shapes[0], config.quantization_dt, output_format), - layout(output_shapes[1], config.scale_dt, output_format) }; + std::vector output_layouts = { layout(output_shapes[0], attrs.quantization_dt, output_format), + layout(output_shapes[1], attrs.scale_dt, output_format) }; - if (config.is_asymmetric_quantization() && !combine_scales_and_zp) { - output_layouts.emplace_back(layout(output_shapes[2], config.zp_dt, output_format)); + if (attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) { + output_layouts.emplace_back(layout(output_shapes[2], attrs.zp_dt, output_format)); } return output_layouts; } template std::vector dynamic_quantize_inst::__calc_output_layouts(const layout &act_layout, - const dynamic_quantize::QuantizationConfig& config, - const std::vector& scales_zp_output_order, - const bool combine_scales_and_zp); + const dynamic_quantize::Attributes& config); template std::vector dynamic_quantize_inst::calc_output_layouts(dynamic_quantize_node const& /*node*/, const kernel_impl_params& impl_param) { auto desc = impl_param.typed_desc(); const auto& input_layout = impl_param.get_input_layout(); - return __calc_output_layouts(input_layout, desc->quantization_config, desc->scales_zp_output_order, desc->combine_scales_and_zp); + return __calc_output_layouts(input_layout, desc->attrs); } template std::vector dynamic_quantize_inst::calc_output_layouts(dynamic_quantize_node const& node, @@ -70,12 +67,13 @@ std::string dynamic_quantize_inst::to_string(dynamic_quantize_node const& node) std::stringstream primitive_description; json_composite dynamic_quantize_info; - dynamic_quantize_info.add("combine_scales_and_zp", desc->combine_scales_and_zp); - dynamic_quantize_info.add("scales_zp_output_order", desc->scales_zp_output_order); - dynamic_quantize_info.add("quantization_dt", desc->quantization_config.quantization_dt); - dynamic_quantize_info.add("scale_dt", desc->quantization_config.scale_dt); - dynamic_quantize_info.add("zp_dt", desc->quantization_config.zp_dt); - dynamic_quantize_info.add("is_asymmetric_quantization", desc->quantization_config.is_asymmetric_quantization()); + dynamic_quantize_info.add("output_storage_type", static_cast(desc->attrs.output_storage_type)); + dynamic_quantize_info.add("scales_zp_output_order", desc->attrs.scales_zp_output_order); + dynamic_quantize_info.add("group_sizes", desc->attrs.group_sizes); + dynamic_quantize_info.add("quantization_dt", desc->attrs.quantization_dt); + dynamic_quantize_info.add("scale_dt", desc->attrs.scale_dt); + dynamic_quantize_info.add("zp_dt", desc->attrs.zp_dt); + dynamic_quantize_info.add("quantization_type", static_cast(desc->attrs.quantization_type)); node_info->add("dynamic_quantize info", dynamic_quantize_info); node_info->dump(primitive_description); diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index e92eefa5b01ec9..6d7d609d232947 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -908,7 +908,7 @@ void prepare_buffer_fusing::run(program& p) { if (desc->compressed) { update_scale_zp(2, 1); - if (desc->quantization_config.is_asymmetric_quantization() && !desc->combine_scales_and_zp) { + if (desc->get_compression_zp_inputs_num() > 0) { update_scale_zp(3, 2); } } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/dynamic_quantize.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/dynamic_quantize.cpp index d33b4b6dade34c..0c212882f9dbbb 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/dynamic_quantize.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/dynamic_quantize.cpp @@ -49,10 +49,10 @@ struct dynamic_quantize_impl : typed_primitive_impl_ocl { params.outputs.push_back(convert_data_tensor(impl_param.get_output_layout(2))); const auto& desc = impl_param.typed_desc(); - params.group_sizes = desc->quantization_config.group_sizes; - params.scales_output_order = desc->scales_zp_output_order; - params.use_asymmetric_quantization = desc->quantization_config.is_asymmetric_quantization(); - params.combine_scales_and_zp = desc->combine_scales_and_zp; + params.group_sizes = desc->attrs.group_sizes; + params.scales_output_order = desc->attrs.scales_zp_output_order; + params.use_asymmetric_quantization = desc->attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric; + params.combine_scales_and_zp = desc->attrs.output_storage_type != ov::op::internal::DynamicQuantize::OutputStorageType::Planar; return params; } diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp index 2ce0b8a5e46b2d..d0fcace0b3f184 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp @@ -369,10 +369,12 @@ struct kv_cache_impl : multi_stage_primitive { auto params = get_default_params(impl_param, is_shape_agnostic); params.append_axis = primitive->concat_axis; - params.group_sizes = primitive->quantization_config.group_sizes; - params.scales_output_order = primitive->scales_zp_output_order; - params.use_asymmetric_quantization = primitive->quantization_config.is_asymmetric_quantization(); - params.combine_scales_and_zp = primitive->combine_scales_and_zp; + params.group_sizes = primitive->quantization_attributes.group_sizes; + params.scales_output_order = primitive->quantization_attributes.scales_zp_output_order; + params.use_asymmetric_quantization = + primitive->quantization_attributes.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric; + params.combine_scales_and_zp = + primitive->quantization_attributes.output_storage_type != ov::op::internal::DynamicQuantize::OutputStorageType::Planar; const auto& past_kv_cache_shape = impl_param.input_layouts[0].get_partial_shape(); params.axis_offset = past_kv_cache_shape[primitive->concat_axis].is_static() ? past_kv_cache_shape[primitive->concat_axis].get_length() : 0; diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp index 5f3919b41e31dd..f4791d38f88742 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/scaled_dot_product_attention.cpp @@ -218,13 +218,15 @@ struct scaled_dot_product_attention_impl : multi_stage_primitiveis_causal; if (desc->is_kv_compressed) { - const auto& group_sizes = desc->quantization_config.group_sizes; + const auto& group_sizes = desc->quantization_attributes.group_sizes; const auto non_compressed_dims = std::count(group_sizes.begin(), group_sizes.end(), 1); config.per_head_quantization = (group_sizes.size() - non_compressed_dims) == 1; config.is_kv_compressed = desc->is_kv_compressed; - config.use_asymmetric_quantization = desc->quantization_config.is_asymmetric_quantization(); - config.combine_scales_and_zp = desc->combine_scales_and_zp; + config.use_asymmetric_quantization = + desc->quantization_attributes.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric; + config.combine_scales_and_zp = + desc->quantization_attributes.output_storage_type != ov::op::internal::DynamicQuantize::OutputStorageType::Planar; } return config; @@ -243,8 +245,7 @@ struct scaled_dot_product_attention_impl : multi_stage_primitiveis_kv_compressed) { data_inputs_num -= 2; // key and value compression scales are handled separately - has_zp_input_buffers = desc->quantization_config.is_asymmetric_quantization() && !desc->combine_scales_and_zp; - if (has_zp_input_buffers) + if (desc->get_compression_zp_inputs_num() > 0) data_inputs_num -= 2; // key and value compression zp are handled separately } diff --git a/src/plugins/intel_gpu/src/graph/include/dynamic_quantize_inst.h b/src/plugins/intel_gpu/src/graph/include/dynamic_quantize_inst.h index bb8ebd093fa696..f96085094ae221 100644 --- a/src/plugins/intel_gpu/src/graph/include/dynamic_quantize_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/dynamic_quantize_inst.h @@ -36,9 +36,7 @@ class typed_primitive_inst : public typed_primitive_inst_base< // Internal function to be used from fakealignment template static std::vector __calc_output_layouts(const layout &act_layout, - const dynamic_quantize::QuantizationConfig& config, - const std::vector& scales_zp_output_order, - const bool combine_scales_and_zp); + const dynamic_quantize::Attributes& config); static std::string to_string(dynamic_quantize_node const& node); typed_primitive_inst(network& network, dynamic_quantize_node const& node); diff --git a/src/plugins/intel_gpu/src/graph/kv_cache.cpp b/src/plugins/intel_gpu/src/graph/kv_cache.cpp index 77e31ed6447443..c65fe5796d6ed9 100644 --- a/src/plugins/intel_gpu/src/graph/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/graph/kv_cache.cpp @@ -3,6 +3,7 @@ // #include "intel_gpu/op/kv_cache.hpp" +#include "intel_gpu/op/kv_cache_compressed.hpp" #include "intel_gpu/plugin/common_utils.hpp" #include "intel_gpu/plugin/multi_tensor_variable_state.hpp" #include "intel_gpu/runtime/optionals.hpp" @@ -29,10 +30,16 @@ template std::vector kv_cache_inst::calc_output_layouts(kv_cache_node const& /*node*/, kernel_impl_params const& impl_param) { auto desc = impl_param.typed_desc(); - ov::intel_gpu::op::KVCache op; - op.set_output_size(desc->num_outputs); - op.set_concat_axis(desc->concat_axis); - op.set_gather_axis(desc->gather_axis); + std::unique_ptr op; + if (!desc->compressed) { + op = make_unique(); + } else { + op = make_unique(); + } + + op->set_output_size(desc->num_outputs); + op->set_concat_axis(desc->concat_axis); + op->set_gather_axis(desc->gather_axis); std::vector input_shapes = {impl_param.get_input_layout(0).get(), impl_param.get_input_layout(1).get()}; @@ -46,13 +53,15 @@ std::vector kv_cache_inst::calc_output_layouts(kv_cache_node const& /*no if (desc->get_compression_zp_inputs_num() > 0) { input_shapes.push_back(impl_param.get_input_layout(4).get()); } + + static_cast(op.get())->set_quantization_attrs(desc->quantization_attributes); } std::vector output_shapes; - if (desc->compressed) { - output_shapes = shape_infer(&op, input_shapes, desc->quantization_config, desc->scales_zp_output_order, desc->combine_scales_and_zp); + if (!desc->compressed) { + output_shapes = shape_infer(static_cast(op.get()), input_shapes); } else { - output_shapes = shape_infer(&op, input_shapes); + output_shapes = shape_infer(static_cast(op.get()), input_shapes); } static const std::map ports_map = {{0, 0}, {1, 2}, {2, 3}, {3, 4}}; @@ -79,8 +88,8 @@ std::string kv_cache_inst::to_string(const kv_cache_node& node) { kv_cache_info.add("gather axis", node.get_primitive()->gather_axis); kv_cache_info.add("indirect", node.get_primitive()->indirect); kv_cache_info.add("compressed", node.get_primitive()->compressed); - kv_cache_info.add("combine_scales_and_zp", node.get_primitive()->combine_scales_and_zp); - kv_cache_info.add("scales_zp_output_order", node.get_primitive()->scales_zp_output_order); + kv_cache_info.add("output_storage_type", static_cast(node.get_primitive()->quantization_attributes.output_storage_type)); + kv_cache_info.add("scales_zp_output_order", node.get_primitive()->quantization_attributes.scales_zp_output_order); node_info->add("kv_cache info", kv_cache_info); std::stringstream primitive_description; node_info->dump(primitive_description); diff --git a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp index f4e5f6df460004..b5f0a0f2c98fd0 100644 --- a/src/plugins/intel_gpu/src/graph/primitive_inst.cpp +++ b/src/plugins/intel_gpu/src/graph/primitive_inst.cpp @@ -707,9 +707,7 @@ event::ptr primitive_inst::realloc_if_needed() { const auto& desc = get_node().as().get_primitive(); auto dyn_quan_scale_layout = dynamic_quantize_inst::__calc_output_layouts(updated_layouts[dep_idx], - desc->quantization_config, - desc->scales_zp_output_order, - desc->combine_scales_and_zp); + desc->attrs); GPU_DEBUG_TRACE_DETAIL << "update layout of dynamic quantize scale parameter layout " << dyn_quan_scale_layout[1].to_short_string() << std::endl; updated_params.output_layouts[1] = dyn_quan_scale_layout[1]; diff --git a/src/plugins/intel_gpu/src/graph/scaled_dot_product_attention.cpp b/src/plugins/intel_gpu/src/graph/scaled_dot_product_attention.cpp index d24488aa2606aa..e80cb62a534b52 100644 --- a/src/plugins/intel_gpu/src/graph/scaled_dot_product_attention.cpp +++ b/src/plugins/intel_gpu/src/graph/scaled_dot_product_attention.cpp @@ -89,12 +89,12 @@ std::string scaled_dot_product_attention_inst::to_string(scaled_dot_product_atte scaled_dot_product_attention_info.add("input id", input.id()); scaled_dot_product_attention_info.add("is_causal", desc->is_causal); scaled_dot_product_attention_info.add("is_kv_compressed", desc->is_kv_compressed); - scaled_dot_product_attention_info.add("combine_scales_and_zp", desc->combine_scales_and_zp); - scaled_dot_product_attention_info.add("group_size", desc->quantization_config.group_sizes); - scaled_dot_product_attention_info.add("is_asymmetric_quantization", desc->quantization_config.is_asymmetric_quantization()); - scaled_dot_product_attention_info.add("quantization_dt", desc->quantization_config.quantization_dt); - scaled_dot_product_attention_info.add("scale_dt", desc->quantization_config.scale_dt); - scaled_dot_product_attention_info.add("zp_dt", desc->quantization_config.zp_dt); + scaled_dot_product_attention_info.add("output_storage_type", static_cast(node.get_primitive()->quantization_attributes.output_storage_type)); + scaled_dot_product_attention_info.add("group_size", desc->quantization_attributes.group_sizes); + scaled_dot_product_attention_info.add("quantization_type", static_cast(node.get_primitive()->quantization_attributes.quantization_type)); + scaled_dot_product_attention_info.add("quantization_dt", desc->quantization_attributes.quantization_dt); + scaled_dot_product_attention_info.add("scale_dt", desc->quantization_attributes.scale_dt); + scaled_dot_product_attention_info.add("zp_dt", desc->quantization_attributes.zp_dt); scaled_dot_product_attention_info.add("indirect_axis", desc->indirect_axis); scaled_dot_product_attention_info.add("has_attn_mask_input", desc->has_attn_mask_input); scaled_dot_product_attention_info.add("has_scale_input", desc->has_scale_input); diff --git a/src/plugins/intel_gpu/src/plugin/ops/dynamic_quantize.cpp b/src/plugins/intel_gpu/src/plugin/ops/dynamic_quantize.cpp index 6bd3d65aaccf7d..85f28cbd711678 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/dynamic_quantize.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/dynamic_quantize.cpp @@ -16,13 +16,9 @@ static void CreateDynamicQuantizeOp(ProgramBuilder& p, const std::shared_ptrget_output_storage_type() == ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; - auto scales_zp_output_order = op->get_scales_zp_output_order(); auto prim = cldnn::dynamic_quantize(primitive_name, inputs[0], - op->get_quantization_config(), - combine_scales_and_zp, - scales_zp_output_order); + op->get_attrs()); prim.num_outputs = op->get_output_size(); diff --git a/src/plugins/intel_gpu/src/plugin/ops/kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/ops/kv_cache.cpp index bfa4df330c2fc8..251c7346db9209 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/kv_cache.cpp @@ -3,6 +3,7 @@ // #include "intel_gpu/op/kv_cache.hpp" +#include "intel_gpu/op/kv_cache_compressed.hpp" #include "intel_gpu/plugin/program_builder.hpp" #include "intel_gpu/plugin/common_utils.hpp" #include "intel_gpu/primitives/kv_cache.hpp" @@ -12,6 +13,7 @@ namespace ov { namespace op { namespace internal { using KVCache = ov::intel_gpu::op::KVCache; +using KVCacheCompressed = ov::intel_gpu::op::KVCacheCompressed; } // namespace internal } // namespace op } // namespace ov @@ -22,7 +24,7 @@ namespace intel_gpu { namespace { void CreateKVCacheOp(ProgramBuilder& p, const std::shared_ptr& op) { - validate_inputs_count(op, {2, 3, 4, 5}); + validate_inputs_count(op, {2, 3}); auto inputs = p.GetInputInfo(op); int64_t rank = op->get_input_partial_shape(0).size(); auto prim = cldnn::kv_cache(layer_type_name_ID(op), @@ -35,12 +37,25 @@ void CreateKVCacheOp(ProgramBuilder& p, const std::shared_ptrget_output_size(); prim.output_data_types = get_output_data_types(op); - if (op->get_kv_compressed()) { - prim.compressed = true; - prim.combine_scales_and_zp = op->get_combine_scales_and_zp(); - prim.quantization_config = op->get_quantization_config(); - prim.scales_zp_output_order = op->get_scales_zp_output_order(); - } + p.add_primitive(*op, prim); +} + +void CreateKVCacheCompressedOp(ProgramBuilder& p, const std::shared_ptr& op) { + validate_inputs_count(op, {4, 5}); + auto inputs = p.GetInputInfo(op); + int64_t rank = op->get_input_partial_shape(0).size(); + auto prim = cldnn::kv_cache(layer_type_name_ID(op), + inputs, + op->get_variable()->get_info(), + ov::util::normalize(op->get_concat_axis(), rank), + ov::util::normalize(op->get_gather_axis(), rank), + op->get_indirect()); + + prim.compressed = true; + prim.quantization_attributes = op->get_quantization_attrs(); + + prim.num_outputs = op->get_output_size(); + prim.output_data_types = get_output_data_types(op); p.add_primitive(*op, prim); } @@ -48,6 +63,7 @@ void CreateKVCacheOp(ProgramBuilder& p, const std::shared_ptrget_input1_transpose_order(), op->get_input2_transpose_order(), op->get_output_transpose_order(), - op->get_kv_compressed(), - op->get_combine_scales_and_zp(), - op->get_quantization_config()); + op->get_quantization_attrs(), + op->get_kv_compressed()); p.add_primitive(*op, sdpa_prim); } diff --git a/src/plugins/intel_gpu/src/plugin/transformations/dynamic_quantize_fully_connected.cpp b/src/plugins/intel_gpu/src/plugin/transformations/dynamic_quantize_fully_connected.cpp index 62947c9e6feedf..68328160a98f82 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/dynamic_quantize_fully_connected.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/dynamic_quantize_fully_connected.cpp @@ -62,9 +62,9 @@ DynamicQuantizeFullyConnected::DynamicQuantizeFullyConnected(uint64_t group_size std::vector shape_group_size(rank, 1); shape_group_size.back() = group_size; - ov::op::internal::QuantizationConfig config; + ov::op::internal::DynamicQuantize::Attributes config; config.quantization_dt = element::i8; - config.type = ov::op::internal::QuantizationConfig::QuantizationType::Symmetric; + config.quantization_type = ov::op::internal::DynamicQuantize::QuantizationType::Symmetric; config.scale_dt = element::f16; config.group_sizes = shape_group_size; diff --git a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp index 9072b8c57d7f1f..561822f9661109 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/kv_cache_compression.cpp @@ -5,6 +5,7 @@ #include "kv_cache_compression.hpp" #include "intel_gpu/op/kv_cache.hpp" +#include "intel_gpu/op/kv_cache_compressed.hpp" #include "intel_gpu/op/indirect_sdpa.hpp" #include "intel_gpu/op/read_value.hpp" #include "intel_gpu/op/read_values.hpp" @@ -36,32 +37,29 @@ namespace intel_gpu { namespace { std::vector get_variable_infos(const ov::op::util::VariableInfo& data_variable_info, - const ov::op::internal::QuantizationConfig& config, - const std::vector& scales_zp_output_order, - const bool combine_scales_and_zp = false) { + const ov::op::internal::DynamicQuantize::Attributes& quantization_attrs) { std::vector infos; // Add initial data variable info infos.push_back(data_variable_info); - // Infer DQ shapes - auto output_storage_type = combine_scales_and_zp ? ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP - : ov::op::internal::DynamicQuantize::OutputStorageType::Planar; ov::op::internal::DynamicQuantize dq; - auto dq_shapes = - ov::op::internal::DynamicQuantize::shape_infer(&dq, {data_variable_info.data_shape}, config, output_storage_type, scales_zp_output_order); + dq.set_attrs(quantization_attrs); + + auto dq_shapes = ov::op::internal::DynamicQuantize::shape_infer(&dq, {data_variable_info.data_shape}); const auto variable_id = data_variable_info.variable_id; const auto scale_shape = dq_shapes[1]; - const auto scale_dt = config.scale_dt; + const auto scale_dt = quantization_attrs.scale_dt; // Add scales variable info infos.push_back(ov::op::util::VariableInfo{scale_shape, scale_dt, variable_id}); - if (config.is_asymmetric_quantization() && !combine_scales_and_zp) { + if (quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) { // Add zero points variable info - const auto zp_dt = config.zp_dt; + const auto zp_dt = quantization_attrs.zp_dt; infos.push_back(ov::op::util::VariableInfo{scale_shape, zp_dt, variable_id}); } @@ -70,30 +68,25 @@ std::vector get_variable_infos(const ov::op::util::V std::shared_ptr update_past_read_value(std::shared_ptr past_rv_node, - const ov::op::internal::QuantizationConfig& config, - const std::vector& scales_zp_output_order, - const bool combine_scales_and_zp = false) { + const ov::op::internal::DynamicQuantize::Attributes& quantization_attrs) { auto variable = past_rv_node->get_variable(); - variable->update_data_type(config.quantization_dt); + variable->update_data_type(quantization_attrs.quantization_dt); - auto variable_infos = get_variable_infos(past_rv_node->get_variable()->get_info(), config, scales_zp_output_order, combine_scales_and_zp); + auto variable_infos = get_variable_infos(past_rv_node->get_variable()->get_info(), quantization_attrs); auto new_past_rv_node = std::make_shared(); if (past_rv_node->get_input_size() == 0) { new_past_rv_node = std::make_shared(past_rv_node->get_variable(), variable_infos); } else { - auto output_storage_type = combine_scales_and_zp ? ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP - : ov::op::internal::DynamicQuantize::OutputStorageType::Planar; auto initializer_dq = std::make_shared(past_rv_node->get_input_node_shared_ptr(0), - config, - output_storage_type, - scales_zp_output_order); + quantization_attrs); initializer_dq->set_friendly_name(past_rv_node->get_input_node_shared_ptr(0)->get_friendly_name() + "_dyn_quan"); ov::copy_runtime_info(past_rv_node->get_input_node_shared_ptr(0), initializer_dq); OutputVector initializer_outputs = { initializer_dq->output(0), initializer_dq->output(1) }; - if (config.is_asymmetric_quantization() && !combine_scales_and_zp) + if (quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) initializer_outputs.push_back(initializer_dq->output(2)); new_past_rv_node = std::make_shared(initializer_outputs, past_rv_node->get_variable(), variable_infos); @@ -105,27 +98,24 @@ std::shared_ptr return new_past_rv_node; } -std::shared_ptr +std::shared_ptr update_kv_cache(std::shared_ptr past_rv_node, std::shared_ptr kv_cache_node, - const ov::op::internal::QuantizationConfig& config, - const std::vector& scales_zp_output_order, - const bool combine_scales_and_zp = false) { + const ov::op::internal::DynamicQuantize::Attributes& quantization_attrs) { OutputVector kv_cache_inputs = { past_rv_node->output(0), kv_cache_node->get_input_node_shared_ptr(1), kv_cache_node->get_input_node_shared_ptr(2), past_rv_node->output(1) }; - if (config.is_asymmetric_quantization() && !combine_scales_and_zp) + if (quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) kv_cache_inputs.push_back(past_rv_node->output(2)); - auto new_kv_cache = std::make_shared(kv_cache_inputs, - kv_cache_node->get_variable(), - kv_cache_node->get_concat_axis(), - kv_cache_node->get_gather_axis(), - combine_scales_and_zp, - config, - scales_zp_output_order); + auto new_kv_cache = std::make_shared(kv_cache_inputs, + kv_cache_node->get_variable(), + kv_cache_node->get_concat_axis(), + kv_cache_node->get_gather_axis(), + quantization_attrs); new_kv_cache->set_friendly_name(kv_cache_node->get_friendly_name()); ov::copy_runtime_info(kv_cache_node, new_kv_cache); @@ -146,12 +136,13 @@ KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compressi if (compression_dt != element::i8) return; - auto quantization_type = ov::op::internal::QuantizationConfig::QuantizationType::Asymmetric; - bool combine_scales_and_zp = quantization_type == ov::op::internal::QuantizationConfig::QuantizationType::Asymmetric; + const auto quantization_type = ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric; + const auto output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; + bool combine_scales_and_zp = output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; GPU_DEBUG_LOG << "KV-cache compression configuration: " << "dt=" << compression_dt << ", " - << "asym=" << (quantization_type == ov::op::internal::QuantizationConfig::QuantizationType::Asymmetric) << ", " + << "asym=" << (quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric) << ", " << "single_buffer_for_scales_and_zp=" << combine_scales_and_zp << "\n"; auto query = any_input(); @@ -219,23 +210,22 @@ KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compressi return scales_zp_output_order; }; - auto group_sizes = get_shape_group_sizes(sdpa_node->get_input1_transpose_order()); - auto scales_zp_output_order = get_scales_output_order(sdpa_node->get_input1_transpose_order()); - - ov::op::internal::QuantizationConfig config; - config.type = quantization_type; - config.group_sizes = group_sizes; + ov::op::internal::DynamicQuantize::Attributes config; + config.quantization_type = quantization_type; + config.group_sizes = get_shape_group_sizes(sdpa_node->get_input1_transpose_order()); config.quantization_dt = element::i8; config.scale_dt = query_node->get_output_element_type(0); + config.scales_zp_output_order = get_scales_output_order(sdpa_node->get_input1_transpose_order()); + config.output_storage_type = output_storage_type; - if (config.is_asymmetric_quantization()) + if (config.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric) config.zp_dt = query_node->get_output_element_type(0); - key_past_rv_node = update_past_read_value(key_past_rv_node, config, scales_zp_output_order, combine_scales_and_zp); - value_past_rv_node = update_past_read_value(value_past_rv_node, config, scales_zp_output_order, combine_scales_and_zp); + key_past_rv_node = update_past_read_value(key_past_rv_node, config); + value_past_rv_node = update_past_read_value(value_past_rv_node, config); - auto new_key_cache = update_kv_cache(key_past_rv_node, key_cache_node, config, scales_zp_output_order, combine_scales_and_zp); - auto new_value_cache = update_kv_cache(value_past_rv_node, value_cache_node, config, scales_zp_output_order, combine_scales_and_zp); + auto new_key_cache = update_kv_cache(key_past_rv_node, key_cache_node, config); + auto new_value_cache = update_kv_cache(value_past_rv_node, value_cache_node, config); OutputVector sdpa_inputs; // Add Query, Key, Value, attention_mask, scale inputs @@ -251,7 +241,8 @@ KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compressi sdpa_inputs.push_back(new_value_cache->output(2)); // Add Key and Value compression zero points - if (config.is_asymmetric_quantization() && !combine_scales_and_zp) { + if (config.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + config.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) { sdpa_inputs.push_back(new_key_cache->output(3)); sdpa_inputs.push_back(new_value_cache->output(3)); } @@ -270,7 +261,6 @@ KVCacheCompressionMatcher::KVCacheCompressionMatcher(ov::element::Type compressi input2_transpose_order, output_transpose_order, config, - combine_scales_and_zp, sdpa_node->get_output_type()); new_key_cache->set_friendly_name(key_cache_node->get_friendly_name()); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_sdpa.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_sdpa.cpp index a900c99eb6a4af..73e916064a0c1c 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_sdpa.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/indirect_sdpa.cpp @@ -33,10 +33,9 @@ IndirectSDPA::IndirectSDPA(const OutputVector& data_inputs, const std::vector& order_k, const std::vector& order_v, const std::vector& order_out, - const QuantizationConfig& quantization_config, - const bool combine_scales_and_zp, + const QuantizationAttribute& quantization_attribute, const ov::element::Type output_type) - : ov::intel_gpu::op::SDPA(data_inputs, is_causal, order_q, order_k, order_v, order_out, quantization_config, combine_scales_and_zp, output_type) + : ov::intel_gpu::op::SDPA(data_inputs, is_causal, order_q, order_k, order_v, order_out, quantization_attribute, output_type) , m_indirect_axis(indirect_axis) { auto beam_table_idx = data_inputs.size(); set_argument(beam_table_idx, beam_table); @@ -68,8 +67,7 @@ std::shared_ptr IndirectSDPA::clone_with_new_inputs(const ov::OutputVe m_order_k, m_order_v, m_order_out, - m_quantization_config, - m_combine_scales_and_zp, + m_quantization_attrs, m_output_type); } } diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp index f0e4f5b829c479..12d961be6d337a 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/kv_cache.cpp @@ -3,6 +3,7 @@ // #include "intel_gpu/op/kv_cache.hpp" +#include "intel_gpu/op/kv_cache_compressed.hpp" #include "gather_shape_inference.hpp" #include "concat_shape_inference.hpp" #include "openvino/core/partial_shape.hpp" @@ -13,66 +14,40 @@ namespace ov { namespace intel_gpu { namespace op { -KVCache::KVCache(const Output& past, - const Output& new_token_data, - const Output& beam_idx, +KVCache::KVCache(const OutputVector& inputs, const std::shared_ptr& past_variable, + bool indirect, int64_t concat_axis, int64_t gather_axis, const ov::element::Type output_type) - : Op({past, new_token_data, beam_idx}) + : Op(inputs) , m_concat_axis(concat_axis) , m_gather_axis(gather_axis) - , m_indirect(true) - , m_compressed(false) + , m_indirect(indirect) , m_output_type(output_type) { m_variable = past_variable; - if (m_indirect) - set_output_size(2); - validate_and_infer_types(); } KVCache::KVCache(const Output& past, const Output& new_token_data, + const Output& beam_idx, const std::shared_ptr& past_variable, int64_t concat_axis, + int64_t gather_axis, const ov::element::Type output_type) - : Op({past, new_token_data}) - , m_concat_axis(concat_axis) - , m_gather_axis(0) - , m_indirect(false) - , m_compressed(false) - , m_output_type(output_type) { - m_variable = past_variable; + : KVCache({past, new_token_data, beam_idx}, past_variable, true, concat_axis, gather_axis, output_type) { + if (m_indirect) + set_output_size(2); validate_and_infer_types(); } -KVCache::KVCache(const OutputVector& inputs, +KVCache::KVCache(const Output& past, + const Output& new_token_data, const std::shared_ptr& past_variable, int64_t concat_axis, - int64_t gather_axis, - bool combine_scales_and_zp, - const QuantizationConfig& config, - const std::vector& scales_zp_output_order, const ov::element::Type output_type) - : Op(inputs) - , m_concat_axis(concat_axis) - , m_gather_axis(gather_axis) - , m_indirect(true) - , m_compressed(true) - , m_combine_scales_and_zp(combine_scales_and_zp) - , m_quantization_config(config) - , m_scales_zp_output_order(scales_zp_output_order) - , m_output_type(output_type) { - OPENVINO_ASSERT(m_quantization_config.quantization_dt == ov::element::i8, - "[GPU] Only I8 data type is currently supported for KV-cache compression"); - + : KVCache({past, new_token_data}, past_variable, false, concat_axis, 0, output_type) { m_variable = past_variable; - size_t output_size = 3; - if (config.is_asymmetric_quantization() && !combine_scales_and_zp) - output_size++; // add zp output - - set_output_size(output_size); validate_and_infer_types(); } @@ -81,15 +56,12 @@ bool KVCache::visit_attributes(ov::AttributeVisitor& visitor) { visitor.on_attribute("gather_axis", m_gather_axis); visitor.on_attribute("indirect", m_indirect); visitor.on_attribute("output_type", m_output_type); - visitor.on_attribute("compressed", m_compressed); return true; } void KVCache::validate_and_infer_types() { auto output_type = m_output_type; - if (m_compressed) { - output_type = m_quantization_config.quantization_dt; - } else if (m_output_type == ov::element::undefined) { + if (m_output_type == ov::element::undefined) { output_type = get_input_element_type(0); } @@ -98,15 +70,7 @@ void KVCache::validate_and_infer_types() { input_shapes.push_back(get_input_partial_shape(2)); } - if (m_compressed) { - input_shapes.push_back(get_input_partial_shape(3)); - - if (m_quantization_config.is_asymmetric_quantization() && !m_combine_scales_and_zp) - input_shapes.push_back(get_input_partial_shape(4)); - } - - auto shapes = m_compressed ? shape_infer(this, input_shapes, m_quantization_config, m_scales_zp_output_order, m_combine_scales_and_zp) - : shape_infer(this, input_shapes); + auto shapes = shape_infer(this, input_shapes); size_t out_ports = 0; set_output_type(out_ports++, output_type, shapes[0]); @@ -114,14 +78,6 @@ void KVCache::validate_and_infer_types() { if (m_indirect) { set_output_type(out_ports++, get_input_element_type(2), shapes[1]); } - - if (m_compressed) { - set_output_type(out_ports++, m_quantization_config.scale_dt, shapes[2]); - - if (m_quantization_config.is_asymmetric_quantization() && !m_combine_scales_and_zp) { - set_output_type(out_ports++, m_quantization_config.zp_dt, shapes[3]); - } - } } std::shared_ptr KVCache::clone_with_new_inputs(const ov::OutputVector& new_args) const { @@ -133,7 +89,7 @@ std::shared_ptr KVCache::clone_with_new_inputs(const ov::OutputVector& new m_concat_axis, m_output_type); - } else if (new_args.size() == 3) { + } else { return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), @@ -141,47 +97,7 @@ std::shared_ptr KVCache::clone_with_new_inputs(const ov::OutputVector& new m_concat_axis, m_gather_axis, m_output_type); - } else { - return std::make_shared(new_args, - m_variable, - m_concat_axis, - m_gather_axis, - m_combine_scales_and_zp, - m_quantization_config, - m_scales_zp_output_order, - m_output_type); - } -} - -std::vector shape_infer(const KVCache* op, - const std::vector& input_shapes, - const ov::op::internal::QuantizationConfig& config, - const std::vector& scales_zp_output_order, - bool combine_scales_and_zp) { - std::vector out_shapes = shape_infer(op, input_shapes); - - if (op->get_output_size() >= 3) { - ov::op::internal::DynamicQuantize op; - - auto output_storage_type = combine_scales_and_zp ? ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP - : ov::op::internal::DynamicQuantize::OutputStorageType::Planar; - auto quantized_data_shapes = - ov::op::internal::DynamicQuantize::shape_infer(&op, { input_shapes[1] }, config, output_storage_type, scales_zp_output_order); - - const auto scales_concat_axis = 2; - ov::PartialShape compression_scale_shape = input_shapes[3]; - compression_scale_shape[scales_concat_axis] += quantized_data_shapes[1][scales_concat_axis]; - out_shapes[2] = compression_scale_shape; - - // add zp output - if (quantized_data_shapes.size() == 3) { - ov::PartialShape compression_zp_shape = input_shapes[4]; - compression_zp_shape[scales_concat_axis] += quantized_data_shapes[2][scales_concat_axis]; - out_shapes[3] = compression_zp_shape; - } } - - return out_shapes; } std::vector shape_infer(const KVCache* op, const std::vector& input_shapes) { @@ -207,6 +123,87 @@ std::vector shape_infer(const KVCache* op, const std::vector& past_variable, + int64_t concat_axis, + int64_t gather_axis, + const QuantizationAttrs& quantization_attrs, + const ov::element::Type output_type) + : KVCache(inputs, past_variable, true, concat_axis, gather_axis, output_type) + , m_compressed(true) + , m_quantization_attrs(quantization_attrs) { + OPENVINO_ASSERT(quantization_attrs.quantization_dt == ov::element::i8, + "[GPU] Only I8 data type is currently supported for KV-cache compression"); + + m_variable = past_variable; + size_t output_size = 3; + if (quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) + output_size++; // add zp output + + set_output_size(output_size); + validate_and_infer_types(); +} + +void KVCacheCompressed::validate_and_infer_types() { + std::vector input_shapes = {m_variable->get_info().data_shape, get_input_partial_shape(1)}; + input_shapes.push_back(get_input_partial_shape(2)); + input_shapes.push_back(get_input_partial_shape(3)); + + if (m_quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + m_quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) + input_shapes.push_back(get_input_partial_shape(4)); + + auto shapes = shape_infer(this, input_shapes); + + size_t out_ports = 0; + set_output_type(out_ports++, m_quantization_attrs.quantization_dt, shapes[0]); + set_output_type(out_ports++, get_input_element_type(2), shapes[1]); + set_output_type(out_ports++, m_quantization_attrs.scale_dt, shapes[2]); + + if (m_quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + m_quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) { + set_output_type(out_ports++, m_quantization_attrs.zp_dt, shapes[3]); + } +} + +std::shared_ptr KVCacheCompressed::clone_with_new_inputs(const ov::OutputVector& new_args) const { + check_new_args_count(this, new_args); + return std::make_shared(new_args, + m_variable, + m_concat_axis, + m_gather_axis, + m_quantization_attrs, + m_output_type); +} + +std::vector shape_infer(const KVCacheCompressed* op, + const std::vector& input_shapes) { + std::vector out_shapes = shape_infer(static_cast(op), input_shapes); + + if (op->get_output_size() >= 3) { + ov::op::internal::DynamicQuantize dq_op; + dq_op.set_attrs(op->get_quantization_attrs()); + + auto quantized_data_shapes = + ov::op::internal::DynamicQuantize::shape_infer(&dq_op, { input_shapes[1] }); + + const auto scales_concat_axis = 2; + ov::PartialShape compression_scale_shape = input_shapes[3]; + compression_scale_shape[scales_concat_axis] += quantized_data_shapes[1][scales_concat_axis]; + out_shapes[2] = compression_scale_shape; + + // add zp output + if (quantized_data_shapes.size() == 3) { + ov::PartialShape compression_zp_shape = input_shapes[4]; + compression_zp_shape[scales_concat_axis] += quantized_data_shapes[2][scales_concat_axis]; + out_shapes[3] = compression_zp_shape; + } + } + + return out_shapes; +} + } // namespace op } // namespace intel_gpu } // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp index 65930d5feb6d0a..09513d99153a1f 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/sdpa.cpp @@ -39,8 +39,7 @@ SDPA::SDPA(const OutputVector& inputs, const std::vector& order_k, const std::vector& order_v, const std::vector& order_out, - const QuantizationConfig& quantization_config, - const bool combine_scales_and_zp, + const QuantizationAttribute& quantization_attrs, const ov::element::Type output_type) : m_is_causal(is_causal) , m_order_q(order_q) @@ -49,8 +48,7 @@ SDPA::SDPA(const OutputVector& inputs, , m_order_out(order_out) , m_output_type(output_type) , m_compressed(true) - , m_combine_scales_and_zp(combine_scales_and_zp) - , m_quantization_config(quantization_config) { + , m_quantization_attrs(quantization_attrs) { set_arguments(inputs); set_causal(is_causal); validate_and_infer_types(); @@ -108,7 +106,8 @@ size_t SDPA::get_compression_inputs_num() const { if (m_compressed) { compression_inputs += 2; // 2 * scales - if (m_quantization_config.is_asymmetric_quantization() && !m_combine_scales_and_zp) + if (m_quantization_attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && + m_quantization_attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) compression_inputs += 2; // 2 * zp } diff --git a/src/plugins/intel_gpu/src/runtime/execution_config.cpp b/src/plugins/intel_gpu/src/runtime/execution_config.cpp index 09a979c495f207..c48f3f02fa9f6a 100644 --- a/src/plugins/intel_gpu/src/runtime/execution_config.cpp +++ b/src/plugins/intel_gpu/src/runtime/execution_config.cpp @@ -211,6 +211,16 @@ void ExecutionConfig::apply_debug_options(const cldnn::device_info& info) { set_property(ov::hint::dynamic_quantization_group_size(debug_config->dynamic_quantize_group_size)); } + int KVCacheCompression = 0; + if (const auto env_var = std::getenv("KVCacheCompression")) { + std::istringstream ss(env_var); + ss >> KVCacheCompression; + } + + if (KVCacheCompression == 1) { + set_property(ov::hint::kv_cache_precision(ov::element::i8)); + } + GPU_DEBUG_IF(debug_config->use_kv_cache_compression != -1) { GPU_DEBUG_IF(debug_config->use_kv_cache_compression == 1) { set_property(ov::hint::kv_cache_precision(ov::element::i8)); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/dynamic_quantize_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/dynamic_quantize_gpu_test.cpp index 9a00f330bd7018..5a78360eb1f6d8 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/dynamic_quantize_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/dynamic_quantize_gpu_test.cpp @@ -22,7 +22,7 @@ using namespace cldnn; using namespace ::tests; -using QuantizationType = dynamic_quantize::QuantizationConfig::QuantizationType; +using QuantizationType = ov::op::internal::DynamicQuantize::QuantizationType; class dynamic_quantization_gpu_tests: public ::testing::Test { public: @@ -51,15 +51,18 @@ class dynamic_quantization_gpu_tests: public ::testing::Test { auto in_layout = input_shape.is_dynamic() ? layout{ dyn_input_ps, data_types::f16, format::bfyx } : layout{ input_ps, data_types::f16, format::bfyx }; - dynamic_quantize::QuantizationConfig dq_config; - dq_config.type = quantization_type; + dynamic_quantize::Attributes dq_config; + dq_config.quantization_type = quantization_type; dq_config.quantization_dt = data_types::i8; dq_config.scale_dt = data_types::f16; dq_config.zp_dt = data_types::undefined; dq_config.group_sizes = group_sizes; + dq_config.scales_zp_output_order = { 0, 1, 2, 3 }; + dq_config.output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::Planar; if (quantization_type == QuantizationType::Asymmetric) { dq_config.zp_dt = data_types::f16; + dq_config.output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; } auto reorder_1 = reorder("reorder_1", input_info("input"), layout{ input_ps, data_types::f16, format::bfyx }); diff --git a/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp b/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp index 689bb835d72936..67123f1d84cfe7 100644 --- a/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp +++ b/src/plugins/intel_gpu/tests/unit/transformations/kv_cache_compression.cpp @@ -16,6 +16,7 @@ #include "intel_gpu/op/read_value.hpp" #include "intel_gpu/op/read_values.hpp" #include "intel_gpu/op/kv_cache.hpp" +#include "intel_gpu/op/kv_cache_compressed.hpp" #include "intel_gpu/op/indirect_sdpa.hpp" #include "plugin/transformations/kv_cache_compression.hpp" @@ -95,14 +96,14 @@ TEST_F(TransformationTestsF, KVCacheCompression) { manager.register_pass(ov::element::i8); } { - bool combine_scales_and_zp = true; - std::vector scales_zp_output_order = { 0, 1, 2, 3 }; - ov::intel_gpu::op::KVCache::QuantizationConfig dq_config; - dq_config.type = ov::intel_gpu::op::KVCache::QuantizationConfig::QuantizationType::Asymmetric; + ov::op::internal::DynamicQuantize::Attributes dq_config; + dq_config.quantization_type = ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric; dq_config.quantization_dt = ov::element::i8; dq_config.scale_dt = ov::element::f16; dq_config.zp_dt = ov::element::f16; dq_config.group_sizes = { 1, 1, 1, UINT64_MAX }; + dq_config.scales_zp_output_order = { 0, 1, 2, 3 }; + dq_config.output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; auto query = std::make_shared(element_type, input_shape); auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape{1}); @@ -113,13 +114,11 @@ TEST_F(TransformationTestsF, KVCacheCompression) { ov::op::util::VariableInfo{{1, 32, -1, 2}, ov::element::f16, "v0"} }; auto key_past_compressed = std::make_shared(key_variable, key_past_variable_infos); auto key_cache_inputs = ov::OutputVector{ key_past_compressed->output(0), key_current, beam_idx, key_past_compressed->output(1) }; - auto key_cache = std::make_shared(key_cache_inputs, - key_variable, - concat_axis, - gather_axis, - combine_scales_and_zp, - dq_config, - scales_zp_output_order); + auto key_cache = std::make_shared(key_cache_inputs, + key_variable, + concat_axis, + gather_axis, + dq_config); auto value_variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v1"}); auto value_current = std::make_shared(ov::element::f16, input_shape); @@ -127,13 +126,11 @@ TEST_F(TransformationTestsF, KVCacheCompression) { ov::op::util::VariableInfo{{1, 32, -1, 2}, ov::element::f16, "v1"} }; auto value_past_compressed = std::make_shared(value_variable, value_past_variable_infos); auto value_cache_inputs = ov::OutputVector{ value_past_compressed->output(0), value_current, beam_idx, value_past_compressed->output(1) }; - auto value_cache = std::make_shared(value_cache_inputs, - value_variable, - concat_axis, - gather_axis, - combine_scales_and_zp, - dq_config, - scales_zp_output_order); + auto value_cache = std::make_shared(value_cache_inputs, + value_variable, + concat_axis, + gather_axis, + dq_config); ov::ParameterVector params{ beam_idx, query, key_current, value_current }; @@ -170,8 +167,7 @@ TEST_F(TransformationTestsF, KVCacheCompression) { qkv_order, qkv_order, ov::intel_gpu::op::SDPA::default_order(4), - dq_config, - combine_scales_and_zp); + dq_config); auto result = std::make_shared(sdpa); @@ -251,15 +247,14 @@ TEST_F(TransformationTestsF, KVCacheCompressionWithInitializers) { manager.register_pass(ov::element::i8); } { - bool combine_scales_and_zp = true; - std::vector scales_zp_output_order = { 0, 1, 2, 3 }; - ov::intel_gpu::op::KVCache::QuantizationConfig dq_config; - dq_config.type = ov::intel_gpu::op::KVCache::QuantizationConfig::QuantizationType::Asymmetric; + ov::op::internal::DynamicQuantize::Attributes dq_config; + dq_config.quantization_type = ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric; dq_config.quantization_dt = ov::element::i8; dq_config.scale_dt = ov::element::f16; dq_config.zp_dt = ov::element::f16; dq_config.group_sizes = { 1, 1, 1, UINT64_MAX }; - auto output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; + dq_config.scales_zp_output_order = { 0, 1, 2, 3 }; + dq_config.output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; auto query = std::make_shared(element_type, input_shape); auto beam_idx = std::make_shared(ov::element::i32, ov::PartialShape{1}); @@ -271,17 +266,15 @@ TEST_F(TransformationTestsF, KVCacheCompressionWithInitializers) { auto key_variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v0"}); auto key_initializer_dq = - std::make_shared(key_variable_initializer, dq_config, output_storage_type, scales_zp_output_order); + std::make_shared(key_variable_initializer, dq_config); auto key_past_initializers = ov::OutputVector{ key_initializer_dq->output(0), key_initializer_dq->output(1) }; auto key_past_compressed = std::make_shared(key_past_initializers, key_variable, key_past_variable_infos); auto key_cache_inputs = ov::OutputVector{ key_past_compressed->output(0), key_current, beam_idx, key_past_compressed->output(1) }; - auto key_cache = std::make_shared(key_cache_inputs, - key_variable, - concat_axis, - gather_axis, - combine_scales_and_zp, - dq_config, - scales_zp_output_order); + auto key_cache = std::make_shared(key_cache_inputs, + key_variable, + concat_axis, + gather_axis, + dq_config); auto value_past_variable_infos = { ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::i8, "v1"}, ov::op::util::VariableInfo{{1, 32, -1, 2}, ov::element::f16, "v1"} }; @@ -291,17 +284,15 @@ TEST_F(TransformationTestsF, KVCacheCompressionWithInitializers) { auto value_variable = std::make_shared(ov::op::util::VariableInfo{{1, 32, -1, 80}, ov::element::f16, "v1"}); auto value_initializer_dq = - std::make_shared(value_variable_initializer, dq_config, output_storage_type, scales_zp_output_order); + std::make_shared(value_variable_initializer, dq_config); auto value_past_initializers = ov::OutputVector{ value_initializer_dq->output(0), value_initializer_dq->output(1) }; auto value_past_compressed = std::make_shared(value_past_initializers, value_variable, value_past_variable_infos); auto value_cache_inputs = ov::OutputVector{ value_past_compressed->output(0), value_current, beam_idx, value_past_compressed->output(1) }; - auto value_cache = std::make_shared(value_cache_inputs, - value_variable, - concat_axis, - gather_axis, - combine_scales_and_zp, - dq_config, - scales_zp_output_order); + auto value_cache = std::make_shared(value_cache_inputs, + value_variable, + concat_axis, + gather_axis, + dq_config); ov::ParameterVector params{ beam_idx, query, key_current, value_current, key_variable_initializer, value_variable_initializer }; @@ -338,8 +329,7 @@ TEST_F(TransformationTestsF, KVCacheCompressionWithInitializers) { qkv_order, qkv_order, ov::intel_gpu::op::SDPA::default_order(4), - dq_config, - combine_scales_and_zp); + dq_config); auto result = std::make_shared(sdpa);