From f04901b38c9cd80921c92d1891b4d4c36c17f9f4 Mon Sep 17 00:00:00 2001 From: Chendi Xue Date: Thu, 17 Oct 2024 01:11:17 +0300 Subject: [PATCH] Review comments based fixings Signed-off-by: Chendi Xue Signed-off-by: Chendi Xue --- examples/offline_inference_spec_decode.py | 6 ++---- vllm/worker/hpu_model_runner.py | 2 ++ vllm/worker/selector.py | 24 +++++++++++++++++------ 3 files changed, 22 insertions(+), 10 deletions(-) diff --git a/examples/offline_inference_spec_decode.py b/examples/offline_inference_spec_decode.py index 04a5bab9fa362..03543ff47de69 100644 --- a/examples/offline_inference_spec_decode.py +++ b/examples/offline_inference_spec_decode.py @@ -56,11 +56,9 @@ def time_generation(llm: LLM, prompts: List[str], del llm gc.collect() print("================= Summary =====================") - print("input is ", prompts) - print() + print("input is ", prompts, "\n") print("Non Spec Decode - latency_per_token is ", latency_per_token_non_spec) - print("Generated Text is :", ret_non_spec) - print() + print("Generated Text is :", ret_non_spec, "\n") print("Spec Decode - latency_per_token is ", latency_per_token_spec) print("Generated Text is :", ret_spec) \ No newline at end of file diff --git a/vllm/worker/hpu_model_runner.py b/vllm/worker/hpu_model_runner.py index 73ad49c214636..6046576854ca6 100644 --- a/vllm/worker/hpu_model_runner.py +++ b/vllm/worker/hpu_model_runner.py @@ -338,6 +338,8 @@ def compute_logits(self, *args, **kwargs): def sample(self, *args, **kwargs): return self.model.sample(*args, **kwargs) + # sampler property will be used by spec_decode_worker + # don't rename @property def sampler(self): return self.model.sampler diff --git a/vllm/worker/selector.py b/vllm/worker/selector.py index 5874121cef81f..b06122f9139c2 100644 --- a/vllm/worker/selector.py +++ b/vllm/worker/selector.py @@ -1,13 +1,25 @@ from vllm.config import DeviceConfig -from vllm.worker.worker import Worker -from vllm.worker.hpu_worker import HPUWorker - def init_worker(*args, **kwargs): device_config: DeviceConfig = kwargs.get("device_config") - if device_config.device_type == 'cuda': - return Worker(*args, **kwargs) + if device_config.device_type == 'neuron': + from vllm.worker.neuron_worker import NeuronWorker + return NeuronWorker(*args, **kwargs) + elif device_config.device_type == 'tpu': + from vllm.worker.tpu_worker import TPUWorker + return TPUWorker(*args, **kwargs) + elif device_config.device_type == 'cpu': + from vllm.worker.cpu_worker import CPUWorker + return CPUWorker(*args, **kwargs) elif device_config.device_type == 'hpu': + from vllm.worker.hpu_worker import HPUWorker return HPUWorker(*args, **kwargs) + elif device_config.device_type == 'openvino': + from vllm.worker.openvino_worker import OpenVINOWorker + return OpenVINOWorker(*args, **kwargs) + elif device_config.device_type == 'xpu': + from vllm.worker.xpu_worker import XPUWorker + return XPUWorker(*args, **kwargs) else: - raise NotImplementedError("Please help to add your preferred backend") \ No newline at end of file + from vllm.worker.worker import Worker + return Worker(*args, **kwargs) \ No newline at end of file