Skip to content

Commit

Permalink
[GPU][TRANSFORMATIONS] Rope for flux (#27719)
Browse files Browse the repository at this point in the history
### Details:
- Enabled ROPE op fusion for flux.1 model on GPU to improve the
performance
- New mode for RoPE is added which enables the case when coordinates and
angles are interleaved and sin/cos tables are separate inputs. It's
activated when `is_interleaved=true` and `output_trans0213=false`

Signed-off-by: Vladimir Paramuzov <[email protected]>
  • Loading branch information
vladimir-paramuzov authored Nov 26, 2024
1 parent 833c034 commit 2b7f48e
Show file tree
Hide file tree
Showing 18 changed files with 436 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ class TRANSFORMATIONS_API RoPE : public Op {
struct Config {
size_t slice_start = 0; // slice inner-most dimensions of input
size_t slice_stop = 0;
bool input_trans0213 = false; // transpose input dim 1&2
bool is_interleaved = false; // interleaved mode, implies trans0213 happens after RoPE
size_t rotary_ndims = 0; // dimensions to be embedded (d in the description)
bool is_chatglm = false; // chatglm is special which overrides other setting
bool support_2d_rope = false; // 2d rope mode, Support 2 dimentional rope which is independant of batch and
// each head. change input order to [batch, head_cnt, 4608] to support 2d rope
bool is_qwen = false; // Qwen is special which overrides other setting
bool input_trans0213 = false; // transpose input dim 1&2
bool output_trans0213 = false; // implies trans0213 happens after RoPE
bool is_interleaved = false; // coordinates are interleaved
size_t rotary_ndims = 0; // dimensions to be embedded (d in the description)
bool is_chatglm = false; // chatglm is special which overrides other setting
bool support_2d_rope = false; // 2d rope mode, Support 2 dimentional rope which is independant of batch and
// each head. change input order to [batch, head_cnt, 4608] to support 2d rope
bool is_qwen = false; // Qwen is special which overrides other setting
size_t head_cnt = 0;
size_t head_size = 0;
int gather_position_arg_id =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ namespace pass {

class TRANSFORMATIONS_API RoPEFusion;
class TRANSFORMATIONS_API RoPEFusionGPTNEOX;
class TRANSFORMATIONS_API RoPEFusionFlux;
class TRANSFORMATIONS_API RoPEFusionGPTJ;
class TRANSFORMATIONS_API RoPEFusionChatGLM;
class TRANSFORMATIONS_API RoPEFusionQwen;
Expand All @@ -29,6 +30,12 @@ class ov::pass::RoPEFusionGPTNEOX : public ov::pass::MatcherPass {
RoPEFusionGPTNEOX();
};

class ov::pass::RoPEFusionFlux : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RoPEFusionFlux", "0");
RoPEFusionFlux();
};

class ov::pass::RoPEFusionGPTJ : public ov::pass::MatcherPass {
public:
OPENVINO_RTTI("RoPEFusionGPTJ", "0");
Expand Down Expand Up @@ -85,6 +92,7 @@ class ov::pass::RoPEFusion : public ov::pass::GraphRewrite {
public:
OPENVINO_RTTI("RoPEFusion", "0");
RoPEFusion(bool support_2d_rope = false) {
add_matcher<ov::pass::RoPEFusionFlux>();
add_matcher<ov::pass::RoPEFusionGPTNEOX>();
add_matcher<ov::pass::RoPEFusionGPTJ>();
// optional heads & tails are fused in separate matcher pass,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ void RoPE::validate_and_infer_types() {
if (m_config.input_trans0213) {
// transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens before RoPE
std::swap(input_pshape[2], input_pshape[1]);
} else if (m_config.is_interleaved) {
} else if (m_config.output_trans0213) {
// transpose 0213 ([B,L,H,S]=>[B,H,L,S]) happens after RoPE
std::swap(input_pshape[2], input_pshape[1]);
}
Expand All @@ -90,6 +90,7 @@ bool RoPE::visit_attributes(ov::AttributeVisitor& visitor) {
visitor.on_attribute("slice_start", m_config.slice_start);
visitor.on_attribute("slice_stop", m_config.slice_stop);
visitor.on_attribute("input_trans0213", m_config.input_trans0213);
visitor.on_attribute("output_trans0213", m_config.output_trans0213);
visitor.on_attribute("is_interleaved", m_config.is_interleaved);
visitor.on_attribute("rotary_ndims", m_config.rotary_ndims);
visitor.on_attribute("is_chatglm", m_config.is_chatglm);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,94 @@

using namespace ov::gen_pattern;

ov::pass::RoPEFusionFlux::RoPEFusionFlux() {
MATCHER_SCOPE(RoPEFusionFlux);
// x[?,24,?,128]
// x1 = reshape(x, [?,24,?,64,2])
// x1_0, x1_1 = split(x1, -1)
// x2 = concat(x1_0, x1_1 * (-1), -1)
// x3 = reshape(x2, [?,24,?,128])
// y1 = x * t_cos
// y2 = x3 * t_sin
// y = y1 + y2
auto x = makePattern(ov::Rank(4));
auto t_cos = makePattern(ov::Rank(4));
auto t_sin = makePattern(ov::Rank(4));

auto num_heads = ov::gen_pattern::Symbol("num_heads");
auto head_size = ov::gen_pattern::Symbol("head_size");

auto x1_target_shape = makeConst({0, num_heads, 0, -1, 2});
auto x1 = makePattern<opset1::Reshape>({x, x1_target_shape}, {{"special_zero", true}});
auto split = makePattern<opset1::Split>({x1, -1}, {{"num_splits", 2}});
split->set_output_size(2);

// 3 versions of mulitply by -1 depending on transformations execution prior to this pass
auto x1_1_neg_1 = makePattern<opset1::Multiply>({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}});

auto squeeze_2 = makePattern<opset1::Squeeze>({split->output(1), -1});
auto x1_1_neg_2 = makePattern<opset1::Multiply>({squeeze_2, -1.0f}, {{"auto_broadcast", "numpy"}});
auto unsqueeze_2 = makePattern<opset1::Unsqueeze>({x1_1_neg_2, -1});

auto x1_1_neg_3 = makePattern<opset1::Multiply>({split->output(1), -1.0f}, {{"auto_broadcast", "numpy"}});
auto squeeze_3 = makePattern<opset1::Squeeze>({x1_1_neg_3, -1});
auto unsqueeze_3 = makePattern<opset1::Unsqueeze>({squeeze_3, -1});

auto x2 = makePattern<opset1::Concat>({x1_1_neg_1 | unsqueeze_2 | unsqueeze_3, split->output(0)}, {{"axis", -1}});
auto x3_target_shape = makeConst({0, num_heads, 0, head_size});
auto x3 = makePattern<opset1::Reshape>({x2, x3_target_shape}, {{"special_zero", true}});

auto y1 = makePattern<opset1::Multiply>({x, t_cos}, {{"auto_broadcast", "numpy"}});
auto y2 = makePattern<opset1::Multiply>({x3, t_sin}, {{"auto_broadcast", "numpy"}});

auto y = makePattern<opset1::Add>({y1, y2}, {{"auto_broadcast", "numpy"}});
auto result = y;

matcher_pass_callback callback = [OV_CAPTURE_CPY_AND_THIS](ov::pass::pattern::Matcher& m) {
PatternValidator validator(m);
if (!validator) {
return false;
}

const auto& pattern_map = m.get_pattern_value_map();
auto root = m.get_match_root();

op::internal::RoPE::Config config;
config.head_cnt = static_cast<size_t>(validator["num_heads"]);
config.head_size = static_cast<size_t>(validator["head_size"]);
config.rotary_ndims = config.head_size;
config.is_interleaved = true;
config.output_trans0213 = false;

OutputVector new_args;
new_args.push_back(pattern_map.at(x));
new_args.push_back(pattern_map.at(t_cos));
new_args.push_back(pattern_map.at(t_sin));

auto old_node = root;
auto new_node = std::make_shared<op::internal::RoPE>(new_args, config);
new_node->set_friendly_name(old_node->get_friendly_name());
ov::copy_runtime_info({pattern_map.at(x1).get_node_shared_ptr(),
pattern_map.at(split).get_node_shared_ptr(),
pattern_map.at(x2).get_node_shared_ptr(),
pattern_map.at(x3).get_node_shared_ptr(),
pattern_map.at(y1).get_node_shared_ptr(),
pattern_map.at(y2).get_node_shared_ptr(),
pattern_map.at(result).get_node_shared_ptr()},
new_node);

ov::replace_node(old_node, new_node);

// this new node may match following additional matchers
register_new_node(new_node);

return true;
};

auto m = std::make_shared<ov::pass::pattern::Matcher>(result, matcher_name);
this->register_matcher(m, callback);
}

ov::pass::RoPEFusionGPTNEOX::RoPEFusionGPTNEOX() {
MATCHER_SCOPE(RoPEFusionGPTNEOX);

Expand Down Expand Up @@ -373,6 +461,7 @@ ov::pass::RoPEFusionGPTJ::RoPEFusionGPTJ() {
OutputVector new_args;
config.rotary_ndims = static_cast<size_t>(validator["ndims"]);

config.output_trans0213 = true;
config.is_interleaved = true;

// input is [B,L,H,S]
Expand Down
Loading

0 comments on commit 2b7f48e

Please sign in to comment.