Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Building a composite mask with attention_mask from tokenizers #98

Open
lhallee opened this issue Jan 2, 2025 · 0 comments
Open

Building a composite mask with attention_mask from tokenizers #98

lhallee opened this issue Jan 2, 2025 · 0 comments

Comments

@lhallee
Copy link

lhallee commented Jan 2, 2025

Trying to compare if document masking is any faster than utilizing the HF transformers masking convention. Really confused why the examples below are not equivalent, I am probably constructing the mask_mods wrong. Would appreciate any help.

import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention
from transformers import EsmTokenizer

torch.set_default_device("cuda")
torch.manual_seed(0)
torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)

tokenizer = EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")


heads = 8
hidden_dim = 512
sliding_window_size = 64
seq_a = 'MYDSNIFEKVNQYKFLYIWWLIMINVNH'
seq_b = 'VPAALSAHAVLVDHHDQLLAALRVLLPVRHWEGAVVTPEQTRRQHVAQRDMVVRSLLLAPTPEERGEERPRKPRQVTVVPVVSREPGAHAYRLPIVREARYL'
batch = [seq_a, seq_b, seq_a, seq_b]


### Example 1 with (1, seq_len * batch_size)
input_ids = []
for seq in batch:
    input_ids.extend(tokenizer(seq, add_special_tokens=False, truncation=True, max_length=1024, padding=False).input_ids)

input_ids = torch.tensor(input_ids).flatten()

def doc_mask_mod(b, h, q_idx, kv_idx): # BERT style document mask with sliding window
    bidirectional_sliding_window_mask = torch.abs(q_idx - kv_idx) < sliding_window_size
    doc_mask = docs[q_idx] == docs[kv_idx]
    return bidirectional_sliding_window_mask & doc_mask

input_ids = input_ids.flatten()
docs = (input_ids == tokenizer.cls_token_id).cumsum(dim=0)  # shape: [S]
S = len(input_ids)
block_mask = create_block_mask(doc_mask_mod, None, None, S, S)

# BHLD
q = torch.randn(1, heads, S, hidden_dim // heads).cuda()
k = torch.randn(1, heads, S, hidden_dim // heads).cuda()
v = torch.randn(1, heads, S, hidden_dim // heads).cuda()

att_1 = flex_attention(q, k, v, block_mask=block_mask).detach().cpu()
print(att_1.shape) # (1, 8, S, 64)


### Example 2 with more traditional (batch_size, seq_len)
tokenized = tokenizer(batch, return_tensors='pt', padding=True, add_special_tokens=False, truncation=True, max_length=1024)
input_ids = tokenized.input_ids.cuda() # (B, S)
attention_mask = tokenized.attention_mask.bool().cuda() # (B, S)

def batched_mask_mod(b, h, q_idx, kv_idx): # BERT style document mask based on attention mask with sliding window
    bidirectional_sliding_window_mask = torch.abs(q_idx - kv_idx) < sliding_window_size
    doc_mask = attention_mask[b, q_idx] & attention_mask[b, kv_idx]
    return bidirectional_sliding_window_mask & doc_mask

B = input_ids.shape[0]
S = input_ids.shape[1]

q = torch.randn(B, heads, S, hidden_dim // heads).cuda()
k = torch.randn(B, heads, S, hidden_dim // heads).cuda()
v = torch.randn(B, heads, S, hidden_dim // heads).cuda()

block_mask = create_block_mask(batched_mask_mod, None, None, S, S)

att_2 = flex_attention(q, k, v, block_mask=block_mask).detach().cpu()
print(att_2.shape) # (4, 8, S, 64)

# The first len(seq_a) of each should be the attention output of the first sequence
close = att_1[0, :, :len(seq_a), :].isclose(att_2[0, :, :len(seq_a), :])
print(close.any()) # False, expecting True
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant