Skip to content

Commit

Permalink
[GPU] Gemm tiled opt add dynamic padding support
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Dec 13, 2023
1 parent a3de4c1 commit d575d2b
Show file tree
Hide file tree
Showing 4 changed files with 222 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ KERNEL(gemm_tiled_opt)(

// Start pointers offsets
#if !TRANSPOSE_INPUT0
const __global INPUT0_TYPE* a_ptr = input0 + batch_offset_input0 + tile_m_offset * K;
const __global INPUT0_TYPE* a_ptr = input0 + batch_offset_input0 + tile_m_offset * K_PADDED_IN0;
#else // !TRANSPOSE_INPUT0
const __global INPUT0_TYPE* a_ptr = input0 + batch_offset_input0 + tile_m_offset;
#endif // !TRANSPOSE_INPUT0
Expand Down Expand Up @@ -153,7 +153,13 @@ KERNEL(gemm_tiled_opt)(
// Loading B tile
unroll_for (uint b_load_id = 0; b_load_id < TILE_K; b_load_id++) {
#if IS_DYNAMIC
#if HAS_DYNAMIC_N_PADDING
// In case of dynamic padding we can't guarantee memory access alignment for
// block reads (4 bytes), so use scattered read
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else
b_tile[b_load_id] = TILE_N_NOT_DIVISIBLE ? (b_raw_global_id > N - 1 ? 0 : b_ptr[sglid]) : BLOCK_READ_B(b_ptr, 0);
#endif
#else // IS_DYNAMIC
#if TILE_N_NOT_DIVISIBLE
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
Expand All @@ -162,7 +168,7 @@ KERNEL(gemm_tiled_opt)(
#endif // TILE_N_NOT_DIVISIBLE
#endif // IS_DYNAMIC
#if !TRANSPOSE_INPUT1
b_ptr += N;
b_ptr += N_PADDED;
#else // !TRANSPOSE_INPUT1
b_ptr += K;
#endif // !TRANSPOSE_INPUT1
Expand Down Expand Up @@ -203,7 +209,13 @@ KERNEL(gemm_tiled_opt)(
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
#if !TRANSPOSE_INPUT0
#if IS_DYNAMIC
A_FLOATN a_read = TILE_K_NOT_DIVISIBLE ? a_ptr[dot_id * K + sglid] : BLOCK_READ_A(a_ptr, dot_id * K);
#if HAS_DYNAMIC_K_PADDING
// In case of dynamic padding we can't guarantee memory access alignment for
// block reads (4 bytes), so use scattered read
A_FLOATN a_read = a_ptr[dot_id * K_PADDED_IN0 + sglid];
#else
A_FLOATN a_read = TILE_K_NOT_DIVISIBLE ? a_ptr[dot_id * K_PADDED_IN0 + sglid] : BLOCK_READ_A(a_ptr, dot_id * K);
#endif
#else // IS_DYNAMIC
#if TILE_K_NOT_DIVISIBLE
A_FLOATN a_read = a_ptr[dot_id * K + sglid];
Expand Down Expand Up @@ -273,13 +285,17 @@ KERNEL(gemm_tiled_opt)(
if (TILE_K_NOT_DIVISIBLE) {
// Loading leftovers of the matrix B
unroll_for (uint b_load_id = 0; b_load_id < TILE_K_LEFTOVER; b_load_id++) {
#if HAS_DYNAMIC_N_PADDING
b_tile[b_load_id] = b_raw_global_id > N - 1 ? 0 : b_ptr[sglid];
#else
b_tile[b_load_id] = TILE_N_NOT_DIVISIBLE ? (b_raw_global_id > N - 1 ? 0 : b_ptr[sglid]) : BLOCK_READ_B(b_ptr, 0);
b_ptr += N;
#endif
b_ptr += N_PADDED;
} // Loading leftovers of the matrix B end

// Loading leftovers of the matrix A and tile C calculation
unroll_for (uint dot_id = 0; dot_id < tile_m_iterations; dot_id++) {
INPUT0_TYPE a_read = a_ptr[dot_id * K + sglid];
INPUT0_TYPE a_read = a_ptr[dot_id * K_PADDED_IN0 + sglid];

unroll_for (uint simd_id = 0; simd_id < TILE_K_LEFTOVER; simd_id++) {
c_tile[dot_id] = mad((INPUT0_TYPE)(sub_group_broadcast(a_read, simd_id)), b_tile[simd_id], c_tile[dot_id]);
Expand Down
7 changes: 5 additions & 2 deletions src/plugins/intel_gpu/src/kernel_selector/jitter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,13 @@ std::string toCodeString(const Tensor::Dim& dim, size_t offset, bool padded, boo
pad_str = " + " + std::to_string(dim.pad.Total());
}
}
if (dim.is_dynamic || pad_is_dynamic) {
if (dim.is_dynamic) {
snprintf(buf, sizeof(buf), "(shape_info[%zu] %s)", offset, pad_str.c_str());
} else {
snprintf(buf, sizeof(buf), "%zu", dim.v + (padded ? dim.pad.Total() : 0));
if (pad_is_dynamic)
snprintf(buf, sizeof(buf), "(%zu %s)", dim.v, pad_str.c_str());
else
snprintf(buf, sizeof(buf), "%zu", dim.v + (padded ? dim.pad.Total() : 0));
}
return buf;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ ParamsKey GemmKernelTiledOpt::GetSupportedKey() const {
k.EnableOutputLayout(DataLayout::bfwzyx);

k.EnableTensorOffset();
k.EnableTensorPitches();
k.EnableBatching();
k.EnableDifferentTypes();
k.EnableDynamicShapesSupport();
Expand Down Expand Up @@ -64,7 +65,7 @@ GemmKernelTiledOpt::GemmTuningData GemmKernelTiledOpt::SetTuningParams(const gem

GemmKernelTiledOpt::GemmTuningData tuning_data;

if (!params.is_shape_agnostic) {
if (!params.is_shape_agnostic) {
auto m_size = output.Y().v;
auto n_size = output.X().v;
auto k_size = params.transpose_input0 ? params.inputs[0].Y().v : params.inputs[0].X().v;
Expand Down Expand Up @@ -112,11 +113,17 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
if (params.has_dynamic_tensors()) {
DimensionAccessHelper dims0(params.inputs[0]);
DimensionAccessHelper dims1(params.inputs[1]);
DimensionAccessHelper dims0_padded(params.inputs[0], true);
DimensionAccessHelper dims1_padded(params.inputs[1], true);
// Note: Actually currently this kernel is not being selected if it is shape agnostic impl && transposed inputs
// Because we cannot get the original rank
auto m_size = params.transpose_input0 ? dims0.x() : dims0.y();
auto n_size = params.transpose_input1 ? dims1.y() : dims1.x();
auto n_padded_size = params.transpose_input1 ? "(" + dims1_padded.y() + ")"
: "(" + dims1_padded.x() + ")";
auto k_size = params.transpose_input0 ? dims0.y() : dims0.x();
auto k_padded_size_in0 = params.transpose_input0 ? "(" + dims0_padded.y() + ")"
: "(" + dims0_padded.x() + ")";
const std::string leftover_m = "(" + m_size + "%" + std::to_string(tuning_data.tile_m_size) + ")";
const std::string leftover_n = "(" + n_size + "%" + std::to_string(tuning_data.tile_n_size) + ")";
const std::string leftover_k = "(" + k_size + "%" + std::to_string(tuning_data.tile_k_size) + ")";
Expand All @@ -129,6 +136,8 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
MakeJitConstant("M", m_size),
MakeJitConstant("K", k_size),
MakeJitConstant("N", n_size),
MakeJitConstant("K_PADDED_IN0", k_padded_size_in0),
MakeJitConstant("N_PADDED", n_padded_size),
MakeJitConstant("SIMD_WIDTH", tuning_data.simd_size),
MakeJitConstant("TILE_M", tuning_data.tile_m_size),
MakeJitConstant("TILE_K", tuning_data.tile_k_size),
Expand All @@ -141,6 +150,15 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
MakeJitConstant("TILE_K_LEFTOVER", leftover_k),
MakeJitConstant("TILE_N_LEFTOVER", leftover_n),
});

bool has_dynamic_k_padding = params.transpose_input0 ? params.inputs[0].Y().pad.is_dynamic
: params.inputs[0].X().pad.is_dynamic;
bool has_dynamic_n_padding = params.transpose_input1 ? params.inputs[1].Y().pad.is_dynamic
: params.inputs[1].X().pad.is_dynamic;
if (has_dynamic_k_padding)
jit.AddConstant(MakeJitConstant("HAS_DYNAMIC_K_PADDING", 1));
if (has_dynamic_n_padding)
jit.AddConstant(MakeJitConstant("HAS_DYNAMIC_N_PADDING", 1));
} else {
auto m_size = output.Y().v;
auto n_size = output.X().v;
Expand All @@ -153,6 +171,8 @@ JitConstants GemmKernelTiledOpt::GetJitConstants(const gemm_params& params) cons
MakeJitConstant("M", m_size),
MakeJitConstant("K", k_size),
MakeJitConstant("N", n_size),
MakeJitConstant("K_PADDED_IN0", k_size),
MakeJitConstant("N_PADDED", n_size),
MakeJitConstant("SIMD_WIDTH", tuning_data.simd_size),
MakeJitConstant("TILE_M", tuning_data.tile_m_size),
MakeJitConstant("TILE_K", tuning_data.tile_k_size),
Expand Down Expand Up @@ -235,10 +255,24 @@ bool GemmKernelTiledOpt::Validate(const Params& params, const optional_params& o
return false;

const auto& gmm_params = static_cast<const gemm_params&>(params);
for (auto input : gmm_params.inputs) {
// Only supports outer padding as first element offset
if (input.X().pad.Total() != 0 || input.Y().pad.Total() != 0 || input.Z().pad.Total() != 0 ||
input.Feature().pad.Total() != 0)

if (gmm_params.outputs[0].PitchesDifferFromLogicalDims())
return false;

for (size_t input_idx = 0; input_idx < gmm_params.inputs.size(); ++input_idx) {
auto& input = gmm_params.inputs[input_idx];
// Supports outer padding as first element offset and dynamic padding for Batch, Feature, X, Y dimensions for first and second inputs
// in case of shape agnostic kernel
bool proper_pad_f = input.Feature().pad.is_dynamic ? false : input.Feature().pad.Total() == 0;
bool proper_pad_x = input.X().pad.is_dynamic ? false : input.X().pad.Total() == 0;
bool proper_pad_y = input.Y().pad.is_dynamic ? false : input.Y().pad.Total() == 0;
if (gmm_params.is_shape_agnostic && input_idx < 2) {
proper_pad_f |= input.Feature().pad.is_dynamic;
proper_pad_x |= input.X().pad.is_dynamic;
proper_pad_y |= input.Y().pad.is_dynamic;
}

if (!proper_pad_x || !proper_pad_y || input.Z().pad.Total() != 0 || !proper_pad_f)
return false;
}

Expand Down
157 changes: 157 additions & 0 deletions src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,159 @@ class gemm_gpu_tests: public ::testing::Test {
}
}

void test_dynamic_padding(bool is_caching_test) {
tests::random_generator rg;
rg.set_seed(GET_SUITE_NAME);

auto& engine = get_test_engine();

const unsigned long BATCH_SIZE = 31;
const unsigned long M_SIZE = 11;
const unsigned long K_SIZE = 37;
const unsigned long N_SIZE = 49;

auto fill_mem = [&](cldnn::memory_ptr mem, std::vector<ov::float16>& data) {
cldnn::mem_lock<ov::float16> mem_ptr(mem, get_test_stream());
auto&& l = mem->get_layout();
auto data_idx = 0;
for (cldnn::tensor::value_type b = 0; b < l.batch(); ++b) {
for (cldnn::tensor::value_type f = 0; f < l.feature(); ++f) {
for (cldnn::tensor::value_type y = 0; y < l.spatial(1); ++y) {
for (cldnn::tensor::value_type x = 0; x < l.spatial(0); ++x) {
auto tensor_coord = cldnn::tensor{{b, f, x, y}, 0};
auto buffer_idx = l.get_linear_offset(tensor_coord);
mem_ptr[buffer_idx] = data[data_idx++];
}
}
}
}
};

const auto align_size_m = 13;
const auto align_size_k = 16;
const auto align_size_n = 15;
const auto align_size_b1 = 3;
const auto align_size_b2 = 19;

const auto aligned_batch1_size = align_to(1ul, align_size_b1);
auto padding_size_batch1 = static_cast<int>(aligned_batch1_size - 1);

const auto aligned_batch2_size = align_to(BATCH_SIZE, align_size_b2);
auto padding_size_batch2 = static_cast<int>(aligned_batch2_size - BATCH_SIZE);

const auto aligned_m_size = align_to(M_SIZE, align_size_m);
auto padding_size_m = static_cast<int>(aligned_m_size - M_SIZE);
const auto aligned_k_size = align_to(K_SIZE, align_size_k);
auto padding_size_k = static_cast<int>(aligned_k_size - K_SIZE);
const auto aligned_n_size = align_to(N_SIZE, align_size_n);
auto padding_size_n = static_cast<int>(aligned_n_size - N_SIZE);

ov::Shape in1_shape = { 1, BATCH_SIZE, M_SIZE, K_SIZE };
ov::Shape in2_shape = { 1, BATCH_SIZE, K_SIZE, N_SIZE };
ov::Shape in1_shape_aligned = { aligned_batch1_size, aligned_batch2_size, aligned_m_size, aligned_k_size };
ov::Shape in2_shape_aligned = { aligned_batch1_size, aligned_batch2_size, aligned_k_size, aligned_n_size };

// Use dynamic padding for all BFYX dimensions
tensor dyn_pad_dims_input({1, 1, 1, 1}, 0);

auto in1_layout = layout{ {-1, -1, -1, -1}, data_types::f16, format::bfyx, padding({0, 0, 0, 0}, {0, 0, 0, 0}, 0.0f, dyn_pad_dims_input)};
auto in2_layout = layout{ {-1, -1, -1, -1}, data_types::f16, format::bfyx, padding({0, 0, 0, 0}, {0, 0, 0, 0}, 0.0f, dyn_pad_dims_input)};

auto aligned_input1_mem = engine.allocate_memory({ov::PartialShape(in1_shape_aligned), data_types::f16, format::bfyx});
auto aligned_input2_mem = engine.allocate_memory({ov::PartialShape(in2_shape_aligned), data_types::f16, format::bfyx});

auto input1_mem = engine.reinterpret_buffer(*aligned_input1_mem, layout{ov::PartialShape(in1_shape),
data_types::f16,
format::bfyx,
padding({padding_size_batch1, 0, 0, 0},
{0, padding_size_batch2, padding_size_k, padding_size_m}, 0.0f, dyn_pad_dims_input)});

auto input2_mem = engine.reinterpret_buffer(*aligned_input2_mem, layout{ov::PartialShape(in2_shape),
data_types::f16,
format::bfyx,
padding({0, padding_size_batch2, 0, 0},
{padding_size_batch1, 0, padding_size_n, padding_size_k}, 0.0f, dyn_pad_dims_input)});

auto input_1_data = rg.generate_random_1d<ov::float16>(ov::shape_size(in1_shape), -2, 2);
auto input_2_data = rg.generate_random_1d<ov::float16>(ov::shape_size(in2_shape), -2, 2);

fill_mem(input1_mem, input_1_data);
fill_mem(input2_mem, input_2_data);

auto get_ref_results = [&]() {
ov::Shape in1_shape = { 1, BATCH_SIZE, M_SIZE, K_SIZE };
ov::Shape in2_shape = { 1, BATCH_SIZE, K_SIZE, N_SIZE };
auto in1_layout = layout{ {-1, -1, -1, -1}, data_types::f16, format::bfyx};
auto in2_layout = layout{ {-1, -1, -1, -1}, data_types::f16, format::bfyx};

auto input1_mem = engine.allocate_memory(layout{ov::PartialShape(in1_shape), data_types::f16, format::bfyx});
auto input2_mem = engine.allocate_memory(layout{ov::PartialShape(in2_shape), data_types::f16, format::bfyx});

fill_mem(input1_mem, input_1_data);
fill_mem(input2_mem, input_2_data);

topology topology;
topology.add(input_layout("input1", in1_layout),
input_layout("input2", in2_layout),
gemm("gemm_ref", { input_info("input1"), input_info("input2") }, data_types::f16, false, false, 1.0f, 0.0f, 4, 4)
);

auto config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::optimize_data(true));
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
config.set_property(ov::enable_profiling(true));

network network(engine, topology, config);
network.set_input_data("input1", input1_mem);
network.set_input_data("input2", input2_mem);

auto outputs = network.execute();
OPENVINO_ASSERT(outputs.size() == 1);
OPENVINO_ASSERT(outputs.begin()->first == "gemm_ref");

auto inst = network.get_primitive("gemm_ref");

auto output_mem = outputs.at("gemm_ref").get_memory();
auto output_layout = outputs.at("gemm_ref").get_layout();

return engine.reinterpret_buffer(*output_mem, output_layout);
};

topology topology;
topology.add(input_layout("input1", in1_layout),
input_layout("input2", in2_layout),
gemm("gemm", { input_info("input1"), input_info("input2") }, data_types::f16, false, false, 1.0f, 0.0f, 4, 4)
);

ExecutionConfig config = get_test_default_config(engine);
config.set_property(ov::intel_gpu::optimize_data(true));
config.set_property(ov::intel_gpu::allow_new_shape_infer(true));
config.set_property(ov::enable_profiling(true));
network::ptr network = get_network(engine, topology, config, get_test_stream_ptr(), is_caching_test);
network->set_input_data("input1", input1_mem);
network->set_input_data("input2", input2_mem);

auto inst = network->get_primitive("gemm");
auto impl = inst->get_impl();
ASSERT_TRUE(impl != nullptr);
ASSERT_TRUE(impl->is_dynamic());

auto outputs = network->execute();

auto output_mem = outputs.at("gemm").get_memory();
auto output_layout = outputs.at("gemm").get_layout();

auto res = engine.reinterpret_buffer(*output_mem, output_layout);

auto ref_res = get_ref_results();

mem_lock<ov::float16> res_lock(res, get_test_stream());
mem_lock<ov::float16> res_ref_lock(ref_res, get_test_stream());
for (size_t i = 0; i < res->count(); i++) {
ASSERT_EQ(res_lock[i], res_ref_lock[i]) << i;
}
}

void test_dynamic_multi_inference_same_shape(bool is_caching_test) {
auto& engine = get_test_engine();

Expand Down Expand Up @@ -549,6 +702,10 @@ TEST_F(gemm_gpu_tests, dynamic) {
this->test_dynamic(false);
}

TEST_F(gemm_gpu_tests, dynamic_padding) {
this->test_dynamic_padding(false);
}

TEST_F(gemm_gpu_tests, dynamic_multi_inference_same_shape) {
this->test_dynamic_multi_inference_same_shape(false);
}
Expand Down

0 comments on commit d575d2b

Please sign in to comment.