Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[New Feature][Habana-Main] speculative_decoding HPU support #375

Open
wants to merge 7 commits into
base: habana_main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 67 additions & 0 deletions examples/offline_inference_spec_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import gc
import time
from typing import List

from vllm import LLM, SamplingParams


def time_generation(llm: LLM, prompts: List[str],
sampling_params: SamplingParams):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
llm.generate(prompts, sampling_params)
llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
latency_per_token = (end - start) / sum(
[len(o.outputs[0].token_ids) for o in outputs])
# Print the outputs.
ret = []
for output in outputs:
generated_text = output.outputs[0].text
ret.append(generated_text)
return ret, latency_per_token


if __name__ == "__main__":

# Sample prompts.
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM without spec decoding
print("==============Without speculation==================")
llm = LLM(model="facebook/opt-6.7b")

ret_non_spec, latency_per_token_non_spec = time_generation(
llm, prompts, sampling_params)

del llm
gc.collect()

# Create an LLM with spec decoding
print("==============With speculation=====================")
llm = LLM(
model="facebook/opt-6.7b",
speculative_model="facebook/opt-125m",
num_speculative_tokens=5,
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
)

ret_spec, latency_per_token_spec = time_generation(llm, prompts,
sampling_params)

del llm
gc.collect()
print("================= Summary =====================")
print("input is ", prompts, "\n")
print("Non Spec Decode - latency_per_token is ",
latency_per_token_non_spec)
print("Generated Text is :", ret_non_spec, "\n")
print("Spec Decode - latency_per_token is ", latency_per_token_spec)
print("Generated Text is :", ret_spec)
12 changes: 6 additions & 6 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
dtype=torch.int64)

rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
recovered_token_ids,
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
Expand Down Expand Up @@ -239,7 +239,7 @@ def get_seeded_seqs():

for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
seeded_seqs = get_seeded_seqs()
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,

rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
Expand Down Expand Up @@ -401,7 +401,7 @@ def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
self.vocab_size = vocab_size
self.vocab_range = (0, vocab_size)

self.rejection_sampler.init_gpu_tensors(device=0)
self.rejection_sampler.init_tensors(device=0)

# Keep test simple, use k=1
self.k = 1
Expand Down
18 changes: 9 additions & 9 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
"""
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
Expand Down Expand Up @@ -222,7 +222,7 @@ def test_temperature_zero_target_distribution(seed: int, device: str):
torch.set_default_device(device)

typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_mixed_target_distribution(seed: int, device: str):
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
Expand Down Expand Up @@ -341,7 +341,7 @@ def test_accept_tokens_partially(seed: int, device: str):
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
Expand Down Expand Up @@ -462,7 +462,7 @@ def test_get_recovered_token_ids(seed: int, device: str):
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
actual_replacement_tokens = (
Expand Down
12 changes: 6 additions & 6 deletions tests/spec_decode/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_initial_call_returns_none():
spec_decode_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None

Expand All @@ -46,7 +46,7 @@ def test_second_call_returns_metrics():
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
Expand All @@ -66,7 +66,7 @@ def test_nonzero_rank_noop(rank):
spec_decode_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=rank)
collector.init_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_noop_until_time():
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_timer_is_reset():
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k)

Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

class HPUAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "hpu-attn"

@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUAttentionImpl
Expand Down
18 changes: 13 additions & 5 deletions vllm/executor/hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,23 @@ def _get_worker_kwargs(
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=rank == 0,
speculative_config=self.speculative_config,
)

def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.hpu_worker",
worker_class_name="HPUWorker",
)
if self.speculative_config is None:
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.hpu_worker",
worker_class_name="HPUWorker",
)
else:
wrapper = WorkerWrapperBase(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker",
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
Expand Down Expand Up @@ -197,7 +204,8 @@ def stop_profile(self) -> None:
self.driver_worker.stop_profile()

def shutdown(self) -> None:
self.driver_worker.shutdown_inc()
if hasattr(self.driver_worker, 'shutdown_inc'):
self.driver_worker.shutdown_inc()


class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase):
Expand Down
11 changes: 6 additions & 5 deletions vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,12 @@ def __init__(self, strict_mode: bool = False):
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0

def init_gpu_tensors(self, device: Union[int, str]) -> None:
def init_tensors(self,
device: Union[int, str],
device_type: str = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"cuda:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
Expand Down Expand Up @@ -77,7 +77,8 @@ def _create_output(
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
# TODO: HPU has an issue when handling squeeze with all dim as 1
bonus_token_ids = bonus_token_ids.squeeze(-1)
xuechendi marked this conversation as resolved.
Show resolved Hide resolved
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
Expand Down
14 changes: 14 additions & 0 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def score_proposals(
contracted = self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
Expand Down Expand Up @@ -216,6 +217,7 @@ def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
num_scoring_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
Expand All @@ -229,6 +231,18 @@ def _contract_batch_all_spec(
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape

(
target_sampler_output.sampled_token_ids,
target_sampler_output.sampled_token_probs,
target_sampler_output.logprobs,
target_sampler_output.hidden_states,
_,
_,
_,
_,
) = self._split_scoring_output(target_sampler_output,
num_scoring_tokens)

# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k + 1)
Expand Down
Loading
Loading