Skip to content

Commit

Permalink
Upgrade gqa attention kernel. (#247)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Dec 14, 2023
1 parent a9cf015 commit bda7154
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 19 deletions.
40 changes: 29 additions & 11 deletions lightllm/models/llama/triton_kernel/gqa_flash_decoding.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import time
import torch
import numpy as np
from lightllm.common.basemodel import InferStateInfo

def gqa_token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None):
def gqa_token_decode_attention_flash_decoding(q, infer_state:InferStateInfo, q_head_num, head_dim, cache_k, cache_v, out=None):
BLOCK_SEQ = 128
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch
Expand All @@ -12,6 +15,20 @@ def gqa_token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_d
o_tensor = torch.empty_like(q) if out is None else out

if getattr(infer_state, 'mid_o', None) is None:
# start_time = time.time()
b_seq_len_numpy = infer_state.b_seq_len.cpu().numpy()

block_batch_ids = torch.from_numpy(np.concatenate([np.full(((b_seq_len_numpy[batch_id] + BLOCK_SEQ - 1) // BLOCK_SEQ,), fill_value=batch_id, dtype=np.int32)
for batch_id in range(len(b_seq_len_numpy))], axis=0)).cuda()

block_start_indexes = torch.from_numpy(np.concatenate([np.arange(0, seq_len, BLOCK_SEQ, dtype=np.int32)
for seq_len in b_seq_len_numpy], axis=0)).cuda()

assert len(block_batch_ids) == len(block_start_indexes)
infer_state.block_batch_ids = block_batch_ids
infer_state.block_start_indexes = block_start_indexes
# print("build block params cost:", (time.time() - start_time) * 1000)

infer_state.mid_o = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1,
Expand All @@ -27,16 +44,17 @@ def gqa_token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_d
mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum

flash_decode_stage1(q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ)
flash_decode_stage1(infer_state.block_batch_ids,
infer_state.block_start_indexes,
q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ)
flash_decode_stage2(mid_o,
mid_o_logexpsum,
infer_state.b_seq_len,
Expand Down
22 changes: 14 additions & 8 deletions lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

@triton.jit
def _fwd_kernel_flash_decode_stage1(
block_batch_ids, block_start_indexes,
Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen,
Mid_O, # [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum, #[batch, head, seq_block_num]
Expand All @@ -19,17 +20,19 @@ def _fwd_kernel_flash_decode_stage1(
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr
):
cur_batch = tl.program_id(0)
cur_block_id = tl.program_id(0)
cur_kv_head = tl.program_id(1)
seq_start_block = tl.program_id(2)

cur_batch = tl.load(block_batch_ids + cur_block_id)
seq_start_index = tl.load(block_start_indexes + cur_block_id)

cur_q_head_offs = tl.arange(0, Q_HEAD_NUM)
cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs

offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_start_index = seq_start_block * BLOCK_SEQ
cur_batch_start_index = seq_start_index
cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ)

off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :]
Expand Down Expand Up @@ -67,15 +70,16 @@ def _fwd_kernel_flash_decode_stage1(

need_store = tl.where(block_n_size == 0, 0, 1)
for _ in range(0, need_store, 1):
off_mid_o = cur_batch * stride_mid_ob + cur_q_head_range[:, None] * stride_mid_oh + seq_start_block * stride_mid_os + offs_d[None, :]
off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block
seq_block_index = cur_batch_start_index // BLOCK_SEQ
off_mid_o = cur_batch * stride_mid_ob + cur_q_head_range[:, None] * stride_mid_oh + seq_block_index * stride_mid_os + offs_d[None, :]
off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_block_index
tl.store(Mid_O + off_mid_o, acc / sum_exp[:, None], mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size)
tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp), mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size)
return


@torch.no_grad()
def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq):
def flash_decode_stage1(block_batch_ids, block_start_indexes, q, k, v, Req_to_tokens, B_req_idx, B_seq_len, mid_out, mid_out_logsumexp, block_seq):
BLOCK_SEQ = block_seq
BLOCK_N = 16
assert BLOCK_SEQ % BLOCK_N == 0
Expand All @@ -85,11 +89,13 @@ def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lk ** 0.5)
batch, kv_head_num = B_req_idx.shape[0], k.shape[1]
grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))
block_nums = len(block_batch_ids)
grid = (block_nums, kv_head_num)
gqa_group_size = q.shape[1] // k.shape[1]

_fwd_kernel_flash_decode_stage1[grid](
q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen,
block_batch_ids, block_start_indexes,
q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_seq_len,
mid_out,
mid_out_logsumexp,
Req_to_tokens.stride(0), Req_to_tokens.stride(1),
Expand Down

0 comments on commit bda7154

Please sign in to comment.