diff --git a/jax/experimental/pallas/ops/gpu/attention.py b/jax/experimental/pallas/ops/gpu/attention.py index 66c9dea39734..198340ec0d11 100644 --- a/jax/experimental/pallas/ops/gpu/attention.py +++ b/jax/experimental/pallas/ops/gpu/attention.py @@ -177,13 +177,14 @@ def mha( debug: bool = False, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) num_warps_ = num_warps if num_warps_ is None: @@ -198,16 +199,16 @@ def mha( (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) + else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype) return pl.pallas_call( @@ -243,13 +244,14 @@ def _mha_forward( debug: bool, ): del backward_pass_impl - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) # Heuristics. grid_ = grid if grid_ is None: - grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads) + grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads) num_warps_ = num_warps if num_warps_ is None: @@ -260,7 +262,7 @@ def _mha_forward( out_shape = [ jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out jax.ShapeDtypeStruct( - shape=(batch_size, num_heads, seq_len), dtype=jnp.float32 # lse + shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse ), ] in_specs = [ @@ -268,16 +270,16 @@ def _mha_forward( (None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) + (None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0) ), ] in_specs.append( None # type: ignore[arg-type] if segment_ids is None - else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0)) + else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0)) ) out, lse = pl.pallas_call( kernel, @@ -362,7 +364,8 @@ def mha_backward_kernel( block_d: int, ): del out_ref # Not needed - seq_len = q_ref.shape[0] + q_seq_len = q_ref.shape[0] + kv_seq_len = k_ref.shape[0] # Scan #1: dK and dV # 1. Load a block of K and V of size (block_k1, head_dim) in SMEM. @@ -423,7 +426,7 @@ def inner_loop_dkdv(start_q, carry): lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0 dv, dk = lax.fori_loop( - lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk) + lower_bound, pl.cdiv(q_seq_len, block_q1), inner_loop_dkdv, (dv, dk) ) dv_ref[...] = dv.astype(dv_ref.dtype) dk_ref[...] = dk.astype(dk_ref.dtype) @@ -486,7 +489,7 @@ def inner_loop_dq(start_k, dq): if causal: upper_bound = lax.div((start_q + 1) * block_q2, block_k2) else: - upper_bound = pl.cdiv(seq_len, block_k2) + upper_bound = pl.cdiv(kv_seq_len, block_k2) dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) dq_ref[...] = dq.astype(dq_ref.dtype) @@ -508,9 +511,10 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, segment_ids, )[1](do) elif backward_pass_impl == "triton": - batch_size, seq_len, num_heads, head_dim = q.shape - block_q = min(block_q, seq_len) - block_k = min(block_k, seq_len) + batch_size, q_seq_len, num_heads, head_dim = q.shape + kv_seq_len = k.shape[1] + block_q = min(block_q, q_seq_len) + block_k = min(block_k, kv_seq_len) delta = _preprocess_backward(out, do, lse, block_q, debug, interpret) out_shapes = [ jax.ShapeDtypeStruct(q.shape, q.dtype), @@ -520,29 +524,29 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int, in_specs = [ pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), pl.BlockSpec( - (None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) + (None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0) ), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), - pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), + pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)), ] if segment_ids is None: in_specs.insert(3, None) # type: ignore[arg-type] else: - in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0))) + in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0))) - grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k)) + grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_k)) num_warps = 8 dq, dk, dv = pl.pallas_call( functools.partial(