Skip to content

Commit

Permalink
[WIP] Change splitM heuristic
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Nov 19, 2024
1 parent 7e04427 commit fb62330
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 29 deletions.
30 changes: 15 additions & 15 deletions src/common/snippets/src/pass/split_dimension_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -34,23 +34,23 @@ bool SplitDimensionM::is_supported_matmul(const std::shared_ptr<const ov::Node>&
std::pair<size_t, size_t> SplitDimensionM::get_splited_dimensions(size_t batch_dim, size_t m_dim, size_t optimal_parallelism_work_amount) {
std::pair<size_t, size_t> splited = { 1, m_dim };

const size_t lower_bound = optimal_parallelism_work_amount / batch_dim;
if (lower_bound * batch_dim == optimal_parallelism_work_amount && m_dim % lower_bound == 0) {
splited.first = lower_bound;
splited.second = m_dim / lower_bound;
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}

const size_t upper_bound = utils::div_up(2 * optimal_parallelism_work_amount, batch_dim);
for (size_t divisor_0 = upper_bound - 1; divisor_0 > 1; divisor_0--) {
size_t divisor_1 = m_dim / divisor_0;
if (divisor_1 * divisor_0 == m_dim) {
splited.first = divisor_0;
splited.second = divisor_1;
break;
// TODO: should we limit minimal kernel_m?
const size_t min_kernel_m = 4;
// Strategy 1: Find a combination such that (batch_dim * splited.first) % optimal_parallelism_work_amount == 0
for (size_t divisor = 1; divisor <= m_dim; ++divisor) {
if (m_dim % divisor == 0) {
const auto m_batch = divisor;
const auto m_kernel = m_dim / divisor;
if (m_kernel < min_kernel_m)
break;
splited = { m_batch, m_kernel };
if ((batch_dim * splited.first) % optimal_parallelism_work_amount == 0) {
OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
}
}

OPENVINO_ASSERT(splited.first * splited.second == m_dim, "Incorrect dimension M splitting!");
return splited;
}
Expand Down
12 changes: 6 additions & 6 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
#endif
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{128, 12, 64}, {128, 12, 64}, {12, 128, 128}, {128, 12, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{2, 64, 12, 64}, {128, 12, 1, 64}, {12, 2, 64, 128}, {1, 128, 12, 64}, {128, 12, 64}},
std::vector<Shape>{{4, 32, 12, 64}, {128, 12, 1, 64}, {12, 4, 32, 128}, {1, 128, 12, 64}, {128, 12, 64}},
true);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -186,7 +186,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA3D_SplitM_withMul) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM) {
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
std::vector<Shape>{{1, 96, 4, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
false);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -201,7 +201,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
#endif
const auto& f = MHASplitMFunction(std::vector<PartialShape>{{1, 384, 16, 64}, {1, 384, 16, 64}, {1, 1, 1, 384}, {1, 384, 16, 64}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{1, 6, 64, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
std::vector<Shape>{{1, 96, 4, 16, 64}, {1, 384, 16, 1, 64}, {1, 1, 1, 1, 384}, {1, 1, 384, 16, 64}, {1, 384, 16, 64}},
true);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -212,17 +212,17 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA4D_SplitM_withMul) {
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHAWOTranspose_SplitM) {
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{10, 9216, 128}, {10, 128, 9216}, {10, 9216, 128}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{10, 3, 3072, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}});
std::vector<Shape>{{10, 9, 1024, 128}, {10, 1, 128, 9216}, {10, 1, 9216, 128}, {10, 9216, 128}});
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(18);
run();
}

TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_SplitM_AlmostAllThreads) {
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{5, 30, 32}, {5, 32, 30}, {5, 30, 32}},
const auto& f = MHAWOTransposeSplitMFunction(std::vector<PartialShape>{{5, 60, 32}, {5, 32, 30}, {5, 30, 32}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32}),
std::vector<Shape>{{5, 10, 3, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 30, 32}});
std::vector<Shape>{{5, 15, 4, 32}, {5, 1, 32, 30}, {5, 1, 30, 32}, {5, 60, 32}});
model = f.getOriginal();
model_ref = f.getReference();
config.set_concurrency(32);
Expand Down
15 changes: 7 additions & 8 deletions src/common/snippets/tests/src/utils/split_dim_m.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ TEST_P(SplitDimensionMTest, SplitDimensionM) {
namespace SplitDimensionMInstantiation {
const std::vector<SplitDimensionMParams> split_dimension_cases = {
// Negative test cases: split is not needed
{InputData{40 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{false /*is_split*/}},
{InputData{65, 32, 40}, ReferenceData{false}},
{InputData{32 /*cur_batch*/, 32 /*cur_m*/, 32 /*concurrency*/}, ReferenceData{false /*is_split*/}},
{InputData{50, 32, 32}, ReferenceData{false}},

// Positive test cases
{InputData{20 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{true /*is_split*/, 2 /*batch_m*/, 16 /*kernel_m*/}},
{InputData{30, 60, 40}, ReferenceData{true, 2, 30}},
{InputData{10, 100, 40}, ReferenceData{true, 4, 25}},
{InputData{15, 45, 40}, ReferenceData{true, 5, 9}},
{InputData{25, 50, 40}, ReferenceData{true, 2, 25}},
{InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}},
{InputData{20 /*cur_batch*/, 32 /*cur_m*/, 32 /*concurrency*/}, ReferenceData{true /*is_split*/, 8 /*batch_m*/, 4 /*kernel_m*/}},
{InputData{16, 60, 32}, ReferenceData{true, 2, 30}},
{InputData{10, 100, 32}, ReferenceData{true, 25, 4}},
{InputData{25, 50, 32}, ReferenceData{true, 10, 5}},
{InputData{5, 16384, 32}, ReferenceData{true, 32, 512}},
};

INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM,
Expand Down

0 comments on commit fb62330

Please sign in to comment.