-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
FlexAttention slower than eager in HF transformers #95
Comments
when running FlexAttention vs. SDPA alone (with compile), I get :
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
from tabulate import tabulate
from torch.nn.attention.flex_attention import (
flex_attention,
create_block_mask,
create_mask,
)
torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)
print(f"Torch version: {torch.__version__}")
from torch._inductor.utils import do_bench_using_profiling
from typing import Callable
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
"""Thin wrapper around do_bench_using_profiling"""
no_args = lambda: func(*args, **kwargs)
time = do_bench_using_profiling(no_args)
return time * 1e3
benchmark_fn = benchmark_cuda_function_in_microseconds
WINDOW_SIZE = 64
def generate_block_mask(sequence_ids, cu_seqlens, WINDOW_SIZE):
def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx):
# only allow attention within the same sequence
same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx]
# get position within the sequence
q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]]
kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]]
# sliding window within each sequence
in_window = (q_pos - kv_pos).abs() <= WINDOW_SIZE
return same_seq & in_window
return sliding_window_seq_mask_mod
def SWA_mask(b, h, q_idx, kv_idx):
# sliding window within each sequence
in_window = (q_idx - kv_idx).abs() <= WINDOW_SIZE
return in_window
# Benchmarking function
def run_benchmark(batch_sizes, sequence_lengths, num_heads=16, hidden_dim=64, n_runs=3):
results = []
for batch_size in batch_sizes:
for seq_len in sequence_lengths:
q = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
).to("cuda")
k = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
).to("cuda")
v = torch.randn(
batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
).to("cuda")
sequence_lengths = [seq_len] * batch_size
sequence_ids = torch.cat([torch.full((length,), i, dtype=torch.long) for i, length in enumerate(sequence_lengths)]).to("cuda")
_, counts = torch.unique_consecutive(sequence_ids, return_counts=True)
cu_seqlens = torch.cat([torch.tensor([0], device=sequence_ids.device), counts.cumsum(0)[:]])
block_mask = create_block_mask(
generate_block_mask(sequence_ids, cu_seqlens, WINDOW_SIZE),
B=None,
H=None,
Q_LEN=cu_seqlens[-1],
KV_LEN=cu_seqlens[-1],
device="cuda",
)
mask = create_mask(SWA_mask, None, None, seq_len, seq_len, device="cuda")
# Benchmark flex_attention
flex_times = []
for _ in range(n_runs):
flex_time = benchmark_fn(
flex_attention,
q.reshape(1, num_heads, -1, hidden_dim),
k.reshape(1, num_heads, -1, hidden_dim),
v.reshape(1, num_heads, -1, hidden_dim),
score_mod=None,
block_mask=block_mask,
)
flex_times.append(flex_time)
flex_avg_time = (sum(flex_times) / n_runs) * 1000 # Convert to ms
# Benchmark scaled_dot_product_attention with mask
sdpa_times = []
with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
for _ in range(n_runs):
sdpa_time = benchmark_fn(
scaled_dot_product_attention,
q,
k,
v,
attn_mask=mask,
)
sdpa_times.append(sdpa_time)
sdpa_avg_time = (sum(sdpa_times) / n_runs) * 1000 # Convert to ms
results.append(
{
"Batch Size": batch_size,
"Seq Length": seq_len,
"FLEX Avg Time (ms)": f"{flex_avg_time:.2f}",
"SDPA Avg Time (ms)": f"{sdpa_avg_time:.2f}",
}
)
return results
if __name__ == "__main__":
batch_sizes = [
1,
2,
4,
]
sequence_lengths = [128, 256, 512, 1024, 2048, 4096, 8192]
n_runs = 5
results = run_benchmark(batch_sizes, sequence_lengths, n_runs=n_runs)
# Generate table
print("\n=== Benchmark Results ===")
print(tabulate(results, headers="keys", tablefmt="grid")) So my question is how to cleanly integrate torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False) into transformers? |
Without a doubt it will be slower than eager when it is not compiled. Let me ping some HF folks to see if we can raise a warning / ensure it is easy to compile. |
is there a way to make the above code work with the latest pytorch release? i'm not sure but it seems that we need the latest nightly because the default BLOCK_SIZE changed? will it work if we specify it manually? |
Taking a look now |
I just insatlled the 2.6 RC via Using this scripts import time
import torch
import gc
from transformers import AutoTokenizer, AutoModelForMaskedLM
torch.set_float32_matmul_precision("high")
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)
texts = [
"The capital of France is [MASK]." * 100,
# "The largest city in Canada is [MASK]."*200,
# "The currency of Japan is [MASK]."*300,
# "The highest mountain in the world is [MASK]."*500
]
implementations = ["flex_attention", "sdpa", "eager"]
num_repeats = 3
num_warmup = 2 # Number of warmup runs before timing
def time_model(attn_implementation, text):
model = AutoModelForMaskedLM.from_pretrained(
model_id, attn_implementation=attn_implementation
).to("cuda")
inputs = tokenizer(text, return_tensors="pt")
inputs = {k: v.to(model.device) for k, v in inputs.items()}
print(f"Sequence length : {inputs['input_ids'].shape}")
# Warmup runs
print(" Performing warmup runs...")
for _ in range(num_warmup):
with torch.no_grad():
_ = model(**inputs)
# Actual timing runs
total_time = 0
for i in range(num_repeats):
torch.cuda.synchronize() # Ensure previous run is complete
start_time = time.time()
with torch.no_grad(): # Added no_grad for inference
outputs = model(**inputs)
torch.cuda.synchronize() # Ensure run is complete before timing
end_time = time.time()
total_time += end_time - start_time
print(f" Run {i+1}: {end_time - start_time:.4f}s")
return total_time / num_repeats
# Clear GPU memory before starting
torch.cuda.empty_cache()
gc.collect()
print("Starting benchmark with warmup...")
for attn_implementation in implementations:
print(f"\nUsing attn_implementation={attn_implementation}")
for text in texts:
avg_time = time_model(attn_implementation, text)
print(f" Average time: {avg_time:.4f} seconds")
torch.cuda.empty_cache()
gc.collect()
print() and this diff to your PR ❯ g diff
diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py
index 8c2de6fec..a21a94a83 100644
--- a/src/transformers/models/modernbert/modeling_modernbert.py
+++ b/src/transformers/models/modernbert/modeling_modernbert.py
@@ -52,6 +52,8 @@ else:
# NOTE : the ModernBERT flexattention implementation is not compatible with torch < 2.6
if is_torch_flex_attn_available():
from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
+ create_block_mask = torch.compile(create_block_mask)
+ flex_attention = torch.compile(flex_attention)
else:
BlockMask, create_block_mask, flex_attention = object, object, object
with Gives me: Starting benchmark with warmup...
Using attn_implementation=flex_attention
Sequence length : torch.Size([1, 702])
Performing warmup runs...
Run 1: 0.0264s
Run 2: 0.0265s
Run 3: 0.0255s
Average time: 0.0261 seconds
Using attn_implementation=sdpa
Sequence length : torch.Size([1, 702])
Performing warmup runs...
Run 1: 0.0575s
Run 2: 0.0593s
Run 3: 0.0574s
Average time: 0.0581 seconds
Using attn_implementation=eager
Sequence length : torch.Size([1, 702])
Performing warmup runs...
Run 1: 0.0620s
Run 2: 0.0588s
Run 3: 0.0625s
Average time: 0.0611 seconds
|
related PR : huggingface/transformers#35423
Repro gist : https://gist.github.com/staghado/c3688a51aadec9e0b63316d8a7227064
The implementation combines a sliding window mask with a document mask. The masks are created once for each input and re-used for subsequent layers.
One thing that might be the issue is that the flex_attention function is not compiled in transformers.
I might be missing something, thanks in advance for your help.
The text was updated successfully, but these errors were encountered: