Skip to content

Commit

Permalink
[GPU] Eliminate Reshape sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
sshlyapn committed Oct 29, 2024
1 parent dced7dd commit 09a6b20
Show file tree
Hide file tree
Showing 8 changed files with 289 additions and 49 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -346,6 +346,7 @@ static bool can_reshape_be_optimized(const reshape_node& node) {
if (!node.is_runtime_propagatable_padding()
&& node.get_input_layout(0).data_padding.is_dynamic()
&& !node.get_output_layout(0).data_padding.is_dynamic()) {
GPU_DEBUG_TRACE_DETAIL << "Check " << node.id() << " false\n";
return false;
}

Expand All @@ -357,9 +358,12 @@ static bool can_reshape_be_optimized(const reshape_node& node) {
node.get_users().front()->get_preferred_impl_type() == impl_types::onednn)
return true;

if (node.is_in_place())
if (node.is_in_place()) {
GPU_DEBUG_TRACE_DETAIL << "Check " << node.id() << " true\n";
return true;
}

GPU_DEBUG_TRACE_DETAIL << "Check " << node.id() << " false\n";
return false;
}

Expand Down Expand Up @@ -493,13 +497,18 @@ bool crop_in_place_optimization::match(const program_node& node,
}
}
if (user->is_type<reshape>()) {
GPU_DEBUG_TRACE_DETAIL << "Check " << user->id() << "\n";
// runtime buffer fusing is only handled when there is only one reshape user
if (node.is_dynamic() && node.get_users().size() != 1)
if (node.is_dynamic() && node.get_users().size() != 1) {
GPU_DEBUG_TRACE_DETAIL << "Check " << user->id() << " false\n";
return false;
}
auto& reshape_node = user->as<reshape>();
if (can_reshape_be_optimized(reshape_node) &&
(!node.is_dynamic() || !reshape_node.is_runtime_propagatable_padding()))
return false;
(!node.is_dynamic() || !reshape_node.is_runtime_propagatable_padding())) {
GPU_DEBUG_TRACE_DETAIL << "Check " << user->id() << " false\n";
return false;
}
}
if (user->is_type<experimental_detectron_roi_feature_extractor>() && user->get_dependency_index(node) == 0)
return false;
Expand Down Expand Up @@ -531,14 +540,19 @@ bool crop_in_place_optimization::match(const program_node& node,
// if output padding has defined padding across features already it wouldn't
// work because it expect to have zeros in the padded area.
if ((!node.is_dynamic() || is_runtime) &&
!is_optimizable_padding_for_crop(node, crop_layout, input_layout, crop_params.input_offsets[0]))
return false;
!is_optimizable_padding_for_crop(node, crop_layout, input_layout, crop_params.input_offsets[0])) {
GPU_DEBUG_TRACE_DETAIL << "Check " << node.id() << " false\n";
return false;
}
if (!(((!node.is_dynamic() || is_runtime) && can_crop_be_optimized_along_feature(crop_layout, input_layout))
|| can_crop_be_optimized_simple_data_format(crop_layout, input_layout)))
return false;
|| can_crop_be_optimized_simple_data_format(crop_layout, input_layout))) {
GPU_DEBUG_TRACE_DETAIL << "Check " << node.id() << " false\n";
return false;
}
} else {
return false;
}
GPU_DEBUG_TRACE_DETAIL << "Check " << node.id() << " true\n";
return true;
}

Expand Down Expand Up @@ -659,8 +673,21 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
if (user_info.first && user_info.first->is_type<reshape>()) {
auto reshape_desc = user_info.first->as<reshape>().get_primitive();
auto reshape_mode = reshape_desc->mode;
auto reshape_axis = crop_axis;
auto reshape_dyn_pad_mask = padding::DynamicDimsMask();

if (reshape_mode == reshape::reshape_mode::base) {
user_info.second.data_padding._dynamic_dims_mask = dyn_pad_sizes;
auto reshape_ps = user_info.second.get_partial_shape();
auto crop_dim_val = crop_layout.get_partial_shape()[crop_axis].get_length();

int64_t mul = 1;
for (size_t i = reshape_ps.size(); i > 1; i--) {
if (reshape_ps[i - 1].is_dynamic() || mul == crop_dim_val)
break;

mul *= reshape_ps[i - 1].get_length();
reshape_axis = i - 1;
}
} else if (reshape_mode == reshape::reshape_mode::unsqueeze || reshape_mode == reshape::reshape_mode::squeeze) {
auto reshape_ps = user_info.second.get_partial_shape();
auto output_pattern = reshape_desc->output_pattern;
Expand All @@ -671,11 +698,10 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1;
}
}

padding::DynamicDimsMask dyn_pad_mask;
dyn_pad_mask[reshape_axis] = 1;
user_info.second.data_padding._dynamic_dims_mask = dyn_pad_mask;
}

reshape_dyn_pad_mask[reshape_axis] = 1;
user_info.second.data_padding._dynamic_dims_mask = reshape_dyn_pad_mask;
}
return;
}
Expand All @@ -695,21 +721,61 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format
upper_sizes.push_back(input_layout.spatial(i) - offsets.spatial[i] - crop_size.spatial[i]);
}

auto print_arr = [&](const std::vector<int32_t>& vec) {
std::stringstream ss;
for (size_t i = 0; i < vec.size(); i++) {
ss << vec[i] << ", ";
}
return ss.str();
};

if (is_runtime) {
padding::DynamicDimsMask dyn_pad_sizes;
dyn_pad_sizes[crop_axis] = 1;
crop_layout.data_padding = padding(lower_sizes, upper_sizes, dyn_pad_sizes);
GPU_DEBUG_TRACE_DETAIL << "Set crop paddings: " << print_arr(lower_sizes) << " " << print_arr(upper_sizes) << "\n";
if (user_info.first) {
auto reshape_desc = user_info.first->as<reshape>().get_primitive();
auto reshape_mode = reshape_desc->mode;
if (reshape_mode == reshape::reshape_mode::base) {
auto reshape_rank = user_info.second.get_partial_shape().size();
auto reshape_last_dim = user_info.second.get_partial_shape().to_shape()[reshape_rank - 1];
if (lower_sizes[crop_axis])
lower_sizes[crop_axis] /= reshape_last_dim;
if (upper_sizes[crop_axis])
upper_sizes[crop_axis] /= reshape_last_dim;
user_info.second.data_padding = padding(lower_sizes, upper_sizes, dyn_pad_sizes);
// auto reshape_rank = user_info.second.get_partial_shape().size();
// auto reshape_last_dim = user_info.second.get_partial_shape().to_shape()[reshape_rank - 1];

auto reshape_ps = user_info.second.get_partial_shape();
auto crop_dim_val = crop_layout.get_partial_shape()[crop_axis].get_length();
auto reshape_axis = crop_axis;

int64_t divider = 1;
for (size_t i = reshape_ps.size(); i > 1; i--) {
const auto& dim_value = reshape_ps[i - 1].get_length();
if (reshape_ps[i - 1].is_dynamic() || divider * dim_value == crop_dim_val)
break;

divider *= reshape_ps[i - 1].get_length();
reshape_axis = i - 1;
}
reshape_axis -= 1;

const auto output_rank = std::max(reshape_ps.size(), static_cast<size_t>(4));
std::vector<int32_t> reshape_lower_sizes(output_rank, 0);
std::vector<int32_t> reshape_upper_sizes(output_rank, 0);
padding::DynamicDimsMask reshape_dyn_pad_mask;

reshape_lower_sizes[reshape_axis] = lower_sizes[crop_axis];
reshape_upper_sizes[reshape_axis] = upper_sizes[crop_axis];
reshape_dyn_pad_mask[reshape_axis] = 1;

if (reshape_lower_sizes[reshape_axis])
reshape_lower_sizes[reshape_axis] /= divider;
if (reshape_upper_sizes[reshape_axis])
reshape_upper_sizes[reshape_axis] /= divider;

user_info.second.data_padding = padding(reshape_lower_sizes, reshape_upper_sizes, reshape_dyn_pad_mask);
// if (lower_sizes[crop_axis])
// lower_sizes[crop_axis] /= reshape_last_dim;
// if (upper_sizes[crop_axis])
// upper_sizes[crop_axis] /= reshape_last_dim;
// user_info.second.data_padding = padding(lower_sizes, upper_sizes, dyn_pad_sizes);
} else {
auto reshape_ps = user_info.second.get_partial_shape();
auto output_pattern = reshape_desc->output_pattern;
Expand Down Expand Up @@ -810,6 +876,7 @@ void prepare_buffer_fusing::run(program& p) {
node.get_primitive()->axis,
false);
if (user_info.first) {
GPU_DEBUG_TRACE_DETAIL << "Update output layout with dynamic padding: " << user_info.second << "\n";
node.get_users().front()->set_output_layout(user_info.second);
}
}
Expand Down
39 changes: 29 additions & 10 deletions src/plugins/intel_gpu/src/graph/include/reshape_inst.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,39 +51,58 @@ struct typed_program_node<reshape> : public typed_program_node_base<reshape> {
}

// TODO: This function is to limit condition to a specific case (crop + reshape) among cases for the base mode
if (!input().is_type<crop>())
if (!input().is_type<crop>()) {
GPU_DEBUG_TRACE_DETAIL << " can't propagate " << id() << " " << "\n";
return false;
}

// oneDNN supports padded input of outer axis only for buffer fusing on static shape
if (!has_outer_padding_offset() && get_users().size() == 1 && get_users().front()->get_preferred_impl_type() == impl_types::onednn)
if (!has_outer_padding_offset() && get_users().size() == 1 && get_users().front()->get_preferred_impl_type() == impl_types::onednn) {
GPU_DEBUG_TRACE_DETAIL << " can't propagate " << id() << " " << "\n";
return false;
}

// TODO: If user is RoPE or MVN and dynamic padding exists, ouput padding propagation is not supported in the base mode
if (get_users().size() == 1 && (get_users().front()->is_type<rope>() || get_users().front()->is_type<mvn>()))
if (get_users().size() == 1 && get_users().front()->is_type<mvn>()) {
GPU_DEBUG_TRACE_DETAIL << " can't propagate " << id() << " " << "\n";
return false;
}

auto axis = input().as<crop>().get_primitive()->axis;
const auto& input_pshape = input().get_output_layout(false).get_partial_shape();
auto input_rank = input_pshape.size();
auto input_last_dim = static_cast<int64_t>(input_rank - 1);
if (axis != input_last_dim || input_pshape[input_last_dim].is_dynamic())
if (axis != input_last_dim || input_pshape[input_last_dim].is_dynamic()) {
GPU_DEBUG_TRACE_DETAIL << " can't propagate " << id() << " " << "\n";
return false;
}

auto input_last_dim_val = input_pshape[input_last_dim].get_length();
const auto& output_pshape = prim->output_partial_shape;
// TODO: If the reshape's output shape is non constant, issue occurs
// during shape inference due to execution order at runtime
if ((output_pshape.size() != input_rank + 1) || prim->output_pattern.empty())
if ((output_pshape.size() != input_rank + 1) || prim->output_pattern.empty()) {
GPU_DEBUG_TRACE_DETAIL << " can't propagate " << id() << " " << "\n";
return false;
}

// TODO: fix comment
// Iteratively check the total product of all static innermost dimensions
// until crop dimension value match or first dynamic dimension met
int64_t mul = 1;
for (size_t i = input_rank - 1; i < output_pshape.size() ; i++) {
if (output_pshape[i].is_dynamic())
return false;
mul *= output_pshape[i].get_length();
for (size_t i = output_pshape.size(); i > 1 ; i--) {
if (output_pshape[i - 1].is_dynamic() || mul == input_last_dim_val)
break;

mul *= output_pshape[i - 1].get_length();
}
if (input_last_dim_val != mul)

if (input_last_dim_val != mul) {
GPU_DEBUG_TRACE_DETAIL << " can't propagate " << id() << " " << input_last_dim_val << " vs " << mul << "\n";
return false;
}

GPU_DEBUG_TRACE_DETAIL << " can propagate " << id() << "\n";

return true;
}
Expand Down
5 changes: 5 additions & 0 deletions src/plugins/intel_gpu/src/graph/primitive_inst.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1810,6 +1810,11 @@ primitive_inst::primitive_inst(network & network, program_node const& node, bool
_outputs = allocate_outputs();
}
}

if (_node) {
GPU_DEBUG_TRACE_DETAIL << _node->type()->to_string(*_node) << "\n";
}

_impls_factory = std::make_shared<ImplementationsFactory>(_node);
_impl_params->strm = _network.get_stream_ptr();
for (size_t i = 0; i < get_node().get_output_layouts().size(); ++i) {
Expand Down
30 changes: 11 additions & 19 deletions src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,11 @@ KERNEL(rope_ref)(
uint r = rf < HALF_ROTARY_NDIMS ? rf * 2 : 0;
uint f = rf < HEAD_SIZE - ROTARY_NDIMS ? rf * 2 : 0;

#ifdef ENABLE_SLICE
uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, p, b, h * HEAD_SIZE, 0);

input_idx += SLICED_FROM_START * (p * INPUT0_FEATURE_NUM + b + 1)
+ SLICED_FROM_END * (p * INPUT0_FEATURE_NUM + b);
#else
uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0);
#ifdef ENABLE_SLICE
input_idx += SLICED_FROM_START;
#endif

uint cos_sin_p = p < INPUT1_BATCH_NUM ? p : 0;
uint cos_sin_b = b < INPUT1_FEATURE_NUM ? b : 0;
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_p, cos_sin_b, 0, 0);
Expand Down Expand Up @@ -69,14 +66,11 @@ KERNEL(rope_ref)(
const uint h = (uint)get_global_id(2) / HALF_ROTARY_NDIMS;
const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS;

#ifdef ENABLE_SLICE
uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, p, h * HEAD_SIZE, 0);

input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + p + 1)
+ SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + p);
#else
uint input_idx = INPUT0_GET_INDEX(b, p, h * HEAD_SIZE, 0);
#ifdef ENABLE_SLICE
input_idx += SLICED_FROM_START;
#endif

uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM : 0;
uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0;
Expand Down Expand Up @@ -119,15 +113,13 @@ KERNEL(rope_ref)(
const uint p = (uint)get_global_id(2) / HALF_ROTARY_NDIMS;
const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS;

#ifdef ENABLE_SLICE
uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, h, p, 0);

input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + h + 1)
+ SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + h);
#elif ENABLE_TRANSPOSE
uint input_idx = GET_DATA_INDEX(TRANSPOSED_INPUT0, b, h, p, 0);
#if ENABLE_TRANSPOSE
uint input_idx = INPUT0_GET_INDEX(b, p, h, 0);
#else
uint input_idx = INPUT0_GET_INDEX(b, h, p, 0);
#ifdef ENABLE_SLICE
input_idx += SLICED_FROM_START;
#endif
#endif

uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
diff a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl (rejected hunks)
@@ -28,11 +28,14 @@ KERNEL(rope_ref)(
uint r = rf < HALF_ROTARY_NDIMS ? rf * 2 : 0;
uint f = rf < HEAD_SIZE - ROTARY_NDIMS ? rf * 2 : 0;

- uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0);
#ifdef ENABLE_SLICE
- input_idx += SLICED_FROM_START;
-#endif
+ uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, p, b, h * HEAD_SIZE, 0);

