From cb2e3a262ced633ebc081aa1faaf4dc9c010d1f2 Mon Sep 17 00:00:00 2001 From: joydddd Date: Fri, 19 Jul 2024 16:39:30 -0700 Subject: [PATCH 1/2] Use flex_attention --- generate.py | 28 ++++++++++++++-------------- model.py | 52 +++++++++++++++++++++++++++++++++++++++++++++++----- 2 files changed, 61 insertions(+), 19 deletions(-) diff --git a/generate.py b/generate.py index c58a2249..17e2daac 100644 --- a/generate.py +++ b/generate.py @@ -56,7 +56,7 @@ def sample(logits, temperature: float = 1.0, top_k: Optional[int] = None): def prefill(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> torch.Tensor: # input_pos: [B, S] - logits = model(x, input_pos) + logits = model.prefill(x, input_pos) return sample(logits, **sampling_kwargs)[0] def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tensor, **sampling_kwargs) -> Tuple[torch.Tensor, torch.Tensor]: @@ -68,15 +68,14 @@ def decode_one_token(model: Transformer, x: torch.Tensor, input_pos: torch.Tenso def decode_n_tokens(model: Transformer, cur_token: torch.Tensor, input_pos: torch.Tensor, num_new_tokens: int, callback=lambda _: _, **sampling_kwargs): new_tokens, new_probs = [], [] for i in range(num_new_tokens): - with torch.backends.cuda.sdp_kernel(enable_flash=False, enable_mem_efficient=False, enable_math=True): # Actually better for Inductor to codegen attention here - next_token, next_prob = decode_one_token( - model, cur_token, input_pos, **sampling_kwargs - ) - input_pos += 1 - new_tokens.append(next_token.clone()) - callback(new_tokens[-1]) - new_probs.append(next_prob.clone()) - cur_token = next_token.view(1, -1) + next_token, next_prob = decode_one_token( + model, cur_token, input_pos, **sampling_kwargs + ) + input_pos += 1 + new_tokens.append(next_token.clone()) + callback(new_tokens[-1]) + new_probs.append(next_prob.clone()) + cur_token = next_token.view(1, -1) return new_tokens, new_probs @@ -154,10 +153,11 @@ def generate( # create an empty tensor of the expected final shape and fill in the current tokens T = prompt.size(0) T_new = T + max_new_tokens + T_buf = ((T_new - 1) // 128 + 1) * 128 # round up to multiple of 128 to use flex_attention if interactive: max_seq_length = 350 else: - max_seq_length = min(T_new, model.config.block_size) + max_seq_length = min(T_buf, model.config.block_size) device, dtype = prompt.device, prompt.dtype max_seq_length = max_seq_length + speculate_k + 1 if is_speculative else max_seq_length @@ -167,7 +167,7 @@ def generate( draft_model.setup_caches(max_batch_size=1, max_seq_length=max_seq_length) # create an empty tensor of the expected final shape and fill in the current tokens - empty = torch.empty(T_new, dtype=dtype, device=device) + empty = torch.empty(T_buf, dtype=dtype, device=device) empty[:T] = prompt seq = empty input_pos = torch.arange(0, T, device=device) @@ -198,12 +198,12 @@ def generate( next_token = next_tokens[-1] else: generated_tokens, _ = decode_n_tokens(model, next_token.view(1, -1), input_pos, max_new_tokens - 1, callback=callback, **sampling_kwargs) - seq[T + 1:] = torch.cat(generated_tokens) + seq[T + 1:T_new] = torch.cat(generated_tokens) generate_stats = { 'accept_counts': accept_counts } - return seq, generate_stats + return seq[:T_new], generate_stats def encode_tokens(tokenizer, string, bos=True, device=default_device): tokens = tokenizer.encode(string) diff --git a/model.py b/model.py index b89a19a0..0f38aad6 100644 --- a/model.py +++ b/model.py @@ -11,6 +11,8 @@ from torch import Tensor from torch.nn import functional as F +from torch.nn.attention.flex_attention import (flex_attention, create_block_mask, BlockMask) +from typing import Callable def find_multiple(n: int, k: int) -> int: if n % k == 0: @@ -88,6 +90,33 @@ def update(self, input_pos, k_val, v_val): return k_out, v_out + +def causal_mask(b, h, q, kv): + return q >= kv +class Att_Mask: + def mask_mod(self, b, h, q, kv): + offset = self.input_pos[0] + return offset + q >= kv + + def __init__(self, context_len): + self.input_pos = None + self.context_len = context_len + self.block_masks = create_block_mask(causal_mask, 1, 1, context_len, context_len) + + def get_mask(self, kv_len, input_pos): + self.input_pos = input_pos + offset = self.input_pos // self.block_masks.BLOCK_SIZE[0] + new_kv_num_blocks = self.block_masks.kv_num_blocks[:, :, offset] + new_kv_indices = self.block_masks.kv_indices[:, :, offset, :kv_len] + new_full_kv_num_blocks = self.block_masks.full_kv_num_blocks[:, :, offset] + new_full_kv_indices = self.block_masks.full_kv_indices[:, :, offset, :kv_len] + layer_mask = BlockMask(new_kv_num_blocks, new_kv_indices, new_full_kv_num_blocks, new_full_kv_indices, self.block_masks.BLOCK_SIZE, self.mask_mod) + return layer_mask + + def gen_prefill_mask(self, kv_len, q_len): + return create_block_mask(causal_mask, 1, 1, q_len, kv_len) + + class Transformer(nn.Module): def __init__(self, config: ModelArgs) -> None: super().__init__() @@ -102,6 +131,7 @@ def __init__(self, config: ModelArgs) -> None: self.mask_cache: Optional[Tensor] = None self.max_batch_size = -1 self.max_seq_length = -1 + self.fa_mod = Att_Mask(config.block_size) def setup_caches(self, max_batch_size, max_seq_length): if self.max_seq_length >= max_seq_length and self.max_batch_size >= max_batch_size: @@ -120,11 +150,22 @@ def setup_caches(self, max_batch_size, max_seq_length): b.attention.kv_cache = KVCache(max_batch_size, max_seq_length, self.config.n_local_heads, head_dim, dtype) self.freqs_cis = precompute_freqs_cis(self.config.block_size, self.config.dim // self.config.n_head, self.config.rope_base, dtype) - self.causal_mask = torch.tril(torch.ones(self.max_seq_length, self.max_seq_length, dtype=torch.bool)) def forward(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: assert self.freqs_cis is not None, "Caches must be initialized first" - mask = self.causal_mask[None, None, input_pos] + mask = self.fa_mod.get_mask(self.max_seq_length, input_pos) + freqs_cis = self.freqs_cis[input_pos] + x = self.tok_embeddings(idx) + + for i, layer in enumerate(self.layers): + x = layer(x, input_pos, freqs_cis, mask) + x = self.norm(x) + logits = self.output(x) + return logits + + def prefill(self, idx: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + assert self.freqs_cis is not None, "Caches must be initialized first" + mask = self.fa_mod.gen_prefill_mask(self.max_seq_length, input_pos.shape[0]) freqs_cis = self.freqs_cis[input_pos] x = self.tok_embeddings(idx) @@ -147,7 +188,7 @@ def __init__(self, config: ModelArgs) -> None: self.ffn_norm = RMSNorm(config.dim, config.norm_eps) self.attention_norm = RMSNorm(config.dim, config.norm_eps) - def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: Tensor) -> Tensor: + def forward(self, x: Tensor, input_pos: Tensor, freqs_cis: Tensor, mask: BlockMask) -> Tensor: h = x + self.attention(self.attention_norm(x), freqs_cis, mask, input_pos) out = h + self.feed_forward(self.ffn_norm(h)) return out @@ -177,7 +218,7 @@ def load_hook(self, state_dict, prefix, *args): wv = state_dict.pop(prefix + "wv.weight") state_dict[prefix + "wqkv.weight"] = torch.cat([wq, wk, wv]) - def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optional[Tensor] = None) -> Tensor: + def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Optional[Tensor] = None) -> Tensor: bsz, seqlen, _ = x.shape kv_size = self.n_local_heads * self.head_dim @@ -197,7 +238,8 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: Tensor, input_pos: Optiona k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0.0) + + y = flex_attention(q, k, v, block_mask=mask) y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim) From 8c869015d70f0d8174de1526e79772ee6c9a2a64 Mon Sep 17 00:00:00 2001 From: joydddd Date: Wed, 31 Jul 2024 15:28:03 -0700 Subject: [PATCH 2/2] Adapt to gqa and new block_mask gen mode --- generate.py | 1 - model.py | 5 +++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/generate.py b/generate.py index 17e2daac..a4b84112 100644 --- a/generate.py +++ b/generate.py @@ -171,7 +171,6 @@ def generate( empty[:T] = prompt seq = empty input_pos = torch.arange(0, T, device=device) - next_token = prefill(model, prompt.view(1, -1), input_pos, **sampling_kwargs).clone() if is_speculative: prefill(draft_model, prompt.view(1, -1), input_pos, **sampling_kwargs) diff --git a/model.py b/model.py index 0f38aad6..11c743de 100644 --- a/model.py +++ b/model.py @@ -110,7 +110,7 @@ def get_mask(self, kv_len, input_pos): new_kv_indices = self.block_masks.kv_indices[:, :, offset, :kv_len] new_full_kv_num_blocks = self.block_masks.full_kv_num_blocks[:, :, offset] new_full_kv_indices = self.block_masks.full_kv_indices[:, :, offset, :kv_len] - layer_mask = BlockMask(new_kv_num_blocks, new_kv_indices, new_full_kv_num_blocks, new_full_kv_indices, self.block_masks.BLOCK_SIZE, self.mask_mod) + layer_mask = BlockMask(new_kv_num_blocks, new_kv_indices, new_full_kv_num_blocks, new_full_kv_indices, BLOCK_SIZE=self.block_masks.BLOCK_SIZE, mask_mod=self.mask_mod) return layer_mask def gen_prefill_mask(self, kv_len, q_len): @@ -239,7 +239,8 @@ def forward(self, x: Tensor, freqs_cis: Tensor, mask: BlockMask, input_pos: Opti k = k.repeat_interleave(self.n_head // self.n_local_heads, dim=1) v = v.repeat_interleave(self.n_head // self.n_local_heads, dim=1) - y = flex_attention(q, k, v, block_mask=mask) + y = flex_attention(q, k, v, block_mask=mask, enable_gqa= (self.n_head != self.n_local_heads)) + y = y.transpose(1, 2).contiguous().view(bsz, seqlen, self.dim)