Skip to content

Commit

Permalink
[GPU] Allow default scales order for dynamic_quantize_gpu_opt kernel
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Oct 30, 2024
1 parent 8da8a30 commit 21d8c71
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -160,8 +160,13 @@ bool DynamicQuantizeKernelOpt::Validate(const Params& params) const {
if (dq_params.group_sizes.back() != UINT64_MAX)
return false;

if (!dq_params.scales_output_order.empty())
return false;
// Allow only default scales order
const auto& scales_output_order = dq_params.scales_output_order;
if (!scales_output_order.empty()) {
for (size_t i = 0; i < scales_output_order.size(); i++)
if (scales_output_order[i] != i)
return false;
}

return true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,9 @@ JitConstants DynamicQuantizeKernelRef::GetJitConstants(const dynamic_quantize_pa
jit.AddConstant(MakeJitConstant("ASYMMETRIC_QUANTIZATION", params.use_asymmetric_quantization));
jit.AddConstant(MakeJitConstant("GROUP_SCALES_WITH_ZP", params.combine_scales_and_zp));

const auto& group_sizes = params.group_sizes;
auto group_sizes = params.group_sizes;
group_sizes.resize(std::min(4LU, group_sizes.size()), 1);

for (size_t i = 0; i < group_sizes.size(); i++) {
jit.AddConstant(MakeJitConstant("GROUP_SIZE_DIM" + std::to_string(i), group_sizes[i]));
}
Expand All @@ -68,7 +70,8 @@ CommonDispatchData DynamicQuantizeKernelRef::SetDefault(const dynamic_quantize_p

OPENVINO_ASSERT(params.outputs[0].GetLayout() == DataLayout::bfyx, "It supports only 4d tensor");

const auto& group_sizes = params.group_sizes;
auto group_sizes = params.group_sizes;
group_sizes.resize(std::min(4LU, group_sizes.size()), 1);
auto batch_size = group_sizes[0] == 1 ? params.outputs[0].Batch().v : 1;
auto feature_size = group_sizes[1] == 1 ? params.outputs[0].Feature().v : 1;
auto y_size = group_sizes[2] == 1 ? params.outputs[0].Y().v : 1;
Expand Down Expand Up @@ -134,10 +137,6 @@ bool DynamicQuantizeKernelRef::Validate(const Params& params) const {
if (!KernelBaseOpenCL::Validate(params))
return false;

const auto& prim_params = static_cast<const dynamic_quantize_params&>(params);
if (prim_params.group_sizes.size() != 4)
return false;

return true;
}
} // namespace kernel_selector

0 comments on commit 21d8c71

Please sign in to comment.