Skip to content

Commit

Permalink
Reenable MHA tokenization tests
Browse files Browse the repository at this point in the history
  • Loading branch information
v-Golubev committed Oct 6, 2024
1 parent 99f601a commit 619b05b
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 63 deletions.
20 changes: 4 additions & 16 deletions src/common/snippets/tests/src/pass/mha_tokenization.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,6 @@ namespace ov {
namespace test {
namespace snippets {

class SKIP_TokenizeMHASnippetsTests : public TokenizeMHASnippetsTests {
public:
void SetUp() override {
GTEST_SKIP();
}
void TearDown() override{};
};

void TokenizeMHASnippetsTests::run() {
ASSERT_TRUE(model);
manager.register_pass<ov::snippets::pass::ExtractReshapesFromMHA>();
Expand Down Expand Up @@ -103,8 +95,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_MatMul0_Transpose) {
run();
}

TEST_F(SKIP_TokenizeMHASnippetsTests /* CVS-142098 */, smoke_Snippets_MHA_with_MatMul0_Transpose_Dynamic) {
GTEST_SKIP();
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_MatMul0_Transpose_Dynamic) {
const auto &f = MHAMatMul0TransposeFunction(std::vector<PartialShape>{{-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}},
std::vector<ov::element::Type>({ov::element::f32, ov::element::f32, ov::element::f32, ov::element::f32}),
false);
Expand All @@ -113,8 +104,7 @@ TEST_F(SKIP_TokenizeMHASnippetsTests /* CVS-142098 */, smoke_Snippets_MHA_with_M
run();
}

TEST_F(SKIP_TokenizeMHASnippetsTests /* CVS-114607 */, smoke_Snippets_MHA_with_int_Matmuls) {
GTEST_SKIP();
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_with_int_Matmuls) {
const auto &f = MHAINT8MatMulTypeRelaxedFunction(std::vector<PartialShape>{{1, 128, 12, 64}, {1, 128, 12, 64}, {1, 12, 128, 128}, {1, 128, 12, 64}});
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -128,8 +118,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction) {
run();
}

TEST_F(SKIP_TokenizeMHASnippetsTests /* CVS-142098 */, smoke_Snippets_MHA_Dynamic_Transpose_extraction) {
GTEST_SKIP();
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Dynamic_Transpose_extraction) {
const auto& f = MHATransposedInputFunction(std::vector<PartialShape>{{-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}}, true);
model = f.getOriginal();
model_ref = f.getReference();
Expand All @@ -144,8 +133,7 @@ TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Transpose_extraction_and_uns
run();
}

TEST_F(SKIP_TokenizeMHASnippetsTests /* CVS-142098 */, smoke_Snippets_MHA_Dynamic_Transpose_extraction_and_unsupported_existing_transpose) {
GTEST_SKIP();
TEST_F(TokenizeMHASnippetsTests, smoke_Snippets_MHA_Dynamic_Transpose_extraction_and_unsupported_existing_transpose) {
const auto& f = MHATransposedInputFunction(std::vector<PartialShape>{{-1, -1, -1, -1}, {-1, -1, -1, -1}, {-1, -1, -1, -1}}, true,
std::vector<int64_t>{0, 3, 1, 2});
model = f.getOriginal();
Expand Down
90 changes: 43 additions & 47 deletions src/tests/ov_helpers/ov_snippets_models/src/subgraph_mha.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -393,8 +393,8 @@ std::shared_ptr<ov::Model> MHASelectFunction::initOriginal() const {
// Value is equal to '1' - to avoid situation e^(-1000) / (sum(e^(-1000)) = 0/0 = NAN
auto selectConst = ov::op::v0::Constant::create(precisions[2], ov::Shape{1}, std::vector<float>{1});

float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto matMul0 = std::make_shared<ov::op::v0::MatMul>(transpose0, transpose1, transA, transB);
Expand Down Expand Up @@ -531,8 +531,8 @@ std::shared_ptr<ov::Model> MHAWOTransposeOnInputsFunction::initOriginal() const

auto transpose3Const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape({4}), std::vector<int64_t>{0, 2, 1, 3});

float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto mulConst = ov::test::utils::make_constant(precision, ov::Shape({1}));
const auto mul = std::make_shared<ov::op::v1::Multiply>(param1, mulConst);
const auto matMul0 = std::make_shared<ov::op::v0::MatMul>(param0, mul, transA, transB);
Expand All @@ -550,8 +550,8 @@ std::shared_ptr<ov::Model> MHAWOTransposeFunction::initOriginal() const {
auto param2 = std::make_shared<ov::opset1::Parameter>(precisions[2], input_shapes[2]);
ov::ParameterVector ngraphParam = {param0, param1, param2};

float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto matMul0 = std::make_shared<ov::op::v0::MatMul>(param0, param1, transA, transB);
const auto softmax = std::make_shared<ov::op::v8::Softmax>(matMul0, -1);
const auto matMul1 = std::make_shared<ov::op::v0::MatMul>(softmax, param2, transA, transB);
Expand Down Expand Up @@ -615,8 +615,8 @@ std::shared_ptr<ov::Model> MHAFQAfterMatMulFunction::initOriginal() const {
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
auto reshape1Const = ov::op::v0::Constant::create(ov::element::i64, {reshape1ConstData.size()}, reshape1ConstData);

float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto matMul0 = std::make_shared<ov::op::v0::MatMul>(transpose0, transpose1, transA, transB);
Expand Down Expand Up @@ -665,8 +665,8 @@ std::shared_ptr<ov::Model> MHAINT8MatMulFunction::initOriginal() const {
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
auto fq2 = ov::test::utils::make_fake_quantize(transpose2Param, ov::element::f32, 256, {1},
{-35.0172004}, {34.7436294}, {-35.0172004}, {34.7436294});
float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
const auto matMul0 = std::make_shared<ov::op::v0::MatMul>(transpose0, transpose1, transA, transB);
Expand Down Expand Up @@ -756,8 +756,8 @@ std::shared_ptr<ov::Model> MHAFQFunction::initOriginal() const {
const auto fq_add = ov::test::utils::make_fake_quantize(addParam, ov::element::f32, 256, {1},
{-1000}, {0}, {-1000}, {0});

float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
Expand Down Expand Up @@ -806,12 +806,12 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initOriginal() cons
auto reshape1Const = ov::op::v0::Constant::create(ov::element::i64, {reshape1ConstData.size()}, reshape1ConstData);

const auto fq_signed_params = ov::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8);
const auto fq0 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose0Param, ov::element::i8, fq_signed_params);
const auto fq1 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose1Param, ov::element::i8, fq_signed_params);
const auto fq2 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose2Param, ov::element::i8, fq_signed_params);
const auto fq0 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose0Param, ov::element::f32, fq_signed_params);
const auto fq1 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose1Param, ov::element::f32, fq_signed_params);
const auto fq2 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(transpose2Param, ov::element::f32, fq_signed_params);

float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(fq0, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(fq1, transpose1Const);
const auto matMul0 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
Expand All @@ -820,7 +820,7 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initOriginal() cons
ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(),
ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB);

const auto fq3 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params);
const auto fq3 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::f32, fq_signed_params);
const auto add = std::make_shared<op::TypeRelaxed<ov::op::v1::Add>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
Expand All @@ -833,20 +833,20 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initOriginal() cons
ov::op::TemporaryReplaceOutputType(add, element::f32).get(),
ov::op::TemporaryReplaceOutputType(deq, element::f32).get());

const auto reshape0 = std::make_shared<ov::opset1::Reshape>(add, reshape0Const, true);
const auto reshape0 = std::make_shared<ov::opset1::Reshape>(deq_mul, reshape0Const, true);
const auto softMax = std::make_shared<ov::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ov::opset1::Reshape>(softMax, reshape1Const, true);

const auto fq_unsigned_params = ov::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8);
const auto fq4 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params);
const auto fq4 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::f32, fq_unsigned_params);

const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(fq2, transpose2Const);
const auto matMul1 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(),
ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB);
const auto fq5 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params);
const auto fq5 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::f32, fq_signed_params);
const auto transpose3 = std::make_shared<ov::op::v1::Transpose>(fq5, transpose3Const);

ov::ResultVector results{std::make_shared<ov::opset1::Result>(transpose3)};
Expand All @@ -860,9 +860,9 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initReference() con
ov::ParameterVector ngraphParams = {data0, data1, data2, data3};

const auto fq_signed_params = ov::builder::subgraph::FakeQuantizeOnData(256, {1}, {-36912.66015625}, {36624.28125}, {-128}, {127}, ov::element::i8);
const auto fq0 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(data0, ov::element::i8, fq_signed_params);
const auto fq1 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(data1, ov::element::i8, fq_signed_params);
const auto fq2 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(data3, ov::element::i8, fq_signed_params);
const auto fq0 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(data0, ov::element::f32, fq_signed_params);
const auto fq1 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(data1, ov::element::f32, fq_signed_params);
const auto fq2 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(data3, ov::element::f32, fq_signed_params);
NodeVector subgraph_inputs = {fq0, fq1, data2, fq2};

auto transpose0Param = std::make_shared<ov::opset1::Parameter>(precision, input_shapes[0]);
Expand All @@ -877,19 +877,8 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initReference() con
auto transpose2Const = ov::op::v0::Constant::create(ov::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose3Const = ov::op::v0::Constant::create(ov::element::i64, {shape_rank}, std::vector<int64_t>{0, 2, 1, 3});

std::vector<int64_t> reshape0ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0] *
input_shapes[0].get_shape()[1] * input_shapes[0].get_shape()[2]),
-1};
auto reshape0Const = ov::op::v0::Constant::create(ov::element::i64, {reshape0ConstData.size()}, reshape0ConstData);

std::vector<int64_t> reshape1ConstData = {static_cast<int64_t>(input_shapes[0].get_shape()[0]),
static_cast<int64_t>(input_shapes[0].get_shape()[2]),
static_cast<int64_t>(input_shapes[0].get_shape()[1]),
static_cast<int64_t>(input_shapes[0].get_shape()[1])};
auto reshape1Const = ov::op::v0::Constant::create(ov::element::i64, {reshape1ConstData.size()}, reshape1ConstData);

float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto matMul0 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
Expand All @@ -898,7 +887,18 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initReference() con
ov::op::TemporaryReplaceOutputType(transpose0, element::f32).get(),
ov::op::TemporaryReplaceOutputType(transpose1, element::f32).get(), transA, transB);

const auto fq3 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul0, ov::element::i8, fq_signed_params);
auto decomposed_fq =
[](const ov::Output<ov::Node>& input, const ov::element::Type& out_precision, float il, float ih, float scale) {
const auto input_low = ov::op::v0::Constant::create(ov::element::f32, {1}, {il});
const auto input_high = ov::op::v0::Constant::create(ov::element::f32, {1}, {ih});
const auto output_scale = ov::op::v0::Constant::create(ov::element::f32, {1}, {scale});
const auto max = std::make_shared<ov::op::v1::Maximum>(input, input_low);
const auto min = std::make_shared<ov::op::v1::Minimum>(max, input_high);
const auto mul = std::make_shared<ov::op::v1::Multiply>(min, output_scale);
return std::make_shared<ov::snippets::op::ConvertSaturation>(mul, out_precision);
};

const auto fq3 = decomposed_fq(matMul0, ov::element::i8, fq_signed_params.inputLowValues[0], fq_signed_params.inputHighValues[0], 0.00346764503f);
const auto add = std::make_shared<op::TypeRelaxed<ov::op::v1::Add>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
Expand All @@ -911,20 +911,16 @@ std::shared_ptr<ov::Model> MHAINT8MatMulTypeRelaxedFunction::initReference() con
ov::op::TemporaryReplaceOutputType(add, element::f32).get(),
ov::op::TemporaryReplaceOutputType(deq, element::f32).get());

const auto reshape0 = std::make_shared<ov::opset1::Reshape>(add, reshape0Const, true);
const auto softMax = std::make_shared<ov::opset1::Softmax>(reshape0, 1);
const auto reshape1 = std::make_shared<ov::opset1::Reshape>(softMax, reshape1Const, true);

const auto fq_unsigned_params = ov::builder::subgraph::FakeQuantizeOnData(256, {1}, {0}, {0.245}, {0}, {255}, ov::element::u8);
const auto fq4 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(reshape1, ov::element::u8, fq_unsigned_params);
const auto softMax = std::make_shared<ov::opset1::Softmax>(deq_mul, 3);
const auto fq4 = decomposed_fq(softMax, ov::element::u8, 0.f, 0.245f, 1040.81628f);

const auto transpose2 = std::make_shared<ov::op::v1::Transpose>(transpose2Param, transpose2Const);
const auto matMul1 = std::make_shared<op::TypeRelaxed<op::v0::MatMul>>(
std::vector<element::Type>{ element::f32, element::f32 },
std::vector<element::Type>{ element::f32 },
ov::op::TemporaryReplaceOutputType(fq4, element::f32).get(),
ov::op::TemporaryReplaceOutputType(transpose2, element::f32).get(), transA, transB);
const auto fq5 = ov::builder::subgraph::makeFakeQuantizeTypeRelaxed(matMul1, ov::element::i8, fq_signed_params);
const auto fq5 = decomposed_fq(matMul1, ov::element::i8, fq_signed_params.inputLowValues[0], fq_signed_params.inputHighValues[0], 0.00346764503f);

auto subgraph = std::make_shared<ov::snippets::op::Subgraph>(subgraph_inputs,
std::make_shared<ov::Model>(NodeVector{fq5}, subgraph_params));
Expand All @@ -946,8 +942,8 @@ std::shared_ptr<ov::Model> MHAMulAddFunction::initOriginal() const {
auto transpose2Const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{input_shapes[2].size()}, std::vector<int64_t>{0, 2, 1, 3});
auto transpose3Const = ov::op::v0::Constant::create(ov::element::i64, ov::Shape{input_shapes[2].size()}, std::vector<int64_t>{0, 2, 1, 3});

float transA = false;
float transB = false;
bool transA = false;
bool transB = false;
const auto transpose0 = std::make_shared<ov::op::v1::Transpose>(transpose0Param, transpose0Const);
const auto transpose1 = std::make_shared<ov::op::v1::Transpose>(transpose1Param, transpose1Const);
const auto matMul0 = std::make_shared<ov::op::v0::MatMul>(transpose0, transpose1, transA, transB);
Expand Down

0 comments on commit 619b05b

Please sign in to comment.