From a47942e0cdb21fb06e4f0e798ebfa7663cd7bf86 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 13 Jun 2024 12:06:04 +0800 Subject: [PATCH] ppl w8a8 support splitfuse mode. (#437) use w8a8 + splitfuse mode start a llama2 7b servers. input: data = { 'inputs': 'San Francisco is a', 'parameters' : { 'do_sample': False, } } output: {'generated_text': ['city in California, United States. Unterscheidung between the two is not always clear-'], 'count_output_tokens': 16, 'finish_reason': 'length'} test function is ok. --------- Co-authored-by: wangzaijun Co-authored-by: shihaobai <42648726+shihaobai@users.noreply.github.com> --- .../layer_infer/transformer_layer_infer.py | 28 +++++++++++++++++++ lightllm/server/httpserver/manager.py | 3 +- 2 files changed, 30 insertions(+), 1 deletion(-) diff --git a/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py b/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py index 517b1ce4..774dbe5f 100755 --- a/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py +++ b/lightllm/models/llama_awquant/layer_infer/transformer_layer_infer.py @@ -13,6 +13,7 @@ ) from lightllm.models.llama.triton_kernel.rotary_emb import rotary_emb_fwd from lightllm.models.llama.infer_struct import LlamaInferStateInfo +from lightllm.models.llama.splitfuse_infer_struct import LlamaSplitFuseInferStateInfo from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer from lightllm.common.basemodel import TransformerLayerInferActivationWeightQuantTpl from lightllm.common.basemodel.cuda_kernel.ppl_awquant import ( @@ -220,6 +221,33 @@ def _token_ffn(self, input_embdings, infer_state: LlamaInferStateInfo, layer_wei input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) return + def _splitfuse_attention(self, input_embding, infer_state: LlamaSplitFuseInferStateInfo, layer_weight): + # 因为 LlamaSplitFuseInferStateInfo 对象并没有 is_prefill 成员,但是后续的矩阵乘法算子入口 + # 函数输入中需要使用到, 所以在开始的地方默认添加一个 is_prefill 成员,并设置为True. + infer_state.is_prefill = True + + input1, token_scale, skip_out = self._awquant_att_norm(input_embding, infer_state, layer_weight) + cache_kv = self._pre_cache_kv(infer_state, layer_weight) + q, cache_kv = self._get_qkv(input1, cache_kv, token_scale, infer_state, layer_weight) + input1 = None + self._post_cache_kv(cache_kv, infer_state, layer_weight) + o = self._splitfuse_attention_kernel(q, infer_state, layer_weight) + q = None + o = self._get_o(o, infer_state, layer_weight) + if self.world_size_ > 1: + dist.all_reduce(o, op=dist.ReduceOp.SUM, async_op=False) + input_embding.add_(o.view(-1, self.embed_dim_)) + return + + def _splitfuse_ffn(self, input_embdings, infer_state: LlamaSplitFuseInferStateInfo, layer_weight): + input1, token_scale, skip_out = self._awquant_ffn_norm(input_embdings, infer_state, layer_weight) + ffn_out = self._ffn(input1, token_scale, infer_state, layer_weight) + input1 = None + if self.world_size_ > 1: + dist.all_reduce(ffn_out, op=dist.ReduceOp.SUM, async_op=False) + input_embdings.add_(ffn_out.view(-1, self.embed_dim_)) + return + def _awquant_matmul_ppl_int8_quant_dequant( self, input, quant_weight_params, is_prefill, token_scale=None, out=None, bias=None, has_act=False ): diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index eb679146..9cbe638c 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -183,7 +183,8 @@ async def generate( logger.debug( f"req_id:{group_request_id},start:{start_time}s,first_token_cost:{first_token_cost_ms}ms\n" f"total_cost_time:{total_cost_time_ms}ms,out_token_counter:{out_token_counter}\n" - f"mean_per_token_cost_time: {total_cost_time_ms/out_token_counter}ms" + f"mean_per_token_cost_time: {total_cost_time_ms/out_token_counter}ms\n" + f"prompt_token_num:{prompt_tokens}" ) monitor.histogram_observe("lightllm_request_inference_duration", total_cost_time_ms) monitor.histogram_observe(