-
Notifications
You must be signed in to change notification settings - Fork 205
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support Llama Multimodal with repad embedding mode --------- Co-authored-by: liuliang1 <[email protected]>
- Loading branch information
1 parent
bf8a829
commit 87557d1
Showing
7 changed files
with
368 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,21 @@ | ||
import torch | ||
from lightllm.models.llama.infer_struct import LlamaInferStateInfo | ||
|
||
|
||
class LlamaMultiModalInferStateInfo(LlamaInferStateInfo): | ||
def __init__(self): | ||
super().__init__() | ||
|
||
def init_some_extra_state(self, | ||
model, | ||
batch_size, | ||
total_token_num, | ||
max_len_in_batch, | ||
input_ids : torch.Tensor, | ||
b_loc : torch.Tensor, | ||
b_start_loc : torch.Tensor, | ||
b_seq_len : torch.Tensor, | ||
is_prefill, | ||
**kwargs): | ||
super().init_some_extra_state(model, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, is_prefill) | ||
self.kwargs = kwargs |
Empty file.
53 changes: 53 additions & 0 deletions
53
lightllm/models/llama_multimodal/layer_infer/pre_layer_infer.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
import torch | ||
import torch.distributed as dist | ||
|
||
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight | ||
from lightllm.models.llama_multimodal.infer_struct import LlamaMultiModalInferStateInfo | ||
from lightllm.utils.infer_utils import mark_cost_time | ||
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer | ||
|
||
|
||
""" | ||
infer_state.kwargs['repad_embeds'] = [(embeds, offset), ...] | ||
embeds: torch.Tensor or None | ||
offset: int | ||
""" | ||
def repad_input_embeds(input_embeds, infer_state: LlamaMultiModalInferStateInfo): | ||
assert isinstance(infer_state, LlamaMultiModalInferStateInfo) | ||
|
||
if infer_state.kwargs and 'repad_embeds' in infer_state.kwargs: | ||
repad_embeds = infer_state.kwargs['repad_embeds'] | ||
assert len(repad_embeds) == infer_state.batch_size, "length of repad_embeds != batch_size: {} vs {}!".format(len(repad_embeds), infer_state.batch_size) | ||
|
||
for i, (embeds, offset) in enumerate(repad_embeds): | ||
# no need to repad if not given repad embeds | ||
if embeds is None: | ||
continue | ||
assert isinstance(embeds, torch.Tensor), "given reapd embeds should be torch.Tensor but got {}!".format(type(embeds)) | ||
|
||
start_idx = infer_state.b_start_loc[i] | ||
seq_len = infer_state.b_seq_len[i] | ||
pad_len, pad_dim = embeds.shape | ||
dim = input_embeds.shape[1] | ||
assert pad_dim == dim, "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim) | ||
assert offset + pad_len <= seq_len, "invalid seq_len={}, offset={}, pad_len={}!".format(seq_len, offset, pad_len) | ||
input_embeds[start_idx + offset: start_idx + offset + pad_len] = embeds | ||
print("repad input_embeds start_idx={} offset={} pad_len={}".format(start_idx, offset, pad_len)) | ||
return input_embeds | ||
|
||
|
||
class LlamaMultiModalPreLayerInfer(LlamaPreLayerInfer): | ||
|
||
def __init__(self, tp_rank, world_size, network_config, mode): | ||
super().__init__(tp_rank, world_size, network_config, mode) | ||
return | ||
|
||
@mark_cost_time("pre context forward") | ||
def context_forward(self, input_ids, infer_state: LlamaMultiModalInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): | ||
input_embeds = super().context_forward(input_ids, infer_state, layer_weight) | ||
return repad_input_embeds(input_embeds, infer_state) | ||
|
||
|
||
def token_forward(self, input_ids, infer_state: LlamaMultiModalInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight): | ||
input_embeds = super().token_forward(input_ids, infer_state, layer_weight) | ||
return repad_input_embeds(input_embeds, infer_state) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,102 @@ | ||
import torch | ||
|
||
from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer | ||
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer | ||
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight | ||
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight | ||
|
||
from lightllm.models.llama_multimodal.layer_infer.pre_layer_infer import LlamaMultiModalPreLayerInfer | ||
from lightllm.models.llama_multimodal.infer_struct import LlamaMultiModalInferStateInfo | ||
|
||
from lightllm.models.llama.model import LlamaTpPartModel | ||
from lightllm.common.infer_utils import init_bloc | ||
|
||
|
||
class LlamaTpPartMulitModal(LlamaTpPartModel): | ||
# weight class | ||
pre_and_post_weight_class = LlamaPreAndPostLayerWeight | ||
transformer_weight_class = LlamaTransformerLayerWeight | ||
|
||
# infer class | ||
pre_layer_infer_class = LlamaMultiModalPreLayerInfer | ||
post_layer_infer_class = LlamaPostLayerInfer | ||
transformer_layer_infer_class = LlamaTransformerLayerInfer | ||
|
||
# infer state class | ||
infer_state_class = LlamaMultiModalInferStateInfo | ||
|
||
def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=""): | ||
super().__init__(tp_rank, world_size, weight_dir, max_total_token_num, load_way, mode) | ||
return | ||
|
||
@torch.no_grad() | ||
def forward( | ||
self, | ||
batch_size, | ||
total_token_num, | ||
max_len_in_batch, | ||
input_ids : torch.Tensor, | ||
b_loc : torch.Tensor, | ||
b_start_loc : torch.Tensor, | ||
b_seq_len : torch.Tensor, | ||
is_prefill=True, | ||
**kwargs): | ||
if is_prefill: | ||
return self._prefill(batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, **kwargs) | ||
else: | ||
return self._decode(batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, **kwargs) | ||
|
||
|
||
def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, **kwargs): | ||
infer_state = self.infer_state_class() | ||
infer_state.is_prefill = True | ||
infer_state.batch_size = batch_size | ||
infer_state.total_token_num = total_token_num | ||
infer_state.max_len_in_batch = max_len_in_batch | ||
assert (input_ids.shape[0] == total_token_num), "{} vs {}".format(input_ids.shape, total_token_num) | ||
assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]) | ||
infer_state.b_loc = b_loc | ||
infer_state.b_start_loc = b_start_loc | ||
infer_state.b_seq_len = b_seq_len | ||
|
||
infer_state.mem_manager = self.mem_manager | ||
infer_state.prefill_mem_index = self.mem_manager.alloc(infer_state.total_token_num) | ||
infer_state.prefill_key_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") | ||
infer_state.prefill_value_buffer = torch.empty((infer_state.total_token_num, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") | ||
init_bloc(b_loc, b_seq_len, max_len_in_batch, infer_state.prefill_mem_index) | ||
|
||
infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, True, **kwargs) | ||
predict_logics = self._context_forward(input_ids, infer_state) | ||
return predict_logics | ||
|
||
def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, **kwargs): | ||
infer_state = self.infer_state_class() | ||
infer_state.is_prefill = False | ||
infer_state.batch_size = batch_size | ||
infer_state.total_token_num = total_token_num | ||
infer_state.max_len_in_batch = max_len_in_batch | ||
assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0]) | ||
infer_state.b_loc = b_loc | ||
infer_state.b_start_loc = b_start_loc | ||
infer_state.b_seq_len = b_seq_len | ||
|
||
infer_state.mem_manager = self.mem_manager | ||
|
||
alloc_mem = self.mem_manager.alloc_contiguous(batch_size) | ||
if alloc_mem is not None: | ||
infer_state.decode_is_contiguous = True | ||
infer_state.decode_mem_index = alloc_mem[0] | ||
infer_state.decode_mem_start = alloc_mem[1] | ||
infer_state.decode_mem_end = alloc_mem[2] | ||
b_loc[:, max_len_in_batch - 1] = infer_state.decode_mem_index | ||
else: | ||
infer_state.decode_is_contiguous = False | ||
alloc_mem = self.mem_manager.alloc(batch_size) | ||
infer_state.decode_mem_index = alloc_mem | ||
infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") | ||
infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda") | ||
b_loc[:, max_len_in_batch - 1] = infer_state.decode_mem_index | ||
|
||
infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, False, **kwargs) | ||
predict_logics = self._token_forward(input_ids, infer_state) | ||
return predict_logics |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,169 @@ | ||
import numpy as np | ||
from multiprocessing import Queue | ||
import multiprocessing | ||
|
||
def test_multimodal_inference(world_size, model_dir, model_class, batch_size, input_len, output_len, repad_embeds_args): | ||
ans_queue = Queue() | ||
workers = [] | ||
for rank_id in range(world_size): | ||
proc = multiprocessing.Process(target=tppart_multimodal_infer, args=(rank_id, world_size, ans_queue, model_dir, model_class, batch_size, input_len, output_len, repad_embeds_args)) | ||
proc.start() | ||
workers.append(proc) | ||
|
||
for proc in workers: | ||
proc.join() | ||
|
||
assert not ans_queue.empty() | ||
while not ans_queue.empty(): | ||
assert ans_queue.get() | ||
return | ||
|
||
|
||
# for multimodal, we need to pass the repad_embeds to forward | ||
def gen_repad_embeds(batch_size, model, input_len, repad_embeds_args): | ||
import torch | ||
all_input_ids = [] | ||
all_repad_embeds = [] | ||
pad_len, pad_dim_size, offset = repad_embeds_args | ||
|
||
for i in range(batch_size): | ||
# shape = [input_len] | ||
input_ids = torch.from_numpy(np.arange(5, input_len + 5).reshape(-1)).cuda() | ||
# shape = [pad_len, pad_dim_size] | ||
pad_embeds = torch.rand( | ||
size=(pad_len, pad_dim_size), | ||
dtype=torch.float16, | ||
device=input_ids.device, | ||
) - 0.5 | ||
all_repad_embeds.append((pad_embeds, offset)) | ||
|
||
# input_ids should be padded | ||
# shape = [pad_len] | ||
pad_ids = torch.zeros( | ||
size=(pad_len,), | ||
dtype=input_ids.dtype, | ||
device=input_ids.device, | ||
) | ||
# shape = [input_len + pad_len] | ||
input_ids = torch.cat([input_ids[:offset], pad_ids, input_ids[offset:]], dim=0) | ||
all_input_ids.append(input_ids) | ||
|
||
all_input_ids = torch.cat(all_input_ids) | ||
return all_input_ids, all_repad_embeds | ||
|
||
|
||
def tppart_multimodal_infer(rank_id, world_size, ans_queue, model_dir, model_class, batch_size, input_len, output_len, repad_embeds_args): | ||
import torch | ||
import torch.distributed as dist | ||
dist.init_process_group('nccl', init_method='tcp://127.0.0.1:28765', rank=rank_id, world_size=world_size) | ||
torch.cuda.set_device(rank_id) | ||
|
||
import torch.distributed as dist | ||
dist.barrier() | ||
torch.cuda.empty_cache() | ||
|
||
model_part = model_class(dist.get_rank(), | ||
dist.get_world_size(), | ||
max_total_token_num= batch_size * (input_len + output_len + repad_embeds_args[0]), | ||
weight_dir=model_dir, | ||
load_way="HF") | ||
# warm up | ||
test_data, test_embeds = gen_repad_embeds(batch_size, model_part, input_len, repad_embeds_args) | ||
# after gen_input_embeds, real input_len is plus by repad_embeds_args[0] | ||
input_len += repad_embeds_args[0] | ||
|
||
b_loc = torch.zeros(batch_size, input_len + output_len, dtype=torch.long, device="cuda") | ||
b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda") | ||
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") | ||
for i in range(batch_size): | ||
b_loc[i, 0:input_len] = i * input_len + torch.arange(0, input_len, dtype=torch.int32, device="cuda") | ||
b_start_loc[i] = i * input_len | ||
b_seq_len[i] = input_len | ||
|
||
total_token_num = input_len * batch_size | ||
logics = model_part.forward(batch_size, | ||
total_token_num, | ||
input_len, | ||
test_data, | ||
b_loc, | ||
b_start_loc, | ||
b_seq_len, | ||
is_prefill=True, | ||
repad_embeds=test_embeds) | ||
prob_out = torch.softmax(logics, dim=-1) | ||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) | ||
predict_ids = predict_ids.detach().cpu().numpy() | ||
|
||
for i in range(output_len): | ||
b_loc[:, input_len + i] = total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") | ||
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") | ||
total_token_num += batch_size | ||
b_seq_len += 1 | ||
logics = model_part.forward(batch_size, total_token_num, input_len + i + 1, torch.from_numpy( | ||
predict_ids).cuda().reshape(-1), b_loc, b_start_loc, b_seq_len, is_prefill=False) | ||
prob_out = torch.softmax(logics, dim=-1) | ||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) | ||
predict_ids = predict_ids.detach().cpu().numpy() | ||
|
||
max_len_in_batch = input_len + output_len | ||
for i in range(batch_size): | ||
model_part.mem_manager.free(b_loc[i, max_len_in_batch - b_seq_len[i]:max_len_in_batch]) | ||
if rank_id == 0: | ||
print("can use mem size:", model_part.mem_manager.can_use_mem_size) | ||
|
||
b_loc = None | ||
b_start_loc = None | ||
b_seq_len = None | ||
|
||
dist.barrier() | ||
import time | ||
torch.cuda.synchronize() | ||
start_time = time.time() | ||
|
||
prefill_start_time = time.time() | ||
|
||
b_loc = torch.zeros(batch_size, input_len + output_len, dtype=torch.long, device="cuda") | ||
b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda") | ||
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda") | ||
for i in range(batch_size): | ||
b_start_loc[i] = i * input_len | ||
b_seq_len[i] = input_len | ||
|
||
total_token_num = batch_size * input_len | ||
logics = model_part.forward(batch_size, total_token_num, input_len, test_data, | ||
b_loc, b_start_loc, b_seq_len, is_prefill=True, repad_embeds=test_embeds) | ||
prob_out = torch.softmax(logics, dim=-1) | ||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) | ||
predict_ids = predict_ids.detach().cpu().numpy() | ||
|
||
torch.cuda.synchronize() | ||
if rank_id == 0: | ||
print("prefill time cost:", (time.time() - prefill_start_time) * 1000) | ||
|
||
for i in range(output_len): | ||
torch.cuda.synchronize() | ||
step_start = time.time() | ||
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda") | ||
total_token_num += batch_size | ||
b_seq_len += 1 | ||
|
||
logics = model_part.forward(batch_size, total_token_num, input_len + i + 1, torch.from_numpy( | ||
predict_ids).cuda().reshape(-1), b_loc, b_start_loc, b_seq_len, is_prefill=False) | ||
prob_out = torch.softmax(logics, dim=-1) | ||
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True) | ||
predict_ids = predict_ids.detach().cpu().numpy() | ||
torch.cuda.synchronize() | ||
if i % 100 == 0 or i == output_len - 1: | ||
if rank_id == 0: | ||
print(i, "step cost time:", (time.time() - step_start) * 1000) | ||
|
||
torch.cuda.synchronize() | ||
end_time = time.time() | ||
|
||
if rank_id == 0: | ||
print("time total cost(ms):", (end_time - start_time) * 1000) | ||
ans_queue.put(True) | ||
|
||
return | ||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
import os | ||
import sys | ||
import unittest | ||
from model_infer_multimodal import test_multimodal_inference | ||
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))) | ||
|
||
class TestLlamaMultiModalInfer(unittest.TestCase): | ||
|
||
def test_llama_infer(self): | ||
from lightllm.models.llama_multimodal.model import LlamaTpPartMulitModal | ||
test_multimodal_inference(world_size=4, | ||
model_dir="/path/to/llama-7b", | ||
model_class=LlamaTpPartMulitModal, | ||
batch_size=10, | ||
input_len=1024, | ||
output_len=1024, | ||
# (pad_len, pad_dim_size, offset) | ||
repad_embeds_args=(36, 4096, 5)) | ||
return | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |