Skip to content

Commit

Permalink
add support for phi3-mini (#433) (#435)
Browse files Browse the repository at this point in the history
Co-authored-by: shihaobai <[email protected]>
  • Loading branch information
hiworldwzj and shihaobai authored Jun 13, 2024
1 parent c8160a4 commit a2206a0
Show file tree
Hide file tree
Showing 16 changed files with 1,454 additions and 31 deletions.
3 changes: 3 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Stablelm](https://huggingface.co/stabilityai/stablelm-2-1_6b)
- [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)
- [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3)
- [CohereForAI](https://huggingface.co/CohereForAI/c4ai-command-r-plus)

> When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'.
Expand All @@ -61,6 +62,8 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
> Stablelm needs to set the parameter '--trust_remote_code'.
> Phi-3 only supports Mini and Small.
## Get started

### Requirements
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,9 @@ def load_hf_weights(self, weights):
if "lm_head.weight" in weights:
# print(weights['lm_head.weight'].shape)
self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :])
tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False)
if tie_word_embeddings:
self.lm_head_weight_ = self.wte_weight_
if "model.norm.weight" in weights:
self.final_norm_weight_ = self._cuda(weights["model.norm.weight"])

Expand Down
142 changes: 111 additions & 31 deletions lightllm/models/llama/model.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
import json
import torch
import math
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
Expand All @@ -17,6 +18,7 @@

logger = init_logger(__name__)


class LlamaTpPartModel(TpPartBaseModel):
# weight class
pre_and_post_weight_class = LlamaPreAndPostLayerWeight
Expand All @@ -34,14 +36,14 @@ class LlamaTpPartModel(TpPartBaseModel):
def __init__(self, kvargs):
super().__init__(kvargs)
return

def _init_config(self):
super()._init_config()
# rename key
# repair_config()
self._reset_num_key_value_heads()
return
return

def _reset_num_key_value_heads(self):
if "num_key_value_heads" not in self.config:
self.config["num_key_value_heads"] = self.config["num_attention_heads"]
Expand All @@ -52,13 +54,15 @@ def _verify_params(self):
assert self.config["num_key_value_heads"] % self.world_size_ == 0
assert self.config["num_attention_heads"] % self.world_size_ == 0
return

def _init_mem_manager(self):
self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"])
self.mem_manager = select_mem_manager_class(self.mode)(
self.max_total_token_num,
dtype=self.data_type,
head_num=self.config["num_key_value_heads"] // self.world_size_,
head_dim=self.config["hidden_size"] // self.config["num_attention_heads"],
layer_num=self.config["num_hidden_layers"],
)
return

def _init_custom(self):
Expand All @@ -67,37 +71,51 @@ def _init_custom(self):
"""
if self.config.get("use_rope_yarn", False):
self._init_to_get_yarn_rotary()
elif self.config.get("use_dynamic_ntk", False) or (self.config.get("rope_scaling", None) is not None and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic"):
elif self.config.get("use_dynamic_ntk", False) or (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic"
):
self._init_to_get_dynamic_ntk_rotary()
elif (
self.config.get("rope_scaling", None) is not None
and self.config.get("rope_scaling", {}).get("type", "base") == "su"
):
self._init_to_su_rotary()
else:
self._init_to_get_rotary()
return

def _init_weights(self):
self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.pre_post_weight = self.pre_and_post_weight_class(
self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
)
self.trans_layers_weight = [
self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode)
self.transformer_weight_class(
i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode
)
for i in range(self.config["n_layer"])
]
if self.load_way == 'HF':
if self.load_way == "HF":
load_hf_weights(
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict)
weight_dict=self.weight_dict,
)
else:
load_ds_weights(
self.data_type,
weight_dir=self.weight_dir_,
pre_post_layer=self.pre_post_weight,
transformer_layer_list=self.trans_layers_weight,
weight_dict=self.weight_dict,
prefix='model.layers.',
num_layer=self.config["n_layer"])
prefix="model.layers.",
num_layer=self.config["n_layer"],
)
self.pre_post_weight.verify_load()
[weight.verify_load() for weight in self.trans_layers_weight]
return
[weight.verify_load() for weight in self.trans_layers_weight]
return

def _init_to_get_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
Expand All @@ -112,8 +130,7 @@ def _init_to_get_rotary(self, default_base=10000):
max_seq_len = self.config["max_sequence_length"]
else:
max_position_embeddings = self.config.get(
"max_position_embeddings",
2048 if base <= 10000.0 + 1e-5 else 16384
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
)
max_seq_len = max_position_embeddings * rope_scaling_factor

Expand All @@ -124,11 +141,13 @@ def _init_to_get_rotary(self, default_base=10000):
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim-2))) #Base change formula
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except:
pass

inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
inv_freq = 1.0 / (
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)

Expand All @@ -147,24 +166,37 @@ def _init_to_get_dynamic_ntk_rotary(self):
max_seq_len = max(self.max_seq_length, max_position_embeddings)
self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda")

inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))

inv_freq = 1.0 / (
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).cuda()

for seq_loc_index in range(max_position_embeddings, max_seq_len, 1):
new_base = base * ((scaling_factor * (seq_loc_index + 1) / max_position_embeddings) -(scaling_factor - 1)) ** (partial_head_dim / (partial_head_dim - 2))
inv_freq = 1.0 / (new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim))
t = torch.tensor([seq_loc_index,], device="cpu", dtype=torch.float32)
new_base = base * (
(scaling_factor * (seq_loc_index + 1) / max_position_embeddings) - (scaling_factor - 1)
) ** (partial_head_dim / (partial_head_dim - 2))
inv_freq = 1.0 / (
new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = torch.tensor(
[
seq_loc_index,
],
device="cpu",
dtype=torch.float32,
)
freqs = torch.outer(t, inv_freq)
self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda()
self._cos_cached[seq_loc_index : seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached[seq_loc_index : seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda()
return

def _init_to_get_yarn_rotary(self):
from .yarn_rotary_utils import find_correction_range, linear_ramp_mask, get_mscale

dim = self.head_dim_
max_position_embeddings = self.config.get("max_position_embeddings", 2048)
base = self.config.get("rope_theta", 10000.0)
Expand All @@ -183,10 +215,12 @@ def _init_to_get_yarn_rotary(self):
inv_freq_interpolation = 1.0 / (scale * pos_freqs)

low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings)
inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).float().cuda()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq_mask = (
1 - linear_ramp_mask(low, high, dim // 2).float().cuda()
) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation
inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask

mscale = float(get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation
mscale = float(get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation

# Build here to make `torch.jit.trace` work.
max_seq_len_cached = max_position_embeddings
Expand All @@ -199,4 +233,50 @@ def _init_to_get_yarn_rotary(self):

return

def _init_to_su_rotary(self):
rope_scaling = self.config["rope_scaling"]
short_factor = rope_scaling["short_factor"]
long_factor = rope_scaling["long_factor"]
original_max_position_embeddings = self.config["original_max_position_embeddings"]
max_position_embeddings = self.config.get("max_position_embeddings", original_max_position_embeddings)
base = self.config.get("rope_theta", 10000.0)
short_factor = torch.tensor(short_factor, dtype=torch.float32, device="cpu")
long_factor = torch.tensor(long_factor, dtype=torch.float32, device="cpu")

scale = max_position_embeddings / original_max_position_embeddings
if scale <= 1.0:
rope_scaling_factor = 1.0
else:
rope_scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings))

max_seq_len = max(self.max_seq_length, max_position_embeddings)
self._cos_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda")
self._sin_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda")

inv_freq = 1.0 / (
short_factor
* base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_)
)
t = torch.arange(original_max_position_embeddings, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[0:original_max_position_embeddings, :] = (
(torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda()
)
self._sin_cached[0:original_max_position_embeddings, :] = (
(torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda()
)

inv_freq = 1.0 / (
long_factor
* base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_)
)
t = torch.arange(original_max_position_embeddings, max_seq_len, device="cpu", dtype=torch.float32)
freqs = torch.outer(t, inv_freq)
self._cos_cached[original_max_position_embeddings:, :] = (
(torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda()
)
self._sin_cached[original_max_position_embeddings:, :] = (
(torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda()
)

return
Empty file.
Empty file.
90 changes: 90 additions & 0 deletions lightllm/models/phi3/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
import torch
import torch.functional as F
import torch.distributed as dist
import numpy as np
from functools import partial

from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import (
context_attention_fwd,
context_attention_fwd_no_prompt_cache,
)
from lightllm.models.phi3.triton_kernel.destindex_copy_kv import destindex_copy_kv
from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight
from lightllm.models.llama.infer_struct import LlamaInferStateInfo


class Phi3TransformerLayerInfer(LlamaTransformerLayerInfer):
""" """

def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]):
super().__init__(layer_num, tp_rank, world_size, network_config, mode)
return

def _bind_attention(self):
self._context_attention_kernel = partial(Phi3TransformerLayerInfer._context_attention_kernel, self)
self._copy_kv_to_mem_cache = partial(Phi3TransformerLayerInfer._copy_kv_to_mem_cache_normal, self)
self._token_attention_kernel = partial(Phi3TransformerLayerInfer._token_decode_attention_flashdecoding, self)
return

def _get_qkv(self, input_emb, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight):
q = torch.mm(input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_)
torch.mm(
input_emb.view(-1, self.embed_dim_),
layer_weight.kv_weight_,
out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_),
)
rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, 0 : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
)
return q, cache_kv

def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager):
destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_])
return

def _context_attention_kernel(
self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None
) -> torch.Tensor:
o_tensor = torch.empty_like(q) if out is None else out
if infer_state.use_dynamic_prompt_cache:
kv = infer_state.mem_manager.kv_buffer[self.layer_num_]
context_attention_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_req_idx,
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.b_ready_cache_len,
infer_state.max_len_in_batch,
infer_state.req_manager.req_to_token_indexs,
)
else:
context_attention_fwd_no_prompt_cache(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
kv[:, 0 : self.tp_k_head_num_, :],
kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :],
o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_),
infer_state.b_start_loc,
infer_state.b_seq_len,
infer_state.max_len_in_batch,
)

return o_tensor

def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
from lightllm.models.phi3.triton_kernel.flash_decoding import token_decode_attention_flash_decoding

cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :]
cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][
:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :
]
return token_decode_attention_flash_decoding(
q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out
)
Empty file.
Loading

0 comments on commit a2206a0

Please sign in to comment.