diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index 1b78c6e67dc4d..38fe170775575 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -12,11 +12,25 @@ SpeculativeProposer) from vllm.spec_decode.proposer_worker_base import ProposerWorkerBase from vllm.spec_decode.top1_proposer import Top1Proposer -from vllm.worker.worker import Worker -from vllm.worker.hpu_worker import HPUWorker - - -class MultiStepWorker(HPUWorker, ProposerWorkerBase): +from vllm.platforms import current_platform +from vllm.utils import is_neuron, is_openvino, is_xpu +if is_neuron(): + from vllm.worker.neuron_worker import NeuronWorker as WorkerBaseCls +elif current_platform.is_hpu(): + from vllm.worker.hpu_worker import HPUWorker as WorkerBaseCls +elif is_openvino: + from vllm.worker.openvino_worker import OpenVINOWorker as WorkerBaseCls +elif current_platform.is_cpu(): + from vllm.worker.cpu_worker import CPUWorker as WorkerBaseCls +elif current_platform.is_tpu(): + from vllm.worker.tpu_worker import TPUWorker as WorkerBaseCls +elif is_xpu(): + from vllm.worker.xpu_worker import XPUWorker as WorkerBaseCls +else: + from vllm.worker.worker import Worker as WorkerBaseCls + + +class MultiStepWorker(WorkerBaseCls, ProposerWorkerBase): """The MultiStepWorker is equivalent to a Worker except that it allows multiple forward passes in a single call, assuming the scheduler has allocated enough space to store the additional KV. This reduces overhead