diff --git a/tests/utilization/fixtures.py b/tests/utilization/fixtures.py index 4366f37c..e93b7757 100644 --- a/tests/utilization/fixtures.py +++ b/tests/utilization/fixtures.py @@ -72,7 +72,7 @@ def dataset( chat_template: Optional[str] = None, ranking_type: Optional[str] = None, batch_size: int = 1, - model_backend = "huggingface", + model_backend="huggingface", **kwargs ): diff --git a/tests/utilization/utils/test_parse_arguments.py b/tests/utilization/utils/test_parse_arguments.py index 32c0d132..e097325e 100644 --- a/tests/utilization/utils/test_parse_arguments.py +++ b/tests/utilization/utils/test_parse_arguments.py @@ -9,11 +9,11 @@ def test_default_vllm(): model_args, dataset_args, evaluation_args = parse_argument(['-m', 'a-random-fake-model', '-d', 'nq', 'quac']) assert model_args.model_backend == "vllm" - assert model_args.prefix_caching is False + assert model_args.prefix_caching is None # vllm default is False def test_no_prefix_caching(): - # currently vllm doesn't support returning logprob for prefix caching + # batch size is 1, so prefix caching is not used model_args, dataset_args, evaluation_args = parse_argument([ '-m', 'a-random-fake-model', '-d', 'nq', 'mmlu', '-b', '1' ]) @@ -27,7 +27,7 @@ def test_default_prefix_caching(): '-m', 'a-random-fake-model', '-d', 'nq', 'mmlu', '-b', '16' ]) assert model_args.model_backend == "huggingface" - assert model_args.prefix_caching is True + assert model_args.prefix_caching is None # huggingface default is True def test_default_no_efficient(): diff --git a/utilization/dataset/dataset.py b/utilization/dataset/dataset.py index b629e23b..0ea8705b 100644 --- a/utilization/dataset/dataset.py +++ b/utilization/dataset/dataset.py @@ -148,6 +148,8 @@ def __init__( self.ranking_type = args.ranking_type self.model_type = model.model_type self.prefix_caching = model.args.prefix_caching + if self.prefix_caching is None: + self.prefix_caching = True self.instance_format = "{source}{target}" if args.instruction: self.instruction = args.instruction diff --git a/utilization/load_dataset.py b/utilization/load_dataset.py index 8fd44445..76ed19bf 100644 --- a/utilization/load_dataset.py +++ b/utilization/load_dataset.py @@ -369,12 +369,6 @@ def load_datasets( args.auto_batch_size = False logger.info("Setting batch_size to -1, since vllm can automatically planning the optimal batch and order.") - if model.args.prefix_caching and model.model_backend != "huggingface": - logger.warning( - "Prefix caching is only available for HuggingFaceModel. Automatically set prefix_caching to False" - ) - model.args.prefix_caching = False - # get all the dataset classes datasets = [] for d in args.dataset_names: diff --git a/utilization/model/vllm_model.py b/utilization/model/vllm_model.py index 2713314f..c0ed60d5 100644 --- a/utilization/model/vllm_model.py +++ b/utilization/model/vllm_model.py @@ -48,15 +48,8 @@ def __init__(self, args: "ModelArguments", **kwargs): self.args = args logger.info(f"Trying to load {args.model_name_or_path} using vllm...") - self.vllm_version = version.parse(vllm.__version__) - if args.prefix_caching: - if self.is_legacy_vllm(): - logger.warning( - f"vllm version ({vllm.__version__}) is lower than 0.4.0, prefix_caching is not supported." - ) - else: - kwargs["enable_prefix_caching"] = True - self.use_cache = True + if args.prefix_caching is not None: + kwargs["enable_prefix_caching"] = args.prefix_caching self.model = LLM( model=args.model_name_or_path, @@ -77,10 +70,21 @@ def __init__(self, args: "ModelArguments", **kwargs): ) self.tokenizer.chat_template = args.chat_template - def is_legacy_vllm(self): - return self.vllm_version < version.parse("0.4.0") + @property + def use_cache(self): + return self.model.llm_engine.cache_config.enable_prefix_caching + + @use_cache.setter + def use_cache(self, value): + self.model.llm_engine.cache_config.enable_prefix_caching = value def set_ppl_args(self, **extra_model_args): + if self.use_cache: + logger.warning( + "Prefix caching is enabled for vllm. However, it is a known issue for vllm to return logprobs with prefix caching enabled. See https://github.com/vllm-project/vllm/issues/3914 for details." + ) + self.use_cache = False + self.ppl_kwargs = SamplingParams(max_tokens=1, prompt_logprobs=0) if len(extra_model_args) > 0: logger.warning(f"Unused generation arguments: {extra_model_args}") @@ -144,6 +148,12 @@ def generation(self, batched_inputs: List[Conversation]) -> List[str]: return [c.get_generation_results() for c in batched_inputs] def set_prob_args(self, **extra_model_args): + if self.use_cache: + logger.warning( + "Prefix caching is enabled for vllm. However, it is a known issue for vllm to return logprobs with prefix caching enabled. See https://github.com/vllm-project/vllm/issues/3914 for details." + ) + self.use_cache = False + self.prob_kwargs = SamplingParams(max_tokens=1, temperature=0) self.candidate_ids = extra_model_args.pop("candidate_ids", None) diff --git a/utilization/utils/arguments.py b/utilization/utils/arguments.py index ae45f6ca..320117aa 100644 --- a/utilization/utils/arguments.py +++ b/utilization/utils/arguments.py @@ -82,7 +82,7 @@ class ModelArguments(ModelBackendMixin): default="auto", help="The device map for model and data", ) - prefix_caching: bool = HfArg( + prefix_caching: Optional[bool] = HfArg( default=None, help="Whether to cache prefix in get_ppl mode", ) @@ -369,13 +369,6 @@ def __post_init__(self): if self.is_vllm_model(): self.vllm_gpu_memory_utilization = 0.9 - if self.prefix_caching is None: - # prefix_caching is still experimental - self.prefix_caching = False - - elif self.is_huggingface_model(): - if self.prefix_caching is None: - self.prefix_caching = True # argparse encodes string with unicode_escape, decode it to normal string, e.g., "\\n" -> "\n" if self.stop is not None: @@ -626,16 +619,11 @@ def check_args(model_args: ModelArguments, dataset_args: DatasetArguments, evalu d not in DEFAULT_VLLM_DATASETS for d in dataset_args.dataset_names ): model_args.model_backend = "huggingface" - if not model_args.passed_in_commandline("prefix_caching"): - model_args.prefix_caching = True model_args.seed = int(evaluation_args.seed) - if dataset_args.batch_size == 1 and model_args.prefix_caching: - if model_args.is_local_model(): - logger.warning( - "Prefix caching is not supported for batch_size=1, automatically set prefix_caching to False." - ) + if dataset_args.batch_size == 1 and model_args.prefix_caching is None and model_args.is_huggingface_model(): + logger.warning("Prefix caching is not supported for batch_size=1, automatically set prefix_caching to False.") model_args.prefix_caching = False # check models @@ -646,14 +634,6 @@ def check_args(model_args: ModelArguments, dataset_args: DatasetArguments, evalu f"chat/completions endpoint model {model_args.model_name_or_path} doesn't support batch_size > 1, automatically set batch_size to 1." ) - # vllm has its own prefix caching mechanism - if model_args.prefix_caching and "expandable_segments" not in os.environ.get( - "PYTORCH_CUDA_ALLOC_CONF", "" - ) and model_args.is_huggingface_model(): - logger.warning( - f"Prefix caching might results in cuda memory fragmentation, which can be mitigated by setting `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`. See https://pytorch.org/docs/stable/notes/cuda.html#environment-variables for details." - ) - # check dataset if "vicuna_bench" in dataset_args.dataset_names and model_args.openai_api_key is None: raise ValueError( @@ -675,7 +655,7 @@ def check_args(model_args: ModelArguments, dataset_args: DatasetArguments, evalu "Instruction does not include any variable, so the input remains unchanged across the insatnces. Try to use f-string or jinja2 format to include variables like `{source}` or `{problem}`. See dataset documentation for details." ) - if evaluation_args.dry_run and model_args.prefix_caching: + if evaluation_args.dry_run: model_args.prefix_caching = False args_ignored = set()