diff --git a/README.md b/README.md index 25798c39..b1e1b2f9 100644 --- a/README.md +++ b/README.md @@ -46,9 +46,10 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram - [Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat) - [Llava-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b) - [Llava-13b](https://huggingface.co/liuhaotian/llava-v1.5-13b) -- [Mixtral]() +- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1) - [Stablelm](https://huggingface.co/stabilityai/stablelm-2-1_6b) - [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16) +- [CohereForAI](https://huggingface.co/CohereForAI/c4ai-command-r-plus) > When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'. diff --git a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py index 54deffa2..906a64df 100755 --- a/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py +++ b/lightllm/common/basemodel/layer_infer/template/transformer_layer_infer_cohere_template.py @@ -55,7 +55,12 @@ def _get_qkv( k = cache_kv[:, 0 : self.tp_k_head_num_, :] q = self._q_norm(q, infer_state, layer_weight) cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight) - self._rotary_emb_fwd(q, cache_kv, infer_state.position_cos, infer_state.position_sin) + self._rotary_emb_fwd( + q.view(-1, self.tp_q_head_num_, self.head_dim_), + cache_kv[:, 0 : self.tp_k_head_num_, :], + infer_state.position_cos, + infer_state.position_sin, + ) return q, cache_kv def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor: diff --git a/lightllm/models/cohere/layer_infer/post_layer_infer.py b/lightllm/models/cohere/layer_infer/post_layer_infer.py index 2c9a9cf2..0ce12ad5 100644 --- a/lightllm/models/cohere/layer_infer/post_layer_infer.py +++ b/lightllm/models/cohere/layer_infer/post_layer_infer.py @@ -4,7 +4,7 @@ from lightllm.models.cohere.infer_struct import CohereInferStateInfo from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward +from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo @@ -22,7 +22,9 @@ def __init__(self, tp_rank, world_size, network_config, mode): return def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor: - return layernorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_) + return layernorm_forward( + input.unsqueeze(1), layer_weight.final_norm_weight_.unsqueeze(0), eps=self.eps_ + ).squeeze(1) def _slice_get_last_input(self, input_embdings, infer_state: CohereInferStateInfo): if infer_state.is_splitfuse: diff --git a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py index 22738bcc..2b255ffd 100644 --- a/lightllm/models/cohere/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/cohere/layer_infer/transformer_layer_infer.py @@ -6,9 +6,9 @@ ) from lightllm.models.cohere.infer_struct import CohereInferStateInfo from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight -from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward +from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, torch_layernorm +from lightllm.models.cohere.triton_kernels.rotary_emb import rotary_emb_fwd from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer -from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd @@ -42,13 +42,13 @@ def _bind_rotary_emb_fwd(self): self._rotary_emb_fwd = partial(CohereTransformerLayerInfer._rotary_emb_fwd, self) def _att_norm(self, input, infer_state, layer_weight): - return layernorm_forward(input, layer_weight.att_norm_weight_, self.eps_) + return layernorm_forward(input.unsqueeze(1), layer_weight.att_norm_weight_.unsqueeze(0), self.eps_).squeeze(1) def _q_norm(self, input, infer_state, layer_weight): - return multi_head_layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_) + return layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_) def _k_norm(self, input, infer_state, layer_weight): - return multi_head_layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_) + return layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_) def _bind_norm(self): self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self) diff --git a/lightllm/models/cohere/model.py b/lightllm/models/cohere/model.py index 7bb0ae23..e6cffe76 100644 --- a/lightllm/models/cohere/model.py +++ b/lightllm/models/cohere/model.py @@ -1,7 +1,10 @@ +import os +import torch from lightllm.common.basemodel.basemodel import TpPartBaseModel from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import ( TransformerLayerCohereInferTpl, ) +from lightllm.common.mem_manager import MemoryManager from lightllm.models.cohere.infer_struct import CohereInferStateInfo from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer @@ -10,6 +13,9 @@ from lightllm.models.cohere.splitfuse_infer_struct import CohereSplitFuseInferStateInfo from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer from lightllm.models.llama.model import LlamaTpPartModel +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class CohereTpPartModel(LlamaTpPartModel): @@ -22,3 +28,42 @@ class CohereTpPartModel(LlamaTpPartModel): infer_state_class = CohereInferStateInfo splitfuse_infer_state_class = CohereSplitFuseInferStateInfo + + def _init_to_get_rotary(self, default_base=10000): + partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_) + if self.config.get("rope_scaling", {}) is None: + rope_scaling_factor = 1.0 + else: + rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0) + + base = self.config.get("rope_theta", float(default_base)) + + if "max_sequence_length" in self.config: + max_seq_len = self.config["max_sequence_length"] + else: + max_position_embeddings = self.config.get( + "max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384 + ) + max_seq_len = max_position_embeddings * rope_scaling_factor + + # NTK + try: + ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1)) + assert ntk_alpha >= 1 + if ntk_alpha > 1: + logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}") + max_seq_len *= ntk_alpha + base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula + except: + pass + + inv_freq = 1.0 / ( + base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim) + ) + t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor + freqs = torch.outer(t, inv_freq) + freqs = torch.repeat_interleave(freqs, 2, dim=-1) + + self._cos_cached = torch.cos(freqs).to(self.data_type).cuda() + self._sin_cached = torch.sin(freqs).to(self.data_type).cuda() + return diff --git a/lightllm/models/cohere/triton_kernels/layernorm.py b/lightllm/models/cohere/triton_kernels/layernorm.py index e6008432..c1d5ff4c 100644 --- a/lightllm/models/cohere/triton_kernels/layernorm.py +++ b/lightllm/models/cohere/triton_kernels/layernorm.py @@ -1,11 +1,109 @@ import torch +import triton +import triton.language as tl +# LayerNorm adapted from triton tutorial, used for Cohere q, k norm +# X [N, head_num, head_dim] +# W [head_num, head_dim] +@triton.jit +def _layer_norm_fwd_kernel( + X, # pointer to the input + W, # pointer to the weights + Y, + stride_x_N, + stride_x_hn, + stride_x_hd, + stride_y_N, + stride_y_hn, + stride_y_hd, + stride_w_hn, + stride_w_hd, + N, # number of columns in X + eps, # epsilon to avoid division by zero + BLOCK_SIZE: tl.constexpr, +): + Seq = tl.program_id(0) + H = tl.program_id(1) -def layernorm_forward(x, weight, eps): - return torch.layer_norm(x, (x.shape[-1],), weight, bias=None, eps=eps) + X += Seq * stride_x_N + H * stride_x_hn + Y += Seq * stride_y_N + H * stride_y_hn + W += H * stride_w_hn + _mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + _mean += a + mean = tl.sum(_mean, axis=0) / N -def multi_head_layernorm_forward(x, weight, eps): + _var = tl.zeros([BLOCK_SIZE], dtype=tl.float32) + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32) + x = tl.where(cols < N, x - mean, 0.0) + _var += x * x + var = tl.sum(_var, axis=0) / N + rstd = 1 / tl.sqrt(var + eps) + + for off in range(0, N, BLOCK_SIZE): + cols = off + tl.arange(0, BLOCK_SIZE) + mask = cols < N + w = tl.load(W + cols, mask=mask).to(tl.float32) + x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32) + x_hat = (x - mean) * rstd + y = x_hat * w + + tl.store(Y + cols, y.to(X.dtype.element_ty), mask=mask) + + +def layernorm_forward( + X, # pointer to the input + W, # pointer to the weights + eps, # epsilon to avoid division by zero +): + assert len(X.shape) == 3 + assert len(W.shape) == 2 + assert X.shape[-1] == W.shape[-1] + assert X.shape[-2] == W.shape[-2] + + y = torch.empty_like(X) + + stride_x_N = X.stride(0) + stride_x_hn = X.stride(1) + stride_x_hd = X.stride(2) + + stride_y_N = y.stride(0) + stride_y_hn = y.stride(1) + stride_y_hd = y.stride(2) + + stride_w_hn = W.stride(0) + stride_w_hd = W.stride(1) + + N = X.shape[-1] + BLOCK_SIZE = 128 + + grid = (X.shape[0], X.shape[1]) + _layer_norm_fwd_kernel[grid]( + X, + W, + y, + stride_x_N, + stride_x_hn, + stride_x_hd, + stride_y_N, + stride_y_hn, + stride_y_hd, + stride_w_hn, + stride_w_hd, + N, + eps, + BLOCK_SIZE, + ) + + return y + + +def torch_layernorm(x, weight, eps): inp_dtype = x.dtype x = x.to(torch.float32) mean = x.mean(-1, keepdim=True) @@ -13,3 +111,21 @@ def multi_head_layernorm_forward(x, weight, eps): x = (x - mean) * torch.rsqrt(variance + eps) x = weight.to(torch.float32) * x return x.to(inp_dtype) + + +def test_layernorm(eps=1e-5): + # create data + dtype = torch.float16 + x_shape = (5, 1, 128) + w_shape = (x_shape[-2], x_shape[-1]) + weight = torch.rand(w_shape, dtype=dtype, device="cuda") + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + # forward pass + y_ref = torch_layernorm(x, weight, eps).to(dtype) + y_out = layernorm_forward(x, weight, eps) + + # compare + print("type:", y_out.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_out - y_ref))) + assert torch.allclose(y_out, y_ref, atol=1e-2, rtol=0) + return diff --git a/lightllm/models/cohere/triton_kernels/rotary_emb.py b/lightllm/models/cohere/triton_kernels/rotary_emb.py new file mode 100644 index 00000000..ac338e71 --- /dev/null +++ b/lightllm/models/cohere/triton_kernels/rotary_emb.py @@ -0,0 +1,199 @@ +import torch + +import triton +import triton.language as tl + + +@triton.jit +def _rotary_kernel( + Q, + K, + Cos, + Sin, + stride_qbs, + stride_qh, + stride_qd, + stride_kbs, + stride_kh, + stride_kd, + stride_cosbs, + stride_cosd, + stride_sinbs, + stride_sind, + max_total_len, + HEAD_Q, + HEAD_K, # N_CTX 代表要计算的上下文长度 + BLOCK_HEAD: tl.constexpr, + BLOCK_SEQ: tl.constexpr, + BLOCK_DMODEL: tl.constexpr, +): + cur_head_index = tl.program_id(0) + cur_seq_index = tl.program_id(1) + + cur_head_range = cur_head_index * BLOCK_HEAD + tl.arange(0, BLOCK_HEAD) + cur_seq_range = cur_seq_index * BLOCK_SEQ + tl.arange(0, BLOCK_SEQ) + + dim_range0 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + dim_range1 = tl.arange(0, BLOCK_DMODEL // 2) * 2 + 1 + + off_q0 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range0[None, None, :] * stride_qd + ) + off_q1 = ( + cur_seq_range[:, None, None] * stride_qbs + + cur_head_range[None, :, None] * stride_qh + + dim_range1[None, None, :] * stride_qd + ) + + off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd + off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd + + q0 = tl.load( + Q + off_q0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + other=0.0, + ) + q1 = tl.load( + Q + off_q1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q), + other=0.0, + ) + + cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out0 = q0 * cos0 - q1 * sin0 + out1 = q0 * sin1 + q1 * cos1 + + tl.store( + Q + off_q0, out0, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) + ) + tl.store( + Q + off_q1, out1, mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_Q) + ) + + off_k0 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range0[None, None, :] * stride_kd + ) + off_k1 = ( + cur_seq_range[:, None, None] * stride_kbs + + cur_head_range[None, :, None] * stride_kh + + dim_range1[None, None, :] * stride_kd + ) + + off_dimcos_sin0 = cur_seq_range[:, None, None] * stride_cosbs + dim_range0[None, None, :] * stride_cosd + off_dimcos_sin1 = cur_seq_range[:, None, None] * stride_cosbs + dim_range1[None, None, :] * stride_cosd + + k0 = tl.load( + K + off_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + k1 = tl.load( + K + off_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + other=0.0, + ) + + cos0 = tl.load(Cos + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin0 = tl.load(Sin + off_dimcos_sin0, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + cos1 = tl.load(Cos + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + sin1 = tl.load(Sin + off_dimcos_sin1, mask=cur_seq_range[:, None, None] < max_total_len, other=0.0) + + out_k0 = k0 * cos0 - k1 * sin0 + out_k1 = k0 * sin1 + k1 * cos1 + + tl.store( + K + off_k0, + out_k0, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) + tl.store( + K + off_k1, + out_k1, + mask=(cur_seq_range[:, None, None] < max_total_len) & (cur_head_range[None, :, None] < HEAD_K), + ) + return + + +def torch_cohere_rotary_emb(x, cos, sin): + dtype = x.dtype + seq_len, h, dim = x.shape + x = x.float() + x1 = x[:, :, ::2] + x2 = x[:, :, 1::2] + rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2) + cos = cos.view((seq_len, 1, dim)) + sin = sin.view((seq_len, 1, dim)) + o = (x * cos) + (rot_x * sin) + return o.to(dtype=dtype) + + +@torch.no_grad() +def rotary_emb_fwd(q, k, cos, sin, partial_rotary_factor=1.0): + total_len = q.shape[0] + head_num_q, head_num_k = q.shape[1], k.shape[1] + head_dim = int(q.shape[2] * partial_rotary_factor) + assert q.shape[0] == cos.shape[0] and q.shape[0] == sin.shape[0], f"q shape {q.shape} cos shape {cos.shape}" + assert k.shape[0] == cos.shape[0] and k.shape[0] == sin.shape[0], f"k shape {k.shape} cos shape {cos.shape}" + + BLOCK_SEQ = 16 + BLOCK_HEAD = 4 + if head_dim >= 128: + num_warps = 8 + else: + num_warps = 4 + + grid = (triton.cdiv(head_num_q, BLOCK_HEAD), triton.cdiv(total_len, BLOCK_SEQ)) + _rotary_kernel[grid]( + q, + k, + cos, + sin, + q.stride(0), + q.stride(1), + q.stride(2), + k.stride(0), + k.stride(1), + k.stride(2), + cos.stride(0), + cos.stride(1), + sin.stride(0), + sin.stride(1), + total_len, + head_num_q, + head_num_k, + BLOCK_HEAD=BLOCK_HEAD, + BLOCK_SEQ=BLOCK_SEQ, + BLOCK_DMODEL=head_dim, + num_warps=num_warps, + num_stages=1, + ) + return + + +def test_rotary_emb(SEQ_LEN, H, D, dtype, eps=1e-5, device="cuda"): + # create data + x_shape = (SEQ_LEN, H, D) + x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + y = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda") + cos_shape = (SEQ_LEN, D) + cos = -1.2 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + sin = -2.0 + 0.5 * torch.randn(cos_shape, dtype=dtype, device="cuda") + # forward pass + y_tri = torch_cohere_rotary_emb(x, cos, sin) + rotary_emb_fwd(x, y, cos, sin) + y_ref = x + + # compare + print("type:", y_tri.dtype, y_ref.dtype) + print("max delta:", torch.max(torch.abs(y_tri - y_ref))) + assert torch.allclose(y_tri, y_ref, atol=1e-2, rtol=0)