-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Co-authored-by: wangzaijun <[email protected]>
- Loading branch information
1 parent
814af33
commit 5836afc
Showing
5 changed files
with
94 additions
and
1 deletion.
There are no files selected for viewing
Empty file.
Empty file.
61 changes: 61 additions & 0 deletions
61
lightllm/models/starcoder_ppl/layer_infer/transformer_layer_infer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
import torch | ||
import torch.functional as F | ||
import torch.distributed as dist | ||
import numpy as np | ||
|
||
from lightllm.utils.infer_utils import mark_cost_time | ||
from lightllm.models.starcoder.layer_infer.infer_struct import StarcoderInferStateInfo | ||
from lightllm.models.starcoder.layer_infer.transformer_layer_infer import StarcoderTransformerLayerInfer | ||
|
||
|
||
|
||
class StarcoderPPlTransformerLayerInfer(StarcoderTransformerLayerInfer): | ||
""" | ||
""" | ||
def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): | ||
super().__init__(layer_num, tp_rank, world_size, network_config, mode) | ||
return | ||
|
||
def _post_cache_kv(self, cache_k, cache_v, infer_state:StarcoderInferStateInfo, layer_weight): | ||
mem_manager = infer_state.mem_manager | ||
from lightllm.models.llama_ppl.triton_kernel.quant_copy_kv import destindex_copy_quantize_kv | ||
if infer_state.is_prefill: | ||
destindex_copy_quantize_kv(cache_k, | ||
infer_state.prefill_mem_index, | ||
mem_manager.key_buffer[self.layer_num_], | ||
mem_manager.key_scale_buffer[self.layer_num_]) | ||
destindex_copy_quantize_kv(cache_v, | ||
infer_state.prefill_mem_index, | ||
mem_manager.value_buffer[self.layer_num_], | ||
mem_manager.value_scale_buffer[self.layer_num_]) | ||
return | ||
else: | ||
if not infer_state.decode_is_contiguous: | ||
destindex_copy_quantize_kv(cache_k, | ||
infer_state.decode_mem_index, | ||
mem_manager.key_buffer[self.layer_num_], | ||
mem_manager.key_scale_buffer[self.layer_num_]) | ||
destindex_copy_quantize_kv(cache_v, | ||
infer_state.decode_mem_index, | ||
mem_manager.value_buffer[self.layer_num_], | ||
mem_manager.value_scale_buffer[self.layer_num_]) | ||
return | ||
return | ||
|
||
def _token_attention_kernel(self, q, infer_state:StarcoderInferStateInfo, layer_weight)->torch.Tensor: | ||
batch_size = infer_state.batch_size | ||
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) | ||
o_tensor = torch.empty_like(q) | ||
|
||
from lightllm_ppl_kernel import group8_int8kv_decode_attention | ||
group8_int8kv_decode_attention(o_tensor.view(calcu_shape1), | ||
q.view(calcu_shape1), | ||
infer_state.mem_manager.key_buffer[self.layer_num_], | ||
infer_state.mem_manager.key_scale_buffer[self.layer_num_], | ||
infer_state.mem_manager.value_buffer[self.layer_num_], | ||
infer_state.mem_manager.value_scale_buffer[self.layer_num_], | ||
infer_state.b_loc, | ||
infer_state.b_seq_len, | ||
infer_state.max_len_in_batch) | ||
|
||
return o_tensor |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
import os | ||
import json | ||
import torch | ||
|
||
from lightllm.models.starcoder.model import StarcoderTpPartModel | ||
from lightllm.models.starcoder_ppl.layer_infer.transformer_layer_infer import StarcoderPPlTransformerLayerInfer | ||
from lightllm.common.ppl_int8kv_mem_manager import PPLINT8KVMemoryManager | ||
|
||
class StarcoderPPlTpPartModel(StarcoderTpPartModel): | ||
|
||
# infer class | ||
transformer_layer_infer_class = StarcoderPPlTransformerLayerInfer | ||
|
||
def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=[]): | ||
super().__init__(tp_rank, world_size, weight_dir, max_total_token_num, load_way, mode) | ||
|
||
def _verify_params(self): | ||
assert self.load_way == "HF", "StarCoder only support HF format to load Now!" | ||
assert "int8kv" in self.mode, "ppl Starcoder only support int8kv mode" | ||
return | ||
|
||
def _init_mem_manager(self): | ||
self.mem_manager = PPLINT8KVMemoryManager(self.max_total_token_num, | ||
dtype=torch.float16, | ||
head_num=self.config["num_key_value_heads"], | ||
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], | ||
layer_num=self.config["num_hidden_layers"]) | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters