diff --git a/src/common/snippets/src/pass/split_dimension_m.cpp b/src/common/snippets/src/pass/split_dimension_m.cpp index 0f50ad27931e04..a263fb8de0a87a 100644 --- a/src/common/snippets/src/pass/split_dimension_m.cpp +++ b/src/common/snippets/src/pass/split_dimension_m.cpp @@ -34,23 +34,23 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr& std::pair SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) { std::pair splited = { 1, m_dim }; - const size_t lower_bound = optimal_parallelism_work_amount / batch_dim; - if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) { - splited.first = lower_bound; - splited.second = m_dim / lower_bound; - OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); - return splited; - } - - const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim); - for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) { - size_t divisor_1 = m_dim / divisor_0; - if (divisor_1 * divisor_0 == m_dim) { - splited.first = divisor_0; - splited.second = divisor_1; - break; + // TODO: should we limit minimal kernel_m? + const size_t min_kernel_m = 4; + // Strategy 1: Find a combination such that (batch_dim * splited.first) % optimal_parallelism_work_amount == 0 + for (size_t divisor = 1; divisor <= m_dim; ++divisor) { + if (m_dim % divisor == 0) { + const auto m_batch = divisor; + const auto m_kernel = m_dim / divisor; + if (m_kernel < min_kernel_m) + break; + splited = { m_batch, m_kernel }; + if ((batch_dim * splited.first) % optimal_parallelism_work_amount == 0) { + OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); + return splited; + } } } + OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!"); return splited; } diff --git a/src/common/snippets/tests/src/pass/mha_tokenization.cpp b/src/common/snippets/tests/src/pass/mha_tokenization.cpp index 040982feb4e0ec..65f2bb3d51f127 100644 --- a/src/common/snippets/tests/src/pass/mha_tokenization.cpp +++ b/src/common/snippets/tests/src/pass/mha_tokenization.cpp @@ -175,7 +175,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) { #endif const auto& f = MHASplitMFunction(std::vector{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}}, std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{2, 64, 12, 64}, {128, 12, 1, 64}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}}, + std::vector{{4, 32, 12, 64}, {128, 12, 1, 64}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}}, true); model = f.getOriginal(); model_ref = f.getReference(); @@ -186,7 +186,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) { TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) { const auto& f = MHASplitMFunction(std::vector{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{1, 6, 64, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}}, + std::vector{{1, 96, 4, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}}, false); model = f.getOriginal(); model_ref = f.getReference(); @@ -201,7 +201,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) { #endif const auto& f = MHASplitMFunction(std::vector{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}}, std::vector({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{1, 6, 64, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}}, + std::vector{{1, 96, 4, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}}, true); model = f.getOriginal(); model_ref = f.getReference(); @@ -212,7 +212,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) { TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) { const auto& f = MHAWOTransposeSplitMFunction(std::vector{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}}, std::vector({ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{10, 3, 3072, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}}); + std::vector{{10, 9, 1024, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}}); model = f.getOriginal(); model_ref = f.getReference(); config.set_concurrency(18); @@ -220,9 +220,9 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) { } TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) { - const auto& f = MHAWOTransposeSplitMFunction(std::vector{{5, 30, 32}, {5, 32, 30}, {5, 30, 32}}, + const auto& f = MHAWOTransposeSplitMFunction(std::vector{{5, 60, 32}, {5, 32, 30}, {5, 30, 32}}, std::vector({ov::element::f32, ov::element::f32, ov::element::f32}), - std::vector{{5, 10, 3, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 30, 32}}); + std::vector{{5, 15, 4, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 60, 32}}); model = f.getOriginal(); model_ref = f.getReference(); config.set_concurrency(32); diff --git a/src/common/snippets/tests/src/utils/split_dim_m.cpp b/src/common/snippets/tests/src/utils/split_dim_m.cpp index 69a04da6f1263f..db574a38f54685 100644 --- a/src/common/snippets/tests/src/utils/split_dim_m.cpp +++ b/src/common/snippets/tests/src/utils/split_dim_m.cpp @@ -48,16 +48,15 @@ TEST_P(SplitDimensionMTest, SplitDimensionM) { namespace SplitDimensionMInstantiation { const std::vector split_dimension_cases = { // Negative test cases: split is not needed - {InputData{40 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{false /*is_split*/}}, - {InputData{65, 32, 40}, ReferenceData{false}}, + {InputData{32 /*cur_batch*/, 32 /*cur_m*/, 32 /*concurrency*/}, ReferenceData{false /*is_split*/}}, + {InputData{50, 32, 32}, ReferenceData{false}}, // Positive test cases - {InputData{20 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{true /*is_split*/, 2 /*batch_m*/, 16 /*kernel_m*/}}, - {InputData{30, 60, 40}, ReferenceData{true, 2, 30}}, - {InputData{10, 100, 40}, ReferenceData{true, 4, 25}}, - {InputData{15, 45, 40}, ReferenceData{true, 5, 9}}, - {InputData{25, 50, 40}, ReferenceData{true, 2, 25}}, - {InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}}, + {InputData{20 /*cur_batch*/, 32 /*cur_m*/, 32 /*concurrency*/}, ReferenceData{true /*is_split*/, 8 /*batch_m*/, 4 /*kernel_m*/}}, + {InputData{16, 60, 32}, ReferenceData{true, 2, 30}}, + {InputData{10, 100, 32}, ReferenceData{true, 25, 4}}, + {InputData{25, 50, 32}, ReferenceData{true, 10, 5}}, + {InputData{5, 16384, 32}, ReferenceData{true, 32, 512}}, }; INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM,