Skip to content

Commit

Permalink
add stop_sequences way to stop generate tokens (#137)
Browse files Browse the repository at this point in the history
Co-authored-by: wangzaijun <[email protected]>
  • Loading branch information
hiworldwzj and wangzaijun authored Sep 19, 2023
1 parent 2b6f5d1 commit 05540ea
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 7 deletions.
6 changes: 1 addition & 5 deletions lightllm/server/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,11 +187,6 @@ async def chat_completions(
HTTPStatus.BAD_REQUEST, "The function call feature is not supported"
)

if request.stop is not None:
return create_error_response(
HTTPStatus.BAD_REQUEST, "The stop parameter is not currently supported"
)

created_time = int(time.time())
prompt = await build_prompt(request)
sampling_params = SamplingParams(
Expand All @@ -203,6 +198,7 @@ async def chat_completions(
top_k=request.top_k,
ignore_eos=request.ignore_eos,
max_new_tokens=request.max_tokens,
stop_sequences=request.stop
)
sampling_params.verify()

Expand Down
2 changes: 2 additions & 0 deletions lightllm/server/httpserver/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ async def generate(self, prompt, sampling_params, request_id):
raise ValueError(
f"the req token total len + 1 (input len + output len + 1) is too long > max_total_token_num:{self.total_token_num}"
)

sampling_params.stop_sentences_to_token_ids(self.tokenizer)

self.send_to_router.send_pyobj((prompt_ids, sampling_params, request_id))
event = asyncio.Event()
Expand Down
12 changes: 12 additions & 0 deletions lightllm/server/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def to_req_detokenization_state(self):
if self.output_metadata_list:
out.gen_metadata.update(self.output_metadata_list[-1])
return out

def stop_sequences_matched(self):
for stop_token_ids in self.sample_params.stop_sequences:
stop_len = len(stop_token_ids)
if stop_len > 0:
if len(self.output_ids) >= stop_len:
if all(self.output_ids[-(stop_len - i)] == stop_token_ids[i] for i in range(stop_len)):
return True
return False

def __repr__(self):
return (f"request_id(n={self.request_id}, "
Expand Down Expand Up @@ -78,6 +87,9 @@ def calcu_used_tokens(self):
def mark_finished_req(self, eos_id):
has_new_finish = False
for req in self.reqs:
if req.stop_sequences_matched():
req.has_generate_finished = True
has_new_finish = True
if req.output_ids[-1] == eos_id and req.sample_params.ignore_eos == False:
req.has_generate_finished = True
has_new_finish = True
Expand Down
22 changes: 20 additions & 2 deletions lightllm/server/sampling_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def __init__(
top_p: float = 1.0,
top_k: int = -1, # -1 is for all
ignore_eos: bool = False,
max_new_tokens: int = 16
max_new_tokens: int = 16,
stop_sequences: Optional[Union[str, List[str]]] = None # 停止句子条件
) -> None:
self.do_sample = do_sample
self.presence_penalty = presence_penalty
Expand All @@ -25,6 +26,7 @@ def __init__(
self.top_k = top_k
self.ignore_eos = ignore_eos
self.max_new_tokens = max_new_tokens
self.stop_sequences = stop_sequences
if self.do_sample == False:
self.temperature = 1.0
self.top_p = 1.0
Expand All @@ -47,7 +49,23 @@ def verify(self):
raise ValueError(f"top_k must be -1 (disable), or at least 1, got {self.top_k}.")
if self.max_new_tokens < 1:
raise ValueError(f"max_new_tokens must be at least 1 , got {self.max_new_tokens}.")
return
return

def stop_sentences_to_token_ids(self, tokenizer):
if self.stop_sequences is None:
self.stop_sequences = []
else:
if isinstance(self.stop_sequences, str):
self.stop_sequences = [self.stop_sequences]
new_stop_sequences = []
for stop_str in self.stop_sequences:
stop_str_ids = tokenizer.encode(stop_str)
if stop_str_ids is not None and len(stop_str_ids) >= 1: # remove bos_token_id
stop_str_ids = stop_str_ids[1:]
if len(stop_str_ids) > 0:
new_stop_sequences.append(stop_str_ids)
self.stop_sequences = new_stop_sequences
return

def to_dict(self):
ret = {}
Expand Down

0 comments on commit 05540ea

Please sign in to comment.