Skip to content

Commit

Permalink
Fix format issue reported by yapf
Browse files Browse the repository at this point in the history
Signed-off-by: Chendi Xue <[email protected]>
  • Loading branch information
xuechendi committed Oct 16, 2024
1 parent 00b230a commit d06d252
Show file tree
Hide file tree
Showing 10 changed files with 49 additions and 36 deletions.
15 changes: 9 additions & 6 deletions examples/offline_inference_spec_decode.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@ def time_generation(llm: LLM, prompts: List[str],
start = time.time()
outputs = llm.generate(prompts, sampling_params)
end = time.time()
latency_per_token = (end - start) / sum([len(o.outputs[0].token_ids) for o in outputs])
latency_per_token = (end - start) / sum(
[len(o.outputs[0].token_ids) for o in outputs])
# Print the outputs.
ret = []
for output in outputs:
Expand All @@ -36,7 +37,8 @@ def time_generation(llm: LLM, prompts: List[str],
print("==============Without speculation==================")
llm = LLM(model="facebook/opt-6.7b")

ret_non_spec,latency_per_token_non_spec = time_generation(llm, prompts, sampling_params)
ret_non_spec, latency_per_token_non_spec = time_generation(
llm, prompts, sampling_params)

del llm
gc.collect()
Expand All @@ -46,19 +48,20 @@ def time_generation(llm: LLM, prompts: List[str],
llm = LLM(
model="facebook/opt-6.7b",
speculative_model="facebook/opt-125m",
num_speculative_tokens = 5,
num_speculative_tokens=5,
# These are currently required for MLPSpeculator decoding
use_v2_block_manager=True,
)

ret_spec,latency_per_token_spec = time_generation(llm, prompts, sampling_params)
ret_spec, latency_per_token_spec = time_generation(llm, prompts,
sampling_params)

del llm
gc.collect()
print("================= Summary =====================")
print("input is ", prompts, "\n")
print("Non Spec Decode - latency_per_token is ", latency_per_token_non_spec)
print("Non Spec Decode - latency_per_token is ",
latency_per_token_non_spec)
print("Generated Text is :", ret_non_spec, "\n")
print("Spec Decode - latency_per_token is ", latency_per_token_spec)
print("Generated Text is :", ret_spec)

15 changes: 8 additions & 7 deletions tests/samplers/test_rejection_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,17 @@ def mock_causal_accepted_tensor(
"""
batch_size = last_accepted_indices.shape[0]

accepted = (torch.arange(k).expand(batch_size, k) <=
last_accepted_indices.unsqueeze(-1).broadcast_to(
accepted = (torch.arange(k).expand(batch_size, k)
<= last_accepted_indices.unsqueeze(-1).broadcast_to(
batch_size, k))

# Sprinkle accepted values after the contiguous initial accepted values.
# This replicates the behavior of rejection sampling, which may "accept"
# a token that cannot be accepted because of causality.
sprinkle_candidates = (
torch.arange(k).expand(batch_size, k) >
last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) + 1)
sprinkle_candidates = (torch.arange(k).expand(
batch_size,
k) > last_accepted_indices.unsqueeze(-1).broadcast_to(batch_size, k) +
1)
sprinkle = torch.rand(batch_size, k) > 0.5
accepted[sprinkle_candidates] = sprinkle[sprinkle_candidates]
return accepted
Expand Down Expand Up @@ -382,8 +383,8 @@ def test_rejection_sampling_approximates_target_distribution(
distance_wrt_reference)

expected_improvement_multiplier = 20
assert (relative_change_in_distance_wrt_target >
relative_change_in_distance_wrt_reference *
assert (relative_change_in_distance_wrt_target
> relative_change_in_distance_wrt_reference *
expected_improvement_multiplier)


Expand Down
2 changes: 1 addition & 1 deletion vllm/attention/backends/hpu_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ class HPUAttentionBackend(AttentionBackend):
@staticmethod
def get_name() -> str:
return "hpu-attn"

@staticmethod
def get_impl_cls() -> Type["HPUAttentionImpl"]:
return HPUAttentionImpl
Expand Down
8 changes: 4 additions & 4 deletions vllm/executor/hpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,13 +144,13 @@ def execute_model(
with gc_ctx as gc_local_metric, \
cpu_fallback_ctx as cpu_fallback_local_metric:
output = self.driver_worker.execute_model(execute_model_req)
if (log_graph_compilation and gc_local_metric.stats()[0][1] > 0
) or log_graph_compilation_all:
if (log_graph_compilation and gc_local_metric.stats()[0][1]
> 0) or log_graph_compilation_all:
msg = ("VLLM_HPU_STEP_GRAPH_COMPILATION: "
f"{gc_local_metric.stats()}, {input_stats}")
logger.warning(msg)
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1] >
0) or log_cpu_fallbacks_all:
if (log_cpu_fallbacks and cpu_fallback_local_metric.stats()[0][1]
> 0) or log_cpu_fallbacks_all:
msg = ("VLLM_HPU_STEP_CPU_FALLBACK: "
f"{cpu_fallback_local_metric.stats()}, {input_stats}")
logger.warning(msg)
Expand Down
4 changes: 3 additions & 1 deletion vllm/model_executor/layers/spec_decode_base_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def __init__(self, strict_mode: bool = False):
self.num_emitted_tokens: Optional[torch.Tensor] = None
self.num_draft_tokens: int = 0

def init_tensors(self, device: Union[int, str], device_type: str = 'cuda') -> None:
def init_tensors(self,
device: Union[int, str],
device_type: str = 'cuda') -> None:
assert self.num_accepted_tokens is None
if isinstance(device, int):
device = f"{device_type}:{device}"
Expand Down
17 changes: 11 additions & 6 deletions vllm/spec_decode/batch_expansion.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,17 @@ def _contract_batch_all_spec(
# of shape [batch_size * k + 1] back to [batch_size, k + 1].
contracted_bs, k = proposals.proposal_token_ids.shape

(target_sampler_output.sampled_token_ids,
target_sampler_output.sampled_token_probs,
target_sampler_output.logprobs,
target_sampler_output.hidden_states,
_, _, _, _,) = self._split_scoring_output(
target_sampler_output, num_scoring_tokens)
(
target_sampler_output.sampled_token_ids,
target_sampler_output.sampled_token_probs,
target_sampler_output.logprobs,
target_sampler_output.hidden_states,
_,
_,
_,
_,
) = self._split_scoring_output(target_sampler_output,
num_scoring_tokens)

# Reshape tensors to original batch size
target_token_ids = target_sampler_output.sampled_token_ids.reshape(
Expand Down
5 changes: 2 additions & 3 deletions vllm/spec_decode/draft_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,8 @@
from vllm.attention.backends.rocm_flash_attn import (
ROCmFlashAttentionMetadata as FlashAttentionMetadata)
except:
logger.warning(
"Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.")
logger.warning("Draft model speculative decoding currently only supports"
"CUDA and ROCm flash attention backend.")

# A flag to enable debug prints for the updated input tensors
# before each step.
Expand Down
4 changes: 2 additions & 2 deletions vllm/spec_decode/spec_decode_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,8 +461,8 @@ def _should_disable_all_speculation(
self, execute_model_req: ExecuteModelRequest) -> bool:
# When the batch size is too large, disable speculative decoding
# to stop trading off throughput for latency.
return (execute_model_req.running_queue_size >=
self.disable_by_batch_size)
return (execute_model_req.running_queue_size
>= self.disable_by_batch_size)

def _maybe_disable_speculative_tokens(
self, disable_all_speculation: bool,
Expand Down
12 changes: 7 additions & 5 deletions vllm/spec_decode/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,15 @@ def get_sampled_token_logprobs(
"""
num_steps, batch_size, vocab_size = logprob_tensor.shape

selected_logprobs = logprob_tensor[torch.arange(num_steps).unsqueeze(1),
torch.arange(batch_size),
sampled_token_ids, ]
selected_logprobs = logprob_tensor[
torch.arange(num_steps).unsqueeze(1),
torch.arange(batch_size),
sampled_token_ids,
]
expanded_selected_logprobs = selected_logprobs.unsqueeze(-1).expand(
-1, -1, vocab_size)
sampled_token_ids_ranks = (logprob_tensor >
expanded_selected_logprobs).sum(-1).add_(1)
sampled_token_ids_ranks = (logprob_tensor
> expanded_selected_logprobs).sum(-1).add_(1)

return sampled_token_ids_ranks, selected_logprobs

Expand Down
3 changes: 2 additions & 1 deletion vllm/worker/selector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from vllm.config import DeviceConfig


def init_worker(*args, **kwargs):
device_config: DeviceConfig = kwargs.get("device_config")
if device_config.device_type == 'neuron':
Expand All @@ -22,4 +23,4 @@ def init_worker(*args, **kwargs):
return XPUWorker(*args, **kwargs)
else:
from vllm.worker.worker import Worker
return Worker(*args, **kwargs)
return Worker(*args, **kwargs)

0 comments on commit d06d252

Please sign in to comment.