Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Illegal memory access on backward when there are unused block masks (nightly build) #96

Open
timt51 opened this issue Dec 28, 2024 · 2 comments

Comments

@timt51
Copy link

timt51 commented Dec 28, 2024

FlexAttention backward can fail with RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered when a block mask is created, not used in a FlexAttention call, and then another block mask is created and used in a FlexAttention call.

Script to reproduce:

import argparse
from typing import cast

import torch
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

torch.set_default_dtype(torch.bfloat16)
torch.set_default_device("cuda")
create_block_mask = torch.compile(create_block_mask)
flex_attention = torch.compile(flex_attention, dynamic=False)


def mask_mod(b, h, q_idx, kv_idx):
    return q_idx >= 0


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--skip_first_block_mask", action="store_true")
    args = parser.parse_args()

    # define problem size
    B, H, D = 1, 1, 64
    for i, S in enumerate([712, 1478]):
        # create block mask
        block_mask = create_block_mask(mask_mod=mask_mod, B=1, H=1, Q_LEN=S, KV_LEN=S)
        if i == 0 and args.skip_first_block_mask:
            continue

        # run forward and backward pass
        q = torch.rand(B, H, S, D, requires_grad=True)
        k = torch.rand(B, H, S, D, requires_grad=True)
        v = torch.rand(B, H, S, D, requires_grad=True)
        grad_out = torch.rand(B, H, S, D)
        flex_out = flex_attention(q, k, v, score_mod=None, block_mask=block_mask)
        flex_out = cast(torch.Tensor, flex_out)
        flex_out.backward(grad_out)


if __name__ == "__main__":
    main()

The script is run on an A100 GPU with the env var TORCHINDUCTOR_FORCE_DISABLE_CACHES=1. It fails when the --skip_first_block_mask flag is set, and succeeds otherwise. Is always succeeds if create_block_mask is not compiled or if it is compiled with dynamic=False.

The issue was observed with torch==2.6.0.dev20241228. It was not observed with torch==2.5.1.

Stack trace:

Traceback (most recent call last):                                                                                                                                                                                     
  File "/home/ttruong/code/attention-gym/examples/nested_fail.py", line 41, in <module>                                                                                                                                
    main()                                                                                                                                                                                                             
  File "/home/ttruong/code/attention-gym/examples/nested_fail.py", line 37, in main                                                                                                                                    
    flex_out.backward(grad_out)                                                                                                                                                                                        
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_tensor.py", line 639, in backward                                                                                                   
    return handle_torch_function(
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/overrides.py", line 1720, in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/utils/_device.py", line 104, in __torch_function__
    return func(*args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1958, in backward
    return impl_fn()
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1944, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2079, in _backward_impl
    out = call_func_at_runtime_with_args(
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 755, in _fn
    return fn(*args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/output_code.py", line 465, in __call__
    return self.current_callable(inputs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2196, in run
    return model(new_inputs)
  File "/tmp/torchinductor_ttruong/tmpbe4pdvcp/f5/cf5iwxj2ahgdeei6lzukpi2sr67mpw3sucjttx3ut7pnus6x2x4o.py", line 914, in call
    triton_per_fused_zeros_0.run(getitem, tangents_1, buf1, 1478, 64, grid=grid(1478), stream=stream0)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 918, in run
    self.autotune_to_one_config(*args, grid=grid, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 795, in autotune_to_one_config
    timings = self.benchmark_all_configs(*args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 769, in benchmark_all_configs
    timings = {
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 770, in <dictcomp>
    launcher: self.bench(launcher, *args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 666, in bench
    return benchmarker.benchmark_gpu(kernel_call, rep=40)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/benchmarking.py", line 66, in wrapper
    return fn(self, *args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/benchmarking.py", line 202, in benchmark_gpu
    return self.triton_do_bench(_callable, **kwargs, return_mode="median")
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/triton/testing.py", line 118, in do_bench
    di.synchronize()
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/cuda/__init__.py", line 987, in synchronize
    return torch._C._cuda_synchronize()
RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1 
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Stack trace with CUDA_LAUNCH_BLOCKING=1:

Traceback (most recent call last):                                                                                                                                                                                     
  File "/home/ttruong/code/attention-gym/examples/nested_fail.py", line 41, in <module>                                                                                                                                
    main()
  File "/home/ttruong/code/attention-gym/examples/nested_fail.py", line 37, in main
    flex_out.backward(grad_out)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_tensor.py", line 639, in backward
    return handle_torch_function(
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/overrides.py", line 1720, in handle_torch_function
    result = mode.__torch_function__(public_api, types, args, kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/utils/_device.py", line 104, in __torch_function__
    return func(*args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_tensor.py", line 648, in backward
    torch.autograd.backward(
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/autograd/graph.py", line 823, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1958, in backward
    return impl_fn()
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 1944, in impl_fn
    out = CompiledFunction._backward_impl(ctx, all_args)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/runtime_wrappers.py", line 2079, in _backward_impl
    out = call_func_at_runtime_with_args(
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_functorch/_aot_autograd/utils.py", line 126, in call_func_at_runtime_with_args
    out = normalize_as_list(f(args))
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py", line 755, in _fn
    return fn(*args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/output_code.py", line 465, in __call__
    return self.current_callable(inputs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/utils.py", line 2196, in run
    return model(new_inputs)
  File "/tmp/torchinductor_ttruong/tmp4_acgb21/uq/cuqj5gwwrinhvkoezg5w6nbbi2trkgz7qn22ykn6f5sx6ze76o5a.py", line 914, in call
    triton_per_fused_zeros_0.run(getitem, tangents_1, buf1, 1478, 64, grid=grid(1478), stream=stream0)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 918, in run
    self.autotune_to_one_config(*args, grid=grid, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 795, in autotune_to_one_config
    timings = self.benchmark_all_configs(*args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 769, in benchmark_all_configs
    timings = {
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 770, in <dictcomp>
    launcher: self.bench(launcher, *args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 666, in bench
    return benchmarker.benchmark_gpu(kernel_call, rep=40)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/benchmarking.py", line 66, in wrapper
    return fn(self, *args, **kwargs)
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/benchmarking.py", line 202, in benchmark_gpu
    return self.triton_do_bench(_callable, **kwargs, return_mode="median")
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/triton/testing.py", line 117, in do_bench
    fn()
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/torch/_inductor/runtime/triton_heuristics.py", line 650, in kernel_call
    launcher(
  File "<string>", line 6, in launcher
  File "/home/ttruong/code/attention-gym/.venv/lib/python3.10/site-packages/triton/backends/nvidia/driver.py", line 435, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [CUDA]: an illegal memory access was encountered

Output of pip freeze:

filelock==3.16.1
fsspec==2024.10.0
Jinja2==3.1.4
MarkupSafe==2.1.5
mpmath==1.3.0
networkx==3.4.2
numpy==2.1.2
nvidia-cublas-cu11==11.11.3.6
nvidia-cuda-cupti-cu11==11.8.87
nvidia-cuda-nvrtc-cu11==11.8.89
nvidia-cuda-runtime-cu11==11.8.89
nvidia-cudnn-cu11==9.1.0.70
nvidia-cufft-cu11==10.9.0.58
nvidia-curand-cu11==10.3.0.86
nvidia-cusolver-cu11==11.4.1.48
nvidia-cusparse-cu11==11.7.5.86
nvidia-nccl-cu11==2.21.5
nvidia-nvtx-cu11==11.8.86
pytorch-triton==3.2.0+git0d4682f0
sympy==1.13.1
torch==2.6.0.dev20241228+cu118
typing_extensions==4.12.2
@drisspg
Copy link
Contributor

drisspg commented Dec 30, 2024

Will take a look

@timt51
Copy link
Author

timt51 commented Jan 24, 2025

The script no longer reproduces the issue for me as of the latest pytorch nightly. The last pytorch nightly version that the script fails for me is 2.7.0.dev20250118+cu118.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants