diff --git a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py index 2fa42ece..5169c99f 100644 --- a/lightllm/models/internlm/layer_weights/transformer_layer_weight.py +++ b/lightllm/models/internlm/layer_weights/transformer_layer_weight.py @@ -45,36 +45,37 @@ def _load_qkvo_weights(self, weights): self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) n_embed = self.network_config_["hidden_size"] - split_n_embed = n_embed // self.world_size_ + q_split_n_embed = n_embed // self.world_size_ + kv_split_n_embed = n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] // self.world_size_ # q k v weights for llama if f"model.layers.{self.layer_num_}.self_attn.q_proj.weight" in weights: - self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"][split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] + self.q_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.weight"] + self.q_weight_ = self.q_weight_[q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1), :] self.q_weight_ = self._cuda(self.q_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.q_proj.bias" in weights: - self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.q_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.q_proj.bias"][q_split_n_embed * + self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] self.q_bias_ = self._cuda(self.q_bias_) if f"model.layers.{self.layer_num_}.self_attn.k_proj.weight" in weights: - self.k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"][split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] + self.k_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.weight"] + self.k_weight_ = self.k_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] self.k_weight_ = self._cuda(self.k_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.k_proj.bias" in weights: - self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.k_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.k_proj.bias"][kv_split_n_embed * + self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1)] self.k_bias_ = self._cuda(self.k_bias_) if f"model.layers.{self.layer_num_}.self_attn.v_proj.weight" in weights: - self.v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"][split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1), :] + self.v_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.weight"] + self.v_weight_ = self.v_weight_[kv_split_n_embed * self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1), :] self.v_weight_ = self._cuda(self.v_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.v_proj.bias" in weights: - self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][split_n_embed * - self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.v_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.v_proj.bias"][kv_split_n_embed * + self.tp_rank_: kv_split_n_embed * (self.tp_rank_ + 1)] self.v_bias_ = self._cuda(self.v_bias_) # attention output dense params if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: - self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"][:, - split_n_embed * self.tp_rank_: split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_: q_split_n_embed * (self.tp_rank_ + 1)] self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) if f"model.layers.{self.layer_num_}.self_attn.o_proj.bias" in weights: self.o_bias_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.bias"]