Skip to content

Commit

Permalink
Vit parallel (#592)
Browse files Browse the repository at this point in the history
Co-authored-by: WANDY666 <[email protected]>
Co-authored-by: sangchengmeng <[email protected]>
Co-authored-by: baishihao <[email protected]>
Co-authored-by: hiworldwzj <[email protected]>
  • Loading branch information
5 people authored Oct 28, 2024
1 parent 96c5e23 commit 00569d0
Show file tree
Hide file tree
Showing 12 changed files with 463 additions and 323 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram
> InternVL-Chat(InternLM2) needs to set the parameter '--eos_id 92542 --trust_remote_code'.
> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code'.
> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code', and use 'pip install git+https://github.com/huggingface/transformers' to upgrade to the latest version.
> Stablelm needs to set the parameter '--trust_remote_code'.
Expand Down
121 changes: 65 additions & 56 deletions lightllm/models/internlm_xcomposer/internlm_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,72 +7,74 @@
from typing import List, Union
from torchvision import transforms
from torchvision.transforms.functional import InterpolationMode
from io import BytesIO
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data


class InternVisionModel:

def __init__(self):
pass

def load_projector_update(self, config, weight_dir):
projector_type = config.get("projector_type", 'mlp2x_gelu')
projector_type = config.get("projector_type", "mlp2x_gelu")
projector_weights = []
mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type)
mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type)
if mlp_gelu_match:
self.mlp_depth = int(mlp_gelu_match.group(1))
new_dict = {}
for f in os.listdir(weight_dir):
if f.endswith(".bin"):
d = torch.load(os.path.join(weight_dir, f), "cpu")
for k, v in d.items():
if 'vision_proj' in k:
if "vision_proj" in k:
projector_weights.append(v.half())
elif "vit.vision_tower." in k:
new_dict[k[len("vit.vision_tower."):]] = v.half()
new_dict[k[len("vit.vision_tower.") :]] = v.half()
self.vision_tower.load_state_dict(new_dict, strict=True)
return projector_weights
if projector_type == 'identity':
if projector_type == "identity":
return []
raise ValueError(f'Unknown projector type: {projector_type}')
raise ValueError(f"Unknown projector type: {projector_type}")

def load_model(self, weight_dir):
config_file = os.path.join(weight_dir, "config.json")
config = json.load(open(config_file))
self.select_layer = config.get('mm_vision_select_layer', -1)
self.select_feature = config.get('mm_vision_select_feature', 'patch')
self.select_layer = config.get("mm_vision_select_layer", -1)
self.select_feature = config.get("mm_vision_select_feature", "patch")

# load clip vision model by cfg['mm_vision_tower']:
# huggingface_name or path_of_clip_relative_to_llava_model_dir
vision_path = config.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336')
vision_path = config.get("mm_vision_tower", "openai/clip-vit-large-patch14-336")
if isinstance(vision_path, list):
vision_path = vision_path[0]
if vision_path.startswith("./"):
vision_path = os.path.join(weight_dir, vision_path)
self.image_processor = transforms.Compose([
transforms.Resize((config["img_size"], config["img_size"]),
interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073),
(0.26862954, 0.26130258, 0.27577711)),
])

self.image_processor = transforms.Compose(
[
transforms.Resize((config["img_size"], config["img_size"]), interpolation=InterpolationMode.BICUBIC),
transforms.ToTensor(),
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)),
]
)
from transformers import CLIPVisionModel

self.vision_tower = CLIPVisionModel.from_pretrained(vision_path)
self.vision_tower.requires_grad_(False)
self.resize_pos(config, vision_path)
self.projector_weights = self.load_projector_update(config, weight_dir)
self.vision_tower = self.vision_tower.half()
self.device = torch.device('cpu')
self.device = torch.device("cpu")

# load projector weights
assert len(self.projector_weights) == self.mlp_depth * 2

def resize_pos(self, config, vision_path):
mm_vision_tower = vision_path.split('/')[-1]
vision_tower_match = re.match(r'^clip-vit-large-patch(\d+)-(\d+)$', mm_vision_tower)
mm_vision_tower = vision_path.split("/")[-1]
vision_tower_match = re.match(r"^clip-vit-large-patch(\d+)-(\d+)$", mm_vision_tower)
patch_size = int(vision_tower_match.group(1))
clip_imge_size = int(vision_tower_match.group(2))

orig_size = clip_imge_size // patch_size
new_size = config["img_size"] // patch_size
if orig_size == new_size:
Expand All @@ -82,58 +84,51 @@ def resize_pos(self, config, vision_path):
pos_embed_checkpoint = self.vision_tower.vision_model.embeddings.position_embedding.weight
pos_embed_checkpoint = pos_embed_checkpoint.unsqueeze(0)

if pos_embed_checkpoint.shape[1] == new_size**2 + 1:
if pos_embed_checkpoint.shape[1] == new_size ** 2 + 1:
self.is_resize_pos = True
else:
embedding_size = pos_embed_checkpoint.shape[-1]
num_extra_tokens = 1
new_num = new_size**2 + num_extra_tokens
print('Position interpolate from %dx%d to %dx%d' %
(orig_size, orig_size, new_size, new_size))
new_num = new_size ** 2 + num_extra_tokens
print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size))
extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens]
# only the position tokens are interpolated
pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:]
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size,
embedding_size).permute(
0, 3, 1, 2)
pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2)
pos_tokens = torch.nn.functional.interpolate(
pos_tokens,
size=(new_size, new_size),
mode='bicubic',
align_corners=False)
pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False
)
pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2)
new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1)

new_pos_embed = new_pos_embed.squeeze(0)

self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(
new_num, 1024)
self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(new_num, 1024)
self.vision_tower.vision_model.embeddings.position_embedding.weight = torch.nn.Parameter(
new_pos_embed.to(pos_embed_checkpoint.dtype))
self.vision_tower.vision_model.embeddings.position_ids = torch.arange(
new_num).expand((1, -1))
new_pos_embed.to(pos_embed_checkpoint.dtype)
)
self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_num).expand((1, -1))
self.is_resize_pos = True

def cuda(self):
self.vision_tower = self.vision_tower.cuda()
for i in range(len(self.projector_weights)):
self.projector_weights[i] = self.projector_weights[i].cuda()
self.device = torch.device('cuda')
return self

# batch images infer
def forward(self, x):
x = x.to(device=self.device, dtype=self.vision_tower.dtype)
x = x.cuda().to(dtype=self.vision_tower.dtype)

x = self.vision_tower(x, output_hidden_states=True)
x = x.hidden_states[self.select_layer]
if self.select_feature == 'patch':
if self.select_feature == "patch":
x = x[:, 1:].contiguous()

if len(self.projector_weights) == 0:
return x
B, L, N = x.shape

B, L, N = x.shape
x = x.view(-1, N)
# mm_project
x = F.linear(
Expand All @@ -151,16 +146,30 @@ def forward(self, x):
x = x.view(B, L, -1)
return x

def encode(self, image_items: List[Union[str, Image.Image]]):
images = []
for item in image_items:
if isinstance(item, Image.Image):
image = item
elif item.startswith("http://") or item.startswith("https://"):
image = Image.open(requests.get(item, stream=True).raw)
def encode(self, image_uuids: List):
img_tensors = []
uuids = []
valid_id = 0
valid_ids = []

for i, item in enumerate(image_uuids):
if isinstance(item, int):
uuids.append(item)
image_data = read_shm(get_shm_name_data(item))
image_data = Image.open(BytesIO(image_data)).convert("RGB")
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
img_tensors.append(t)
else:
image = Image.open(item)
image = self.image_processor(image.convert('RGB')).unsqueeze(0).to(self.device)
images.append(image)
images = torch.cat(images, dim=0)
return self.forward(images)
raise Exception("Unsupport input types: {} for {}".format(type(item), item))

cur_num = img_tensors[-1].shape[0]
valid_ids.append([valid_id, valid_id + cur_num])
valid_id += cur_num

if len(img_tensors) <= 0:
return None

img = torch.cat(img_tensors, dim=0)
all_img_embeds = self.forward(img)

return all_img_embeds, uuids, valid_ids
33 changes: 17 additions & 16 deletions lightllm/models/internvl/internvl_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,22 +8,20 @@
from torchvision import transforms as T
from torchvision.transforms.functional import InterpolationMode
from transformers import AutoModel, AutoTokenizer
import requests
from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed
import rpyc
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from io import BytesIO
from lightllm.models.internvl.img_process import load_image
from lightllm.utils.log_utils import init_logger

logger = init_logger(__name__)


class InternVLVisionModel:
def __init__(self, kvargs):
self.cache_port = kvargs["client_port"]
self.cache_client = None
def __init__(self):
pass

def load_model(self, weight_dir):
assert torch.cuda.is_available()
self.device = torch.device("cuda")
self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32
self.config = json.load(open(os.path.join(weight_dir, "config.json")))
self.model = AutoModel.from_pretrained(
Expand All @@ -37,28 +35,31 @@ def load_model(self, weight_dir):
def cuda(self):
return self

def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]):
def encode(self, image_uuids: List):
img_tensors = []
valid_ids = []
valid_id = 0
# load images to batch tensor
for i, url in enumerate(image_items):
if isinstance(url, Image.Image):
t = load_image(url, max_num=6)
uuids = []

