-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Example code to run batched inference? #3
Labels
good first issue
Good for newcomers
Comments
Hello, thanks for your interest in our work! Here, I can give you an example. Basically, you can define class BatchRetrievalCache:
def __init__(self, model, max_budget=1024, bsz=2, prefill=1024, chunk_size=8, gamma=6) -> None:
self.chunk_size = chunk_size
self.prefill = prefill
self.chunks = prefill // self.chunk_size
self.select_sets = max_budget // self.chunk_size
self.gamma = gamma
self.max_budget = max_budget
assert prefill % self.chunk_size == 0, f"prefill should be multiple of chunk_size, got {prefill} % {self.chunk_size}"
assert max_budget % self.chunk_size == 0, f"max_budget should be multiple of chunk_size, got {max_budget} % {self.chunk_size}"
self.real_budget = max_budget + gamma
self.hidden_size = model.config.hidden_size
self.num_heads = model.config.num_key_value_heads
self.num_key_value_groups = model.config.num_attention_heads // model.config.num_key_value_heads
self.head_dim = self.hidden_size // model.config.num_attention_heads
self.layers = model.config.num_hidden_layers
self.bsz = bsz
self.seq_len = torch.zeros(bsz, dtype=torch.int32).to(model.device)
device=model.device
dtype=torch.float16
self.key_cache=torch.zeros([self.layers, bsz, self.real_budget, self.num_heads, self.head_dim], dtype=dtype).to(device)
self.value_cache=torch.zeros([self.layers, bsz, self.real_budget, self.num_heads, self.head_dim], dtype=dtype).to(device)
def print_status(self):
print("Budget:", self.max_budget, f"| bsz: {self.bsz}", " | Real Budget:", self.real_budget, " | PreFill:", self.prefill, " | Chunk Size:", self.chunk_size, " | Chunks:", self.chunks, " | Select Sets:", self.select_sets)
def init_graph_cache(self, kv_cache, query_states, layer_idx):
# query_states: (bsz, 1, 32, head_dim) --> (bsz, 32, 1, head_dim)
# key_cache: (bsz, seq_len, 32, head_dim) --> (bsz, 32, head_dim, seq_len)
# print(query_states.shape, self.chunk_k[layer_idx].shape)
assert 1 == query_states.shape[1], "query_states should be 1 for init"
# chunk_k: (bsz, chunks, chunk_size, kv_heads, head_dim) --> (bsz, chunks, kv_heads, head_dim)
chunk_k = kv_cache.key_cache[layer_idx,:,:self.prefill].cuda().view(self.bsz, self.chunks, self.chunk_size, self.num_heads, self.head_dim).mean(dim=-3)
chunk_k = repeat_kv(chunk_k.permute(0, 2, 1, 3), self.num_key_value_groups) # (bsz, kv_heads, chunks, head_dim)
chunk_attn = torch.matmul(query_states.permute(0, 2, 1, 3), chunk_k.permute(0, 1, 3, 2)).squeeze(2) # (bsz, 32, chunks)
_, topk_idx_rest = torch.topk(chunk_attn[:, :, 1:], k=self.select_sets-1, dim=-1) # (bsz, 32, select_sets) --> (bsz, select_sets, 32)
topk_idx_rest += 1
topk_idx_first = torch.zeros((topk_idx_rest.shape[0], topk_idx_rest.shape[1], 1), device=topk_idx_rest.device, dtype=topk_idx_rest.dtype)
topk_idx = torch.cat([topk_idx_first, topk_idx_rest], dim=-1) # (bsz, 32, select_sets)
expanded_index_tensor = topk_idx.permute(0, 2, 1).unsqueeze(-1).unsqueeze(-1).expand(-1, -1, -1, self.chunk_size, self.head_dim)
# (bsz, prefill, 32, head_dim) --> (bsz, chunks, chunk_size, 32, head_dim) --> (bsz, chunks, 32, chunk_size, head_dim)
key_ = kv_cache.key_cache[layer_idx][:, :self.prefill].reshape(self.bsz, self.chunks, self.chunk_size, self.num_heads, self.head_dim).cuda()
key_ = key_.permute(0, 1, 3, 2, 4)
result_tensor = torch.gather(key_, 1, expanded_index_tensor) # (bsz, select_sets, 32, chunk_size, head_dim)
# (bsz, select_sets, 32, chunk_size, head_dim) --> (bsz, select_sets*chunk_size, 32, head_dim)
self.key_cache[layer_idx][:,:self.max_budget] = result_tensor.permute(0, 1, 3, 2, 4).reshape(self.bsz, self.select_sets*self.chunk_size, self.num_heads, self.head_dim).clone()
value_ = kv_cache.value_cache[layer_idx][:, :self.prefill].reshape(self.bsz, self.chunks, self.chunk_size, self.num_heads, self.head_dim).cuda()
value_ = value_.permute(0, 1, 3, 2, 4)
result_tensor = torch.gather(value_, 1, expanded_index_tensor)
self.value_cache[layer_idx][:,:self.max_budget] = result_tensor.permute(0, 1, 3, 2, 4).reshape(self.bsz, self.select_sets*self.chunk_size, self.num_heads, self.head_dim).clone()
if layer_idx == self.layers-1:
self.init_graph = True
def update(self, key_states :torch.Tensor, value_states :torch.Tensor, layer_idx :int, gamma_offset :int):
self.key_cache[layer_idx][:, self.real_budget-self.gamma+gamma_offset] = key_states.clone().squeeze(1)
self.value_cache[layer_idx][:, self.real_budget-self.gamma+gamma_offset] = value_states.clone().squeeze(1)
return self.key_cache[layer_idx][:,:self.real_budget-self.gamma+gamma_offset+1], self.value_cache[layer_idx][:,:self.real_budget-self.gamma+gamma_offset+1]
def update_graph_cache(self, kv_cache=None):
# self.value_cache[:,:,self.max_budget-(kv_cache.seq_len-self.prefill):self.max_budget] = kv_cache.value_cache[:,:, self.prefill:kv_cache.seq_len]
# self.key_cache[:,:,self.max_budget-(kv_cache.seq_len-self.prefill):self.max_budget] = kv_cache.key_cache[:,:, self.prefill:kv_cache.seq_len]
bsz = kv_cache.seq_len.shape[0]
for i in range(bsz):
self.value_cache[:,i,self.max_budget-(kv_cache.seq_len[i]-self.prefill):self.max_budget] = kv_cache.value_cache[:,i, self.prefill:kv_cache.seq_len[i]]
self.key_cache[:,i,self.max_budget-(kv_cache.seq_len[i]-self.prefill):self.max_budget] = kv_cache.key_cache[:,i, self.prefill:kv_cache.seq_len[i]]
def reset(self):
self.key_cache.zero_()
self.value_cache.zero_() During verification phase, you can use attn_output = flash_attn_with_kvcache(q=query_states, k_cache=key_states, v_cache=value_states, cache_seqlens=kv_cache.seq_len, softmax_scale=1/torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)), causal=True) |
Great, thanks very much! |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Nice work! In the paper I saw this batched result:
But examples like https://github.com/Infini-AI-Lab/TriForce/blob/main/test/on_chip.py only use batch size=1. Does the code supports batched speculative inference?
The text was updated successfully, but these errors were encountered: