From 50d99a62103057fac8ffe94487667607d09acc07 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Fri, 1 Dec 2023 15:13:59 +0800 Subject: [PATCH] support ppl fp16 attention (#235) --- .../layer_infer/transformer_layer_infer.py | 21 +++++++++++++++++++ lightllm/server/api_server.py | 3 ++- lightllm/server/router/manager.py | 2 +- 3 files changed, 24 insertions(+), 2 deletions(-) diff --git a/lightllm/models/llama/layer_infer/transformer_layer_infer.py b/lightllm/models/llama/layer_infer/transformer_layer_infer.py index 596a9b63..dfdb36d5 100644 --- a/lightllm/models/llama/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama/layer_infer/transformer_layer_infer.py @@ -39,6 +39,9 @@ def _bind_func(self): if "ppl_int8kv" in self.mode: self._token_attention_kernel = self._token_decode_attention_ppl_int8kv self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_ppl_int8kv + elif "ppl_fp16" in self.mode: + self._token_attention_kernel = self._token_decode_attention_ppl_fp16 + self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal elif "triton_int8kv" in self.mode: self._token_attention_kernel = self._token_decode_attention_int8kv self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_int8kv @@ -273,4 +276,22 @@ def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo infer_state.b_seq_len, infer_state.max_len_in_batch) + return o_tensor + + def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None): + batch_size = infer_state.batch_size + calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_) + o_tensor = torch.empty_like(q) if out is None else out + from lightllm_ppl_fp16_kernel import fp16_decode_attention + # group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch) + fp16_decode_attention(o_tensor.view(calcu_shape1), + 1.0 / (self.head_dim_**0.5), + q.view(calcu_shape1), + infer_state.mem_manager.key_buffer[self.layer_num_], + infer_state.mem_manager.value_buffer[self.layer_num_], + infer_state.req_manager.req_to_token_indexs, + infer_state.b_req_idx, + infer_state.b_seq_len, + infer_state.max_len_in_batch) + return o_tensor \ No newline at end of file diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index f6f1f287..e314b276 100644 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -296,11 +296,12 @@ def main(): parser.add_argument("--nccl_port", type=int, default=28765, help="the nccl_port to build a distributed environment for PyTorch") parser.add_argument("--mode", type=str, default=[], nargs='+', - help="""Model mode: [triton_int8kv | ppl_int8kv | triton_flashdecoding] + help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding] [triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight], triton_flashdecoding mode is for long context, current support llama llama2 qwen; triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel; ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel; + ppl_fp16 mode use ppl fast fp16 decode attention kernel; triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights; you need to read source code to make sure the supported detail mode for all models""") parser.add_argument("--trust_remote_code", action='store_true', diff --git a/lightllm/server/router/manager.py b/lightllm/server/router/manager.py index 948788d0..1a85fb57 100644 --- a/lightllm/server/router/manager.py +++ b/lightllm/server/router/manager.py @@ -170,13 +170,13 @@ async def _step(self): # 有运行请求,但是已经到了可以调度新的请求合并推理的时机 if self.has_wait_tokens >= self.max_wait_tokens: new_mini_batch = self.req_queue.generate_new_batch(self.running_batch) + self.has_wait_tokens = 0 if new_mini_batch is not None: self.stats_tool.count_prompt_tokens(new_mini_batch) await self._prefill_batch(new_mini_batch) if not new_mini_batch.is_clear(): await self._merge_batch(self.running_batch, new_mini_batch) self.running_batch.merge(new_mini_batch) - self.has_wait_tokens = 0 return # 正常 decode 阶段, 如果可以直接decode就直接decode,否则通过暂停策略暂停一些请求