Skip to content

Commit

Permalink
remove speculative_decoding hardcode to Cuda and add worker selector
Browse files Browse the repository at this point in the history
There is one hardcode to HPUWorker, need to remove

Signed-off-by: Chendi.Xue <[email protected]>
  • Loading branch information
xuechendi committed Oct 8, 2024
1 parent 563184a commit efc17b7
Show file tree
Hide file tree
Showing 15 changed files with 166 additions and 45 deletions.
66 changes: 66 additions & 0 deletions examples/offline_inference_spec_decode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import gc
import time
from typing import List

from vllm import LLM, SamplingParams


def time_generation(llm: LLM, prompts: List[str],
sampling_params: SamplingParams):
# Generate texts from the prompts. The output is a list of RequestOutput
# objects that contain the prompt, generated text, and other information.
# Warmup first
llm.generate(prompts, sampling_params)
llm.generate(prompts, sampling_params)
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
latency_per_token = (end - start) / sum([len(o.outputs[0].token_ids) for o in outputs])
# Print the outputs.
ret = []
for output in outputs:
generated_text = output.outputs[0].text
ret.append(generated_text)
return ret, latency_per_token


if __name__ == "__main__":

# Sample prompts.
prompts = [
"The future of AI is",
]
sampling_params = SamplingParams(temperature=0.8, top_p=0.95)

# Create an LLM without spec decoding
print("==============Without speculation==================")
llm = LLM(model="facebook/opt-6.7b")

ret_non_spec,latency_per_token_non_spec = time_generation(llm, prompts, sampling_params)

del llm
gc.collect()

# Create an LLM with spec decoding
print("==============With speculation=====================")
llm = LLM(
model="facebook/opt-6.7b",
speculative_model="facebook/opt-125m",
num_speculative_tokens = 5,
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
)

ret_spec,latency_per_token_spec = time_generation(llm, prompts, sampling_params)

del llm
gc.collect()
print("================= Summary =====================")
print("input is ", prompts)
print()
print("Non Spec Decode - latency_per_token is ", latency_per_token_non_spec)
print("Generated Text is :", ret_non_spec)
print()
print("Spec Decode - latency_per_token is ", latency_per_token_spec)
print("Generated Text is :", ret_spec)

12 changes: 6 additions & 6 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def test_correct_output_format(which_tokens_accepted: str, seed: int,
dtype=torch.int64)

rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)
output_token_ids = rejection_sampler._create_output( # pylint: disable=protected-access
accepted,
recovered_token_ids,
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
device: str, use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
Expand Down Expand Up @@ -166,7 +166,7 @@ def test_deterministic_when_seeded(k: int, vocab_size: int, batch_size: int,
use_flashinfer: bool):
torch.set_default_device(device)
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
Expand Down Expand Up @@ -239,7 +239,7 @@ def get_seeded_seqs():

for use_flashinfer in [True, False]:
rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)
# We use seeded sequences to ensure the same tokens are accepted
# for both flashinfer and nonflashinfer backends.
seeded_seqs = get_seeded_seqs()
Expand Down Expand Up @@ -270,7 +270,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,

rejection_sampler = RejectionSampler(use_flashinfer=use_flashinfer,
strict_mode=True)
rejection_sampler.init_gpu_tensors(device=device)
rejection_sampler.init_tensors(device=device)

draft_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
target_probs = torch.rand(batch_size,
Expand Down Expand Up @@ -401,7 +401,7 @@ def __init__(self, vocab_size: int, rejection_sampler: RejectionSampler):
self.vocab_size = vocab_size
self.vocab_range = (0, vocab_size)

self.rejection_sampler.init_gpu_tensors(device=0)
self.rejection_sampler.init_tensors(device=0)

# Keep test simple, use k=1
self.k = 1
Expand Down
18 changes: 9 additions & 9 deletions tests/samplers/test_typical_acceptance_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def test_no_crash_with_varying_dims(k: int, vocab_size: int, batch_size: int,
"""
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler()
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
Expand Down Expand Up @@ -113,7 +113,7 @@ def test_raises_when_vocab_oob(above_or_below_vocab_range: str,
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
Expand Down Expand Up @@ -172,7 +172,7 @@ def test_uniform_target_distribution_accepts_all_tokens(
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
target_with_bonus_probs = torch.rand(batch_size,
k + 1,
vocab_size,
Expand Down Expand Up @@ -222,7 +222,7 @@ def test_temperature_zero_target_distribution(seed: int, device: str):
torch.set_default_device(device)

typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
# Simulate temperature 0 probability distribution for target probabilities
# and create target probabilities such that only 1 token id has
# probability 1.0
Expand Down Expand Up @@ -278,7 +278,7 @@ def test_mixed_target_distribution(seed: int, device: str):
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
# For sequences 0 and 2 set the distribution to a temperature
# zero distribution. For sequences 1 and 3 set it to a uniform
# distribution.
Expand Down Expand Up @@ -341,7 +341,7 @@ def test_accept_tokens_partially(seed: int, device: str):
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
# Create a temperature zero target probability distribution and ensure
# all draft token ids correspond to the tokens with 1.0 probability.
# Verify that all of them are accepted.
Expand Down Expand Up @@ -399,7 +399,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
# Simulate temperature 0 probability distribution for target
# probabilities and create target probabilities such that only 1 token
# id has probability 1.0 and others have a very low probability of
Expand Down Expand Up @@ -430,7 +430,7 @@ def test_accept_tokens_set_non_default_posteriors(seed: int, device: str):
# target distribution. Simulate and verify the same.
typical_acceptance_sampler = TypicalAcceptanceSampler(
strict_mode=True, posterior_threshold=0.0, posterior_alpha=0.0)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
output_token_ids = typical_acceptance_sampler(
target_probs,
bonus_token_ids,
Expand Down Expand Up @@ -462,7 +462,7 @@ def test_get_recovered_token_ids(seed: int, device: str):
vocab_size = 30_000
torch.set_default_device(device)
typical_acceptance_sampler = get_acceptance_sampler(strict_mode=True)
typical_acceptance_sampler.init_gpu_tensors(device=device)
typical_acceptance_sampler.init_tensors(device=device)
target_probs = torch.rand(batch_size, k, vocab_size, dtype=torch.float32)
expected_replacement_tokens = torch.argmax(target_probs, dim=-1)
actual_replacement_tokens = (
Expand Down
12 changes: 6 additions & 6 deletions tests/spec_decode/test_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def test_initial_call_returns_none():
spec_decode_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
maybe_metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert maybe_metrics is None

Expand All @@ -46,7 +46,7 @@ def test_second_call_returns_metrics():
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is not None
Expand All @@ -66,7 +66,7 @@ def test_nonzero_rank_noop(rank):
spec_decode_sampler.num_draft_tokens = 0

collector = AsyncMetricsCollector(spec_decode_sampler)
collector.init_gpu_tensors(rank=rank)
collector.init_tensors(rank=rank)
_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
assert metrics is None
Expand Down Expand Up @@ -94,7 +94,7 @@ def test_noop_until_time():
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_timer_is_reset():
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)

_ = collector.maybe_collect_rejsample_metrics(k=5)
metrics = collector.maybe_collect_rejsample_metrics(k=5)
Expand Down Expand Up @@ -183,7 +183,7 @@ def test_initial_metrics_has_correct_values(has_data: bool):
collector = AsyncMetricsCollector(spec_decode_sampler=spec_decode_sampler,
timer=timer,
collect_interval_s=collect_interval_s)
collector.init_gpu_tensors(rank=0)
collector.init_tensors(rank=0)
_ = collector.maybe_collect_rejsample_metrics(k)
metrics = collector.maybe_collect_rejsample_metrics(k)

Expand Down
4 changes: 4 additions & 0 deletions vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@

class HPUAttentionBackend(AttentionBackend):

@staticmethod
def get_name() -> str:
return "hpu-attn"

@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUAttentionImpl
Expand Down
18 changes: 13 additions & 5 deletions vllm/executor/hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,23 @@ def _get_worker_kwargs(
distributed_init_method=distributed_init_method,
lora_config=self.lora_config,
is_driver_worker=rank == 0,
speculative_config=self.speculative_config,
)

def _create_worker(self,
local_rank: int = 0,
rank: int = 0,
distributed_init_method: Optional[str] = None):
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.hpu_worker",
worker_class_name="HPUWorker",
)
if self.speculative_config is None:
wrapper = WorkerWrapperBase(
worker_module_name="vllm.worker.hpu_worker",
worker_class_name="HPUWorker",
)
else:
wrapper = WorkerWrapperBase(
worker_module_name="vllm.spec_decode.spec_decode_worker",
worker_class_name="create_spec_worker",
)
wrapper.init_worker(**self._get_worker_kwargs(local_rank, rank,
distributed_init_method))
return wrapper.worker
Expand Down Expand Up @@ -197,7 +204,8 @@ def stop_profile(self) -> None:
self.driver_worker.stop_profile()

def shutdown(self) -> None:
self.driver_worker.shutdown_inc()
if hasattr(self.driver_worker, 'shutdown_inc'):
self.driver_worker.shutdown_inc()


class HPUExecutorAsync(HPUExecutor, ExecutorAsyncBase):
Expand Down
9 changes: 4 additions & 5 deletions vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,12 +30,10 @@ def __init__(self, strict_mode: bool = False):
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0

def init_gpu_tensors(self, device: Union[int, str]) -> None:
def init_tensors(self, device: Union[int, str], device_type: str = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"cuda:{device}"
elif not isinstance(device, str):
raise ValueError(f"Device must be int or str, get {type(device)}")
device = f"{device_type}:{device}"
self.num_accepted_tokens = torch.tensor(0,
dtype=torch.long,
device=device)
Expand Down Expand Up @@ -77,7 +75,8 @@ def _create_output(
tensor is [batch_size, k + num_bonus_tokens]
"""
batch_size, k = substitute_token_ids.shape
bonus_token_ids = bonus_token_ids.squeeze()
# TODO: HPU has an issue when handling squeeze with all dim as 1
bonus_token_ids = bonus_token_ids.squeeze(-1)
# Determine the index of the first False value for each row.
limits = (accepted == 0).max(1).indices
limits[~(accepted == 0).any(1)] = k
Expand Down
9 changes: 9 additions & 0 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,7 @@ def score_proposals(
contracted = self._contract_batch_all_spec(
target_sampler_output=target_sampler_output,
proposals=proposals,
num_scoring_tokens=num_scoring_tokens,
)
else:
# Batch has a mix of spec decode enabled and disabled seq groups
Expand Down Expand Up @@ -216,6 +217,7 @@ def _contract_batch_all_spec(
self,
target_sampler_output: SamplerOutput,
proposals: SpeculativeProposals,
num_scoring_tokens: int,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor,
Optional[torch.Tensor]]:
"""Contract the expanded batch back into its original size.
Expand All @@ -229,6 +231,13 @@ def _contract_batch_all_spec(
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape

(target_sampler_output.sampled_token_ids,
target_sampler_output.sampled_token_probs,
target_sampler_output.logprobs,
target_sampler_output.hidden_states,
_, _, _, _,) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)

# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
contracted_bs, k + 1)
Expand Down
12 changes: 8 additions & 4 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,14 @@

try:
from vllm.attention.backends.flash_attn import FlashAttentionMetadata
except ModuleNotFoundError:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except:
try:
# vllm_flash_attn is not installed, use the identical ROCm FA metadata
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except:
pass


from vllm.config import (CacheConfig, DeviceConfig, LoadConfig, LoRAConfig,
ModelConfig, ObservabilityConfig, ParallelConfig,
Expand Down
Loading

0 comments on commit efc17b7

Please sign in to comment.