Skip to content

Commit

Permalink
add a llama_ppl model that use ppl fast decode attention kernel. (#132)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Sep 14, 2023
1 parent feff041 commit bf8a829
Show file tree
Hide file tree
Showing 8 changed files with 208 additions and 3 deletions.
15 changes: 15 additions & 0 deletions lightllm/common/ppl_int8kv_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import torch

from .mem_manager import MemoryManager


class PPLINT8KVMemoryManager(MemoryManager):
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True):
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
group_quant_size = 8
self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)]
self.key_scale_buffer = [torch.empty((size, head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)]
self.value_scale_buffer = [torch.empty((size, head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)]
Empty file.
Empty file.
63 changes: 63 additions & 0 deletions lightllm/models/llama_ppl/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np
from typing import Tuple

from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.llama2.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight
from lightllm.models.llama.infer_struct import LlamaInferStateInfo

class LlamaPPlTransformerLayerInfer(LlamaTransformerLayerInfer):
"""
"""

def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
return

def _ffn(self, input, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerWeight)->torch.Tensor:
gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_proj)
torch.nn.functional.silu(gate_out, inplace=True)
up_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.up_proj)
input = None
ffn1_out = gate_out * up_out
gate_out, up_out = None, None
ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj)
ffn1_out = None
return ffn2_out

def _copy_kv_to_mem_cache(self, key_buffer, value_buffer, mem_index, mem_manager):
from lightllm.models.llama_ppl.triton_kernel.quant_copy_kv import destindex_copy_quantize_kv
destindex_copy_quantize_kv(key_buffer,
mem_index,
mem_manager.key_buffer[self.layer_num_],
mem_manager.key_scale_buffer[self.layer_num_])
destindex_copy_quantize_kv(value_buffer,
mem_index,
mem_manager.value_buffer[self.layer_num_],
mem_manager.value_scale_buffer[self.layer_num_])

def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo):
total_token_num = infer_state.total_token_num
batch_size = infer_state.batch_size
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda")
o_tensor = torch.empty_like(q)

from lightllm_ppl_kernel import group8_int8kv_decode_attention
# group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch)
group8_int8kv_decode_attention(o_tensor.view(calcu_shape1),
q.view(calcu_shape1),
infer_state.mem_manager.key_buffer[self.layer_num_],
infer_state.mem_manager.key_scale_buffer[self.layer_num_],
infer_state.mem_manager.value_buffer[self.layer_num_],
infer_state.mem_manager.value_scale_buffer[self.layer_num_],
infer_state.b_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch)

return o_tensor

def _token_decode_attention_mode(self, q, infer_state: LlamaInferStateInfo):
return self._token_decode_attention_int8kv(q, infer_state)
26 changes: 26 additions & 0 deletions lightllm/models/llama_ppl/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import torch
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager
from lightllm.models.llama_ppl.layer_infer.transformer_layer_infer import LlamaPPlTransformerLayerInfer

class LlamaPPlTpPartModel(LlamaTpPartModel):

transformer_layer_infer_class = LlamaPPlTransformerLayerInfer

memory_manager_class = PPLINT8KVMemoryManager

def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=[]):
super().__init__(tp_rank, world_size, weight_dir, max_total_token_num, load_way, mode)
return

def _verify_params(self):
assert self.load_way == "HF", "llama only support HF format to load Now!"
assert "int8kv" in self.mode, "only support int8kv mode"

def _init_mem_manager(self):
self.mem_manager = self.memory_manager_class(self.max_total_token_num,
dtype=torch.float16,
head_num=self.config["num_attention_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"])
return
Empty file.
97 changes: 97 additions & 0 deletions lightllm/models/llama_ppl/triton_kernel/quant_copy_kv.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
import torch

import triton
import triton.language as tl


@triton.jit
def _fwd_kernel_destindex_copy_quantize_kv(
K, Dest_loc, Out, Out_scale,
stride_k_bs, stride_k_h, stride_k_g, stride_k_d,
stride_o_bs, stride_o_h, stride_o_g, stride_o_d,
stride_os_bs, stride_os_h, stride_os_g,
group_size,
BLOCK_GROUP_NUM: tl.constexpr,
BLOCK_GROUP_DIM: tl.constexpr
):
cur_index = tl.program_id(0)
cur_head = tl.program_id(1)

offs_g = tl.arange(0, BLOCK_GROUP_NUM)
offs_d = tl.arange(0, BLOCK_GROUP_DIM)

dest_index = tl.load(Dest_loc + cur_index)

src_data = tl.load(K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :],
mask=offs_g[:, None] < group_size, other=0.0)
abs_data = tl.abs(src_data)
data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16)
q_src_data = (src_data / data_scale[:, None]).to(tl.int8)

o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :]
os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g
tl.store(o_ptrs, q_src_data, mask=offs_g[:, None]<group_size)
tl.store(os_ptrs, data_scale)
return


@torch.no_grad()
def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale):
seq_len = DestLoc.shape[0]
head_num = K.shape[1]
head_dim = K.shape[2]
quant_group_dim = 8

assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv"
grid = (seq_len, head_num)
num_warps = 1

group_size = head_dim // quant_group_dim
group_dim = quant_group_dim

K = K.view((K.shape[0], K.shape[1], group_size, group_dim))
Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim)

_fwd_kernel_destindex_copy_quantize_kv[grid](
K, DestLoc, Out, Out_scale,
K.stride(0), K.stride(1), K.stride(2), K.stride(3),
Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3),
Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2),
group_size,
BLOCK_GROUP_NUM=triton.next_power_of_2(group_size),
BLOCK_GROUP_DIM=group_dim,
num_warps=num_warps,
num_stages=1,
)
return


def test2():
import time

B, N_CTX, H, D = 32, 1024, 12, 128
src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda()
dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda()
value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8)
scale_dest = torch.randn((B * N_CTX, H, D // 8), dtype=torch.float16).cuda()

for _ in range(10):
destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest)
torch.cuda.synchronize()
t1 = time.time()
for _ in range(1000):
destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest)
torch.cuda.synchronize()
t2 = time.time()

print("Time cost ", t2 - t1)
value_dest = value_dest.view((B * N_CTX, H, D // 8, 8))
scale_dest = scale_dest.view((B * N_CTX, H, D // 8, 1))
print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src)))
print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src)))
cos = torch.nn.CosineSimilarity(0)
print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32)))


if __name__ == '__main__':
test2()
10 changes: 7 additions & 3 deletions lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from lightllm.models.bloom.model import BloomTpPartModel
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.models.llama_quantized.model import LlamaTpPartModelQuantized
from lightllm.models.llama_ppl.model import LlamaPPlTpPartModel
from lightllm.models.llama2.model import Llama2TpPartModel
from lightllm.models.starcoder.model import StarcoderTpPartModel
from lightllm.models.qwen.model import QWenTpPartModel
Expand Down Expand Up @@ -54,10 +55,13 @@ def exposed_init_model(self, rank_id, world_size, weight_dir, max_total_token_nu
if "num_key_value_heads" in model_cfg.keys():
self.model = Llama2TpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
else:
if 'int8weight' in mode or 'int4weight' in mode:
self.model = LlamaTpPartModelQuantized(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
if "ppl" not in mode:
if 'int8weight' in mode or 'int4weight' in mode:
self.model = LlamaTpPartModelQuantized(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
else:
self.model = LlamaTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
else:
self.model = LlamaTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
self.model = LlamaPPlTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
elif self.model_type == "qwen":
self.model = QWenTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
elif self.model_type == "baichuan":
Expand Down

0 comments on commit bf8a829

Please sign in to comment.