diff --git a/praxis/layers/grok.py b/praxis/layers/grok.py index f66d143f..d1effc49 100644 --- a/praxis/layers/grok.py +++ b/praxis/layers/grok.py @@ -59,6 +59,7 @@ def GrokStackedTransformerHParams( combine_qkv=False, bidirectional=False, use_fp8=False, + use_te_dpa=False, ) -> pax_fiddle.Config[transformers.StackedTransformer]: """Common setup for Grok-1 Transformer layers. @@ -168,6 +169,7 @@ def GrokStackedTransformerHParams( p.transformer_layer_params_tpl.tr_atten_tpl = pax_fiddle.Config( multi_query_attention.MultiQueryDotProductAttention, num_kv_heads=attention_num_groups, + use_te_dpa=use_te_dpa, ) tr_atten_tpl = p.transformer_layer_params_tpl.tr_atten_tpl tr_atten_tpl.combine_qkv = False @@ -225,6 +227,7 @@ def GrokUniTransformerLmHParams( model_type=LanguageModelType.CAUSAL, checkpoint_policy=AutodiffCheckpointType.SAVE_NOTHING, use_fp8=False, + use_te_dpa=False, ) -> pax_fiddle.Config[transformer_models.TransformerLm]: """Common setup for Grok-1 Decoder-only Transformer Model. @@ -328,6 +331,7 @@ def GrokUniTransformerLmHParams( bidirectional=bidirectional, moe_gating_embedding_level=moe_gating_embedding_level, use_fp8=use_fp8, + use_te_dpa=use_te_dpa, ) num_blocks = num_transformer_layers diff --git a/praxis/layers/multi_query_attention.py b/praxis/layers/multi_query_attention.py index 23a074b6..ab08f28f 100644 --- a/praxis/layers/multi_query_attention.py +++ b/praxis/layers/multi_query_attention.py @@ -17,7 +17,7 @@ import math from typing import Any, Callable, Mapping, Sequence - +from absl import logging from flax import linen as nn import jax from jax import numpy as jnp @@ -31,7 +31,7 @@ from praxis.layers import base_ops from praxis.layers import embedding_softmax from praxis.layers import stochastics - +from transformer_engine.jax.flax.transformer import DotProductAttention as TEDotProductAttention WeightInit = base_layer.WeightInit WeightHParams = base_layer.WeightHParams @@ -215,6 +215,7 @@ class MultiQueryDotProductAttention(base_layer.BaseLayer): scale_query_by_dim_per_head: bool = False chunked_attn_num_seq_split: int = 1 local_window_size: tuple[int, int] | None = None + use_te_dpa: bool = False # Experimental way to use TE flash attention when can't use standard TE # SPMD partition related params. # @@ -353,6 +354,20 @@ def project_input_kv(input_dim, dim_per_head): self.create_child('post', post_proj_p) self.create_child('qk_einsum', self.qk_einsum_tpl.clone()) self.create_child('pv_einsum', self.pv_einsum_tpl.clone()) + self.dpa_layer = TEDotProductAttention( + head_dim=dim_per_head, + num_attention_heads=self.num_heads, + num_gqa_groups=self.num_kv_heads, + attn_mask_type='causal', # 'causal' or 'padding' + attn_bias_type='no_bias', # 'no_bias', 'pre_scale_bias' or 'post_scale_bias' + attention_dropout=0., + dropout_rng_name='aqt', + dtype=jnp.bfloat16, + float32_logits=False, + qkv_layout='BSHD_BSHD_BSHD', # 'BS3HD', 'BSHD_BS2HD' or 'BSHD_BSHD_BSHD' + scale_factor=1.0/math.sqrt(self.num_heads), + transpose_batch_sequence=False + ) def _shard_bnh(self, x: JTensor) -> JTensor: """Shards tensors of shape [b, n, h]. @@ -889,29 +904,36 @@ def _rep_d(x): else: key_proj = self._shard_blnh(key_proj) value_proj = self._shard_blnh(value_proj) - b, t, n, h = query_proj.shape - _, s, nk, _ = key_proj.shape - assert n % nk == 0 - v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h)) - if relative_bias is not None: - v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s)) - else: - v_rb = None - with self._context_for_kv_vmap(): - encoded, atten_probs = jax.vmap( - self._dot_atten, - in_axes=(2, 2, 2, None, 1), - out_axes=(2, 1), - )( - v_q, - key_proj, - value_proj, - atten_mask, - v_rb, + if self.use_te_dpa: + logging.warning( + 'use_te_dpa is set to True, so TE dpa is used as an experimental way to use TE flash attention.' ) - encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h))) - if atten_probs is not None: - atten_probs = jnp.reshape(atten_probs, (b, t, n, s)) + atten_probs = None + encoded = self.dpa_layer(query_proj, key_proj, value_proj) + else: + b, t, n, h = query_proj.shape + _, s, nk, _ = key_proj.shape + assert n % nk == 0 + v_q = jnp.reshape(query_proj, (b, t, nk, n // nk, h)) + if relative_bias is not None: + v_rb = jnp.reshape(relative_bias, (b, nk, n // nk, t, s)) + else: + v_rb = None + with self._context_for_kv_vmap(): + encoded, atten_probs = jax.vmap( + self._dot_atten, + in_axes=(2, 2, 2, None, 1), + out_axes=(2, 1), + )( + v_q, + key_proj, + value_proj, + atten_mask, + v_rb, + ) + encoded = self._shard_blnh(jnp.reshape(encoded, (b, t, n, h))) + if atten_probs is not None: + atten_probs = jnp.reshape(atten_probs, (b, t, n, s)) # Post projection encoded = self.post(encoded)