for i, url in enumerate(image_uuids):
if isinstance(url, int):
uuids.append(url)
image_data = read_shm(get_shm_name_data(url))
image_data = Image.open(BytesIO(image_data))
t = load_image(image_data)
img_tensors.append(t)
else:
raise Exception("Unsupport input types: {} for {}".format(type(url), url))

cur_num = img_tensors[-1].shape[0]

valid_ids.append([valid_id, valid_id + cur_num])
valid_id += cur_num

if len(img_tensors) <= 0:
return None

# (b, 3, 224, 224)
imgs = torch.cat(img_tensors, dim=0)
pixel_values = imgs.to(self.device, dtype=self.dtype)
pixel_values = imgs.cuda().to(dtype=self.dtype)
all_img_embeds = self.model.extract_feature(pixel_values)
return [all_img_embeds[start:end] for start, end in valid_ids]

return all_img_embeds, uuids, valid_ids
59 changes: 41 additions & 18 deletions lightllm/models/llava/llava_visual.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
from PIL import Image
from typing import List, Union
from safetensors import safe_open
from io import BytesIO
from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data
from lightllm.utils.log_utils import init_logger


logger = init_logger(__name__)


class LlavaVisionModel:
Expand All @@ -31,6 +37,7 @@ def load_model(self, weight_dir):

def load_hf_model(self, config, weight_dir):
from transformers import AutoConfig, AutoProcessor, LlavaForConditionalGeneration

config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True)
self.select_layer = config.vision_feature_layer
self.select_feature = config.vision_feature_select_strategy
Expand All @@ -48,12 +55,16 @@ def load_hf_model(self, config, weight_dir):
self.projector_weights = {}
for f in os.listdir(weight_dir):
if f.endswith(".safetensors"):
d = safe_open(os.path.join(weight_dir, f), 'pt', 'cpu')
d = safe_open(os.path.join(weight_dir, f), "pt", "cpu")
for k in d.keys():
if "multi_modal_projector.linear_1" in k:
self.projector_weights[k.replace("multi_modal_projector.linear_1", "model.mm_projector.0")] = d.get_tensor(k).half()
self.projector_weights[
k.replace("multi_modal_projector.linear_1", "model.mm_projector.0")
] = d.get_tensor(k).half()
if "multi_modal_projector.linear_2" in k:
self.projector_weights[k.replace("multi_modal_projector.linear_2", "model.mm_projector.2")] = d.get_tensor(k).half()
self.projector_weights[
k.replace("multi_modal_projector.linear_2", "model.mm_projector.2")
] = d.get_tensor(k).half()

def load_bin_model(self, config, weight_dir):
self.select_layer = config.get("mm_vision_select_layer", -2)
Expand All @@ -68,6 +79,7 @@ def load_bin_model(self, config, weight_dir):
vision_path = os.path.join(weight_dir, vision_path)

from transformers import CLIPVisionModel, CLIPImageProcessor

self.image_processor = CLIPImageProcessor.from_pretrained(vision_path)
self.vision_tower = CLIPVisionModel.from_pretrained(vision_path).half()

Expand All @@ -84,13 +96,11 @@ def cuda(self):
self.vision_tower = self.vision_tower.cuda()
for k, v in self.projector_weights.items():
self.projector_weights[k] = v.cuda()
self.device = torch.device("cuda")
return self

# batch images infer
def forward(self, x):
x = x.half().to(device=self.device)

x = x.half().cuda()
x = self.vision_tower(x, output_hidden_states=True)
x = x.hidden_states[self.select_layer]
if self.select_feature == "patch" or self.select_feature == "default":
Expand All @@ -113,17 +123,30 @@ def forward(self, x):
x = x.view(B, L, -1)
return x

def encode(self, image_items: List[Union[str, Image.Image]]):
images = []
for item in image_items:
if isinstance(item, Image.Image):
image = item
elif item.startswith("http://") or item.startswith("https://"):
import requests
image = Image.open(requests.get(item, stream=True).raw)
def encode(self, image_uuids: List):
img_tensors = []
uuids = []
valid_id = 0
valid_ids = []

for i, item in enumerate(image_uuids):
if isinstance(item, int):
uuids.append(item)
image_data = read_shm(get_shm_name_data(item))
image_data = Image.open(BytesIO(image_data)).convert("RGB")
t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"]
img_tensors.append(t)
else:
image = Image.open(item)
images.append(image.convert("RGB"))
raise Exception("Unsupport input types: {} for {}".format(type(item), item))

cur_num = img_tensors[-1].shape[0]
valid_ids.append([valid_id, valid_id + cur_num])
valid_id += cur_num

if len(img_tensors) <= 0:
return None

img = torch.cat(img_tensors, dim=0)
all_img_embeds = self.forward(img)

images = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"]
return self.forward(images)
return all_img_embeds, uuids, valid_ids
Loading

0 comments on commit 00569d0

Please sign in to comment.