diff --git a/lightllm/common/deepseek2_mem_manager.py b/lightllm/common/deepseek2_mem_manager.py new file mode 100644 index 00000000..97ee50a0 --- /dev/null +++ b/lightllm/common/deepseek2_mem_manager.py @@ -0,0 +1,18 @@ +import torch + +from .mem_manager import MemoryManager + + +class Deepseek2MemoryManager(MemoryManager): + def __init__(self, size, dtype, head_num, key_head_dim, value_head_dim, layer_num, always_copy=True): + self.key_head_dim = key_head_dim + self.value_head_dim = value_head_dim + super().__init__(size, dtype, head_num, -1, layer_num, always_copy=True) + + def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): + self.k_buffer = [ + torch.empty((size, head_num, self.key_head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) + ] + self.v_buffer = [ + torch.empty((size, head_num, self.value_head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) + ] diff --git a/lightllm/models/deepseek2/__init__.py b/lightllm/models/deepseek2/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/deepseek2/infer_struct.py b/lightllm/models/deepseek2/infer_struct.py new file mode 100644 index 00000000..31a4f003 --- /dev/null +++ b/lightllm/models/deepseek2/infer_struct.py @@ -0,0 +1,9 @@ +import torch +import numpy as np +from lightllm.common.basemodel import InferStateInfo +from lightllm.common.req_manager import ReqManager + + +class Deepseek2InferStateInfo(InferStateInfo): + def __init__(self): + super().__init__() diff --git a/lightllm/models/deepseek2/layer_infer/__init__.py b/lightllm/models/deepseek2/layer_infer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py b/lightllm/models/deepseek2/layer_infer/transformer_layer_infer.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/deepseek2/layer_weights/__init__.py b/lightllm/models/deepseek2/layer_weights/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py new file mode 100644 index 00000000..280435fe --- /dev/null +++ b/lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py @@ -0,0 +1,170 @@ +import torch +import math +import numpy as np +from lightllm.common.basemodel import TransformerLayerWeight + + +class Deepseek2TransformerLayerWeight(TransformerLayerWeight): + def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) + self.is_moe = ( + self.network_config_["n_routed_experts"] is not None + and self.layer_num_ >= self.network_config_["first_k_dense_replace"] + and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0 + ) + self.n_routed_experts = self.network_config_["n_routed_experts"] + return + + def load_hf_weights(self, weights): + self._load_qkvo_weights(weights) + self._load_ffn_weights(weights) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_a_proj_with_mqa_, + self.kv_a_layernorm_, + self.kv_b_proj_, + self.o_weight_, + self.ffn_norm_weight_, + ] + if self.is_moe: + assert len(self.experts) == self.n_routed_experts // self.world_size_, "experts weight load not ok" + weights.append([ + self.moe_gate, + self.shared_experts_gate_up_proj, + self.shared_experts_down_proj + ]) + else: + weights.append([ + self.gate_up_proj, + self.down_proj, + ]) + + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return + + def _load_qkvo_weights(self, weights): + # input layernorm params + if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: + self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) + + n_embed = self.network_config_["hidden_size"] + q_split_n_embed = n_embed // self.world_size_ + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) + # q k v weights for llama + if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: + self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + + if f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight" in weights: + kv_a_proj_with_mqa_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight"] + print("kv_a_proj_with_mqa shape", kv_a_proj_with_mqa_.shape) + kv_a_proj_with_mqa_ = kv_a_proj_with_mqa_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.kv_a_proj_with_mqa_ = kv_a_proj_with_mqa_.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight" in weights: + kv_a_layernorm_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight"] + kv_a_layernorm_ = kv_a_layernorm_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.kv_a_layernorm_ = kv_a_layernorm_.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: + kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] + kv_b_proj_ = kv_b_proj_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] + self.kv_b_proj_ = kv_b_proj_.transpose(0, 1) + + # attention output dense params + if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: + self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + + return + + def _load_ffn_weights(self, weights): + if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: + self.ffn_norm_weight_ = self._cuda( + weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] + ) + + if self.is_moe: + experts_per_rank = self.n_routed_experts // self.world_size_ + + if f"model.layers.{self.layer_num_}.mlp.gate.weight" in weights: + moe_gate = weights[f"model.layers.{self.layer_num_}.mlp.gate.weight"] + self.moe_gate = self._cuda(moe_gate.transpose(0, 1)) + + if f"model.layers.{self.layer_num_}.shared_experts.up_proj.weight" in weights: + shared_experts_up_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.up_proj.weight"] + self.shared_experts_up_proj = shared_experts_up_proj.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.shared_experts.gate_proj.weight" in weights: + shared_experts_gate_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.gate_proj.weight"] + self.shared_experts_gate_proj = shared_experts_gate_proj.transpose(0, 1) + + self._try_cat_to(["shared_experts_gate_proj", "shared_experts_up_proj"], "shared_experts_gate_up_proj", cat_dim=1) + + if f"model.layers.{self.layer_num_}.shared_experts.down_proj.weight" in weights: + self.shared_experts_down_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.down_proj.weight"] + self.shared_experts_down_proj = self._cuda(self.shared_experts_down_proj.transpose(0, 1)) + + self.experts = [] + for i_experts in range(experts_per_rank * self.tp_rank_, experts_per_rank * (self.tp_rank_ + 1)): + self.expert_up_proj = None + self.expert_gate_proj = None + self.expert_down_proj = None + + if f"model.layers.{self.layer_num_}.experts.{i_experts}.up_proj.weight" in weights: + expert_up_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.up_proj.weight"] + self.expert_up_proj = expert_up_proj.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.experts.{i_experts}.gate_proj.weight" in weights: + expert_gate_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.gate_proj.weight"] + self.expert_gate_proj = expert_gate_proj.transpose(0, 1) + + self._try_cat_to(["expert_gate_proj", "expert_up_proj"], "expert_gate_up_proj", cat_dim=1) + + if f"model.layers.{self.layer_num_}.experts.{i_experts}.down_proj.weight" in weights: + self.expert_down_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.down_proj.weight"] + self.expert_down_proj = self._cuda(self.expert_down_proj.transpose(0, 1)) + + if self.expert_gate_up_proj is not None and self.expert_down_proj is not None: + self.experts.append({ + "expert_gate_up_proj": self.expert_gate_up_proj, + "expert_down_proj": self.expert_down_proj + }) + + else: + inter_size = self.network_config_["intermediate_size"] + split_inter_size = inter_size // self.world_size_ + + if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: + up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1) + + if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: + gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + + if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: + self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] + self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + return diff --git a/lightllm/models/deepseek2/model.py b/lightllm/models/deepseek2/model.py new file mode 100644 index 00000000..e30440f9 --- /dev/null +++ b/lightllm/models/deepseek2/model.py @@ -0,0 +1,65 @@ +import os +import json +import torch +from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer +from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight +from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo +from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights + +from lightllm.common.basemodel import LlamaTpPartModel +from lightllm.common.mem_utils import select_mem_manager_class +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) + +class Deepseek2TpPartModle(LlamaTpPartModel): + # weight class + transformer_weight_class = Deepseek2TransformerLayerWeight + + # infer class + transformer_layer_infer_class = Deepseek2TransformerLayerInfer + + # infer state class + infer_state_class = Deepseek2InferStateInfo + + def __init__(self, kvargs): + super().__init__(kvargs) + return + + def _init_config(self): + super()._init_config() + return + + def _verify_params(self): + return super()._verify_params() + + + def _init_mem_manager(self): + self.mem_manager = select_mem_manager_class(self.mode, "deepseek2")(self.max_total_token_num, + dtype=self.data_type, + head_num=self.config["num_key_value_heads"] // self.world_size_, + key_head_dim=self.config["qk_nope_head_dim"] + self.config["qk_rope_head_dim"], + value_head_dim=self.config["qk_nope_head_dim"], + layer_num=self.config["num_hidden_layers"]) + return + + def _init_custom(self): + return + + def _init__weights(self): + self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) + self.trans_layers_weight = [ + self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) + for i in range(self.config["n_layer"]) + ] + load_hf_weights( + self.data_type, + weight_dir=self.weight_dir_, + pre_post_layer=self.pre_post_weight, + transformer_layer_list=self.trans_layers_weight, + weight_dict=self.weight_dict + ) + self.pre_post_weight.verify_load() + [weight.verify_load() for weight in self.trans_layers_weight] + return \ No newline at end of file