Skip to content

Commit

Permalink
Fix unit-test failure.
Browse files Browse the repository at this point in the history
Signed-off-by: hyunback <[email protected]>
  • Loading branch information
hyunback committed Oct 10, 2024
1 parent 5f56898 commit 2ec0ae1
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
2 changes: 2 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2124,6 +2124,8 @@ memory::ptr primitive_inst::allocate_output(engine& _engine,
}
} else if (!_node.can_share_buffer() || _node.is_output()
|| (((impl_params.can_be_optimized() || (_node.can_be_optimized() != impl_params.can_be_optimized())) && !_node.is_runtime_skippable()))) {
// To use a memory pool, skippable should always be true.
// Concat and Crop should not use a memory pool if optimized_out changes at runtime because skippable is always false.
GPU_DEBUG_LOG << "[" << _node.id() << ": output]" << std::endl;
return ov::intel_gpu::allocate_memory_evenif_zero_bytes(_engine, layout, alloc_type, reset);
} else {
Expand Down
16 changes: 12 additions & 4 deletions src/plugins/intel_gpu/src/graph/program_node.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1548,12 +1548,20 @@ void program_node::create_onednn_primitive_attributes(
mem_desc.get_dims(), mem_desc.get_data_type());
} else if (is_type<gemm>()) {
size_t rank = cldnn::format::dimension(in.format);
auto in_pshape = in.get_partial_shape();
auto out_pshape = get_output_layout().get_partial_shape();
size_t ones_to_add = std::max(out_pshape.size(), static_cast<size_t>(4)) - in_pshape.size();
if (ones_to_add > 0) {
layout new_layout = in;
ov::PartialShape new_input_pshape;
std::vector<ov::Dimension> dims(in_pshape.begin(), in_pshape.begin() + in_pshape.size());
new_input_pshape = ov::PartialShape(dims);
new_input_pshape.insert(new_input_pshape.begin(), ones_to_add, 1ul);
new_layout.set_partial_shape(new_input_pshape);
in = new_layout;
}
size_t in_batched_size = in.count() / (in.spatial(0) * in.spatial(1));
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in_batched_size == 1);
bool spatial_dims_can_be_removed = (in.spatial(0) * in.spatial(1) == 1);
if (dims.size() == 4 && spatial_dims_can_be_removed) {
dims.erase(dims.begin() + 2, dims.begin() + 4);
}
dnnl::memory::data_type dt = onednn::convert_data_type(in.data_type);
dnnl::memory::format_tag fmt = onednn::convert_gemm_data_format(dims, in.format);
post_ops.append_binary(alg, dnnl::memory::desc(dims, dt, fmt));
Expand Down

0 comments on commit 2ec0ae1

Please sign in to comment.