Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Snippets] Move BrgemmCopyB repacking logic outside the Subgraph #27007

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
63ba621
[Snippets][WIP] Move CopyB repacking out from Subgraph
v-Golubev Oct 16, 2024
74c4557
Disable CopyB moving out for i8i8 case
v-Golubev Nov 7, 2024
8e66e61
[TMP] Avoid createScratchPadMem usage
v-Golubev Nov 7, 2024
0753911
update_ptrs fix for dynamic
v-Golubev Nov 7, 2024
6a7ad32
Fix original memptrs corruption in dynamic scenario
v-Golubev Nov 7, 2024
bdf1e9a
Compilation fix
v-Golubev Nov 8, 2024
a8c558e
Propagate updated shapes from SplitM to desc adjuster
v-Golubev Nov 8, 2024
00dd68f
Move brgemm repacking out fix
v-Golubev Nov 8, 2024
fa14643
Codestyle
v-Golubev Nov 10, 2024
013cc2f
Scratchpad reused for intermediate repackings
v-Golubev Nov 11, 2024
43c939d
Recreate memory object for external repacking on each inference
v-Golubev Nov 11, 2024
89be0f3
Cleanup
v-Golubev Nov 11, 2024
f36b323
Store descs in SubgraphExecutor
v-Golubev Nov 11, 2024
a863817
get_copy_b_expr helper
v-Golubev Nov 11, 2024
1d1d605
Match AdjustBrgemmCopyBLoopPorts on BrgemmCPU instead of repacking
v-Golubev Nov 12, 2024
a7cb0fa
Cleanup
v-Golubev Nov 12, 2024
db09212
Introduced BrgemmExternalRepackingAdjuster
v-Golubev Nov 12, 2024
e885a38
[WIP] Move adjuster to a separate file
v-Golubev Nov 12, 2024
aef1ecb
[WIP] Use shapes from config in optimizers
v-Golubev Nov 12, 2024
cdc9636
[WIP] introduced RuntimeOptimizer base class
v-Golubev Nov 12, 2024
710d64e
[WIP] RuntimeOptimizer inherited from lowered pass
v-Golubev Nov 12, 2024
cf667d8
Introduced RuntimeOptimizersPipeline
v-Golubev Nov 13, 2024
c04366f
All optimizers are rewritten to RuntimeOptimizers
v-Golubev Nov 13, 2024
32fded1
Serialization passes updated
v-Golubev Nov 13, 2024
aef985f
Docs and cleanup
v-Golubev Nov 13, 2024
66773fd
Further cleanup
v-Golubev Nov 13, 2024
64a9fb9
compute_offsets refactoring
v-Golubev Nov 13, 2024
8666703
Correct MHA tokenization
v-Golubev Nov 15, 2024
435cf45
Cover SplitDimensionM heuristic by unit tests
v-Golubev Nov 18, 2024
1c64d03
[WIP] Change splitM heuristic
v-Golubev Nov 18, 2024
63e9876
Correct Transpose tokenization in tests
v-Golubev Nov 19, 2024
ef2a1a6
Enable u8i8 and bf16 MHA tokenization with transpose_b=true
v-Golubev Nov 19, 2024
4bc76e2
Alexandra's comments applied
v-Golubev Nov 19, 2024
6860c67
Ivan's comments applied
v-Golubev Nov 19, 2024
2ad10d2
Rest review comments
v-Golubev Nov 20, 2024
11865aa
Further refactoring in accordance to review suggestions
v-Golubev Nov 21, 2024
8a391f1
Revert "Enable u8i8 and bf16 MHA tokenization with transpose_b=true"
v-Golubev Nov 21, 2024
bbd607d
Revert "[WIP] Change splitM heuristic"
v-Golubev Nov 21, 2024
504dbb2
Revert "Correct MHA tokenization"
v-Golubev Nov 21, 2024
01115d3
Conservatively extend SplitDimensionM::get_splited_dimensions
v-Golubev Nov 21, 2024
d21a099
Revert changes in BF16 tests
v-Golubev Nov 21, 2024
6df0b31
Finilize snippets tests
v-Golubev Nov 21, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/pass/runtime_optimizer.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @class MHAParallelWAOptimizer
* @brief Optimizes the dynamic MHA execution increasing parallel work amount dy dividing Brgemm's "M" dimension to "parallel_m"
* and "kernel_m". Uses heuristics from snippets::pass::SplitDimensionM for dimension splitting.
* The optimizer performs the following steps:
* - Identifies applicable Brgemm operations within the LinearIR.
* - Finds parameters whose shapes and layouts need to be adjusted after the split.
* - Determines loops that should be adjusted.
*/
class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
public:
MHAParallelWAOptimizer() = default;
MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator);

