Skip to content

Commit

Permalink
WIP: zp support
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Oct 16, 2024
1 parent 08cafc3 commit 775d01a
Show file tree
Hide file tree
Showing 27 changed files with 243 additions and 49 deletions.
11 changes: 10 additions & 1 deletion src/common/transformations/include/ov_ops/dynamic_quantize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,18 @@ class TRANSFORMATIONS_API DynamicQuantize : public ov::op::Op {
public:
OPENVINO_OP("DynamicQuantize", "gpu_opset");

enum class QuantizationMode {
Asymmetric,
Symmetric
};

DynamicQuantize() = default;
/// \brief Constructs an DynamicQuantize operation.
///
/// \param data Input tensor with data
/// \param group_sizes Group sizes for dynamic quantization
/// \param dt_scale Data type for scale output
DynamicQuantize(const Output<Node>& data, std::vector<uint64_t> group_sizes, element::Type dt_scale, std::vector<uint64_t> scales_output_order = {});
DynamicQuantize(const Output<Node>& data, std::vector<uint64_t> group_sizes, element::Type dt_scale, QuantizationMode mode, std::vector<uint64_t> scales_output_order = {});

void validate_and_infer_types() override;

Expand All @@ -33,12 +38,16 @@ class TRANSFORMATIONS_API DynamicQuantize : public ov::op::Op {
const std::vector<uint64_t>& get_scales_output_order() const {
return m_scales_output_order;
};
QuantizationMode get_quantization_mode() const {
return m_mode;
};
static std::vector<ov::PartialShape> shape_infer(const DynamicQuantize* op,
const std::vector<ov::PartialShape>& input_shapes,
const std::vector<uint64_t>& group_sizes,
const std::vector<uint64_t>& scales_output_order = {});

private:
QuantizationMode m_mode;
std::vector<uint64_t> m_group_sizes;
std::vector<uint64_t> m_scales_output_order;
element::Type m_dt_scale;
Expand Down
5 changes: 3 additions & 2 deletions src/common/transformations/src/ov_ops/dynamic_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,9 @@ namespace ov {
namespace op {
namespace internal {

DynamicQuantize::DynamicQuantize(const Output<Node>& data, std::vector<uint64_t> group_sizes, element::Type dt_scale, std::vector<uint64_t> scales_output_order)
DynamicQuantize::DynamicQuantize(const Output<Node>& data, std::vector<uint64_t> group_sizes, element::Type dt_scale, QuantizationMode mode, std::vector<uint64_t> scales_output_order)
: Op({data}),
m_mode(mode),
m_group_sizes(std::move(group_sizes)),
m_scales_output_order(std::move(scales_output_order)),
m_dt_scale(dt_scale) {
Expand All @@ -38,7 +39,7 @@ void DynamicQuantize::validate_and_infer_types() {

std::shared_ptr<Node> DynamicQuantize::clone_with_new_inputs(const ov::OutputVector& new_args) const {
check_new_args_count(this, new_args);
return std::make_shared<DynamicQuantize>(new_args.at(0), m_group_sizes, m_dt_scale);
return std::make_shared<DynamicQuantize>(new_args.at(0), m_group_sizes, m_dt_scale, m_mode);
}

std::vector<ov::PartialShape> DynamicQuantize::shape_infer(const DynamicQuantize* op,
Expand Down
3 changes: 3 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/indirect_sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ class IndirectSDPA : public ov::intel_gpu::op::SDPA {

std::shared_ptr<Node> clone_with_new_inputs(const ov::OutputVector& new_args) const override;

void set_asym(bool val) { m_is_asym_compressed = val; }
bool get_asym() const { return m_is_asym_compressed; }

ov::element::Type get_output_type() const { return m_output_type; }

int64_t get_indirect_axis() const { return m_indirect_axis; }
Expand Down
9 changes: 9 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/kv_cache.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,13 +70,22 @@ class KVCache : public ov::op::Op, public ov::op::util::VariableExtension {
const std::vector<uint64_t>& get_group_sizes() const { return m_group_sizes; };
const std::vector<uint64_t>& get_scales_output_order() const { return m_scales_output_order; };

bool get_asymmetric_quantization() const {
return m_use_asymmetric_quantization;
}
void set_asymmetric_quantization(bool val) {
m_use_asymmetric_quantization = val;
}

private:
int64_t m_concat_axis = 0;
int64_t m_gather_axis = 0;
bool m_indirect = false;

// KV-cache compression parameters
// TODO: move these parameters to separate structure
bool m_compressed = false;
bool m_use_asymmetric_quantization = false;
std::vector<uint64_t> m_group_sizes = {};
std::vector<uint64_t> m_scales_output_order = {};
ov::element::Type m_compression_type = ov::element::undefined;
Expand Down
1 change: 1 addition & 0 deletions src/plugins/intel_gpu/include/intel_gpu/op/sdpa.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ class SDPA : public ov::op::v13::ScaledDotProductAttention {
protected:
bool m_is_causal;
bool m_is_kv_compressed;
bool m_is_asym_compressed;
std::vector<int64_t> m_order_q;
std::vector<int64_t> m_order_k;
std::vector<int64_t> m_order_v;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ struct dynamic_quantize : public primitive_base<dynamic_quantize> {
group_sizes(group_sizes),
scales_output_order(scales_output_order) {}

bool use_asymmetric_quantization = false;
std::vector<uint64_t> group_sizes;
std::vector<uint64_t> scales_output_order;

Expand All @@ -50,6 +51,7 @@ struct dynamic_quantize : public primitive_base<dynamic_quantize> {
void save(BinaryOutputBuffer& ob) const override {
primitive_base<dynamic_quantize>::save(ob);
ob << group_sizes;
// TODO: add more parameters
}

void load(BinaryInputBuffer& ib) override {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ struct kv_cache : public primitive_base<kv_cache> {
int64_t gather_axis = 0;
bool indirect = false;
bool compressed = false;
bool use_asymmetric_quantization = false;
std::vector<uint64_t> group_sizes = {};
std::vector<uint64_t> scales_output_order = {};
ov::element::Type compression_type = ov::element::undefined;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a

bool is_causal = false;
bool is_kv_compressed = false;
bool is_asym_compressed = false;
bool has_attn_mask_input = false;
bool has_scale_input = false;
int64_t indirect_axis = -1;
Expand Down Expand Up @@ -95,6 +96,7 @@ struct scaled_dot_product_attention : public primitive_base<scaled_dot_product_a
ob << input_k_transpose_order;
ob << input_v_transpose_order;
ob << output_transpose_order;
// TODO: add new params
}

void load(BinaryInputBuffer& ib) override {
Expand Down
11 changes: 8 additions & 3 deletions src/plugins/intel_gpu/src/graph/dynamic_quantize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ layout dynamic_quantize_inst::calc_output_layout(dynamic_quantize_node const& no
}

template<typename ShapeType>
std::vector<layout> dynamic_quantize_inst::__calc_output_layouts(const layout &act_layout, const std::vector<uint64_t>& group_sizes, const std::vector<uint64_t>& scales_output_order) {
std::vector<layout> dynamic_quantize_inst::__calc_output_layouts(const layout &act_layout, const std::vector<uint64_t>& group_sizes, const std::vector<uint64_t>& scales_output_order, bool use_asymmetric_quantization) {
ov::op::internal::DynamicQuantize op;
auto output_format = act_layout.format;

Expand All @@ -44,16 +44,21 @@ std::vector<layout> dynamic_quantize_inst::__calc_output_layouts(const layout &a
auto output_shapes = ov::op::internal::DynamicQuantize::shape_infer(&op, input_shapes, group_sizes, scales_output_order);
GPU_DEBUG_TRACE_DETAIL << "shape infer dynamic" << output_shapes[0] << " " << output_shapes[1] << "\n";

if (use_asymmetric_quantization) {
output_shapes[1][3] *= 2;
}

return { layout(output_shapes[0], data_types::i8, output_format), layout(output_shapes[1], data_types::f16, output_format) };
}

template std::vector<layout> dynamic_quantize_inst::__calc_output_layouts<ov::PartialShape>(const layout &act_layout, const std::vector<uint64_t>& group_sizes, const std::vector<uint64_t>& scales_output_order);
template std::vector<layout> dynamic_quantize_inst::__calc_output_layouts<ov::PartialShape>(const layout &act_layout, const std::vector<uint64_t>& group_sizes, const std::vector<uint64_t>& scales_output_order, bool use_asymmetric_quantization);

template<typename ShapeType>
std::vector<layout> dynamic_quantize_inst::calc_output_layouts(dynamic_quantize_node const& /*node*/, const kernel_impl_params& impl_param) {
auto desc = impl_param.typed_desc<dynamic_quantize>();
const auto& input_layout = impl_param.get_input_layout();
return __calc_output_layouts<ov::PartialShape>(input_layout, desc->group_sizes, desc->scales_output_order);

return __calc_output_layouts<ov::PartialShape>(input_layout, desc->group_sizes, desc->scales_output_order, desc->use_asymmetric_quantization);
}

template std::vector<layout> dynamic_quantize_inst::calc_output_layouts<ov::PartialShape>(dynamic_quantize_node const& node,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ struct dynamic_quantize_impl : typed_primitive_impl_ocl<dynamic_quantize> {
const auto& desc = impl_param.typed_desc<dynamic_quantize>();
params.group_sizes = desc->group_sizes;
params.scales_output_order = desc->scales_output_order;
params.use_asymmetric_quantization = desc->use_asymmetric_quantization;

return params;
}
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,6 +411,8 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
params.append_axis = primitive->concat_axis;
params.group_sizes = primitive->group_sizes;
params.scales_output_order = primitive->scales_output_order;
params.use_asymmetric_quantization = primitive->use_asymmetric_quantization;
params.group_scales_with_zp = true;

if (!is_shape_agnostic) {
const auto& past_kv_cache_shape = impl_param.input_layouts[0].get_partial_shape();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,10 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod

config.is_causal = desc->is_causal;
config.is_kv_compressed = desc->is_kv_compressed;
config.is_asym_compressed = desc->is_asym_compressed;

GPU_DEBUG_TRACE << "Set is_kv_compressed to " << config.is_kv_compressed << "\n";
GPU_DEBUG_TRACE << "Set is_asym_compressed to " << config.is_asym_compressed << "\n";

return config;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class typed_primitive_inst<dynamic_quantize> : public typed_primitive_inst_base<

// Internal function to be used from fakealignment
template<typename ShapeType>
static std::vector<layout> __calc_output_layouts(const layout &act_layout, const std::vector<uint64_t>& group_size, const std::vector<uint64_t>& scales_output_order);
static std::vector<layout> __calc_output_layouts(const layout &act_layout, const std::vector<uint64_t>& group_size, const std::vector<uint64_t>& scales_output_order, bool use_asymmetric_quantization);
static std::string to_string(dynamic_quantize_node const& node);

typed_primitive_inst(network& network, dynamic_quantize_node const& node);
Expand Down
8 changes: 8 additions & 0 deletions src/plugins/intel_gpu/src/graph/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,17 @@ std::vector<layout> kv_cache_inst::calc_output_layouts(kv_cache_node const& node
// input_shapes.push_back(impl_param.get_input_layout(4).get<ShapeType>());
}

if (desc->compressed && desc->use_asymmetric_quantization) {
input_shapes[3][3] /= 2;
}

std::vector<ShapeType> output_shapes = desc->compressed ? shape_infer(&op, input_shapes, desc->group_sizes, desc->scales_output_order)
: shape_infer(&op, input_shapes);

if (desc->compressed && desc->use_asymmetric_quantization) {
output_shapes[2][3] *= 2;
}

if (desc->num_outputs == 3)
GPU_DEBUG_TRACE_DETAIL << desc->id << " scales output calculated shape: " << output_shapes[2] << "\n";

Expand Down
2 changes: 1 addition & 1 deletion src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -688,7 +688,7 @@ event::ptr primitive_inst::realloc_if_needed() {
// dynamic quantization is only applied to activation of FC
if (get_node().is_type<dynamic_quantize>()) {
const auto& desc = get_node().as<dynamic_quantize>().get_primitive();
auto dyn_quan_scale_layout = dynamic_quantize_inst::__calc_output_layouts<ov::PartialShape>(updated_layouts[dep_idx], desc->group_sizes, desc->scales_output_order);
auto dyn_quan_scale_layout = dynamic_quantize_inst::__calc_output_layouts<ov::PartialShape>(updated_layouts[dep_idx], desc->group_sizes, desc->scales_output_order, desc->use_asymmetric_quantization);
GPU_DEBUG_TRACE_DETAIL << "update layout of dynamic quantize scale parameter layout "
<< dyn_quan_scale_layout[1].to_short_string() << std::endl;
updated_params.output_layouts[1] = dyn_quan_scale_layout[1];
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ inline uint FUNC(get_scales_offset_nt)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, u
return OUTPUT1_GET_INDEX(b, f, y, x);
}

inline uint FUNC(get_scales_offset)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint y, uint x, uint axis_offset) {
#ifdef APPEND_MODE
APPEND_AXIS_NAME += axis_offset;
#endif
inline uint FUNC(get_scales_offset)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint y, uint x) {
#ifdef SCALES_OUTPUT_ORDER
return FUNC_CALL(get_scales_offset_nt)(OPTIONAL_SHAPE_INFO_TENSOR SCALES_OUTPUT_ORDER);
#else
Expand All @@ -49,6 +46,9 @@ KERNEL(dynamic_quantize_gpu_opt_generic)(
const __global INPUT0_TYPE* input,
__global OUTPUT_TYPE* output,
__global OUTPUT1_TYPE* output_scale
#if ASYMMETRIC_QUANTIZATION && !GROUP_SCALES_WITH_ZP
, __global OUTPUT1_TYPE* output_zp
#endif
#ifdef APPEND_MODE
, const uint axis_offset
#endif
Expand All @@ -64,42 +64,71 @@ KERNEL(dynamic_quantize_gpu_opt_generic)(
// the innermost dimension is always handled in the loop inside the kernel
const uint x = 0;

half max_value = 0.0001h;
half max_value = INPUT0_VAL_MIN;
half min_value = INPUT0_VAL_MAX;

half val[INNERMOST_DIM_VALUE / SUBGROUP_SIZE];

const uint input_offset = INPUT0_GET_INDEX(b, f, y, x);
unroll_for (uint i = 0; i < INNERMOST_DIM_VALUE / SUBGROUP_SIZE; i++) {
val[i] = INPUT_BLOCK_READ(input, input_offset + i * SUBGROUP_SIZE);
#if ASYMMETRIC_QUANTIZATION
max_value = fmax(max_value, val[i]);
min_value = fmin(min_value, val[i]);
#else
max_value = fmax(max_value, fabs(val[i]));
#endif
}

#if ASYMMETRIC_QUANTIZATION
min_value = work_group_reduce_min(min_value);
max_value = work_group_reduce_max(max_value);

half scale = 127.0h / max_value;
OUTPUT1_TYPE scale = (OUTPUT1_TYPE)((CHAR_MAX - CHAR_MIN) / (max_value - min_value));
OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(-min_value * scale) - CHAR_MAX;
#else
max_value = work_group_reduce_max(max_value);
OUTPUT1_TYPE scale = 127.0h / max_value;
#endif

#ifdef APPEND_MODE
APPEND_AXIS_NAME += axis_offset;
#endif

const uint output_offset = OUTPUT_GET_INDEX(b, f, y, x);
unroll_for (uint i = 0; i < INNERMOST_DIM_VALUE / SUBGROUP_SIZE; i++) {
OUTPUT_BLOCK_WRITE(output, output_offset + i * SUBGROUP_SIZE, convert_char(val[i] * scale));
}

#ifdef APPEND_MODE
// const uint scale_axis_offset = axis_offset;
const uint scale_axis_offset = 0;
#if ASYMMETRIC_QUANTIZATION
OUTPUT_TYPE res = convert_char(val[i] * scale + zp);
#else
const uint scale_axis_offset = 0;
OUTPUT_TYPE res = convert_char(val[i] * scale);
#endif
const uint scale_idx = FUNC_CALL(get_scales_offset)(OPTIONAL_SHAPE_INFO_TENSOR b, f, y, x, scale_axis_offset);
OUTPUT_BLOCK_WRITE(output, output_offset + i * SUBGROUP_SIZE, res);
}

const uint scale_idx = FUNC_CALL(get_scales_offset)(OPTIONAL_SHAPE_INFO_TENSOR b, f, y, x);

if (grouped_indexes == 0 && sglid == 0) {
#ifdef APPEND_MODE
// if (axis_offset > 0) {
// printf("Save scale_idx=%d, axis_offset=%d; output=%p, scale=%p; val=%f\n", scale_idx, axis_offset, output, output_scale, 1.0h / scale);
// }
#if GROUP_SCALES_WITH_ZP
// half result0 = (convert_half(convert_char(val[0] * scale + zp)) - zp) * (1.0h / scale);
// half result1 = (convert_half(convert_char(val[1] * scale + zp)) - zp) * (1.0h / scale);
// half result2 = (convert_half(convert_char(val[2] * scale + zp)) - zp) * (1.0h / scale);
// half result3 = (convert_half(convert_char(val[3] * scale + zp)) - zp) * (1.0h / scale);
// printf("Save scale_idx=%d, axis_offset=%d; scale=%f; zp=%f, min=%f, max=%f; orig=(%f %f %f %f), compressed=(%d %d %d %d), decompressed=(%f %f)\n", scale_idx, axis_offset, scale, zp, min_value, max_value,
// val[0], val[1], val[2], val[3],
// convert_char(val[0] * scale + zp), convert_char(val[1] * scale + zp), convert_char(val[2] * scale + zp), convert_char(val[3] * scale + zp),
// result0,
// result1);
#endif
#endif
#if ASYMMETRIC_QUANTIZATION
output_scale[scale_idx] = 1.0h / scale;
#if GROUP_SCALES_WITH_ZP
output_scale[scale_idx + 1] = zp;
#else
output_zp[scale_idx] = zp;
#endif
#else
output_scale[scale_idx] = 1.0h / scale;
#endif
}
}
Loading

0 comments on commit 775d01a

Please sign in to comment.