Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FlexAttention slower than eager in HF transformers #95

Open
staghado opened this issue Dec 27, 2024 · 5 comments
Open

FlexAttention slower than eager in HF transformers #95

staghado opened this issue Dec 27, 2024 · 5 comments

Comments

@staghado
Copy link

related PR : huggingface/transformers#35423
Repro gist : https://gist.github.com/staghado/c3688a51aadec9e0b63316d8a7227064

The implementation combines a sliding window mask with a document mask. The masks are created once for each input and re-used for subsequent layers.
One thing that might be the issue is that the flex_attention function is not compiled in transformers.
I might be missing something, thanks in advance for your help.

Using attn_implementation=flex_attention
Sequence length : torch.Size([1, 702])
  Time taken: 0.7820 seconds

Using attn_implementation=sdpa
Sequence length : torch.Size([1, 702])
  Time taken: 0.0748 seconds

Using attn_implementation=eager
Sequence length : torch.Size([1, 702])
  Time taken: 0.0679 seconds
@staghado
Copy link
Author

when running FlexAttention vs. SDPA alone (with compile), I get :

Torch version: 2.6.0.dev20241112+cu121

=== Benchmark Results ===
+--------------+--------------+----------------------+----------------------+
|   Batch Size |   Seq Length |   FLEX Avg Time (ms) |   SDPA Avg Time (ms) |
+==============+==============+======================+======================+
|            1 |          128 |      19157.4         |      18323.5         |
+--------------+--------------+----------------------+----------------------+
|            1 |          256 |      29308.9         |      26515.9         |
+--------------+--------------+----------------------+----------------------+
|            1 |          512 |      42290.9         |      43449.5         |
+--------------+--------------+----------------------+----------------------+
|            1 |         1024 |      47303.1         |      85003.2         |
+--------------+--------------+----------------------+----------------------+
|            1 |         2048 |      89719.6         |     221348           |
+--------------+--------------+----------------------+----------------------+
|            1 |         4096 |     170735           |     842239           |
+--------------+--------------+----------------------+----------------------+
|            1 |         8192 |     331801           |          3.34551e+06 |
+--------------+--------------+----------------------+----------------------+
|            2 |         8192 |     645975           |          6.44274e+06 |
+--------------+--------------+----------------------+----------------------+
|            4 |         8192 |          1.26423e+06 |          1.24122e+07 |
+--------------+--------------+----------------------+----------------------+
|            4 |         8192 |          1.26883e+06 |          1.24035e+07 |
+--------------+--------------+----------------------+----------------------+
import torch
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
from tabulate import tabulate

from torch.nn.attention.flex_attention import (
    flex_attention,
    create_block_mask,
    create_mask,
)

torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)
print(f"Torch version: {torch.__version__}")

from torch._inductor.utils import do_bench_using_profiling
from typing import Callable
def benchmark_cuda_function_in_microseconds(func: Callable, *args, **kwargs) -> float:
    """Thin wrapper around do_bench_using_profiling"""
    no_args = lambda: func(*args, **kwargs)
    time = do_bench_using_profiling(no_args)
    return time * 1e3

benchmark_fn = benchmark_cuda_function_in_microseconds

WINDOW_SIZE = 64

def generate_block_mask(sequence_ids, cu_seqlens, WINDOW_SIZE):
    def sliding_window_seq_mask_mod(b, h, q_idx, kv_idx):
        # only allow attention within the same sequence
        same_seq = sequence_ids[q_idx] == sequence_ids[kv_idx]

        # get position within the sequence
        q_pos = q_idx - cu_seqlens[sequence_ids[q_idx]]
        kv_pos = kv_idx - cu_seqlens[sequence_ids[kv_idx]]

        # sliding window within each sequence
        in_window = (q_pos - kv_pos).abs() <= WINDOW_SIZE

        return same_seq & in_window
    return sliding_window_seq_mask_mod

def SWA_mask(b, h, q_idx, kv_idx):
    # sliding window within each sequence
    in_window = (q_idx - kv_idx).abs() <= WINDOW_SIZE
    return in_window

# Benchmarking function
def run_benchmark(batch_sizes, sequence_lengths, num_heads=16, hidden_dim=64, n_runs=3):
    results = []

    for batch_size in batch_sizes:
        for seq_len in sequence_lengths:
            q = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
            ).to("cuda")
            k = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
            ).to("cuda")
            v = torch.randn(
                batch_size, num_heads, seq_len, hidden_dim, dtype=torch.bfloat16
            ).to("cuda")

            
            sequence_lengths = [seq_len] * batch_size
            sequence_ids = torch.cat([torch.full((length,), i, dtype=torch.long) for i, length in enumerate(sequence_lengths)]).to("cuda")
            _, counts = torch.unique_consecutive(sequence_ids, return_counts=True)
            cu_seqlens = torch.cat([torch.tensor([0], device=sequence_ids.device), counts.cumsum(0)[:]])

            block_mask = create_block_mask(
                generate_block_mask(sequence_ids, cu_seqlens, WINDOW_SIZE),
                B=None,
                H=None,
                Q_LEN=cu_seqlens[-1],
                KV_LEN=cu_seqlens[-1],
                device="cuda",
            )
            mask = create_mask(SWA_mask, None, None, seq_len, seq_len, device="cuda")

            # Benchmark flex_attention
            flex_times = []
            for _ in range(n_runs):
                flex_time = benchmark_fn(
                    flex_attention,
                    q.reshape(1, num_heads, -1, hidden_dim),
                    k.reshape(1, num_heads, -1, hidden_dim),
                    v.reshape(1, num_heads, -1, hidden_dim),
                    score_mod=None,
                    block_mask=block_mask,
                )
                flex_times.append(flex_time)
            flex_avg_time = (sum(flex_times) / n_runs) * 1000  # Convert to ms

            # Benchmark scaled_dot_product_attention with mask
            sdpa_times = []
            with sdpa_kernel([SDPBackend.MATH, SDPBackend.EFFICIENT_ATTENTION]):
                for _ in range(n_runs):
                    sdpa_time = benchmark_fn(
                        scaled_dot_product_attention,
                        q,
                        k,
                        v,
                        attn_mask=mask,
                    )
                    sdpa_times.append(sdpa_time)
                sdpa_avg_time = (sum(sdpa_times) / n_runs) * 1000  # Convert to ms

            results.append(
                {
                    "Batch Size": batch_size,
                    "Seq Length": seq_len,
                    "FLEX Avg Time (ms)": f"{flex_avg_time:.2f}",
                    "SDPA Avg Time (ms)": f"{sdpa_avg_time:.2f}",
                }
            )

    return results


if __name__ == "__main__":
    batch_sizes = [
        1,
        2,
        4,
    ]
    sequence_lengths = [128, 256, 512, 1024, 2048, 4096, 8192]
    n_runs = 5

    results = run_benchmark(batch_sizes, sequence_lengths, n_runs=n_runs)

    # Generate table
    print("\n=== Benchmark Results ===")
    print(tabulate(results, headers="keys", tablefmt="grid"))

So my question is how to cleanly integrate

torch._dynamo.config.cache_size_limit = 1000
flex_attention = torch.compile(flex_attention, dynamic=False)

into transformers?

@drisspg
Copy link
Contributor

drisspg commented Dec 27, 2024

Without a doubt it will be slower than eager when it is not compiled. Let me ping some HF folks to see if we can raise a warning / ensure it is easy to compile.

@staghado
Copy link
Author

staghado commented Jan 9, 2025

is there a way to make the above code work with the latest pytorch release? i'm not sure but it seems that we need the latest nightly because the default BLOCK_SIZE changed? will it work if we specify it manually?

@drisspg
Copy link
Contributor

drisspg commented Jan 10, 2025

Taking a look now

@drisspg
Copy link
Contributor

drisspg commented Jan 10, 2025

I just insatlled the 2.6 RC via
pip3 install torch==2.6.0 --index-url https://download.pytorch.org/whl/test/cu124

Using this scripts

import time
import torch
import gc
from transformers import AutoTokenizer, AutoModelForMaskedLM

torch.set_float32_matmul_precision("high")
model_id = "answerdotai/ModernBERT-base"
tokenizer = AutoTokenizer.from_pretrained(model_id)

texts = [
    "The capital of France is [MASK]." * 100,
    # "The largest city in Canada is [MASK]."*200,
    # "The currency of Japan is [MASK]."*300,
    # "The highest mountain in the world is [MASK]."*500
]

implementations = ["flex_attention", "sdpa", "eager"]
num_repeats = 3
num_warmup = 2  # Number of warmup runs before timing


def time_model(attn_implementation, text):
    model = AutoModelForMaskedLM.from_pretrained(
        model_id, attn_implementation=attn_implementation
    ).to("cuda")
    inputs = tokenizer(text, return_tensors="pt")
    inputs = {k: v.to(model.device) for k, v in inputs.items()}
    print(f"Sequence length : {inputs['input_ids'].shape}")

    # Warmup runs
    print("  Performing warmup runs...")
    for _ in range(num_warmup):
        with torch.no_grad():
            _ = model(**inputs)

    # Actual timing runs
    total_time = 0
    for i in range(num_repeats):
        torch.cuda.synchronize()  # Ensure previous run is complete
        start_time = time.time()
        with torch.no_grad():  # Added no_grad for inference
            outputs = model(**inputs)
        torch.cuda.synchronize()  # Ensure run is complete before timing
        end_time = time.time()
        total_time += end_time - start_time
        print(f"    Run {i+1}: {end_time - start_time:.4f}s")

    return total_time / num_repeats


# Clear GPU memory before starting
torch.cuda.empty_cache()
gc.collect()

print("Starting benchmark with warmup...")
for attn_implementation in implementations:
    print(f"\nUsing attn_implementation={attn_implementation}")
    for text in texts:
        avg_time = time_model(attn_implementation, text)
        print(f"  Average time: {avg_time:.4f} seconds")
    torch.cuda.empty_cache()
    gc.collect()
    print()

and this diff to your PR

❯ g diff
diff --git a/src/transformers/models/modernbert/modeling_modernbert.py b/src/transformers/models/modernbert/modeling_modernbert.py
index 8c2de6fec..a21a94a83 100644
--- a/src/transformers/models/modernbert/modeling_modernbert.py
+++ b/src/transformers/models/modernbert/modeling_modernbert.py
@@ -52,6 +52,8 @@ else:
 # NOTE : the ModernBERT flexattention implementation is not compatible with torch < 2.6
 if is_torch_flex_attn_available():
     from torch.nn.attention.flex_attention import BlockMask, create_block_mask, flex_attention
+    create_block_mask = torch.compile(create_block_mask)
+    flex_attention = torch.compile(flex_attention)
 else:
     BlockMask, create_block_mask, flex_attention = object, object, object
 

with

Gives me:

Starting benchmark with warmup...

Using attn_implementation=flex_attention
Sequence length : torch.Size([1, 702])
  Performing warmup runs...
    Run 1: 0.0264s
    Run 2: 0.0265s
    Run 3: 0.0255s
  Average time: 0.0261 seconds


Using attn_implementation=sdpa
Sequence length : torch.Size([1, 702])
  Performing warmup runs...
    Run 1: 0.0575s
    Run 2: 0.0593s
    Run 3: 0.0574s
  Average time: 0.0581 seconds


Using attn_implementation=eager
Sequence length : torch.Size([1, 702])
  Performing warmup runs...
    Run 1: 0.0620s
    Run 2: 0.0588s
    Run 3: 0.0625s
  Average time: 0.0611 seconds

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants