Skip to content

Commit

Permalink
modify all models
Browse files Browse the repository at this point in the history
  • Loading branch information
baishihao committed Nov 4, 2024
1 parent ed40f35 commit cd92bf6
Show file tree
Hide file tree
Showing 26 changed files with 259 additions and 327 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ def _bind_func(self):
return

def _get_qkv(self, input, cache_kv, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor:
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_)
torch.mm(
input.view(-1, self.embed_dim_),
q = layer_weight.mm_op.apply(input, layer_weight.q_weight_)
cache_kv = layer_weight.mm_op.apply(
input,
layer_weight.kv_weight_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
return q, cache_kv
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,14 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
def _get_qkv(
self, input, cache_kv: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight
) -> torch.Tensor:
q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_).view(
-1, self.tp_q_head_num_, self.head_dim_
)
torch.mm(
input.view(-1, self.embed_dim_),

q = layer_weight.mm_op.apply(input, layer_weight.q_weight_)
cache_kv = layer_weight.mm_op.apply(
input,
layer_weight.kv_weight_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)

q_ = q.float()
cache_k_ = cache_kv[:, 0 : self.tp_k_head_num_, :].float()
rotary_emb_fwd(q_, cache_k_, infer_state.position_cos, infer_state.position_sin)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,16 @@ def _load_qkvo_weights(self, weights):
qkv_weights = weights[f"model.layers.{self.layer_num_}.self_attn.W_pack.weight"]
split_size = qkv_weights.shape[0] // 3
q_weights, k_weights, v_weights = torch.split(qkv_weights, split_size, dim=0)
self.q_weight_ = q_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1))
k_weight_ = k_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
self.k_weight_ = k_weight_.transpose(0, 1)
v_weight_ = v_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
self.v_weight_ = v_weight_.transpose(0, 1)

self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1)
q_weight_ = q_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
self.q_weight_ = self.mm_op.preprocess_weight(q_weight_)
self.k_weight_ = k_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
self.v_weight_ = v_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]

self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight)
# attention output dense params
if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights:
self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"][
o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"][
:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)
]
self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1))
self.o_weight_ = self.mm_op.preprocess_weight(o_weight_)
return
35 changes: 20 additions & 15 deletions lightllm/models/bloom/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,17 +46,17 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTrans
def _get_qkv(
self, input, cache_kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight
) -> torch.Tensor:
q = torch.addmm(
layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0
q = layer_weight.mm_op.apply(
input,
layer_weight.q_weight_,
bias=layer_weight.q_bias_,
)
torch.addmm(
layer_weight.kv_bias_,
input.view(-1, self.embed_dim_),
cache_kv = layer_weight.mm_op.apply(
input,
layer_weight.kv_weight_,
beta=1.0,
alpha=1.0,
bias=layer_weight.kv_bias_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
return q, cache_kv

def _context_attention_kernel(
Expand Down Expand Up @@ -103,21 +103,26 @@ def _token_attention_kernel(
return o_tensor

def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor:
o = torch.addmm(
layer_weight.o_bias_,
input.view(-1, self.tp_q_head_num_ * self.head_dim_),
o = layer_weight.mm_op.apply(
input,
layer_weight.o_weight_,
beta=1.0 / self.world_size_,
bias=layer_weight.o_bias_,
)
return o

def _ffn(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor:
ffn1_out = torch.addmm(layer_weight.ffn_1_bias_, input.view(-1, self.embed_dim_), layer_weight.ffn_1_weight_)
ffn1_out = layer_weight.mm_op.apply(
input.view(-1, self.embed_dim_),
layer_weight.ffn_1_weight_,
bias=layer_weight.ffn_1_bias_,
)
input = None
gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh")
ffn1_out = None
ffn2_out = torch.addmm(
layer_weight.ffn_2_bias_, gelu_out, layer_weight.ffn_2_weight_, beta=1.0 / self.world_size_
ffn2_out = layer_weight.mm_op.apply(
gelu_out,
layer_weight.ffn_2_weight_,
bias=layer_weight.ffn_2_bias_,
)
gelu_out = None
return ffn2_out
46 changes: 19 additions & 27 deletions lightllm/models/bloom/layer_weights/transformer_layer_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,23 +61,19 @@ def _load_qkvo_weights(self, weights):
att_qkv_dense_weight = weights[f"h.{self.layer_num_}.self_attention.query_key_value.weight"].reshape(
head_num, 3, -1, n_embed
)
self.q_weight_ = self._cuda(
att_qkv_dense_weight[:, 0, :, :]
.reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
.transpose(0, 1)
)
self.k_weight_ = (
att_qkv_dense_weight[:, 1, :, :]
.reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
.transpose(0, 1)
)
self.v_weight_ = (
att_qkv_dense_weight[:, 2, :, :]
.reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
.transpose(0, 1)
self.q_weight_ = self.mm_op.preprocess_weight(
att_qkv_dense_weight[:, 0, :, :].reshape(-1, n_embed)[
split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :
]
)
self.k_weight_ = att_qkv_dense_weight[:, 1, :, :].reshape(-1, n_embed)[
split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :
]
self.v_weight_ = att_qkv_dense_weight[:, 2, :, :].reshape(-1, n_embed)[
split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :
]

self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1)
self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight)

if f"h.{self.layer_num_}.self_attention.query_key_value.bias" in weights:
n_embed = self.network_config_["n_embed"]
Expand All @@ -103,13 +99,13 @@ def _load_qkvo_weights(self, weights):
if f"h.{self.layer_num_}.self_attention.dense.weight" in weights:
n_embed = self.network_config_["n_embed"]
split_n_embed = n_embed // self.world_size_
self.o_weight_ = self._cuda(
self.o_weight_ = self.mm_op.preprocess_weight(
weights[f"h.{self.layer_num_}.self_attention.dense.weight"][
:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)
].transpose(0, 1)
]
)
if f"h.{self.layer_num_}.self_attention.dense.bias" in weights:
self.o_bias_ = self._cuda(weights[f"h.{self.layer_num_}.self_attention.dense.bias"])
self.o_bias_ = self._cuda(weights[f"h.{self.layer_num_}.self_attention.dense.bias"]) / self.world_size_
return

def _load_ffn_weights(self, weights):
Expand All @@ -123,10 +119,8 @@ def _load_ffn_weights(self, weights):
n_embed = self.network_config_["n_embed"] * 4
split_n_embed = n_embed // self.world_size_
self.ffn_1_weight_ = weights[f"h.{self.layer_num_}.mlp.dense_h_to_4h.weight"]
self.ffn_1_weight_ = self._cuda(
self.ffn_1_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :].transpose(
0, 1
)
self.ffn_1_weight_ = self.mm_op.preprocess_weight(
self.ffn_1_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :]
)

if f"h.{self.layer_num_}.mlp.dense_h_to_4h.bias" in weights:
Expand All @@ -141,14 +135,12 @@ def _load_ffn_weights(self, weights):
n_embed = self.network_config_["n_embed"] * 4
split_n_embed = n_embed // self.world_size_
self.ffn_2_weight_ = weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.weight"]
self.ffn_2_weight_ = self._cuda(
self.ffn_2_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)].transpose(
0, 1
)
self.ffn_2_weight_ = self.mm_op.preprocess_weight(
self.ffn_2_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)]
)

if f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias" in weights:
self.ffn_2_bias_ = self._cuda(weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias"])
self.ffn_2_bias_ = self._cuda(weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias"]) / self.world_size_
return

def _generate_alibi(self, n_head, dtype=torch.float16):
Expand Down
24 changes: 15 additions & 9 deletions lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,17 @@ def swiglu(self, x):
def _get_qkv(
self, input_emb, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight
):
q = torch.addmm(
layer_weight.q_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0
q = layer_weight.mm_op.apply(
input_emb.view(-1, self.embed_dim_),
layer_weight.q_weight_,
bias=layer_weight.q_bias_,
)
torch.addmm(
layer_weight.kv_bias_,
cache_kv = layer_weight.mm_op.apply(
input_emb.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
beta=1.0,
alpha=1.0,
bias=layer_weight.kv_bias_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_)
rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, 0 : self.tp_k_head_num_, :],
Expand All @@ -48,8 +48,14 @@ def _get_qkv(

def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight):

ffn1_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj)
ffn1_out = layer_weight.mm_op.apply(
input.view(-1, self.embed_dim_),
layer_weight.gate_up_proj,
)
act_out = self.swiglu(ffn1_out)
ffn1_out = None
ffn2_out = torch.mm(act_out, layer_weight.down_proj)
ffn2_out = layer_weight.mm_op.apply(
act_out,
layer_weight.down_proj,
)
return ffn2_out
30 changes: 13 additions & 17 deletions lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,16 +39,13 @@ def _load_qkvo_weights(self, weights):
tp_kv_head_dim = multi_query_group_num // self.world_size_ * head_dim
split_n_embed = n_embed // self.world_size_
if f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight" in weights:
qkv_weight_ = (
weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight"]
.transpose(0, 1)
.contiguous()
.to(self.data_type_)
)
qkv_weight_ = weights[
f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight"
].to(self.data_type_)
self.q_weight_ = qkv_weight_[:, :n_embed][
:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)
]
self.q_weight_ = self._cuda(self.q_weight_)
self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_)
k_weight_ = qkv_weight_[:, n_embed : n_embed + head_dim * multi_query_group_num]
self.k_weight_ = k_weight_[:, tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)]

Expand All @@ -57,7 +54,7 @@ def _load_qkvo_weights(self, weights):
]
self.v_weight_ = v_weight_[:, tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)]

self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1)
self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight)

if f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias" in weights:

Expand All @@ -79,8 +76,7 @@ def _load_qkvo_weights(self, weights):
if f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight" in weights:
self.o_weight_ = weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight"]
self.o_weight_ = self.o_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)]
self.o_weight_ = self.o_weight_.transpose(0, 1)
self.o_weight_ = self._cuda(self.o_weight_)
self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_)

def _load_ffn_weights(self, weights):
if f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights:
Expand All @@ -97,21 +93,21 @@ def _load_ffn_weights(self, weights):
self.data_type_
)
gate_proj = tweights[:ffn_hidden_size, :]
gate_proj = gate_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :]
self.gate_proj = gate_proj.transpose(0, 1)
self.gate_proj = gate_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :]

up_proj = tweights[ffn_hidden_size : 2 * ffn_hidden_size, :]
up_proj = up_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :]
self.up_proj = up_proj.transpose(0, 1)
self.up_proj = up_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :]

self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1)
self._try_cat_to(
["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight
)

if f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight" in weights:
self.down_proj = weights[f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight"].to(
self.data_type_
)
self.down_proj = self.down_proj[
:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)
].transpose(0, 1)
self.down_proj = self._cuda(self.down_proj)
]
self.down_proj = self.mm_op.preprocess_weight(self.down_proj)
return
18 changes: 13 additions & 5 deletions lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,23 @@ class Gemma_2bTransformerLayerInfer(LlamaTransformerLayerInfer):

def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
self.tp_k_head_num_ = network_config["num_key_value_heads"] # [SYM] always == 1
self.tp_k_head_num_ = network_config["num_key_value_heads"] # [SYM] always == 1
self.tp_v_head_num_ = network_config["num_key_value_heads"]
return

def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bTransformerLayerWeight) -> torch.Tensor:
up_gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj)
def _ffn(
self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bTransformerLayerWeight
) -> torch.Tensor:
up_gate_out = layer_weight.mm_op.apply(
input.view(-1, self.embed_dim_),
layer_weight.gate_up_proj,
)
ffn1_out = gelu_and_mul_fwd(up_gate_out)
input = None
up_gate_out = None
ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj)
ffn2_out = layer_weight.mm_op.apply(
ffn1_out,
layer_weight.down_proj,
)
ffn1_out = None
return ffn2_out
return ffn2_out
Loading

0 comments on commit cd92bf6

Please sign in to comment.