From b840082ac11b1608f349d9554b020498c328164f Mon Sep 17 00:00:00 2001 From: Mingyu Kim Date: Mon, 9 Dec 2024 14:09:30 +0900 Subject: [PATCH 1/8] [GPU] Integrate dynamic quantization for onednn (#26940) ### Details: - Integrated grouped dynamic quantization from onednn - Integrated asymmetric per-token dynamic quantization from onednn - Those are not enabled by default, yet ### Tickets: - 148732, 157869, 157589 --- .../op/fully_connected_compressed.hpp | 1 + .../intel_gpu/primitives/dynamic_quantize.hpp | 13 +- .../intel_gpu/primitives/fully_connected.hpp | 18 +++ .../intel_gpu/runtime/debug_configuration.hpp | 1 + .../prepare_primitive_fusing.cpp | 2 + .../src/graph/impls/ocl/dynamic_quantize.cpp | 8 +- .../impls/onednn/fully_connected_onednn.cpp | 47 +++++-- .../impls/onednn/fully_connected_onednn.hpp | 2 +- .../cl_kernels/dynamic_quantize_gpu_opt.cl | 133 ++++++++++++++++-- .../cl_kernels/dynamic_quantize_gpu_ref.cl | 50 ++++--- .../dynamic_quantize_kernel_opt.cpp | 56 +++++--- .../dynamic_quantize_kernel_ref.cpp | 18 ++- .../fully_connected_kernel_bf_tiled.cpp | 20 +-- .../src/plugin/ops/dynamic_quantize.cpp | 3 +- .../src/plugin/ops/fully_connected.cpp | 4 +- .../intel_gpu/src/plugin/program_builder.cpp | 4 + .../dynamic_quantize_fully_connected.cpp | 30 ++-- .../op/fully_connected_compressed.cpp | 5 +- .../src/plugin/transformations_pipeline.cpp | 22 ++- .../src/runtime/debug_configuration.cpp | 3 + .../src/runtime/execution_config.cpp | 7 +- .../dynamic/matmul_weights_decompression.cpp | 33 +++-- .../test_cases/dynamic_quantize_gpu_test.cpp | 61 +++++--- .../test_cases/fully_connected_gpu_test.cpp | 24 ++-- .../unit/test_cases/hash_key_gpu_test.cpp | 8 +- 25 files changed, 420 insertions(+), 153 deletions(-) diff --git a/src/plugins/intel_gpu/include/intel_gpu/op/fully_connected_compressed.hpp b/src/plugins/intel_gpu/include/intel_gpu/op/fully_connected_compressed.hpp index 1112a3785317a3..e58c6ab4cb17f1 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/op/fully_connected_compressed.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/op/fully_connected_compressed.hpp @@ -22,6 +22,7 @@ class FullyConnectedCompressed : public FullyConnected { const ov::Output &w_decompression_scale, const ov::Output &w_decompression_zero_point, const ov::Output &a_decompression_scale, + const ov::Output &a_decompression_zero_point, const ov::element::Type output_type = ov::element::undefined); diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/dynamic_quantize.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/dynamic_quantize.hpp index 79af223e32cdaa..8dd1ebf2809782 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/dynamic_quantize.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/dynamic_quantize.hpp @@ -26,9 +26,11 @@ struct dynamic_quantize : public primitive_base { /// @param output_size Output data size of the primitive dynamic_quantize(const primitive_id& id, const input_info& input, - const Attributes& attrs) + const Attributes& attrs, + const size_t input_size = 3) : primitive_base(id, {input}) - , attrs(attrs) { + , attrs(attrs) + , input_size(input_size) { num_outputs = 2; if (attrs.quantization_type == ov::op::internal::DynamicQuantize::QuantizationType::Asymmetric && attrs.output_storage_type == ov::op::internal::DynamicQuantize::OutputStorageType::Planar) @@ -36,6 +38,7 @@ struct dynamic_quantize : public primitive_base { } Attributes attrs; + size_t input_size; size_t hash() const override { size_t seed = primitive::hash(); @@ -46,6 +49,7 @@ struct dynamic_quantize : public primitive_base { seed = hash_combine(seed, attrs.scale_dt.hash()); seed = hash_combine(seed, attrs.zp_dt.hash()); seed = hash_combine(seed, attrs.output_storage_type); + seed = hash_combine(seed, input_size); return seed; } @@ -62,7 +66,8 @@ struct dynamic_quantize : public primitive_base { attrs.quantization_dt == rhs_casted.attrs.quantization_dt && attrs.scale_dt == rhs_casted.attrs.scale_dt && attrs.zp_dt == rhs_casted.attrs.zp_dt && - attrs.quantization_type == rhs_casted.attrs.quantization_type;; + attrs.quantization_type == rhs_casted.attrs.quantization_type && + input_size == rhs_casted.input_size; } void save(BinaryOutputBuffer& ob) const override { @@ -75,6 +80,7 @@ struct dynamic_quantize : public primitive_base { ob << make_data(&attrs.output_storage_type, sizeof(attrs.output_storage_type)); ob << attrs.scales_zp_output_order; ob << attrs.group_sizes; + ob << input_size; } void load(BinaryInputBuffer& ib) override { @@ -87,6 +93,7 @@ struct dynamic_quantize : public primitive_base { ib >> make_data(&attrs.output_storage_type, sizeof(attrs.output_storage_type)); ib >> attrs.scales_zp_output_order; ib >> attrs.group_sizes; + ib >> input_size; } }; } // namespace cldnn diff --git a/src/plugins/intel_gpu/include/intel_gpu/primitives/fully_connected.hpp b/src/plugins/intel_gpu/include/intel_gpu/primitives/fully_connected.hpp index e39078cb1011cc..0819a39534696d 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/primitives/fully_connected.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/primitives/fully_connected.hpp @@ -96,6 +96,7 @@ struct fully_connected : public primitive_base { decompression_scale(decompression_scale), decompression_zero_point(decompression_zero_point), dynamic_quantized_activation(false), + dynamic_quantized_activation_zp(false), input_size(input_size), weights_rank(weights_rank) { OPENVINO_ASSERT(!decompression_scale.empty(), "[GPU] Compressed fully connected requires at least decompression scale input"); @@ -109,6 +110,7 @@ struct fully_connected : public primitive_base { /// @param compression_scale Primitive id containing scale factors for weights decompression. /// @param compression_zero_point Primitive id containing zero points for weights decompression. /// @param activation_scale Primitive id containing scale factor for activation. + /// @param activation_zero_point Primitive id containing zero point for activation. fully_connected(const primitive_id& id, const input_info& input, const primitive_id& weights, @@ -116,6 +118,7 @@ struct fully_connected : public primitive_base { const primitive_id& decompression_scale, const primitive_id& decompression_zero_point, const input_info& activation_scale, + const input_info& activation_zero_point, const data_types data_type, const size_t input_size = 2, const size_t weights_rank = 2) @@ -126,11 +129,15 @@ struct fully_connected : public primitive_base { decompression_scale(decompression_scale), decompression_zero_point(decompression_zero_point), dynamic_quantized_activation(false), + dynamic_quantized_activation_zp(false), activation_scale(activation_scale), + activation_zero_point(activation_zero_point), input_size(input_size), weights_rank(weights_rank) { if (activation_scale.is_valid()) dynamic_quantized_activation = true; + if (activation_zero_point.is_valid()) + dynamic_quantized_activation_zp = true; OPENVINO_ASSERT(!decompression_scale.empty(), "[GPU] Compressed fully connected requires at least decompression scale input"); } @@ -144,7 +151,9 @@ struct fully_connected : public primitive_base { primitive_id decompression_scale = ""; primitive_id decompression_zero_point = ""; bool dynamic_quantized_activation = false; + bool dynamic_quantized_activation_zp = false; input_info activation_scale = {"", 0}; + input_info activation_zero_point = {"", 0}; optional_value decompression_zero_point_scalar = optional_value(); /// @brief Primitive dimension size. @@ -161,6 +170,7 @@ struct fully_connected : public primitive_base { seed = hash_combine(seed, !decompression_scale.empty()); seed = hash_combine(seed, !decompression_zero_point.empty()); seed = hash_combine(seed, activation_scale.is_valid()); + seed = hash_combine(seed, activation_zero_point.is_valid()); seed = hash_combine(seed, decompression_zero_point_scalar.has_value()); seed = hash_combine(seed, decompression_zero_point_scalar.value_or(0.0f)); return seed; @@ -179,6 +189,7 @@ struct fully_connected : public primitive_base { decompression_scale.empty() == rhs_casted.decompression_scale.empty() && decompression_zero_point.empty() == rhs_casted.decompression_zero_point.empty() && activation_scale.is_valid() == rhs_casted.activation_scale.is_valid() && + activation_zero_point.is_valid() == rhs_casted.activation_zero_point.is_valid() && decompression_zero_point_scalar.value_or(0.0f) == rhs_casted.decompression_zero_point_scalar.value_or(0.0f); } @@ -190,9 +201,11 @@ struct fully_connected : public primitive_base { ob << decompression_scale; ob << decompression_zero_point; ob << activation_scale; + ob << activation_zero_point; ob << input_size; ob << weights_rank; ob << dynamic_quantized_activation; + ob << dynamic_quantized_activation_zp; if (decompression_zero_point_scalar.has_value()) { ob << true; @@ -211,9 +224,11 @@ struct fully_connected : public primitive_base { ib >> decompression_scale; ib >> decompression_zero_point; ib >> activation_scale; + ib >> activation_zero_point; ib >> input_size; ib >> weights_rank; ib >> dynamic_quantized_activation; + ib >> dynamic_quantized_activation_zp; bool has_value; ib >> has_value; @@ -243,6 +258,9 @@ struct fully_connected : public primitive_base { if (activation_scale.is_valid()) ret.push_back(activation_scale); + if (activation_zero_point.is_valid()) + ret.push_back(activation_zero_point); + return ret; } }; diff --git a/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp b/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp index a7a8ae1f229a72..52d828353fa155 100644 --- a/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp +++ b/src/plugins/intel_gpu/include/intel_gpu/runtime/debug_configuration.hpp @@ -146,6 +146,7 @@ class debug_configuration { std::vector dynamic_quantize_layers_without_onednn; // Specify Fully-connected layers which enable Dynamic quantization int use_kv_cache_compression; // Enable KV-cache compression int dynamic_quantize_group_size; // Enable Dynamic quantization for fully connected primitive by specified group size + int dynamic_quantize_asym; // Use asymmetric dynamic quantization int disable_horizontal_fc_fusion; // Disable fc horizontal fusion int disable_fc_swiglu_fusion; // Disable swiglu fusion to fc std::set dump_iteration; // Dump n-th execution of network. diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp index 29b7cf58a19b54..93f0905b3a1ef7 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_primitive_fusing.cpp @@ -463,7 +463,9 @@ void prepare_primitive_fusing::fuse_bias(program &p) { if (desc->decompression_zero_point_scalar.has_value()) fc_with_bias_prim->decompression_zero_point_scalar = desc->decompression_zero_point_scalar.value(); fc_with_bias_prim->activation_scale = desc->activation_scale; + fc_with_bias_prim->activation_zero_point = desc->activation_zero_point; fc_with_bias_prim->dynamic_quantized_activation = desc->dynamic_quantized_activation; + fc_with_bias_prim->dynamic_quantized_activation_zp = desc->dynamic_quantized_activation_zp; } auto& new_fc_node = p.get_or_create(fc_with_bias_prim); fuse_bias_f(fc, new_fc_node, bias_node, eltw_node); diff --git a/src/plugins/intel_gpu/src/graph/impls/ocl/dynamic_quantize.cpp b/src/plugins/intel_gpu/src/graph/impls/ocl/dynamic_quantize.cpp index b9fe00ac525720..ca628a48ac76e0 100644 --- a/src/plugins/intel_gpu/src/graph/impls/ocl/dynamic_quantize.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/ocl/dynamic_quantize.cpp @@ -35,6 +35,7 @@ struct dynamic_quantize_impl : typed_primitive_impl_ocl { static kernel_params_t get_kernel_params(const kernel_impl_params& impl_param, bool is_shape_agnostic = false) { auto params = get_default_params(impl_param, is_shape_agnostic); + const auto& primitive = impl_param.typed_desc(); params.outputs.push_back(convert_data_tensor(impl_param.get_output_layout(1))); // In Some model, the feature size could be dynamic in input0. @@ -48,6 +49,10 @@ struct dynamic_quantize_impl : typed_primitive_impl_ocl { if (impl_param.output_layouts.size() > 2) params.outputs.push_back(convert_data_tensor(impl_param.get_output_layout(2))); + // Keep 2d data as bf layout + if (primitive->input_size == 2) + params.outputs[0] = params.outputs[0].FlattenFeatureAndSpatials(); + const auto& desc = impl_param.typed_desc(); params.group_sizes = desc->attrs.group_sizes; params.scales_output_order = desc->attrs.scales_zp_output_order; @@ -68,7 +73,8 @@ namespace detail { attach_dynamic_quantize_impl::attach_dynamic_quantize_impl() { auto types = { data_types::f16, - data_types::i8 + data_types::i8, + data_types::u8 }; auto formats = { diff --git a/src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.cpp b/src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.cpp index 6b93b279129812..6cca9848af3472 100644 --- a/src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.cpp +++ b/src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.cpp @@ -83,10 +83,16 @@ struct fully_connected_onednn : typed_primitive_onednn_impl { if (prim->activation_scale.is_valid()) { auto activation_scale_idx = idx++; auto act_scale_mem = instance.dep_memory_ptr(activation_scale_idx); - // TODO: handle group_size here - dnnl::memory::desc desc = onednn::layout_to_memory_desc(act_scale_mem->get_layout(), dnnl::memory::format_tag::a, true); + dnnl::memory::desc desc = onednn::layout_to_memory_desc(act_scale_mem->get_layout(), dnnl::memory::format_tag::ab, true); args.insert({DNNL_ARG_ATTR_SCALES | DNNL_ARG_SRC_0, act_scale_mem->get_onednn_memory(desc)}); } + + if (prim->activation_zero_point.is_valid()) { + auto activation_zp_idx = idx++; + auto act_zp_mem = instance.dep_memory_ptr(activation_zp_idx); + dnnl::memory::desc desc = onednn::layout_to_memory_desc(act_zp_mem->get_layout(), dnnl::memory::format_tag::ab, true); + args.insert({DNNL_ARG_ATTR_ZERO_POINTS | DNNL_ARG_SRC_0, act_zp_mem->get_onednn_memory(desc)}); + } } return args; @@ -245,6 +251,7 @@ struct fully_connected_onednn : typed_primitive_onednn_impl { ob << has_bias; ob << is_compressed; ob << prim->dynamic_quantized_activation; + ob << prim->dynamic_quantized_activation_zp; bool has_decompression_scale = !prim->decompression_scale.empty(); if (has_decompression_scale) { @@ -271,10 +278,12 @@ struct fully_connected_onednn : typed_primitive_onednn_impl { bool has_bias = false; bool is_compressed = false; bool dynamic_quantized_activation; + bool dynamic_quantized_activation_zp; ib >> input_size; ib >> has_bias; ib >> is_compressed; ib >> dynamic_quantized_activation; + ib >> dynamic_quantized_activation_zp; const kernel_impl_params* impl_params = reinterpret_cast(ib.getKernelImplParams()); auto prim = impl_params->typed_desc(); @@ -293,11 +302,12 @@ struct fully_connected_onednn : typed_primitive_onednn_impl { bool has_decompression_zp = !prim->decompression_zero_point.empty() || prim->decompression_zero_point_scalar.has_value(); auto& arg = impl_params->get_program().get_node(impl_params->desc->id).as(); - int idx = !arg.bias_term() ? 3 : 4; + int idx = !arg.bias_term() ? 2 : 3; if (has_decompression_zp) { ib >> make_data(&_dzp_data_type, sizeof(dnnl::memory::data_type)); - auto dzp_layout = arg.get_dependency(idx++).get_output_layout(); + auto decompression_zp_idx = ++idx; + auto dzp_layout = arg.get_dependency(decompression_zp_idx).get_output_layout(); if (dzp_layout.count() == 1) { _attrs->set_zero_points(DNNL_ARG_WEIGHTS, COMMON, dnnl::memory::dims{}, _dzp_data_type); @@ -312,12 +322,17 @@ struct fully_connected_onednn : typed_primitive_onednn_impl { } if (dynamic_quantized_activation) { - // TODO: it supports per-token activation scale only + auto src_scale_idx = ++idx; auto partial_shape = impl_params->get_input_layout(0).get_partial_shape(); auto innermost_len = partial_shape[partial_shape.size() - 1].get_length(); - - auto act_scale_data_type = convert_data_type(impl_params->get_input_layout(idx).data_type); - _attrs->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, innermost_len}, act_scale_data_type); + auto& src_scale_shape = impl_params->input_layouts[src_scale_idx].get_partial_shape(); + int src_scale_ngroups = src_scale_shape[src_scale_shape.size() - 1].get_length(); + int src_group_size = innermost_len / src_scale_ngroups; + + auto act_scale_data_type = convert_data_type(impl_params->get_input_layout(src_scale_idx).data_type); + _attrs->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, act_scale_data_type); + if (dynamic_quantized_activation_zp) + _attrs->set_zero_points(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8); } if (is_compressed) { @@ -387,15 +402,21 @@ struct fully_connected_onednn : typed_primitive_onednn_impl { } if (prim->dynamic_quantized_activation) { - // Note: it supports per-token activation scale only - ++idx; - auto partial_shape = impl_params.input_layouts[0].get_partial_shape(); + auto src_scale_idx = ++idx; + auto& partial_shape = impl_params.input_layouts[0].get_partial_shape(); auto innermost_len = partial_shape[partial_shape.size() - 1].get_length(); + auto& src_scale_shape = impl_params.input_layouts[src_scale_idx].get_partial_shape(); + int src_scale_ngroups = src_scale_shape[src_scale_shape.size() - 1].get_length(); + int src_group_size = innermost_len / src_scale_ngroups; - auto act_scale_data_type = convert_data_type(impl_params.input_layouts[idx].data_type); - attr->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, innermost_len}, act_scale_data_type); + auto act_scale_data_type = convert_data_type(impl_params.input_layouts[src_scale_idx].data_type); + attr->set_scales(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, act_scale_data_type); + + if (prim->activation_zero_point.is_valid()) + attr->set_zero_points(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8); } + auto prim_desc = get_matmul_primitive_descriptor(impl_params, impl_params.prog->get_engine(), prim->input_size, !prim->bias.empty(), *attr); diff --git a/src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.hpp b/src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.hpp index 17498831a542d1..62129866927ea4 100644 --- a/src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.hpp +++ b/src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.hpp @@ -48,7 +48,7 @@ struct FullyConnectedImplementationManager : public ImplementationManager { one_of(wei_dt, {data_types::i8, data_types::u8}) && one_of(out_dt, {data_types::f16, data_types::f32, data_types::i32, data_types::i8, data_types::u8}); bool compressed_case = fc_prim->compressed_weights && - one_of(in0_dt, {data_types::f16, data_types::f32, data_types::i8}) && + one_of(in0_dt, {data_types::f16, data_types::f32, data_types::i8, data_types::u8}) && one_of(wei_dt, {data_types::u8, data_types::i8, data_types::u4, data_types::i4}) && one_of(out_dt, {data_types::f16, data_types::f32, data_types::u8, data_types::i8}); if (!f16f16_case && !f32f32_case && !u8s8_case && !compressed_case) diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_opt.cl index 6db1790844e501..22c620d712770c 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_opt.cl @@ -4,77 +4,180 @@ #include "include/batch_headers/fetch_data.cl" -#if OUTPUT_DIMS != 4 +#if OUTPUT_DIMS != 4 && OUTPUT_DIMS != 2 #error "dynamic_quantize_gpu_opt.cl: Unsupported output dimension" #endif #define VLOAD_N CAT(vload, VEC_SIZE) #define VSTORE_N CAT(vstore, VEC_SIZE) +#define CONVERT_UCHAR_N CAT(convert_uchar, VEC_SIZE) #define CONVERT_CHAR_N CAT(convert_char, VEC_SIZE) #define AS_TYPE_N_(type, n, x) as_##type##n(x) #define AS_TYPE_N(type, n, x) AS_TYPE_N_(type, n, x) #define AS_INPUT_TYPE_N(x) AS_TYPE_N(INPUT0_TYPE, VEC_SIZE, x) +#if QUANTIZE_GROUP_SIZE <= 128 + +#if ASYMMETRIC_QUANTIZATION +#error "UNIMPLMENTED: asymmetric quantization when group size is small" +#endif + +KERNEL(dynamic_quantize_gpu_opt)( + OPTIONAL_SHAPE_INFO_ARG + const __global INPUT0_TYPE* input, + __global OUTPUT_TYPE* output, + __global OUTPUT1_TYPE* output_scale + ) { + +#if OUTPUT_DIMS == 2 + const uint b = get_global_id(0); + const uint f_grp = get_global_id(1); + const uint input_offset = INPUT0_GET_INDEX(b, f_grp * QUANTIZE_GROUP_SIZE, 0, 0); + const uint output_offset = OUTPUT_GET_INDEX(b, f_grp * QUANTIZE_GROUP_SIZE, 0, 0); +#else + const uint bf = get_global_id(0); + const uint b = bf / INPUT0_FEATURE_NUM; + const uint f = bf % INPUT0_FEATURE_NUM; + const uint y_grp = get_global_id(1); + const uint input_offset = INPUT0_GET_INDEX(b, f, y_grp * QUANTIZE_GROUP_SIZE, 0); + const uint output_offset = OUTPUT_GET_INDEX(b, f, y_grp * QUANTIZE_GROUP_SIZE, 0); + +#endif + const uint quantize_block = QUANTIZE_GROUP_SIZE / 4; + half4 input_0[quantize_block]; + char4 quantized_value[quantize_block]; + half max[quantize_block]; + + unroll_for (uint i = 0 ; i < quantize_block; ++i) { + input_0[i] = vload4(0, &input[input_offset + i * 4]); + max[i] = fmax(fmax(fabs(input_0[i][0]), fabs(input_0[i][1])), fmax(fabs(input_0[i][2]), fabs(input_0[i][3]))); + } + + half max_value = fmax(0.001h, max[0]); + for (uint i = 1; i < quantize_block; i++) { + max_value = fmax(max_value, max[i]); + } + + half quan_scale = 128.0h / max_value; + + unroll_for (uint i = 0 ; i < quantize_block; ++i) { + quantized_value[i] = convert_char4(input_0[i] * (half4)quan_scale); + vstore4(quantized_value[i], 0, &output[output_offset + i * 4]); + } + +#if OUTPUT_DIMS == 2 + output_scale[OUTPUT1_GET_INDEX(b, f_grp, 0, 0)] = 1.0h / quan_scale; +#else + output_scale[OUTPUT1_GET_INDEX(b, f, y_grp, 0)] = 1.0h / quan_scale; +#endif +} + +#else // !(QUANTIZE_GROUP_SIZE <= 128) + REQD_SUB_GROUP_SIZE(SIMD) KERNEL(dynamic_quantize_gpu_opt)( OPTIONAL_SHAPE_INFO_ARG const __global INPUT0_TYPE* input, __global OUTPUT_TYPE* output, - __global OUTPUT1_TYPE* output_scale) + __global OUTPUT1_TYPE* output_scale +#if ASYMMETRIC_QUANTIZATION + , __global OUTPUT2_TYPE* output_zp +#endif + ) { const uint bf = (uint)get_global_id(2); const uint sglid = get_sub_group_local_id(); const uint local_id = (uint)get_local_id(1); const uint block_size = SIMD * VEC_SIZE; +#if OUTPUT_DIMS == 2 + const uint b_offset = bf * INPUT0_BATCH_PITCH; +#else const uint b_offset = bf * INPUT0_FEATURE_PITCH; - +#endif const uint offset = b_offset + VEC_SIZE * sglid; const uint iteration = ALIGNED_BLOCK_NUM / BLOCK_NUM; - __local half local_mem[BLOCK_NUM]; + __local half local_mem_max[BLOCK_NUM]; + __local half local_mem_min[BLOCK_NUM]; MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_SIZE) val[iteration]; MAKE_VECTOR_TYPE(INPUT0_TYPE, VEC_SIZE) abs_val; - half max = 0.0h; half grp_max = 0.001h; - half max_value; + half grp_min = 0.001h; + half max_value = 0.0h; + half min_value = 0.0h; unroll_for(int i = 0; i < iteration; ++i) { if ((local_id * iteration + i) >= TOTAL_BLOCK_NUM) continue; val[i] = AS_INPUT_TYPE_N(VLOAD_N(0, input + offset + ((local_id * iteration + i) * block_size))); - abs_val = fabs(val[i]); - +#if ASYMMETRIC_QUANTIZATION unroll_for (int j = 0; j < VEC_SIZE; j++) { - max = fmax(max, abs_val[j]); + max_value = fmax(max_value, val[i][j]); + min_value = fmin(min_value, val[i][j]); } + grp_max = fmax(grp_max, max_value); + grp_min = fmin(grp_min, min_value); +#else + abs_val = fabs(val[i]); + + unroll_for (int j = 0; j < VEC_SIZE; j++) + max_value = fmax(max_value, abs_val[j]); - grp_max = fmax(grp_max, max); + grp_max = fmax(grp_max, max_value); +#endif } max_value = sub_group_reduce_max(grp_max); - if (sglid == 0) - local_mem[local_id] = max_value; +#if ASYMMETRIC_QUANTIZATION + min_value = sub_group_reduce_min(grp_min); +#endif + + if (sglid == 0) { + local_mem_max[local_id] = max_value; +#if ASYMMETRIC_QUANTIZATION + local_mem_min[local_id] = min_value; +#endif + } barrier(CLK_LOCAL_MEM_FENCE); for (int j = 0; j < BLOCK_NUM; j++) { - max_value = fmax(max_value, local_mem[j]); + max_value = fmax(max_value, local_mem_max[j]); +#if ASYMMETRIC_QUANTIZATION + min_value = fmin(min_value, local_mem_min[j]); +#endif } - half scale = 127.0h / max_value; +#if ASYMMETRIC_QUANTIZATION + OUTPUT1_TYPE scale = (OUTPUT1_TYPE)((CHAR_MAX - CHAR_MIN) / (max_value - min_value)); + OUTPUT2_TYPE zp = (OUTPUT2_TYPE)(-min_value * scale); +#else + OUTPUT1_TYPE scale = 127.0h / max_value; +#endif + unroll_for(int i = 0; i < iteration; ++i) { if ((local_id * iteration + i) >= TOTAL_BLOCK_NUM) continue; val[i] *= scale; +#if ASYMMETRIC_QUANTIZATION + val[i] += zp; + VSTORE_N(CAT(CONVERT_UCHAR_N, _rte)(val[i]), 0, output + offset + ((local_id * iteration + i) * block_size)); +#else VSTORE_N(CAT(CONVERT_CHAR_N, _rte)(val[i]), 0, output + offset + ((local_id * iteration + i) * block_size)); +#endif } - if (sglid == 0 && local_id == 0) + if (sglid == 0 && local_id == 0) { output_scale[bf] = 1.0h / scale; +#if ASYMMETRIC_QUANTIZATION + output_zp[bf] = convert_uchar_rte(zp); +#endif + } } +#endif // QUANTIZE_GROUP_SIZE <= 128 diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_ref.cl index 62482b8b9b5047..4acf87eb37ceb0 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/dynamic_quantize_gpu_ref.cl @@ -4,6 +4,16 @@ #include "include/batch_headers/fetch_data.cl" +#define UINT64_MAX 0xFFFFFFFFFFFFFFFF + +#if ASYMMETRIC_QUANTIZATION && UNSIGNED_OUTPUT + #define TO_OUTPUT_TYPE_RTE(val) convert_uchar_rte(val) + #define TO_OUTPUT_VEC_TYPE_RTE(val) convert_uchar8_rte(val) +#else + #define TO_OUTPUT_TYPE_RTE(val) convert_char_rte(val) + #define TO_OUTPUT_VEC_TYPE_RTE(val) convert_char8_rte(val) +#endif + #if OUTPUT_DIMS != 4 #error "dynamic_quantize_gpu_ref.cl: Unsupported output dimension" #endif @@ -33,19 +43,21 @@ KERNEL(dynamic_quantize_gpu_ref)( const uint bf = (uint)get_global_id(0); const uint b = bf / INPUT0_FEATURE_NUM; const uint f = bf % INPUT0_FEATURE_NUM; - const uint y = (uint)get_global_id(1); + const uint out_y = (uint)get_global_id(1); + const uint y = out_y * GROUP_SIZE_DIM2; // quantization may be grouped for y axis const uint x = (uint)get_global_id(2); #ifdef SCALES_OUTPUT_ORDER - const uint scale_idx = FUNC_CALL(get_scales_offset)(OPTIONAL_SHAPE_INFO_TENSOR b, f, y, x); + const uint scale_idx = FUNC_CALL(get_scales_offset)(OPTIONAL_SHAPE_INFO_TENSOR b, f, out_y, x); #else - const uint scale_idx = OUTPUT1_GET_INDEX_SAFE(b, f, y, x); + const uint scale_idx = OUTPUT1_GET_INDEX_SAFE(b, f, out_y, x); #endif half max_val = INPUT0_VAL_MIN; half min_val = INPUT0_VAL_MAX; for (int b_off = 0; b_off < (GROUP_SIZE_DIM0 == 1 ? 1 : INPUT0_BATCH_NUM); b_off++) { for (int f_off = 0; f_off < (GROUP_SIZE_DIM1 == 1 ? 1 : INPUT0_FEATURE_NUM); f_off++) { - for (int y_off = 0; y_off < (GROUP_SIZE_DIM2 == 1 ? 1 : INPUT0_SIZE_Y); y_off++) { + for (int y_off = 0; y_off < (GROUP_SIZE_DIM2 == UINT64_MAX ? INPUT0_SIZE_Y : GROUP_SIZE_DIM2); y_off++) { + // It is assumed that grouped quantization happens only for 3d input case where we don't have x axis #if GROUP_SIZE_DIM3 == 1 const uint offset = INPUT0_GET_INDEX(b + b_off, f + f_off, y + y_off, x); half val = input[offset]; @@ -88,53 +100,49 @@ KERNEL(dynamic_quantize_gpu_ref)( #if ASYMMETRIC_QUANTIZATION OUTPUT1_TYPE scale = (OUTPUT1_TYPE)((CHAR_MAX - CHAR_MIN) / (max_val - min_val)); +# if UNSIGNED_OUTPUT + OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(-min_val * scale); +# else // !UNSIGNED_OUTPUT OUTPUT1_TYPE zp = (OUTPUT1_TYPE)(-min_val * scale) - CHAR_MAX; -#else +# endif +#else // !ASYMMETRIC_QUANTIZATION max_val = work_group_reduce_max(max_val); OUTPUT1_TYPE scale = 127.0h / max_val; #endif for (int b_off = 0; b_off < (GROUP_SIZE_DIM0 == 1 ? 1 : INPUT0_BATCH_NUM); b_off++) { for (int f_off = 0; f_off < (GROUP_SIZE_DIM1 == 1 ? 1 : INPUT0_FEATURE_NUM); f_off++) { - for (int y_off = 0; y_off < (GROUP_SIZE_DIM2 == 1 ? 1 : INPUT0_SIZE_Y); y_off++) { + for (int y_off = 0; y_off < (GROUP_SIZE_DIM2 == UINT64_MAX ? INPUT0_SIZE_Y : GROUP_SIZE_DIM2); y_off++) { #if GROUP_SIZE_DIM3 == 1 const uint in_offset = INPUT0_GET_INDEX(b + b_off, f + f_off, y + y_off, x); const uint out_offset = OUTPUT_GET_INDEX(b + b_off, f + f_off, y + y_off, x); half val = input[in_offset]; -#if ASYMMETRIC_QUANTIZATION val *= scale; +#if ASYMMETRIC_QUANTIZATION val += zp; - output[out_offset] = convert_char_rte(val); -#else - val *= scale; - output[out_offset] = convert_char_rte(val); #endif + output[out_offset] = TO_OUTPUT_TYPE_RTE(val); #else const uint in_offset = INPUT0_GET_INDEX(b + b_off, f + f_off, y + y_off, 0); const uint out_offset = OUTPUT_GET_INDEX(b + b_off, f + f_off, y + y_off, 0); int x; for (x = 0; x < INPUT0_SIZE_X / 8; x++) { half8 val = as_half8(vload8(0, (ushort*)input + in_offset + x * 8)); -#if ASYMMETRIC_QUANTIZATION val *= scale; +#if ASYMMETRIC_QUANTIZATION val += zp; -#else - val *= scale; #endif - vstore8(convert_char8_rte(val), 0, output + out_offset + x * 8); + vstore8(TO_OUTPUT_VEC_TYPE_RTE(val), 0, output + out_offset + x * 8); } x *= 8; for (; x < INPUT0_SIZE_X; x++) { half val = input[in_offset + x]; -#if ASYMMETRIC_QUANTIZATION val *= scale; +#if ASYMMETRIC_QUANTIZATION val += zp; - output[out_offset + x] = convert_char_rte(val); -#else - val *= scale; - output[out_offset + x] = convert_char_rte(val); #endif + output[out_offset + x] = TO_OUTPUT_TYPE_RTE(val); } #endif } @@ -145,6 +153,6 @@ KERNEL(dynamic_quantize_gpu_ref)( #if ASYMMETRIC_QUANTIZATION && GROUP_SCALES_WITH_ZP output_scale[scale_idx + 1] = zp; #elif ASYMMETRIC_QUANTIZATION - output_zp[scale_idx] = zp; + output_zp[scale_idx] = convert_uchar_rte(zp); #endif } diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/dynamic_quantize/dynamic_quantize_kernel_opt.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/dynamic_quantize/dynamic_quantize_kernel_opt.cpp index 52a648679499f2..b4f667475f26f1 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/dynamic_quantize/dynamic_quantize_kernel_opt.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/dynamic_quantize/dynamic_quantize_kernel_opt.cpp @@ -30,9 +30,11 @@ static std::pair get_input_bf_size(const dynamic_quantize_params static size_t get_match_vector_size(const dynamic_quantize_params& params) { auto block_sizes = { 8, 4, 2 }; + auto bf = get_input_bf_size(params); + auto f = bf.second; for (auto block_size : block_sizes) { - if (((params.inputs[0].X().v * params.inputs[0].Y().v) / simd) % block_size == 0) { + if ((f / simd) % block_size == 0) { return block_size; } } @@ -43,10 +45,13 @@ static size_t get_match_vector_size(const dynamic_quantize_params& params) { ParamsKey DynamicQuantizeKernelOpt::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::F16); + k.EnableOutputDataType(Datatype::UINT8); k.EnableOutputDataType(Datatype::INT8); k.EnableDifferentTypes(); - k.EnableAllInputLayout(); - k.EnableAllOutputLayout(); + k.EnableInputLayout(DataLayout::bf); + k.EnableInputLayout(DataLayout::bfyx); + k.EnableOutputLayout(DataLayout::bf); + k.EnableOutputLayout(DataLayout::bfyx); k.EnableTensorOffset(); k.EnableTensorPitches(); k.EnableBatching(); @@ -68,6 +73,8 @@ JitConstants DynamicQuantizeKernelOpt::GetJitConstants(const dynamic_quantize_pa jit.AddConstant(MakeJitConstant("TOTAL_BLOCK_NUM", total_block_num)); jit.AddConstant(MakeJitConstant("ALIGNED_BLOCK_NUM", aligned_block_num)); jit.AddConstant(MakeJitConstant("BLOCK_NUM", block_num)); + jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", params.group_sizes.back())); + jit.AddConstant(MakeJitConstant("ASYMMETRIC_QUANTIZATION", params.use_asymmetric_quantization)); jit.Merge(GetTensorFriendlyWorkGroupsJit(params.outputs[0])); return jit; @@ -76,15 +83,20 @@ JitConstants DynamicQuantizeKernelOpt::GetJitConstants(const dynamic_quantize_pa CommonDispatchData DynamicQuantizeKernelOpt::SetDefault(const dynamic_quantize_params& params) const { CommonDispatchData dispatchData; - auto vec_size = get_match_vector_size(params); - auto bf_size = get_input_bf_size(params); - size_t total_block_num = bf_size.second / (simd * vec_size); - size_t batch = get_input_bf_size(params).first; - size_t block_num = (total_block_num > 32) ? 32 : total_block_num; - - dispatchData.gws = {simd, block_num, batch}; - dispatchData.lws = {simd, block_num, 1}; - + if (params.group_sizes.back() <= 128) { + auto bf_size = get_input_bf_size(params); + dispatchData.gws = {bf_size.first, bf_size.second / params.group_sizes.back(), 1}; + dispatchData.lws = {1, 1, 1}; + } else { + auto vec_size = get_match_vector_size(params); + auto bf_size = get_input_bf_size(params); + size_t total_block_num = bf_size.second / (simd * vec_size); + size_t batch = get_input_bf_size(params).first; + size_t block_num = (total_block_num > 32) ? 32 : total_block_num; + + dispatchData.gws = {simd, block_num, batch}; + dispatchData.lws = {simd, block_num, 1}; + } return dispatchData; } @@ -147,8 +159,9 @@ bool DynamicQuantizeKernelOpt::Validate(const Params& params) const { const auto& dq_params = static_cast(params); - // Todo : Add proper exception here - if (((dq_params.inputs[0].X().v * dq_params.inputs[0].Y().v) % (simd * 2)) != 0) + + auto bf = get_input_bf_size(dq_params); + if (((bf.second) % (simd * 2)) != 0) return false; if (dq_params.inputs[0].GetPaddedVal() != 0 || dq_params.outputs[0].GetPaddedVal() != 0) @@ -157,8 +170,10 @@ bool DynamicQuantizeKernelOpt::Validate(const Params& params) const { if (dq_params.append_axis != -1) return false; - if (dq_params.group_sizes.back() != UINT64_MAX) - return false; + for (size_t i = 0; i < dq_params.group_sizes.size() - 1; i++) { + if (dq_params.group_sizes[i] != 1) + return false; + } // Allow only default scales order const auto& scales_output_order = dq_params.scales_output_order; @@ -168,7 +183,16 @@ bool DynamicQuantizeKernelOpt::Validate(const Params& params) const { return false; } + if (dq_params.use_asymmetric_quantization) { + if (dq_params.combine_scales_and_zp) + return false; + if (dq_params.outputs[0].GetDType() != Datatype::UINT8) + return false; + } + return true; } + + } // namespace kernel_selector diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/dynamic_quantize/dynamic_quantize_kernel_ref.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/dynamic_quantize/dynamic_quantize_kernel_ref.cpp index bd3d0f87cdc931..f432fa6ac5756d 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/dynamic_quantize/dynamic_quantize_kernel_ref.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/dynamic_quantize/dynamic_quantize_kernel_ref.cpp @@ -11,6 +11,7 @@ ParamsKey DynamicQuantizeKernelRef::GetSupportedKey() const { ParamsKey k; k.EnableInputDataType(Datatype::F16); k.EnableOutputDataType(Datatype::INT8); + k.EnableOutputDataType(Datatype::UINT8); k.EnableInputLayout(DataLayout::bfyx); k.EnableOutputLayout(DataLayout::bfyx); k.EnableTensorOffset(); @@ -53,6 +54,7 @@ JitConstants DynamicQuantizeKernelRef::GetJitConstants(const dynamic_quantize_pa jit.AddConstant(MakeJitConstant("ASYMMETRIC_QUANTIZATION", params.use_asymmetric_quantization)); jit.AddConstant(MakeJitConstant("GROUP_SCALES_WITH_ZP", params.combine_scales_and_zp)); + jit.AddConstant(MakeJitConstant("UNSIGNED_OUTPUT", params.outputs[0].GetDType() == Datatype::UINT8 ? 1 : 0)); auto group_sizes = params.group_sizes; group_sizes.resize(std::min((size_t)4, group_sizes.size()), 1); @@ -71,12 +73,26 @@ CommonDispatchData DynamicQuantizeKernelRef::SetDefault(const dynamic_quantize_p OPENVINO_ASSERT(params.outputs[0].GetLayout() == DataLayout::bfyx, "It supports only 4d tensor"); auto group_sizes = params.group_sizes; - group_sizes.resize(std::min((size_t)4, group_sizes.size()), 1); + group_sizes.resize(std::max((size_t)4, group_sizes.size()), 1); auto batch_size = group_sizes[0] == 1 ? params.outputs[0].Batch().v : 1; auto feature_size = group_sizes[1] == 1 ? params.outputs[0].Feature().v : 1; auto y_size = group_sizes[2] == 1 ? params.outputs[0].Y().v : 1; auto x_size = group_sizes[3] == 1 ? params.outputs[0].X().v : 1; + OPENVINO_ASSERT( + (group_sizes[0] == 1 || group_sizes[0] == params.outputs[0].Batch().v || group_sizes[0] == UINT64_MAX) && + (group_sizes[1] == 1 || group_sizes[1] == params.outputs[0].Feature().v || group_sizes[1] == UINT64_MAX) && + (group_sizes[2] == 1 || group_sizes[2] == params.outputs[0].Y().v || group_sizes[2] == UINT64_MAX + || (params.outputs[0].Y().v % group_sizes[2] == 0 && params.outputs[0].X().v == 1)) && // Grouped quantization is only supported for 3d case + (group_sizes[3] == 1 || group_sizes[3] == params.outputs[0].X().v || group_sizes[3] == UINT64_MAX), + "[GPU] Unsupported dynamic quantization configuration: (", + group_sizes[0], ",", group_sizes[1], ",", group_sizes[2], ",", group_sizes[3], ") - (", + params.outputs[0].Batch().v, ",", params.outputs[0].Feature().v, ",", params.outputs[0].Y().v, ",", params.outputs[0].X().v, ")"); + + // Grouped quantization is supported only over y axis + if (params.group_sizes[2] > 1 && params.group_sizes[2] != UINT64_MAX) + y_size = params.outputs[0].Y().v / params.group_sizes[2]; + dispatchData.gws = {batch_size * feature_size, y_size, x_size}; dispatchData.lws = {1, 1, 1}; diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp index 46e8f7f1104f0d..68da7aea7b1fe6 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp @@ -124,16 +124,16 @@ static bool should_dynamic_quantize(const fully_connected_params& params, bool p if ((scale_group_size % simd == 0) && (input_f % dynamic_quantization_group_size == 0) && (params.is_shape_agnostic || (params.inputs[0].Batch().v > 1 && input_b > min_slm_size)) && params.inputs[0].GetDType() == Datatype::F16 && is_weight_dyn_quantizable(params)) { - if (print_log) { - GPU_DEBUG_TRACE_DETAIL << " Dynamic quantizing for FC : scale_group_size: " << scale_group_size << - ", Dyn-quan group size: " << dynamic_quantization_group_size << - ", Type(I:" << kernel_selector::toString(params.inputs[0].GetDType()) << - ", O:" << kernel_selector::toString(params.outputs[0].GetDType()) << - ", W:" << kernel_selector::toString(params.weights.GetDType()) << - "), Format(W:" << kernel_selector::toString(params.weights.GetLayout()) << - ") B: " << params.inputs[0].Batch().v << ", F: " << params.inputs[0].Feature().v << - ", Y: " << params.inputs[0].Y().v << std ::endl; - } + if (print_log) { + GPU_DEBUG_TRACE_DETAIL << " Dynamic quantizing for FC : scale_group_size: " << scale_group_size << + ", Dyn-quan group size: " << dynamic_quantization_group_size << + ", Type(I:" << kernel_selector::toString(params.inputs[0].GetDType()) << + ", O:" << kernel_selector::toString(params.outputs[0].GetDType()) << + ", W:" << kernel_selector::toString(params.weights.GetDType()) << + "), Format(W:" << kernel_selector::toString(params.weights.GetLayout()) << + ") B: " << params.inputs[0].Batch().v << ", F: " << params.inputs[0].Feature().v << + ", Y: " << params.inputs[0].Y().v << std ::endl; + } return true; } diff --git a/src/plugins/intel_gpu/src/plugin/ops/dynamic_quantize.cpp b/src/plugins/intel_gpu/src/plugin/ops/dynamic_quantize.cpp index 85f28cbd711678..4c11bdb21971e9 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/dynamic_quantize.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/dynamic_quantize.cpp @@ -18,7 +18,8 @@ static void CreateDynamicQuantizeOp(ProgramBuilder& p, const std::shared_ptrget_attrs()); + op->get_attrs(), + op->get_input_partial_shape(0).size()); prim.num_outputs = op->get_output_size(); diff --git a/src/plugins/intel_gpu/src/plugin/ops/fully_connected.cpp b/src/plugins/intel_gpu/src/plugin/ops/fully_connected.cpp index 7b0aa921ef3ad5..5f4fe19c5c4c08 100644 --- a/src/plugins/intel_gpu/src/plugin/ops/fully_connected.cpp +++ b/src/plugins/intel_gpu/src/plugin/ops/fully_connected.cpp @@ -26,7 +26,7 @@ namespace ov { namespace intel_gpu { static void CreateFullyConnectedCompressedOp(ProgramBuilder& p, const std::shared_ptr& op) { - validate_inputs_count(op, {4, 5, 6}); + validate_inputs_count(op, {4, 5, 6, 7}); auto inputs = p.GetInputInfo(op); std::string primitive_name = layer_type_name_ID(op); auto supports_immad = p.get_engine().get_device_info().supports_immad; @@ -39,6 +39,7 @@ static void CreateFullyConnectedCompressedOp(ProgramBuilder& p, const std::share const size_t W_ZP_IDX = input_idx; std::string zp_name = op->get_input_size() > input_idx ? inputs[input_idx++].pid : ""; auto activation_scale_input = op->get_input_size() > input_idx ? inputs[input_idx++] : cldnn::input_info(); + auto activation_zero_point_input = op->get_input_size() > input_idx ? inputs[input_idx++] : cldnn::input_info(); float zp_value = 0.0f; bool has_scalar_zp = false; @@ -58,6 +59,7 @@ static void CreateFullyConnectedCompressedOp(ProgramBuilder& p, const std::share scale_name, has_scalar_zp && !supports_immad ? "" : zp_name, activation_scale_input, + activation_zero_point_input, cldnn::element_type_to_data_type(op->get_output_element_type(0)), op->get_input_partial_shape(0).size(), op->get_input_partial_shape(1).size()); diff --git a/src/plugins/intel_gpu/src/plugin/program_builder.cpp b/src/plugins/intel_gpu/src/plugin/program_builder.cpp index b623c86fabe02c..368e25abe2ddac 100644 --- a/src/plugins/intel_gpu/src/plugin/program_builder.cpp +++ b/src/plugins/intel_gpu/src/plugin/program_builder.cpp @@ -10,6 +10,7 @@ #include "openvino/op/lstm_sequence.hpp" #include "openvino/op/loop.hpp" #include "openvino/op/search_sorted.hpp" +#include "ov_ops/dynamic_quantize.hpp" #include "intel_gpu/plugin/common_utils.hpp" #include "intel_gpu/plugin/program_builder.hpp" @@ -357,6 +358,9 @@ bool ProgramBuilder::requires_new_shape_infer(const std::shared_ptr& o if (ov::is_type(op)) return true; + if (ov::is_type(op)) + return true; + if (ov::is_type(op)) { const auto body_function = std::static_pointer_cast(op)->get_function(); if (body_function->is_dynamic()) diff --git a/src/plugins/intel_gpu/src/plugin/transformations/dynamic_quantize_fully_connected.cpp b/src/plugins/intel_gpu/src/plugin/transformations/dynamic_quantize_fully_connected.cpp index c36212713ae717..61dc40e2713800 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/dynamic_quantize_fully_connected.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/dynamic_quantize_fully_connected.cpp @@ -21,24 +21,11 @@ DynamicQuantizeFullyConnected::DynamicQuantizeFullyConnected(uint64_t group_size : ov::pass::MatcherPass() { GPU_DEBUG_GET_INSTANCE(debug_config); using namespace ov::pass::pattern; - - // per-token quantization is supported - if (group_size != UINT64_MAX) { - GPU_DEBUG_TRACE << "Dynamic quantization is disabled " << group_size << std::endl; - return; - } - auto is_dynamic = [](const ov::Output& output) -> bool { - bool is_dynamic = output.get_node_shared_ptr()->get_output_partial_shape(0).is_dynamic(); - size_t num_inputs = output.get_node_shared_ptr()->get_input_size(); - for (size_t idx = 0; idx < num_inputs; idx++) { - is_dynamic |= output.get_node_shared_ptr()->get_input_partial_shape(idx).is_dynamic(); - } - return is_dynamic; - }; + using QuantizationType = ov::op::internal::DynamicQuantize::QuantizationType; auto data = any_input(); - auto fully_connected_compressed3 = wrap_type({data, any_input(), any_input(), any_input()}, is_dynamic); - auto fully_connected_compressed4 = wrap_type({data, any_input(), any_input(), any_input(), any_input()}, is_dynamic); + auto fully_connected_compressed3 = wrap_type({data, any_input(), any_input(), any_input()}); + auto fully_connected_compressed4 = wrap_type({data, any_input(), any_input(), any_input(), any_input()}); auto fully_connected_compressed = std::make_shared(OutputVector{fully_connected_compressed3, fully_connected_compressed4}); @@ -65,12 +52,20 @@ DynamicQuantizeFullyConnected::DynamicQuantizeFullyConnected(uint64_t group_size ov::op::internal::DynamicQuantize::Attributes config; config.quantization_dt = element::i8; - config.quantization_type = ov::op::internal::DynamicQuantize::QuantizationType::Symmetric; + config.quantization_type = QuantizationType::Symmetric; config.scale_dt = element::f16; config.group_sizes = shape_group_size; + if (debug_config->dynamic_quantize_asym) { + config.quantization_type = QuantizationType::Asymmetric; + config.quantization_dt = element::u8; + config.zp_dt = element::u8; // it supports u8 only now + } + auto dyn_quan = std::make_shared(m_data, config); auto optional_w_zp = m_fc->get_input_size() > 4 ? m_fc->get_input_node_shared_ptr(4) : std::make_shared(); + auto optional_a_zp = config.quantization_type == QuantizationType::Symmetric ? + std::make_shared() : dyn_quan->output(2); auto output_type = m_fc->get_output_type(); if (output_type == ov::element::undefined) @@ -82,6 +77,7 @@ DynamicQuantizeFullyConnected::DynamicQuantizeFullyConnected(uint64_t group_size m_fc->get_input_node_shared_ptr(3), optional_w_zp, dyn_quan->output(1), + optional_a_zp, output_type); ov::replace_node(m_fc, new_fc); diff --git a/src/plugins/intel_gpu/src/plugin/transformations/op/fully_connected_compressed.cpp b/src/plugins/intel_gpu/src/plugin/transformations/op/fully_connected_compressed.cpp index 2e3819d7e850ee..dd5c555b1e6bc8 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations/op/fully_connected_compressed.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations/op/fully_connected_compressed.cpp @@ -14,11 +14,13 @@ FullyConnectedCompressed::FullyConnectedCompressed(const ov::Output& A, const ov::Output& w_decompression_scale, const ov::Output& w_decompression_zero_point, const ov::Output& a_decompression_scale, + const ov::Output& a_decompression_zero_point, const ov::element::Type output_type) : FullyConnected(A, B, bias, output_type) { set_argument(3, w_decompression_scale); set_argument(4, w_decompression_zero_point); set_argument(5, a_decompression_scale); + set_argument(6, a_decompression_zero_point); validate_and_infer_types(); } @@ -60,12 +62,13 @@ std::shared_ptr FullyConnectedCompressed::clone_with_new_inputs(const new_args.at(3), new_args.at(4), m_output_type); - else if (new_args.size() == 6) + else if (new_args.size() == 7) return std::make_shared(new_args.at(0), new_args.at(1), new_args.at(2), new_args.at(3), new_args.at(4), + new_args.at(5), new_args.at(6), m_output_type); else diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index e47ccbb09a9c43..50eecf51b945b7 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -975,18 +975,34 @@ void TransformationsPipeline::apply(std::shared_ptr func) { // This Validate is needed for proper data type propagation after applying IncreasePositionIdsPrecision pass manager.register_pass(); - auto dynamic_quantization_group_size = config.get_property(ov::hint::dynamic_quantization_group_size); if (device_info.supports_immad) { + auto dynamic_quantization_group_size = config.get_property(ov::hint::dynamic_quantization_group_size); pass_config->set_callback([=](const_node_ptr& root) -> bool { if (root->get_input_node_shared_ptr(0)->get_element_type() == ov::element::Type_t::f32) { - GPU_DEBUG_TRACE << root->get_friendly_name() << " Dynamic quantization is turned off because input type is not supported" << std::endl; + GPU_DEBUG_TRACE << root->get_friendly_name() << " dyn_quan is turned off: input type is not supported" << std::endl; return true; } auto weight_shape = root->get_input_partial_shape(1); const size_t innermost_size = weight_shape[weight_shape.size() - 1].get_length(); if (innermost_size < 32) { - GPU_DEBUG_TRACE << "Dynamic quantization: shape is too small " << innermost_size << " / " << dynamic_quantization_group_size << std::endl; + GPU_DEBUG_TRACE << root->get_friendly_name() << " dyn_quan is turned off: shape is too small - " << innermost_size << std::endl; + return true; + } + + // AZP does not support 8bit weight + if (debug_config->dynamic_quantize_asym + && (root->get_input_element_type(1) == ov::element::i8 || root->get_input_element_type(1) == ov::element::u8)) { + GPU_DEBUG_TRACE << root->get_friendly_name() << " dyn_quan is turned off: asym quantization does not support 8bit weight" << std::endl; + return true; + } + + bool has_wzp = root->get_input_size() > 4; + if ((root->get_input_element_type(1) == ov::element::i8 || root->get_input_element_type(1) == ov::element::u8) + && has_wzp + && dynamic_quantization_group_size != UINT64_MAX) { + GPU_DEBUG_TRACE << root->get_friendly_name() << " dyn_quan is turned off:" + " asym 8bit weight does not support grouped quantization" << std::endl; return true; } return false; diff --git a/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp b/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp index 65ca31f16c720c..380480dccc68bf 100644 --- a/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp +++ b/src/plugins/intel_gpu/src/runtime/debug_configuration.cpp @@ -190,6 +190,7 @@ static void print_help_messages() { "separated by space. Support case-insensitive and regular expression. For example .*fully_connected.*"); message_list.emplace_back("OV_GPU_DynamicQuantizeGroupSize", "Specify a group size of dynamic quantization to enable " "dynamic quantization for Fully-connected primitive."); + message_list.emplace_back("OV_GPU_DynamicQuantizeAsym", "Enable asymmetric dynamic quantization when set as 1."); message_list.emplace_back("OV_GPU_DisableHorizontalFCFusion", "Disable horizontal fc fusion"); message_list.emplace_back("OV_GPU_DisableFCSwigluFusion", "Disable fc + swiglu fusion"); message_list.emplace_back("OV_GPU_DumpIteration", "Dump n-th execution of network, separated by space."); @@ -260,6 +261,7 @@ debug_configuration::debug_configuration() , use_usm_host(0) , use_kv_cache_compression(-1) , dynamic_quantize_group_size(DYNAMIC_QUANTIZE_GROUP_SIZE_NOT_SET) + , dynamic_quantize_asym(0) , disable_horizontal_fc_fusion(0) , disable_fc_swiglu_fusion(0) { #ifdef GPU_DEBUG_CONFIG @@ -315,6 +317,7 @@ debug_configuration::debug_configuration() get_gpu_debug_env_var("UseUsmHost", use_usm_host); get_gpu_debug_env_var("KVCacheCompression", use_kv_cache_compression); get_gpu_debug_env_var("DynamicQuantizeGroupSize", dynamic_quantize_group_size); + get_gpu_debug_env_var("DynamicQuantizeAsym", dynamic_quantize_asym); get_gpu_debug_env_var("DisableHorizontalFCFusion", disable_horizontal_fc_fusion); get_gpu_debug_env_var("DisableFCSwigluFusion", disable_fc_swiglu_fusion); std::string dump_iteration_str; diff --git a/src/plugins/intel_gpu/src/runtime/execution_config.cpp b/src/plugins/intel_gpu/src/runtime/execution_config.cpp index 30a9477e1600dd..804ad81f2d3735 100644 --- a/src/plugins/intel_gpu/src/runtime/execution_config.cpp +++ b/src/plugins/intel_gpu/src/runtime/execution_config.cpp @@ -57,7 +57,7 @@ void ExecutionConfig::set_default() { std::make_tuple(ov::internal::query_model_ratio, 1.0f), std::make_tuple(ov::cache_mode, ov::CacheMode::OPTIMIZE_SPEED), std::make_tuple(ov::cache_encryption_callbacks, EncryptionCallbacks{}), - std::make_tuple(ov::hint::dynamic_quantization_group_size, 32), + std::make_tuple(ov::hint::dynamic_quantization_group_size, 0), std::make_tuple(ov::hint::kv_cache_precision, ov::element::undefined), std::make_tuple(ov::intel_gpu::hint::enable_kernels_reuse, false), std::make_tuple(ov::weights_path, ""), @@ -254,6 +254,11 @@ void ExecutionConfig::apply_user_properties(const cldnn::device_info& info) { set_property(ov::hint::kv_cache_precision(ov::element::i8)); } + // Enable dynamic quantization by default for non-systolic platforms + if (!is_set_by_user(ov::hint::dynamic_quantization_group_size) && !info.supports_immad) { + set_property(ov::hint::dynamic_quantization_group_size(32)); + } + user_properties.clear(); } diff --git a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/matmul_weights_decompression.cpp b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/matmul_weights_decompression.cpp index 27c57aa072878d..b430884decb71a 100644 --- a/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/matmul_weights_decompression.cpp +++ b/src/plugins/intel_gpu/tests/functional/subgraph_tests/dynamic/matmul_weights_decompression.cpp @@ -58,7 +58,8 @@ using MatmulWeightsDecompressionParams = std::tuple; class MatmulWeightsDecompression : public testing::WithParamInterface, @@ -74,6 +75,7 @@ class MatmulWeightsDecompression : public testing::WithParamInterface(dyn_input_ps.size(), 1); - group_sizes.back() = UINT64_MAX; + group_sizes.back() = group_size; - auto input_data = rg.generate_random_1d(ov::shape_size(data_shape), -16.0f, 16.0f); + auto input_data = rg.generate_random_1d(ov::shape_size(data_shape), -16.0f, 20.0f); set_values(input_mem, input_data); auto in_layout_f32 = input_shape.is_dynamic() ? layout{ dyn_input_ps, data_types::f32, format::bfyx } @@ -53,17 +58,15 @@ class dynamic_quantization_gpu_tests: public ::testing::Test { dynamic_quantize::Attributes dq_config; dq_config.quantization_type = quantization_type; - dq_config.quantization_dt = data_types::i8; + dq_config.quantization_dt = quant_dt; dq_config.scale_dt = data_types::f16; - dq_config.zp_dt = data_types::undefined; + dq_config.zp_dt = zp_dt; dq_config.group_sizes = group_sizes; - dq_config.scales_zp_output_order = { 0, 1, 2, 3 }; - dq_config.output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::Planar; + dq_config.scales_zp_output_order = { 0, 1, 2}; - if (quantization_type == QuantizationType::Asymmetric) { - dq_config.zp_dt = data_types::f16; - dq_config.output_storage_type = ov::op::internal::DynamicQuantize::OutputStorageType::InterleavedScalesZP; - } + if (data_shape.size() == 4) + dq_config.scales_zp_output_order.emplace_back(3); + dq_config.output_storage_type = storage_type; auto reorder_1 = reorder("reorder_1", input_info("input"), layout{ input_ps, data_types::f16, format::bfyx }); auto dyn_quan_prim = dynamic_quantize("dyn_quan_prim", input_info("reorder_1"), dq_config); @@ -156,6 +159,19 @@ TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_single_batch) { this->test_dynamic_quantization(false, {-1, 1, 1, 4096}, {1, 1, 1, 4096}); } +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_asym_act) { + this->test_dynamic_quantization(false, {-1, 1, 1, 4096}, {1, 1, 1, 4096}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::u8, data_types::u8, OutputStorageType::Planar); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_small_size_grouped) { + this->test_dynamic_quantization(false, {1, 1, 4096}, {64, 1, 4096}, QuantizationType::Symmetric, 32); +} + +TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_single_batch_grouped) { + this->test_dynamic_quantization(false, {-1, 1, 4096}, {1, 1, 4096}, QuantizationType::Symmetric, 32); +} + TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_ref_only) { this->test_dynamic_quantization(false, {-1, 1, 1, 33}, {16, 1, 1, 33}); } @@ -177,33 +193,36 @@ TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_unaligned_dynamic) { } TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache) { - this->test_dynamic_quantization(false, {-1, 8, -1, 96}, {1, 8, 1, 96}, QuantizationType::Symmetric, "dynamic_quantize_gpu_kv_cache"); + this->test_dynamic_quantization(false, {-1, 8, -1, 96}, {1, 8, 1, 96}, QuantizationType::Symmetric, UINT64_MAX, + data_types::i8, data_types::undefined, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); } TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched) { - this->test_dynamic_quantization(false, {-1, 4, -1, 64}, {1, 4, 35, 64}, QuantizationType::Symmetric, "dynamic_quantize_gpu_kv_cache"); + this->test_dynamic_quantization(false, {-1, 4, -1, 64}, {1, 4, 35, 64}, QuantizationType::Symmetric, UINT64_MAX, + data_types::i8, data_types::undefined, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); } TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_reordered) { - this->test_dynamic_quantization(false, {-1, -1, 8, 96}, {1, 1, 8, 96}, QuantizationType::Symmetric, "dynamic_quantize_gpu_kv_cache"); + this->test_dynamic_quantization(false, {-1, -1, 8, 96}, {1, 1, 8, 96}, QuantizationType::Symmetric, UINT64_MAX, + data_types::i8, data_types::undefined, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); } TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_reordered) { - this->test_dynamic_quantization(false, {-1, -1, 4, 64}, {1, 35, 4, 64}, QuantizationType::Symmetric, "dynamic_quantize_gpu_kv_cache"); + this->test_dynamic_quantization(false, {-1, -1, 4, 64}, {1, 35, 4, 64}, QuantizationType::Symmetric, UINT64_MAX, + data_types::i8, data_types::undefined, OutputStorageType::Planar, "dynamic_quantize_gpu_kv_cache"); } TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_asym) { - this->test_dynamic_quantization(false, {-1, 8, -1, 96}, {1, 8, 1, 96}, QuantizationType::Asymmetric, "dynamic_quantize_gpu_kv_cache"); + this->test_dynamic_quantization(false, {-1, 8, -1, 96}, {1, 8, 1, 96}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::f16, OutputStorageType::InterleavedScalesZP, "dynamic_quantize_gpu_kv_cache"); } TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_asym) { - this->test_dynamic_quantization(false, {-1, 4, -1, 64}, {1, 4, 35, 64}, QuantizationType::Asymmetric, "dynamic_quantize_gpu_kv_cache"); + this->test_dynamic_quantization(false, {-1, 4, -1, 64}, {1, 4, 35, 64}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::f16, OutputStorageType::InterleavedScalesZP, "dynamic_quantize_gpu_kv_cache"); } TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_reordered_asym) { - this->test_dynamic_quantization(false, {-1, -1, 8, 96}, {1, 1, 8, 96}, QuantizationType::Asymmetric, "dynamic_quantize_gpu_kv_cache"); -} - -TEST_F(dynamic_quantization_gpu_tests, simple_quantizing_kv_cache_batched_reordered_asym) { - this->test_dynamic_quantization(false, {-1, -1, 4, 64}, {1, 35, 4, 64}, QuantizationType::Asymmetric, "dynamic_quantize_gpu_kv_cache"); + this->test_dynamic_quantization(false, {-1, -1, 8, 96}, {1, 1, 8, 96}, QuantizationType::Asymmetric, UINT64_MAX, + data_types::i8, data_types::f16, OutputStorageType::InterleavedScalesZP, "dynamic_quantize_gpu_kv_cache"); } diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp index 6bf44a31add0f4..f59dc5c42cffc1 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp @@ -1555,7 +1555,7 @@ class fully_connected_gpu_tests: public ::testing::Test { auto config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); - config.set_property(ov::hint::dynamic_quantization_group_size(32)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(32)); network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), is_caching_test); @@ -1643,7 +1643,7 @@ class fully_connected_gpu_tests: public ::testing::Test { config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); ov::intel_gpu::ImplementationDesc fc_impl_desc = { format::bfyx, "fully_connected_gpu_bfyx_ref", impl_types::ocl }; config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"fc_prim", fc_impl_desc} })); - config.set_property(ov::hint::dynamic_quantization_group_size(0)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(0)); network network(engine, topology, config); network.set_input_data("input", input_mem); @@ -1669,7 +1669,7 @@ class fully_connected_gpu_tests: public ::testing::Test { auto config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); - config.set_property(ov::hint::dynamic_quantization_group_size(0)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(0)); network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), is_caching_test); @@ -1753,7 +1753,7 @@ class fully_connected_gpu_tests: public ::testing::Test { config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); ov::intel_gpu::ImplementationDesc fc_impl_desc = { format::bfyx, "fully_connected_gpu_bfyx_ref", impl_types::ocl }; config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"fc_prim", fc_impl_desc} })); - config.set_property(ov::hint::dynamic_quantization_group_size(0)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(0)); network network(engine, topology, config); network.set_input_data("input", input_mem); @@ -1780,9 +1780,9 @@ class fully_connected_gpu_tests: public ::testing::Test { config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); if (is_dyn_quan) { - config.set_property(ov::hint::dynamic_quantization_group_size(32)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(32)); } else { - config.set_property(ov::hint::dynamic_quantization_group_size(0)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(0)); } network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), is_caching_test); @@ -1923,7 +1923,7 @@ class fully_connected_gpu_tests: public ::testing::Test { config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); ov::intel_gpu::ImplementationDesc fc_impl = { in_layout.format, "", impl_types::ocl }; config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ { "fc_prim1", fc_impl }, { "fc_prim2", fc_impl } })); - config.set_property(ov::hint::dynamic_quantization_group_size(0)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(0)); network network(engine, topology, config); network.set_input_data("input", input_mem); @@ -1952,7 +1952,7 @@ class fully_connected_gpu_tests: public ::testing::Test { auto config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); - config.set_property(ov::hint::dynamic_quantization_group_size(0)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(0)); network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), is_caching_test); @@ -2905,7 +2905,7 @@ class fully_connected_gpu_tests: public ::testing::Test { config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); ov::intel_gpu::ImplementationDesc fc_impl_desc = { format::bfyx, "fully_connected_gpu_bfyx_ref", impl_types::ocl }; config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"fc_prim", fc_impl_desc} })); - config.set_property(ov::hint::dynamic_quantization_group_size(0)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(0)); network network(engine, topo, config); network.set_input_data("input", input_mem); @@ -2931,7 +2931,7 @@ class fully_connected_gpu_tests: public ::testing::Test { auto config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); - config.set_property(ov::hint::dynamic_quantization_group_size(quantize_group_size)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(quantize_group_size)); network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), false); @@ -3031,7 +3031,7 @@ class fully_connected_gpu_tests: public ::testing::Test { config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); ov::intel_gpu::ImplementationDesc fc_impl_desc = { format::bfyx, "fully_connected_gpu_bf_tiled", impl_types::ocl }; config.set_property(ov::intel_gpu::force_implementations(ov::intel_gpu::ImplForcingMap{ {"fc_prim", fc_impl_desc} })); - config.set_property(ov::hint::dynamic_quantization_group_size(0)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(0)); network network(engine, topo, config); network.set_input_data("input", input_mem); @@ -3057,7 +3057,7 @@ class fully_connected_gpu_tests: public ::testing::Test { auto config = get_test_default_config(engine); config.set_property(ov::intel_gpu::allow_new_shape_infer(true)); config.set_property(ov::intel_gpu::optimize_data(true)); - config.set_property(ov::hint::dynamic_quantization_group_size(quantize_group_size)); + config.set_user_property(ov::hint::dynamic_quantization_group_size(quantize_group_size)); network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), false); diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/hash_key_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/hash_key_gpu_test.cpp index fb30222998008b..3384fb1ed514f6 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/hash_key_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/hash_key_gpu_test.cpp @@ -71,11 +71,11 @@ class check_hash_value: public ::testing::Test { const auto primitive_hash = primitve->hash(); const auto params_hash = primitve->type->get_fake_aligned_params(*prim_inst->get_impl_params()).hash(); if (!engine.get_device_info().supports_immad) { - ASSERT_EQ(primitive_hash, 8017451717095756666UL); - ASSERT_EQ(params_hash, 8889154389021912103UL); + ASSERT_EQ(primitive_hash, 9510988594087947885UL); + ASSERT_EQ(params_hash, 7833603199176871790UL); } else { - ASSERT_EQ(primitive_hash, 8017451717095756666UL); - ASSERT_EQ(params_hash, 10847775446937354749UL); + ASSERT_EQ(primitive_hash, 9510988594087947885UL); + ASSERT_EQ(params_hash, 16259702189938020305UL); } } From a3f4edb3d8f12769c7ae7d39206730502fae711f Mon Sep 17 00:00:00 2001 From: Taylor Yeonbok Lee Date: Mon, 9 Dec 2024 14:47:37 +0900 Subject: [PATCH 2/8] [GPU] Fix crash on swiglu fused case (due to outer_ofm == 1) (#27972) ### Details: - fixed crash happens in minicpm-1b-sft int4 model ### Tickets: - *ticket-id* --- .../fully_connected_kernel_bf_tiled.cpp | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp index 68da7aea7b1fe6..d0f881adcd88b1 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp @@ -435,10 +435,14 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params, return selector.Default(tune_params(1, 1, 4, 4, 1, 1, 1, EXE_MODE_DEFAULT)); } } else if (is_weight_small_kn(params, output_f)) { - if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2) - return selector.Default(tune_params(1, 1, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT)); - else + if (params.weights.GetLayout() == WeightsLayout::os_is_yx_osv32_isv2) { + if (swiglu_fused) + return selector.Default(tune_params(1, 1, 4, 2, 2, 1, 1, EXE_MODE_DEFAULT)); + else + return selector.Default(tune_params(1, 1, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT)); + } else { return selector.Default(tune_params(1, 2, 4, 2, 1, 1, 1, EXE_MODE_DEFAULT)); + } } else { if (params.weights.GetLayout() == WeightsLayout::os_iyx_osv16) { return selector.Default(tune_params(1, 1, 4, 4, 1, 1, 1, EXE_MODE_DEFAULT)); @@ -865,7 +869,9 @@ KernelsData FullyConnected_bf_tiled::GetTunedKernelsDataByIndex(const Params &pa auto output_f = get_output_aligned_bf_size(fc_params, false).second; WeightsLayout weights_layout = WeightsLayout::os_iyx_osv16; - if (!is_swiglu_fused(fc_params) && fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16 + if (is_swiglu_fused(fc_params)) { + weights_layout = WeightsLayout::os_is_yx_osv32_isv2; + } else if (fc_params.compressed && fc_params.inputs[0].GetDType() == Datatype::F16 && (fc_params.weights.GetLayout() == WeightsLayout::oiyx || fc_params.weights.GetLayout() == WeightsLayout::os_is_yx_osv64_isv2) && (fc_params.weights.GetDType() == WeightsType::INT4 || fc_params.weights.GetDType() == WeightsType::UINT4) && is_weight_horizontal(fc_params, output_f)) { From 27138a8af6b9cd8e79b394ab5b56b4c61fd7deba Mon Sep 17 00:00:00 2001 From: Sebastian Golebiewski Date: Mon, 9 Dec 2024 07:40:37 +0100 Subject: [PATCH 3/8] [DOCS] saveModelSync method in Node.js addon (#27960) Porting: #27958 Signed-off-by: sgolebiewski-intel --- docs/sphinx_setup/api/nodejs_api/addon.rst | 37 ++++++++++++++++++++++ 1 file changed, 37 insertions(+) diff --git a/docs/sphinx_setup/api/nodejs_api/addon.rst b/docs/sphinx_setup/api/nodejs_api/addon.rst index f6ee4ab7b15836..7c42824bcd88a3 100644 --- a/docs/sphinx_setup/api/nodejs_api/addon.rst +++ b/docs/sphinx_setup/api/nodejs_api/addon.rst @@ -49,6 +49,7 @@ The **openvino-node** package exports ``addon`` which contains the following pro resizeAlgorithm: typeof resizeAlgorithm; PrePostProcessor: PrePostProcessorConstructor; }; + saveModelSync(model: Model, path: string, compressToFp16?: boolean): void; element: typeof element; } @@ -142,3 +143,39 @@ Properties - **Defined in:** `addon.ts:674 `__ + +.. rubric:: saveModelSync + +* + + .. code-block:: ts + + saveModelSync(model: Model, path: string, compressToFp16?: boolean): void; + + + This method saves a model to IR (xml and bin files), applying all + necessary transformations that are usually added during model conversion. + Particularly, weights are compressed to FP16 by default, and debug information + in model nodes is cleaned up. + + * **Parameters:** + + - model: :doc:`Model ` + + A model which will be converted to IR and saved. + + - path: string + + A path for saving the model. + + - ``Optional`` + + - compressToFp16: boolean + + Compression of weights to FP16 floating point precision. The default value is `true` . + + * **Returns:** void + + * **Defined in:** + `addon.ts:692 `__ + From 15a9b617fcfd591a14daf632cdeecbe99255bd64 Mon Sep 17 00:00:00 2001 From: Roman Kazantsev Date: Mon, 9 Dec 2024 12:33:16 +0400 Subject: [PATCH 4/8] [TF FE] Run If tests on all platforms (#27966) **Details:** Run If tests on all platforms **Ticket:** TBD --------- Signed-off-by: Kazantsev, Roman --- .../tensorflow_tests/test_tf_If.py | 44 ++++++++----------- 1 file changed, 18 insertions(+), 26 deletions(-) diff --git a/tests/layer_tests/tensorflow_tests/test_tf_If.py b/tests/layer_tests/tensorflow_tests/test_tf_If.py index 67686ef53a5750..21dee5aa28616d 100644 --- a/tests/layer_tests/tensorflow_tests/test_tf_If.py +++ b/tests/layer_tests/tensorflow_tests/test_tf_If.py @@ -1,13 +1,13 @@ # Copyright (C) 2018-2024 Intel Corporation # SPDX-License-Identifier: Apache-2.0 -import platform - import numpy as np import pytest import tensorflow as tf from common.tf_layer_test_class import CommonTFLayerTest +rng = np.random.default_rng(32345) + class TestIfFloat(CommonTFLayerTest): def _prepare_input(self, inputs_info): @@ -18,9 +18,9 @@ def _prepare_input(self, inputs_info): x_shape = inputs_info['x:0'] y_shape = inputs_info['y:0'] inputs_data = {} - inputs_data['cond:0'] = np.random.randint(0, 2, cond_shape).astype(bool) - inputs_data['x:0'] = np.random.randint(1, 10, x_shape).astype(np.float32) - inputs_data['y:0'] = np.random.randint(-50, 50, y_shape).astype(np.float32) + inputs_data['cond:0'] = rng.integers(0, 2, cond_shape).astype(bool) + inputs_data['x:0'] = rng.integers(1, 10, x_shape).astype(np.float32) + inputs_data['y:0'] = rng.integers(-50, 50, y_shape).astype(np.float32) return inputs_data def create_if_net(self, x_shape, y_shape, lower_control_flow): @@ -69,12 +69,10 @@ def else_branch(): @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.precommit @pytest.mark.nightly - @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', - reason='Ticket - 122716') def test_if_basic(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): if ie_device == 'GPU': - pytest.xfail('104855') + pytest.xfail('104855: If operation is not supported by GPU') self._test(*self.create_if_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) @@ -89,9 +87,9 @@ def _prepare_input(self, inputs_info): ind_shape = inputs_info['ind:0'] y_shape = inputs_info['y:0'] inputs_data = {} - inputs_data['cond:0'] = np.random.randint(0, 2, cond_shape).astype(bool) - inputs_data['ind:0'] = np.random.randint(1, 10, ind_shape).astype(np.int32) - inputs_data['y:0'] = np.random.randint(-50, 50, y_shape).astype(np.float32) + inputs_data['cond:0'] = rng.integers(0, 2, cond_shape).astype(bool) + inputs_data['ind:0'] = rng.integers(1, 10, ind_shape).astype(np.int32) + inputs_data['y:0'] = rng.integers(-50, 50, y_shape).astype(np.float32) return inputs_data def create_if_net(self, ind_shape, y_shape, lower_control_flow): @@ -141,12 +139,10 @@ def else_branch(): @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.precommit @pytest.mark.nightly - @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', - reason='Ticket - 122716') def test_if_basic(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): if ie_device == 'GPU': - pytest.xfail('104855') + pytest.xfail('104855: If operation is not supported by GPU') self._test(*self.create_if_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) @@ -161,9 +157,9 @@ def _prepare_input(self, inputs_info): y_shape = inputs_info['y:0'] z_shape = inputs_info['z:0'] inputs_data = {} - inputs_data['x:0'] = np.random.randint(0, 6, x_shape).astype(np.int32) - inputs_data['y:0'] = np.random.randint(1, 10, y_shape).astype(np.float32) - inputs_data['z:0'] = np.random.randint(-50, 50, z_shape).astype(np.float32) + inputs_data['x:0'] = rng.integers(0, 6, x_shape).astype(np.int32) + inputs_data['y:0'] = rng.integers(1, 10, y_shape).astype(np.float32) + inputs_data['z:0'] = rng.integers(-50, 50, z_shape).astype(np.float32) return inputs_data def create_if_net(self, y_shape, z_shape, lower_control_flow): @@ -221,12 +217,10 @@ def else_branch(): @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.precommit @pytest.mark.nightly - @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', - reason='Ticket - 122716') def test_if_basic(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): if ie_device == 'GPU': - pytest.xfail('104855') + pytest.xfail('104855: If operation is not supported by GPU') self._test(*self.create_if_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) @@ -241,9 +235,9 @@ def _prepare_input(self, inputs_info): x_shape = inputs_info['x:0'] y_shape = inputs_info['y:0'] inputs_data = {} - inputs_data['cond:0'] = np.random.randint(0, 2, cond_shape).astype(bool) - inputs_data['x:0'] = np.random.randint(1, 10, x_shape).astype(np.float32) - inputs_data['y:0'] = np.random.randint(-50, 50, y_shape).astype(np.float32) + inputs_data['cond:0'] = rng.integers(0, 2, cond_shape).astype(bool) + inputs_data['x:0'] = rng.integers(1, 10, x_shape).astype(np.float32) + inputs_data['y:0'] = rng.integers(-50, 50, y_shape).astype(np.float32) return inputs_data def create_sequential_ifs_net(self, x_shape, y_shape, lower_control_flow): @@ -313,12 +307,10 @@ def else_branch(): @pytest.mark.parametrize("params", test_data_basic) @pytest.mark.precommit @pytest.mark.nightly - @pytest.mark.xfail(condition=platform.system() == 'Darwin' and platform.machine() == 'arm64', - reason='Ticket - 122716') def test_if_basic(self, params, ie_device, precision, ir_version, temp_dir, use_legacy_frontend): if ie_device == 'GPU': - pytest.xfail('104855') + pytest.xfail('104855: If operation is not supported by GPU') self._test(*self.create_sequential_ifs_net(**params), ie_device, precision, ir_version, temp_dir=temp_dir, use_legacy_frontend=use_legacy_frontend) From 408a5e065200b1fcb41200f9361094fa1c7df5d7 Mon Sep 17 00:00:00 2001 From: Mingyu Kim Date: Mon, 9 Dec 2024 17:49:45 +0900 Subject: [PATCH 5/8] [GPU] update onednn to latest 3.7-pc (#27811) --- src/plugins/intel_gpu/thirdparty/onednn_gpu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/plugins/intel_gpu/thirdparty/onednn_gpu b/src/plugins/intel_gpu/thirdparty/onednn_gpu index 0f269193c74663..36e090a367a431 160000 --- a/src/plugins/intel_gpu/thirdparty/onednn_gpu +++ b/src/plugins/intel_gpu/thirdparty/onednn_gpu @@ -1 +1 @@ -Subproject commit 0f269193c7466313888d3338209d0d06a22cc6fa +Subproject commit 36e090a367a4312a1caa2db9e95fb94d17d7573b From de949b4a2b59faf1bf701528dd37b7ecd076d4e0 Mon Sep 17 00:00:00 2001 From: Yuan Hu Date: Mon, 9 Dec 2024 17:08:40 +0800 Subject: [PATCH 6/8] [CPU] enable brdgmm kernel in CPU plugin (#27589) ### Details: - *replace impl string brdgmm with brgconv* - *add test case* - *remove skip CVS-56143 config, CVS-56143 is already closed* - *remove skip CVS-53578 config, CVS-53578 is already closed* - *use new ticket CVS-157596 to track leftover test case* ### Tickets: - *CVS-156792* --------- Signed-off-by: HU Yuan2 --- src/plugins/intel_cpu/src/nodes/conv.cpp | 13 +- .../intel_cpu/src/onednn/iml_type_mapper.cpp | 3 + .../intel_cpu/src/onednn/iml_type_mapper.h | 3 + .../single_layer_tests/group_convolution.cpp | 126 +++++++++++++++++- .../skip_tests_config.cpp | 10 +- 5 files changed, 140 insertions(+), 15 deletions(-) diff --git a/src/plugins/intel_cpu/src/nodes/conv.cpp b/src/plugins/intel_cpu/src/nodes/conv.cpp index 7cf7698e989343..53d53d093cfabf 100644 --- a/src/plugins/intel_cpu/src/nodes/conv.cpp +++ b/src/plugins/intel_cpu/src/nodes/conv.cpp @@ -343,6 +343,7 @@ const std::vector& Convolution::getDefaultImplPriority() { impl_desc_type::winograd_acl, impl_desc_type::gemm_acl, impl_desc_type::acl, + impl_desc_type::brgconv_avx512_dw, impl_desc_type::brgconv_avx512_amx_1x1, impl_desc_type::brgconv_avx512_amx, impl_desc_type::jit_avx512_amx_dw, @@ -353,6 +354,7 @@ const std::vector& Convolution::getDefaultImplPriority() { impl_desc_type::jit_avx512_dw, impl_desc_type::jit_avx512_1x1, impl_desc_type::jit_avx512, + impl_desc_type::brgconv_avx2_dw, impl_desc_type::brgconv_avx2_1x1, impl_desc_type::brgconv_avx2, impl_desc_type::jit_uni_dw, @@ -815,7 +817,11 @@ void Convolution::initSupportedPrimitiveDescriptors() { #endif for (size_t dIdx = 0; dIdx < descs.size(); dIdx++) { auto& desc = descs[dIdx]; - auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(desc.get())); + auto primitive_desc = desc.get(true); //true mean allow empty + if (primitive_desc == nullptr) { + continue; + } + auto first_desc = dnnl::primitive_desc(DnnlExtensionUtils::clone_primitive_desc(primitive_desc)); auto add_supported_desc = [&](dnnl::primitive_desc& desc) { addSupportedPrimitiveDescriptor(desc); @@ -823,7 +829,7 @@ void Convolution::initSupportedPrimitiveDescriptors() { }; const bool first_match = customImplPriorities.empty(); - DEBUG_LOG("#", getName(), + DEBUG_LOG("#", getName(), ",descIndex:", dIdx + 1, "/", descs.size(), ", itpd.impl_info_str(): ", desc.impl_info_str(), ", parsed imp_type: ", impl_type_to_string(parse_impl_name(desc.impl_info_str())), ", first_match: ", first_match ? "true" : "false"); @@ -944,8 +950,7 @@ void Convolution::createDescriptor(const std::vector& inputDesc, const auto desc = createDescriptorInternal(getEngine(), inDnnlDesc, weightDnnlDesc, biasDnnlDesc, outDnnlDesc, withBiases, stride, dilation, paddingL, paddingR, alg, attr); - if (desc) - descs.emplace_back(desc); + descs.emplace_back(desc); } } } diff --git a/src/plugins/intel_cpu/src/onednn/iml_type_mapper.cpp b/src/plugins/intel_cpu/src/onednn/iml_type_mapper.cpp index d7a1e5979ddad9..5c57a94f69f67d 100644 --- a/src/plugins/intel_cpu/src/onednn/iml_type_mapper.cpp +++ b/src/plugins/intel_cpu/src/onednn/iml_type_mapper.cpp @@ -17,6 +17,7 @@ impl_desc_type parse_impl_name(std::string impl_desc_name) { if (pos != std::string::npos) impl_desc_name.replace(pos, std::string(#_wrd).length(), #_sub); } // Replace the ONEDNN pd name with OV definition. REPLACE_WORD(brg_conv, brgconv); + REPLACE_WORD(brdgmm, brgconv); REPLACE_WORD(avx10_1_512, avx512); REPLACE_WORD(brg_matmul, brgemm); @@ -119,6 +120,8 @@ const char* impl_type_to_string(impl_desc_type type) { CASE(brgconv_sse42_1x1); CASE(brgconv_uni_1x1); CASE(brgconv_avx512_amx_1x1); + CASE(brgconv_avx512_dw); + CASE(brgconv_avx2_dw); CASE(brgemm_avx512); CASE(brgemm_avx2); CASE(brgemm_avx); diff --git a/src/plugins/intel_cpu/src/onednn/iml_type_mapper.h b/src/plugins/intel_cpu/src/onednn/iml_type_mapper.h index 3fd79716c7cd72..45a71bdb88dd33 100644 --- a/src/plugins/intel_cpu/src/onednn/iml_type_mapper.h +++ b/src/plugins/intel_cpu/src/onednn/iml_type_mapper.h @@ -98,6 +98,9 @@ enum impl_desc_type : int64_t { brgconv_uni_1x1 = brgconv | uni | _1x1, brgconv_avx512_amx_1x1 = brgconv | avx512 | amx | _1x1, + brgconv_avx2_dw = brgconv_avx2 | _dw, + brgconv_avx512_dw = brgconv_avx512 | _dw, + brgemm_avx512 = brgemm | avx512, brgemm_avx2 = brgemm | avx2, brgemm_avx = brgemm | avx, diff --git a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/group_convolution.cpp b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/group_convolution.cpp index 47d7d3072b7337..f3f5b1f2e07975 100644 --- a/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/group_convolution.cpp +++ b/src/plugins/intel_cpu/tests/functional/custom/single_layer_tests/group_convolution.cpp @@ -5,6 +5,7 @@ #include "shared_test_classes/single_op/group_convolution.hpp" #include "common_test_utils/node_builders/group_convolution.hpp" +#include "openvino/runtime/system_conf.hpp" #include "shared_test_classes/base/ov_subgraph.hpp" #include "utils/convolution_params.hpp" #include "utils/cpu_test_utils.hpp" @@ -176,14 +177,15 @@ class GroupConvolutionLayerCPUTest : public testing::WithParamInterface()) { - selectedType += "_bf16"; - rel_threshold = 1e-2f; - } else { - selectedType = makeSelectedTypeStr(selectedType, netType); + const auto& it = configuration.find(ov::hint::inference_precision.name()); + if (it != configuration.end()) { + if (ov::element::bf16 == it->second.as()) { + rel_threshold = 1e-2f; + } else if (ov::element::f16 == it->second.as()) { + rel_threshold = 0.00125f; + } } + selectedType = makeSelectedTypeStr(selectedType, deduce_expected_precision(netType, configuration)); // according to range propagation feature, resolution of generated inputs data for parameters moved from 32 to 32768 // 'real' part of input data was changed and some fails became visible for cases with Elu and FakeQuantize, so let's setup abs_threshold @@ -289,6 +291,7 @@ std::vector filterCPUInfoForDeviceSupportBF16(std::vector fusingParamsSetBF16{emptyFusingSpec, // sum fusingSum}; +const std::vector fusingParamsSet_Brdgmm{emptyFusingSpec, + // eltwise + fusingRelu, + fusingPRelu1D, + // depthwise + fusingReluScaleShift, + // fake quantize + fusingFakeQuantizePerTensorRelu, + fusingFakeQuantizePerChannelRelu + // sum + // comment out sum due to MFDNN-12841 + //fusingSumEluFQ, + //fusingSum + }; + +const std::vector fusingParamsSetBF16_Brdgmm{emptyFusingSpec, + // eltwise + fusingRelu, + // depthwise + fusingReluScaleShift + // sum + // comment out sum due to MFDNN-12841 + //fusingSum + }; + +const std::vector fusingParamsSetFP16_Brdgmm = fusingParamsSetBF16_Brdgmm; + /* ============= GroupConvolution params (planar layout) ============= */ const std::vector numOutChannels_Gemm = {6}; const std::vector numGroups_Gemm = {2, 3}; @@ -1299,6 +1329,38 @@ INSTANTIATE_TEST_SUITE_P(smoke_GroupConv_2D_DW_FP32, ::testing::Values(empty_plugin_config)), GroupConvolutionLayerCPUTest::getTestCaseName); +const std::vector> dilations2d_Brdgmm = {{1, 1}}; +const auto groupConvParams_ExplicitPadding_DW_2D_Brdgmm = ::testing::Combine(::testing::ValuesIn(kernels2d), + ::testing::ValuesIn(strides2d), + ::testing::ValuesIn(padBegins2d), + ::testing::ValuesIn(padEnds2d), + ::testing::ValuesIn(dilations2d_Brdgmm), + ::testing::ValuesIn(numOutChannels_DW), + ::testing::ValuesIn(numGroups_DW), + ::testing::Values(ov::op::PadType::EXPLICIT)); +const auto BrdgmmCPUSpec = []()-> std::vector { + std::string isaStr; + if (ov::with_cpu_x86_avx512f()) { + isaStr = "avx512"; + } else { + isaStr = "avx2"; + } + return {CPUSpecificParams{{}, {}, {}, "brgconv_" + isaStr + "_dw"}}; +}; + +INSTANTIATE_TEST_SUITE_P(smoke_GroupConv_2D_DW_FP32_Brdgmm, + GroupConvolutionLayerCPUTest, + ::testing::Combine(::testing::Combine(groupConvParams_ExplicitPadding_DW_2D_Brdgmm, + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::ValuesIn(inputShapes2dDW), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + ::testing::ValuesIn(filterCPUInfoForDevice(BrdgmmCPUSpec())), + ::testing::ValuesIn(fusingParamsSet_Brdgmm), + ::testing::Values(empty_plugin_config)), + GroupConvolutionLayerCPUTest::getTestCaseName); + INSTANTIATE_TEST_SUITE_P(smoke_GroupConv_2D_DW_BF16, GroupConvolutionLayerCPUTest, ::testing::Combine(::testing::Combine(groupConvParams_ExplicitPadding_DW_2D, @@ -1313,6 +1375,32 @@ INSTANTIATE_TEST_SUITE_P(smoke_GroupConv_2D_DW_BF16, ::testing::Values(cpu_bf16_plugin_config)), GroupConvolutionLayerCPUTest::getTestCaseName); +INSTANTIATE_TEST_SUITE_P(smoke_GroupConv_2D_DW_BF16_Brdgmm, + GroupConvolutionLayerCPUTest, + ::testing::Combine(::testing::Combine(groupConvParams_ExplicitPadding_DW_2D_Brdgmm, + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::ValuesIn(inputShapes2dDW), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + ::testing::ValuesIn(filterCPUInfoForDeviceSupportBF16(BrdgmmCPUSpec())), + ::testing::ValuesIn(fusingParamsSetBF16_Brdgmm), + ::testing::Values(cpu_bf16_plugin_config)), + GroupConvolutionLayerCPUTest::getTestCaseName); + +INSTANTIATE_TEST_SUITE_P(smoke_GroupConv_2D_DW_FP16_Brdgmm, + GroupConvolutionLayerCPUTest, + ::testing::Combine(::testing::Combine(groupConvParams_ExplicitPadding_DW_2D_Brdgmm, + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::ValuesIn(inputShapes2dDW), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + ::testing::ValuesIn(filterCPUInfoForDevice(BrdgmmCPUSpec())), + ::testing::ValuesIn(fusingParamsSetFP16_Brdgmm), + ::testing::Values(cpu_f16_plugin_config)), + GroupConvolutionLayerCPUTest::getTestCaseName); + /* ============= GroupConvolution (DW 3D) ============= */ const auto groupConvParams_ExplicitPadding_DW_3D = ::testing::Combine(::testing::ValuesIn(kernels3d), ::testing::ValuesIn(strides3d), @@ -1349,6 +1437,30 @@ INSTANTIATE_TEST_SUITE_P(smoke_GroupConv_3D_DW_FP32, ::testing::ValuesIn(fusingParamsSet), ::testing::Values(empty_plugin_config)), GroupConvolutionLayerCPUTest::getTestCaseName); + +const std::vector> dilations3d_Brdgmm = {{1, 1, 1}}; +const auto groupConvParams_ExplicitPadding_DW_3D_Brdgmm = ::testing::Combine(::testing::ValuesIn(kernels3d), + ::testing::ValuesIn(strides3d), + ::testing::ValuesIn(padBegins3d), + ::testing::ValuesIn(padEnds3d), + ::testing::ValuesIn(dilations3d_Brdgmm), + ::testing::ValuesIn(numOutChannels_DW), + ::testing::ValuesIn(numGroups_DW), + ::testing::Values(ov::op::PadType::EXPLICIT)); + +INSTANTIATE_TEST_SUITE_P(smoke_GroupConv_3D_DW_FP32_Brdgmm, + GroupConvolutionLayerCPUTest, + ::testing::Combine(::testing::Combine(groupConvParams_ExplicitPadding_DW_3D_Brdgmm, + ::testing::Values(ElementType::f32), + ::testing::Values(ElementType::undefined), + ::testing::Values(ElementType::undefined), + ::testing::ValuesIn(inputShapes3dDW), + ::testing::Values(ov::test::utils::DEVICE_CPU)), + ::testing::ValuesIn(filterCPUInfoForDevice(BrdgmmCPUSpec())), + ::testing::ValuesIn(fusingParamsSet_Brdgmm), + ::testing::Values(empty_plugin_config)), + GroupConvolutionLayerCPUTest::getTestCaseName); + /* ========= */ /* ============= SINGLE TEST CASES ============= */ diff --git a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp index b675a7c2da7d42..089a03b4d6bba7 100644 --- a/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp +++ b/src/plugins/intel_cpu/tests/functional/shared_tests_instances/skip_tests_config.cpp @@ -40,10 +40,12 @@ std::vector disabledTestPatterns() { R"(.*BinaryConvolutionLayerTest.*)", // TODO: 53618. BF16 gemm ncsp convolution crash R"(.*_GroupConv.*_inFmts=nc.*_primitive=jit_gemm.*ENFORCE_BF16=YES.*)", - // TODO: 53578. fork DW bf16 convolution does not support 3d cases yet - R"(.*_DW_GroupConv.*_inFmts=(ndhwc|nCdhw16c).*ENFORCE_BF16=YES.*)", - // TODO: 56143. Enable nspc convolutions for bf16 precision - R"(.*ConvolutionLayerCPUTest.*_inFmts=(ndhwc|nhwc).*INFERENCE_PRECISION_HINT=bf16.*)", + // TODO: 157596 convolution bf16 leftover test case + R"(smoke_JIT_AVX512_DW_GroupConv/GroupConvolutionLayerCPUTest.*ndhwc.*jit_avx512_dw.*INFERENCE_PRECISION_HINT=bf16.*)", + R"(smoke_Conv_1D_1x1_BF16/ConvolutionLayerCPUTest\.CompareWithRefs/IS=\[\]_TS=\(\((1|2)\.6(4|7)\.7\)_\)_K\(1\)_S\(1\)_PB\(0\)_PE\(0\)_D=\(1\)_O=63_AP=explicit_netPRC=f32_inPRC=undefined_outPRC=undefined_trgDev=CPU_inFmts=nhwc_outFmts=nhwc_primitive=jit_avx512_1x1_.*PluginConf_INFERENCE_PRECISION_HINT=bf16)", + R"(smoke_Conv_1D_1x1_BF16/ConvolutionLayerCPUTest\.CompareWithRefs/IS=\[1\.\.200\.64\.\?\]_TS=\(\(2\.64\.7\)_\(1\.64\.5\)_\)_K\(1\)_S\(1\)_PB\(0\)_PE\(0\)_D=\(1\)_O=63_AP=explicit_netPRC=f32_inPRC=undefined_outPRC=undefined_trgDev=CPU_inFmts=nhwc_outFmts=nhwc_primitive=jit_avx512_1x1_.*PluginConf_INFERENCE_PRECISION_HINT=bf16)", + R"(smoke_Conv_1D_1x1_BF16/ConvolutionLayerCPUTest\.CompareWithRefs/IS=\[\?\.6(4|7)\.1\.\.200\]_TS=\(\(2\.6(4|7)\.7\)_\(1\.6(4|7)\.9\)_\)_K\(1\)_S\(1\)_PB\(0\)_PE\(0\)_D=\(1\)_O=63_AP=explicit_netPRC=f32_inPRC=undefined_outPRC=undefined_trgDev=CPU_inFmts=nhwc_outFmts=nhwc_primitive=jit_avx512_1x1_.*PluginConf_INFERENCE_PRECISION_HINT=bf16)", + R"(smoke_GroupConv_brgemm_2D_BF16/GroupConvolutionLayerCPUTest\.CompareWithRefs/IS=\[\]_TS=\(\(1\.64\.7\.7\)_\)_K\(3\.3\)_S\(2\.2\)_PB\((0|1)\.(0|1)\)_PE\(0\.0\)_D=\(2\.2\)_O=64_G=2_AP=explicit_netPRC=f32_inPRC=undefined_outPRC=undefined_trgDev=CPU_inFmts=nhwc_outFmts=nhwc_primitive=brgconv_avx512_amx_.*PluginConf_INFERENCE_PRECISION_HINT=bf16)", // TODO: 56827. Sporadic test failures R"(.*smoke_Conv.+_FP32.ConvolutionLayerCPUTest\.CompareWithRefs.*TS=\(\(.\.67.+\).*inFmts=n.+c.*_primitive=jit_avx2.*)", // incorrect jit_uni_planar_convolution with dilation = {1, 2, 1} and output channel 1 From de776f279c87e542c640acc8140aaf87f278c991 Mon Sep 17 00:00:00 2001 From: Andrei Kashchikhin Date: Mon, 9 Dec 2024 09:27:11 +0000 Subject: [PATCH 7/8] [CI] [GHA] Introduce additional Python (3.9-3.12) API tests on macOS (#27666) ### Details: - Based on #27304, should be reviewed after it. ### Tickets: - *152690* --- .github/workflows/job_python_api_tests.yml | 142 ++++++++++++++++++++ .github/workflows/job_python_unit_tests.yml | 54 ++------ .github/workflows/job_samples_tests.yml | 14 +- .github/workflows/linux_arm64.yml | 10 ++ .github/workflows/mac.yml | 60 ++++++++- .github/workflows/mac_arm64.yml | 57 +++++++- .github/workflows/ubuntu_22.yml | 10 ++ .github/workflows/ubuntu_24.yml | 10 ++ 8 files changed, 304 insertions(+), 53 deletions(-) create mode 100644 .github/workflows/job_python_api_tests.yml diff --git a/.github/workflows/job_python_api_tests.yml b/.github/workflows/job_python_api_tests.yml new file mode 100644 index 00000000000000..541a14e2b1b6df --- /dev/null +++ b/.github/workflows/job_python_api_tests.yml @@ -0,0 +1,142 @@ +name: Python API tests + +on: + workflow_call: + inputs: + runner: + description: 'Machine on which the tests would run' + type: string + required: true + container: + description: 'JSON to be converted to the value of the "container" configuration for the job' + type: string + required: false + default: '{"image": null}' + python-version: + description: 'Python version to setup. E.g., "3.11"' + type: string + required: true + +permissions: read-all + +env: + PIP_CACHE_PATH: /mount/caches/pip/linux + +jobs: + Python_Unit_Tests: + name: Python API tests + timeout-minutes: 30 + runs-on: ${{ inputs.runner }} + container: ${{ fromJSON(inputs.container) }} + defaults: + run: + shell: bash + env: + DEBIAN_FRONTEND: noninteractive # to prevent apt-get from waiting user input + OPENVINO_REPO: ${{ github.workspace }}/openvino + INSTALL_DIR: ${{ github.workspace }}/install + INSTALL_TEST_DIR: ${{ github.workspace }}/install/openvino_tests + INSTALL_WHEELS_DIR: ${{ github.workspace }}/install/openvino_wheels + steps: + - name: Download OpenVINO artifacts (tarballs and wheels) + uses: actions/download-artifact@fa0a91b85d4f404e444e00e005971372dc801d16 # v4.1.8 + with: + pattern: openvino_@(wheels|tests) + path: ${{ env.INSTALL_DIR }} + + # Needed as ${{ github.workspace }} is not working correctly when using Docker + - name: Setup Variables + run: | + echo "OPENVINO_REPO=$GITHUB_WORKSPACE/openvino" >> "$GITHUB_ENV" + echo "INSTALL_DIR=$GITHUB_WORKSPACE/install" >> "$GITHUB_ENV" + echo "INSTALL_TEST_DIR=$GITHUB_WORKSPACE/install/openvino_tests" >> "$GITHUB_ENV" + echo "INSTALL_WHEELS_DIR=$GITHUB_WORKSPACE/install/openvino_wheels" >> "$GITHUB_ENV" + + - name: Install OpenVINO dependencies (mac) + if: runner.os == 'macOS' + run: brew install pigz + + - name: Extract OpenVINO packages + run: pigz -dc openvino_tests.tar.gz | tar -xf - -C ${INSTALL_TEST_DIR} + working-directory: ${{ env.INSTALL_TEST_DIR }} + + - name: Fetch setup_python and install wheels actions + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 + with: + sparse-checkout: | + .github/actions/setup_python/action.yml + .github/actions/install_ov_wheels/action.yml + sparse-checkout-cone-mode: false + path: 'action_root' + + - name: Setup Python ${{ inputs.python-version }} + uses: ./action_root/.github/actions/setup_python + with: + version: ${{ inputs.python-version }} + pip-cache-path: ${{ runner.os == 'Linux' && env.PIP_CACHE_PATH || '' }} + should-setup-pip-paths: ${{ runner.os == 'Linux' }} + self-hosted-runner: ${{ runner.os == 'Linux' }} + + # + # Tests + # + - name: Install OpenVINO Python wheels + uses: ./action_root/.github/actions/install_ov_wheels + with: + wheels-dir-path: ${{ env.INSTALL_WHEELS_DIR }} + wheels-to-install: 'openvino' + + - name: Install Python API tests dependencies + run: python3 -m pip install -r ${INSTALL_TEST_DIR}/tests/bindings/python/requirements_test.txt + + # + # Tests + # + + - name: Python API Tests + run: | + # for 'template' extension + export LD_LIBRARY_PATH=${INSTALL_TEST_DIR}/tests/:$LD_LIBRARY_PATH + python3 -m pytest -sv ${INSTALL_TEST_DIR}/tests/pyopenvino \ + --junitxml=${INSTALL_TEST_DIR}/TEST-Pyngraph.xml \ + --ignore=${INSTALL_TEST_DIR}/tests/pyopenvino/tests/test_utils/test_utils.py + + - name: Python API Tests -- numpy>=2.0.0 + run: | + python3 -m pip uninstall -y numpy + python3 -m pip install "numpy~=2.0.0" + python3 -m pip install -r ${INSTALL_TEST_DIR}/tests/bindings/python/requirements_test.txt + # for 'template' extension + export LD_LIBRARY_PATH=${INSTALL_TEST_DIR}/tests/:$LD_LIBRARY_PATH + python3 -m pytest -sv ${INSTALL_TEST_DIR}/tests/pyopenvino \ + --junitxml=${INSTALL_TEST_DIR}/TEST-Pyngraph_new_numpy.xml \ + --ignore=${INSTALL_TEST_DIR}/tests/pyopenvino/tests/test_utils/test_utils.py + + - name: Clone API snippets + if: runner.os != 'macOS' + uses: actions/checkout@d632683dd7b4114ad314bca15554477dd762a938 # v4.2.0 + with: + sparse-checkout: docs/articles_en/assets/snippets + path: ${{ env.OPENVINO_REPO }} + submodules: 'false' + + - name: Docs Python snippets + if: runner.os != 'macOS' + run: | + # torch, onnx + python3 -m pip install -r ${INSTALL_TEST_DIR}/tests/python/preprocess/torchvision/requirements.txt -r ${INSTALL_TEST_DIR}/tests/requirements_onnx + # to find 'snippets' module in docs + export PYTHONPATH=${OPENVINO_REPO}/docs/articles_en/assets + # for 'template' extension + export LD_LIBRARY_PATH=${INSTALL_TEST_DIR}/tests/:$LD_LIBRARY_PATH + python3 ${OPENVINO_REPO}/docs/articles_en/assets/snippets/main.py + + - name: Upload Test Results + uses: actions/upload-artifact@50769540e7f4bd5e21e526ee35c689e35e0d6874 # v4.4.0 + if: ${{ !cancelled() }} + with: + name: test-results-python-api-${{ inputs.python-version }} + path: | + ${{ env.INSTALL_TEST_DIR }}/TEST*.html + ${{ env.INSTALL_TEST_DIR }}/TEST*.xml + if-no-files-found: 'warn' diff --git a/.github/workflows/job_python_unit_tests.yml b/.github/workflows/job_python_unit_tests.yml index 8075f3299fe063..47506c83bf0945 100644 --- a/.github/workflows/job_python_unit_tests.yml +++ b/.github/workflows/job_python_unit_tests.yml @@ -65,21 +65,22 @@ jobs: echo "INSTALL_DIR=$GITHUB_WORKSPACE/install" >> "$GITHUB_ENV" echo "INSTALL_TEST_DIR=$GITHUB_WORKSPACE/install/tests" >> "$GITHUB_ENV" echo "LAYER_TESTS_INSTALL_DIR=$GITHUB_WORKSPACE/install/tests/layer_tests" >> "$GITHUB_ENV" + echo "INSTALL_WHEELS_DIR=$GITHUB_WORKSPACE/install/wheels" >> "$GITHUB_ENV" - name: Install OpenVINO dependencies (mac) if: runner.os == 'macOS' run: brew install pigz - name: Extract OpenVINO packages - run: | - pigz -dc openvino_tests.tar.gz | tar -xf - -C ${INSTALL_DIR} + run: pigz -dc openvino_tests.tar.gz | tar -xf - -C ${INSTALL_DIR} working-directory: ${{ env.INSTALL_DIR }} - - name: Fetch setup_python action + - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: sparse-checkout: | .github/actions/setup_python/action.yml + .github/actions/install_ov_wheels/action.yml sparse-checkout-cone-mode: false path: 'action_root' @@ -92,11 +93,10 @@ jobs: self-hosted-runner: ${{ runner.os == 'Linux' }} - name: Install OpenVINO Python wheels - run: | - # Install the core OV wheel - python3 -m pip install ./openvino-*.whl - - working-directory: ${{ env.INSTALL_WHEELS_DIR }} + uses: ./action_root/.github/actions/install_ov_wheels + with: + wheels-dir-path: ${{ env.INSTALL_WHEELS_DIR }} + wheels-to-install: 'openvino' - name: Install Python API tests dependencies run: | @@ -121,15 +121,6 @@ jobs: # Tests # - - name: Python API Tests - if: ${{ fromJSON(inputs.affected-components).Python_API.test }} - run: | - # for 'template' extension - export LD_LIBRARY_PATH=${INSTALL_TEST_DIR}:$LD_LIBRARY_PATH - python3 -m pytest -sv ${INSTALL_TEST_DIR}/pyopenvino \ - --junitxml=${INSTALL_TEST_DIR}/TEST-Pyngraph.xml \ - --ignore=${INSTALL_TEST_DIR}/pyopenvino/tests/test_utils/test_utils.py - - name: Python ONNX operators tests if: (fromJSON(inputs.affected-components).Python_API.test || fromJSON(inputs.affected-components).ONNX_FE.test) && @@ -185,35 +176,6 @@ jobs: TEST_DEVICE: CPU TEST_PRECISION: FP16 - - name: Clone API snippets - if: runner.os != 'macOS' - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - sparse-checkout: docs/articles_en/assets/snippets - path: ${{ env.OPENVINO_REPO }} - submodules: 'false' - - - name: Docs Python snippets - if: runner.os != 'macOS' - run: | - # to find 'snippets' module in docs - export PYTHONPATH=${OPENVINO_REPO}/docs/articles_en/assets - # for 'template' extension - export LD_LIBRARY_PATH=${INSTALL_TEST_DIR}:$LD_LIBRARY_PATH - python3 ${OPENVINO_REPO}/docs/articles_en/assets/snippets/main.py - - - name: Python API Tests -- numpy>=2.0.0 - if: ${{ fromJSON(inputs.affected-components).Python_API.test }} - run: | - python3 -m pip uninstall -y numpy - python3 -m pip install "numpy>=2.0.0,<2.2.0" - python3 -m pip install -r ${INSTALL_TEST_DIR}/bindings/python/requirements_test.txt - # for 'template' extension - export LD_LIBRARY_PATH=${INSTALL_TEST_DIR}:$LD_LIBRARY_PATH - python3 -m pytest -sv ${INSTALL_TEST_DIR}/pyopenvino \ - --junitxml=${INSTALL_TEST_DIR}/TEST-Pyngraph.xml \ - --ignore=${INSTALL_TEST_DIR}/pyopenvino/tests/test_utils/test_utils.py - - name: Upload Test Results uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 if: ${{ !cancelled() }} diff --git a/.github/workflows/job_samples_tests.yml b/.github/workflows/job_samples_tests.yml index e144aa0cfb95aa..6f95d316abfc3f 100644 --- a/.github/workflows/job_samples_tests.yml +++ b/.github/workflows/job_samples_tests.yml @@ -54,6 +54,7 @@ jobs: echo "INSTALL_DIR=$GITHUB_WORKSPACE/install" >> "$GITHUB_ENV" echo "INSTALL_TEST_DIR=$GITHUB_WORKSPACE/install/tests" >> "$GITHUB_ENV" echo "BUILD_DIR=$GITHUB_WORKSPACE/build" >> "$GITHUB_ENV" + echo "INSTALL_WHEELS_DIR=$GITHUB_WORKSPACE/install/wheels" >> "$GITHUB_ENV" - name: Install OpenVINO dependencies (mac) if: runner.os == 'macOS' @@ -65,13 +66,12 @@ jobs: pigz -dc openvino_tests.tar.gz | tar -xf - -C ${INSTALL_DIR} working-directory: ${{ env.INSTALL_DIR }} - - name: Fetch setup_python action - # Python is already installed on Ubuntu within Dockerfile - if: runner.os != 'Linux' + - name: Fetch setup_python and install wheels actions uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 with: sparse-checkout: | .github/actions/setup_python/action.yml + .github/actions/install_ov_wheels/action.yml sparse-checkout-cone-mode: false path: 'openvino' @@ -113,6 +113,12 @@ jobs: # Tests # + - name: Install OpenVINO Python wheels + uses: ./openvino/.github/actions/install_ov_wheels + with: + wheels-dir-path: ${{ env.INSTALL_WHEELS_DIR }} + wheels-to-install: 'openvino' + - name: Samples tests if: fromJSON(inputs.affected-components).samples.test run: | @@ -122,7 +128,7 @@ jobs: export SHARE=$INSTALL_TEST_DIR/smoke_tests/samples_smoke_tests_data # Install Python benchmark_app by installing openvino-*.whl - python3 -m pip install --ignore-installed PyYAML -r $INSTALL_TEST_DIR/smoke_tests/requirements.txt $INSTALL_WHEELS_DIR/openvino-*.whl + python3 -m pip install --ignore-installed PyYAML -r $INSTALL_TEST_DIR/smoke_tests/requirements.txt export LD_LIBRARY_PATH=${IE_APP_PATH}:$LD_LIBRARY_PATH source ${INSTALL_DIR}/setupvars.sh diff --git a/.github/workflows/linux_arm64.yml b/.github/workflows/linux_arm64.yml index 66ce9461f05fe8..e1aaa886d631c7 100644 --- a/.github/workflows/linux_arm64.yml +++ b/.github/workflows/linux_arm64.yml @@ -169,6 +169,16 @@ jobs: affected-components: ${{ needs.smart_ci.outputs.affected_components }} python-version: '3.11' + Python_API_Tests: + name: Python API tests + needs: [ Docker, Build, Smart_CI ] + uses: ./.github/workflows/job_python_api_tests.yml + with: + runner: 'aks-linux-16-cores-arm' + container: '{"image": "${{ fromJSON(needs.docker.outputs.images).ov_test.ubuntu_20_04_arm64 }}", "volumes": ["/mount:/mount"]}' + python-version: '3.11' + if: fromJSON(needs.smart_ci.outputs.affected_components).Python_API.test + TensorFlow_Layer_Tests: name: TensorFlow Layer Tests needs: [ Build, Docker, Smart_CI, Openvino_tokenizers ] diff --git a/.github/workflows/mac.yml b/.github/workflows/mac.yml index c587c5ad7323b3..26289e969c4e00 100644 --- a/.github/workflows/mac.yml +++ b/.github/workflows/mac.yml @@ -151,6 +151,7 @@ jobs: -DENABLE_CPPLINT=OFF \ -DENABLE_NCC_STYLE=OFF \ -DENABLE_TESTS=ON \ + -DENABLE_WHEEL=OFF \ -DCMAKE_COMPILE_WARNING_AS_ERROR=OFF \ -DENABLE_STRICT_DEPENDENCIES=OFF \ -DCMAKE_CXX_COMPILER_LAUNCHER=${{ env.CMAKE_CXX_COMPILER_LAUNCHER }} \ @@ -168,7 +169,6 @@ jobs: run: | cmake -DCMAKE_INSTALL_PREFIX=${{ env.INSTALL_DIR }} -P ${{ env.BUILD_DIR }}/cmake_install.cmake cmake -DCMAKE_INSTALL_PREFIX=${{ env.INSTALL_TEST_DIR }} -DCOMPONENT=tests -P ${{ env.BUILD_DIR }}/cmake_install.cmake - cmake -DCMAKE_INSTALL_PREFIX=${{ env.INSTALL_WHEELS_DIR }} -DCOMPONENT=python_wheels -P ${{ env.BUILD_DIR }}/cmake_install.cmake - name: Pack Artifacts run: | @@ -179,6 +179,48 @@ jobs: tar -cvf - * | pigz > ${{ env.BUILD_DIR }}/openvino_tests.tar.gz popd + # Setup additional Python versions for wheels building + - name: Setup Python 3.9 + uses: ./openvino/.github/actions/setup_python + with: + version: "3.9" + should-setup-pip-paths: 'false' + self-hosted-runner: 'false' + + - name: Setup Python 3.10 + uses: ./openvino/.github/actions/setup_python + with: + version: "3.10" + should-setup-pip-paths: 'false' + self-hosted-runner: 'false' + + - name: Setup Python 3.12 + uses: ./openvino/.github/actions/setup_python + with: + version: "3.12" + should-setup-pip-paths: 'false' + self-hosted-runner: 'false' + + - name: Build additional Python wheels + run: | + for py_version in "3.9" "3.10" "3.11" "3.12" + do + python_exec_path=$(python$py_version -c "import sys; print(sys.executable)") + $python_exec_path -m pip install -r ${{ env.OPENVINO_REPO }}/src/bindings/python/wheel/requirements-dev.txt + + cmake -DPython3_EXECUTABLE=$python_exec_path -DENABLE_WHEEL=ON -DOpenVINODeveloperPackage_DIR=${{ env.BUILD_DIR }} -S ${{ env.OPENVINO_REPO }}/src/bindings/python -B ${{ github.workspace }}/py$py_version + cmake --build ${{ github.workspace }}/py$py_version --parallel + cmake --install ${{ github.workspace }}/py$py_version --config ${{ env.CMAKE_BUILD_TYPE }} --prefix ${{ env.INSTALL_WHEELS_DIR }} --component python_wheels + done + + # Setup Python 3.11 as the default one + - name: Setup Python ${{ env.PYTHON_VERSION }} + uses: ./openvino/.github/actions/setup_python + with: + version: ${{ env.PYTHON_VERSION }} + should-setup-pip-paths: 'false' + self-hosted-runner: 'false' + - name: Cmake & Build - OpenVINO Contrib run: | cmake \ @@ -199,6 +241,7 @@ jobs: cmake --build ${{ env.BUILD_DIR }} --parallel $(nproc) cmake -DCMAKE_INSTALL_PREFIX=${{ env.INSTALL_DIR_JS }} -P ${{ env.BUILD_DIR }}/cmake_install.cmake + # # Upload build artifacts # @@ -210,7 +253,7 @@ jobs: name: openvino_package path: ${{ env.BUILD_DIR }}/openvino_package.tar.gz if-no-files-found: 'error' - + - name: Upload openvino wheels uses: actions/upload-artifact@b4b15b8c7c6ac21ea08fcf65892d2ee8f75cf882 # v4.4.3 with: @@ -270,6 +313,19 @@ jobs: affected-components: ${{ needs.smart_ci.outputs.affected_components }} os: 'mac_13' + Python_API_Tests: + name: Python API tests + needs: [ Build, Smart_CI ] + uses: ./.github/workflows/job_python_api_tests.yml + strategy: + fail-fast: false + matrix: + python-version: [ '3.9', '3.10', '3.11', '3.12' ] + with: + runner: 'macos-13' + python-version: ${{ matrix.python-version }} + if: fromJSON(needs.smart_ci.outputs.affected_components).Python_API.test + Python_Unit_Tests: name: Python unit tests needs: [ Build, Smart_CI ] diff --git a/.github/workflows/mac_arm64.yml b/.github/workflows/mac_arm64.yml index 0708a844fe6b8b..d3fb10082adfd4 100644 --- a/.github/workflows/mac_arm64.yml +++ b/.github/workflows/mac_arm64.yml @@ -151,6 +151,7 @@ jobs: -DENABLE_CPPLINT=OFF \ -DENABLE_NCC_STYLE=OFF \ -DENABLE_TESTS=ON \ + -DENABLE_WHEEL=OFF \ -DCMAKE_COMPILE_WARNING_AS_ERROR=OFF \ -DENABLE_STRICT_DEPENDENCIES=OFF \ -DCMAKE_CXX_COMPILER_LAUNCHER=${{ env.CMAKE_CXX_COMPILER_LAUNCHER }} \ @@ -168,7 +169,6 @@ jobs: run: | cmake -DCMAKE_INSTALL_PREFIX=${{ env.INSTALL_DIR }} -P ${{ env.BUILD_DIR }}/cmake_install.cmake cmake -DCMAKE_INSTALL_PREFIX=${{ env.INSTALL_TEST_DIR }} -DCOMPONENT=tests -P ${{ env.BUILD_DIR }}/cmake_install.cmake - cmake -DCMAKE_INSTALL_PREFIX=${{ env.INSTALL_WHEELS_DIR }} -DCOMPONENT=python_wheels -P ${{ env.BUILD_DIR }}/cmake_install.cmake - name: Pack Artifacts run: | @@ -180,6 +180,48 @@ jobs: tar -cvf - * | pigz > ${{ env.BUILD_DIR }}/openvino_tests.tar.gz popd + # Setup additional Python versions for wheels building + - name: Setup Python 3.9 + uses: ./openvino/.github/actions/setup_python + with: + version: "3.9" + should-setup-pip-paths: 'false' + self-hosted-runner: 'false' + + - name: Setup Python 3.10 + uses: ./openvino/.github/actions/setup_python + with: + version: "3.10" + should-setup-pip-paths: 'false' + self-hosted-runner: 'false' + + - name: Setup Python 3.12 + uses: ./openvino/.github/actions/setup_python + with: + version: "3.12" + should-setup-pip-paths: 'false' + self-hosted-runner: 'false' + + - name: Build additional Python wheels + run: | + for py_version in "3.9" "3.10" "3.11" "3.12" + do + python_exec_path=$(python$py_version -c "import sys; print(sys.executable)") + $python_exec_path -m pip install -r ${{ env.OPENVINO_REPO }}/src/bindings/python/wheel/requirements-dev.txt + + cmake -DPython3_EXECUTABLE=$python_exec_path -DENABLE_WHEEL=ON -DOpenVINODeveloperPackage_DIR=${{ env.BUILD_DIR }} -S ${{ env.OPENVINO_REPO }}/src/bindings/python -B ${{ github.workspace }}/py$py_version + cmake --build ${{ github.workspace }}/py$py_version --parallel + cmake --install ${{ github.workspace }}/py$py_version --config ${{ env.CMAKE_BUILD_TYPE }} --prefix ${{ env.INSTALL_WHEELS_DIR }} --component python_wheels + done + + # Setup Python 3.11 as the default one + - name: Setup Python ${{ env.PYTHON_VERSION }} + uses: ./openvino/.github/actions/setup_python + with: + version: ${{ env.PYTHON_VERSION }} + should-setup-pip-paths: 'false' + self-hosted-runner: 'false' + - name: Cmake & Build - OpenVINO Contrib run: | cmake \ @@ -279,6 +321,19 @@ jobs: affected-components: ${{ needs.smart_ci.outputs.affected_components }} python-version: '3.11' + Python_API_Tests: + name: Python API tests + needs: [ Build, Smart_CI ] + uses: ./.github/workflows/job_python_api_tests.yml + strategy: + fail-fast: false + matrix: + python-version: [ '3.9', '3.10', '3.11', '3.12' ] + with: + runner: 'macos-13-xlarge' + python-version: ${{ matrix.python-version }} + if: fromJSON(needs.smart_ci.outputs.affected_components).Python_API.test + TensorFlow_Layer_Tests: name: TensorFlow Layer Tests needs: [ Build, Smart_CI, Openvino_tokenizers ] diff --git a/.github/workflows/ubuntu_22.yml b/.github/workflows/ubuntu_22.yml index f4caec8b2458a0..4fc93d73213f78 100644 --- a/.github/workflows/ubuntu_22.yml +++ b/.github/workflows/ubuntu_22.yml @@ -300,6 +300,16 @@ jobs: affected-components: ${{ needs.smart_ci.outputs.affected_components }} python-version: '3.11' + Python_API_Tests: + name: Python API tests + needs: [ Docker, Build, Smart_CI ] + uses: ./.github/workflows/job_python_api_tests.yml + with: + runner: 'aks-linux-4-cores-16gb' + container: '{"image": "${{ fromJSON(needs.docker.outputs.images).ov_test.ubuntu_22_04_x64 }}", "volumes": ["/mount:/mount"]}' + python-version: '3.11' + if: fromJSON(needs.smart_ci.outputs.affected_components).Python_API.test + TensorFlow_Layer_Tests: name: TensorFlow Layer Tests needs: [ Docker, Build, Smart_CI, Openvino_tokenizers ] diff --git a/.github/workflows/ubuntu_24.yml b/.github/workflows/ubuntu_24.yml index d874e06a189232..1ad3951ecd3347 100644 --- a/.github/workflows/ubuntu_24.yml +++ b/.github/workflows/ubuntu_24.yml @@ -134,6 +134,16 @@ jobs: affected-components: ${{ needs.smart_ci.outputs.affected_components }} python-version: '3.12' + Python_API_Tests: + name: Python API tests + needs: [ Docker, Build, Smart_CI ] + uses: ./.github/workflows/job_python_api_tests.yml + with: + runner: 'aks-linux-4-cores-16gb' + container: '{"image": "${{ fromJSON(needs.docker.outputs.images).ov_test.ubuntu_24_04_x64 }}", "volumes": ["/mount:/mount"]}' + python-version: '3.12' + if: fromJSON(needs.smart_ci.outputs.affected_components).Python_API.test + Pytorch_Layer_Tests: name: Pytorch Layer Tests needs: [ Docker, Build, Smart_CI ] From 67f253764c4d0a9b7ab5a8f9706d063e488d7b5b Mon Sep 17 00:00:00 2001 From: Alina Kladieva Date: Mon, 9 Dec 2024 19:27:32 +0100 Subject: [PATCH 8/8] [GHA][ov-provider] Exclude custom release packages from matching (#27979) To filter out automatically picking unwanted custom release builds like https://storage.openvinotoolkit.org/repositories/openvino/packages/2024.5/windows_vc_mt Test run: https://github.com/openvinotoolkit/openvino_tokenizers/actions/runs/12237578864/job/34133648815?pr=338 (now the regular "windows" package is picked) Signed-off-by: Alina Kladieva --- .github/actions/openvino_provider/get_s3_package.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/actions/openvino_provider/get_s3_package.py b/.github/actions/openvino_provider/get_s3_package.py index df253a422421ec..02ea99cb2f3403 100644 --- a/.github/actions/openvino_provider/get_s3_package.py +++ b/.github/actions/openvino_provider/get_s3_package.py @@ -54,6 +54,10 @@ def main(product, version_pattern, platform, arch, folder): matching_files = filter_files_by_criteria(all_files, product, version_pattern, platform, arch, folder) if matching_files: logger.info(f"Matching packages: {sorted(matching_files)}") + if len(matching_files) > 1: + custom_release_build_pattern = fr".*/{version_pattern}/(linux_|windows_|macos_).*/.*" + # Exclude custom release builds, if any, from matches + matching_files = [file for file in matching_files if not re.search(custom_release_build_pattern, file)] package_url = f"https://storage.openvinotoolkit.org{sorted(matching_files)[-1]}" logger.info(f"Returning package URL: {package_url}") action_utils.set_github_output("package_url", package_url)