diff --git a/README.md b/README.md index 3776a1de..faf78964 100644 --- a/README.md +++ b/README.md @@ -71,7 +71,7 @@ LightLLM is a Python-based LLM (Large Language Model) inference and serving fram > InternVL-Chat(InternLM2) needs to set the parameter '--eos_id 92542 --trust_remote_code'. -> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code'. +> Qwen2-VL-7b needs to set the parameter '--eos_id 151645 --trust_remote_code', and use 'pip install git+https://github.com/huggingface/transformers' to upgrade to the latest version. > Stablelm needs to set the parameter '--trust_remote_code'. diff --git a/lightllm/models/internlm_xcomposer/internlm_visual.py b/lightllm/models/internlm_xcomposer/internlm_visual.py index 868aab0d..15133e2d 100644 --- a/lightllm/models/internlm_xcomposer/internlm_visual.py +++ b/lightllm/models/internlm_xcomposer/internlm_visual.py @@ -7,17 +7,19 @@ from typing import List, Union from torchvision import transforms from torchvision.transforms.functional import InterpolationMode +from rpyc.utils.classic import obtain +from io import BytesIO +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data class InternVisionModel: - def __init__(self): pass def load_projector_update(self, config, weight_dir): - projector_type = config.get("projector_type", 'mlp2x_gelu') + projector_type = config.get("projector_type", "mlp2x_gelu") projector_weights = [] - mlp_gelu_match = re.match(r'^mlp(\d+)x_gelu$', projector_type) + mlp_gelu_match = re.match(r"^mlp(\d+)x_gelu$", projector_type) if mlp_gelu_match: self.mlp_depth = int(mlp_gelu_match.group(1)) new_dict = {} @@ -25,54 +27,55 @@ def load_projector_update(self, config, weight_dir): if f.endswith(".bin"): d = torch.load(os.path.join(weight_dir, f), "cpu") for k, v in d.items(): - if 'vision_proj' in k: + if "vision_proj" in k: projector_weights.append(v.half()) elif "vit.vision_tower." in k: - new_dict[k[len("vit.vision_tower."):]] = v.half() + new_dict[k[len("vit.vision_tower.") :]] = v.half() self.vision_tower.load_state_dict(new_dict, strict=True) return projector_weights - if projector_type == 'identity': + if projector_type == "identity": return [] - raise ValueError(f'Unknown projector type: {projector_type}') + raise ValueError(f"Unknown projector type: {projector_type}") def load_model(self, weight_dir): config_file = os.path.join(weight_dir, "config.json") config = json.load(open(config_file)) - self.select_layer = config.get('mm_vision_select_layer', -1) - self.select_feature = config.get('mm_vision_select_feature', 'patch') + self.select_layer = config.get("mm_vision_select_layer", -1) + self.select_feature = config.get("mm_vision_select_feature", "patch") # load clip vision model by cfg['mm_vision_tower']: # huggingface_name or path_of_clip_relative_to_llava_model_dir - vision_path = config.get('mm_vision_tower', 'openai/clip-vit-large-patch14-336') + vision_path = config.get("mm_vision_tower", "openai/clip-vit-large-patch14-336") if isinstance(vision_path, list): vision_path = vision_path[0] if vision_path.startswith("./"): vision_path = os.path.join(weight_dir, vision_path) - - self.image_processor = transforms.Compose([ - transforms.Resize((config["img_size"], config["img_size"]), - interpolation=InterpolationMode.BICUBIC), - transforms.ToTensor(), - transforms.Normalize((0.48145466, 0.4578275, 0.40821073), - (0.26862954, 0.26130258, 0.27577711)), - ]) + + self.image_processor = transforms.Compose( + [ + transforms.Resize((config["img_size"], config["img_size"]), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), + ] + ) from transformers import CLIPVisionModel + self.vision_tower = CLIPVisionModel.from_pretrained(vision_path) self.vision_tower.requires_grad_(False) self.resize_pos(config, vision_path) self.projector_weights = self.load_projector_update(config, weight_dir) self.vision_tower = self.vision_tower.half() - self.device = torch.device('cpu') + self.device = torch.device("cpu") # load projector weights assert len(self.projector_weights) == self.mlp_depth * 2 def resize_pos(self, config, vision_path): - mm_vision_tower = vision_path.split('/')[-1] - vision_tower_match = re.match(r'^clip-vit-large-patch(\d+)-(\d+)$', mm_vision_tower) + mm_vision_tower = vision_path.split("/")[-1] + vision_tower_match = re.match(r"^clip-vit-large-patch(\d+)-(\d+)$", mm_vision_tower) patch_size = int(vision_tower_match.group(1)) clip_imge_size = int(vision_tower_match.group(2)) - + orig_size = clip_imge_size // patch_size new_size = config["img_size"] // patch_size if orig_size == new_size: @@ -82,58 +85,51 @@ def resize_pos(self, config, vision_path): pos_embed_checkpoint = self.vision_tower.vision_model.embeddings.position_embedding.weight pos_embed_checkpoint = pos_embed_checkpoint.unsqueeze(0) - if pos_embed_checkpoint.shape[1] == new_size**2 + 1: + if pos_embed_checkpoint.shape[1] == new_size ** 2 + 1: self.is_resize_pos = True else: embedding_size = pos_embed_checkpoint.shape[-1] num_extra_tokens = 1 - new_num = new_size**2 + num_extra_tokens - print('Position interpolate from %dx%d to %dx%d' % - (orig_size, orig_size, new_size, new_size)) + new_num = new_size ** 2 + num_extra_tokens + print("Position interpolate from %dx%d to %dx%d" % (orig_size, orig_size, new_size, new_size)) extra_tokens = pos_embed_checkpoint[:, :num_extra_tokens] # only the position tokens are interpolated pos_tokens = pos_embed_checkpoint[:, num_extra_tokens:] - pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, - embedding_size).permute( - 0, 3, 1, 2) + pos_tokens = pos_tokens.reshape(-1, orig_size, orig_size, embedding_size).permute(0, 3, 1, 2) pos_tokens = torch.nn.functional.interpolate( - pos_tokens, - size=(new_size, new_size), - mode='bicubic', - align_corners=False) + pos_tokens, size=(new_size, new_size), mode="bicubic", align_corners=False + ) pos_tokens = pos_tokens.permute(0, 2, 3, 1).flatten(1, 2) new_pos_embed = torch.cat((extra_tokens, pos_tokens), dim=1) new_pos_embed = new_pos_embed.squeeze(0) - self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding( - new_num, 1024) + self.vision_tower.vision_model.embeddings.position_embedding = torch.nn.Embedding(new_num, 1024) self.vision_tower.vision_model.embeddings.position_embedding.weight = torch.nn.Parameter( - new_pos_embed.to(pos_embed_checkpoint.dtype)) - self.vision_tower.vision_model.embeddings.position_ids = torch.arange( - new_num).expand((1, -1)) + new_pos_embed.to(pos_embed_checkpoint.dtype) + ) + self.vision_tower.vision_model.embeddings.position_ids = torch.arange(new_num).expand((1, -1)) self.is_resize_pos = True def cuda(self): self.vision_tower = self.vision_tower.cuda() for i in range(len(self.projector_weights)): self.projector_weights[i] = self.projector_weights[i].cuda() - self.device = torch.device('cuda') return self # batch images infer def forward(self, x): - x = x.to(device=self.device, dtype=self.vision_tower.dtype) + x = x.cuda().to(dtype=self.vision_tower.dtype) x = self.vision_tower(x, output_hidden_states=True) x = x.hidden_states[self.select_layer] - if self.select_feature == 'patch': + if self.select_feature == "patch": x = x[:, 1:].contiguous() - + if len(self.projector_weights) == 0: return x - - B, L, N = x.shape + + B, L, N = x.shape x = x.view(-1, N) # mm_project x = F.linear( @@ -151,16 +147,31 @@ def forward(self, x): x = x.view(B, L, -1) return x - def encode(self, image_items: List[Union[str, Image.Image]]): - images = [] - for item in image_items: - if isinstance(item, Image.Image): - image = item - elif item.startswith("http://") or item.startswith("https://"): - image = Image.open(requests.get(item, stream=True).raw) + def encode(self, image_uuids: List): + img_tensors = [] + uuids = [] + valid_id = 0 + valid_ids = [] + + for i, item in enumerate(image_uuids): + item = obtain(item) + if isinstance(item, int): + uuids.append(item) + image_data = read_shm(get_shm_name_data(item)) + image_data = Image.open(BytesIO(image_data)).convert("RGB") + t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] + img_tensors.append(t) else: - image = Image.open(item) - image = self.image_processor(image.convert('RGB')).unsqueeze(0).to(self.device) - images.append(image) - images = torch.cat(images, dim=0) - return self.forward(images) \ No newline at end of file + raise Exception("Unsupport input types: {} for {}".format(type(item), item)) + + cur_num = img_tensors[-1].shape[0] + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + img = torch.cat(img_tensors, dim=0) + all_img_embeds = self.forward(img) + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/internvl/internvl_visual.py b/lightllm/models/internvl/internvl_visual.py index b773141e..f0b31456 100644 --- a/lightllm/models/internvl/internvl_visual.py +++ b/lightllm/models/internvl/internvl_visual.py @@ -8,22 +8,21 @@ from torchvision import transforms as T from torchvision.transforms.functional import InterpolationMode from transformers import AutoModel, AutoTokenizer -import requests -from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed -import rpyc +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO from lightllm.models.internvl.img_process import load_image +from rpyc.utils.classic import obtain +from lightllm.utils.log_utils import init_logger + +logger = init_logger(__name__) class InternVLVisionModel: - def __init__(self, kvargs): - self.cache_port = kvargs["client_port"] - self.cache_client = None + def __init__(self): pass def load_model(self, weight_dir): assert torch.cuda.is_available() - self.device = torch.device("cuda") self.dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float32 self.config = json.load(open(os.path.join(weight_dir, "config.json"))) self.model = AutoModel.from_pretrained( @@ -37,28 +36,32 @@ def load_model(self, weight_dir): def cuda(self): return self - def encode(self, image_items: List[Union[str, torch.Tensor, Image.Image]]): + def encode(self, image_uuids: List): img_tensors = [] valid_ids = [] valid_id = 0 - # load images to batch tensor - for i, url in enumerate(image_items): - if isinstance(url, Image.Image): - t = load_image(url, max_num=6) + uuids = [] + + for i, url in enumerate(image_uuids): + url = obtain(url) + if isinstance(url, int): + uuids.append(url) + image_data = read_shm(get_shm_name_data(url)) + image_data = Image.open(BytesIO(image_data)) + t = load_image(image_data) img_tensors.append(t) else: raise Exception("Unsupport input types: {} for {}".format(type(url), url)) cur_num = img_tensors[-1].shape[0] - valid_ids.append([valid_id, valid_id + cur_num]) valid_id += cur_num if len(img_tensors) <= 0: return None - # (b, 3, 224, 224) imgs = torch.cat(img_tensors, dim=0) - pixel_values = imgs.to(self.device, dtype=self.dtype) + pixel_values = imgs.cuda().to(dtype=self.dtype) all_img_embeds = self.model.extract_feature(pixel_values) - return [all_img_embeds[start:end] for start, end in valid_ids] + + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/llava/llava_visual.py b/lightllm/models/llava/llava_visual.py index 776f2959..42deffa2 100644 --- a/lightllm/models/llava/llava_visual.py +++ b/lightllm/models/llava/llava_visual.py @@ -5,6 +5,13 @@ from PIL import Image from typing import List, Union from safetensors import safe_open +from rpyc.utils.classic import obtain +from io import BytesIO +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data +from lightllm.utils.log_utils import init_logger + + +logger = init_logger(__name__) class LlavaVisionModel: @@ -31,6 +38,7 @@ def load_model(self, weight_dir): def load_hf_model(self, config, weight_dir): from transformers import AutoConfig, AutoProcessor, LlavaForConditionalGeneration + config = AutoConfig.from_pretrained(weight_dir, trust_remote_code=True) self.select_layer = config.vision_feature_layer self.select_feature = config.vision_feature_select_strategy @@ -48,12 +56,16 @@ def load_hf_model(self, config, weight_dir): self.projector_weights = {} for f in os.listdir(weight_dir): if f.endswith(".safetensors"): - d = safe_open(os.path.join(weight_dir, f), 'pt', 'cpu') + d = safe_open(os.path.join(weight_dir, f), "pt", "cpu") for k in d.keys(): if "multi_modal_projector.linear_1" in k: - self.projector_weights[k.replace("multi_modal_projector.linear_1", "model.mm_projector.0")] = d.get_tensor(k).half() + self.projector_weights[ + k.replace("multi_modal_projector.linear_1", "model.mm_projector.0") + ] = d.get_tensor(k).half() if "multi_modal_projector.linear_2" in k: - self.projector_weights[k.replace("multi_modal_projector.linear_2", "model.mm_projector.2")] = d.get_tensor(k).half() + self.projector_weights[ + k.replace("multi_modal_projector.linear_2", "model.mm_projector.2") + ] = d.get_tensor(k).half() def load_bin_model(self, config, weight_dir): self.select_layer = config.get("mm_vision_select_layer", -2) @@ -68,6 +80,7 @@ def load_bin_model(self, config, weight_dir): vision_path = os.path.join(weight_dir, vision_path) from transformers import CLIPVisionModel, CLIPImageProcessor + self.image_processor = CLIPImageProcessor.from_pretrained(vision_path) self.vision_tower = CLIPVisionModel.from_pretrained(vision_path).half() @@ -84,13 +97,11 @@ def cuda(self): self.vision_tower = self.vision_tower.cuda() for k, v in self.projector_weights.items(): self.projector_weights[k] = v.cuda() - self.device = torch.device("cuda") return self # batch images infer def forward(self, x): - x = x.half().to(device=self.device) - + x = x.half().cuda() x = self.vision_tower(x, output_hidden_states=True) x = x.hidden_states[self.select_layer] if self.select_feature == "patch" or self.select_feature == "default": @@ -113,17 +124,31 @@ def forward(self, x): x = x.view(B, L, -1) return x - def encode(self, image_items: List[Union[str, Image.Image]]): - images = [] - for item in image_items: - if isinstance(item, Image.Image): - image = item - elif item.startswith("http://") or item.startswith("https://"): - import requests - image = Image.open(requests.get(item, stream=True).raw) + def encode(self, image_uuids: List): + img_tensors = [] + uuids = [] + valid_id = 0 + valid_ids = [] + + for i, item in enumerate(image_uuids): + item = obtain(item) + if isinstance(item, int): + uuids.append(item) + image_data = read_shm(get_shm_name_data(item)) + image_data = Image.open(BytesIO(image_data)).convert("RGB") + t = self.image_processor.preprocess(image_data, return_tensors="pt")["pixel_values"] + img_tensors.append(t) else: - image = Image.open(item) - images.append(image.convert("RGB")) + raise Exception("Unsupport input types: {} for {}".format(type(item), item)) + + cur_num = img_tensors[-1].shape[0] + valid_ids.append([valid_id, valid_id + cur_num]) + valid_id += cur_num + + if len(img_tensors) <= 0: + return None + + img = torch.cat(img_tensors, dim=0) + all_img_embeds = self.forward(img) - images = self.image_processor.preprocess(images, return_tensors="pt")["pixel_values"] - return self.forward(images) + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/qwen2_vl/qwen2_visual.py b/lightllm/models/qwen2_vl/qwen2_visual.py index eced443a..dbe4cb18 100644 --- a/lightllm/models/qwen2_vl/qwen2_visual.py +++ b/lightllm/models/qwen2_vl/qwen2_visual.py @@ -28,10 +28,9 @@ from torchvision import transforms as T from torchvision.transforms.functional import InterpolationMode from transformers import AutoModel, AutoTokenizer -import requests -from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed -import rpyc +from lightllm.server.embed_cache.utils import read_shm, get_shm_name_data from io import BytesIO +from rpyc.utils.classic import obtain from transformers.configuration_utils import PretrainedConfig from transformers.utils import logging from transformers.modeling_utils import PreTrainedModel @@ -112,7 +111,7 @@ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: target_dtype = self.proj.weight.dtype hidden_states = hidden_states.view( -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size - ) + ).cuda() hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim) return hidden_states @@ -382,11 +381,11 @@ def rot_pos_emb(self, grid_thw): def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor) -> torch.Tensor: hidden_states = hidden_states.to( dtype=self.get_dtype(), - device=torch.device("cuda"), + device=self.device, ) grid_thw = grid_thw.to( dtype=torch.int32, - device=torch.device("cuda"), + device=self.device, ) hidden_states = self.patch_embed(hidden_states) @@ -427,16 +426,22 @@ def load_model(self, weight_dir): self.load_state_dict(weight_dict) - def encode(self, image_items: List[Union[str, Image.Image]]): + def encode(self, image_uuids: List): img_tensors = [] valid_ids = [] valid_id = 0 img_grids = [] - for i, url in enumerate(image_items): - if isinstance(url, Image.Image): - t = get_image(url) - image_inputs = self.processor.preprocess(images=t, return_tensors="pt") - pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16, device="cuda") + uuids = [] + + for i, url in enumerate(image_uuids): + url = obtain(url) + if isinstance(url, int): + uuids.append(url) + image_data = read_shm(get_shm_name_data(url)) + image_data = Image.open(BytesIO(image_data)) + image_data = get_image(image_data) + image_inputs = self.processor.preprocess(images=image_data, return_tensors="pt") + pixel_values = image_inputs["pixel_values"].to(dtype=torch.bfloat16) image_grid_thw = image_inputs["image_grid_thw"] img_tensors.append(pixel_values) img_grids.append(image_grid_thw) @@ -455,10 +460,10 @@ def encode(self, image_items: List[Union[str, Image.Image]]): imgs = torch.cat(img_tensors, dim=0) grid_thw = torch.cat(img_grids, dim=0) - pixel_values = imgs.to(self.device, dtype=torch.float32) - image_grid_thw = grid_thw.to(self.device) + pixel_values = imgs.cuda().to(dtype=torch.float32) + image_grid_thw = grid_thw.cuda() pixel_values = pixel_values.type(self.get_dtype()) - all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw).to(self.device) + all_img_embeds = self.forward(pixel_values, grid_thw=image_grid_thw) - return [all_img_embeds[start:end] for start, end in valid_ids] + return all_img_embeds, uuids, valid_ids diff --git a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py index 616703e3..6515ddc4 100644 --- a/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py +++ b/lightllm/models/qwen_vl/layer_infer/pre_layer_infer.py @@ -38,6 +38,9 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_start_loc = 0 img_start_locs = [] + device = layer_weight.wte_weight_.device + dtype = layer_weight.wte_weight_.dtype + hidden_size = layer_weight.wte_weight_.shape[1] for batch_id, p in enumerate(infer_state.multimodal_params): for img in p["images"]: # skip the same image @@ -50,10 +53,6 @@ def context_forward(self, input_ids, infer_state: LlamaInferStateInfo, layer_wei img_token_lens.append(img["token_num"]) img_start_locs.append(img_start_loc) img_start_loc += img["token_num"] - - device = layer_weight.wte_weight_.device - dtype = layer_weight.wte_weight_.dtype - hidden_size = layer_weight.wte_weight_.shape[1] out = torch.zeros((len(input_ids), hidden_size), dtype=dtype, device=device) if len(img_weight) > 0: img_weight = torch.cat(img_weight, dim=0).to(device=device, dtype=dtype) diff --git a/lightllm/models/qwen_vl/model.py b/lightllm/models/qwen_vl/model.py index aa94d0b1..8eb91370 100644 --- a/lightllm/models/qwen_vl/model.py +++ b/lightllm/models/qwen_vl/model.py @@ -8,7 +8,6 @@ # Warp of the origal tokenizer class QWenVLTokenizer: - def __init__(self, tokenizer, model_cfg): self.tokenizer = tokenizer # : 151857 @@ -18,7 +17,7 @@ def __init__(self, tokenizer, model_cfg): self.image_end_tag = tokenizer.image_end_tag self.image_end_id = tokenizer.img_end_id # : 151859 - self.image_length = model_cfg['visual'].get("n_queries", 256) + self.image_length = model_cfg["visual"].get("n_queries", 256) def _list_find(self, input_list, target, start_idx): cur_list = input_list[start_idx:] @@ -34,18 +33,18 @@ def _format_prompt(self, prompt): parts = prompt.split(self.image_start_tag) prompt = parts[0] for idx, part in enumerate(parts[1:]): - prompt += f'Picture {idx + 1}:' + self.image_start_tag + part + prompt += f"Picture {idx + 1}:" + self.image_start_tag + part parts = prompt.split(self.image_end_tag) prompt = parts[0] for part in parts[1:]: - prompt += self.image_end_tag + '\n' + part + prompt += self.image_end_tag + "\n" + part return prompt # only change the impl of the encode func: def encode(self, prompt, multimodal_params: MultimodalParams = None): prompt = unicodedata.normalize("NFC", prompt) prompt = self._format_prompt(prompt) - origin_ids = self.tokenizer.tokenizer.encode(prompt, allowed_special='all', disallowed_special=()) + origin_ids = self.tokenizer.tokenizer.encode(prompt, allowed_special="all", disallowed_special=()) input_ids = [] image_id = 0 @@ -55,7 +54,7 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None): start = self._list_find(origin_ids, self.image_start_id, end) if start == -1: break - input_ids.extend(origin_ids[end: start]) + input_ids.extend(origin_ids[end:start]) end = self._list_find(origin_ids, self.image_end_id, start) if end == -1: raise ValueError("Unclosed image token") @@ -70,14 +69,14 @@ def encode(self, prompt, multimodal_params: MultimodalParams = None): end += 1 image_id += 1 - input_ids.extend(origin_ids[end: ]) + input_ids.extend(origin_ids[end:]) if multimodal_params: image_cnt = len(multimodal_params.images) assert image_cnt == image_id, "invalid image tag num: {} vs {}!".format(image_cnt, image_id) return input_ids def __getattr__(self, name): - if name != 'encode': + if name != "encode": return getattr(self.tokenizer, name) return self.encode diff --git a/lightllm/models/qwen_vl/qwen_visual.py b/lightllm/models/qwen_vl/qwen_visual.py index fa372d41..5cd6b54f 100644 --- a/lightllm/models/qwen_vl/qwen_visual.py +++ b/lightllm/models/qwen_vl/qwen_visual.py @@ -11,13 +11,15 @@ from PIL import Image from typing import Callable, Optional, Sequence, Tuple, List, Union import numpy as np - +import rpyc +from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed import torch from torch import nn from torch.nn import functional as F from torch.nn.init import trunc_normal_ from torchvision import transforms from torchvision.transforms import InterpolationMode +from rpyc.utils.classic import obtain def get_abs_pos(abs_pos, tgt_size): @@ -29,15 +31,21 @@ def get_abs_pos(abs_pos, tgt_size): dtype = abs_pos.dtype if src_size != tgt_size: - return F.interpolate( - abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), - size=(tgt_size, tgt_size), - mode="bicubic", - align_corners=False, - ).permute(0, 2, 3, 1).flatten(0, 2).to(dtype=dtype) + return ( + F.interpolate( + abs_pos.float().reshape(1, src_size, src_size, -1).permute(0, 3, 1, 2), + size=(tgt_size, tgt_size), + mode="bicubic", + align_corners=False, + ) + .permute(0, 2, 3, 1) + .flatten(0, 2) + .to(dtype=dtype) + ) else: return abs_pos + # https://github.com/facebookresearch/mae/blob/efb2a8062c206524e35e47d04501ed4f544c0ae8/util/pos_embed.py#L20 def get_2d_sincos_pos_embed(embed_dim, grid_size, cls_token=False): """ @@ -64,7 +72,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid): emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2) emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2) - emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) + emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D) return emb @@ -76,14 +84,14 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): """ assert embed_dim % 2 == 0 omega = np.arange(embed_dim // 2, dtype=np.float32) - omega /= embed_dim / 2. - omega = 1. / 10000**omega # (D/2,) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000 ** omega # (D/2,) pos = pos.reshape(-1) # (M,) - out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product - emb_sin = np.sin(out) # (M, D/2) - emb_cos = np.cos(out) # (M, D/2) + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D) return emb @@ -96,14 +104,8 @@ class Resampler(nn.Module): Outputs: A tensor with the shape of (grid_size**2, embed_dim) """ - def __init__( - self, - grid_size, - embed_dim, - num_heads, - kv_dim=None, - norm_layer=nn.LayerNorm - ): + + def __init__(self, grid_size, embed_dim, num_heads, kv_dim=None, norm_layer=nn.LayerNorm): super().__init__() self.num_queries = grid_size ** 2 self.embed_dim = embed_dim @@ -114,7 +116,7 @@ def __init__( ).requires_grad_(False) self.query = nn.Parameter(torch.zeros(self.num_queries, embed_dim)) - trunc_normal_(self.query, std=.02) + trunc_normal_(self.query, std=0.02) if kv_dim is not None and kv_dim != embed_dim: self.kv_proj = nn.Linear(kv_dim, embed_dim, bias=False) @@ -124,12 +126,12 @@ def __init__( self.attn = nn.MultiheadAttention(embed_dim, num_heads) self.ln_q = norm_layer(embed_dim) self.ln_kv = norm_layer(embed_dim) - + # self.apply(self._init_weights) def _init_weights(self, m): if isinstance(m, nn.Linear): - trunc_normal_(m.weight, std=.02) + trunc_normal_(m.weight, std=0.02) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): @@ -146,10 +148,8 @@ def forward(self, x, attn_mask=None): N = x.shape[1] q = self.ln_q(self.query) out = self.attn( - self._repeat(q, N) + self.pos_embed.unsqueeze(1), - x + pos_embed.unsqueeze(1), - x, - attn_mask=attn_mask)[0] + self._repeat(q, N) + self.pos_embed.unsqueeze(1), x + pos_embed.unsqueeze(1), x, attn_mask=attn_mask + )[0] return out.permute(1, 0, 2) def _repeat(self, query, N: int): @@ -163,8 +163,7 @@ class VisualAttention(nn.Module): and returns output of the same size. """ - def __init__(self, embed_dim, num_heads, - bias=True, kdim=None, vdim=None): + def __init__(self, embed_dim, num_heads, bias=True, kdim=None, vdim=None): super(VisualAttention, self).__init__() self.embed_dim = embed_dim self.kdim = kdim if kdim is not None else embed_dim @@ -180,37 +179,37 @@ def __init__(self, embed_dim, num_heads, self.hidden_size_per_partition = embed_dim # Strided linear layer. - assert self._qkv_same_embed_dim, 'Only Support SelfAttention Currently' + assert self._qkv_same_embed_dim, "Only Support SelfAttention Currently" self.in_proj = nn.Linear(embed_dim, 3 * embed_dim) self.out_proj = nn.Linear(embed_dim, embed_dim) self.norm_factor = math.sqrt(self.hidden_size_per_attention_head) - def forward(self, query, key, value, attn_mask = None): + def forward(self, query, key, value, attn_mask=None): # query/key/value: [sq, b, h] sq, b, _ = query.size() - assert torch.allclose(query, key), 'Only Support Self-Attention Currently' + assert torch.allclose(query, key), "Only Support Self-Attention Currently" sk = sq mixed_x_layer = self.in_proj(query) # [sq, b, (np * 3 * hn)] --> [sq, b, np, 3 * hn] - new_tensor_shape = mixed_x_layer.size()[:-1] + \ - (self.num_attention_heads_per_partition, - 3 * self.hidden_size_per_attention_head) + new_tensor_shape = mixed_x_layer.size()[:-1] + ( + self.num_attention_heads_per_partition, + 3 * self.hidden_size_per_attention_head, + ) mixed_x_layer = mixed_x_layer.view(*new_tensor_shape) # [sq, b, np, 3 * hn] --> 3 [sq, b, np, hn] - query_layer, key_layer, value_layer = mixed_x_layer.split( - self.hidden_size_per_attention_head, dim=-1) + query_layer, key_layer, value_layer = mixed_x_layer.split(self.hidden_size_per_attention_head, dim=-1) # [sq, b, np, hn] -> [sq, b * np, hn] - query_layer = query_layer.view(sq, - b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + query_layer = query_layer.view( + sq, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(0, 1) # [sk, b, np, hn] -> [sk, b * np, hn] - key_layer = key_layer.view(sk, - b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + key_layer = key_layer.view( + sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(0, 1) q_scaled = query_layer / self.norm_factor if attn_mask is not None: @@ -219,24 +218,23 @@ def forward(self, query, key, value, attn_mask = None): attention_probs = torch.bmm(q_scaled, key_layer.transpose(-2, -1)) attention_probs = attention_probs.softmax(dim=-1) - value_layer = value_layer.view(sk, - b * self.num_attention_heads_per_partition, - self.hidden_size_per_attention_head).transpose(0, 1) + value_layer = value_layer.view( + sk, b * self.num_attention_heads_per_partition, self.hidden_size_per_attention_head + ).transpose(0, 1) # matmul: [b * np, sq, hn] context_layer = torch.bmm(attention_probs, value_layer) # change view [b, np, sq, hn] - context_layer = context_layer.view(b, - self.num_attention_heads_per_partition, - sq, self.hidden_size_per_attention_head) + context_layer = context_layer.view( + b, self.num_attention_heads_per_partition, sq, self.hidden_size_per_attention_head + ) # [b, np, sq, hn] --> [sq, b, np, hn] context_layer = context_layer.permute(2, 0, 1, 3).contiguous() # [sq, b, np, hn] --> [sq, b, hp] - new_context_layer_shape = context_layer.size()[:-2] + \ - (self.hidden_size_per_partition,) + new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_partition,) context_layer = context_layer.view(*new_context_layer_shape) output = self.out_proj(context_layer) @@ -246,13 +244,13 @@ def forward(self, query, key, value, attn_mask = None): class VisualAttentionBlock(nn.Module): def __init__( - self, - d_model: int, - n_head: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, - is_cross_attention: bool = False, + self, + d_model: int, + n_head: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, + is_cross_attention: bool = False, ): super().__init__() @@ -263,18 +261,22 @@ def __init__( self.ln_2 = norm_layer(d_model) mlp_width = int(d_model * mlp_ratio) self.attn = VisualAttention(d_model, n_head) - self.mlp = nn.Sequential(OrderedDict([ - ("c_fc", nn.Linear(d_model, mlp_width)), - ("gelu", act_layer()), - ("c_proj", nn.Linear(mlp_width, d_model)) - ])) + self.mlp = nn.Sequential( + OrderedDict( + [ + ("c_fc", nn.Linear(d_model, mlp_width)), + ("gelu", act_layer()), + ("c_proj", nn.Linear(mlp_width, d_model)), + ] + ) + ) def attention( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): k_x = k_x if k_x is not None else q_x v_x = v_x if v_x is not None else q_x @@ -283,11 +285,11 @@ def attention( return self.attn(q_x, k_x, v_x, attn_mask=attn_mask) def forward( - self, - q_x: torch.Tensor, - k_x: Optional[torch.Tensor] = None, - v_x: Optional[torch.Tensor] = None, - attn_mask: Optional[torch.Tensor] = None, + self, + q_x: torch.Tensor, + k_x: Optional[torch.Tensor] = None, + v_x: Optional[torch.Tensor] = None, + attn_mask: Optional[torch.Tensor] = None, ): k_x = self.ln_1_kv(k_x) if hasattr(self, "ln_1_kv") and k_x is not None else None v_x = self.ln_1_kv(v_x) if hasattr(self, "ln_1_kv") and v_x is not None else None @@ -299,23 +301,24 @@ def forward( class TransformerBlock(nn.Module): def __init__( - self, - width: int, - layers: int, - heads: int, - mlp_ratio: float = 4.0, - act_layer: Callable = nn.GELU, - norm_layer: Callable = nn.LayerNorm, + self, + width: int, + layers: int, + heads: int, + mlp_ratio: float = 4.0, + act_layer: Callable = nn.GELU, + norm_layer: Callable = nn.LayerNorm, ): super().__init__() self.width = width self.layers = layers - self.resblocks = nn.ModuleList([ - VisualAttentionBlock( - width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) - for _ in range(layers) - ]) + self.resblocks = nn.ModuleList( + [ + VisualAttentionBlock(width, heads, mlp_ratio, act_layer=act_layer, norm_layer=norm_layer) + for _ in range(layers) + ] + ) def get_cast_dtype(self) -> torch.dtype: return self.resblocks[0].mlp.c_fc.weight.dtype @@ -330,18 +333,17 @@ def forward(self, x: torch.Tensor, attn_mask: Optional[torch.Tensor] = None): class QWenVisionTransformer(nn.Module): - def __init__( - self, - image_size: int, - patch_size: int, - width: int, - layers: int, - heads: int, - mlp_ratio: float, - n_queries: int = 256, - output_dim: int = 512, - **kwargs + self, + image_size: int, + patch_size: int, + width: int, + layers: int, + heads: int, + mlp_ratio: float, + n_queries: int = 256, + output_dim: int = 512, + **kwargs, ): super().__init__() image_height, image_width = self.image_size = (image_size, image_size) @@ -351,14 +353,13 @@ def __init__( mean = (0.48145466, 0.4578275, 0.40821073) std = (0.26862954, 0.26130258, 0.27577711) - self.image_transform = transforms.Compose([ - transforms.Resize( - (image_size, image_size), - interpolation=InterpolationMode.BICUBIC - ), - transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), - ]) + self.image_transform = transforms.Compose( + [ + transforms.Resize((image_size, image_size), interpolation=InterpolationMode.BICUBIC), + transforms.ToTensor(), + transforms.Normalize(mean=mean, std=std), + ] + ) self.conv1 = nn.Conv2d(in_channels=3, out_channels=width, kernel_size=patch_size, stride=patch_size, bias=False) @@ -387,7 +388,7 @@ def __init__( norm_layer=norm_layer, ) self.ln_post = norm_layer(output_dim) - self.proj = nn.Parameter((output_dim** -0.5) * torch.randn(output_dim, output_dim)) + self.proj = nn.Parameter((output_dim ** -0.5) * torch.randn(output_dim, output_dim)) def forward(self, x: torch.Tensor): x = x.to( @@ -413,28 +414,42 @@ def forward(self, x: torch.Tensor): x = x.to(dtype=torch.float16) return x - - def encode(self, image_items: List[Union[str, Image.Image]]): - images = [] - for item in image_items: - if isinstance(item, Image.Image): - image = item - elif item.startswith("http://") or item.startswith("https://"): - image = Image.open(requests.get(item, stream=True).raw) + + def encode(self, image_uuids: List): + img_tensors = [] + uuids = [] + valid_id = 0 + valid_ids = [] + + for i, item in enumerate(image_uuids): + item = obtain(item) + if isinstance(item, int): + uuids.append(item) + image_data = read_shm(get_shm_name_data(item)) + image_data = Image.open(BytesIO(image_data)).convert("RGB") + t = self.image_transform(image_data) + img_tensors.append(t) else: - image = Image.open(item) - image = image.convert("RGB") - images.append(self.image_transform(image)) - images = torch.stack(images, dim=0) - return self(images) - + raise Exception("Unsupport input types: {} for {}".format(type(item), item)) + + valid_ids.append([valid_id, valid_id + 1]) + valid_id += 1 + if len(img_tensors) <= 0: + return None + + pixel_values = torch.stack(img_tensors, dim=0) + all_img_embeds = self(pixel_values) + + return all_img_embeds, uuids, valid_ids + def load_model(self, weight_dir): import os - weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin") ] + + weight_files = [file_ for file_ in os.listdir(weight_dir) if file_.endswith(".bin")] weight_dict = {} for file_ in weight_files: f_weight_dict = torch.load(os.path.join(weight_dir, file_), "cpu") for k, v in f_weight_dict.items(): if "visual" in k: - weight_dict[k[len("transformer.visual."):]] = v + weight_dict[k[len("transformer.visual.") :]] = v self.load_state_dict(weight_dict) diff --git a/lightllm/server/api_server.py b/lightllm/server/api_server.py index f3df77ab..b43e0f28 100755 --- a/lightllm/server/api_server.py +++ b/lightllm/server/api_server.py @@ -476,6 +476,21 @@ def make_argument_parser() -> argparse.ArgumentParser: "--grouping_key", action="append", default=[], help="grouping_key for the monitor in the form key=value" ) parser.add_argument("--push_interval", type=int, default=10, help="interval of pushing monitoring metrics") + parser.add_argument( + "--visual_infer_batch_size", type=int, default=4, help="number of images to process in each inference batch" + ) + parser.add_argument( + "--visual_gpu_ids", nargs="+", type=int, default=[0], help="List of GPU IDs to use, e.g., 0 1 2" + ) + parser.add_argument("--visual_tp", type=int, default=1, help="number of tensort parallel instances for ViT") + parser.add_argument("--visual_dp", type=int, default=1, help="number of data parallel instances for ViT") + parser.add_argument( + "--visual_nccl_ports", + nargs="+", + type=int, + default=[29500], + help="List of NCCL ports to build a distributed environment for Vit, e.g., 29500 29501 29502", + ) parser.add_argument( "--enable_monitor_auth", action="store_true", help="Whether to open authentication for push_gateway" ) @@ -534,6 +549,24 @@ def main(): if args.beam_mode or args.diverse_mode: assert args.router_token_ratio == 0.0 + # 检查GPU数量是否足够 + total_required_gpus = args.visual_dp * args.visual_tp + if len(args.visual_gpu_ids) < total_required_gpus: + raise ValueError( + f"Not enough GPUs specified. You need at least {total_required_gpus}, but got {len(args.visual_gpu_ids)}." + ) + else: + args.visual_gpu_ids = args.visual_gpu_ids[:total_required_gpus] + + # 检查visual_nccl_port数量是否足够 + if len(args.visual_nccl_ports) < args.visual_dp: + raise ValueError( + f"Not enough visual_nccl_ports specified. You need at least {args.visual_dp}, " + f"but got ({len(args.visual_nccl_ports)})." + ) + else: + args.visual_nccl_ports = args.visual_nccl_ports[: args.visual_dp] + if not args.splitfuse_mode: # 普通模式下 if args.batch_max_tokens is None: @@ -570,9 +603,19 @@ def main(): logger.info(f"all start args:{args}") - can_use_ports = alloc_can_use_network_port(num=6 + args.tp, used_nccl_port=args.nccl_port) + already_uesd_ports = args.visual_nccl_ports + [args.nccl_port] + can_use_ports = alloc_can_use_network_port( + num=6 + args.tp + args.visual_dp * args.visual_tp, used_nccl_ports=already_uesd_ports + ) router_port, detokenization_port, httpserver_port, visual_port, cache_port, metric_port = can_use_ports[0:6] - model_rpc_ports = can_use_ports[6:] + model_rpc_ports = can_use_ports[6 : 6 + args.tp] + can_use_ports = can_use_ports[6 + args.tp :] + + visual_model_tp_ports = [] + for _ in range(args.visual_dp): + tp_ports_for_dp = can_use_ports[0 : args.visual_tp] + can_use_ports = can_use_ports[args.visual_tp :] + visual_model_tp_ports.append(tp_ports_for_dp) if args.enable_multimodal: start_submodule_processes( @@ -586,7 +629,7 @@ def main(): start_visual_process, ], start_args=[ - (args, router_port, visual_port, cache_port), + (args, router_port, visual_port, cache_port, visual_model_tp_ports), ], ) @@ -617,7 +660,6 @@ def main(): (args, detokenization_port, httpserver_port), ], ) - if "s3://" in args.model_dir: from lightllm.utils.petrel_helper import s3_model_clear diff --git a/lightllm/server/visualserver/manager.py b/lightllm/server/visualserver/manager.py index f3158dcb..228f7561 100644 --- a/lightllm/server/visualserver/manager.py +++ b/lightllm/server/visualserver/manager.py @@ -23,8 +23,8 @@ def __init__( args, router_port, visual_port, - client_port, - infer_batch_size=4, + cache_port, + visual_model_rpc_ports, ): context = zmq.asyncio.Context(2) self.send_to_router = context.socket(zmq.PUSH) @@ -32,33 +32,45 @@ def __init__( self.recv_from_httpserver = context.socket(zmq.PULL) self.recv_from_httpserver.bind(f"tcp://127.0.0.1:{visual_port}") - self.cache_client = rpyc.connect("localhost", client_port) - self.client_port = client_port + self.cache_client = rpyc.connect("localhost", cache_port) + self.cache_port = cache_port self.waiting_reqs = [] self.model_weightdir = args.model_dir self.tp_world_size = args.tp - self.world_size = 1 - self.infer_batch_size = infer_batch_size + self.vit_dp = args.visual_dp + self.vit_tp = args.visual_tp + self.infer_batch_size = args.visual_infer_batch_size self.trust_remote_code = args.trust_remote_code self.args = args + self.visual_model_rpc_ports = visual_model_rpc_ports async def wait_to_model_ready(self): - self.model_rpcs: List[VisualModelRpcClient] = [] - for rank_id in range(self.world_size): - rpc_model = await start_model_process(world_size=self.world_size) - self.model_rpcs.append(rpc_model) + self.model_rpcs: List[List[VisualModelRpcClient]] = [[] for _ in range(self.vit_dp)] + + for dp_rank_id in range(self.vit_dp): + tp_ports_each_dp = self.visual_model_rpc_ports[dp_rank_id] + for tp_rank_id in range(self.vit_tp): + rpc_model = await start_model_process(port=tp_ports_each_dp[tp_rank_id], vit_tp=self.vit_tp) + self.model_rpcs[dp_rank_id].append(rpc_model) init_model_ret = [] - for rank_id in range(self.world_size): # async init model process - kvargs = { - "weight_dir": self.model_weightdir, - "trust_remote_code": self.trust_remote_code, - "client_port": self.client_port, - "rank_id": rank_id, - "data_type": self.args.data_type, - } - init_model_ret.append(self.model_rpcs[rank_id].init_model(kvargs)) + for dp_rank_id in range(self.vit_dp): # async init model process + for tp_rank_id in range(self.vit_tp): + kvargs = { + "weight_dir": self.model_weightdir, + "trust_remote_code": self.trust_remote_code, + "vit_dp": self.vit_dp, + "vit_tp": self.vit_tp, + "cache_port": self.cache_port, + "tp_rank_id": tp_rank_id, + "dp_rank_id": dp_rank_id, + "vit_rank_id": dp_rank_id * self.vit_tp + tp_rank_id, + "data_type": self.args.data_type, + "visual_nccl_port": self.args.visual_nccl_ports[dp_rank_id], + "visual_gpu_ids": self.args.visual_gpu_ids, + } + init_model_ret.append(self.model_rpcs[dp_rank_id][tp_rank_id].init_model(kvargs)) await asyncio.gather(*init_model_ret) return @@ -73,25 +85,16 @@ async def infer_imgs(self, uuids): if len(uuids) == 0: return # uuids -> PIL Images - images = [] - for uid in uuids: - image_data = read_shm(get_shm_name_data(uid)) - images.append(Image.open(BytesIO(image_data))) - # print(" + got pil image:", images[-1].size, images[-1].mode) - rets = [self.model_rpcs[tp_rank].encode(images) for tp_rank in range(self.world_size)] - ans = await asyncio.gather(*rets) - if self.world_size != 1: - img_embed = obtain(ans[0]) - else: - img_embed = ans[0] - torch.cuda.synchronize() - # b = time.time() - for i in range(len(uuids)): - # print(" + set_item_embed:", uuids[i], img_embed[i].shape) - if not self.cache_client.root.get_item_embed(uuids[i]): - cur_embed_bytes = tensor2bytes(img_embed[i]) - create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) - self.cache_client.root.set_item_embed(uuids[i]) + tasks = [] + for vit_dp_rank in range(self.vit_dp): + assigned_uuids = [uuids[i] for i in range(vit_dp_rank, len(uuids), self.vit_dp)] + if assigned_uuids: + for vit_tp_rank in range(self.vit_tp): + task = asyncio.create_task(self.model_rpcs[vit_dp_rank][vit_tp_rank].encode(assigned_uuids)) + tasks.append(task) + + # rets = [self.model_rpcs[tp_rank].encode(images) for tp_rank in range(self.world_size)] + await asyncio.gather(*tasks) return async def loop_for_fwd(self): @@ -112,12 +115,10 @@ async def loop_for_fwd(self): if len(uuids_need_infer) > 0: reqs_need_infer.append(req) else: - # print(" + no need need infer, send to router...") self.send_to_router.send_pyobj(req) await self.infer_imgs(uuids_need_infer) for req in reqs_need_infer: - # print(" + after infer_imgs, send to router...") self.send_to_router.send_pyobj(req) async def loop_for_netio_req(self): @@ -140,7 +141,7 @@ def clean_up(self): return -def start_visual_process(args, router_port, visual_port, client_port, pipe_writer): +def start_visual_process(args, router_port, visual_port, cache_port, model_rpc_ports, pipe_writer): # 注册graceful 退出的处理 from lightllm.utils.graceful_utils import graceful_registry import inspect @@ -148,7 +149,7 @@ def start_visual_process(args, router_port, visual_port, client_port, pipe_write graceful_registry(inspect.currentframe().f_code.co_name) try: - visualserver = VisualManager(args, router_port, visual_port, client_port) + visualserver = VisualManager(args, router_port, visual_port, cache_port, model_rpc_ports) asyncio.run(visualserver.wait_to_model_ready()) except Exception as e: import traceback diff --git a/lightllm/server/visualserver/model_infer/model_rpc.py b/lightllm/server/visualserver/model_infer/model_rpc.py index 55aebdef..be341d49 100644 --- a/lightllm/server/visualserver/model_infer/model_rpc.py +++ b/lightllm/server/visualserver/model_infer/model_rpc.py @@ -2,6 +2,7 @@ import numpy as np import rpyc import torch +import os from datetime import timedelta from typing import Dict, List, Tuple from transformers.configuration_utils import PretrainedConfig @@ -12,29 +13,39 @@ from lightllm.models.internlm_xcomposer.internlm_visual import InternVisionModel from lightllm.models.internvl.internvl_visual import InternVLVisionModel from lightllm.models.qwen2_vl.qwen2_visual import Qwen2VisionTransformerPretrainedModel +from lightllm.server.embed_cache.utils import tensor2bytes, read_shm, create_shm, get_shm_name_data, get_shm_name_embed from lightllm.utils.infer_utils import set_random_seed from lightllm.utils.infer_utils import calculate_time, mark_start, mark_end class VisualModelRpcServer(rpyc.Service): def exposed_init_model(self, kvargs): - # 注册graceful 退出的处理 - from lightllm.utils.graceful_utils import graceful_registry - import inspect - - graceful_registry(inspect.currentframe().f_code.co_name) - - # import torch - # import torch.distributed as dist - # world_size = kvargs["world_size"] - # if world_size != 1: - # kvargs = obtain(kvargs) - # world_size = kvargs["world_size"] - # dist.init_process_group('nccl', init_method=f'tcp://127.0.0.1:{kvargs["nccl_port"]}', - # rank=self.tp_rank, world_size=world_size) - # torch.cuda.set_device(self.tp_rank) + import torch + import torch.distributed as dist + + self.vit_dp = kvargs["vit_dp"] + self.vit_tp = kvargs["vit_tp"] + self.dp_rank_id = kvargs["dp_rank_id"] + self.tp_rank_id = kvargs["tp_rank_id"] + self.cache_port = kvargs["cache_port"] weight_dir = kvargs["weight_dir"] + visual_gpu_ids = kvargs["visual_gpu_ids"] + visual_nccl_port = kvargs["visual_nccl_port"] + self.vit_rank_id = kvargs["vit_rank_id"] + self.cache_client = rpyc.connect("localhost", self.cache_port) + + torch.cuda.set_device(visual_gpu_ids[self.vit_rank_id]) + if self.vit_tp != 1: + dist.init_process_group( + backend="nccl", + init_method=f"tcp://127.0.0.1:{visual_nccl_port}", + rank=self.tp_rank_id, + world_size=self.vit_tp, + ) model_cfg, _ = PretrainedConfig.get_config_dict(weight_dir) + + if self.vit_tp != 1: + raise ValueError(f"ERROR: Not support vit_tp value: {self.vit_tp}") try: self.model_type = model_cfg["model_type"] if self.model_type == "qwen": @@ -46,14 +57,10 @@ def exposed_init_model(self, kvargs): elif self.model_type == "internlmxcomposer2": self.model = InternVisionModel() elif self.model_type == "internvl_chat": - # tp_rank = kvargs['rank_id'] - client_port = kvargs["client_port"] - data_type = kvargs["data_type"] - model_kvargs = {"weight_dir": weight_dir, "client_port": client_port, "data_type": data_type} - self.model = InternVLVisionModel(model_kvargs) - + self.model = InternVLVisionModel() else: raise Exception(f"can not support {self.model_type} now") + self.model.load_model(weight_dir) self.model = self.model.cuda() except Exception as e: @@ -69,20 +76,33 @@ def exposed_init_model(self, kvargs): # @calculate_time(show=True, min_cost_ms=150) @torch.no_grad() - def forward(self, images): - return self.model.encode(images) + def forward(self, images_uuids): + return self.model.encode(images_uuids) # @calculate_time(show=False, min_cost_ms=300) - def exposed_encode(self, images): - return self.forward(images) + def exposed_encode(self, images_uuids): + all_img_embeds, uuids, valid_ids = self.forward(images_uuids) + all_img_embeds = all_img_embeds.to(torch.device("cpu")) + if self.tp_rank_id == 0: + if len(uuids) == 0: + return [all_img_embeds[start:end] for start, end in valid_ids] + else: + for i in range(len(uuids)): + uid = uuids[i] + if not self.cache_client.root.get_item_embed(uid): + start, end = valid_ids[i] + cur_embed_bytes = tensor2bytes(all_img_embeds[start:end]) + create_shm(get_shm_name_embed(uuids[i]), cur_embed_bytes) + self.cache_client.root.set_item_embed(uuids[i]) + return class VisualModelRpcClient: - def __init__(self, model_rpc, world_size, rpc_server_process=None): + def __init__(self, model_rpc, vit_tp, rpc_server_process=None): self.model: VisualModelRpcServer = model_rpc - self.world_size = world_size + self.vit_tp = vit_tp self.rpc_server_process = rpc_server_process - self.use_rpc = self.world_size != 1 + self.use_rpc = True if self.use_rpc: def async_wrap(f): @@ -111,14 +131,44 @@ async def init_model(self, kvargs): else: return - async def encode(self, images): - ans = self._encode(images) + async def encode(self, uuids): + ans = self._encode(uuids) if self.use_rpc: return await ans else: return ans -async def start_model_process(world_size): - if world_size == 1: - return VisualModelRpcClient(VisualModelRpcServer(), world_size) +def _init_env(port): + # 注册graceful 退出的处理 + from lightllm.utils.graceful_utils import graceful_registry + import inspect + + graceful_registry(inspect.currentframe().f_code.co_name) + + from rpyc.utils.server import ThreadedServer + + t = ThreadedServer(VisualModelRpcServer(), port=port, protocol_config={"allow_pickle": True}) + t.start() + return + + +async def start_model_process(port, vit_tp): + import multiprocessing + + proc = multiprocessing.Process(target=_init_env, args=(port,)) + proc.start() + await asyncio.sleep(2) + repeat_count = 0 + while repeat_count < 20: + try: + con = rpyc.connect("localhost", port, config={"allow_pickle": True}) + break + except BaseException: + await asyncio.sleep(1) + repeat_count += 1 + if repeat_count == 20: + raise Exception("init rpc env error!") + + assert proc.is_alive() + return VisualModelRpcClient(con.root, vit_tp, rpc_server_process=proc) diff --git a/lightllm/utils/net_utils.py b/lightllm/utils/net_utils.py index acac8fda..b58e0b09 100644 --- a/lightllm/utils/net_utils.py +++ b/lightllm/utils/net_utils.py @@ -1,12 +1,12 @@ import socket -def alloc_can_use_network_port(num=3, used_nccl_port=None): +def alloc_can_use_network_port(num=3, used_nccl_ports=None): port_list = [] for port in range(10000, 65536): with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: result = s.connect_ex(("localhost", port)) - if result != 0 and port != used_nccl_port: + if result != 0 and port not in used_nccl_ports: port_list.append(port) if len(port_list) == num: