Skip to content

Commit

Permalink
ruff reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
BoyuanFeng committed Jan 15, 2025
1 parent 9c6a480 commit 9f4b741
Show file tree
Hide file tree
Showing 4 changed files with 221 additions and 90 deletions.
69 changes: 55 additions & 14 deletions attn_gym/paged_attention/latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,40 +17,69 @@
dtype = torch.bfloat16


def benchmark_layer(bsz, n_heads, max_seq_len, head_dim, paged_attention, batch_idx, input_pos, block_mask, score_mod, converted_block_mask, converted_score_mod, dtype=torch.bfloat16):
def benchmark_layer(
bsz,
n_heads,
max_seq_len,
head_dim,
paged_attention,
batch_idx,
input_pos,
block_mask,
score_mod,
converted_block_mask,
converted_score_mod,
dtype=torch.bfloat16,
):
from model import NonPagedAttentionLayer, PagedAttentionLayer

# compile model
non_paged_foo = torch.compile(NonPagedAttentionLayer(bsz, n_heads, max_seq_len, head_dim, dtype), fullgraph=True)
paged_foo = torch.compile(PagedAttentionLayer(n_heads, head_dim, dtype, paged_attention), fullgraph=True)
non_paged_foo = torch.compile(
NonPagedAttentionLayer(bsz, n_heads, max_seq_len, head_dim, dtype), fullgraph=True
)
paged_foo = torch.compile(
PagedAttentionLayer(n_heads, head_dim, dtype, paged_attention), fullgraph=True
)

with torch.no_grad():
# randomize a token embedding
x = torch.randn(bsz, 1, n_heads*head_dim, device="cuda", dtype=dtype)
x = torch.randn(bsz, 1, n_heads * head_dim, device="cuda", dtype=dtype)

# warmup
for _ in range(10):
non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod)
paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod)

# benchmark
non_paged_latency = benchmarker.benchmark_gpu(lambda: non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod))
paged_latency = benchmarker.benchmark_gpu(lambda: paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod))
print(f"non_paged_latency: {non_paged_latency} ms, paged_latency: {paged_latency} ms, overhead: {round((paged_latency/non_paged_latency-1.0)*100, 2)}%")


def benchmark(attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_heads: int, head_dim: int):
non_paged_latency = benchmarker.benchmark_gpu(
lambda: non_paged_foo(batch_idx, input_pos, x, block_mask, score_mod)
)
paged_latency = benchmarker.benchmark_gpu(
lambda: paged_foo(batch_idx, input_pos, x, converted_block_mask, converted_score_mod)
)
print(
f"non_paged_latency: {non_paged_latency} ms, paged_latency: {paged_latency} ms, overhead: {round((paged_latency/non_paged_latency-1.0)*100, 2)}%"
)


def benchmark(
attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_heads: int, head_dim: int
):
# For decoding benchmark, we set input_pos to be half of max_seq_len
input_pos = torch.tensor([max_seq_len // 2]*bsz, device="cuda", dtype=torch.int32).view(bsz, 1) # [bsz, 1]
batch_idx = torch.arange(bsz, device="cuda", dtype=torch.int32) # [bsz]
input_pos = torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32).view(
bsz, 1
) # [bsz, 1]
batch_idx = torch.arange(bsz, device="cuda", dtype=torch.int32) # [bsz]

# init paged attention
n_pages = (max_seq_len + page_size - 1) // page_size * bsz
paged_attention = random_init_paged_attention(n_pages, page_size, bsz, max_seq_len)

# Block mask
if attn_type == "causal":
mask_mod = gen_offset(torch.tensor([max_seq_len // 2]*bsz, device="cuda", dtype=torch.int32))
mask_mod = gen_offset(
torch.tensor([max_seq_len // 2] * bsz, device="cuda", dtype=torch.int32)
)
else:
mask_mod = noop_mask
block_mask = create_block_mask(mask_mod, bsz, 1, 1, max_seq_len, BLOCK_SIZE=page_size)
Expand All @@ -60,7 +89,19 @@ def benchmark(attn_type: str, page_size: int, bsz: int, max_seq_len: int, n_head
score_mod = generate_score_mod(attn_type)
converted_score_mod = paged_attention.get_score_mod(score_mod)

benchmark_layer(bsz, n_heads, max_seq_len, head_dim, paged_attention, batch_idx, input_pos, block_mask, score_mod, converted_block_mask, converted_score_mod)
benchmark_layer(
bsz,
n_heads,
max_seq_len,
head_dim,
paged_attention,
batch_idx,
input_pos,
block_mask,
score_mod,
converted_block_mask,
converted_score_mod,
)


if __name__ == "__main__":
Expand Down
110 changes: 77 additions & 33 deletions attn_gym/paged_attention/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,22 @@

class NonPagedAttentionLayer(torch.nn.Module):
"""An attention layer without paged attention, ported from GPT-Fast:
https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L180-L227
https://github.com/pytorch-labs/gpt-fast/blob/main/model.py#L180-L227
"""
def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int=32768):

def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int = 32768):
super().__init__()
self.n_head = n_heads
self.head_dim = head_dim

# key, query, value projections for all heads, but in a batch
total_head_dim = n_heads * head_dim
self.wqkv = torch.nn.Linear(total_head_dim, 3*total_head_dim, bias=False, device="cuda", dtype=dtype)
self.wo = torch.nn.Linear(total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype)
self.wqkv = torch.nn.Linear(
total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype
)
self.wo = torch.nn.Linear(
total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype
)
self.k_cache = torch.randn(
(bsz, n_heads, max_seq_len, head_dim), device="cuda", dtype=dtype
)
Expand All @@ -26,7 +31,14 @@ def __init__(self, bsz, n_heads, max_seq_len, head_dim, dtype, block_size: int=3
)
self.freqs_cis = precompute_freqs_cis(block_size, self.head_dim, dtype=dtype)

def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, block_mask: BlockMask, score_mod: _score_mod_signature) -> Tensor:
def forward(
self,
batch_idx: Tensor,
input_pos: Tensor,
x: Tensor,
block_mask: BlockMask,
score_mod: _score_mod_signature,
) -> Tensor:
# input_pos: [B, S], batch_idx: [B], x: [B, S, D]
B, S, _ = x.shape

Expand All @@ -37,7 +49,9 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, block_mask: B
k = k.view(B, S, self.n_head, self.head_dim)
v = v.view(B, S, self.n_head, self.head_dim)

freqs_cis = self.freqs_cis.unsqueeze(0)[torch.zeros((B, 1), dtype=torch.int), input_pos] # [B, S, D//2, 2]
freqs_cis = self.freqs_cis.unsqueeze(0)[
torch.zeros((B, 1), dtype=torch.int), input_pos
] # [B, S, D//2, 2]

q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
Expand All @@ -46,7 +60,9 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, block_mask: B
self.k_cache[batch_idx.view(B, 1), :, input_pos] = k
self.v_cache[batch_idx.view(B, 1), :, input_pos] = v

y = flex_attention(q, self.k_cache, self.v_cache, block_mask=block_mask, score_mod=score_mod)
y = flex_attention(
q, self.k_cache, self.v_cache, block_mask=block_mask, score_mod=score_mod
)

y = y.transpose(1, 2).contiguous().view(B, S, -1)

Expand All @@ -56,39 +72,53 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, block_mask: B

class PagedAttentionLayer(torch.nn.Module):
"""An attention layer with paged attention"""
def __init__(self, n_heads, head_dim, dtype, paged_attention, block_size: int=65536):

def __init__(self, n_heads, head_dim, dtype, paged_attention, block_size: int = 65536):
super().__init__()
self.n_head = n_heads
self.head_dim = head_dim

# key, query, value projections for all heads, but in a batch
total_head_dim = n_heads * head_dim
self.wqkv = torch.nn.Linear(total_head_dim, 3*total_head_dim, bias=False, device="cuda", dtype=dtype)
self.wo = torch.nn.Linear(total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype)
self.wqkv = torch.nn.Linear(
total_head_dim, 3 * total_head_dim, bias=False, device="cuda", dtype=dtype
)
self.wo = torch.nn.Linear(
total_head_dim, total_head_dim, bias=False, device="cuda", dtype=dtype
)

# allocate kv cache with batch size=1 for paged attention
max_cached_seq_len = paged_attention.n_pages * paged_attention.page_size
self.k_cache_paged = torch.randn(
1,
n_heads,
max_cached_seq_len,
head_dim,
device="cuda",
dtype=dtype,
)
1,
n_heads,
max_cached_seq_len,
head_dim,
device="cuda",
dtype=dtype,
)
self.v_cache_paged = torch.randn(
1,
n_heads,
max_cached_seq_len,
head_dim,
device="cuda",
dtype=dtype,
)
1,
n_heads,
max_cached_seq_len,
head_dim,
device="cuda",
dtype=dtype,
)
self.paged_attention = paged_attention

self.freqs_cis = precompute_freqs_cis(block_size, self.head_dim, dtype=dtype) # [block_size, D//2, 2]

def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, converted_block_mask: BlockMask, converted_score_mod: _score_mod_signature) -> Tensor:
self.freqs_cis = precompute_freqs_cis(
block_size, self.head_dim, dtype=dtype
) # [block_size, D//2, 2]

def forward(
self,
batch_idx: Tensor,
input_pos: Tensor,
x: Tensor,
converted_block_mask: BlockMask,
converted_score_mod: _score_mod_signature,
) -> Tensor:
# input_pos: [B, S], batch_idx: [B], x: [B, S, D]
B, S, _ = x.shape
kv_size = self.n_head * self.head_dim
Expand All @@ -98,7 +128,9 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, converted_blo
k = k.view(B, S, self.n_head, self.head_dim)
v = v.view(B, S, self.n_head, self.head_dim)

freqs_cis = self.freqs_cis.unsqueeze(0)[torch.zeros((B, 1), dtype=torch.int), input_pos] # [B, S, D//2, 2]
freqs_cis = self.freqs_cis.unsqueeze(0)[
torch.zeros((B, 1), dtype=torch.int), input_pos
] # [B, S, D//2, 2]

q = apply_rotary_emb(q, freqs_cis)
k = apply_rotary_emb(k, freqs_cis)
Expand All @@ -110,7 +142,13 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, converted_blo
batch_idx, input_pos, k, v, self.k_cache_paged, self.v_cache_paged
)

y = flex_attention(q, self.k_cache_paged, self.v_cache_paged, block_mask=converted_block_mask, score_mod=converted_score_mod)
y = flex_attention(
q,
self.k_cache_paged,
self.v_cache_paged,
block_mask=converted_block_mask,
score_mod=converted_score_mod,
)

y = y.transpose(1, 2).contiguous().view(B, S, -1)

Expand All @@ -120,8 +158,10 @@ def forward(self, batch_idx: Tensor, input_pos: Tensor, x: Tensor, converted_blo

def apply_rotary_emb(x: Tensor, freqs_cis: Tensor) -> Tensor:
# x: [B, S, H, D], freqs_cis: [B, S, D//2, 2]
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2]
freqs_cis = freqs_cis.view(xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2) # [B, S, 1, D//2, 2]
xshaped = x.float().reshape(*x.shape[:-1], -1, 2) # [B, S, H, D//2, 2]
freqs_cis = freqs_cis.view(
xshaped.size(0), xshaped.size(1), 1, xshaped.size(3), 2
) # [B, S, 1, D//2, 2]
x_out2 = torch.stack(
[
xshaped[..., 0] * freqs_cis[..., 0] - xshaped[..., 1] * freqs_cis[..., 1],
Expand Down Expand Up @@ -151,13 +191,17 @@ def apply_rope_scaling(freqs: torch.Tensor, rope_scaling: Dict):
new_freqs.append(freq / factor)
else:
assert low_freq_wavelen != high_freq_wavelen
smooth = (old_context_len / wavelen - low_freq_factor) / (high_freq_factor - low_freq_factor)
smooth = (old_context_len / wavelen - low_freq_factor) / (
high_freq_factor - low_freq_factor
)
new_freqs.append((1 - smooth) * freq / factor + smooth * freq)
return torch.tensor(new_freqs, dtype=freqs.dtype, device=freqs.device)


def precompute_freqs_cis(
seq_len: int, n_elem: int, base: int = 10000,
seq_len: int,
n_elem: int,
base: int = 10000,
dtype: torch.dtype = torch.bfloat16,
rope_scaling: Optional[dict] = None,
) -> Tensor:
Expand Down
Loading

0 comments on commit 9f4b741

Please sign in to comment.