Skip to content

Commit

Permalink
Fix logic for determining the number of cache blocks (IBM#98)
Browse files Browse the repository at this point in the history
When we deploy spec decoding in prod., we are frequently seeing the
servers running out of free blocks. We have determined that this is due
to two issues:
1. The constraint on `SPECULATOR_MAX_BATCH_SIZE` is not enough to avoid
running into memory pressure due to speculation - we need to able ensure
that we do not speculate on batches that may have a small "size" but
very large weight.
2. The computation of the number of blocks is very wrong in most cases.

1. I have introduced an additional constraint that says we should only
speculate on batches with weight up to 75% of the weight limit. This
should ensure that we never speculate when we are close to the memory
limits.
2. I have written new code to calculate the number of KV cache blocks.
This calculation uses the memory scaling coefficients that we have
learned at startup. In particular, it uses to the learned coefficients
to figure out what % of the memory capacity needs to be set aside for
cache blocks.
3. In the above calculation, I use the next token coefficient, rather
than the prefill coefficient, since typically during next token phase
the KV cache blocks comprise a relatively large percentage of the total
memory consumption and we need to be able to handle this worst-case.
However, this means that during prefill steps, we may not have enough
memory leftover to store the auxiliary data structures we need for a
forward pass. There isn't really a clean way to handle this other than
re-writing the router logic to be block-aware, but what we can do is
recommend to the user that they should increase the batch safety margin
to a certain level to ensure that prefills will not run OOM. I've added
a print statement to provide this guidance.
4. I now load the speculator before learning the memory scaling model
since we also need to take that into account when measuring the amount
of free memory.

These changes, together with setting the `BATCH_SAFETY_MARGIN=35`, seems
to result in robust behaviour for both `llama3-8b` and `granite-20b`. We
no longer need to manually set the number of KV cache blocks in the
latter case.

n/a

---------

Signed-off-by: Thomas Parnell <[email protected]>
  • Loading branch information
tdoublep authored and Xaenalt committed Aug 7, 2024
1 parent 9fa25b5 commit db97ac5
Show file tree
Hide file tree
Showing 6 changed files with 95 additions and 35 deletions.
2 changes: 2 additions & 0 deletions server/text_generation_server/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def get_model(
dtype_str: str,
quantize: Optional[str],
max_sequence_length: Optional[int],
memory_scaling_model: Optional[int] = None,
) -> Model:
dtype = get_torch_dtype(dtype_str)
model_path = get_model_path(model_name, revision)
Expand Down Expand Up @@ -74,6 +75,7 @@ def get_model(
dtype, quantize,
model_config,
max_sequence_length=max_sequence_length,
memory_scaling_model=memory_scaling_model,
)

if FLASH_ATTENTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,9 @@ def __init__(self, config, weights):
weights=weights,
)

def get_kv_cache_block_size(self, block_size: int) -> int:
return block_size * self.model.num_key_value_heads * self.model.head_size * 2

def get_input_embeddings(self) -> nn.Module:
return self.model.embed_tokens

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,9 @@ def __init__(self, config, weights):
config, prefix="transformer.wte", weights=weights
)

def get_kv_cache_block_size(self, block_size: int) -> int:
return block_size * self.transformer.head_size * 2

def get_input_embeddings(self) -> nn.Module:
return self.transformer.wte

Expand Down
82 changes: 52 additions & 30 deletions server/text_generation_server/models/paged_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,16 +18,14 @@
from text_generation_server.utils.token_types import TokenInfo, InputTokens
from text_generation_server.utils.tokens import HeterogeneousNextTokenChooser, get_token_info, get_input_tokens_info
from text_generation_server.utils.paged import (
load_speculator,
prepare_inputs_without_speculation,
prepare_inputs_with_speculation,
process_outputs_with_speculation,
prepare_inputs_for_prefill
)
from text_generation_server.inference_engine import get_inference_engine_class

# HF name or path to speculator model (None means no speculation will be used)
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)

# we will only do speculation if the batch size is <= this parameter
SPECULATOR_MAX_BATCH_SIZE = int(os.getenv("SPECULATOR_MAX_BATCH_SIZE", "16"))

Expand Down Expand Up @@ -277,6 +275,7 @@ def __init__(
quantize: Optional[str],
model_config: Union[Any] = None,
max_sequence_length: Optional[int] = None,
memory_scaling_model: Optional["MemoryScalingModel"] = None,
):
model_path = get_model_path(model_name, revision)

Expand All @@ -300,27 +299,41 @@ def __init__(

from fms_extras.utils.cache.paged import PagedKVCacheManager

if SPECULATOR_NAME is not None:
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
speculator_revision = os.getenv("SPECULATOR_REVISION", None)
speculator_model_path = get_model_path(SPECULATOR_NAME, speculator_revision)
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
# load speculator
self.speculator = load_speculator(self.device, dtype)

if self.speculator is not None:
print_rank_n(f"Speculation will be enabled up to batch size {SPECULATOR_MAX_BATCH_SIZE}")
kwargs = {
"pretrained_model_name_or_path": speculator_model_path,
"local_files_only": True,
"torch_dtype": dtype,
}
with self.device:
self.speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
self.speculator.to(device=self.device)
else:
self.speculator = None

block_size = 16

if KV_CACHE_MANAGER_NUM_GPU_BLOCKS is not None:
total_num_gpu_blocks = int(KV_CACHE_MANAGER_NUM_GPU_BLOCKS)
else:
total_num_gpu_blocks = None
# Firstly, let's compute the size of a cache block in bytes
kv_cache_block_size = self.model.get_kv_cache_block_size(block_size)
total_size = model_config.num_hidden_layers * kv_cache_block_size
dtype_size = torch.tensor([], dtype=dtype).element_size()
cache_block_size = dtype_size * total_size
# We then use our memory scaling model to determine the fraction of the prefill memory
# usage that is due to cache blocks (as opposed to the other stuff needed for forward):
pf_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.linear_fit_params[0]
# We can then do the same for the next token (decoding) step:
nt_cache_block_ratio = cache_block_size / block_size / memory_scaling_model.next_token_params[1]
# In general we know that the next token phase can use many more cache blocks
# relative to the prefill phase (e.g., nt_cache_block_ratio > pf_cache_block_ratio).
# Thus, we need to allocate enough cache blocks to handle the more extreme case:
total_num_gpu_blocks = int(nt_cache_block_ratio * memory_scaling_model.free_memory // cache_block_size)
# This creates an issue though, because if we then try to perform a large prefill, while we
# will certainly have enough cache blocks available, we may not have enough memory leftover
# to allocate the other data structures needed during a forward pass.
# To overcome this, we can set the batch_safety_margin a bit to ensure that:
# free_memory * (1.0-batch_safety_margin/100-0.05) * (1.0-pf_cache_block_ratio) <
# free_memory * (1.0-nf_cache_block_ratio)
# This should ensure that our prefills batches can never get so big as to cause OOM.
recommend_safety_margin = 5 + int(100*(1.0 - (1.0 - nt_cache_block_ratio)/(1.0 - pf_cache_block_ratio)))
if memory_scaling_model.safety_margin < recommend_safety_margin:
print(f"WARN: We recommend increasing the value of BATCH_SAFETY_MARGIN to: {recommend_safety_margin}")

self.kv_cache_manager = PagedKVCacheManager(
model_config.num_hidden_layers,
Expand All @@ -331,8 +344,14 @@ def __init__(
dtype=dtype,
device=self.device,
total_num_gpu_blocks=total_num_gpu_blocks,
block_size=block_size,
)

self.memory_scaling_model = memory_scaling_model

# log number of free blocks at init
print("[PagedKVCacheManager] number of free blocks: %d" % (len(self.kv_cache_manager.free_blocks)))

@property
def batch_type(self) -> Type[PagedCausalLMBatch]:
return self._batch_type
Expand Down Expand Up @@ -410,12 +429,18 @@ def _prefill(
)

t0 = time.time_ns()
output = self.model(
input_ids,
position_ids=position_ids,
cache_data=cache_data,
return_embeds=True,
)
try:
output = self.model(
input_ids,
position_ids=position_ids,
cache_data=cache_data,
return_embeds=True,
)
except:
# if something goes wrong during forward, we still need to set the sequence ids
#TODO it would be better to fix the forward method to avoid possibility of partial failures
batch.sequence_ids = cache_data.sequence_ids
raise
t_forward_ns = time.time_ns()-t0
logits, embeds = output

Expand Down Expand Up @@ -600,10 +625,7 @@ def generate_token(
)
else:
bsize = batch.input_ids.shape[0]

tokens_remaining = 0
for i in range(len(batch.total_lengths)):
tokens_remaining += batch.total_lengths[i] - batch.input_lengths[i]
weight = sum(batch.total_lengths) * self.memory_scaling_model.next_token_params[1]

spec_ind = []
for i, sample in enumerate(batch.next_token_chooser.do_sample):
Expand All @@ -615,7 +637,7 @@ def generate_token(
len(spec_ind) > 0 and
bsize <= SPECULATOR_MAX_BATCH_SIZE and
batch.next_token_chooser.repetition_processor is None and
tokens_remaining < 0.25*len(self.kv_cache_manager.free_blocks)*self.kv_cache_manager.block_size
(weight/self.memory_scaling_model.weight_limit) <= 0.75
)

if speculate:
Expand Down
13 changes: 8 additions & 5 deletions server/text_generation_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ async def func_with_log(*args, **kwargs):


class TextGenerationService(generate_pb2_grpc.TextGenerationServiceServicer):
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModelPB):
def __init__(self, model: Model, cache: Cache, server_urls: List[str], memory_scaling_model: MemoryScalingModel):
self.cache = cache
self.model = model
self.server_urls = server_urls
Expand All @@ -81,7 +81,7 @@ async def ModelInfo(self, request: generate_pb2.ModelInfoRequest, context) -> ge
if isinstance(self.model, Seq2SeqLM) else ModelInfoResponse.ModelType.CAUSAL_LM,
eos_token=self.model.config.eos_token_id,
batch_padding=not isinstance(self.model, FlashCausalLM),
memory_scaling_model=self.memory_scaling_model,
memory_scaling_model=self.memory_scaling_model.as_pb(),
)

@log_rpc_handler_errors
Expand Down Expand Up @@ -234,8 +234,9 @@ def _free_paged_sequences(self, batch: "Batch", completed_ids: Optional[List[int
]
else:
return
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)

if sequence_ids_to_free is not None:
self.model.kv_cache_manager.free_sequences(sequence_ids_to_free, recursive=True)

def serve(
model_name: str,
Expand Down Expand Up @@ -276,6 +277,8 @@ async def serve_inner(
proc.start()
memory_scaling_model_ext = q_out.get()
proc.join()
else:
memory_scaling_model_ext = None

unix_socket_template = "unix://{}-{}"
world_size = int(os.getenv("WORLD_SIZE", "1"))
Expand Down Expand Up @@ -307,7 +310,7 @@ async def serve_inner(
torch.cuda.set_per_process_memory_fraction(cuda_process_memory_fraction)

model = get_model(
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length, memory_scaling_model_ext,
)

device = model.engine.get_device()
Expand Down Expand Up @@ -415,7 +418,7 @@ def estimate_memory():

server = aio.server()
generate_pb2_grpc.add_TextGenerationServiceServicer_to_server(
TextGenerationService(model, Cache(), server_urls, memory_scaling_model.as_pb()), server
TextGenerationService(model, Cache(), server_urls, memory_scaling_model), server
)
# SERVICE_NAMES = (
# generate_pb2.DESCRIPTOR.services_by_name["TextGenerationService"].full_name,
Expand Down
27 changes: 27 additions & 0 deletions server/text_generation_server/utils/paged.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,37 @@

from fms_extras.models.speculator import flatten_batch, apply_index_map

# HF name or path to speculator model (None means no speculation will be used)
SPECULATOR_NAME = os.getenv("SPECULATOR_NAME", None)

# speculator revision
SPECULATOR_REVISION = os.getenv("SPECULATOR_REVISION", None)

# number of candidates during speculation
SPECULATOR_N_CANDIDATES = os.getenv("SPECULATOR_N_CANDIDATES", None)

# number of candidates per head
SPECULATOR_TOP_K_TOKENS_PER_HEAD = os.getenv("SPECULATOR_TOP_K_TOKENS_PER_HEAD", None)

def load_speculator(device, dtype):

if SPECULATOR_NAME is not None:
from fms_extras.models.hf.modeling_mlp_speculator import MLPSpeculatorPreTrainedModel
from text_generation_server.utils.hub import get_model_path
from text_generation_server.utils import print_rank_n
speculator_model_path = get_model_path(SPECULATOR_NAME, SPECULATOR_REVISION)
print_rank_n(f"Loading speculator model from: {speculator_model_path}")
kwargs = {
"pretrained_model_name_or_path": speculator_model_path,
"local_files_only": True,
"torch_dtype": dtype,
}
with device:
speculator = MLPSpeculatorPreTrainedModel.from_pretrained(**kwargs)
speculator.to(device=device)
return speculator
else:
return None

def fit_memory_scaling_model(
model_name: str,
Expand Down Expand Up @@ -38,6 +63,8 @@ def fit_memory_scaling_model(
model_name, revision, deployment_framework, dtype_str, quantize, max_sequence_length
)

speculator = load_speculator(model.device, model.dtype)

memory_scaling_model = Estimator.build_from_env(
model,
batch_safety_margin,
Expand Down

0 comments on commit db97ac5

Please sign in to comment.