From 5836afcb1104f78facef69ccd7a29612ebcb1313 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 21 Sep 2023 18:12:13 +0800 Subject: [PATCH] starcoder support ppl int8kv (#141) Co-authored-by: wangzaijun --- lightllm/models/starcoder_ppl/__init__.py | 0 .../starcoder_ppl/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 61 +++++++++++++++++++ lightllm/models/starcoder_ppl/model.py | 28 +++++++++ .../server/router/model_infer/model_rpc.py | 6 +- 5 files changed, 94 insertions(+), 1 deletion(-) create mode 100644 lightllm/models/starcoder_ppl/__init__.py create mode 100644 lightllm/models/starcoder_ppl/layer_infer/__init__.py create mode 100644 lightllm/models/starcoder_ppl/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/starcoder_ppl/model.py diff --git a/lightllm/models/starcoder_ppl/__init__.py b/lightllm/models/starcoder_ppl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/starcoder_ppl/layer_infer/__init__.py b/lightllm/models/starcoder_ppl/layer_infer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/starcoder_ppl/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder_ppl/layer_infer/transformer_layer_infer.py new file mode 100644 index 00000000..09f97e1b --- /dev/null +++ b/lightllm/models/starcoder_ppl/layer_infer/transformer_layer_infer.py @@ -0,0 +1,61 @@ +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np + +from lightllm.utils.infer_utils import mark_cost_time +from lightllm.models.starcoder.layer_infer.infer_struct import StarcoderInferStateInfo +from lightllm.models.starcoder.layer_infer.transformer_layer_infer import StarcoderTransformerLayerInfer + + + +class StarcoderPPlTransformerLayerInfer(StarcoderTransformerLayerInfer): + """ + """ + def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): + super().__init__(layer_num, tp_rank, world_size, network_config, mode) + return + + def _post_cache_kv(self, cache_k, cache_v, infer_state:StarcoderInferStateInfo, layer_weight): + mem_manager = infer_state.mem_manager + from lightllm.models.llama_ppl.triton_kernel.quant_copy_kv import destindex_copy_quantize_kv + if infer_state.is_prefill: + destindex_copy_quantize_kv(cache_k, + infer_state.prefill_mem_index, + mem_manager.key_buffer[self.layer_num_], + mem_manager.key_scale_buffer[self.layer_num_]) + destindex_copy_quantize_kv(cache_v, + infer_state.prefill_mem_index, + mem_manager.value_buffer[self.layer_num_], + mem_manager.value_scale_buffer[self.layer_num_]) + return + else: + if not infer_state.decode_is_contiguous: + destindex_copy_quantize_kv(cache_k, + infer_state.decode_mem_index, + mem_manager.key_buffer[self.layer_num_], + mem_manager.key_scale_buffer[self.layer_num_]) + destindex_copy_quantize_kv(cache_v, + infer_state.decode_mem_index, + mem_manager.value_buffer[self.layer_num_], + mem_manager.value_scale_buffer[self.layer_num_]) + return + return + + def _token_attention_kernel(self, q, infer_state:StarcoderInferStateInfo, layer_weight)->torch.Tensor: + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) + o_tensor = torch.empty_like(q) + + from lightllm_ppl_kernel import group8_int8kv_decode_attention + group8_int8kv_decode_attention(o_tensor.view(calcu_shape1), + q.view(calcu_shape1), + infer_state.mem_manager.key_buffer[self.layer_num_], + infer_state.mem_manager.key_scale_buffer[self.layer_num_], + infer_state.mem_manager.value_buffer[self.layer_num_], + infer_state.mem_manager.value_scale_buffer[self.layer_num_], + infer_state.b_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch) + + return o_tensor diff --git a/lightllm/models/starcoder_ppl/model.py b/lightllm/models/starcoder_ppl/model.py new file mode 100644 index 00000000..2b5b122b --- /dev/null +++ b/lightllm/models/starcoder_ppl/model.py @@ -0,0 +1,28 @@ +import os +import json +import torch + +from lightllm.models.starcoder.model import StarcoderTpPartModel +from lightllm.models.starcoder_ppl.layer_infer.transformer_layer_infer import StarcoderPPlTransformerLayerInfer +from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager + +class StarcoderPPlTpPartModel(StarcoderTpPartModel): + + # infer class + transformer_layer_infer_class = StarcoderPPlTransformerLayerInfer + + def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=[]): + super().__init__(tp_rank, world_size, weight_dir, max_total_token_num, load_way, mode) + + def _verify_params(self): + assert self.load_way == "HF", "StarCoder only support HF format to load Now!" + assert "int8kv" in self.mode, "ppl Starcoder only support int8kv mode" + return + + def _init_mem_manager(self): + self.mem_manager = PPLINT8KVMemoryManager(self.max_total_token_num, + dtype=torch.float16, + head_num=self.config["num_key_value_heads"], + head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], + layer_num=self.config["num_hidden_layers"]) + return \ No newline at end of file diff --git a/lightllm/server/router/model_infer/model_rpc.py b/lightllm/server/router/model_infer/model_rpc.py index 472c30b8..7ac3c82d 100644 --- a/lightllm/server/router/model_infer/model_rpc.py +++ b/lightllm/server/router/model_infer/model_rpc.py @@ -16,6 +16,7 @@ from lightllm.models.llama2_ppl.model import Llama2PPlTpPartModel from lightllm.models.llama2.model import Llama2TpPartModel from lightllm.models.starcoder.model import StarcoderTpPartModel +from lightllm.models.starcoder_ppl.model import StarcoderPPlTpPartModel from lightllm.models.qwen.model import QWenTpPartModel from lightllm.models.baichuan7b.model import Baichuan7bTpPartModel from lightllm.models.baichuan13b.model import Baichuan13bTpPartModel @@ -76,7 +77,10 @@ def exposed_init_model(self, rank_id, world_size, weight_dir, max_total_token_nu else: raise Exception('can not support baichuan format') elif self.model_type == 'gpt_bigcode': - self.model = StarcoderTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode) + if "ppl" not in mode: + self.model = StarcoderTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode) + else: + self.model = StarcoderPPlTpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode) elif self.model_type == 'chatglm': self.model = ChatGlm2TpPartModel(rank_id, world_size, weight_dir, max_total_token_num, load_way, mode) elif self.model_type == 'internlm':