From 7e04427b76964fa2a637cd20bf7b2c0ba9b0e97e Mon Sep 17 00:00:00 2001 From: Vladislav Golubev Date: Mon, 18 Nov 2024 19:15:48 +0100 Subject: [PATCH] Cover SplitDimensionM heuristic by unit tests --- .../tests/include/utils/split_dim_m.hpp | 37 ++++++++++ .../snippets/tests/src/utils/split_dim_m.cpp | 71 +++++++++++++++++++ 2 files changed, 108 insertions(+) create mode 100644 src/common/snippets/tests/include/utils/split_dim_m.hpp create mode 100644 src/common/snippets/tests/src/utils/split_dim_m.cpp diff --git a/src/common/snippets/tests/include/utils/split_dim_m.hpp b/src/common/snippets/tests/include/utils/split_dim_m.hpp new file mode 100644 index 00000000000000..3e04c2a911d76a --- /dev/null +++ b/src/common/snippets/tests/include/utils/split_dim_m.hpp @@ -0,0 +1,37 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#pragma once + +#include + +namespace ov { +namespace test { +namespace snippets { + +struct InputData { + size_t cur_batch; + size_t cur_m; + size_t concurrency; +}; + +struct ReferenceData { + bool is_split; + size_t batch_m; + size_t kernel_m; +}; + +struct SplitDimensionMParams { + InputData input; + ReferenceData reference; +}; + +class SplitDimensionMTest : public testing::TestWithParam { +public: + static std::string getTestCaseName(testing::TestParamInfo obj); +}; + +} // namespace snippets +} // namespace test +} // namespace ov diff --git a/src/common/snippets/tests/src/utils/split_dim_m.cpp b/src/common/snippets/tests/src/utils/split_dim_m.cpp new file mode 100644 index 00000000000000..69a04da6f1263f --- /dev/null +++ b/src/common/snippets/tests/src/utils/split_dim_m.cpp @@ -0,0 +1,71 @@ +// Copyright (C) 2024 Intel Corporation +// SPDX-License-Identifier: Apache-2.0 +// + +#include "utils/split_dim_m.hpp" + +#include "common_test_utils/ov_test_utils.hpp" +#include "snippets/pass/split_dimension_m.hpp" +#include "snippets/utils/utils.hpp" + +namespace ov { +namespace test { +namespace snippets { + +std::string SplitDimensionMTest::getTestCaseName(testing::TestParamInfo obj) { + const auto& input = obj.param.input; + const auto& reference = obj.param.reference; + std::ostringstream result; + result << "Batch=" << input.cur_batch << "_"; + result << "CurM=" << input.cur_m << "_"; + result << "OptimalParallelWorkAmount=" << input.concurrency << "_"; + result << "IsSplit=" << reference.is_split << "_"; + result << "BatchM=" << reference.batch_m << "_"; + result << "KernelM=" << reference.kernel_m; + return result.str(); +} + +TEST_P(SplitDimensionMTest, SplitDimensionM) { + const auto& input = GetParam().input; + const auto& reference = GetParam().reference; + + // last_dim is fixed since it doesn't affect the SplitDimensionM result. + static const size_t last_dim = 1024; + ov::Shape shape = {input.cur_batch, input.cur_m, last_dim}; + size_t batch_m_dim, new_m_dim; + bool result = ov::snippets::pass::SplitDimensionM::split(shape, + input.concurrency, + batch_m_dim, + new_m_dim); + + ASSERT_EQ(result, reference.is_split); + if (result) { + ASSERT_EQ(batch_m_dim, reference.batch_m); + ASSERT_EQ(new_m_dim, reference.kernel_m); + } +} + +namespace SplitDimensionMInstantiation { +const std::vector split_dimension_cases = { + // Negative test cases: split is not needed + {InputData{40 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{false /*is_split*/}}, + {InputData{65, 32, 40}, ReferenceData{false}}, + + // Positive test cases + {InputData{20 /*cur_batch*/, 32 /*cur_m*/, 40 /*concurrency*/}, ReferenceData{true /*is_split*/, 2 /*batch_m*/, 16 /*kernel_m*/}}, + {InputData{30, 60, 40}, ReferenceData{true, 2, 30}}, + {InputData{10, 100, 40}, ReferenceData{true, 4, 25}}, + {InputData{15, 45, 40}, ReferenceData{true, 5, 9}}, + {InputData{25, 50, 40}, ReferenceData{true, 2, 25}}, + {InputData{5, 16384, 40}, ReferenceData{true, 8, 2048}}, +}; + +INSTANTIATE_TEST_SUITE_P(smoke_Snippets_SplitDimensionM, + SplitDimensionMTest, + ::testing::ValuesIn(split_dimension_cases), + SplitDimensionMTest::getTestCaseName); + +} // namespace SplitDimensionMInstantiation +} // namespace snippets +} // namespace test +} // namespace ov \ No newline at end of file