Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix the wrong output of pallas attention kernel when q_len!=kv_len #24495

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 34 additions & 30 deletions jax/experimental/pallas/ops/gpu/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,13 +177,14 @@ def mha(
debug: bool = False,
):
del backward_pass_impl
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_q, q_seq_len)
block_k = min(block_k, kv_seq_len)
# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads)

num_warps_ = num_warps
if num_warps_ is None:
Expand All @@ -198,16 +199,16 @@ def mha(
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
),
]
in_specs.append(
None # type: ignore[arg-type]
if segment_ids is None
else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0))
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
)
out_shape = jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype)
return pl.pallas_call(
Expand Down Expand Up @@ -243,13 +244,14 @@ def _mha_forward(
debug: bool,
):
del backward_pass_impl
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_q, q_seq_len)
block_k = min(block_k, kv_seq_len)
# Heuristics.
grid_ = grid
if grid_ is None:
grid_ = (pl.cdiv(seq_len, block_q), batch_size, num_heads)
grid_ = (pl.cdiv(q_seq_len, block_q), batch_size, num_heads)

num_warps_ = num_warps
if num_warps_ is None:
Expand All @@ -260,24 +262,24 @@ def _mha_forward(
out_shape = [
jax.ShapeDtypeStruct(shape=q.shape, dtype=q.dtype), # out
jax.ShapeDtypeStruct(
shape=(batch_size, num_heads, seq_len), dtype=jnp.float32 # lse
shape=(batch_size, num_heads, q_seq_len), dtype=jnp.float32 # lse
),
]
in_specs = [
pl.BlockSpec(
(None, block_q, None, head_dim), lambda i, j, k: (j, i, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
(None, kv_seq_len, None, head_dim), lambda _, j, k: (j, 0, k, 0)
),
]
in_specs.append(
None # type: ignore[arg-type]
if segment_ids is None
else pl.BlockSpec((None, seq_len), lambda _, j, k: (j, 0))
else pl.BlockSpec((None, kv_seq_len), lambda _, j, k: (j, 0))
)
out, lse = pl.pallas_call(
kernel,
Expand Down Expand Up @@ -362,7 +364,8 @@ def mha_backward_kernel(
block_d: int,
):
del out_ref # Not needed
seq_len = q_ref.shape[0]
q_seq_len = q_ref.shape[0]
kv_seq_len = k_ref.shape[0]

# Scan #1: dK and dV
# 1. Load a block of K and V of size (block_k1, head_dim) in SMEM.
Expand Down Expand Up @@ -423,7 +426,7 @@ def inner_loop_dkdv(start_q, carry):

lower_bound = lax.div(start_k * block_k1, block_q1) if causal else 0
dv, dk = lax.fori_loop(
lower_bound, pl.cdiv(seq_len, block_q1), inner_loop_dkdv, (dv, dk)
lower_bound, pl.cdiv(q_seq_len, block_q1), inner_loop_dkdv, (dv, dk)
)
dv_ref[...] = dv.astype(dv_ref.dtype)
dk_ref[...] = dk.astype(dk_ref.dtype)
Expand Down Expand Up @@ -486,7 +489,7 @@ def inner_loop_dq(start_k, dq):
if causal:
upper_bound = lax.div((start_q + 1) * block_q2, block_k2)
else:
upper_bound = pl.cdiv(seq_len, block_k2)
upper_bound = pl.cdiv(kv_seq_len, block_k2)

dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq))
dq_ref[...] = dq.astype(dq_ref.dtype)
Expand All @@ -508,9 +511,10 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,
segment_ids,
)[1](do)
elif backward_pass_impl == "triton":
batch_size, seq_len, num_heads, head_dim = q.shape
block_q = min(block_q, seq_len)
block_k = min(block_k, seq_len)
batch_size, q_seq_len, num_heads, head_dim = q.shape
kv_seq_len = k.shape[1]
block_q = min(block_q, q_seq_len)
block_k = min(block_k, kv_seq_len)
delta = _preprocess_backward(out, do, lse, block_q, debug, interpret)
out_shapes = [
jax.ShapeDtypeStruct(q.shape, q.dtype),
Expand All @@ -520,29 +524,29 @@ def _mha_backward(sm_scale: float, causal: bool, block_q: int, block_k: int,

in_specs = [
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
(None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
(None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
(None, kv_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
(None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec(
(None, seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
(None, q_seq_len, None, head_dim), lambda i, j, _: (i, 0, j, 0)
),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
pl.BlockSpec((None, None, q_seq_len), lambda i, j, _: (i, j, 0)),
]
if segment_ids is None:
in_specs.insert(3, None) # type: ignore[arg-type]
else:
in_specs.insert(3, pl.BlockSpec((None, seq_len), lambda i, j, _: (i, 0)))
in_specs.insert(3, pl.BlockSpec((None, kv_seq_len), lambda i, j, _: (i, 0)))

grid = (batch_size, num_heads, pl.cdiv(seq_len, block_k))
grid = (batch_size, num_heads, pl.cdiv(kv_seq_len, block_k))
num_warps = 8
dq, dk, dv = pl.pallas_call(
functools.partial(
Expand Down