From 9f4b74108a12f0b6732807be36d754965e76a3df Mon Sep 17 00:00:00 2001 From: Boyuan Feng Date: Tue, 14 Jan 2025 16:42:40 -0800 Subject: [PATCH] ruff reformat --- attn_gym/paged_attention/latency.py | 69 ++++++++++++---- attn_gym/paged_attention/model.py | 110 +++++++++++++++++-------- attn_gym/paged_attention/throughput.py | 110 ++++++++++++++++++------- attn_gym/paged_attention/utils.py | 22 ++--- 4 files changed, 221 insertions(+), 90 deletions(-) diff --git a/attn_gym/paged_attention/latency.py b/attn_gym/paged_attention/latency.py index a93bd4b..b623351 100644 --- a/attn_gym/paged_attention/latency.py +++ b/attn_gym/paged_attention/latency.py @@ -17,16 +17,33 @@ dtype = torch.bfloat16 -def benchmark_layer(bsz, n_heads, max_seq_len, head_dim, paged_attention, batch_idx, input_pos, block_mask, score_mod, converted_block_mask, converted_score_mod, dtype=torch.bfloat16): +def benchmark_layer( + bsz, + n_heads, + max_seq_len, + head_dim, + paged_attention, + batch_idx, + input_pos, + block_mask, + score_mod, + converted_block_mask, + converted_score_mod, + dtype=torch.bfloat16, +): from model import NonPagedAttentionLayer, PagedAttentionLayer # compile model - non_paged_foo = torch.compile(NonPagedAttentionLayer(bsz, n_heads, max_seq_len, head_dim, dtype), fullgraph=True) - paged_foo = torch.compile(PagedAttentionLayer(n_heads, head_dim, dtype, paged_attention), fullgraph=True) + non_paged_foo = torch.compile( + NonPagedAttentionLayer(bsz, n_heads, max_seq_len, head_dim, dtype), fullgraph=True + ) + paged_foo = torch.compile( + PagedAttentionLayer(n_heads, head_dim, dtype, paged_attention), fullgraph=True + ) with torch.no_grad(): # randomize a token embedding - x = torch.randn(bsz, 1, n_heads*head_dim, device="cuda", dtype=dtype) + x = torch.randn(bsz, 1, n_heads * head_dim, device="cuda", dtype=dtype) # warmup for _ in range(10): @@ -34,15 +51,25 @@ def benchmark_layer(bsz, n_heads, max_seq_len, head_dim, paged_attention, batch_ paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod) # benchmark - non_paged_latency = benchmarker.benchmark_gpu(lambda: non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod)) - paged_latency = benchmarker.benchmark_gpu(lambda: paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod)) - print(f"non_paged_latency: {non_paged_latency} ms, paged_latency: {paged_latency} ms, overhead: {round((paged_latency/non_paged_latency-1.0)*100, 2)}%") - - -def benchmark(attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_heads: int, head_dim: int): + non_paged_latency = benchmarker.benchmark_gpu( + lambda: non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod) + ) + paged_latency = benchmarker.benchmark_gpu( + lambda: paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod) + ) + print( + f"non_paged_latency: {non_paged_latency} ms, paged_latency: {paged_latency} ms, overhead: {round((paged_latency/non_paged_latency-1.0)*100, 2)}%" + ) + + +def benchmark( + attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_heads: int, head_dim: int +): # For decoding benchmark, we set input_pos to be half of max_seq_len - input_pos = torch.tensor([max_seq_len // 2]*bsz, device="cuda", dtype=torch.int32).view(bsz, 1) # [bsz, 1] - batch_idx = torch.arange(bsz, device="cuda", dtype=torch.int32) # [bsz] + input_pos = torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32).view( + bsz, 1 + ) # [bsz, 1] + batch_idx = torch.arange(bsz, device="cuda", dtype=torch.int32) # [bsz] # init paged attention n_pages = (max_seq_len + page_size - 1) // page_size * bsz @@ -50,7 +77,9 @@ def benchmark(attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_head # Block mask if attn_type == "causal": - mask_mod = gen_offset(torch.tensor([max_seq_len // 2]*bsz, device="cuda", dtype=torch.int32)) + mask_mod = gen_offset( + torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32) + ) else: mask_mod = noop_mask block_mask = create_block_mask(mask_mod, bsz, 1, 1, max_seq_len, BLOCK_SIZE=page_size) @@ -60,7 +89,19 @@ def benchmark(attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_head score_mod = generate_score_mod(attn_type) converted_score_mod = paged_attention.get_score_mod(score_mod) - benchmark_layer(bsz, n_heads, max_seq_len, head_dim, paged_attention, batch_idx, input_pos, block_mask, score_mod, converted_block_mask, converted_score_mod) + benchmark_layer( + bsz, + n_heads, + max_seq_len, + head_dim, + paged_attention, + batch_idx, + input_pos, + block_mask, + score_mod, + converted_block_mask, + converted_score_mod, + ) if __name__ == "__main__": diff --git a/attn_gym/paged_attention/model.py b/attn_gym/paged_attention/model.py index cc9300b..0730f7d 100644 --- a/attn_gym/paged_attention/model.py +++ b/attn_gym/paged_attention/model.py @@ -7,17 +7,22 @@ class NonPagedAttentionLayer(torch.nn.Module): """An attention layer without paged attention, ported from GPT-Fast: - https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L180-L227 + https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L180-L227 """ - def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int=32768): + + def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int = 32768): super().__init__() self.n_head = n_heads self.head_dim = head_dim # key, query, value projections for all heads, but in a batch total_head_dim = n_heads * head_dim - self.wqkv = torch.nn.Linear(total_head_dim, 3*total_head_dim, bias=False, device="cuda", dtype=dtype) - self.wo = torch.nn.Linear(total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype) + self.wqkv = torch.nn.Linear( + total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype + ) + self.wo = torch.nn.Linear( + total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype + ) self.k_cache = torch.randn( (bsz, n_heads, max_seq_len, head_dim), device="cuda", dtype=dtype ) @@ -26,7 +31,14 @@ def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int=3 ) self.freqs_cis = precompute_freqs_cis(block_size, self.head_dim, dtype=dtype) - def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, block_mask: BlockMask, score_mod: _score_mod_signature) -> Tensor: + def forward( + self, + batch_idx: Tensor, + input_pos: Tensor, + x: Tensor, + block_mask: BlockMask, + score_mod: _score_mod_signature, + ) -> Tensor: # input_pos: [B, S], batch_idx: [B], x: [B, S, D] B, S, _ = x.shape @@ -37,7 +49,9 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, block_mask: B k = k.view(B, S, self.n_head, self.head_dim) v = v.view(B, S, self.n_head, self.head_dim) - freqs_cis = self.freqs_cis.unsqueeze(0)[torch.zeros((B, 1), dtype=torch.int), input_pos] # [B, S, D//2, 2] + freqs_cis = self.freqs_cis.unsqueeze(0)[ + torch.zeros((B, 1), dtype=torch.int), input_pos + ] # [B, S, D//2, 2] q = apply_rotary_emb(q, freqs_cis) k = apply_rotary_emb(k, freqs_cis) @@ -46,7 +60,9 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, block_mask: B self.k_cache[batch_idx.view(B, 1), :, input_pos] = k self.v_cache[batch_idx.view(B, 1), :, input_pos] = v - y = flex_attention(q, self.k_cache, self.v_cache, block_mask=block_mask, score_mod=score_mod) + y = flex_attention( + q, self.k_cache, self.v_cache, block_mask=block_mask, score_mod=score_mod + ) y = y.transpose(1, 2).contiguous().view(B, S, -1) @@ -56,39 +72,53 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, block_mask: B class PagedAttentionLayer(torch.nn.Module): """An attention layer with paged attention""" - def __init__(self, n_heads, head_dim, dtype, paged_attention, block_size: int=65536): + + def __init__(self, n_heads, head_dim, dtype, paged_attention, block_size: int = 65536): super().__init__() self.n_head = n_heads self.head_dim = head_dim # key, query, value projections for all heads, but in a batch total_head_dim = n_heads * head_dim - self.wqkv = torch.nn.Linear(total_head_dim, 3*total_head_dim, bias=False, device="cuda", dtype=dtype) - self.wo = torch.nn.Linear(total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype) + self.wqkv = torch.nn.Linear( + total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype + ) + self.wo = torch.nn.Linear( + total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype + ) # allocate kv cache with batch size=1 for paged attention max_cached_seq_len = paged_attention.n_pages * paged_attention.page_size self.k_cache_paged = torch.randn( - 1, - n_heads, - max_cached_seq_len, - head_dim, - device="cuda", - dtype=dtype, - ) + 1, + n_heads, + max_cached_seq_len, + head_dim, + device="cuda", + dtype=dtype, + ) self.v_cache_paged = torch.randn( - 1, - n_heads, - max_cached_seq_len, - head_dim, - device="cuda", - dtype=dtype, - ) + 1, + n_heads, + max_cached_seq_len, + head_dim, + device="cuda", + dtype=dtype, + ) self.paged_attention = paged_attention - self.freqs_cis = precompute_freqs_cis(block_size, self.head_dim, dtype=dtype) # [block_size, D//2, 2] - - def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, converted_block_mask: BlockMask, converted_score_mod: _score_mod_signature) -> Tensor: + self.freqs_cis = precompute_freqs_cis( + block_size, self.head_dim, dtype=dtype + ) # [block_size, D//2, 2] + + def forward( + self, + batch_idx: Tensor, + input_pos: Tensor, + x: Tensor, + converted_block_mask: BlockMask, + converted_score_mod: _score_mod_signature, + ) -> Tensor: # input_pos: [B, S], batch_idx: [B], x: [B, S, D] B, S, _ = x.shape kv_size = self.n_head * self.head_dim @@ -98,7 +128,9 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, converted_blo k = k.view(B, S, self.n_head, self.head_dim) v = v.view(B, S, self.n_head, self.head_dim) - freqs_cis = self.freqs_cis.unsqueeze(0)[torch.zeros((B, 1), dtype=torch.int), input_pos] # [B, S, D//2, 2] + freqs_cis = self.freqs_cis.unsqueeze(0)[ + torch.zeros((B, 1), dtype=torch.int), input_pos + ] # [B, S, D//2, 2] q = apply_rotary_emb(q, freqs_cis) k = apply_rotary_emb(k, freqs_cis) @@ -110,7 +142,13 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, converted_blo batch_idx, input_pos, k, v, self.k_cache_paged, self.v_cache_paged ) - y = flex_attention(q, self.k_cache_paged, self.v_cache_paged, block_mask=converted_block_mask, score_mod=converted_score_mod) + y = flex_attention( + q, + self.k_cache_paged, + self.v_cache_paged, + block_mask=converted_block_mask, + score_mod=converted_score_mod, + ) y = y.transpose(1, 2).contiguous().view(B, S, -1) @@ -120,8 +158,10 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, converted_blo def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor: # x: [B, S, H, D], freqs_cis: [B, S, D//2, 2] - xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2] - freqs_cis = freqs_cis.view(xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2) # [B, S, 1, D//2, 2] + xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2] + freqs_cis = freqs_cis.view( + xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2 + ) # [B, S, 1, D//2, 2] x_out2 = torch.stack( [ xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1], @@ -151,13 +191,17 @@ def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Dict): new_freqs.append(freq / factor) else: assert low_freq_wavelen != high_freq_wavelen - smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth = (old_context_len / wavelen - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) new_freqs.append((1 - smooth) * freq / factor + smooth * freq) return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device) def precompute_freqs_cis( - seq_len: int, n_elem: int, base: int = 10000, + seq_len: int, + n_elem: int, + base: int = 10000, dtype: torch.dtype = torch.bfloat16, rope_scaling: Optional[dict] = None, ) -> Tensor: diff --git a/attn_gym/paged_attention/throughput.py b/attn_gym/paged_attention/throughput.py index da31694..14d5ec2 100644 --- a/attn_gym/paged_attention/throughput.py +++ b/attn_gym/paged_attention/throughput.py @@ -17,7 +17,7 @@ [Paged Attention] kv cache requires 2 * (2 * h * n_pages * page_size * d) -bytes. Assuming a page size of 128, there could be at most 32768 pages. +bytes. Assuming a page size of 128, there could be at most 32768 pages. We empirically observe that the max batch size to serve is 2448, which is 76x of the max batch size without paged attention. """ @@ -31,23 +31,24 @@ ) from datasets import load_dataset import random -from collections import deque +from collections import deque from typing import Tuple from utils import gen_offset, slice_block_mask from model import PagedAttentionLayer create_block_mask = torch.compile(create_block_mask) + class Requests: def __init__(self): - self.data = load_dataset("Open-Orca/OpenOrca")['train'] + self.data = load_dataset("Open-Orca/OpenOrca")["train"] def sample_request(self): # sample a prompt len and response len from openorca dataset # to simulate a real world use case - idx = random.randint(0, len(self.data)-1) - prompt_len = len(self.data[idx]['system_prompt']) + len(self.data[idx]['question']) - response_len = len(self.data[idx]['response']) + idx = random.randint(0, len(self.data) - 1) + prompt_len = len(self.data[idx]["system_prompt"]) + len(self.data[idx]["question"]) + response_len = len(self.data[idx]["response"]) return prompt_len, response_len @@ -55,53 +56,82 @@ class Server: def __init__(self, batch_size: int, n_pages: int, page_size: int, n_heads: int, head_dim: int): self.paged_attention = PagedAttention(n_pages, page_size, batch_size) - self.model = torch.compile(PagedAttentionLayer(n_heads, head_dim, torch.bfloat16, self.paged_attention)) + self.model = torch.compile( + PagedAttentionLayer(n_heads, head_dim, torch.bfloat16, self.paged_attention) + ) self.batch_size = batch_size self.n_heads = n_heads self.head_dim = head_dim - self.bsz_watermark = 0 # max batch size served during benchmark + self.bsz_watermark = 0 # max batch size served during benchmark self.available_batch_idx = list(range(batch_size))[::-1] self.request_queue = deque([]) self.batch_idx = [] self.input_pos = torch.zeros(batch_size, device="cuda", dtype=torch.int64) - self.request_length = torch.tensor([float('inf')] * batch_size, device="cuda") # decide whether a request is completed + self.request_length = torch.tensor( + [float("inf")] * batch_size, device="cuda" + ) # decide whether a request is completed - self.token_embedding = torch.randn((batch_size, 1, n_heads*head_dim), device="cuda", dtype=torch.bfloat16) # [B, 1, n_heads*head_dim] + self.token_embedding = torch.randn( + (batch_size, 1, n_heads * head_dim), device="cuda", dtype=torch.bfloat16 + ) # [B, 1, n_heads*head_dim] - self.block_mask = create_block_mask(lambda b,h,q,kv: q >= kv, batch_size, 1, 64*1024, 64*1024, BLOCK_SIZE=page_size) + self.block_mask = create_block_mask( + lambda b, h, q, kv: q >= kv, batch_size, 1, 64 * 1024, 64 * 1024, BLOCK_SIZE=page_size + ) def receive_request(self, prompt_len: int, response_len: int): # assume we know prompt length and response length in advance. self.request_queue.append((prompt_len, response_len)) def can_schedule(self, request: Tuple[int, int]) -> bool: - return len(self.paged_attention.empty_pages) * self.paged_attention.page_size >= sum(request) + return len(self.paged_attention.empty_pages) * self.paged_attention.page_size >= sum( + request + ) def prefill_one_token(self, batch_idx: int, prompt_len: int, response_len: int): # allocate page table # in practice we don't know response length in advance. A good way is to use a heuristic to estimate response length - # and allocate page table accordingly. We may also allocate pages on the fly. For simplicity, we assume we know + # and allocate page table accordingly. We may also allocate pages on the fly. For simplicity, we assume we know # response length in advance. - self.paged_attention.reserve(torch.tensor(batch_idx, device="cuda"), torch.tensor(prompt_len+response_len, device="cuda")) + self.paged_attention.reserve( + torch.tensor(batch_idx, device="cuda"), + torch.tensor(prompt_len + response_len, device="cuda"), + ) # simulate input token embedding - token_embedding = torch.randn(1, prompt_len, self.head_dim * self.n_heads, device="cuda", dtype=torch.bfloat16) + token_embedding = torch.randn( + 1, prompt_len, self.head_dim * self.n_heads, device="cuda", dtype=torch.bfloat16 + ) # generate block mask. The same block mask is used for all layers. new_block_mask = slice_block_mask(self.block_mask, batch_idx, prompt_len, prompt_len) - converted_block_mask = self.paged_attention.convert_logical_block_mask(new_block_mask, torch.tensor([batch_idx], device="cuda")) + converted_block_mask = self.paged_attention.convert_logical_block_mask( + new_block_mask, torch.tensor([batch_idx], device="cuda") + ) converted_score_mod = self.paged_attention.get_score_mod(_identity) prefill_input_pos = torch.arange(prompt_len, device="cuda").view(1, -1) - token_embedding = self.model(torch.tensor([batch_idx], device="cuda"), prefill_input_pos, token_embedding, converted_block_mask, converted_score_mod) + token_embedding = self.model( + torch.tensor([batch_idx], device="cuda"), + prefill_input_pos, + token_embedding, + converted_block_mask, + converted_score_mod, + ) return token_embedding def prefill(self): - while self.request_queue and self.can_schedule(self.request_queue[0]) and self.available_batch_idx: + while ( + self.request_queue + and self.can_schedule(self.request_queue[0]) + and self.available_batch_idx + ): prompt_len, response_len = self.request_queue.popleft() - print(f"serving a new request with prompt_len: {prompt_len}, response_len: {response_len}") + print( + f"serving a new request with prompt_len: {prompt_len}, response_len: {response_len}" + ) new_batch_idx = self.available_batch_idx.pop() token_embedding = self.prefill_one_token(new_batch_idx, prompt_len, response_len) self.token_embedding[new_batch_idx] = token_embedding[:, -1].view(1, -1) @@ -114,42 +144,62 @@ def prefill(self): def get_decode_mask(self, batch_idx: torch.Tensor, input_pos: torch.Tensor): # batch_idx: [B], input_pos: [B] - B, = batch_idx.shape - input_block_idx = (input_pos // self.block_mask.BLOCK_SIZE[0]) # [B] + (B,) = batch_idx.shape + input_block_idx = input_pos // self.block_mask.BLOCK_SIZE[0] # [B] kv_num_blocks = self.block_mask.kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1) kv_indices = self.block_mask.kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1) full_kv_num_blocks, full_kv_indices = None, None if self.block_mask.full_kv_num_blocks is not None: - full_kv_num_blocks = self.block_mask.full_kv_num_blocks[batch_idx, :, input_block_idx].view(B, 1, 1) - full_kv_indices = self.block_mask.full_kv_indices[batch_idx, :, input_block_idx].view(B, 1, 1, -1) + full_kv_num_blocks = self.block_mask.full_kv_num_blocks[ + batch_idx, :, input_block_idx + ].view(B, 1, 1) + full_kv_indices = self.block_mask.full_kv_indices[batch_idx, :, input_block_idx].view( + B, 1, 1, -1 + ) seq_length = (1, self.block_mask.seq_lengths[1]) - return BlockMask.from_kv_blocks(kv_num_blocks, kv_indices, full_kv_num_blocks, full_kv_indices, BLOCK_SIZE=self.block_mask.BLOCK_SIZE, mask_mod=gen_offset(input_pos), seq_lengths=seq_length) + return BlockMask.from_kv_blocks( + kv_num_blocks, + kv_indices, + full_kv_num_blocks, + full_kv_indices, + BLOCK_SIZE=self.block_mask.BLOCK_SIZE, + mask_mod=gen_offset(input_pos), + seq_lengths=seq_length, + ) def decode(self): B = len(self.batch_idx) - batch_idx = torch.tensor(self.batch_idx, device="cuda").view(-1) # [B]. + batch_idx = torch.tensor(self.batch_idx, device="cuda").view(-1) # [B]. input_pos = self.input_pos[batch_idx] # [B] mask = self.get_decode_mask(batch_idx, input_pos) converted_block_mask = self.paged_attention.convert_logical_block_mask(mask, batch_idx) converted_score_mod = self.paged_attention.get_score_mod(_identity) - self.token_embedding[batch_idx] = self.model(batch_idx, input_pos.view(B, 1), self.token_embedding[batch_idx], converted_block_mask, converted_score_mod) + self.token_embedding[batch_idx] = self.model( + batch_idx, + input_pos.view(B, 1), + self.token_embedding[batch_idx], + converted_block_mask, + converted_score_mod, + ) self.input_pos[batch_idx] += 1 def clean(self): completed_batch_indices = torch.where(self.input_pos >= self.request_length)[0] self.available_batch_idx += completed_batch_indices.tolist() - self.batch_idx = [idx for idx in self.batch_idx if idx not in completed_batch_indices.tolist()] + self.batch_idx = [ + idx for idx in self.batch_idx if idx not in completed_batch_indices.tolist() + ] for b in completed_batch_indices: self.paged_attention.erase(torch.tensor([b])) - self.request_length[completed_batch_indices] = float('inf') + self.request_length[completed_batch_indices] = float("inf") if __name__ == "__main__": # serving loop - num_requests = 10 # total number of requests during benchmark - gap = 3 # get a new request after `gap` number of decoding tokens + num_requests = 10 # total number of requests during benchmark + gap = 3 # get a new request after `gap` number of decoding tokens batch_size, n_pages, page_size, n_heads, head_dim = 4096, 32768, 128, 4, 64 diff --git a/attn_gym/paged_attention/utils.py b/attn_gym/paged_attention/utils.py index 3c7627a..ccc7806 100644 --- a/attn_gym/paged_attention/utils.py +++ b/attn_gym/paged_attention/utils.py @@ -6,9 +6,7 @@ ) -def batch_reserve( - paged_attention: PagedAttention, target_seq_len: torch.Tensor -): +def batch_reserve(paged_attention: PagedAttention, target_seq_len: torch.Tensor): """Reserves pages for each sequence in the batch. Args: @@ -23,9 +21,7 @@ def batch_reserve( ) -def random_init_paged_attention( - n_pages: int, page_size: int, bsz: int, max_seq_len: int -): +def random_init_paged_attention(n_pages: int, page_size: int, bsz: int, max_seq_len: int): """Allocate physical pages across batches in a round-robin fashion to simulate a use case where multiple batches run in parallel. This is for testing and benchmarking only. @@ -39,12 +35,9 @@ def random_init_paged_attention( repeat = bsz // 4 sequence_lengths = [ - [max_seq_len // 4, max_seq_len // 2, max_seq_len // 4, max_seq_len // 3] - * repeat, - [max_seq_len // 4, max_seq_len // 2, max_seq_len // 2, max_seq_len // 2] - * repeat, - [max_seq_len // 4, max_seq_len // 2, max_seq_len // 2, max_seq_len // 2] - * repeat, + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 4, max_seq_len // 3] * repeat, + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 2, max_seq_len // 2] * repeat, + [max_seq_len // 4, max_seq_len // 2, max_seq_len // 2, max_seq_len // 2] * repeat, [max_seq_len // 2, max_seq_len, max_seq_len // 2, max_seq_len] * repeat, [max_seq_len, max_seq_len, max_seq_len, max_seq_len] * repeat, ] @@ -77,6 +70,7 @@ def generate_score_mod(attn_type: str): Args: attn_type: Attention type. """ + def relative_bias(score, b, h, m, n): return score + (m - n) @@ -115,7 +109,9 @@ def _adjust_num_blocks_and_indices( return num_blocks.clone(), indices.clone() -def slice_block_mask(block_mask: BlockMask, batch_idx: int, new_q_len: int, new_kv_len: int) -> BlockMask: +def slice_block_mask( + block_mask: BlockMask, batch_idx: int, new_q_len: int, new_kv_len: int +) -> BlockMask: """Slice the block mask based on the new query and key/value lengths. Args: