From a2206a060b9bba8b699e55060e0bbc88f65c6c23 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 13 Jun 2024 11:52:21 +0800 Subject: [PATCH] add support for phi3-mini (#433) (#435) Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com> --- README.md | 3 + .../pre_and_post_layer_weight.py | 3 + lightllm/models/llama/model.py | 142 ++++-- lightllm/models/phi3/__init__.py | 0 lightllm/models/phi3/layer_infer/__init__.py | 0 .../layer_infer/transformer_layer_infer.py | 90 ++++ .../models/phi3/layer_weights/__init__.py | 0 .../layer_weights/transformer_layer_weight.py | 99 ++++ lightllm/models/phi3/model.py | 17 + .../context_flashattention_nopad.py | 433 ++++++++++++++++++ .../phi3/triton_kernel/destindex_copy_kv.py | 192 ++++++++ .../phi3/triton_kernel/flash_decoding.py | 39 ++ .../triton_kernel/flash_decoding_stage1.py | 162 +++++++ .../triton_kernel/flash_decoding_stage2.py | 85 ++++ .../models/phi3/triton_kernel/rotary_emb.py | 217 +++++++++ .../model_infer/mode_backend/base_backend.py | 3 + 16 files changed, 1454 insertions(+), 31 deletions(-) create mode 100644 lightllm/models/phi3/__init__.py create mode 100644 lightllm/models/phi3/layer_infer/__init__.py create mode 100755 lightllm/models/phi3/layer_infer/transformer_layer_infer.py create mode 100644 lightllm/models/phi3/layer_weights/__init__.py create mode 100755 lightllm/models/phi3/layer_weights/transformer_layer_weight.py create mode 100644 lightllm/models/phi3/model.py create mode 100644 lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py create mode 100644 lightllm/models/phi3/triton_kernel/destindex_copy_kv.py create mode 100644 lightllm/models/phi3/triton_kernel/flash_decoding.py create mode 100644 lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py create mode 100644 lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py create mode 100755 lightllm/models/phi3/triton_kernel/rotary_emb.py diff --git a/README.md b/README.md index b1e1b2f9..3f9107a6 100644 --- a/README.md +++ b/README.md @@ -49,6 +49,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram - [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - [Stablelm](https://huggingface.co/stabilityai/stablelm-2-1_6b) - [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) +- [Phi-3](https://huggingface.co/collections/microsoft/phi-3-6626e15e9585a200d2d761e3) - [CohereForAI](https://huggingface.co/CohereForAI/c4ai-command-r-plus) > When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'. @@ -61,6 +62,8 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram > Stablelm needs to set the parameter '--trust_remote_code'. +> Phi-3 only supports Mini and Small. + ## Get started ### Requirements diff --git a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py index dba0e7ca..b6ed8fc0 100644 --- a/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py +++ b/lightllm/models/llama/layer_weights/pre_and_post_layer_weight.py @@ -19,6 +19,9 @@ def load_hf_weights(self, weights): if "lm_head.weight" in weights: # print(weights['lm_head.weight'].shape) self.lm_head_weight_ = self._cuda(weights["lm_head.weight"][split_start:split_end, :]) + tie_word_embeddings = self.network_config_.get("tie_word_embeddings", False) + if tie_word_embeddings: + self.lm_head_weight_ = self.wte_weight_ if "model.norm.weight" in weights: self.final_norm_weight_ = self._cuda(weights["model.norm.weight"]) diff --git a/lightllm/models/llama/model.py b/lightllm/models/llama/model.py index 7ec4ab3b..9f8e1aba 100644 --- a/lightllm/models/llama/model.py +++ b/lightllm/models/llama/model.py @@ -1,6 +1,7 @@ import os import json import torch +import math from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer @@ -17,6 +18,7 @@ logger = init_logger(__name__) + class LlamaTpPartModel(TpPartBaseModel): # weight class pre_and_post_weight_class = LlamaPreAndPostLayerWeight @@ -34,14 +36,14 @@ class LlamaTpPartModel(TpPartBaseModel): def __init__(self, kvargs): super().__init__(kvargs) return - + def _init_config(self): super()._init_config() # rename key # repair_config() self._reset_num_key_value_heads() - return - + return + def _reset_num_key_value_heads(self): if "num_key_value_heads" not in self.config: self.config["num_key_value_heads"] = self.config["num_attention_heads"] @@ -52,13 +54,15 @@ def _verify_params(self): assert self.config["num_key_value_heads"] % self.world_size_ == 0 assert self.config["num_attention_heads"] % self.world_size_ == 0 return - + def _init_mem_manager(self): - self.mem_manager = select_mem_manager_class(self.mode)(self.max_total_token_num, - dtype=self.data_type, - head_num=self.config["num_key_value_heads"] // self.world_size_, - head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], - layer_num=self.config["num_hidden_layers"]) + self.mem_manager = select_mem_manager_class(self.mode)( + self.max_total_token_num, + dtype=self.data_type, + head_num=self.config["num_key_value_heads"] // self.world_size_, + head_dim=self.config["hidden_size"] // self.config["num_attention_heads"], + layer_num=self.config["num_hidden_layers"], + ) return def _init_custom(self): @@ -67,25 +71,38 @@ def _init_custom(self): """ if self.config.get("use_rope_yarn", False): self._init_to_get_yarn_rotary() - elif self.config.get("use_dynamic_ntk", False) or (self.config.get("rope_scaling", None) is not None and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic"): + elif self.config.get("use_dynamic_ntk", False) or ( + self.config.get("rope_scaling", None) is not None + and self.config.get("rope_scaling", {}).get("type", "base") == "dynamic" + ): self._init_to_get_dynamic_ntk_rotary() + elif ( + self.config.get("rope_scaling", None) is not None + and self.config.get("rope_scaling", {}).get("type", "base") == "su" + ): + self._init_to_su_rotary() else: self._init_to_get_rotary() return def _init_weights(self): - self.pre_post_weight = self.pre_and_post_weight_class(self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) + self.pre_post_weight = self.pre_and_post_weight_class( + self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode + ) self.trans_layers_weight = [ - self.transformer_weight_class(i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode) + self.transformer_weight_class( + i, self.tp_rank_, self.world_size_, self.data_type, network_config=self.config, mode=self.mode + ) for i in range(self.config["n_layer"]) ] - if self.load_way == 'HF': + if self.load_way == "HF": load_hf_weights( self.data_type, weight_dir=self.weight_dir_, pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, - weight_dict=self.weight_dict) + weight_dict=self.weight_dict, + ) else: load_ds_weights( self.data_type, @@ -93,11 +110,12 @@ def _init_weights(self): pre_post_layer=self.pre_post_weight, transformer_layer_list=self.trans_layers_weight, weight_dict=self.weight_dict, - prefix='model.layers.', - num_layer=self.config["n_layer"]) + prefix="model.layers.", + num_layer=self.config["n_layer"], + ) self.pre_post_weight.verify_load() - [weight.verify_load() for weight in self.trans_layers_weight] - return + [weight.verify_load() for weight in self.trans_layers_weight] + return def _init_to_get_rotary(self, default_base=10000): partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) @@ -112,8 +130,7 @@ def _init_to_get_rotary(self, default_base=10000): max_seq_len = self.config["max_sequence_length"] else: max_position_embeddings = self.config.get( - "max_position_embeddings", - 2048 if base <= 10000.0 + 1e-5 else 16384 + "max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384 ) max_seq_len = max_position_embeddings * rope_scaling_factor @@ -124,11 +141,13 @@ def _init_to_get_rotary(self, default_base=10000): if ntk_alpha > 1: logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") max_seq_len *= ntk_alpha - base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim-2))) #Base change formula + base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula except: pass - inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)) + inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor freqs = torch.outer(t, inv_freq) @@ -147,24 +166,37 @@ def _init_to_get_dynamic_ntk_rotary(self): max_seq_len = max(self.max_seq_length, max_position_embeddings) self._cos_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda") self._sin_cached = torch.zeros((max_seq_len, partial_head_dim // 2), dtype=self.data_type, device="cuda") - - inv_freq = 1.0 / (base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)) + + inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) t = torch.arange(max_position_embeddings, device="cpu", dtype=torch.float32) freqs = torch.outer(t, inv_freq) self._cos_cached[0:max_position_embeddings, :] = torch.cos(freqs).to(self.data_type).cuda() self._sin_cached[0:max_position_embeddings, :] = torch.sin(freqs).to(self.data_type).cuda() for seq_loc_index in range(max_position_embeddings, max_seq_len, 1): - new_base = base * ((scaling_factor * (seq_loc_index + 1) / max_position_embeddings) -(scaling_factor - 1)) ** (partial_head_dim / (partial_head_dim - 2)) - inv_freq = 1.0 / (new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)) - t = torch.tensor([seq_loc_index,], device="cpu", dtype=torch.float32) + new_base = base * ( + (scaling_factor * (seq_loc_index + 1) / max_position_embeddings) - (scaling_factor - 1) + ) ** (partial_head_dim / (partial_head_dim - 2)) + inv_freq = 1.0 / ( + new_base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) + t = torch.tensor( + [ + seq_loc_index, + ], + device="cpu", + dtype=torch.float32, + ) freqs = torch.outer(t, inv_freq) - self._cos_cached[seq_loc_index:seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda() - self._sin_cached[seq_loc_index:seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda() + self._cos_cached[seq_loc_index : seq_loc_index + 1, :] = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached[seq_loc_index : seq_loc_index + 1, :] = torch.sin(freqs).to(self.data_type).cuda() return def _init_to_get_yarn_rotary(self): from .yarn_rotary_utils import find_correction_range, linear_ramp_mask, get_mscale + dim = self.head_dim_ max_position_embeddings = self.config.get("max_position_embeddings", 2048) base = self.config.get("rope_theta", 10000.0) @@ -183,10 +215,12 @@ def _init_to_get_yarn_rotary(self): inv_freq_interpolation = 1.0 / (scale * pos_freqs) low, high = find_correction_range(beta_fast, beta_slow, dim, base, original_max_position_embeddings) - inv_freq_mask = (1 - linear_ramp_mask(low, high, dim // 2).float().cuda()) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation + inv_freq_mask = ( + 1 - linear_ramp_mask(low, high, dim // 2).float().cuda() + ) * extrapolation_factor # Get n-d rotational scaling corrected for extrapolation inv_freq = inv_freq_interpolation * (1 - inv_freq_mask) + inv_freq_extrapolation * inv_freq_mask - mscale = float(get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation + mscale = float(get_mscale(scale) * attn_factor) # Get n-d magnitude scaling corrected for interpolation # Build here to make `torch.jit.trace` work. max_seq_len_cached = max_position_embeddings @@ -199,4 +233,50 @@ def _init_to_get_yarn_rotary(self): return + def _init_to_su_rotary(self): + rope_scaling = self.config["rope_scaling"] + short_factor = rope_scaling["short_factor"] + long_factor = rope_scaling["long_factor"] + original_max_position_embeddings = self.config["original_max_position_embeddings"] + max_position_embeddings = self.config.get("max_position_embeddings", original_max_position_embeddings) + base = self.config.get("rope_theta", 10000.0) + short_factor = torch.tensor(short_factor, dtype=torch.float32, device="cpu") + long_factor = torch.tensor(long_factor, dtype=torch.float32, device="cpu") + + scale = max_position_embeddings / original_max_position_embeddings + if scale <= 1.0: + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = math.sqrt(1 + math.log(scale) / math.log(original_max_position_embeddings)) + + max_seq_len = max(self.max_seq_length, max_position_embeddings) + self._cos_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda") + self._sin_cached = torch.zeros((max_seq_len, self.head_dim_ // 2), dtype=self.data_type, device="cuda") + + inv_freq = 1.0 / ( + short_factor + * base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_) + ) + t = torch.arange(original_max_position_embeddings, device="cpu", dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached[0:original_max_position_embeddings, :] = ( + (torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda() + ) + self._sin_cached[0:original_max_position_embeddings, :] = ( + (torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda() + ) + inv_freq = 1.0 / ( + long_factor + * base ** (torch.arange(0, self.head_dim_, 2, device="cpu", dtype=torch.float32) / self.head_dim_) + ) + t = torch.arange(original_max_position_embeddings, max_seq_len, device="cpu", dtype=torch.float32) + freqs = torch.outer(t, inv_freq) + self._cos_cached[original_max_position_embeddings:, :] = ( + (torch.cos(freqs) * rope_scaling_factor).to(self.data_type).cuda() + ) + self._sin_cached[original_max_position_embeddings:, :] = ( + (torch.sin(freqs) * rope_scaling_factor).to(self.data_type).cuda() + ) + + return diff --git a/lightllm/models/phi3/__init__.py b/lightllm/models/phi3/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/phi3/layer_infer/__init__.py b/lightllm/models/phi3/layer_infer/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/phi3/layer_infer/transformer_layer_infer.py b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py new file mode 100755 index 00000000..66614a3f --- /dev/null +++ b/lightllm/models/phi3/layer_infer/transformer_layer_infer.py @@ -0,0 +1,90 @@ +import torch +import torch.functional as F +import torch.distributed as dist +import numpy as np +from functools import partial + +from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer +from lightllm.models.phi3.triton_kernel.rotary_emb import rotary_emb_fwd +from lightllm.models.phi3.triton_kernel.context_flashattention_nopad import ( + context_attention_fwd, + context_attention_fwd_no_prompt_cache, +) +from lightllm.models.phi3.triton_kernel.destindex_copy_kv import destindex_copy_kv +from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight +from lightllm.models.llama.infer_struct import LlamaInferStateInfo + + +class Phi3TransformerLayerInfer(LlamaTransformerLayerInfer): + """ """ + + def __init__(self, layer_num, tp_rank, world_size, network_config, mode=[]): + super().__init__(layer_num, tp_rank, world_size, network_config, mode) + return + + def _bind_attention(self): + self._context_attention_kernel = partial(Phi3TransformerLayerInfer._context_attention_kernel, self) + self._copy_kv_to_mem_cache = partial(Phi3TransformerLayerInfer._copy_kv_to_mem_cache_normal, self) + self._token_attention_kernel = partial(Phi3TransformerLayerInfer._token_decode_attention_flashdecoding, self) + return + + def _get_qkv(self, input_emb, cache_kv, infer_state: LlamaInferStateInfo, layer_weight: Phi3TransformerLayerWeight): + q = torch.mm(input_emb.view(-1, self.embed_dim_), layer_weight.q_weight_) + torch.mm( + input_emb.view(-1, self.embed_dim_), + layer_weight.kv_weight_, + out=cache_kv.view(-1, (self.tp_k_head_num_ + self.tp_v_head_num_) * self.head_dim_), + ) + rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) + return q, cache_kv + + def _copy_kv_to_mem_cache(self, buffer, mem_index, mem_manager): + destindex_copy_kv(buffer, mem_index, mem_manager.kv_buffer[self.layer_num_]) + return + + def _context_attention_kernel( + self, q, kv, infer_state: LlamaInferStateInfo, layer_weight, out=None + ) -> torch.Tensor: + o_tensor = torch.empty_like(q) if out is None else out + if infer_state.use_dynamic_prompt_cache: + kv = infer_state.mem_manager.kv_buffer[self.layer_num_] + context_attention_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.b_req_idx, + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.b_ready_cache_len, + infer_state.max_len_in_batch, + infer_state.req_manager.req_to_token_indexs, + ) + else: + context_attention_fwd_no_prompt_cache( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + kv[:, 0 : self.tp_k_head_num_, :], + kv[:, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, :], + o_tensor.view(-1, self.tp_q_head_num_, self.head_dim_), + infer_state.b_start_loc, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + ) + + return o_tensor + + def _token_decode_attention_flashdecoding(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): + from lightllm.models.phi3.triton_kernel.flash_decoding import token_decode_attention_flash_decoding + + cache_k = infer_state.mem_manager.kv_buffer[self.layer_num_][:, 0 : self.tp_k_head_num_, :] + cache_v = infer_state.mem_manager.kv_buffer[self.layer_num_][ + :, self.tp_k_head_num_ : self.tp_k_head_num_ + self.tp_v_head_num_, : + ] + return token_decode_attention_flash_decoding( + q, infer_state, self.tp_q_head_num_, self.head_dim_, cache_k, cache_v, out=out + ) diff --git a/lightllm/models/phi3/layer_weights/__init__.py b/lightllm/models/phi3/layer_weights/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/lightllm/models/phi3/layer_weights/transformer_layer_weight.py b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py new file mode 100755 index 00000000..1400eaba --- /dev/null +++ b/lightllm/models/phi3/layer_weights/transformer_layer_weight.py @@ -0,0 +1,99 @@ +import torch +import math +import numpy as np +from lightllm.common.basemodel import TransformerLayerWeight +from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight + + +class Phi3TransformerLayerWeight(LlamaTransformerLayerWeight): + def __init__(self, layer_num, tp_rank, world_size, data_type, network_config, mode=[]): + super().__init__(layer_num, tp_rank, world_size, data_type, network_config, mode) + return + + def verify_load(self): + errors = "weights load not ok" + weights = [ + self.att_norm_weight_, + self.q_weight_, + self.kv_weight_, + self.o_weight_, + self.ffn_norm_weight_, + self.gate_up_proj, + self.down_proj, + ] + for i in range(len(weights)): + assert weights[i] is not None, "index:" + str(i) + " " + errors + return + + def _load_qkvo_weights(self, weights): + # input layernorm params + if f"model.layers.{self.layer_num_}.input_layernorm.weight" in weights: + self.att_norm_weight_ = self._cuda(weights[f"model.layers.{self.layer_num_}.input_layernorm.weight"]) + + n_embed = self.network_config_["hidden_size"] + q_split_n_embed = n_embed // self.world_size_ + kv_n_embed = ( + n_embed // self.network_config_["num_attention_heads"] * self.network_config_["num_key_value_heads"] + ) + kv_split_n_embed = ( + n_embed + // self.network_config_["num_attention_heads"] + * self.network_config_["num_key_value_heads"] + // self.world_size_ + ) + if f"model.layers.{self.layer_num_}.self_attn.qkv_proj.weight" in weights: + qkv_weight_ = ( + weights[f"model.layers.{self.layer_num_}.self_attn.qkv_proj.weight"] + .transpose(0, 1) + .contiguous() + .to(self.data_type_) + ) + self.q_weight_ = qkv_weight_[:, :n_embed][ + :, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1) + ] + self.q_weight_ = self._cuda(self.q_weight_) + k_weight_ = qkv_weight_[:, n_embed : n_embed + kv_n_embed] + self.k_weight_ = k_weight_[:, kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)] + + v_weight_ = qkv_weight_[:, n_embed + kv_n_embed : n_embed + 2 * kv_n_embed] + self.v_weight_ = v_weight_[:, kv_split_n_embed * self.tp_rank_ : kv_split_n_embed * (self.tp_rank_ + 1)] + + self._try_cat_to(["k_weight_", "v_weight_"], "kv_weight_", cat_dim=1) + + # attention output dense params + if f"model.layers.{self.layer_num_}.self_attn.o_proj.weight" in weights: + self.o_weight_ = weights[f"model.layers.{self.layer_num_}.self_attn.o_proj.weight"] + self.o_weight_ = self.o_weight_[:, q_split_n_embed * self.tp_rank_ : q_split_n_embed * (self.tp_rank_ + 1)] + self.o_weight_ = self._cuda(self.o_weight_.transpose(0, 1)) + + return + + def _load_ffn_weights(self, weights): + if f"model.layers.{self.layer_num_}.post_attention_layernorm.weight" in weights: + self.ffn_norm_weight_ = self._cuda( + weights[f"model.layers.{self.layer_num_}.post_attention_layernorm.weight"] + ) + + inter_size = self.network_config_["intermediate_size"] + split_inter_size = inter_size // self.world_size_ + + if f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight" in weights: + gate_up_proj = weights[f"model.layers.{self.layer_num_}.mlp.gate_up_proj.weight"] + gate_proj = gate_up_proj[0:inter_size][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.gate_proj = gate_proj.transpose(0, 1) + + up_proj = gate_up_proj[inter_size:][ + split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1), : + ] + self.up_proj = up_proj.transpose(0, 1) + + self._try_cat_to(["gate_proj", "up_proj"], "gate_up_proj", cat_dim=1) + + if f"model.layers.{self.layer_num_}.mlp.down_proj.weight" in weights: + self.down_proj = weights[f"model.layers.{self.layer_num_}.mlp.down_proj.weight"][ + :, split_inter_size * self.tp_rank_ : split_inter_size * (self.tp_rank_ + 1) + ] + self.down_proj = self._cuda(self.down_proj.transpose(0, 1)) + return diff --git a/lightllm/models/phi3/model.py b/lightllm/models/phi3/model.py new file mode 100644 index 00000000..79c573bf --- /dev/null +++ b/lightllm/models/phi3/model.py @@ -0,0 +1,17 @@ +import os +import json +import torch + +from lightllm.models.phi3.layer_weights.transformer_layer_weight import Phi3TransformerLayerWeight +from lightllm.models.phi3.layer_infer.transformer_layer_infer import Phi3TransformerLayerInfer +from lightllm.models.llama.model import LlamaTpPartModel + + +class Phi3TpPartModel(LlamaTpPartModel): + # weight class + transformer_weight_class = Phi3TransformerLayerWeight + + transformer_layer_infer_class = Phi3TransformerLayerInfer + + def __init__(self, kvargs): + super().__init__(kvargs) diff --git a/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py b/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py new file mode 100644 index 00000000..0538fe9e --- /dev/null +++ b/lightllm/models/phi3/triton_kernel/context_flashattention_nopad.py @@ -0,0 +1,433 @@ +import torch + +import triton +import triton.language as tl +import math +import torch.nn.functional as F + +TESLA = "Tesla" in torch.cuda.get_device_name(0) + + +@triton.jit +def _fwd_kernel( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + Out, + Req_to_tokens, + B_req_idx, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + stride_req_to_tokens_b, + stride_req_to_tokens_s, + kv_group_num, + b_prompt_cache_len, + head_dim: tl.constexpr, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + prompt_cache_len = tl.load(b_prompt_cache_len + cur_batch) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) - prompt_cache_len + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + + q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0) + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + block_end_loc = tl.minimum((start_m + 1) * BLOCK_M + prompt_cache_len, cur_batch_seq_len + prompt_cache_len) + + for start_n in range(0, block_mask * block_end_loc, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + kv_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + stride_req_to_tokens_s * (start_n + offs_n), + mask=(start_n + offs_n) < block_end_loc, + other=0, + ) + off_k = kv_loc[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + k = tl.load( + K + off_k, mask=((start_n + offs_n[None, :]) < block_end_loc) & (offs_d[:, None] < head_dim), other=0.0 + ) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] + prompt_cache_len >= start_n + offs_n[None, :], qk, float("-100000000.0")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc_scale = tl.where(offs_m + prompt_cache_len >= start_n, acc_scale, 1.0) + acc = acc * acc_scale[:, None] + # update acc + off_v = kv_loc[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + v = tl.load( + V + off_v, mask=((start_n + offs_n[:, None]) < block_end_loc) & (offs_d[None, :] < head_dim), other=0.0 + ) + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim)) + return + + +@torch.no_grad() +def context_attention_fwd( + q, k, v, o, b_req_idx, b_start_loc, b_seq_len, b_prompt_cache_len, max_input_len, req_to_token_indexs +): + BLOCK = 128 if not TESLA else 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + head_dim = Lq + BLOCK_DMODEL = triton.next_power_of_2(head_dim) + + sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + req_to_token_indexs, + b_req_idx, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + req_to_token_indexs.stride(0), + req_to_token_indexs.stride(1), + kv_group_num=kv_group_num, + b_prompt_cache_len=b_prompt_cache_len, + head_dim=head_dim, + BLOCK_M=BLOCK, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _fwd_kernel_no_prompt_cache( + Q, + K, + V, + sm_scale, + B_Start_Loc, + B_Seqlen, # B_LOC 内部记录每个batch 输入的真实位置, B_SEQ_len 记录当前输入的真实长度 + Out, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_obs, + stride_oh, + stride_od, + kv_group_num, + head_dim, + BLOCK_M: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + start_m = tl.program_id(2) + + cur_kv_head = cur_head // kv_group_num + + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_in_all_start_index = tl.load(B_Start_Loc + cur_batch) + + block_start_loc = BLOCK_M * start_m + + # initialize offsets + offs_n = tl.arange(0, BLOCK_N) + offs_d = tl.arange(0, BLOCK_DMODEL) + offs_m = start_m * BLOCK_M + tl.arange(0, BLOCK_M) + off_q = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_qbs + + cur_head * stride_qh + + offs_d[None, :] * stride_qd + ) + off_k = offs_n[None, :] * stride_kbs + cur_kv_head * stride_kh + offs_d[:, None] * stride_kd + off_v = offs_n[:, None] * stride_vbs + cur_kv_head * stride_vh + offs_d[None, :] * stride_vd + + q = tl.load(Q + off_q, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim), other=0.0) + + k_ptrs = K + off_k + v_ptrs = V + off_v + + # initialize pointer to m and l + m_i = tl.zeros([BLOCK_M], dtype=tl.float32) - float("inf") + l_i = tl.zeros([BLOCK_M], dtype=tl.float32) + acc = tl.zeros([BLOCK_M, BLOCK_DMODEL], dtype=tl.float32) + + block_mask = tl.where(block_start_loc < cur_batch_seq_len, 1, 0) + + for start_n in range(0, block_mask * (start_m + 1) * BLOCK_M, BLOCK_N): + start_n = tl.multiple_of(start_n, BLOCK_N) + # -- compute qk ---- + k = tl.load( + k_ptrs + (cur_batch_in_all_start_index + start_n) * stride_kbs, + mask=((start_n + offs_n[None, :]) < cur_batch_seq_len) & (offs_d[:, None] < head_dim), + other=0.0, + ) + # mask = tl.load(mask_ptrs + start_n, mask=start_n + offs_n < cur_batch_end_loc, other=0.0) + + qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) + qk += tl.dot(q, k) + qk *= sm_scale + qk = tl.where(offs_m[:, None] >= (start_n + offs_n[None, :]), qk, float("-inf")) + + # -- compute m_ij, p, l_ij + m_ij = tl.max(qk, 1) + p = tl.exp(qk - m_ij[:, None]) + l_ij = tl.sum(p, 1) + # -- update m_i and l_i + m_i_new = tl.maximum(m_i, m_ij) + alpha = tl.exp(m_i - m_i_new) + beta = tl.exp(m_ij - m_i_new) + l_i_new = alpha * l_i + beta * l_ij + # -- update output accumulator -- + # scale p + p_scale = beta / l_i_new + p = p * p_scale[:, None] + # scale acc + acc_scale = l_i / l_i_new * alpha + acc = acc * acc_scale[:, None] + # update acc + v = tl.load( + v_ptrs + (cur_batch_in_all_start_index + start_n) * stride_vbs, + mask=((start_n + offs_n[:, None]) < cur_batch_seq_len) & (offs_d[None, :] < head_dim), + other=0.0, + ) + + p = p.to(v.dtype) + acc += tl.dot(p, v) + # update m_i and l_i + l_i = l_i_new + m_i = m_i_new + # initialize pointers to output + off_o = ( + (cur_batch_in_all_start_index + offs_m[:, None]) * stride_obs + + cur_head * stride_oh + + offs_d[None, :] * stride_od + ) + out_ptrs = Out + off_o + tl.store(out_ptrs, acc, mask=(offs_m[:, None] < cur_batch_seq_len) & (offs_d[None, :] < head_dim)) + return + + +@torch.no_grad() +def context_attention_fwd_no_prompt_cache(q, k, v, o, b_start_loc, b_seq_len, max_input_len): + BLOCK = 128 if not TESLA else 64 + # shape constraints + Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1] + assert Lq == Lk and Lk == Lv + head_dim = Lq + BLOCK_DMODEL = triton.next_power_of_2(head_dim) + sm_scale = 1.0 / (Lq ** 0.5) # 计算scale系数 + batch, head = b_seq_len.shape[0], q.shape[1] + kv_group_num = q.shape[1] // k.shape[1] + + grid = (batch, head, triton.cdiv(max_input_len, BLOCK)) # batch, head, + + num_warps = 4 if Lk <= 64 else 8 + _fwd_kernel_no_prompt_cache[grid]( + q, + k, + v, + sm_scale, + b_start_loc, + b_seq_len, + o, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + o.stride(0), + o.stride(1), + o.stride(2), + kv_group_num=kv_group_num, + head_dim=head_dim, + BLOCK_M=BLOCK, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK, + num_warps=num_warps, + num_stages=1, + ) + return + + +def torch_att(xq, xk, xv, bs, seqlen, num_head, head_dim, prompt_cache_len): + xq = xq.view(bs, seqlen, num_head, head_dim) + xk = xk.view(bs, seqlen + prompt_cache_len, num_head, head_dim) + xv = xv.view(bs, seqlen + prompt_cache_len, num_head, head_dim) + mask_cache = torch.ones((seqlen, prompt_cache_len)).cuda().unsqueeze(0).unsqueeze(0).cuda() + mask = torch.tril(torch.ones(seqlen, seqlen), diagonal=0).unsqueeze(0).unsqueeze(0).cuda() + mask[mask == 0.0] = -100000000.0 + mask = torch.cat([mask_cache, mask], dim=-1) + mask = mask.repeat(bs, num_head, 1, 1) + keys = xk + values = xv + xq = xq.transpose(1, 2) + keys = keys.transpose(1, 2) + values = values.transpose(1, 2) + scores = torch.matmul(xq, keys.transpose(2, 3)) / math.sqrt(head_dim) + scores = F.softmax(scores.float() + mask, dim=-1).type_as(xq) + output = torch.matmul(scores, values).transpose(1, 2).contiguous().reshape(-1, num_head, head_dim) + return output + + +def test(): + import torch + import numpy as np + + Z, H, N_CTX, D_HEAD = 10, 6, 500, 96 + dtype = torch.float16 + Z = 1 + q = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.1, std=0.2) + k = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.4, std=0.2) + v = torch.empty((Z * N_CTX + 7000, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + o = torch.empty((Z * N_CTX, H, D_HEAD), dtype=dtype, device="cuda").normal_(mean=0.3, std=0.2) + req_to_token_indexs = torch.zeros((10, Z * N_CTX + 7000), dtype=torch.int32, device="cuda") + max_input_len = N_CTX + Z = 1 + b_start_loc = torch.zeros((Z,), dtype=torch.int32, device="cuda") + b_seq_len = torch.ones((Z,), dtype=torch.int32, device="cuda") + b_req_idx = torch.ones((Z,), dtype=torch.int32, device="cuda") + b_prompt_cache_len = torch.zeros(1, dtype=torch.int32, device="cuda") + b_prompt_cache_len[0] = 0 + prompt_cache_len = 0 + + b_seq_len[0] = 500 + b_req_idx[0] = 0 + req_to_token_indexs[0][: prompt_cache_len + N_CTX] = torch.tensor( + np.arange(prompt_cache_len + N_CTX), dtype=torch.int32 + ).cuda() + + torch_out = [] + start = 0 + for i in range(Z): + end = start + b_seq_len[i] + torch_o = torch_att( + q[start:end], + k[start : end + prompt_cache_len], + v[start : end + prompt_cache_len], + 1, + b_seq_len[i], + H, + D_HEAD, + prompt_cache_len, + ) + start = end + torch_out.append(torch_o) + + torch_out = torch.cat(torch_out, dim=0) + + context_attention_fwd( + q, + k, + v, + o, + b_req_idx, + b_start_loc, + b_seq_len + prompt_cache_len, + b_prompt_cache_len, + max_input_len, + req_to_token_indexs, + ) + + # context_attention_fwd_no_prompt_cache( + # q, k, v, o, b_start_loc, b_seq_len, max_input_len + # ) + + print("max ", torch.max(torch.abs(torch_out - o))) + print("mean ", torch.mean(torch.abs(torch_out - o))) + assert torch.allclose(torch_out, o, atol=1e-2, rtol=0) diff --git a/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py b/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py new file mode 100644 index 00000000..4f31895a --- /dev/null +++ b/lightllm/models/phi3/triton_kernel/destindex_copy_kv.py @@ -0,0 +1,192 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_destindex_copy_kv( + K, + Dest_loc, + Out, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_d, + head_num, + head_dim, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(Dest_loc + cur_index) + + k_ptrs = K + cur_index * stride_k_bs + stride_k_h * offs_h[:, None] + stride_k_d * offs_d[None, :] + o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] + + k = tl.load(k_ptrs, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), other=0.0) + tl.store(o_ptrs, k, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim)) + return + + +@torch.no_grad() +def destindex_copy_kv(K, DestLoc, Out): + seq_len = DestLoc.shape[0] + head_num = K.shape[1] + head_dim = K.shape[2] + assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] + BLOCK_HEAD = triton.next_power_of_2(head_num) + BLOCK_DMODEL = triton.next_power_of_2(head_dim) + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_kv[grid]( + K, + DestLoc, + Out, + K.stride(0), + K.stride(1), + K.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + head_num, + head_dim, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_HEAD=BLOCK_HEAD, + num_warps=num_warps, + num_stages=1, + ) + return + + +@triton.jit +def _fwd_kernel_destindex_copy_quantize_kv( + K, + Dest_loc, + Out, + Out_scale, + stride_k_bs, + stride_k_h, + stride_k_d, + stride_o_bs, + stride_o_h, + stride_o_d, + stride_os_bs, + stride_os_h, + stride_os_d, + head_num, + head_dim, + BLOCK_DMODEL: tl.constexpr, + BLOCK_HEAD: tl.constexpr, +): + cur_index = tl.program_id(0) + offs_h = tl.arange(0, BLOCK_HEAD) + offs_d = tl.arange(0, BLOCK_DMODEL) + + dest_index = tl.load(Dest_loc + cur_index) + src_data = tl.load( + K + cur_index * stride_k_bs + offs_h[:, None] * stride_k_h + stride_k_d * offs_d[None, :], + mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim), + other=0.0, + ) + abs_data = tl.abs(src_data) + data_scale = (tl.max(abs_data, axis=1) / 127.0).to(Out_scale.dtype.element_ty)[:, None] + q_src_data = (src_data / data_scale).to(tl.int8) + o_ptrs = Out + dest_index * stride_o_bs + stride_o_h * offs_h[:, None] + stride_o_d * offs_d[None, :] + os_ptrs = Out_scale + dest_index * stride_os_bs + stride_os_h * offs_h[:, None] + tl.store(o_ptrs, q_src_data, mask=(offs_h[:, None] < head_num) & (offs_d[None, :] < head_dim)) + tl.store(os_ptrs, data_scale, mask=(offs_h[:, None] < head_num)) + + +@torch.no_grad() +def destindex_copy_quantize_kv(K, DestLoc, Out, Out_scale): + seq_len = DestLoc.shape[0] + head_num = K.shape[1] + head_dim = K.shape[2] + assert K.shape[1] == Out.shape[1] and K.shape[2] == Out.shape[2] + BLOCK_HEAD = triton.next_power_of_2(head_num) + BLOCK_DMODEL = triton.next_power_of_2(head_dim) + grid = (seq_len,) + num_warps = 1 + + _fwd_kernel_destindex_copy_quantize_kv[grid]( + K, + DestLoc, + Out, + Out_scale, + K.stride(0), + K.stride(1), + K.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + Out_scale.stride(0), + Out_scale.stride(1), + Out_scale.stride(2), + head_num, + head_dim, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_HEAD=BLOCK_HEAD, + num_warps=num_warps, + num_stages=1, + ) + return + + +def test1(): + import time + + B, N_CTX, H, D = 32, 1024, 12, 96 + dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() + src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() + dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32, device="cuda") + + for _ in range(10): + destindex_copy_kv(src, dest_loc, dest) + torch.cuda.synchronize() + t1 = time.time() + for _ in range(1000): + destindex_copy_kv(src, dest_loc, dest) + torch.cuda.synchronize() + t2 = time.time() + + print("Time cost ", t2 - t1) + print("max ", torch.max(torch.abs(dest - src))) + print("mean ", torch.mean(torch.abs(dest - src))) + assert torch.allclose(src, dest, atol=1e-2, rtol=0) + + +def test2(): + import time + + B, N_CTX, H, D = 32, 1024, 12, 96 + src = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda() + dest_loc = torch.arange(0, B * N_CTX, dtype=torch.int32).cuda() + value_dest = torch.randn((B * N_CTX, H, D), dtype=torch.float16).cuda().to(torch.int8) + scale_dest = torch.randn((B * N_CTX, H, 1), dtype=torch.float16).cuda() + + for _ in range(10): + destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) + torch.cuda.synchronize() + t1 = time.time() + for _ in range(1000): + destindex_copy_quantize_kv(src, dest_loc, value_dest, scale_dest) + torch.cuda.synchronize() + t2 = time.time() + + print("Time cost ", t2 - t1) + print("max ", torch.max(torch.abs(value_dest * scale_dest - src))) + print("mean ", torch.mean(torch.abs(value_dest * scale_dest - src))) + cos = torch.nn.CosineSimilarity(0) + print("cos ", cos(src.flatten().to(torch.float32), (value_dest * scale_dest).flatten().to(torch.float32))) + + +if __name__ == "__main__": + test1() + test2() diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding.py b/lightllm/models/phi3/triton_kernel/flash_decoding.py new file mode 100644 index 00000000..94a3daca --- /dev/null +++ b/lightllm/models/phi3/triton_kernel/flash_decoding.py @@ -0,0 +1,39 @@ +import torch + + +def token_decode_attention_flash_decoding(q, infer_state, q_head_num, head_dim, cache_k, cache_v, out=None): + BLOCK_SEQ = 256 + batch_size = infer_state.batch_size + max_len_in_batch = infer_state.max_len_in_batch + calcu_shape1 = (batch_size, q_head_num, head_dim) + + from .flash_decoding_stage1 import flash_decode_stage1 + from .flash_decoding_stage2 import flash_decode_stage2 + + o_tensor = torch.empty_like(q) if out is None else out + + if getattr(infer_state, "mid_o", None) is None: + infer_state.mid_o = torch.empty( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1, head_dim], dtype=torch.float32, device="cuda" + ) + infer_state.mid_o_logexpsum = torch.empty( + [batch_size, q_head_num, max_len_in_batch // BLOCK_SEQ + 1], dtype=torch.float32, device="cuda" + ) + + mid_o = infer_state.mid_o + mid_o_logexpsum = infer_state.mid_o_logexpsum + + flash_decode_stage1( + q.view(calcu_shape1), + cache_k, + cache_v, + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.max_len_in_batch, + mid_o, + mid_o_logexpsum, + BLOCK_SEQ, + ) + flash_decode_stage2(mid_o, mid_o_logexpsum, infer_state.b_seq_len, o_tensor.view(calcu_shape1), BLOCK_SEQ) + return o_tensor diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py b/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py new file mode 100644 index 00000000..f6d8b5ab --- /dev/null +++ b/lightllm/models/phi3/triton_kernel/flash_decoding_stage1.py @@ -0,0 +1,162 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage1( + Q, + K, + V, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + stride_req_to_tokens_b, + stride_req_to_tokens_s, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_vbs, + stride_vh, + stride_vd, + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + gqa_group_size, + head_dim, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, + BLOCK_N: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + seq_start_block = tl.program_id(2) + cur_kv_head = cur_head // gqa_group_size + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + cur_batch_req_idx = tl.load(B_req_idx + cur_batch) + cur_batch_start_index = seq_start_block * BLOCK_SEQ + cur_batch_end_index = tl.minimum(cur_batch_seq_len, cur_batch_start_index + BLOCK_SEQ) + + off_q = cur_batch * stride_qbs + cur_head * stride_qh + offs_d + + block_n_size = ( + tl.where( + cur_batch_end_index - cur_batch_start_index <= 0, + 0, + cur_batch_end_index - cur_batch_start_index + BLOCK_N - 1, + ) + // BLOCK_N + ) + + offs_n = cur_batch_start_index + tl.arange(0, BLOCK_N) + + q = tl.load(Q + off_q, mask=offs_d < head_dim, other=0.0) + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + for start_n in range(0, block_n_size, 1): + offs_n_new = start_n * BLOCK_N + offs_n + k_loc = tl.load( + Req_to_tokens + stride_req_to_tokens_b * cur_batch_req_idx + offs_n_new, + mask=offs_n_new < cur_batch_end_index, + other=0, + ) + off_k = k_loc[:, None] * stride_kbs + cur_kv_head * stride_kh + offs_d[None, :] + k = tl.load( + K + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0 + ) + att_value = tl.sum(q[None, :] * k, 1) + att_value *= sm_scale + att_value = tl.where(offs_n_new < cur_batch_end_index, att_value, float("-inf")) + v = tl.load( + V + off_k, mask=(offs_n_new[:, None] < cur_batch_end_index) & (offs_d[None, :] < head_dim), other=0.0 + ) + + cur_max_logic = tl.max(att_value, axis=0) + new_max_logic = tl.maximum(cur_max_logic, max_logic) + + exp_logic = tl.exp(att_value - new_max_logic) + logic_scale = tl.exp(max_logic - new_max_logic) + acc *= logic_scale + acc += tl.sum(exp_logic[:, None] * v, axis=0) + + sum_exp = sum_exp * logic_scale + tl.sum(exp_logic, axis=0) + max_logic = new_max_logic + + need_store = tl.where(block_n_size == 0, 0, 1) + for _ in range(0, need_store, 1): + off_mid_o = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + seq_start_block * stride_mid_os + offs_d + off_mid_o_logexpsum = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + seq_start_block + tl.store(Mid_O + off_mid_o, acc / sum_exp, mask=offs_d < head_dim) + tl.store(Mid_O_LogExpSum + off_mid_o_logexpsum, max_logic + tl.log(sum_exp)) + return + + +@torch.no_grad() +def flash_decode_stage1( + q, k, v, Req_to_tokens, B_req_idx, B_Seqlen, max_len_in_batch, mid_out, mid_out_logsumexp, block_seq +): + BLOCK_SEQ = block_seq + BLOCK_N = 16 + assert BLOCK_SEQ % BLOCK_N == 0 + # shape constraints + Lq, Lk = q.shape[-1], k.shape[-1] + assert Lq == Lk + head_dim = Lq + BLOCK_DMODEL = triton.next_power_of_2(head_dim) + sm_scale = 1.0 / (Lk ** 0.5) + batch, head_num = B_req_idx.shape[0], q.shape[1] + grid = (batch, head_num, triton.cdiv(max_len_in_batch, BLOCK_SEQ)) + gqa_group_size = q.shape[1] // k.shape[1] + + _fwd_kernel_flash_decode_stage1[grid]( + q, + k, + v, + sm_scale, + Req_to_tokens, + B_req_idx, + B_Seqlen, + mid_out, + mid_out_logsumexp, + Req_to_tokens.stride(0), + Req_to_tokens.stride(1), + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + v.stride(0), + v.stride(1), + v.stride(2), + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logsumexp.stride(0), + mid_out_logsumexp.stride(1), + mid_out_logsumexp.stride(2), + gqa_group_size, + head_dim, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + BLOCK_N=BLOCK_N, + num_warps=1, + num_stages=2, + ) + return diff --git a/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py b/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py new file mode 100644 index 00000000..a06ee545 --- /dev/null +++ b/lightllm/models/phi3/triton_kernel/flash_decoding_stage2.py @@ -0,0 +1,85 @@ +import torch +import triton +import triton.language as tl + + +@triton.jit +def _fwd_kernel_flash_decode_stage2( + B_Seqlen, + Mid_O, # [batch, head, seq_block_num, head_dim] + Mid_O_LogExpSum, # [batch, head, seq_block_num] + Out, # [batch, head, head_dim] + stride_mid_ob, + stride_mid_oh, + stride_mid_os, + stride_mid_od, + stride_mid_o_eb, + stride_mid_o_eh, + stride_mid_o_es, + stride_obs, + stride_oh, + stride_od, + head_dim, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_batch = tl.program_id(0) + cur_head = tl.program_id(1) + + offs_d = tl.arange(0, BLOCK_DMODEL) + cur_batch_seq_len = tl.load(B_Seqlen + cur_batch) + + block_n_size = tl.where(cur_batch_seq_len <= 0, 0, cur_batch_seq_len + BLOCK_SEQ - 1) // BLOCK_SEQ + + sum_exp = 0.0 + max_logic = -float("inf") + acc = tl.zeros([BLOCK_DMODEL], dtype=tl.float32) + + offs_v = cur_batch * stride_mid_ob + cur_head * stride_mid_oh + offs_d + offs_logic = cur_batch * stride_mid_o_eb + cur_head * stride_mid_o_eh + for block_seq_n in range(0, block_n_size, 1): + tv = tl.load(Mid_O + offs_v + block_seq_n * stride_mid_os, mask=offs_d < head_dim, other=0.0) + tlogic = tl.load(Mid_O_LogExpSum + offs_logic + block_seq_n) + new_max_logic = tl.maximum(tlogic, max_logic) + + old_scale = tl.exp(max_logic - new_max_logic) + acc *= old_scale + exp_logic = tl.exp(tlogic - new_max_logic) + acc += exp_logic * tv + sum_exp = sum_exp * old_scale + exp_logic + max_logic = new_max_logic + + tl.store(Out + cur_batch * stride_obs + cur_head * stride_oh + offs_d, acc / sum_exp, mask=offs_d < head_dim) + return + + +@torch.no_grad() +def flash_decode_stage2(mid_out, mid_out_logexpsum, B_Seqlen, Out, block_seq): + Lk = mid_out.shape[-1] + head_dim = Lk + batch, head_num = mid_out.shape[0], mid_out.shape[1] + BLOCK_DMODEL = triton.next_power_of_2(head_dim) + grid = (batch, head_num) + + _fwd_kernel_flash_decode_stage2[grid]( + B_Seqlen, + mid_out, + mid_out_logexpsum, + Out, + mid_out.stride(0), + mid_out.stride(1), + mid_out.stride(2), + mid_out.stride(3), + mid_out_logexpsum.stride(0), + mid_out_logexpsum.stride(1), + mid_out_logexpsum.stride(2), + Out.stride(0), + Out.stride(1), + Out.stride(2), + head_dim, + BLOCK_SEQ=block_seq, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=4, + num_stages=2, + ) + return diff --git a/lightllm/models/phi3/triton_kernel/rotary_emb.py b/lightllm/models/phi3/triton_kernel/rotary_emb.py new file mode 100755 index 00000000..d0eab854 --- /dev/null +++ b/lightllm/models/phi3/triton_kernel/rotary_emb.py @@ -0,0 +1,217 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + Q, + K, + Cos, + Sin, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, + max_total_len, + HEAD_Q, + HEAD_K, # N_CTX 代表要计算的上下文长度 + rot_dim, + head_dim, + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_head_index = tl.program_id(0) + cur_seq_index = tl.program_id(1) + + cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + # dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) + # dim_range1 = tl.arange(BLOCK_DMODEL // 2, BLOCK_DMODEL) + + dim_range0 = tl.arange(0, BLOCK_DMODEL) + dim_range1 = rot_dim + tl.arange(0, BLOCK_DMODEL) + + off_q0 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range0[None, None, :] * stride_qd + ) + off_q1 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range1[None, None, :] * stride_qd + ) + + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd + + q0 = tl.load( + Q + off_q0, + mask=(cur_seq_range[:, None, None] < max_total_len) + & (cur_head_range[None, :, None] < HEAD_Q) + & (dim_range0[None, None, :] < rot_dim), + other=0.0, + ) + q1 = tl.load( + Q + off_q1, + mask=(cur_seq_range[:, None, None] < max_total_len) + & (cur_head_range[None, :, None] < HEAD_Q) + & (dim_range1[None, None, :] < head_dim), + other=0.0, + ) + + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out0 = q0 * cos - q1 * sin + out1 = q0 * sin + q1 * cos + + tl.store( + Q + off_q0, + out0, + mask=(cur_seq_range[:, None, None] < max_total_len) + & (cur_head_range[None, :, None] < HEAD_Q) + & (dim_range0[None, None, :] < rot_dim), + ) + tl.store( + Q + off_q1, + out1, + mask=(cur_seq_range[:, None, None] < max_total_len) + & (cur_head_range[None, :, None] < HEAD_Q) + & (dim_range1[None, None, :] < head_dim), + ) + + off_k0 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range0[None, None, :] * stride_kd + ) + off_k1 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range1[None, None, :] * stride_kd + ) + + off_dimcos_sin = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd + + k0 = tl.load( + K + off_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) + & (cur_head_range[None, :, None] < HEAD_K) + & (dim_range0[None, None, :] < rot_dim), + other=0.0, + ) + k1 = tl.load( + K + off_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) + & (cur_head_range[None, :, None] < HEAD_K) + & (dim_range1[None, None, :] < head_dim), + other=0.0, + ) + cos = tl.load(Cos + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin = tl.load(Sin + off_dimcos_sin, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out_k0 = k0 * cos - k1 * sin + out_k1 = k0 * sin + k1 * cos + + tl.store( + K + off_k0, + out_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) + & (cur_head_range[None, :, None] < HEAD_K) + & (dim_range0[None, None, :] < rot_dim), + ) + tl.store( + K + off_k1, + out_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) + & (cur_head_range[None, :, None] < HEAD_K) + & (dim_range1[None, None, :] < head_dim), + ) + return + + +@torch.no_grad() +def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): + total_len = q.shape[0] + head_num_q, head_num_k = q.shape[1], k.shape[1] + head_dim = int(q.shape[2] * partial_rotary_factor) + rot_dim = head_dim // 2 + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + + BLOCK_SEQ = 16 + BLOCK_HEAD = 4 + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + BLOCK_DMODEL = triton.next_power_of_2(rot_dim) + grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + _rotary_kernel[grid]( + q, + k, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num_q, + head_num_k, + rot_dim, + head_dim, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=BLOCK_DMODEL, + num_warps=num_warps, + num_stages=1, + ) + return + + +def torch_rotary_emb(x, cos, sin): + seq_len, h, dim = x.shape + # dim = dim // 4 + x0 = x[:, :, 0 : dim // 2] + x1 = x[:, :, dim // 2 : dim] + cos = cos.view((seq_len, 1, dim // 2)) + sin = sin.view((seq_len, 1, dim // 2)) + o0 = x0 * cos - x1 * sin + o1 = x0 * sin + x1 * cos + return torch.cat((o0, o1), dim=-1) + + +def test_rotary_emb(SEQ_LEN, H, D, dtype, eps=1e-5, device="cuda"): + # create data + x_shape = (SEQ_LEN, H, D) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + y = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + cos_shape = (SEQ_LEN, D // 2) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + # forward pass + y_tri = torch_rotary_emb(x, cos, sin) + rotary_emb_fwd(y, x, cos, sin) + y_ref = x + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0) diff --git a/lightllm/server/router/model_infer/mode_backend/base_backend.py b/lightllm/server/router/model_infer/mode_backend/base_backend.py index 29895d96..ca1f069d 100644 --- a/lightllm/server/router/model_infer/mode_backend/base_backend.py +++ b/lightllm/server/router/model_infer/mode_backend/base_backend.py @@ -38,6 +38,7 @@ from lightllm.models.qwen_vl.model import QWenVLTpPartModel from lightllm.models.internlm_xcomposer.model import InternlmComposerTpPartModel from lightllm.models.gemma_2b.model import Gemma_2bTpPartModel +from lightllm.models.phi3.model import Phi3TpPartModel from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end from lightllm.utils.log_utils import init_logger @@ -174,6 +175,8 @@ def init_model(self, kvargs): self.model = Gemma_2bTpPartModel(model_kvargs) elif self.model_type == "cohere": self.model = CohereTpPartModel(model_kvargs) + elif self.model_type == "phi3": + self.model = Phi3TpPartModel(model_kvargs) else: raise Exception(f"can not support {self.model_type} now") except Exception as e: