Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

flex_attention ver. #192

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 14 additions & 15 deletions generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None):

def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor:
# input_pos: [B, S]
logits = model(x, input_pos)
logits = model.prefill(x, input_pos)
return sample(logits, **sampling_kwargs)[0]

def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]:
Expand All @@ -68,15 +68,14 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso
def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs):
new_tokens, new_probs = [], []
for i in range(num_new_tokens):
with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)
next_token, next_prob = decode_one_token(
model, cur_token, input_pos, **sampling_kwargs
)
input_pos += 1
new_tokens.append(next_token.clone())
callback(new_tokens[-1])
new_probs.append(next_prob.clone())
cur_token = next_token.view(1, -1)

return new_tokens, new_probs

Expand Down Expand Up @@ -154,10 +153,11 @@ def generate(
# create an empty tensor of the expected final shape and fill in the current tokens
T = prompt.size(0)
T_new = T + max_new_tokens
T_buf = ((T_new - 1) // 128 + 1) * 128 # round up to multiple of 128 to use flex_attention
if interactive:
max_seq_length = 350
else:
max_seq_length = min(T_new, model.config.block_size)
max_seq_length = min(T_buf, model.config.block_size)

device, dtype = prompt.device, prompt.dtype
max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length
Expand All @@ -167,11 +167,10 @@ def generate(
draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length)

# create an empty tensor of the expected final shape and fill in the current tokens
empty = torch.empty(T_new, dtype=dtype, device=device)
empty = torch.empty(T_buf, dtype=dtype, device=device)
empty[:T] = prompt
seq = empty
input_pos = torch.arange(0, T, device=device)

next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone()
if is_speculative:
prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs)
Expand All @@ -198,12 +197,12 @@ def generate(
next_token = next_tokens[-1]
else:
generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs)
seq[T + 1:] = torch.cat(generated_tokens)
seq[T + 1:T_new] = torch.cat(generated_tokens)

generate_stats = {
'accept_counts': accept_counts
}
return seq, generate_stats
return seq[:T_new], generate_stats

def encode_tokens(tokenizer, string, bos=True, device=default_device):
tokens = tokenizer.encode(string)
Expand Down
53 changes: 48 additions & 5 deletions model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from torch import Tensor
from torch.nn import functional as F

from torch.nn.attention.flex_attention import (flex_attention, create_block_mask, BlockMask)
from typing import Callable

def find_multiple(n: int, k: int) -> int:
if n % k == 0:
Expand Down Expand Up @@ -88,6 +90,33 @@ def update(self, input_pos, k_val, v_val):

return k_out, v_out


def causal_mask(b, h, q, kv):
return q >= kv
class Att_Mask:
def mask_mod(self, b, h, q, kv):
offset = self.input_pos[0]
return offset + q >= kv

def __init__(self, context_len):
self.input_pos = None
self.context_len = context_len
self.block_masks = create_block_mask(causal_mask, 1, 1, context_len, context_len)

def get_mask(self, kv_len, input_pos):
self.input_pos = input_pos
offset = self.input_pos // self.block_masks.BLOCK_SIZE[0]
new_kv_num_blocks = self.block_masks.kv_num_blocks[:, :, offset]
new_kv_indices = self.block_masks.kv_indices[:, :, offset, :kv_len]
new_full_kv_num_blocks = self.block_masks.full_kv_num_blocks[:, :, offset]
new_full_kv_indices = self.block_masks.full_kv_indices[:, :, offset, :kv_len]
layer_mask = BlockMask(new_kv_num_blocks, new_kv_indices, new_full_kv_num_blocks, new_full_kv_indices, BLOCK_SIZE=self.block_masks.BLOCK_SIZE, mask_mod=self.mask_mod)
return layer_mask

def gen_prefill_mask(self, kv_len, q_len):
return create_block_mask(causal_mask, 1, 1, q_len, kv_len)


class Transformer(nn.Module):
def __init__(self, config: ModelArgs) -> None:
super().__init__()
Expand All @@ -102,6 +131,7 @@ def __init__(self, config: ModelArgs) -> None:
self.mask_cache: Optional[Tensor] = None
self.max_batch_size = -1
self.max_seq_length = -1
self.fa_mod = Att_Mask(config.block_size)

def setup_caches(self, max_batch_size, max_seq_length):
if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size:
Expand All @@ -120,11 +150,22 @@ def setup_caches(self, max_batch_size, max_seq_length):
b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype)

self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype)
self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool))

def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.causal_mask[None, None, input_pos]
mask = self.fa_mod.get_mask(self.max_seq_length, input_pos)
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

for i, layer in enumerate(self.layers):
x = layer(x, input_pos, freqs_cis, mask)
x = self.norm(x)
logits = self.output(x)
return logits

def prefill(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
assert self.freqs_cis is not None, "Caches must be initialized first"
mask = self.fa_mod.gen_prefill_mask(self.max_seq_length, input_pos.shape[0])
freqs_cis = self.freqs_cis[input_pos]
x = self.tok_embeddings(idx)

Expand All @@ -147,7 +188,7 @@ def __init__(self, config: ModelArgs) -> None:
self.ffn_norm = RMSNorm(config.dim, config.norm_eps)
self.attention_norm = RMSNorm(config.dim, config.norm_eps)

def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor:
def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor:
h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos)
out = h + self.feed_forward(self.ffn_norm(h))
return out
Expand Down Expand Up @@ -177,7 +218,7 @@ def load_hook(self, state_dict, prefix, *args):
wv = state_dict.pop(prefix + "wv.weight")
state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv])

def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor:
def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor:
bsz, seqlen, _ = x.shape

kv_size = self.n_local_heads * self.head_dim
Expand All @@ -197,7 +238,9 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona

k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1)
y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0)

y = flex_attention(q, k, v, block_mask=mask, enable_gqa= (self.n_head != self.n_local_heads))


y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)

Expand Down