-
Notifications
You must be signed in to change notification settings - Fork 2.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Not implemented: Non-trivial layouts unsupported #24415
Comments
What I am trying to figure out is in the flash attention,
why is the last dimension
|
It seems the assignee is not set when I use the link https://github.com/google/jax/issues/new?assignees=apaszke in the error message to create the issue. So manually cc @apaszke |
Try using a broadcast ( Repeat as currently implemented only works for very limited cases. |
Thanks Justin. Actually I just figured out it can also be fixed by making the 2nd to the last dimension to be 8, instead of 4 which is what I used, due the Pallas kernel constraint:
|
Description
Hi. I am extending the Pallas paged attention kernel. The case is a MQA. When I run my kernel, I encountered the following error which suggests it is an internal error and I should report here.
Here is my pallas kernel and the test code that calls the kernel.
Please let me know if you need more info.
cc @miladm @WoosukKwon
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: