Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix logic for determining the number of cache blocks #98

Merged
merged 10 commits into from
May 31, 2024

Conversation

tdoublep
Copy link
Member

Motivation

When we deploy spec decoding in prod., we are frequently seeing the servers running out of free blocks. We have determined that this is due to two issues:

  1. The constraint on SPECULATOR_MAX_BATCH_SIZE is not enough to avoid running into memory pressure due to speculation - we need to able ensure that we do not speculate on batches that may have a small "size" but very large weight.
  2. The computation of the number of blocks is very wrong in most cases.

Modifications

  1. I have introduced an additional constraint that says we should only speculate on batches with weight up to 75% of the weight limit. This should ensure that we never speculate when we are close to the memory limits.
  2. I have written new code to calculate the number of KV cache blocks. This calculation uses the memory scaling coefficients that we have learned at startup. In particular, it uses to the learned coefficients to figure out what % of the memory capacity needs to be set aside for cache blocks.
  3. In the above calculation, I use the next token coefficient, rather than the prefill coefficient, since typically during next token phase the KV cache blocks comprise a relatively large percentage of the total memory consumption and we need to be able to handle this worst-case. However, this means that during prefill steps, we may not have enough memory leftover to store the auxiliary data structures we need for a forward pass. There isn't really a clean way to handle this other than re-writing the router logic to be block-aware, but what we can do is recommend to the user that they should increase the batch safety margin to a certain level to ensure that prefills will not run OOM. I've added a print statement to provide this guidance.
  4. I now load the speculator before learning the memory scaling model since we also need to take that into account when measuring the amount of free memory.

Result

These changes, together with setting the BATCH_SAFETY_MARGIN=35, seems to result in robust behaviour for both llama3-8b and granite-20b. We no longer need to manually set the number of KV cache blocks in the latter case.

Related Issues

n/a

Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Signed-off-by: Thomas Parnell <[email protected]>
Copy link
Collaborator

@JRosenkranz JRosenkranz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

looks good, but a few comments before approval. Also do we have an image available we can try this out?

nt_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.next_token_params[1]
total_num_gpu_blocks = int(nt_cache_block_ratio * memory_scaling_model.free_memory // cache_block_size)
# we may need to increase the safety margin a bit to ensure that prefill forward does not run OOM
recommend_safety_margin = 5 + int(100*(1.0 - (1.0 - nt_cache_block_ratio)/(1.0 - pf_cache_block_ratio)))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a comment around this line as to what is being done?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will add something

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added some explanation now. This approach isn't ideal, and might affect the maximum throughput we can achieve with the server. However, I can't see any other way to ensure robustness without re-implementing the batching logic to interact with the KVCacheManager.

SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)

# speculator revision
SPECULATOR_REVISION = os.getenv("SPECULATOR_REVISION", None)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is this used for when loading?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like it's analogous to MODEL_REVISION, the specific commit hash of the model to load. Like one of these I think: https://huggingface.co/ibm/granite-7b-lab-accelerator/commits/main

)
except:
# if something goes wrong during forward, we still need to set the sequence ids
batch.sequence_ids = cache_data.sequence_ids
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to just move this from the bottom of the method to above the call to self.model(...)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think cache_data.sequence_ids only gets populated within the call to self.model(...) so we can't move it beforehand.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This feels a bit fragile, wonder if it would be better to revert to prior state (if possible) if the call to call to self.model fails? ideally within that call... i.e. avoid partial success.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree that its fragile, and there might be a better way to address it from within the function. Not sure whether to prioritize that at this stage though.

Copy link
Collaborator

@JRosenkranz JRosenkranz left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lgtm

@njhill njhill merged commit c265390 into IBM:main May 31, 2024
5 checks passed
Xaenalt pushed a commit to Xaenalt/text-generation-inference that referenced this pull request Aug 7, 2024
When we deploy spec decoding in prod., we are frequently seeing the
servers running out of free blocks. We have determined that this is due
to two issues:
1. The constraint on `SPECULATOR_MAX_BATCH_SIZE` is not enough to avoid
running into memory pressure due to speculation - we need to able ensure
that we do not speculate on batches that may have a small "size" but
very large weight.
2. The computation of the number of blocks is very wrong in most cases.

1. I have introduced an additional constraint that says we should only
speculate on batches with weight up to 75% of the weight limit. This
should ensure that we never speculate when we are close to the memory
limits.
2. I have written new code to calculate the number of KV cache blocks.
This calculation uses the memory scaling coefficients that we have
learned at startup. In particular, it uses to the learned coefficients
to figure out what % of the memory capacity needs to be set aside for
cache blocks.
3. In the above calculation, I use the next token coefficient, rather
than the prefill coefficient, since typically during next token phase
the KV cache blocks comprise a relatively large percentage of the total
memory consumption and we need to be able to handle this worst-case.
However, this means that during prefill steps, we may not have enough
memory leftover to store the auxiliary data structures we need for a
forward pass. There isn't really a clean way to handle this other than
re-writing the router logic to be block-aware, but what we can do is
recommend to the user that they should increase the batch safety margin
to a certain level to ensure that prefills will not run OOM. I've added
a print statement to provide this guidance.
4. I now load the speculator before learning the memory scaling model
since we also need to take that into account when measuring the amount
of free memory.

These changes, together with setting the `BATCH_SAFETY_MARGIN=35`, seems
to result in robust behaviour for both `llama3-8b` and `granite-20b`. We
no longer need to manually set the number of KV cache blocks in the
latter case.

n/a

---------

Signed-off-by: Thomas Parnell <[email protected]>
Xaenalt pushed a commit to Xaenalt/text-generation-inference that referenced this pull request Aug 12, 2024
When we deploy spec decoding in prod., we are frequently seeing the
servers running out of free blocks. We have determined that this is due
to two issues:
1. The constraint on `SPECULATOR_MAX_BATCH_SIZE` is not enough to avoid
running into memory pressure due to speculation - we need to able ensure
that we do not speculate on batches that may have a small "size" but
very large weight.
2. The computation of the number of blocks is very wrong in most cases.

1. I have introduced an additional constraint that says we should only
speculate on batches with weight up to 75% of the weight limit. This
should ensure that we never speculate when we are close to the memory
limits.
2. I have written new code to calculate the number of KV cache blocks.
This calculation uses the memory scaling coefficients that we have
learned at startup. In particular, it uses to the learned coefficients
to figure out what % of the memory capacity needs to be set aside for
cache blocks.
3. In the above calculation, I use the next token coefficient, rather
than the prefill coefficient, since typically during next token phase
the KV cache blocks comprise a relatively large percentage of the total
memory consumption and we need to be able to handle this worst-case.
However, this means that during prefill steps, we may not have enough
memory leftover to store the auxiliary data structures we need for a
forward pass. There isn't really a clean way to handle this other than
re-writing the router logic to be block-aware, but what we can do is
recommend to the user that they should increase the batch safety margin
to a certain level to ensure that prefills will not run OOM. I've added
a print statement to provide this guidance.
4. I now load the speculator before learning the memory scaling model
since we also need to take that into account when measuring the amount
of free memory.

These changes, together with setting the `BATCH_SAFETY_MARGIN=35`, seems
to result in robust behaviour for both `llama3-8b` and `granite-20b`. We
no longer need to manually set the number of KV cache blocks in the
latter case.

n/a

---------

Signed-off-by: Thomas Parnell <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants