Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding vllm speculative decoding example #317

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions vllm-speculative-decoding/config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
base_image:
image: nvcr.io/nvidia/pytorch:23.11-py3
python_executable_path: /usr/bin/python3
Comment on lines +2 to +4
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just cuious why we need this base image? Can you add a cooment?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Without this base image the build does not succeed. The baseten base image does not have nvcc, which is required for the developer build of vLLM.

build_commands: []
environment_variables:
HF_TOKEN: ""
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why here over secrets?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vLLM reads only this specific environment variable for the access token. It doesn't work with secrets

external_package_dirs: []
model_metadata:
main_model: meta-llama/Meta-Llama-3-8B-Instruct
assistant_model: ibm-fms/llama3-8b-accelerator
tensor_parallel: 1
max_num_seqs: 16
model_name: vLLM Speculative Decoding
python_version: py310
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

drop python version if using base image

requirements:
- git+https://github.com/vllm-project/vllm@9def10664e8b54dcc5c6114f2895bc9e712bf182
resources:
accelerator: A100
use_gpu: true
system_packages:
- python3.10-venv
runtime:
predict_concurrency: 128
Empty file.
65 changes: 65 additions & 0 deletions vllm-speculative-decoding/model/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import logging
import subprocess
import uuid

from vllm import SamplingParams
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine

logger = logging.getLogger(__name__)


class Model:
def __init__(self, **kwargs):
self._config = kwargs["config"]
self.model = None
self.llm_engine = None
self.model_args = None

num_gpus = self._config["model_metadata"]["tensor_parallel"]
logger.info(f"num GPUs ray: {num_gpus}")
command = f"ray start --head --num-gpus={num_gpus}"
subprocess.check_output(command, shell=True, text=True)

def load(self):
model_metadata = self._config["model_metadata"]
logger.info(f"main model: {model_metadata['main_model']}")
logger.info(f"assistant model: {model_metadata['assistant_model']}")
logger.info(f"tensor parallelism: {model_metadata['tensor_parallel']}")
logger.info(f"max num seqs: {model_metadata['max_num_seqs']}")

self.model_args = AsyncEngineArgs(
model=model_metadata["main_model"],
speculative_model=model_metadata["assistant_model"],
trust_remote_code=True,
tensor_parallel_size=model_metadata["tensor_parallel"],
max_num_seqs=model_metadata["max_num_seqs"],
dtype="half",
use_v2_block_manager=True,
enforce_eager=True,
)
self.llm_engine = AsyncLLMEngine.from_engine_args(self.model_args)

async def predict(self, model_input):
prompt = model_input.pop("prompt")
stream = model_input.pop("stream", True)

sampling_params = SamplingParams(**model_input)
idx = str(uuid.uuid4().hex)
vllm_generator = self.llm_engine.generate(prompt, sampling_params, idx)

async def generator():
full_text = ""
async for output in vllm_generator:
text = output.outputs[0].text
delta = text[len(full_text) :]
full_text = text
yield delta

if stream:
return generator()
else:
full_text = ""
async for delta in generator():
full_text += delta
return {"text": full_text}
Loading