Skip to content

Commit

Permalink
Support Llama Multimodal (#133)
Browse files Browse the repository at this point in the history
Support Llama Multimodal with repad embedding mode

---------

Co-authored-by: liuliang1 <[email protected]>
  • Loading branch information
huochaitiantang and liuliang1 authored Sep 18, 2023
1 parent bf8a829 commit 87557d1
Show file tree
Hide file tree
Showing 7 changed files with 368 additions and 0 deletions.
Empty file.
21 changes: 21 additions & 0 deletions lightllm/models/llama_multimodal/infer_struct.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import torch
from lightllm.models.llama.infer_struct import LlamaInferStateInfo


class LlamaMultiModalInferStateInfo(LlamaInferStateInfo):
def __init__(self):
super().__init__()

def init_some_extra_state(self,
model,
batch_size,
total_token_num,
max_len_in_batch,
input_ids : torch.Tensor,
b_loc : torch.Tensor,
b_start_loc : torch.Tensor,
b_seq_len : torch.Tensor,
is_prefill,
**kwargs):
super().init_some_extra_state(model, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, is_prefill)
self.kwargs = kwargs
Empty file.
53 changes: 53 additions & 0 deletions lightllm/models/llama_multimodal/layer_infer/pre_layer_infer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
import torch
import torch.distributed as dist

from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
from lightllm.models.llama_multimodal.infer_struct import LlamaMultiModalInferStateInfo
from lightllm.utils.infer_utils import mark_cost_time
from lightllm.models.llama.layer_infer.pre_layer_infer import LlamaPreLayerInfer


"""
infer_state.kwargs['repad_embeds'] = [(embeds, offset), ...]
embeds: torch.Tensor or None
offset: int
"""
def repad_input_embeds(input_embeds, infer_state: LlamaMultiModalInferStateInfo):
assert isinstance(infer_state, LlamaMultiModalInferStateInfo)

if infer_state.kwargs and 'repad_embeds' in infer_state.kwargs:
repad_embeds = infer_state.kwargs['repad_embeds']
assert len(repad_embeds) == infer_state.batch_size, "length of repad_embeds != batch_size: {} vs {}!".format(len(repad_embeds), infer_state.batch_size)

for i, (embeds, offset) in enumerate(repad_embeds):
# no need to repad if not given repad embeds
if embeds is None:
continue
assert isinstance(embeds, torch.Tensor), "given reapd embeds should be torch.Tensor but got {}!".format(type(embeds))

start_idx = infer_state.b_start_loc[i]
seq_len = infer_state.b_seq_len[i]
pad_len, pad_dim = embeds.shape
dim = input_embeds.shape[1]
assert pad_dim == dim, "invalid pad_dim={}, input_embed_dim={}!".format(pad_dim, dim)
assert offset + pad_len <= seq_len, "invalid seq_len={}, offset={}, pad_len={}!".format(seq_len, offset, pad_len)
input_embeds[start_idx + offset: start_idx + offset + pad_len] = embeds
print("repad input_embeds start_idx={} offset={} pad_len={}".format(start_idx, offset, pad_len))
return input_embeds


class LlamaMultiModalPreLayerInfer(LlamaPreLayerInfer):

def __init__(self, tp_rank, world_size, network_config, mode):
super().__init__(tp_rank, world_size, network_config, mode)
return

@mark_cost_time("pre context forward")
def context_forward(self, input_ids, infer_state: LlamaMultiModalInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
input_embeds = super().context_forward(input_ids, infer_state, layer_weight)
return repad_input_embeds(input_embeds, infer_state)


def token_forward(self, input_ids, infer_state: LlamaMultiModalInferStateInfo, layer_weight: LlamaPreAndPostLayerWeight):
input_embeds = super().token_forward(input_ids, infer_state, layer_weight)
return repad_input_embeds(input_embeds, infer_state)
102 changes: 102 additions & 0 deletions lightllm/models/llama_multimodal/model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import torch

from lightllm.models.llama.layer_infer.post_layer_infer import LlamaPostLayerInfer
from lightllm.models.llama.layer_infer.transformer_layer_infer import LlamaTransformerLayerInfer
from lightllm.models.llama.layer_weights.pre_and_post_layer_weight import LlamaPreAndPostLayerWeight
from lightllm.models.llama.layer_weights.transformer_layer_weight import LlamaTransformerLayerWeight

from lightllm.models.llama_multimodal.layer_infer.pre_layer_infer import LlamaMultiModalPreLayerInfer
from lightllm.models.llama_multimodal.infer_struct import LlamaMultiModalInferStateInfo

from lightllm.models.llama.model import LlamaTpPartModel
from lightllm.common.infer_utils import init_bloc


class LlamaTpPartMulitModal(LlamaTpPartModel):
# weight class
pre_and_post_weight_class = LlamaPreAndPostLayerWeight
transformer_weight_class = LlamaTransformerLayerWeight

# infer class
pre_layer_infer_class = LlamaMultiModalPreLayerInfer
post_layer_infer_class = LlamaPostLayerInfer
transformer_layer_infer_class = LlamaTransformerLayerInfer

# infer state class
infer_state_class = LlamaMultiModalInferStateInfo

def __init__(self, tp_rank, world_size, weight_dir, max_total_token_num, load_way="HF", mode=""):
super().__init__(tp_rank, world_size, weight_dir, max_total_token_num, load_way, mode)
return

@torch.no_grad()
def forward(
self,
batch_size,
total_token_num,
max_len_in_batch,
input_ids : torch.Tensor,
b_loc : torch.Tensor,
b_start_loc : torch.Tensor,
b_seq_len : torch.Tensor,
is_prefill=True,
**kwargs):
if is_prefill:
return self._prefill(batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, **kwargs)
else:
return self._decode(batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, **kwargs)


def _prefill(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, **kwargs):
infer_state = self.infer_state_class()
infer_state.is_prefill = True
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
assert (input_ids.shape[0] == total_token_num), "{} vs {}".format(input_ids.shape, total_token_num)
assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0])
infer_state.b_loc = b_loc
infer_state.b_start_loc = b_start_loc
infer_state.b_seq_len = b_seq_len

infer_state.mem_manager = self.mem_manager
infer_state.prefill_mem_index = self.mem_manager.alloc(infer_state.total_token_num)
infer_state.prefill_key_buffer = torch.empty((infer_state.total_token_num, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.prefill_value_buffer = torch.empty((infer_state.total_token_num, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
init_bloc(b_loc, b_seq_len, max_len_in_batch, infer_state.prefill_mem_index)

infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, True, **kwargs)
predict_logics = self._context_forward(input_ids, infer_state)
return predict_logics

def _decode(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, **kwargs):
infer_state = self.infer_state_class()
infer_state.is_prefill = False
infer_state.batch_size = batch_size
infer_state.total_token_num = total_token_num
infer_state.max_len_in_batch = max_len_in_batch
assert (b_loc.shape[0] == b_start_loc.shape[0] == b_seq_len.shape[0])
infer_state.b_loc = b_loc
infer_state.b_start_loc = b_start_loc
infer_state.b_seq_len = b_seq_len

infer_state.mem_manager = self.mem_manager

alloc_mem = self.mem_manager.alloc_contiguous(batch_size)
if alloc_mem is not None:
infer_state.decode_is_contiguous = True
infer_state.decode_mem_index = alloc_mem[0]
infer_state.decode_mem_start = alloc_mem[1]
infer_state.decode_mem_end = alloc_mem[2]
b_loc[:, max_len_in_batch - 1] = infer_state.decode_mem_index
else:
infer_state.decode_is_contiguous = False
alloc_mem = self.mem_manager.alloc(batch_size)
infer_state.decode_mem_index = alloc_mem
infer_state.decode_key_buffer = torch.empty((batch_size, self.tp_k_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
infer_state.decode_value_buffer = torch.empty((batch_size, self.tp_v_head_num_, self.head_dim_), dtype=torch.float16, device="cuda")
b_loc[:, max_len_in_batch - 1] = infer_state.decode_mem_index

infer_state.init_some_extra_state(self, batch_size, total_token_num, max_len_in_batch, input_ids, b_loc, b_start_loc, b_seq_len, False, **kwargs)
predict_logics = self._token_forward(input_ids, infer_state)
return predict_logics
169 changes: 169 additions & 0 deletions test/model/model_infer_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
import numpy as np
from multiprocessing import Queue
import multiprocessing

def test_multimodal_inference(world_size, model_dir, model_class, batch_size, input_len, output_len, repad_embeds_args):
ans_queue = Queue()
workers = []
for rank_id in range(world_size):
proc = multiprocessing.Process(target=tppart_multimodal_infer, args=(rank_id, world_size, ans_queue, model_dir, model_class, batch_size, input_len, output_len, repad_embeds_args))
proc.start()
workers.append(proc)

for proc in workers:
proc.join()

assert not ans_queue.empty()
while not ans_queue.empty():
assert ans_queue.get()
return


# for multimodal, we need to pass the repad_embeds to forward
def gen_repad_embeds(batch_size, model, input_len, repad_embeds_args):
import torch
all_input_ids = []
all_repad_embeds = []
pad_len, pad_dim_size, offset = repad_embeds_args

for i in range(batch_size):
# shape = [input_len]
input_ids = torch.from_numpy(np.arange(5, input_len + 5).reshape(-1)).cuda()
# shape = [pad_len, pad_dim_size]
pad_embeds = torch.rand(
size=(pad_len, pad_dim_size),
dtype=torch.float16,
device=input_ids.device,
) - 0.5
all_repad_embeds.append((pad_embeds, offset))

# input_ids should be padded
# shape = [pad_len]
pad_ids = torch.zeros(
size=(pad_len,),
dtype=input_ids.dtype,
device=input_ids.device,
)
# shape = [input_len + pad_len]
input_ids = torch.cat([input_ids[:offset], pad_ids, input_ids[offset:]], dim=0)
all_input_ids.append(input_ids)

all_input_ids = torch.cat(all_input_ids)
return all_input_ids, all_repad_embeds


def tppart_multimodal_infer(rank_id, world_size, ans_queue, model_dir, model_class, batch_size, input_len, output_len, repad_embeds_args):
import torch
import torch.distributed as dist
dist.init_process_group('nccl', init_method='tcp://127.0.0.1:28765', rank=rank_id, world_size=world_size)
torch.cuda.set_device(rank_id)

import torch.distributed as dist
dist.barrier()
torch.cuda.empty_cache()

model_part = model_class(dist.get_rank(),
dist.get_world_size(),
max_total_token_num= batch_size * (input_len + output_len + repad_embeds_args[0]),
weight_dir=model_dir,
load_way="HF")
# warm up
test_data, test_embeds = gen_repad_embeds(batch_size, model_part, input_len, repad_embeds_args)
# after gen_input_embeds, real input_len is plus by repad_embeds_args[0]
input_len += repad_embeds_args[0]

b_loc = torch.zeros(batch_size, input_len + output_len, dtype=torch.long, device="cuda")
b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
for i in range(batch_size):
b_loc[i, 0:input_len] = i * input_len + torch.arange(0, input_len, dtype=torch.int32, device="cuda")
b_start_loc[i] = i * input_len
b_seq_len[i] = input_len

total_token_num = input_len * batch_size
logics = model_part.forward(batch_size,
total_token_num,
input_len,
test_data,
b_loc,
b_start_loc,
b_seq_len,
is_prefill=True,
repad_embeds=test_embeds)
prob_out = torch.softmax(logics, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()

for i in range(output_len):
b_loc[:, input_len + i] = total_token_num + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num += batch_size
b_seq_len += 1
logics = model_part.forward(batch_size, total_token_num, input_len + i + 1, torch.from_numpy(
predict_ids).cuda().reshape(-1), b_loc, b_start_loc, b_seq_len, is_prefill=False)
prob_out = torch.softmax(logics, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()

max_len_in_batch = input_len + output_len
for i in range(batch_size):
model_part.mem_manager.free(b_loc[i, max_len_in_batch - b_seq_len[i]:max_len_in_batch])
if rank_id == 0:
print("can use mem size:", model_part.mem_manager.can_use_mem_size)

b_loc = None
b_start_loc = None
b_seq_len = None

dist.barrier()
import time
torch.cuda.synchronize()
start_time = time.time()

prefill_start_time = time.time()

b_loc = torch.zeros(batch_size, input_len + output_len, dtype=torch.long, device="cuda")
b_start_loc = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
b_seq_len = torch.zeros(batch_size, dtype=torch.int32, device="cuda")
for i in range(batch_size):
b_start_loc[i] = i * input_len
b_seq_len[i] = input_len

total_token_num = batch_size * input_len
logics = model_part.forward(batch_size, total_token_num, input_len, test_data,
b_loc, b_start_loc, b_seq_len, is_prefill=True, repad_embeds=test_embeds)
prob_out = torch.softmax(logics, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()

torch.cuda.synchronize()
if rank_id == 0:
print("prefill time cost:", (time.time() - prefill_start_time) * 1000)

for i in range(output_len):
torch.cuda.synchronize()
step_start = time.time()
b_start_loc = b_start_loc + torch.arange(0, batch_size, dtype=torch.int32, device="cuda")
total_token_num += batch_size
b_seq_len += 1

logics = model_part.forward(batch_size, total_token_num, input_len + i + 1, torch.from_numpy(
predict_ids).cuda().reshape(-1), b_loc, b_start_loc, b_seq_len, is_prefill=False)
prob_out = torch.softmax(logics, dim=-1)
predict_ids = torch.argmax(prob_out, dim=1, keepdim=True)
predict_ids = predict_ids.detach().cpu().numpy()
torch.cuda.synchronize()
if i % 100 == 0 or i == output_len - 1:
if rank_id == 0:
print(i, "step cost time:", (time.time() - step_start) * 1000)

torch.cuda.synchronize()
end_time = time.time()

if rank_id == 0:
print("time total cost(ms):", (end_time - start_time) * 1000)
ans_queue.put(True)

return


23 changes: 23 additions & 0 deletions test/model/test_llama_multimodal.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import os
import sys
import unittest
from model_infer_multimodal import test_multimodal_inference
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))

class TestLlamaMultiModalInfer(unittest.TestCase):

def test_llama_infer(self):
from lightllm.models.llama_multimodal.model import LlamaTpPartMulitModal
test_multimodal_inference(world_size=4,
model_dir="/path/to/llama-7b",
model_class=LlamaTpPartMulitModal,
batch_size=10,
input_len=1024,
output_len=1024,
# (pad_len, pad_dim_size, offset)
repad_embeds_args=(36, 4096, 5))
return


if __name__ == '__main__':
unittest.main()

0 comments on commit 87557d1

Please sign in to comment.