Skip to content

Commit

Permalink
starcoder support ppl int8kv (#141)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Sep 21, 2023
1 parent 814af33 commit 5836afc
Show file tree
Hide file tree
Showing 5 changed files with 94 additions and 1 deletion.
Empty file.
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np

from lightllm.utils.infer_utils import mark_cost_time
from lightllm.models.starcoder.layer_infer.infer_struct import StarcoderInferStateInfo
from lightllm.models.starcoder.layer_infer.transformer_layer_infer import StarcoderTransformerLayerInfer



class StarcoderPPlTransformerLayerInfer(StarcoderTransformerLayerInfer):
"""
"""
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
return

def _post_cache_kv(self, cache_k, cache_v, infer_state:StarcoderInferStateInfo, layer_weight):
mem_manager = infer_state.mem_manager
from lightllm.models.llama_ppl.triton_kernel.quant_copy_kv import destindex_copy_quantize_kv
if infer_state.is_prefill:
destindex_copy_quantize_kv(cache_k,
infer_state.prefill_mem_index,
mem_manager.key_buffer[self.layer_num_],
mem_manager.key_scale_buffer[self.layer_num_])
destindex_copy_quantize_kv(cache_v,
infer_state.prefill_mem_index,
mem_manager.value_buffer[self.layer_num_],
mem_manager.value_scale_buffer[self.layer_num_])
return
else:
if not infer_state.decode_is_contiguous:
destindex_copy_quantize_kv(cache_k,
infer_state.decode_mem_index,
mem_manager.key_buffer[self.layer_num_],
mem_manager.key_scale_buffer[self.layer_num_])
destindex_copy_quantize_kv(cache_v,
infer_state.decode_mem_index,
mem_manager.value_buffer[self.layer_num_],
mem_manager.value_scale_buffer[self.layer_num_])
return
return

def _token_attention_kernel(self, q, infer_state:StarcoderInferStateInfo, layer_weight)->torch.Tensor:
batch_size = infer_state.batch_size
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
o_tensor = torch.empty_like(q)

from lightllm_ppl_kernel import group8_int8kv_decode_attention
group8_int8kv_decode_attention(o_tensor.view(calcu_shape1),
q.view(calcu_shape1),
infer_state.mem_manager.key_buffer[self.layer_num_],
infer_state.mem_manager.key_scale_buffer[self.layer_num_],
infer_state.mem_manager.value_buffer[self.layer_num_],
infer_state.mem_manager.value_scale_buffer[self.layer_num_],
infer_state.b_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch)

return o_tensor
28 changes: 28 additions & 0 deletions lightllm/models/starcoder_ppl/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os
import json
import torch

from lightllm.models.starcoder.model import StarcoderTpPartModel
from lightllm.models.starcoder_ppl.layer_infer.transformer_layer_infer import StarcoderPPlTransformerLayerInfer
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager

class StarcoderPPlTpPartModel(StarcoderTpPartModel):

# infer class
transformer_layer_infer_class = StarcoderPPlTransformerLayerInfer

def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=[]):
super().__init__(tp_rank, world_size, weight_dir, max_total_token_num, load_way, mode)

def _verify_params(self):
assert self.load_way == "HF", "StarCoder only support HF format to load Now!"
assert "int8kv" in self.mode, "ppl Starcoder only support int8kv mode"
return

def _init_mem_manager(self):
self.mem_manager = PPLINT8KVMemoryManager(self.max_total_token_num,
dtype=torch.float16,
head_num=self.config["num_key_value_heads"],
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"])
return
6 changes: 5 additions & 1 deletion lightllm/server/router/model_infer/model_rpc.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from lightllm.models.llama2_ppl.model import Llama2PPlTpPartModel
from lightllm.models.llama2.model import Llama2TpPartModel
from lightllm.models.starcoder.model import StarcoderTpPartModel
from lightllm.models.starcoder_ppl.model import StarcoderPPlTpPartModel
from lightllm.models.qwen.model import QWenTpPartModel
from lightllm.models.baichuan7b.model import Baichuan7bTpPartModel
from lightllm.models.baichuan13b.model import Baichuan13bTpPartModel
Expand Down Expand Up @@ -76,7 +77,10 @@ def exposed_init_model(self, rank_id, world_size, weight_dir, max_total_token_nu
else:
raise Exception('can not support baichuan format')
elif self.model_type == 'gpt_bigcode':
self.model = StarcoderTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
if "ppl" not in mode:
self.model = StarcoderTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
else:
self.model = StarcoderPPlTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
elif self.model_type == 'chatglm':
self.model = ChatGlm2TpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode)
elif self.model_type == 'internlm':
Expand Down

0 comments on commit 5836afc

Please sign in to comment.