Skip to content

Commit

Permalink
[Snippets][Port to 2025.0] Implemented SetDynamicWAToOuterMostLoop pa…
Browse files Browse the repository at this point in the history
…ss (#28506)

### Details:
- *Dynamic MHA Subgraphs may have only dynamic batch. Then the pass
`MHAParallelWAOptimizer` cannot be applied to this subgraph to increase
parallel work amount since outermost Loop By M in MHA has static work
amount. Then Subgraph may be inefficiently executed. This PR implemented
the pass `SetDynamicWAToOuterMostLoop ` which sets dynamic work amount
to outmost Loop in dynamic MHA to make applicable
`MHAParallelWAOptimizer` in runtime.*
 - *Original PR: #28505

### Tickets:
 - *160647*
  • Loading branch information
a-sidorova authored Jan 17, 2025
1 parent bd1764d commit 105004b
Show file tree
Hide file tree
Showing 6 changed files with 123 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@ namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

class SetDynamicWAToOuterMostLoop;
/**
* @class MHAParallelWAOptimizer
* @brief Optimizes the dynamic MHA execution increasing parallel work amount dy dividing Brgemm's "M" dimension to "parallel_m"
Expand All @@ -22,6 +24,7 @@ namespace pass {
* - Determines loops that should be adjusted.
*/
class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
friend class SetDynamicWAToOuterMostLoop;
public:
MHAParallelWAOptimizer() = default;
MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator);
Expand All @@ -30,10 +33,14 @@ class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
bool applicable() const override { return !m_loops_to_split.empty(); }

private:
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir);
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(
const lowered::LinearIRCPtr& linear_ir,
bool check_dynamic_wa = true);

static std::unordered_set<size_t> find_unsqueezed_params(
const lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<lowered::ExpressionPtr>& brgemms);

static std::vector<lowered::ExpandedLoopInfoPtr> find_loops_to_split(
const lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<size_t>& unsqueezed_params);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "pass.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

