Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Not implemented: Non-trivial layouts unsupported #24415

Open
vanbasten23 opened this issue Oct 20, 2024 · 4 comments
Open

Not implemented: Non-trivial layouts unsupported #24415

vanbasten23 opened this issue Oct 20, 2024 · 4 comments
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@vanbasten23
Copy link

Description

Hi. I am extending the Pallas paged attention kernel. The case is a MQA. When I run my kernel, I encountered the following error which suggests it is an internal error and I should report here.

root@t1v-n-f3643994-w-0:/workspaces/persist# rm -rf /workspaces/persist/tpu_logs && LIBTPU_INIT_ARGS="--xla_tpu_dump_logs_to_dir=/workspaces/persist/tpu_logs"  python pytorch/xla/test/test_pallas.py -v -k  PallasTest.test_extended_paged_attention_v1_multiple_queries 2>&1 | tee ~/out.txt
WARNING:root:libtpu.so and TPU device found. Setting PJRT_DEVICE=TPU.
test_extended_paged_attention_v1_multiple_queries (__main__.PallasTest) ... The test test_extended_paged_attention_multiple_queries begins with query_len=4
ERROR

======================================================================
ERROR: test_extended_paged_attention_v1_multiple_queries (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 266, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: Not implemented: Non-trivial layouts unsupported

at location: loc("/repeat"(callsite("_flash_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":180:0) at callsite("paged_flash_attention_kernel"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":335:0) at callsite("paged_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":558:0) at callsite("test_extended_paged_attention_v1_multiple_queries"("/workspaces/persist/pytorch/xla/test/test_pallas.py":773:0) at "<module>"("/workspaces/persist/pytorch/xla/test/test_pallas.py":1669:0)))))))

The MLIR operation involved:
  %186 = "tpu.repeat"(%177) <{dimension = 1 : i32, times = 1 : i32}> {in_layout = [#tpu.vpad<"32,{0,0},(4,128)">], out_layout = [#tpu.vpad<"32,{0,0},(4,128)">]} : (vector<4x128xf32>) -> vector<4x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke


The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
    out = jax_extended_paged_attention1(
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 1669, in <module>
    test = unittest.main()
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
    out = jax_extended_paged_attention1(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 558, in paged_attention
    out = pl.pallas_call(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 335, in paged_flash_attention_kernel
    out_q_head_idx = _flash_attention(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 180, in _flash_attention
    acc_scratch_ref[:] *= pltpu.repeat(acc_scale, acc_scale_repeats, axis=1)
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: Not implemented: Non-trivial layouts unsupported

The MLIR operation involved:
  %186 = "tpu.repeat"(%177) <{dimension = 1 : i32, times = 1 : i32}> {in_layout = [#tpu.vpad<"32,{0,0},(4,128)">], out_layout = [#tpu.vpad<"32,{0,0},(4,128)">]} : (vector<4x128xf32>) -> vector<4x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke


----------------------------------------------------------------------
Ran 1 test in 0.592s

FAILED (errors=1)

Here is my pallas kernel and the test code that calls the kernel.

Please let me know if you need more info.

cc @miladm @WoosukKwon

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.33.dev20240913
jaxlib: 0.4.33.dev20240913
numpy:  2.1.1
python: 3.10.15 (main, Sep 27 2024, 06:06:16) [GCC 10.2.1 20210110]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-f3643994-w-0', release='5.19.0-1030-gcp', version='#32~22.04.1-Ubuntu SMP Thu Jul 13 09:36:23 UTC 2023', machine='x86_64')

@vanbasten23
Copy link
Author

What I am trying to figure out is in the flash attention,

    m_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE), jnp.float32)
    l_scratch = pltpu.VMEM((block_b, 1, block_q, MIN_BLOCK_SIZE), jnp.float32)

why is the last dimension MIN_BLOCK_SIZE while if we stick to the original flash attention v1 paper, it seems should be (block_b, 1, block_q, 1) IIUC? I guess it has something to do with the Pallas internal constraint that

Furthermore, the last two dimensions of your block shape must be equal to the respective dimension of the overall array, or be divisible by 8 and 128 respectively.

@vanbasten23
Copy link
Author

It seems the assignee is not set when I use the link https://github.com/google/jax/issues/new?assignees=apaszke in the error message to create the issue. So manually cc @apaszke

@justinjfu
Copy link
Collaborator

justinjfu commented Oct 21, 2024

Try using a broadcast (lax.broadcast_in_dim) instead of repeat - e.g. #23318

Repeat as currently implemented only works for very limited cases.

@justinjfu justinjfu added the pallas Issues pertaining to Pallas (GPU or TPU) label Oct 21, 2024
@vanbasten23
Copy link
Author

Thanks Justin. Actually I just figured out it can also be fixed by making the 2nd to the last dimension to be 8, instead of 4 which is what I used, due the Pallas kernel constraint:

On TPU, only blocks with rank at least 1 are supported. Furthermore, the last two dimensions of your block shape must be equal to the respective dimension of the overall array, or be divisible by 8 and 128 respectively.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

2 participants