Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast #24404

Open
vanbasten23 opened this issue Oct 19, 2024 · 4 comments
Open
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)

Comments

@vanbasten23
Copy link

vanbasten23 commented Oct 19, 2024

Description

Hi. I am extending the Pallas paged attention kernel. The case is a MQA. When I run my kernel, I encountered the following error which suggests it is an internal error and I should report here.

======================================================================
ERROR: test_extended_paged_attention_v1_multiple_queries (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 266, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast

at location: loc("/swap"(callsite("_flash_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":188:0) at callsite("paged_flash_attention_kernel"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":331:0) at callsite("paged_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":547:0) at callsite("test_extended_paged_attention_v1_multiple_queries"("/workspaces/persist/pytorch/xla/test/test_pallas.py":773:0) at "<module>"("/workspaces/persist/pytorch/xla/test/test_pallas.py":1669:0)))))))

The MLIR operation involved:
  %61 = "vector.shape_cast"(%60) : (vector<4x128xf32>) -> vector<1x4x1x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke


The above exception was the direct cause of the following exception:

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
    out = jax_extended_paged_attention1(
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 1669, in <module>
    test = unittest.main()
  File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
    out = jax_extended_paged_attention1(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 547, in paged_attention
    out = pl.pallas_call(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 331, in paged_flash_attention_kernel
    _flash_attention(
  File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 188, in _flash_attention
    o_ref[:, q_head_idx, :] = acc_scratch_ref[:].astype(o_ref.dtype)
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast

The MLIR operation involved:
  %61 = "vector.shape_cast"(%60) : (vector<4x128xf32>) -> vector<1x4x1x128xf32>

Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke


----------------------------------------------------------------------
Ran 1 test in 0.607s

FAILED (errors=1)

Here is my pallas kernel and the test code that calls the kernel.

Please let me know if you need more info.

System info (python version, jaxlib version, accelerator, etc.)

>>> import jax; jax.print_environment_info()
jax:    0.4.33.dev20240913
jaxlib: 0.4.33.dev20240913
numpy:  2.1.1
python: 3.10.15 (main, Sep 27 2024, 06:06:16) [GCC 10.2.1 20210110]
jax.devices (8 total, 8 local): [TpuDevice(id=0, process_index=0, coords=(0,0,0), core_on_chip=0) TpuDevice(id=1, process_index=0, coords=(1,0,0), core_on_chip=0) ... TpuDevice(id=6, process_index=0, coords=(2,1,0), core_on_chip=0) TpuDevice(id=7, process_index=0, coords=(3,1,0), core_on_chip=0)]
process_count: 1
platform: uname_result(system='Linux', node='t1v-n-f3643994-w-0', release='5.19.0-1030-gcp', version='#32~22.04.1-Ubuntu SMP Thu Jul 13 09:36:23 UTC 2023', machine='x86_64')

cc: @miladm @WoosukKwon

@vanbasten23 vanbasten23 added the bug Something isn't working label Oct 19, 2024
@vanbasten23
Copy link
Author

The problematic line is o_ref[:, q_head_idx, :] = acc_scratch_ref[:].astype(o_ref.dtype). I found a way to work around the problem (the code is in #24415). But I'm trying to figure out why the flash attention example also does something similar but it works fine.

@vanbasten23
Copy link
Author

It seems the assignee is not set when I use the link https://github.com/google/jax/issues/new?assignees=apaszke in the error message to create the issue. So manually cc @apaszke

@justinjfu
Copy link
Collaborator

#22938 should in principle address this, which was checked in on Sep 20 (that's newer than the version you are running).

For some explanation on the error: The last two dimensions of an array are special because they are physically tiled into VREGs (also the reason for the special 8x128 block size as noted here: https://jax.readthedocs.io/en/latest/pallas/tpu/details.html#blockspecs-and-grid-iteration). So certain reshapes require additional work under the hood.

Because of the tiling, it's in general more efficient to leave the singleton dimensions in front rather than in the last 2 dimensions if you can afford to do so. For example, reshaping from 4x128 to 4x1x128 would require 4 copy operations to copy each row of the original VREG into the first row of 4 new VREGs. Whereas reshaping from 4x128 to 1x4x128 is effectively "free" since it just adds an extra logical dimension in the front that can be handled at compile time.

@justinjfu justinjfu added the pallas Issues pertaining to Pallas (GPU or TPU) label Oct 21, 2024
@vanbasten23
Copy link
Author

Thanks Justin for the explanation!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working pallas Issues pertaining to Pallas (GPU or TPU)
Projects
None yet
Development

No branches or pull requests

2 participants