Skip to content

Commit

Permalink
WIP: [GPU] Debug accuracy issue
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Dec 28, 2024
1 parent 88cc7de commit 584b137
Show file tree
Hide file tree
Showing 11 changed files with 614 additions and 9 deletions.
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/include/intel_gpu/runtime/memory.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,8 @@ struct memory {
std::shared_ptr<MemoryTracker> get_mem_tracker() const { return m_mem_tracker; }
GPU_DEBUG_CODE(bool from_memory_pool = false);

void print_memory(stream& stream, layout data_layout, std::string name, bool add_paddings) const;

protected:
engine* _engine;
const layout _layout;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/debug_helper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,14 @@ void dump(memory::ptr mem, stream& stream, std::ofstream& file_stream, bool dump
file_stream << "shape: " << size.to_string() << " ";
file_stream << "(count: " << size.count()
<< ", addr: " << mem->buffer_ptr()
<< ", original dt: " << mem->get_layout().data_type
<< ", original format: " << cldnn::fmt_to_str(mem->get_layout().format) << ")"
<< (dump_raw ? " raw data" : "") << std::endl;
} else {
file_stream << "shape: " << tmp_size.to_string() << " ";
file_stream << "(count: " << tmp_size.count()
<< ", addr: " << mem->buffer_ptr()
<< ", original dt: " << mem->get_layout().data_type
<< ", original format: " << cldnn::fmt_to_str(mem->get_layout().format)
<< ", original shape: " << size.to_string() << ")"
<< (dump_raw ? " raw data" : "") << std::endl;
Expand Down
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/impls/ocl/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -431,6 +431,8 @@ struct kv_cache_impl : multi_stage_primitive<kv_cache> {
auto comp_scale_past_layout = impl_param.input_layouts[input_idx];
auto comp_scale_present_layout = impl_param.output_layouts[output_idx];

GPU_DEBUG_TRACE_DETAIL << "Update params, input: " << comp_scale_past_layout.to_short_string() << ", output: " << comp_scale_past_layout << "\n";

params.inputs.resize(inputs_count);
params.inputs[0] = convert_data_tensor(comp_scale_past_layout);
params.outputs[0] = convert_data_tensor(comp_scale_present_layout);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
for (size_t i = 0; i < instance.get_intermediates_memories().size(); i++)
args.intermediates.push_back(instance.get_intermediates_memories()[i]);

stream.finish();

const auto impl_params = instance.get_impl_params();
const auto& desc = impl_params->typed_desc<scaled_dot_product_attention>();
if (desc->is_kv_compressed) {
instance.input_memory_ptr(4)->print_memory(stream, impl_params->get_input_layout(4), desc->id + " key_scale", true);
instance.input_memory_ptr(5)->print_memory(stream, impl_params->get_input_layout(5), desc->id + " val_scale", true);
instance.input_memory_ptr(6)->print_memory(stream, impl_params->get_input_layout(6), desc->id + " key_zp", true);
instance.input_memory_ptr(7)->print_memory(stream, impl_params->get_input_layout(7), desc->id + " val_zp", true);
}

stream.set_arguments(*_kernels[idx_final], _kernels_data[stage].kernels[kd_idx].params, args);

const auto& gws = params.workGroups.global;
Expand Down Expand Up @@ -313,7 +324,11 @@ struct scaled_dot_product_attention_impl : multi_stage_primitive<scaled_dot_prod
}

if (desc->is_kv_compressed) {
GPU_DEBUG_TRACE_DETAIL << "Update scale layout: " << impl_param.get_input_layout(data_inputs_num) << "\n";
params.key_cache_comp_scale = convert_data_tensor(impl_param.get_input_layout(data_inputs_num));
GPU_DEBUG_TRACE_DETAIL << "Updated scale tensor pad: " << params.key_cache_comp_scale.Y().pad.before << " "
<< params.key_cache_comp_scale.Y().pad.after << "\n";

params.value_cache_comp_scale = convert_data_tensor(impl_param.get_input_layout(data_inputs_num + 1));

if (has_zp_input_buffers) {
Expand Down
16 changes: 15 additions & 1 deletion src/plugins/intel_gpu/src/graph/kv_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,21 @@ int32_t kv_cache_inst::get_prealloc_iter_num() {
// iteration.
// - Therfore, to avoid this situation where the allocation and copying occurs simutaneously for all the kv_cache_insts,
// we assigned different prealloc-size for each kv cache so that we could prevent a memory peak
return 128 + kv_cache_id % 64;
int KV_STRIDE = 0;
if (const auto env_var = std::getenv("KV_STRIDE")) {
std::istringstream ss(env_var);
ss >> KV_STRIDE;
static bool print_once = true;
if (print_once) {
std::cout << ">>> KV_STRIDE = " << KV_STRIDE << "\n";
print_once = false;
}
}
if (KV_STRIDE != 0) {
return KV_STRIDE;
}

return 128 + kv_cache_id % 64;;
}

void kv_cache_inst::update_shape_info_tensor(const kernel_impl_params& params) {
Expand Down
90 changes: 86 additions & 4 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_micro.cl
Original file line number Diff line number Diff line change
Expand Up @@ -151,10 +151,10 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG
#endif
int d, int k, int q
#ifdef KV_COMPRESSED
, const global KEY_ATTR_SCALES_DATA_T *K_scales
, const global KEY_ATTR_ZP_DATA_T *K_zp
, const global VAL_ATTR_SCALES_DATA_T *V_scales
, const global VAL_ATTR_ZP_DATA_T *V_zp
, global KEY_ATTR_SCALES_DATA_T *K_scales
, global KEY_ATTR_ZP_DATA_T *K_zp
, global VAL_ATTR_SCALES_DATA_T *V_scales
, global VAL_ATTR_ZP_DATA_T *V_zp
#endif
) {
uint sg_ij = sub_group_broadcast(get_local_id(1), 0);
Expand All @@ -164,6 +164,87 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG

uint wg_j0 = get_group_id(0) * ugemm_kq_wg_tile_n;





// if (get_global_id(0) == 0 && get_global_id(1) == 0 && get_global_id(2) == 0) {
// for (int j = 0; j < 32; j++) {
// #ifdef KV_COMPRESSED
// for (int i = k; i < k + KEY_SCALE_PAD_AFTER_SIZE_Y; i++) {
// K_scales[(j * (k + KEY_SCALE_PAD_AFTER_SIZE_Y)) + i] = 0;
// K_zp[(j * (k + KEY_SCALE_PAD_AFTER_SIZE_Y)) + i] = 0;
// }

// for (int i = k; i < k + VAL_SCALE_PAD_AFTER_SIZE_Y; i++) {
// V_scales[(j * (k + VAL_SCALE_PAD_AFTER_SIZE_Y)) + i] = 0;
// V_zp[(j * (k + VAL_SCALE_PAD_AFTER_SIZE_Y)) + i] = 0;
// }
// #endif
// }

// printf("h=0 key scales[%p]={", K_scales);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", K_scales[i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=0 key zp[%p]={", K_zp);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", K_zp[i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=0 val scales[%p]={", V_scales);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", V_scales[i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=0 key zp[%p]={", V_zp);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", V_zp[i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);


// printf("h=1 key scales[%p]={", K_scales);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", K_scales[k + KEY_SCALE_PAD_AFTER_SIZE_Y + i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=1 key zp[%p]={", K_zp);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", K_zp[k + KEY_SCALE_PAD_AFTER_SIZE_Y + i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=1 val scales[%p]={", V_scales);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", V_scales[k + VAL_SCALE_PAD_AFTER_SIZE_Y + i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=1 key zp[%p]={", V_zp);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", V_zp[k + VAL_SCALE_PAD_AFTER_SIZE_Y + i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);


// printf("h=31 key scales[%p]={", K_scales);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", K_scales[(31 * (k + KEY_SCALE_PAD_AFTER_SIZE_Y)) + i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=31 key zp[%p]={", K_zp);
// for (int i = 0; i < max(k, k + KEY_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", K_zp[(31 * (k + KEY_SCALE_PAD_AFTER_SIZE_Y)) + i]);
// printf("total_num=%d, pad=%d\n", k, KEY_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=31 val scales[%p]={", V_scales);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%f, ", V_scales[(31 * (k + VAL_SCALE_PAD_AFTER_SIZE_Y)) + i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);

// printf("h=31 key zp[%p]={", V_zp);
// for (int i = 0; i < max(k, k + VAL_SCALE_PAD_AFTER_SIZE_Y); i++)
// printf("%d, ", V_zp[(31 * (k + VAL_SCALE_PAD_AFTER_SIZE_Y)) + i]);
// printf("total_num=%d, pad=%d\n", k, VAL_SCALE_PAD_AFTER_SIZE_Y);
// }
/* Leading dimension for matrices */
uint ldk = TRANSPOSE_K ? KEY_S3 : KEY_S2;
uint ldq = QRY_S2;
Expand Down Expand Up @@ -307,6 +388,7 @@ KERNEL(micro_sdpa)(OPTIONAL_SHAPE_INFO_ARG

#if WITH_ATTN_MASK
mask_tile_type mask_tile;
// tile_load_t(&mask_tile, msk, q, k, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
tile_load_t(&mask_tile, msk, q, k, q, sg_j0_kq + wg_j0, k0 + sg_i0_kq);
#endif

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,15 @@ Tensor::Dim get_seq_length(const DataTensor& qkv, const std::vector<int64_t>& or
return normalize_dims(qkv)[order[2]];
}

std::string get_kernel_name(const std::string& kernel_name, const sdpa_params& params, bool is_prefill) {
auto name = kernel_name + (is_prefill ? "_prefill" : "_generate");

if (params.conf.is_kv_compressed)
name += "_compressed";

return name;
}

struct sdpa_config_t {
int unroll_m_kq, unroll_n_kq; // Subgroup tile sizes for K*Q GEMM
int unroll_m_vs, unroll_n_vs; // Subgroup tile sizes for V*S GEMM
Expand Down Expand Up @@ -272,8 +281,10 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag
opts_kq.scaleA = params.conf.is_kv_compressed && !kq_common_scales;
opts_kq.offsetA = params.conf.is_kv_compressed && params.conf.use_asymmetric_quantization;

// auto key_dt_size = micro::data_type_size(convert_type(params.inputs[1].GetDType()));
problem_kq.B.layout = micro::MatrixLayout::Pr;
problem_kq.C.layout = micro::MatrixLayout::T;
// problem_kq.A.setAlignment(micro::alignment_for_ld(head_size * key_dt_size));
problem_kq.A.setAlignment(micro::alignment_for_ld(head_size * problem.Ta));
problem_kq.B.setAlignment(64); // Q is packed in VNNI format in SLM
problem_kq.B.crosspack = 2;
Expand Down Expand Up @@ -337,8 +348,10 @@ void SDPAKernelMicro::init_microkernels(const sdpa_params& params, micro::Packag
opts_vs.scaleA = params.conf.is_kv_compressed && !vs_common_scales;
opts_vs.offsetA = params.conf.is_kv_compressed && params.conf.use_asymmetric_quantization;

// auto val_dt_size = micro::data_type_size(convert_type(params.inputs[2].GetDType()));
problem_vs.B.layout = micro::MatrixLayout::Pr;
problem_vs.C.layout = micro::MatrixLayout::N;
// problem_vs.A.setAlignment(micro::alignment_for_ld(head_size * val_dt_size));
problem_vs.A.setAlignment(micro::alignment_for_ld(head_size * problem.Ta));
problem_vs.B.setAlignment(64); // S is packed in SLM
problem_vs.B.crosspack = 16;
Expand Down Expand Up @@ -536,6 +549,12 @@ JitConstants SDPAKernelMicro::GetJitConstants(const sdpa_params& params, const m
// TODO: Causes accuracy drop for static SD model. Enable back once the issue is resolved
// if (lda % 4 == 0 && v_full)
// jit.AddConstant(MakeJitConstant("BLOCK_A", 1));
// if (params.inputs.size() > 3 && !params.inputs[3].is_dynamic()) {
// auto ldmsk = params.inputs[3].X().v * params.inputs[3].ElementSize();
// if (ldmsk % 4 == 0)
// jit.AddConstant(MakeJitConstant("BLOCK_MSK", 1));
// }
// if (ldmsk % 4 == 0) kernel_ctx.define_int("BLOCK_MSK", 1);
jit.AddConstant(MakeJitConstant("REMAINDER_Q", !q_full));
} else if (params.engineInfo.arch >= gpu_arch::xe_hpc) {
auto vbytes = n_values.v * V.ElementSize();
Expand Down Expand Up @@ -629,7 +648,7 @@ CommonDispatchData SDPAKernelMicro::SetDefault(const sdpa_params& params, const
}

clKernelData SDPAKernelMicro::get_kernel_data(const sdpa_params& params, bool is_prefill) const {
auto name = kernelName + (is_prefill ? "_prefill" : "_generate");
auto name = get_kernel_name(kernelName, params, is_prefill);

std::vector<micro::Package> gemms(2); // KQ and VS
init_microkernels(params, gemms[kq_id], gemms[vs_id], is_prefill);
Expand Down Expand Up @@ -753,6 +772,12 @@ void SDPAKernelMicro::GetUpdateDispatchDataFunc(KernelData& kd) const {
const auto n_queries = get_seq_length(Q, prim_params.input0_order);
const auto n_keys = get_seq_length(K, prim_params.input1_order);

GPU_DEBUG_TRACE_DETAIL << "Key scale pad_before=" << prim_params.key_cache_comp_scale.Y().pad.before
<< "pad_after=" << prim_params.key_cache_comp_scale.Y().pad.after << "\n";

GPU_DEBUG_TRACE_DETAIL << "Value scale pad_before=" << prim_params.value_cache_comp_scale.Y().pad.before
<< "pad_after=" << prim_params.value_cache_comp_scale.Y().pad.after << "\n";

auto head_size = prim_params.conf.head_size;

ScalarDescriptor s_d;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,10 @@ static std::string GetKernelName(std::string base_name, KernelsTypes type, const
kernel_name += "_finalization";
}

if (params.conf.is_kv_compressed) {
kernel_name += "_compressed";
}

return kernel_name;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,14 @@ KernelsData SDPAKernelRef::GetKernelsData(const Params& params) const {
return {};
}

auto kernel_name = kernelName;
if (prim_params.conf.is_kv_compressed)
kernel_name += "_compressed";

auto dispatchData = SetDefault(prim_params);
auto entry_point = GetEntryPoint(kernelName, prim_params.layerID, params);
auto entry_point = GetEntryPoint(kernel_name, prim_params.layerID, params);
auto cldnn_jit = GetJitConstants(prim_params);
auto jit = CreateJit(kernelName, cldnn_jit, entry_point);
auto jit = CreateJit(kernel_name, cldnn_jit, entry_point);

auto& kernel = kd.kernels[0];

Expand Down
Loading

0 comments on commit 584b137

Please sign in to comment.