diff --git a/praxis/layers/BUILD b/praxis/layers/BUILD index 2b29b3fd..154854de 100644 --- a/praxis/layers/BUILD +++ b/praxis/layers/BUILD @@ -117,6 +117,7 @@ pytype_strict_library( # Implicit jax dependency. # Implicit Pallas GPU dependency. # BC internal suffix. # Implicit Pallas GPU Ops dependency. # BC internal suffix. + # Implicit numpy dependency. "//praxis:asserts", "//praxis:base_layer", "//praxis:py_utils", diff --git a/praxis/layers/gpu_fast_attention.py b/praxis/layers/gpu_fast_attention.py index cc066964..9836e96d 100644 --- a/praxis/layers/gpu_fast_attention.py +++ b/praxis/layers/gpu_fast_attention.py @@ -25,6 +25,8 @@ import jax from jax.experimental.shard_map import shard_map +import numpy as np + from praxis import asserts from praxis import base_layer from praxis import py_utils @@ -47,6 +49,91 @@ JTensor = pytypes.JTensor +class GpuCudnnFusedDotProductAttention(attentions.DotProductAttention): + """Using Jax/Cudnn to call into a fused MHA kernel on NVIDIA GPU.""" + + is_causal: bool = False + + def _shard_only_bn(self, x: JTensor) -> JTensor: + """Adds sharding annotations to tensors of shape [b, n, None, None].""" + ap = self.activation_split_dims_mapping + if self.mesh_axis_names is None or ap.blnh is None: + return x + assert len(ap.blnh) == 4 + b = [ap.blnh[0], ap.blnh[2], None, None] + return base_layer.maybe_shard(x, b, self.mesh_axis_names) + + def _dot_atten( + self, + query: JTensor, + key: JTensor, + value: JTensor, + atten_mask: JTensor, + relative_bias: JTensor | None = None, + ) -> tuple[JTensor, JTensor]: + """Main attention function. + + Args: + query: JTensor of shape [B, T, N, H]. + key: JTensor of shape [B, S, N, H]. + value: JTensor of shape [B, S, N, H]. + atten_mask: JTensor of shape [1|B, 1, 1|T, S] which is a mask that is + applied to prevent attention between unwanted pairs. This has already + been converted into large negative logits. Note that the first and third + dimension allow size 1 if the mask is shared by every item in the batch + or every token in the target sequence. + relative_bias: Relative bias of shape [B, N, T, S]. + + Returns: + encoded: JTensor of shape [B, T, N, H]. + atten_probs: JTensor of shape [B, N, T, S]. + """ + query = self._shard_blnh(query) + key = self._shard_blnh(key) + value = self._shard_blnh(value) + + b, s, n, h = key.shape + base_layer.assert_has_shape(value, [b, s, n, h]) + base_layer.assert_has_shape(query, [b, -1, n, h]) + t = query.shape[1] + # If only padding bias is supplied, then atten_mask can be [B, 1, 1, S] + # since each target token is prohibited from attending to the same set of + # source tokens. In this case tiling is inefficient and unnecessary. + # If there is no padding mask, and only causal mask then the shape can be + # [1, 1, T, S] + base_layer.assert_has_shape(atten_mask, [-1, 1, -1, s]) + asserts.in_set(atten_mask.shape[2], [t, 1]) + asserts.in_set(atten_mask.shape[0], [b, 1]) + + assert self.attention_extra_logit is None + assert not self.zero_fully_masked + assert not self.atten_logit_cap or self.atten_logit_cap <= 0.0 + + if self.atten_dropout_prob > 0.0 and not self.do_eval: + raise NotImplementedError + + query = self._scale_query(query) + logits_scale = 1.0 / np.sqrt(h) if self.scale_logits_by_head_dims else 1.0 + + # Explicitly shard the relative_bias to ensure it has the same sharding on + # batch and num_head dim with the query. This is required by the + # dot_product_attention. + if relative_bias is not None: + relative_bias = self._shard_only_bn(relative_bias) + + encoded = jax.nn.dot_product_attention( + query, + key, + value, + bias=relative_bias, + scale=logits_scale, + is_causal=self.is_causal, + implementation='cudnn', + ) + encoded = self._shard_blnh(encoded) + return encoded, None # pytype: disable=bad-return-type + + class GpuTritonFusedDotProductAttention(attentions.DotProductAttention): """Using Jax/Pallas/Triton to call into a fused MHA kernel on NVIDIA GPU."""