Skip to content

Commit

Permalink
Remove SDPA Transpose fusing transformation
Browse files Browse the repository at this point in the history
  • Loading branch information
maxnick committed Nov 21, 2024
1 parent 7656b9c commit 3395f05
Show file tree
Hide file tree
Showing 4 changed files with 2 additions and 218 deletions.

This file was deleted.

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -142,7 +142,6 @@
#include "transformations/cpu_opset/common/pass/swap_convert_transpose.hpp"
#include "transformations/cpu_opset/common/pass/causal_mask_preprocess_fusion.hpp"
#include "transformations/cpu_opset/common/pass/stateful_sdpa_fusion.hpp"
#include "transformations/cpu_opset/x64/pass/sdpa_fuse_transpose_reshape.hpp"

// Snippets
#include "snippets/pass/tokenization.hpp"
Expand Down Expand Up @@ -701,8 +700,6 @@ void Transformations::PreLpt(const std::vector<ov::element::Type>& defaultPrecis
CPU_REGISTER_PASS_COMMON(sdpa_manager, ov::pass::ConstantFolding);
CPU_REGISTER_PASS_COMMON(sdpa_manager, ov::pass::TransposeSinking);
CPU_REGISTER_PASS_COMMON(sdpa_manager, ov::pass::TransposeMatMul);
CPU_REGISTER_PASS_COMMON(sdpa_manager, ov::pass::VisualizeTree, "decomposed_sdpa_cf_tt.svg");
CPU_REGISTER_PASS_X64(sdpa_manager, ov::intel_cpu::SDPAFuseTransposeReshape);

sdpa_manager.run_passes(model);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -205,17 +205,10 @@ class FuseSDPAReshapeTransposeTest : virtual public ov::test::SubgraphBaseTest,

TEST_P(FuseSDPAReshapeTransposeTest, CompareWithRefs) {
SKIP_IF_CURRENT_TEST_IS_DISABLED();
bool reshape_transpose_fused = false;
auto actualOutputs = run_test(function);
CheckNumberOfNodesWithType(compiledModel, "ScaledDotProductAttention", 1);
CheckNumberOfNodesWithType(compiledModel, "Reshape", 0);
CheckNumberOfNodesWithType(compiledModel, "Subgraph", 1);
CheckNumberOfNodesWithType(compiledModel, "Reshape", 4);
CheckNumberOfNodesWithType(compiledModel, "Transpose", 0);
for (const auto& n : compiledModel.get_runtime_model()->get_ordered_ops()) {
if (n->get_friendly_name() == "mha/fused_reshape_transpose") {
reshape_transpose_fused = true;
}
}
ASSERT_TRUE(reshape_transpose_fused);

auto expectedOutputs = run_test(functionRefs);
for (size_t i = 0; i < actualOutputs.size(); i++) {
Expand Down

0 comments on commit 3395f05

Please sign in to comment.