-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add a llama_ppl model that use ppl fast decode attention kernel. (#132)
- Loading branch information
1 parent
feff041
commit bf8a829
Showing
8 changed files
with
208 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
|
||
from .mem_manager import MemoryManager | ||
|
||
|
||
class PPLINT8KVMemoryManager(MemoryManager): | ||
def __init__(self, size, dtype, head_num, head_dim, layer_num, always_copy=True): | ||
super().__init__(size, dtype, head_num, head_dim, layer_num, always_copy=True) | ||
|
||
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): | ||
group_quant_size = 8 | ||
self.key_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)] | ||
self.value_buffer = [torch.empty((size, head_num, head_dim), dtype=torch.int8, device="cuda") for _ in range(layer_num)] | ||
self.key_scale_buffer = [torch.empty((size, head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)] | ||
self.value_scale_buffer = [torch.empty((size, head_num, head_dim // group_quant_size), dtype=dtype, device="cuda") for _ in range(layer_num)] |
Empty file.
Empty file.
63 changes: 63 additions & 0 deletions
63
lightllm/models/llama_ppl/layer_infer/transformer_layer_infer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
import torch | ||
import torch.functional as F | ||
import torch.distributed as dist | ||
import numpy as np | ||
from typing import Tuple | ||
|
||
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
from lightllm.models.llama2.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight | ||
from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
|
||
class LlamaPPlTransformerLayerInfer(LlamaTransformerLayerInfer): | ||
""" | ||
""" | ||
|
||
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): | ||
super().__init__(layer_num, tp_rank, world_size, network_config, mode) | ||
return | ||
|
||
def _ffn(self, input, infer_state:LlamaInferStateInfo, layer_weight:LlamaTransformerLayerWeight)->torch.Tensor: | ||
gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_proj) | ||
torch.nn.functional.silu(gate_out, inplace=True) | ||
up_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.up_proj) | ||
input = None | ||
ffn1_out = gate_out * up_out | ||
gate_out, up_out = None, None | ||
ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj) | ||
ffn1_out = None | ||
return ffn2_out | ||
|
||
def _copy_kv_to_mem_cache(self, key_buffer, value_buffer, mem_index, mem_manager): | ||
from lightllm.models.llama_ppl.triton_kernel.quant_copy_kv import destindex_copy_quantize_kv | ||
destindex_copy_quantize_kv(key_buffer, | ||
mem_index, | ||
mem_manager.key_buffer[self.layer_num_], | ||
mem_manager.key_scale_buffer[self.layer_num_]) | ||
destindex_copy_quantize_kv(value_buffer, | ||
mem_index, | ||
mem_manager.value_buffer[self.layer_num_], | ||
mem_manager.value_scale_buffer[self.layer_num_]) | ||
|
||
def _token_decode_attention_int8kv(self, q, infer_state: LlamaInferStateInfo): | ||
total_token_num = infer_state.total_token_num | ||
batch_size = infer_state.batch_size | ||
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) | ||
att_m_tensor = torch.empty((self.tp_q_head_num_, total_token_num), dtype=q.dtype, device="cuda") | ||
o_tensor = torch.empty_like(q) | ||
|
||
from lightllm_ppl_kernel import group8_int8kv_decode_attention | ||
# group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) | ||
group8_int8kv_decode_attention(o_tensor.view(calcu_shape1), | ||
q.view(calcu_shape1), | ||
infer_state.mem_manager.key_buffer[self.layer_num_], | ||
infer_state.mem_manager.key_scale_buffer[self.layer_num_], | ||
infer_state.mem_manager.value_buffer[self.layer_num_], | ||
infer_state.mem_manager.value_scale_buffer[self.layer_num_], | ||
infer_state.b_loc, | ||
infer_state.b_seq_len, | ||
infer_state.max_len_in_batch) | ||
|
||
return o_tensor | ||
|
||
def _token_decode_attention_mode(self, q, infer_state: LlamaInferStateInfo): | ||
return self._token_decode_attention_int8kv(q, infer_state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,26 @@ | ||
import torch | ||
from lightllm.models.llama.model import LlamaTpPartModel | ||
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager | ||
from lightllm.models.llama_ppl.layer_infer.transformer_layer_infer import LlamaPPlTransformerLayerInfer | ||
|
||
class LlamaPPlTpPartModel(LlamaTpPartModel): | ||
|
||
transformer_layer_infer_class = LlamaPPlTransformerLayerInfer | ||
|
||
memory_manager_class = PPLINT8KVMemoryManager | ||
|
||
def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=[]): | ||
super().__init__(tp_rank, world_size, weight_dir, max_total_token_num, load_way, mode) | ||
return | ||
|
||
def _verify_params(self): | ||
assert self.load_way == "HF", "llama only support HF format to load Now!" | ||
assert "int8kv" in self.mode, "only support int8kv mode" | ||
|
||
def _init_mem_manager(self): | ||
self.mem_manager = self.memory_manager_class(self.max_total_token_num, | ||
dtype=torch.float16, | ||
head_num=self.config["num_attention_heads"] // self.world_size_, | ||
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], | ||
layer_num=self.config["num_hidden_layers"]) | ||
return |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import torch | ||
|
||
import triton | ||
import triton.language as tl | ||
|
||
|
||
@triton.jit | ||
def _fwd_kernel_destindex_copy_quantize_kv( | ||
K, Dest_loc, Out, Out_scale, | ||
stride_k_bs, stride_k_h, stride_k_g, stride_k_d, | ||
stride_o_bs, stride_o_h, stride_o_g, stride_o_d, | ||
stride_os_bs, stride_os_h, stride_os_g, | ||
group_size, | ||
BLOCK_GROUP_NUM: tl.constexpr, | ||
BLOCK_GROUP_DIM: tl.constexpr | ||
): | ||
cur_index = tl.program_id(0) | ||
cur_head = tl.program_id(1) | ||
|
||
offs_g = tl.arange(0, BLOCK_GROUP_NUM) | ||
offs_d = tl.arange(0, BLOCK_GROUP_DIM) | ||
|
||
dest_index = tl.load(Dest_loc + cur_index) | ||
|
||
src_data = tl.load(K + cur_index * stride_k_bs + cur_head * stride_k_h + offs_g[:, None] * stride_k_g + offs_d[None, :], | ||
mask=offs_g[:, None] < group_size, other=0.0) | ||
abs_data = tl.abs(src_data) | ||
data_scale = (tl.max(abs_data, axis=1) / 127.).to(tl.float16) | ||
q_src_data = (src_data / data_scale[:, None]).to(tl.int8) | ||
|
||
o_ptrs = Out + dest_index * stride_o_bs + cur_head * stride_o_h + offs_g[:, None] * stride_o_g + offs_d[None, :] | ||
os_ptrs = Out_scale + dest_index * stride_os_bs + cur_head * stride_os_h + offs_g | ||
tl.store(o_ptrs, q_src_data, mask=offs_g[:, None]<group_size) | ||
tl.store(os_ptrs, data_scale) | ||
return | ||
|
||
|
||
@torch.no_grad() | ||
def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): | ||
seq_len = DestLoc.shape[0] | ||
head_num = K.shape[1] | ||
head_dim = K.shape[2] | ||
quant_group_dim = 8 | ||
|
||
assert head_dim % quant_group_dim == 0, "error head dim, can not been supported to copy quant kv" | ||
grid = (seq_len, head_num) | ||
num_warps = 1 | ||
|
||
group_size = head_dim // quant_group_dim | ||
group_dim = quant_group_dim | ||
|
||
K = K.view((K.shape[0], K.shape[1], group_size, group_dim)) | ||
Out = Out.view(Out.shape[0], Out.shape[1], group_size, group_dim) | ||
|
||
_fwd_kernel_destindex_copy_quantize_kv[grid]( | ||
K, DestLoc, Out, Out_scale, | ||
K.stride(0), K.stride(1), K.stride(2), K.stride(3), | ||
Out.stride(0), Out.stride(1), Out.stride(2), Out.stride(3), | ||
Out_scale.stride(0), Out_scale.stride(1), Out_scale.stride(2), | ||
group_size, | ||
BLOCK_GROUP_NUM=triton.next_power_of_2(group_size), | ||
BLOCK_GROUP_DIM=group_dim, | ||
num_warps=num_warps, | ||
num_stages=1, | ||
) | ||
return | ||
|
||
|
||
def test2(): | ||
import time | ||
|
||
B, N_CTX, H, D = 32, 1024, 12, 128 | ||
src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() | ||
dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() | ||
value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) | ||
scale_dest = torch.randn((B * N_CTX, H, D // 8), dtype=torch.float16).cuda() | ||
|
||
for _ in range(10): | ||
destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) | ||
torch.cuda.synchronize() | ||
t1 = time.time() | ||
for _ in range(1000): | ||
destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) | ||
torch.cuda.synchronize() | ||
t2 = time.time() | ||
|
||
print("Time cost ", t2 - t1) | ||
value_dest = value_dest.view((B * N_CTX, H, D // 8, 8)) | ||
scale_dest = scale_dest.view((B * N_CTX, H, D // 8, 1)) | ||
print("max ", torch.max(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) | ||
print("mean ", torch.mean(torch.abs((value_dest * scale_dest).view(B * N_CTX, H, D) - src))) | ||
cos = torch.nn.CosineSimilarity(0) | ||
print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) | ||
|
||
|
||
if __name__ == '__main__': | ||
test2() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters