Skip to content

Commit

Permalink
Review comments based fixings
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
Chendi Xue authored and xuechendi committed Oct 16, 2024
1 parent efc17b7 commit f04901b
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 10 deletions.
6 changes: 2 additions & 4 deletions examples/offline_inference_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,9 @@ def time_generation(llm: LLM, prompts: List[str],
del llm
gc.collect()
print("================= Summary =====================")
print("input is ", prompts)
print()
print("input is ", prompts, "\n")
print("Non Spec Decode - latency_per_token is ", latency_per_token_non_spec)
print("Generated Text is :", ret_non_spec)
print()
print("Generated Text is :", ret_non_spec, "\n")
print("Spec Decode - latency_per_token is ", latency_per_token_spec)
print("Generated Text is :", ret_spec)

2 changes: 2 additions & 0 deletions vllm/worker/hpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,8 @@ def compute_logits(self, *args, **kwargs):
def sample(self, *args, **kwargs):
return self.model.sample(*args, **kwargs)

# sampler property will be used by spec_decode_worker
# don't rename
@property
def sampler(self):
return self.model.sampler
Expand Down
24 changes: 18 additions & 6 deletions vllm/worker/selector.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,25 @@
from vllm.config import DeviceConfig

from vllm.worker.worker import Worker
from vllm.worker.hpu_worker import HPUWorker

def init_worker(*args, **kwargs):
device_config: DeviceConfig = kwargs.get("device_config")
if device_config.device_type == 'cuda':
return Worker(*args, **kwargs)
if device_config.device_type == 'neuron':
from vllm.worker.neuron_worker import NeuronWorker
return NeuronWorker(*args, **kwargs)
elif device_config.device_type == 'tpu':
from vllm.worker.tpu_worker import TPUWorker
return TPUWorker(*args, **kwargs)
elif device_config.device_type == 'cpu':
from vllm.worker.cpu_worker import CPUWorker
return CPUWorker(*args, **kwargs)
elif device_config.device_type == 'hpu':
from vllm.worker.hpu_worker import HPUWorker
return HPUWorker(*args, **kwargs)
elif device_config.device_type == 'openvino':
from vllm.worker.openvino_worker import OpenVINOWorker
return OpenVINOWorker(*args, **kwargs)
elif device_config.device_type == 'xpu':
from vllm.worker.xpu_worker import XPUWorker
return XPUWorker(*args, **kwargs)
else:
raise NotImplementedError("Please help to add your preferred backend")
from vllm.worker.worker import Worker
return Worker(*args, **kwargs)

0 comments on commit f04901b

Please sign in to comment.