Skip to content

Commit

Permalink
[GPU] Copy weights to SLM and load from it
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Oct 31, 2023
1 parent 9894ca7 commit 8428338
Show file tree
Hide file tree
Showing 5 changed files with 628 additions and 28 deletions.
8 changes: 8 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1058,6 +1058,14 @@ primitive_inst::primitive_inst(network& network, program_node const& node, bool
_shape_info_memory = _network.get_engine().allocate_memory(layout{{shape_elements}, data_types::i32, format::bfyx});
}
}
if (_node) {
std::stringstream ss;
try {
ss << _node->type()->to_string(*_node);
} catch(const std::exception& e) {
}
GPU_DEBUG_TRACE_DETAIL << id() << ": node's parameters:\n" << ss.str() << std::endl;
}
_impl_params->strm = _network.get_stream_ptr();
if (_outputs[0])
max_output_layout_size = _outputs[0]->get_layout().get_tensor().count();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,11 @@ KERNEL(fc)(
, FUSED_OPS_DECLS
#endif
) {
uint gid = (uint)get_group_id(0);
// TODO: check old HW to allocate less than 4K restiction for SLM

uint gid = (uint)get_group_id(0) * 8;
uint local_id = (uint)get_local_id(2);
gid += local_id;
uint sglid = (uint)get_sub_group_local_id();

// Dispatch as bs_fs_bsv_fsv, where bsv = DISPATCH_BSV and fsv = DISPATCH_FSV.
Expand All @@ -123,7 +127,7 @@ KERNEL(fc)(
ACCUMULATOR_VEC_TYPE acc[TILE_B] = { };
INPUT_VEC_TYPE in_0[TILE_B] = { };

FILTER_VEC_TYPE wei = 0;
__local ACCUMULATOR_TYPE wei_local[512];
uint input_offset = out_b * TILE_IN_B_PITCH + INPUT0_OFFSET;
#if COMPRESSED_WEIGHTS_INT4
uint weights_offset = out_f * (INPUT_ELEMENTS_COUNT / 2);
Expand Down Expand Up @@ -168,6 +172,7 @@ KERNEL(fc)(
// For fp16 we need to ensure that all block reads are aligned to 4 byte (2 words) boundary.
// To do this solve first input feature separately.
{
#error "REALLIGN NEEDED"
INPUT0_TYPE tmp_input = input[input_offset + get_sub_group_local_id() % TILE_B * TILE_IN_B_PITCH];
ACCUMULATOR_VEC_TYPE tmp_wei = TO_ACCUMULATOR_VEC_TYPE(BLOCK_READN(FILTER_TYPE, TILE_OFM, weights, weights_offset));
#if COMPRESSED_WEIGHTS
Expand Down Expand Up @@ -202,12 +207,49 @@ KERNEL(fc)(
ACCUMULATOR_VEC_TYPE acc_tmp[TILE_B] = { };
#endif

FILTER_VEC_TYPE wei;
wei = TO_FILTER_VEC_TYPE(FILTER_BLOCK_READ(weights, weights_offset + local_id * (SIMD * TILE_K_OFM))); // 64 = SIMD * TILE_K_OFM

#if TILE_K_OFM != 4
#error "Can not re-arrange weights"
#endif

// __local wei layout: os_iyx_osv16_osv2

uint wei_local_idx = (local_id * (SIMD * TILE_OFM)) + sglid;

#define FILTER_VEC2 MAKE_VECTOR_TYPE(ACCUMULATOR_TYPE, 2)
__local FILTER_VEC2* wei2_local = (__local FILTER_VEC2*)&wei_local;
#undef FILTER_VEC2

// if ((sglid == 0 || sglid == 15 || sglid == 14) && ni == 0) {
// printf("gid=%d; sglid=%d, out_b=%d, out_f=%d, weights_offset=%d, iterations=%d, wei_local_idx=%d, wei(%f,%f), weigths_read_from=%d\n", gid, sglid, out_b, out_f, weights_offset, iterations, wei_local_idx, wei.s0, wei.s1, weights_offset + local_id * (SIMD * TILE_K_OFM));
// }

wei2_local[wei_local_idx] = wei.s01;
wei_local_idx += SIMD;
wei2_local[wei_local_idx] = wei.s23;
wei_local_idx -= SIMD;

wei_local_idx = sglid;

barrier(CLK_LOCAL_MEM_FENCE);

unroll_for(uint ki = 0; ki < (TILE_IFM * SIMD) / TILE_K; ++ki) {
#if COMPRESSED_WEIGHTS_INT4
FILTER_PACKED_VEC_TYPE wei_packed = FILTER_BLOCK_READ(weights, weights_offset);
wei = UNPACK_INT4x2(ACCUMULATOR_TYPE, *((INT4_PACKED_TYPE*)&wei_packed));
#error "Unexpected here"
#else
wei = TO_FILTER_VEC_TYPE(FILTER_BLOCK_READ(weights, weights_offset));
// wei = TO_FILTER_VEC_TYPE(FILTER_BLOCK_READ(weights, weights_offset));
// try to load from SLM
wei.s01 = wei2_local[wei_local_idx];
wei_local_idx += SIMD;
wei.s23 = wei2_local[wei_local_idx];
wei_local_idx += SIMD;
// wei_local_idx -= SIMD;

// wei_local_idx += 2 * SIMD;
#endif

#if COMPRESSED_WEIGHTS
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,7 @@ KernelsData FullyConnectedKernelBase::GetCommonKernelsData(const Params &params,
GetSupportedKey());

if (!succeed) {
std::cout << "Not a succeeded\n";
return {};
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ bool FullyConnected_bf_tiled::Validate(const Params& params, const optional_para
return false;
}

std::cout << "FC tiled " << static_cast<int>(input.GetLayout()) << " " << static_cast<int>(output.GetLayout()) << "\n";

if (input.GetLayout() == DataLayout::bfyx) {
// Padding on input is not supported.
// TODO: Enable by mirroring the padding in weights.
Expand All @@ -112,6 +114,7 @@ bool FullyConnected_bf_tiled::Validate(const Params& params, const optional_para
return false;
}

std::cout << "FC tiled - OK" << "\n";
return true;
}

Expand All @@ -126,6 +129,7 @@ struct TuneParamsSelector {
TuneParamsSelector& Case(const tune_params& tparams) {
if (!selected && VerifyTuneParams(params, tparams)) {
result = tparams;
std::cout << "Selected\n";
selected = true;
}
return *this;
Expand Down Expand Up @@ -241,7 +245,7 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params,
// tune_params(tile_b, tile_ofm, tile_ifm, tile_k, dispatch_bsv, dispatch_fsv, exec_options)
selector.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 16, 2, EXE_MODE_AGE_BASED))
.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 16, 1, EXE_MODE_AGE_BASED))
.Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 4, 2, EXE_MODE_AGE_BASED))
// .Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 4, 2, EXE_MODE_AGE_BASED))
.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 8, 1, EXE_MODE_AGE_BASED))
.Case(tune_params(16, std::min(max_tile_ofm, 2u), 1, 2, 2, 2, EXE_MODE_AGE_BASED))
.Case(tune_params(8, std::min(max_tile_ofm, 2u), 1, 2, 4, 1, EXE_MODE_AGE_BASED))
Expand Down Expand Up @@ -277,7 +281,11 @@ FullyConnected_bf_tiled::GetAutoTuneParams(const fully_connected_params& params,
});
}

