diff --git a/Dockerfile b/Dockerfile index 22a2b72a..c2d04e8d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -47,7 +47,7 @@ RUN mkdir ~/cuda-nvcc && cd ~/cuda-nvcc && \ WORKDIR /root COPY ./requirements.txt /lightllm/requirements.txt -RUN pip install -r /lightllm/requirements.txt --no-cache-dir --ignore-installed +RUN pip install -r /lightllm/requirements.txt --no-cache-dir --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu118 COPY . /lightllm RUN pip install -e /lightllm --no-cache-dir diff --git a/README.md b/README.md index a2441ec1..e0e3609d 100644 --- a/README.md +++ b/README.md @@ -86,7 +86,8 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram The code has been tested with Pytorch>=1.3, CUDA 11.8, and Python 3.9. To install the necessary dependencies, please refer to the provided **requirements.txt** and follow the instructions as ~~~shell -pip install -r requirements.txt +# for cuda 11.8 +pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 ~~~ ### Container @@ -136,13 +137,12 @@ python setup.py install - Install Triton Package -The code has been tested on a range of GPUs including V100, A100, A800, 4090, and H800. If you are running the code on A100, A800, etc., we recommend using triton==2.1.0. +The code has been tested on a range of GPUs including V100, A100, A800, 4090, and H800. If you are running the code on A100, A800, etc., we recommend using triton==3.0.0. ~~~shell -pip install triton==2.1.0 --no-deps +pip install triton==3.0.0 --no-deps ~~~ -If you are running the code on H800 or V100., we recommend using triton-nightly, triton-nightly has a significant CPU bottleneck, leading to high decode latency at low concurrency levels. You can observe [this issue](https://github.com/openai/triton/issues/3619) and [fix PR](https://github.com/openai/triton/pull/3638).You can try modifying and compiling the -source code yourself to resolve this issue. +If you are running the code on H800 or V100., you can try triton-nightly to get better performance. ~~~shell pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps ~~~ diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py index f3a1eea0..9ade94c9 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py +++ b/lightllm/models/llama/triton_kernel/gqa_flash_decoding.py @@ -1,9 +1,7 @@ -import time import torch -import numpy as np -from lightllm.common.basemodel import InferStateInfo -def gqa_token_decode_attention_flash_decoding(q, infer_state:InferStateInfo, q_head_num, head_dim, cache_k, cache_v, out=None): + +def gqa_token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None): BLOCK_SEQ = 128 batch_size = infer_state.batch_size max_len_in_batch = infer_state.max_len_in_batch @@ -13,51 +11,29 @@ def gqa_token_decode_attention_flash_decoding(q, infer_state:InferStateInfo, q_h from .gqa_flash_decoding_stage2 import flash_decode_stage2 o_tensor = torch.empty_like(q) if out is None else out - - if getattr(infer_state, 'mid_o', None) is None: - # start_time = time.time() - b_seq_len_numpy = infer_state.b_seq_len.cpu().numpy() - block_batch_ids = torch.from_numpy(np.concatenate([np.full(((b_seq_len_numpy[batch_id] + BLOCK_SEQ - 1) // BLOCK_SEQ,), fill_value=batch_id, dtype=np.int32) - for batch_id in range(len(b_seq_len_numpy))], axis=0)).cuda() - - block_start_indexes = torch.from_numpy(np.concatenate([np.arange(0, seq_len, BLOCK_SEQ, dtype=np.int32) - for seq_len in b_seq_len_numpy], axis=0)).cuda() - - assert len(block_batch_ids) == len(block_start_indexes) - infer_state.block_batch_ids = block_batch_ids - infer_state.block_start_indexes = block_start_indexes - # print("build block params cost:", (time.time() - start_time) * 1000) + if getattr(infer_state, "mid_o", None) is None: + infer_state.mid_o = torch.empty( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + ) + infer_state.mid_o_logexpsum = torch.empty( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + ) - infer_state.mid_o = torch.empty([batch_size, - q_head_num, - max_len_in_batch // BLOCK_SEQ + 1, - head_dim], - dtype=torch.float32, - device="cuda") - infer_state.mid_o_logexpsum = torch.empty([batch_size, - q_head_num, - max_len_in_batch // BLOCK_SEQ + 1], - dtype=torch.float32, - device="cuda") - mid_o = infer_state.mid_o mid_o_logexpsum = infer_state.mid_o_logexpsum - flash_decode_stage1(infer_state.block_batch_ids, - infer_state.block_start_indexes, - q.view(calcu_shape1), - cache_k, - cache_v, - infer_state.req_manager.req_to_token_indexs, - infer_state.b_req_idx, - infer_state.b_seq_len, - mid_o, - mid_o_logexpsum, - BLOCK_SEQ) - flash_decode_stage2(mid_o, - mid_o_logexpsum, - infer_state.b_seq_len, - o_tensor.view(calcu_shape1), - BLOCK_SEQ) + flash_decode_stage1( + q.view(calcu_shape1), + cache_k, + cache_v, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + mid_o, + mid_o_logexpsum, + BLOCK_SEQ, + ) + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) return o_tensor diff --git a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py index cb5c2e36..00e649ab 100644 --- a/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py +++ b/lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py @@ -2,45 +2,68 @@ import triton import triton.language as tl + @triton.jit def _fwd_kernel_flash_decode_stage1( - block_batch_ids, block_start_indexes, - Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen, - Mid_O, # [batch, head, seq_block_num, head_dim] - Mid_O_LogExpSum, #[batch, head, seq_block_num] - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - stride_vbs, stride_vh, stride_vd, - stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od, - stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es, + Q, + K, + V, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, gqa_group_size, Q_HEAD_NUM: tl.constexpr, - BLOCK_SEQ: tl.constexpr, + BLOCK_SEQ: tl.constexpr, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): - cur_block_id = tl.program_id(0) + cur_batch = tl.program_id(0) cur_kv_head = tl.program_id(1) - - cur_batch = tl.load(block_batch_ids + cur_block_id) - seq_start_index = tl.load(block_start_indexes + cur_block_id) + seq_start_block = tl.program_id(2) cur_q_head_offs = tl.arange(0, Q_HEAD_NUM) cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs - + offs_d = tl.arange(0, BLOCK_DMODEL) cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_batch_start_index = seq_start_index + cur_batch_start_index = seq_start_block * BLOCK_SEQ cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :] - - block_n_size = tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1) // BLOCK_N - + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) - + q = tl.load(Q + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0) sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32) @@ -49,14 +72,22 @@ def _fwd_kernel_flash_decode_stage1( for start_n in range(0, block_n_size, 1): offs_n_new = start_n * BLOCK_N + offs_n - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0) + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] k = tl.load(K + off_k, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0) att_value = tl.dot(q, k) att_value *= sm_scale att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float("-inf")) - v = tl.load(V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :], mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - + v = tl.load( + V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :], + mask=offs_n_new[:, None] < cur_batch_end_index, + other=0.0, + ) + cur_max_logic = tl.max(att_value, axis=1) new_max_logic = tl.maximum(cur_max_logic, max_logic) @@ -67,19 +98,33 @@ def _fwd_kernel_flash_decode_stage1( sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1) max_logic = new_max_logic - + need_store = tl.where(block_n_size == 0, 0, 1) for _ in range(0, need_store, 1): - seq_block_index = cur_batch_start_index // BLOCK_SEQ - off_mid_o = cur_batch * stride_mid_ob + cur_q_head_range[:, None] * stride_mid_oh + seq_block_index * stride_mid_os + offs_d[None, :] - off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_block_index - tl.store(Mid_O + off_mid_o, acc / sum_exp[:, None], mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size) - tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp), mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size) + off_mid_o = ( + cur_batch * stride_mid_ob + + cur_q_head_range[:, None] * stride_mid_oh + + seq_start_block * stride_mid_os + + offs_d[None, :] + ) + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block + tl.store( + Mid_O + off_mid_o, + acc / sum_exp[:, None], + mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, + ) + tl.store( + Mid_O_LogExpSum + off_mid_o_logexpsum, + max_logic + tl.log(sum_exp), + mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size, + ) return @torch.no_grad() -def flash_decode_stage1(block_batch_ids, block_start_indexes, q, k, v, Req_to_tokens, B_req_idx, B_seq_len, mid_out, mid_out_logsumexp, block_seq): +def flash_decode_stage1( + q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq +): BLOCK_SEQ = block_seq BLOCK_N = 16 assert BLOCK_SEQ % BLOCK_N == 0 @@ -89,21 +134,37 @@ def flash_decode_stage1(block_batch_ids, block_start_indexes, q, k, v, Req_to_to assert Lk in {16, 32, 64, 128} sm_scale = 1.0 / (Lk ** 0.5) batch, kv_head_num = B_req_idx.shape[0], k.shape[1] - block_nums = len(block_batch_ids) - grid = (block_nums, kv_head_num) + grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) gqa_group_size = q.shape[1] // k.shape[1] - + _fwd_kernel_flash_decode_stage1[grid]( - block_batch_ids, block_start_indexes, - q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_seq_len, + q, + k, + v, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, mid_out, mid_out_logsumexp, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - v.stride(0), v.stride(1), v.stride(2), - mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3), - mid_out_logsumexp.stride(0), mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), gqa_group_size, Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)), BLOCK_SEQ=BLOCK_SEQ, @@ -112,4 +173,4 @@ def flash_decode_stage1(block_batch_ids, block_start_indexes, q, k, v, Req_to_to num_warps=2, num_stages=2, ) - return \ No newline at end of file + return diff --git a/requirements.txt b/requirements.txt index c99e4625..af973b2c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -20,7 +20,6 @@ filelock==3.12.2 fsspec==2023.6.0 gmpy2==2.1.2 h11==0.14.0 -huggingface-hub==0.24.6 humanfriendly==10.0 humanize==4.7.0 idna==3.4 @@ -61,11 +60,13 @@ sniffio==1.3.0 sympy==1.12 sortedcontainers==2.4.0 toolz==0.12.0 -torch==2.1.0 +torch==2.4.0 +torchvision==0.19.0 tqdm==4.65.0 transformers==4.43.2 tokenizers==0.19.1 -triton==2.1.0 +huggingface-hub==0.24.6 +triton==3.0.0 urllib3==1.26.16 uvicorn==0.19.0 uvloop==0.17.0 @@ -75,8 +76,6 @@ safetensors==0.4.3 Pillow==10.2.0 tiktoken==0.5.2 matplotlib==3.8.2 ---extra-index-url https://download.pytorch.org/whl/cu118 -torchvision==0.16.0 psutil==5.9.4 prometheus_client==0.20.0 outlines==0.0.46