You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Trying to compare if document masking is any faster than utilizing the HF transformers masking convention. Really confused why the examples below are not equivalent, I am probably constructing the mask_mods wrong. Would appreciate any help.
importtorchfromtorch.nn.attention.flex_attentionimportcreate_block_mask, flex_attentionfromtransformersimportEsmTokenizertorch.set_default_device("cuda")
torch.manual_seed(0)
torch._dynamo.config.cache_size_limit=1000flex_attention=torch.compile(flex_attention, dynamic=False)
tokenizer=EsmTokenizer.from_pretrained("facebook/esm2_t33_650M_UR50D")
heads=8hidden_dim=512sliding_window_size=64seq_a='MYDSNIFEKVNQYKFLYIWWLIMINVNH'seq_b='VPAALSAHAVLVDHHDQLLAALRVLLPVRHWEGAVVTPEQTRRQHVAQRDMVVRSLLLAPTPEERGEERPRKPRQVTVVPVVSREPGAHAYRLPIVREARYL'batch= [seq_a, seq_b, seq_a, seq_b]
### Example 1 with (1, seq_len * batch_size)input_ids= []
forseqinbatch:
input_ids.extend(tokenizer(seq, add_special_tokens=False, truncation=True, max_length=1024, padding=False).input_ids)
input_ids=torch.tensor(input_ids).flatten()
defdoc_mask_mod(b, h, q_idx, kv_idx): # BERT style document mask with sliding windowbidirectional_sliding_window_mask=torch.abs(q_idx-kv_idx) <sliding_window_sizedoc_mask=docs[q_idx] ==docs[kv_idx]
returnbidirectional_sliding_window_mask&doc_maskinput_ids=input_ids.flatten()
docs= (input_ids==tokenizer.cls_token_id).cumsum(dim=0) # shape: [S]S=len(input_ids)
block_mask=create_block_mask(doc_mask_mod, None, None, S, S)
# BHLDq=torch.randn(1, heads, S, hidden_dim//heads).cuda()
k=torch.randn(1, heads, S, hidden_dim//heads).cuda()
v=torch.randn(1, heads, S, hidden_dim//heads).cuda()
att_1=flex_attention(q, k, v, block_mask=block_mask).detach().cpu()
print(att_1.shape) # (1, 8, S, 64)### Example 2 with more traditional (batch_size, seq_len)tokenized=tokenizer(batch, return_tensors='pt', padding=True, add_special_tokens=False, truncation=True, max_length=1024)
input_ids=tokenized.input_ids.cuda() # (B, S)attention_mask=tokenized.attention_mask.bool().cuda() # (B, S)defbatched_mask_mod(b, h, q_idx, kv_idx): # BERT style document mask based on attention mask with sliding windowbidirectional_sliding_window_mask=torch.abs(q_idx-kv_idx) <sliding_window_sizedoc_mask=attention_mask[b, q_idx] &attention_mask[b, kv_idx]
returnbidirectional_sliding_window_mask&doc_maskB=input_ids.shape[0]
S=input_ids.shape[1]
q=torch.randn(B, heads, S, hidden_dim//heads).cuda()
k=torch.randn(B, heads, S, hidden_dim//heads).cuda()
v=torch.randn(B, heads, S, hidden_dim//heads).cuda()
block_mask=create_block_mask(batched_mask_mod, None, None, S, S)
att_2=flex_attention(q, k, v, block_mask=block_mask).detach().cpu()
print(att_2.shape) # (4, 8, S, 64)# The first len(seq_a) of each should be the attention output of the first sequenceclose=att_1[0, :, :len(seq_a), :].isclose(att_2[0, :, :len(seq_a), :])
print(close.any()) # False, expecting True
The text was updated successfully, but these errors were encountered:
Trying to compare if document masking is any faster than utilizing the HF transformers masking convention. Really confused why the examples below are not equivalent, I am probably constructing the mask_mods wrong. Would appreciate any help.
The text was updated successfully, but these errors were encountered: