diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index a863ab55..c2fe98e4 100644 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -187,11 +187,6 @@ async def chat_completions( HTTPStatus.BAD_REQUEST, "The function call feature is not supported" ) - if request.stop is not None: - return create_error_response( - HTTPStatus.BAD_REQUEST, "The stop parameter is not currently supported" - ) - created_time = int(time.time()) prompt = await build_prompt(request) sampling_params = SamplingParams( @@ -203,6 +198,7 @@ async def chat_completions( top_k=request.top_k, ignore_eos=request.ignore_eos, max_new_tokens=request.max_tokens, + stop_sequences=request.stop ) sampling_params.verify() diff --git a/lightllm/server/httpserver/manager.py b/lightllm/server/httpserver/manager.py index 8fd62897..d7f5f4d0 100644 --- a/lightllm/server/httpserver/manager.py +++ b/lightllm/server/httpserver/manager.py @@ -53,6 +53,8 @@ async def generate(self, prompt, sampling_params, request_id): raise ValueError( f"the req token total len + 1 (input len + output len + 1) is too long > max_total_token_num:{self.total_token_num}" ) + + sampling_params.stop_sentences_to_token_ids(self.tokenizer) self.send_to_router.send_pyobj((prompt_ids, sampling_params, request_id)) event = asyncio.Event() diff --git a/lightllm/server/io_struct.py b/lightllm/server/io_struct.py index a6cdde7f..5324ee26 100644 --- a/lightllm/server/io_struct.py +++ b/lightllm/server/io_struct.py @@ -26,6 +26,15 @@ def to_req_detokenization_state(self): if self.output_metadata_list: out.gen_metadata.update(self.output_metadata_list[-1]) return out + + def stop_sequences_matched(self): + for stop_token_ids in self.sample_params.stop_sequences: + stop_len = len(stop_token_ids) + if stop_len > 0: + if len(self.output_ids) >= stop_len: + if all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)): + return True + return False def __repr__(self): return (f"request_id(n={self.request_id}, " @@ -78,6 +87,9 @@ def calcu_used_tokens(self): def mark_finished_req(self, eos_id): has_new_finish = False for req in self.reqs: + if req.stop_sequences_matched(): + req.has_generate_finished = True + has_new_finish = True if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False: req.has_generate_finished = True has_new_finish = True diff --git a/lightllm/server/sampling_params.py b/lightllm/server/sampling_params.py index dbe4dd80..8af532df 100644 --- a/lightllm/server/sampling_params.py +++ b/lightllm/server/sampling_params.py @@ -15,7 +15,8 @@ def __init__( top_p: float = 1.0, top_k: int = -1, # -1 is for all ignore_eos: bool = False, - max_new_tokens: int = 16 + max_new_tokens: int = 16, + stop_sequences: Optional[Union[str, List[str]]] = None # 停止句子条件 ) -> None: self.do_sample = do_sample self.presence_penalty = presence_penalty @@ -25,6 +26,7 @@ def __init__( self.top_k = top_k self.ignore_eos = ignore_eos self.max_new_tokens = max_new_tokens + self.stop_sequences = stop_sequences if self.do_sample == False: self.temperature = 1.0 self.top_p = 1.0 @@ -47,7 +49,23 @@ def verify(self): raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.") if self.max_new_tokens < 1: raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.") - return + return + + def stop_sentences_to_token_ids(self, tokenizer): + if self.stop_sequences is None: + self.stop_sequences = [] + else: + if isinstance(self.stop_sequences, str): + self.stop_sequences = [self.stop_sequences] + new_stop_sequences = [] + for stop_str in self.stop_sequences: + stop_str_ids = tokenizer.encode(stop_str) + if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id + stop_str_ids = stop_str_ids[1:] + if len(stop_str_ids) > 0: + new_stop_sequences.append(stop_str_ids) + self.stop_sequences = new_stop_sequences + return def to_dict(self): ret = {}