Skip to content

Commit

Permalink
add vlmm backend (#274)
Browse files Browse the repository at this point in the history
what this PR does:

- adds vllm as backend for faster inference.

how to use:

```
 lighteval accelerate --model_args="pretrained=meta-llama/Meta-Llama-3.1-8B-Instruct,dtype=bfloat16,vllm,data_parallel_size=2" use_chat_template --tasks "leaderboard|arc:challenge|0|0,extended|ifeval|0|0,lighteval|gsm8k|5|1" output_dir="./evals/"
```

---------

Co-authored-by: Clémentine Fourrier <[email protected]>
  • Loading branch information
NathanHB and clefourrier authored Sep 3, 2024
1 parent 8c787df commit 21934d5
Show file tree
Hide file tree
Showing 8 changed files with 436 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ nanotron = [
"tensorboardX"
]
tensorboardX = ["tensorboardX"]
vllm = ["vllm", "ray", "more_itertools"]
quality = ["ruff==v0.2.2","pre-commit"]
tests = ["pytest==7.4.0"]
dev = ["lighteval[accelerate,quality,tests]"]
Expand Down
14 changes: 14 additions & 0 deletions src/lighteval/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,20 @@ def __len__(self) -> int:
"""
return self.split_end - self.split_start

def __iter__(self) -> Iterator[Request]:
"""
Iterator that yields the items of the dataset depending on the split we
are currently in. For instance, if we are in split 0, we will get the
items from index 0 to self.split_size, if we are in split 1, we will get
the items from index self.split_size to 2 * self.split_size, etc. Used
for dynamic batching.
Yields:
Any: The items of the dataset.
"""
for i in range(self.split_start, self.split_end):
yield self.sorted_data[i]

def _sorting_criteria(self, request) -> int:
raise NotImplementedError()

Expand Down
8 changes: 6 additions & 2 deletions src/lighteval/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,6 +93,7 @@ def __init__(
hlog(f"Using Data Parallelism, putting model on device {self._device}")
self.model = self.model.to(self._device)
if config.compile:
hlog("Compiling the model")
self.model.model.compile()

self.model_name = _simplify_name(config.pretrained)
Expand Down Expand Up @@ -549,9 +550,9 @@ def greedy_until(
tokenized = self.tokenizer(
context,
truncation="longest_first", # we truncate to the model max length if needed
padding="longest", # we pad to the longest sequence
padding="max_length", # we pad to the longest sequence
return_tensors="pt",
max_length=self.max_length - 1, # we always allow minimum one token of generation
max_length=max_context_continuation_size_allowed, # we always allow minimum one token of generation
add_special_tokens=self.add_special_tokens,
).to(self.device)

Expand All @@ -573,7 +574,10 @@ def greedy_until(
if max_new_tokens is None: # If generation size is not set, we go all the way
max_new_tokens = self.max_length - context_size
else:
print(self.max_length, context_size, max_new_tokens)
max_new_tokens = min(self.max_length - context_size, max_new_tokens)
if max_new_tokens < 1:
max_new_tokens = 1

prepared_batch = Batch(
input_ids=tokenized["input_ids"],
Expand Down
23 changes: 23 additions & 0 deletions src/lighteval/models/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,25 @@ def init_configs(self, env_config: EnvConfig):
return self._init_configs(self.base_model, env_config)


@dataclass
class VLLMModelConfig:
pretrained: str
gpu_memory_utilisation: float = 0.8
batch_size: int = -1
revision: str = "main"
dtype: str | None = None
tensor_parallel_size: int = 1
data_parallel_size: int = 1
max_model_length: int = 1024
swap_space: int = 4 # CPU swap space size (GiB) per GPU.
seed: int = 1234
trust_remote_code: bool = False
use_chat_template: bool = False
add_special_tokens: bool = True
multichoice_continuations_start_space: bool = True
subfolder: Optional[str] = None


@dataclass
class TGIModelConfig:
inference_server_address: str
Expand Down Expand Up @@ -279,6 +298,7 @@ def create_model_config( # noqa: C901
TGIModelConfig,
InferenceEndpointModelConfig,
DummyModelConfig,
VLLMModelConfig,
]:
"""
Create a model configuration based on the provided arguments.
Expand Down Expand Up @@ -313,6 +333,9 @@ def create_model_config( # noqa: C901
if model_args.pop("dummy", False):
return DummyModelConfig(**model_args)

if model_args.pop("vllm", False):
return VLLMModelConfig(**model_args)

model_args["accelerator"] = accelerator
model_args["use_chat_template"] = use_chat_template
model_args["compile"] = bool(model_args["compile"]) if "compile" in model_args else False
Expand Down
15 changes: 13 additions & 2 deletions src/lighteval/models/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,15 @@
BaseModelConfig,
DeltaModelConfig,
DummyModelConfig,
EnvConfig,
InferenceEndpointModelConfig,
InferenceModelConfig,
TGIModelConfig,
VLLMModelConfig,
)
from lighteval.models.tgi_model import ModelClient
from lighteval.utils.imports import NO_TGI_ERROR_MSG, is_tgi_available
from lighteval.models.vllm_model import VLLMModel
from lighteval.utils.imports import NO_TGI_ERROR_MSG, NO_VLLM_ERROR_MSG, is_tgi_available, is_vllm_available
from lighteval.utils.utils import EnvConfig


def load_model( # noqa: C901
Expand All @@ -50,6 +52,7 @@ def load_model( # noqa: C901
TGIModelConfig,
InferenceEndpointModelConfig,
DummyModelConfig,
VLLMModelConfig,
],
env_config: EnvConfig,
) -> Union[BaseModel, AdapterModel, DeltaModel, ModelClient, DummyModel]:
Expand Down Expand Up @@ -81,6 +84,9 @@ def load_model( # noqa: C901
if isinstance(config, DummyModelConfig):
return load_dummy_model(config=config, env_config=env_config)

if isinstance(config, VLLMModelConfig):
return load_model_with_accelerate_or_default(config=config, env_config=env_config)


def load_model_with_tgi(config: TGIModelConfig):
if not is_tgi_available():
Expand All @@ -106,6 +112,11 @@ def load_model_with_accelerate_or_default(
model = AdapterModel(config=config, env_config=env_config)
elif isinstance(config, DeltaModelConfig):
model = DeltaModel(config=config, env_config=env_config)
elif isinstance(config, VLLMModelConfig):
if not is_vllm_available():
raise ImportError(NO_VLLM_ERROR_MSG)
model = VLLMModel(config=config, env_config=env_config)
return model
else:
model = BaseModel(config=config, env_config=env_config)

Expand Down
Loading

0 comments on commit 21934d5

Please sign in to comment.