From 42fbf61b15de9b738781bda5b12705ec36d6223a Mon Sep 17 00:00:00 2001 From: Sergey Shlyapnikov Date: Wed, 30 Oct 2024 11:19:47 +0400 Subject: [PATCH] [GPU] Add subsequent reshapes optimization and dynamic paddings support for RoPE and PagedAttention --- .../graph_optimizer/prepare_buffer_fusing.cpp | 59 ++++++++--- .../src/graph/include/reshape_inst.h | 17 ++-- .../cl_kernels/pa_kv_cache_update_ref.cl | 47 +++++---- .../kernel_selector/cl_kernels/rope_ref.cl | 30 +++--- .../kernel_selector/cl_kernels/sdpa_opt.cl | 32 ++++-- .../optimize_subsequent_reshapes.cpp | 96 ++++++++++++++++++ .../optimize_subsequent_reshapes.hpp | 23 +++++ .../src/plugin/transformations_pipeline.cpp | 3 + .../optimize_subsequent_reshapes_test.cpp | 97 +++++++++++++++++++ 9 files changed, 338 insertions(+), 66 deletions(-) create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.cpp create mode 100644 src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.hpp create mode 100644 src/plugins/intel_gpu/tests/unit/transformations/optimize_subsequent_reshapes_test.cpp diff --git a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp index 7bdbc53ad54d16..d4683f5a4f3667 100644 --- a/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp +++ b/src/plugins/intel_gpu/src/graph/graph_optimizer/prepare_buffer_fusing.cpp @@ -659,23 +659,33 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format if (user_info.first && user_info.first->is_type()) { auto reshape_desc = user_info.first->as().get_primitive(); auto reshape_mode = reshape_desc->mode; + auto reshape_axis = crop_axis; if (reshape_mode == reshape::reshape_mode::base) { - user_info.second.data_padding._dynamic_dims_mask = dyn_pad_sizes; + auto reshape_ps = user_info.second.get_partial_shape(); + auto crop_dim_val = crop_layout.get_partial_shape()[crop_axis].get_length(); + + int64_t mul = 1; + for (size_t i = reshape_ps.size(); i > 1; i--) { + if (reshape_ps[i - 1].is_dynamic() || mul == crop_dim_val) + break; + + mul *= reshape_ps[i - 1].get_length(); + reshape_axis = i - 1; + } } else if (reshape_mode == reshape::reshape_mode::unsqueeze || reshape_mode == reshape::reshape_mode::squeeze) { auto reshape_ps = user_info.second.get_partial_shape(); auto output_pattern = reshape_desc->output_pattern; - auto reshape_axis = crop_axis; for (size_t i = 0; i < output_pattern.size(); i++) { if (output_pattern[i] <= static_cast(reshape_axis)) { reshape_axis += reshape_mode == reshape::reshape_mode::unsqueeze ? 1 : -1; } } - - padding::DynamicDimsMask dyn_pad_mask; - dyn_pad_mask[reshape_axis] = 1; - user_info.second.data_padding._dynamic_dims_mask = dyn_pad_mask; } + + auto reshape_dyn_pad_mask = padding::DynamicDimsMask(); + reshape_dyn_pad_mask[reshape_axis] = 1; + user_info.second.data_padding._dynamic_dims_mask = reshape_dyn_pad_mask; } return; } @@ -703,13 +713,36 @@ void crop_in_place_optimization::update_in_place_crop_padding_simple_data_format auto reshape_desc = user_info.first->as().get_primitive(); auto reshape_mode = reshape_desc->mode; if (reshape_mode == reshape::reshape_mode::base) { - auto reshape_rank = user_info.second.get_partial_shape().size(); - auto reshape_last_dim = user_info.second.get_partial_shape().to_shape()[reshape_rank - 1]; - if (lower_sizes[crop_axis]) - lower_sizes[crop_axis] /= reshape_last_dim; - if (upper_sizes[crop_axis]) - upper_sizes[crop_axis] /= reshape_last_dim; - user_info.second.data_padding = padding(lower_sizes, upper_sizes, dyn_pad_sizes); + auto reshape_ps = user_info.second.get_partial_shape(); + auto crop_dim_val = crop_layout.get_partial_shape()[crop_axis].get_length(); + + auto divider = 1; + auto reshape_axis = reshape_ps.size(); + for (size_t i = reshape_ps.size(); i > 1; i--) { + const auto& dim_value = reshape_ps[i - 1].get_length(); + if (divider * dim_value == crop_dim_val) + break; + + divider *= dim_value; + reshape_axis = i - 1; + } + reshape_axis -= 1; + + const auto output_rank = std::max(reshape_ps.size(), static_cast(4)); + std::vector reshape_lower_sizes(output_rank, 0); + std::vector reshape_upper_sizes(output_rank, 0); + padding::DynamicDimsMask reshape_dyn_pad_mask; + + reshape_lower_sizes[reshape_axis] = lower_sizes[crop_axis]; + reshape_upper_sizes[reshape_axis] = upper_sizes[crop_axis]; + reshape_dyn_pad_mask[reshape_axis] = 1; + + if (reshape_lower_sizes[reshape_axis]) + reshape_lower_sizes[reshape_axis] /= divider; + if (reshape_upper_sizes[reshape_axis]) + reshape_upper_sizes[reshape_axis] /= divider; + + user_info.second.data_padding = padding(reshape_lower_sizes, reshape_upper_sizes, reshape_dyn_pad_mask); } else { auto reshape_ps = user_info.second.get_partial_shape(); auto output_pattern = reshape_desc->output_pattern; diff --git a/src/plugins/intel_gpu/src/graph/include/reshape_inst.h b/src/plugins/intel_gpu/src/graph/include/reshape_inst.h index 1bbfd94256a50c..1ceb491af497eb 100644 --- a/src/plugins/intel_gpu/src/graph/include/reshape_inst.h +++ b/src/plugins/intel_gpu/src/graph/include/reshape_inst.h @@ -9,6 +9,7 @@ #include "crop_inst.h" #include "rope_inst.h" #include "mvn_inst.h" +#include "paged_attention_inst.h" #include "primitive_inst.h" #include @@ -59,7 +60,7 @@ struct typed_program_node : public typed_program_node_base { return false; // TODO: If user is RoPE or MVN and dynamic padding exists, ouput padding propagation is not supported in the base mode - if (get_users().size() == 1 && (get_users().front()->is_type() || get_users().front()->is_type())) + if (get_users().size() == 1 && get_users().front()->is_type()) return false; auto axis = input().as().get_primitive()->axis; @@ -73,15 +74,19 @@ struct typed_program_node : public typed_program_node_base { const auto& output_pshape = prim->output_partial_shape; // TODO: If the reshape's output shape is non constant, issue occurs // during shape inference due to execution order at runtime - if ((output_pshape.size() != input_rank + 1) || prim->output_pattern.empty()) + if (prim->output_pattern.empty()) return false; + // Iteratively check the total product of all static innermost dimensions + // until the crop dimension value matches or the first dynamic dimension is encountered int64_t mul = 1; - for (size_t i = input_rank - 1; i < output_pshape.size() ; i++) { - if (output_pshape[i].is_dynamic()) - return false; - mul *= output_pshape[i].get_length(); + for (size_t i = output_pshape.size(); i > 1 ; i--) { + if (output_pshape[i - 1].is_dynamic() || mul == input_last_dim_val) + break; + + mul *= output_pshape[i - 1].get_length(); } + if (input_last_dim_val != mul) return false; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl index ef2f78496b2cf2..ec7bdb6f9f9209 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/pa_kv_cache_update_ref.cl @@ -34,10 +34,14 @@ KERNEL(pa_kv_cache_update)( const uint seq_block_idx = block_indices_begins[seq_idx] + seq_len / PAGED_ATTENTION_BLOCK_SIZE; const uint block_idx = block_indices[seq_block_idx]; - uint key_value_in_offset = seq_idx * KV_HEADS_NUM * HEAD_SIZE + head_idx * HEAD_SIZE; + uint key_in_offset = INPUT0_PAD_BEFORE_FEATURE_NUM + + seq_idx * (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) + + head_idx * HEAD_SIZE; + uint value_in_offset = INPUT1_PAD_BEFORE_FEATURE_NUM + + seq_idx * (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_BEFORE_FEATURE_NUM + INPUT1_PAD_AFTER_FEATURE_NUM) + + head_idx * HEAD_SIZE; uint key_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block; - uint value_out_offset = block_idx * KV_HEADS_NUM * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + head_idx * HEAD_SIZE * PAGED_ATTENTION_BLOCK_SIZE + current_token_pos_in_block * HEAD_SIZE; #define READ_BLOCK_SIZE GENERATE_STAGE_BLOCK_SIZE @@ -45,7 +49,7 @@ KERNEL(pa_kv_cache_update)( #define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset); #define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE) - DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index); + DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE; @@ -56,7 +60,7 @@ KERNEL(pa_kv_cache_update)( #endif } - input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index); + input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i; @@ -83,8 +87,13 @@ KERNEL(pa_kv_cache_update)( const uint token_start_pos = (past_len + block_start_pos - subsequence_begin_idx) % PAGED_ATTENTION_BLOCK_SIZE; - uint key_value_in_offset = block_start_pos * KV_HEADS_NUM * HEAD_SIZE + - head_idx * HEAD_SIZE; + uint key_in_offset = INPUT0_PAD_BEFORE_FEATURE_NUM + + block_start_pos * (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM) + + head_idx * HEAD_SIZE; + + uint value_in_offset = INPUT1_PAD_BEFORE_FEATURE_NUM + + block_start_pos * (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM) + + head_idx * HEAD_SIZE; const uint current_block_idx = (past_len + block_start_pos - subsequence_begin_idx) / PAGED_ATTENTION_BLOCK_SIZE; @@ -106,14 +115,14 @@ KERNEL(pa_kv_cache_update)( #define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset); #define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE) - DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index); + DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE; key_cache_data[key_offset] = input_data[i]; } - input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index); + input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i; @@ -126,14 +135,14 @@ KERNEL(pa_kv_cache_update)( #define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset); #define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE) - DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index); + DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE; key_cache_data[key_offset] = input_data[i]; } - input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index); + input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i; @@ -146,14 +155,14 @@ KERNEL(pa_kv_cache_update)( #define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset); #define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE) - DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index); + DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE; key_cache_data[key_offset] = input_data[i]; } - input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index); + input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i; @@ -166,14 +175,14 @@ KERNEL(pa_kv_cache_update)( #define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset); #define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE) - DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index); + DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE; key_cache_data[key_offset] = input_data; } - input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index); + input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i; @@ -181,7 +190,8 @@ KERNEL(pa_kv_cache_update)( } } - key_value_in_offset += KV_HEADS_NUM * HEAD_SIZE; + key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM); + value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM); key_out_offset += 1; value_out_offset += HEAD_SIZE; } @@ -194,14 +204,14 @@ KERNEL(pa_kv_cache_update)( #define BLOCK_READ(ptr, offset) BLOCK_READN(INPUT0_TYPE, READ_BLOCK_SIZE, ptr, offset); #define DATA_VEC MAKE_VECTOR_TYPE(INPUT0_TYPE, READ_BLOCK_SIZE) - DATA_VEC input_data = BLOCK_READ(key_data, key_value_in_offset + head_idx_index); + DATA_VEC input_data = BLOCK_READ(key_data, key_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint key_offset = key_out_offset + (head_idx_index + sglid + SUBGROUP_SIZE * i) * PAGED_ATTENTION_BLOCK_SIZE; key_cache_data[key_offset] = input_data; } - input_data = BLOCK_READ(value_data, key_value_in_offset + head_idx_index); + input_data = BLOCK_READ(value_data, value_in_offset + head_idx_index); unroll_for (uint i = 0; i < READ_BLOCK_SIZE; i++) { uint value_offset = value_out_offset + head_idx_index + sglid + SUBGROUP_SIZE * i; @@ -209,7 +219,8 @@ KERNEL(pa_kv_cache_update)( } } - key_value_in_offset += KV_HEADS_NUM * HEAD_SIZE; + key_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT0_PAD_AFTER_FEATURE_NUM + INPUT0_PAD_BEFORE_FEATURE_NUM); + value_in_offset += (KV_HEADS_NUM * HEAD_SIZE + INPUT1_PAD_AFTER_FEATURE_NUM + INPUT1_PAD_BEFORE_FEATURE_NUM); key_out_offset += 1; value_out_offset += HEAD_SIZE; } diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl index 38066b4461def4..133440a21301f2 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/rope_ref.cl @@ -28,14 +28,11 @@ KERNEL(rope_ref)( uint r = rf < HALF_ROTARY_NDIMS ? rf * 2 : 0; uint f = rf < HEAD_SIZE - ROTARY_NDIMS ? rf * 2 : 0; -#ifdef ENABLE_SLICE - uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, p, b, h * HEAD_SIZE, 0); - - input_idx += SLICED_FROM_START * (p * INPUT0_FEATURE_NUM + b + 1) - + SLICED_FROM_END * (p * INPUT0_FEATURE_NUM + b); -#else uint input_idx = INPUT0_GET_INDEX(p, b, h * HEAD_SIZE, 0); +#ifdef ENABLE_SLICE + input_idx += SLICED_FROM_START; #endif + uint cos_sin_p = p < INPUT1_BATCH_NUM ? p : 0; uint cos_sin_b = b < INPUT1_FEATURE_NUM ? b : 0; uint cos_sin_idx = INPUT1_GET_INDEX(cos_sin_p, cos_sin_b, 0, 0); @@ -69,14 +66,11 @@ KERNEL(rope_ref)( const uint h = (uint)get_global_id(2) / HALF_ROTARY_NDIMS; const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS; -#ifdef ENABLE_SLICE - uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, p, h * HEAD_SIZE, 0); - - input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + p + 1) - + SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + p); -#else uint input_idx = INPUT0_GET_INDEX(b, p, h * HEAD_SIZE, 0); +#ifdef ENABLE_SLICE + input_idx += SLICED_FROM_START; #endif + uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; uint cos_sin_p = p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM < INPUT1_FEATURE_NUM ? p + INPUT1_FEATURE_NUM - INPUT0_FEATURE_NUM : 0; uint cos_sin_h = h < INPUT1_SIZE_Y ? h : 0; @@ -119,15 +113,13 @@ KERNEL(rope_ref)( const uint p = (uint)get_global_id(2) / HALF_ROTARY_NDIMS; const uint r = (uint)get_global_id(2) % HALF_ROTARY_NDIMS; -#ifdef ENABLE_SLICE - uint input_idx = GET_DATA_INDEX(SLICED_INPUT0, b, h, p, 0); - - input_idx += SLICED_FROM_START * (b * INPUT0_FEATURE_NUM + h + 1) - + SLICED_FROM_END * (b * INPUT0_FEATURE_NUM + h); -#elif ENABLE_TRANSPOSE - uint input_idx = GET_DATA_INDEX(TRANSPOSED_INPUT0, b, h, p, 0); +#if ENABLE_TRANSPOSE + uint input_idx = INPUT0_GET_INDEX(b, p, h, 0); #else uint input_idx = INPUT0_GET_INDEX(b, h, p, 0); +#ifdef ENABLE_SLICE + input_idx += SLICED_FROM_START; +#endif #endif uint cos_sin_b = b < INPUT1_BATCH_NUM ? b : 0; diff --git a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl index 748f79115262e0..2a7d2a58b7383a 100644 --- a/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl +++ b/src/plugins/intel_gpu/src/kernel_selector/cl_kernels/sdpa_opt.cl @@ -733,9 +733,10 @@ KERNEL(sdpa_opt)( #if IS_PAGED_ATTENTION const uint block_start_pos = blocked_indexes_start[target_seq_dim]; const uint block_end_pos = blocked_indexes_end[target_seq_dim]; - - uint query_offset = block_start_pos * HEAD_SIZE * NUM_HEADS + num_heads_dim * HEAD_SIZE + head_size_idx; - const uint query_pitch = HEAD_SIZE * NUM_HEADS; + uint query_offset = INPUT0_PAD_BEFORE_FEATURE_NUM + + block_start_pos * (HEAD_SIZE * NUM_HEADS + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM) + + num_heads_dim * HEAD_SIZE + head_size_idx; + const uint query_pitch = (HEAD_SIZE * NUM_HEADS + INPUT0_PAD_BEFORE_FEATURE_NUM + INPUT0_PAD_AFTER_FEATURE_NUM); const uint cur_target_seq_len_size = block_end_pos - block_start_pos; #else @@ -834,8 +835,11 @@ KERNEL(sdpa_opt)( const uint heads_dim = num_heads_dim; #endif #define KEY_SEQ_OFFSET subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]] - uint key_offset = KEY_SEQ_OFFSET * HEAD_SIZE * NUM_KV_HEADS + heads_dim * HEAD_SIZE + seq_len * HEAD_SIZE * NUM_KV_HEADS; - const uint key_pitch = HEAD_SIZE * NUM_KV_HEADS; + uint key_offset = INPUT1_PAD_BEFORE_FEATURE_NUM + + KEY_SEQ_OFFSET * (HEAD_SIZE * NUM_KV_HEADS + INPUT1_PAD_BEFORE_FEATURE_NUM + INPUT1_PAD_AFTER_FEATURE_NUM) + + heads_dim * HEAD_SIZE + + seq_len * (HEAD_SIZE * NUM_KV_HEADS + INPUT1_PAD_BEFORE_FEATURE_NUM + INPUT1_PAD_AFTER_FEATURE_NUM); + const uint key_pitch = (HEAD_SIZE * NUM_KV_HEADS + INPUT1_PAD_BEFORE_FEATURE_NUM + INPUT1_PAD_AFTER_FEATURE_NUM); #else #ifdef BEAM_TABLE_TYPE const uint b_idx = beam_table[FUNC_CALL(get_bt_index_key)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, seq_len + sglid, 0)]; @@ -1018,7 +1022,7 @@ KERNEL(sdpa_opt)( // QK*V calculation MAKE_VECTOR_TYPE(OUTPUT_TYPE, TARGET_SEQ_LEN_BLOCK_SIZE) acc_output_res = OUTPUT_VAL_ZERO; #if IS_PAGED_ATTENTION - const uint value_pitch = HEAD_SIZE * NUM_KV_HEADS; + const uint value_pitch = (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM); #else #ifdef INPUT2_DIMS_ORDER uint value_offset_base = FUNC_CALL(get_input2_index)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, 0, 0); @@ -1028,7 +1032,6 @@ KERNEL(sdpa_opt)( const uint value_pitch = HEAD_SIZE; #endif #endif - if (partition_seq_len == SEQ_LEN_PARTITION_SIZE) { uint seq_len_start = (sgid / (SUBGROUPS_PER_WG / SG_SCALE_FACTOR)) * (SEQ_LEN_PARTITION_SIZE / SG_SCALE_FACTOR); for (uint seq_len = seq_len_start; seq_len < seq_len_start + (SEQ_LEN_PARTITION_SIZE / SG_SCALE_FACTOR); seq_len += SUBGROUP_SIZE) { @@ -1039,7 +1042,10 @@ KERNEL(sdpa_opt)( const uint heads_dim = num_heads_dim; #endif const uint value_seq_offset = subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]]; - uint value_offset = value_seq_offset * HEAD_SIZE * NUM_KV_HEADS + heads_dim * HEAD_SIZE + (start_partition_idx + (seq_len)) * HEAD_SIZE * NUM_KV_HEADS + head_size_idx; + uint value_offset = INPUT2_PAD_BEFORE_FEATURE_NUM + + value_seq_offset * (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM) + + heads_dim * HEAD_SIZE + + (start_partition_idx + (seq_len)) * (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM) + head_size_idx; #else #ifdef BEAM_TABLE_TYPE const uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len) + sglid, sgid * SUBGROUP_SIZE)]; @@ -1087,7 +1093,10 @@ KERNEL(sdpa_opt)( const uint heads_dim = num_heads_dim; #endif const uint value_seq_offset = subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]]; - uint value_offset = value_seq_offset * HEAD_SIZE * NUM_KV_HEADS + heads_dim * HEAD_SIZE + (start_partition_idx + (seq_len * SUBGROUP_SIZE)) * HEAD_SIZE * NUM_KV_HEADS + head_size_idx; + uint value_offset = INPUT2_PAD_BEFORE_FEATURE_NUM + + value_seq_offset * (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM) + + heads_dim * HEAD_SIZE + + (start_partition_idx + (seq_len * SUBGROUP_SIZE)) * (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM) + head_size_idx; #else #ifdef BEAM_TABLE_TYPE const uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + (seq_len * SUBGROUP_SIZE) + sglid, sgid * SUBGROUP_SIZE)]; @@ -1138,7 +1147,10 @@ KERNEL(sdpa_opt)( const uint heads_dim = num_heads_dim; #endif const uint value_seq_offset = subsequence_begins[gws_seq_indexes_correspondence[target_seq_dim]]; - uint value_offset = value_seq_offset * HEAD_SIZE * NUM_KV_HEADS + heads_dim * HEAD_SIZE + (start_partition_idx + seq_len_leftovers_start) * HEAD_SIZE * NUM_KV_HEADS + head_size_idx; + uint value_offset = INPUT2_PAD_BEFORE_FEATURE_NUM + + value_seq_offset * (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM) + + heads_dim * HEAD_SIZE + + (start_partition_idx + seq_len_leftovers_start) * (HEAD_SIZE * NUM_KV_HEADS + INPUT2_PAD_BEFORE_FEATURE_NUM + INPUT2_PAD_AFTER_FEATURE_NUM) + head_size_idx; #else #ifdef BEAM_TABLE_TYPE const uint b_idx = beam_table[FUNC_CALL(get_bt_index_value)(OPTIONAL_SHAPE_INFO_TENSOR b0_idx, b1_idx, 0, 0, start_partition_idx + seq_len_leftovers_start + sglid, sgid * SUBGROUP_SIZE)]; diff --git a/src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.cpp b/src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.cpp new file mode 100644 index 00000000000000..8aece52897aef9 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.cpp @@ -0,0 +1,96 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "optimize_subsequent_reshapes.hpp" + +#include "openvino/core/rt_info.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/concat.hpp" +#include "openvino/op/reshape.hpp" +#include "openvino/pass/pattern/op/or.hpp" +#include "openvino/pass/pattern/op/wrap_type.hpp" +#include "transformations/utils/utils.hpp" + +namespace ov { +namespace intel_gpu { + +OptimizeSubsequentReshapes::OptimizeSubsequentReshapes() { + using namespace ov::pass::pattern; + using ov::pass::pattern::op::Or; + + auto dynamic_batch_only = [](Output output) { + const auto& shape = output.get_partial_shape(); + + if (shape.size() <= 1) + return false; + + if (shape[0].is_static()) + return false; + + for (size_t i = 1; i < shape.size(); i++) + if (shape[i].is_dynamic()) + return false; + + return true; + }; + + auto first_reshape_data = any_input(ov::pass::pattern::all_of({ dynamic_batch_only, ov::pass::pattern::consumers_count(1) })); + auto first_reshape_pattern = ov::pass::pattern::wrap_type(); + auto first_reshape = wrap_type({ first_reshape_data, first_reshape_pattern }, + ov::pass::pattern::all_of({ dynamic_batch_only, ov::pass::pattern::consumers_count(1) })); + + auto second_reshape_pattern = ov::pass::pattern::wrap_type(); + auto second_reshape = wrap_type({ first_reshape, second_reshape_pattern }, dynamic_batch_only); + + ov::matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) { + const auto& pattern_map = m.get_pattern_value_map(); + + auto input_node = pattern_map.at(first_reshape_data).get_node_shared_ptr(); + auto first_reshape_node = pattern_map.at(first_reshape).get_node_shared_ptr(); + auto second_reshape_node = pattern_map.at(second_reshape).get_node_shared_ptr(); + + auto input_ps = first_reshape_node->input(0).get_partial_shape(); + auto first_reshape_ps = first_reshape_node->get_output_partial_shape(0); + auto second_reshape_ps = second_reshape_node->get_output_partial_shape(0); + + auto static_dims_product = [](ov::PartialShape& ps) { + int64_t total_dims = 1; + + for (auto& dim : ps) { + if (dim.is_static()) + total_dims *= dim.get_length(); + } + + return total_dims; + }; + + if (static_dims_product(input_ps) != static_dims_product(first_reshape_ps) || + static_dims_product(first_reshape_ps) != static_dims_product(second_reshape_ps)) + return false; + + std::vector new_pattern; + for (auto& dim : second_reshape_ps) { + if (dim.is_dynamic()) { + new_pattern.push_back(0); + } else { + new_pattern.push_back(dim.get_length()); + } + } + + auto new_pattern_const = std::make_shared(ov::element::i32, ov::Shape{new_pattern.size()}, new_pattern); + auto new_reshape = std::make_shared(first_reshape_node->input(0).get_source_output(), new_pattern_const, true); + new_reshape->set_friendly_name(second_reshape_node->get_friendly_name()); + + ov::replace_node(second_reshape_node, new_reshape); + copy_runtime_info(first_reshape_node, new_reshape); + + return true; + }; + + auto m = std::make_shared(second_reshape, "OptimizeSubsequentReshapes"); + this->register_matcher(m, callback); +} + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.hpp b/src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.hpp new file mode 100644 index 00000000000000..a194e230463ca9 --- /dev/null +++ b/src/plugins/intel_gpu/src/plugin/transformations/optimize_subsequent_reshapes.hpp @@ -0,0 +1,23 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include "openvino/pass/graph_rewrite.hpp" + +namespace ov { +namespace intel_gpu { + +/** + * @brief This pass looks for specific patterns of subsequent reshapes that can be + * merged together. + */ +class OptimizeSubsequentReshapes : public ov::pass::MatcherPass { +public: + OPENVINO_RTTI("OptimizeSubsequentReshapes", "0"); + OptimizeSubsequentReshapes(); +}; + +} // namespace intel_gpu +} // namespace ov diff --git a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp index 4b72385663bf9d..c74c29b85b29b3 100644 --- a/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp +++ b/src/plugins/intel_gpu/src/plugin/transformations_pipeline.cpp @@ -80,6 +80,7 @@ #include "plugin/transformations/increase_position_ids_precision.hpp" #include "plugin/transformations/group_norm_composition.hpp" #include "plugin/transformations/dynamic_quantize_fully_connected.hpp" +#include "plugin/transformations/optimize_subsequent_reshapes.hpp" #include "transformations/common_optimizations/nop_elimination.hpp" #include "transformations/common_optimizations/rms_fusion.hpp" #include "transformations/common_optimizations/broadcast_elementwise_fusion.hpp" @@ -901,6 +902,8 @@ void TransformationsPipeline::apply(std::shared_ptr func) { pass_config->disable(); pass_config->disable(); + manager.register_pass(); + manager.register_pass(); // This Validate is needed for proper data type propagation after applying IncreasePositionIdsPrecision pass manager.register_pass(); diff --git a/src/plugins/intel_gpu/tests/unit/transformations/optimize_subsequent_reshapes_test.cpp b/src/plugins/intel_gpu/tests/unit/transformations/optimize_subsequent_reshapes_test.cpp new file mode 100644 index 00000000000000..732a14be03bf39 --- /dev/null +++ b/src/plugins/intel_gpu/tests/unit/transformations/optimize_subsequent_reshapes_test.cpp @@ -0,0 +1,97 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include + +#include +#include +#include + +#include "openvino/pass/manager.hpp" +#include "openvino/core/model.hpp" +#include "openvino/core/coordinate_diff.hpp" +#include "openvino/core/type/element_type.hpp" +#include "openvino/op/constant.hpp" +#include "openvino/op/reshape.hpp" + +#include +#include +#include + +#include "common_test_utils/ov_test_utils.hpp" + +using namespace testing; +using namespace ov::intel_gpu; + +TEST_F(TransformationTestsF, OptimizeSubsequentReshapes1) { + { + auto input = std::make_shared(ov::element::i64, ov::PartialShape{ -1, 1, 4096 }); + auto first_reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{ 0, 0, 32, 128 }); + auto first_reshape = std::make_shared(input, first_reshape_pattern, true); + + auto second_reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{2}, std::vector{ 0, -1 }); + auto second_reshape = std::make_shared(first_reshape, second_reshape_pattern, true); + auto result = std::make_shared(second_reshape); + + model = std::make_shared(ov::NodeVector{ result }, ov::ParameterVector{ input }); + manager.register_pass(); + } + { + auto input = std::make_shared(ov::element::i64, ov::PartialShape{ -1, 1, 4096 }); + auto reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{2}, std::vector{ 0, 4096 }); + auto reshape = std::make_shared(input, reshape_pattern, true); + auto result = std::make_shared(reshape); + + model_ref = std::make_shared(ov::NodeVector{ result }, ov::ParameterVector{ input }); + } + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, OptimizeSubsequentReshapes2) { + { + auto input = std::make_shared(ov::element::i64, ov::PartialShape{ -1, 1, 4096 }); + auto first_reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{ 0, 0, 32, 128 }); + auto first_reshape = std::make_shared(input, first_reshape_pattern, true); + + auto second_reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{ 0, 32, 1, 0 }); + auto second_reshape = std::make_shared(first_reshape, second_reshape_pattern, true); + auto result = std::make_shared(second_reshape); + + model = std::make_shared(ov::NodeVector{ result }, ov::ParameterVector{ input }); + manager.register_pass(); + } + { + auto input = std::make_shared(ov::element::i64, ov::PartialShape{ -1, 1, 4096 }); + auto reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{ 0, 32, 1, 128 }); + auto reshape = std::make_shared(input, reshape_pattern, true); + auto result = std::make_shared(reshape); + + model_ref = std::make_shared(ov::NodeVector{ result }, ov::ParameterVector{ input }); + } + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +} + +TEST_F(TransformationTestsF, OptimizeSubsequentReshapes3) { + { + auto input = std::make_shared(ov::element::i64, ov::PartialShape{ -1, 32, 1, 128 }); + auto first_reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{4}, std::vector{ 0, 1, 32, 0 }); + auto first_reshape = std::make_shared(input, first_reshape_pattern, true); + + auto second_reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{2}, std::vector{ 0, -1 }); + auto second_reshape = std::make_shared(first_reshape, second_reshape_pattern, true); + auto result = std::make_shared(second_reshape); + + model = std::make_shared(ov::NodeVector{ result }, ov::ParameterVector{ input }); + manager.register_pass(); + } + { + auto input = std::make_shared(ov::element::i64, ov::PartialShape{ -1, 32, 1, 128 }); + auto reshape_pattern = std::make_shared(ov::element::i32, ov::Shape{2}, std::vector{ 0, 4096 }); + auto reshape = std::make_shared(input, reshape_pattern, true); + auto result = std::make_shared(reshape); + + model_ref = std::make_shared(ov::NodeVector{ result }, ov::ParameterVector{ input }); + } + comparator.enable(FunctionsComparator::CmpValues::ATTRIBUTES); +}