Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
WANDY666 committed Jun 14, 2024
1 parent a47942e commit a2be0e4
Show file tree
Hide file tree
Showing 8 changed files with 262 additions and 0 deletions.
18 changes: 18 additions & 0 deletions lightllm/common/deepseek2_mem_manager.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
import torch

from .mem_manager import MemoryManager


class Deepseek2MemoryManager(MemoryManager):
def __init__(self, size, dtype, head_num, key_head_dim, value_head_dim, layer_num, always_copy=True):
self.key_head_dim = key_head_dim
self.value_head_dim = value_head_dim
super().__init__(size, dtype, head_num, -1, layer_num, always_copy=True)

def _init_buffers(self, size, dtype, head_num, head_dim, layer_num):
self.k_buffer = [
torch.empty((size, head_num, self.key_head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)
]
self.v_buffer = [
torch.empty((size, head_num, self.value_head_dim), dtype=dtype, device="cuda") for _ in range(layer_num)
]
Empty file.
9 changes: 9 additions & 0 deletions lightllm/models/deepseek2/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import torch
import numpy as np
from lightllm.common.basemodel import InferStateInfo
from lightllm.common.req_manager import ReqManager


class Deepseek2InferStateInfo(InferStateInfo):
def __init__(self):
super().__init__()
Empty file.
Empty file.
Empty file.
170 changes: 170 additions & 0 deletions lightllm/models/deepseek2/layer_weights/transformer_layer_weight.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,170 @@
import torch
import math
import numpy as np
from lightllm.common.basemodel import TransformerLayerWeight


class Deepseek2TransformerLayerWeight(TransformerLayerWeight):
def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode)
self.is_moe = (
self.network_config_["n_routed_experts"] is not None
and self.layer_num_ >= self.network_config_["first_k_dense_replace"]
and self.layer_num_ % self.network_config_["moe_layer_freq"] == 0
)
self.n_routed_experts = self.network_config_["n_routed_experts"]
return

def load_hf_weights(self, weights):
self._load_qkvo_weights(weights)
self._load_ffn_weights(weights)
return

def verify_load(self):
errors = "weights load not ok"
weights = [
self.att_norm_weight_,
self.q_weight_,
self.kv_a_proj_with_mqa_,
self.kv_a_layernorm_,
self.kv_b_proj_,
self.o_weight_,
self.ffn_norm_weight_,
]
if self.is_moe:
assert len(self.experts) == self.n_routed_experts // self.world_size_, "experts weight load not ok"
weights.append([
self.moe_gate,
self.shared_experts_gate_up_proj,
self.shared_experts_down_proj
])
else:
weights.append([
self.gate_up_proj,
self.down_proj,
])

for i in range(len(weights)):
assert weights[i] is not None, "index:" + str(i) + " " + errors
return

def _load_qkvo_weights(self, weights):
# input layernorm params
if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights:
self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"])

n_embed = self.network_config_["hidden_size"]
q_split_n_embed = n_embed // self.world_size_
kv_split_n_embed = (
n_embed
// self.network_config_["num_attention_heads"]
* self.network_config_["num_key_value_heads"]
// self.world_size_
)
# q k v weights for llama
if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights:
self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"]
self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :]
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1))

if f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight" in weights:
kv_a_proj_with_mqa_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_a_proj_with_mqa.weight"]
print("kv_a_proj_with_mqa shape", kv_a_proj_with_mqa_.shape)
kv_a_proj_with_mqa_ = kv_a_proj_with_mqa_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.kv_a_proj_with_mqa_ = kv_a_proj_with_mqa_.transpose(0, 1)

if f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight" in weights:
kv_a_layernorm_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_a_layernorm.weight"]
kv_a_layernorm_ = kv_a_layernorm_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.kv_a_layernorm_ = kv_a_layernorm_.transpose(0, 1)

if f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight" in weights:
kv_b_proj_ = weights[f"model.layers.{self.layer_num_}.self_attn.kv_b_proj.weight"]
kv_b_proj_ = kv_b_proj_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :]
self.kv_b_proj_ = kv_b_proj_.transpose(0, 1)

# attention output dense params
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights:
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"]
self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)]
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1))

return

def _load_ffn_weights(self, weights):
if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights:
self.ffn_norm_weight_ = self._cuda(
weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"]
)

if self.is_moe:
experts_per_rank = self.n_routed_experts // self.world_size_

if f"model.layers.{self.layer_num_}.mlp.gate.weight" in weights:
moe_gate = weights[f"model.layers.{self.layer_num_}.mlp.gate.weight"]
self.moe_gate = self._cuda(moe_gate.transpose(0, 1))

if f"model.layers.{self.layer_num_}.shared_experts.up_proj.weight" in weights:
shared_experts_up_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.up_proj.weight"]
self.shared_experts_up_proj = shared_experts_up_proj.transpose(0, 1)

if f"model.layers.{self.layer_num_}.shared_experts.gate_proj.weight" in weights:
shared_experts_gate_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.gate_proj.weight"]
self.shared_experts_gate_proj = shared_experts_gate_proj.transpose(0, 1)

self._try_cat_to(["shared_experts_gate_proj", "shared_experts_up_proj"], "shared_experts_gate_up_proj", cat_dim=1)

if f"model.layers.{self.layer_num_}.shared_experts.down_proj.weight" in weights:
self.shared_experts_down_proj = weights[f"model.layers.{self.layer_num_}.shared_experts.down_proj.weight"]
self.shared_experts_down_proj = self._cuda(self.shared_experts_down_proj.transpose(0, 1))

self.experts = []
for i_experts in range(experts_per_rank * self.tp_rank_, experts_per_rank * (self.tp_rank_ + 1)):
self.expert_up_proj = None
self.expert_gate_proj = None
self.expert_down_proj = None

if f"model.layers.{self.layer_num_}.experts.{i_experts}.up_proj.weight" in weights:
expert_up_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.up_proj.weight"]
self.expert_up_proj = expert_up_proj.transpose(0, 1)

if f"model.layers.{self.layer_num_}.experts.{i_experts}.gate_proj.weight" in weights:
expert_gate_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.gate_proj.weight"]
self.expert_gate_proj = expert_gate_proj.transpose(0, 1)

self._try_cat_to(["expert_gate_proj", "expert_up_proj"], "expert_gate_up_proj", cat_dim=1)

if f"model.layers.{self.layer_num_}.experts.{i_experts}.down_proj.weight" in weights:
self.expert_down_proj = weights[f"model.layers.{self.layer_num_}.experts.{i_experts}.down_proj.weight"]
self.expert_down_proj = self._cuda(self.expert_down_proj.transpose(0, 1))

if self.expert_gate_up_proj is not None and self.expert_down_proj is not None:
self.experts.append({
"expert_gate_up_proj": self.expert_gate_up_proj,
"expert_down_proj": self.expert_down_proj
})

else:
inter_size = self.network_config_["intermediate_size"]
split_inter_size = inter_size // self.world_size_

if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights:
up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
]
self.up_proj = up_proj.transpose(0, 1)

if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights:
gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][
split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :
]
self.gate_proj = gate_proj.transpose(0, 1)

self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1)

if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights:
self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
]
self.down_proj = self._cuda(self.down_proj.transpose(0, 1))
return
65 changes: 65 additions & 0 deletions lightllm/models/deepseek2/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
import json
import torch
from lightllm.models.deepseek2.layer_infer.transformer_layer_infer import Deepseek2TransformerLayerInfer
from lightllm.models.deepseek2.layer_weights.transformer_layer_weight import Deepseek2TransformerLayerWeight
from lightllm.models.deepseek2.infer_struct import Deepseek2InferStateInfo
from lightllm.common.basemodel.layer_weights.hf_load_utils import load_hf_weights

from lightllm.common.basemodel import LlamaTpPartModel
from lightllm.common.mem_utils import select_mem_manager_class
from lightllm.utils.log_utils import init_logger


logger = init_logger(__name__)

class Deepseek2TpPartModle(LlamaTpPartModel):
# weight class
transformer_weight_class = Deepseek2TransformerLayerWeight

# infer class
transformer_layer_infer_class = Deepseek2TransformerLayerInfer

# infer state class
infer_state_class = Deepseek2InferStateInfo

def __init__(self, kvargs):
super().__init__(kvargs)
return

def _init_config(self):
super()._init_config()
return

def _verify_params(self):
return super()._verify_params()


def _init_mem_manager(self):
self.mem_manager = select_mem_manager_class(self.mode, "deepseek2")(self.max_total_token_num,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
key_head_dim=self.config["qk_nope_head_dim"] + self.config["qk_rope_head_dim"],
value_head_dim=self.config["qk_nope_head_dim"],
layer_num=self.config["num_hidden_layers"])
return

def _init_custom(self):
return

def _init__weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.trans_layers_weight = [
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
for i in range(self.config["n_layer"])
]
load_hf_weights(
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
return

0 comments on commit a2be0e4

Please sign in to comment.