You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi. I am extending the Pallas paged attention kernel. The case is a MQA. When I run my kernel, I encountered the following error which suggests it is an internal error and I should report here.
======================================================================
ERROR: test_extended_paged_attention_v1_multiple_queries (__main__.PallasTest)
----------------------------------------------------------------------
Traceback (most recent call last):
File "/usr/local/lib/python3.10/site-packages/jax/_src/compiler.py", line 266, in backend_compile
return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast
at location: loc("/swap"(callsite("_flash_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":188:0) at callsite("paged_flash_attention_kernel"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":331:0) at callsite("paged_attention"("/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py":547:0) at callsite("test_extended_paged_attention_v1_multiple_queries"("/workspaces/persist/pytorch/xla/test/test_pallas.py":773:0) at "<module>"("/workspaces/persist/pytorch/xla/test/test_pallas.py":1669:0)))))))
The MLIR operation involved:
%61 = "vector.shape_cast"(%60) : (vector<4x128xf32>) -> vector<1x4x1x128xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
The above exception was the direct cause of the following exception:
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
out = jax_extended_paged_attention1(
File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 1669, in <module>
test = unittest.main()
File "/workspaces/persist/pytorch/xla/test/test_pallas.py", line 773, in test_extended_paged_attention_v1_multiple_queries
out = jax_extended_paged_attention1(
File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 547, in paged_attention
out = pl.pallas_call(
File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 331, in paged_flash_attention_kernel
_flash_attention(
File "/workspaces/persist/pytorch/xla/torch_xla/experimental/pallas_kernels/extended_paged_attention_kernel1.py", line 188, in _flash_attention
o_ref[:, q_head_idx, :] = acc_scratch_ref[:].astype(o_ref.dtype)
jax._src.pallas.mosaic.error_handling.MosaicError: INTERNAL: Mosaic failed to compile TPU kernel: unsupported shape cast
The MLIR operation involved:
%61 = "vector.shape_cast"(%60) : (vector<4x128xf32>) -> vector<1x4x1x128xf32>
Please report a bug at: https://github.com/google/jax/issues/new?assignees=apaszke
----------------------------------------------------------------------
Ran 1 test in 0.607s
FAILED (errors=1)
The problematic line is o_ref[:, q_head_idx, :] = acc_scratch_ref[:].astype(o_ref.dtype). I found a way to work around the problem (the code is in #24415). But I'm trying to figure out why the flash attention example also does something similar but it works fine.
Because of the tiling, it's in general more efficient to leave the singleton dimensions in front rather than in the last 2 dimensions if you can afford to do so. For example, reshaping from 4x128 to 4x1x128 would require 4 copy operations to copy each row of the original VREG into the first row of 4 new VREGs. Whereas reshaping from 4x128 to 1x4x128 is effectively "free" since it just adds an extra logical dimension in the front that can be handled at compile time.
Description
Hi. I am extending the Pallas paged attention kernel. The case is a MQA. When I run my kernel, I encountered the following error which suggests it is an internal error and I should report here.
Here is my pallas kernel and the test code that calls the kernel.
Please let me know if you need more info.
System info (python version, jaxlib version, accelerator, etc.)
cc: @miladm @WoosukKwon
The text was updated successfully, but these errors were encountered: