Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use TE dpa for grok mqa #84

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions praxis/layers/grok.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ def GrokStackedTransformerHParams(
combine_qkv=False,
bidirectional=False,
use_fp8=False,
use_te_dpa=False,
) -> pax_fiddle.Config[transformers.StackedTransformer]:
"""Common setup for Grok-1 Transformer layers.

Expand Down Expand Up @@ -168,6 +169,7 @@ def GrokStackedTransformerHParams(
p.transformer_layer_params_tpl.tr_atten_tpl = pax_fiddle.Config(
multi_query_attention.MultiQueryDotProductAttention,
num_kv_heads=attention_num_groups,
use_te_dpa=use_te_dpa,
)
tr_atten_tpl = p.transformer_layer_params_tpl.tr_atten_tpl
tr_atten_tpl.combine_qkv = False
Expand Down Expand Up @@ -225,6 +227,7 @@ def GrokUniTransformerLmHParams(
model_type=LanguageModelType.CAUSAL,
checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING,
use_fp8=False,
use_te_dpa=False,
) -> pax_fiddle.Config[transformer_models.TransformerLm]:
"""Common setup for Grok-1 Decoder-only Transformer Model.

Expand Down Expand Up @@ -328,6 +331,7 @@ def GrokUniTransformerLmHParams(
bidirectional=bidirectional,
moe_gating_embedding_level=moe_gating_embedding_level,
use_fp8=use_fp8,
use_te_dpa=use_te_dpa,
)
num_blocks = num_transformer_layers

Expand Down
70 changes: 46 additions & 24 deletions praxis/layers/multi_query_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

import math
from typing import Any, Callable, Mapping, Sequence

from absl import logging
from flax import linen as nn
import jax
from jax import numpy as jnp
Expand All @@ -31,7 +31,7 @@
from praxis.layers import base_ops
from praxis.layers import embedding_softmax
from praxis.layers import stochastics

from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention

WeightInit = base_layer.WeightInit
WeightHParams = base_layer.WeightHParams
Expand Down Expand Up @@ -215,6 +215,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer):
scale_query_by_dim_per_head: bool = False
chunked_attn_num_seq_split: int = 1
local_window_size: tuple[int, int] | None = None
use_te_dpa: bool = False # Experimental way to use TE flash attention when can't use standard TE

# SPMD partition related params.
#
Expand Down Expand Up @@ -353,6 +354,20 @@ def project_input_kv(input_dim, dim_per_head):
self.create_child('post', post_proj_p)
self.create_child('qk_einsum', self.qk_einsum_tpl.clone())
self.create_child('pv_einsum', self.pv_einsum_tpl.clone())
self.dpa_layer = TEDotProductAttention(
head_dim=dim_per_head,
num_attention_heads=self.num_heads,
num_gqa_groups=self.num_kv_heads,
attn_mask_type='causal', # 'causal' or 'padding'
attn_bias_type='no_bias', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias'
attention_dropout=0.,
dropout_rng_name='aqt',
dtype=jnp.bfloat16,
float32_logits=False,
qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD'
scale_factor=1.0/math.sqrt(self.num_heads),
transpose_batch_sequence=False
)

def _shard_bnh(self, x: JTensor) -> JTensor:
"""Shards tensors of shape [b, n, h].
Expand Down Expand Up @@ -889,29 +904,36 @@ def _rep_d(x):
else:
key_proj = self._shard_blnh(key_proj)
value_proj = self._shard_blnh(value_proj)
b, t, n, h = query_proj.shape
_, s, nk, _ = key_proj.shape
assert n % nk == 0
v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h))
if relative_bias is not None:
v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s))
else:
v_rb = None
with self._context_for_kv_vmap():
encoded, atten_probs = jax.vmap(
self._dot_atten,
in_axes=(2, 2, 2, None, 1),
out_axes=(2, 1),
)(
v_q,
key_proj,
value_proj,
atten_mask,
v_rb,
if self.use_te_dpa:
logging.warning(
'use_te_dpa is set to True, so TE dpa is used as an experimental way to use TE flash attention.'
)
encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h)))
if atten_probs is not None:
atten_probs = jnp.reshape(atten_probs, (b, t, n, s))
atten_probs = None
encoded = self.dpa_layer(query_proj, key_proj, value_proj)
else:
b, t, n, h = query_proj.shape
_, s, nk, _ = key_proj.shape
assert n % nk == 0
v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h))
if relative_bias is not None:
v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s))
else:
v_rb = None
with self._context_for_kv_vmap():
encoded, atten_probs = jax.vmap(
self._dot_atten,
in_axes=(2, 2, 2, None, 1),
out_axes=(2, 1),
)(
v_q,
key_proj,
value_proj,
atten_mask,
v_rb,
)
encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h)))
if atten_probs is not None:
atten_probs = jnp.reshape(atten_probs, (b, t, n, s))

# Post projection
encoded = self.post(encoded)
Expand Down