From aceeb6e2d6e6549a3daf3254553b3a525f011235 Mon Sep 17 00:00:00 2001 From: hiworldwzj <30762946+hiworldwzj@users.noreply.github.com> Date: Thu, 16 Nov 2023 17:26:28 +0800 Subject: [PATCH] fix bug for lose token (#216) Co-authored-by: wangzaijun --- lightllm/server/httpserver/manager.py | 54 ++++++++++++++++----------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 0e8073d5..196fe499 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -7,6 +7,12 @@ from ..tokenizer import get_tokenizer from ..io_struct import BatchStrOut, AbortReq +class ReqStatus: + def __init__(self, req_id) -> None: + self.req_id = req_id + self.lock = asyncio.Lock() + self.event = asyncio.Event() + self.out_token_info_list = [] class HttpServerManager: def __init__( @@ -56,26 +62,34 @@ async def generate(self, prompt, sampling_params, request_id): sampling_params.stop_sentences_to_token_ids(self.tokenizer) + req_status = ReqStatus(request_id) + event = req_status.event + self.req_id_to_out_inf[request_id] = req_status + self.send_to_router.send_pyobj((prompt_ids, sampling_params, request_id)) - event = asyncio.Event() - self.req_id_to_out_inf[request_id] = ("", {}, False, event) + while True: try: await asyncio.wait_for(event.wait(), timeout=5) except asyncio.TimeoutError: pass - event.clear() - out_str, metadata, finished, _ = self.req_id_to_out_inf[request_id] - if len(metadata) != 0: - self.req_id_to_out_inf[request_id] = ("", {}, finished, event) - metadata["prompt_tokens"] = prompt_tokens - yield out_str, metadata, finished - if finished: - try: - del self.req_id_to_out_inf[request_id] - except: - pass - break + + async with req_status.lock: + event.clear() + if len(req_status.out_token_info_list) == 0: + continue + + for out_str, metadata, finished in req_status.out_token_info_list: + metadata["prompt_tokens"] = prompt_tokens + yield out_str, metadata, finished + + if finished: + try: + del self.req_id_to_out_inf[request_id] + except: + pass + return + req_status.out_token_info_list.clear() return async def abort(self, request_id): @@ -96,14 +110,10 @@ async def handle_loop(self): for req_id, text, metadata, finished, abort in recv_ans.reqs_infs: try: if not abort: - _, _, _, event = self.req_id_to_out_inf[req_id] - self.req_id_to_out_inf[req_id] = ( - text, - metadata, - finished, - event, - ) - event.set() + req_status : ReqStatus = self.req_id_to_out_inf[req_id] + async with req_status.lock: + req_status.out_token_info_list.append((text, metadata, finished)) + req_status.event.set() else: del self.req_id_to_out_inf[req_id] except: