From bda7154ae9e8210b56fbae4ac00b0182f2b90686 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 14 Dec 2023 16:04:13 +0800 Subject: [PATCH] Upgrade gqa attention kernel. (#247) --- .../llama/triton_kernel/gqa_flash_decoding.py | 40 ++++++++++++++----- .../gqa_flash_decoding_stage1.py | 22 ++++++---- 2 files changed, 43 insertions(+), 19 deletions(-) diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py index f69a7bf9..f3a1eea0 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py @@ -1,6 +1,9 @@ +import time import torch +import numpy as np +from lightllm.common.basemodel import InferStateInfo -def gqa_token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None): +def gqa_token_decode_attention_flash_decoding(q, infer_state:InferStateInfo, q_head_num, head_dim, cache_k, cache_v, out=None): BLOCK_SEQ = 128 batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch @@ -12,6 +15,20 @@ def gqa_token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_d o_tensor = torch.empty_like(q) if out is None else out if getattr(infer_state, 'mid_o', None) is None: + # start_time = time.time() + b_seq_len_numpy = infer_state.b_seq_len.cpu().numpy() + + block_batch_ids = torch.from_numpy(np.concatenate([np.full(((b_seq_len_numpy[batch_id] + BLOCK_SEQ - 1) // BLOCK_SEQ,), fill_value=batch_id, dtype=np.int32) + for batch_id in range(len(b_seq_len_numpy))], axis=0)).cuda() + + block_start_indexes = torch.from_numpy(np.concatenate([np.arange(0, seq_len, BLOCK_SEQ, dtype=np.int32) + for seq_len in b_seq_len_numpy], axis=0)).cuda() + + assert len(block_batch_ids) == len(block_start_indexes) + infer_state.block_batch_ids = block_batch_ids + infer_state.block_start_indexes = block_start_indexes + # print("build block params cost:", (time.time() - start_time) * 1000) + infer_state.mid_o = torch.empty([batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, @@ -27,16 +44,17 @@ def gqa_token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_d mid_o = infer_state.mid_o mid_o_logexpsum = infer_state.mid_o_logexpsum - flash_decode_stage1(q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - infer_state.max_len_in_batch, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ) + flash_decode_stage1(infer_state.block_batch_ids, + infer_state.block_start_indexes, + q.view(calcu_shape1), + cache_k, + cache_v, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + mid_o, + mid_o_logexpsum, + BLOCK_SEQ) flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py index 6efbf47c..cb5c2e36 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py +++ b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py @@ -4,6 +4,7 @@ @triton.jit def _fwd_kernel_flash_decode_stage1( + block_batch_ids, block_start_indexes, Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, Mid_O, # [batch, head, seq_block_num, head_dim] Mid_O_LogExpSum, #[batch, head, seq_block_num] @@ -19,9 +20,11 @@ def _fwd_kernel_flash_decode_stage1( BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr ): - cur_batch = tl.program_id(0) + cur_block_id = tl.program_id(0) cur_kv_head = tl.program_id(1) - seq_start_block = tl.program_id(2) + + cur_batch = tl.load(block_batch_ids + cur_block_id) + seq_start_index = tl.load(block_start_indexes + cur_block_id) cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs @@ -29,7 +32,7 @@ def _fwd_kernel_flash_decode_stage1( offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_start_index = seq_start_index cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] @@ -67,15 +70,16 @@ def _fwd_kernel_flash_decode_stage1( need_store = tl.where(block_n_size == 0, 0, 1) for _ in range(0, need_store, 1): - off_mid_o = cur_batch * stride_mid_ob + cur_q_head_range[:, None] * stride_mid_oh + seq_start_block * stride_mid_os + offs_d[None, :] - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block + seq_block_index = cur_batch_start_index // BLOCK_SEQ + off_mid_o = cur_batch * stride_mid_ob + cur_q_head_range[:, None] * stride_mid_oh + seq_block_index * stride_mid_os + offs_d[None, :] + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_block_index tl.store(Mid_O + off_mid_o, acc / sum_exp[:, None], mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size) tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp), mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size) return @torch.no_grad() -def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq): +def flash_decode_stage1(block_batch_ids, block_start_indexes, q, k, v, Req_to_tokens, B_req_idx, B_seq_len, mid_out, mid_out_logsumexp, block_seq): BLOCK_SEQ = block_seq BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 @@ -85,11 +89,13 @@ def flash_decode_stage1(q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_ assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] - grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + block_nums = len(block_batch_ids) + grid = (block_nums, kv_head_num) gqa_group_size = q.shape[1] // k.shape[1] _fwd_kernel_flash_decode_stage1[grid]( - q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, + block_batch_ids, block_start_indexes, + q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_seq_len, mid_out, mid_out_logsumexp, Req_to_tokens.stride(0), Req_to_tokens.stride(1),