Skip to content

Commit

Permalink
[GPU] Add missing code in dynamic fc impl
Browse files Browse the repository at this point in the history
* Add acc_tmp in general calc in fc funcion in common include file
* Set DECOMPRESSION_SCALE_POST_OP=ON for dynamic qunatization case
  • Loading branch information
ahnyoung-paul committed Dec 12, 2024
1 parent 2d78f2a commit 3f907fd
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,7 @@ inline void (FUNC_NAME)(
// NOTE: Manually unrolling multiplication loop leads to lower register pressure and allows for bigger block sizes,
// but significantly degrades readability and generality of code.
// It doesn't also show noticable performance improvement on tested configurations.
#if DECOMPRESSION_SCALE_POST_OP
ACCUMULATOR_VEC_TYPE acc_tmp[FORCED_TILE_B] = { };
#endif
ACCUMULATOR_VEC_TYPE acc_tmp[FORCED_TILE_B] = { };

unroll_for(uint ki = 0; ki < (TILE_IFM * SIMD) / TILE_K; ++ki) {
#if COMPRESSED_WEIGHTS_INT4
Expand Down Expand Up @@ -201,11 +199,7 @@ inline void (FUNC_NAME)(
unroll_for (uint bi = 0; bi < FORCED_TILE_B; ++bi) {
INPUT0_TYPE in_val = _sub_group_shuffle(((INPUT0_TYPE*)(&in_0[bi]))[total_k / SIMD], total_k % SIMD);
unroll_for (uint fi = 0; fi < TILE_OFM; ++fi) {
#if DECOMPRESSION_SCALE_POST_OP
((ACCUMULATOR_TYPE*)(&acc_tmp[bi]))[fi] += in_val * ((ACCUMULATOR_TYPE*)(&wei))[W_IDX];
#else
((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += in_val * ((ACCUMULATOR_TYPE*)(&wei))[W_IDX];
#endif
}
}
}
Expand Down Expand Up @@ -243,6 +237,16 @@ inline void (FUNC_NAME)(
}
}
#endif

#if !DECOMPRESSION_SCALE_POST_OP
unroll_for (uint bi = 0; bi < FORCED_TILE_B; ++bi) {
unroll_for(uint fi = 0; fi < TILE_OFM; ++fi) {
((ACCUMULATOR_TYPE*)(&acc[bi]))[fi] += ((ACCUMULATOR_TYPE*)(&acc_tmp[bi]))[fi];
}
}
#endif


}
// =====================================================================================================================================
// Leftovers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -698,14 +698,14 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para
jit.AddConstant(MakeJitConstant("USE_SLM", 0));
}

if (add_decompress_scale_post_op)
jit.AddConstant(MakeJitConstant("DECOMPRESSION_SCALE_POST_OP", 1));
// Validated perf gain, Dynamic quantize force enable SCALE_POST_OP for char type multiplication
if (should_dynamic_quantize(params)) {
jit.AddConstant(MakeJitConstant("DYNAMIC_QUANTIZE", 1));
jit.AddConstant(MakeJitConstant("DQ_DECOMPRESSION_SCALE_POST_OP", 1));
jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", quantize_grp_size));
} else {
if (add_decompress_scale_post_op)
jit.AddConstant(MakeJitConstant("DECOMPRESSION_SCALE_POST_OP", 1));
jit.AddConstant(MakeJitConstant("DYNAMIC_QUANTIZE", 0));
jit.AddConstant(MakeJitConstant("QUANTIZE_GROUP_SIZE", min_quantize_grp_size));
}
Expand Down

0 comments on commit 3f907fd

Please sign in to comment.