From cd92bf6d02adeee179a4a3980692434f67164f14 Mon Sep 17 00:00:00 2001 From: baishihao Date: Mon, 4 Nov 2024 10:14:57 +0800 Subject: [PATCH] modify all models --- .../layer_infer/transformer_layer_infer.py | 8 +-- .../layer_infer/transformer_layer_infer.py | 12 ++--- .../layer_weights/transformer_layer_weight.py | 17 +++--- .../layer_infer/transformer_layer_infer.py | 35 +++++++------ .../layer_weights/transformer_layer_weight.py | 46 +++++++--------- .../layer_infer/transformer_layer_infer.py | 24 +++++---- .../layer_weights/transformer_layer_weight.py | 30 +++++------ .../layer_infer/transformer_layer_infer.py | 18 +++++-- .../layer_weights/transformer_layer_weight.py | 24 ++++----- .../layer_infer/transformer_layer_infer.py | 18 +++---- .../layer_weights/transformer_layer_weight.py | 14 +++-- .../layer_weights/transformer_layer_weight.py | 52 ++++++++----------- .../layer_infer/transformer_layer_infer.py | 17 +++--- .../layer_weights/transformer_layer_weight.py | 24 ++++----- .../layer_infer/transformer_layer_infer.py | 6 +-- .../layer_weights/transformer_layer_weight.py | 23 +++----- .../layer_infer/transformer_layer_infer.py | 12 ++--- .../layer_weights/transformer_layer_weight.py | 22 ++++---- .../layer_infer/transformer_layer_infer.py | 12 ++--- .../layer_weights/transformer_layer_weight.py | 22 ++++---- .../layer_infer/transformer_layer_infer.py | 26 +++++----- .../layer_weights/transformer_layer_weight.py | 26 ++++------ .../layer_weights/transformer_layer_weight.py | 27 +++------- .../layer_infer/transformer_layer_infer.py | 23 ++++---- .../layer_weights/transformer_layer_weight.py | 26 ++++------ .../layer_weights/transformer_layer_weight.py | 22 ++++---- 26 files changed, 259 insertions(+), 327 deletions(-) diff --git a/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py b/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py index 18d408a4..33451d95 100755 --- a/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/baichuan13b/layer_infer/transformer_layer_infer.py @@ -26,10 +26,10 @@ def _bind_func(self): return def _get_qkv(self, input, cache_kv, infer_state, layer_weight: BaiChuan13bTransformerLayerWeight) -> torch.Tensor: - q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_) - torch.mm( - input.view(-1, self.embed_dim_), + q = layer_weight.mm_op.apply(input, layer_weight.q_weight_) + cache_kv = layer_weight.mm_op.apply( + input, layer_weight.kv_weight_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) return q, cache_kv diff --git a/lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py b/lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py index c22a7f23..f07d974d 100755 --- a/lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/baichuan2_7b/layer_infer/transformer_layer_infer.py @@ -15,14 +15,14 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): def _get_qkv( self, input, cache_kv: torch.Tensor, infer_state: LlamaInferStateInfo, layer_weight: LlamaTransformerLayerWeight ) -> torch.Tensor: - q = torch.mm(input.view(-1, self.embed_dim_), layer_weight.q_weight_).view( - -1, self.tp_q_head_num_, self.head_dim_ - ) - torch.mm( - input.view(-1, self.embed_dim_), + + q = layer_weight.mm_op.apply(input, layer_weight.q_weight_) + cache_kv = layer_weight.mm_op.apply( + input, layer_weight.kv_weight_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) + q_ = q.float() cache_k_ = cache_kv[:, 0 : self.tp_k_head_num_, :].float() rotary_emb_fwd(q_, cache_k_, infer_state.position_cos, infer_state.position_sin) diff --git a/lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py b/lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py index fe4aa736..0082155c 100644 --- a/lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/baichuan7b/layer_weights/transformer_layer_weight.py @@ -22,19 +22,16 @@ def _load_qkvo_weights(self, weights): qkv_weights = weights[f"model.layers.{self.layer_num_}.self_attn.W_pack.weight"] split_size = qkv_weights.shape[0] // 3 q_weights, k_weights, v_weights = torch.split(qkv_weights, split_size, dim=0) - self.q_weight_ = q_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) - k_weight_ = k_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = k_weight_.transpose(0, 1) - v_weight_ = v_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = v_weight_.transpose(0, 1) - - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + q_weight_ = q_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = self.mm_op.preprocess_weight(q_weight_) + self.k_weight_ = k_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: - self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"][ + o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"][ :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) ] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(o_weight_) return diff --git a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py index 3dd8654c..57b5faa8 100755 --- a/lightllm/models/bloom/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/bloom/layer_infer/transformer_layer_infer.py @@ -46,17 +46,17 @@ def _ffn_norm(self, input, infer_state: InferStateInfo, layer_weight: BloomTrans def _get_qkv( self, input, cache_kv, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight ) -> torch.Tensor: - q = torch.addmm( - layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 + q = layer_weight.mm_op.apply( + input, + layer_weight.q_weight_, + bias=layer_weight.q_bias_, ) - torch.addmm( - layer_weight.kv_bias_, - input.view(-1, self.embed_dim_), + cache_kv = layer_weight.mm_op.apply( + input, layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, + bias=layer_weight.kv_bias_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) return q, cache_kv def _context_attention_kernel( @@ -103,21 +103,26 @@ def _token_attention_kernel( return o_tensor def _get_o(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: - o = torch.addmm( - layer_weight.o_bias_, - input.view(-1, self.tp_q_head_num_ * self.head_dim_), + o = layer_weight.mm_op.apply( + input, layer_weight.o_weight_, - beta=1.0 / self.world_size_, + bias=layer_weight.o_bias_, ) return o def _ffn(self, input, infer_state: InferStateInfo, layer_weight: BloomTransformerLayerWeight) -> torch.Tensor: - ffn1_out = torch.addmm(layer_weight.ffn_1_bias_, input.view(-1, self.embed_dim_), layer_weight.ffn_1_weight_) + ffn1_out = layer_weight.mm_op.apply( + input.view(-1, self.embed_dim_), + layer_weight.ffn_1_weight_, + bias=layer_weight.ffn_1_bias_, + ) input = None gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh") ffn1_out = None - ffn2_out = torch.addmm( - layer_weight.ffn_2_bias_, gelu_out, layer_weight.ffn_2_weight_, beta=1.0 / self.world_size_ + ffn2_out = layer_weight.mm_op.apply( + gelu_out, + layer_weight.ffn_2_weight_, + bias=layer_weight.ffn_2_bias_, ) gelu_out = None return ffn2_out diff --git a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py index a43b8a45..8d449ad6 100644 --- a/lightllm/models/bloom/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/bloom/layer_weights/transformer_layer_weight.py @@ -61,23 +61,19 @@ def _load_qkvo_weights(self, weights): att_qkv_dense_weight = weights[f"h.{self.layer_num_}.self_attention.query_key_value.weight"].reshape( head_num, 3, -1, n_embed ) - self.q_weight_ = self._cuda( - att_qkv_dense_weight[:, 0, :, :] - .reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - .transpose(0, 1) - ) - self.k_weight_ = ( - att_qkv_dense_weight[:, 1, :, :] - .reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - .transpose(0, 1) - ) - self.v_weight_ = ( - att_qkv_dense_weight[:, 2, :, :] - .reshape(-1, n_embed)[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - .transpose(0, 1) + self.q_weight_ = self.mm_op.preprocess_weight( + att_qkv_dense_weight[:, 0, :, :].reshape(-1, n_embed)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ] ) + self.k_weight_ = att_qkv_dense_weight[:, 1, :, :].reshape(-1, n_embed)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ] + self.v_weight_ = att_qkv_dense_weight[:, 2, :, :].reshape(-1, n_embed)[ + split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), : + ] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"h.{self.layer_num_}.self_attention.query_key_value.bias" in weights: n_embed = self.network_config_["n_embed"] @@ -103,13 +99,13 @@ def _load_qkvo_weights(self, weights): if f"h.{self.layer_num_}.self_attention.dense.weight" in weights: n_embed = self.network_config_["n_embed"] split_n_embed = n_embed // self.world_size_ - self.o_weight_ = self._cuda( + self.o_weight_ = self.mm_op.preprocess_weight( weights[f"h.{self.layer_num_}.self_attention.dense.weight"][ :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) - ].transpose(0, 1) + ] ) if f"h.{self.layer_num_}.self_attention.dense.bias" in weights: - self.o_bias_ = self._cuda(weights[f"h.{self.layer_num_}.self_attention.dense.bias"]) + self.o_bias_ = self._cuda(weights[f"h.{self.layer_num_}.self_attention.dense.bias"]) / self.world_size_ return def _load_ffn_weights(self, weights): @@ -123,10 +119,8 @@ def _load_ffn_weights(self, weights): n_embed = self.network_config_["n_embed"] * 4 split_n_embed = n_embed // self.world_size_ self.ffn_1_weight_ = weights[f"h.{self.layer_num_}.mlp.dense_h_to_4h.weight"] - self.ffn_1_weight_ = self._cuda( - self.ffn_1_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :].transpose( - 0, 1 - ) + self.ffn_1_weight_ = self.mm_op.preprocess_weight( + self.ffn_1_weight_[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] ) if f"h.{self.layer_num_}.mlp.dense_h_to_4h.bias" in weights: @@ -141,14 +135,12 @@ def _load_ffn_weights(self, weights): n_embed = self.network_config_["n_embed"] * 4 split_n_embed = n_embed // self.world_size_ self.ffn_2_weight_ = weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.weight"] - self.ffn_2_weight_ = self._cuda( - self.ffn_2_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)].transpose( - 0, 1 - ) + self.ffn_2_weight_ = self.mm_op.preprocess_weight( + self.ffn_2_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] ) if f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias" in weights: - self.ffn_2_bias_ = self._cuda(weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias"]) + self.ffn_2_bias_ = self._cuda(weights[f"h.{self.layer_num_}.mlp.dense_4h_to_h.bias"]) / self.world_size_ return def _generate_alibi(self, n_head, dtype=torch.float16): diff --git a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py index 60620b95..f2778a46 100755 --- a/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/chatglm2/layer_infer/transformer_layer_infer.py @@ -27,17 +27,17 @@ def swiglu(self, x): def _get_qkv( self, input_emb, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight ): - q = torch.addmm( - layer_weight.q_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 + q = layer_weight.mm_op.apply( + input_emb.view(-1, self.embed_dim_), + layer_weight.q_weight_, + bias=layer_weight.q_bias_, ) - torch.addmm( - layer_weight.kv_bias_, + cache_kv = layer_weight.mm_op.apply( input_emb.view(-1, self.embed_dim_), layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, + bias=layer_weight.kv_bias_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], @@ -48,8 +48,14 @@ def _get_qkv( def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: ChatGLM2TransformerLayerWeight): - ffn1_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj) + ffn1_out = layer_weight.mm_op.apply( + input.view(-1, self.embed_dim_), + layer_weight.gate_up_proj, + ) act_out = self.swiglu(ffn1_out) ffn1_out = None - ffn2_out = torch.mm(act_out, layer_weight.down_proj) + ffn2_out = layer_weight.mm_op.apply( + act_out, + layer_weight.down_proj, + ) return ffn2_out diff --git a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py b/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py index 19e41304..2df8dc5c 100755 --- a/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/chatglm2/layer_weights/transformer_layer_weight.py @@ -39,16 +39,13 @@ def _load_qkvo_weights(self, weights): tp_kv_head_dim = multi_query_group_num // self.world_size_ * head_dim split_n_embed = n_embed // self.world_size_ if f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight" in weights: - qkv_weight_ = ( - weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight"] - .transpose(0, 1) - .contiguous() - .to(self.data_type_) - ) + qkv_weight_ = weights[ + f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.weight" + ].to(self.data_type_) self.q_weight_ = qkv_weight_[:, :n_embed][ :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) ] - self.q_weight_ = self._cuda(self.q_weight_) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) k_weight_ = qkv_weight_[:, n_embed : n_embed + head_dim * multi_query_group_num] self.k_weight_ = k_weight_[:, tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)] @@ -57,7 +54,7 @@ def _load_qkvo_weights(self, weights): ] self.v_weight_ = v_weight_[:, tp_kv_head_dim * self.tp_rank_ : tp_kv_head_dim * (self.tp_rank_ + 1)] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"transformer.encoder.layers.{self.layer_num_}.self_attention.query_key_value.bias" in weights: @@ -79,8 +76,7 @@ def _load_qkvo_weights(self, weights): if f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight" in weights: self.o_weight_ = weights[f"transformer.encoder.layers.{self.layer_num_}.self_attention.dense.weight"] self.o_weight_ = self.o_weight_[:, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self.o_weight_.transpose(0, 1) - self.o_weight_ = self._cuda(self.o_weight_) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) def _load_ffn_weights(self, weights): if f"transformer.encoder.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: @@ -97,14 +93,14 @@ def _load_ffn_weights(self, weights): self.data_type_ ) gate_proj = tweights[:ffn_hidden_size, :] - gate_proj = gate_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] - self.gate_proj = gate_proj.transpose(0, 1) + self.gate_proj = gate_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] up_proj = tweights[ffn_hidden_size : 2 * ffn_hidden_size, :] - up_proj = up_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] - self.up_proj = up_proj.transpose(0, 1) + self.up_proj = up_proj[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to( + ["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight + ) if f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight" in weights: self.down_proj = weights[f"transformer.encoder.layers.{self.layer_num_}.mlp.dense_4h_to_h.weight"].to( @@ -112,6 +108,6 @@ def _load_ffn_weights(self, weights): ) self.down_proj = self.down_proj[ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) - ].transpose(0, 1) - self.down_proj = self._cuda(self.down_proj) + ] + self.down_proj = self.mm_op.preprocess_weight(self.down_proj) return diff --git a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py index 1c3e6f33..847efb48 100644 --- a/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/gemma_2b/layer_infer/transformer_layer_infer.py @@ -18,15 +18,23 @@ class Gemma_2bTransformerLayerInfer(LlamaTransformerLayerInfer): def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) - self.tp_k_head_num_ = network_config["num_key_value_heads"] # [SYM] always == 1 + self.tp_k_head_num_ = network_config["num_key_value_heads"] # [SYM] always == 1 self.tp_v_head_num_ = network_config["num_key_value_heads"] return - def _ffn(self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bTransformerLayerWeight) -> torch.Tensor: - up_gate_out = torch.mm(input.view(-1, self.embed_dim_), layer_weight.gate_up_proj) + def _ffn( + self, input, infer_state: LlamaInferStateInfo, layer_weight: Gemma_2bTransformerLayerWeight + ) -> torch.Tensor: + up_gate_out = layer_weight.mm_op.apply( + input.view(-1, self.embed_dim_), + layer_weight.gate_up_proj, + ) ffn1_out = gelu_and_mul_fwd(up_gate_out) input = None up_gate_out = None - ffn2_out = torch.mm(ffn1_out, layer_weight.down_proj) + ffn2_out = layer_weight.mm_op.apply( + ffn1_out, + layer_weight.down_proj, + ) ffn1_out = None - return ffn2_out \ No newline at end of file + return ffn2_out diff --git a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py index 4597938e..be8e4950 100644 --- a/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/gemma_2b/layer_weights/transformer_layer_weight.py @@ -22,23 +22,21 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: - k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - self.k_weight_ = k_weight_.transpose(0, 1) + self.k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: - v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - self.v_weight_ = v_weight_.transpose(0, 1) + self.v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_.transpose(0, 1)) - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) return @@ -53,22 +51,20 @@ def _load_ffn_weights(self, weights): split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + self.up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + self.gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.gate_proj = gate_proj.transpose(0, 1) - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) ] - self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) - return \ No newline at end of file + self.down_proj = self.mm_op.preprocess_weight(self.down_proj.transpose(0, 1)) + return diff --git a/lightllm/models/internlm/layer_infer/transformer_layer_infer.py b/lightllm/models/internlm/layer_infer/transformer_layer_infer.py index bf594f97..78ceea0a 100755 --- a/lightllm/models/internlm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/internlm/layer_infer/transformer_layer_infer.py @@ -18,16 +18,13 @@ def _get_qkv( self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - q = self.alloc_tensor((input.size(0), layer_weight.q_weight_.size(1)), dtype=input.dtype) - torch.addmm(layer_weight.q_bias_, input, layer_weight.q_weight_, beta=1.0, alpha=1.0, out=q) - torch.addmm( - layer_weight.kv_bias_, + q = layer_weight.mm_op.apply(input, layer_weight.q_weight_, bias=layer_weight.q_bias_) + cache_kv = layer_weight.mm_op.apply( input, layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, + bias=layer_weight.kv_bias_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], @@ -40,6 +37,9 @@ def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight ) -> torch.Tensor: input = input.view(-1, self.tp_o_head_num_ * self.head_dim_) - o_tensor = self.alloc_tensor((input.size(0), layer_weight.o_weight_.size(1)), input.dtype) - torch.addmm(layer_weight.o_bias_, input, layer_weight.o_weight_, beta=1.0 / self.world_size_, out=o_tensor) + o_tensor = layer_weight.mm_op.apply( + input, + layer_weight.o_weight_, + bias=layer_weight.o_bias_, + ) return o_tensor diff --git a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py index ccb09ef1..d9cd494f 100755 --- a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py @@ -55,7 +55,7 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][ q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) @@ -63,22 +63,20 @@ def _load_qkvo_weights(self, weights): self.q_bias_ = self._cuda(self.q_bias_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = k_weight_.transpose(0, 1) + self.k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights: self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][ kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) ] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = v_weight_.transpose(0, 1) + self.v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights: self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][ kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) ] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) @@ -86,8 +84,8 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) if f"model.layers.{self.layer_num_}.self_attn.o_proj.bias" in weights: self.o_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"] - self.o_bias_ = self._cuda(self.o_bias_) + self.o_bias_ = self._cuda(self.o_bias_) / self.world_size_ return diff --git a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py index ef98a1ae..5ad381ee 100755 --- a/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm2/layer_weights/transformer_layer_weight.py @@ -13,27 +13,27 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo def verify_load(self): errors = "weights load not ok" - - # handle internlm 20b, which has no bias, so set q k v o bias to zero - if not self.network_config_.get("bias", True): - for layer_type in ("q", "kv", "o"): - attr_name = f"{layer_type}_bias_" - if hasattr(self, attr_name): - continue - setattr(self, attr_name, self._cuda(torch.zeros(1))) - weights = [ self.att_norm_weight_, self.q_weight_, self.kv_weight_, self.o_weight_, - self.q_bias_, - self.kv_bias_, - self.o_bias_, self.ffn_norm_weight_, self.gate_up_proj, self.down_proj, ] + # handle internlm 20b, which has no bias, so set q k v o bias to zero + if not self.network_config_.get("bias", True): + for layer_type in ("q", "kv", "o"): + attr_name = f"{layer_type}_bias_" + if hasattr(self, attr_name): + continue + setattr(self, attr_name, None) + else: + weights.append(self.q_bias_) + weights.append(self.kv_bias_) + weights.append(self.o_bias_) + for i in range(len(weights)): assert weights[i] is not None, "index:" + str(i) + " " + errors return @@ -58,29 +58,25 @@ def _load_qkvo_weights(self, weights): q_groups = self.network_config_["num_attention_heads"] // self.network_config_["num_key_value_heads"] qkv_weight_ = qkv_weight_.reshape(self.network_config_["num_key_value_heads"], q_groups + 2, head_dim, -1) q_weight_ = qkv_weight_[:, :q_groups, :, :].reshape(-1, qkv_weight_.shape[-1]) - self.q_weight_ = self._cuda( - q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) :].transpose(0, 1) + self.q_weight_ = self.mm_op.preprocess_weight( + q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) :] ) k_weight_ = qkv_weight_[:, -2, :, :].reshape(-1, qkv_weight_.shape[-1]) - self.k_weight_ = k_weight_[ - kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) : - ].transpose(0, 1) + self.k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) :] v_weight_ = qkv_weight_[:, -1, :, :].reshape(-1, qkv_weight_.shape[-1]) - self.v_weight_ = v_weight_[ - kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) : - ].transpose(0, 1) + self.v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) :] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) # attention output dense params if f"model.layers.{self.layer_num_}.attention.wo.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.attention.wo.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) if f"model.layers.{self.layer_num_}.attention.wo.bias" in weights: self.o_bias_ = weights[f"model.layers.{self.layer_num_}.attention.wo.bias"] - self.o_bias_ = self._cuda(self.o_bias_) + self.o_bias_ = self._cuda(self.o_bias_) / self.world_size_ return def _load_ffn_weights(self, weights): @@ -91,22 +87,20 @@ def _load_ffn_weights(self, weights): split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.feed_forward.w3.weight" in weights: - up_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w3.weight"][ + self.up_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w3.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.feed_forward.w1.weight" in weights: - gate_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w1.weight"][ + self.gate_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w1.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.gate_proj = gate_proj.transpose(0, 1) - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"model.layers.{self.layer_num_}.feed_forward.w2.weight" in weights: self.down_proj = weights[f"model.layers.{self.layer_num_}.feed_forward.w2.weight"][ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) ] - self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + self.down_proj = self.mm_op.preprocess_weight(self.down_proj) return diff --git a/lightllm/models/minicpm/layer_infer/transformer_layer_infer.py b/lightllm/models/minicpm/layer_infer/transformer_layer_infer.py index 10dd0c4d..c56a0334 100755 --- a/lightllm/models/minicpm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/minicpm/layer_infer/transformer_layer_infer.py @@ -17,17 +17,13 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): def _get_qkv( self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight ) -> torch.Tensor: - q = torch.addmm( - layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 - ) - torch.addmm( - layer_weight.kv_bias_, + q = layer_weight.mm_op.apply(input.view(-1, self.embed_dim_), layer_weight.q_weight_, bias=layer_weight.q_bias_) + cache_kv = layer_weight.mm_op.apply( input.view(-1, self.embed_dim_), layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, + bias=layer_weight.kv_bias_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], @@ -39,10 +35,9 @@ def _get_qkv( def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: InternlmTransformerLayerWeight ) -> torch.Tensor: - o_tensor = torch.addmm( - layer_weight.o_bias_, + o_tensor = layer_weight.mm_op.applys( input.view(-1, self.tp_o_head_num_ * self.head_dim_), layer_weight.o_weight_, - beta=1.0 / self.world_size_, + bias=layer_weight.o_bias_, ) return o_tensor diff --git a/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py b/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py index 135c3738..5c115060 100755 --- a/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/minicpm/layer_weights/transformer_layer_weight.py @@ -11,7 +11,7 @@ def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mo super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) num_hidden_layers = self.network_config_["num_hidden_layers"] scale_depth = self.network_config_.get("scale_depth", math.sqrt(num_hidden_layers)) - self.layer_scale =scale_depth / math.sqrt(num_hidden_layers) + self.layer_scale = scale_depth / math.sqrt(num_hidden_layers) return def _load_qkvo_weights(self, weights): @@ -30,25 +30,23 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = k_weight_.transpose(0, 1) + self.k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = v_weight_.transpose(0, 1) + self.v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) * self.layer_scale + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) * self.layer_scale - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) return @@ -62,22 +60,20 @@ def _load_ffn_weights(self, weights): split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + self.up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + self.gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.gate_proj = gate_proj.transpose(0, 1) - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) ] - self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) * self.layer_scale + self.down_proj = self.mm_op.preprocess_weight(self.down_proj) * self.layer_scale return diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py index 7ac3019c..a9db29a4 100755 --- a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -29,12 +29,12 @@ def _bind_attention(self): return def _get_qkv(self, input_emb, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight): - q = torch.mm(input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_) - torch.mm( + q = layer_weight.mm_op.apply(input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_) + cache_kv = layer_weight.mm_op.apply( input_emb.view(-1, self.embed_dim_), layer_weight.kv_weight_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], diff --git a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py index 1400eaba..41fb4333 100755 --- a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py @@ -42,29 +42,24 @@ def _load_qkvo_weights(self, weights): // self.world_size_ ) if f"model.layers.{self.layer_num_}.self_attn.qkv_proj.weight" in weights: - qkv_weight_ = ( - weights[f"model.layers.{self.layer_num_}.self_attn.qkv_proj.weight"] - .transpose(0, 1) - .contiguous() - .to(self.data_type_) - ) + qkv_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.qkv_proj.weight"].to(self.data_type_) self.q_weight_ = qkv_weight_[:, :n_embed][ :, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) ] - self.q_weight_ = self._cuda(self.q_weight_) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) k_weight_ = qkv_weight_[:, n_embed : n_embed + kv_n_embed] self.k_weight_ = k_weight_[:, kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)] v_weight_ = qkv_weight_[:, n_embed + kv_n_embed : n_embed + 2 * kv_n_embed] self.v_weight_ = v_weight_[:, kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) return @@ -79,21 +74,19 @@ def _load_ffn_weights(self, weights): if f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight" in weights: gate_up_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight"] - gate_proj = gate_up_proj[0:inter_size][ + self.gate_proj = gate_up_proj[0:inter_size][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.gate_proj = gate_proj.transpose(0, 1) - up_proj = gate_up_proj[inter_size:][ + self.up_proj = gate_up_proj[inter_size:][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.up_proj = up_proj.transpose(0, 1) - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) ] - self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + self.down_proj = self.mm_op.preprocess_weight(self.down_proj) return diff --git a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py index a4c5313b..faa11e2f 100755 --- a/lightllm/models/qwen/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen/layer_infer/transformer_layer_infer.py @@ -17,17 +17,15 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): return def _get_qkv(self, input_emb, cache_kv, infer_state: QwenInferStateInfo, layer_weight: QwenTransformerLayerWeight): - q = torch.addmm( - layer_weight.q_bias_, input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 + q = layer_weight.mm_op.apply( + input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_, bias=layer_weight.q_bias_ ) - torch.addmm( - layer_weight.kv_bias_, + cache_kv = layer_weight.mm_op.apply( input_emb.view(-1, self.embed_dim_), layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, + bias=layer_weight.kv_bias_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], diff --git a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py index 039d1516..c4a89342 100755 --- a/lightllm/models/qwen/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen/layer_weights/transformer_layer_weight.py @@ -22,13 +22,11 @@ def load_hf_weights(self, weights): q_weights, k_weights, v_weights = torch.split(qkv_weights, split_size, dim=0) self.q_weight_ = q_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) - k_weight_ = k_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = k_weight_.transpose(0, 1) - v_weight_ = v_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = v_weight_.transpose(0, 1) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) + self.k_weight_ = k_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = v_weights[split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1), :] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"transformer.h.{self.layer_num_}.attn.c_attn.bias" in weights: qkv_bias = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.bias"] @@ -45,7 +43,7 @@ def load_hf_weights(self, weights): self.o_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.weight"][ :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) ] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) if f"transformer.h.{self.layer_num_}.ln_2.weight" in weights: self.ffn_norm_weight_ = self._cuda(weights[f"transformer.h.{self.layer_num_}.ln_2.weight"]) @@ -54,24 +52,22 @@ def load_hf_weights(self, weights): split_inter_size = inter_size // self.world_size_ if f"transformer.h.{self.layer_num_}.mlp.w1.weight" in weights: - up_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w1.weight"][ + self.up_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w1.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.up_proj = up_proj.transpose(0, 1) if f"transformer.h.{self.layer_num_}.mlp.w2.weight" in weights: - gate_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w2.weight"][ + self.gate_proj = weights[f"transformer.h.{self.layer_num_}.mlp.w2.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.gate_proj = gate_proj.transpose(0, 1) if f"transformer.h.{self.layer_num_}.mlp.c_proj.weight" in weights: self.down_proj = weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.weight"][ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) ] - self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + self.down_proj = self.mm_op.preprocess_weight(self.down_proj) - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight) return diff --git a/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py b/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py index d50ec869..18add543 100644 --- a/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/qwen2/layer_infer/transformer_layer_infer.py @@ -24,17 +24,13 @@ def _get_qkv( layer_weight: Qwen2TransformerLayerWeight, ) -> torch.Tensor: input = input.view(-1, self.embed_dim_) - dtype = input.dtype - q = self.alloc_tensor((input.shape[0], layer_weight.q_weight_.shape[1]), dtype=dtype) - torch.addmm(layer_weight.q_bias_, input, layer_weight.q_weight_, beta=1.0, alpha=1.0, out=q) - torch.addmm( - layer_weight.kv_bias_, + q = layer_weight.mm_op.apply(input, layer_weight.q_weight_, bias=layer_weight.q_bias_) + cache_kv = torch.addmm( input, layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, + bias=layer_weight.kv_bias_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], diff --git a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py index 06eb86ff..fe403104 100644 --- a/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/qwen2/layer_weights/transformer_layer_weight.py @@ -26,23 +26,21 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = k_weight_.transpose(0, 1) + self.k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = v_weight_.transpose(0, 1) + self.v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) # q k v bias if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: @@ -57,7 +55,7 @@ def _load_qkvo_weights(self, weights): v_bias = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"] self.v_bias_ = v_bias[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) def _load_ffn_weights(self, weights): @@ -70,24 +68,22 @@ def _load_ffn_weights(self, weights): split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + self.up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + self.gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.gate_proj = gate_proj.transpose(0, 1) - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) ] - self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + self.down_proj = self.mm_op.preprocess_weight(self.down_proj) return def load_hf_weights(self, weights): diff --git a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py index 1651d8fe..c0a70dfb 100755 --- a/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/stablelm/layer_infer/transformer_layer_infer.py @@ -16,7 +16,7 @@ def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): super().__init__(layer_num, tp_rank, world_size, network_config, mode) self.partial_rotary_factor = self.network_config_.get("partial_rotary_factor", 1) return - + def _bind_norm(self): self._att_norm = partial(StablelmTransformerLayerInfer._att_norm, self) self._ffn_norm = partial(StablelmTransformerLayerInfer._ffn_norm, self) @@ -25,17 +25,15 @@ def _bind_norm(self): def _get_qkv( self, input, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - q = torch.addmm( - layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_, beta=1.0, alpha=1.0 + q = layer_weight.mm_op.pre.apply( + input.view(-1, self.embed_dim_), layer_weight.q_weight_, bias=layer_weight.q_bias_ ) - torch.addmm( - layer_weight.kv_bias_, + cache_kv = layer_weight.mm_op.pre.apply( input.view(-1, self.embed_dim_), layer_weight.kv_weight_, - beta=1.0, - alpha=1.0, + bias=layer_weight.kv_bias_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], @@ -48,13 +46,15 @@ def _get_qkv( def _get_o( self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight ) -> torch.Tensor: - o_tensor = torch.mm( + o_tensor = layer_weight.mm_op.pre.apply( input.view(-1, self.tp_o_head_num_ * self.head_dim_), layer_weight.o_weight_, ) return o_tensor - def _att_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight) -> torch.Tensor: + def _att_norm( + self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight + ) -> torch.Tensor: return layernorm_forward( input.view(-1, self.embed_dim_), weight=layer_weight.att_norm_weight_, @@ -62,10 +62,12 @@ def _att_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: Stabl eps=self.eps_, ) - def _ffn_norm(self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight) -> torch.Tensor: + def _ffn_norm( + self, input, infer_state: LlamaInferStateInfo, layer_weight: StablelmTransformerLayerWeight + ) -> torch.Tensor: return layernorm_forward( input.view(-1, self.embed_dim_), weight=layer_weight.ffn_norm_weight_, bias=layer_weight.ffn_norm_bias_, eps=self.eps_, - ) \ No newline at end of file + ) diff --git a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py index 8bb4c7a5..0a4e5529 100755 --- a/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/stablelm/layer_weights/transformer_layer_weight.py @@ -58,7 +58,7 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][ q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) @@ -66,22 +66,20 @@ def _load_qkvo_weights(self, weights): self.q_bias_ = self._cuda(self.q_bias_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = k_weight_.transpose(0, 1) + self.k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights: self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][ kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) ] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = v_weight_.transpose(0, 1) + self.v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights: self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][ kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) ] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) @@ -89,7 +87,7 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) return @@ -99,30 +97,26 @@ def _load_ffn_weights(self, weights): weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] ) if f"model.layers.{self.layer_num_}.post_attention_layernorm.bias" in weights: - self.ffn_norm_bias_ = self._cuda( - weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.bias"] - ) + self.ffn_norm_bias_ = self._cuda(weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.bias"]) inter_size = self.network_config_["intermediate_size"] split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + self.up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + self.gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.gate_proj = gate_proj.transpose(0, 1) - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) ] - self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + self.down_proj = self.mm_op.preprocess_weight(self.down_proj) return diff --git a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py index 913e3e6c..37960c96 100644 --- a/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder/layer_weights/transformer_layer_weight.py @@ -25,21 +25,16 @@ def _load_qkvo_weights(self, weights): head_dim = self.network_config_["hidden_size"] // self.network_config_["num_attention_heads"] split_n_embed = n_embed // self.world_size_ if f"transformer.h.{self.layer_num_}.attn.c_attn.weight" in weights: - qkv_weight_ = ( - weights[f"transformer.h.{self.layer_num_}.attn.c_attn.weight"] - .transpose(0, 1) - .contiguous() - .to(self.data_type_) - ) + qkv_weight_ = weights[f"transformer.h.{self.layer_num_}.attn.c_attn.weight"].to(self.data_type_) self.q_weight_ = qkv_weight_[:, :n_embed][ :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) ] - self.q_weight_ = self.q_weight_.cuda() + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) self.k_weight_ = qkv_weight_[:, n_embed : n_embed + head_dim] self.v_weight_ = qkv_weight_[:, n_embed + head_dim : n_embed + 2 * head_dim] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"transformer.h.{self.layer_num_}.attn.c_attn.bias" in weights: @@ -58,11 +53,11 @@ def _load_qkvo_weights(self, weights): :, split_n_embed * self.tp_rank_ : split_n_embed * (self.tp_rank_ + 1) ] self.o_weight_ = self.o_weight_.transpose(0, 1).contiguous().to(self.data_type_) - self.o_weight_ = self.o_weight_.cuda() + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) if f"transformer.h.{self.layer_num_}.attn.c_proj.bias" in weights: self.o_bias_ = weights[f"transformer.h.{self.layer_num_}.attn.c_proj.bias"].to(self.data_type_) - self.o_bias_ = self.o_bias_.cuda() + self.o_bias_ = self.o_bias_.cuda() / self.world_size_ def _load_ffn_weights(self, weights): if f"transformer.h.{self.layer_num_}.ln_2.weight" in weights: @@ -76,11 +71,8 @@ def _load_ffn_weights(self, weights): split_inter_size = intermediate_size // self.world_size_ if f"transformer.h.{self.layer_num_}.mlp.c_fc.weight" in weights: self.ffn_1_weight_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_fc.weight"].to(self.data_type_) - self.ffn_1_weight_ = ( + self.ffn_1_weight_ = self.mm_op.preprocess_weight( self.ffn_1_weight_[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] - .transpose(0, 1) - .contiguous() - .cuda() ) if f"transformer.h.{self.layer_num_}.mlp.c_fc.bias" in weights: @@ -95,16 +87,13 @@ def _load_ffn_weights(self, weights): if f"transformer.h.{self.layer_num_}.mlp.c_proj.weight" in weights: self.ffn_2_weight_ = weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.weight"].to(self.data_type_) - self.ffn_2_weight_ = ( + self.ffn_2_weight_ = self.mm_op.preprocess_weight( self.ffn_2_weight_[:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)] - .transpose(0, 1) - .contiguous() - .cuda() ) if f"transformer.h.{self.layer_num_}.mlp.c_proj.bias" in weights: self.ffn_2_bias_ = ( weights[f"transformer.h.{self.layer_num_}.mlp.c_proj.bias"].to(self.data_type_).contiguous().cuda() - ) + ) / self.world_size_ return diff --git a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py index 44ed85bc..ea7aee08 100644 --- a/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/starcoder2/layer_infer/transformer_layer_infer.py @@ -48,13 +48,13 @@ def _ffn_norm( def _get_qkv( self, input, cache_kv, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - q = torch.addmm(layer_weight.q_bias_, input.view(-1, self.embed_dim_), layer_weight.q_weight_) - torch.addmm( - layer_weight.kv_bias_, + q = layer_weight.mm_op.apply(input.view(-1, self.embed_dim_), layer_weight.q_weight_, bias=layer_weight.q_bias_) + cache_kv = layer_weight.mm_op.apply( input.view(-1, self.embed_dim_), layer_weight.kv_weight_, + bias=layer_weight.kv_bias_, out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), - ) + ).view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_), self.head_dim_) rotary_emb_fwd( q.view(-1, self.tp_q_head_num_, self.head_dim_), cache_kv[:, 0 : self.tp_k_head_num_, :], @@ -66,24 +66,21 @@ def _get_qkv( def _get_o( self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - o_tensor = torch.addmm( - layer_weight.o_bias_, - input.view(-1, self.tp_o_head_num_ * self.head_dim_), - layer_weight.o_weight_, - beta=1.0 / self.world_size_, + o_tensor = layer_weight.mm_op.apply( + input.view(-1, self.tp_o_head_num_ * self.head_dim_), layer_weight.o_weight_, bias=layer_weight.o_bias_ ) return o_tensor def _ffn( self, input, infer_state: MistralInferStateInfo, layer_weight: Starcoder2TransformerLayerWeight ) -> torch.Tensor: - ffn1_out = torch.addmm(layer_weight.ffn_1_bias_, input.view(-1, self.embed_dim_), layer_weight.ffn_1_weight_) + ffn1_out = layer_weight.mm_op.apply( + input.view(-1, self.embed_dim_), layer_weight.ffn_1_weight_, bias=layer_weight.ffn_1_bias_ + ) input = None gelu_out = torch.nn.functional.gelu(ffn1_out, approximate="tanh") ffn1_out = None - ffn2_out = torch.addmm( - layer_weight.ffn_2_bias_, gelu_out, layer_weight.ffn_2_weight_, beta=1.0 / self.world_size_ - ) + ffn2_out = layer_weight.mm_op.apply(gelu_out, layer_weight.ffn_2_weight_, bias=layer_weight.ffn_2_bias_) gelu_out = None return ffn2_out diff --git a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py index edc94f71..8ff12c54 100644 --- a/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/starcoder2/layer_weights/transformer_layer_weight.py @@ -56,7 +56,7 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][ q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) @@ -64,22 +64,20 @@ def _load_qkvo_weights(self, weights): self.q_bias_ = self._cuda(self.q_bias_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = k_weight_.transpose(0, 1) + self.k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights: self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][ kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) ] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = v_weight_.transpose(0, 1) + self.v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights: self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][ kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1) ] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) self._try_cat_to(["k_bias_", "v_bias_"], "kv_bias_", cat_dim=0) @@ -87,10 +85,10 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) if f"model.layers.{self.layer_num_}.self_attn.o_proj.bias" in weights: self.o_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"] - self.o_bias_ = self._cuda(self.o_bias_) + self.o_bias_ = self._cuda(self.o_bias_) / self.world_size_ return def _load_ffn_weights(self, weights): @@ -107,11 +105,8 @@ def _load_ffn_weights(self, weights): split_inter_size = intermediate_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.c_fc.weight" in weights: self.ffn_1_weight_ = weights[f"model.layers.{self.layer_num_}.mlp.c_fc.weight"].to(self.data_type_) - self.ffn_1_weight_ = ( + self.ffn_1_weight_ = self.mm_op.preprocess_weight( self.ffn_1_weight_[split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), :] - .transpose(0, 1) - .contiguous() - .cuda() ) if f"model.layers.{self.layer_num_}.mlp.c_fc.bias" in weights: @@ -126,16 +121,13 @@ def _load_ffn_weights(self, weights): if f"model.layers.{self.layer_num_}.mlp.c_proj.weight" in weights: self.ffn_2_weight_ = weights[f"model.layers.{self.layer_num_}.mlp.c_proj.weight"].to(self.data_type_) - self.ffn_2_weight_ = ( + self.ffn_2_weight_ = self.mm_op.preprocess_weight( self.ffn_2_weight_[:, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1)] - .transpose(0, 1) - .contiguous() - .cuda() ) if f"model.layers.{self.layer_num_}.mlp.c_proj.bias" in weights: self.ffn_2_bias_ = ( weights[f"model.layers.{self.layer_num_}.mlp.c_proj.bias"].to(self.data_type_).contiguous().cuda() - ) + ) / self.world_size_ return diff --git a/lightllm/models/yi/layer_weights/transformer_layer_weight.py b/lightllm/models/yi/layer_weights/transformer_layer_weight.py index c22c1997..6de8e50b 100644 --- a/lightllm/models/yi/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/yi/layer_weights/transformer_layer_weight.py @@ -27,25 +27,23 @@ def _load_qkvo_weights(self, weights): if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1), :] - self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) + self.q_weight_ = self.mm_op.preprocess_weight(self.q_weight_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] - k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.k_weight_ = k_weight_.transpose(0, 1) + self.k_weight_ = k_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] - v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self.v_weight_ = v_weight_.transpose(0, 1) + self.v_weight_ = v_weight_[kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1), :] - self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=0, handle_func=self.mm_op.preprocess_weight) # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] - self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + self.o_weight_ = self.mm_op.preprocess_weight(self.o_weight_) return def _load_ffn_weights(self, weights): @@ -56,22 +54,20 @@ def _load_ffn_weights(self, weights): split_inter_size = inter_size // self.world_size_ if f"model.layers.{self.layer_num_}.mlp.up_proj.weight" in weights: - up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ + self.up_proj = weights[f"model.layers.{self.layer_num_}.mlp.up_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.up_proj = up_proj.transpose(0, 1) if f"model.layers.{self.layer_num_}.mlp.gate_proj.weight" in weights: - gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ + self.gate_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_proj.weight"][ split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : ] - self.gate_proj = gate_proj.transpose(0, 1) - self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=0, handle_func=self.mm_op.preprocess_weight) if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) ] - self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + self.down_proj = self.mm_op.preprocess_weight(self.down_proj) return