Skip to content

Commit

Permalink
add high performance layernorm triton kernels. (#432)
Browse files Browse the repository at this point in the history
  • Loading branch information
flyinglandlord authored Jun 13, 2024
1 parent 62c006c commit c8160a4
Show file tree
Hide file tree
Showing 7 changed files with 380 additions and 12 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,10 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
- [Qwen-VL-Chat](https://huggingface.co/Qwen/Qwen-VL-Chat)
- [Llava-7b](https://huggingface.co/liuhaotian/llava-v1.5-7b)
- [Llava-13b](https://huggingface.co/liuhaotian/llava-v1.5-13b)
- [Mixtral]()
- [Mixtral](https://huggingface.co/mistralai/Mixtral-8x7B-Instruct-v0.1)
- [Stablelm](https://huggingface.co/stabilityai/stablelm-2-1_6b)
- [MiniCPM](https://huggingface.co/openbmb/MiniCPM-2B-sft-bf16)
- [CohereForAI](https://huggingface.co/CohereForAI/c4ai-command-r-plus)

> When you start Qwen-7b, you need to set the parameter '--eos_id 151643 --trust_remote_code'.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,12 @@ def _get_qkv(
k = cache_kv[:, 0 : self.tp_k_head_num_, :]
q = self._q_norm(q, infer_state, layer_weight)
cache_kv[:, 0 : self.tp_k_head_num_, :] = self._k_norm(k, infer_state, layer_weight)
self._rotary_emb_fwd(q, cache_kv, infer_state.position_cos, infer_state.position_sin)
self._rotary_emb_fwd(
q.view(-1, self.tp_q_head_num_, self.head_dim_),
cache_kv[:, 0 : self.tp_k_head_num_, :],
infer_state.position_cos,
infer_state.position_sin,
)
return q, cache_kv

def _context_attention_kernel(self, q, kv, infer_state: InferStateInfo, layer_weight, out=None) -> torch.Tensor:
Expand Down
6 changes: 4 additions & 2 deletions lightllm/models/cohere/layer_infer/post_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from lightllm.models.cohere.infer_struct import CohereInferStateInfo
from lightllm.models.cohere.layer_weights.pre_and_post_layer_weight import CoherePreAndPostLayerWeight
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward
from lightllm.common.basemodel.layer_weights.base_layer_weight import BaseLayerWeight
from lightllm.common.basemodel.splitfuse_infer_struct import SplitFuseInferStateInfo

Expand All @@ -22,7 +22,9 @@ def __init__(self, tp_rank, world_size, network_config, mode):
return

def _norm(self, input, infer_state, layer_weight: CoherePreAndPostLayerWeight) -> torch.Tensor:
return layernorm_forward(input, layer_weight.final_norm_weight_, eps=self.eps_)
return layernorm_forward(
input.unsqueeze(1), layer_weight.final_norm_weight_.unsqueeze(0), eps=self.eps_
).squeeze(1)

def _slice_get_last_input(self, input_embdings, infer_state: CohereInferStateInfo):
if infer_state.is_splitfuse:
Expand Down
10 changes: 5 additions & 5 deletions lightllm/models/cohere/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
)
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
from lightllm.models.cohere.layer_weights.transformer_layer_weight import CohereTransformerLayerWeight
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, multi_head_layernorm_forward
from lightllm.models.cohere.triton_kernels.layernorm import layernorm_forward, torch_layernorm
from lightllm.models.cohere.triton_kernels.rotary_emb import rotary_emb_fwd
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd
from lightllm.models.llama.triton_kernel.silu_and_mul import silu_and_mul_fwd


Expand Down Expand Up @@ -42,13 +42,13 @@ def _bind_rotary_emb_fwd(self):
self._rotary_emb_fwd = partial(CohereTransformerLayerInfer._rotary_emb_fwd, self)

def _att_norm(self, input, infer_state, layer_weight):
return layernorm_forward(input, layer_weight.att_norm_weight_, self.eps_)
return layernorm_forward(input.unsqueeze(1), layer_weight.att_norm_weight_.unsqueeze(0), self.eps_).squeeze(1)

def _q_norm(self, input, infer_state, layer_weight):
return multi_head_layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_)
return layernorm_forward(input, layer_weight.q_norm_weight_, self.eps_)

def _k_norm(self, input, infer_state, layer_weight):
return multi_head_layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_)
return layernorm_forward(input, layer_weight.k_norm_weight_, self.eps_)

def _bind_norm(self):
self._att_norm = partial(CohereTransformerLayerInfer._att_norm, self)
Expand Down
45 changes: 45 additions & 0 deletions lightllm/models/cohere/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
import os
import torch
from lightllm.common.basemodel.basemodel import TpPartBaseModel
from lightllm.common.basemodel.layer_infer.template.transformer_layer_infer_cohere_template import (
TransformerLayerCohereInferTpl,
)
from lightllm.common.mem_manager import MemoryManager
from lightllm.models.cohere.infer_struct import CohereInferStateInfo
from lightllm.models.cohere.layer_infer.post_layer_infer import CoherePostLayerInfer
from lightllm.models.cohere.layer_infer.transformer_layer_infer import CohereTransformerLayerInfer
Expand All @@ -10,6 +13,9 @@
from lightllm.models.cohere.splitfuse_infer_struct import CohereSplitFuseInferStateInfo
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer
from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class CohereTpPartModel(LlamaTpPartModel):
Expand All @@ -22,3 +28,42 @@ class CohereTpPartModel(LlamaTpPartModel):

infer_state_class = CohereInferStateInfo
splitfuse_infer_state_class = CohereSplitFuseInferStateInfo

def _init_to_get_rotary(self, default_base=10000):
partial_head_dim = int(self.config.get("partial_rotary_factor", 1) * self.head_dim_)
if self.config.get("rope_scaling", {}) is None:
rope_scaling_factor = 1.0
else:
rope_scaling_factor = self.config.get("rope_scaling", {}).get("factor", 1.0)

base = self.config.get("rope_theta", float(default_base))

if "max_sequence_length" in self.config:
max_seq_len = self.config["max_sequence_length"]
else:
max_position_embeddings = self.config.get(
"max_position_embeddings", 2048 if base <= 10000.0 + 1e-5 else 16384
)
max_seq_len = max_position_embeddings * rope_scaling_factor

# NTK
try:
ntk_alpha = float(os.environ.get("LIGHTLLM_NTK_ALPHA", 1))
assert ntk_alpha >= 1
if ntk_alpha > 1:
logger.info(f"Note: NTK enabled, alpha set to {ntk_alpha}")
max_seq_len *= ntk_alpha
base = base * (ntk_alpha ** (partial_head_dim / (partial_head_dim - 2))) # Base change formula
except:
pass

inv_freq = 1.0 / (
base ** (torch.arange(0, partial_head_dim, 2, device="cpu", dtype=torch.float32) / partial_head_dim)
)
t = torch.arange(max_seq_len + 1024 * 128, device="cpu", dtype=torch.float32) / rope_scaling_factor
freqs = torch.outer(t, inv_freq)
freqs = torch.repeat_interleave(freqs, 2, dim=-1)

self._cos_cached = torch.cos(freqs).to(self.data_type).cuda()
self._sin_cached = torch.sin(freqs).to(self.data_type).cuda()
return
122 changes: 119 additions & 3 deletions lightllm/models/cohere/triton_kernels/layernorm.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,131 @@
import torch
import triton
import triton.language as tl

# LayerNorm adapted from triton tutorial, used for Cohere q, k norm
# X [N, head_num, head_dim]
# W [head_num, head_dim]
@triton.jit
def _layer_norm_fwd_kernel(
X, # pointer to the input
W, # pointer to the weights
Y,
stride_x_N,
stride_x_hn,
stride_x_hd,
stride_y_N,
stride_y_hn,
stride_y_hd,
stride_w_hn,
stride_w_hd,
N, # number of columns in X
eps, # epsilon to avoid division by zero
BLOCK_SIZE: tl.constexpr,
):
Seq = tl.program_id(0)
H = tl.program_id(1)

def layernorm_forward(x, weight, eps):
return torch.layer_norm(x, (x.shape[-1],), weight, bias=None, eps=eps)
X += Seq * stride_x_N + H * stride_x_hn
Y += Seq * stride_y_N + H * stride_y_hn
W += H * stride_w_hn

_mean = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
a = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
_mean += a
mean = tl.sum(_mean, axis=0) / N

def multi_head_layernorm_forward(x, weight, eps):
_var = tl.zeros([BLOCK_SIZE], dtype=tl.float32)
for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
x = tl.load(X + cols, mask=cols < N, other=0.0).to(tl.float32)
x = tl.where(cols < N, x - mean, 0.0)
_var += x * x
var = tl.sum(_var, axis=0) / N
rstd = 1 / tl.sqrt(var + eps)

for off in range(0, N, BLOCK_SIZE):
cols = off + tl.arange(0, BLOCK_SIZE)
mask = cols < N
w = tl.load(W + cols, mask=mask).to(tl.float32)
x = tl.load(X + cols, mask=mask, other=0.0).to(tl.float32)
x_hat = (x - mean) * rstd
y = x_hat * w

tl.store(Y + cols, y.to(X.dtype.element_ty), mask=mask)


def layernorm_forward(
X, # pointer to the input
W, # pointer to the weights
eps, # epsilon to avoid division by zero
):
assert len(X.shape) == 3
assert len(W.shape) == 2
assert X.shape[-1] == W.shape[-1]
assert X.shape[-2] == W.shape[-2]

y = torch.empty_like(X)

stride_x_N = X.stride(0)
stride_x_hn = X.stride(1)
stride_x_hd = X.stride(2)

stride_y_N = y.stride(0)
stride_y_hn = y.stride(1)
stride_y_hd = y.stride(2)

stride_w_hn = W.stride(0)
stride_w_hd = W.stride(1)

N = X.shape[-1]
BLOCK_SIZE = 128

grid = (X.shape[0], X.shape[1])
_layer_norm_fwd_kernel[grid](
X,
W,
y,
stride_x_N,
stride_x_hn,
stride_x_hd,
stride_y_N,
stride_y_hn,
stride_y_hd,
stride_w_hn,
stride_w_hd,
N,
eps,
BLOCK_SIZE,
)

return y


def torch_layernorm(x, weight, eps):
inp_dtype = x.dtype
x = x.to(torch.float32)
mean = x.mean(-1, keepdim=True)
variance = (x - mean).pow(2).mean(-1, keepdim=True)
x = (x - mean) * torch.rsqrt(variance + eps)
x = weight.to(torch.float32) * x
return x.to(inp_dtype)


def test_layernorm(eps=1e-5):
# create data
dtype = torch.float16
x_shape = (5, 1, 128)
w_shape = (x_shape[-2], x_shape[-1])
weight = torch.rand(w_shape, dtype=dtype, device="cuda")
x = -2.3 + 0.5 * torch.randn(x_shape, dtype=dtype, device="cuda")
# forward pass
y_ref = torch_layernorm(x, weight, eps).to(dtype)
y_out = layernorm_forward(x, weight, eps)

# compare
print("type:", y_out.dtype, y_ref.dtype)
print("max delta:", torch.max(torch.abs(y_out - y_ref)))
assert torch.allclose(y_out, y_ref, atol=1e-2, rtol=0)
return
Loading

0 comments on commit c8160a4

Please sign in to comment.