diff --git a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp index ec49dd7152fed1..b0e94891719b95 100644 --- a/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/src/transformations/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -575,7 +575,8 @@ ov::pass::RoPEFusionChatGLM::RoPEFusionChatGLM(int split_output_id, const bool s auto const_target_shape_2 = makeConst({batch, 1, seq_len, ndims / 2, 2}); // Slice cos_sin_cache to support 2-dimentional RoPE - auto ScatterUpdate = makePattern({{0, 0}, {1}, seq_length, {0}}, {}); + auto zero_seqlen = makePattern({{0}, seq_length}, {{"axis", 0}}); + auto ScatterUpdate = makePattern({{0, 0}, {1}, seq_length, {0}}, {}) | zero_seqlen; auto slice_Slice_449_1d = makePattern({cos_sin_cache, {0}, seq_length, {1}, {1}}); auto slice_Slice_449_2d = makePattern({cos_sin_cache, {0, 0}, ScatterUpdate, {1, 1}, {0}}); auto slice_StridedSlice_449 = GenStridedSlice(cos_sin_cache, {0, 0}, ScatterUpdate, {1, 1}, 1); diff --git a/src/common/transformations/src/transformations/symbolic_transformations/symbol_optimization.cpp b/src/common/transformations/src/transformations/symbolic_transformations/symbol_optimization.cpp index 55f0794e0ee008..a7e8c5044ff111 100644 --- a/src/common/transformations/src/transformations/symbolic_transformations/symbol_optimization.cpp +++ b/src/common/transformations/src/transformations/symbolic_transformations/symbol_optimization.cpp @@ -18,6 +18,7 @@ #include "openvino/op/squeeze.hpp" #include "openvino/op/util/multi_subgraph_base.hpp" #include "openvino/op/util/op_types.hpp" +#include "transformations/symbolic_transformations/utils.hpp" #include "transformations/utils/utils.hpp" namespace { @@ -84,27 +85,28 @@ int64_t get_idx_of_symbol_in_source(const ov::Output& source, const st } ov::Output alternative_source_from_existing_value(const std::shared_ptr& symbol, - const ov::Output& original_output, + const ov::Shape& original_shape, + const ov::element::Type& original_et, + const std::shared_ptr& node_to_copy_rt_info, STS_map& symbol_value_source) { auto alternative_source = ov::Output(); if (symbol_value_source.count(symbol)) { alternative_source = symbol_value_source[symbol]; - const auto &original_shape = original_output.get_shape(), &alternative_shape = alternative_source.get_shape(); - const auto &original_et = original_output.get_element_type(), - &alternative_et = alternative_source.get_element_type(); + const auto& alternative_shape = alternative_source.get_shape(); + const auto& alternative_et = alternative_source.get_element_type(); if (alternative_shape != original_shape && (original_shape.empty() || original_shape == ov::Shape{0})) { auto squeeze = std::make_shared(alternative_source); - ov::copy_runtime_info(original_output.get_node_shared_ptr(), squeeze); + ov::copy_runtime_info(node_to_copy_rt_info, squeeze); alternative_source = squeeze->output(0); } else if (alternative_shape != original_shape) { auto shape = ov::op::v0::Constant::create(ov::element::i64, {original_shape.size()}, original_shape); auto reshape = std::make_shared(alternative_source, shape, false); - ov::copy_runtime_info(original_output.get_node_shared_ptr(), reshape); + ov::copy_runtime_info(node_to_copy_rt_info, reshape); alternative_source = reshape->output(0); } if (alternative_et != original_et) { auto convert = std::make_shared(alternative_source, original_et); - ov::copy_runtime_info(original_output.get_node_shared_ptr(), convert); + ov::copy_runtime_info(node_to_copy_rt_info, convert); alternative_source = convert->output(0); } } @@ -113,7 +115,9 @@ ov::Output alternative_source_from_existing_value(const std::shared_pt ov::Output alternative_source_from_shape_source(const STS_map& symbol_shape_source, const std::shared_ptr& symbol, - const ov::Output& original_output, + const ov::Shape& original_shape, + const ov::element::Type& original_et, + const std::shared_ptr& node_to_copy_rt_info, STS_map& symbol_value_source) { auto alternative_source = ov::Output(); if (symbol_shape_source.count(symbol)) { @@ -122,39 +126,61 @@ ov::Output alternative_source_from_shape_source(const STS_map& symbol_ const int64_t& idx = get_idx_of_symbol_in_source(source, symbol); if (idx == -1) return alternative_source; - const auto& original_et = original_output.get_element_type(); std::shared_ptr shape; if (original_et == ov::element::i32 || original_et == ov::element::i64) { shape = std::make_shared(source, original_et); } else { shape = std::make_shared(source); - ov::copy_runtime_info(original_output.get_node_shared_ptr(), shape); + ov::copy_runtime_info(node_to_copy_rt_info, shape); shape = std::make_shared(shape, original_et); } - auto indices = ov::op::v0::Constant::create(ov::element::i64, original_output.get_shape(), {idx}); + auto indices = ov::op::v0::Constant::create(ov::element::i64, original_shape, {idx}); auto axis = ov::op::v0::Constant::create(ov::element::i64, {}, {0}); auto gather = std::make_shared(shape, indices, axis); - ov::copy_runtime_info(original_output.get_node_shared_ptr(), {shape, indices, axis, gather}); + ov::copy_runtime_info(node_to_copy_rt_info, {shape, indices, axis, gather}); alternative_source = gather; symbol_value_source[symbol] = alternative_source; } return alternative_source; } -ov::Output get_alternative_source_from_value_or_shape_source(const STS_map& symbol_shape_source, - const std::shared_ptr& symbol, - const ov::Output& original_output, - STS_map& symbol_value_source) { +ov::Output get_alternative_source_from_value_or_shape_source( + const STS_map& symbol_shape_source, + const std::shared_ptr& symbol, + const ov::Shape& original_shape, + const ov::element::Type& original_et, + const std::shared_ptr& node_to_copy_rt_info, + STS_map& symbol_value_source) { auto alternative_source = ov::Output(); if (symbol == nullptr) return alternative_source; - alternative_source = alternative_source_from_existing_value(symbol, original_output, symbol_value_source); + alternative_source = alternative_source_from_existing_value(symbol, + original_shape, + original_et, + node_to_copy_rt_info, + symbol_value_source); if (!alternative_source.get_node_shared_ptr()) - alternative_source = - alternative_source_from_shape_source(symbol_shape_source, symbol, original_output, symbol_value_source); + alternative_source = alternative_source_from_shape_source(symbol_shape_source, + symbol, + original_shape, + original_et, + node_to_copy_rt_info, + symbol_value_source); return alternative_source; } +ov::Output get_alternative_source_from_value_or_shape_source(const STS_map& symbol_shape_source, + const std::shared_ptr& symbol, + const ov::Output& original_output, + STS_map& symbol_value_source) { + return get_alternative_source_from_value_or_shape_source(symbol_shape_source, + symbol, + original_output.get_shape(), + original_output.get_element_type(), + original_output.get_node_shared_ptr(), + symbol_value_source); +} + ov::Output alternative_source_from_concat_input_sources(const STS_map& symbol_shape_source, const std::shared_ptr& symbol, const ov::Output& original_output, @@ -198,7 +224,9 @@ ov::Output alternative_source_from_concat_input_sources(const STS_map& return alternative_source; } -void optimize_value_usage(ov::Output& output, STS_map& symbol_shape_source, STS_map& symbol_value_source) { +void optimize_single_value_usage(ov::Output& output, + STS_map& symbol_shape_source, + STS_map& symbol_value_source) { auto value_symbols = output.get_tensor().get_value_symbol(); if (value_symbols.size() != 1) return; @@ -316,16 +344,16 @@ std::vector> topological_order(const std::shared_ptr& op, STS_map& symbol_shape_source) { - if (ov::is_type(op) || ov::is_type(op)) { - const auto& output = op->input_value(0); + const auto is_shape_of = ov::is_type(op); + const auto is_parameter = ov::is_type(op); + if (is_shape_of || is_parameter) { + const auto& output = is_shape_of ? op->input_value(0) : op->output(0); if (output.get_partial_shape().rank().is_dynamic()) return; for (const auto& d : output.get_partial_shape()) { - if (d.is_static()) - continue; - auto symbol = d.get_symbol(); - if (symbol == nullptr) + if (d.is_static() || d.get_symbol() == nullptr) continue; + auto symbol = ov::symbol::ancestor_of(d.get_symbol()); if (symbol_shape_source.count(symbol)) continue; symbol_shape_source[symbol] = output; @@ -344,11 +372,9 @@ void save_shape_sources(const std::shared_ptr& op, STS_map& symbol_sha if (input.get_partial_shape().rank().is_dynamic()) continue; const auto dimension = input.get_partial_shape()[axis]; - if (dimension.is_static()) - continue; - auto symbol = dimension.get_symbol(); - if (symbol == nullptr) + if (dimension.is_static() || dimension.get_symbol() == nullptr) continue; + auto symbol = ov::symbol::ancestor_of(dimension.get_symbol()); if (symbol_shape_source.count(symbol)) continue; symbol_shape_source[symbol] = input; @@ -402,27 +428,73 @@ struct OutputValue { } }; -void save_and_update_value_sources(const std::shared_ptr& op, - std::map>& multi_symbol_source) { - for (auto& output : op->outputs()) { - if (output.get_tensor().get_value_symbol().size() < 2) - continue; // singular values are handled by optimize_value_usage helper - - if (auto result = OutputValue::make(output)) { - if (multi_symbol_source.count(*result)) { - auto alternative_source = multi_symbol_source[*result]; - if (output.get_element_type() != alternative_source.get_element_type()) { - auto convert = std::make_shared(alternative_source, output.get_element_type()); - ov::copy_runtime_info(output.get_node_shared_ptr(), convert); - alternative_source = convert->output(0); - } - if (output.get_partial_shape().is_dynamic() || - output.get_partial_shape() != alternative_source.get_partial_shape()) - continue; - output.replace(alternative_source); +void optimize_multi_value_usage(ov::Output& output, + std::map>& multi_symbol_source, + STS_map& symbol_shape_source, + STS_map& symbol_value_source) { + if (output.get_tensor().get_value_symbol().size() < 2) + return; // singular values are handled by optimize_single_value_usage helper + const auto result = OutputValue::make(output); + if (!result) + return; + if (multi_symbol_source.count(*result)) { + // multiple value source have been seen before + auto alternative_source = multi_symbol_source[*result]; + if (output.get_element_type() != alternative_source.get_element_type()) { + auto convert = std::make_shared(alternative_source, output.get_element_type()); + ov::copy_runtime_info(output.get_node_shared_ptr(), convert); + alternative_source = convert->output(0); + } + if (output.get_partial_shape() != alternative_source.get_partial_shape()) { + const auto& shape = ov::op::v0::Constant::create(ov::element::i32, + ov::Shape{output.get_shape().size()}, + output.get_shape()); + alternative_source = std::make_shared(alternative_source, shape, false)->output(0); + } + output.replace(alternative_source); + } else { + // new instance of multiple value source + ov::OutputVector to_be_concated; + for (const auto& el : result->value) { + if (el.is()) { + const auto& value = el.as(); + const auto& constant = ov::op::v0::Constant::create(output.get_element_type(), ov::Shape{1}, {value}); + to_be_concated.push_back(constant->output(0)); + } else if (el.is>()) { + const auto& symbol = el.as>(); + auto alternative_output = + get_alternative_source_from_value_or_shape_source(symbol_shape_source, + symbol, + ov::Shape{1}, + output.get_element_type(), + output.get_node_shared_ptr(), + symbol_value_source); + if (alternative_output.get_node_shared_ptr()) + to_be_concated.push_back(alternative_output); + else + break; } else { - multi_symbol_source[*result] = output; + break; + } + } + if (to_be_concated.size() != ov::shape_size(output.get_shape())) { + multi_symbol_source[*result] = output; + } else { + auto alternative_output = std::make_shared(to_be_concated, 0)->output(0); + ov::copy_runtime_info(output.get_node_shared_ptr(), alternative_output.get_node_shared_ptr()); + if (output.get_partial_shape() != alternative_output.get_partial_shape()) { + alternative_output = std::make_shared( + alternative_output, + ov::op::v0::Constant::create(ov::element::i32, + ov::Shape{output.get_shape().size()}, + output.get_shape()), + false) + ->output(0); + ov::copy_runtime_info(output.get_node_shared_ptr(), alternative_output.get_node_shared_ptr()); } + ov::util::evaluate_both_bounds(alternative_output); + output.replace(alternative_output); + multi_symbol_source[*result] = alternative_output; } } } @@ -443,13 +515,14 @@ bool ov::pass::OptimizeSymbolsUsedAsValues::run_on_model(const std::shared_ptr(op)) continue; - // LTS maps aren't shared with sub-graphs because inner graph can not access outer graph for label sources + // LTS maps aren't shared with sub-graphs because inner graph can not access outer graph for symbol sources ov::op::util::process_subgraph(*this, op); - for (auto& output : op->outputs()) - optimize_value_usage(output, symbol_shape_source, symbol_value_source); + for (auto& output : op->outputs()) { + optimize_single_value_usage(output, symbol_shape_source, symbol_value_source); + optimize_multi_value_usage(output, multi_symbol_source, symbol_shape_source, symbol_value_source); + } save_shape_sources(op, symbol_shape_source); - save_and_update_value_sources(op, multi_symbol_source); } return true; } diff --git a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp index a42e11120d7276..e93eae340713d7 100644 --- a/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp +++ b/src/common/transformations/tests/common_optimizations/fuse_rotary_positional_embeddings.cpp @@ -13,6 +13,7 @@ #include "openvino/opsets/opset3.hpp" #include "ov_ops/rotary_positional_embeddings.hpp" #include "ov_ops/type_relaxed.hpp" +#include "transformations/symbolic_transformations/symbolic_optimizations.hpp" #include "transformations/utils/gen_pattern.hpp" using namespace testing; @@ -124,6 +125,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_LLama2_no_gather) { const size_t num_head = 32; model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, false); + manager.register_pass(); manager.register_pass(); { @@ -159,6 +161,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_LLama2_with_gather) { const size_t num_head = 32; model = buildROPE_Llama2(batch, seq_length, max_position_embeddings, ndims, true); + manager.register_pass(); manager.register_pass(); { @@ -300,6 +303,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_no_gather) { const int max_position_embeddings = 2048; model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, false); + manager.register_pass(); manager.register_pass(); { auto input = @@ -335,6 +339,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTNEOX_with_gather) { const int max_position_embeddings = 2048; model = buildROPE_GPTNEOX(batch, seq_len, max_position_embeddings, ndims, num_heads, rotary_ndims, true); + manager.register_pass(); manager.register_pass(); { auto cos_sin = makeCosSinCache(max_position_embeddings, rotary_ndims); @@ -456,6 +461,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ) { model = std::make_shared(ov::NodeVector{permute_Transpose_828}, ov::ParameterVector{input, gather_sin_cos}); } + manager.register_pass(); manager.register_pass(); { auto input = @@ -643,6 +649,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_chatGML_Slice) { model = std::make_shared(ov::NodeVector{cat_Concat}, ov::ParameterVector{input, seq_length, cos_sin_cache}); } + manager.register_pass(); manager.register_pass(); { auto input = std::make_shared(ov::element::f32, ov::Shape{seq_len, batch, 4608}); @@ -728,6 +735,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_GPTJ_Slice) { model = std::make_shared(ov::NodeVector{permute_Transpose}, ov::ParameterVector{input, gather_sin_cos}); } + manager.register_pass(); manager.register_pass(); { auto input = @@ -1007,6 +1015,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_Flux_mul) { model = std::make_shared(ov::NodeVector{y}, ov::ParameterVector{x, t_cos, t_sin}); } + manager.register_pass(); manager.register_pass(true); { auto x = @@ -1061,6 +1070,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_Flux_squeeze_mul_unsqueeze) { model = std::make_shared(ov::NodeVector{y}, ov::ParameterVector{x, t_cos, t_sin}); } + manager.register_pass(); manager.register_pass(true); { auto x = @@ -1115,6 +1125,7 @@ TEST_F(TransformationTestsF, ConvertToROPE_Flux_mul_squeeze_unsqueeze) { model = std::make_shared(ov::NodeVector{y}, ov::ParameterVector{x, t_cos, t_sin}); } + manager.register_pass(); manager.register_pass(true); { auto x = diff --git a/src/common/transformations/tests/symbolic_transformations/symbol_optimization.cpp b/src/common/transformations/tests/symbolic_transformations/symbol_optimization.cpp index e4653ec084bafb..2070a2bce7d349 100644 --- a/src/common/transformations/tests/symbolic_transformations/symbol_optimization.cpp +++ b/src/common/transformations/tests/symbolic_transformations/symbol_optimization.cpp @@ -19,7 +19,6 @@ #include "openvino/pass/visualize_tree.hpp" #include "transformations/common_optimizations/shared_ops_optimization.hpp" #include "transformations/symbolic_transformations/symbolic_optimizations.hpp" -#include "transformations/symbolic_transformations/utils.hpp" using namespace ov; using namespace ov::op; @@ -174,7 +173,7 @@ TEST_F(TransformationTestsF, ValueOptimizationDoubleValue) { auto input = make_shared(element::f32, PartialShape::dynamic(4)); auto dim_0 = get_dim_by_idx(input, {-1, -2}, element::i64); - auto dim_1 = get_dim_by_idx(input, {3, 2}, element::i32); + auto dim_1 = get_dim_by_idx(input, {3, 2}, element::i64); auto reshape_0 = make_shared( input, @@ -182,28 +181,25 @@ TEST_F(TransformationTestsF, ValueOptimizationDoubleValue) { false); auto reshape_1 = make_shared( input, - make_shared(OutputVector{v0::Constant::create(element::i32, {1}, {0}), dim_1}, 0), + make_shared(OutputVector{v0::Constant::create(element::i64, {1}, {0}), dim_1}, 0), false); model = make_shared(NodeVector{reshape_0, reshape_1}, ParameterVector{input}); manager.set_per_pass_validation(false); - manager.register_pass(); - manager.register_pass(); - manager.register_pass(); + manager.register_pass(); } { auto input = make_shared(element::f32, PartialShape::dynamic(4)); - auto dim_0 = get_dim_by_idx(input, {3, 2}, element::i32); - auto dim_1 = std::make_shared(dim_0, element::i64); + auto dim_0 = get_dim_by_idx(input, {3, 2}, element::i64); auto reshape_0 = make_shared( input, - make_shared(OutputVector{v0::Constant::create(element::i64, {1}, {-1}), dim_1}, 0), + make_shared(OutputVector{v0::Constant::create(element::i64, {1}, {-1}), dim_0}, 0), false); auto reshape_1 = make_shared( input, - make_shared(OutputVector{v0::Constant::create(element::i32, {1}, {0}), dim_0}, 0), + make_shared(OutputVector{v0::Constant::create(element::i64, {1}, {0}), dim_0}, 0), false); model_ref = make_shared(NodeVector{reshape_0, reshape_1}, ParameterVector{input}); @@ -216,7 +212,7 @@ TEST_F(TransformationTestsF, ValueOptimizationSymbolAndValue) { auto input = make_shared(element::f32, PartialShape({-1, -1, 4, -1})); auto dim_0 = get_dim_by_idx(input, {-1, -2}, element::i64); - auto dim_1 = get_dim_by_idx(input, {3, 2}, element::i32); + auto dim_1 = get_dim_by_idx(input, {3, 2}, element::i64); auto reshape_0 = make_shared( input, @@ -224,7 +220,7 @@ TEST_F(TransformationTestsF, ValueOptimizationSymbolAndValue) { false); auto reshape_1 = make_shared( input, - make_shared(OutputVector{v0::Constant::create(element::i32, {1}, {-1}), dim_1}, 0), + make_shared(OutputVector{v0::Constant::create(element::i64, {1}, {-1}), dim_1}, 0), false); model = make_shared(NodeVector{reshape_0, reshape_1}, ParameterVector{input}); @@ -236,12 +232,12 @@ TEST_F(TransformationTestsF, ValueOptimizationSymbolAndValue) { } { auto input = make_shared(element::f32, PartialShape({-1, -1, 4, -1})); - auto dim_0 = make_shared( - OutputVector{v0::Constant::create(element::i32, {1}, {-1}), get_dim_by_idx(input, {3, 2}, element::i32)}, - 0); - auto dim_1 = std::make_shared(dim_0, element::i64); + auto dim_0 = make_shared(OutputVector{v0::Constant::create(element::i64, {1}, {-1}), + get_dim_by_idx(input, {3}, element::i64), + v0::Constant::create(element::i64, {1}, {4})}, + 0); - auto reshape_0 = make_shared(input, dim_1, false); + auto reshape_0 = make_shared(input, dim_0, false); auto reshape_1 = make_shared(input, dim_0, false); model_ref = make_shared(NodeVector{reshape_0, reshape_1}, ParameterVector{input}); diff --git a/src/core/src/pass/sdpa_to_paged_attention.cpp b/src/core/src/pass/sdpa_to_paged_attention.cpp index e6fc744bb5ef4f..912f7bb7f37d7c 100644 --- a/src/core/src/pass/sdpa_to_paged_attention.cpp +++ b/src/core/src/pass/sdpa_to_paged_attention.cpp @@ -25,7 +25,7 @@ ov::pass::SDPAToPagedAttention::SDPAToPagedAttention(bool use_block_indices_inpu m_use_score_outputs(use_score_outputs) {} static std::shared_ptr setName(std::shared_ptr node, const char* name) { - // Set name for both node and output tensor (should be only one tensor, and any other names will be overriden by a + // Set name for both node and output tensor (should be only one tensor, and any other names will be overridden by a // given single name) node->set_friendly_name(name); OPENVINO_ASSERT(node->get_output_size() == 1); @@ -149,19 +149,25 @@ bool ov::pass::SDPAToPagedAttention::run_on_model(const std::shared_ptroutput(0).get_target_inputs(); + if (!strcmp(param_name, "attention_mask") && target_inputs.size() == 1 && + ov::is_type(target_inputs.begin()->get_node())) { + target_inputs.begin()->replace_source_output(unsqueezed_input_ids->output(0)); + target_inputs = param->output(0).get_target_inputs(); + } model->remove_parameter(param); - if (param->output(0).get_target_inputs().size() == 0) { + if (!target_inputs.empty()) { std::stringstream consumers; consumers << std::endl; - for (auto& input : param->output(0).get_target_inputs()) { + for (auto& input : target_inputs) { consumers << *input.get_node() << std::endl; } - OPENVINO_ASSERT(param->output(0).get_target_inputs().size() == 0, + OPENVINO_ASSERT(target_inputs.empty(), "PagedAttention transformation failed: couldn't remove ", - param->output(0).get_target_inputs().size(), + target_inputs.size(), " inputs of ", param_name, " input: ", diff --git a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.cpp b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.cpp index 5f3058429a8497..61fa875df603cf 100644 --- a/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.cpp +++ b/src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.cpp @@ -123,8 +123,9 @@ CausalMaskPreprocess::CausalMaskPreprocess() { auto ShapeOf_49034 = makePattern({attention_mask}); // tensor_array auto Gather_41642 = makePattern({ShapeOf_49034, {1}, 0}, {{"batch_dims", 0}}); // tensor_array - auto ScatterUpdate_93502 = - makePattern({{0, 0, 0, 0}, {3}, Gather_41642, {0}}); // tensor_array + auto alternative_concat = makePattern({{0}, {0}, {0}, Gather_41642}, {{"axis", 0}}); + auto ScatterUpdate_93502 = makePattern({{0, 0, 0, 0}, {3}, Gather_41642, {0}}) | + alternative_concat; // tensor_array auto SliceAssign_201_Slice = makePattern({SliceAssign_201_Reshape, {0}, Gather_41642, {1}, {3}}); auto SliceAssign_201_StridedSlice = GenStridedSlice(SliceAssign_201_Reshape, {0, 0, 0, 0}, @@ -184,8 +185,9 @@ CausalMaskPreprocess::CausalMaskPreprocess() { auto SliceAssign_201_Reshape_3 = makePattern({SliceAssign_201_ScatterNDUpdate, {-1, 1, max_seq_len, max_seq_len}}, {{"special_zero", true}}); // tensor_array - auto ScatterUpdate_93554 = - makePattern({{0, 0, 0, 0}, {3}, kvLen, {0}}); // tensor_array + auto alternative_concat_1 = makePattern({{0}, {0}, {0}, Gather_41642}, {{"axis", 0}}); + auto ScatterUpdate_93554 = makePattern({{0, 0, 0, 0}, {3}, kvLen, {0}}) | + alternative_concat_1; // tensor_array auto slice_StridedSlice_14 = GenStridedSlice(SliceAssign_201_Reshape_3, {0, 0, 0, 0}, ScatterUpdate_93554,