Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vit parallel #577

Closed
wants to merge 14 commits into from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram

> InternVL-Chat(InternLM2) needs to set the parameter '--eos_id 92542 --trust_remote_code'.

> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code'.
> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code', and use 'pip install git+https://github.com/huggingface/transformers' to upgrade to the latest version.

> Stablelm needs to set the parameter '--trust_remote_code'.

Expand Down
140 changes: 84 additions & 56 deletions lightllm/models/internlm_xcomposer/internlm_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,72 +7,80 @@
from typing import List, Union
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from rpyc.utils.classic import obtain
from io import BytesIO
import rpyc
from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
from lightllm.utils.log_utils import init_logger


class InternVisionModel:

def __init__(self):
def __init__(self, kvargs):
self.visual_gpu = kvargs["visual_gpu"]
self.vit_tp = kvargs["vit_tp"]
self.device = torch.device(f"cuda:{self.visual_gpu}")
pass

def load_projector_update(self, config, weight_dir):
projector_type = config.get("projector_type", 'mlp2x_gelu')
projector_type = config.get("projector_type", "mlp2x_gelu")
projector_weights = []
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
self.mlp_depth = int(mlp_gelu_match.group(1))
new_dict = {}
for f in os.listdir(weight_dir):
if f.endswith(".bin"):
d = torch.load(os.path.join(weight_dir, f), "cpu")
for k, v in d.items():
if 'vision_proj' in k:
if "vision_proj" in k:
projector_weights.append(v.half())
elif "vit.vision_tower." in k:
new_dict[k[len("vit.vision_tower."):]] = v.half()
new_dict[k[len("vit.vision_tower.") :]] = v.half()
self.vision_tower.load_state_dict(new_dict, strict=True)
return projector_weights
if projector_type == 'identity':
if projector_type == "identity":
return []
raise ValueError(f'Unknown projector type: {projector_type}')
raise ValueError(f"Unknown projector type: {projector_type}")

def load_model(self, weight_dir):
config_file = os.path.join(weight_dir, "config.json")
config = json.load(open(config_file))
self.select_layer = config.get('mm_vision_select_layer', -1)
self.select_feature = config.get('mm_vision_select_feature', 'patch')
self.select_layer = config.get("mm_vision_select_layer", -1)
self.select_feature = config.get("mm_vision_select_feature", "patch")

# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = config.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336')
vision_path = config.get("mm_vision_tower", "openai/clip-vit-large-patch14-336")
if isinstance(vision_path, list):
vision_path = vision_path[0]
if vision_path.startswith("./"):
vision_path = os.path.join(weight_dir, vision_path)
self.image_processor = transforms.Compose([
transforms.Resize((config["img_size"], config["img_size"]),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])

self.image_processor = transforms.Compose(
[
transforms.Resize((config["img_size"], config["img_size"]), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
from transformers import CLIPVisionModel

self.vision_tower = CLIPVisionModel.from_pretrained(vision_path)
self.vision_tower.requires_grad_(False)
self.resize_pos(config, vision_path)
self.projector_weights = self.load_projector_update(config, weight_dir)
self.vision_tower = self.vision_tower.half()
self.device = torch.device('cpu')
self.device = torch.device("cpu")

# load projector weights
assert len(self.projector_weights) == self.mlp_depth * 2

def resize_pos(self, config, vision_path):
mm_vision_tower = vision_path.split('/')[-1]
vision_tower_match = re.match(r'^clip-vit-large-patch(\d+)-(\d+)$', mm_vision_tower)
mm_vision_tower = vision_path.split("/")[-1]
vision_tower_match = re.match(r"^clip-vit-large-patch(\d+)-(\d+)$", mm_vision_tower)
patch_size = int(vision_tower_match.group(1))
clip_imge_size = int(vision_tower_match.group(2))

orig_size = clip_imge_size // patch_size
new_size = config["img_size"] // patch_size
if orig_size == new_size:
Expand All @@ -82,43 +90,37 @@ def resize_pos(self, config, vision_path):
pos_embed_checkpoint = self.vision_tower.vision_model.embeddings.position_embedding.weight
pos_embed_checkpoint = pos_embed_checkpoint.unsqueeze(0)

if pos_embed_checkpoint.shape[1] == new_size**2 + 1:
if pos_embed_checkpoint.shape[1] == new_size ** 2 + 1:
self.is_resize_pos = True
else:
embedding_size = pos_embed_checkpoint.shape[-1]
num_extra_tokens = 1
new_num = new_size**2 + num_extra_tokens
print('Position interpolate from %dx%d to %dx%d' %
(orig_size, orig_size, new_size, new_size))
new_num = new_size ** 2 + num_extra_tokens
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
embedding_size).permute(
0, 3, 1, 2)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode='bicubic',
align_corners=False)
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)

new_pos_embed = new_pos_embed.squeeze(0)

self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(
new_num, 1024)
self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(new_num, 1024)
self.vision_tower.vision_model.embeddings.position_embedding.weight = torch.nn.Parameter(
new_pos_embed.to(pos_embed_checkpoint.dtype))
self.vision_tower.vision_model.embeddings.position_ids = torch.arange(
new_num).expand((1, -1))
new_pos_embed.to(pos_embed_checkpoint.dtype)
)
self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_num).expand((1, -1))
self.is_resize_pos = True

