-
Notifications
You must be signed in to change notification settings - Fork 38
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Adding vllm speculative decoding example #317
base: main
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
base_image: | ||
image: nvcr.io/nvidia/pytorch:23.11-py3 | ||
python_executable_path: /usr/bin/python3 | ||
build_commands: [] | ||
environment_variables: | ||
HF_TOKEN: "" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why here over secrets? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. vLLM reads only this specific environment variable for the access token. It doesn't work with secrets |
||
external_package_dirs: [] | ||
model_metadata: | ||
main_model: meta-llama/Meta-Llama-3-8B-Instruct | ||
assistant_model: ibm-fms/llama3-8b-accelerator | ||
tensor_parallel: 1 | ||
max_num_seqs: 16 | ||
model_name: vLLM Speculative Decoding | ||
python_version: py310 | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. drop python version if using base image |
||
requirements: | ||
- git+https://github.com/vllm-project/vllm@9def10664e8b54dcc5c6114f2895bc9e712bf182 | ||
resources: | ||
accelerator: A100 | ||
use_gpu: true | ||
system_packages: | ||
- python3.10-venv | ||
runtime: | ||
predict_concurrency: 128 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,65 @@ | ||
import logging | ||
import subprocess | ||
import uuid | ||
|
||
from vllm import SamplingParams | ||
from vllm.engine.arg_utils import AsyncEngineArgs | ||
from vllm.engine.async_llm_engine import AsyncLLMEngine | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class Model: | ||
def __init__(self, **kwargs): | ||
self._config = kwargs["config"] | ||
self.model = None | ||
self.llm_engine = None | ||
self.model_args = None | ||
|
||
num_gpus = self._config["model_metadata"]["tensor_parallel"] | ||
logger.info(f"num GPUs ray: {num_gpus}") | ||
command = f"ray start --head --num-gpus={num_gpus}" | ||
subprocess.check_output(command, shell=True, text=True) | ||
|
||
def load(self): | ||
model_metadata = self._config["model_metadata"] | ||
logger.info(f"main model: {model_metadata['main_model']}") | ||
logger.info(f"assistant model: {model_metadata['assistant_model']}") | ||
logger.info(f"tensor parallelism: {model_metadata['tensor_parallel']}") | ||
logger.info(f"max num seqs: {model_metadata['max_num_seqs']}") | ||
|
||
self.model_args = AsyncEngineArgs( | ||
model=model_metadata["main_model"], | ||
speculative_model=model_metadata["assistant_model"], | ||
trust_remote_code=True, | ||
tensor_parallel_size=model_metadata["tensor_parallel"], | ||
max_num_seqs=model_metadata["max_num_seqs"], | ||
dtype="half", | ||
use_v2_block_manager=True, | ||
enforce_eager=True, | ||
) | ||
self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args) | ||
|
||
async def predict(self, model_input): | ||
prompt = model_input.pop("prompt") | ||
stream = model_input.pop("stream", True) | ||
|
||
sampling_params = SamplingParams(**model_input) | ||
idx = str(uuid.uuid4().hex) | ||
vllm_generator = self.llm_engine.generate(prompt, sampling_params, idx) | ||
|
||
async def generator(): | ||
full_text = "" | ||
async for output in vllm_generator: | ||
text = output.outputs[0].text | ||
delta = text[len(full_text) :] | ||
full_text = text | ||
yield delta | ||
|
||
if stream: | ||
return generator() | ||
else: | ||
full_text = "" | ||
async for delta in generator(): | ||
full_text += delta | ||
return {"text": full_text} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
just cuious why we need this base image? Can you add a cooment?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Without this base image the build does not succeed. The baseten base image does not have
nvcc
, which is required for the developer build of vLLM.