Skip to content

Commit

Permalink
Update kernel, qwen-7b needs USE_REF_DQ=1
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Oct 10, 2024
1 parent e38d2aa commit 390502c
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,33 @@ inline uint FUNC(get_scales_offset)(OPTIONAL_SHAPE_INFO_ARG uint b, uint f, uint
}

#define SUBGROUP_SIZE 16
#define HEAD_SIZE 128
#define NUM_HEADS 32
#define INNERMOST_DIM_VALUE INPUT0_SIZE_X
#define INPUT_BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, 1, ptr, offset)
#define OUTPUT_BLOCK_WRITE(ptr, offset, val) BLOCK_WRITEN(OUTPUT_TYPE, 1, ptr, offset, val)

__attribute__((reqd_work_group_size(1, NUM_HEADS * SUBGROUP_SIZE, 1)))
__attribute__((reqd_work_group_size(SUBGROUP_SIZE, SUBGROUPS_NUMBER, 1)))
REQD_SUB_GROUP_SIZE(SUBGROUP_SIZE)
KERNEL(dynamic_quantize_gpu_opt)(
KERNEL(dynamic_quantize_gpu_opt_generic)(
OPTIONAL_SHAPE_INFO_ARG
const __global INPUT0_TYPE* input,
__global OUTPUT_TYPE* output,
__global OUTPUT1_TYPE* output_scale)
{
const uint batch_indexes = get_global_id(0);
const uint head_idx = get_global_id(1) / SUBGROUP_SIZE;
const uint sglid = get_sub_group_local_id();
// const uint data_indexes = get_global_id(1);
const uint grouped_indexes = get_global_id(1);
const uint batch_indexes = get_global_id(2);

DECLARE_BATCHED_DIMS_INDEXES(batch_indexes);
const uint f = head_idx;
DECLARE_GROUPED_DIMS_INDEXES(grouped_indexes);

// the innermost dimension is always handled in the loop inside the kernel
const uint x = 0;

half max_value = 0.0001h;
half val[HEAD_SIZE / SUBGROUP_SIZE];
half val[INNERMOST_DIM_VALUE / SUBGROUP_SIZE];

const uint data_offset = INPUT0_GET_INDEX(b, f, y, x);
unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {
unroll_for (uint i = 0; i < INNERMOST_DIM_VALUE / SUBGROUP_SIZE; i++) {
// val[i] = input[data_offset + i * SUBGROUP_SIZE + sglid];
val[i] = INPUT_BLOCK_READ(input, data_offset + i * SUBGROUP_SIZE);
max_value = fmax(max_value, fabs(val[i]));
Expand All @@ -71,7 +71,7 @@ KERNEL(dynamic_quantize_gpu_opt)(

half scale = 127.0h / max_value;

unroll_for (uint i = 0; i < HEAD_SIZE / SUBGROUP_SIZE; i++) {
unroll_for (uint i = 0; i < INNERMOST_DIM_VALUE / SUBGROUP_SIZE; i++) {
OUTPUT_BLOCK_WRITE(output, data_offset + i * SUBGROUP_SIZE, convert_char(val[i] * scale));
// output[data_offset + i * SUBGROUP_SIZE + sglid] = convert_char(val[i] * scale);
}
Expand All @@ -82,6 +82,6 @@ KERNEL(dynamic_quantize_gpu_opt)(
const uint scale_idx = OUTPUT1_GET_INDEX_SAFE(b, f, y, x);
#endif

if (get_global_id(1) == 0 && get_global_id(2) == 0)
if (grouped_indexes == 0 && sglid == 0)
output_scale[scale_idx] = 1.0h / scale;
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
#include <string>


static constexpr size_t simd = 16;
static constexpr size_t subgroup_size = 16;

namespace kernel_selector {
static Tensor::NDims get_normalized_dims(const DataTensor& tensor) {
Expand Down Expand Up @@ -145,12 +145,18 @@ JitConstants DynamicQuantizeKernelOptGeneric::GetJitConstants(const dynamic_quan
else
grouped_dims.push_back(default_dims[i]);
}

const auto& input_dims = get_normalized_dims(params.inputs[0]);
const auto total_grouped_elements = get_elements_number_per_group(params);
const auto per_iter_elements_number = get_per_iter_elements_number(params);
const auto total_subgroups_number = total_grouped_elements / input_dims.back().v;

// drop the last dimensions, since it will be processed inside kernel
grouped_dims.pop_back();

jit.AddConstant(MakeJitConstant("DECLARE_BATCHED_DIMS_INDEXES(data_idx)", generate_dims_indexes_calculation(batch_dims)));
jit.AddConstant(MakeJitConstant("DECLARE_GROUPED_DIMS_INDEXES(data_idx)", generate_dims_indexes_calculation(grouped_dims)));
jit.AddConstant(MakeJitConstant("LWS_SIZE", per_iter_elements_number));
jit.AddConstant(MakeJitConstant("SUBGROUPS_NUMBER", total_subgroups_number));

const auto iterations_number = total_grouped_elements / per_iter_elements_number;

Expand Down Expand Up @@ -192,12 +198,16 @@ JitConstants DynamicQuantizeKernelOptGeneric::GetJitConstants(const dynamic_quan
CommonDispatchData DynamicQuantizeKernelOptGeneric::SetDefault(const dynamic_quantize_params& params) const {
CommonDispatchData dispatchData;

const auto& input_dims = get_normalized_dims(params.inputs[0]);
const auto total_batched_elements = get_elements_number_per_batch(params);
// const auto total_grouped_elements = get_elements_number_per_group(params);
const auto total_grouped_elements = get_elements_number_per_group(params);
const auto total_subgroups_number = total_grouped_elements / input_dims.back().v;
// const auto per_iter_elements_number = get_per_iter_elements_number(params);

dispatchData.gws = {total_batched_elements, 32 * 16, 1};
dispatchData.lws = {1, 32 * 16, 1};
// TODO: add check that input_dims.back().v / SUBGROUP_SIZE is enough to allocate private array inside kernel

dispatchData.gws = {subgroup_size, total_subgroups_number, total_batched_elements};
dispatchData.lws = {subgroup_size, total_subgroups_number, 1};

return dispatchData;
}
Expand Down Expand Up @@ -274,6 +284,10 @@ bool DynamicQuantizeKernelOptGeneric::Validate(const Params& params) const {
}
}

// last dimension should be static, reduced by group_sizes configuration and divisible by 16
if (group_sizes.back() == 1 || input_dims.back().is_dynamic || input_dims.back().v % subgroup_size != 0)
return false;

if (dq_params.inputs[0].GetPaddedVal() != 0 || dq_params.outputs[0].GetPaddedVal() != 0)
return false;

Expand Down

0 comments on commit 390502c

Please sign in to comment.