Skip to content

Commit

Permalink
[Snippets] Move BrgemmCopyB repacking logic outside the Subgraph (ope…
Browse files Browse the repository at this point in the history
…nvinotoolkit#27007)

### Details:
Currently, CopyB repacking is always performed inside Subgraph. In the
case when batch on B Matmul input is significantly smaller than batch on
A Matmul input, and parallel work amount is big enough, this may lead to
ineffective execution, since repacking for B input is performed in each
parallel task whereas only one repacking iteration for each B batch is
enough.

Within this PR, CopyB repacking is moved outside the snippets kernel and
performed via common reorder primitive just before the snippets kernel
execution.

### Tickets:
 - *CVS-154383*
  • Loading branch information
v-Golubev authored and NishantPrabhuFujitsu committed Nov 26, 2024
1 parent 631ad3d commit 8cc715c
Show file tree
Hide file tree
Showing 34 changed files with 969 additions and 465 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/lowered/pass/runtime_optimizer.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @class MHAParallelWAOptimizer
* @brief Optimizes the dynamic MHA execution increasing parallel work amount dy dividing Brgemm's "M" dimension to "parallel_m"
* and "kernel_m". Uses heuristics from snippets::pass::SplitDimensionM for dimension splitting.
* The optimizer performs the following steps:
* - Identifies applicable Brgemm operations within the LinearIR.
* - Finds parameters whose shapes and layouts need to be adjusted after the split.
* - Determines loops that should be adjusted.
*/
class MHAParallelWAOptimizer : public lowered::pass::RuntimeOptimizer {
public:
MHAParallelWAOptimizer() = default;
MHAParallelWAOptimizer(const lowered::LinearIRCPtr& linear_ir, const RuntimeConfigurator* configurator);

bool run(const lowered::LinearIR& linear_ir) override;
bool applicable() const override { return !m_loops_to_split.empty(); }

private:
static std::unordered_set<lowered::ExpressionPtr> find_applicable_brgemms(const lowered::LinearIRCPtr& linear_ir);
static std::unordered_set<size_t> find_unsqueezed_params(
const lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<lowered::ExpressionPtr>& brgemms);
static std::vector<lowered::ExpandedLoopInfoPtr> find_loops_to_split(
const lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<size_t>& unsqueezed_params);

std::vector<lowered::ExpandedLoopInfoPtr> m_loops_to_split{};
std::unordered_set<size_t> m_unsqueezed_params{};
std::vector<std::vector<size_t>> m_optimized_layouts{};
std::vector<size_t> m_dim_M_idces{};
size_t m_concurrency = 0;

static const size_t m_dim_M_idx;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
16 changes: 16 additions & 0 deletions src/common/snippets/include/snippets/lowered/pass/pass.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,21 @@ class Pass : public PassBase {
virtual bool run(lowered::LinearIR& linear_ir) = 0;
};

/**
* @interface ConstPass
* @brief Base class for LIR passes which are performed on a full LIR body but doesn't change it
* @ingroup snippets
*/
class ConstPass : public PassBase {
public:
/**
* @brief Apply the pass to the Linear IR
* @param linear_ir the target Linear IR
* @return status of the pass
*/
virtual bool run(const lowered::LinearIR& linear_ir) = 0;
};

/**
* @interface RangedPass
* @brief Base class for LIR passes which are performed on a range of a LIR body
Expand Down Expand Up @@ -114,6 +129,7 @@ class PassPipeline {
void register_positioned_passes(const std::vector<PositionedPassLowered>& pos_passes);

void run(lowered::LinearIR& linear_ir) const;
void run(const lowered::LinearIR& linear_ir) const;
void run(lowered::LinearIR& linear_ir, lowered::LinearIR::constExprIt begin, lowered::LinearIR::constExprIt end) const;

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
// Copyright (C) 2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0
//

#pragma once

#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/pass/pass.hpp"
#include "snippets/runtime_configurator.hpp"

namespace ov {
namespace snippets {
namespace lowered {
namespace pass {
/**
* @class RuntimeOptimizer
* @brief Base class for runtime optimizers that operate on LinearIR and RuntimeConfigurator during
* RuntimeConfigurator::update stage.
*/
class RuntimeOptimizer : public ConstPass {
public:
RuntimeOptimizer() = default;
RuntimeOptimizer(const RuntimeConfigurator* configurator) : m_configurator(configurator) {
OPENVINO_ASSERT(configurator, "RuntimeConfigurator musn't be nullptr");
}
/**
* @brief Defines if this pass is applicable. If it is not applicable, its registration in pass pipeline can be skipped.
*/
virtual bool applicable() const = 0;

/**
* @brief Creates an instance of the specified pass type and checks if it is applicable.
* If the pass is applicable, it is registered in the provided pipeline.
* @param pipeline The pipeline in which the pass should be registered.
* @param args The arguments to be forwarded to the pass constructor.
*/
template <typename OptimizerType, typename... Args, typename = std::enable_if<std::is_base_of<RuntimeOptimizer, OptimizerType>::value>>
static void register_if_applicable(PassPipeline& pipeline, Args&&... args) {
auto pass = std::make_shared<OptimizerType>(std::forward<Args>(args)...);
if (pass->applicable()) {
pipeline.register_pass(pass);
}
}

protected:
const RuntimeConfigurator* m_configurator = nullptr;
};

} // namespace pass
} // namespace lowered
} // namespace snippets
} // namespace ov
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,9 @@ namespace pass {
* @brief Base class for LinearIR serialization passes
* @ingroup snippets
*/
class SerializeBase : public Pass {
class SerializeBase : public ConstPass {
public:
OPENVINO_RTTI("SerializeBase", "Pass")
OPENVINO_RTTI("SerializeBase", "ConstPass")
SerializeBase(const std::string& xml_path);

protected:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,12 +22,7 @@ class SerializeControlFlow : public SerializeBase {
OPENVINO_RTTI("SerializeControlFlow", "Pass", SerializeBase)
SerializeControlFlow(const std::string& xml_path, bool update_dynamic_ops = false) :
SerializeBase(xml_path), m_update_dynamic_ops{update_dynamic_ops} {}

bool run(LinearIR& linear_ir) override {
return run(const_cast<const LinearIR&>(linear_ir));
}
// We need a const method to run from functions that can't change LIR
bool run(const LinearIR& linear_ir);
bool run(const LinearIR& linear_ir) override;

private:
const bool m_update_dynamic_ops = false;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,7 @@ class SerializeDataFlow : public SerializeBase {
public:
OPENVINO_RTTI("SerializeDataFlow", "Pass", SerializeBase)
SerializeDataFlow(const std::string& xml_path) : SerializeBase(xml_path) {}

bool run(LinearIR& linear_ir) override {
return run(const_cast<const LinearIR&>(linear_ir));
}
// We need a const method to run from functions that can't change LIR
bool run(const LinearIR& linear_ir);
bool run(const LinearIR& linear_ir) override;
};

} // namespace pass
Expand Down
123 changes: 62 additions & 61 deletions src/common/snippets/include/snippets/runtime_configurator.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,9 @@

#pragma once

#include "snippets/kernel_executor_table.hpp"
#include "snippets/lowered/linear_ir.hpp"
#include "snippets/lowered/loop_info.hpp"
#include "snippets/kernel_executor_table.hpp"
#include "snippets/lowered/pass/pass.hpp"

namespace ov {
Expand Down Expand Up @@ -44,12 +44,15 @@ class RuntimeConfig {
size_t tensor_rank = 0;
size_t tile_rank = 0;

std::vector<ov::snippets::VectorDims> io_shapes = {};
std::vector<ov::snippets::VectorDims> io_layouts = {};
std::vector<ov::snippets::VectorDims> io_data_offsets = {};
ov::snippets::VectorDims master_shape = {};

size_t buffer_scratchpad_size = 0;
std::vector<size_t> buffer_cluster_offsets {};
KernelExecutorTablePtr kernel_executor_table = std::make_shared<ov::snippets::KernelExecutorTable>();
std::vector<ov::snippets::VectorDims> latest_shapes = {};
};

/**
Expand Down Expand Up @@ -83,18 +86,62 @@ class RuntimeConfigurator {
*/
void reset_kernel_executor_table() const;

protected:
// Getters for private members
std::shared_ptr<RuntimeConfig> get_config() const { return m_config; }
size_t get_io_num() const { return m_io_num; }
size_t get_in_num() const { return m_in_num; }
const std::vector<snippets::lowered::PortDescriptorPtr>& get_io_descs() const { return m_io_descs; }
const std::vector<size_t>& get_io_data_sizes() const { return m_io_data_sizes; }
const std::map<size_t, std::set<lowered::BufferExpressionPtr>>& get_dynamic_buffer_clusters() const { return m_dynamic_buffer_clusters; }

/**
* @brief Update RuntimeConfig based on LinearIR
* @brief Computes the offsets for each dimension of a tensor shape.
*
* This function calculates the offsets for each dimension of a tensor shape, which represent the distance between
* consecutive elements of the corresponding dimension. If a dimension size is 1, the next dimension starts
* immediately, and the stride is 0.
* @param shape The shape for offset computation.
* @param idx The index to get the corresponding offsets and io_data_sizes.
* @param idx_stride Defines the number of dimensions that should be skipped in the offsets vector.
*/
void compute_offsets(const ov::snippets::VectorDims& shape, size_t idx, size_t idx_stride) const;
struct UnifiedLoopInfoRtParams {
size_t work_amount = 0;
std::vector<int64_t> ptr_increments;
std::vector<int64_t> finalization_offsets;
};
/**
* @brief Retrieves the runtime parameters for a given UnifiedLoopInfo.
* @param unified_loop_info The UnifiedLoopInfo for which the runtime parameters are to be retrieved.
* @return A LoopInfoRuntimeParams object containing the runtime parameters.
*/
static UnifiedLoopInfoRtParams get_loop_runtime_params(const lowered::UnifiedLoopInfoPtr& unified_loop_info);
using LoopInfoRuntimeParamsMap = std::unordered_map<lowered::UnifiedLoopInfoPtr, UnifiedLoopInfoRtParams>;
/**
* @brief Update Loop information in LinearIR: Unified and ExpandedLoopInfo
* @param linear_ir LinearIR
* @todo Ticket 148891: Rewrite on PassPipeline
*/
virtual void update(const lowered::LinearIRCPtr& linear_ir);
static void update_loop_info(const lowered::LinearIRCPtr& linear_ir);
/**
* @brief Updates the ExpandedLoopInfo based on the initialized runtime parameters.
* @param expanded_loop_info The ExpandedLoopInfo to be updated.
* @param initialized_info_map A map containing the initialized runtime parameters for UnifiedLoopInfo.
*/
static void update_expanded_loop_info(const lowered::ExpandedLoopInfoPtr& expanded_loop_info,
LoopInfoRuntimeParamsMap& initializated_info_map);
/**
* @brief Update tensor rank based on master shape
* @param master_shape Master shape
*/
virtual void update_tensor_rank(const ov::snippets::VectorDims& master_shape);
virtual void update_tensor_rank(const ov::snippets::VectorDims& master_shape) const;

protected:
/**
* @brief Update RuntimeConfig based on LinearIR
* @param linear_ir LinearIR
* @todo Ticket 148891: Rewrite on PassPipeline
*/
virtual void update(const lowered::LinearIRCPtr& linear_ir);
/**
* @brief Allocate and intialize fields in RuntimeConfig and RuntimeConfigurator
* @param linear_ir LinearIR
Expand All @@ -120,21 +167,6 @@ class RuntimeConfigurator {
* @param linear_ir LinearIR
*/
virtual void init_tensor_rank(const lowered::LinearIRCPtr& linear_ir) const;

struct UnifiedLoopInfoRtParams {
size_t work_amount = 0;
std::vector<int64_t> ptr_increments;
std::vector<int64_t> finalization_offsets;
};
static UnifiedLoopInfoRtParams get_loop_runtime_params(const lowered::UnifiedLoopInfoPtr& unified_loop_info);
using LoopInfoRuntimeParamsMap = std::unordered_map<lowered::UnifiedLoopInfoPtr, UnifiedLoopInfoRtParams>;
/**
* @brief Update Loop informations in LinearIR: Unified and ExpandedLoopInfo
* @param linear_ir LinearIR
*/
static void update_loop_info(const lowered::LinearIRCPtr& linear_ir);
static void update_expanded_loop_info(const lowered::ExpandedLoopInfoPtr& expanded_loop_info,
LoopInfoRuntimeParamsMap& initializated_info_map);
/**
* @brief Update Buffer scratchpad size and offsets if needed
* Note: `update_loop_info` must be called before
Expand All @@ -146,8 +178,7 @@ class RuntimeConfigurator {
* @param shapes shapes used in offsets computation
* @param layouts layouts used in offsets computation
*/
void update_data_offsets(const std::vector<ov::snippets::VectorDims>& shapes,
const std::vector<std::vector<size_t>>& layouts) const;
void update_data_offsets() const;
/**
* @brief Extract shapes from m_io_descs
*/
Expand All @@ -157,43 +188,6 @@ class RuntimeConfigurator {
*/
std::vector<std::vector<size_t>> extract_layouts() const;

class MHAParallelWAOptimizer {
public:
MHAParallelWAOptimizer() = default;
MHAParallelWAOptimizer(const ov::snippets::lowered::LinearIRCPtr& linear_ir, RuntimeConfigurator* configurator);
/**
* @brief Checks if the current master shape can be optimized, and if yes, updates all the necessary runtime information
* @return status if the optimization is applied
*/
bool optimize();

private:
/**
* @brief Checks if optimizer is enabled
* @todo Ticket 148891: when RuntimeConfigurator::update will be rewritten on PassPipeline, this method should be removed
* We will not just register MHAParallelWAOptimizer in case if it is not needed
*/
bool enabled() const;

static std::unordered_set<snippets::lowered::ExpressionPtr> find_applicable_brgemms(const ov::snippets::lowered::LinearIRCPtr& linear_ir);
static std::unordered_set<size_t> find_unsqueezed_params(
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<snippets::lowered::ExpressionPtr>& brgemms);
static std::vector<ov::snippets::lowered::ExpandedLoopInfoPtr> find_loops_to_split(
const ov::snippets::lowered::LinearIRCPtr& linear_ir,
const std::unordered_set<size_t>& unsqueezed_params);

RuntimeConfigurator* configurator = nullptr;

std::vector<ov::snippets::lowered::ExpandedLoopInfoPtr> loops_to_split{};
std::unordered_set<size_t> unsqueezed_params{};
std::vector<std::vector<size_t>> optimized_layouts{};
std::vector<size_t> m_dim_idces{};
size_t concurrency = 0;

static const size_t m_dim_idx;
} m_optimizer;

std::shared_ptr<RuntimeConfig> m_config = nullptr;

size_t m_io_num = 0;
Expand All @@ -203,7 +197,14 @@ class RuntimeConfigurator {
// [cluster_id -> buffer expressions ]
std::map<size_t, std::set<lowered::BufferExpressionPtr>> m_dynamic_buffer_clusters = {};

std::vector<ov::snippets::VectorDims> m_latest_shapes = {};
// WA: until ticket 148891 is not implemented, 2 pass pipelines for runtime optimizers are necessary since different
// optimizers must be called at different pipeline stages.
// - Intermediate optimizers must be called right after `update_loop_info`
// - Final optimizers must be called after all other RuntimeConfigurator's update methods
// When all updates will be rewritten on PassPipeline, PositionedPasses can be used to precisely define the place of
// the additional optimizers
lowered::pass::PassPipeline m_intermediate_optimizers;
lowered::pass::PassPipeline m_final_optimizers;
};

} // namespace snippets
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ namespace snippets {

class LIRPassDump {
public:
explicit LIRPassDump(lowered::LinearIR& linear_ir, std::string pass_name)
explicit LIRPassDump(const lowered::LinearIR& linear_ir, std::string pass_name)
: linear_ir(linear_ir), pass_name(std::move(pass_name)), debug_config(linear_ir.get_config().debug_config) {
dump("_in");
}
Expand Down Expand Up @@ -44,7 +44,7 @@ class LIRPassDump {
num++;
}

lowered::LinearIR& linear_ir;
const lowered::LinearIR& linear_ir;
const std::string pass_name;
const DebugCapsConfig& debug_config;
};
Expand Down
Loading

0 comments on commit 8cc715c

Please sign in to comment.