From e502de8c8c79df724d86ced99c9fbed3b8eb2da4 Mon Sep 17 00:00:00 2001 From: wangzaijun Date: Fri, 18 Oct 2024 13:42:40 +0800 Subject: [PATCH] support to read eos_id from config.json --- lightllm/server/api_server.py | 24 ++++++++++++++++-------- lightllm/utils/config_utils.py | 18 ++++++++++++++++++ 2 files changed, 34 insertions(+), 8 deletions(-) create mode 100644 lightllm/utils/config_utils.py diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index d29b5e67..91126425 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -365,7 +365,9 @@ def make_argument_parser() -> argparse.ArgumentParser: default=None, help="max tokens num for new cat batch, it control prefill batch size to Preventing OOM", ) - parser.add_argument("--eos_id", nargs="+", type=int, default=[2], help="eos stop token id") + parser.add_argument( + "--eos_id", nargs="+", type=int, default=None, help="eos stop token id, if None, will load from config.json" + ) parser.add_argument( "--running_max_req_size", type=int, default=1000, help="the max size for forward requests in the same time" ) @@ -505,7 +507,7 @@ def main(): assert args.max_req_input_len < args.max_req_total_len assert args.max_req_total_len <= args.max_total_token_num assert not (args.beam_mode and args.use_dynamic_prompt_cache), "Beam mode incompatible with dynamic prompt cache" - + # splitfuse_mode 和 cuda_graph 不能同时开启 if args.splitfuse_mode: assert args.disable_cudagraph @@ -537,6 +539,18 @@ def main(): batch_max_tokens = max(batch_max_tokens, args.splitfuse_block_size) args.batch_max_tokens = batch_max_tokens + # help to manage data stored on Ceph + if "s3://" in args.model_dir: + from lightllm.utils.petrel_helper import s3_model_prepare + + s3_model_prepare(args.model_dir) + + # 如果args.eos_id 是 None, 从 config.json 中读取 eos_token_id 相关的信息,赋值给 args + if args.eos_id is None: + from lightllm.utils.config_utils import get_eos_token_ids + + args.eos_id = get_eos_token_ids(args.model_dir) + logger.info(f"all start args:{args}") can_use_ports = alloc_can_use_network_port(num=6 + args.tp, used_nccl_port=args.nccl_port) @@ -560,12 +574,6 @@ def main(): global metric_client metric_client = MetricClient(metric_port) - # help to manage data stored on Ceph - if "s3://" in args.model_dir: - from lightllm.utils.petrel_helper import s3_model_prepare - - s3_model_prepare(args.model_dir) - global httpserver_manager httpserver_manager = HttpServerManager( args, diff --git a/lightllm/utils/config_utils.py b/lightllm/utils/config_utils.py new file mode 100644 index 00000000..dd92d271 --- /dev/null +++ b/lightllm/utils/config_utils.py @@ -0,0 +1,18 @@ +import json +import os + + +def get_config_json(model_path: str): + with open(os.path.join(model_path, "config.json"), "r") as file: + json_obj = json.load(file) + return json_obj + + +def get_eos_token_ids(model_path: str): + config_json = get_config_json(model_path) + eos_token_id = config_json["eos_token_id"] + if isinstance(eos_token_id, int): + return [eos_token_id] + if isinstance(eos_token_id, list): + return eos_token_id + assert False, "error eos_token_id format in config.json"