return selector.Default(tune_params(1, 1, 1, 1, 1, 1, EXE_MODE_DEFAULT));
auto tuning_res = selector.Default(tune_params(1, 1, 1, 1, 1, 1, EXE_MODE_DEFAULT));

std::cout << "Tuning params:" << " tile_b=" << tuning_res.tile_b << " tile_ofm=" << tuning_res.tile_ofm << " tile_ifm=" << tuning_res.tile_ifm << " tile_k=" << tuning_res.tile_k << " dispatch_bsv=" << tuning_res.dispatch_bsv << " dispatch_fsv=" << tuning_res.dispatch_fsv << "\n";

return tuning_res;
}

FullyConnected_bf_tiled::DispatchData
Expand All @@ -294,13 +302,15 @@ FullyConnected_bf_tiled::SetDefault(const fully_connected_params& params, int au

batch_threads = CeilDiv(batch_threads, tparams.tile_b);

dispatchData.gws[0] = feature_threads * batch_threads * simd;
std::cout << "Kernel params:" << " feature_threads=" << feature_threads << " batch_threads=" << batch_threads << " simd=" << simd << "\n";

dispatchData.gws[0] = feature_threads * simd;
dispatchData.gws[1] = 1;
dispatchData.gws[2] = 1;
dispatchData.gws[2] = 8; // 8

dispatchData.lws[0] = simd;
dispatchData.lws[1] = 1;
dispatchData.lws[2] = 1;
dispatchData.lws[2] = 8; // = batch / batch_threads = 64 / 8 = 8

dispatchData.tile_m = tparams.tile_b;
dispatchData.tile_n = tparams.tile_ofm;
Expand Down Expand Up @@ -358,6 +368,7 @@ JitConstants FullyConnected_bf_tiled::GetJitConstants(const fully_connected_para

auto activation_dt = GetActivationType(params);
auto accumulator_dt = GetAccumulatorType(params);
std::cout << "accumulator_dt= " << static_cast<int>(accumulator_dt) << "\n";
jit.Merge(MakeTypeJitConstants(activation_dt, "ACTIVATION"));
jit.Merge(MakeActivationJitConstants(params.activations, activation_dt, "_TYPED"));
jit.Merge(MakeTypeJitConstants(accumulator_dt, "ACCUMULATOR"));
Expand Down Expand Up @@ -439,6 +450,7 @@ KernelsData FullyConnected_bf_tiled::GetKernelsDataForAutoTune(const Params& par
}
}

std::cout << "Res2 " << res.size() << '\n';
return res;
}

Expand All @@ -452,6 +464,8 @@ KernelsData FullyConnected_bf_tiled::GetKernelsData(const Params& params, const
res.emplace_back(kds[0]);
}

std::cout << "Res " << res.size() << '\n';

return res;
}

Expand Down
Loading

0 comments on commit 8428338

Please sign in to comment.