/**
* @interface SetDynamicWAToOuterMostLoop
* @brief The pass set dynamic work amount to outermost Loop by M in dynamic MHA Subgraphs
* to allow MHAParallelWAOptimizer optimizes parallel work amount in runtime.
* @ingroup snippets
*/
class SetDynamicWAToOuterMostLoop : public Pass {
public:
OPENVINO_RTTI("SetDynamicWAToOuterMostLoop", "", Pass);
SetDynamicWAToOuterMostLoop() = default;
bool run(LinearIR& linear_ir) override;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,9 @@ bool MHAParallelWAOptimizer::run(const lowered::LinearIR& linear_ir) {
return true;
}

std::unordered_set<lowered::ExpressionPtr> MHAParallelWAOptimizer::find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir) {
std::unordered_set<lowered::ExpressionPtr> MHAParallelWAOptimizer::find_applicable_brgemms(
const lowered::LinearIRCPtr& linear_ir,
bool check_dynamic_wa) {
auto is_brgemm = [](const lowered::ExpressionPtr& expr) {
return ov::is_type<op::Brgemm>(expr->get_node());
};
Expand All @@ -96,12 +98,12 @@ std::unordered_set<lowered::ExpressionPtr> MHAParallelWAOptimizer::find_applicab
brgemm_it = std::find_if(std::next(brgemm_it), linear_ir->end(), is_brgemm);
}
const auto& loop_manager = linear_ir->get_loop_manager();
auto applicable_brgemm = [&loop_manager](const lowered::ExpressionPtr& expr) {
auto applicable_brgemm = [&loop_manager, check_dynamic_wa](const lowered::ExpressionPtr& expr) {
const auto& loop_idces = expr->get_loop_ids();
if (loop_idces.empty())
return false;
const auto& outermost_loop = loop_manager->get_loop_info(loop_idces[0]);
if (!snippets::utils::is_dynamic_value(outermost_loop->get_work_amount()))
if (check_dynamic_wa && !snippets::utils::is_dynamic_value(outermost_loop->get_work_amount()))
return false;
bool loop_by_m = true;
outermost_loop->iterate_through_ports([&loop_by_m](const lowered::LoopPort& port) {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#include "snippets/lowered/pass/set_dynamic_wa_to_outermost_loop.hpp"

#include "snippets/lowered/pass/mha_parallel_wa_optimizer.hpp"
#include "snippets/itt.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_manager.hpp"
#include "snippets/op/brgemm.hpp"
#include "snippets/utils/loop_utils.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {

bool SetDynamicWAToOuterMostLoop::run(LinearIR& linear_ir) {
OV_ITT_SCOPED_TASK(ov::pass::itt::domains::SnippetsTransform, "Snippets::SetDynamicWAToOuterMostLoop")
if (linear_ir.empty() || !linear_ir.is_dynamic() || linear_ir.get_config().m_enable_domain_optimization)
return false;

const auto linear_ir_ptr = std::make_shared<const LinearIR>(linear_ir);
const auto brgemms = MHAParallelWAOptimizer::find_applicable_brgemms(linear_ir_ptr, false);
if (brgemms.empty())
return false;

const auto unsqueezed_params = MHAParallelWAOptimizer::find_unsqueezed_params(linear_ir_ptr, brgemms);
OPENVINO_ASSERT(!unsqueezed_params.empty(), "unsqueezed_params mustn't be empty after initialization");


const auto& loop_manager = linear_ir_ptr->get_loop_manager();
std::unordered_set<lowered::UnifiedLoopInfoPtr> affected_loops;
size_t prev_loop_id = std::numeric_limits<size_t>::max();
static const size_t dim_M_idx = 1;

auto add_affected_loop = [&](const lowered::ExpressionPtr& expr) {
const auto& loop_idces = expr->get_loop_ids();
if (loop_idces.empty() || loop_idces.front() == prev_loop_id)
return;

prev_loop_id = loop_idces.front();
const auto loop_info = loop_manager->get_loop_info<lowered::UnifiedLoopInfo>(prev_loop_id);
if (loop_info->get_dim_idx() == dim_M_idx) {
affected_loops.insert(loop_info);
}
};

size_t i = 0;
std::unordered_set<lowered::ExpressionPtr> visited;
for (const auto& param : linear_ir_ptr->get_parameters()) {
if (unsqueezed_params.count(i++))
continue;
utils::visit_path(param, visited, add_affected_loop, false);
}

bool modified = false;
for (const auto& loop : affected_loops) {
if (!utils::is_dynamic_value(loop->get_work_amount())) {
loop->set_work_amount(utils::get_dynamic_value<size_t>());
ov::snippets::utils::update_data_pointer_shifts(loop);
modified = true;
}
}

return modified;
}

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
2 changes: 2 additions & 0 deletions src/common/snippets/src/op/subgraph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
#include "snippets/lowered/pass/validate_expanded_loops.hpp"
#include "snippets/lowered/pass/set_load_store_scalar.hpp"
#include "snippets/lowered/pass/extract_loop_invariants.hpp"
#include "snippets/lowered/pass/set_dynamic_wa_to_outermost_loop.hpp"

#include "transformations/utils/utils.hpp"

Expand Down Expand Up @@ -468,6 +469,7 @@ void Subgraph::control_flow_transformations(size_t min_parallel_work_amount, siz
pipeline.register_pass<lowered::pass::ValidateShapes>();
pipeline.register_pass<lowered::pass::ValidateUnifiedLoops>();
pipeline.register_pass<lowered::pass::InitLoops>();
pipeline.register_pass<lowered::pass::SetDynamicWAToOuterMostLoop>();
pipeline.register_pass<lowered::pass::InsertLoops>();
pipeline.register_pass<lowered::pass::AllocateBuffers>(m_linear_ir->get_config().m_are_buffers_optimized);
pipeline.register_pass<lowered::pass::CleanRepeatedDataPointerShifts>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,11 @@ std::vector<std::vector<ov::test::InputShape>> originalShape_3D {
{PartialShape{2, -1, 64}, {{2, 9, 64}, {2, 4, 64}, {2, 9, 64}}},
{PartialShape{2, 64, -1}, {{2, 64, 9}, {2, 64, 4}, {2, 64, 9}}},
{PartialShape{2, -1, 64}, {{2, 9, 64}, {2, 4, 64}, {2, 9, 64}}},
},
{
{PartialShape{-1, 128, 64}, {{1, 128, 64}, {2, 128, 64}, {1, 128, 64}}},
{PartialShape{-1, 64, 128}, {{1, 64, 128}, {2, 64, 128}, {1, 64, 128}}},
{PartialShape{-1, 128, 64}, {{1, 128, 64}, {2, 128, 64}, {1, 128, 64}}},
}
};

Expand Down

0 comments on commit 105004b

Please sign in to comment.