bool run(const lowered::LinearIR& linear_ir) override;
bool applicable() const override { return !m_loops_to_split.empty(); }

private:
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir);
static std::unordered_set<size_t> find_unsqueezed_params(
const lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<lowered::ExpressionPtr>& brgemms);
static std::vector<lowered::ExpandedLoopInfoPtr> find_loops_to_split(
const lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<size_t>& unsqueezed_params);

std::vector<lowered::ExpandedLoopInfoPtr> m_loops_to_split{};
std::unordered_set<size_t> m_unsqueezed_params{};
std::vector<std::vector<size_t>> m_optimized_layouts{};
std::vector<size_t> m_dim_M_idces{};
size_t m_concurrency = 0;

static const size_t m_dim_M_idx;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
16 changes: 16 additions & 0 deletions src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ class Pass : public PassBase {
virtual bool run(lowered::LinearIR& linear_ir) = 0;
};

/**
* @interface ConstPass
* @brief Base class for LIR passes which are performed on a full LIR body but doesn't change it
* @ingroup snippets
*/
class ConstPass : public PassBase {
public:
/**
* @brief Apply the pass to the Linear IR
* @param linear_ir the target Linear IR
* @return status of the pass
*/
virtual bool run(const lowered::LinearIR& linear_ir) = 0;
};

/**
* @interface RangedPass
* @brief Base class for LIR passes which are performed on a range of a LIR body
Expand Down Expand Up @@ -114,6 +129,7 @@ class PassPipeline {
void register_positioned_passes(const std::vector<PositionedPassLowered>& pos_passes);

void run(lowered::LinearIR& linear_ir) const;
void run(const lowered::LinearIR& linear_ir) const;
void run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/runtime_configurator.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @class RuntimeOptimizer
* @brief Base class for runtime optimizers that operate on LinearIR and RuntimeConfigurator during
* RuntimeConfigurator::update stage.
*/
class RuntimeOptimizer : public ConstPass {
public:
RuntimeOptimizer() = default;
RuntimeOptimizer(const RuntimeConfigurator* configurator) : m_configurator(configurator) {
OPENVINO_ASSERT(configurator, "RuntimeConfigurator musn't be nullptr");
}
/**
* @brief Defines if this pass is applicable. If it is not applicable, its registration in pass pipeline can be skipped.
*/
virtual bool applicable() const = 0;

/**
* @brief Creates an instance of the specified pass type and checks if it is applicable.
* If the pass is applicable, it is registered in the provided pipeline.
* @param pipeline The pipeline in which the pass should be registered.
* @param args The arguments to be forwarded to the pass constructor.
*/
template <typename OptimizerType, typename... Args, typename = std::enable_if<std::is_base_of<RuntimeOptimizer, OptimizerType>::value>>
static void register_if_applicable(PassPipeline& pipeline, Args&&... args) {
auto pass = std::make_shared<OptimizerType>(std::forward<Args>(args)...);
if (pass->applicable()) {
pipeline.register_pass(pass);
}
}

protected:
const RuntimeConfigurator* m_configurator = nullptr;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace pass {
* @brief Base class for LinearIR serialization passes
* @ingroup snippets
*/
class SerializeBase : public Pass {
class SerializeBase : public ConstPass {
public:
OPENVINO_RTTI("SerializeBase", "Pass")
OPENVINO_RTTI("SerializeBase", "ConstPass")
SerializeBase(const std::string& xml_path);

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@ class SerializeControlFlow : public SerializeBase {
OPENVINO_RTTI("SerializeControlFlow", "Pass", SerializeBase)
SerializeControlFlow(const std::string& xml_path, bool update_dynamic_ops = false) :
SerializeBase(xml_path), m_update_dynamic_ops{update_dynamic_ops} {}

bool run(LinearIR& linear_ir) override {
return run(const_cast<const LinearIR&>(linear_ir));
}
// We need a const method to run from functions that can't change LIR
bool run(const LinearIR& linear_ir);
bool run(const LinearIR& linear_ir) override;

private:
const bool m_update_dynamic_ops = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ class SerializeDataFlow : public SerializeBase {
public:
OPENVINO_RTTI("SerializeDataFlow", "Pass", SerializeBase)
SerializeDataFlow(const std::string& xml_path) : SerializeBase(xml_path) {}

bool run(LinearIR& linear_ir) override {
return run(const_cast<const LinearIR&>(linear_ir));
}
// We need a const method to run from functions that can't change LIR
bool run(const LinearIR& linear_ir);
bool run(const LinearIR& linear_ir) override;
};

} // namespace pass
Expand Down
123 changes: 62 additions & 61 deletions src/common/snippets/include/snippets/runtime_configurator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#pragma once

#include "snippets/kernel_executor_table.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/kernel_executor_table.hpp"
#include "snippets/lowered/pass/pass.hpp"

namespace ov {
Expand Down Expand Up @@ -44,12 +44,15 @@ class RuntimeConfig {
size_t tensor_rank = 0;
size_t tile_rank = 0;

std::vector<ov::snippets::VectorDims> io_shapes = {};
std::vector<ov::snippets::VectorDims> io_layouts = {};
std::vector<ov::snippets::VectorDims> io_data_offsets = {};
ov::snippets::VectorDims master_shape = {};

size_t buffer_scratchpad_size = 0;
std::vector<size_t> buffer_cluster_offsets {};
KernelExecutorTablePtr kernel_executor_table = std::make_shared<ov::snippets::KernelExecutorTable>();
std::vector<ov::snippets::VectorDims> latest_shapes = {};
};

/**
Expand Down Expand Up @@ -83,18 +86,62 @@ class RuntimeConfigurator {
*/
void reset_kernel_executor_table() const;

protected:
// Getters for private members
std::shared_ptr<RuntimeConfig> get_config() const { return m_config; }
size_t get_io_num() const { return m_io_num; }
size_t get_in_num() const { return m_in_num; }
const std::vector<snippets::lowered::PortDescriptorPtr>& get_io_descs() const { return m_io_descs; }
const std::vector<size_t>& get_io_data_sizes() const { return m_io_data_sizes; }
const std::map<size_t, std::set<lowered::BufferExpressionPtr>>& get_dynamic_buffer_clusters() const { return m_dynamic_buffer_clusters; }

/**
* @brief Update RuntimeConfig based on LinearIR
* @brief Computes the offsets for each dimension of a tensor shape.
*
* This function calculates the offsets for each dimension of a tensor shape, which represent the distance between
* consecutive elements of the corresponding dimension. If a dimension size is 1, the next dimension starts
* immediately, and the stride is 0.
* @param shape The shape for offset computation.
* @param idx The index to get the corresponding offsets and io_data_sizes.
* @param idx_stride Defines the number of dimensions that should be skipped in the offsets vector.
*/
void compute_offsets(const ov::snippets::VectorDims& shape, size_t idx, size_t idx_stride) const;
struct UnifiedLoopInfoRtParams {
size_t work_amount = 0;
std::vector<int64_t> ptr_increments;
std::vector<int64_t> finalization_offsets;
};
/**
* @brief Retrieves the runtime parameters for a given UnifiedLoopInfo.
* @param unified_loop_info The UnifiedLoopInfo for which the runtime parameters are to be retrieved.
* @return A LoopInfoRuntimeParams object containing the runtime parameters.
*/
static UnifiedLoopInfoRtParams get_loop_runtime_params(const lowered::UnifiedLoopInfoPtr& unified_loop_info);
using LoopInfoRuntimeParamsMap = std::unordered_map<lowered::UnifiedLoopInfoPtr, UnifiedLoopInfoRtParams>;
/**
* @brief Update Loop information in LinearIR: Unified and ExpandedLoopInfo
* @param linear_ir LinearIR
* @todo Ticket 148891: Rewrite on PassPipeline
*/
virtual void update(const lowered::LinearIRCPtr& linear_ir);
static void update_loop_info(const lowered::LinearIRCPtr& linear_ir);
/**
* @brief Updates the ExpandedLoopInfo based on the initialized runtime parameters.
* @param expanded_loop_info The ExpandedLoopInfo to be updated.
* @param initialized_info_map A map containing the initialized runtime parameters for UnifiedLoopInfo.
*/
static void update_expanded_loop_info(const lowered::ExpandedLoopInfoPtr& expanded_loop_info,
LoopInfoRuntimeParamsMap& initializated_info_map);
/**
* @brief Update tensor rank based on master shape
* @param master_shape Master shape
*/
virtual void update_tensor_rank(const ov::snippets::VectorDims& master_shape);
virtual void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const;

protected:
/**
* @brief Update RuntimeConfig based on LinearIR
* @param linear_ir LinearIR
* @todo Ticket 148891: Rewrite on PassPipeline
*/
virtual void update(const lowered::LinearIRCPtr& linear_ir);
/**
* @brief Allocate and intialize fields in RuntimeConfig and RuntimeConfigurator
* @param linear_ir LinearIR
Expand All @@ -120,21 +167,6 @@ class RuntimeConfigurator {
* @param linear_ir LinearIR
*/
virtual void init_tensor_rank(const lowered::LinearIRCPtr& linear_ir) const;

struct UnifiedLoopInfoRtParams {
size_t work_amount = 0;
std::vector<int64_t> ptr_increments;
std::vector<int64_t> finalization_offsets;
};
static UnifiedLoopInfoRtParams get_loop_runtime_params(const lowered::UnifiedLoopInfoPtr& unified_loop_info);
using LoopInfoRuntimeParamsMap = std::unordered_map<lowered::UnifiedLoopInfoPtr, UnifiedLoopInfoRtParams>;
/**
* @brief Update Loop informations in LinearIR: Unified and ExpandedLoopInfo
* @param linear_ir LinearIR
*/
static void update_loop_info(const lowered::LinearIRCPtr& linear_ir);
static void update_expanded_loop_info(const lowered::ExpandedLoopInfoPtr& expanded_loop_info,
LoopInfoRuntimeParamsMap& initializated_info_map);
/**
* @brief Update Buffer scratchpad size and offsets if needed
* Note: `update_loop_info` must be called before
Expand All @@ -146,8 +178,7 @@ class RuntimeConfigurator {
* @param shapes shapes used in offsets computation
* @param layouts layouts used in offsets computation
*/
void update_data_offsets(const std::vector<ov::snippets::VectorDims>& shapes,
const std::vector<std::vector<size_t>>& layouts) const;
void update_data_offsets() const;
/**
* @brief Extract shapes from m_io_descs
*/
Expand All @@ -157,43 +188,6 @@ class RuntimeConfigurator {
*/
std::vector<std::vector<size_t>> extract_layouts() const;

class MHAParallelWAOptimizer {
public:
MHAParallelWAOptimizer() = default;
MHAParallelWAOptimizer(const ov::snippets::lowered::LinearIRCPtr& linear_ir, RuntimeConfigurator* configurator);
/**
* @brief Checks if the current master shape can be optimized, and if yes, updates all the necessary runtime information
* @return status if the optimization is applied
*/
bool optimize();

private:
/**
* @brief Checks if optimizer is enabled
* @todo Ticket 148891: when RuntimeConfigurator::update will be rewritten on PassPipeline, this method should be removed
* We will not just register MHAParallelWAOptimizer in case if it is not needed
*/
bool enabled() const;

static std::unordered_set<snippets::lowered::ExpressionPtr> find_applicable_brgemms(const ov::snippets::lowered::LinearIRCPtr& linear_ir);
static std::unordered_set<size_t> find_unsqueezed_params(
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<snippets::lowered::ExpressionPtr>& brgemms);
static std::vector<ov::snippets::lowered::ExpandedLoopInfoPtr> find_loops_to_split(
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<size_t>& unsqueezed_params);

RuntimeConfigurator* configurator = nullptr;

std::vector<ov::snippets::lowered::ExpandedLoopInfoPtr> loops_to_split{};
std::unordered_set<size_t> unsqueezed_params{};
std::vector<std::vector<size_t>> optimized_layouts{};
std::vector<size_t> m_dim_idces{};
size_t concurrency = 0;

static const size_t m_dim_idx;
} m_optimizer;

std::shared_ptr<RuntimeConfig> m_config = nullptr;

size_t m_io_num = 0;
Expand All @@ -203,7 +197,14 @@ class RuntimeConfigurator {
// [cluster_id -> buffer expressions ]
std::map<size_t, std::set<lowered::BufferExpressionPtr>> m_dynamic_buffer_clusters = {};

std::vector<ov::snippets::VectorDims> m_latest_shapes = {};
// WA: until ticket 148891 is not implemented, 2 pass pipelines for runtime optimizers are necessary since different
// optimizers must be called at different pipeline stages.
// - Intermediate optimizers must be called right after `update_loop_info`
// - Final optimizers must be called after all other RuntimeConfigurator's update methods
// When all updates will be rewritten on PassPipeline, PositionedPasses can be used to precisely define the place of
// the additional optimizers
lowered::pass::PassPipeline m_intermediate_optimizers;
lowered::pass::PassPipeline m_final_optimizers;
};

} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace snippets {

class LIRPassDump {
public:
explicit LIRPassDump(lowered::LinearIR& linear_ir, std::string pass_name)
explicit LIRPassDump(const lowered::LinearIR& linear_ir, std::string pass_name)
: linear_ir(linear_ir), pass_name(std::move(pass_name)), debug_config(linear_ir.get_config().debug_config) {
dump("_in");
}
Expand Down Expand Up @@ -44,7 +44,7 @@ class LIRPassDump {
num++;
}

lowered::LinearIR& linear_ir;
const lowered::LinearIR& linear_ir;
const std::string pass_name;
const DebugCapsConfig& debug_config;
};
Expand Down
Loading
Loading