Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Vit parallel #590

Closed
wants to merge 15 commits into from
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram

> InternVL-Chat(InternLM2) needs to set the parameter '--eos_id 92542 --trust_remote_code'.

> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code'.
> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code', and use 'pip install git+https://github.com/huggingface/transformers' to upgrade to the latest version.

> Stablelm needs to set the parameter '--trust_remote_code'.

Expand Down
123 changes: 67 additions & 56 deletions lightllm/models/internlm_xcomposer/internlm_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,72 +7,75 @@
from typing import List, Union
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from rpyc.utils.classic import obtain
from io import BytesIO
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data


class InternVisionModel:

def __init__(self):
pass

def load_projector_update(self, config, weight_dir):
projector_type = config.get("projector_type", 'mlp2x_gelu')
projector_type = config.get("projector_type", "mlp2x_gelu")
projector_weights = []
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
self.mlp_depth = int(mlp_gelu_match.group(1))
new_dict = {}
for f in os.listdir(weight_dir):
if f.endswith(".bin"):
d = torch.load(os.path.join(weight_dir, f), "cpu")
for k, v in d.items():
if 'vision_proj' in k:
if "vision_proj" in k:
projector_weights.append(v.half())
elif "vit.vision_tower." in k:
new_dict[k[len("vit.vision_tower."):]] = v.half()
new_dict[k[len("vit.vision_tower.") :]] = v.half()
self.vision_tower.load_state_dict(new_dict, strict=True)
return projector_weights
if projector_type == 'identity':
if projector_type == "identity":
return []
raise ValueError(f'Unknown projector type: {projector_type}')
raise ValueError(f"Unknown projector type: {projector_type}")

def load_model(self, weight_dir):
config_file = os.path.join(weight_dir, "config.json")
config = json.load(open(config_file))
self.select_layer = config.get('mm_vision_select_layer', -1)
self.select_feature = config.get('mm_vision_select_feature', 'patch')
self.select_layer = config.get("mm_vision_select_layer", -1)
self.select_feature = config.get("mm_vision_select_feature", "patch")

# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = config.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336')
vision_path = config.get("mm_vision_tower", "openai/clip-vit-large-patch14-336")
if isinstance(vision_path, list):
vision_path = vision_path[0]
if vision_path.startswith("./"):
vision_path = os.path.join(weight_dir, vision_path)
self.image_processor = transforms.Compose([
transforms.Resize((config["img_size"], config["img_size"]),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])

self.image_processor = transforms.Compose(
[
transforms.Resize((config["img_size"], config["img_size"]), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
from transformers import CLIPVisionModel

self.vision_tower = CLIPVisionModel.from_pretrained(vision_path)
self.vision_tower.requires_grad_(False)
self.resize_pos(config, vision_path)
self.projector_weights = self.load_projector_update(config, weight_dir)
self.vision_tower = self.vision_tower.half()
self.device = torch.device('cpu')
self.device = torch.device("cpu")

# load projector weights
assert len(self.projector_weights) == self.mlp_depth * 2

def resize_pos(self, config, vision_path):
mm_vision_tower = vision_path.split('/')[-1]
vision_tower_match = re.match(r'^clip-vit-large-patch(\d+)-(\d+)$', mm_vision_tower)
mm_vision_tower = vision_path.split("/")[-1]
vision_tower_match = re.match(r"^clip-vit-large-patch(\d+)-(\d+)$", mm_vision_tower)
patch_size = int(vision_tower_match.group(1))
clip_imge_size = int(vision_tower_match.group(2))

orig_size = clip_imge_size // patch_size
new_size = config["img_size"] // patch_size
if orig_size == new_size:
Expand All @@ -82,58 +85,51 @@ def resize_pos(self, config, vision_path):
pos_embed_checkpoint = self.vision_tower.vision_model.embeddings.position_embedding.weight
pos_embed_checkpoint = pos_embed_checkpoint.unsqueeze(0)

if pos_embed_checkpoint.shape[1] == new_size**2 + 1:
if pos_embed_checkpoint.shape[1] == new_size ** 2 + 1:
self.is_resize_pos = True
else:
embedding_size = pos_embed_checkpoint.shape[-1]
num_extra_tokens = 1
new_num = new_size**2 + num_extra_tokens
print('Position interpolate from %dx%d to %dx%d' %
(orig_size, orig_size, new_size, new_size))
new_num = new_size ** 2 + num_extra_tokens
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
embedding_size).permute(
0, 3, 1, 2)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode='bicubic',
align_corners=False)
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)

new_pos_embed = new_pos_embed.squeeze(0)

self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(
new_num, 1024)
self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(new_num, 1024)
self.vision_tower.vision_model.embeddings.position_embedding.weight = torch.nn.Parameter(
new_pos_embed.to(pos_embed_checkpoint.dtype))
self.vision_tower.vision_model.embeddings.position_ids = torch.arange(
new_num).expand((1, -1))
new_pos_embed.to(pos_embed_checkpoint.dtype)
)
self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_num).expand((1, -1))
self.is_resize_pos = True

def cuda(self):
self.vision_tower = self.vision_tower.cuda()
for i in range(len(self.projector_weights)):
self.projector_weights[i] = self.projector_weights[i].cuda()
self.device = torch.device('cuda')
return self

# batch images infer
def forward(self, x):
x = x.to(device=self.device, dtype=self.vision_tower.dtype)
x = x.cuda().to(dtype=self.vision_tower.dtype)

x = self.vision_tower(x, output_hidden_states=True)
x = x.hidden_states[self.select_layer]
if self.select_feature == 'patch':
if self.select_feature == "patch":
x = x[:, 1:].contiguous()

if len(self.projector_weights) == 0:
return x
B, L, N = x.shape

B, L, N = x.shape
x = x.view(-1, N)
# mm_project
x = F.linear(
Expand All @@ -151,16 +147,31 @@ def forward(self, x):
x = x.view(B, L, -1)
return x

def encode(self, image_items: List[Union[str, Image.Image]]):
images = []
for item in image_items:
if isinstance(item, Image.Image):
image = item
elif item.startswith("http://") or item.startswith("https://"):
image = Image.open(requests.get(item, stream=True).raw)
def encode(self, image_uuids: List):
img_tensors = []
uuids = []
valid_id = 0
valid_ids = []

for i, item in enumerate(image_uuids):
item = obtain(item)
if isinstance(item, int):
uuids.append(item)
image_data = read_shm(get_shm_name_data(item))
image_data = Image.open(BytesIO(image_data)).convert("RGB")
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
img_tensors.append(t)
else:
image = Image.open(item)
image = self.image_processor(image.convert('RGB')).unsqueeze(0).to(self.device)
images.append(image)
images = torch.cat(images, dim=0)
return self.forward(images)
raise Exception("Unsupport input types: {} for {}".format(type(item), item))

cur_num = img_tensors[-1].shape[0]
valid_ids.append([valid_id, valid_id + cur_num])
valid_id += cur_num

if len(img_tensors) <= 0:
return None

img = torch.cat(img_tensors, dim=0)
all_img_embeds = self.forward(img)

return all_img_embeds, uuids, valid_ids
35 changes: 19 additions & 16 deletions lightllm/models/internvl/internvl_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,21 @@
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import requests
from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
import rpyc
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from io import BytesIO
from lightllm.models.internvl.img_process import load_image
from rpyc.utils.classic import obtain
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class InternVLVisionModel:
def __init__(self, kvargs):
self.cache_port = kvargs["client_port"]
self.cache_client = None
def __init__(self):
pass

def load_model(self, weight_dir):
assert torch.cuda.is_available()
self.device = torch.device("cuda")
self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
self.config = json.load(open(os.path.join(weight_dir, "config.json")))
self.model = AutoModel.from_pretrained(
Expand All @@ -37,28 +36,32 @@ def load_model(self, weight_dir):
def cuda(self):
return self

def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]):
def encode(self, image_uuids: List):
img_tensors = []
valid_ids = []
valid_id = 0
# load images to batch tensor
for i, url in enumerate(image_items):
if isinstance(url, Image.Image):
t = load_image(url, max_num=6)
uuids = []

for i, url in enumerate(image_uuids):
url = obtain(url)
if isinstance(url, int):
uuids.append(url)
image_data = read_shm(get_shm_name_data(url))
image_data = Image.open(BytesIO(image_data))
t = load_image(image_data)
img_tensors.append(t)
else:
raise Exception("Unsupport input types: {} for {}".format(type(url), url))

cur_num = img_tensors[-1].shape[0]

valid_ids.append([valid_id, valid_id + cur_num])
valid_id += cur_num

if len(img_tensors) <= 0:
return None

# (b, 3, 224, 224)
imgs = torch.cat(img_tensors, dim=0)
pixel_values = imgs.to(self.device, dtype=self.dtype)
pixel_values = imgs.cuda().to(dtype=self.dtype)
all_img_embeds = self.model.extract_feature(pixel_values)
return [all_img_embeds[start:end] for start, end in valid_ids]

return all_img_embeds, uuids, valid_ids
Loading
Loading