def cuda(self):
self.vision_tower = self.vision_tower.cuda()
def cuda(self, device):
self.vision_tower = self.vision_tower.cuda(device)
for i in range(len(self.projector_weights)):
self.projector_weights[i] = self.projector_weights[i].cuda()
self.device = torch.device('cuda')
self.projector_weights[i] = self.projector_weights[i].cuda(device)
self.device = torch.device(f"cuda:{self.visual_gpu}")
return self

# batch images infer
Expand All @@ -127,13 +129,13 @@ def forward(self, x):

x = self.vision_tower(x, output_hidden_states=True)
x = x.hidden_states[self.select_layer]
if self.select_feature == 'patch':
if self.select_feature == "patch":
x = x[:, 1:].contiguous()

if len(self.projector_weights) == 0:
return x
B, L, N = x.shape

B, L, N = x.shape
x = x.view(-1, N)
# mm_project
x = F.linear(
Expand All @@ -151,16 +153,42 @@ def forward(self, x):
x = x.view(B, L, -1)
return x

def encode(self, image_items: List[Union[str, Image.Image]]):
images = []
for item in image_items:
def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]):
img_tensors = []
uuids = []
valid_id = 0
valid_ids = []
for i, item in enumerate(image_items):
if self.vit_tp != 1:
item = obtain(item)
if isinstance(item, Image.Image):
image = item
image = item.convert("RGB")
t = self.image_processor.preprocess(image, return_tensors="pt")["pixel_values"]
img_tensors.append(t)
elif isinstance(item, torch.Tensor):
img_tensors.append(item)
elif isinstance(item, int):
uuids.append(item)
image_data = read_shm(get_shm_name_data(item))
image_data = Image.open(BytesIO(image_data)).convert("RGB")
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
img_tensors.append(t)
elif item.startswith("http://") or item.startswith("https://"):
import requests

image = Image.open(requests.get(item, stream=True).raw)
else:
image = Image.open(item)
image = self.image_processor(image.convert('RGB')).unsqueeze(0).to(self.device)
images.append(image)
images = torch.cat(images, dim=0)
return self.forward(images)
raise Exception("Unsupport input types: {} for {}".format(type(item), item))

cur_num = img_tensors[-1].shape[0]

valid_ids.append([valid_id, valid_id + cur_num])
valid_id += cur_num

if len(img_tensors) <= 0:
return None

img = torch.cat(img_tensors, dim=0)
all_img_embeds = self.forward(img)

return all_img_embeds, uuids, valid_ids
37 changes: 28 additions & 9 deletions lightllm/models/internvl/internvl_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,23 @@
import rpyc
from io import BytesIO
from lightllm.models.internvl.img_process import load_image
from rpyc.utils.classic import obtain
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class InternVLVisionModel:
def __init__(self, kvargs):
self.cache_port = kvargs["client_port"]
self.cache_client = None
self.tp_rank_id = kvargs["tp_rank_id"]
self.vit_tp = kvargs["vit_tp"]
self.visual_gpu = kvargs["visual_gpu"]
self.device = torch.device(f"cuda:{self.visual_gpu}")
print(f"self.device is {self.device}")
pass

def load_model(self, weight_dir):
assert torch.cuda.is_available()
self.device = torch.device("cuda")
self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
self.config = json.load(open(os.path.join(weight_dir, "config.json")))
self.model = AutoModel.from_pretrained(
Expand All @@ -32,20 +38,32 @@ def load_model(self, weight_dir):
trust_remote_code=True,
language_model="fake_language_model",
)
self.model.eval().cuda()
self.model.eval().cuda(self.device)

def cuda(self):
def cuda(self, device):
return self

def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]):
def encode(self, image_items: List[Union[int, str, torch.Tensor, Image.Image]]):
img_tensors = []
valid_ids = []
valid_id = 0
uuids = []
# load images to batch tensor

for i, url in enumerate(image_items):
if self.vit_tp != 1:
url = obtain(url)
if isinstance(url, Image.Image):
t = load_image(url, max_num=6)
img_tensors.append(t)
elif isinstance(url, torch.Tensor):
img_tensors.append(url)
elif isinstance(url, int):
uuids.append(url)
image_data = read_shm(get_shm_name_data(url))
image_data = Image.open(BytesIO(image_data))
t = load_image(image_data)
img_tensors.append(t)
else:
raise Exception("Unsupport input types: {} for {}".format(type(url), url))

Expand All @@ -56,9 +74,10 @@ def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]):

if len(img_tensors) <= 0:
return None

# (b, 3, 224, 224)
torch.cuda.set_device(self.device)
imgs = torch.cat(img_tensors, dim=0)
pixel_values = imgs.to(self.device, dtype=self.dtype)
pixel_values = imgs.to(device=self.device, dtype=self.dtype)
all_img_embeds = self.model.extract_feature(pixel_values)
return [all_img_embeds[start:end] for start, end in valid_ids]

return all_img_embeds, uuids, valid_ids
Loading
Loading