Skip to content

Commit

Permalink
support ppl fp16 attention (#235)
Browse files Browse the repository at this point in the history
  • Loading branch information
hiworldwzj authored Dec 1, 2023
1 parent f2a9b45 commit 50d99a6
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 2 deletions.
21 changes: 21 additions & 0 deletions lightllm/models/llama/layer_infer/transformer_layer_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,9 @@ def _bind_func(self):
if "ppl_int8kv" in self.mode:
self._token_attention_kernel = self._token_decode_attention_ppl_int8kv
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_ppl_int8kv
elif "ppl_fp16" in self.mode:
self._token_attention_kernel = self._token_decode_attention_ppl_fp16
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_normal
elif "triton_int8kv" in self.mode:
self._token_attention_kernel = self._token_decode_attention_int8kv
self._copy_kv_to_mem_cache = self._copy_kv_to_mem_cache_int8kv
Expand Down Expand Up @@ -273,4 +276,22 @@ def _token_decode_attention_ppl_int8kv(self, q, infer_state: LlamaInferStateInfo
infer_state.b_seq_len,
infer_state.max_len_in_batch)

return o_tensor

def _token_decode_attention_ppl_fp16(self, q, infer_state: LlamaInferStateInfo, layer_weight, out=None):
batch_size = infer_state.batch_size
calcu_shape1 = (batch_size, self.tp_q_head_num_, self.head_dim_)
o_tensor = torch.empty_like(q) if out is None else out
from lightllm_ppl_fp16_kernel import fp16_decode_attention
# group_int8kv_decode_attention(at::Tensor o, at::Tensor q, at::Tensor k, at::Tensor k_s, at::Tensor v, at::Tensor v_s, at::Tensor b_loc, at::Tensor b_seq_len, int max_len_in_batch)
fp16_decode_attention(o_tensor.view(calcu_shape1),
1.0 / (self.head_dim_**0.5),
q.view(calcu_shape1),
infer_state.mem_manager.key_buffer[self.layer_num_],
infer_state.mem_manager.value_buffer[self.layer_num_],
infer_state.req_manager.req_to_token_indexs,
infer_state.b_req_idx,
infer_state.b_seq_len,
infer_state.max_len_in_batch)

return o_tensor
3 changes: 2 additions & 1 deletion lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,12 @@ def main():
parser.add_argument("--nccl_port", type=int, default=28765,
help="the nccl_port to build a distributed environment for PyTorch")
parser.add_argument("--mode", type=str, default=[], nargs='+',
help="""Model mode: [triton_int8kv | ppl_int8kv | triton_flashdecoding]
help="""Model mode: [triton_int8kv | ppl_int8kv | ppl_fp16 | triton_flashdecoding]
[triton_int8weight | triton_int4weight | lmdeploy_int4weight | ppl_int4weight],
triton_flashdecoding mode is for long context, current support llama llama2 qwen;
triton_int8kv mode use int8 to store kv cache, can increase token capacity, use triton kernel;
ppl_int8kv mode use int8 to store kv cache, and use ppl fast kernel;
ppl_fp16 mode use ppl fast fp16 decode attention kernel;
triton_int8weight and triton_int4weight and lmdeploy_int4weight or ppl_int4weight mode use int8 and int4 to store weights;
you need to read source code to make sure the supported detail mode for all models""")
parser.add_argument("--trust_remote_code", action='store_true',
Expand Down
2 changes: 1 addition & 1 deletion lightllm/server/router/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,13 +170,13 @@ async def _step(self):
# 有运行请求,但是已经到了可以调度新的请求合并推理的时机
if self.has_wait_tokens >= self.max_wait_tokens:
new_mini_batch = self.req_queue.generate_new_batch(self.running_batch)
self.has_wait_tokens = 0
if new_mini_batch is not None:
self.stats_tool.count_prompt_tokens(new_mini_batch)
await self._prefill_batch(new_mini_batch)
if not new_mini_batch.is_clear():
await self._merge_batch(self.running_batch, new_mini_batch)
self.running_batch.merge(new_mini_batch)
self.has_wait_tokens = 0
return

# 正常 decode 阶段, 如果可以直接decode就直接decode,否则通过暂停策略暂停一些请求
Expand Down

0 comments on commit 50d99a6

Please sign in to comment.