From 0163eb5ab325dbd5a5d1d7546c120e96bcb88823 Mon Sep 17 00:00:00 2001 From: Yunqian Fan Date: Wed, 12 Jun 2024 10:05:01 +0800 Subject: [PATCH 1/3] feat: re-impl the cohere with template (#429) --- ...transformer_layer_infer_cohere_template.py | 173 ++++++++++++++++++ lightllm/models/cohere/__init__.py | 0 lightllm/models/cohere/infer_struct.py | 8 + .../models/cohere/layer_infer/__init__.py | 0 .../cohere/layer_infer/post_layer_infer.py | 138 ++++++++++++++ .../layer_infer/transformer_layer_infer.py | 77 ++++++++ .../models/cohere/layer_weights/__init__.py | 0 .../pre_and_post_layer_weight.py | 36 ++++ .../layer_weights/transformer_layer_weight.py | 107 +++++++++++ lightllm/models/cohere/model.py | 24 +++ .../models/cohere/splitfuse_infer_struct.py | 11 ++ .../models/cohere/triton_kernels/__init__.py | 0 .../models/cohere/triton_kernels/layernorm.py | 15 ++ .../model_infer/mode_backend/base_backend.py | 5 +- 14 files changed, 593 insertions(+), 1 deletion(-) create mode 100755 lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py create mode 100644 lightllm/models/cohere/__init__.py create mode 100644 lightllm/models/cohere/infer_struct.py create mode 100644 lightllm/models/cohere/layer_infer/__init__.py create mode 100644 lightllm/models/cohere/layer_infer/post_layer_infer.py create mode 100644 lightllm/models/cohere/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/cohere/layer_weights/__init__.py create mode 100644 lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py create mode 100644 lightllm/models/cohere/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/cohere/model.py create mode 100644 lightllm/models/cohere/splitfuse_infer_struct.py create mode 100644 lightllm/models/cohere/triton_kernels/__init__.py create mode 100644 lightllm/models/cohere/triton_kernels/layernorm.py diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py new file mode 100755 index 00000000..54deffa2 --- /dev/null +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py @@ -0,0 +1,173 @@ +from functools import partial +from typing import Tuple + +import torch +import torch.distributed as dist + +from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_template import TransformerLayerInferTpl +from lightllm.utils.infer_utils import mark_cost_time + +from ...infer_struct import InferStateInfo +from ...splitfuse_infer_struct import SplitFuseInferStateInfo +from ..transformer_layer_infer import TransformerLayerInfer + + +class TransformerLayerCohereInferTpl(TransformerLayerInferTpl): + """ """ + + def __init__(self, layer_num, tp_rank, world_size, network_config, mode): + super().__init__(layer_num, tp_rank, world_size, network_config, mode) + + self.use_qk_norm_ = self.network_config_.get("use_qk_norm", False) + return + + def _att_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: + raise Exception("need to impl") + + def _q_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: + raise Exception("need to impl") + + def _k_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: + raise Exception("need to impl") + + def _bind_norm(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: + self._att_norm = partial(TransformerLayerCohereInferTpl._q_norm, self) + self._q_norm = partial(TransformerLayerCohereInferTpl._k_norm, self) + self._k_norm = partial(TransformerLayerCohereInferTpl._att_norm, self) + + def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): + raise Exception("need to impl") + + def _bind_rotary_emb_fwd(self): + raise Exception("need to impl") + + def _get_qkv( + self, input, cache_kv, infer_state: InferStateInfo, layer_weight + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) + torch.mm( + input.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + if self.use_qk_norm_: + q = q.view(-1, self.tp_q_head_num_, self.head_dim_) + k = cache_kv[:, 0 : self.tp_k_head_num_, :] + q = self._q_norm(q, infer_state, layer_weight) + cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight) + self._rotary_emb_fwd(q, cache_kv, infer_state.position_cos, infer_state.position_sin) + return q, cache_kv + + def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: + raise Exception("need to impl") + + def _token_attention_kernel(self, q, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: + raise Exception("need to impl") + + def _splitfuse_attention_kernel( + self, q, infer_state: SplitFuseInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + raise Exception("need to impl") + + def _get_o(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: + raise Exception("need to impl") + + def _ffn(self, input, infer_state: InferStateInfo, layer_weight) -> torch.Tensor: + raise Exception("need to impl") + + @mark_cost_time( + "trans context flash forward time cost" + ) # dont to remove this, will make performence down, did not know why + def _context_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input_embding, cache_kv, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._context_attention_kernel(q, cache_kv, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + infer_state._attn_out = o + return + + @mark_cost_time( + "trans context ffn forward time cost" + ) # dont to remove this, will make performence down, did not know why + def _context_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): + ffn_out = self._ffn(input_embdings, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + infer_state._ffn_out = ffn_out + return + + # this impl dont to use @mark_cost_time + def _token_attention(self, input_embding, infer_state: InferStateInfo, layer_weight): + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input_embding, cache_kv, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._token_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + infer_state._attn_out = o + return + + # this impl dont to use @mark_cost_time + def _token_ffn(self, input_embdings, infer_state: InferStateInfo, layer_weight): + ffn_out = self._ffn(input_embdings, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + infer_state._ffn_out = ffn_out + return + + def _splitfuse_attention(self, input_embding, infer_state: SplitFuseInferStateInfo, layer_weight): + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input_embding, cache_kv, infer_state, layer_weight) + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + infer_state._attn_out = o + return + + def _splitfuse_ffn(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight): + ffn_out = self._ffn(input_embdings, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + infer_state._ffn_out = ffn_out + return + + def _cohere_residual(self, input_embdings, infer_state: InferStateInfo): + # emb_addr = input_embdings.data_ptr() + # attn_out_addr = infer_state._attn_out.data_ptr() + # ffn_addr = infer_state._ffn_out.data_ptr() + # assert emb_addr != attn_out_addr + # assert emb_addr != ffn_addr + # assert attn_out_addr != ffn_addr + input_embdings.add_( + infer_state._attn_out.view(-1, self.embed_dim_) + infer_state._ffn_out.view(-1, self.embed_dim_) + ) + + def context_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + self._context_attention(input1, infer_state, layer_weight=layer_weight) + self._context_ffn(input1, infer_state, layer_weight) + self._cohere_residual(input_embdings, infer_state) + return input_embdings + + def token_forward(self, input_embdings, infer_state: InferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + self._token_attention(input1, infer_state, layer_weight=layer_weight) + self._token_ffn(input1, infer_state, layer_weight) + self._cohere_residual(input_embdings, infer_state) + return input_embdings + + def splitfuse_forward(self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight): + input1 = self._att_norm(input_embdings, infer_state, layer_weight) + self._splitfuse_attention(input1, infer_state, layer_weight=layer_weight) + self._splitfuse_ffn(input1, infer_state, layer_weight) + self._cohere_residual(input_embdings, infer_state) + return input_embdings diff --git a/lightllm/models/cohere/__init__.py b/lightllm/models/cohere/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/cohere/infer_struct.py b/lightllm/models/cohere/infer_struct.py new file mode 100644 index 00000000..d9571af9 --- /dev/null +++ b/lightllm/models/cohere/infer_struct.py @@ -0,0 +1,8 @@ +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + + +class CohereInferStateInfo(LlamaInferStateInfo): + def __init__(self): + super().__init__() + self._attn_out = None + self._ffn_out = None diff --git a/lightllm/models/cohere/layer_infer/__init__.py b/lightllm/models/cohere/layer_infer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py new file mode 100644 index 00000000..2c9a9cf2 --- /dev/null +++ b/lightllm/models/cohere/layer_infer/post_layer_infer.py @@ -0,0 +1,138 @@ +import torch +import torch.distributed as dist +import numpy as np + +from lightllm.models.cohere.infer_struct import CohereInferStateInfo +from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight +from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward +from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight +from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo + +from einops import rearrange +from lightllm.common.basemodel import PostLayerInferTpl + + +class CoherePostLayerInfer(PostLayerInferTpl): + def __init__(self, tp_rank, world_size, network_config, mode): + super().__init__(tp_rank, world_size, network_config, mode) + self.eps_ = network_config["layer_norm_eps"] + self.vocab_size_ = network_config["vocab_size"] + self.embed_dim_ = network_config["n_embed"] + self.logits_scale = network_config["logit_scale"] + return + + def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor: + return layernorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_) + + def _slice_get_last_input(self, input_embdings, infer_state: CohereInferStateInfo): + if infer_state.is_splitfuse: + # for SplitFuse + batch_size = infer_state.batch_size + last_input = torch.empty( + (batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype + ) + tmp_ = torch.cat( + [ + torch.ones(infer_state.decode_req_num, dtype=torch.int32, device="cuda"), + infer_state.prefill_b_seq_len - infer_state.prefill_b_split_ready_cache_len, + ], + dim=0, + ) + last_index = torch.cumsum(tmp_, dim=0, dtype=torch.long) - 1 + last_input[:, :] = input_embdings[last_index, :] + return last_input, batch_size + + if infer_state.is_prefill and infer_state.is_token_healing: + batch_size = infer_state.batch_size + b_seq_len_numpy = (infer_state.b_seq_len - infer_state.b_ready_cache_len).detach().cpu().numpy() + select_index = [] + start_index = 0 + select_token_num = 0 + for cur_len in b_seq_len_numpy: + if cur_len == 1: + select_index.append(start_index + cur_len - 1) + start_index += cur_len + select_token_num += 1 + else: + select_index.append(start_index + cur_len - 2) + select_index.append(start_index + cur_len - 1) + start_index += cur_len + select_token_num += 2 + + last_index = torch.tensor(select_index, dtype=torch.long, device=input_embdings.device) + last_input = torch.empty( + (select_token_num, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype + ) + + last_input[:, :] = input_embdings[last_index, :] + return last_input, select_token_num + + if not infer_state.is_splitfuse and infer_state.is_prefill and not infer_state.return_all_prompt_logics: + batch_size = infer_state.batch_size + last_input = torch.empty( + (batch_size, self.embed_dim_), device=input_embdings.device, dtype=input_embdings.dtype + ) + last_index = ( + torch.cumsum(infer_state.b_seq_len - infer_state.b_ready_cache_len, dim=0, dtype=torch.long) - 1 + ) + last_input[:, :] = input_embdings[last_index, :] + return last_input, batch_size + + if not infer_state.is_splitfuse and infer_state.is_prefill and infer_state.return_all_prompt_logics: + total_tokens = infer_state.total_token_num + return input_embdings, total_tokens + + if not infer_state.is_splitfuse and not infer_state.is_prefill: + batch_size = infer_state.batch_size + return input_embdings[-batch_size:, :], batch_size + + assert False, "Error State" + + def soft_max(self, data): + return torch.softmax(data.permute(1, 0).float(), dim=-1) + + def token_forward( + self, + input_embdings, + infer_state: CohereInferStateInfo, + layer_weight: CoherePreAndPostLayerWeight, + return_logics=False, + ): + last_input, token_num = self._slice_get_last_input(input_embdings, infer_state) + input_embdings_dtype = input_embdings.dtype + input_embdings = None + last_input = self._norm(last_input, infer_state, layer_weight) + last_input = rearrange(last_input, "batch embed_dim -> embed_dim batch").contiguous().reshape(-1, token_num) + logic_batch = torch.mm(layer_weight.lm_head_weight_, last_input) + + last_input = None + if self.world_size_ == 1: + gather_data = logic_batch + else: + gather_data = torch.empty( + (self.vocab_size_, token_num), device=logic_batch.device, dtype=input_embdings_dtype + ) + split_indexes = np.linspace(0, self.vocab_size_, self.world_size_ + 1, dtype=np.int64) + dist.all_gather( + [gather_data[split_indexes[i] : split_indexes[i + 1], :] for i in range(self.world_size_)], + logic_batch, + group=None, + async_op=False, + ) + gather_data = gather_data * self.logits_scale + logic_batch = None + + if not return_logics: + prob_out = self.soft_max(gather_data) + gather_data = None + return prob_out + else: + ans_logics = gather_data.permute(1, 0).float() + gather_data = None + return ans_logics + + # @mark_cost_time("splitfuse post forward") + def splitfuse_forward( + self, input_embdings, infer_state: SplitFuseInferStateInfo, layer_weight: BaseLayerWeight, return_logics=False + ): + return self.token_forward(input_embdings, infer_state, layer_weight, return_logics=return_logics) diff --git a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py new file mode 100644 index 00000000..22738bcc --- /dev/null +++ b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py @@ -0,0 +1,77 @@ +import torch +from functools import partial + +from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( + TransformerLayerCohereInferTpl, +) +from lightllm.models.cohere.infer_struct import CohereInferStateInfo +from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight +from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd + + +class CohereTransformerLayerInfer(TransformerLayerCohereInferTpl): + def __init__(self, layer_num, tp_rank, world_size, network_config, mode): + super().__init__(layer_num, tp_rank, world_size, network_config, mode) + self.tp_q_head_num_ = network_config["num_attention_heads"] // self.world_size_ + self.tp_k_head_num_ = network_config["num_key_value_heads"] // self.world_size_ + self.tp_v_head_num_ = network_config["num_key_value_heads"] // self.world_size_ + self.tp_o_head_num_ = self.tp_q_head_num_ + self.head_dim_ = network_config["hidden_size"] // network_config["num_attention_heads"] + self.embed_dim_ = network_config["hidden_size"] + self.eps_ = self.network_config_["layer_norm_eps"] + self.use_qk_norm_ = network_config.get("use_qk_norm", False) + self._bind_func() + + def _bind_func(self): + self._bind_rotary_emb_fwd() + self._bind_norm() + self._bind_attn() + + def _rotary_emb_fwd(self, q, kv, position_cos, position_sin): + return rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv, + position_cos, + position_sin, + ) + + def _bind_rotary_emb_fwd(self): + self._rotary_emb_fwd = partial(CohereTransformerLayerInfer._rotary_emb_fwd, self) + + def _att_norm(self, input, infer_state, layer_weight): + return layernorm_forward(input, layer_weight.att_norm_weight_, self.eps_) + + def _q_norm(self, input, infer_state, layer_weight): + return multi_head_layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_) + + def _k_norm(self, input, infer_state, layer_weight): + return multi_head_layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_) + + def _bind_norm(self): + self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self) + self._q_norm = partial(CohereTransformerLayerInfer._q_norm, self) + self._k_norm = partial(CohereTransformerLayerInfer._k_norm, self) + + def _bind_attn(self): + # no need to re-impl + LlamaTransformerLayerInfer._bind_attention(self) + + def _get_o( + self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight + ) -> torch.Tensor: + o_tensor = torch.mm(input.view(-1, self.tp_o_head_num_ * self.head_dim_), layer_weight.o_weight_) + return o_tensor + + def _ffn( + self, input, infer_state: CohereInferStateInfo, layer_weight: CohereTransformerLayerWeight + ) -> torch.Tensor: + up_gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj) + ffn1_out = silu_and_mul_fwd(up_gate_out) + input = None + up_gate_out = None + ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj) + ffn1_out = None + return ffn2_out diff --git a/lightllm/models/cohere/layer_weights/__init__.py b/lightllm/models/cohere/layer_weights/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py new file mode 100644 index 00000000..1adf3394 --- /dev/null +++ b/lightllm/models/cohere/layer_weights/pre_and_post_layer_weight.py @@ -0,0 +1,36 @@ +import torch +import numpy as np + +from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight + + +class CoherePreAndPostLayerWeight(LlamaPreAndPostLayerWeight): + def load_hf_weights(self, weights): + vob_size = self.network_config_["vocab_size"] + tie_weight = self.network_config_.get("tie_word_embeddings", True) + split_indexes = np.linspace(0, vob_size, self.world_size_ + 1, dtype=np.int64) + split_start = split_indexes[self.tp_rank_] + split_end = split_indexes[self.tp_rank_ + 1] + if "model.embed_tokens.weight" in weights: + # print(weights['model.embed_tokens.weight'].shape) + self.wte_weight_ = self._cuda(weights["model.embed_tokens.weight"][split_start:split_end, :]) + if tie_weight: + self.lm_head_weight_ = self.wte_weight_ + if "model.norm.weight" in weights: + self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) + if "model.lm_head.weight" in weights: + self.lm_head_weight_ = self._cuda(weights["model.lm_head.weight"]) + return + + def verify_load(self): + super().verify_load() + + errors = "tie weights load not ok" + tie_weight = self.network_config_.get("tie_word_embeddings", True) + if tie_weight: + assert self.lm_head_weight_ is not None, errors + assert self.wte_weight_ is self.lm_head_weight_, errors + else: + assert self.lm_head_weight_ is not None, errors + assert self.wte_weight_ is not None, errors + assert self.wte_weight_ is not self.lm_head_weight_, errors diff --git a/lightllm/models/cohere/layer_weights/transformer_layer_weight.py b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py new file mode 100644 index 00000000..7920de6c --- /dev/null +++ b/lightllm/models/cohere/layer_weights/transformer_layer_weight.py @@ -0,0 +1,107 @@ +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight + + +class CohereTransformerLayerWeight(LlamaTransformerLayerWeight): + def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) + self.use_qk_norm = network_config.get("use_qk_norm", False) + return + + def load_hf_weights(self, weights): + self._load_qkvo_weights(weights) + self._load_ffn_weights(weights) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.o_weight_, + self.gate_up_proj, + self.down_proj, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + if self.use_qk_norm: + qk_weights = [self.q_norm_weight_, self.k_norm_weight_] + for i in range(len(qk_weights)): + assert qk_weights[i] is not None, "index:" + str(i + len(weights)) + " " + errors + return + + def _load_qkvo_weights(self, weights): + # input layernorm params + if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: + self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) + + n_embed = self.network_config_["hidden_size"] + q_split_n_embed = n_embed // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) + q_split_head = self.network_config_["num_attention_heads"] // self.world_size_ + kv_split_head = self.network_config_["num_key_value_heads"] // self.world_size_ + + # q k v weights for llama + if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: + self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + + if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: + k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] + k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = k_weight_.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: + v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] + v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weight_.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.self_attn.q_norm.weight" in weights: + q_norm_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_norm.weight"] + q_norm_weight_ = q_norm_weight_[q_split_head * self.tp_rank_ : q_split_head * (self.tp_rank_ + 1)] + self.q_norm_weight_ = self._cuda(q_norm_weight_) + if f"model.layers.{self.layer_num_}.self_attn.k_norm.weight" in weights: + k_norm_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_norm.weight"] + k_norm_weight_ = k_norm_weight_[kv_split_head * self.tp_rank_ : kv_split_head * (self.tp_rank_ + 1)] + self.k_norm_weight_ = self._cuda(k_norm_weight_) + + # attention output dense params + if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: + self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + + return + + def _load_ffn_weights(self, weights): + inter_size = self.network_config_["intermediate_size"] + split_inter_size = inter_size // self.world_size_ + + if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: + up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: + gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + + if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: + self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] + self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + return diff --git a/lightllm/models/cohere/model.py b/lightllm/models/cohere/model.py new file mode 100644 index 00000000..7bb0ae23 --- /dev/null +++ b/lightllm/models/cohere/model.py @@ -0,0 +1,24 @@ +from lightllm.common.basemodel.basemodel import TpPartBaseModel +from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( + TransformerLayerCohereInferTpl, +) +from lightllm.models.cohere.infer_struct import CohereInferStateInfo +from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer +from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer +from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight +from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight +from lightllm.models.cohere.splitfuse_infer_struct import CohereSplitFuseInferStateInfo +from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer +from lightllm.models.llama.model import LlamaTpPartModel + + +class CohereTpPartModel(LlamaTpPartModel): + pre_and_post_weight_class = CoherePreAndPostLayerWeight + transformer_weight_class = CohereTransformerLayerWeight + + pre_layer_infer_class = LlamaPreLayerInfer + transformer_layer_infer_class = CohereTransformerLayerInfer + post_layer_infer_class = CoherePostLayerInfer + + infer_state_class = CohereInferStateInfo + splitfuse_infer_state_class = CohereSplitFuseInferStateInfo diff --git a/lightllm/models/cohere/splitfuse_infer_struct.py b/lightllm/models/cohere/splitfuse_infer_struct.py new file mode 100644 index 00000000..d642b5ed --- /dev/null +++ b/lightllm/models/cohere/splitfuse_infer_struct.py @@ -0,0 +1,11 @@ +from lightllm.models.cohere.infer_struct import CohereInferStateInfo +from lightllm.models.llama.splitfuse_infer_struct import LlamaSplitFuseInferStateInfo + + +class CohereSplitFuseInferStateInfo(LlamaSplitFuseInferStateInfo): + inner_decode_infer_state_class = CohereInferStateInfo + + def __init__(self): + super().__init__() + self._attn_out = None + self._ffn_out = None diff --git a/lightllm/models/cohere/triton_kernels/__init__.py b/lightllm/models/cohere/triton_kernels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/cohere/triton_kernels/layernorm.py b/lightllm/models/cohere/triton_kernels/layernorm.py new file mode 100644 index 00000000..e6008432 --- /dev/null +++ b/lightllm/models/cohere/triton_kernels/layernorm.py @@ -0,0 +1,15 @@ +import torch + + +def layernorm_forward(x, weight, eps): + return torch.layer_norm(x, (x.shape[-1],), weight, bias=None, eps=eps) + + +def multi_head_layernorm_forward(x, weight, eps): + inp_dtype = x.dtype + x = x.to(torch.float32) + mean = x.mean(-1, keepdim=True) + variance = (x - mean).pow(2).mean(-1, keepdim=True) + x = (x - mean) * torch.rsqrt(variance + eps) + x = weight.to(torch.float32) * x + return x.to(inp_dtype) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index fcacb8b8..29895d96 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -6,6 +6,7 @@ from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig +from lightllm.models.cohere.model import CohereTpPartModel from lightllm.models.mixtral.model import MixtralTpPartModel from lightllm.models.qwen2.model import Qwen2TpPartModel from rpyc.utils.classic import obtain @@ -171,6 +172,8 @@ def init_model(self, kvargs): self.model = Qwen2TpPartModel(model_kvargs) elif self.model_type == "gemma": self.model = Gemma_2bTpPartModel(model_kvargs) + elif self.model_type == "cohere": + self.model = CohereTpPartModel(model_kvargs) else: raise Exception(f"can not support {self.model_type} now") except Exception as e: @@ -192,7 +195,7 @@ def init_model(self, kvargs): return def init_custom(self): - pass + pass # @calculate_time(show=False, min_cost_ms=300) def prefill_batch(self, batch_id): From 0bf7ec9f24f80467ff41348d421ea4b5f32a45cc Mon Sep 17 00:00:00 2001 From: and_gate <38602277+senbeiasano@users.noreply.github.com> Date: Wed, 12 Jun 2024 10:42:18 +0800 Subject: [PATCH 2/3] use triton to init window info (#420) --- lightllm/models/mistral/infer_struct.py | 29 ++--- .../layer_infer/transformer_layer_infer.py | 3 +- .../init_att_sliding_window_info.py | 45 +++++++ .../token_attention_nopad_att1.py | 94 ++++++++++----- .../token_attention_nopad_reduceV.py | 8 +- .../token_attention_softmax_and_reducev.py | 114 ++++++++++++------ lightllm/models/mixtral/infer_struct.py | 29 ++--- lightllm/models/qwen2/infer_struct.py | 19 +-- .../layer_infer/transformer_layer_infer.py | 3 +- .../layer_infer/transformer_layer_infer.py | 3 +- 10 files changed, 221 insertions(+), 126 deletions(-) create mode 100644 lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py diff --git a/lightllm/models/mistral/infer_struct.py b/lightllm/models/mistral/infer_struct.py index e64df83c..b53a92d6 100644 --- a/lightllm/models/mistral/infer_struct.py +++ b/lightllm/models/mistral/infer_struct.py @@ -2,23 +2,25 @@ import numpy as np from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.req_manager import ReqManager +from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd + class MistralInferStateInfo(LlamaInferStateInfo): def __init__(self): super().__init__() self.sliding_window = None - self.b_start_loc_window = None self.b_att_seq_len = None self.b_att_start_loc = None self.total_cache_num = None # self.window_postion = None - def init_some_extra_state(self, model, input_ids : torch.Tensor): + def init_some_extra_state(self, model, input_ids: torch.Tensor): self.sliding_window = model.config["sliding_window"] if self.is_prefill: b_seq_len_numpy = self.b_seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0) + ).cuda() self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) position_ids = None @@ -30,17 +32,8 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor): # b_loc[0, max_len_in_batch - 1].item() # [SYM] still reserve all kv cache - self.b_att_seq_len = self.b_seq_len.clone() - self.b_att_start_loc = self.b_start_loc.clone() - self.b_start_loc_window = self.b_start_loc.clone() - self.total_cache_num = 0 - for i in range(0, self.batch_size): - if self.sliding_window < self.b_seq_len[i]: - self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window - self.b_att_seq_len[i] = self.sliding_window - else: - self.b_start_loc_window[i] = 0 - self.b_att_seq_len[i] = self.b_seq_len[i] - self.b_att_start_loc[i] = self.total_cache_num - self.total_cache_num += self.b_att_seq_len[i] - return \ No newline at end of file + self.b_att_seq_len = torch.zeros_like(self.b_seq_len) + init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window) + self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len + self.total_cache_num = torch.sum(self.b_att_seq_len).item() + return diff --git a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py index 8ed5a7fb..51cb4817 100755 --- a/lightllm/models/mistral/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/mistral/layer_infer/transformer_layer_infer.py @@ -59,7 +59,6 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.sliding_window, @@ -79,9 +78,9 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.other_kv_index, + infer_state.sliding_window, ) return o_tensor diff --git a/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py b/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py new file mode 100644 index 00000000..a60fe970 --- /dev/null +++ b/lightllm/models/mistral/triton_kernel/init_att_sliding_window_info.py @@ -0,0 +1,45 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_init_att_window_info( + b_seq_len, + b_att_seq_len, + batch_size, + sliding_window, + BLOCK_SIZE: tl.constexpr, +): + cur_index = tl.program_id(0) + cur_start = cur_index * BLOCK_SIZE + offsets = cur_start + tl.arange(0, BLOCK_SIZE) + mask = offsets < batch_size + + cur_seq_len = tl.load(b_seq_len + offsets, mask=mask) + b_att_seq_len_data = tl.minimum(cur_seq_len, sliding_window) + + tl.store(b_att_seq_len + offsets, b_att_seq_len_data, mask=mask) + return + + +@torch.no_grad() +def init_att_window_info_fwd(batch_size, b_seq_len, b_att_seq_len, sliding_window): + # shape constraints + assert batch_size == b_seq_len.shape[0] == b_att_seq_len.shape[0] + + BLOCK_SIZE = 32 + num_warps = 1 + grid = (triton.cdiv(batch_size, BLOCK_SIZE),) + + _fwd_kernel_init_att_window_info[grid]( + b_seq_len, + b_att_seq_len, + batch_size=batch_size, + sliding_window=sliding_window, + BLOCK_SIZE=BLOCK_SIZE, + num_warps=num_warps, + num_stages=1, + ) + return diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py b/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py index 1dec71d4..09ce9d2a 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py +++ b/lightllm/models/mistral/triton_kernel/token_attention_nopad_att1.py @@ -7,49 +7,68 @@ @triton.jit def _fwd_kernel_token_att1( - Q, K, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, + Q, + K, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + B_Att_Start_Loc, + B_Att_Seqlen, Att_Out, - stride_req_to_tokens_b, stride_req_to_tokens_s, - stride_qbs, stride_qh, stride_qd, - stride_kbs, stride_kh, stride_kd, - att_stride_h, att_stride_bs, - kv_group_num, sliding_window, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + att_stride_h, + att_stride_bs, + kv_group_num, + sliding_window, BLOCK_DMODEL: tl.constexpr, - BLOCK_N: tl.constexpr + BLOCK_N: tl.constexpr, ): cur_batch = tl.program_id(0) cur_head = tl.program_id(1) start_n = tl.program_id(2) - + cur_kv_head = cur_head // kv_group_num - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] + offs_d = tl.arange(0, BLOCK_DMODEL) # [D] cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index + cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # use window index cur_batch_req_idx = tl.load(B_req_idx + cur_batch) cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # use new start index of k value - cur_batch_start_index = tl.load(B_Start_Loc_Window + cur_batch) + cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) cur_batch_end_index = cur_batch_seq_len - off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D] + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d * stride_qd # [D] - offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32] + offs_n = start_n * BLOCK_N + tl.arange(0, BLOCK_N) # [32] # use new value to decide block mask block_stard_index = start_n * BLOCK_N - block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number + block_mask = tl.where(block_stard_index < cur_att_seq_len, 1, 0) # a number for start_mark in range(0, block_mask, 1): - q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark - offs_n_new = cur_batch_start_index + offs_n # the latest window of token - k_loc = tl.load(Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, - mask=offs_n_new < cur_batch_end_index, other=0) - off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd # [32, D], find token index + q = tl.load(Q + off_q + start_mark) # [SYM] why here add start_mark + offs_n_new = cur_batch_start_index + offs_n # the latest window of token + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + off_k = ( + k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] * stride_kd + ) # [32, D], find token index k = tl.load(K + off_k, mask=offs_n_new[:, None] < cur_batch_end_index, other=0.0) - att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32] + att_value = tl.sum(q[None, :] * k, 1) # [1, D] * [32, D] = [32, D] -> [32] att_value *= sm_scale off_o = cur_head * att_stride_h + (cur_batch_in_all_start_index + offs_n) * att_stride_bs tl.store(Att_Out + off_o, att_value, mask=offs_n_new < cur_batch_end_index) @@ -58,8 +77,8 @@ def _fwd_kernel_token_att1( @torch.no_grad() def token_att_fwd( - q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, sliding_window): + q, k, att_out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window +): BLOCK = 32 # shape constraints Lq, Lk = q.shape[-1], k.shape[-1] @@ -71,20 +90,33 @@ def token_att_fwd( grid = (batch, head_num, triton.cdiv(sliding_window, BLOCK)) kv_group_num = q.shape[1] // k.shape[1] - + if kv_group_num == 1: num_warps = 4 else: num_warps = 2 _fwd_kernel_token_att1[grid]( - q, k, sm_scale, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, + q, + k, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + B_Att_Start_Loc, + B_Att_Seqlen, att_out, - Req_to_tokens.stride(0), Req_to_tokens.stride(1), - q.stride(0), q.stride(1), q.stride(2), - k.stride(0), k.stride(1), k.stride(2), - att_out.stride(0), att_out.stride(1), + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + att_out.stride(0), + att_out.stride(1), kv_group_num=kv_group_num, sliding_window=sliding_window, BLOCK_DMODEL=Lk, @@ -92,4 +124,4 @@ def token_att_fwd( num_warps=num_warps, num_stages=1, ) - return \ No newline at end of file + return diff --git a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py b/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py index 60e3d13b..acf4923f 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py +++ b/lightllm/models/mistral/triton_kernel/token_attention_nopad_reduceV.py @@ -13,7 +13,6 @@ def _fwd_kernel_token_att2( B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, stride_req_to_tokens_b, @@ -27,6 +26,7 @@ def _fwd_kernel_token_att2( stride_oh, stride_od, kv_group_num, + sliding_window, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -38,7 +38,7 @@ def _fwd_kernel_token_att2( offs_n = tl.arange(0, BLOCK_N) # [64] offs_d = tl.arange(0, BLOCK_DMODEL) # [D] cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_index = tl.load(B_Start_Loc_Window + cur_batch) # new index + cur_batch_start_index = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index # cur_batch_end_index = cur_batch_seq_len cur_batch_in_all_start_index = tl.load(B_Att_Start_Loc + cur_batch) # new index cur_batch_req_idx = tl.load(B_req_idx + cur_batch) @@ -75,7 +75,7 @@ def _fwd_kernel_token_att2( @torch.no_grad() def token_att_fwd2( - prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen + prob, v, out, Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, B_Att_Start_Loc, B_Att_Seqlen, sliding_window ): BLOCK = 128 # BLOCK = 64 # for triton 2.0.0dev @@ -94,7 +94,6 @@ def token_att_fwd2( B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, Req_to_tokens.stride(0), @@ -108,6 +107,7 @@ def token_att_fwd2( out.stride(1), out.stride(2), kv_group_num=kv_group_num, + siliding_window=sliding_window, BLOCK_DMODEL=dim, BLOCK_N=BLOCK, num_warps=num_warps, diff --git a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py index a620706e..c37013f1 100644 --- a/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py +++ b/lightllm/models/mistral/triton_kernel/token_attention_softmax_and_reducev.py @@ -6,15 +6,28 @@ @triton.jit def _fwd_kernel( - Logics, V, Out, - Req_to_tokens, B_req_idx, B_Start_Loc, B_Seqlen, - B_Start_Loc_Window, B_Att_Start_Loc, B_Att_Seqlen, - stride_logic_h, stride_logic_bs, - stride_vbs, stride_vh, stride_vd, - stride_obs, stride_oh, stride_od, - stride_req_to_token_b, stride_req_to_token_s, - other_kv_index, # 避免读取到nan的数据 + Logics, + V, + Out, + Req_to_tokens, + B_req_idx, + B_Start_Loc, + B_Seqlen, + B_Att_Start_Loc, + B_Att_Seqlen, + stride_logic_h, + stride_logic_bs, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_token_b, + stride_req_to_token_s, + other_kv_index, # 避免读取到nan的数据 kv_group_num, + sliding_window, BLOCK_DMODEL: tl.constexpr, BLOCK_N: tl.constexpr, ): @@ -24,36 +37,43 @@ def _fwd_kernel( cur_kv_head = cur_head // kv_group_num cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index + cur_batch_start_loc = tl.load(B_Att_Start_Loc + cur_batch) # new index cur_batch_req_idx = tl.load(B_req_idx + cur_batch) - cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index - cur_cache_start_loc = tl.load(B_Start_Loc_Window + cur_batch) # new index + cur_att_seq_len = tl.load(B_Att_Seqlen + cur_batch) # new index + cur_cache_start_loc = tl.maximum(cur_batch_seq_len - sliding_window, 0) # new index - offs_n = tl.arange(0, BLOCK_N) # [64] - offs_d = tl.arange(0, BLOCK_DMODEL) # [D] + offs_n = tl.arange(0, BLOCK_N) # [64] + offs_d = tl.arange(0, BLOCK_DMODEL) # [D] - off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] + off_v = cur_kv_head * stride_vh + offs_d[None, :] * stride_vd # [1, D] v_ptrs = V + off_v e_max = float("-inf") e_sum = 0.0 - acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) # [D] for start_n in range(0, cur_att_seq_len, BLOCK_N): - start_n = tl.multiple_of(start_n, BLOCK_N) # check - v_index = tl.load(Req_to_tokens + cur_batch_req_idx * stride_req_to_token_b + - (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s, - mask=(cur_cache_start_loc + start_n + offs_n) < cur_batch_seq_len, other=other_kv_index) # [64] + start_n = tl.multiple_of(start_n, BLOCK_N) # check + v_index = tl.load( + Req_to_tokens + + cur_batch_req_idx * stride_req_to_token_b + + (cur_cache_start_loc + start_n + offs_n) * stride_req_to_token_s, + mask=(start_n + offs_n) < cur_att_seq_len, + other=other_kv_index, + ) # [64] - qk = tl.load(Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, - mask=start_n + offs_n < cur_batch_seq_len, other=float("-inf")) # [64] - - n_e_max = tl.maximum(tl.max(qk, 0), e_max) + qk = tl.load( + Logics + cur_head * stride_logic_h + (cur_batch_start_loc + start_n + offs_n) * stride_logic_bs, + mask=(start_n + offs_n) < cur_att_seq_len, + other=float("-inf"), + ) # [64] + + n_e_max = tl.maximum(tl.max(qk, 0), e_max) old_scale = tl.exp(e_max - n_e_max) p = tl.exp(qk - n_e_max) e_sum = e_sum * old_scale + tl.sum(p, 0) - v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D] - acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D] + v = tl.load(v_ptrs + v_index[:, None] * stride_vbs) # [1, D] + [64, 1] = [64, D] + acc = acc * old_scale + tl.sum(p[:, None] * v, 0) # [64, 1] * [64, D] = [64, D] -> [D] e_max = n_e_max acc = acc / e_sum @@ -65,8 +85,18 @@ def _fwd_kernel( @torch.no_grad() def token_softmax_reducev_fwd( - logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, - b_start_loc_window, b_att_start_loc, b_att_seq_len, other_kv_index): + logics, + v, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_att_start_loc, + b_att_seq_len, + other_kv_index, + sliding_window, +): BLOCK = 64 batch, head = b_seq_len.shape[0], logics.shape[0] grid = (batch, head) @@ -74,17 +104,31 @@ def token_softmax_reducev_fwd( num_warps = 1 _fwd_kernel[grid]( - logics, v, o, req_to_tokens, b_req_idx, b_start_loc, b_seq_len, - b_start_loc_window, b_att_start_loc, b_att_seq_len, - logics.stride(0), logics.stride(1), - v.stride(0), v.stride(1), v.stride(2), - o.stride(0), o.stride(1), o.stride(2), - req_to_tokens.stride(0), req_to_tokens.stride(1), + logics, + v, + o, + req_to_tokens, + b_req_idx, + b_start_loc, + b_seq_len, + b_att_start_loc, + b_att_seq_len, + logics.stride(0), + logics.stride(1), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_tokens.stride(0), + req_to_tokens.stride(1), other_kv_index, kv_group_num, + sliding_window, BLOCK_DMODEL=v.shape[-1], BLOCK_N=BLOCK, num_warps=num_warps, - num_stages=3 + num_stages=3, ) - return \ No newline at end of file + return diff --git a/lightllm/models/mixtral/infer_struct.py b/lightllm/models/mixtral/infer_struct.py index 19303be3..426b28c5 100644 --- a/lightllm/models/mixtral/infer_struct.py +++ b/lightllm/models/mixtral/infer_struct.py @@ -3,30 +3,32 @@ from lightllm.models.llama.infer_struct import LlamaInferStateInfo from lightllm.common.req_manager import ReqManager from lightllm.models.mistral.infer_struct import MistralInferStateInfo +from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd from lightllm.utils.log_utils import init_logger logger = init_logger(__name__) + class MixtralInferStateInfo(MistralInferStateInfo): def __init__(self): super().__init__() self.sliding_window = None - self.b_start_loc_window = None self.b_att_seq_len = None self.b_att_start_loc = None self.total_cache_num = None self.experts_topk = None self.num_local_experts = None - def init_some_extra_state(self, model, input_ids : torch.Tensor): + def init_some_extra_state(self, model, input_ids: torch.Tensor): # sliding_window is not used in Mixtral 8x7b, ignore it self.sliding_window = 4096 if model.config["sliding_window"] is None else model.config["sliding_window"] self.experts_topk = model.config["num_experts_per_tok"] self.num_local_experts = model.config["num_local_experts"] if self.is_prefill: b_seq_len_numpy = self.b_seq_len.cpu().numpy() - position_ids = torch.from_numpy(np.concatenate([np.arange(0, b_seq_len_numpy[i]) - for i in range(len(b_seq_len_numpy))], axis=0)).cuda() + position_ids = torch.from_numpy( + np.concatenate([np.arange(0, b_seq_len_numpy[i]) for i in range(len(b_seq_len_numpy))], axis=0) + ).cuda() self.position_cos = torch.index_select(model._cos_cached, 0, position_ids).view(position_ids.shape[0], -1) self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(position_ids.shape[0], -1) position_ids = None @@ -38,17 +40,8 @@ def init_some_extra_state(self, model, input_ids : torch.Tensor): # b_loc[0, max_len_in_batch - 1].item() # [SYM] still reserve all kv cache - self.b_att_seq_len = self.b_seq_len.clone() - self.b_att_start_loc = self.b_start_loc.clone() - self.b_start_loc_window = self.b_start_loc.clone() - self.total_cache_num = 0 - for i in range(0, self.batch_size): - if self.sliding_window < self.b_seq_len[i]: - self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window - self.b_att_seq_len[i] = self.sliding_window - else: - self.b_start_loc_window[i] = 0 - self.b_att_seq_len[i] = self.b_seq_len[i] - self.b_att_start_loc[i] = self.total_cache_num - self.total_cache_num += self.b_att_seq_len[i] - return \ No newline at end of file + self.b_att_seq_len = torch.zeros_like(self.b_seq_len) + init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window) + self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len + self.total_cache_num = torch.sum(self.b_att_seq_len).item() + return diff --git a/lightllm/models/qwen2/infer_struct.py b/lightllm/models/qwen2/infer_struct.py index 074c457d..4cb1b61a 100644 --- a/lightllm/models/qwen2/infer_struct.py +++ b/lightllm/models/qwen2/infer_struct.py @@ -1,6 +1,7 @@ import torch import numpy as np from lightllm.common.basemodel import InferStateInfo +from lightllm.models.mistral.triton_kernel.init_att_sliding_window_info import init_att_window_info_fwd from lightllm.common.req_manager import ReqManager @@ -8,7 +9,6 @@ class Qwen2InferStateInfo(InferStateInfo): def __init__(self): super().__init__() self.sliding_window = None - self.b_start_loc_window = None self.b_att_seq_len = None self.b_att_start_loc = None self.total_cache_num = None @@ -30,17 +30,8 @@ def init_some_extra_state(self, model, input_ids: torch.Tensor): self.position_sin = torch.index_select(model._sin_cached, 0, position_ids).view(self.b_seq_len.shape[0], -1) self.other_kv_index = self.req_manager.req_to_token_indexs[self.b_req_idx[0], 0].item() - self.b_att_seq_len = self.b_seq_len.clone() - self.b_att_start_loc = self.b_start_loc.clone() - self.b_start_loc_window = self.b_start_loc.clone() - self.total_cache_num = 0 - for i in range(0, self.batch_size): - if self.sliding_window < self.b_seq_len[i]: - self.b_start_loc_window[i] = self.b_seq_len[i] - self.sliding_window - self.b_att_seq_len[i] = self.sliding_window - else: - self.b_start_loc_window[i] = 0 - self.b_att_seq_len[i] = self.b_seq_len[i] - self.b_att_start_loc[i] = self.total_cache_num - self.total_cache_num += self.b_att_seq_len[i] + self.b_att_seq_len = torch.zeros_like(self.b_seq_len) + init_att_window_info_fwd(self.batch_size, self.b_seq_len, self.b_att_seq_len, self.sliding_window) + self.b_att_start_loc = torch.cumsum(self.b_att_seq_len, 0) - self.b_att_seq_len + self.total_cache_num = torch.sum(self.b_att_seq_len).item() return diff --git a/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py index 003f38dc..01808513 100644 --- a/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py @@ -85,7 +85,6 @@ def _token_decode_attention_normal(self, q, infer_state: Qwen2InferStateInfo, la infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.sliding_window, @@ -107,9 +106,9 @@ def _token_decode_attention_normal(self, q, infer_state: Qwen2InferStateInfo, la infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.other_kv_index, + infer_state.sliding_window, ) return o_tensor diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 5ddc8faa..5aed58a5 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -119,7 +119,6 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.sliding_window, @@ -141,9 +140,9 @@ def _token_decode_attention_normal(self, q, infer_state: MistralInferStateInfo, infer_state.b_req_idx, infer_state.b_start_loc, infer_state.b_seq_len, - infer_state.b_start_loc_window, infer_state.b_att_start_loc, infer_state.b_att_seq_len, infer_state.other_kv_index, + infer_state.sliding_window, ) return o_tensor From 62c006c83ca5caafeb5a545707c9865464e9e842 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Wed, 12 Jun 2024 11:13:27 +0800 Subject: [PATCH 3/3] update tokenizer_mode help info. (#431) --- lightllm/server/api_server.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index 08918361..4b5e08c8 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -302,8 +302,9 @@ def main(): "--tokenizer_mode", type=str, default="slow", - help="""tokenizer load mode, can be slow or auto, slow mode load fast but run slow, slow mode is - good for debug and test, when you want to get best performance, try auto mode""", + help="""tokenizer load mode, can be slow, fast or auto, slow mode load fast but run slow, + slow mode is good for debug and test, fast mode get best performance, auto mode will + try to use fast mode, if failed will use slow mode""", ) parser.add_argument( "--load_way",