Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

update torch version to 2.4.0 and update gqa flash decoding kernel #536

Merged
merged 1 commit into from
Sep 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ RUN mkdir ~/cuda-nvcc && cd ~/cuda-nvcc && \
WORKDIR /root

COPY ./requirements.txt /lightllm/requirements.txt
RUN pip install -r /lightllm/requirements.txt --no-cache-dir --ignore-installed
RUN pip install -r /lightllm/requirements.txt --no-cache-dir --ignore-installed --extra-index-url https://download.pytorch.org/whl/cu118

COPY . /lightllm
RUN pip install -e /lightllm --no-cache-dir
10 changes: 5 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
The code has been tested with Pytorch>=1.3, CUDA 11.8, and Python 3.9. To install the necessary dependencies, please refer to the provided **requirements.txt** and follow the instructions as

~~~shell
pip install -r requirements.txt
# for cuda 11.8
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118
~~~

### Container
Expand Down Expand Up @@ -136,13 +137,12 @@ python setup.py install

- Install Triton Package

The code has been tested on a range of GPUs including V100, A100, A800, 4090, and H800. If you are running the code on A100, A800, etc., we recommend using triton==2.1.0.
The code has been tested on a range of GPUs including V100, A100, A800, 4090, and H800. If you are running the code on A100, A800, etc., we recommend using triton==3.0.0.

~~~shell
pip install triton==2.1.0 --no-deps
pip install triton==3.0.0 --no-deps
~~~
If you are running the code on H800 or V100., we recommend using triton-nightly, triton-nightly has a significant CPU bottleneck, leading to high decode latency at low concurrency levels. You can observe [this issue](https://github.com/openai/triton/issues/3619) and [fix PR](https://github.com/openai/triton/pull/3638).You can try modifying and compiling the
source code yourself to resolve this issue.
If you are running the code on H800 or V100., you can try triton-nightly to get better performance.
~~~shell
pip install -U --index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/Triton-Nightly/pypi/simple/ triton-nightly --no-deps
~~~
Expand Down
68 changes: 22 additions & 46 deletions lightllm/models/llama/triton_kernel/gqa_flash_decoding.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
import time
import torch
import numpy as np
from lightllm.common.basemodel import InferStateInfo

def gqa_token_decode_attention_flash_decoding(q, infer_state:InferStateInfo, q_head_num, head_dim, cache_k, cache_v, out=None):

def gqa_token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None):
BLOCK_SEQ = 128
batch_size = infer_state.batch_size
max_len_in_batch = infer_state.max_len_in_batch
Expand All @@ -13,51 +11,29 @@ def gqa_token_decode_attention_flash_decoding(q, infer_state:InferStateInfo, q_h
from .gqa_flash_decoding_stage2 import flash_decode_stage2

o_tensor = torch.empty_like(q) if out is None else out

if getattr(infer_state, 'mid_o', None) is None:
# start_time = time.time()
b_seq_len_numpy = infer_state.b_seq_len.cpu().numpy()

block_batch_ids = torch.from_numpy(np.concatenate([np.full(((b_seq_len_numpy[batch_id] + BLOCK_SEQ - 1) // BLOCK_SEQ,), fill_value=batch_id, dtype=np.int32)
for batch_id in range(len(b_seq_len_numpy))], axis=0)).cuda()

block_start_indexes = torch.from_numpy(np.concatenate([np.arange(0, seq_len, BLOCK_SEQ, dtype=np.int32)
for seq_len in b_seq_len_numpy], axis=0)).cuda()

assert len(block_batch_ids) == len(block_start_indexes)
infer_state.block_batch_ids = block_batch_ids
infer_state.block_start_indexes = block_start_indexes
# print("build block params cost:", (time.time() - start_time) * 1000)
if getattr(infer_state, "mid_o", None) is None:
infer_state.mid_o = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda"
)
infer_state.mid_o_logexpsum = torch.empty(
[batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda"
)

infer_state.mid_o = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1,
head_dim],
dtype=torch.float32,
device="cuda")
infer_state.mid_o_logexpsum = torch.empty([batch_size,
q_head_num,
max_len_in_batch // BLOCK_SEQ + 1],
dtype=torch.float32,
device="cuda")

mid_o = infer_state.mid_o
mid_o_logexpsum = infer_state.mid_o_logexpsum

flash_decode_stage1(infer_state.block_batch_ids,
infer_state.block_start_indexes,
q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ)
flash_decode_stage2(mid_o,
mid_o_logexpsum,
infer_state.b_seq_len,
o_tensor.view(calcu_shape1),
BLOCK_SEQ)
flash_decode_stage1(
q.view(calcu_shape1),
cache_k,
cache_v,
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
mid_o,
mid_o_logexpsum,
BLOCK_SEQ,
)
flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ)
return o_tensor
149 changes: 105 additions & 44 deletions lightllm/models/llama/triton_kernel/gqa_flash_decoding_stage1.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,45 +2,68 @@
import triton
import triton.language as tl


@triton.jit
def _fwd_kernel_flash_decode_stage1(
block_batch_ids, block_start_indexes,
Q, K, V, sm_scale, Req_to_tokens, B_req_idx, B_Seqlen,
Mid_O, # [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum, #[batch, head, seq_block_num]
stride_req_to_tokens_b, stride_req_to_tokens_s,
stride_qbs, stride_qh, stride_qd,
stride_kbs, stride_kh, stride_kd,
stride_vbs, stride_vh, stride_vd,
stride_mid_ob, stride_mid_oh, stride_mid_os, stride_mid_od,
stride_mid_o_eb, stride_mid_o_eh, stride_mid_o_es,
Q,
K,
V,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Seqlen,
Mid_O, # [batch, head, seq_block_num, head_dim]
Mid_O_LogExpSum, # [batch, head, seq_block_num]
stride_req_to_tokens_b,
stride_req_to_tokens_s,
stride_qbs,
stride_qh,
stride_qd,
stride_kbs,
stride_kh,
stride_kd,
stride_vbs,
stride_vh,
stride_vd,
stride_mid_ob,
stride_mid_oh,
stride_mid_os,
stride_mid_od,
stride_mid_o_eb,
stride_mid_o_eh,
stride_mid_o_es,
gqa_group_size,
Q_HEAD_NUM: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
BLOCK_SEQ: tl.constexpr,
BLOCK_DMODEL: tl.constexpr,
BLOCK_N: tl.constexpr
BLOCK_N: tl.constexpr,
):
cur_block_id = tl.program_id(0)
cur_batch = tl.program_id(0)
cur_kv_head = tl.program_id(1)

cur_batch = tl.load(block_batch_ids + cur_block_id)
seq_start_index = tl.load(block_start_indexes + cur_block_id)
seq_start_block = tl.program_id(2)

cur_q_head_offs = tl.arange(0, Q_HEAD_NUM)
cur_q_head_range = cur_kv_head * gqa_group_size + cur_q_head_offs

offs_d = tl.arange(0, BLOCK_DMODEL)
cur_batch_seq_len = tl.load(B_Seqlen + cur_batch)
cur_batch_req_idx = tl.load(B_req_idx + cur_batch)
cur_batch_start_index = seq_start_index
cur_batch_start_index = seq_start_block * BLOCK_SEQ
cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ)

off_q = cur_batch * stride_qbs + cur_q_head_range[:, None] * stride_qh + offs_d[None, :]

block_n_size = tl.where(cur_batch_end_index - cur_batch_start_index <= 0, 0, cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1) // BLOCK_N


block_n_size = (
tl.where(
cur_batch_end_index - cur_batch_start_index <= 0,
0,
cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1,
)
// BLOCK_N
)

offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N)

q = tl.load(Q + off_q, mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size, other=0.0)

sum_exp = tl.zeros([Q_HEAD_NUM], dtype=tl.float32)
Expand All @@ -49,14 +72,22 @@ def _fwd_kernel_flash_decode_stage1(

for start_n in range(0, block_n_size, 1):
offs_n_new = start_n * BLOCK_N + offs_n
k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, mask=offs_n_new < cur_batch_end_index, other=0)
k_loc = tl.load(
Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new,
mask=offs_n_new < cur_batch_end_index,
other=0,
)
off_k = k_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None]
k = tl.load(K + off_k, mask=offs_n_new[None, :] < cur_batch_end_index, other=0.0)
att_value = tl.dot(q, k)
att_value *= sm_scale
att_value = tl.where(offs_n_new[None, :] < cur_batch_end_index, att_value, float("-inf"))
v = tl.load(V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :], mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0)

v = tl.load(
V + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :],
mask=offs_n_new[:, None] < cur_batch_end_index,
other=0.0,
)

cur_max_logic = tl.max(att_value, axis=1)
new_max_logic = tl.maximum(cur_max_logic, max_logic)

Expand All @@ -67,19 +98,33 @@ def _fwd_kernel_flash_decode_stage1(

sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=1)
max_logic = new_max_logic

need_store = tl.where(block_n_size == 0, 0, 1)
for _ in range(0, need_store, 1):
seq_block_index = cur_batch_start_index // BLOCK_SEQ
off_mid_o = cur_batch * stride_mid_ob + cur_q_head_range[:, None] * stride_mid_oh + seq_block_index * stride_mid_os + offs_d[None, :]
off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_block_index
tl.store(Mid_O + off_mid_o, acc / sum_exp[:, None], mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size)
tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp), mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size)
off_mid_o = (
cur_batch * stride_mid_ob
+ cur_q_head_range[:, None] * stride_mid_oh
+ seq_start_block * stride_mid_os
+ offs_d[None, :]
)
off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_q_head_range * stride_mid_o_eh + seq_start_block
tl.store(
Mid_O + off_mid_o,
acc / sum_exp[:, None],
mask=cur_q_head_range[:, None] < (cur_kv_head + 1) * gqa_group_size,
)
tl.store(
Mid_O_LogExpSum + off_mid_o_logexpsum,
max_logic + tl.log(sum_exp),
mask=cur_q_head_range < (cur_kv_head + 1) * gqa_group_size,
)
return


@torch.no_grad()
def flash_decode_stage1(block_batch_ids, block_start_indexes, q, k, v, Req_to_tokens, B_req_idx, B_seq_len, mid_out, mid_out_logsumexp, block_seq):
def flash_decode_stage1(
q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq
):
BLOCK_SEQ = block_seq
BLOCK_N = 16
assert BLOCK_SEQ % BLOCK_N == 0
Expand All @@ -89,21 +134,37 @@ def flash_decode_stage1(block_batch_ids, block_start_indexes, q, k, v, Req_to_to
assert Lk in {16, 32, 64, 128}
sm_scale = 1.0 / (Lk ** 0.5)
batch, kv_head_num = B_req_idx.shape[0], k.shape[1]
block_nums = len(block_batch_ids)
grid = (block_nums, kv_head_num)
grid = (batch, kv_head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ))
gqa_group_size = q.shape[1] // k.shape[1]

_fwd_kernel_flash_decode_stage1[grid](
block_batch_ids, block_start_indexes,
q, k, v, sm_scale, Req_to_tokens, B_req_idx, B_seq_len,
q,
k,
v,
sm_scale,
Req_to_tokens,
B_req_idx,
B_Seqlen,
mid_out,
mid_out_logsumexp,
Req_to_tokens.stride(0), Req_to_tokens.stride(1),
q.stride(0), q.stride(1), q.stride(2),
k.stride(0), k.stride(1), k.stride(2),
v.stride(0), v.stride(1), v.stride(2),
mid_out.stride(0), mid_out.stride(1), mid_out.stride(2), mid_out.stride(3),
mid_out_logsumexp.stride(0), mid_out_logsumexp.stride(1), mid_out_logsumexp.stride(2),
Req_to_tokens.stride(0),
Req_to_tokens.stride(1),
q.stride(0),
q.stride(1),
q.stride(2),
k.stride(0),
k.stride(1),
k.stride(2),
v.stride(0),
v.stride(1),
v.stride(2),
mid_out.stride(0),
mid_out.stride(1),
mid_out.stride(2),
mid_out.stride(3),
mid_out_logsumexp.stride(0),
mid_out_logsumexp.stride(1),
mid_out_logsumexp.stride(2),
gqa_group_size,
Q_HEAD_NUM=max(16, triton.next_power_of_2(gqa_group_size)),
BLOCK_SEQ=BLOCK_SEQ,
Expand All @@ -112,4 +173,4 @@ def flash_decode_stage1(block_batch_ids, block_start_indexes, q, k, v, Req_to_to
num_warps=2,
num_stages=2,
)
return
return
9 changes: 4 additions & 5 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@ filelock==3.12.2
fsspec==2023.6.0
gmpy2==2.1.2
h11==0.14.0
huggingface-hub==0.24.6
humanfriendly==10.0
humanize==4.7.0
idna==3.4
Expand Down Expand Up @@ -61,11 +60,13 @@ sniffio==1.3.0
sympy==1.12
sortedcontainers==2.4.0
toolz==0.12.0
torch==2.1.0
torch==2.4.0
torchvision==0.19.0
tqdm==4.65.0
transformers==4.43.2
tokenizers==0.19.1
triton==2.1.0
huggingface-hub==0.24.6
triton==3.0.0
urllib3==1.26.16
uvicorn==0.19.0
uvloop==0.17.0
Expand All @@ -75,8 +76,6 @@ safetensors==0.4.3
Pillow==10.2.0
tiktoken==0.5.2
matplotlib==3.8.2
--extra-index-url https://download.pytorch.org/whl/cu118
torchvision==0.16.0
psutil==5.9.4
prometheus_client==0.20.0
outlines==0.0.46
Expand Down
Loading