-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
262 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
import torch | ||
|
||
from .mem_manager import MemoryManager | ||
|
||
|
||
class Deepseek2MemoryManager(MemoryManager): | ||
def __init__(self, size, dtype, head_num, key_head_dim, value_head_dim, layer_num, always_copy=True): | ||
self.key_head_dim = key_head_dim | ||
self.value_head_dim = value_head_dim | ||
super().__init__(size, dtype, head_num, -1, layer_num, always_copy=True) | ||
|
||
def _init_buffers(self, size, dtype, head_num, head_dim, layer_num): | ||
self.k_buffer = [ | ||
torch.empty((size, head_num, self.key_head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) | ||
] | ||
self.v_buffer = [ | ||
torch.empty((size, head_num, self.value_head_dim), dtype=dtype, device="cuda") for _ in range(layer_num) | ||
] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
import torch | ||
import numpy as np | ||
from lightllm.common.basemodel import InferStateInfo | ||
from lightllm.common.req_manager import ReqManager | ||
|
||
|
||
class Deepseek2InferStateInfo(InferStateInfo): | ||
def __init__(self): | ||
super().__init__() |
Empty file.
Empty file.
Empty file.
170 changes: 170 additions & 0 deletions
170
lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,170 @@ | ||
import torch | ||
import math | ||
import numpy as np | ||
from lightllm.common.basemodel import TransformerLayerWeight | ||
|
||
|
||
class Deepseek2TransformerLayerWeight(TransformerLayerWeight): | ||
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): | ||
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) | ||
self.is_moe = ( | ||
self.network_config_["n_routed_experts"] is not None | ||
and self.layer_num_ >= self.network_config_["first_k_dense_replace"] | ||
and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0 | ||
) | ||
self.n_routed_experts = self.network_config_["n_routed_experts"] | ||
return | ||
|
||
def load_hf_weights(self, weights): | ||
self._load_qkvo_weights(weights) | ||
self._load_ffn_weights(weights) | ||
return | ||
|
||
def verify_load(self): | ||
errors = "weights load not ok" | ||
weights = [ | ||
self.att_norm_weight_, | ||
self.q_weight_, | ||
self.kv_a_proj_with_mqa_, | ||
self.kv_a_layernorm_, | ||
self.kv_b_proj_, | ||
self.o_weight_, | ||
self.ffn_norm_weight_, | ||
] | ||
if self.is_moe: | ||
assert len(self.experts) == self.n_routed_experts // self.world_size_, "experts weight load not ok" | ||
weights.append([ | ||
self.moe_gate, | ||
self.shared_experts_gate_up_proj, | ||
self.shared_experts_down_proj | ||
]) | ||
else: | ||
weights.append([ | ||
self.gate_up_proj, | ||
self.down_proj, | ||
]) | ||
|
||
for i in range(len(weights)): | ||
assert weights[i] is not None, "index:" + str(i) + " " + errors | ||
return | ||
|
||
def _load_qkvo_weights(self, weights): | ||
# input layernorm params | ||
if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: | ||
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) | ||
|
||
n_embed = self.network_config_["hidden_size"] | ||
q_split_n_embed = n_embed // self.world_size_ | ||
kv_split_n_embed = ( | ||
n_embed | ||
// self.network_config_["num_attention_heads"] | ||
* self.network_config_["num_key_value_heads"] | ||
// self.world_size_ | ||
) | ||
# q k v weights for llama | ||
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: | ||
self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] | ||
self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) | ||
|
||
if f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight" in weights: | ||
kv_a_proj_with_mqa_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight"] | ||
print("kv_a_proj_with_mqa shape", kv_a_proj_with_mqa_.shape) | ||
kv_a_proj_with_mqa_ = kv_a_proj_with_mqa_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.kv_a_proj_with_mqa_ = kv_a_proj_with_mqa_.transpose(0, 1) | ||
|
||
if f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight" in weights: | ||
kv_a_layernorm_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight"] | ||
kv_a_layernorm_ = kv_a_layernorm_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.kv_a_layernorm_ = kv_a_layernorm_.transpose(0, 1) | ||
|
||
if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights: | ||
kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"] | ||
kv_b_proj_ = kv_b_proj_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] | ||
self.kv_b_proj_ = kv_b_proj_.transpose(0, 1) | ||
|
||
# attention output dense params | ||
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: | ||
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] | ||
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] | ||
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) | ||
|
||
return | ||
|
||
def _load_ffn_weights(self, weights): | ||
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: | ||
self.ffn_norm_weight_ = self._cuda( | ||
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] | ||
) | ||
|
||
if self.is_moe: | ||
experts_per_rank = self.n_routed_experts // self.world_size_ | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.gate.weight" in weights: | ||
moe_gate = weights[f"model.layers.{self.layer_num_}.mlp.gate.weight"] | ||
self.moe_gate = self._cuda(moe_gate.transpose(0, 1)) | ||
|
||
if f"model.layers.{self.layer_num_}.shared_experts.up_proj.weight" in weights: | ||
shared_experts_up_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.up_proj.weight"] | ||
self.shared_experts_up_proj = shared_experts_up_proj.transpose(0, 1) | ||
|
||
if f"model.layers.{self.layer_num_}.shared_experts.gate_proj.weight" in weights: | ||
shared_experts_gate_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.gate_proj.weight"] | ||
self.shared_experts_gate_proj = shared_experts_gate_proj.transpose(0, 1) | ||
|
||
self._try_cat_to(["shared_experts_gate_proj", "shared_experts_up_proj"], "shared_experts_gate_up_proj", cat_dim=1) | ||
|
||
if f"model.layers.{self.layer_num_}.shared_experts.down_proj.weight" in weights: | ||
self.shared_experts_down_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.down_proj.weight"] | ||
self.shared_experts_down_proj = self._cuda(self.shared_experts_down_proj.transpose(0, 1)) | ||
|
||
self.experts = [] | ||
for i_experts in range(experts_per_rank * self.tp_rank_, experts_per_rank * (self.tp_rank_ + 1)): | ||
self.expert_up_proj = None | ||
self.expert_gate_proj = None | ||
self.expert_down_proj = None | ||
|
||
if f"model.layers.{self.layer_num_}.experts.{i_experts}.up_proj.weight" in weights: | ||
expert_up_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.up_proj.weight"] | ||
self.expert_up_proj = expert_up_proj.transpose(0, 1) | ||
|
||
if f"model.layers.{self.layer_num_}.experts.{i_experts}.gate_proj.weight" in weights: | ||
expert_gate_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.gate_proj.weight"] | ||
self.expert_gate_proj = expert_gate_proj.transpose(0, 1) | ||
|
||
self._try_cat_to(["expert_gate_proj", "expert_up_proj"], "expert_gate_up_proj", cat_dim=1) | ||
|
||
if f"model.layers.{self.layer_num_}.experts.{i_experts}.down_proj.weight" in weights: | ||
self.expert_down_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.down_proj.weight"] | ||
self.expert_down_proj = self._cuda(self.expert_down_proj.transpose(0, 1)) | ||
|
||
if self.expert_gate_up_proj is not None and self.expert_down_proj is not None: | ||
self.experts.append({ | ||
"expert_gate_up_proj": self.expert_gate_up_proj, | ||
"expert_down_proj": self.expert_down_proj | ||
}) | ||
|
||
else: | ||
inter_size = self.network_config_["intermediate_size"] | ||
split_inter_size = inter_size // self.world_size_ | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: | ||
up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ | ||
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : | ||
] | ||
self.up_proj = up_proj.transpose(0, 1) | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: | ||
gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ | ||
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : | ||
] | ||
self.gate_proj = gate_proj.transpose(0, 1) | ||
|
||
self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) | ||
|
||
if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: | ||
self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ | ||
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) | ||
] | ||
self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) | ||
return |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import os | ||
import json | ||
import torch | ||
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer | ||
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight | ||
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo | ||
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights | ||
|
||
from lightllm.common.basemodel import LlamaTpPartModel | ||
from lightllm.common.mem_utils import select_mem_manager_class | ||
from lightllm.utils.log_utils import init_logger | ||
|
||
|
||
logger = init_logger(__name__) | ||
|
||
class Deepseek2TpPartModle(LlamaTpPartModel): | ||
# weight class | ||
transformer_weight_class = Deepseek2TransformerLayerWeight | ||
|
||
# infer class | ||
transformer_layer_infer_class = Deepseek2TransformerLayerInfer | ||
|
||
# infer state class | ||
infer_state_class = Deepseek2InferStateInfo | ||
|
||
def __init__(self, kvargs): | ||
super().__init__(kvargs) | ||
return | ||
|
||
def _init_config(self): | ||
super()._init_config() | ||
return | ||
|
||
def _verify_params(self): | ||
return super()._verify_params() | ||
|
||
|
||
def _init_mem_manager(self): | ||
self.mem_manager = select_mem_manager_class(self.mode, "deepseek2")(self.max_total_token_num, | ||
dtype=self.data_type, | ||
head_num=self.config["num_key_value_heads"] // self.world_size_, | ||
key_head_dim=self.config["qk_nope_head_dim"] + self.config["qk_rope_head_dim"], | ||
value_head_dim=self.config["qk_nope_head_dim"], | ||
layer_num=self.config["num_hidden_layers"]) | ||
return | ||
|
||
def _init_custom(self): | ||
return | ||
|
||
def _init__weights(self): | ||
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) | ||
self.trans_layers_weight = [ | ||
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) | ||
for i in range(self.config["n_layer"]) | ||
] | ||
load_hf_weights( | ||
self.data_type, | ||
weight_dir=self.weight_dir_, | ||
pre_post_layer=self.pre_post_weight, | ||
transformer_layer_list=self.trans_layers_weight, | ||
weight_dict=self.weight_dict | ||
) | ||
self.pre_post_weight.verify_load() | ||
[weight.verify_load() for weight in self.trans_layers_weight] | ||
return |