From d783866587f2289b03f5da0e7bde1cd0529d5b26 Mon Sep 17 00:00:00 2001 From: drisspg Date: Wed, 31 Jul 2024 16:42:01 -0700 Subject: [PATCH] add document mask --- attn_gym/masks/document_mask.py | 100 ++++++++++++++++++++++++++++++++ attn_gym/utils.py | 5 +- 2 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 attn_gym/masks/document_mask.py diff --git a/attn_gym/masks/document_mask.py b/attn_gym/masks/document_mask.py new file mode 100644 index 0000000..31e3baf --- /dev/null +++ b/attn_gym/masks/document_mask.py @@ -0,0 +1,100 @@ +"""Generates a document causal attention mask based on a document ID tensor""" + +import torch +from torch import Tensor +from torch.nn.attention.flex_attention import _mask_mod_signature +from attn_gym.masks import causal_mask + + +def _offsets_to_doc_ids_tensor(offsets): + device = offsets.device + counts = offsets[1:] - offsets[:-1] + return torch.repeat_interleave( + torch.arange(len(counts), device=device, dtype=torch.int32), counts + ) + + +def generate_doc_mask_mod(mask_mod: _mask_mod_signature, offsets: Tensor) -> _mask_mod_signature: + """Generates mask mods that apply to inputs to flex attention in the sequence stacked + format. + + Args: + mask_mod: The mask mod to apply to the documents + offsets: This tensor should be of shape(num_documents + 1) + this should contain the cumulative counts of document tokens. + e.g. if you have 3 documents of length 2, 4, 3 then + offsets = [0, 2, 6, 9] + + Note: + What is the sequence stacked format? When assembling batches of inputs, we + take multiple sequences and stack them together to form 1 large sequence. We then + use masking to ensure that the attention scores are only applied to tokens within + the same document. + """ + document_id = _offsets_to_doc_ids_tensor(offsets) + + def doc_mask_mod(b, h, q_idx, kv_idx): + same_doc = document_id[q_idx] == document_id[kv_idx] + q_logical = q_idx - offsets[document_id[q_idx]] + kv_logical = kv_idx - offsets[document_id[kv_idx]] + inner_mask = mask_mod(b, h, q_logical, kv_logical) + return same_doc & inner_mask + + return doc_mask_mod + + +def main(device: str = "cpu"): + """Visualize the attention scores of document causal mask mod. + + Args: + device (str): Device to use for computation. Defaults to "cpu". + """ + from attn_gym import visualize_attention_scores + import random + + random.seed(0) + + def generate_random_lengths(total_length, num_documents): + # Initialize all lengths to 1 to ensure each document has at least one token + lengths = [1] * num_documents + remaining_length = total_length - num_documents + + # Randomly distribute the remaining length + for _ in range(remaining_length): + index = random.randint(0, num_documents - 1) + lengths[index] += 1 + + return lengths + + max_seq_len, doc_count = 21, 4 + B, H, SEQ_LEN, HEAD_DIM = 1, 1, max_seq_len, 8 + + lengths = generate_random_lengths(max_seq_len, doc_count) + + offsets = [0] + offsets.extend(lengths) + offsets = torch.tensor(offsets, device=device, dtype=torch.int32) + offsets = torch.cumsum(offsets, dim=-1) + + def make_tensor(): + return torch.ones(B, H, SEQ_LEN, HEAD_DIM, device=device) + + query, key = make_tensor(), make_tensor() + document_causal_mask = generate_doc_mask_mod(causal_mask, offsets) + + visualize_attention_scores( + query, + key, + mask_mod=document_causal_mask, + device=device, + name="document_causal_mask", + ) + + +if __name__ == "__main__": + try: + from jsonargparse import CLI + except ImportError: + raise ImportError("Be sure to run: pip install -e .[viz]") + + CLI(main) diff --git a/attn_gym/utils.py b/attn_gym/utils.py index 1c5948d..a31a235 100644 --- a/attn_gym/utils.py +++ b/attn_gym/utils.py @@ -132,7 +132,10 @@ def visualize_attention_scores( num_query_tokens, num_kv_tokens = scores_viz.shape[-2:] if num_query_tokens <= 32 and num_kv_tokens <= 32: ax.set_xticks(range(num_kv_tokens)) - ax.set_xticklabels([f"KV{i}" for i in range(num_kv_tokens)], fontsize=16) + rotation = 45 if num_kv_tokens > 12 else 0 + ax.set_xticklabels( + [f"KV{i}" for i in range(num_kv_tokens)], fontsize=16, rotation=rotation + ) ax.set_yticks(range(num_query_tokens)) ax.set_yticklabels([f"Q{i}" for i in range(num_query_tokens)], fontsize=16) # Align grid with pixel boundaries