-
Notifications
You must be signed in to change notification settings - Fork 33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
How to do KV Cache with FlexAttention and BlockMask by slicing? #60
Comments
For prefill it might be worth just regenerating the block mask. But in general, indexing at a position just gives you the block mask corresponding to that position. I believe we also support using a sequence smaller than the BlockMask with the BlockMask. So if you generate a BlockMask with S=2048, for example, you can pass in a sequence of length 1001. |
Thank you so much for your reply! I understand that with BlockMask (S=2048), we can process sequences up to length 1000 for Prefill Stage. I am sorry if I did not make it clear in the original comment. I'm facing an issue during decoding stage:
How can we implement KV Cache in this scenario where we can't slice individual rows from the BlockMask? |
We integrated flexattention in gpt-fast for decoding: pytorch-labs/gpt-fast#196 You only need to build a new BLockMask every 128 tokens generated. BlockMask stays the same for tokens #1024 - #1024 + 128. |
Is there any example code to do this? Should I generate new BlockMask everytime?
Thanks!
Essentially, I have problem of slicing BlockMask. For exmaple, if we have a prompt token of length 1000 (Prefill stage), I have the following codes for attention, which can be wrong. But, my question is if I need to generate 1001th token (one single token for Q), how do I slice the exact position in the BlockMask for it?
Another question is that if I use Prefix Mask for token prompts, when I set
H=None
, it works; when I setH=H
, it has errors.When
H=H
Errors
The text was updated successfully, but these errors were encountered: