diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl index 6a5c9e54a8e904..8b351e133d83ad 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/fully_connected_gpu_bf_tiled.cl @@ -17,8 +17,6 @@ // DISPATCH_FSV - output coordinates for each sub-group are calculated from linearized coordinates // DISPATCH_BSV as if they laid in bs_fs_bsv_fsv format, these macros describe fsv and bsv factors; -#define INPUT_LOAD_SIZE 4 - #if FC_KERNEL_DYNAMIC_QUANTIZE KERNEL(quantize_input)( const __global INPUT0_TYPE* input, @@ -28,40 +26,41 @@ KERNEL(quantize_input)( const uint offset = get_global_id(0); const uint input_offset = offset * QUANTIZE_GROUP_SIZE; - const uint quantize_block = QUANTIZE_GROUP_SIZE / 4; - MAKE_VECTOR_TYPE(INPUT0_TYPE, INPUT_LOAD_SIZE) input_0[quantize_block]; - MAKE_VECTOR_TYPE(DQ_TYPE, INPUT_LOAD_SIZE) quantized_value[quantize_block]; + const uint quantize_block = QUANTIZE_GROUP_SIZE / INPUT_LOAD_SIZE; + MAKE_VECTOR_TYPE(INPUT0_TYPE, INPUT_LOAD_SIZE) input_0; + MAKE_VECTOR_TYPE(DQ_TYPE, INPUT_LOAD_SIZE) quantized_value; INPUT0_TYPE max[quantize_block]; unroll_for (uint i = 0 ; i < quantize_block ; ++i) { - input_0[i] = vload4(0, &input[input_offset + i * 4]); - max[i] = fmax(fmax(fabs(input_0[i][0]), fabs(input_0[i][1])), fmax(fabs(input_0[i][2]), fabs(input_0[i][3]))); + input_0 = vload4(0, &input[input_offset + i * 4]); + max[i] = fmax(fmax(fabs(input_0[0]), fabs(input_0[1])), fmax(fabs(input_0[2]), fabs(input_0[3]))); } - INPUT0_TYPE max_value = 0.001; + INPUT0_TYPE max_value = 0.001h; for (uint i = 0 ; i < quantize_block ; i+=8) { INPUT0_TYPE temp = fmax(fmax(fmax(max[i], max[i+1]), fmax(max[i+2], max[i+3])), fmax(fmax(max[i+4], max[i+5]), fmax(max[i+6], max[i+7]))); max_value = fmax(max_value, temp); } - half quan_scale = (half)max_value / 127; + float quan_scale = (float)max_value / 127.f; #if COMPRESSED_WEIGHTS_INT8 int quantized_sum = 0; #endif for (uint i = 0 ; i < quantize_block ; ++i) { - half4 buff = input_0[i] / (half4)quan_scale; - quantized_value[i] = CAT(CAT(convert_, MAKE_VECTOR_TYPE(DQ_TYPE, INPUT_LOAD_SIZE)), _rte)(buff); + input_0 = vload4(0, &input[input_offset + i * 4]); + float4 buff = convert_float4(input_0) / quan_scale; + quantized_value = CAT(CAT(convert_, MAKE_VECTOR_TYPE(DQ_TYPE, INPUT_LOAD_SIZE)), _rte)(buff); #if COMPRESSED_WEIGHTS_INT8 - quantized_sum += quantized_value[i][0] + quantized_value[i][1] + quantized_value[i][2] + quantized_value[i][3]; + quantized_sum += quantized_value[0] + quantized_value[1] + quantized_value[2] + quantized_value[3]; #endif - vstore4(quantized_value[i], 0, &quantized_input[input_offset + i * 4]); + vstore4(quantized_value, 0, &quantized_input[input_offset + i * 4]); } // Pair of quantizing_scale and quantized activation_sum for each group - quan_var[offset * 2] = quan_scale; + quan_var[offset * 2] = convert_half(quan_scale); #if COMPRESSED_WEIGHTS_INT8 - quan_var[(offset * 2) + 1] = CAT(CAT(convert_, INPUT0_TYPE), _rte)(quantized_sum); + quan_var[(offset * 2) + 1] = convert_half(quantized_sum); #endif } #else // !FC_KERNEL_DYNAMIC_QUANTIZE @@ -808,9 +807,6 @@ inline void FUNC(fc_bf_tiled_kernel_default)( // ===================================================================================================================================== } - - - // Dyc Quantize #if USE_SLM && DYNAMIC_QUANTIZE @@ -974,11 +970,38 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( // ===================================================================================================================================== // Main computation loop const uint iterations = MAIN_LOOP_ELEMENTS_COUNT / TILE_IFM_ELEMENTS_SIZE; // TILE_IFM_ELEMENTS_SIZE : (TILE_IFM * SIMD) - // Each sub-group loads 2 Batch - uint idx_sglid = (sglid * TILE_K) % TILE_IFM_ELEMENTS_SIZE; // same index for sglid 0~7 : to tile_k direction - uint batch_sglid = (sglid * TILE_K) / TILE_IFM_ELEMENTS_SIZE; // 0 to 1 : to batch direction - + // Each sub-group loads 2 Batch + const uint idx_sglid = (sglid * TILE_K) % TILE_IFM_ELEMENTS_SIZE; // same index for sglid 0~7 : to tile_k direction + const uint batch_sglid = (sglid * TILE_K) / TILE_IFM_ELEMENTS_SIZE; // 0 to 1 : to batch direction const uint scale_pitch = (TILE_IN_B_PITCH / QUANTIZE_GROUP_SIZE); + + #if PER_TOKEN_SIZE_DYN_QUANTIZE + // Each token is quantized by once. So, all MAIN_LOOP_ELEMENTS_COUNT share just one quantizing variable + uint per_token_offset = input_offset / QUANTIZE_GROUP_SIZE; + unroll_for (uint bi = 0; bi < TILE_B; ++bi) { + de_quantize_scale[bi] = TO_INPUT0_TYPE(quan_var[per_token_offset * 2]); + #if COMPRESSED_WEIGHTS_INT8 + activation_sum[bi] = TO_INPUT0_TYPE(quan_var[per_token_offset * 2 + 1]); + #endif + per_token_offset += scale_pitch; + } + #endif + + #if COMPRESSED_WEIGHTS_INT8 + ACCUMULATOR_TYPE wei_zp[TILE_OFM] = { }; + unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) { + #if DECOMPRESSION_ZP_TERM + #if DECOMPRESSION_ZP_SCALAR + wei_zp[fi] = (TO_ACCUMULATOR_TYPE)(DECOMPRESSION_ZP_VALUE); + #elif DECOMPRESSION_ZP_GROUPS_NUM == 1 + wei_zp[fi] = TO_ACCUMULATOR_TYPE(d_zps[fi % DECOMPRESSION_ZP_LENGTH]); + #endif + #else + wei_zp[fi] = ACCUMULATOR_VAL_ZERO; + #endif + } + #endif + MAKE_VECTOR_TYPE(int, TILE_B) acc_tmp[TILE_OFM] = { }; __attribute__((opencl_unroll_hint(1))) for (uint ni = 0; ni < iterations; ++ni) { @@ -993,7 +1016,7 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( // Next batch in_offset += (TILE_IN_B_PITCH * 2); - #if NUM_LOOP_IN_DYN_QUAN_GROUP == 1 + #if !PER_TOKEN_SIZE_DYN_QUANTIZE && (NUM_LOOP_IN_DYN_QUAN_GROUP == 1) de_quantize_scale[bi * 2] = quan_var[scale_offset * 2]; de_quantize_scale[bi * 2 + 1] = quan_var[scale_offset * 2 + scale_pitch * 2]; #if COMPRESSED_WEIGHTS_INT8 @@ -1006,7 +1029,7 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif } - #if NUM_LOOP_IN_DYN_QUAN_GROUP > 1 + #if !PER_TOKEN_SIZE_DYN_QUANTIZE && (NUM_LOOP_IN_DYN_QUAN_GROUP > 1) if (ni % NUM_LOOP_IN_DYN_QUAN_GROUP == 0) { unroll_for (uint bi = 0; bi < TILE_B; ++bi) { de_quantize_scale[bi] = quan_var[scale_offset * 2]; @@ -1045,10 +1068,6 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif uint wei_local_idx = local_id * SIMD * FILTER_LOAD_ITERS * (FILTER_LOAD_BLOCK_SIZE/2) + sglid * 2; - #if COMPRESSED_WEIGHTS_INT8 - ACCUMULATOR_TYPE wei_zp[TILE_OFM] = { }; - #endif - // DQ_DECOMPRESSION_SCALE_POST_OP SHOULD be enabled for dynamic quantize FC : scale is ACCUMULATOR_VAL_ONE unroll_for(uint load_iter = 0; load_iter < FILTER_LOAD_ITERS; ++load_iter) { #if COMPRESSED_WEIGHTS_INT4 @@ -1110,31 +1129,6 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif #endif - #if COMPRESSED_WEIGHTS_INT8 - unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) { - #if DECOMPRESSION_ZP_TERM - #if DECOMPRESSION_ZP_SCALAR - wei_zp[fi] = (TO_ACCUMULATOR_TYPE)(DECOMPRESSION_ZP_VALUE); - #elif DECOMPRESSION_ZP_GROUPS_NUM > 1 - #if FILTER_LOAD_BLOCK_SIZE % DECOMPRESSION_ZP_GROUP_SIZE != 0 - #error "FC bf_tiled kernel: Not support DECOMPRESSION_ZP_GROUPS_NUM > 1" - #endif - - const uint ni_offset = ni * TILE_IFM * SIMD + local_id * FILTER_LOAD_ITERS * FILTER_LOAD_BLOCK_SIZE; - const uint offset_ofm = out_f + fi*SIMD + sglid; - const uint offset_ifm = ni_offset + load_iter * FILTER_LOAD_BLOCK_SIZE; - const uint zp_offset = (offset_ofm % DECOMPRESSION_ZP_BATCH_NUM) * DECOMPRESSION_ZP_BATCH_PITCH + - (offset_ifm / DECOMPRESSION_ZP_GROUP_SIZE) * DECOMPRESSION_ZP_FEATURE_PITCH; - wei_zp[fi] = TO_ACCUMULATOR_TYPE(decompression_zp[zp_offset]); - #else - wei_zp[fi] = TO_ACCUMULATOR_TYPE(d_zps[fi % DECOMPRESSION_ZP_LENGTH]); - #endif - #else - wei_zp[fi] = ACCUMULATOR_VAL_ZERO; - #endif - } - #endif - #if FILTER_LOAD_BLOCK_SIZE == 2 SLM_WEIGHT_VEC wei_1 = {dq_wei_unpacked.s01, dq_wei_unpacked.s23}; char_slm_weight[wei_local_idx] = as_uint(wei_1); @@ -1162,6 +1156,21 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #else weights_idx += SIMD * FILTER_ACTUAL_LOAD_BLOCK_SIZE; #endif + + #if COMPRESSED_WEIGHTS_INT8 && DECOMPRESSION_ZP_TERM && DECOMPRESSION_ZP_GROUPS_NUM > 1 && !DECOMPRESSION_ZP_SCALAR + unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) { + #if FILTER_LOAD_BLOCK_SIZE % DECOMPRESSION_ZP_GROUP_SIZE != 0 + #error "FC bf_tiled kernel: Not support DECOMPRESSION_ZP_GROUPS_NUM > 1" + #endif + + const uint ni_offset = ni * TILE_IFM * SIMD + local_id * FILTER_LOAD_ITERS * FILTER_LOAD_BLOCK_SIZE; + const uint offset_ofm = out_f + fi*SIMD + sglid; + const uint offset_ifm = ni_offset + load_iter * FILTER_LOAD_BLOCK_SIZE; + const uint zp_offset = (offset_ofm % DECOMPRESSION_ZP_BATCH_NUM) * DECOMPRESSION_ZP_BATCH_PITCH + + (offset_ifm / DECOMPRESSION_ZP_GROUP_SIZE) * DECOMPRESSION_ZP_FEATURE_PITCH; + wei_zp[fi] = TO_ACCUMULATOR_TYPE(decompression_zp[zp_offset]); + } + #endif } wei_local_idx = sglid * 2; @@ -1199,7 +1208,7 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif #if COMPRESSED_WEIGHTS_INT8 - ACCUM_DQ_TYPE modified_calc_buff = ((int *)(&acc_tmp[fi]))[bi] - ((float)(wei_zp[fi]) * (convert_float)(activation_sum[bi])); + ACCUM_DQ_TYPE modified_calc_buff = ((int *)(&acc_tmp[fi]))[bi] - ((float)(wei_zp[fi]) * activation_sum[bi]); ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += (convert_half)(convert_float(modified_calc_buff) * (float)ds * (float)de_quantize_scale[bi]); #else ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += convert_half(((int *)(&acc_tmp[fi]))[bi]) * ds * de_quantize_scale[bi]; @@ -1210,7 +1219,7 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif } // Whole tile_k elements of each iteration : ki - #if DQ_DECOMPRESSION_SCALE_POST_OP && (TILE_IFM_ELEMENTS_SIZE <= DECOMPRESSION_SCALE_GROUP_SIZE) + #if !PER_TOKEN_SIZE_DYN_QUANTIZE && DQ_DECOMPRESSION_SCALE_POST_OP && (TILE_IFM_ELEMENTS_SIZE <= DECOMPRESSION_SCALE_GROUP_SIZE) // Dynamic-quantizing group size set to same or smaller than scale group size if ((ni % NUM_LOOP_IN_DYN_QUAN_GROUP) == (NUM_LOOP_IN_DYN_QUAN_GROUP - 1)) { const uint ni_offset = ((ni*TILE_IFM*SIMD) / DECOMPRESSION_SCALE_GROUP_SIZE)*DECOMPRESSION_SCALE_FEATURE_PITCH; @@ -1226,7 +1235,7 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif #if COMPRESSED_WEIGHTS_INT8 - ACCUM_DQ_TYPE modified_calc_buff = ((int *)(&acc_tmp[fi]))[bi] - ((float)(wei_zp[fi]) * (convert_float)(activation_sum[bi])); + ACCUM_DQ_TYPE modified_calc_buff = ((float)((int *)(&acc_tmp[fi]))[bi]) - ((float)(wei_zp[fi]) * activation_sum[bi]); ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += (convert_half)(convert_float(modified_calc_buff) * (float)ds * (float)de_quantize_scale[bi]); #else ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += convert_half(((int *)(&acc_tmp[fi]))[bi]) * ds * de_quantize_scale[bi]; @@ -1238,6 +1247,20 @@ inline void FUNC(fc_bf_tiled_kernel_dyn_quan)( #endif } // Main compute loop : ni + #if PER_TOKEN_SIZE_DYN_QUANTIZE + unroll_for (uint bi = 0; bi < TILE_B; ++bi) { + unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) { + ACCUMULATOR_TYPE ds = d_scales[fi % DECOMPRESSION_SCALE_LENGTH]; + #if COMPRESSED_WEIGHTS_INT8 + float modified_calc_buff = ((float)((int *)(&acc_tmp[fi]))[bi]) - ((float)(wei_zp[fi]) * activation_sum[bi]); + ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] = (convert_half)(modified_calc_buff) * ds * de_quantize_scale[bi]; + #else + ((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] = convert_half(((int *)(&acc_tmp[fi]))[bi]) * ds * de_quantize_scale[bi]; + #endif + } + } + #endif + // ===================================================================================================================================== // Post-processing: bias, activation, fused-ops for (uint bi = 0; bi < TILE_B; ++bi) { diff --git a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp index 0774c62add1643..51dfac1b2b4fee 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp +++ b/src/plugins/intel_gpu/src/kernel_selector/kernels/fully_connected/fully_connected_kernel_bf_tiled.cpp @@ -11,7 +11,8 @@ static constexpr size_t lws_batches = 8; static constexpr size_t simd = 16; -static constexpr size_t min_quantize_grp_size = 32; +static constexpr size_t input_load_size = 4; +static constexpr size_t min_quantize_grp_size = (simd * 2); // SIMD * (min value of tile_ifm) static constexpr size_t min_slm_size = 256; static std::vector available_quantize_grp_size = {128, 64, 32}; @@ -52,17 +53,38 @@ static std::pair get_output_aligned_bf_size(const fully_connecte return {output_b, output_f}; } +static size_t get_scale_group_size(const fully_connected_params& params) { + return params.weights.IFM().v / params.decompression_scale.Feature().v; +} + +static bool is_8bit_asym_wei(const fully_connected_params& params) { + auto weight_type = params.weights.GetDType(); + // UINT8 weight type is supported by FC dyn-quantize(with SLM). + if (weight_type == WeightsType::UINT8 && params.has_decompression_zp) + return true; + + return false; +} + static bool is_weight_dyn_quantizable(const fully_connected_params& params) { auto weight_type = params.weights.GetDType(); if (weight_type == WeightsType::INT4 || weight_type == WeightsType::UINT4) return true; - // UINT8 weight type is supported by FC dyn-quantize(with SLM). - if (weight_type == WeightsType::UINT8) + // No validated case of sym 8bit weight + if (is_8bit_asym_wei(params)) return true; return false; } +static bool is_per_token_dynamic_quantize(const fully_connected_params& params) { + auto dynamic_quantization_group_size = params.dynamic_quantization_group_size; + if (dynamic_quantization_group_size == UINT64_MAX) + return true; + + return false; + } + // DYNAMIC_QUANTIZE static size_t get_dynamic_quantize_group_size(const fully_connected_params& params) { auto dynamic_quantization_group_size = params.dynamic_quantization_group_size; @@ -87,9 +109,35 @@ static size_t get_dynamic_quantize_group_size(const fully_connected_params& para } } - const size_t scale_group_size = params.weights.IFM().v / params.decompression_scale.Feature().v; + size_t scale_group_size = get_scale_group_size(params); + size_t zp_group_num = params.decompression_zero_point.Feature().v; + size_t zp_group_size = 0; + if (params.has_decompression_zp) + zp_group_size = params.weights.IFM().v / params.decompression_zero_point.Feature().v; + + // Per-token dyn-quan + if (dynamic_quantization_group_size >= min_quantize_grp_size && is_per_token_dynamic_quantize(params)) { + // Validate size to fit dyn-quan group to the size of weight-scale and weight-zp + if ((scale_group_size % min_quantize_grp_size) == 0 && scale_group_size > min_quantize_grp_size) { + dynamic_quantization_group_size = scale_group_size; + + // For int8 ASYM model, activation_sum should fit to weight zp + if (is_8bit_asym_wei(params) && params.has_decompression_zp == true && + dynamic_quantization_group_size > zp_group_size && (zp_group_size % input_load_size) == 0) { + dynamic_quantization_group_size = zp_group_size; + } + + GPU_DEBUG_LOG << "FC dyn-quantize by per-token. Actual dyn_quan_group_size(" << dynamic_quantization_group_size + << ") : From scale_group_size (" << scale_group_size << ", zp_group_size(" << zp_group_size + << "), zp_group_num(" << zp_group_num << "), ifm_size (" << get_input_bf_size(params).second << ")" << std::endl; + return (size_t)dynamic_quantization_group_size; + } + } + + // Grouped-size dyn-quan : use aligned sizes which are in 'available_quantize_grp_size' for (auto group_size : available_quantize_grp_size) { - if (dynamic_quantization_group_size >= group_size) { + if (dynamic_quantization_group_size >= group_size && + (scale_group_size % group_size) == 0) { dynamic_quantization_group_size = group_size; if (dynamic_quantization_group_size > scale_group_size) { @@ -97,6 +145,7 @@ static size_t get_dynamic_quantize_group_size(const fully_connected_params& para << dynamic_quantization_group_size << ". Reduce FC dyn-quan group size to scale size." << std::endl; dynamic_quantization_group_size = scale_group_size; } + return (size_t)dynamic_quantization_group_size; } } @@ -104,7 +153,7 @@ static size_t get_dynamic_quantize_group_size(const fully_connected_params& para return 0; } -static bool should_dynamic_quantize(const fully_connected_params& params, bool print_log = false) { +static bool should_dynamic_quantize(const fully_connected_params& params) { size_t dynamic_quantization_group_size = get_dynamic_quantize_group_size(params); if (params.inputs[0].GetFirstElementOffset() != 0) @@ -116,24 +165,25 @@ static bool should_dynamic_quantize(const fully_connected_params& params, bool p return false; } + const size_t scale_group_size = get_scale_group_size(params); + if ((scale_group_size % min_quantize_grp_size) != 0) + return false; + auto threads = get_input_bf_size(params); auto input_b = threads.first; auto input_f = threads.second; - const size_t scale_group_size = params.weights.IFM().v / params.decompression_scale.Feature().v; if ((scale_group_size % simd == 0) && (input_f % dynamic_quantization_group_size == 0) && (params.is_shape_agnostic || (params.inputs[0].Batch().v > 1 && input_b > min_slm_size)) && params.inputs[0].GetDType() == Datatype::F16 && is_weight_dyn_quantizable(params)) { - if (print_log) { - GPU_DEBUG_TRACE_DETAIL << " Dynamic quantizing for FC : scale_group_size: " << scale_group_size << - ", Dyn-quan group size: " << dynamic_quantization_group_size << - ", Type(I:" << kernel_selector::toString(params.inputs[0].GetDType()) << - ", O:" << kernel_selector::toString(params.outputs[0].GetDType()) << - ", W:" << kernel_selector::toString(params.weights.GetDType()) << - "), Format(W:" << kernel_selector::toString(params.weights.GetLayout()) << - ") B: " << params.inputs[0].Batch().v << ", F: " << params.inputs[0].Feature().v << - ", Y: " << params.inputs[0].Y().v << std ::endl; - } + GPU_DEBUG_TRACE_DETAIL << " Dynamic quantizing for FC : scale_group_size: " << scale_group_size << + ", Dyn-quan group size: " << dynamic_quantization_group_size << + ", Type(I:" << kernel_selector::toString(params.inputs[0].GetDType()) << + ", O:" << kernel_selector::toString(params.outputs[0].GetDType()) << + ", W:" << kernel_selector::toString(params.weights.GetDType()) << + "), Format(W:" << kernel_selector::toString(params.weights.GetLayout()) << + ") B: " << params.inputs[0].Batch().v << ", F: " << params.inputs[0].Feature().v << + ", Y: " << params.inputs[0].Y().v << std ::endl; return true; } @@ -621,7 +671,7 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para if (weights_dt == WeightsType::UINT4 || weights_dt == WeightsType::INT4) { tile_k_ofm_packed /= 2; jit.Merge(make_int4_packed_type_jit_constant("INT4_PACKED_TYPE", weights_dt, tile_k_ofm)); - const size_t scale_group_size = params.weights.IFM().v / params.decompression_scale.Feature().v; + const size_t scale_group_size = get_scale_group_size(params); // Do not use SCALE_POST_OP for SLM kernel, since it demonstrates worse performance if (scale_group_size % simd == 0 && !dispatchData.use_slm) add_decompress_scale_post_op = true; @@ -703,16 +753,21 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para jit.AddConstant(MakeJitConstant("DYNAMIC_QUANTIZE", 1)); jit.AddConstant(MakeJitConstant("DQ_DECOMPRESSION_SCALE_POST_OP", 1)); jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", quantize_grp_size)); + + jit.AddConstant(MakeJitConstant("PER_TOKEN_SIZE_DYN_QUANTIZE", + is_per_token_dynamic_quantize(params) && quantize_grp_size == get_input_bf_size(params).second)); } else { if (add_decompress_scale_post_op) jit.AddConstant(MakeJitConstant("DECOMPRESSION_SCALE_POST_OP", 1)); jit.AddConstant(MakeJitConstant("DYNAMIC_QUANTIZE", 0)); jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", min_quantize_grp_size)); } - jit.AddConstant(MakeJitConstant("DQ_TYPE", "char")); + jit.AddConstant(MakeJitConstant("INPUT_LOAD_SIZE", input_load_size)); + jit.AddConstant(MakeJitConstant("DQ_TYPE", "char")); jit.AddConstant(MakeJitConstant("IFM_SIZE", get_input_bf_size(params).second)); jit.AddConstant(MakeJitConstant("SIMD", simd)); + jit.AddConstant(MakeJitConstant("TILE_B", dispatchData.tile_m)); jit.AddConstant(MakeJitConstant("HALF_TILE_B", dispatchData.tile_m/2)); jit.AddConstant(MakeJitConstant("TILE_OFM", dispatchData.tile_n)); @@ -840,19 +895,21 @@ void FullyConnected_bf_tiled::GetUpdateDispatchDataFunc(KernelData& kd) const { kd.kernels[0].skip_execution = false; size_t input_f = get_input_bf_size(prim_params).second; size_t input_size = input_f * dispatchData.tile_m * dispatchData.gws[2]; + OPENVINO_ASSERT(quantize_grp_size != 0, "Error: quantize_grp_size is zero."); + // half type of de_quan_scale and activation sum for each quantized group + size_t quan_var_size = (input_size / quantize_grp_size) * 2 * 2; - if (kd.internalBufferSizes[0] < input_size) { + if (kd.internalBufferSizes[0] < input_size || + kd.internalBufferSizes[1] < quan_var_size) { kd.internalBufferSizes.clear(); // quantized input is char type kd.internalBufferSizes.push_back(input_size); - // half type of de_quan_scale and activation sum for each quantized group - OPENVINO_ASSERT(quantize_grp_size != 0, "Error: quantize_grp_size is zero."); - kd.internalBufferSizes.push_back((input_size / quantize_grp_size) * 2 * 2); + // float type of de_quan_scale and activation sum for each quantized group + kd.internalBufferSizes.push_back(quan_var_size); } - OPENVINO_ASSERT(quantize_grp_size != 0, "Error: quantize_grp_size is zero."); - kd.kernels[0].params.workGroups.global = {std::max((input_size / quantize_grp_size), (size_t)1), 1, 1}; - kd.kernels[0].params.workGroups.local = {16, 1, 1}; + kd.kernels[0].params.workGroups.global = {(std::max((input_size / quantize_grp_size), (size_t)1)), 1, 1}; + kd.kernels[0].params.workGroups.local = {1, 1, 1}; } } }; @@ -1027,8 +1084,10 @@ KernelsData FullyConnected_bf_tiled::GetMultiKernelsData(const Params ¶ms, auto input_size = std::max(fc_params.inputs[0].PhysicalSize(), get_input_bf_size(fc_params).second); if (!params.is_shape_agnostic) input_size = std::max(input_size, Align(get_input_bf_size(fc_params).first, lws_batches) * get_input_bf_size(fc_params).second); - dyn_quan_dispatch.gws = {input_size / quantize_grp_size, 1, 1}; - dyn_quan_dispatch.lws = {16, 1, 1}; + + dyn_quan_dispatch.gws = {(input_size / quantize_grp_size), 1, 1}; + dyn_quan_dispatch.lws = {1, 1, 1}; + quan_kernel.params.workGroups.global = dyn_quan_dispatch.gws; quan_kernel.params.workGroups.local = dyn_quan_dispatch.lws; quan_kernel.skip_execution = false; @@ -1059,7 +1118,7 @@ KernelsData FullyConnected_bf_tiled::GetMultiKernelsData(const Params ¶ms, // char type quantized input kd.internalBufferSizes.push_back(input_size); // half type of de_quan_scale and activation sum for each quantized group - kd.internalBufferSizes.push_back(input_size / quantize_grp_size * 2 * 2); + kd.internalBufferSizes.push_back((input_size / quantize_grp_size) * 2 * 2); kernel_number++; } kd.internalBufferDataType = Datatype::F16; diff --git a/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp b/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp index 5bc7e403d3bf74..8a90d137ed7e5a 100644 --- a/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp +++ b/src/plugins/intel_gpu/tests/unit/test_cases/fully_connected_gpu_test.cpp @@ -2846,7 +2846,7 @@ class fully_connected_gpu_tests: public ::testing::Test { } void test_compressed_int4_scale_dyn_quan_weight_i4(bool is_dynamic, int batch = 1, int ifm = 512, int ofm = 2048, - int quantize_group_size = 32, int scales_group_size = 128, + size_t quantize_group_size = 32, int scales_group_size = 128, bool is_wzp_test = false, bool is_wzp_scalar = false) { tests::random_generator rg(GET_SUITE_NAME); auto& engine = get_test_engine(); @@ -2972,7 +2972,7 @@ class fully_connected_gpu_tests: public ::testing::Test { } void test_compressed_int8_scale_dyn_quan_weight_u8(bool is_dynamic, int batch = 1, int ifm = 512, int ofm = 2048, - int quantize_group_size = 32, int scales_group_size = 128, + size_t quantize_group_size = 32, int scales_group_size = 128, bool is_wzp_test = false, bool is_wzp_scalar = false) { tests::random_generator rg(GET_SUITE_NAME); auto& engine = get_test_engine(); @@ -3065,8 +3065,9 @@ class fully_connected_gpu_tests: public ::testing::Test { auto inst = network->get_primitive("fc_prim"); auto impl = inst->get_impl(); ASSERT_TRUE(impl != NULL); - auto kernel_num = (is_dynamic) ? 3 : 2; - kernel_num = (quantize_group_size < 32) ? 2 : kernel_num; + // For UINT8 weight, SLM kernel (no dyn-quan) would not be selected + auto kernel_num = (is_dynamic) ? 3 : 1; + kernel_num = (quantize_group_size < 32) ? 1 : kernel_num; ASSERT_EQ(impl->get_kernels().size(), size_t(kernel_num)); } @@ -4194,6 +4195,7 @@ TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_wzp_sta this->test_compressed_int4_scale_dyn_quan_weight_i4(false, 320, 1024, 1024, 32, 32, true); } +// Test weight zp for INT8 ASYM TEST_F(fully_connected_gpu_tests, compressed_int8_scale_dynamic_quantize_wzp_128_large) { this->test_compressed_int8_scale_dyn_quan_weight_u8(true, 320, 4096, 4096, 128, 128, true); } @@ -4210,10 +4212,6 @@ TEST_F(fully_connected_gpu_tests, compressed_int8_scale_dynamic_quantize_wzp_32_ this->test_compressed_int8_scale_dyn_quan_weight_u8(true, 320, 4096, 4096, 32, 32, true); } -TEST_F(fully_connected_gpu_tests, compressed_int8_scale_dynamic_quantize_wzp_32_large_unaligned) { - this->test_compressed_int8_scale_dyn_quan_weight_u8(true, 310, 1024, 1024, 32, 32, true); -} - TEST_F(fully_connected_gpu_tests, compressed_int8_scale_dynamic_quantize_wzp_128_small) { this->test_compressed_int8_scale_dyn_quan_weight_u8(true, 16, 1024, 1024, 128, 128, true); } @@ -4222,6 +4220,23 @@ TEST_F(fully_connected_gpu_tests, compressed_int8_scale_dynamic_quantize_wzp_128 this->test_compressed_int8_scale_dyn_quan_weight_u8(true, 1, 1024, 1024, 128, 128, true); } +// Test per-token dyn-quan +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_test_fake_per_token) { + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 600, 1024, 2048, -1, 32, true); +} + +TEST_F(fully_connected_gpu_tests, compressed_int4_scale_dynamic_quantize_test_per_token) { + this->test_compressed_int4_scale_dyn_quan_weight_i4(true, 600, 1024, 2048, -1, 1024, true); +} + +TEST_F(fully_connected_gpu_tests, compressed_int8_scale_dynamic_quantize_test_per_token_small_scale) { + this->test_compressed_int8_scale_dyn_quan_weight_u8(true, 600, 1024, 2048, -1, 32, true); +} + +TEST_F(fully_connected_gpu_tests, compressed_int8_scale_dynamic_quantize_test_per_token_full_scale) { + this->test_compressed_int8_scale_dyn_quan_weight_u8(true, 600, 1024, 2048, -1, 1024, true); +} + TEST_F(fully_connected_gpu_tests, compressed_scale_bias) { this->test_compressed_scale_bias(false); }