+ input_idx += SLICED_FROM_START * (p * INPUT0_FEATURE_NUM + b + 1)
+ + SLICED_FROM_END * (p * INPUT0_FEATURE_NUM + b);
+#else
+ uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0);
+#endif
uint cos_sin_p = p < INPUT1_BATCH_NUM ? p : 0;
uint cos_sin_b = b < INPUT1_FEATURE_NUM ? b : 0;
uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_p, cos_sin_b, 0, 0);
@@ -66,11 +69,14 @@ KERNEL(rope_ref)(
const uint h = (uint)get_global_id(2) / HALF_ROTARY_NDIMS;
const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS;

- uint input_idx = INPUT0_GET_INDEX(b, p, h * HEAD_SIZE, 0);
#ifdef ENABLE_SLICE
- input_idx += SLICED_FROM_START;
-#endif
+ uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, p, h * HEAD_SIZE, 0);

+ input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + p + 1)
+ + SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + p);
+#else
+ uint input_idx = INPUT0_GET_INDEX(b, p, h * HEAD_SIZE, 0);
+#endif
uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM : 0;
uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0;
@@ -113,13 +119,15 @@ KERNEL(rope_ref)(
const uint p = (uint)get_global_id(2) / HALF_ROTARY_NDIMS;
const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS;

-#if ENABLE_TRANSPOSE
- uint input_idx = INPUT0_GET_INDEX(b, p, h, 0);
+#ifdef ENABLE_SLICE
+ uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, h, p, 0);
+
+ input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + h + 1)
+ + SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + h);
+#elif ENABLE_TRANSPOSE
+ uint input_idx = GET_DATA_INDEX(TRANSPOSED_INPUT0, b, h, p, 0);
#else
uint input_idx = INPUT0_GET_INDEX(b, h, p, 0);
-#ifdef ENABLE_SLICE
- input_idx += SLICED_FROM_START;
-#endif
#endif

uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0;
Loading

0 comments on commit 09a6b20

Please sign in to comment.